diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..7e7ba98c00744c2f2f4d67c511bcf5f0005b10be --- /dev/null +++ b/.clang-format @@ -0,0 +1,209 @@ +--- +# 计算软件平台部C/C++代码格式 +# version: 0.1 +# C +BasedOnStyle: LLVM +# 关闭格式化 +DisableFormat: false +# tab宽度 +TabWidth: 4 +# 使用tab字符: Never, ForIndentation, ForContinuationAndIndentation, Always +UseTab: Never +# 左对齐逃脱换行(使用反斜杠换行)的反斜杠 +AlignEscapedNewlines: Left +# 连续赋值时,对齐所有等号 +AlignConsecutiveAssignments: None +# 连续声明时,对齐所有声明的变量名 +AlignConsecutiveDeclarations: None +# 开括号(开圆括号、开尖括号、开方括号)后的对齐: Align, DontAlign, AlwaysBreak(总是在开括号后换行) +AlignAfterOpenBracket: Align +# 水平对齐二元和三元表达式的操作数 +AlignOperands: Align +# 指针和引用的对齐,可选项:Left, Right, Middle +PointerAlignment: Right +# 对齐连续的尾随的注释 +AlignTrailingComments: true +# 超过行宽,强制将所有参数放在下一行 +AllowAllArgumentsOnNextLine: false +# 超过行宽,函数声明的所有参数强制放在下一行 +AllowAllParametersOfDeclarationOnNextLine: false +# 允许短的块放在同一行 +AllowShortBlocksOnASingleLine: false +# 允许短的case标签放在同一行 +AllowShortCaseLabelsOnASingleLine: false +# 允许短的函数放在同一行: None, InlineOnly(定义在类中), Empty(空函数), Inline(定义在类中,空函数), All +AllowShortFunctionsOnASingleLine: None +# 允许短的if语句保持在同一行 +AllowShortIfStatementsOnASingleLine: Never +# 允许短的循环保持在同一行 +AllowShortLoopsOnASingleLine: false +# 允许短的枚举保持在同一行 +AllowShortEnumsOnASingleLine: false +# 总是在返回类型后换行: None, All, TopLevel(顶级函数,不包括在类中的函数), +# AllDefinitions(所有的定义,不包括声明), TopLevelDefinitions(所有的顶级函数的定义) +AlwaysBreakAfterReturnType: None +# 每行字符的限制,0表示没有限制(规范120,实际一般推荐80) +ColumnLimit: 120 +# false表示函数实参要么都在同一行,要么都各自一行 +BinPackArguments: true +# false表示所有形参要么都在同一行,要么都各自一行 +BinPackParameters: true +# 在大括号前换行: Attach(始终将大括号附加到周围的上下文), Linux(除函数、命名空间和类定义,与Attach类似), +# Mozilla(除枚举、函数、记录定义,与Attach类似), Stroustrup(除函数定义、catch、else,与Attach类似), +# Allman(总是在大括号前换行), GNU(总是在大括号前换行,并对于控制语句的大括号增加额外的缩进), WebKit(在函数前换行), Custom +# 注:这里认为语句块也属于函数 +BreakBeforeBraces: Custom +# 大括号换行,只有当BreakBeforeBraces设置为Custom时才有效 +BraceWrapping: + # class定义后面 + AfterClass: false + # 控制语句后面 + AfterControlStatement: Never + # enum定义后面 + AfterEnum: false + # 函数定义后面 + AfterFunction: true + # 命名空间定义后面 + AfterNamespace: false + # ObjC定义后面 + AfterObjCDeclaration: false + # struct定义后面 + AfterStruct: false + # union定义后面 + AfterUnion: false + # catch之前 + BeforeCatch: false + # else之前 + BeforeElse: false + # 缩进大括号 + IndentBraces: false + AfterExternBlock: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +# 在二元运算符前换行: None(在操作符后换行), NonAssignment(在非赋值的操作符前换行), All(在操作符前换行) +BreakBeforeBinaryOperators: None +# 在三元运算符前换行 +BreakBeforeTernaryOperators: false +# 超长的字符串分成多行 +BreakStringLiterals: true +# 延续的行的缩进宽度 +ContinuationIndentWidth: 8 +# 将函数的返回类型放到它自己的行的penalty +PenaltyReturnTypeOnItsOwnLine: 80 +PenaltyBreakAssignment: 100 +PenaltyExcessCharacter: 100 +PenaltyBreakBeforeFirstCallParameter: 100 +# 总是在多行string字面量前换行 +AlwaysBreakBeforeMultilineStrings: false +# 头文件相关 +IncludeBlocks: Regroup +SortIncludes: false +IncludeCategories: + - Regex: '^"' + Priority: 1 + - Regex: '^<' + Priority: 2 +# 缩进 Case 代码块 +IndentCaseBlocks: false +# 缩进case标签 +IndentCaseLabels: true +# 缩进宽度 +IndentWidth: 4 +# 保留在块开始处的空行 +KeepEmptyLinesAtTheStartOfBlocks: false +# 预处理指令的缩进级别 +IndentPPDirectives: None +# 连续空行的最大数量 +MaxEmptyLinesToKeep: 1 +# 在C风格类型转换后添加空格 +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: false +# 在赋值运算符之前添加空格 +SpaceBeforeAssignmentOperators: true +#SpaceBeforeCaseColon: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +# 开圆括号之前添加一个空格: Never, ControlStatements, Always +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceBeforeSquareBrackets: false +# 在空的圆括号中添加空格 +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +# 在尾随的评论前添加的空格数(只适用于//) +SpacesBeforeTrailingComments: 2 +# 在尖括号的<后和>前添加空格 +SpacesInAngles: false +# 在容器(ObjC和JavaScript的数组和字典等)字面量中添加空格 +SpacesInContainerLiterals: true +# 在C风格类型转换的括号中添加空格 +SpacesInCStyleCastParentheses: false +SpacesInConditionalStatement: false +# 在圆括号的(后和)前添加空格 +SpacesInParentheses: false +# 在方括号的[后和]前添加空格,Lambda表达式和未指明大小的数组的声明不受影响 +SpacesInSquareBrackets: false +# 需要被解读为foreach循环而不是函数调用的宏(根据项目自定义添加) +ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] +# 需要被解读为Statement而不是函数调用的宏(根据项目自定义添加) +StatementMacros : [] +# 需要被解读为Typename而不是函数调用的宏(根据项目自定义添加) +TypenameMacros: [] +# 宏白名单 +WhitespaceSensitiveMacros: [] + +# CPP +# 标准: Cpp03, Cpp11, Auto +Standard: Cpp11 +# 访问说明符(public、private等)的偏移 +AccessModifierOffset: -4 +# 总是在template声明后换行 +AlwaysBreakTemplateDeclarations: false +# 多重继承的情况下是否在继承列表的逗号之前插入换行符 +BreakBeforeInheritanceComma: false +# 继承最常用的指针和引用的对齐方式 +DerivePointerAlignment: false +# 总是在定义返回类型后换行(deprecated) +AlwaysBreakAfterDefinitionReturnType: None +# 在构造函数的初始化列表的逗号前换行 +BreakConstructorInitializersBeforeComma: false +# 在构造函数的初始化列表的逗号后换行 +BreakConstructorInitializers: BeforeColon +# 描述具有特殊意义的注释的正则表达式,它不应该被分割为多行或以其它方式改变 +CommentPragmas: '^lint' +# 构造函数的初始化列表要么都在同一行,要么都各自一行 +ConstructorInitializerAllOnOneLineOrOnePerLine: true +# 构造函数的初始化列表的缩进宽度 +ConstructorInitializerIndentWidth: 4 +# 去除C++11的列表初始化的大括号{后和}前的空格 +Cpp11BracedListStyle: true +# 自动检测函数的调用和定义是否被格式为每行一个参数(Experimental) +ExperimentalAutoDetectBinPacking: false +# 函数返回类型换行时,缩进函数声明或函数定义的函数名 +IndentWrappedFunctionNames: false +# 开始一个块的宏的正则表达式 +MacroBlockBegin: '' +# 结束一个块的宏的正则表达式 +MacroBlockEnd: '' +# 命名空间的缩进: None, Inner(缩进嵌套的命名空间中的内容), All +NamespaceIndentation: None +# 使用ObjC块时缩进宽度 +ObjCBlockIndentWidth: 4 +# 在ObjC的@property后添加一个空格 +ObjCSpaceAfterProperty: false +# 在ObjC的protocol列表前添加一个空格 +ObjCSpaceBeforeProtocolList: true +# 在一个注释中引入换行的penalty +PenaltyBreakComment: 300 +# 第一次在<<前换行的penalty +PenaltyBreakFirstLessLess: 120 +# 在一个字符串字面量中引入换行的penalty +PenaltyBreakString: 1000 +# 允许重新排版注释 +ReflowComments: false +# 允许合并namespace +CompactNamespaces: false +... diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..89fdd23e7069ef83f68f309d367e8b6bb1642b91 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,234 @@ +cmake_minimum_required(VERSION 3.14.1) +project(hcom CXX C) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_SKIP_BUILD_RPATH ON) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE release) +endif() +message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") + +if(NOT DEFINED CPU_TYPE) + set(CPU_TYPE "arm_64") +endif() +message(STATUS "CPU_TYPE: ${CPU_TYPE}") + +if (NOT BUILD_SERVICE_VERSION) + set(BUILD_SERVICE_VERSION 1) +endif() +add_definitions(-DBUILD_SERVICE_VERSION=${BUILD_SERVICE_VERSION}) +message(STATUS "BUILD_SERVICE_VERSION: ${BUILD_SERVICE_VERSION}") + +option(BUILD_JAVA_SDK "Build java sdk" OFF) +option(BUILD_TESTS "Build unit tests" OFF) +option(BUILD_WITH_HW_CRC "Build with hardware crc" OFF) +option(BUILD_WITH_SERVICE "Build with SERVICE" ON) +option(BUILD_WITH_RDMA "Build with RDMA" ON) +option(BUILD_WITH_UB "Build with UB" OFF) +option(BUILD_WITH_SHM "Build with SHM" ON) +option(BUILD_WITH_SOCK "Build with SOCK" ON) +option(BUILD_WITH_ALLOCATOR_PROTECTION "Build with allocator protection" OFF) +option(BUILD_WITH_HTRACER "Build with htracer" ON) +option(ENABLE_ARM_KP "enable arm kunpeng" OFF) + +if (ENABLE_ARM_KP) + add_definitions(-DENABLE_ARM_KP) +endif () + +# Default to ld.gold if available. +option(USE_GOLD_LINKER "Use ld.gold to link" ON) +if(USE_GOLD_LINKER) + execute_process( + COMMAND ${CMAKE_C_COMPILER} -fuse-ld=gold -Wl,--version + ERROR_QUIET + OUTPUT_VARIABLE LD_GOLD_VERSION + ) + if("${LD_GOLD_VERSION}" MATCHES "GNU gold") + message(STATUS "ld.gold is available, using it to link") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fuse-ld=gold") + else() + message(WARNING "USE_GOLD_LINKER is set but ld.gold is not available, fallback to ld.bfd") + set(USE_GOLD_LINKER OFF) + endif() +endif() + +# Obtain commit information from Git. +if(NOT HCOM_COMMIT_ID) + execute_process(COMMAND git rev-parse HEAD + OUTPUT_VARIABLE HCOM_COMMIT_ID + OUTPUT_STRIP_TRAILING_WHITESPACE + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}") + + if(NOT HCOM_COMMIT_ID) + set(HCOM_COMMIT_ID "") + endif() +endif() + +add_definitions(-DHCOM_COMMIT_ID="${HCOM_COMMIT_ID}") +message(STATUS "Build hcom on commit ${HCOM_COMMIT_ID}") + +if (BUILD_TESTS) + enable_testing() + add_definitions(-DBUILD_TESTS) + + # GTest can be found via find_package. + # find_package(GTest CONFIG REQUIRED) + # target_link_libraries(foo GTest::gtest_main) # or GTest::gtest w/o the main function + if(NOT EXISTS ${TEST_TOOL_INSTALL_PATH}) + message(FATAL_ERROR "Specify -DTEST_TOOL_INSTALL_PATH=/path/to/dir in cmake command args.") + endif() + list(APPEND CMAKE_PREFIX_PATH ${TEST_TOOL_INSTALL_PATH}/googletest/lib64/cmake) +endif () + +if (BUILD_WITH_RDMA) + add_definitions(-DRDMA_BUILD_ENABLED) +endif () + +if (BUILD_WITH_SHM) + add_definitions(-DSHM_BUILD_ENABLED) +endif () + +if (BUILD_WITH_SOCK) + add_definitions(-DSOCK_BUILD_ENABLED) +endif () + +if (BUILD_WITH_UB) + add_definitions(-DUB_BUILD_ENABLED) +endif () + +if (BUILD_WITH_ALLOCATOR_PROTECTION) + add_definitions(-DALLOCATOR_PROTECTION_ENABLED) +endif () + +if (NOT HCOM_COMPONENT_VERSION) + add_definitions(-DHCOM_COMPONENT_VERSION="1.0.0") +else () + add_definitions(-DHCOM_COMPONENT_VERSION="${HCOM_COMPONENT_VERSION}") +endif () + +if (${CMAKE_BUILD_TYPE} MATCHES "release") + set(CXX_FLAGS + -g + -pipe + -O3 + -Wall + -fms-extensions + -Wno-unused-parameter + -Wno-unused-function + -w -Wno-address-of-packed-member + -Wunused-variable + -Wunused-value + -Wcast-align + -Wcast-qual + -Winvalid-pch + -Wwrite-strings + -Wsign-compare + -Wfloat-equal + -Wextra + -std=c++11 + -fPIC + -fstack-protector-strong + -Wl,-z,relro,-z,now,-z,noexecstack + ) +else () + set(CXX_FLAGS + -pipe + -g + -rdynamic + -Wall + -fms-extensions + -Wno-unused-parameter + -Wno-unused-function + -w -Wno-address-of-packed-member + -Wunused-variable + -Wunused-value + -Winvalid-pch + -Wcast-align + -Wcast-qual + -Wwrite-strings + -Wsign-compare + -Wfloat-equal + -Wextra + -std=c++11 + #-fsanitize=address + #-fno-omit-frame-pointer + -fstack-protector-strong + #-fPIC will cause failure while mocking system functions + ) + add_definitions(-DDEBUG) + if (NOT BUILD_TESTS) + set(CXX_FLAGS "${CXX_FLAGS} -fPIC") + endif () +endif() + +string(REPLACE ";" " " CMAKE_CXX_FLAGS "${CXX_FLAGS}") + +#set(CMAKE_VERBOSE_MAKEFILE on) + +# enable asan +# add_compile_options(-fsanitize=address) +# add_link_options(-fsanitize=address) + +add_definitions(-DUSE_PROCESS_MONOTONIC) + +if(BUILD_WITH_HW_CRC) + add_definitions(-DUSE_HARDWARE_CRC) + if (${CPU_TYPE} MATCHES "arm_64") + add_definitions(-march=armv8-a+crc) + endif() +endif() +message(STATUS "BUILD_WITH_HW_CRC: ${BUILD_WITH_HW_CRC}") + +#enable object statistics +add_definitions(-DENABLE_OBJ_GLOBAL_STATISTICS) + +set(CMAKE_C_STANDARD 11) + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O2") + +set(HCOM_SRC_DIR ${PROJECT_SOURCE_DIR}) +set(HCOM_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/dist) +set(HCOM_BUILD_PATH ${PROJECT_SOURCE_DIR}/build) + +set(HCOM_STUB_SRC_DIR ${PROJECT_SOURCE_DIR}/test/stub) + +if (NOT DEFINED ENV{JAVA_HOME}) + set(JAVA_HOME /usr/local/jdk) +else () + set(JAVA_HOME $ENV{JAVA_HOME}) +endif () + +include_directories(${HCOM_SRC_DIR}/src/under_api/verbs + ${HCOM_SRC_DIR}/src/under_api/urma + ${HCOM_SRC_DIR}/src/under_api/openssl + ${HCOM_SRC_DIR}/src/under_api/obmm + ${HCOM_SRC_DIR}/src + ${HCOM_SRC_DIR}/src/api + ${HCOM_SRC_DIR}/src/api/capi + ${HCOM_SRC_DIR}/src/api/java_sdk + ${HCOM_SRC_DIR}/src/common + ${HCOM_SRC_DIR}/src/service + ${HCOM_SRC_DIR}/src/service_v2 + ${HCOM_SRC_DIR}/src/service_v2/api + ${HCOM_SRC_DIR}/src/transport + ${HCOM_SRC_DIR}/src/transport/rdma + ${HCOM_SRC_DIR}/src/transport/rdma/verbs + ${HCOM_SRC_DIR}/src/transport/ub + ${HCOM_SRC_DIR}/src/transport/shm + ${HCOM_SRC_DIR}/src/transport/sock + ${HCOM_STUB_SRC_DIR}/ + ${JAVA_HOME}/include + ${JAVA_HOME}/include/linux) + +# Reset the prefix path if user doesn't provide one. +if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + set_property(CACHE CMAKE_INSTALL_PREFIX PROPERTY VALUE ${PROJECT_BINARY_DIR}) +endif() + +add_subdirectory(src) + +if(BUILD_TESTS) + add_subdirectory(test) +endif() +message(STATUS "BUILD_TESTS: ${BUILD_TESTS}") diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..e46202f955441c97cd4386918a86657a4f9b77e8 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,92 @@ +# HCOM项目贡献 + +这篇文章主要介绍了如何进行`HCOM`的贡献 + +## 加入贡献 + +`HCOM`项目欢迎新成员加入并贡献,新成员需要积极的为`HCOM`项目做出贡献。我们乐于接受提交`Issue`/处理`Issue`、代码、文档、检视意见、测试等类型的贡献。 + +### 1.提交 `Issue` / 处理 `Issue` +`Issue` 是用来记录和追踪开发者的想法、反馈、任务和缺陷。您可以通过提交`Issue`到`HCOM`项目进行贡献。常见的`Issue`场景有: + +a)报告 bug + +b)提交建议 + +c)记录一个待完成任务 + +d)指出文档缺失/安装问题 + +e)答疑交流 + +#### 1.1 找到 `Issue` 列表: + +在`HCOM`的代码仓中,点击工具栏目的 “Issues”,您可以找到其 `Issue` 列表 + +#### 1.2 提交 `Issue`: + +如果您准备上报`Bug`或者提交需求,为`HCOM`贡献自己的意见或建议,可以在`HCOM`仓库上提交`Issue`(请参考 [Issue 提交指南](https://gitee.com/openeuler/community/blob/master/zh/contributors/issue-submit.md) )。为了吸引更广泛的注意,您也可以把`Issue`的链接附在邮件内,通过 [邮件列表](https://www.openeuler.openatom.cn/zh/community/mailing-list/) 发送给所有人。 + +#### 1.3 参与 `Issue` 内的讨论: + +每个 Issue 下面可能有参与者们的交流和讨论,如果您感兴趣,可以在评论框中发表自己的意见。 + +#### 1.4 找到愿意处理的 `Issue`: + +如果您愿意处理其中的一个 `Issue`,可以将它分配给自己。只需要在评论框内输入`/assign`或`/assign @yourself`,机器人就会将`Issue`分配给您,您的名字将显示在负责人列表里。 + +### 2.贡献编码 +#### 2.1 搭建开发环境 + +1.开发环境准备:如果您想参与编码贡献,需要准备`HCOM`的开发环境,请参考`doc`目录下的《HCOM用户指南》搭建并准备开发环境。 + +2.下载和构建软件包:如果您想下载、修改、构建及验证`HCOM`提供的软件包,请参考`README.md`进行编译、构建、用例验证。 + + +#### 2.2 下载代码和拉分支 + +如果要参与代码贡献,您还需要了解如何在 `Gitee` 下载代码,通过 `PR`(`Pull Request`) 合入代码等。`HCOM` 使用 `Gitee` 代码托管平台,想了解具体的指导,请参考 [Gitee Workflow Guide](https://gitee.com/openeuler/community/blob/master/zh/contributors/Gitee-workflow.md) 。该托管平台的使用方法类似`GitHub`,如果您以前使用过`GitHub`,本节的内容您可以大致了解甚至跳过。 + +#### 2.3 修改构建和本地验证 + +在本地分支上完成修改后,进行构建和本地验证,请参考构建软件包。 + +#### 2.4 提交一个 `PR`(`Pull Request`) + +当您提交一个`PR`的时候,就意味您已经开始给社区贡献代码了。请参考 [openEuler 社区 PR 提交指导](https://gitee.com/openeuler/community/blob/master/zh/contributors/pull-request.md) 。为了使您的提交更容易被接受,您需要: + +1.代码要遵循以下几个原则 + +可读性 - 重要代码应充分注释,`API`应具备文档,代码风格应遵循现有规范。 + +优雅性 - 新增功能、类或组件应设计精良。 + +可测试性 - 新增代码的 70% 应被单元测试覆盖。 + +2.准备完善的提交信息 + +3.如果一次提交的代码量较大,建议将大型的内容分解成一系列逻辑上较小的内容,分别进行提交会更便于检视者理解您的想法 + +### 3.检视代码 +`HCOM`非常欢迎所有参与的人都能成为活跃的检视者。可以参考 [社区成员](https://gitee.com/openeuler/community/blob/master/community-membership_cn.md) ,该文档描述了不同贡献者的角色职责。 + +当成为`HCOM`项目的`Committer`或`Maintainer`角色时,便拥有审核代码的责任与权利。强烈建议本着[社区行为准则](https://gitee.com/openeuler/community/blob/master/code-of-conduct.md) ,超越自我,相互尊重和促进协作。在检视其他人的`PR`时,可以重点关注包括: + +1.贡献背后的想法是否合理 + +2.贡献的架构是否正确 + +3.贡献是否完善 + +### 4.测试 +测试是所有贡献者的责任,对于社区版本来说,`sig-QA` 组是负责测试活动的社区官方组织。如果您希望在自己的基础架构上开展测试活动,可以参考:[社区测试体系介绍](https://gitee.com/openeuler/QA/blob/master/%E7%A4%BE%E5%8C%BA%E6%B5%8B%E8%AF%95%E4%BD%93%E7%B3%BB%E4%BB%8B%E7%BB%8D.md) 。 + +为了成功发行一个社区版本,`openEuler` 需要完成多种测试活动。不同的测试活动,测试代码的位置也有所不同,成功运行测试所需的环境细节也会有差异,有关的信息可以参考 [测试指南](https://gitee.com/openeuler/QA/blob/master/%E7%A4%BE%E5%8C%BA%E5%BC%80%E5%8F%91%E8%80%85%E6%B5%8B%E8%AF%95%E8%B4%A1%E7%8C%AE%E6%8C%87%E5%8D%97.md) 。 + +### 5.社区安全问题披露 +[安全处理流程](https://gitee.com/openeuler/security-committee/blob/master/docs/zh/vulnerability-management-process/security-process.md) ——简要描述了处理安全问题的过程。 + +[安全披露信息](https://gitee.com/openeuler/security-committee/blob/master/docs/zh/vulnerability-management-process/security-disclosure.md) ——如果您希望报告安全漏洞,请参考此页面。 + +### 6.遗漏事项 +您如果发现了本指南不足的点,或者您对某些特定步骤感到困惑,请告诉我们!或者您可以选择提交一个`PR`来解决这个问题 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4b0b13e43bf2624e6f95e79ebd87a4e209c02a07 --- /dev/null +++ b/LICENSE @@ -0,0 +1,132 @@ +木兰宽松许可证, 第2版 + +木兰宽松许可证, 第2版 + +2020年1月 http://license.coscl.org.cn/MulanPSL2 + +您对“软件”的复制、使用、修改及分发受木兰宽松许可证,第2版(“本许可证”)的如下条款的约束: + +0. 定义 + +“软件” 是指由“贡献”构成的许可在“本许可证”下的程序和相关文档的集合。 + +“贡献” 是指由任一“贡献者”许可在“本许可证”下的受版权法保护的作品。 + +“贡献者” 是指将受版权法保护的作品许可在“本许可证”下的自然人或“法人实体”。 + +“法人实体” 是指提交贡献的机构及其“关联实体”。 + +“关联实体” 是指,对“本许可证”下的行为方而言,控制、受控制或与其共同受控制的机构,此处的控制是指有受控方或共同受控方至少50%直接或间接的投票权、资金或其他有价证券。 + +1. 授予版权许可 + +每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的版权许可,您可以复制、使用、修改、分发其“贡献”,不论修改与否。 + +2. 授予专利许可 + +每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的(根据本条规定撤销除外)专利许可,供您制造、委托制造、使用、许诺销售、销售、进口其“贡献”或以其他方式转移其“贡献”。前述专利许可仅限于“贡献者”现在或将来拥有或控制的其“贡献”本身或其“贡献”与许可“贡献”时的“软件”结合而将必然会侵犯的专利权利要求,不包括对“贡献”的修改或包含“贡献”的其他结合。如果您或您的“关联实体”直接或间接地,就“软件”或其中的“贡献”对任何人发起专利侵权诉讼(包括反诉或交叉诉讼)或其他专利维权行动,指控其侵犯专利权,则“本许可证”授予您对“软件”的专利许可自您提起诉讼或发起维权行动之日终止。 + +3. 无商标许可 + +“本许可证”不提供对“贡献者”的商品名称、商标、服务标志或产品名称的商标许可,但您为满足第4条规定的声明义务而必须使用除外。 + +4. 分发限制 + +您可以在任何媒介中将“软件”以源程序形式或可执行形式重新分发,不论修改与否,但您必须向接收者提供“本许可证”的副本,并保留“软件”中的版权、商标、专利及免责声明。 + +5. 免责声明与责任限制 + +“软件”及其中的“贡献”在提供时不带任何明示或默示的担保。在任何情况下,“贡献者”或版权所有者不对任何人因使用“软件”或其中的“贡献”而引发的任何直接或间接损失承担责任,不论因何种原因导致或者基于何种法律理论,即使其曾被建议有此种损失的可能性。 + +6. 语言 + +“本许可证”以中英文双语表述,中英文版本具有同等法律效力。如果中英文版本存在任何冲突不一致,以中文版为准。 + +条款结束 + +如何将木兰宽松许可证,第2版,应用到您的软件 + +如果您希望将木兰宽松许可证,第2版,应用到您的新软件,为了方便接收者查阅,建议您完成如下三步: + +1, 请您补充如下声明中的空白,包括软件名、软件的首次发表年份以及您作为版权人的名字; + +2, 请您在软件包的一级目录下创建以“LICENSE”为名的文件,将整个许可证文本放入该文件中; + +3, 请将如下声明文本放入每个源文件的头部注释中。 + +Copyright (c) [Year] [name of copyright holder] +[Software Name] is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +See the Mulan PSL v2 for more details. + + +Mulan Permissive Software License,Version 2 + +Mulan Permissive Software License,Version 2 (Mulan PSL v2) + +January 2020 http://license.coscl.org.cn/MulanPSL2 + +Your reproduction, use, modification and distribution of the Software shall be subject to Mulan PSL v2 (this License) with the following terms and conditions: + +0. Definition + +Software means the program and related documents which are licensed under this License and comprise all Contribution(s). + +Contribution means the copyrightable work licensed by a particular Contributor under this License. + +Contributor means the Individual or Legal Entity who licenses its copyrightable work under this License. + +Legal Entity means the entity making a Contribution and all its Affiliates. + +Affiliates means entities that control, are controlled by, or are under common control with the acting entity under this License, ‘control’ means direct or indirect ownership of at least fifty percent (50%) of the voting power, capital or other securities of controlled or commonly controlled entity. + +1. Grant of Copyright License + +Subject to the terms and conditions of this License, each Contributor hereby grants to you a perpetual, worldwide, royalty-free, non-exclusive, irrevocable copyright license to reproduce, use, modify, or distribute its Contribution, with modification or not. + +2. Grant of Patent License + +Subject to the terms and conditions of this License, each Contributor hereby grants to you a perpetual, worldwide, royalty-free, non-exclusive, irrevocable (except for revocation under this Section) patent license to make, have made, use, offer for sale, sell, import or otherwise transfer its Contribution, where such patent license is only limited to the patent claims owned or controlled by such Contributor now or in future which will be necessarily infringed by its Contribution alone, or by combination of the Contribution with the Software to which the Contribution was contributed. The patent license shall not apply to any modification of the Contribution, and any other combination which includes the Contribution. If you or your Affiliates directly or indirectly institute patent litigation (including a cross claim or counterclaim in a litigation) or other patent enforcement activities against any individual or entity by alleging that the Software or any Contribution in it infringes patents, then any patent license granted to you under this License for the Software shall terminate as of the date such litigation or activity is filed or taken. + +3. No Trademark License + +No trademark license is granted to use the trade names, trademarks, service marks, or product names of Contributor, except as required to fulfill notice requirements in section 4. + +4. Distribution Restriction + +You may distribute the Software in any medium with or without modification, whether in source or executable forms, provided that you provide recipients with a copy of this License and retain copyright, patent, trademark and disclaimer statements in the Software. + +5. Disclaimer of Warranty and Limitation of Liability + +THE SOFTWARE AND CONTRIBUTION IN IT ARE PROVIDED WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED. IN NO EVENT SHALL ANY CONTRIBUTOR OR COPYRIGHT HOLDER BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE SOFTWARE OR THE CONTRIBUTION IN IT, NO MATTER HOW IT’S CAUSED OR BASED ON WHICH LEGAL THEORY, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Language + +THIS LICENSE IS WRITTEN IN BOTH CHINESE AND ENGLISH, AND THE CHINESE VERSION AND ENGLISH VERSION SHALL HAVE THE SAME LEGAL EFFECT. IN THE CASE OF DIVERGENCE BETWEEN THE CHINESE AND ENGLISH VERSIONS, THE CHINESE VERSION SHALL PREVAIL. + +END OF THE TERMS AND CONDITIONS + +How to Apply the Mulan Permissive Software License,Version 2 (Mulan PSL v2) to Your Software + +To apply the Mulan PSL v2 to your work, for easy identification by recipients, you are suggested to complete following three steps: + +i. Fill in the blanks in following statement, including insert your software name, the year of the first publication of your software, and your name identified as the copyright owner; + +ii. Create a file named "LICENSE" which contains the whole context of this License in the first directory of your software package; + +iii. Attach the statement to the appropriate annotated syntax at the beginning of each source file. + +Copyright (c) [Year] [name of copyright holder] +[Software Name] is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +See the Mulan PSL v2 for more details. \ No newline at end of file diff --git a/README.md b/README.md index d1c6d2f099ef5cdb669dc7e64641fe1b5cd58937..f4eb97e4fe6b8cad465f261a12c5e656c2101112 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,132 @@ -# ubs-comm +# HCOM -#### 介绍 -Ubs-comm Provides high-performance, high-reliability, and ecosystem-compatible(user-mode socket/verbs over UB) communication protocols based on UB superpods. +`HCOM`是一个适用于C/S架构应用程序的高性能通信库,主要有以下特征: -#### 软件架构 -软件架构说明 +- **高易用性**:`HCOM`底层支持多种网卡硬件及通信协议(如`RDMA`、`TCP`、`SHM`、`UB`),屏蔽了这些硬件或传输协议间的差异,向开发者提供统一的API。此外,`HCOM`还提供了`QoS`能力(如流控、故障检测、消息重传等),认证加密能力等,进一步方便开发者使用。 +- **高性能**:`HCOM`通过软硬件结合,实现极致高性能。针对不同的场景,软件实现了多线程管理、`RNDV`(Rendezvous协议,用于大包场景)、`MultiRail`(多网口聚合,充分利用网络带宽)等加速特性。 +## 1 源码下载 -#### 安装教程 +可以使用如下两种方式下载HCOM源码。 -1. xxxx -2. xxxx -3. xxxx +```shell +# 方法一 +$ git clone +$ git submodule update --init --recursive +# 方法二 +$ git clone --recurse-submodules +``` -#### 使用说明 +## 2 源码目录结构 -1. xxxx -2. xxxx -3. xxxx +`HCOM`源码的主要目录结构如下所示。 -#### 参与贡献 +```shell +. +├── build // 存放项目中使用的脚本文件 +├── doc // 存放项目文档,例如《代码架构设计》 +├── src // 存放项目的功能实现源码,仅该目录参与构建出包 +├── test // 存放项目的ut和dtfuzz等 +└── build.sh // 统一的构建入口 +``` -1. Fork 本仓库 -2. 新建 Feat_xxx 分支 -3. 提交代码 -4. 新建 Pull Request +## 3 用户指南 +`HCOM`提供给开发者的的资料主要有以下几本。 +《UBS-COMM-API-Spec》 +《UBS-COMM-Architecture-Design-Specification》 +《UBS-Comm-Tutorial-Demo》 +《UBS-Comm-Tutorial-UseCase》 -#### 特技 +## 4 编译 + +`HCOM`在代码仓中提供了统一的编译构建脚本(即`build.sh`),可以直接执行该脚本编译构建(该脚本同时用于CI流水线构建出包)。默认无需任何配置项,直接执行即可。 + +```shell +$ ./build.sh +``` + +执行完毕后可以在源码的dist目录中找到一个`xxx.tar.gz`的软件包,其核心内容及介绍如下所示。 + +```shell +$ tree +. +├── include // C&C++头文件 +│   └── hcom +│   ├── capi +│   │   ├── hcom_c.h +│   │   └── hcom_service_c.h +│   ├── hcom.h +│   └── hcom_service.h +└── lib // C&C++动态库和静态库 + ├── libhcom.so + └── libhcom_static.a +``` + +可以通过环境变量,对`build.sh`的编译过程进行控制,如下所示。 + +```shell +$ cat build.sh | head -n 23 +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# Script for building HCOM. +# Build options can be configured through environment variables. +# (1) HCOM_BUILD_TYPE(optional, default is release) => set build type.(release/debug) +# (2) HCOM_BUILD_TESTS(optional, default is off) => enable build test or not.(on/off) +# (3) HCOM_BUILD_JAVA_SDK(optional, default is off) => build java sdk or not.(on/off) +# (4) HCOM_BUILD_SERVICE(optional, default is on) => build service level or not.(on/off) +# (5) HCOM_BUILD_RDMA(optional, default is on) => build rdma or not.(on/off) +# (6) HCOM_BUILD_SOCK(optional, default is on) => build sock (tcp/uds) or not.(on/off) +# (7) HCOM_BUILD_SHM(optional, default is on) => build shm or not.(on/off) +# (8) HCOM_BUILD_EXAMPLE(optional, default is off) => build example and perf.(on/off) +# (9) HCOM_ENABLE_ARM_KP(optional, default is on) => check kunpeng or not.(on/off) +# (10) HCOM_TEST_TOOL_PATH(optional) => test tool install path.(mockcpp/gtest/dtfuzz) +# (11) HCOM_CI_WORKSPACE(optional) => ci workspace, for buildInfo.properties file. +# (12) HCOM_BUILD_RPM(optional, default is on) => build rpm.(on/off) +# (13) HCOM_BUILD_TOOLS_PERF(optional, default is off) => build rpm.(on/off) +# (14) HCOM_BUILD_HW_CRC(optional, default is off) => build with hardware based crc.(on/off) + +# version: 1.0.0 +# change log: +# *********************************************************************** +``` + +## 5 编译和执行HCOM性能测试工具 + +HCOM的示例存放在两个目录: + +- test/tools/perf_test目录:存放性能用例,用例链接`HCOM`静态库。 + +考虑门禁构建时间,默认不会编译perf_test用例,请参考以下README文档编译 +``` +lingqu\test\tools\perf_test\README.md +``` +或执行以下命令,开启环境变量后编译 +``` +export HCOM_BUILD_TOOLS_PERF=on +bash build.sh +``` + +## 6 编译和执行UT用例 +可以按照如下方式,手动编译和执行UT用例。 + +```shell +# ut用例中涉及较多mock,mock框架需要知道具体的符号,只能以debug模式编译 +$ export HCOM_BUILD_TYPE=debug +# 构建出包时,默认不编译ut,需手动开启 +$ export HCOM_BUILD_TESTS=on +# 直接执行构建脚本,即可编译 +$ ./build.sh +# 执行UT用例并生成测试报告,耗时较长,结果存放在build目录中 +$ ./build/generate_gtest_report.sh +# 生成UT覆盖率信息,结果存放在build目录中 +$ ./build/generate_lcov_report.sh +``` + +## License +HCOM 采用 Mulan V2 License. + +## 贡献指南 +请阅读 贡献指南 `CONTRIBUTING.md` 以了解如何贡献项目。 -1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md -2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com) -3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 -4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目 -5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) -6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md new file mode 100644 index 0000000000000000000000000000000000000000..4b171d1534c53d81ba3e757a0e50fcf6ebc52ec4 --- /dev/null +++ b/RELEASE-NOTES.md @@ -0,0 +1,8 @@ +1.0.0 +初始版本 + +`HCOM`主要支持以下特性: + +1.`HCOM`底层支持多种网卡硬件及通信协议(如`RDMA`、`TCP`、`SHM`、`UB`),屏蔽了这些硬件或传输协议间的差异,向开发者提供统一的API。此外,`HCOM`还提供了`QoS`能力(如流控、故障检测、消息重传等),认证加密能力等,进一步方便开发者使用。 + +2.`HCOM`通过软硬件结合,实现极致高性能。针对部分常用的网卡型号(如MLX5),使能硬件加速特性;针对不同的场景,软件实现了多线程管理、`RNDV`(Rendezvous协议,用于大包场景)、`MultiRail`(多网口聚合,充分利用网络带宽)等加速特性。 diff --git a/build.sh b/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..b6ca5ce6474e302649082c1d3e203e82c6ef51cd --- /dev/null +++ b/build.sh @@ -0,0 +1,143 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# Script for building HCOM. +# Build options can be configured through environment variables. +# (1) HCOM_BUILD_TYPE(optional, default is release) => set build type.(release/debug) +# (2) HCOM_BUILD_TESTS(optional, default is off) => enable build test or not.(on/off) +# (3) HCOM_BUILD_JAVA_SDK(optional, default is off) => build java sdk or not.(on/off) +# (4) HCOM_BUILD_SERVICE(optional, default is on) => build service level or not.(on/off) +# (5) HCOM_BUILD_RDMA(optional, default is on) => build rdma or not.(on/off) +# (6) HCOM_BUILD_SOCK(optional, default is on) => build sock (tcp/uds) or not.(on/off) +# (7) HCOM_BUILD_SHM(optional, default is on) => build shm or not.(on/off) +# (8) HCOM_BUILD_EXAMPLE(optional, default is off) => build example and perf.(on/off) +# (9) HCOM_ENABLE_ARM_KP(optional, default is on) => check kunpeng or not.(on/off) +# (10) HCOM_TEST_TOOL_PATH(optional) => test tool install path.(mockcpp/gtest/dtfuzz) +# (11) HCOM_CI_WORKSPACE(optional) => ci workspace, for buildInfo.properties file. +# (12) HCOM_BUILD_RPM(optional, default is on) => build rpm.(on/off) +# (13) HCOM_BUILD_TOOLS_PERF(optional, default is off) => build rpm.(on/off) +# (14) HCOM_BUILD_HW_CRC(optional, default is off) => build with hardware based crc.(on/off) + +# version: 1.0.0 +# change log: +# *********************************************************************** +set -eo pipefail + +readonly HCOM_ROOT_DIR=$(cd $(dirname ${0}) && pwd) +readonly HCOM_BUILD_DIR="${HCOM_ROOT_DIR}/tmp_build_dir" +readonly HCOM_LOG_TAG="[$(basename ${0})]" +readonly HCOM_INSTALL_DIR="${HCOM_ROOT_DIR}/dist/hcom" + +echo "HCOM ROOT: ${HCOM_ROOT_DIR}" +echo "HCOM BUILD DIR: ${HCOM_BUILD_DIR}" +echo "HCOM INSTALL DIR: ${HCOM_INSTALL_DIR}" + +HCOM_COMPONENT_VERSION="1.0.0" + +function show_help() { + echo "Usage: $0 [COMMAND] [OPTION]" + echo "Build the project with specified options." + echo "Commands: clean" + echo "Options:" + echo " -t, --type TYPE Set build type. debug/release" +} + +function clean_dir() { + [[ -n "${HCOM_BUILD_DIR}" ]] && rm -rf "${HCOM_BUILD_DIR}" + [[ -n "${HCOM_INSTALL_DIR}" ]] && rm -rf "${HCOM_INSTALL_DIR}" + echo "Cleanup: ${HCOM_BUILD_DIR}, ${HCOM_INSTALL_DIR}" +} + +# 编译类型通过环境变量 HCOM_BUILD_TYPE 和命令行参数 -t 二选一,如果两者都提供了, +# 则优先使用命令行参数。默认编译类型为 Release. +HCOM_BUILD_TYPE="${HCOM_BUILD_TYPE,,}" +HCOM_BUILD_TYPE="${HCOM_BUILD_TYPE:-release}" + +while [[ "$#" -gt 0 ]]; do + case "$1" in + -t|--type) HCOM_BUILD_TYPE="${2,,}"; shift ;; + clean) clean_dir; exit 0 ;; + *) echo "Unknown parameter passed: $1"; show_help; exit 1 ;; + esac + shift +done + +echo "HCOM BUILD TYPE: ${HCOM_BUILD_TYPE}" + +# Hardware CRC is disabled by default +HCOM_BUILD_HW_CRC="${HCOM_BUILD_HW_CRC:-off}" +echo "${HCOM_LOG_TAG} hcom build hw crc: ${HCOM_BUILD_HW_CRC}" + +# check whether build UB, default is off +HCOM_BUILD_UB="${HCOM_BUILD_UB:-off}" +echo "${HCOM_LOG_TAG} hcom build ub: ${HCOM_BUILD_UB}" + +# check whether build service module, default is on +HCOM_BUILD_SERVICE="${HCOM_BUILD_SERVICE:-on}" +echo "${HCOM_LOG_TAG} hcom build service: ${HCOM_BUILD_SERVICE}" + +# check whether build RDMA module, default is on +HCOM_BUILD_RDMA="${HCOM_BUILD_RDMA:-on}" +echo "${HCOM_LOG_TAG} hcom build rdma: ${HCOM_BUILD_RDMA}" + +# check whether build sock(tcp and uds) module, default is on +HCOM_BUILD_SOCK="${HCOM_BUILD_SOCK:-on}" +echo "${HCOM_LOG_TAG} hcom build sock: ${HCOM_BUILD_SOCK}" + +# check whether build shm module, default is on +HCOM_BUILD_SHM="${HCOM_BUILD_SHM:-on}" +echo "${HCOM_LOG_TAG} hcom build shm: ${HCOM_BUILD_SHM}" + +# check whether check kunpeng, default is off +HCOM_ENABLE_ARM_KP="${HCOM_ENABLE_ARM_KP:-off}" +echo "${HCOM_LOG_TAG} hcom enable arm kunpeng check: ${HCOM_ENABLE_ARM_KP}" + +# check whether build java sdk, default is off. +HCOM_BUILD_JAVA_SDK="${HCOM_BUILD_JAVA_SDK:-off}" +echo "${HCOM_LOG_TAG} hcom build java sdk: ${HCOM_BUILD_JAVA_SDK}" + +# check whether enable unittest, default is off. +HCOM_BUILD_TESTS="${HCOM_BUILD_TESTS:-off}" +echo "${HCOM_LOG_TAG} hcom build tests: ${HCOM_BUILD_TESTS}" + +# check whether test tools are installed +if [[ "${HCOM_BUILD_TESTS,,}" == "on" ]]; then + [[ -z "${HCOM_TEST_TOOL_PATH}" ]] && HCOM_TEST_TOOL_PATH="${HCOM_ROOT_DIR}/dist/hcom_test_tools" + echo "${HCOM_LOG_TAG} hcom test tools path: ${HCOM_TEST_TOOL_PATH}" + if [[ ! -d "${HCOM_TEST_TOOL_PATH}" ]]; then + echo "${HCOM_LOG_TAG} hcom test tools are not installed, installing..." + bash "${HCOM_ROOT_DIR}/build/install_test_tools.sh" + fi +fi + +# Fresh build everytime +[[ -n "${HCOM_BUILD_DIR}" ]] && rm -rf "${HCOM_BUILD_DIR}" +[[ -n "${HCOM_INSTALL_DIR}" ]] && rm -rf "${HCOM_INSTALL_DIR}" + +cmake -S"${HCOM_ROOT_DIR}" -B"${HCOM_BUILD_DIR}" -DCMAKE_INSTALL_PREFIX="${HCOM_INSTALL_DIR}" \ + -DCMAKE_BUILD_TYPE=${HCOM_BUILD_TYPE} \ + -DBUILD_TESTS=${HCOM_BUILD_TESTS} \ + -DTEST_TOOL_INSTALL_PATH="${HCOM_TEST_TOOL_PATH}" \ + -DBUILD_JAVA_SDK=${HCOM_BUILD_JAVA_SDK} \ + -DBUILD_WITH_HW_CRC=${HCOM_BUILD_HW_CRC} \ + -DBUILD_WITH_UB=${HCOM_BUILD_UB} \ + -DBUILD_WITH_RDMA=${HCOM_BUILD_RDMA} \ + -DBUILD_WITH_SOCK=${HCOM_BUILD_SOCK} \ + -DBUILD_WITH_SHM=${HCOM_BUILD_SHM} \ + -DENABLE_ARM_KP=${HCOM_ENABLE_ARM_KP} \ + -DHCOM_COMPONENT_VERSION="${HCOM_COMPONENT_VERSION}" + +cmake --build "${HCOM_BUILD_DIR}" -j $(nproc) + +# Install to the specified path +cmake --build "${HCOM_BUILD_DIR}" --target install + +# collect objects and make software package +output=$(HCOM_COMPONENT_VERSION=${HCOM_COMPONENT_VERSION} bash "${HCOM_ROOT_DIR}/build/make_software_package.sh" -t "${HCOM_BUILD_TYPE}") + +# build example and perf +[[ "${HCOM_BUILD_EXAMPLE,,}" == "on" ]] && bash "${HCOM_ROOT_DIR}/build/build_example_perf.sh" + +# 不要删除本行。因 `[[ A ]] && B` 为表达式,其返回值会返回给 shell,一旦 +# HCOM_BUILD_EXAMPLE 不为 on 就会返回 1 导致 CI 构建失败. +echo "${HCOM_LOG_TAG} $0 succeeds" diff --git a/build/adapter_script.sh b/build/adapter_script.sh new file mode 100644 index 0000000000000000000000000000000000000000..5c6aff71c89be4ffe116127c167b51285b186958 --- /dev/null +++ b/build/adapter_script.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# script for adapting gcc 4.8.5 +# version: 1.0.0 +# change log: +# *********************************************************************** +set -e + +echo "hcom adaptation gcc 4.8.5 ..." +GCC_VERSION=$(gcc --version | grep "4.8.5") +if [ -n "$GCC_VERSION" ]; then + sed -i '/Wdate-time/d' ../3rdparty/secure_c/huawei_secure_c/src/Makefile + sed -i '/Wduplicated-branches/d' ../3rdparty/secure_c/huawei_secure_c/src/Makefile + sed -i '/Wduplicated-cond/d' ../3rdparty/secure_c/huawei_secure_c/src/Makefile + sed -i '/Wimplicit-fallthrough/d' ../3rdparty/secure_c/huawei_secure_c/src/Makefile + sed -i '/Wshift-negative-value/d' ../3rdparty/secure_c/huawei_secure_c/src/Makefile + sed -i '/Wshift-overflow/d' ../3rdparty/secure_c/huawei_secure_c/src/Makefile +fi diff --git a/build/build_htracer_cli.sh b/build/build_htracer_cli.sh new file mode 100644 index 0000000000000000000000000000000000000000..785340c6f98ebaa9ae57c50ea11181d8f06b1f87 --- /dev/null +++ b/build/build_htracer_cli.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# script for build hcom example and perf. +# Build options can be configured through environment variables. +# (1) HCOM_BUILD_TYPE(optional, default is release) => set build type.(release/debug) +# (2) HCOM_INSTALL_DIR(optional) => directory where hcom is installed. +# (3) HCOM_BUILD_DIR(optional) => directory for building and outputing example and perf. +# (4) HCOM_BUILD_JAVA_SDK(optional, default is off) => build java example or not.(on/off) +# version: 1.0.0 +# change log: +# *********************************************************************** +set -eo pipefail + +readonly HCOM_LOG_TAG="[$(basename ${0})]" +readonly CURRENT_SCRIPT_DIR=$(cd $(dirname ${0}) && pwd) +readonly HCOM_ROOT_DIR=$(dirname ${CURRENT_SCRIPT_DIR}) +readonly HTRACER_CLI_SRC_DIR="${HCOM_ROOT_DIR}/test/tools/hcom_tracer" +readonly HTRACER_CLI_BUILD_DIR="${HCOM_ROOT_DIR}/test/tools/hcom_tracer/build" + +# **************************************** +# build htracer_cli +# **************************************** +cd ${HTRACER_CLI_SRC_DIR} || { echo "Error: hcom test/tools/hcom_tracer directory not found!"; exit 1; } + +# 如果build目录存在,清理 +if [ -d "build" ]; then + rm -rf build/* +fi + +mkdir -p build +cd build || exit 1 + +cmake .. +make -j8 + +if [ $? -eq 0 ]; then + echo -e "\n\033[32mhtracer_cli compiled successfully!\033[0m" +else + echo -e "\n\033[31mError: Failed to compile htracer_cli\033[0m" + exit 1 +fi \ No newline at end of file diff --git a/build/build_tools_perf.sh b/build/build_tools_perf.sh new file mode 100644 index 0000000000000000000000000000000000000000..c86f9b2db6f8d73da4fd6f22eeae08c50ac502ba --- /dev/null +++ b/build/build_tools_perf.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# script for build hcom example and perf. +# Build options can be configured through environment variables. +# (1) HCOM_TOOLS_BUILD_TYPE(optional, default is release) => set build type.(release/debug) +# (2) HCOM_TOOLS_INCLUDE_DIR(optional) => default dir {HCOM_ROOT_DIR}/dist/hcom/include. +# (3) HCOM_TOOLS_LIB_DIR(optional) => default dir {HCOM_ROOT_DIR}/dist/hcom/lib. +# version: 1.0.0 +# change log: +# *********************************************************************** +set -e + +readonly HCOM_LOG_TAG="[$(basename ${0})]" +readonly CURRENT_SCRIPT_DIR=$(realpath $(dirname ${0})) +readonly HCOM_ROOT_DIR=$(dirname ${CURRENT_SCRIPT_DIR}) +readonly HCOM_TOOLS_PERF_DIR="${HCOM_ROOT_DIR}/test/tools/perf_test/build" + +# default tools build type is release +if [ "${HCOM_TOOLS_BUILD_TYPE,,}" == "debug" ]; then + HCOM_TOOLS_BUILD_TYPE="debug" +else + HCOM_TOOLS_BUILD_TYPE="release" +fi +echo "${HCOM_LOG_TAG} hcom build type: ${HCOM_TOOLS_BUILD_TYPE}" + +# 设置环境变量 +HCOM_TOOLS_INCLUDE_DIR="${HCOM_TOOLS_INCLUDE_DIR:-${HCOM_ROOT_DIR}/dist/hcom/include}" +HCOM_TOOLS_LIB_DIR="${HCOM_TOOLS_LIB_DIR:-${HCOM_ROOT_DIR}/dist/hcom/lib}" + +if [ ! -d "${HCOM_TOOLS_INCLUDE_DIR}" ]; then + echo "Error: HCOM_TOOLS_INCLUDE_DIR does not exist." + exit 1 +fi + +if [ ! -d "${HCOM_TOOLS_LIB_DIR}" ]; then + echo "Error: HCOM_TOOLS_LIB_DIR does not exist." + exit 1 +fi + +# check cpu num for parallel build +CPU_PROCESSOR_NUM=$(grep processor /proc/cpuinfo | wc -l) +echo "${HCOM_LOG_TAG} parallel build job num is ${CPU_PROCESSOR_NUM}" + +# build tools perf +if [ -e "${HCOM_TOOLS_PERF_DIR}" ]; then + # 如果存在,删除该路径(无论是文件还是目录) + rm -rf "${HCOM_TOOLS_PERF_DIR}" +fi + +mkdir -p "${HCOM_TOOLS_PERF_DIR}" + +cd ${HCOM_TOOLS_PERF_DIR} + +cmake -DCMAKE_BUILD_TYPE="${HCOM_TOOLS_BUILD_TYPE}"\ + -DHCOM_INCLUDE_DIR="${HCOM_TOOLS_INCLUDE_DIR}"\ + -DHCOM_LIB_DIR="${HCOM_TOOLS_LIB_DIR}" .. + +if [ "$?" != 0 ]; then + echo "${HCOM_LOG_TAG} hcom tools cmake failed" + exit 1 +fi + +make clean +if [ "$?" != 0 ]; then + echo "${HCOM_LOG_TAG} hcom tools make clean failed" + exit 1 +fi + +make -j"${CPU_PROCESSOR_NUM}" +if [ "$?" != 0 ]; then + echo "${HCOM_LOG_TAG} hcom tools make failed" + exit 1 +fi + +echo "${HCOM_LOG_TAG} hcom tools perf build success" diff --git a/build/generate_gtest_report.sh b/build/generate_gtest_report.sh new file mode 100644 index 0000000000000000000000000000000000000000..737c6e15fb415439df6a85aded66d25aa0d071f8 --- /dev/null +++ b/build/generate_gtest_report.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# script for run gtest and generate gtest report +# version: 1.0.0 +# change log: +# *********************************************************************** +set -e + +readonly HCOM_LOG_TAG="[$(basename ${0})]" +readonly CURRENT_SCRIPT_DIR=$(cd $(dirname ${0}) && pwd) +readonly HCOM_ROOT_DIR=$(dirname ${CURRENT_SCRIPT_DIR}) +readonly HCOM_BUILD_DIR="${HCOM_ROOT_DIR}/tmp_build_dir" +readonly HCOM_GTEST_RESULT="${HCOM_ROOT_DIR}/tmp_build_dir/gtest_report.xml" +readonly HCOM_GTEST_TEMP_DIR="${HCOM_ROOT_DIR}/tmp_build_dir/res_xml" + +cd ${HCOM_BUILD_DIR} + +./hcom_ut --gtest_output=xml:./res_xml/ut_result.xml +./hcom_test --gtest_output=xml:./res_xml/llt_result.xml + +# **************************************** +# combine gtest report +# **************************************** +echo '' > ${HCOM_GTEST_RESULT} +tests_val=$(cat res_xml/* |grep "" >> ${HCOM_GTEST_RESULT} +cat res_xml/* | grep -v testsuites |grep -v "xml version" >> ${HCOM_GTEST_RESULT} +echo '' >> ${HCOM_GTEST_RESULT} diff --git a/build/generate_lcov_report.sh b/build/generate_lcov_report.sh new file mode 100644 index 0000000000000000000000000000000000000000..493edf7ab3ef63e0d93661903b0f0eda48ae0d1f --- /dev/null +++ b/build/generate_lcov_report.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# script for generate line coverage report +# version: 1.0.0 +# change log: +# *********************************************************************** +set -e + +readonly HCOM_LOG_TAG="[$(basename ${0})]" +CURRENT_SCRIPT_DIR=$(cd $(dirname ${0}) && pwd) +HCOM_ROOT_DIR=$(dirname ${CURRENT_SCRIPT_DIR}) + +echo ${CURRENT_SCRIPT_DIR} +echo ${HCOM_ROOT_DIR} + +cd ${HCOM_ROOT_DIR}/build +# get the result of code coverage +lcov --rc lcov_branch_coverage=1 --rc lcov_excl_br_line="LCOV_EXCL_BR_LINE|NN_LOG*" \ + -b ../src/ -d ./test/llt/ -c -o lcov_report_llt.info +lcov --rc lcov_branch_coverage=1 --rc lcov_excl_br_line="LCOV_EXCL_BR_LINE|NN_LOG*" \ + -b ../src/ -d ./test/unit_test/ -c -o lcov_report_ut.info +lcov --rc lcov_branch_coverage=1 -a lcov_report_llt.info -a lcov_report_ut.info -o lcov_report_all.info + +# filter the result, remove useless info +# hcom_c.cpp will significantly lower overall line coverage and code coverage, we'll deal with it later. +lcov --rc lcov_branch_coverage=1 --rc lcov_excl_br_line="LCOV_EXCL_BR_LINE|NN_LOG*" -r lcov_report_all.info \ + '*/googletest/*' '*/mockcpp/*' '/usr/include' '*/gcc/*' '*/c++/*' \ + '*/test/*' '*/rdma-core/*' '*/dist/*' \ + '*/src/api/capi/hcom_c.cpp' '*/src/api/capi/hcom_service_c.cpp' \ + '*/src/service/service_net_driver_manager.*' \ + '*/src/under_api/*' \ + -o lcov_report_filterd.info + +# visualize the result +genhtml --branch-coverage -o gcover_report lcov_report_filterd.info diff --git a/build/install_test_tools.sh b/build/install_test_tools.sh new file mode 100644 index 0000000000000000000000000000000000000000..37d0413a52b0a84fe70edee7ad48885332bc7f11 --- /dev/null +++ b/build/install_test_tools.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# Script for preparing test environment. +# HCOM_TEST_TOOL_PATH is set to "${HCOM_ROOT_DIR}/dist/hcom_test_tools" by default. +# gtest will be installed to "${HCOM_TEST_TOOL_PATH}/googletest". +# mockcpp will be installed to "${HCOM_TEST_TOOL_PATH}/mockcpp". +# secodefuzz will be installed to "${HCOM_TEST_TOOL_PATH}/secodefuzz". +# +# version: 1.0.0 +# change log: +# *********************************************************************** +set -e + +readonly MOCKCPP_PATCH_FILENAME="0001-fix-page-size.patch" +readonly CURRENT_SCRIPT_DIR=$(cd $(dirname ${0}) && pwd) +readonly HCOM_ROOT_DIR=$(dirname ${CURRENT_SCRIPT_DIR}) +readonly HCOM_LOG_TAG="[$(basename ${0})]" +readonly TEST_TOOL_BUILD_DIR="${HCOM_ROOT_DIR}/build/tmp_dir_for_prepare_test" + +if [ -z "${HCOM_TEST_TOOL_PATH}" ]; then + echo "${HCOM_LOG_TAG} HCOM_TEST_TOOL_PATH is empty, set to default value." + HCOM_TEST_TOOL_PATH="${HCOM_ROOT_DIR}/dist/hcom_test_tools" +fi +echo "${HCOM_LOG_TAG} HCOM_TEST_TOOL_PATH: ${HCOM_TEST_TOOL_PATH}" +echo "${HCOM_LOG_TAG} TEST_TOOL_BUILD_DIR: ${TEST_TOOL_BUILD_DIR}" + +GTEST_INSTALL_PATH="${HCOM_TEST_TOOL_PATH}/googletest" +MOCKCPP_INSTALL_PATH="${HCOM_TEST_TOOL_PATH}/mockcpp" +SECODEFUZZ_INSTALL_PATH="${HCOM_TEST_TOOL_PATH}/secodefuzz" +echo "${HCOM_LOG_TAG} GTEST_INSTALL_PATH: ${GTEST_INSTALL_PATH}" +echo "${HCOM_LOG_TAG} MOCKCPP_INSTALL_PATH: ${MOCKCPP_INSTALL_PATH}" +echo "${HCOM_LOG_TAG} SECODEFUZZ_INSTALL_PATH: ${SECODEFUZZ_INSTALL_PATH}" + +# prepare test tool build dir +if [ -d "${TEST_TOOL_BUILD_DIR}" ]; then + rm -rf ${TEST_TOOL_BUILD_DIR} +fi +mkdir -p ${TEST_TOOL_BUILD_DIR} + +# prepare googletest +cd ${TEST_TOOL_BUILD_DIR} +git clone https://github.com/google/googletest.git +cd googletest +git checkout -b release-1.12.1 release-1.12.1 +mkdir build && cd build +cmake -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ \ + -DCMAKE_BUILD_TYPE=Debug -DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_PATH} -DINSTALL_GTEST=ON .. +make -j8 +make install +echo "${HCOM_LOG_TAG} googletest install to ${GTEST_INSTALL_PATH} success." + +# prepare mockcpp +cd ${TEST_TOOL_BUILD_DIR} +git clone https://github.com/sinojelly/mockcpp.git +cd mockcpp +git checkout -b mockcpp_arm v2.7 +git apply ${HCOM_ROOT_DIR}/test/external_libs/mockcpp_support_arm64.patch +mkdir build && cd build +cmake -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ \ + -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=${MOCKCPP_INSTALL_PATH} .. +make -j8 +make install +echo "${HCOM_LOG_TAG} mockcpp install to ${MOCKCPP_INSTALL_PATH} success." + +# prepare secodefuzz +cd ${TEST_TOOL_BUILD_DIR} +git clone https://codehub-dg-y.huawei.com/software-engineering-research-community/fuzz/secodefuzz.git +cd secodefuzz +git checkout -b v2.4.8 v2.4.8 +bash build.sh +mkdir -p "${SECODEFUZZ_INSTALL_PATH}/lib" +cp ./examples/out-bin-x64/out/* "${SECODEFUZZ_INSTALL_PATH}/lib" +cp ./examples/out-bin-x64/libSecodefuzz.a "${SECODEFUZZ_INSTALL_PATH}/lib" +mkdir -p "${SECODEFUZZ_INSTALL_PATH}/include/secodefuzz" +cp ./Secodefuzz/secodeFuzz.h "${SECODEFUZZ_INSTALL_PATH}/include/secodefuzz" +echo "${HCOM_LOG_TAG} secodefuzz install to ${SECODEFUZZ_INSTALL_PATH} success." + +# clean +rm -rf ${TEST_TOOL_BUILD_DIR} diff --git a/build/install_ubscomm_tar.sh b/build/install_ubscomm_tar.sh new file mode 100644 index 0000000000000000000000000000000000000000..2ce3d4be00dd29b215029abd4e97958b83c29359 --- /dev/null +++ b/build/install_ubscomm_tar.sh @@ -0,0 +1,44 @@ +#!/bin/bash +HCOM_PACKAGE_DIR=/opt/install/package/ +HCOM_PACKAGE_NAME=ubs_comm-hcom-1.1* +HCOM_INCLUDE_DIR=/usr/include/hcom +HCOM_LIB_DIR=/usr/lib64/ + +# 获取匹配的文件列表 +files=$(ls /opt/install/package/ubs_comm-hcom-1.1*.tar.gz 2>/dev/null) + +if [ -z "$files" ]; then + echo -e "\e[31m=== UbsComm安装失败, 没有找到对应的安装包 ===\e[0m" + exit 1 +else + # 只取第一个匹配的文件 + first_file=$(echo "$files" | head -n 1) +fi + +[[ -n "${HCOM_INCLUDE_DIR}" ]] && rm -rf "${HCOM_INCLUDE_DIR}" +[[ -n "${HCOM_PACKAGE_DIR}/${HCOM_PACKAGE_NAME}" ]] && rm -rf "${HCOM_PACKAGE_DIR}/${HCOM_PACKAGE_NAME}" + +if [ -f "${HCOM_LIB_DIR}libhcom.so" ]; then + rm -f "${HCOM_LIB_DIR}libhcom.so" +fi + +if [ -f "${HCOM_LIB_DIR}libhcom_static.a" ]; then + rm -f "${HCOM_LIB_DIR}libhcom_static.a" +fi + +tar -zxvf "$first_file" -C ${HCOM_PACKAGE_DIR}|| { + echo -e "\e[31m=== UbsComm安装失败, 解压失败 ===\e[0m" + exit 1 +} +cp ${HCOM_PACKAGE_DIR}${HCOM_PACKAGE_NAME}/hcom/lib/libhcom.so ${HCOM_LIB_DIR}|| { + echo -e "\e[31m=== UbsComm安装失败, 没有对应的动态库文件 ===\e[0m" + exit 1 +} +cp ${HCOM_PACKAGE_DIR}${HCOM_PACKAGE_NAME}/hcom/lib/libhcom_static.a ${HCOM_LIB_DIR}|| { + echo -e "\e[31m=== UbsComm安装失败, 没有对应的静态库文件 ===\e[0m" + exit 1 +} +cp -r ${HCOM_PACKAGE_DIR}${HCOM_PACKAGE_NAME}/hcom/include/hcom ${HCOM_INCLUDE_DIR}|| { + echo -e "\e[31m=== UbsComm安装失败, 没有对应的头文件 ===\e[0m" + exit 1 +} diff --git a/build/make_software_package.sh b/build/make_software_package.sh new file mode 100644 index 0000000000000000000000000000000000000000..d4f9b21d559ca3758f6d24338fa34363cc343ca2 --- /dev/null +++ b/build/make_software_package.sh @@ -0,0 +1,174 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# script for packaging hcom. +# Build options can be configured through environment variables. +# (1) HCOM_PRODUCT_NAME(optional) => product name(default BeiMing) +# (2) HCOM_PRODUCT_VERSION(optional) => product version(default 24.4) +# (3) HCOM_COMPONENT_VERSION(optional) => hcom version(default 1.0.0) +# (3) HCOM_PACKAGE_PATH(optional) => software package path(default ${HCOM_ROOT_DIR}/dist) +# version: 1.0.0 +# change log: +# *********************************************************************** +set -eo pipefail + +readonly HCOM_LOG_TAG="[$(basename ${0})]" +readonly CURRENT_SCRIPT_DIR=$(cd $(dirname ${0}) && pwd) +readonly HCOM_ROOT_DIR=$(dirname ${CURRENT_SCRIPT_DIR}) +readonly HCOM_INSTALL_DIR="${HCOM_ROOT_DIR}/dist/hcom" +readonly HCOM_INSTALL_TRACER_DIR="${HCOM_ROOT_DIR}/dist/hcom_3rdparty/hcom_tracer" +readonly HCOM_TRACER_TOOL="${HCOM_ROOT_DIR}/test/tools/hcom_tracer/build/htracer_cli" +readonly HCOM_COMPONENT_NAME="hcom" +readonly HCOM_BUILD_TIME=$(date "+%Y-%m-%d %Z") +readonly HCOM_BUILD_OS_TYPE=$(uname -s) +readonly HCOM_BUILD_OS_ARCH=$(uname -m) + +function show_help() { + echo "Usage: $0 [OPTION]" + echo "Build the project with specified options." + echo "Options:" + echo " -t, --type TYPE Set build type. debug/release" +} + +# 编译类型通过环境变量 HCOM_BUILD_TYPE 和命令行参数 -t 二选一,如果两者都提供了, +# 则优先使用命令行参数。 +HCOM_BUILD_TYPE="release" + +while [[ "$#" -gt 0 ]]; do + case "$1" in + -t|--type) HCOM_BUILD_TYPE="${2,,}"; shift ;; + *) echo "Unknown parameter passed: $1"; show_help; exit 1 ;; + esac + shift +done + +[[ ! -d "${HCOM_INSTALL_DIR}" ]] && echo "${HCOM_LOG_TAG} HCOM install directory(${HCOM_INSTALL_DIR}) does not exist." && exit 1 +echo "${HCOM_LOG_TAG} HCOM install directory: ${HCOM_INSTALL_DIR}" + +# **************************************** +# make HCOM software package +# **************************************** +[[ -z "${HCOM_PRODUCT_NAME}" ]] && HCOM_PRODUCT_NAME="BeiMing" +[[ -z "${HCOM_PRODUCT_VERSION}" ]] && HCOM_PRODUCT_VERSION="24.4" +[[ -z "${HCOM_COMPONENT_VERSION}" ]] && HCOM_COMPONENT_VERSION="1.0.0" +[[ -z "${HCOM_PACKAGE_PATH}" ]] && HCOM_PACKAGE_PATH="${HCOM_ROOT_DIR}/dist" +[[ -z "${ARCH}" ]] && ARCH="aarch64" +HCOM_COMPONENT_COMMIT_ID="" +if [ -d "${HCOM_ROOT_DIR}/.git" ] || (cd "${HCOM_ROOT_DIR}" && git rev-parse --is-inside-work-tree >/dev/null 2>&1); then + HCOM_COMPONENT_COMMIT_ID=$(cd "${HCOM_ROOT_DIR}" && git rev-parse HEAD 2>/dev/null) +fi +# prepare HCOM software package directory +# hcom is published by BoostKit +cd "${HCOM_PACKAGE_PATH}" + +if [[ -z "${OS}" || -z "${ARCH}" ]]; then + echo "${HCOM_LOG_TAG} env OS or env ARCH is empty!" + HCOM_PACKAGE_NAME="BoostKit-${HCOM_COMPONENT_NAME}_${HCOM_COMPONENT_VERSION}_${HCOM_BUILD_OS_ARCH}" +else + HCOM_PACKAGE_NAME="BoostKit-${HCOM_COMPONENT_NAME}_${HCOM_COMPONENT_VERSION}_${OS}_${ARCH}" +fi + +[[ -n "${HCOM_PACKAGE_NAME}" ]] && rm -rf "${HCOM_PACKAGE_NAME}" +mkdir -p "${HCOM_PACKAGE_NAME}" + +# drop securec +rm -rf "${HCOM_INSTALL_DIR}/lib/securec" + +# copy HCOM build dist +cp -r "${HCOM_INSTALL_DIR}" "${HCOM_PACKAGE_NAME}" + +# check whether build tools perf only release type, default is OFF +HCOM_BUILD_TOOLS_PERF=${HCOM_BUILD_TOOLS_PERF:-off} +if [[ "${HCOM_BUILD_TOOLS_PERF,,}" == "on" && "${HCOM_BUILD_TYPE,,}" == "release" ]]; then + bash "${HCOM_ROOT_DIR}/build/build_tools_perf.sh" + # copy hcom_perf + cp "${HCOM_ROOT_DIR}/test/tools/perf_test/build/hcom_perf" "${HCOM_PACKAGE_NAME}"/hcom/ + echo "${HCOM_LOG_TAG} hcom build tools perf success: ${HCOM_BUILD_TOOLS_PERF}" +fi + +# check whether enable htracer_cli, default is off. +#HCOM_BUILD_HTRACER_CLI="${HCOM_BUILD_HTRACER_CLI:-off}" +echo "${HCOM_LOG_TAG} hcom build htracer_cli: ${HCOM_BUILD_HTRACER}" +if [[ "${HCOM_BUILD_HTRACER,,}" == "on" ]]; then + bash "${HCOM_ROOT_DIR}/build/build_htracer_cli.sh" + + if [ ! -d "$HCOM_INSTALL_TRACER_DIR" ]; then + mkdir -p "$HCOM_INSTALL_TRACER_DIR" + fi + # copy htracer_cli to dist + cp "${HCOM_TRACER_TOOL}" "${HCOM_INSTALL_TRACER_DIR}"/ + + if [ ! -d "${HCOM_PACKAGE_NAME}/hcom/bin" ]; then + mkdir -p "${HCOM_PACKAGE_NAME}/hcom/bin" + fi + # copy htracer_cli software package + cp -r "${HCOM_TRACER_TOOL}" "${HCOM_PACKAGE_NAME}/hcom/bin" +fi + + +# generate version info +VERSION_FILE="${HCOM_PACKAGE_PATH}/${HCOM_PACKAGE_NAME}/version.property" +echo "# Copyright: (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + +# product info +product_name=@HCOM_PRODUCT_NAME@ +product_version=@HCOM_PRODUCT_VERSION@ + +# component info +component_name=@HCOM_COMPONENT_NAME@ +component_version=@HCOM_COMPONENT_VERSION@ +component_commit_id=@HCOM_COMPONENT_COMMIT_ID@ + +# build info +build_time=@HCOM_BUILD_TIME@ +build_ostype=@HCOM_BUILD_OS_TYPE@ +build_osarch=@HCOM_BUILD_OS_ARCH@ + +" > "${VERSION_FILE}" + +REQUIRED_VARS=("HCOM_PRODUCT_NAME" "HCOM_PRODUCT_VERSION" "HCOM_COMPONENT_NAME" "HCOM_COMPONENT_VERSION" + "HCOM_BUILD_TIME" "HCOM_BUILD_OS_TYPE" "HCOM_BUILD_OS_ARCH") +for var in "${REQUIRED_VARS[@]}"; do + [[ -z "${!var}" ]] && echo "${HCOM_LOG_TAG} missing environment: $var" && exit 1 + sed -i "s/@$var@/${!var}/g" "${VERSION_FILE}" +done +chmod 600 "${VERSION_FILE}" +echo "${HCOM_LOG_TAG} generate HCOM version info done" + +# make HCOM software package +tar -czf "${HCOM_PACKAGE_NAME}.tar.gz" --exclude *.debug* "${HCOM_PACKAGE_NAME}" +echo "${HCOM_LOG_TAG} make HCOM software package done.(${HCOM_PACKAGE_PATH}/${HCOM_PACKAGE_NAME}.tar.gz)" + +# check whether enable build rpm, default is ON. +if [[ "${HCOM_BUILD_RPM,,}" == "off" ]]; then + exit 0 +fi + +mkdir -p ~/rpmbuild/SOURCES/ +cp "${HCOM_PACKAGE_PATH}/${HCOM_PACKAGE_NAME}.tar.gz" ~/rpmbuild/SOURCES/ + +cd "${HCOM_ROOT_DIR}" + +# 定义基础的 rpmbuild 命令和公共参数 +base_rpmbuild_cmd="rpmbuild --define \"package_name ${HCOM_PACKAGE_NAME}\" -bb hcom.spec" + +# 添加特定于 Java SDK 的选项 +[[ "${HCOM_BUILD_JAVA_SDK}" == "ON" ]] && base_rpmbuild_cmd="${base_rpmbuild_cmd} --with java_compile" + +# 添加特定于 htracer_cli 的选项 +[[ "${HCOM_BUILD_HTRACER,,}" == "on" ]] && base_rpmbuild_cmd="${base_rpmbuild_cmd} --define \"_with_htracer_cli 1\"" + +# 根据构建类型添加调试信息选项 +[[ "${HCOM_BUILD_TYPE}" == "debug" ]] && base_rpmbuild_cmd="${base_rpmbuild_cmd} --define \"_build_type debug\"" + +# 根据是否需要性能工具包决定是否包含 hcom_perf +[[ "${HCOM_BUILD_TYPE}" == "release" && "${HCOM_BUILD_TOOLS_PERF}" == "ON" ]] && base_rpmbuild_cmd="${base_rpmbuild_cmd} --define \"_with_hcom_perf 1\"" + +# 执行最终的 rpmbuild 命令 +eval "$base_rpmbuild_cmd" + +if [[ "${HCOM_BUILD_TYPE,,}" == "debug" ]]; then + cp ~/rpmbuild/RPMS/${ARCH}/OCK-CommunicationSuite_HCOM_Debug-2.0.0-B099*.rpm "${HCOM_ROOT_DIR}/dist/OCK-CommunicationSuite_HCOM_Debug_2.0.0_${OS}-${ARCH}.rpm" +else + cp ~/rpmbuild/RPMS/${ARCH}/OCK-CommunicationSuite_HCOM-2.0.0-B099*.rpm "${HCOM_ROOT_DIR}/dist/OCK-CommunicationSuite_HCOM_2.0.0_${OS}-${ARCH}.rpm" +fi diff --git a/build/upload_lcov_report.sh b/build/upload_lcov_report.sh new file mode 100644 index 0000000000000000000000000000000000000000..75d44182fcaba66adc7c74581bdfd133316a8f4f --- /dev/null +++ b/build/upload_lcov_report.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# script for upload line coverage report +# version: 1.0.0 +# change log: +# *********************************************************************** +set -e + +readonly HCOM_LOG_TAG="[$(basename ${0})]" +CURRENT_SCRIPT_DIR=$(cd $(dirname ${0}) && pwd) +HCOM_ROOT_DIR=$(dirname ${CURRENT_SCRIPT_DIR}) + +echo ${CURRENT_SCRIPT_DIR} +echo ${HCOM_ROOT_DIR} + +cd ${HCOM_ROOT_DIR}/build/gcover_report +cp ../gtest_report.xml ./test_detail.xml +cp ../lcov_report_filterd.info ./coverage.info + +zip -r lcov.zip * +artget pull "ock_3rdparty ock3rdparty1.0" -ru software -user p_OckCI \ + -pwd encryption:ETMsDgAAAYgIefwyABFBRVMvR0NNL05vUGFkZGluZwCAABAAEBKGslaG2E1RnzCAiRGoekcAAAAqIwJz1WwrhJUvE4ohzMKYYtHPTBeTa7LlILcfVZJoOuQOYEmRgSMNt85UABQBhk4+/kX90aleLjjXzrA/G5tcGw== \ + -rp "hdfsutil.jar" -ap "./" +java -jar hdfsutil.jar -prod -upload lcov.zip ${upload_path}/lcov.zip diff --git a/config.yml b/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..e818ffcd63d9328c19d7f7b7b8264fd2c9df152e --- /dev/null +++ b/config.yml @@ -0,0 +1 @@ +# 构建规范说要有一个 config 文件... diff --git a/dependence.xml b/dependence.xml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/doc/LICENSE b/doc/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..33bec29d5181dfd57411f387bf7a0e60da1946d7 --- /dev/null +++ b/doc/LICENSE @@ -0,0 +1,427 @@ +Attribution-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-ShareAlike 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-ShareAlike 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + l. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + m. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part; and + + b. produce, reproduce, and Share Adapted Material. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + + including for purposes of Section 3(b); and + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/doc/UBS-COMM-API-Spec.md b/doc/UBS-COMM-API-Spec.md new file mode 100644 index 0000000000000000000000000000000000000000..5cc8783187e071e70b01e7839e05acd3c6c775cd --- /dev/null +++ b/doc/UBS-COMM-API-Spec.md @@ -0,0 +1,9793 @@ +[TABLE] + +[TABLE] + +| 华为技术有限公司 | | +|------------------|---------------------------------------------| +| 地址: | 深圳市龙岗区坂田华为总部办公楼 邮编:518129 | +| 网址: | | +| 客户服务邮箱: | | +| 客户服务电话: | 4008302118 | + +[TABLE] + +# 前言 + +## 概述 + +本文档详细描述了UBS Comm对外提供的API接口信息,包括API接口参数解释和使用样例等内容。 + +## 读者对象 + +本文档主要适用于以下工程师: + +- 技术支持工程师 + +- 二次开发工程师 + +- 维护工程师 + +## 符号约定 + +在本文中可能出现下列标志,它们所代表的含义如下。 + +[TABLE] + +## 修改记录 + +[TABLE] + +# 目 录 + +[前言 [iii](#前言)](#前言) + +[1 介绍 [1](#介绍)](#介绍) + +[2 API 2.0 [2](#api-2.0)](#api-2.0) + +[2.1 基础API参考 [2](#基础api参考)](#基础api参考) + +[2.1.1 C++ API [2](#c-api)](#c-api) + +[2.1.1.1 服务层API [2](#服务层api)](#服务层api) + +[2.1.1.1.1 UBSHcomService::Create [2](#ubshcomservicecreate)](#ubshcomservicecreate) + +[2.1.1.1.2 UBSHcomService::Destroy [3](#ubshcomservicedestroy)](#ubshcomservicedestroy) + +[2.1.1.1.3 UBSHcomService::Bind [3](#ubshcomservicebind)](#ubshcomservicebind) + +[2.1.1.1.4 UBSHcomService::Start [4](#ubshcomservicestart)](#ubshcomservicestart) + +[2.1.1.1.5 UBSHcomService::Connect [5](#ubshcomserviceconnect)](#ubshcomserviceconnect) + +[2.1.1.1.6 UBSHcomService::Disconnect [5](#ubshcomservicedisconnect)](#ubshcomservicedisconnect) + +[2.1.1.1.7 UBSHcomService::RegisterMemoryRegion [6](#ubshcomserviceregistermemoryregion)](#ubshcomserviceregistermemoryregion) + +[2.1.1.1.8 UBSHcomService::DestroyMemoryRegion [7](#ubshcomservicedestroymemoryregion)](#ubshcomservicedestroymemoryregion) + +[2.1.1.1.9 UBSHcomService::RegisterChannelBrokenHandler [7](#ubshcomserviceregisterchannelbrokenhandler)](#ubshcomserviceregisterchannelbrokenhandler) + +[2.1.1.1.10 UBSHcomService::RegisterIdleHandler [8](#ubshcomserviceregisteridlehandler)](#ubshcomserviceregisteridlehandler) + +[2.1.1.1.11 UBSHcomService::RegisterRecvHandler [8](#ubshcomserviceregisterrecvhandler)](#ubshcomserviceregisterrecvhandler) + +[2.1.1.1.12 UBSHcomService::RegisterSendHandler [9](#ubshcomserviceregistersendhandler)](#ubshcomserviceregistersendhandler) + +[2.1.1.1.13 UBSHcomService::RegisterOneSideHandler [10](#ubshcomserviceregisteronesidehandler)](#ubshcomserviceregisteronesidehandler) + +[2.1.1.1.14 UBSHcomChannel::Send [10](#ubshcomchannelsend)](#ubshcomchannelsend) + +[2.1.1.1.15 UBSHcomChannel::Call [11](#ubshcomchannelcall)](#ubshcomchannelcall) + +[2.1.1.1.16 UBSHcomChannel::Reply [13](#ubshcomchannelreply)](#ubshcomchannelreply) + +[2.1.1.1.17 UBSHcomChannel::Get [14](#ubshcomchannelget)](#ubshcomchannelget) + +[2.1.1.1.18 UBSHcomChannel::Put [14](#ubshcomchannelput)](#ubshcomchannelput) + +[2.1.1.1.19 UBSHcomChannel::Recv [15](#ubshcomchannelrecv)](#ubshcomchannelrecv) + +[2.1.1.1.20 UBSHcomChannel::SetFlowControlConfig [16](#ubshcomchannelsetflowcontrolconfig)](#ubshcomchannelsetflowcontrolconfig) + +[2.1.1.1.21 UBSHcomChannel::SetChannelTimeOut [16](#ubshcomchannelsetchanneltimeout)](#ubshcomchannelsetchanneltimeout) + +[2.1.1.1.22 UBSHcomChannel::SetUBSHcomTwoSideThreshold [17](#ubshcomchannelsetubshcomtwosidethreshold)](#ubshcomchannelsetubshcomtwosidethreshold) + +[2.1.1.1.23 UBSHcomChannel::GetId [17](#ubshcomchannelgetid)](#ubshcomchannelgetid) + +[2.1.1.1.24 UBSHcomChannel::GetPeerConnectPayload [18](#ubshcomchannelgetpeerconnectpayload)](#ubshcomchannelgetpeerconnectpayload) + +[2.1.1.1.25 UBSHcomChannel::SetTraceId [18](#ubshcomchannelsettraceid)](#ubshcomchannelsettraceid) + +[2.1.1.1.26 UBSHcomServiceContext::Result [19](#ubshcomservicecontextresult)](#ubshcomservicecontextresult) + +[2.1.1.1.27 UBSHcomServiceContext::Channel [19](#ubshcomservicecontextchannel)](#ubshcomservicecontextchannel) + +[2.1.1.1.28 UBSHcomServiceContext::OpType [19](#ubshcomservicecontextoptype)](#ubshcomservicecontextoptype) + +[2.1.1.1.29 UBSHcomServiceContext::RspCtx [20](#ubshcomservicecontextrspctx)](#ubshcomservicecontextrspctx) + +[2.1.1.1.30 UBSHcomServiceContext::ErrorCode [20](#ubshcomservicecontexterrorcode)](#ubshcomservicecontexterrorcode) + +[2.1.1.1.31 UBSHcomServiceContext::OpCode [20](#ubshcomservicecontextopcode)](#ubshcomservicecontextopcode) + +[2.1.1.1.32 UBSHcomServiceContext::MessageData [21](#ubshcomservicecontextmessagedata)](#ubshcomservicecontextmessagedata) + +[2.1.1.1.33 UBSHcomServiceContext::MessageDataLen [21](#ubshcomservicecontextmessagedatalen)](#ubshcomservicecontextmessagedatalen) + +[2.1.1.1.34 UBSHcomServiceContext::Clone [21](#ubshcomservicecontextclone)](#ubshcomservicecontextclone) + +[2.1.1.1.35 UBSHcomServiceContext::IsTimeout [22](#ubshcomservicecontextistimeout)](#ubshcomservicecontextistimeout) + +[2.1.1.1.36 UBSHcomServiceContext::Invalidate [22](#ubshcomservicecontextinvalidate)](#ubshcomservicecontextinvalidate) + +[2.1.1.1.37 UBSHcomService::SetEnableMrCache [23](#ubshcomservicesetenablemrcache)](#ubshcomservicesetenablemrcache) + +[2.1.1.2 传输层API [23](#传输层api)](#传输层api) + +[2.1.1.2.1 UBSHcomNetDriver::Instance [23](#ubshcomnetdriverinstance)](#ubshcomnetdriverinstance) + +[2.1.1.2.2 UBSHcomNetDriver::DestroyInstance [24](#ubshcomnetdriverdestroyinstance)](#ubshcomnetdriverdestroyinstance) + +[2.1.1.2.3 UBSHcomNetDriver::LocalSupport [25](#ubshcomnetdriverlocalsupport)](#ubshcomnetdriverlocalsupport) + +[2.1.1.2.4 UBSHcomNetDriver::MultiRailGetDevCount [25](#ubshcomnetdrivermultirailgetdevcount)](#ubshcomnetdrivermultirailgetdevcount) + +[2.1.1.2.5 UBSHcomNetDriver::Initialize [26](#ubshcomnetdriverinitialize)](#ubshcomnetdriverinitialize) + +[2.1.1.2.6 UBSHcomNetDriver::UnInitialize [26](#ubshcomnetdriveruninitialize)](#ubshcomnetdriveruninitialize) + +[2.1.1.2.7 UBSHcomNetDriver::Start [27](#ubshcomnetdriverstart)](#ubshcomnetdriverstart) + +[2.1.1.2.8 UBSHcomNetDriver::Stop [27](#ubshcomnetdriverstop)](#ubshcomnetdriverstop) + +[2.1.1.2.9 UBSHcomNetDriver::CreateMemoryRegion [28](#ubshcomnetdrivercreatememoryregion)](#ubshcomnetdrivercreatememoryregion) + +[2.1.1.2.10 UBSHcomNetDriver::DestroyMemoryRegion [28](#ubshcomnetdriverdestroymemoryregion)](#ubshcomnetdriverdestroymemoryregion) + +[2.1.1.2.11 UBSHcomNetDriver::Connect [29](#ubshcomnetdriverconnect)](#ubshcomnetdriverconnect) + +[2.1.1.2.12 UBSHcomNetDriver::DestroyEndpoint [31](#ubshcomnetdriverdestroyendpoint)](#ubshcomnetdriverdestroyendpoint) + +[2.1.1.2.13 UBSHcomNetDriver::OobIpAndPort [31](#ubshcomnetdriveroobipandport)](#ubshcomnetdriveroobipandport) + +[2.1.1.2.14 UBSHcomNetDriver::GetOobIpAndPort [32](#ubshcomnetdrivergetoobipandport)](#ubshcomnetdrivergetoobipandport) + +[2.1.1.2.15 UBSHcomNetDriver::AddOobOptions [32](#ubshcomnetdriveraddooboptions)](#ubshcomnetdriveraddooboptions) + +[2.1.1.2.16 UBSHcomNetDriver::OobUdsName [33](#ubshcomnetdriveroobudsname)](#ubshcomnetdriveroobudsname) + +[2.1.1.2.17 UBSHcomNetDriver::AddOobUdsOptions [34](#ubshcomnetdriveraddoobudsoptions)](#ubshcomnetdriveraddoobudsoptions) + +[2.1.1.2.18 UBSHcomNetDriver::RegisterNewEPHandler [34](#ubshcomnetdriverregisternewephandler)](#ubshcomnetdriverregisternewephandler) + +[2.1.1.2.19 UBSHcomNetDriver::RegisterEPBrokenHandler [35](#ubshcomnetdriverregisterepbrokenhandler)](#ubshcomnetdriverregisterepbrokenhandler) + +[2.1.1.2.20 UBSHcomNetDriver::RegisterNewReqHandler [36](#ubshcomnetdriverregisternewreqhandler)](#ubshcomnetdriverregisternewreqhandler) + +[2.1.1.2.21 UBSHcomNetDriver::RegisterReqPostedHandler [36](#ubshcomnetdriverregisterreqpostedhandler)](#ubshcomnetdriverregisterreqpostedhandler) + +[2.1.1.2.22 UBSHcomNetDriver::RegisterOneSideDoneHandler [37](#ubshcomnetdriverregisteronesidedonehandler)](#ubshcomnetdriverregisteronesidedonehandler) + +[2.1.1.2.23 UBSHcomNetDriver::RegisterIdleHandler [38](#ubshcomnetdriverregisteridlehandler)](#ubshcomnetdriverregisteridlehandler) + +[2.1.1.2.24 UBSHcomNetDriver::Name [39](#ubshcomnetdrivername)](#ubshcomnetdrivername) + +[2.1.1.2.25 UBSHcomNetDriver::GetId [39](#ubshcomnetdrivergetid)](#ubshcomnetdrivergetid) + +[2.1.1.2.26 UBSHcomNetDriver::Protocol [39](#ubshcomnetdriverprotocol)](#ubshcomnetdriverprotocol) + +[2.1.1.2.27 UBSHcomNetDriver::IsStarted [40](#ubshcomnetdriverisstarted)](#ubshcomnetdriverisstarted) + +[2.1.1.2.28 UBSHcomNetDriver::IsInited [40](#ubshcomnetdriverisinited)](#ubshcomnetdriverisinited) + +[2.1.1.2.29 UBSHcomNetDriver::NetUid [41](#ubshcomnetdrivernetuid)](#ubshcomnetdrivernetuid) + +[2.1.1.2.30 UBSHcomNetDriver::DumpObjectStatistics [41](#ubshcomnetdriverdumpobjectstatistics)](#ubshcomnetdriverdumpobjectstatistics) + +[2.1.1.2.31 UBSHcomNetDriver::SetPeerDevId [41](#ubshcomnetdriversetpeerdevid)](#ubshcomnetdriversetpeerdevid) + +[2.1.1.2.32 UBSHcomNetDriver::GetPeerDevId [42](#ubshcomnetdrivergetpeerdevid)](#ubshcomnetdrivergetpeerdevid) + +[2.1.1.2.33 UBSHcomNetDriver::SetDeviceId [42](#ubshcomnetdriversetdeviceid)](#ubshcomnetdriversetdeviceid) + +[2.1.1.2.34 UBSHcomNetDriver::GetDeviceId [43](#ubshcomnetdrivergetdeviceid)](#ubshcomnetdrivergetdeviceid) + +[2.1.1.2.35 UBSHcomNetDriver::GetBandWidth [43](#ubshcomnetdrivergetbandwidth)](#ubshcomnetdrivergetbandwidth) + +[2.1.1.2.36 UBSHcomNetDriver::OobEidAndJettyId [43](#ubshcomnetdriveroobeidandjettyid)](#ubshcomnetdriveroobeidandjettyid) + +[2.1.1.2.37 UBSHcomNetEndpoint::SetEpOption [44](#ubshcomnetendpointsetepoption)](#ubshcomnetendpointsetepoption) + +[2.1.1.2.38 UBSHcomNetEndpoint::GetSendQueueCount [44](#ubshcomnetendpointgetsendqueuecount)](#ubshcomnetendpointgetsendqueuecount) + +[2.1.1.2.39 UBSHcomNetEndpoint::Id [45](#ubshcomnetendpointid)](#ubshcomnetendpointid) + +[2.1.1.2.40 UBSHcomNetEndpoint::WorkerIndex [45](#ubshcomnetendpointworkerindex)](#ubshcomnetendpointworkerindex) + +[2.1.1.2.41 UBSHcomNetEndpoint::IsEstablished [45](#ubshcomnetendpointisestablished)](#ubshcomnetendpointisestablished) + +[2.1.1.2.42 UBSHcomNetEndpoint::UpCtx [46](#ubshcomnetendpointupctx)](#ubshcomnetendpointupctx) + +[2.1.1.2.43 UBSHcomNetEndpoint::UpCtx [46](#ubshcomnetendpointupctx-1)](#ubshcomnetendpointupctx-1) + +[2.1.1.2.44 UBSHcomNetEndpoint::PeerConnectPayload [47](#ubshcomnetendpointpeerconnectpayload)](#ubshcomnetendpointpeerconnectpayload) + +[2.1.1.2.45 UBSHcomNetEndpoint::LocalIp [47](#ubshcomnetendpointlocalip)](#ubshcomnetendpointlocalip) + +[2.1.1.2.46 UBSHcomNetEndpoint::ListenPort [47](#ubshcomnetendpointlistenport)](#ubshcomnetendpointlistenport) + +[2.1.1.2.47 UBSHcomNetEndpoint::Version [48](#ubshcomnetendpointversion)](#ubshcomnetendpointversion) + +[2.1.1.2.48 UBSHcomNetEndpoint::State [48](#ubshcomnetendpointstate)](#ubshcomnetendpointstate) + +[2.1.1.2.49 UBSHcomNetEndpoint::PeerIpAndPort [48](#ubshcomnetendpointpeeripandport)](#ubshcomnetendpointpeeripandport) + +[2.1.1.2.50 UBSHcomNetEndpoint::UdsName [49](#ubshcomnetendpointudsname)](#ubshcomnetendpointudsname) + +[2.1.1.2.51 UBSHcomNetEndpoint::PostSend [49](#ubshcomnetendpointpostsend)](#ubshcomnetendpointpostsend) + +[2.1.1.2.52 UBSHcomNetEndpoint::PostSendRaw [50](#ubshcomnetendpointpostsendraw)](#ubshcomnetendpointpostsendraw) + +[2.1.1.2.53 UBSHcomNetEndpoint::PostRead [51](#ubshcomnetendpointpostread)](#ubshcomnetendpointpostread) + +[2.1.1.2.54 UBSHcomNetEndpoint::PostWrite [52](#ubshcomnetendpointpostwrite)](#ubshcomnetendpointpostwrite) + +[2.1.1.2.55 UBSHcomNetEndpoint::DefaultTimeout [52](#ubshcomnetendpointdefaulttimeout)](#ubshcomnetendpointdefaulttimeout) + +[2.1.1.2.56 UBSHcomNetEndpoint::WaitCompletion [53](#ubshcomnetendpointwaitcompletion)](#ubshcomnetendpointwaitcompletion) + +[2.1.1.2.57 UBSHcomNetEndpoint::Receive [53](#ubshcomnetendpointreceive)](#ubshcomnetendpointreceive) + +[2.1.1.2.58 UBSHcomNetEndpoint::ReceiveRaw [54](#ubshcomnetendpointreceiveraw)](#ubshcomnetendpointreceiveraw) + +[2.1.1.2.59 UBSHcomNetEndpoint::GetRemoteUdsIdInfo [55](#ubshcomnetendpointgetremoteudsidinfo)](#ubshcomnetendpointgetremoteudsidinfo) + +[2.1.1.2.60 UBSHcomNetEndpoint::GetPeerIpPort [55](#ubshcomnetendpointgetpeeripport)](#ubshcomnetendpointgetpeeripport) + +[2.1.1.2.61 UBSHcomNetEndpoint::Close [56](#ubshcomnetendpointclose)](#ubshcomnetendpointclose) + +[2.1.1.2.62 UBSHcomNetEndpoint::GetDevIndex [56](#ubshcomnetendpointgetdevindex)](#ubshcomnetendpointgetdevindex) + +[2.1.1.2.63 UBSHcomNetEndpoint::GetPeerDevIndex [56](#ubshcomnetendpointgetpeerdevindex)](#ubshcomnetendpointgetpeerdevindex) + +[2.1.1.2.64 UBSHcomNetEndpoint::GetBandWidth [57](#ubshcomnetendpointgetbandwidth)](#ubshcomnetendpointgetbandwidth) + +[2.1.1.2.65 UBSHcomNetMessage::DataLen [57](#ubshcomnetmessagedatalen)](#ubshcomnetmessagedatalen) + +[2.1.1.2.66 UBSHcomNetMessage::Data [58](#ubshcomnetmessagedata)](#ubshcomnetmessagedata) + +[2.1.1.2.67 UBSHcomNetRequestContext::EndPoint [58](#ubshcomnetrequestcontextendpoint)](#ubshcomnetrequestcontextendpoint) + +[2.1.1.2.68 UBSHcomNetRequestContext::Result [58](#ubshcomnetrequestcontextresult)](#ubshcomnetrequestcontextresult) + +[2.1.1.2.69 UBSHcomNetRequestContext::Header [59](#ubshcomnetrequestcontextheader)](#ubshcomnetrequestcontextheader) + +[2.1.1.2.70 UBSHcomNetRequestContext::Message [59](#ubshcomnetrequestcontextmessage)](#ubshcomnetrequestcontextmessage) + +[2.1.1.2.71 UBSHcomNetRequestContext::OpType [59](#ubshcomnetrequestcontextoptype)](#ubshcomnetrequestcontextoptype) + +[2.1.1.2.72 UBSHcomNetRequestContext::OriginalRequest [60](#ubshcomnetrequestcontextoriginalrequest)](#ubshcomnetrequestcontextoriginalrequest) + +[2.1.1.2.73 UBSHcomNetRequestContext::OriginalSgeRequest [60](#ubshcomnetrequestcontextoriginalsgerequest)](#ubshcomnetrequestcontextoriginalsgerequest) + +[2.1.1.2.74 UBSHcomNetRequestContext::SafeClone [60](#ubshcomnetrequestcontextsafeclone)](#ubshcomnetrequestcontextsafeclone) + +[2.1.1.2.75 UBSHcomNetResponseContext::Header [61](#ubshcomnetresponsecontextheader)](#ubshcomnetresponsecontextheader) + +[2.1.1.2.76 UBSHcomNetResponseContext::Message [61](#ubshcomnetresponsecontextmessage)](#ubshcomnetresponsecontextmessage) + +[2.1.1.2.77 UBSHcomNetMemoryRegion::GetLKey [62](#ubshcomnetmemoryregiongetlkey)](#ubshcomnetmemoryregiongetlkey) + +[2.1.1.2.78 UBSHcomNetMemoryRegion::GetAddress [62](#ubshcomnetmemoryregiongetaddress)](#ubshcomnetmemoryregiongetaddress) + +[2.1.1.2.79 UBSHcomNetMemoryRegion::Size [62](#ubshcomnetmemoryregionsize)](#ubshcomnetmemoryregionsize) + +[2.1.1.2.80 UBSHcomNetMemoryAllocator::Create [63](#ubshcomnetmemoryallocatorcreate)](#ubshcomnetmemoryallocatorcreate) + +[2.1.1.2.81 UBSHcomNetMemoryAllocator::MrKey [63](#ubshcomnetmemoryallocatormrkey)](#ubshcomnetmemoryallocatormrkey) + +[2.1.1.2.82 UBSHcomNetMemoryAllocator::MrKey [64](#ubshcomnetmemoryallocatormrkey-1)](#ubshcomnetmemoryallocatormrkey-1) + +[2.1.1.2.83 UBSHcomNetMemoryAllocator::MemOffset [64](#ubshcomnetmemoryallocatormemoffset)](#ubshcomnetmemoryallocatormemoffset) + +[2.1.1.2.84 UBSHcomNetMemoryAllocator::FreeSize [65](#ubshcomnetmemoryallocatorfreesize)](#ubshcomnetmemoryallocatorfreesize) + +[2.1.1.2.85 UBSHcomNetMemoryAllocator::Allocate [65](#ubshcomnetmemoryallocatorallocate)](#ubshcomnetmemoryallocatorallocate) + +[2.1.1.2.86 UBSHcomNetMemoryAllocator::Free [66](#ubshcomnetmemoryallocatorfree)](#ubshcomnetmemoryallocatorfree) + +[2.1.1.2.87 UBSHcomNetMemoryAllocator::Destroy [66](#ubshcomnetmemoryallocatordestroy)](#ubshcomnetmemoryallocatordestroy) + +[2.1.1.2.88 UBSHcomNetMemoryAllocator::GetTargetSeg [67](#ubshcomnetmemoryallocatorgettargetseg)](#ubshcomnetmemoryallocatorgettargetseg) + +[2.1.1.2.89 UBSHcomNetMemoryAllocator::SetTargetSeg [67](#ubshcomnetmemoryallocatorsettargetseg)](#ubshcomnetmemoryallocatorsettargetseg) + +[2.1.1.2.90 UBSHcomNetMemoryAllocatorTypeToString [67](#ubshcomnetmemoryallocatortypetostring)](#ubshcomnetmemoryallocatortypetostring) + +[2.1.1.2.91 UBSHcomNetDriverProtocolToString [68](#ubshcomnetdriverprotocoltostring)](#ubshcomnetdriverprotocoltostring) + +[2.1.1.2.92 UBSHcomNetDriverSecTypeToString [68](#ubshcomnetdriversectypetostring)](#ubshcomnetdriversectypetostring) + +[2.1.1.2.93 UBSHcomNetDriverOobTypeToString [69](#ubshcomnetdriveroobtypetostring)](#ubshcomnetdriveroobtypetostring) + +[2.1.1.2.94 UBSHcomNetDriverLBPolicyToString [69](#ubshcomnetdriverlbpolicytostring)](#ubshcomnetdriverlbpolicytostring) + +[2.1.1.2.95 UBSHcomNEPStateToString [70](#ubshcomnepstatetostring)](#ubshcomnepstatetostring) + +[2.1.2 C API [71](#c-api-1)](#c-api-1) + +[2.1.2.1 服务层API [71](#服务层api-1)](#服务层api-1) + +[2.1.2.1.1 ubs_hcom_service_create [71](#ubs_hcom_service_create)](#ubs_hcom_service_create) + +[2.1.2.1.2 ubs_hcom_service_bind [71](#ubs_hcom_service_bind)](#ubs_hcom_service_bind) + +[2.1.2.1.3 ubs_hcom_service_start [72](#ubs_hcom_service_start)](#ubs_hcom_service_start) + +[2.1.2.1.4 ubs_hcom_service_destroy [73](#ubs_hcom_service_destroy)](#ubs_hcom_service_destroy) + +[2.1.2.1.5 ubs_hcom_service_connect [73](#ubs_hcom_service_connect)](#ubs_hcom_service_connect) + +[2.1.2.1.6 ubs_hcom_service_disconnect [74](#ubs_hcom_service_disconnect)](#ubs_hcom_service_disconnect) + +[2.1.2.1.7 ubs_hcom_service_register_memory_region [75](#ubs_hcom_service_register_memory_region)](#ubs_hcom_service_register_memory_region) + +[2.1.2.1.8 ubs_hcom_service_get_memory_region_info [75](#ubs_hcom_service_get_memory_region_info)](#ubs_hcom_service_get_memory_region_info) + +[2.1.2.1.9 ubs_hcom_service_register_assign_memory_region [76](#ubs_hcom_service_register_assign_memory_region)](#ubs_hcom_service_register_assign_memory_region) + +[2.1.2.1.10 ubs_hcom_service_destroy_memory_region [77](#ubs_hcom_service_destroy_memory_region)](#ubs_hcom_service_destroy_memory_region) + +[2.1.2.1.11 ubs_hcom_service_register_broken_handler [77](#ubs_hcom_service_register_broken_handler)](#ubs_hcom_service_register_broken_handler) + +[2.1.2.1.12 ubs_hcom_service_register_idle_handler [78](#ubs_hcom_service_register_idle_handler)](#ubs_hcom_service_register_idle_handler) + +[2.1.2.1.13 ubs_hcom_service_register_handler [79](#ubs_hcom_service_register_handler)](#ubs_hcom_service_register_handler) + +[2.1.2.1.14 ubs_hcom_service_set_enable_mrcache [80](#ubs_hcom_service_set_enable_mrcache)](#ubs_hcom_service_set_enable_mrcache) + +[2.1.2.1.15 ubs_hcom_channel_refer [81](#ubs_hcom_channel_refer)](#ubs_hcom_channel_refer) + +[2.1.2.1.16 ubs_hcom_channel_derefer [81](#ubs_hcom_channel_derefer)](#ubs_hcom_channel_derefer) + +[2.1.2.1.17 ubs_hcom_channel_send [82](#ubs_hcom_channel_send)](#ubs_hcom_channel_send) + +[2.1.2.1.18 ubs_hcom_channel_call [82](#ubs_hcom_channel_call)](#ubs_hcom_channel_call) + +[2.1.2.1.19 ubs_hcom_channel_reply [84](#ubs_hcom_channel_reply)](#ubs_hcom_channel_reply) + +[2.1.2.1.20 ubs_hcom_channel_put [84](#ubs_hcom_channel_put)](#ubs_hcom_channel_put) + +[2.1.2.1.21 ubs_hcom_channel_get [85](#ubs_hcom_channel_get)](#ubs_hcom_channel_get) + +[2.1.2.1.22 ubs_hcom_channel_recv [86](#ubs_hcom_channel_recv)](#ubs_hcom_channel_recv) + +[2.1.2.1.23 ubs_hcom_channel_set_flowctl_cfg [86](#ubs_hcom_channel_set_flowctl_cfg)](#ubs_hcom_channel_set_flowctl_cfg) + +[2.1.2.1.24 ubs_hcom_channel_set_timeout [87](#ubs_hcom_channel_set_timeout)](#ubs_hcom_channel_set_timeout) + +[2.1.2.1.25 ubs_hcom_channel_set_twoside_threshold [87](#ubs_hcom_channel_set_twoside_threshold)](#ubs_hcom_channel_set_twoside_threshold) + +[2.1.2.1.26 Channel_Close [88](#channel_close)](#channel_close) + +[2.1.2.1.27 ubs_hcom_channel_get_id [89](#ubs_hcom_channel_get_id)](#ubs_hcom_channel_get_id) + +[2.1.2.1.28 ubs_hcom_context_get_channel [89](#ubs_hcom_context_get_channel)](#ubs_hcom_context_get_channel) + +[2.1.2.1.29 ubs_hcom_context_get_type [90](#ubs_hcom_context_get_type)](#ubs_hcom_context_get_type) + +[2.1.2.1.30 ubs_hcom_context_get_result [90](#ubs_hcom_context_get_result)](#ubs_hcom_context_get_result) + +[2.1.2.1.31 ubs_hcom_context_get_rspctx [91](#ubs_hcom_context_get_rspctx)](#ubs_hcom_context_get_rspctx) + +[2.1.2.1.32 ubs_hcom_context_get_opcode [91](#ubs_hcom_context_get_opcode)](#ubs_hcom_context_get_opcode) + +[2.1.2.1.33 ubs_hcom_context_get_data [92](#ubs_hcom_context_get_data)](#ubs_hcom_context_get_data) + +[2.1.2.1.34 ubs_hcom_context_get_datalen [92](#ubs_hcom_context_get_datalen)](#ubs_hcom_context_get_datalen) + +[2.1.2.2 传输层API [93](#传输层api-1)](#传输层api-1) + +[2.1.2.2.1 ubs_hcom_driver_create [93](#ubs_hcom_driver_create)](#ubs_hcom_driver_create) + +[2.1.2.2.2 ubs_hcom_driver_set_ipport [94](#ubs_hcom_driver_set_ipport)](#ubs_hcom_driver_set_ipport) + +[2.1.2.2.3 ubs_hcom_driver_get_ipport [94](#ubs_hcom_driver_get_ipport)](#ubs_hcom_driver_get_ipport) + +[2.1.2.2.4 ubs_hcom_driver_set_udsname [95](#ubs_hcom_driver_set_udsname)](#ubs_hcom_driver_set_udsname) + +[2.1.2.2.5 ubs_hcom_driver_add_uds_opt [95](#ubs_hcom_driver_add_uds_opt)](#ubs_hcom_driver_add_uds_opt) + +[2.1.2.2.6 ubs_hcom_driver_add_oob_opt [96](#ubs_hcom_driver_add_oob_opt)](#ubs_hcom_driver_add_oob_opt) + +[2.1.2.2.7 ubs_hcom_driver_initizalize [97](#ubs_hcom_driver_initizalize)](#ubs_hcom_driver_initizalize) + +[2.1.2.2.8 ubs_hcom_driver_start [97](#ubs_hcom_driver_start)](#ubs_hcom_driver_start) + +[2.1.2.2.9 ubs_hcom_driver_connect [98](#ubs_hcom_driver_connect)](#ubs_hcom_driver_connect) + +[2.1.2.2.10 ubs_hcom_driver_stop [99](#ubs_hcom_driver_stop)](#ubs_hcom_driver_stop) + +[2.1.2.2.11 ubs_hcom_driver_uninitialize [99](#ubs_hcom_driver_uninitialize)](#ubs_hcom_driver_uninitialize) + +[2.1.2.2.12 ubs_hcom_driver_destroy [100](#ubs_hcom_driver_destroy)](#ubs_hcom_driver_destroy) + +[2.1.2.2.13 ubs_hcom_driver_register_ep_handler [100](#ubs_hcom_driver_register_ep_handler)](#ubs_hcom_driver_register_ep_handler) + +[2.1.2.2.14 ubs_hcom_driver_register_op_handler [101](#ubs_hcom_driver_register_op_handler)](#ubs_hcom_driver_register_op_handler) + +[2.1.2.2.15 ubs_hcom_driver_register_idle_handler [102](#ubs_hcom_driver_register_idle_handler)](#ubs_hcom_driver_register_idle_handler) + +[2.1.2.2.16 ubs_hcom_driver_register_secinfo_provider [103](#ubs_hcom_driver_register_secinfo_provider)](#ubs_hcom_driver_register_secinfo_provider) + +[2.1.2.2.17 ubs_hcom_driver_register_secinfo_validator [104](#ubs_hcom_driver_register_secinfo_validator)](#ubs_hcom_driver_register_secinfo_validator) + +[2.1.2.2.18 ubs_hcom_driver_unregister_ep_handler [104](#ubs_hcom_driver_unregister_ep_handler)](#ubs_hcom_driver_unregister_ep_handler) + +[2.1.2.2.19 ubs_hcom_driver_unregister_op_handler [105](#ubs_hcom_driver_unregister_op_handler)](#ubs_hcom_driver_unregister_op_handler) + +[2.1.2.2.20 ubs_hcom_driver_unregister_idle_handler [106](#ubs_hcom_driver_unregister_idle_handler)](#ubs_hcom_driver_unregister_idle_handler) + +[2.1.2.2.21 ubs_hcom_driver_create_memory_region [106](#ubs_hcom_driver_create_memory_region)](#ubs_hcom_driver_create_memory_region) + +[2.1.2.2.22 ubs_hcom_driver_create_assign_memory_region [107](#ubs_hcom_driver_create_assign_memory_region)](#ubs_hcom_driver_create_assign_memory_region) + +[2.1.2.2.23 ubs_hcom_driver_destroy_memory_region [107](#ubs_hcom_driver_destroy_memory_region)](#ubs_hcom_driver_destroy_memory_region) + +[2.1.2.2.24 ubs_hcom_driver_get_memory_region_info [108](#ubs_hcom_driver_get_memory_region_info)](#ubs_hcom_driver_get_memory_region_info) + +[2.1.2.2.25 ubs_hcom_ep_set_context [108](#ubs_hcom_ep_set_context)](#ubs_hcom_ep_set_context) + +[2.1.2.2.26 ubs_hcom_ep_get_context [109](#ubs_hcom_ep_get_context)](#ubs_hcom_ep_get_context) + +[2.1.2.2.27 ubs_hcom_ep_get_worker_idx [109](#ubs_hcom_ep_get_worker_idx)](#ubs_hcom_ep_get_worker_idx) + +[2.1.2.2.28 ubs_hcom_ep_get_workergroup_idx [110](#ubs_hcom_ep_get_workergroup_idx)](#ubs_hcom_ep_get_workergroup_idx) + +[2.1.2.2.29 ubs_hcom_ep_get_listen_port [110](#ubs_hcom_ep_get_listen_port)](#ubs_hcom_ep_get_listen_port) + +[2.1.2.2.30 ubs_hcom_ep_version [111](#ubs_hcom_ep_version)](#ubs_hcom_ep_version) + +[2.1.2.2.31 ubs_hcom_ep_set_timeout [111](#ubs_hcom_ep_set_timeout)](#ubs_hcom_ep_set_timeout) + +[2.1.2.2.32 ubs_hcom_ep_post_send [112](#ubs_hcom_ep_post_send)](#ubs_hcom_ep_post_send) + +[2.1.2.2.33 ubs_hcom_ep_post_send_with_opinfo [113](#ubs_hcom_ep_post_send_with_opinfo)](#ubs_hcom_ep_post_send_with_opinfo) + +[2.1.2.2.34 ubs_hcom_ep_post_send_with_seqno [113](#ubs_hcom_ep_post_send_with_seqno)](#ubs_hcom_ep_post_send_with_seqno) + +[2.1.2.2.35 ubs_hcom_ep_post_read [114](#ubs_hcom_ep_post_read)](#ubs_hcom_ep_post_read) + +[2.1.2.2.36 ubs_hcom_ep_post_write [115](#ubs_hcom_ep_post_write)](#ubs_hcom_ep_post_write) + +[2.1.2.2.37 ubs_hcom_ep_wait_completion [115](#ubs_hcom_ep_wait_completion)](#ubs_hcom_ep_wait_completion) + +[2.1.2.2.38 ubs_hcom_ep_receive [116](#ubs_hcom_ep_receive)](#ubs_hcom_ep_receive) + +[2.1.2.2.39 ubs_hcom_ep_refer [116](#ubs_hcom_ep_refer)](#ubs_hcom_ep_refer) + +[2.1.2.2.40 ubs_hcom_ep_close [117](#ubs_hcom_ep_close)](#ubs_hcom_ep_close) + +[2.1.2.2.41 ubs_hcom_ep_destroy [117](#ubs_hcom_ep_destroy)](#ubs_hcom_ep_destroy) + +[2.1.2.2.42 ubs_hcom_err_str [118](#ubs_hcom_err_str)](#ubs_hcom_err_str) + +[2.1.2.2.43 ubs_hcom_mem_allocator_create [119](#ubs_hcom_mem_allocator_create)](#ubs_hcom_mem_allocator_create) + +[2.1.2.2.44 ubs_hcom_mem_allocator_destroy [119](#ubs_hcom_mem_allocator_destroy)](#ubs_hcom_mem_allocator_destroy) + +[2.1.2.2.45 ubs_hcom_mem_allocator_set_mr_key [120](#ubs_hcom_mem_allocator_set_mr_key)](#ubs_hcom_mem_allocator_set_mr_key) + +[2.1.2.2.46 ubs_hcom_mem_allocator_get_offset [120](#ubs_hcom_mem_allocator_get_offset)](#ubs_hcom_mem_allocator_get_offset) + +[2.1.2.2.47 ubs_hcom_mem_allocator_get_free_size [121](#ubs_hcom_mem_allocator_get_free_size)](#ubs_hcom_mem_allocator_get_free_size) + +[2.1.2.2.48 ubs_hcom_mem_allocator_allocate [122](#ubs_hcom_mem_allocator_allocate)](#ubs_hcom_mem_allocator_allocate) + +[2.1.2.2.49 ubs_hcom_mem_allocator_free [122](#ubs_hcom_mem_allocator_free)](#ubs_hcom_mem_allocator_free) + +[2.1.2.2.50 ubs_hcom_set_log_handler [123](#ubs_hcom_set_log_handler)](#ubs_hcom_set_log_handler) + +[2.1.2.2.51 ubs_hcom_check_local_supporr [123](#ubs_hcom_check_local_supporr)](#ubs_hcom_check_local_supporr) + +[2.1.2.2.52 ubs_hcom_get_remote_uds_info [124](#ubs_hcom_get_remote_uds_info)](#ubs_hcom_get_remote_uds_info) + +[2.2 高级API参考 [125](#高级api参考)](#高级api参考) + +[2.2.1 C++API [125](#capi)](#capi) + +[2.2.1.1 服务层 [125](#服务层)](#服务层) + +[2.2.1.1.1 UBSHcomService::AddWorkerGroup [125](#ubshcomserviceaddworkergroup)](#ubshcomserviceaddworkergroup) + +[2.2.1.1.2 UBSHcomService::AddListener [126](#ubshcomserviceaddlistener)](#ubshcomserviceaddlistener) + +[2.2.1.1.3 UBSHcomService::SetConnectLBPolicy [126](#ubshcomservicesetconnectlbpolicy)](#ubshcomservicesetconnectlbpolicy) + +[2.2.1.1.4 UBSHcomService::SetUBSHcomTlsOptions [127](#ubshcomservicesetubshcomtlsoptions)](#ubshcomservicesetubshcomtlsoptions) + +[2.2.1.1.5 UBSHcomService::SetConnSecureOpt [127](#ubshcomservicesetconnsecureopt)](#ubshcomservicesetconnsecureopt) + +[2.2.1.1.6 UBSHcomService::SetTcpUserTimeOutSec [128](#ubshcomservicesettcpusertimeoutsec)](#ubshcomservicesettcpusertimeoutsec) + +[2.2.1.1.7 UBSHcomService::SetTcpSendZCopy [128](#ubshcomservicesettcpsendzcopy)](#ubshcomservicesettcpsendzcopy) + +[2.2.1.1.8 UBSHcomService::SetDeviceIpMask [129](#ubshcomservicesetdeviceipmask)](#ubshcomservicesetdeviceipmask) + +[2.2.1.1.9 UBSHcomService::SetDeviceIpGroups [129](#ubshcomservicesetdeviceipgroups)](#ubshcomservicesetdeviceipgroups) + +[2.2.1.1.10 UBSHcomService::SetCompletionQueueDepth [130](#ubshcomservicesetcompletionqueuedepth)](#ubshcomservicesetcompletionqueuedepth) + +[2.2.1.1.11 UBSHcomService::SetSendQueueSize [130](#ubshcomservicesetsendqueuesize)](#ubshcomservicesetsendqueuesize) + +[2.2.1.1.12 UBSHcomService::SetRecvQueueSize [131](#ubshcomservicesetrecvqueuesize)](#ubshcomservicesetrecvqueuesize) + +[2.2.1.1.13 UBSHcomService::SetPollingBatchSize [131](#ubshcomservicesetpollingbatchsize)](#ubshcomservicesetpollingbatchsize) + +[2.2.1.1.14 UBSHcomService::SetEventPollingTimeOutUs [132](#ubshcomserviceseteventpollingtimeoutus)](#ubshcomserviceseteventpollingtimeoutus) + +[2.2.1.1.15 UBSHcomService::SetTimeOutDetectionThreadNum [132](#ubshcomservicesettimeoutdetectionthreadnum)](#ubshcomservicesettimeoutdetectionthreadnum) + +[2.2.1.1.16 UBSHcomService::SetMaxConnectionCount [133](#ubshcomservicesetmaxconnectioncount)](#ubshcomservicesetmaxconnectioncount) + +[2.2.1.1.17 UBSHcomService::SetUBSHcomHeartBeatOptions [133](#ubshcomservicesetubshcomheartbeatoptions)](#ubshcomservicesetubshcomheartbeatoptions) + +[2.2.1.1.18 UBSHcomService::SetUBSHcomMultiRailOptions [134](#ubshcomservicesetubshcommultirailoptions)](#ubshcomservicesetubshcommultirailoptions) + +[2.2.1.1.19 UBSHcomService::SetQueuePrePostSize [134](#ubshcomservicesetqueueprepostsize)](#ubshcomservicesetqueueprepostsize) + +[2.2.1.1.20 UBSHcomService::SetMaxSendRecvDataCount [135](#ubshcomservicesetmaxsendrecvdatacount)](#ubshcomservicesetmaxsendrecvdatacount) + +[2.2.1.1.21 UBSHcomRegMemoryRegion::GetMemoryKey [135](#ubshcomregmemoryregiongetmemorykey)](#ubshcomregmemoryregiongetmemorykey) + +[2.2.1.1.22 UBSHcomRegMemoryRegion::GetAddress [136](#ubshcomregmemoryregiongetaddress)](#ubshcomregmemoryregiongetaddress) + +[2.2.1.1.23 UBSHcomRegMemoryRegion::GetSize [136](#ubshcomregmemoryregiongetsize)](#ubshcomregmemoryregiongetsize) + +[2.2.1.1.24 UBSHcomRegMemoryRegion::GetHcomMrs [137](#ubshcomregmemoryregiongethcommrs)](#ubshcomregmemoryregiongethcommrs) + +[2.2.1.1.25 UBSHcomNewCallback [137](#ubshcomnewcallback)](#ubshcomnewcallback) + +[2.2.1.2 传输层 [138](#传输层)](#传输层) + +[2.2.1.2.1 UBSHcomNetDriver::RegisterTLSCaCallback [138](#ubshcomnetdriverregistertlscacallback)](#ubshcomnetdriverregistertlscacallback) + +[2.2.1.2.2 UBSHcomNetDriver::RegisterTLSCertificationCallback [140](#ubshcomnetdriverregistertlscertificationcallback)](#ubshcomnetdriverregistertlscertificationcallback) + +[2.2.1.2.3 UBSHcomNetDriver::RegisterTLSPrivateKeyCallback [141](#ubshcomnetdriverregistertlsprivatekeycallback)](#ubshcomnetdriverregistertlsprivatekeycallback) + +[2.2.1.2.4 UBSHcomNetDriver::RegisterPskUseSessionCb [144](#ubshcomnetdriverregisterpskusesessioncb)](#ubshcomnetdriverregisterpskusesessioncb) + +[2.2.1.2.5 UBSHcomNetDriver::RegisterPskFindSessionCb [145](#ubshcomnetdriverregisterpskfindsessioncb)](#ubshcomnetdriverregisterpskfindsessioncb) + +[2.2.1.2.6 UBSHcomNetDriver::RegisterEndpointSecInfoProvider [146](#ubshcomnetdriverregisterendpointsecinfoprovider)](#ubshcomnetdriverregisterendpointsecinfoprovider) + +[2.2.1.2.7 UBSHcomNetDriver::RegisterEndpointSecInfoValidator [147](#ubshcomnetdriverregisterendpointsecinfovalidator)](#ubshcomnetdriverregisterendpointsecinfovalidator) + +[2.2.1.2.8 UBSHcomNetEndpoint::PostSendRawSgl [147](#ubshcomnetendpointpostsendrawsgl)](#ubshcomnetendpointpostsendrawsgl) + +[2.2.1.2.9 UBSHcomNetEndpoint::ReceiveRaw [148](#ubshcomnetendpointreceiveraw-1)](#ubshcomnetendpointreceiveraw-1) + +[2.2.1.2.10 UBSHcomNetEndpoint::EstimatedEncryptLen [149](#ubshcomnetendpointestimatedencryptlen)](#ubshcomnetendpointestimatedencryptlen) + +[2.2.1.2.11 UBSHcomNetEndpoint::Encrypt [149](#ubshcomnetendpointencrypt)](#ubshcomnetendpointencrypt) + +[2.2.1.2.12 UBSHcomNetEndpoint::EstimatedDecryptLen [150](#ubshcomnetendpointestimateddecryptlen)](#ubshcomnetendpointestimateddecryptlen) + +[2.2.1.2.13 UBSHcomNetEndpoint::Decrypt [150](#ubshcomnetendpointdecrypt)](#ubshcomnetendpointdecrypt) + +[2.2.1.2.14 UBSHcomNetEndpoint::SendFds [151](#ubshcomnetendpointsendfds)](#ubshcomnetendpointsendfds) + +[2.2.1.2.15 UBSHcomNetEndpoint::ReceiveFds [151](#ubshcomnetendpointreceivefds)](#ubshcomnetendpointreceivefds) + +[2.2.1.2.16 UBSHcomNetOutLogger::Instance [152](#ubshcomnetoutloggerinstance)](#ubshcomnetoutloggerinstance) + +[2.2.1.2.17 UBSHcomNetOutLogger::SetLogLevel [152](#ubshcomnetoutloggersetloglevel)](#ubshcomnetoutloggersetloglevel) + +[2.2.1.2.18 UBSHcomNetOutLogger::SetExternalLogFunction [153](#ubshcomnetoutloggersetexternallogfunction)](#ubshcomnetoutloggersetexternallogfunction) + +[2.2.1.2.19 UBSHcomNetOutLogger::Print [153](#ubshcomnetoutloggerprint)](#ubshcomnetoutloggerprint) + +[2.2.1.2.20 UBSHcomNetOutLogger::Log [154](#ubshcomnetoutloggerlog)](#ubshcomnetoutloggerlog) + +[2.2.1.2.21 UBSHcomNetOutLogger::GetLogLevel [155](#ubshcomnetoutloggergetloglevel)](#ubshcomnetoutloggergetloglevel) + +[2.2.1.2.22 UBSHcomNetAtomicState::Get [155](#ubshcomnetatomicstateget)](#ubshcomnetatomicstateget) + +[2.2.1.2.23 UBSHcomNetAtomicState::Set [155](#ubshcomnetatomicstateset)](#ubshcomnetatomicstateset) + +[2.2.1.2.24 UBSHcomNetAtomicState::CAS [156](#ubshcomnetatomicstatecas)](#ubshcomnetatomicstatecas) + +[2.2.1.2.25 UBSHcomNetAtomicState::Compare [156](#ubshcomnetatomicstatecompare)](#ubshcomnetatomicstatecompare) + +[2.2.2 C API [157](#c-api-2)](#c-api-2) + +[2.2.2.1 服务层 [157](#服务层-1)](#服务层-1) + +[2.2.2.1.1 ubs_hcom_service_add_workergroup [157](#ubs_hcom_service_add_workergroup)](#ubs_hcom_service_add_workergroup) + +[2.2.2.1.2 ubs_hcom_service_add_listener [158](#ubs_hcom_service_add_listener)](#ubs_hcom_service_add_listener) + +[2.2.2.1.3 ubs_hcom_service_set_lbpolicy [158](#ubs_hcom_service_set_lbpolicy)](#ubs_hcom_service_set_lbpolicy) + +[2.2.2.1.4 ubs_hcom_service_set_tls_opt [159](#ubs_hcom_service_set_tls_opt)](#ubs_hcom_service_set_tls_opt) + +[2.2.2.1.5 ubs_hcom_service_set_secure_opt [160](#ubs_hcom_service_set_secure_opt)](#ubs_hcom_service_set_secure_opt) + +[2.2.2.1.6 ubs_hcom_service_set_tcp_usr_timeout [161](#ubs_hcom_service_set_tcp_usr_timeout)](#ubs_hcom_service_set_tcp_usr_timeout) + +[2.2.2.1.7 ubs_hcom_service_set_tcp_send_zcopy [161](#ubs_hcom_service_set_tcp_send_zcopy)](#ubs_hcom_service_set_tcp_send_zcopy) + +[2.2.2.1.8 ubs_hcom_service_set_ipmask [162](#ubs_hcom_service_set_ipmask)](#ubs_hcom_service_set_ipmask) + +[2.2.2.1.9 ubs_hcom_service_set_ipgroup [162](#ubs_hcom_service_set_ipgroup)](#ubs_hcom_service_set_ipgroup) + +[2.2.2.1.10 ubs_hcom_service_set_cq_depth [163](#ubs_hcom_service_set_cq_depth)](#ubs_hcom_service_set_cq_depth) + +[2.2.2.1.11 ubs_hcom_service_set_sq_size [163](#ubs_hcom_service_set_sq_size)](#ubs_hcom_service_set_sq_size) + +[2.2.2.1.12 ubs_hcom_service_set_rq_size [164](#ubs_hcom_service_set_rq_size)](#ubs_hcom_service_set_rq_size) + +[2.2.2.1.13 ubs_hcom_service_set_polling_batchsize [165](#ubs_hcom_service_set_polling_batchsize)](#ubs_hcom_service_set_polling_batchsize) + +[2.2.2.1.14 ubs_hcom_service_set_polling_timeoutus [165](#ubs_hcom_service_set_polling_timeoutus)](#ubs_hcom_service_set_polling_timeoutus) + +[2.2.2.1.15 ubs_hcom_service_set_timeout_threadnum [166](#ubs_hcom_service_set_timeout_threadnum)](#ubs_hcom_service_set_timeout_threadnum) + +[2.2.2.1.16 ubs_hcom_service_set_max_connection_cnt [166](#ubs_hcom_service_set_max_connection_cnt)](#ubs_hcom_service_set_max_connection_cnt) + +[2.2.2.1.17 ubs_hcom_service_set_heartbeat_opt [167](#ubs_hcom_service_set_heartbeat_opt)](#ubs_hcom_service_set_heartbeat_opt) + +[2.2.2.1.18 ubs_hcom_service_set_multirail_opt [167](#ubs_hcom_service_set_multirail_opt)](#ubs_hcom_service_set_multirail_opt) + +[2.2.2.1.19 ubs_hcom_set_log_handler [168](#ubs_hcom_set_log_handler-1)](#ubs_hcom_set_log_handler-1) + +[2.2.2.2 传输层 [169](#传输层-1)](#传输层-1) + +[2.2.2.2.1 ubs_hcom_driver_register_tls_cb [169](#ubs_hcom_driver_register_tls_cb)](#ubs_hcom_driver_register_tls_cb) + +[2.2.2.2.2 ubs_hcom_ep_post_send_raw [172](#ubs_hcom_ep_post_send_raw)](#ubs_hcom_ep_post_send_raw) + +[2.2.2.2.3 ubs_hcom_ep_post_send_raw_sgl [173](#ubs_hcom_ep_post_send_raw_sgl)](#ubs_hcom_ep_post_send_raw_sgl) + +[2.2.2.2.4 ubs_hcom_ep_post_read_sgl [174](#ubs_hcom_ep_post_read_sgl)](#ubs_hcom_ep_post_read_sgl) + +[2.2.2.2.5 ubs_hcom_ep_post_write_sgl [175](#ubs_hcom_ep_post_write_sgl)](#ubs_hcom_ep_post_write_sgl) + +[2.2.2.2.6 ubs_hcom_ep_receive_raw [175](#ubs_hcom_ep_receive_raw)](#ubs_hcom_ep_receive_raw) + +[2.2.2.2.7 ubs_hcom_ep_receive_raw_sgl [176](#ubs_hcom_ep_receive_raw_sgl)](#ubs_hcom_ep_receive_raw_sgl) + +[2.2.2.2.8 ubs_hcom_estimate_encrypt_len [177](#ubs_hcom_estimate_encrypt_len)](#ubs_hcom_estimate_encrypt_len) + +[2.2.2.2.9 ubs_hcom_encrypt [177](#ubs_hcom_encrypt)](#ubs_hcom_encrypt) + +[2.2.2.2.10 ubs_hcom_estimate_decrypt_len [178](#ubs_hcom_estimate_decrypt_len)](#ubs_hcom_estimate_decrypt_len) + +[2.2.2.2.11 ubs_hcom_decrypt [178](#ubs_hcom_decrypt)](#ubs_hcom_decrypt) + +[2.2.2.2.12 ubs_hcom_send_fds [179](#ubs_hcom_send_fds)](#ubs_hcom_send_fds) + +[2.2.2.2.13 ubs_hcom_receive_fds [180](#ubs_hcom_receive_fds)](#ubs_hcom_receive_fds) + +[2.3 结构体参考 [180](#结构体参考)](#结构体参考) + +[2.3.1 C++结构体 [180](#c结构体)](#c结构体) + +[2.3.1.1 服务层结构体 [180](#服务层结构体)](#服务层结构体) + +[2.3.1.1.1 UBSHcomServiceOptions [180](#ubshcomserviceoptions)](#ubshcomserviceoptions) + +[2.3.1.1.2 UBSHcomConnectOptions [181](#ubshcomconnectoptions)](#ubshcomconnectoptions) + +[2.3.1.1.3 UBSHcomRequest [182](#ubshcomrequest)](#ubshcomrequest) + +[2.3.1.1.4 UBSHcomResponse [182](#ubshcomresponse)](#ubshcomresponse) + +[2.3.1.1.5 UBSHcomReplyContext [182](#ubshcomreplycontext)](#ubshcomreplycontext) + +[2.3.1.1.6 UBSHcomOneSideRequest [182](#ubshcomonesiderequest)](#ubshcomonesiderequest) + +[2.3.1.1.7 UBSHcomFlowCtrlOptions [183](#flowctrloptions)](#flowctrloptions) + +[2.3.1.1.8 UBSHcomTlsOptions [183](#ubshcomtlsoptions)](#ubshcomtlsoptions) + +[2.3.1.1.9 UBSHcomConnSecureOptions [187](#ubshcomconnsecureoptions)](#ubshcomconnsecureoptions) + +[2.3.1.1.10 UBSHcomHeartBeatOptions [187](#ubshcomheartbeatoptions)](#ubshcomheartbeatoptions) + +[2.3.1.1.11 UBSHcomMultiRailOptions [188](#ubshcommultirailoptions)](#ubshcommultirailoptions) + +[2.3.1.1.12 UBSHcomIov [188](#ubshcomiov)](#ubshcomiov) + +[2.3.1.1.13 UBSHcomOneSideSglRequest [188](#ubshcomonesidesglrequest)](#ubshcomonesidesglrequest) + +[2.3.1.1.14 UBSHcomMemoryKey [188](#ubshcommemorykey)](#ubshcommemorykey) + +[2.3.1.1.15 UBSHcomSglRequest [189](#ubshcomsglrequest)](#ubshcomsglrequest) + +[2.3.1.1.16 UBSHcomTwoSideThreshold [189](#ubshcomtwosidethreshold)](#ubshcomtwosidethreshold) + +[2.3.1.2 传输层结构体 [189](#传输层结构体)](#传输层结构体) + +[2.3.1.2.1 UBSHcomNetDriverDeviceInfo [189](#ubshcomnetdriverdeviceinfo)](#ubshcomnetdriverdeviceinfo) + +[2.3.1.2.2 UBSHcomNetDriverOptions [190](#ubshcomnetdriveroptions)](#ubshcomnetdriveroptions) + +[2.3.1.2.3 UBSHcomNetOobListenerOptions [194](#ubshcomnetooblisteneroptions)](#ubshcomnetooblisteneroptions) + +[2.3.1.2.4 UBSHcomNetOobUDSListenerOptions [194](#ubshcomnetoobudslisteneroptions)](#ubshcomnetoobudslisteneroptions) + +[2.3.1.2.5 UBSHcomEpOptions [194](#ubshcomepoptions)](#ubshcomepoptions) + +[2.3.1.2.6 UBSHcomNetTransRequest [195](#ubshcomnettransrequest)](#ubshcomnettransrequest) + +[2.3.1.2.7 UBSHcomNetTransOpInfo [195](#ubshcomnettransopinfo)](#ubshcomnettransopinfo) + +[2.3.1.2.8 UBSHcomNetUdsIdInfo [196](#ubshcomnetudsidinfo)](#ubshcomnetudsidinfo) + +[2.3.1.2.9 UBSHcomNetMemoryAllocatorOptions [196](#ubshcomnetmemoryallocatoroptions)](#ubshcomnetmemoryallocatoroptions) + +[2.3.1.2.10 UBSHcomNetTransSglRequest [197](#ubshcomnettranssglrequest)](#ubshcomnettranssglrequest) + +[2.3.1.2.11 UBSHcomNetTransSgeIov [197](#ubshcomnettranssgeiov)](#ubshcomnettranssgeiov) + +[2.3.1.2.12 UBSHcomWorkerGroupInfo [197](#ubshcomworkergroupinfo)](#ubshcomworkergroupinfo) + +[2.3.1.2.13 UBSHcomNetUdsIdInfo [198](#ubshcomnetudsidinfo-1)](#ubshcomnetudsidinfo-1) + +[2.3.1.2.14 UBSHcomNetTransHeader [198](#ubshcomnettransheader)](#ubshcomnettransheader) + +[2.3.2 C结构体 [199](#c结构体-1)](#c结构体-1) + +[2.3.2.1 服务层结构体 [199](#服务层结构体-1)](#服务层结构体-1) + +[2.3.2.1.1 ubs_hcom_mr_info [199](#ubs_hcom_mr_info)](#ubs_hcom_mr_info) + +[2.3.2.1.2 ubs_hcom_channel_reply_context [199](#ubs_hcom_channel_reply_context)](#ubs_hcom_channel_reply_context) + +[2.3.2.1.3 ubs_hcom_oneside_request [199](#ubs_hcom_oneside_request)](#ubs_hcom_oneside_request) + +[2.3.2.1.4 ubs_hcom_channel_callback [200](#channel_callback)](#channel_callback) + +[2.3.2.1.5 ubs_hcom_flowctl_opts [200](#ubs_hcom_flowctl_opts)](#ubs_hcom_flowctl_opts) + +[2.3.2.1.6 ubs_hcom_service_options [200](#ubs_hcom_service_options)](#ubs_hcom_service_options) + +[2.3.2.1.7 Service_UBSHcomConnectOptions [201](#service_ubshcomconnectoptions)](#service_ubshcomconnectoptions) + +[2.3.2.1.8 ubs_hcom_channel_request [202](#ubs_hcom_channel_request)](#ubs_hcom_channel_request) + +[2.3.2.1.9 ubs_hcom_channel_response [202](#ubs_hcom_channel_response)](#ubs_hcom_channel_response) + +[2.3.2.1.10 Channel_UBSHcomTwoSideThreshold [202](#channel_ubshcomtwosidethreshold)](#channel_ubshcomtwosidethreshold) + +[2.3.2.1.11 ubs_hcom_oneside_key [203](#ubs_hcom_oneside_key)](#ubs_hcom_oneside_key) + +[2.3.2.2 传输层结构体 [203](#传输层结构体-1)](#传输层结构体-1) + +[2.3.2.2.1 ubs_hcom_send_request [203](#ubs_hcom_send_request)](#ubs_hcom_send_request) + +[2.3.2.2.2 ubs_hcom_opinfo [203](#ubs_hcom_opinfo)](#ubs_hcom_opinfo) + +[2.3.2.2.3 ubs_hcom_device_info [203](#ubs_hcom_device_info)](#ubs_hcom_device_info) + +[2.3.2.2.4 ubs_hcom_readwrite_request [204](#ubs_hcom_readwrite_request)](#ubs_hcom_readwrite_request) + +[2.3.2.2.5 ubs_hcom_readwrite_sge [204](#ubs_hcom_readwrite_sge)](#ubs_hcom_readwrite_sge) + +[2.3.2.2.6 ubs_hcom_readwrite_request_sgl [204](#ubs_hcom_readwrite_request_sgl)](#ubs_hcom_readwrite_request_sgl) + +[2.3.2.2.7 ubs_hcom_memory_region_info [205](#ubs_hcom_memory_region_info)](#ubs_hcom_memory_region_info) + +[2.3.2.2.8 ubs_hcom_request_context [205](#ubs_hcom_request_context)](#ubs_hcom_request_context) + +[2.3.2.2.9 ubs_hcom_response_context [206](#ubs_hcom_response_context)](#ubs_hcom_response_context) + +[2.3.2.2.10 ubs_hcom_uds_id_info [206](#ubs_hcom_uds_id_info)](#ubs_hcom_uds_id_info) + +[2.3.2.2.11 ubs_hcom_driver_opts [206](#ubs_hcom_driver_opts)](#ubs_hcom_driver_opts) + +[2.3.2.2.12 ubs_hcom_driver_listen_opts [209](#ubs_hcom_driver_listen_opts)](#ubs_hcom_driver_listen_opts) + +[2.3.2.2.13 ubs_hcom_driver_uds_listen_opts [210](#ubs_hcom_driver_uds_listen_opts)](#ubs_hcom_driver_uds_listen_opts) + +[2.3.2.2.14 ubs_hcom_memory_allocator_options [210](#ubs_hcom_memory_allocator_options)](#ubs_hcom_memory_allocator_options) + +[2.4 枚举值参考 [211](#枚举值参考)](#枚举值参考) + +[2.4.1 C++枚举值 [211](#c枚举值)](#c枚举值) + +[2.4.1.1 服务层枚举值 [211](#服务层枚举值)](#服务层枚举值) + +[2.4.1.1.1 UBSHcomChannelBrokenPolicy [211](#ubshcomchannelbrokenpolicy)](#ubshcomchannelbrokenpolicy) + +[2.4.1.1.2 Operation [211](#operation)](#operation) + +[2.4.1.1.3 UBSHcomClientPollingMode [212](#ubshcomclientpollingmode)](#ubshcomclientpollingmode) + +[2.4.1.1.4 UBSHcomChannelCallBackType [212](#ubshcomchannelcallbacktype)](#ubshcomchannelcallbacktype) + +[2.4.1.1.5 UBSHcomFlowCtrlLevel [212](#ubshcomflowctrllevel)](#ubshcomflowctrllevel) + +[2.4.1.1.6 UBSHcomChannelState [213](#ubshcomchannelstate)](#ubshcomchannelstate) + +[2.4.1.1.7 UBSHcomOobType [213](#ubshcomoobtype)](#ubshcomoobtype) + +[2.4.1.1.8 UBSHcomSecType [213](#hcomsectype)](#hcomsectype) + +[2.4.1.2 传输层枚举值 [214](#传输层枚举值)](#传输层枚举值) + +[2.4.1.2.1 UBSHcomNetEndPointState [214](#ubshcomnetendpointstate-1)](#ubshcomnetendpointstate-1) + +[2.4.1.2.2 UBSHcomNetCipherSuite [214](#ubshcomnetciphersuite)](#ubshcomnetciphersuite) + +[2.4.1.2.3 UBSHcomTlsVersion [215](#ubshcomtlsversion)](#ubshcomtlsversion) + +[2.4.1.2.4 NN_OpType [215](#nn_optype)](#nn_optype) + +[2.4.1.2.5 UBSHcomNetMemoryAllocatorType [216](#ubshcomnetmemoryallocatortype)](#ubshcomnetmemoryallocatortype) + +[2.4.1.2.6 UBSHcomNetMemoryAllocatorCacheTierPolicy [216](#ubshcomnetmemoryallocatorcachetierpolicy)](#ubshcomnetmemoryallocatorcachetierpolicy) + +[2.4.1.2.7 UBSHcomPeerCertVerifyType [216](#ubshcompeercertverifytype)](#ubshcompeercertverifytype) + +[2.4.1.2.8 UBSHcomNetDriverSecType [217](#ubshcomnetdriversectype)](#ubshcomnetdriversectype) + +[2.4.1.2.9 NetDriverOobType [217](#netdriveroobtype)](#netdriveroobtype) + +[2.4.1.2.10 UBSHcomNetDriverWorkingMode [217](#ubshcomnetdriverworkingmode)](#ubshcomnetdriverworkingmode) + +[2.4.1.2.11 UBSHcomNetDriverLBPolicy [218](#ubshcomnetdriverlbpolicy)](#ubshcomnetdriverlbpolicy) + +[2.4.1.2.12 UBSHcomNetDriverProtocol [218](#ubshcomnetdriverprotocol-1)](#ubshcomnetdriverprotocol-1) + +[2.4.1.2.13 UBSHcomUbcMode [219](#ubshcomubcmode)](#ubshcomubcmode) + +[2.4.2 C枚举值 [219](#c枚举值-1)](#c枚举值-1) + +[2.4.2.1 服务层枚举值 [219](#服务层枚举值-1)](#服务层枚举值-1) + +[2.4.2.1.1 ubs_hcom_channel_cb_type [219](#ubs_hcom_channel_cb_type)](#ubs_hcom_channel_cb_type) + +[2.4.2.1.2 ubs_hcom_service_context_type [219](#ubs_hcom_service_context_type)](#ubs_hcom_service_context_type) + +[2.4.2.1.3 ubs_hcom_channel_flowctl_level [220](#ubs_hcom_channel_flowctl_level)](#ubs_hcom_channel_flowctl_level) + +[2.4.2.1.4 ubs_hcom_service_worker_mode [220](#ubs_hcom_service_worker_mode)](#ubs_hcom_service_worker_mode) + +[2.4.2.1.5 ubs_hcom_service_lb_policy [221](#ubs_hcom_service_lb_policy)](#ubs_hcom_service_lb_policy) + +[2.4.2.1.6 ubs_hcom_service_cipher_suite [221](#ubs_hcom_service_cipher_suite)](#ubs_hcom_service_cipher_suite) + +[2.4.2.1.7 ubs_hcom_service_tls_version [221](#ubs_hcom_service_tls_version)](#ubs_hcom_service_tls_version) + +[2.4.2.1.8 ubs_hcom_service_secure_type [222](#ubs_hcom_service_secure_type)](#ubs_hcom_service_secure_type) + +[2.4.2.1.9 ubs_hcom_service_channel_policy [222](#ubs_hcom_service_channel_policy)](#ubs_hcom_service_channel_policy) + +[2.4.2.1.10 ubs_hcom_service_channel_handler_type [223](#ubs_hcom_service_channel_handler_type)](#ubs_hcom_service_channel_handler_type) + +[2.4.2.1.11 ubs_hcom_service_handler_type [223](#ubs_hcom_service_handler_type)](#ubs_hcom_service_handler_type) + +[2.4.2.1.12 ubs_hcom_service_type [223](#ubs_hcom_service_type)](#ubs_hcom_service_type) + +[2.4.2.1.13 ubs_hcom_service_polling_mode [224](#ubs_hcom_service_polling_mode)](#ubs_hcom_service_polling_mode) + +[2.4.2.2 传输层枚举值 [224](#传输层枚举值-1)](#传输层枚举值-1) + +[2.4.2.2.1 ubs_hcom_request_type [224](#ubs_hcom_request_type)](#ubs_hcom_request_type) + +[2.4.2.2.2 ubs_hcom_driver_working_mode [225](#ubs_hcom_driver_working_mode)](#ubs_hcom_driver_working_mode) + +[2.4.2.2.3 ubs_hcom_driver_type [225](#ubs_hcom_driver_type)](#ubs_hcom_driver_type) + +[2.4.2.2.4 ubs_hcom_driver_oob_type [226](#ubs_hcom_driver_oob_type)](#ubs_hcom_driver_oob_type) + +[2.4.2.2.5 ubs_hcom_driver_sec_type [226](#ubs_hcom_driver_sec_type)](#ubs_hcom_driver_sec_type) + +[2.4.2.2.6 ubs_hcom_driver_tls_version [226](#ubs_hcom_driver_tls_version)](#ubs_hcom_driver_tls_version) + +[2.4.2.2.7 ubs_hcom_driver_cipher_suite [227](#ubs_hcom_driver_cipher_suite)](#ubs_hcom_driver_cipher_suite) + +[2.4.2.2.8 ubs_hcom_peer_cert_verify_type [227](#ubs_hcom_peer_cert_verify_type)](#ubs_hcom_peer_cert_verify_type) + +[2.4.2.2.9 ubs_hcom_memory_allocator_cache_tier_policy [227](#ubs_hcom_memory_allocator_cache_tier_policy)](#ubs_hcom_memory_allocator_cache_tier_policy) + +[2.4.2.2.10 ubs_hcom_memory_allocator_type [228](#ubs_hcom_memory_allocator_type)](#ubs_hcom_memory_allocator_type) + +[2.4.2.2.11 ubs_hcom_ep_handler_type [228](#ubs_hcom_ep_handler_type)](#ubs_hcom_ep_handler_type) + +[2.4.2.2.12 ubs_hcom_op_handler_type [228](#ubs_hcom_op_handler_type)](#ubs_hcom_op_handler_type) + +[2.4.2.2.13 ubs_hcom_polling_mode [229](#ubs_hcom_polling_mode)](#ubs_hcom_polling_mode) + +[2.4.2.2.14 ubs_hcom_service_polling_mode [229](#ubs_hcom_service_polling_mode-1)](#ubs_hcom_service_polling_mode-1) + +[3 环境变量参考 [230](#环境变量参考)](#环境变量参考) + +[4 错误码 [233](#错误码)](#错误码) + +[4.1 服务层错误码 [233](#服务层错误码)](#服务层错误码) + +[4.2 传输层错误码 [235](#传输层错误码)](#传输层错误码) + +[4.3 RDMA协议错误码 [238](#rdma协议错误码)](#rdma协议错误码) + +# 介绍 + +本文主要介绍UBS Comm对外提供的API。可以从两个不同的角度,对UBS Comm的API进行分类: + +- 编程语言 + +UBS Comm主体使用C++语言开发,对外提供C++ API。为了方便不同场景的开发者使用,UBS Comm还对C++ API做了一层封装,对外提供C和Java API。 + +- 功能架构 + +考虑性能及易用性,UBS Comm使用了“传输层”和“服务层”两层架构。传输层追求极致性能,服务层追求极致易用性。传输层和服务层均提供API,使用传输层或服务层的API均可以独立完成通信功能。传输层仅提供了高性能的通信基础功能,服务层还提供了链路重连、限流、超时检测等常用的高级功能。 + +由于UBS Comm对外提供的API较多,为方便开发者阅读及理解,本文分为如下几个大章节介绍UBS Comm的API。 + +- 基础API参考 + +介绍应用开发过程中最常用和基础的API,建议使用UBS Comm的开发者对这些API都有所了解。 + +- 高级API参考 + +介绍应用开发过程中不常用的API,开发者可以根据自身场景需要进行查阅。 + +- 环境变量 + +介绍UBS Comm对外提供的环境变量。 + +- 错误码 + +介绍UBS Comm的错误码名称、取值及部分常见错误码的处理方法。 + +# API 2.0 + +[3.1 基础API参考](#基础api参考) + +[3.2 高级API参考](#高级api参考) + +[3.3 结构体参考](#结构体参考) + +[3.4 枚举值参考](#枚举值参考) + +## 基础API参考 + +### C++ API + +#### 服务层API + +##### UBSHcomService::Create + +1. 函数定义 + +根据类型、名字和可选配置项创建一个服务层的NetService对象。 + +2. 实现方法 + +static UBSHcomService\* UBSHcomService::Create(UBSHcomServiceProtocol t, const std::string &name, const UBSHcomServiceOptions &opt = {}); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| t | [UBSHcomNetDriverProtocol](#ubshcomnetdriverprotocol-1) | 入参 | UBSHcomService协议类型。 | +| name | String | 入参 | UBSHcomService的名称。长度范围\[1, 64\],只能包含数字、字母、‘\_’和‘-’。 | +| opt | [UBSHcomServiceOptions](#ubshcomserviceoptions) | 入参 | 可选基础配置项。 | + +4. 返回值 + +成功则返回NetService类型的实例,否则返回空。 + +##### UBSHcomService::Destroy + +1. 函数定义 + +销毁服务,会清理全局map并根据名字销毁对象。 + +2. 实现方法 + +static int32_t UBSHcomService::Destroy(const std::string &name); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| name | String | 入参 | 要删除的服务对象的名称。长度范围\[1, 100\],只能包含数字、字母、‘\_’和‘-’。 | + +4. 返回值 + +表示函数执行结果,返回值为0则表示销毁成功。 + +##### UBSHcomService::Bind + +1. 函数定义 + +服务端绑定监听的url和端口号 + +2. 实现方法 + +int32_t UBSHcomService::Bind(const std::string &listenerUrl, const UBSHcomServiceNewChannelHandler &handler) + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +绑定成功返回0,失败返回对应错误码 + +##### UBSHcomService::Start + +1. 函数定义 + +启动服务 + +2. 实现方法 + +int32_t UBSHcomService::Start() + +3. 参数说明 + +无 + +4. 返回值 + +启动成功返回0,启动失败返回失败错误码。 + +##### UBSHcomService::Connect + +1. 函数定义 + +客户端向服务端发起建链。 + +2. 实现方法 + +int32_t UBSHcomService::Connect(const std::string &serverUrl, UBSHcomChannelPtr &ch, const UBSHcomConnectOptions &opt = {}) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| serverUrl | const std::string | 入参 | 服务端绑定监听的url。 | +| ch | UBSHcomChannelPtr | 出参 | 建链成功返回的channel通道。 | +| opt | const UBSHcomConnectOptions & | 入参 | 建链配置项。 | + +4. 返回值 + +无 + +##### UBSHcomService::Disconnect + +1. 函数定义 + +断开链接。 + +2. 实现方法 + +void UBSHcomService::Disconnect(const UBSHcomChannelPtr &ch) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------------|----------|-----------------------| +| ch | const UBSHcomChannelPtr | 入参 | 要断开的channel通道。 | + +4. 返回值 + +无 + +##### UBSHcomService::RegisterMemoryRegion + +1. 函数定义 + +- 注册一个内存区域,内存将在UBS Comm内部分配。 + +- 将用户申请的内存,注册到UBS Comm中。 + + 1. 实现方法 + +- int32_t UBSHcomService::RegisterMemoryRegion(uint64_t size, UBSHcomRegMemoryRegion &mr) + +- int32_t UBSHcomService::RegisterMemoryRegion(uintptr_t address, uint64_t size, UBSHcomRegMemoryRegion &mr) + + 1. 参数说明 + + 1. 参数说明 + +[TABLE] + +![](media/image8.png) + +若需要放入pgTable管理(通过UBSHcomService::SetEnableMrCache设置为true,默认不放入),则要求首地址(startAddress)和尾地址(startAddress+size)都需要16字节对齐,因此用户申请的size需要能16整除。 + +2. 返回值 + +表示函数执行结果,返回值为0则表示注册成功。 + +##### UBSHcomService::DestroyMemoryRegion + +1. 函数定义 + +销毁一个内存区域。 + +2. 实现方法 + +void UBSHcomService::DestroyMemoryRegion(UBSHcomRegMemoryRegion &mr) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|------------------------|----------|--------------------| +| mr | UBSHcomRegMemoryRegion | 入参 | 要销毁的内存区域。 | + +4. 返回值 + +无 + +##### UBSHcomService::RegisterChannelBrokenHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁Service及相关的资源。 + +1. 函数定义 + +给UBSHcomService注册断链回调函数。 + +2. 实现方法 + +void UBSHcomService::RegisterChannelBrokenHandler(const UBSHcomServiceChannelBrokenHandler &handler, const UBSHcomChannelBrokenPolicy policy) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|------------------------------------|----------|----------------| +| handler | UBSHcomServiceChannelBrokenHandler | 入参 | 断链回调函数。 | +| policy | UBSHcomChannelBrokenPolicy | 入参 | 断链回调策略。 | + +4. 返回值 + +无 + +##### UBSHcomService::RegisterIdleHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁service及相关的资源。 + +1. 函数定义 + +给此UBSHcomService注册worker闲时回调函数 + +2. 实现方法 + +void UBSHcomService::RegisterIdleHandler(const NetServiceIdleHandler &h) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|---------------------------------|----------|------------| +| handler | const UBSHcomServiceIdleHandler | 入参 | 回调函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomServiceIdleHandler= std::function\. + +##### UBSHcomService::RegisterRecvHandler + +![](media/image9.png) + +- 用户实现的回调函数,内部不能销毁Service及相关的资源。 + +- 用户需要避免在该回调中死等发送完成事件,应添加超时时间,否则会造成死锁。 + +- 用户需要尽量避免在该回调中占用过长时间处理业务,以免影响性能。 + + 1. 函数定义 + +注册回调函数以处理异步通信收到消息事件。 + +2. 实现方法 + +void UBSHcomService::RegisterRecvHandler(const UBSHcomServiceRecvHandler &recvHandler) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| recvHandler | UBSHcomServiceRecvHandler | 入参 | 处理异步通信收数据事件的回调函数句柄。 | + +4. 返回值 + +无 + +##### UBSHcomService::RegisterSendHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁Service及相关的资源。 + +1. 函数定义 + +注册回调函数以处理消息发送完成事件。 + +2. 实现方法 + +void UBSHcomService::RegisterSendHandler(const UBSHcomServiceSendHandler &sendHandler) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| sendHandler | UBSHcomServiceSendHandler | 入参 | 处理发送完成事件的回调函数句柄。 | + +4. 返回值 + +无 + +##### UBSHcomService::RegisterOneSideHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁Service及相关的资源。 + +1. 函数定义 + +注册回调函数以处理单边读/写完成事件。 + +2. 实现方法 + +void UBSHcomService::RegisterOneSideHandler(const UBSHcomServiceOneSideDoneHandler &oneSideDoneHandler) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| oneSideDoneHandler | UBSHcomServiceOneSideDoneHandler | 入参 | 处理单边读/写完成事件的回调函数句柄。 | + +4. 返回值 + +无 + +##### UBSHcomChannel::Send + +![](media/image9.png) + +- 若使用拆包和rndv的功能,需要通过UBSHcomChannel::SetUBSHcomTwoSideThreshold设置拆包和rndv的阈值。 + +- 使用rndv,则需要创建service后将UBSHcomService::SetEnableMrCache设置为true(UBSHcomService::RegisterMemoryRegion函数调用前)。 + + 1. 函数定义 + +  + +- 向对端异步发送一个双边请求消息,并且不等待响应。 + +- 向对端同步发送一个双边请求消息,并且不等待响应。 + + 1. 实现方法 + +- int32_t UBSHcomChannel::Send(const UBSHcomRequest &req, const Callback \*done) + +- int32_t UBSHcomChannel::Send(const UBSHcomRequest &req) + + 1. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +表示函数执行结果,0表示发送成功。 + +##### UBSHcomChannel::Call + +![](media/image9.png) + +- rsp中若address字段填了有效内存地址,则回复会被拷贝到该地址上。 + +- 若address==NULL,则UBS Comm会通过malloc申请内存,但用户需要自行维护该内存的生命周期,在使用完后通过free释放。 + +- 若使用拆包和rndv的功能,需要通过UBSHcomChannel::SetUBSHcomTwoSideThreshold设置拆包和rndv的阈值。 + +- 使用rndv,则需要创建service后将UBSHcomService::SetEnableMrCache设置为true(UBSHcomService::RegisterMemoryRegion函数调用前)。 + + 1. 函数定义 + +  + +- 异步模式下,发送一个UBSHcomRequest消息,并等待对方回复UBSHcomResponse响应消息。 + +- 同步模式下,发送一个UBSHcomRequest消息,并等待对方回复UBSHcomResponse响应消息。 + + 1. 实现方法 + +- int32_t UBSHcomChannel::Call(const UBSHcomRequest &req, UBSHcomResponse &rsp, const Callback \*done) + +- int32_t UBSHcomChannel::Call(const UBSHcomRequest &req, UBSHcomResponse &rsp) + + 1. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +表示函数执行结果,返回值为0则表示发送成功。 + +##### UBSHcomChannel::Reply + +1. 函数定义 + +1\. 异步模式下,向对端回复一个消息,配合Call接口使用 + +2\. 同步模式下,向对端回复一个消息,配合Call接口使用 + +2. 实现方法 + +1\. int32_t UBSHcomChannel::Reply(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req, const Callback \*done) + +2\. int32_t UBSHcomChannel::Reply(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req) + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +表示函数执行结果,0表示发送成功。 + +##### UBSHcomChannel::Get + +1. 函数定义 + +  + +1. 同步模式下,发送一个读请求给对方。 + +2. 异步模式下,发送一个读请求给对方。 + + 1. 实现方法 + +  + +1. int32_t UBSHcomChannel::Get(const UBSHcomOneSideRequest &req, const Callback \*done) + +2. int32_t UBSHcomChannel::Get(const UBSHcomOneSideRequest &req) + + 1. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +表示函数执行结果,返回值为0则表示读请求成功。 + +##### UBSHcomChannel::Put + +1. 函数定义 + +  + +1. 同步模式下,发送一个写请求给对方。 + +2. 异步模式下,发送一个写请求给对方。 + + 1. 实现方法 + +  + +1. int32_t UBSHcomChannel::Put(const UBSHcomOneSideRequest &req, const Callback \*done) + +2. int32_t UBSHcomChannel::Put(const UBSHcomOneSideRequest &req) + + 1. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +表示函数执行结果,返回值为0则表示写请求成功。 + +##### UBSHcomChannel::Recv + +1. 函数定义 + +只用于接收RNDV请求。 + +2. 实现方法 + +int32_t Recv(const UBSHcomServiceContext &context, uintptr_t address, uint32_t size, const Callback \*done = nullptr) + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +表示函数执行结果,返回值为0则表示接收请求成功。 + +##### UBSHcomChannel::SetFlowControlConfig + +1. 函数定义 + +设置限流。 + +2. 实现方法 + +int32_t SetFlowControlConfig(const UBSHcomFlowCtrlOptions &opt) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------------------------|----------|--------------| +| opt | [UBSHcomFlowCtrlOptions](#flowctrloptions) | 入参 | 流控配置项。 | + +4. 返回值 + +返回0表示成功 + +##### UBSHcomChannel::SetChannelTimeOut + +1. 函数定义 + +给该channel设置超时时间。未设置时默认超时时间30s。 + +2. 实现方法 + +void UBSHcomChannel::SetChannelTimeOut(int16_t oneSideTimeout, int16_t twoSideTimeout) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| oneSideTimeout | int16_t | 入参 | 单边超时时间,单位为秒,0为立即超时,负数为永不超时(通常设置为-1)。范围是\[-1, INT16_MAX\]。未设置时默认超时时间30s。 | +| twoSideTimeout | int16_t | 入参 | 双边超时时间,单位为秒,0为立即超时,负数为永不超时(通常设置为-1)。范围是\[-1, INT16_MAX\]。未设置时默认超时时间30s。 | + +4. 返回值 + +无 + +##### UBSHcomChannel::SetUBSHcomTwoSideThreshold + +1. 函数定义 + +设置双边操作阈值。 + +2. 实现方法 + +int32_t SetUBSHcomTwoSideThreshold(const UBSHcomTwoSideThreshold &threshold) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| threshold | [UBSHcomTwoSideThreshold](#ubshcomtwosidethreshold) | 入参 | 双边操作阈值。 | + +4. 返回值 + +返回0表示成功。 + +##### UBSHcomChannel::GetId + +1. 函数定义 + +获得channel ID。 + +2. 实现方法 + +uint64_t UBSHcomChannel::GetId() + +3. 参数说明 + +无 + +4. 返回值 + +uint64_t id信息。 + +##### UBSHcomChannel::GetPeerConnectPayload + +1. 函数定义 + +获得建链的payLoad信息。 + +2. 实现方法 + +std::string UBSHcomChannel::GetPeerConnectPayload() + +3. 参数说明 + +无 + +4. 返回值 + +playLoad信息。 + +##### UBSHcomChannel::SetTraceId + +1. 函数定义 + +设置trace id。 + +2. 实现方法 + +void SetTraceId(const std::string &traceId) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|-------------|----------|--------------------| +| traceId | std::string | 入参 | 要设置的trace ID。 | + +4. 返回值 + +无 + +##### UBSHcomServiceContext::Result + +1. 函数定义 + +获得ctx的结果,表示通信操作的成功与否。 + +2. 实现方法 + +SerResult UBSHcomServiceContext::Result() const + +3. 参数说明 + +无 + +4. 返回值 + +ctx的结果。 + +##### UBSHcomServiceContext::Channel + +1. 函数定义 + +获得ctx的NetChannel,可以用于向对端回复消息。 + +2. 实现方法 + +const UBSHcomChannelPtr &UBSHcomServiceContext::Channel() const + +3. 参数说明 + +无 + +4. 返回值 + +ctx的NetChannel。 + +##### UBSHcomServiceContext::OpType + +1. 函数定义 + +获得ctx的操作类型。 + +2. 实现方法 + +Operation UBSHcomServiceContext::OpType() const + +3. 参数说明 + +无 + +4. 返回值 + +[ctx的操作类型](#ZH-CN_TOPIC_0000002465536418)。 + +##### UBSHcomServiceContext::RspCtx + +1. 函数定义 + +获得ctx的rspCtx,可以用于接收对端发送call消息后回复消息时当作参数使用。 + +2. 实现方法 + +uintptr_t UBSHcomServiceContext::RspCtx() const + +3. 参数说明 + +无 + +4. 返回值 + +ctx的rspCtx。 + +##### UBSHcomServiceContext::ErrorCode + +1. 函数定义 + +获得ctx的errorCode。 + +2. 实现方法 + +const int32_t UBSHcomServiceContext::ErrorCode() + +3. 参数说明 + +无 + +4. 返回值 + +ctx的errorCode + +##### UBSHcomServiceContext::OpCode + +1. 函数定义 + +获得ctx的opCode。 + +2. 实现方法 + +uint16_t UBSHcomServiceContext::OpCode() const + +3. 参数说明 + +无 + +4. 返回值 + +ctx的opCode。 + +##### UBSHcomServiceContext::MessageData + +1. 函数定义 + +获得ctx的消息,为对端发送过来的消息。 + +2. 实现方法 + +void \*UBSHcomServiceContext::MessageData() const + +3. 参数说明 + +无 + +4. 返回值 + +ctx的消息。 + +##### UBSHcomServiceContext::MessageDataLen + +1. 函数定义 + +获得ctx的消息长度。 + +2. 实现方法 + +uint32_t UBSHcomServiceContext::MessageDataLen() const + +3. 参数说明 + +无 + +4. 返回值 + +ctx的消息长度。 + +##### UBSHcomServiceContext::Clone + +1. 函数定义 + +将ctx的内容拷贝。 + +2. 实现方法 + +static SerResult UBSHcomServiceContext::Clone(NetServiceContext &newOne, const NetServiceContext &oldOne, bool copyData = true) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----------|-----------------------------|----------|-----------------| +| newOne | UBSHcomServiceContext | 出参 | 拷贝得到的ctx。 | +| oldOne | const UBSHcomServiceContext | 入参 | 被拷贝的ctx。 | +| copyData | bool | 入参 | 是否拷贝数据。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### UBSHcomServiceContext::IsTimeout + +1. 函数定义 + +获得ctx是否超时,表示此次操作是否超时。 + +2. 实现方法 + +bool UBSHcomServiceContext::IsTimeout() const + +3. 参数说明 + +无 + +4. 返回值 + +ctx是否超时。 + +##### UBSHcomServiceContext::Invalidate + +1. 函数定义 + +将ctx的内容失效。 + +2. 实现方法 + +void UBSHcomServiceContext::Invalidate() + +3. 参数说明 + +无 + +4. 返回值 + +无 + +##### UBSHcomService::SetEnableMrCache + +![](media/image9.png) + +若用户需要使用RNDV,则需要设置为true。 + +1. 函数定义 + +设置RegisterMemoryRegion是否将mr放入pgTable管理。 + +2. 实现说明 + +void SetEnableMrCache(bool enableMrCache); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------------|----------|----------|---------------------------| +| enableMrCache | bool | 入参 | mr放入pgTable管理标志位。 | + +4. 返回值 + +无 + +#### 传输层API + +##### UBSHcomNetDriver::Instance + +1. 函数定义 + +生成UBSHcomNetDriver实例。 + +2. 实现方法 + +static UBSHcomNetDriver \*UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol t, const std::string &name, bool startOobSvr) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| t | [UBSHcomNetDriverProtocol](#ZH-CN_TOPIC_0000002498495129) | 入参 | 设置UBSHcomNetDriver的协议类型。 | +| name | const std::string | 入参 | 设置UBSHcomNetDriver的名字。长度范围\[1, 100\],只能包含数字、字母、‘\_’和‘-’。 | +| startOobSvr | bool | 入参 | Server端设置为true,Client端设置为false。 | + +4. 返回值 + +成功则返回UBSHcomNetDriver类型的实例,否则返回空。 + +##### UBSHcomNetDriver::DestroyInstance + +1. 函数定义 + +销毁UBSHcomNetDriver实例。 + +2. 实现方法 + +static NResult UBSHcomNetDriver::DestroyInstance(const std::string &name) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| name | String | 入参 | 需要销毁UBSHcomNetDriver的名字。长度范围\[1, 100\],只能包含数字、字母、‘\_’和‘-’。 | + +4. 返回值 + +返回值为0则表示成功销毁UBSHcomNetDriver实例。 + +##### UBSHcomNetDriver::LocalSupport + +1. 函数定义 + +通过UBSHcomNetDriver对象,校验本机是否支持所提供协议,若为RDMA协议且支持的情况下,会返回设备信息。 + +2. 实现方法 + +static bool UBSHcomNetDriver::LocalSupport(UBSHcomNetDriverProtocol t, UBSHcomNetDriverDeviceInfo &deviceInfo) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| t | [UBSHcomNetDriverProtocol](#ZH-CN_TOPIC_0000002498495129) | 入参 | 需要校验的协议。 | +| deviceInfo | [UBSHcomNetDriverDeviceInfo](#ZH-CN_TOPIC_0000002465376694) | 出参 | RDMA设备信息。 | + +4. 返回值 + +返回值为true则表示支持此协议。 + +##### UBSHcomNetDriver::MultiRailGetDevCount + +1. 函数定义 + +通过UBSHcomNetDriver对象,校验本机是否支持所提供协议,若为RDMA协议且支持的情况下,会通过ipGroup筛选符合要求的IP,若ipGroup为空,则会用ipMask来筛选。 + +2. 实现方法 + +static bool UBSHcomNetDriver::MultiRailGetDevCount(UBSHcomNetDriverProtocol t, std::string ipMask, uint16_t &enableDevCount, std::string ipGroup) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| t | [UBSHcomNetDriverProtocol](#ZH-CN_TOPIC_0000002498495129) | 入参 | 需要校验的协议。 | +| ipMask | std::string | 入参 | IP掩码。长度范围\[1, 256\]。 | +| enableDevCount | uint16_t | 出参 | 符合要求的RDMA设备个数。 | +| ipGroup | std::string | 入参 | IP组。长度范围\[1, 1024\]。 | + +4. 返回值 + +返回值为true则表示支持此协议。 + +##### UBSHcomNetDriver::Initialize + +1. 函数定义 + +初始化UBSHcomNetDriver。 + +2. 实现方法 + +NResult UBSHcomNetDriver::Initialize(const UBSHcomNetDriverOptions &option) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| option | [UBSHcomNetDriverOptions](#ZH-CN_TOPIC_0000002465376494) | 入参 | 根据Option,初始化UBSHcomNetDriver。 | + +4. 返回值 + +返回值为0则表示初始化成功。 + +##### UBSHcomNetDriver::UnInitialize + +1. 函数定义 + +取消初始化UBSHcomNetDriver。 + +2. 方法实现 + +void UBSHcomNetDriver::UnInitialize() + +3. 参数说明 + +无 + +4. 返回值 + +无 + +##### UBSHcomNetDriver::Start + +1. 函数定义 + +运行UBSHcomNetDriver。 + +2. 实现方法 + +NResult UBSHcomNetDriver::Start() + +3. 参数说明 + +无 + +4. 返回值 + +返回值为0则表示运行UBSHcomNetDriver成功。 + +##### UBSHcomNetDriver::Stop + +1. 函数定义 + +停止UBSHcomNetDriver。 + +2. 实现方法 + +void UBSHcomNetDriver::Stop() + +3. 参数说明 + +无 + +4. 返回值 + +无 + +##### UBSHcomNetDriver::CreateMemoryRegion + +1. 函数定义 + +  + +1. 注册一个内存区域,内存将在UBS Comm内部分配。 + +2. 注册一个内存区域,内存需要外部传入。 + + 1. 实现方法 + +- NResult UBSHcomNetDriver::CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) + +- NResult UBSHcomNetDriver::CreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) + +- NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, unsigned long memid) + + 1. 参数说明 + + 1. 参数说明 + +[TABLE] + +2. 返回值 + +返回值为0则表示发送消息成功。 + +##### UBSHcomNetDriver::DestroyMemoryRegion + +1. 函数定义 + +注销内存区域。 + +2. 实现方法 + +void UBSHcomNetDriver::DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|---------------------------|----------|----------------------| +| mr | UBSHcomNetMemoryRegionPtr | 入参 | 需要注销的内存区域。 | + +4. 返回值 + +无 + +##### UBSHcomNetDriver::Connect + +![](media/image9.png) + +如果用户实现中需要主动销毁EP,要调用UBSHcomNetEndpoint::Close接口;如果需要减少EP的引用计数,可调用DecreaseRef函数。 + +1. 函数定义 + +- 建立与Server的连接。利用设置IP端口或者UDS名称的方法来选择对端,指定本端和对端的worker group,指定链路类型。 + +- 建立与Server的连接。利用设置IP端口或者UDS名称的方法来选择对端,指定本端和对端的worker group。 + +- 建立与Server的连接。利用设置IP端口或者UDS名称的方法来选择对端,指定链路类型。 + +- 建立与Server的连接。利用设置IP端口或者UDS名称的方法来选择对端。 + +- 建立与Server的连接。自定义地址来选择对端,指定本端和对端的worker group,指定链路类型。 + +- 建立与Server的连接。自定义地址来选择对端,指定本端和对端的worker group,指定链路类型,指定sec校验时回调中的ctx。 + +- 建立与Server的连接。自定义地址来选择对端,指定本端和对端的worker group。 + +- 建立与Server的连接。自定义地址来选择对端,指定链路类型。 + +- 建立与Server的连接。自定义地址来选择对端。 + + 1. 实现方法 + +- NResult UBSHcomNetDriver::Connect( const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo) + +- NResult UBSHcomNetDriver::Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint8_t serverGrpNo, uint8_t clientGrpNo) + +- NResult UBSHcomNetDriver::Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags) + +- NResult UBSHcomNetDriver::Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep) + +- NResult UBSHcomNetDriver::Connect(const std::string &oobIpOrName, uint16_t oobPort, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo) + +- NResult UBSHcomNetDriver::Connect(const std::string &oobIpOrName, uint16_t oobPort, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) + +- NResult UBSHcomNetDriver::Connect(const std::string &oobIpOrName, uint16_t oobPort, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint8_t serverGrpNo, uint8_t clientGrpNo) + +- NResult UBSHcomNetDriver::Connect(const std::string &oobIpOrName, uint16_t oobPort, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags) + +- NResult UBSHcomNetDriver::Connect( const std::string &oobIpOrName, uint16_t oobPort, const std::string &payload, UBSHcomNetEndpointPtr &ep) + +- NResult Connect(const std::string &serverUrl, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) + + 1. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| oobIpOrName | String | 入参 | 连接时带外链路IP地址或者名字,用TCP则用IP地址,UDS则用名字。 | +| oobPort | uint16_t | 入参 | 当带外链路是TCP的时候,需要设置。范围是\[1024, 65535\]。 | +| payload | String | 入参 | Payload传输给对端,对端通过ep connect的回调中获得。长度范围\[0, 512\]。 | +| ep | UBSHcomNetEndpointPtr & | 出参 | 连接的EP。 | +| flags | uint32_t | 入参 | 可选参数,当创建同步EP时flags设置为Net_EP_SELF_POLLING。 | +| serverGrpNo | uint8_t | 入参 | 选择对端的worker group number。 | +| clientGrpNo | uint8_t | 入参 | 选择本端的worker group number。 | +| ctx | uint64_t | 入参 | secInfo回调时的ctx。 | + +2. 返回值 + +- 返回值为0表示connect成功。 + +- 返回值为其它值则表示建链失败。 + +##### UBSHcomNetDriver::DestroyEndpoint + +1. 函数定义 + +通过UBSHcomNetDriver对象来销毁EP。 + +2. 实现方法 + +void UBSHcomNetDriver::DestroyEndpoint(UBSHcomNetEndpointPtr &ep) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------------------|----------|------------------| +| ep | UBSHcomNetEndpointPtr | 入参 | 需要被销毁的EP。 | + +4. 返回值 + +无 + +##### UBSHcomNetDriver::OobIpAndPort + +1. 函数定义 + +给UBSHcomNetDriver对象设置OOB的IP和端口号。当此UBSHcomNetDriver是server时,UBSHcomNetDriver会监听此ipPort,且此方法在可以多次调用,会同时监听多个ipPort组合;当此UBSHcomNetDriver是client时,client在Connect时会默认向此ipPort建链。 + +2. 实现方法 + +void UBSHcomNetDriver::OobIpAndPort(const std::string &ip, uint16_t port) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ip | const std::string | 入参 | IP。该字段内部有系统函数对IP有效性进行校验。 | +| port | uint16_t | 入参 | 端口号。范围值\[1024, 65535\]。 | + +4. 返回值 + +无 + +##### UBSHcomNetDriver::GetOobIpAndPort + +1. 函数定义 + +得到UBSHcomNetDriver对象的OOB的IP和Port。 + +2. 实现方法 + +bool UBSHcomNetDriver::GetOobIpAndPort(std::vector\\> &result) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| result | std::vector\\> | 出参 | IP和Port的数组。 | + +4. 返回值 + +返回值为true则表示成功。 + +##### UBSHcomNetDriver::AddOobOptions + +1. 函数定义 + +设置UBSHcomNetDriver对象的OOB的IP和Port。 + +2. 实现方法 + +void UBSHcomNetDriver::AddOobOptions(const UBSHcomNetOobListenerOptions &option) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| option | const [UBSHcomNetOobListenerOptions](#ZH-CN_TOPIC_0000002498615633) | 入参 | IP和Port的数组。 | + +4. 返回值 + +无 + +##### UBSHcomNetDriver::OobUdsName + +1. 函数定义 + +给UBSHcomNetDriver对象设置的OOB type为UDS时的name。当此UBSHcomNetDriver是server时,UBSHcomNetDriver会监听此name,且此方法在可以多次调用,会同时监听多个name;当此UBSHcomNetDriver是client时,client在Connect时会默认向此name建链。 + +2. 实现方法 + +void UBSHcomNetDriver::OobUdsName(const std::string &name) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------|----------|-------------------------------------| +| name | std::string | 入参 | 需要设置的name。长度范围是(0, 96)。 | + +4. 返回值 + +无 + +##### UBSHcomNetDriver::AddOobUdsOptions + +1. 函数定义 + +给UBSHcomNetDriver对象设置的OOB type为UDS时的name和一些参数。 + +2. 实现方法 + +void UBSHcomNetDriver::AddOobUdsOptions(const UBSHcomNetOobUDSListenerOptions &option) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| option | [UBSHcomNetOobUDSListenerOptions](#ZH-CN_TOPIC_0000002465536522) | 入参 | 需要设置的UDS参数。 | + +4. 返回值 + +无 + +##### UBSHcomNetDriver::RegisterNewEPHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +为从Client端连接的新链接注册回调,只需要在Server端注册。 + +2. 实现方法 + +void UBSHcomNetDriver::RegisterNewEPHandler(const UBSHcomNetDriverNewEndPointHandler &handler) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|------------------------------------|----------|------------| +| handler | UBSHcomNetDriverNewEndPointHandler | 入参 | 回调函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomNetDriverNewEndPointHandler =std::function\. + +##### UBSHcomNetDriver::RegisterEPBrokenHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +给UBSHcomNetDriver对象设置EP断链回调函数。 + +2. 实现方法 + +void UBSHcomNetDriver::RegisterEPBrokenHandler(const UBSHcomNetDriverEndpointBrokenHandler &handler) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|---------------------------------------|----------|----------------| +| handler | UBSHcomNetDriverEndpointBrokenHandler | 入参 | 断链回调函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomNetDriverEndpointBrokenHandler = std::function\. + +##### UBSHcomNetDriver::RegisterNewReqHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +注册接收到对方请求的回调。 + +2. 实现方法 + +void UBSHcomNetDriver::RegisterNewReqHandler(const UBSHcomNetDriverReceivedHandler &handler) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|---------------------------------|----------|------------| +| handler | UBSHcomNetDriverReceivedHandler | 入参 | 回调函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomNetDriverReceivedHandler = std::function\. + +##### UBSHcomNetDriver::RegisterReqPostedHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +注册将请求发送到对端的回调。 + +2. 实现方法 + +void UBSHcomNetDriver::RegisterReqPostedHandler(const UBSHcomNetDriverSentHandler &handler) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|-----------------------------|----------|------------| +| handler | UBSHcomNetDriverSentHandler | 入参 | 回调函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomNetDriverSentHandler = std::function\. + +##### UBSHcomNetDriver::RegisterOneSideDoneHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +注册单边操作完成的回调。 + +2. 实现方法 + +void UBSHcomNetDriver::RegisterOneSideDoneHandler(const UBSHcomNetDriverOneSideDoneHandler &handler) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|------------------------------------|----------|------------| +| handler | UBSHcomNetDriverOneSideDoneHandler | 入参 | 回调函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomNetDriverOneSideDoneHandler = std::function\. + +##### UBSHcomNetDriver::RegisterIdleHandler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +给UBSHcomNetDriver对象设置worker闲时回调函数。 + +2. 实现方法 + +void UBSHcomNetDriver::RegisterIdleHandler(const UBSHcomNetDriverIdleHandler &handler) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|-----------------------------|----------|----------------| +| handler | UBSHcomNetDriverIdleHandler | 入参 | 闲时回调函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomNetDriverIdleHandler = std::function\. + +##### UBSHcomNetDriver::Name + +1. 函数定义 + +得到UBSHcomNetDriver对象的name。 + +2. 实现方法 + +const std::string &UBSHcomNetDriver::Name() const + +3. 参数说明 + +无 + +4. 返回值 + +返回UBSHcomNetDriver对象的name。 + +##### UBSHcomNetDriver::GetId + +1. 函数定义 + +得到UBSHcomNetDriver对象的index。 + +2. 实现方法 + +uint8_t UBSHcomNetDriver::GetId() const + +3. 参数说明 + +无 + +4. 返回值 + +返回UBSHcomNetDriver对象的index。 + +##### UBSHcomNetDriver::Protocol + +1. 函数定义 + +得到UBSHcomNetDriver对象的通信协议。 + +2. 实现方法 + +UBSHcomNetDriverProtocol UBSHcomNetDriver::Protocol() const + +3. 参数说明 + +无 + +4. 返回值 + +返回UBSHcomNetDriver对象的[通信协议](#ZH-CN_TOPIC_0000002498495129)。 + +##### UBSHcomNetDriver::IsStarted + +1. 函数定义 + +得到UBSHcomNetDriver对象的是否启动。 + +2. 实现方法 + +bool UBSHcomNetDriver::IsStarted() const + +3. 参数说明 + +无 + +4. 返回值 + +返回UBSHcomNetDriver对象的是否启动。 + +##### UBSHcomNetDriver::IsInited + +1. 函数定义 + +得到UBSHcomNetDriver对象的是否初始化。 + +2. 实现方法 + +bool UBSHcomNetDriver::IsInited() const + +3. 参数说明 + +无 + +4. 返回值 + +返回UBSHcomNetDriver对象的是否初始化。 + +##### UBSHcomNetDriver::NetUid + +1. 函数定义 + +通过UBSHcomNetDriver对象获得一个新的UID。 + +2. 实现方法 + +uint64_t UBSHcomNetDriver::NetUid() const + +3. 参数说明 + +无 + +4. 返回值 + +返回一个UID。 + +##### UBSHcomNetDriver::DumpObjectStatistics + +1. 函数定义 + +得到UBSHcomNetDriver对象的各项成员变量的引用计数。 + +2. 实现方法 + +static void UBSHcomNetDriver::DumpObjectStatistics() + +3. 参数说明 + +无 + +4. 返回值 + +返回UBSHcomNetDriver对象的各项成员变量的引用计数。 + +##### UBSHcomNetDriver::SetPeerDevId + +1. 函数定义 + +给UBSHcomNetDriver对象设置对端RDMA设备索引。 + +2. 实现方法 + +void UBSHcomNetDriver::SetPeerDevId(uint8_t index) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|--------------------------------------------| +| index | uint8_t | 入参 | 对端RDMA设备索引。范围是\[0, UINT8_MAX\]。 | + +4. 返回值 + +无 + +##### UBSHcomNetDriver::GetPeerDevId + +1. 函数定义 + +得到UBSHcomNetDriver对象的对端RDMA设备索引。 + +2. 实现方法 + +uint8_t UBSHcomNetDriver::GetPeerDevId() const + +3. 参数说明 + +无 + +4. 返回值 + +返回UBSHcomNetDriver对象的对端RDMA设备索引。 + +##### UBSHcomNetDriver::SetDeviceId + +1. 函数定义 + +给UBSHcomNetDriver对象设置RDMA设备索引。 + +2. 实现方法 + +void UBSHcomNetDriver::SetDeviceId(uint8_t index) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------------------------------| +| index | uint8_t | 入参 | RDMA设备索引。范围是\[0, UINT8_MAX\]。 | + +4. 返回值 + +无 + +##### UBSHcomNetDriver::GetDeviceId + +1. 函数定义 + +得到UBSHcomNetDriver对象的RDMA设备索引。 + +2. 实现方法 + +uint8_t UBSHcomNetDriver::GetDeviceId() const + +3. 参数说明 + +无 + +4. 返回值 + +返回UBSHcomNetDriver对象的RDMA设备索引。 + +##### UBSHcomNetDriver::GetBandWidth + +1. 函数定义 + +得到UBSHcomNetDriver对象的带宽。 + +2. 实现方法 + +uint8_t UBSHcomNetDriver::GetBandWidth() const + +3. 参数说明 + +无 + +4. 返回值 + +返回UBSHcomNetDriver对象的带宽。 + +##### UBSHcomNetDriver::OobEidAndJettyId + +1. 函数定义 + +暂时只支持UBC协议时使用,传入对应的EID用于公知jetty自举建链 + +2. 实现方法 + +void UBSHcomNetDriver::OobEidAndJettyId(const std::string &eid, uint16_t id) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|-------------| +| eid | const std::string | 入参 | eid字符串。 | +| id | uint16_t | 入参 | port。 | + +4. 返回值 + +无 + +##### UBSHcomNetEndpoint::SetEpOption + +1. 函数定义 + +暂时只支持TCP协议时使用,TCP通信默认为非阻塞通信,可以将此EP设置为阻塞通信。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::SetEpOption(UBSHcomEpOptions &epOptions) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| epOptions | [UBSHcomEpOptions](#ZH-CN_TOPIC_0000002498496065) | 入参 | 链路设置选项。 | + +4. 返回值 + +返回值为0则表示设置成功。 + +##### UBSHcomNetEndpoint::GetSendQueueCount + +1. 函数定义 + +得到此EP正在使用的发送队列大小。 + +2. 实现方法 + +uint32_t UBSHcomNetEndpoint::GetSendQueueCount() + +3. 参数说明 + +无 + +4. 返回值 + +返回正在使用的发送队列大小。 + +##### UBSHcomNetEndpoint::Id + +1. 函数定义 + +得到此EP的ID。 + +2. 实现方法 + +uint64_t UBSHcomNetEndpoint::Id() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP的ID。 + +##### UBSHcomNetEndpoint::WorkerIndex + +1. 函数定义 + +得到此EP所在的worker的索引。 + +2. 实现方法 + +const UBSHcomNetWorkerIndex &UBSHcomNetEndpoint::WorkerIndex() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP所在的worker的索引。 + +##### UBSHcomNetEndpoint::IsEstablished + +1. 函数定义 + +得到此EP的状态是否为已创建。 + +2. 实现方法 + +bool UBSHcomNetEndpoint::IsEstablished() + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP的状态是否为已创建。 + +##### UBSHcomNetEndpoint::UpCtx + +1. 函数定义 + +用于设置此EP的上层上下文,其中储存用户所需的数据指针,在回调函数中可以被得到。 + +2. 实现方法 + +void UBSHcomNetEndpoint::UpCtx(uint64_t ctx) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|-----------------------------------------------| +| ctx | uint64_t | 入参 | 用户上下文数据指针。范围是\[0, UINT64_MAX\]。 | + +4. 返回值 + +无 + +##### UBSHcomNetEndpoint::UpCtx + +1. 函数定义 + +用于得到此EP的上层上下文。 + +2. 实现方法 + +uint64_t UBSHcomNetEndpoint::UpCtx() const + +3. 参数说明 + +无 + +4. 返回值 + +返回上层上下文。 + +##### UBSHcomNetEndpoint::PeerConnectPayload + +1. 函数定义 + +用于得到此EP建链时设置的payload信息。 + +2. 实现方法 + +const std::string &UBSHcomNetEndpoint::PeerConnectPayload() const + +3. 参数说明 + +无 + +4. 返回值 + +返回payload。 + +##### UBSHcomNetEndpoint::LocalIp + +1. 函数定义 + +用于得到此EP的本端IP地址。 + +2. 实现方法 + +uint32_t UBSHcomNetEndpoint::LocalIp() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP的本端IP地址。 + +##### UBSHcomNetEndpoint::ListenPort + +1. 函数定义 + +用于得到此EP建链时监听所用的端口。 + +2. 实现方法 + +uint16_t UBSHcomNetEndpoint::ListenPort() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP建链时监听时所用端口。 + +##### UBSHcomNetEndpoint::Version + +1. 函数定义 + +用于得到此EP所在UBSHcomNetDriver的version。 + +2. 实现方法 + +uint8_t UBSHcomNetEndpoint::Version() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP所在UBSHcomNetDriver的version。 + +##### UBSHcomNetEndpoint::State + +1. 函数定义 + +用于得到此EP的状态。 + +2. 实现方法 + +UBSHcomNetAtomicState\ &UBSHcomNetEndpoint::State() + +3. 参数说明 + +无 + +4. 返回值 + +返回此[EP的状态](#ZH-CN_TOPIC_0000002465535858)。 + +##### UBSHcomNetEndpoint::PeerIpAndPort + +1. 函数定义 + +用于得到此EP的对端IP和端口信息。 + +2. 实现方法 + +const std::string &UBSHcomNetEndpoint::PeerIpAndPort() + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP的对端IP和端口信息。 + +##### UBSHcomNetEndpoint::UdsName + +1. 函数定义 + +仅对SHM协议有效,用于得到此EP的UDS name。 + +2. 实现方法 + +const std::string &UBSHcomNetEndpoint::UdsName() + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP的UDS name。 + +##### UBSHcomNetEndpoint::PostSend + +1. 函数定义 + +- 发送一个带有opcode和header的请求给对方,并且自定义seqNo。 + +- 发送一个带有opcode和header的请求给对方,并且可以设置操作参数。 + +- 发送一个带有opcode和header的请求给对方。 + + 1. 实现方法 + +- NResult UBSHcomNetEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNo) + +- NResult UBSHcomNetEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo) + +- NResult UBSHcomNetEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request) + + 1. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| opCode | uint16_t | 入参 | 操作码\[0, 1023\]。 | +| request | [UBSHcomNetTransRequest](#ZH-CN_TOPIC_0000002498615245) | 入参 | 发送请求信息,使用本地内存来存储数据,数据会被复制,调用后可释放本地内存。 | +| seqNo | uint32_t | 入参 | 对方要回复的seqNo必须大于0,对方可以从context.Header().seqNo中获取它;如果seqNo为0,则生成自动递增的数字。在同步发送消息的情况下,请求和响应的seqNo相等。 | +| opInfo | [UBSHcomNetTransOpInfo](#ZH-CN_TOPIC_0000002465377186) | 入参 | 此发送操作相关的参数。 | + +2. 返回值 + +返回值为0则表示发送消息成功。 + +![](media/image8.png) + +- 如果NET_EP_SELF_POLLING未设置,则只发出发送请求,不等待发送请求完成情况。 + +- 如果NET_EP_SELF_POLLING设置,则发出发送请求并等待发送到达对端。 + +##### UBSHcomNetEndpoint::PostSendRaw + +1. 函数定义 + +发送一个不带有header的请求给对方,并且自定义为seqNo,对方将触发新的请求回调,同样不带有header。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::PostSendRaw(const UBSHcomNetTransRequest &request,uint32_t seqNo) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| request | [UBSHcomNetTransRequest](#ZH-CN_TOPIC_0000002498615245) | 入参 | 发送请求信息,使用本地内存来存储数据,数据会被复制,调用后可释放本地内存。 | +| seqNo | uint32_t | 入参 | 对方要回复的seqNo必须大于0,对方可以从context.Header().seqNo中获取它;如果seqNo为0,则生成自动递增的数字。在同步发送消息的情况下,请求和响应的seqNo相等。 | + +4. 返回值 + +返回值为0则表示发送消息成功。 + +![](media/image8.png) + +- 如果NET_EP_SELF_POLLING未设置,则只发出发送请求,不等待发送请求完成情况。 + +- 如果NET_EP_SELF_POLLING设置,则发出发送请求并等待发送到达对端。 + +##### UBSHcomNetEndpoint::PostRead + +1. 函数定义 + +将单边读请求发送到对端,对端不会触发回调。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::PostRead(const UBSHcomNetTransRequest &request) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| request | [UBSHcomNetTransRequest](#ZH-CN_TOPIC_0000002498615245) | 入参 | 发送请求信息。 | + +4. 返回值 + +返回值为0则表示发送消息成功。 + +##### UBSHcomNetEndpoint::PostWrite + +1. 函数定义 + +将单边写请求发送到对端,对端不会触发回调。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::PostWrite(const UBSHcomNetTransRequest &request) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| request | [UBSHcomNetTransRequest](#ZH-CN_TOPIC_0000002498615245) | 入参 | 发送请求信息。 | + +4. 返回值 + +返回值为0则表示发送消息成功。 + +##### UBSHcomNetEndpoint::DefaultTimeout + +1. 函数定义 + +用于设置此EP的超时时间。 + +2. 实现方法 + +void UBSHcomNetEndpoint::DefaultTimeout(int32_t timeout) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| timeout | int32_t | 入参 | 超时时间,单位为秒,0为立即超时,负数为永不超时(一般设置为-1)。设置时小于或等于65536。 | + +4. 返回值 + +无 + +##### UBSHcomNetEndpoint::WaitCompletion + +1. 函数定义 + +- 等待发送/读/写完成,仅对NET_EP_SELF_POLLING设置时使用。 + +- 等待发送/读/写完成,仅对NET_EP_SELF_POLLING设置时使用。使用mDefaultTimeout作为超时时间。 + + 1. 实现方法 + +- NResult UBSHcomNetEndpoint::WaitCompletion(int32_t timeout) + +- NResult UBSHcomNetEndpoint::WaitCompletion() + + 1. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +返回值为0则表示发送消息成功。 + +![](media/image8.png) + +- 此函数用在发送后时,当请求发送到对方时被调用。 + +- 此函数用在读后时,当读完成时被调用。 + +- 此函数用在写后时,当写完成时被调用。 + +##### UBSHcomNetEndpoint::Receive + +1. 函数定义 + +- 接收对端发送的Send消息。 + +- 接收对端发送的Send消息。使用mDefaultTimeout作为超时时间。 + + 1. 实现方法 + +- NResult UBSHcomNetEndpoint::Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) + +- NResult UBSHcomNetEndpoint::Receive(UBSHcomNetResponseContext &ctx) + + 1. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| timeout | int32_t | 入参 | 超时时间,单位为秒,0为立即超时,负数为永不超时(一般设置为-1)。 | +| ctx | UBSHcomNetResponseContext | 入参 | 用来存放收到消息的内容的对象。 | + +2. 返回值 + +返回值为0则表示成功。 + +##### UBSHcomNetEndpoint::ReceiveRaw + +1. 函数定义 + +- 接收对端发送的SendRaw消息。 + +- 接收对端发送的SendRaw消息。使用mDefaultTimeout作为超时时间。 + + 1. 实现方法 + +- NResult UBSHcomNetEndpoint::ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) + +- NResult UBSHcomNetEndpoint::ReceiveRaw(UBSHcomNetResponseContext &ctx) + + 1. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| timeout | int32_t | 入参 | 超时时间,单位为秒,0为立即超时,负数为永不超时(一般设置为-1)。 | +| ctx | UBSHcomNetResponseContext | 入参 | 用来存放收到消息的内容的对象。 | + +2. 返回值 + +返回值为0则表示成功。 + +##### UBSHcomNetEndpoint::GetRemoteUdsIdInfo + +1. 函数定义 + +仅支持服务端且OOB type为UDS时,查询此EP的对端UDS ID信息。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &idInfo) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| idInfo | [UBSHcomNetUdsIdInfo](#ZH-CN_TOPIC_0000002498615509) | 出参 | 对端UDS ID信息。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### UBSHcomNetEndpoint::GetPeerIpPort + +1. 函数定义 + +查询此EP对端的IP地址和端口信息。 + +2. 实现方法 + +bool UBSHcomNetEndpoint::GetPeerIpPort(std::string &ip, uint16_t &port) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|--------------| +| ip | String | 出参 | 对端IP地址。 | +| port | uint16_t | 出参 | 对端的端口。 | + +4. 返回值 + +返回值为true则表示成功。 + +##### UBSHcomNetEndpoint::Close + +![](media/image9.png) + +如果用户实现中需要主动销毁EP,要调用UBSHcomNetEndpoint::Close接口;如果需要减少EP的引用计数,可调用DecreaseRef函数。 + +1. 函数定义 + +关闭此EP。 + +2. 实现方法 + +void UBSHcomNetEndpoint::Close() + +3. 参数说明 + +无 + +4. 返回值 + +无 + +##### UBSHcomNetEndpoint::GetDevIndex + +1. 函数定义 + +仅在RDMA协议下,得到此EP的设备索引。 + +2. 实现方法 + +uint8_t UBSHcomNetEndpoint::GetDevIndex() + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP的设备索引。 + +##### UBSHcomNetEndpoint::GetPeerDevIndex + +1. 函数定义 + +仅在RDMA协议下,得到此EP的对端设备索引。 + +2. 实现方法 + +uint8_t UBSHcomNetEndpoint::GetPeerDevIndex() + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP的对端设备索引。 + +##### UBSHcomNetEndpoint::GetBandWidth + +1. 函数定义 + +仅在RDMA协议下,得到此EP的设备带宽。 + +2. 实现方法 + +uint8_t UBSHcomNetEndpoint::GetBandWidth() + +3. 参数说明 + +无 + +4. 返回值 + +返回此EP的设备带宽。 + +##### UBSHcomNetMessage::DataLen + +1. 函数定义 + +得到此UBSHcomNetMessage的大小。 + +2. 实现方法 + +uint32_t UBSHcomNetMessage::DataLen() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此UBSHcomNetMessage的大小。 + +##### UBSHcomNetMessage::Data + +1. 函数定义 + +得到此UBSHcomNetMessage的消息。 + +2. 实现方法 + +void \*UBSHcomNetMessage::Data() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此UBSHcomNetMessage的消息。 + +##### UBSHcomNetRequestContext::EndPoint + +1. 函数定义 + +在回调函数中,通过ctx参数获得此消息所关联的EP。 + +2. 实现方法 + +const UBSHcomNetEndpointPtr &UBSHcomNetRequestContext::EndPoint() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此消息所关联的EP。 + +##### UBSHcomNetRequestContext::Result + +1. 函数定义 + +在回调函数中,通过ctx参数获得此次通信的结果。 + +2. 实现方法 + +NResult UBSHcomNetRequestContext::Result() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此消息所关联的EP。 + +##### UBSHcomNetRequestContext::Header + +1. 函数定义 + +在回调函数中,通过ctx参数获得此次通信对端发送过来的Header。 + +2. 实现方法 + +const UBSHcomNetTransHeader &UBSHcomNetRequestContext::Header() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此次通信对端发送过来的Header。 + +##### UBSHcomNetRequestContext::Message + +1. 函数定义 + +在回调函数中,通过ctx参数获得此次通信对端发送过来的消息信息。 + +2. 实现方法 + +UBSHcomNetMessage \*UBSHcomNetRequestContext::Message() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此次通信对端发送过来的消息信息。 + +##### UBSHcomNetRequestContext::OpType + +1. 函数定义 + +在回调函数中,通过ctx参数获得此次通信的类型。 + +2. 实现方法 + +NN_OpType UBSHcomNetRequestContext::OpType() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此次通信的类型。 + +##### UBSHcomNetRequestContext::OriginalRequest + +1. 函数定义 + +在回调函数中,通过ctx参数获得此次通信的发送消息。 + +2. 实现方法 + +const UBSHcomNetTransRequest &UBSHcomNetRequestContext::OriginalRequest() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此次通信的发送消息。 + +##### UBSHcomNetRequestContext::OriginalSgeRequest + +1. 函数定义 + +在回调函数中,通过ctx参数获得此次通信的SGL发送消息。 + +2. 实现方法 + +const UBSHcomNetTransSglRequest &UBSHcomNetRequestContext::OriginalSgeRequest() const + +3. 参数说明 + +无 + +4. 返回值 + +返回此次通信的SGL发送消息。 + +##### UBSHcomNetRequestContext::SafeClone + +1. 函数定义 + +在回调函数中,将ctx信息进行拷贝。 + +2. 实现方法 + +static bool UBSHcomNetRequestContext::SafeClone(const UBSHcomNetRequestContext &old, const UBSHcomNetRequestContextPtr &newOne) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------------------------|----------|-------------------| +| old | UBSHcomNetRequestContext | 入参 | 需要被拷贝的ctx。 | +| newOne | UBSHcomNetRequestContextPtr | 出参 | 拷贝得到的ctx。 | + +4. 返回值 + +返回true则成功,会拷贝其中的EP,mHeader,mOpType信息。 + +##### UBSHcomNetResponseContext::Header + +1. 函数定义 + +得到NetResponseContext中的header。 + +2. 实现方法 + +const UBSHcomNetTransHeader &UBSHcomNetResponseContext::Header() const + +3. 参数说明 + +无 + +4. 返回值 + +返回NetResponseContext中的header。 + +##### UBSHcomNetResponseContext::Message + +1. 函数定义 + +得到NetResponseContext中的消息。 + +2. 实现方法 + +UBSHcomNetMessage \*UBSHcomNetResponseContext::Message() const + +3. 参数说明 + +无 + +4. 返回值 + +返回NetResponseContext中的消息。 + +##### UBSHcomNetMemoryRegion::GetLKey + +1. 函数定义 + +获得MR的local key。 + +2. 实现方法 + +uint32_t UBSHcomNetMemoryRegion::GetLKey() const + +3. 参数说明 + +无 + +4. 返回值 + +返回local key。 + +##### UBSHcomNetMemoryRegion::GetAddress + +1. 函数定义 + +获得MR的内存地址。 + +2. 实现方法 + +uintptr_t UBSHcomNetMemoryRegion::GetAddress() const + +3. 参数说明 + +无 + +4. 返回值 + +返回内存地址。 + +##### UBSHcomNetMemoryRegion::Size + +1. 函数定义 + +获得MR的内存大小。 + +2. 实现方法 + +uint64_t UBSHcomNetMemoryRegion::Size() const + +3. 参数说明 + +无 + +4. 返回值 + +返回内存大小。 + +##### UBSHcomNetMemoryAllocator::Create + +1. 函数定义 + +创建一个内存分配器。 + +2. 实现方法 + +static NResult UBSHcomNetMemoryAllocator::Create(UBSHcomNetMemoryAllocatorType t, const UBSHcomNetMemoryAllocatorOptions &options, UBSHcomNetMemoryAllocatorPtr &out) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| t | [UBSHcomNetMemoryAllocatorType](#ZH-CN_TOPIC_0000002498615653) | 入参 | 分配器类型。 | +| options | [UBSHcomNetMemoryAllocatorOptions](#ZH-CN_TOPIC_0000002498615221) | 入参 | 分配器参数。 | +| out | UBSHcomNetMemoryAllocatorPtr | 出参 | 创建的分配器指针。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### UBSHcomNetMemoryAllocator::MrKey + +1. 函数定义 + +得到分配器的memory region key。 + +2. 实现方法 + +uint32_t UBSHcomNetMemoryAllocator::MrKey() const + +3. 参数说明 + +无 + +4. 返回值 + +返回分配器的memory region key。 + +##### UBSHcomNetMemoryAllocator::MrKey + +1. 函数定义 + +给分配器设置memory region key。 + +2. 实现方法 + +void UBSHcomNetMemoryAllocator::MrKey(uint32_t mrKey) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|---------------------------------------------| +| mrKey | uint32_t | 入参 | memory region key。范围值(0, UINT32_MAX\]。 | + +4. 返回值 + +无 + +##### UBSHcomNetMemoryAllocator::MemOffset + +1. 函数定义 + +得到地址在分配器内存的偏移值。 + +2. 实现方法 + +uintptr_t UBSHcomNetMemoryAllocator::MemOffset(uintptr_t address) const + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|-----------|----------|------------| +| address | uintptr_t | 入参 | 内存地址。 | + +4. 返回值 + +返回偏移值。 + +##### UBSHcomNetMemoryAllocator::FreeSize + +1. 函数定义 + +得到分配器剩余的内存大小。 + +2. 实现方法 + +uint64_t UBSHcomNetMemoryAllocator::FreeSize() const + +3. 参数说明 + +无 + +4. 返回值 + +返回分配器剩余的内存大小。 + +##### UBSHcomNetMemoryAllocator::Allocate + +1. 函数定义 + +从内存分配器中分配出指定大小的内存。 + +2. 实现方法 + +NResult UBSHcomNetMemoryAllocator::Allocate(uint64_t size, uintptr_t &outAddress) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|------------|-----------|----------|----------------------| +| size | uint64_t | 入参 | 需要分配的内存大小。 | +| outAddress | uintptr_t | 出参 | 分配的内存地址。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### UBSHcomNetMemoryAllocator::Free + +1. 函数定义 + +将从内存分配器中分配的内存释放给分配器。 + +2. 实现方法 + +NResult UBSHcomNetMemoryAllocator::Free(uintptr_t address) + +![](media/image8.png) + +使用时防止相同address多次调用该函数。 + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|-----------|----------|----------------------| +| address | uintptr_t | 入参 | 需要释放的内存地址。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### UBSHcomNetMemoryAllocator::Destroy + +1. 函数定义 + +当编译选项BUILD_WITH_ALLOCATOR_PROTECTION为ON时,去掉内存保护。 + +2. 实现方法 + +void UBSHcomNetMemoryAllocator::Destroy() + +3. 参数说明 + +无 + +4. 返回值 + +无 + +##### UBSHcomNetMemoryAllocator::GetTargetSeg + +1. 函数定义 + +获得SEG。 + +2. 实现方法 + +void \*UBSHcomNetMemoryAllocator::GetTargetSeg() + +3. 参数说明 + +无 + +4. 返回值 + +void \* + +##### UBSHcomNetMemoryAllocator::SetTargetSeg + +1. 函数定义 + +设置SEG。 + +2. 实现方法 + +void UBSHcomNetMemoryAllocator::SetTargetSeg(void \*targetSeg) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|----------|----------|-----------------| +| targetSeg | void \* | 入参 | 需要设置的SEG。 | + +4. 返回值 + +无 + +##### UBSHcomNetMemoryAllocatorTypeToString + +1. 函数定义 + +将内存类型转化成字符串。 + +2. 实现方法 + +std::string &UBSHcomNetMemoryAllocatorTypeToString(UBSHcomNetMemoryAllocatorType v) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------------------|----------|------------| +| v | UBSHcomNetMemoryAllocatorType | 入参 | 内存类型。 | + +4. 返回值 + +字符串{"Dynamic size allocator", "Dynamic size allocator with cache","UNKNOWN ALLOCATOR TYPE"} + +##### UBSHcomNetDriverProtocolToString + +1. 函数定义 + +将协议类型转化成字符串。 + +2. 实现方法 + +std::string &UBSHcomNetDriverProtocolToString(UBSHcomNetDriverProtocol v) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|--------------------------|----------|------------| +| v | UBSHcomNetDriverProtocol | 入参 | 协议类型。 | + +4. 返回值 + +字符串{"RDMA", "TCP", "UDS", "SHM", "RDMA_MLX5", "UB", "UBOE", "UBC", "HSHMEM","UNKNOWN PROTOCOL"} + +##### UBSHcomNetDriverSecTypeToString + +1. 函数定义 + +将安全校验类型转化成字符串。 + +2. 实现方法 + +std::string &UBSHcomNetDriverSecTypeToString(UBSHcomNetDriverSecType v) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------------|----------|----------------| +| v | UBSHcomNetDriverSecType | 入参 | 安全校验类型。 | + +4. 返回值 + +字符串{"SecNoValid", "SecValidOneWay", "SecValidTwoWay", "UNKNOWN SEC TYPE"} + +##### UBSHcomNetDriverOobTypeToString + +1. 函数定义 + +oob类型转化成字符串。 + +2. 实现方法 + +std::string &UBSHcomNetDriverOobTypeToString(NetDriverOobType v) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|------------------|----------|----------------------| +| v | NetDriverOobType | 入参 | UBSHcomOobType类型。 | + +4. 返回值 + +字符串{"Tcp", "UDS", "URMA","UNKNOWN OOB TYPE"} + +##### UBSHcomNetDriverLBPolicyToString + +1. 函数定义 + +将负载均衡类型转化成字符串。 + +2. 实现方法 + +std::string &UBSHcomNetDriverLBPolicyToString(UBSHcomNetDriverLBPolicy v) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|--------------------------|----------|----------------| +| v | UBSHcomNetDriverLBPolicy | 入参 | 负载均衡类型。 | + +4. 返回值 + +字符串{"RR", "Hash","UNKNOWN POLICY" } + +##### UBSHcomNEPStateToString + +1. 函数定义 + +将EndPoint状态类型转化成字符串。 + +2. 实现方法 + +std::string &UBSHcomNEPStateToString(UBSHcomNetEndPointState v) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------------|----------|--------------------| +| v | UBSHcomNetEndPointState | 入参 | EndPoint状态类型。 | + +4. 返回值 + +字符串{"new", "established", "broken","UNKNOWN EP STATE"} + +### C API + +#### 服务层API + +##### ubs_hcom_service_create + +1. 函数定义 + +根据类型和名字创建一个服务层的NetService对象。 + +2. 实现方法 + +int ubs_hcom_service_create(ubs_hcom_service_type t, const char \*name, ubs_hcom_service_options options, ubs_hcom_service \*service); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| t | [ubs_hcom_service_type](#ZH-CN_TOPIC_0000002465536674) | 入参 | ubs_hcom_service协议类型。 | +| name | const char \* | 入参 | ubs_hcom_service的名字。长度范围\[1, 64\],只能包含数字、字母、‘\_’和‘-’。 | +| options | ubs_hcom_service_options | 入参 | Service配置项。 | +| service | ubs_hcom_service | 出参 | 表示创建的ubs_hcom_service对象,如果创建失败返回空。 | + +4. 返回值 + +返回值为0则表示发送消息成功。 + +##### ubs_hcom_service_bind + +1. 函数定义 + +根据类型和名字创建一个服务层的NetService对象。 + +2. 实现方法 + +int ubs_hcom_service_bind(ubs_hcom_service service, const char \*listenerUrl, ubs_hcom_service_channel_handler h); + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +返回值为0则表示bind成功。 + +##### ubs_hcom_service_start + +1. 函数定义 + +根据类型和名字创建一个服务层的NetService对象。 + +2. 实现方法 + +int ubs_hcom_service_start(ubs_hcom_service service); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | + +4. 返回值 + +返回值为0则表示创建成功。 + +##### ubs_hcom_service_destroy + +1. 函数定义 + +销毁服务,会清理全局map根据名字销毁对象。 + +2. 实现方法 + +int ubs_hcom_service_destroy(ubs_hcom_service service, const char \*name); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|---------------|----------|--------------------------------------| +| service | Net_Service | 入参 | 需要销毁的ubs_hcom_service对象。 | +| name | const char \* | 入参 | 需要销毁的ubs_hcom_service对象名字。 | + +##### ubs_hcom_service_connect + +1. 函数定义 + +建立与远程服务器的连接,并返回连接通道。 + +2. 实现方法 + +int ubs_hcom_service_connect(ubs_hcom_service service, const char \*serverUrl, ubs_hcom_channel \*channel, Service_UBSHcomConnectOptions options); + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +返回值为0则表示建链成功。 + +##### ubs_hcom_service_disconnect + +1. 函数定义 + +切断与远程服务器的连接。 + +2. 实现方法 + +int ubs_hcom_service_disconnect(ubs_hcom_service service, ubs_hcom_channel channel); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|-------------|----------|----------------------------------| +| service | Net_Service | 入参 | 之前创建的ubs_hcom_service对象。 | +| channel | Net_Channel | 入参 | 建链生成的连接通道NetChannel。 | + +4. 返回值 + +返回值为0则表示断链成功。 + +##### ubs_hcom_service_register_memory_region + +1. 函数定义 + +注册一个内存区域,内存将在UBS Comm内部分配。 + +2. 实现方法 + +int ubs_hcom_service_register_memory_region(ubs_hcom_service service, uint64_t size, ubs_hcom_memory_region \*mr); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 初始化创建的ubs_hcom_service对象。 | +| size | uint64_t | 入参 | 需要注册的内存大小,单位byte。范围为(0, 107374182400\]。 | +| mr | ubs_hcom_memory_region | 入参 | 内存区域结构,包含key、名字、大小、buf等字段。 | + +![](media/image8.png) + +若需要放入pgTable管理(通过ubs_hcom_service_set_enable_mrcache设置为true,默认不放入),则要求首地址(startAddress)和尾地址(startAddress+size)都需要16字节对齐,因此用户申请的size需要能16整除。 + +4. 返回值 + +表示函数执行结果,返回值为0则表示注册成功。 + +##### ubs_hcom_service_get_memory_region_info + +1. 函数定义 + +获得mr的内容。 + +2. 实现方法 + +int ubs_hcom_service_get_memory_region_info(ubs_hcom_memory_region mr, ubs_hcom_mr_info \*info); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| mr | ubs_hcom_memory_region | 入参 | mr对象。 | +| info | [ubs_hcom_mr_info](#ZH-CN_TOPIC_0000002465377214) | 出参 | mr的信息。 | + +4. 返回值 + +表示函数执行结果,返回值为0则表示注册成功。 + +##### ubs_hcom_service_register_assign_memory_region + +1. 函数定义 + +注册一个内存区域,内存将在UBS Comm外部分配。 + +2. 实现方法 + +int ubs_hcom_service_register_assign_memory_region(ubs_hcom_service service, uintptr_t address, uint64_t size, ubs_hcom_memory_region \*mr); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 初始化创建的ubs_hcom_service对象。 | +| address | uintptr_t | 入参 | 外部申请的内存地址。 | +| size | uint64_t | 入参 | 外部申请的内存大小,单位byte。范围为(0, 1099511627776\]。 | +| mr | ubs_hcom_memory_region | 出参 | 内存区域结构,包含key、名字、大小、buf等字段。 | + +4. 返回值 + +表示函数执行结果,返回值为0则表示注册成功。 + +##### ubs_hcom_service_destroy_memory_region + +1. 函数定义 + +销毁一个内存区域,内存将在UBS Comm内部分配。 + +2. 实现方法 + +int ubs_hcom_service_destroy_memory_region(ubs_hcom_service service, ubs_hcom_memory_region mr); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 初始化创建的ubs_hcom_service对象。 | +| mr | ubs_hcom_memory_region | 入参 | 内存区域结构,包含key、名字、大小、buf等字段。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_register_broken_handler + +Service_RegisterChannelHandler + +![](media/image9.png) + +用户注册的回调函数,不能销毁Service及相关的资源。 + +1. 函数定义 + +注册通道Channel的回调函数,以处理通道建链和断连事件。 + +2. 实现方法 + +void ubs_hcom_service_register_broken_handler(ubs_hcom_service service, ubs_hcom_service_channel_handler h, + +ubs_hcom_service_channel_policy policy, uint64_t usrCtx); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| t | [ubs_hcom_service_handler_type](#ZH-CN_TOPIC_0000002465536462) | 入参 | 句柄的类型。 | +| h | ubs_hcom_service_channel_handler | 入参 | 回调函数句柄。 | +| policy | [ubs_hcom_service_channel_policy](#ZH-CN_TOPIC_0000002465536134) | 入参 | 链路断开时的策略,策略可选。 | +| usrCtx | uint64_t | 入参 | 用户上下文。 | + +4. 返回值 + +uintptr_t,返回内部句柄地址。 + +##### ubs_hcom_service_register_idle_handler + +![](media/image9.png) + +用户注册的回调函数,不能销毁Service及相关的资源。 + +1. 函数定义 + +设置NetService的worker闲时回调函数。 + +2. 实现方法 + +void ubs_hcom_service_register_idle_handler(ubs_hcom_service service, ubs_hcom_service_idle_handler h, uint64_t usrCtx); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | ubs_hcom_service对象。 | +| h | ubs_hcom_service_idle_handler | 入参 | worker闲时回调函数。 | +| usrCtx | uint64_t | 入参 | 用户上下文,可以在回调函数中使用。 | + +4. 返回值 + +内部回调函数地址。 + +![](media/image8.png) + +数据类型解释如下: + +typedef void (\*ubs_hcom_service_idle_handler)(uint8_t wkrGrpIdx, uint16_t idxInGrp, uint64_t usrCtx). + +##### ubs_hcom_service_register_handler + +![](media/image9.png) + +- 用户注册的回调函数,不能销毁Service及相关的资源。 + +- 用户需要避免在该回调中死等发送完成事件,应添加超时时间,否则会造成死锁。 + +- 用户需要尽量避免在该回调中占用过长时间处理业务,以免影响性能。 + + 1. 函数定义 + +注册回调函数,以处理通道双边发送完成、单边发送完成、双边收消息事件。 + +2. 实现方法 + +void ubs_hcom_service_register_handler(ubs_hcom_service service, ubs_hcom_service_handler_type t, ubs_hcom_service_request_handler h, + +uint64_t usrCtx); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| t | ubs_hcom_service_handler_type | 入参 | 句柄的类型。 | +| h | ubs_hcom_service_request_handler | 入参 | 回调函数句柄。 | +| usrCtx | uint64_t | 入参 | 用户上下文。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_enable_mrcache + +![](media/image9.png) + +用户需要使用RNDV,则需要设置为true。 + +1. 函数定义 + +设置RegisterMemoryRegion是否将mr放入pgTable管理。 + +2. 实现方法 + +void ubs_hcom_service_set_enable_mrcache(ubs_hcom_service service, bool enableMrCache); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| enableMrCache | bool | 入参 | mr放入pgTable管理标志位。 | + +4. 返回值 + +无 + +##### ubs_hcom_channel_refer + +1. 函数定义 + +将此NetChannel增加一次引用计数。 + +2. 实现方法 + +void ubs_hcom_channel_refer(Net_Channel channel) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|-------------|----------|--------------------------------| +| channel | Net_Channel | 入参 | 需要增加引用计数的NetChannel。 | + +4. 返回值 + +无 + +##### ubs_hcom_channel_derefer + +1. 函数定义 + +将此NetChannel减少一次引用计数。 + +2. 实现方法 + +void ubs_hcom_channel_derefer(Net_Channel channel) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|-------------|----------|--------------------------------| +| channel | Net_Channel | 入参 | 需要减少引用计数的NetChannel。 | + +4. 返回值 + +无 + +##### ubs_hcom_channel_send + +![](media/image9.png) + +- 若使用拆包和rndv的功能,需要通过UBSHcomChannel::SetUBSHcomTwoSideThreshold设置拆包和rndv的阈值。 + +- 使用rndv,则需要创建service后将UBSHcomService::SetEnableMrCache设置为true(UBSHcomService::RegisterMemoryRegion函数调用前)。 + + 1. 函数定义 + +发送双边消息,不需要对端回复。 + +2. 实现方法 + +int ubs_hcom_channel_send(ubs_hcom_channel channel, ubs_hcom_channel_request req, ubs_hcom_channel_callback \*cb); + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_channel_call + +![](media/image9.png) + +- rsp中若address字段填了有效内存地址,则用户回复的信息会被拷贝到该地址上。 + +- 若address==NULL,则UBS Comm会通过malloc申请内存,但用户需要自行维护该内存的生命周期,在使用完后通过free释放。 + +- 若使用拆包和rndv的功能,需要通过UBSHcomChannel::SetUBSHcomTwoSideThreshold设置拆包和rndv的阈值。 + +- 使用rndv,则需要创建service后将UBSHcomService::SetEnableMrCache设置为true(UBSHcomService::RegisterMemoryRegion函数调用前)。 + + 1. 函数定义 + +发送双边消息并等待回复,需要对端配合Reply使用。 + +2. 实现方法 + +int ubs_hcom_channel_call(ubs_hcom_channel channel, ubs_hcom_channel_request req, ubs_hcom_channel_response \*rsp, ubs_hcom_channel_callback \*cb); + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_channel_reply + +1. 函数定义 + +回复双边消息,接收端配合Call使用 + +2. 实现方法 + +int ubs_hcom_channel_reply(ubs_hcom_channel channel, ubs_hcom_channel_request req, ubs_hcom_channel_reply_context ctx, ubs_hcom_channel_callback \*cb); + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_channel_put + +1. 数据定义 + +发送单边写请求。 + +2. 实现方法 + +int ubs_hcom_channel_put(ubs_hcom_channel channel, ubs_hcom_oneside_request req, ubs_hcom_channel_callback \*cb); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| channel | ubs_hcom_channel | 入参 | 创建的channel对象。 | +| req | [ubs_hcom_oneside_request](#ubs_hcom_oneside_request) | 入参 | 单边请求。 | +| cb | [ubs_hcom_channel_callback](#channel_callback) | 入参 | 异步请求回调函数。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_channel_get + +1. 函数定义 + +发送单边读请求。 + +2. 实现方法 + +int ubs_hcom_channel_get(ubs_hcom_channel channel, ubs_hcom_oneside_request req, ubs_hcom_channel_callback \*cb); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| channel | ubs_hcom_channel | 入参 | 创建的channel对象。 | +| req | [ubs_hcom_oneside_request](#ubs_hcom_oneside_request) | 入参 | 单边请求。 | +| cb | [ubs_hcom_channel_callback](#channel_callback) | 入参 | 异步请求回调函数。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_channel_recv + +1. 函数定义 + +只用于接收RNDV请求。 + +2. 实现方法 + +int ubs_hcom_channel_recv(ubs_hcom_channel channel, ubs_hcom_service_context ctx, uintptr_t address, uint32_t size, ubs_hcom_channel_callback \*cb); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|---------------------|----------|---------------------| +| channel | ubs_hcom_channel | 入参 | 创建的channel对象。 | +| ctx | ubs_hcom_service_context | 入参 | 上下文。 | +| address | uintptr_t | 入参 | 接收数据地址。 | +| size | uint32_t | 入参 | 接收数据大小。 | +| cb | ubs_hcom_channel_callback \* | 入参 | 异步请求回调。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_channel_set_flowctl_cfg + +1. 函数定义 + +给此NetChannel设置流控参数。 + +2. 实现方法 + +int Channel_ConfigFlowControl(ubs_hcom_channel channel, ubs_hcom_flowctl_opts options) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| channel | ubs_hcom_channel | 入参 | 通信渠道。 | +| options | [ubs_hcom_flowctl_opts](#ZH-CN_TOPIC_0000002465536490) | 入参 | 流控参数。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_channel_set_timeout + +1. 函数定义 + +设置NetChannel的双边超时时间。 + +2. 实现方法 + +void Channel_SetTwoSideTimeout(Net_Channel channel, int32_t timeout) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| channel | Net_Channel | 入参 | NetChannel。 | +| timeout | int32_t | 入参 | 超时时间,单位为秒,0为立即超时,负数为永不超时(一般设置为-1)。范围是\[-1, INT16_MAX\]。 | + +4. 返回值 + +无 + +##### ubs_hcom_channel_set_twoside_threshold + +1. 函数定义 + +设置拆包和rndv的阈值。 + +2. 实现方法 + +int ubs_hcom_channel_set_twoside_threshold(ubs_hcom_channel channel, Channel_UBSHcomTwoSideThreshold threshold); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|---------------------------------|----------|------------------| +| channel | Net_Channel | 入参 | NetChannel。 | +| threshold | Channel_UBSHcomTwoSideThreshold | 入参 | 拆包和rndv阈值。 | + +4. 返回值 + +无 + +##### Channel_Close + +![](media/image9.png) + +如果用户实现中需要主动销毁channel,要调用Channel_Close和Channel_Destroy接口。 + +1. 函数定义 + +关闭NetChannel。 + +2. 实现方法 + +void Channel_Close(Net_Channel channel) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|-------------|----------|--------------| +| channel | Net_Channel | 入参 | NetChannel。 | + +4. 返回值 + +无 + +##### ubs_hcom_channel_get_id + +1. 函数定义 + +获取channelId。 + +2. 实现方法 + +int ubs_hcom_channel_get_id(ubs_hcom_channel channel); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|------------------|----------|-------------------| +| channel | ubs_hcom_channel | 入参 | 通信channel对象。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_context_get_channel + +1. 函数定义 + +通过ctx获得NetChannel。 + +2. 实现方法 + +int ubs_hcom_context_get_channel(ubs_hcom_service_context context, ubs_hcom_channel \*channel); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|--------------------------|----------|------------------------------| +| context | ubs_hcom_service_context | 入参 | 回调函数的参数ctx。 | +| channel | ubs_hcom_channel | 出参 | 返回得到的ubs_hcom_channel。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_context_get_type + +1. 函数定义 + +通过ctx获得操作类型。 + +2. 实现方法 + +int ubs_hcom_context_get_type(ubs_hcom_service_context context, ubs_hcom_service_context_type \*type); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| context | ubs_hcom_service_context | 入参 | 回调函数的参数ctx。 | +| type | [ubs_hcom_service_context_type](#ZH-CN_TOPIC_0000002498615725) | 出参 | 返回得到的NetChannel。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_context_get_result + +1. 函数定义 + +通过ctx获得操作结果。 + +2. 实现方法 + +int ubs_hcom_context_get_result(ubs_hcom_service_context context, int \*result) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|--------------------------|----------|---------------------| +| context | ubs_hcom_service_context | 入参 | 回调函数的参数ctx。 | +| result | int | 出参 | 操作结果。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_context_get_rspctx + +1. 函数定义 + +通过ctx获得回复消息所需的rspCtx。 + +2. 实现方法 + +int ubs_hcom_context_get_rspctx(ubs_hcom_service_context context, ubs_hcom_channel_reply_context \*rspCtx); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|--------------------------------|----------|------------------------| +| context | ubs_hcom_service_context | 入参 | 回调函数的参数ctx。 | +| rspCtx | ubs_hcom_channel_reply_context | 出参 | 回复消息接口所需参数。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_context_get_opcode + +1. 函数定义 + +通过ctx获得OpCode。 + +2. 实现方法 + +uint16_t ubs_hcom_context_get_opcode(ubs_hcom_service_context context); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|--------------------------|----------|---------------------| +| context | ubs_hcom_service_context | 入参 | 回调函数的参数ctx。 | +| 返回值 | uint16_t | 出参 | OpCode。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_context_get_data + +1. 函数定义 + +通过ctx获得接收到的消息。 + +2. 实现方法 + +void \*ubs_hcom_context_get_data(ubs_hcom_service_context context) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|--------------------------|----------|---------------------| +| context | ubs_hcom_service_context | 入参 | 回调函数的参数ctx。 | + +4. 返回值 + +返回接收到的消息。 + +##### ubs_hcom_context_get_datalen + +1. 函数定义 + +通过ctx获得接收到的消息长度。 + +2. 实现方法 + +uint32_t ubs_hcom_context_get_datalen(ubs_hcom_service_context context) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|--------------------------|----------|---------------------| +| context | ubs_hcom_service_context | 入参 | 回调函数的参数ctx。 | + +4. 返回值 + +返回接收到的消息长度。 + +#### 传输层API + +##### ubs_hcom_driver_create + +1. 函数定义 + +根据类型和名字创建一个传输层的HcomDriver对象。 + +2. 实现方法 + +int ubs_hcom_driver_create(ubs_hcom_driver_type t, const char \*name, uint8_t startOobSvr, ubs_hcom_driver \*driver) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| t | [ubs_hcom_driver_type](#ubs_hcom_driver_type) | 入参 | ubs_hcom_driver协议类型,取值范围详见[ubs_hcom_driver_type](#ubs_hcom_driver_type)。 | +| name | char \* | 入参 | ubs_hcom_driver的名字。长度范围\[1, 100\],只能包含数字、字母、‘\_’和‘-’。 | +| startOobSvr | uint8_t | 入参 | Server端设置为0,Client端设置为1。 | +| driver | ubs_hcom_driver | 出参 | 创建的ubs_hcom_driver实例。 | + +4. 返回值 + +返回值为0则表示创建HcomDriver成功。 + +##### ubs_hcom_driver_set_ipport + +1. 函数定义 + +给HcomDriver对象设置OOB的IP和Port。 + +2. 实现方法 + +void ubs_hcom_driver_set_ipport(ubs_hcom_driver driver, const char \*ip, uint16_t port) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| ip | const char \* | 入参 | IP。该参数内部有系统函数对IP有效性进行校验。 | +| port | uint16_t | 入参 | 端口。范围值\[1024, 65535\]。 | + +4. 返回值 + +无 + +##### ubs_hcom_driver_get_ipport + +1. 函数定义 + +得到HcomDriver对象的OOB的IP和Port。 + +2. 实现方法 + +bool ubs_hcom_driver_get_ipport(ubs_hcom_driver driver, char \*\*\*ipArray, uint16_t \*\*portArray, int \*length) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|-----------------|----------|-----------------------| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| ipArray | char \*\* | 出参 | OOB的IP数组。 | +| portArray | uint16_t \* | 出参 | OOB的端口数组。 | +| length | int | 出参 | 数组长度。 | + +4. 返回值 + +返回值为true则表示成功。 + +![](media/image8.png) + +出参ipArray和portArray为内部分配的内存,用户需要在使用完成之后自行释放此内存。 + +##### ubs_hcom_driver_set_udsname + +1. 函数定义 + +给HcomDriver对象设置的OOB type为UDS时的name。 + +2. 实现方法 + +void ubs_hcom_driver_set_udsname(ubs_hcom_driver driver, const char \*name) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------------|----------|-------------------------------------| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| name | const char \* | 入参 | 需要设置的name。长度范围是(0, 96)。 | + +4. 返回值 + +无 + +##### ubs_hcom_driver_add_uds_opt + +1. 函数定义 + +给HcomDriver对象设置的OOB type为UDS时的name和一些参数。 + +2. 实现方法 + +void ubs_hcom_driver_add_uds_opt(ubs_hcom_driver driver, ubs_hcom_driver_uds_listen_opts option) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| option | [ubs_hcom_driver_listen_opts](#ubs_hcom_driver_listen_opts) | 入参 | 需要设置的UDS参数。 | + +4. 返回值 + +无 + +##### ubs_hcom_driver_add_oob_opt + +1. 函数定义 + +设置HcomDriver对象的OOB的IP和Port。 + +2. 实现方法 + +void ubs_hcom_driver_add_oob_opt(ubs_hcom_driver driver, ubs_hcom_driver_listen_opts options) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| options | [ubs_hcom_driver_listen_opts](#ubs_hcom_driver_listen_opts) | 出参 | 需要设置的IP和Port参数。 | + +4. 返回值 + +无 + +##### ubs_hcom_driver_initizalize + +1. 函数定义 + +根据类型和名字创建一个传输层的HcomDriver对象。 + +2. 实现方法 + +int ubs_hcom_driver_initizalize(ubs_hcom_driver driver, ubs_hcom_driver_opts options) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| driver | ubs_hcom_driver | 入参 | 需要初始化的ubs_hcom_driver。 | +| options | [ubs_hcom_driver_opts](#ubs_hcom_driver_opts) | 入参 | 根据Option,初始化ubs_hcom_driver。 | + +4. 返回值 + +返回值为0则表示初始化HcomDriver成功。 + +##### ubs_hcom_driver_start + +1. 函数描述 + +根据类型和名字创建一个传输层的HcomDriver对象。 + +2. 函数定义 + +int ubs_hcom_driver_start(ubs_hcom_driver driver) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------------|----------|-----------------------------| +| driver | ubs_hcom_driver | 入参 | 需要开启的ubs_hcom_driver。 | + +4. 返回值 + +返回值为0则表示开启HcomDriver成功。 + +##### ubs_hcom_driver_connect + +![](media/image9.png) + +如果用户实现中需要主动销毁EP,要先调用ubs_hcom_ep_close接口;如果需要减少EP的引用计数,可调用ubs_hcom_ep_destroy函数。 + +1. 函数定义 + +建立与远程服务器的连接,并返回连接创建的EP。 + +2. 实现方法 + +- int ubs_hcom_driver_connect(ubs_hcom_driver driver, const char \*payloadData, ubs_hcom_endpoint \*ep, uint32_t flags) + +- int ubs_hcom_driver_connect_with_grpno(ubs_hcom_driver driver, const char \*payloadData, ubs_hcom_endpoint \*ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo) + +- int ubs_hcom_driver_connect_to_ipport(ubs_hcom_driver driver, const char \*serverIp, uint16_t serverPort, const char \*payloadData, ubs_hcom_endpoint \*ep, uint32_t flags) + +- int ubs_hcom_driver_connect_to_ipport_with_grpno(ubs_hcom_driver driver, const char \*serverIp, uint16_t serverPort, const char \*payloadData, ubs_hcom_endpoint \*ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo) + +- int ubs_hcom_driver_connect_to_ipport_with_ctx(ubs_hcom_driver driver, const char \*serverIp, uint16_t serverPort, const char \*payloadData, ubs_hcom_endpoint \*ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) + + 1. 参数说明 + + 1. 参数说明 + +[TABLE] + +2. 返回值 + +返回值为0则表示连接成功。 + +##### ubs_hcom_driver_stop + +1. 函数定义 + +停止服务和内部启动的线程。 + +2. 实现方法 + +void ubs_hcom_driver_stop(ubs_hcom_driver driver) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------------|----------|-----------------------------| +| driver | ubs_hcom_driver | 入参 | 需要停止的ubs_hcom_driver。 | + +4. 返回值 + +无 + +##### ubs_hcom_driver_uninitialize + +1. 函数定义 + +清理服务创建时的相关资源。 + +2. 实现方法 + +void ubs_hcom_driver_uninitialize(ubs_hcom_driver driver) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------------|----------|-----------------------------| +| driver | ubs_hcom_driver | 入参 | 需要清理的ubs_hcom_driver。 | + +4. 返回值 + +无 + +##### ubs_hcom_driver_destroy + +1. 函数定义 + +销毁HcomDriver。 + +2. 实现方法 + +void ubs_hcom_driver_destroy(ubs_hcom_driver driver) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------------|----------|-----------------------------| +| driver | ubs_hcom_driver | 入参 | 需要销毁的ubs_hcom_driver。 | + +4. 返回值 + +无 + +##### ubs_hcom_driver_register_ep_handler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +注册EP的回调函数,以处理EP建链和断连事件。并把回调函数句柄放入全局句柄管理器。 + +2. 实现方法 + +uintptr_t ubs_hcom_driver_register_ep_handler(ubs_hcom_driver driver, ubs_hcom_ep_handler_type t, ubs_hcom_ep_handler h, uint64_t usrCtx) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| driver | ubs_hcom_driver | 入参 | 需要注册回调函数的ubs_hcom_driver。 | +| t | [Net_EPHandlerType](#ZH-CN_TOPIC_0000002498495477) | 入参 | 句柄的类型。 | +| h | ubs_hcom_ep_handler | 入参 | 回调函数的句柄。 | +| usrCtx | uint64 | 入参 | 用户上下文。 | + +4. 返回值 + +uintptr_t类型,返回内部句柄地址。 + +##### ubs_hcom_driver_register_op_handler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +注册回调函数,以处理通道双边发送完成、单边发送完成、双边收消息事件。并把回调函数句柄放入全局句柄管理器。 + +2. 实现方法 + +uintptr_t ubs_hcom_driver_register_op_handler(ubs_hcom_driver driver, ubs_hcom_op_handler_type t, ubs_hcom_request_handler h, uint64_t usrCtx) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| driver | ubs_hcom_driver | 入参 | 需要注册回调函数的ubs_hcom_driver。 | +| t | [ubs_hcom_op_handler_type](#ubs_hcom_op_handler_type) | 入参 | 句柄的类型。 | +| h | Net_RequestHandler | 入参 | 回调函数的句柄。 | +| usrCtx | uint64_t | 入参 | 用户上下文。 | + +4. 返回值 + +uintptr_t类型,返回内部句柄地址。 + +##### ubs_hcom_driver_register_idle_handler + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +给HcomDriver对象设置EP闲时回调函数。并把回调函数句柄放入全局句柄管理器。 + +2. 实现方法 + +uintptr_t ubs_hcom_driver_register_idle_handler(ubs_hcom_driver driver, ubs_hcom_idle_handler h, uint64_t usrCtx) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------------------|----------|-----------------------| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| h | ubs_hcom_idle_handler | 入参 | 闲时回调函数。 | +| usrCtx | uint64_t | 入参 | 带到回调函数中的ctx。 | + +4. 返回值 + +内部回调函数地址。 + +![](media/image8.png) + +数据类型解释如下: + +typedef void (\*ubs_hcom_idle_handler)(uint8_t wkrGrpIdx, uint16_t idxInGrp, uint64_t usrCtx) + +##### ubs_hcom_driver_register_secinfo_provider + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +给HcomDriver对象设置EP安全信息提供函数。 + +2. 实现方法 + +uintptr_t ubs_hcom_driver_register_secinfo_provider(ubs_hcom_driver driver, ubs_hcom_secinfo_provider provider) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----------|---------------------------|----------|-----------------------| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| provider | ubs_hcom_secinfo_provider | 入参 | 安全信息提供函数。 | + +4. 返回值 + +内部回调函数地址。 + +![](media/image8.png) + +数据类型解释如下: + +typedef int (\*ubs_hcom_secinfo_provider)(uint64_t ctx, int64_t \*flag, ubs_hcom_driver_sec_type \*type, char \*\*output, uint32_t \*outLen, int \*needAutoFree) + +##### ubs_hcom_driver_register_secinfo_validator + +![](media/image9.png) + +用户实现的回调函数,内部不能销毁driver及相关的资源。 + +1. 函数定义 + +给HcomDriver对象设置EP安全信息校验函数。 + +2. 实现方法 + +uintptr_t ubs_hcom_driver_register_secinfo_validator(ubs_hcom_driver driver, ubs_hcom_secinfo_validator validator) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|----------------------------|----------|-----------------------| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| validator | ubs_hcom_secinfo_validator | 入参 | 安全信息校验函数。 | + +4. 返回值 + +内部回调函数地址。 + +![](media/image8.png) + +数据类型解释如下: + +typedef int (\*ubs_hcom_secinfo_validator)(uint64_t ctx, int64_t flag, const char \*input, uint32_t inputLen) + +##### ubs_hcom_driver_unregister_ep_handler + +1. 函数定义 + +从全局回调函数句柄管理器中去掉某一个回调函数句柄。 + +2. 实现方法 + +void ubs_hcom_driver_unregister_ep_handler(ubs_hcom_ep_handler_type t, uintptr_t handle) + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +无 + +##### ubs_hcom_driver_unregister_op_handler + +1. 函数定义 + +从全局回调函数句柄管理器中去掉某一个回调函数句柄。 + +2. 实现方法 + +void ubs_hcom_driver_unregister_op_handler(ubs_hcom_op_handler_type t, uintptr_t handle) + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +无 + +##### ubs_hcom_driver_unregister_idle_handler + +1. 函数定义 + +从全局回调函数句柄管理器中去掉某一个回调函数句柄。 + +2. 实现方法 + +void ubs_hcom_driver_unregister_idle_handler(uintptr_t handle) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------|----------|----------------| +| handle | uintptr_t | 入参 | 回调函数句柄。 | + +4. 返回值 + +无 + +##### ubs_hcom_driver_create_memory_region + +1. 函数定义 + +通过HcomDriver对象来创建一个Memory region。 + +2. 实现方法 + +int ubs_hcom_driver_create_memory_region(ubs_hcom_driver driver, uint64_t size, ubs_hcom_memory_region \*mr) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|------------------------|----------|--------------------------------------| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| size | uint64_t | 入参 | MR的大小。范围为(0, 107374182400\]。 | +| mr | ubs_hcom_memory_region | 出参 | 创建的MR。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_driver_create_assign_memory_region + +1. 函数定义 + +通过HcomDriver对象来创建一个Memory region。 + +2. 实现方法 + +int ubs_hcom_driver_create_assign_memory_region(ubs_hcom_driver driver, uintptr_t address, uint64_t size, ubs_hcom_memory_region \*mr) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|------------------------|----------|-----------------------| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| address | uintptr_t | 入参 | 内存地址。 | +| size | uint64_t | 入参 | 内存的大小。 | +| mr | ubs_hcom_memory_region | 出参 | 创建的MR。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_driver_destroy_memory_region + +1. 函数定义 + +通过HcomDriver对象来销毁一个Memory region。 + +2. 实现方法 + +void ubs_hcom_driver_destroy_memory_region(ubs_hcom_driver driver, ubs_hcom_memory_region mr) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|------------------------|----------|-----------------------| +| driver | ubs_hcom_driver | 入参 | ubs_hcom_driver对象。 | +| mr | ubs_hcom_memory_region | 入参 | 需要销毁的MR。 | + +4. 返回值 + +无 + +##### ubs_hcom_driver_get_memory_region_info + +1. 函数定义 + +获取一个MR的信息。 + +2. 实现方法 + +int ubs_hcom_driver_get_memory_region_info(ubs_hcom_memory_region mr, ubs_hcom_memory_region_info \*info) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| mr | ubs_hcom_memory_region | 入参 | 创建的MR。 | +| info | [ubs_hcom_memory_region_info](#ubs_hcom_memory_region_info) | 出参 | MR相关的信息。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_ep_set_context + +1. 函数定义 + +给EP设置本端回调函数可使用的ctx。 + +2. 实现方法 + +void ubs_hcom_ep_set_context(ubs_hcom_endpoint ep, uint64_t ctx) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|-----------------| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | +| ctx | uint64_t | 入参 | 设置的context。 | + +4. 返回值 + +无 + +##### ubs_hcom_ep_get_context + +1. 函数定义 + +获得EP的ctx。 + +2. 实现方法 + +uint64_t ubs_hcom_ep_get_context(ubs_hcom_endpoint ep) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|------------| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | + +4. 返回值 + +EP的context。 + +##### ubs_hcom_ep_get_worker_idx + +1. 函数定义 + +获取EP所在的worker group的worker索引。 + +2. 实现方法 + +uint16_t ubs_hcom_ep_get_worker_idx(ubs_hcom_endpoint ep) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|------------| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | + +4. 返回值 + +返回worker group的worker索引。 + +##### ubs_hcom_ep_get_workergroup_idx + +1. 函数定义 + +获取EP所在的worker group索引。 + +2. 实现方法 + +uint8_t ubs_hcom_ep_get_workergroup_idx(ubs_hcom_endpoint ep) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|------------| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | + +4. 返回值 + +返回worker group索引。 + +##### ubs_hcom_ep_get_listen_port + +1. 函数定义 + +获取EP建链时所监听的端口号。 + +2. 实现方法 + +uint32_t ubs_hcom_ep_get_listen_port(ubs_hcom_endpoint ep) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|------------| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | + +4. 返回值 + +返回端口号。 + +##### ubs_hcom_ep_version + +1. 函数定义 + +获取EP的版本。 + +2. 实现方法 + +uint8_t ubs_hcom_ep_version(ubs_hcom_endpoint ep) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|------------| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | + +4. 返回值 + +返回EP的版本。 + +##### ubs_hcom_ep_set_timeout + +1. 函数定义 + +设置EP的超时时间。 + +2. 实现方法 + +void ubs_hcom_ep_set_timeout(ubs_hcom_endpoint ep, int32_t timeout) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | +| timeout | int32_t | 入参 | 超时时间,单位是秒。0为立刻超时,负数为永不超时。 | + +4. 返回值 + +无 + +##### ubs_hcom_ep_post_send + +1. 函数定义 + +向对端发送一个带有op信息的请求。 + +2. 实现方法 + +int ubs_hcom_ep_post_send(ubs_hcom_endpoint ep, uint16_t opcode, ubs_hcom_send_request \*req) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| opcode | uint16_t | 入参 | 操作码,取值范围\[0, 1023\]。 | +| \*req | [ubs_hcom_send_request](#ubs_hcom_send_request) | 入参 | 发送请求信息,使用本地内存来存储数据,数据会被复制,调用后可释放本地内存。 | + +4. 返回值 + +返回值为0则表示发送成功。 + +##### ubs_hcom_ep_post_send_with_opinfo + +1. 函数定义 + +使用EP发送PostSend消息,带有opInfo。 + +2. 实现方法 + +int ubs_hcom_ep_post_send_with_opinfo(ubs_hcom_endpoint ep, uint16_t opcode, ubs_hcom_send_request \*req, ubs_hcom_opinfo \*opInfo) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | +| opcode | uint16_t | 入参 | 操作编号。 | +| req | [ubs_hcom_send_request](#ubs_hcom_send_request) | 入参 | 需要发送的消息。 | +| opInfo | [ubs_hcom_opinfo](#ubs_hcom_opinfo) | 入参 | 操作信息。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_ep_post_send_with_seqno + +1. 函数定义 + +使用EP发送PostSend消息,带有seqNo。 + +2. 实现方法 + +int ubs_hcom_ep_post_send_with_seqno(ubs_hcom_endpoint ep, uint16_t opcode, ubs_hcom_send_request \*req, uint32_t replySeqNo) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | +| opcode | uint16_t | 入参 | 操作编号。 | +| req | [ubs_hcom_send_request](#ubs_hcom_send_request) | 入参 | 需要发送的消息。 | +| replySeqNo | uint32_t | 入参 | 序列号。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_ep_post_read + +1. 函数定义 + +向对端发送一个读请求。 + +2. 实现方法 + +int ubs_hcom_ep_post_read(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request \*req) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| \*req | [ubs_hcom_readwrite_request](#ubs_hcom_readwrite_request) | 入参 | 读请求信息。 | + +4. 返回值 + +返回值为0则表示读成功。 + +##### ubs_hcom_ep_post_write + +1. 函数定义 + +向对端发送一个写请求。 + +2. 实现方法 + +int ubs_hcom_ep_post_write(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request \*req) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| \*req | [ubs_hcom_readwrite_request](#ubs_hcom_readwrite_request) | 入参 | 写请求信息。 | + +4. 返回值 + +返回值为0则表示写成功。 + +##### ubs_hcom_ep_wait_completion + +1. 函数定义 + +等待send,read,write消息完成,只有在EP是NET_EP_SELF_POLLING时生效。 + +2. 实现方法 + +int ubs_hcom_ep_wait_completion(ubs_hcom_endpoint ep, int32_t timeout) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | +| timeout | int32_t | 入参 | 超时时间,单位是秒。0为立刻超时,负数为永不超时。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_ep_receive + +1. 函数定义 + +接收对端发送过来的消息。 + +2. 实现方法 + +int ubs_hcom_ep_receive(ubs_hcom_endpoint ep, int32_t timeout, ubs_hcom_response_context \*\*ctx) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | +| timeout | int32_t | 入参 | 超时时间,单位是秒。0为立刻超时,负数为永不超时。 | +| ctx | [ubs_hcom_response_context](#ubs_hcom_response_context) | 出参 | 接收到的消息。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_ep_refer + +1. 函数定义 + +给EP增加一次引用。 + +2. 实现方法 + +void ubs_hcom_ep_refer(ubs_hcom_endpoint ep) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|------------| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | + +4. 返回值 + +无 + +##### ubs_hcom_ep_close + +![](media/image9.png) + +如果用户实现中需要主动销毁EP,要先调用ubs_hcom_ep_close接口;如果需要减少EP的引用计数,可调用ubs_hcom_ep_destroy函数。 + +1. 函数定义 + +关闭EP。 + +2. 实现方法 + +void ubs_hcom_ep_close(ubs_hcom_endpoint ep) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|------------| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | + +4. 返回值 + +无 + +##### ubs_hcom_ep_destroy + +![](media/image9.png) + +如果用户实现中需要主动销毁EP,要先调用ubs_hcom_ep_close接口;如果需要减少EP的引用计数,可调用ubs_hcom_ep_destroy函数。 + +1. 函数定义 + +销毁EP。 + +2. 实现方法 + +void ubs_hcom_ep_destroy(ubs_hcom_endpoint ep) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|------------| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | + +4. 返回值 + +无 + +##### ubs_hcom_err_str + +1. 函数定义 + +得到errorCode的解释。 + +2. 实现方法 + +const char \*ubs_hcom_err_str(int16_t errCode) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|----------|----------|--------------------| +| errCode | int16_t | 入参 | 需要翻译的错误码。 | + +4. 返回值 + +返回错误码翻译。 + +##### ubs_hcom_mem_allocator_create + +1. 函数定义 + +创建一个内存分配器。 + +2. 实现方法 + +int ubs_hcom_mem_allocator_create(ubs_hcom_memory_allocator_type t, ubs_hcom_memory_allocator_options \*options, ubs_hcom_memory_allocator \*allocator) + +3. 参数说明\` + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_mem_allocator_destroy + +1. 函数定义 + +销毁一个内存分配器。 + +2. 实现方法 + +int ubs_hcom_mem_allocator_destroy(ubs_hcom_memory_allocator allocator) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|---------------------------|----------|------------------------| +| allocator | ubs_hcom_memory_allocator | 入参 | 需要销毁的内存分配器。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_mem_allocator_set_mr_key + +1. 函数定义 + +给分配器设置memory region key。 + +2. 实现方法 + +int ubs_hcom_mem_allocator_set_mr_key(ubs_hcom_memory_allocator allocator, uint32_t mrKey) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| allocator | ubs_hcom_memory_allocator | 入参 | 需要被设置的内存分配器。 | +| mrKey | uint32_t | 入参 | memory region key。范围值(0, UINT32_MAX\]。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_mem_allocator_get_offset + +1. 函数定义 + +得到地址在分配器内存的偏移值。 + +2. 实现方法 + +int ubs_hcom_mem_allocator_get_offset(ubs_hcom_memory_allocator allocator, uintptr_t address, uintptr_t \*offset) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|---------------------------|----------|--------------| +| allocator | ubs_hcom_memory_allocator | 入参 | 内存分配器。 | +| address | uintptr_t | 入参 | 内存地址。 | +| offset | uintptr_t | 出参 | 偏移值。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_mem_allocator_get_free_size + +1. 函数定义 + +得到分配器剩余的内存大小。 + +2. 实现方法 + +int ubs_hcom_mem_allocator_get_free_size(ubs_hcom_memory_allocator allocator, uintptr_t \*size) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|---------------------------|----------|------------------| +| allocator | ubs_hcom_memory_allocator | 入参 | 内存分配器。 | +| size | uintptr_t | 出参 | 剩余的内存大小。 | + +4. 返回值 + +返回0为成功。 + +##### ubs_hcom_mem_allocator_allocate + +1. 函数定义 + +从内存分配器中分配出指定大小的内存。 + +2. 实现方法 + +int ubs_hcom_mem_allocator_allocate(ubs_hcom_memory_allocator allocator, uint64_t size, uintptr_t \*address, uint32_t \*key) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|---------------------------|----------|------------------| +| allocator | ubs_hcom_memory_allocator | 入参 | 内存分配器。 | +| size | uint64_t | 入参 | 分配的内存大小。 | +| address | uintptr_t | 出参 | 分配的内存地址。 | +| key | uint32_t | 出参 | MR Key。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_mem_allocator_free + +1. 函数定义 + +将从内存分配器中分配的内存释放给分配器。 + +2. 实现方法 + +int ubs_hcom_mem_allocator_free(ubs_hcom_memory_allocator allocator, uintptr_t address) + +![](media/image8.png) + +使用时防止相同address多次调用该函数。 + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|---------------------------|----------|----------------------| +| allocator | ubs_hcom_memory_allocator | 入参 | 内存分配器。 | +| address | uintptr_t | 入参 | 需要释放的内存地址。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_set_log_handler + +1. 函数定义 + +设置外部日志。 + +2. 实现方法 + +void ubs_hcom_set_log_handler(ubs_hcom_log_handler h) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------------------|----------|----------------| +| h | ubs_hcom_log_handler | 入参 | 外部日志函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +typedef void (\*ubs_hcom_log_handler)(int level, const char \*msg) + +##### ubs_hcom_check_local_supporr + +1. 函数定义 + +校验本机是否支持所提供协议,若为RDMA协议且支持的情况下,会返回设备信息。 + +2. 实现方法 + +int ubs_hcom_check_local_supporr(ubs_hcom_driver_type t, ubs_hcom_device_info \*info) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| t | ubs_hcom_driver_type | 入参 | 需要校验的协议。 | +| info | [ubs_hcom_device_info](#ubs_hcom_device_info) | 出参 | RDMA设备信息,最大的SGL的iov count。 | + +4. 返回值 + +返回值为1则表示支持此协议。 + +##### ubs_hcom_get_remote_uds_info + +1. 函数定义 + +仅支持服务端且OOB type为UDS时,查询此EP的对端UDS ID信息。 + +2. 实现方法 + +int ubs_hcom_get_remote_uds_info(ubs_hcom_endpoint ep, ubs_hcom_uds_id_info \*idInfo) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | +| idInfo | [ubs_hcom_uds_id_info](#ubs_hcom_uds_id_info) | 出参 | 对端UDS ID信息。 | + +4. 返回值 + +返回值为0则表示成功。 + +## 高级API参考 + +### C++API + +#### 服务层 + +##### UBSHcomService::AddWorkerGroup + +1. 函数定义 + +向Service中增加内存池。 + +2. 实现说明 + +void UBSHcomService::AddWorkerGroup(uint16_t workerGroupId, uint32_t threadCount,const std::pair\ &cpuIdsRange, int8_t priority = 0, uint16_t multirailIdx = 0); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| workerGroupId | uint16_t | 入参 | 增加的workerGroup的编号ID,用于标识不同的workerGroup。 | +| threadCount | uint32_t | 入参 | 该workerGroup中的线程数。 | +| cpuIdsRange | const std::pair\ | 入参 | 该workerGroup绑定的cpu范围,如{0, 10}表示绑定在0到10号CPU上。 | +| priority | int8_t | 入参 | 线程优先级,同线程nice值,范围\[-20, 19\],取值越大优先级越低。 | +| multirailIdx | uint16_t | 入参 | 该workerGroup绑定的MultiRail索引序号。 | + +4. 返回值 + +无 + +##### UBSHcomService::AddListener + +1. 函数定义 + +向Service中增加listener。 + +2. 实现说明 + +void UBSHcomService::AddListener(const std::string &url, uint16_t workerCount = UINT16_MAX); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| url | const std::string | 入参 | 增加的listener监听的url,同bind。 | +| workerCount | uint16_t | 入参 | 从workerGroup中选取workerCount个线程,与该url建立的连接请求通过这workerCount个线程去处理。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetConnectLBPolicy + +1. 函数定义 + +设置建链负载均衡策略 + +2. 实现说明 + +void UBSHcomService::SetConnectLBPolicy(UBSHcomServiceLBPolicy lbPolicy) + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +无 + +##### UBSHcomService::SetUBSHcomTlsOptions + +1. 函数定义 + +设置TLS可选配置项。 + +2. 实现说明 + +void UBSHcomService::SetUBSHcomTlsOptions(const UBSHcomTlsOptions &opt); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-----------------------------------------|----------|-----------------| +| opt | [UBSHcomTlsOptions](#ubshcomtlsoptions) | 入参 | TLS可选配置项。 | + +![](media/image8.png) + +使用UB自举建链时,暂不支持安全认证和安全加密。 + +4. 返回值 + +无 + +##### UBSHcomService::SetConnSecureOpt + +1. 函数定义 + +链接安全配置项 + +2. 实现说明 + +void UBSHcomService::SetConnSecureOpt(const UBSHcomConnSecureOptions &opt); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------------------------------|----------|----------------------| +| opt | const UBSHcomConnSecureOptions & | 入参 | 链接安全可选配置项。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetTcpUserTimeOutSec + +1. 函数定义 + +设置TCP套接字选项TCP_USER_TIMEOUT。 + +2. 实现说明 + +void UBSHcomService::SetTcpUserTimeOutSec(uint16_t timeOutSec); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| timeOutSec | uint16_t | 入参 | 对应TCP_USER_TIMEOUT套接字选项,范围\[0, 1024\],0表示永不超时。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetTcpSendZCopy + +1. 函数定义 + +设置TCP发送是否要做内存拷贝。 + +2. 实现说明 + +void UBSHcomService::SetTcpSendZCopy(bool tcpSendZCopy); + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +无 + +##### UBSHcomService::SetDeviceIpMask + +1. 函数定义 + +设置设备ipMask,用于rdma/ub。 + +2. 实现说明 + +void UBSHcomService::SetDeviceIpMask(const std::vector\ &ipMasks); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|----------------------------------|----------|------------------------| +| ipMasks | const std::vector\ | 入参 | 用于过滤的ipMask集合。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetDeviceIpGroups + +1. 函数定义 + +设置设备ipGroup。 + +2. 实现说明 + +void UBSHcomService::SetDeviceIpGroups(const std::vector\ &ipGroups); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----------|----------------------------------|----------|----------------------| +| ipGroups | const std::vector\ | 入参 | 设备的ipGroups集合。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetCompletionQueueDepth + +1. 函数定义 + +设置cq队列深度。 + +2. 实现说明 + +void UBSHcomService::SetCompletionQueueDepth(uint16_t depth); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------| +| depth | uint16_t | 入参 | 完成队列深度。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetSendQueueSize + +1. 函数定义 + +设置发送队列深度。 + +2. 实现说明 + +void UBSHcomService::SetSendQueueSize(uint32_t sqSize); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------| +| sqSize | uint32_t | 入参 | 发送队列深度。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetRecvQueueSize + +1. 函数定义 + +设置接收队列深度。 + +2. 实现说明 + +void UBSHcomService::SetRecvQueueSize(uint32_t rqSize); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------| +| rqSize | uint32_t | 入参 | 接收队列深度。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetPollingBatchSize + +1. 函数定义 + +设置批量polling的大小。 + +2. 实现说明 + +void UBSHcomService::SetPollingBatchSize(uint16_t pollSize); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----------|----------|----------|---------------------| +| pollSize | uint16_t | 入参 | 批量polling的大小。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetEventPollingTimeOutUs + +1. 函数定义 + +设置event polling的超时时间。 + +2. 实现说明 + +void UBSHcomService::SetEventPollingTimeOutUs(uint16_t pollTimeout); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-------------|----------|----------|-------------------------| +| pollTimeout | uint16_t | 入参 | event polling超时时间。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetTimeOutDetectionThreadNum + +1. 函数定义 + +设置周期任务处理线程数。 + +2. 实现说明 + +void UBSHcomService::SetTimeOutDetectionThreadNum(uint32_t threadNum); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|----------|----------|----------------------| +| threadNum | uint32_t | 入参 | 周期任务处理线程数。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetMaxConnectionCount + +1. 函数定义 + +设置最大链接数。 + +2. 实现说明 + +void UBSHcomService::SetMaxConnectionCount(uint32_t maxConnCount); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------------|----------|----------|--------------| +| maxConnCount | uint32_t | 入参 | 最大链接数。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetUBSHcomHeartBeatOptions + +1. 函数定义 + +设置心跳参数配置项。 + +2. 实现说明 + +void UBSHcomService::SetUBSHcomHeartBeatOptions(const UBSHcomHeartBeatOptions &opt); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| opt | [UBSHcomHeartBeatOptions](#ubshcomheartbeatoptions) | 入参 | 心跳可选参数项。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetUBSHcomMultiRailOptions + +1. 函数定义 + +设置多路径参数配置项。 + +2. 实现说明 + +void UBSHcomService::SetUBSHcomMultiRailOptions(const UBSHcomMultiRailOptions &opt); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| opt | [UBSHcomMultiRailOptions](#ubshcommultirailoptions) | 入参 | 多路径参数配置项。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetQueuePrePostSize + +1. 函数定义 + +设置提前下发wr的数量,不设置的话默认64。 + +2. 实现说明 + +void UBSHcomService::SetQueuePrePostSize(uint32_t prePostSize); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-------------|----------|----------|--------------------| +| prePostSize | uint32_t | 入参 | 预先下发的wr数量。 | + +4. 返回值 + +无 + +##### UBSHcomService::SetMaxSendRecvDataCount + +1. 函数定义 + +设置发送数据块最大数量,不设置的话默认8192。 + +2. 实现说明 + +void SetMaxSendRecvDataCount(uint32_t maxSendRecvDataCount); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----------------------|----------|----------|----------------------| +| maxSendRecvDataCount | uint32_t | 入参 | 发送数据块最大数量。 | + +4. 返回值 + +无 + +##### UBSHcomRegMemoryRegion::GetMemoryKey + +1. 函数定义 + +获得所有内存池的keys。 + +2. 实现说明 + +void UBSHcomRegMemoryRegion::GetMemoryKey(UBSHcomMemoryKey &mrKey); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|------------------|----------|----------------| +| mrKey | UBSHcomMemoryKey | 出参 | 内存池的keys。 | + +4. 返回值 + +无 + +##### UBSHcomRegMemoryRegion::GetAddress + +1. 函数定义 + +获得首个内存池地址。 + +2. 实现说明 + +uintptr_t UBSHcomRegMemoryRegion::GetAddress(); + +3. 参数说明 + +无 + +4. 返回值 + +地址值。 + +##### UBSHcomRegMemoryRegion::GetSize + +1. 函数定义 + +获得首个内存池长度。 + +2. 实现说明 + +uint64_t UBSHcomRegMemoryRegion::GetSize(); + +3. 参数说明 + +无 + +4. 返回值 + +内存池的长度。 + +##### UBSHcomRegMemoryRegion::GetHcomMrs + +1. 函数定义 + +获得内存池组。 + +2. 实现说明 + +std::vector\& UBSHcomRegMemoryRegion::GetHcomMrs(); + +3. 参数说明 + +无 + +4. 返回值 + +返回std::vector\类型的数组。 + +##### UBSHcomNewCallback + +1. 函数定义 + +创建Callback函数。 + +2. 实现说明 + +template \ Callback \*UBSHcomNewCallback(Args... args); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------| +| args | Args | 入参 | 回调函数入参。 | + +4. 返回值 + +返回Callback \*函数。 + +#### 传输层 + +##### UBSHcomNetDriver::RegisterTLSCaCallback + +1. ?.1.接口使用方法 + + 函数定义 + +注册建链双向认证的回调函数,用于获取CA证书。 + +3. 实现方法 + +void UBSHcomNetDriver::RegisterTLSCaCallback(const UBSHcomTLSCaCallback &cb); + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| cb | [UBSHcomTLSCaCallback](#ZH-CN_TOPIC_0000002465536138) | 入参 | 获取CA证书的回调函数。 | + +5. 返回值 + +无 + +6. 代码样例 + +int Verify(void \*x509, const char \*path) +{ +return 0; +} + +bool CACallback(const std::string &name, std::string &caPath, std::string &crlPath, +UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ +caPath = certPath + "/CA/cacert.pem"; +cb = std::bind(&Verify, std::placeholders::\_1, std::placeholders::\_2); +return true; +} + +driver-\>RegisterTLSCaCallback(std::bind(&CACallback, std::placeholders::\_1, std::placeholders::\_2, std::placeholders::\_3, std::placeholders::\_4, std::placeholders::\_5)); + +7. ?.2.UBSHcomTLSCaCallback函数类型 + + 函数定义 + +using UBSHcomTLSCaCallback = std::function\; + +9. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +Bool类型,表示回调函数是否执行成功。 + +- 返回值为true:表示成功。 + +- 返回值为false:表示失败。 + + 1. 代码样例 + +bool CACallback(const std::string &name, std::string &caPath, std::string &crlPath, +UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ +caPath = certPath + "/CA/cacert.pem"; +cb = std::bind(&Verify, std::placeholders::\_1, std::placeholders::\_2); +return true; +} + +2. ?.3.UBSHcomTLSCertVerifyCallback函数类型 + + 函数定义 + +using UBSHcomTLSCertVerifyCallback = std::function\; + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|------------------| +| \- | void\* | 入参 | 加载之后的证书。 | + +5. 返回值 + +表示函数执行结果,返回值为0表示证书验证成功。 + +##### UBSHcomNetDriver::RegisterTLSCertificationCallback + +1. ?.1.接口使用方法 + + 函数定义 + +注册建链双向认证的回调函数,用于获取公钥证书。 + +3. 实现方法 + +void UBSHcomNetDriver::RegisterTLSCertificationCallback(const UBSHcomTLSCertificationCallback &cb); + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| cb | [UBSHcomTLSCertificationCallback](#ZH-CN_TOPIC_0000002498615437) | 入参 | 获取公钥的回调函数。 | + +5. 返回值 + +无 + +6. 代码样例 + +bool CertCallback(const std::string &name, std::string &value) +{ +value = certPath + "/client/cert.pem"; +return true; +} + +driver-\>RegisterTLSCertificationCallback( std::bind(&CertCallback, std::placeholders::\_1, std::placeholders::\_2)); + +7. ?.2.UBSHcomTLSCertificationCallback函数类型 + + 函数定义 + +using UBSHcomTLSCertificationCallback = std::function\; + +9. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------------| +| name | String | 入参 | 句柄名字。 | +| path | String | 出参 | 提供的公钥证书路径。 | + +10. 返回值 + +Bool类型,表示回调函数是否执行成功。 + +- 返回值为true:表示成功。 + +- 返回值为false:表示失败。 + + 1. 代码样例 + +bool CertCallback(const std::string &name, std::string &value) +{ +value = certPath + "/client/cert.pem"; +return true; +} + +##### UBSHcomNetDriver::RegisterTLSPrivateKeyCallback + +1. ?.1.接口使用方法 + + 函数定义 + +注册建链双向认证的回调函数,用户获取私钥证书。 + +3. 实现方法 + +void UBSHcomNetDriver::RegisterTLSPrivateKeyCallback(const UBSHcomTLSPrivateKeyCallback &cb); + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| cb | [UBSHcomTLSPrivateKeyCallback](#ZH-CN_TOPIC_0000002498615317) | 入参 | 私钥回调函数。 | + +5. 返回值 + +无 + +6. 代码样例 + +void Erase(void \*pass, int len) {} + +bool PrivateKeyCallback(const std::string &name, std::string &value, void \*&keyPass, int &len, UBSHcomTLSEraseKeypass &erase) +{ +static char content\[\] = "xxxx"; +keyPass = reinterpret_cast\(content); +len = sizeof(content); +value = certPath + "/client/key.pem"; +erase = std::bind(&Erase, std::placeholders::\_1, std::placeholders::\_2); +return true; +} + +driver-\>RegisterTLSPrivateKeyCallback(std::bind(&PrivateKeyCallback, std::placeholders::\_1, std::placeholders::\_2, std::placeholders::\_3, std::placeholders::\_4, std::placeholders::\_5)); + +7. ?.2.UBSHcomTLSPrivateKeyCallback函数类型 + + 函数定义 + +using UBSHcomTLSPrivateKeyCallback = std::function\; + +9. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| name | String | 入参 | 句柄名字。 | +| path | String | 出参 | 提供的私钥证书路径。 | +| password | void\* | 出参 | 私钥加载的明文密码。 | +| length | int | 出参 | 私钥加载的密码长度。 | +| erase | [UBSHcomTLSEraseKeypass](#ZH-CN_TOPIC_0000002498615289) | 出参 | 擦除私钥密码的回调函数,当加载完私钥的时候调用。 | + +10. 返回值 + +Bool类型,表示回调函数是否执行成功。 + +- 返回值为true表示成功。 + +- 返回值为false表示失败。 + + 1. 代码样例 + +void Erase(void \*pass, int len) {} + +bool PrivateKeyCallback(const std::string &name, std::string &value, void \*&keyPass, int &len, UBSHcomTLSEraseKeypass &erase) +{ +static char content\[\] = "xxxx"; +keyPass = reinterpret_cast\(content); +len = sizeof(content); +value = certPath + "/client/key.pem"; +erase = std::bind(&Erase, std::placeholders::\_1, std::placeholders::\_2); +return true; +} + +2. ?.3.UBSHcomTLSEraseKeypass函数类型 + + 函数定义 + +using UBSHcomTLSEraseKeypass = std::function\; + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----------|----------|----------|----------------------| +| password | void\* | 出参 | 私钥加载的明文密码。 | +| length | int | 出参 | 私钥加载的密码长度。 | + +5. 返回值 + +无 + +##### UBSHcomNetDriver::RegisterPskUseSessionCb + +1. ?.1.接口使用方法 + + 函数定义 + +供Client端注册PSK回调函数。 + +3. 实现方法 + +void UBSHcomNetDriver::RegisterPskUseSessionCb(const UBSHcomPskUseSessionCb &cb) + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|------------------------|----------|----------------------| +| cb | UBSHcomPskUseSessionCb | 入参 | 预共享密钥回调函数。 | + +5. 返回值 + +无。 + +6. ?.2.UBSHcomPskUseSessionCb函数类型 + + 函数定义 + +using UBSHcomPskUseSessionCb = std::function\; + +8. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|--------------------------| +| ssl | void\* | 入参 | SSL连接对象。 | +| md | void\* | 入参 | 摘要算法。 | +| id | char \* | 出参 | 预共享密钥身份标识。 | +| idlen | size_t | 出参 | 预共享密钥身份标识长度。 | +| sess | void\* | 出参 | SSL会话对象。 | + +9. 返回值 + +int类型。 + +- 1:表示回调函数执行成功。 + +- 0:表示回调函数执行失败。 + +##### UBSHcomNetDriver::RegisterPskFindSessionCb + +1. ?.1.接口使用方法 + + 函数定义 + +供Server端注册PSK回调函数。 + +3. 实现方法 + +void UBSHcomNetDriver::RegisterPskFindSessionCb(const UBSHcomPskFindSessionCb &cb) + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------------|----------|----------------------| +| cb | UBSHcomPskFindSessionCb | 入参 | 预共享密钥回调函数。 | + +5. 返回值 + +无。 + +6. ?.2.UBSHcomPskFindSessionCb函数类型 + + 函数定义 + +using UBSHcomPskFindSessionCb = std::function\; + +8. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------------|----------|----------|--------------------------| +| ssl | void\* | 入参 | SSL连接对象。 | +| identity | char \* | 入参 | 预共享密钥身份标识。 | +| identity_len | size_t | 入参 | 预共享密钥身份标识长度。 | +| sess | void\* | 出参 | SSL会话对象。 | + +9. 返回值 + +int类型,1表示回调函数是执行成功,0表示执行失败。 + +##### UBSHcomNetDriver::RegisterEndpointSecInfoProvider + +1. 函数定义 + +给UBSHcomNetDriver对象设置EP安全信息提供函数。 + +2. 实现方法 + +void UBSHcomNetDriver::RegisterEndpointSecInfoProvider(const UBSHcomNetDriverEndpointSecInfoProvider &provider) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| provider | UBSHcomNetDriverEndpointSecInfoProvider | 入参 | 安全信息提供函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomNetDriverEndpointSecInfoProvider = std::function\; + +其中,outLen的有效范围为(0,2147483646\]。 + +##### UBSHcomNetDriver::RegisterEndpointSecInfoValidator + +1. 函数定义 + +给UBSHcomNetDriver对象设置EP安全信息校验函数。 + +2. 实现方法 + +void UBSHcomNetDriver::RegisterEndpointSecInfoValidator(const UBSHcomNetDriverEndpointSecInfoValidator &validator) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| validator | UBSHcomNetDriverEndpointSecInfoValidator | 入参 | 安全信息校验函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomNetDriverEndpointSecInfoValidator = std::function\; + +##### UBSHcomNetEndpoint::PostSendRawSgl + +1. 函数定义 + +发送一个不带opcode和header的请求给对方,对方将触发新的请求回调,也不带opcode和header,当客户有自己定义的header时可以使用。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::PostSendRawSgl(const UBSHcomNetTransSglRequest &request,uint32_t seqNo) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| request | UBSHcomNetTransSglRequest | 入参 | [请求信息](#ZH-CN_TOPIC_0000002465376586),填入本地注册MR,按本地MR顺序发送到同一个远端MR,调用后即可释放,rKey/rAddress不需要赋值。 | +| seqNo | uint32_t | 入参 | 对方要回复的seqNo必须大于0,对方可以从context.Header().seqNo中获取它。如果seqNo为0,则生成自动递增的数字。在同步发送消息的情况下,请求和响应的seqNo相等。 | + +4. 返回值 + +返回值为0则表示发送消息成功。 + +![](media/image8.png) + +- 如果NET_EP_SELF_POLLING未设置,则只发出发送请求,不等待发送请求完成情况。 + +- 如果NET_EP_SELF_POLLING设置,则发出发送请求并等待发送到达对端。 + +##### UBSHcomNetEndpoint::ReceiveRaw + +1. 函数定义 + +获得发送请求应答的响应,不包含header和opCode,默认超时生效。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::ReceiveRaw(UBSHcomNetResponseContext &ctx) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|--------------------|----------|-----------------| +| ctx | UBSHcomNetResponseContext | 出参 | 响应消息的ctx。 | + +4. 返回值 + +返回值为0则表示发送消息成功。 + +##### UBSHcomNetEndpoint::EstimatedEncryptLen + +1. 函数定义 + +输入原始数据大小的估计加密长度。 + +2. 实现方法 + +uint64_t UBSHcomNetEndpoint::EstimatedEncryptLen(uint64_t rawLen) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|--------------------------------------------------| +| rawLen | uint64_t | 入参 | 原始数据长度。范围是(0, 18446744073709551571\]。 | + +4. 返回值 + +返回uint64_t类型,表示数据加密长度。 + +##### UBSHcomNetEndpoint::Encrypt + +1. 函数定义 + +加密数据。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::Encrypt(const void \*rawData, uint64_t rawLen, void \*cipher, uint64_t &cipherLen) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|----------|----------|------------------| +| rawData | void \* | 入参 | 原始数据地址。 | +| rawLen | uint64_t | 入参 | 原始数据长度。 | +| cipher | void \* | 出参 | 加密后数据地址。 | +| cipherLen | uint64_t | 出参 | 加密后数据长度。 | + +4. 返回值 + +返回值为0则表示加密成功。 + +##### UBSHcomNetEndpoint::EstimatedDecryptLen + +1. 函数定义 + +输出原始数据大小。 + +2. 实现方法 + +uint64_t UBSHcomNetEndpoint::EstimatedDecryptLen(uint64_t cipherLen) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|----------|----------|--------------------| +| cipherLen | uint64_t | 入参 | 加密后的数据长度。 | + +4. 返回值 + +返回uint64_t类型,表示解密后的原始数据长度。 + +##### UBSHcomNetEndpoint::Decrypt + +1. 函数定义 + +解密数据。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::Decrypt(const void \* cipher, uint64_t cipherLen, void \*rawData, uint64_t &rawLen) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|----------|----------|------------------------| +| cipher | void \* | 入参 | 待解密的数据地址。 | +| cipherLen | uint64_t | 入参 | 待解密的数据长度。 | +| rawData | void \* | 出参 | 解密后,原始数据地址。 | +| rawLen | uint64_t | 出参 | 解密后,原始数据长度。 | + +4. 返回值 + +返回值为0则表示解密成功。 + +##### UBSHcomNetEndpoint::SendFds + +1. 函数定义 + +发送共享文件的句柄,该接口只支持在SHM协议下使用。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::SendFds(int fds\[\], uint32_t len) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------------------------| +| fds | int\[\] | 入参 | 需要发送的句柄数组。 | +| len | uint32_t | 入参 | 发送的句柄数量。范围是\[1, 4\]。 | + +4. 返回值 + +返回值为0则表示句柄发送成功。 + +##### UBSHcomNetEndpoint::ReceiveFds + +1. 函数定义 + +接收共享文件的句柄,该接口只支持在SHM协议下使用。 + +2. 实现方法 + +NResult UBSHcomNetEndpoint::ReceiveFds(int fds\[\], uint32_t len) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------------------------| +| fds | int\[\] | 入参 | 需要接收的句柄数组。 | +| len | uint32_t | 入参 | 接收的句柄数量。范围是\[1, 4\]。 | + +4. 返回值 + +返回值为0则表示句柄发送成功。 + +##### UBSHcomNetOutLogger::Instance + +1. 函数定义 + +创建外部日志对象。 + +2. 实现方法 + +static UBSHcomNetOutLogger \*UBSHcomNetOutLogger::Instance() + +3. 参数说明 + +无 + +4. 返回值 + +返回外部日志导入对象。 + +##### UBSHcomNetOutLogger::SetLogLevel + +1. 函数定义 + +- 设置外部日志对象日志等级,设置为环境变量HCOM_SET_LOG_LEVEL。 + +- 设置外部日志对象日志等级,大于等于此等级的日志将被打印。 + +日志等级如下: + +- 0:debug + +- 1:info + +- 2:warn + +- 3:error + + 1. 实现方法 + +  + +- static void UBSHcomNetOutLogger::SetLogLevel() + +- static void UBSHcomNetOutLogger::SetLogLevel(int level) + + 1. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------------------| +| level | int | 入参 | 日志等级。范围是\[0, 3\]。 | + +2. 返回值 + +无 + +##### UBSHcomNetOutLogger::SetExternalLogFunction + +1. 函数定义 + +设置外部日志对象外部日志函数。 + +2. 实现方法 + +void UBSHcomNetOutLogger::SetExternalLogFunction(ExternalLog func) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------|----------|----------------| +| func | ExternalLog | 入参 | 外部日志函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +typedef void (\*ExternalLog)(int level, const char \*msg). + +##### UBSHcomNetOutLogger::Print + +1. 函数定义 + +打印日志。 + +2. 实现方法 + +static inline void UBSHcomNetOutLogger::Print(int level, const char \*msg) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|---------------|----------|----------------------------| +| level | int | 入参 | 日志等级。范围是\[0, 3\]。 | +| msg | const char \* | 入参 | 日志内容。 | + +4. 返回值 + +无 + +##### UBSHcomNetOutLogger::Log + +1. 函数定义 + +打印日志,如果有外部日志函数,则使用外部日志函数。 + +2. 实现方法 + +void UBSHcomNetOutLogger::Log(int level, const std::ostringstream &oss) const + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|--------------------------|----------|----------------------------| +| level | int | 入参 | 日志等级。范围是\[0, 3\]。 | +| oss | const std::ostringstream | 入参 | 日志内容。 | + +4. 返回值 + +无 + +##### UBSHcomNetOutLogger::GetLogLevel + +1. 函数定义 + +获取当日志打印等级。 + +2. 实现方法 + +int UBSHcomNetOutLogger::GetLogLevel() + +3. 参数说明 + +无 + +4. 返回值 + +返回日志等级。 + +##### UBSHcomNetAtomicState::Get + +1. 函数定义 + +获得原子状态。 + +2. 实现方法 + +T Get() const + +3. 参数说明 + +无 + +4. 返回值 + +获得原子状态。 + +##### UBSHcomNetAtomicState::Set + +1. 函数定义 + +设置原子状态。 + +2. 实现方法 + +void Set(T newState) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----------|----------|----------|----------------------| +| newState | T | 入参 | 需要设置的原子状态。 | + +4. 返回值 + +无 + +##### UBSHcomNetAtomicState::CAS + +1. 函数定义 + +原子状态比较并交换。 + +2. 实现方法 + +bool CAS(T oldState, T newState) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----------|----------|----------|--------------| +| oldState | T | 入参 | 旧原子状态。 | +| newState | T | 入参 | 新原子状态。 | + +4. 返回值 + +布尔值。检查mState是否等于oldState。如果是,则将其设置为newState,并返回true;否则不做任何修改,返回false。 + +##### UBSHcomNetAtomicState::Compare + +1. 函数定义 + +原子状态比较。 + +2. 实现方法 + +bool Compare(T state) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|------------| +| state | T | 入参 | 原子状态。 | + +4. 返回值 + +布尔值。检查mState是否等于state,如果是返回true;否则返回false。 + +### C API + +#### 服务层 + +##### ubs_hcom_service_add_workergroup + +1. 函数定义 + +向Service中增加内存池。 + +2. 实现方法 + +void ubs_hcom_service_add_workergroup(ubs_hcom_service service, int8_t priority, uint16_t workerGroupId, uint32_t threadCount, + +const char \*cpuIdsRange); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| priority | int8_t | 入参 | 线程优先级。 | +| workerGroupId | uint16_t | 入参 | 线程组ID。 | +| threadCount | uint32_t | 入参 | 组里的线程数。 | +| cpuIdsRange | const char \* | 入参 | CPU ID范围。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_add_listener + +1. 函数定义 + +添加监听线程。 + +2. 实现方法 + +void ubs_hcom_service_add_listener(ubs_hcom_service service, const char \*url, uint16_t workerCount); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| priority | int8_t | 入参 | 线程优先级。 | +| workerGroupId | uint16_t | 入参 | 线程组ID。 | +| threadCount | uint32_t | 入参 | 组里的线程数。 | +| cpuIdsRange | const char \* | 入参 | CPU ID范围。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_lbpolicy + +1. 函数定义 + +设置负载均衡策略。 + +2. 实现方法 + +void ubs_hcom_service_set_lbpolicy(ubs_hcom_service service, ubs_hcom_service_lb_policy lbPolicy); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| lbPolicy | ubs_hcom_service_lb_policy | 入参 | 负载均衡策略。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_tls_opt + +1. 函数定义 + +设置TLS配置项。 + +2. 实现方法 + +void ubs_hcom_service_set_tls_opt(ubs_hcom_service service, bool enableTls, ubs_hcom_service_tls_version version, + +ubs_hcom_service_cipher_suite cipherSuite, ubs_hcom_tls_get_cert_cb certCb, ubs_hcom_tls_get_pk_cb priKeyCb, ubs_hcom_tls_get_ca_cb caCb); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| enableTls | bool | 入参 | 是否开启TLS。 | +| version | ubs_hcom_service_tls_version | 入参 | TLS版本。 | +| cipherSuite | ubs_hcom_service_cipher_suite | 入参 | 加密方式。 | +| certCb | ubs_hcom_tls_get_cert_cb | 入参 | 获取TLS证书的回调。 | +| priKeyCb | ubs_hcom_tls_get_pk_cb | 入参 | 获取TLS私钥的回调。 | +| caCb | ubs_hcom_tls_get_ca_cb | 入参 | 获取TLS认证的回调。 | + +![](media/image8.png) + +使用UB自举建链时,暂不支持安全认证和安全加密。 + +4. 返回值 + +无 + +##### ubs_hcom_service_set_secure_opt + +1. 函数定义 + +设置安全加密选项。 + +2. 实现方法 + +void ubs_hcom_service_set_secure_opt(ubs_hcom_service service, ubs_hcom_service_secure_type secType, ubs_hcom_secinfo_provider provider, + +ubs_hcom_secinfo_validator validator, uint16_t magic, uint8_t version); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| secType | ubs_hcom_service_secure_type | 入参 | 加密方式。 | +| provider | ubs_hcom_secinfo_provider | 入参 | 密钥提供回调。 | +| validator | ubs_hcom_secinfo_validator | 入参 | 密钥校验回调。 | +| magic | uint16_t | 入参 | 魔数。 | +| version | uint8_t | 入参 | 安全加密版本。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_tcp_usr_timeout + +1. 函数定义 + +设置负载均衡策略。 + +2. 实现方法 + +void ubs_hcom_service_set_tcp_usr_timeout(ubs_hcom_service service, uint16_t timeOutSec); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|------------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| timeOutSec | uint16_t | 入参 | TCP超时时间 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_tcp_send_zcopy + +1. 函数定义 + +设置负载均衡策略。 + +2. 实现方法 + +void ubs_hcom_service_set_tcp_send_zcopy(ubs_hcom_service service, bool tcpSendZCopy); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| tcpSendZCopy | bool | 入参 | TCP是否开启ZCopy。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_ipmask + +1. 函数定义 + +设置要监听的IP。 + +2. 实现方法 + +void ubs_hcom_service_set_ipmask(ubs_hcom_service service, const char \*ipMask); + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +4. 返回值 + +无 + +##### ubs_hcom_service_set_ipgroup + +1. 函数定义 + +设置要监听的IP。 + +2. 实现方法 + +void ubs_hcom_service_set_ipgroup(ubs_hcom_service service, const char \*ipGroup); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| ipGroup | const char \* | 入参 | 要监听的IP,如果明确指定了ipGroup,则直接使用对应的设备。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_cq_depth + +1. 函数定义 + +设置cq队列的深度。 + +2. 实现方法 + +void ubs_hcom_service_set_cq_depth(ubs_hcom_service service, uint16_t depth); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| depth | uint16_t | 入参 | cq队列的深度。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_sq_size + +1. 函数定义 + +设置SQ队列的大小,默认256。 + +2. 实现方法 + +void ubs_hcom_service_set_sq_size(ubs_hcom_service service, uint32_t sqSize); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| sqSize | uint32_t | 入参 | SQ队列的大小,默认256。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_rq_size + +1. 函数定义 + +设置RQ队列的大小,默认256。 + +2. 实现方法 + +void ubs_hcom_service_set_rq_size(ubs_hcom_service service, uint32_t rqSize); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| rqSize | uint32_t | 入参 | RQ队列的大小,默认256。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_polling_batchsize + +1. 函数定义 + +设置传输层worker单次poll的个数。 + +2. 实现方法 + +void ubs_hcom_service_set_polling_batchsize(ubs_hcom_service service, uint16_t pollSize); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----------|------------------|----------|----------------------------------| +| service | ubs_hcom_service | 入参 | 之前创建的ubs_hcom_service对象。 | +| pollSize | uint16_t | 入参 | 单次poll的个数。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_polling_timeoutus + +1. 函数定义 + +设置event polling的超时时间。 + +2. 实现说明 + +void ubs_hcom_service_set_polling_timeoutus(ubs_hcom_service service, uint16_t pollTimeout); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-------------|------------------|----------|------------------------------| +| service | ubs_hcom_service | 入参 | 创建的ubs_hcom_service对象。 | +| pollTimeout | uint16_t | 入参 | event polling超时时间。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_timeout_threadnum + +1. 函数定义 + +设置周期任务处理线程数。 + +2. 实现说明 + +void ubs_hcom_service_set_timeout_threadnum(ubs_hcom_service service, uint32_t threadNum); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|------------------|----------|------------------------------| +| service | ubs_hcom_service | 入参 | 创建的ubs_hcom_service对象。 | +| threadNum | uint32_t | 入参 | 周期任务处理线程数。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_max_connection_cnt + +1. 函数定义 + +设置最大链接数。 + +2. 实现说明 + +void ubs_hcom_service_set_max_connection_cnt(ubs_hcom_service service, uint32_t maxConnCount); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------------|------------------|----------|------------------------------| +| service | ubs_hcom_service | 入参 | 创建的ubs_hcom_service对象。 | +| maxConnCount | uint32_t | 入参 | 最大链接数。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_heartbeat_opt + +1. 函数定义 + +设置心跳参数配置项。 + +2. 实现说明 + +void ubs_hcom_service_set_heartbeat_opt(ubs_hcom_service service, uint16_t idleSec, uint16_t probeTimes, uint16_t intervalSec); + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| service | ubs_hcom_service | 入参 | 创建的ubs_hcom_service对象。 | +| idleSec | uint16_t | 入参 | 发送心跳保活消息间隔时间。 | +| probeTimes | uint16_t | 入参 | 发送心跳探测失败/没收到回复重试次数,超了认为连接已经断开。 | +| intervalSec | uint16_t | 入参 | 发送心跳后再次发送时间。 | + +4. 返回值 + +无 + +##### ubs_hcom_service_set_multirail_opt + +1. 函数定义 + +设置多路径参数配置项。 + +2. 实现说明 + +void ubs_hcom_service_set_multirail_opt(ubs_hcom_service service, bool enable, uint32_t threshold); + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +无 + +##### ubs_hcom_set_log_handler + +1. 函数定义 + +设置外部日志模板。 + +2. 实现方法 + +void ubs_hcom_set_log_handler(ubs_hcom_log_handler h) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------------------|----------|--------------------| +| h | ubs_hcom_log_handler | 入参 | 外部日志回调函数。 | + +4. 返回值 + +无 + +![](media/image8.png) + +数据类型解释如下: + +typedef void (\*ubs_hcom_log_handler)(int level, const char \*msg). + +#### 传输层 + +##### ubs_hcom_driver_register_tls_cb + +1. ?.1.接口使用方法 + + 函数定义 + +注册建链双向认证的回调函数,分别用于获取CA证书、获取公钥证书和私钥证书。 + +3. 实现方法 + +uintptr_t ubs_hcom_driver_register_tls_cb(ubs_hcom_driver driver, ubs_hcom_tls_get_cert_cb certCb, ubs_hcom_tls_get_pk_cb priKeyCb, ubs_hcom_tls_get_ca_cb caCb) + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| driver | ubs_hcom_driver | 入参 | 创建的ubs_hcom_driver对象。 | +| certCb | [Hcom_TlsGetCb](#ZH-CN_TOPIC_0000002465376234) | 入参 | 回调函数句柄。 | +| priKeyCb | [ubs_hcom_tls_get_pk_cb](#ZH-CN_TOPIC_0000002498615949) | 入参 | 回调函数句柄。 | +| caCb | [ubs_hcom_tls_get_ca_cb](#ZH-CN_TOPIC_0000002465535906) | 入参 | 回调函数句柄。 | + +5. 返回值 + +uintptr_t,返回内部句柄地址。 + +6. ?.2.ubs_hcom_tls_get_ca_cb函数类型 + + 函数定义 + +typedef int (\*ubs_hcom_tls_get_ca_cb)(const char \*name, char \*\*caPath, char \*\*crlPath, ubs_hcom_peer_cert_verify_type \*verifyType, ubs_hcom_tls_cert_verify \*verify) + +8. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +返回值为0则表示回调函数执行成功。 + +2. ?.3.ubs_hcom_peer_cert_verify_type函数类型 + + 函数定义 + +typedef int (\*ubs_hcom_tls_cert_verify)(void \*x509, const char \*crlPath) + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|----------|----------|----------------------| +| x509 | void\* | 入参 | 加载之后的x509证书。 | +| crlPath | char \* | 入参 | 提供的吊销列表路径。 | + +5. 返回值 + +表示函数执行结果,返回值为0则表示证书验证成功。 + +6. ?.4.Hcom_TlsGetCb + + 函数定义 + +typedef int (\*ubs_hcom_tls_get_cert_cb)(const char \*name, char \*\*certPath) + +8. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|------------|----------|----------|----------------------| +| name | char \* | 出参 | 句柄名字。 | +| \*certPath | char \* | 出参 | 提供的公钥证书路径。 | + +9. 返回值 + +Bool类型,表示回调函数是否执行成功 + +- 返回值为true:表示成功。 + +- 返回值为false:表示失败。 + + 1. ?.5.ubs_hcom_tls_get_pk_cb + + 函数定义 + +typedef int (\*ubs_hcom_tls_get_pk_cb)(const char \*name, char \*\*priKeyPath, char \*\*keyPass, ubs_hcom_tls_keypass_erase \*erase) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| name | char \* | 出参 | 句柄名字。 | +| \*priKeyPath | char \* | 出参 | 提供的私钥证书路径。 | +| \*keyPass | void\* | 出参 | 私钥加载的明文密码。 | +| erase | [ubs_hcom_tls_keypass_erase](#ZH-CN_TOPIC_0000002498495445) \* | 出参 | 擦除私钥密码的回调函数,当加载完私钥的时候调用。 | + +4. 返回值 + +返回值为0则表示回调函数执行成功。 + +5. ?.6.ubs_hcom_tls_keypass_erase + + 函数定义 + +typedef void (\*ubs_hcom_tls_keypass_erase)(char \*keyPass, int len) + +7. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|---------|----------|----------|----------------------| +| keyPass | char \* | 入参 | 私钥加载的明文密码。 | +| len | int | 入参 | 私钥加载的密码长度。 | + +8. 返回值 + +无 + +##### ubs_hcom_ep_post_send_raw + +1. 函数定义 + +向对端发送一个带有op信息的请求。 + +2. 实现方法 + +int ubs_hcom_ep_post_send_raw(ubs_hcom_endpoint ep, ubs_hcom_send_request \*req, uint32_t seqNo) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| \*req | [ubs_hcom_send_request](#ubs_hcom_send_request) | 入参 | 发送请求信息,使用本地内存来存储数据,数据会被复制,调用后可释放本地内存。 | +| seqNo | uint32_t | 入参 | 对端用于回复的序列号。 | + +##### ubs_hcom_ep_post_send_raw_sgl + +1. 函数定义 + +向对端发送请求。 + +2. 实现方法 + +int ubs_hcom_ep_post_send_raw_sgl(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request_sgl \*req, uint32_t seqNo) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| \*req | ubs_hcom_readwrite_request_sgl | 入参 | 请求信息。 | +| seqNo | uint32_t | 入参 | 对方要回复的seqNo必须大于0,对方可以从context.Header().seqNo中获取它;如果seqNo为0,则生成自动递增的数字。在同步发送消息的情况下,请求和响应的seqNo相等。 | + +2. ubs_hcom_readwrite_request_sgl结构体 + +[TABLE] + +4. 返回值 + +返回值为0则表示发送请求成功。 + +##### ubs_hcom_ep_post_read_sgl + +1. 函数定义 + +向对端发送一个读请求。 + +2. 实现方法 + +int ubs_hcom_ep_post_read_sgl(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request_sgl \*req) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| \*req | [ubs_hcom_readwrite_request_sgl](#ubs_hcom_readwrite_request_sgl) | 入参 | 读请求信息。 | + +4. 返回值 + +返回值为0则表示读成功。 + +##### ubs_hcom_ep_post_write_sgl + +1. 函数定义 + +向对端发送一个写请求。 + +2. 实现方法 + +int ubs_hcom_ep_post_write_sgl(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request_sgl \*req) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| \*req | [ubs_hcom_readwrite_request_sgl](#ubs_hcom_readwrite_request_sgl) | 入参 | 写请求信息。 | + +4. 返回值 + +返回值为0则表示写成功。 + +##### ubs_hcom_ep_receive_raw + +1. 函数定义 + +接收消息,仅对NET_C_EP_SELF_POLLING设置时使用。 + +2. 实现方法 + +int ubs_hcom_ep_receive_raw(ubs_hcom_endpoint ep, int32_t timeout, ubs_hcom_response_context \*\*ctx) + +3. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +返回值为0则表示接收消息成功。 + +##### ubs_hcom_ep_receive_raw_sgl + +1. 函数定义 + +接收对端发送过来的SGL消息。 + +2. 实现方法 + +int ubs_hcom_ep_receive_raw_sgl(ubs_hcom_endpoint ep, int32_t timeout, ubs_hcom_response_context \*\*ctx) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | EndPoint。 | +| timeout | int32_t | 入参 | 超时时间,单位是秒。0为立刻超时,负数为永不超时。 | +| ctx | [ubs_hcom_response_context](#ubs_hcom_response_context) | 出参 | 接收到的消息。 | + +4. 返回值 + +返回值为0则表示成功。 + +##### ubs_hcom_estimate_encrypt_len + +1. 函数定义 + +输入原始数据大小的估计加密长度。 + +2. 实现方法 + +uint64_t ubs_hcom_estimate_encrypt_len(ubs_hcom_endpoint ep, uint64_t rawLen) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| rawLen | uint64_t | 入参 | 原始数据长度。范围是(0, 18446744073709551571\]。 | + +4. 返回值 + +返回uint64_t类型,表示数据加密长度。 + +##### ubs_hcom_encrypt + +1. 函数定义 + +加密数据。 + +2. 实现方法 + +int ubs_hcom_encrypt(ubs_hcom_endpoint ep, const void \*rawData, uint64_t rawLen, void \*cipher, uint64_t \*cipherLen) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|-------------------|----------|----------------------| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| rawData | void \* | 入参 | 原始数据地址。 | +| rawLen | uint64_t | 入参 | 原始数据长度。 | +| cipher | void \* | 出参 | 加密后数据地址。 | +| cipherLen | uint64_t | 出参 | 加密后数据长度。 | + +4. 返回值 + +返回值为0则表示加密成功。 + +##### ubs_hcom_estimate_decrypt_len + +1. 函数定义 + +输出原始数据大小。 + +2. 实现方法 + +uint64_t ubs_hcom_estimate_decrypt_len(ubs_hcom_endpoint ep, uint64_t cipherLen) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|-------------------|----------|----------------------| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| cipherLen | uint64_t | 入参 | 加密后的数据长度。 | + +4. 返回值 + +返回uint64_t类型,表示解密后的原始数据长度。 + +##### ubs_hcom_decrypt + +1. 函数定义 + +解密数据。 + +2. 实现方法 + +int ubs_hcom_decrypt(ubs_hcom_endpoint ep, const void \*cipher, uint64_t cipherLen, void \*rawData, uint64_t \*rawLen) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|-----------|-------------------|----------|------------------------| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| cipher | void \* | 入参 | 待解密的数据地址。 | +| cipherLen | uint64_t | 入参 | 待解密的数据长度。 | +| rawData | void \* | 出参 | 解密后,原始数据地址。 | +| rawLen | uint64_t | 出参 | 解密后,原始数据长度。 | + +4. 返回值 + +返回值为0则表示解密成功。 + +##### ubs_hcom_send_fds + +1. 函数定义 + +发送共享文件的句柄,该接口只支持在SHM协议下使用。 + +2. 实现方法 + +int ubs_hcom_send_fds(ubs_hcom_endpoint ep, int fds\[\], uint32_t len) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|-------------------|----------|----------------------------------| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| fds | int\[\] | 入参 | 需要发送的句柄数组。 | +| len | uint32_t | 入参 | 发送的句柄数量。范围是\[1, 4\]。 | + +4. 返回值 + +返回值为0则表示句柄发送成功。 + +##### ubs_hcom_receive_fds + +1. 函数定义 + +接收共享文件的句柄,该接口只支持在SHM协议下使用。 + +2. 实现方法 + +int ubs_hcom_receive_fds(ubs_hcom_endpoint ep, int fds\[\], uint32_t len, int timeoutSec) + +3. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|------------|-------------------|----------|------------------------------------| +| ep | ubs_hcom_endpoint | 入参 | 建链创建好的EP对象。 | +| fds | int\[\] | 出参 | 需要接收的句柄数组。 | +| len | uint32_t | 入参 | 接收的句柄数量。范围是\[1, 4\]。 | +| timeoutSec | int | 入参 | 设置接收超时时间,-1表示不设超时。 | + +4. 返回值 + +返回值为0则表示句柄发送成功。 + +## 结构体参考 + +### C++结构体 + +#### 服务层结构体 + +##### UBSHcomServiceOptions + +1. 参数说明 + +[TABLE] + +![](media/image8.png) + +双边操作允许发送最大消息的长度,可结合使用场景通过maxSendRecvDataSize来配置。 + +##### UBSHcomConnectOptions + +1. 参数说明 + +| 参数名 | 数据类型 | 默认值 | 描述 | +|----|----|----|----| +| clientGroupId | uint16_t | 0 | 客户端worker线程池ID。 | +| serverGroupId | uint16_t | 0 | 服务端worker线程池ID。 | +| linkCount | uint8_t | 1 | 链接数。 | +| mode | [UBSHcomClientPollingMode](#ubshcomclientpollingmode) | WORKER_POLL | 客户端调用通信接口时poll的模式。 | +| cbType | [UBSHcomChannelCallBackType](#ubshcomchannelcallbacktype) | CHANNEL_FUNC_CB | 回调类型。 | +| payload | std::string | 空 | 建链发送给服务端的payload。 | + +##### UBSHcomRequest + +1. 参数说明 + +| 配置项 | 数据类型 | 默认值 | 说明 | +|---------|----------|---------|--------------------------------------| +| address | void\* | nullptr | 数据指针。该字段在内部有空指针校验。 | +| size | uint32_t | 0 | 数据大小。范围是(0,UINT32_MAX\]。 | +| key | uint64_t | 0 | 数据地址key值。 | +| opcode | uint16_t | 0 | 操作类型。 | + +##### UBSHcomResponse + +1. 参数说明 + +| 配置项 | 类型 | 默认值 | 说明 | +|-----------|----------|---------|--------------| +| address | void\* | nullptr | 数据指针。 | +| size | uint32_t | 0 | 数据大小。 | +| errorCode | int16_t | 0 | 回复错误码。 | + +##### UBSHcomReplyContext + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|-----------|-----------|-------------------------------------| +| rspCtx | uintptr_t | 回复上下文,可从回调context中获取。 | +| errorCode | int16_t | 回复的错误码。 | + +##### UBSHcomOneSideRequest + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|----------|------------------|----------------------------------------------| +| lAddress | uintptr_t | 单边通信本端内存地址。 | +| rAddress | uintptr_t | 单边通信,对端内存地址。 | +| lKey | UBSHcomMemoryKey | 单边通信,本端UBSHcomMemoryKey。 | +| rKey | UBSHcomMemoryKey | 单边通信,对端UBSHcomMemoryKey。 | +| size | uint32_t | 单边通信,数据大小。范围是(0, UINT32_MAX\]。 | + +2. UBSHcomMemoryKey + +| 参数名 | 数据类型 | 描述 | +|----|----|----| +| keys | uint64_t \[4\] | 内存注册后的内存区域key,MultiRail场景下多个设备有多个key(最多4个),非MultiRail场景下只需要1个。 | +| tokens | uint64_t | UBC场景下注册内存区域的token value,MultiRail场景下多个设备有多个token value(最多4个),非MultiRail场景下只需要1个。 | + +##### UBSHcomFlowCtrlOptions + +1. 参数说明 + +[TABLE] + +##### UBSHcomTlsOptions + +1. ?.1.参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 默认值 | 描述 | +|----|----|----|----| +| caCb | UBSHcomTLSCaCallback | nullptr | 建链双向认证的回调函数,用于获取CA证书。 | +| cfCb | UBSHcomTLSCertificationCallback | nullptr | 建链双向认证的回调函数,用于获取公钥证书。 | +| pkCb | UBSHcomTLSPrivateKeyCallback | nullptr | 建链双向认证的回调函数,用于获取私钥证书。 | +| tlsVersion | UBSHcomTlsVersion | UBSHcomTlsVersion::TLS_1_3 | TLS版本,支持TLS1.3,不再支持TLS1.2。 | +| netCipherSuite | UBSHcomCipherSuite | UBSHcomCipherSuite::AES_GCM_128 | 加密算法,取值范围见[UBSHcomNetCipherSuite](#ubshcomnetciphersuite)。 | +| enableTls | bool | true | 是否开启TLS认证。 | + +2. ?.2.UBSHcomTLSCaCallback函数类型 + + 函数定义 + +using UBSHcomTLSCaCallback = std::function\; + +4. 参数说明 + + 1. 参数说明 + +[TABLE] + +1. 返回值 + +Bool类型,表示回调函数是否执行成功。 + +- 返回值为true:表示成功。 + +- 返回值为false:表示失败。 + + 1. 代码样例 + +int Verify(void \*x509, const char \*path) +{ +return 0; +} +bool CACallback(const std::string &name, std::string &caPath, std::string &crlPath, +UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ +caPath = certPath + "/CA/cacert.pem"; +cb = std::bind(&Verify, std::placeholders::\_1, std::placeholders::\_2); +return true; +} + +2. ?.3.UBSHcomTLSCertVerifyCallback函数类型 + + 函数定义 + +using UBSHcomTLSCertificationCallback = std::function\; + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|--------|----------|----------|----------------------| +| name | String | 入参 | 句柄名字。 | +| path | String | 出参 | 提供的公钥证书路径。 | + +5. 返回值 + +Bool类型,表示回调函数是否执行成功。 + +- 返回值为true:表示成功。 + +- 返回值为false:表示失败。 + + 1. 代码样例 + +bool CertCallback(const std::string &name, std::string &value) +{ +value = certPath + "/client/cert.pem"; +return true; +} + +2. ?.4.UBSHcomTLSPrivateKeyCallback函数类型 + + 函数定义 + +using UBSHcomTLSPrivateKeyCallback = std::function\; + +4. 参数说明 + + 1. 参数说明 + +| 参数名 | 数据类型 | 参数类型 | 描述 | +|----|----|----|----| +| name | String | 入参 | 句柄名字。 | +| path | String | 出参 | 提供的私钥证书路径。 | +| password | void\* | 出参 | 私钥加载的明文密码。 | +| length | int | 出参 | 私钥加载的密码长度。 | +| erase | [UBSHcomTLSEraseKeypass](#ZH-CN_TOPIC_0000002498615289) | 出参 | 擦除私钥密码的回调函数,当加载完私钥的时候调用。 | + +5. 返回值 + +Bool类型,表示回调函数是否执行成功。 + +- 返回值为true:表示成功。 + +- 返回值为false:表示失败。 + + 1. 代码样例 + +void Erase(void \*pass, int len) {} + +bool PrivateKeyCallback(const std::string &name, std::string &value, void \*&keyPass, int &len, UBSHcomTLSEraseKeypass &erase) +{ +static char content\[\] = "xxxx"; +keyPass = reinterpret_cast\(content); +len = sizeof(content); +value = certPath + "/client/key.pem"; +erase = std::bind(&Erase, std::placeholders::\_1, std::placeholders::\_2); +return true; +} + +##### UBSHcomConnSecureOptions + +1. 参数说明 + +[TABLE] + +![](media/image8.png) + +数据类型解释如下: + +using UBSHcomDriverSecInfoProvider= std::function\; + +其中,outLen的有效范围为(0, 2147483646\]。 + +using UBSHcomNetDriverEndpointSecInfoValidator = std::function\; + +##### UBSHcomHeartBeatOptions + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | 默认值 | +|----|----|----|----| +| heartBeatIdleSec | uint16_t | 发送心跳保活消息间隔时间。 | 60 | +| heartBeatProbeTimes | uint16_t | 发送心跳探测失败/没收到回复重试次数,超了认为连接已经断开。 | 7 | +| heartBeatProbeIntervalSec | uint16_t | 发送心跳后再次发送时间。 | 2 | + +##### UBSHcomMultiRailOptions + +1. 参数说明 + +[TABLE] + +##### UBSHcomIov + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | 默认值 | +|---------|----------|------------|---------| +| address | void \* | 地址值。 | nullptr | +| size | uint32_t | 数据大小。 | 0 | + +##### UBSHcomOneSideSglRequest + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | 默认值 | +|----------|-----------------------|---------------|---------| +| iov | UBSHcomOneSideRequest | 单边iov数组。 | nullptr | +| iovCount | uint16_t | iov数量。 | 0 | + +##### UBSHcomMemoryKey + +1. 参数说明 + +| 参数名 | 数据类型 | 默认值 | 描述 | +|-------------|----------|--------|----------------------------| +| keys\[4\] | uint64_t | \- | key数组。 | +| tokens\[4\] | uint64_t | \- | UBC场景下的token value数组 | + +##### UBSHcomSglRequest + +1. 参数说明 + +| 参数名 | 数据类型 | 默认值 | 描述 | +|----------|----------------|---------|---------------| +| iov | UBSHcomRequest | nullptr | 双边iov数组。 | +| iovCount | uint16_t | 0 | iov数量。 | + +##### UBSHcomTwoSideThreshold + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | 默认值 | +|----|----|----|----| +| splitThreshold | uint32_t | UBC专用。此值表示拆包发送的阈值,也可以当做拆包发送时每个小包的最大长度(含额外头部),一般将其配置成小于等于SegSize的值。可配置范围为 \[128, maxSendRecvDataSize\],特别的配置成UINT32_MAX会禁用拆包功能。 | UINT32_MAX | +| rndvThreshold | uint32_t | rndv阈值,请求长度大于等于该值,则启用RNDV。 | UINT32_MAX | + +#### 传输层结构体 + +##### UBSHcomNetDriverDeviceInfo + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|--------|----------|--------------------------------| +| maxSge | int | 最大SGL数组元素个数,默认为4。 | + +##### UBSHcomNetDriverOptions + +1. 参数说明 + +[TABLE] + +![](media/image8.png) + +- UBS Comm默认开启TLS认证,关闭认证可能存在安全风险,用户可通过enableTls = false进行关闭。 + +- 双边操作允许发送最大消息的长度,可结合使用场景通过mrSendReceiveSegSize来配置。 + +##### UBSHcomNetOobListenerOptions + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | 默认值 | +|----|----|----|----| +| ip | char\[16\] | 监听的IP。 | \- | +| port | uint16_t | 监听的端口号,默认是9980。范围是\[1024, 65535\]。 | 9980 | +| targetWorkerCount | uint16_t | 可用worker数量,0代表全部,默认是全部。 | UINT16_MAX | + +##### UBSHcomNetOobUDSListenerOptions + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | 默认值 | +|----|----|----|----| +| name | char\[96\] | 监听的UDS name。长度范围是(0, 96)。 | \- | +| perm | uint16_t | 0代表不使用文件,其他情况则使用文件,此参数为其权限,最高为0600。 | 0600 | +| targetWorkerCount | uint16_t | 可用worker数量,0代表全部,默认是全部。 | UINT16_MAX | +| isCheck | bool | 是否校验权限,默认值为true。 | true | + +##### UBSHcomEpOptions + +1. 参数说明 + +[TABLE] + +##### UBSHcomNetTransRequest + +1. 参数说明 + +| 配置项 | 类型 | 默认值 | 说明 | +|----|----|----|----| +| lAddress | uintptr_t | 0 | 本地缓存地址。 | +| rAddress | uintptr_t | 0 | 远程缓存地址。 | +| lKey | uint64_t | 0 | 本地内存区域key。 | +| rKey | uint64_t | 0 | 远程内存区域key。 | +| size | uint32_t | 0 | 缓存大小。有效范围为(0, UINT32_MAX\]。 | +| upCtxSize | uint16_t | 0 | 上下文大小。 | +| upCtxData | char\[64\] | \- | 上下文数据。 | +| srcSeg | void \* | nullptr | 仅UB场景使用,填写发送端的urma_target_seg_t \*指针。 | +| dstSeg | void \* | nullptr | 仅UB场景使用,填写目的端的urma_target_seg_t \*指针。 | + +##### UBSHcomNetTransOpInfo + +1. 参数说明 + +[TABLE] + +##### UBSHcomNetUdsIdInfo + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | 默认值 | +|--------|----------|----------|--------| +| pid | uint32_t | 进程ID。 | 0 | +| uid | uint32_t | 用户ID。 | 0 | +| gid | uint32_t | 组ID。 | 0 | + +##### UBSHcomNetMemoryAllocatorOptions + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | 默认值 | +|----|----|----|----| +| address | uintptr_t | 内存地址。 | 0 | +| size | uint64_t | 内存大小。 | 0 | +| minBlockSize | uint32_t | 分配时最小单位大小(2的倍数)。范围是\[4096, 1073741824\],单位是byte。 | 0 | +| bucketCount | uint32_t | 对齐的前提下,HashMap的桶数。 | 8192 | +| alignedAddress | bool | 是否对齐。 | false | +| cacheTierCount | uint16_t | 缓存器的层数。 | 8 | +| cacheBlockCountPerTier | uint16_t | 每层有多少个内存块。 | 16 | +| cacheTierPolicy | [UBSHcomNetMemoryAllocatorCacheTierPolicy](#ZH-CN_TOPIC_0000002465536242) | 分层策略,0为times,1为power。 | TIER_TIMES | + +##### UBSHcomNetTransSglRequest + +1. 参数说明 + +| 配置项 | 类型 | 默认值 | 说明 | +|----|----|----|----| +| \*iov | [UBSHcomNetTransSgeIov](#ZH-CN_TOPIC_0000002465376330) | Nullptr | 消息数组。该字段在内部有空指针校验。 | +| iovCount | uint16_t | 0 | 数组长度。最大为4。 | +| upCtxSize | uint16_t | 0 | 上下文大小。 | +| upCtxData\[16\] | char | \- | 上下文数据。 | + +##### UBSHcomNetTransSgeIov + +1. 参数说明 + +| 配置项 | 类型 | 默认值 | 说明 | +|----|----|----|----| +| lAddress | uintptr_t | 0 | 本端内存地址。 | +| rAddress | uintptr_t | 0 | 对端内存地址。 | +| lKey | uint64_t | 0 | 本端key。 | +| rKey | uint64_t | 0 | 对端key。 | +| size | uint32_t | 0 | 内存大小。 | +| memid | unsigned long | 0 | 显示Urmah在rndv中使用的obmm内存。 | +| srcSeg | void \* | nullptr | 仅UB场景使用,填写发送端的urma_target_seg_t \*指针。 | +| dstSeg | void \* | nullptr | 仅UB场景使用,填写目的端的urma_target_seg_t \*指针。 | + +##### UBSHcomWorkerGroupInfo + +1. 参数说明 + +| 配置项 | 类型 | 默认值 | 说明 | +|----|----|----|----| +| threadPriority | int8_t | 0 | 线程优先级。范围:\[-20, 19\] | +| threadCount | uint16_t | 1 | 线程总数。 | +| groupId | uint16_t | 0 | worker线程中的组ID。 | +| cpuIdsRange | std:pair\ | \- | 指定worker线程CPU ID。 | + +##### UBSHcomNetUdsIdInfo + +1. 参数说明 + +| 配置项 | 类型 | 默认值 | 说明 | +|--------|----------|--------|----------| +| pid | uint32_t | 0 | 进程ID。 | +| uid | uint32_t | 0 | 用户ID。 | +| gid | uint32_t | 0 | 组ID。 | + +##### UBSHcomNetTransHeader + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|----|----|----| +| headerCrc | uint32_t | crc值。 | +| opCode | int16_t | 用户定义的操作码。传输层范围\[0, 1023\],service层范围\[0, 999\]。 | +| flags | uint16_t | 保留位。 | +| seqNo | uint32_t | 序列号。 | +| timeout | int16_t | 超时时间。 | +| errorCode | int16_t | 错误码。 | +| dataLength | uint32_t | 数据长度。 | +| immData | uint32_t | 立即数。 | +| extHeaderType | UBSHcomExtHeaderType | 传输层payload中是否存在服务层的头部,用户不使用。 | + +1. 结构体函数定义 + +重置opcode、seqNo、errorCode和dataLenagth。 + +2. 实现说明 + +void Invalid(); + +3. 参数说明 + +无 + +4. 返回值 + +无 + +### C结构体 + +#### 服务层结构体 + +##### ubs_hcom_mr_info + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|----------|----------------------|--------------| +| lAddress | uintptr_t | mr内存地址。 | +| lKey | ubs_hcom_oneside_key | mr key。 | +| size | uint32_t | mr内存大小。 | + +##### ubs_hcom_channel_reply_context + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|-----------|----------|--------------------------| +| rspCtx | void \* | 用于回复的RSP上下文。 | +| errorCode | int16_t | 失败场景下回复的错误码。 | + +##### ubs_hcom_oneside_request + +1. 参数说明 + +| 参数名称 | 数据类型 | 描述 | +|----------|----------------------|----------------------------------------| +| lAddress | uintptr_t | 本地的地址。 | +| rAddress | uintptr_t | 远端的地址。 | +| lKey | ubs_hcom_oneside_key | 本地MR的key。 | +| rKey | ubs_hcom_oneside_key | 远端MR的key。 | +| size | uint32_t | 数据大小。有效范围为(0, UINT32_MAX\]。 | + +##### ubs_hcom_channel_callback + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|--------|--------------------------|----------------------| +| cb | ubs_hcom_channel_cb_func | 回调函数。 | +| arg | void \* | 回调函数的参数指针。 | + +##### ubs_hcom_flowctl_opts + +1. 参数说明 + +[TABLE] + +##### ubs_hcom_service_options + +1. 参数说明 + +| 配置项 | 数据类型 | 默认值 | 说明 | +|----|----|----|----| +| mrSendReceiveSegSize | uint32_t | \- | 双边发送消息时,采用bcopy模式,发送端和接收端预留的内存大小。范围为(0, 524288000\],单位byte。 | +| workerGroupId | uint16_t | \- | worker group编号。 | +| workerGroupThreadCount | uint16_t | \- | worker group内worker线程数量。 | +| workerGroupMode | ubs_hcom_service_worker_mode | \- | busy poll/event poll模式。 | +| workerThreadPriority | int8_t | \- | worker线程优先级设置,\[-20, 20\],20为优先级最低,-20为优先级最高,0为不设置优先级。 | +| workerGroupCpuRange | char\[64\] | \- | worker group内worker线程cpu绑核id,例:'0-0',为绑在cpu id 0上。ID为UINT32_MAX即为不绑。 | + +![](media/image8.png) + +UBS Comm默认开启TLS认证,关闭认证可能存在安全风险,用户可通过Service_SetUBSHcomTlsOptions函数进行关闭。 + +##### Service_UBSHcomConnectOptions + +1. 参数说明 + +| 参数名 | 类型 | 默认值 | 说明 | +|----|----|----|----| +| clientGroupId | uint16_t | \- | client端worker group索引。 | +| serverGroupId | uint16_t | \- | server端worker group索引。 | +| linkCount | uint8_t | \- | channel内单个路径的ep数量。多路径场景下实际ep数量为linkCount \* 路径数。 | +| mode | ubs_hcom_service_polling_mode | \- | channel内ep poll模式。 | +| cbType | ubs_hcom_channel_cb_type | \- | cb方式,每次传入或全局同一个cb。 | +| payLoad | char\[512\] | \- | 用户可携带的自定义信息。 | + +##### ubs_hcom_channel_request + +1. 参数说明 + +| 参数名 | 类型 | 默认值 | 说明 | +|---------|----------|--------|----------------------| +| address | void \* | \- | 消息内存首地址。 | +| size | uint32_t | \- | 消息大小。 | +| opcode | uint16_t | \- | 用户自定义的opcode。 | + +##### ubs_hcom_channel_response + +1. 参数说明 + +| 参数名 | 类型 | 默认值 | 说明 | +|-----------|----------|--------|-----------------------------------------------| +| address | void \* | \- | 消息内存首地址。 | +| size | uint32_t | \- | 消息大小。 | +| errorCode | uint16_t | \- | 用户自定义的errorCode,对端回复时用户可填写。 | + +##### Channel_UBSHcomTwoSideThreshold + +1. 参数说明 + +[TABLE] + +##### ubs_hcom_oneside_key + +1. 参数说明 + +| 参数名 | 类型 | 默认值 | 说明 | +|----|----|----|----| +| keys | uint64_t\[4\] | \- | 已注册内存的key。多路径场景每个路径有一个key,单路径场景只使用key\[0\]。 | + +#### 传输层结构体 + +##### ubs_hcom_send_request + +1. 参数说明 + +| 参数名 | 类型 | 默认值 | 说明 | +|-----------|------------|--------|----------------------------| +| data | uintptr_t | 0 | 准备发送给对方的数据地址。 | +| size | uint32_t | 0 | 数据大小。 | +| upCtxSize | uint16_t | 0 | 用户上下文大小。 | +| upCtxData | char\[16\] | \- | 用户上下文。 | + +##### ubs_hcom_opinfo + +1. 参数说明 + +| 参数名参数说明 | 数据类型 | 描述 | +|----|----|----| +| seqNo | uint32_t | 序列号。范围是\[0, 1023\]。 | +| timeout | uint16_t | 超时时间,单位为秒。0为立刻超时,负数为永不超时。范围\[-1, 1200\]。 | +| errorCode | int16_t | 错误码。 | +| flags | uint8_t | 标志位。 | + +##### ubs_hcom_device_info + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|--------|----------|--------------------------------------| +| maxSge | int | RDMA设备信息,最大的SGL的iov count。 | + +##### ubs_hcom_readwrite_request + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|-----------|------------|--------------------| +| lMRA | uintptr_t | 本端MR的地址。 | +| rMRA | uintptr_t | 远端MR的地址。 | +| lKey | uint64_t | 本端MR的密钥。 | +| rKey | uint64_t | 远端MR的密钥。 | +| size | uint32_t | 数据大小。 | +| upCtxSize | uint16_t | 用户上下文的大小。 | +| upCtxData | char\[16\] | 用户上下文。 | + +##### ubs_hcom_readwrite_sge + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|----------|-----------|----------------| +| lAddress | uintptr_t | 本端MR的地址。 | +| rAddress | uintptr_t | 远端MR的地址。 | +| lKey | uint64_t | 本端MR的密钥。 | +| rKey | uint64_t | 远端MR的密钥。 | +| size | uint32_t | 数据大小。 | + +##### ubs_hcom_readwrite_request_sgl + +1. 参数说明 + +| 配置项 | 类型 | 默认值 | 说明 | +|----|----|----|----| +| \*iov | [Net_ReadWriteSge](#ZH-CN_TOPIC_0000002465376958) | \- | 消息数组。 | +| iovCount | uint16_t | \- | 小于max count(NET_SGE_MAX_IOV)。 | +| upCtxSize | uint16_t | \- | 上下文大小。 | +| upCtxData\[16\] | char | \- | 上下文数据。 | + +##### ubs_hcom_memory_region_info + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|----------|-----------|------------| +| lAddress | uintptr_t | MR的地址。 | +| lKey | uint64_t | MR的key。 | +| size | uint32_t | MR的大小。 | + +##### ubs_hcom_request_context + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|----|----|----| +| type | [Net_RequestType](#ZH-CN_TOPIC_0000002498615389) | 请求的操作类型\[0, 8\]。 | +| opCode | uint16_t | 操作码,取值范围\[0, 1023\]。 | +| flags | uint16_t | header中的标志位。 | +| timeout | int16_t | 超时时间。 | +| errorCode | int16_t | 错误码。 | +| result | int | 结果值0代表成功。 | +| msgData | void \* | 数据指针。用于接收操作。 | +| msgSize | uint32_t | 数据大小。用于接收操作。 | +| seqNo | uint32_t | 序列号。用于post send raw。 | +| ep | ubs_hcom_endpoint | 建链创建好的EP对象。 | +| originalSend | [Net_SendRequest](#ZH-CN_TOPIC_0000002465536686) | 用于C_OP_REQUEST_POSTED复制的结构体信息。 | +| originalReq | [Net_ReadWriteRequest](#ZH-CN_TOPIC_0000002465536694) | 用于C_OP_READWRITE_DONE复制的结构体信息。 | +| originalSglReq | [Net_ReadWriteSglRequest](#ZH-CN_TOPIC_0000002498495605) | 用于C_OP_READWRITE_DONE复制的结构体信息。 | + +##### ubs_hcom_response_context + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|---------|----------|----------------| +| opCode | uint16_t | 操作编号。 | +| seqNo | uint32_t | 序列号。 | +| msgData | void \* | 接收到的消息。 | +| msgSize | uint32_t | 消息长度。 | + +##### ubs_hcom_uds_id_info + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|--------|----------|----------| +| pid | uint32_t | 进程ID。 | +| uid | uint32_t | 用户ID。 | +| gid | uint32_t | 组ID。 | + +##### ubs_hcom_driver_opts + +1. 参数说明 + +[TABLE] + +![](media/image8.png) + +UBS Comm默认开启TLS认证,关闭认证可能存在安全风险,用户可通过enableTls = false进行关闭。 + +##### ubs_hcom_driver_listen_opts + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|----|----|----| +| name | char\[16\] | 监听的IP地址。长度范围是(0, 16\]。 | +| port | uint16_t | 监听的端口号。范围是\[1024, 65535\]。 | +| targetWorkerCount | uint16_t | 可用worker数量,0代表全部,默认可使用worker数量为全部。 | + +##### ubs_hcom_driver_uds_listen_opts + +1. 参数说明 + +| 参数名 | 数据类型 | 描述 | +|----|----|----| +| name | char\[96\] | 监听的UDS name。长度范围(0, 96)。 | +| perm | uint16_t | 0代表不使用文件,其他情况则使用文件,此参数为其权限,最高为0600。 | +| targetWorkerCount | uint16_t | 可用worker数量,0代表全部,默认是全部。 | + +##### ubs_hcom_memory_allocator_options + +1. 参数说明 + +[TABLE] + +## 枚举值参考 + +### C++枚举值 + +#### 服务层枚举值 + +##### UBSHcomChannelBrokenPolicy + +1. 枚举说明 + +断链策略,如[表3-307](#d1e92463)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|------------|------|----------------------------------------------------| +| BROKEN_ALL | 0 | 当一个EP断开则断开channel。 | +| RECONNECT | 1 | 当一个EP断开尝试重连,若失败则断开channel。 | +| KEEP_ALIVE | 2 | 当一个EP断开,保持其他EP正常功能,直至所有EP断开。 | + +##### Operation + +1. 枚举说明 + +此NetServiceContext所包含的操作类别,如[表3-308](#d1e92547)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|---------------------|------|-------------------| +| SER_RECEIVED | 0 | 接收到新消息。 | +| SER_RECEIVED_RAW | 1 | 接收到新raw消息。 | +| SER_SENT | 2 | 消息发送完成。 | +| SER_SENT_RAW | 3 | raw消息发送完成。 | +| SER_ONE_SIDE | 4 | 单边操作完成。 | +| SER_INVALID_OP_TYPE | 255 | 非法操作。 | + +##### UBSHcomClientPollingMode + +1. 枚举说明 + +客户端poll模式 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------|------|------------------------------| +| WORKER_POLL | 0 | 使用worker线程poll。 | +| SELF_POLL | 1 | 使用调用通信接口的线程poll。 | +| UNKNOWN | 255 | 未知。 | + +##### UBSHcomChannelCallBackType + +1. 枚举说明 + +Channel的回调函数类型,如[表3-310](#d1e92742)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------------|------|--------------------------------------------| +| CHANNEL_FUNC_CB | 0 | 会使用用户传入到异步通信方法中的回调函数。 | +| CHANNEL_GLOBAL_CB | 1 | 会使用注册给NetService的回调函数。 | + +##### UBSHcomFlowCtrlLevel + +1. 枚举说明 + +Channel的流控等待策略,如[表3-311](#d1e92816)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|------------------|------|--------------------| +| HIGH_LEVEL_BLOCK | 0 | 忙循环等待。 | +| LOW_LEVEL_BLOCK | 1 | 睡眠指定时长等待。 | + +##### UBSHcomChannelState + +1. 枚举说明 + +Channel状态,如[表3-312](#d1e92890)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|----------------|------|------------| +| CH_NEW | 0 | 新建状态。 | +| CH_ESTABLISHED | 1 | 就绪状态。 | +| CH_CLOSE | 2 | 关闭状态。 | +| CH_DESTROY | 3 | 销毁状态。 | + +##### UBSHcomOobType + +1. 枚举说明 + +建链类型,如[表3-313](#d1e92984)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|--------|------|---------------| +| TCP | 0 | TCP建链方式。 | +| UDS | 1 | UDS建链方式。 | + +##### UBSHcomSecType + +1. 枚举说明 + +建链安全校验,如[表3-314](#d1e93058)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-----------------------|------|------------| +| NET_SEC_DISABLED | 0 | 不校验。 | +| NET_SEC_VALID_ONE_WAY | 1 | 单边校验。 | +| NET_SEC_VALID_TWO_WAY | 2 | 双边校验。 | + +#### 传输层枚举值 + +##### UBSHcomNetEndPointState + +1. 枚举说明 + +描述EP此时所处的状态,如[表3-315](#d1e93169)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-----------------|------|------------| +| NEP_NEW | 0 | 新建状态。 | +| NEP_ESTABLISHED | 1 | 就绪状态。 | +| NEP_BROKEN | 2 | 断开状态。 | +| NEP_BUFF | 3 | \- | + +##### UBSHcomNetCipherSuite + +1. 枚举说明 + +加密算法,如[表3-316](#d1e93263)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------------|------|---------------------| +| AES_GCM_128 | 0 | AES_GCM_128。 | +| AES_GCM_256 | 1 | AES_GCM_256。 | +| AES_CCM_128 | 2 | AES_CCM_128。 | +| CHACHA20_POLY1305 | 3 | CHACHA20_POLY1305。 | + +##### UBSHcomTlsVersion + +1. 枚举说明 + +TLS的版本信息,如[表3-317](#d1e93357)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|---------|------|-----------| +| TLS_1_2 | 771 | 1.2版本。 | +| TLS_1_3 | 772 | 1.3版本。 | + +##### NN_OpType + +1. 枚举说明 + +此UBSHcomNetRequestContext所包含的操作类别,如[表3-318](#d1e93431)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|--------------------|------|-------------------| +| NN_SENT | 0 | 消息发送完成。 | +| NN_SENT_RAW | 1 | raw消息发送完成。 | +| NN_SENT_RAW_SGL | 2 | SGL消息发送完成。 | +| NN_RECEIVED | 3 | 接收到新消息。 | +| NN_RECEIVED_RAW | 4 | 接收到新raw消息。 | +| NN_WRITTEN | 5 | 写操作完成。 | +| NN_READ | 6 | 读操作完成。 | +| NN_SGL_WRITTEN | 7 | SGL写操作完成。 | +| NN_SGL_READ | 8 | SGL读操作完成。 | +| NN_INVALID_OP_TYPE | 255 | 非法操作。 | + +##### UBSHcomNetMemoryAllocatorType + +1. 枚举说明 + +内存分配器类型,如[表3-319](#d1e93585)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------------------|------|------------------------| +| DYNAMIC_SIZE | 0 | 动态大小。 | +| DYNAMIC_SIZE_WITH_CACHE | 1 | 动态大小,配有缓存器。 | + +##### UBSHcomNetMemoryAllocatorCacheTierPolicy + +1. 枚举说明 + +内存分配器的缓存器分级策略,如[表3-320](#d1e93659)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|------------|------|-------------------------| +| TIER_TIMES | 0 | 基准值的倍数策略。 | +| TIER_POWER | 1 | 基准值乘以2的幂数策略。 | + +##### UBSHcomPeerCertVerifyType + +1. 枚举说明 + +对端校验类型,如[表3-321](#d1e93733)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-----------------------|------|--------------------------------| +| VERIFY_BY_NONE | 0 | 对端不需要校验。 | +| VERIFY_BY_DEFAULT | 1 | 对端使用UBS Comm内部校验方式。 | +| VERIFY_BY_CUSTOM_FUNC | 2 | 对端使用用户定义的校验方式。 | + +##### UBSHcomNetDriverSecType + +1. 枚举说明 + +UBSHcomNetDriver校验类型,如[表3-322](#d1e93817)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-----------------------|------|--------------------------------| +| NET_SEC_DISABLED | 0 | 不需要校验。 | +| NET_SEC_VALID_ONE_WAY | 1 | 单边校验,仅服务端校验客户端。 | +| NET_SEC_VALID_TWO_WAY | 2 | 双边校验。 | + +##### NetDriverOobType + +1. 枚举说明 + +OOB建链时协议,如[表3-323](#d1e93902)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------|------|----------------------------------| +| NET_OOB_TCP | 0 | TCP协议。 | +| NET_OOB_UDS | 1 | UDS协议。 | +| NET_OOB_UB | 2 | UBC自举建链,仅支持UBC协议配置。 | + +##### UBSHcomNetDriverWorkingMode + +1. 枚举说明 + +worker线程工作模式,如[表3-324](#d1e93986)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------------|------|-----------------| +| NET_BUSY_POLLING | 0 | busy polling。 | +| NET_EVENT_POLLING | 1 | event polling。 | + +##### UBSHcomNetDriverLBPolicy + +1. 枚举说明 + +worker线程分配策略,如[表3-325](#d1e94060)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|------------------|------|------------------------------------| +| NET_ROUND_ROBIN | 0 | 轮询策略。 | +| NET_HASH_IP_PORT | 1 | 根据IP和Port进行取哈希值分配策略。 | + +##### UBSHcomNetDriverProtocol + +1. 枚举说明 + +UBSHcomNetDriver通信时协议。 + +| 枚举名 | 数值 | 描述 | +|--------------|------|----------------------| +| RDMA | 0 | RDMA。 | +| TCP | 1 | TCP。 | +| UDS | 2 | UDS。 | +| SHM | 3 | SHM。 | +| RDMA_MLX5_RC | 4 | 需求MLX5网卡的RDMA。 | +| UBC | 7 | UBC。 | +| HSHMEM | 8 | HSHMEM。 | +| UNKNOWN | 255 | 不支持协议。 | + +##### UBSHcomUbcMode + +1. 枚举说明 + +UBC协议专用能力。UB-C 具有多路径能力,发送时使用多条路径可以增大带宽,对于带宽要求不高、时延敏感型业务又提供单路径直连模式。 + +| 枚举名 | 数值 | 描述 | +|---------------|------|--------------------------------| +| Disabled | -1 | 禁用多路径能力(默认)。 | +| LowLatency | 0 | 低时延模式,使用单路径发送。 | +| HighBandwidth | 1 | 高带宽模式,使用多条路径发送。 | + +### C枚举值 + +#### 服务层枚举值 + +##### ubs_hcom_channel_cb_type + +1. 枚举说明 + +Channel的回调函数类型,如[表3-326](#d1e94396)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|---------------------|------|--------------------------------------------| +| C_CHANNEL_FUNC_CB | 0 | 会使用用户传入到异步通信方法中的回调函数。 | +| C_CHANNEL_GLOBAL_CB | 1 | 会使用注册给NetService的回调函数。 | + +##### ubs_hcom_service_context_type + +1. 枚举说明 + +此NetServiceContext所包含的操作类别,如[表3-327](#d1e94470)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|---------------------|------|-------------------| +| SER_RECEIVED | 0 | 接收到新消息。 | +| SER_RECEIVED_RAW | 1 | 接收到新raw消息。 | +| SER_SENT | 2 | 消息发送完成。 | +| SER_SENT_RAW | 3 | raw消息发送完成。 | +| SER_ONE_SIDE | 4 | 单边操作完成。 | +| SERVICE_RNDV | 5 | rndv请求。 | +| SER_INVALID_OP_TYPE | 255 | 非法操作。 | + +##### ubs_hcom_channel_flowctl_level + +1. 枚举说明 + +Channel的流控等待策略,如[表3-328](#d1e94594)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|------------------|------|--------------------| +| HIGH_LEVEL_BLOCK | 0 | 忙循环等待。 | +| LOW_LEVEL_BLOCK | 1 | 睡眠指定时长等待。 | + +##### ubs_hcom_service_worker_mode + +1. 枚举说明 + +worker线程工作模式,如[表3-329](#d1e94668)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------------------|------|-----------------| +| C_SERVICE_BUSY_POLLING | 0 | busy polling。 | +| C_SERVICE_EVENT_POLLING | 1 | event polling。 | + +##### ubs_hcom_service_lb_policy + +1. 枚举说明 + +worker线程分配策略,如[表3-330](#d1e94742)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|----------------------|------|------------------------------------| +| SERVICE_ROUND_ROBIN | 0 | 轮询策略。 | +| SERVICE_HASH_IP_PORT | 1 | 根据IP和Port进行取哈希值分配策略。 | + +##### ubs_hcom_service_cipher_suite + +1. 枚举说明 + +加密算法,如[表3-331](#d1e94816)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-----------------------------|------|---------------------| +| C_SERVICE_AES_GCM_128 | 0 | AES_GCM_128。 | +| C_SERVICE_AES_GCM_256 | 1 | AES_GCM_256。 | +| C_SERVICE_AES_CCM_128 | 2 | AES_CCM_128。 | +| C_SERVICE_CHACHA20_POLY1305 | 3 | CHACHA20_POLY1305。 | + +##### ubs_hcom_service_tls_version + +1. 枚举说明 + +TLS的版本信息,如[表3-332](#d1e94910)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------------|------|-----------| +| C_SERVICE_TLS_1_2 | 771 | 1.2版本。 | +| C_SERVICE_TLS_1_3 | 772 | 1.3版本。 | + +##### ubs_hcom_service_secure_type + +1. 枚举说明 + +NetService校验类型,如[表3-333](#d1e94984)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|---------------------------------|------|--------------------------------| +| C_SERVICE_NET_SEC_DISABLED | 0 | 不需要校验。 | +| C_SERVICE_NET_SEC_VALID_ONE_WAY | 1 | 单边校验,仅服务端校验客户端。 | +| C_SERVICE_NET_SEC_VALID_TWO_WAY | 2 | 双边校验。 | + +##### ubs_hcom_service_channel_policy + +1. 枚举说明 + +Channel断链策略,如[表3-334](#d1e95069)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|----|----|----| +| C_CHANNEL_BROKEN_ALL | 0 | 当一个EP断开则断开channel。 | +| C_CHANNEL_RECONNECT | 1 | 当一个EP断开尝试重连,若失败则断开channel。 | +| C_CHANNEL_KEEP_ALIVE | 2 | 当一个EP断开,保持其他EP正常功能,直至所有EP断开。 | + +##### ubs_hcom_service_channel_handler_type + +1. 枚举说明 + +链路相关的回调函数类型,如[表3-335](#d1e95153)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|------------------|------|--------------------| +| C_CHANNEL_NEW | 0 | 新建链的回调函数。 | +| C_CHANNEL_BROKEN | 1 | 断链的回调函数。 | + +##### ubs_hcom_service_handler_type + +1. 枚举说明 + +通信相关的回调函数类型,如[表3-336](#d1e95227)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|----------------------------|------|--------------------------| +| C_SERVICE_REQUEST_RECEIVED | 0 | 接收新消息的回调函数。 | +| C_SERVICE_REQUEST_POSTED | 1 | 消息发送完成的回调函数。 | +| C_SERVICE_READWRITE_DONE | 2 | 读写完成的回调函数。 | + +##### ubs_hcom_service_type + +1. 枚举说明 + +NetService通信时协议,如[表3-337](#d1e95311)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|------------------|------|------------------------------| +| C_SERVICE_RDMA | 0 | RDMA。 | +| C_SERVICE_TCP | 1 | TCP。 | +| C_SERVICE_UDS | 2 | UDS。 | +| C_SERVICE_SHM | 3 | SHM。 | +| C_SERVICE_UBC | 6 | UBC。 | +| C_SERVICE_HSHMEM | 7 | HSHMEM(北冥版本暂不支持)。 | + +##### ubs_hcom_service_polling_mode + +#### 传输层枚举值 + +##### ubs_hcom_request_type + +1. 枚举说明 + +此UBSHcomNetRequestContext所包含的操作类别,如[表3-338](#d1e95479)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|----------------|------|-------------------| +| C_SENT | 0 | 消息发送完成。 | +| C_SENT_RAW | 1 | raw消息发送完成。 | +| C_SENT_RAW_SGL | 2 | SGL消息发送完成。 | +| C_RECEIVED | 3 | 接收到新消息。 | +| C_RECEIVED_RAW | 4 | 接收到新raw消息。 | +| C_WRITTEN | 5 | 写操作完成。 | +| C_READ | 6 | 读操作完成。 | +| C_SGL_WRITTEN | 7 | SGL写操作完成。 | +| C_SGL_READ | 8 | SGL读操作完成。 | + +##### ubs_hcom_driver_working_mode + +1. 枚举说明 + +worker线程工作模式,如[表3-339](#d1e95623)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-----------------|------|-----------------| +| C_BUSY_POLLING | 0 | busy polling。 | +| C_EVENT_POLLING | 1 | event polling。 | + +##### ubs_hcom_driver_type + +1. 枚举说明 + +UBSHcomNetDriver通信时协议,如[表3-340](#d1e95697)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|---------------|------|--------| +| C_DRIVER_RDMA | 0 | RDMA。 | +| C_DRIVER_TCP | 1 | TCP。 | +| C_DRIVER_UDS | 2 | UDS。 | +| C_DRIVER_SHM | 3 | SHM。 | +| C_DRIVER_UBC | 6 | UBC。 | + +##### ubs_hcom_driver_oob_type + +1. 枚举说明 + +OOB建链时协议,如[表3-341](#d1e95801)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|---------------|------|-----------| +| C_NET_OOB_TCP | 0 | TCP协议。 | +| C_NET_OOB_UDS | 1 | UDS协议。 | + +##### ubs_hcom_driver_sec_type + +1. 枚举说明 + +UBSHcomNetDriver校验类型,如[表3-342](#d1e95875)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------------------|------|--------------------------------| +| C_NET_SEC_DISABLED | 0 | 不需要校验。 | +| C_NET_SEC_VALID_ONE_WAY | 1 | 单边校验,仅服务端校验客户端。 | +| C_NET_SEC_VALID_TWO_WAY | 2 | 双边校验。 | + +##### ubs_hcom_driver_tls_version + +1. 枚举说明 + +TLS的版本信息,如[表3-343](#d1e95959)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-----------|------|-----------| +| C_TLS_1_2 | 771 | 1.2版本。 | +| C_TLS_1_3 | 772 | 1.3版本。 | + +##### ubs_hcom_driver_cipher_suite + +1. 枚举说明 + +加密算法,如[表3-344](#d1e96033)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|---------------------|------|---------------------| +| C_AES_GCM_128 | 0 | AES_GCM_128。 | +| C_AES_GCM_256 | 1 | AES_GCM_256。 | +| C_AES_CCM_128 | 2 | AES_CCM_128。 | +| C_CHACHA20_POLY1305 | 3 | CHACHA20_POLY1305。 | + +##### ubs_hcom_peer_cert_verify_type + +1. 枚举说明 + +对端校验类型,如[表3-345](#d1e96127)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------------------|------|--------------------------------| +| C_VERIFY_BY_NONE | 0 | 对端不需要校验。 | +| C_VERIFY_BY_DEFAULT | 1 | 对端使用UBS Comm内部校验方式。 | +| C_VERIFY_BY_CUSTOM_FUNC | 2 | 对端使用用户定义的校验方式。 | + +##### ubs_hcom_memory_allocator_cache_tier_policy + +1. 枚举说明 + +内存分配器的缓存器分级策略,如[表3-346](#d1e96212)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|--------------|------|-------------------------| +| C_TIER_TIMES | 0 | 基准值的倍数策略。 | +| C_TIER_POWER | 1 | 基准值乘以2的幂数策略。 | + +##### ubs_hcom_memory_allocator_type + +1. 枚举说明 + +内存分配器类型,如[表3-347](#d1e96286)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|---------------------------|------|------------------------| +| C_DYNAMIC_SIZE | 0 | 动态大小。 | +| C_DYNAMIC_SIZE_WITH_CACHE | 1 | 动态大小,配有缓存器。 | + +##### ubs_hcom_ep_handler_type + +1. 枚举说明 + +链路相关的回调函数类型,如[表3-348](#d1e96360)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-------------|------|--------------------| +| C_EP_NEW | 0 | 新建链的回调函数。 | +| C_EP_BROKEN | 1 | 断链的回调函数。 | + +##### ubs_hcom_op_handler_type + +1. 枚举说明 + +通信相关的回调函数类型,如[表3-349](#d1e96434)所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|-----------------------|------|--------------------------| +| C_OP_REQUEST_RECEIVED | 0 | 接收新消息的回调函数。 | +| C_OP_REQUEST_POSTED | 1 | 消息发送完成的回调函数。 | +| C_OP_READWRITE_DONE | 2 | 读写完成的回调函数。 | + +##### ubs_hcom_polling_mode + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|------------------------|------|-----------------| +| NET_C_EP_SELF_POLLING | 0 | self ep模式。 | +| NET_C_EP_EVENT_POLLING | 1 | 非self ep模式。 | + +##### ubs_hcom_service_polling_mode + +1. 枚举说明 + +client poll信息模式,如下表所示。 + +1. 枚举说明 + +| 枚举名 | 数值 | 描述 | +|----------------------|------|--------------------| +| C_CLIENT_WORKER_POLL | 0 | 非self poll 模式。 | +| C_CLIENT_SELF_POLL | 1 | self poll模式。 | + +# 环境变量参考 + +1. 环境变量参数 + +[TABLE] + +环境变量设置示例如下: + +export HCOM_FILE_PATH_PREFIX="/home/uds/socket/file" +export HCOM_OPENSSL_PATH="/home/openssl" +export HCOM_TRACE_LEVEL=0 +export HCOM_QP_TRAFFIC_CLASS=106 +export HCOM_SHM_EXCHANGE_FD_QUEUE_SIZE=10 +export HCOM_CONNECTION_RETRY_TIMES=5 +export HCOM_CONNECTION_RETRY_INTERVAL_SEC=2 +export HCOM_SET_LOG_LEVEL=1 + +# 错误码 + +[5.1 服务层错误码](#服务层错误码) + +[5.2 传输层错误码](#传输层错误码) + +[5.3 RDMA协议错误码](#rdma协议错误码) + +## 服务层错误码 + +1. 服务层错误码 + +| 错误码数 | 错误码 | 含义 | +|----------|-----------------------------------|--------------------------------| +| 0 | SER_OK | 成功。 | +| 500 | SER_ERROR | 内部错误。 | +| 501 | SER_INVALID_PARAM | 无效参数。 | +| 502 | SER_NEW_OBJECT_FAILED | 对象生成失败。 | +| 503 | SER_CREATE_TIMEOUT_THREAD_FAILED | 创建超时处理线程失败。 | +| 504 | SER_NEW_MESSAGE_DATA_FAILED | 生成消息失败。 | +| 505 | SER_NOT_ESTABLISHED | NetChannel未建链。 | +| 506 | SER_STORE_SEQ_DUP | 序列号重复。 | +| 507 | SER_STORE_SEQ_NO_FOUND | 序列号不存在。 | +| 508 | SER_RSP_SIZE_TOO_SMALL | 消息大小不一致。 | +| 509 | SER_TIMEOUT | 超时。 | +| 510 | SER_TIMER_NOT_WORK | 超时处理线程开启失败。 | +| 511 | SER_NOT_ENABLE_RNDV | 开启Rndv失败。 | +| 512 | SER_RNDV_FAILED_BY_PEER | 对端使用Rndv失败。 | +| 513 | SER_CHANNEL_ID_DUP | Channel Id重复。 | +| 514 | SER_EP_NOT_BROKEN_ALL | NetChannel中所有EP未发生断链。 | +| 515 | SER_CHANNEL_NOT_EXIST | NetChannel不存在。 | +| 516 | SER_CHANNEL_RECONNECT_OVER_WINDOW | \- | +| 517 | SER_EP_BROKEN_DURING_CONNECTING | NetChannel中所有EP均断链。 | +| 518 | SER_NOT_SUPPORT_SERVER_RECONNECT | 不支持重建链。 | +| 519 | SER_STOP | 服务停止。 | +| 520 | SER_NULL_INSTANCE | 空指针。 | +| 521 | SER_UNSUPPORTED | 不支持的操作。 | +| 522 | SER_INVALID_IP | 非法IP。 | +| 523 | SER_MALLOC_FAILED | 分配内存失败。 | +| 524 | SER_SPLIT_INVALID_MSG | 拆包消息无效。 | + +## 传输层错误码 + +1. 传输层错误码 + +[TABLE] + +![](media/image8.png) + +部分常见错误码详细说明: + +- 114:在RDMA和TCP协议的双边非SGL通信方式时,为了发送消息的持久化和RDMA特性需求,需要把用户发送的消息内容拷贝到UBS Comm内部预申请的内存中。但是在并发很大的情况下,可能将预申请的内存耗尽,在耗尽的时候如果再发送双边非SGL消息时就会产生此错误码。解决方式可以是调大UBSHcomNetDriverOptions中的mrSendReceiveSegCount参数来扩大预申请内存;如果是对端接收压力过大导致本端发送也可以调整对端接收队列的长度prePostReceiveSizePerQP。 + +- 128:在进行建链的时候客户端建链失败时会返回此错误。请检查服务端是否启动并且启动监听线程,然后检查客户端发起建链的IP地址和端口是否和服务端监听的一致,推荐先启动服务端,再使用客户端去建链。 + +## RDMA协议错误码 + +1. RDMA协议错误码 + +| 错误码数 | 错误码 | 含义 | +|----|----|----| +| 0 | RR_OK | 成功。 | +| 200 | RR_PARAM_INVALID | 参数无效。 | +| 201 | RR_MEMORY_ALLOCATE_FAILED | 分配内存失败。 | +| 202 | RR_NEW_OBJECT_FAILED | 创建对象失败。 | +| 203 | RR_OPEN_FILE_FAILED | 打开文件失败。 | +| 204 | RR_READ_FILE_FAILED | 读取文件失败。 | +| 205 | RR_DEVICE_FAILED_OPEN | 得到RDMA设备失败。 | +| 206 | RR_DEVICE_INDEX_OVERFLOW | RDMA设备序号异常。 | +| 207 | RR_DEVICE_OPEN_FAILED | 打开RDMA设备失败。 | +| 208 | RR_DEVICE_FAILED_GET_IF_ADDRESS | 获得网卡地址失败。 | +| 209 | RR_DEVICE_NO_IF_MATCHED | 获得符合IP地址的网卡地址失败。 | +| 210 | RR_DEVICE_NO_IF_TO_GID_MATCHED | 获得符合IP地址的RDMA设备GID。 | +| 211 | RR_DEVICE_INVALID_IP_MASK | IP地址掩码异常。 | +| 212 | RR_MR_REG_FAILED | Memory Region(MR)注册失败。 | +| 213 | RR_CQ_NOT_INITIALIZED | Completion Queue(CQ)初始化失败。 | +| 214 | RR_CQ_POLLING_FAILED | Poll CQ方法异常。 | +| 215 | RR_CQ_POLLING_TIMEOUT | Poll CQ超时。 | +| 216 | RR_CQ_POLLING_ERROR_RESULT | Poll CQ结果错误。 | +| 217 | RR_CQ_POLLING_UNMATCHED_OPCODE | Poll CQ结果opcode不匹配。 | +| 218 | RR_CQ_EVENT_GET_FAILED | Poll事件失败。 | +| 219 | RR_CQ_EVENT_NOTIFY_FAILED | 通知CQ失败。 | +| 220 | RR_CQ_WC_WRONG | poll CQ后的完成事件的状态异常。 | +| 221 | RR_CQ_EVENT_GET_TIMOUT | poll CQ超时。 | +| 222 | RR_QP_CREATE_FAILED | 创建Queue Pair(QP)失败。 | +| 223 | RR_QP_NOT_INITIALIZED | 初始化QP失败。 | +| 224 | RR_QP_CHANGE_STATE_FAILED | 更新QP状态失败。 | +| 225 | RR_QP_POST_RECEIVE_FAILED | 发起接收请求失败。 | +| 226 | RR_QP_POST_SEND_FAILED | 发起发送请求失败。 | +| 227 | RR_QP_POST_READ_FAILED | 发起读取请求失败。 | +| 228 | RR_QP_POST_WRITE_FAILED | 发起写请求失败。 | +| 229 | RR_QP_RECEIVE_CONFIG_ERR | 收发相关参数设定失败。 | +| 230 | RR_QP_POST_SEND_WR_FULL | 发送队列满。 | +| 231 | RR_QP_ONE_SIDE_WR_FULL | 单边请求队列满。 | +| 232 | RR_QP_CTX_FULL | 上下文耗尽。 | +| 233 | RR_QP_CHANGE_ERR | 更新QP状态至停止失败。 | +| 234 | RR_OOB_LISTEN_SOCKET_ERROR | 带外链路监听开启失败。 | +| 235 | RR_OOB_CONN_SEND_ERROR | 带外链路发送失败。 | +| 236 | RR_OOB_CONN_RECEIVE_ERROR | 带外链路接收失败。 | +| 237 | RR_OOB_CONN_CB_NOT_SET | 带外链路连接回调未设置。 | +| 238 | RR_OOB_CLIENT_SOCKET_ERROR | 带外链路客户端发起连接失败。 | +| 239 | RR_OOB_SSL_INIT_ERROR | 加密初始化失败。 | +| 240 | RR_OOB_SSL_WRITE_ERROR | 加密写失败。 | +| 241 | RR_OOB_SSL_READ_ERROR | 加密读失败。 | +| 242 | RR_EP_NOT_INITIALIZED | EP未初始化。 | +| 243 | RR_WORKER_NOT_INITIALIZED | Worker未初始化。 | +| 244 | RR_WORKER_BIND_CPU_FAILED | Worker线程绑定CPU失败。 | +| 245 | RR_WORKER_REQUEST_HANDLER_NOT_SET | Worker的新消息回调函数未注册。 | +| 246 | RR_WORKER_SEND_POSTED_HANDLER_NOT_SET | Worker的消息发送回调函数未注册。 | +| 247 | RR_WORKER_ONE_SIDE_DONE_HANDLER_NOT_SET | Worker的单边消息回调函数未注册。 | +| 248 | RR_WORKER_FAILED_ADD_QP | Worker线程添加QP失败。 | +| 249 | RR_HEARTBEAT_CREATE_EPOLL_FAILED | 心跳检测创建失败。 | +| 250 | RR_HEARTBEAT_SET_SOCKET_OPT_FAILED | 心跳检测设置失败。 | +| 251 | RR_HEARTBEAT_IP_ALREADY_EXISTED | 心跳检测IP地址已存在。 | +| 252 | RR_HEARTBEAT_IP_ADD_FAILED | 心跳检测IP地址添加失败。 | +| 253 | RR_HEARTBEAT_IP_ADD_EPOLL_FAILED | 心跳检测IP地址添加失败。 | +| 254 | RR_HEARTBEAT_IP_REMOVE_EPOLL_FAILED | 心跳检测IP地址移除失败。 | +| 255 | RR_HEARTBEAT_IP_NO_FOUND | 心跳检测IP地址未匹配。 | + +![](media/image8.png) + +部分常见错误码详细说明: + +230:RDMA的双边请求发起时,有限制长度的发送队列来限制并发,如果并发过大时,可能耗尽队列导致出现此错误。解决方式可以通过调大UBSHcomNetDriverOptions中的prePostReceiveSizePerQP和qpSendQueueSize来扩大队列,这个队列的值是取上述两个参数的较小值。 diff --git a/doc/UBS-COMM-Architecture-Design-Specification.md b/doc/UBS-COMM-Architecture-Design-Specification.md new file mode 100644 index 0000000000000000000000000000000000000000..3664ac69c8cd6c08018fc394ec127fc95b1aa982 --- /dev/null +++ b/doc/UBS-COMM-Architecture-Design-Specification.md @@ -0,0 +1,778 @@ +# Summary + +## 目的 + +本文介绍了UBS COMM 通信子系统的整体架构和设计原则,用于设计人员和开发人员理解系统的架构和设计原则,指导设计人员进行系统和特性设计,指导开发人员进行开发工作。 + +## 范围 + +本文作为MatrixServer通信子系统架构设计,整体包括2个通信组件,HCOM和URMA。主要包括的关键领域特性为:全面支持灵衢2.0网络通信,北向提供Socket生态和纯UB生态两种通信方式,使能应用平滑迁移灵衢架构。本文的架构设计以总-分方式,由通信系统整体设计到各组件设计。 + +## 利益相关人 +| 利益相关人| 关注点与需求| +| ------------ | ------------ | +|通信子系统架构师 |1. 负责定义LingQu BeiMing-TD 2.0通信子系统架构
2. 关注包括安全可信在内的架构DFx属性
3. 关注架构演进与架构中长期竞争力 | +|通信子系统设计师 |1. 负责确定特性级架构与技术方案
2. 关注特性模块设计是否遵循整体架构设计原则 | +|通信子系统软件开发人员 |开发实现系统架构及模块,反馈实现中涉及的问题 | +|产品架构师与设计人员 |1. 关注技术项目合入产品版本后对产品现有架构的影响
2. 关注技术项目架构创新能否提升产品的架构竞争力 | +|解决方案集成人员(含ISV/开发者) |1. 清晰的开发者界面,明确的开发接口、架构与设计原则与约束、文档与指导手册,方便进行解决方案集成
2. 完善的开发工作链 | + +## 对已有架构的借鉴和反思 + +### HCOM vs UCX +在通算数据库场景、虚拟化场景、大数据场景、以及项目内部场景等多个场景都提出了通过一个通讯组件屏蔽下层的多种协议(RDMA、TCP、UDS、SHM、URMA等),这个组件向上提供一组统一的API,这样可以简化上层软件的开发难度;因为RDMA、TCP、UDS、SHM、URMA存在多个方面的差异,概要如下: +1. 功能: RDMA、URMA既有双边通信的能力,又有单边通信的能力;socket只有双边的能力;SHM既没有双边的能力,也没有单边的能力; TCP、UDS拥有自建链的能力,而RDMA、URMA、SHM没有自建链接的能力,需求借助TCP或其他通道来建立链接; +2. 工作模式: RDMA、URMA为proactive的模式(接收方可以不干预,发送方直接操作接收方内存),Socket/UDS为reactive模式(接收方主动recv); +3. API: RDMA、URMA接口近似,TCP Socket/UDS较为相近,SHM与其它又不一样; +4. DFX: 安全方面, RDMA、URMA、SHM软件没有配套的安全能力,TCP配合openssl的TLS能力达成较好的安全能力; 链接存在的检测,TCP有Keepalive机制,URMA、RDMA、SHM没有这样的能力; + +该组件必须保证: +1. 暴露能力的最大集(单边+双边); +2. 简化底层API的复杂性(比如RDMA、URMA); +3. 接口统一; +4. 同一个API行为完成一致; +5. 足够的DFX能力; + +在业界,开源组件UCX(Unified Communication - X Framework) 有这样的架构,它分为UCT和UCP两层,UCP上层提供统一接口、UCT是对不同底层协议的封装,包括TCP\RDMA\SHM等;在设计HCOM之前,我们使用UCX做基础去构建统一封装,在UCP的基础上补齐了多线程的支持、辅助建链、安全相关的能力。 +![](./images/ucx.png) + +通过不断的使用和实践后,发现本质问题,即UCX为HPC的集合通讯而设计,将其使用在client/server这种场景下产生了不可弥补的gap,包括: +1. UCX的TCP、UDS、SHM的语义实现与RDMA不对等,即行为不一致; +2. UCX的单线程友好,但多线程的性能有gap, 因为UCX的主要使用者是MPI, MPI基本是单线程;而且使用在多线程的程序里,有多种意想不到的race condition发生,难以处理的core dump; +3. 建链方面,UCX需要双向建TCP,而不是client向server connect的单向;双向建立tcp会引入端口过多也不必要的问题; +4. 资源释放,UCX的使用者是MPI,这类程序的进程是同生共死的,不是client/server, 导致资源释放不及时等问题; + +由于上述原因,我们没有选择UCX, 也没有选择在UCX的基础上做修改,而且采用开发HCOM以应对client/server的融合通信组件。这样HCOM也有非常明确的定位与发力方向,UCX定位于集合通讯,而HCOM定位于client/server的通信场景;2个组件独立发展,技术互相借鉴。 + +### UCM vs RDMA CM +UCM即UB Communication Manager, 主要负责UB通信前的建链,类似RDMA的CM。 +UB的RC/RM通信之前必须建立两边的信息交换通道。这个通道可以有三种方式: +1. 利用UCM(公知jetty)建链 +2. 利用TCP Socket建链 +3. 利用UNIC/IPoUB建链 + +第1种方式即所谓的带内建链,可直接跑在纯UB环境中,且Bypass了TCP/IP内核协议栈,性能更优。 + +第2种方式依赖网卡,依赖IP/Socket通道, 适用于有网卡的业务场景。 + +第3方式需要MatrixServer管控面集中分配和管理IP,而且依然需要走完整的TCP/IP内核协议栈,在大规模建链场景下性能较差,同时IP无法由用户管理配置,有违用户对传统IP的认知。 + +综上,MatrixServer在没有TCP Socket的情况下采用UCM(公知jetty)的建链方式,同时通信子系统支持TCP Socket建链方式作为可选。 + + +# Usage Example + +## 用例视图 + +### 上下文模型 + +#### 上下文图 + +MatrixServer通信子系统(HCOM、URMA)北向对接分布式计算应用(数据库、大数据、虚机热迁),东西向对接MatrixServer逻辑资源管理模块(MXE), 南向对接UB硬件传输协议。 + +- HCOM上下文模型 + +![](./images/4.1.1_1.png) + + +- URMA上下文模型 + +![](./images/4.1.1_2.png) + +#### 外部接口描述 + +- HCOM外部接口描述 + +表1 HCOM外部关键接口 + +|接口类别 |接口 |类型 |功能 |备注 | +| ------------ | ------------ | ------------ | ------------ | ------------ | +|北向接口|static UBSHcomService* UBSHcomService::Create(UBSHcomServiceProtocol t, const std::string &name, const UBSHcomServiceOptions &opt = {}); |服务层|创建实例对象 | | +|北向接口|static int32_t UBSHcomService::Destroy(const std::string &name); |服务层|销毁实例对象 | | +|北向接口|int32_t UBSHcomService::Start() |服务层|启动实例 | | +|北向接口|int32_t UBSHcomService::Bind(const std::string &listenerUrl, const UBSHcomServiceNewChannelHandler &handler) |服务层|服务端绑定监听的url和端口号 | | +|北向接口|int32_t UBSHcomService::Connect(const std::string &serverUrl, UBSHcomChannelPtr &ch, const UBSHcomConnectOptions &opt = {}) |服务层|建立连接 | | +|北向接口|void UBSHcomService::Disconnect(const UBSHcomChannelPtr &ch) |服务层|断开链接 | | +|北向接口|int32_t UBSHcomService::RegisterMemoryRegion(uint64_t size, UBSHcomRegMemoryRegion &mr)
int32_t UBSHcomService::RegisterMemoryRegion(uintptr_t address, uint64_t size, UBSHcomRegMemoryRegion &mr) |服务层|1.注册一个内存区域,内存将在UBS Comm内部分配。
2.将用户申请的内存,注册到UBS Comm中。| | +|北向接口|int32_t UBSHcomChannel::Send(const UBSHcomRequest &req, const Callback *done)
int32_t UBSHcomChannel::Send(const UBSHcomRequest &req) |服务层|1. 向对端异步发送一个双边请求消息,并且不等待响应。
2. 向对端同步发送一个双边请求消息,并且不等待响应。| | +|北向接口|int32_t UBSHcomChannel::Get(const UBSHcomOneSideRequest &req, const Callback *done)
int32_t UBSHcomChannel::Get(const UBSHcomOneSideRequest &req) |服务层|1.同步模式下,发送一个读请求给对方。
2.异步模式下,发送一个读请求给对方。 | | +|北向接口|int32_t UBSHcomChannel::Put(const UBSHcomOneSideRequest &req, const Callback *done)
int32_t UBSHcomChannel::Put(const UBSHcomOneSideRequest &req) |服务层|1.同步模式下,发送一个写请求给对方。
异步模式下,发送一个写请求给对方。 | | +|北向接口|static UBSHcomNetDriver *UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol t, const std::string &name, bool startOobSvr) |传输层|创建实例对象 | | +|北向接口|static NResult UBSHcomNetDriver::DestroyInstance(const std::string &name) |传输层|销毁实例对象 | | +|北向接口|NResult UBSHcomNetDriver::Initialize(const UBSHcomNetDriverOptions &option) |传输层|初始化实例 | | +|北向接口|void UBSHcomNetDriver::UnInitialize() |传输层|反初始化实例 | | +|北向接口|NResult UBSHcomNetDriver::Start() |传输层|启动实例 | | +|北向接口|void UBSHcomNetDriver::Stop() |传输层|停止实例 | | +|北向接口|void UBSHcomNetDriver::OobIpAndPort(const std::string &ip, uint16_t port) |传输层|配置监听的IP和Port| | +|北向接口|NResult UBSHcomNetDriver::Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo)
NResult UBSHcomNetDriver::Connect(const std::string &serverUrl, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo) |传输层|建立连接 | | +|北向接口|NResult UBSHcomNetEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNo) |传输层|发送双边通信请求 | | +|北向接口|NResult UBSHcomNetEndpoint::PostRead(const UBSHcomNetTransRequest &request) |传输层|单边读请求 | | +|北向接口|NResult UBSHcomNetEndpoint::PostWrite(const UBSHcomNetTransRequest &request) |传输层|单边写请求 | | + + +【备注】更详细接口在HCOM API使用手册中体现 + +- HCOM服务层常见错误码说明 + +表2 + +|错误码数 |错误码 |含义 |推荐处理方式 | +| ------------ | ------------ | ------------ | ------------ | +| 0 | SER_OK | 成功。 | 无 | +| 500 | SER_ERROR | 内部错误。 | 排查日志或者联系HCOM问题接口人处理 | +| 501 | SER_INVALID_PARAM | 无效参数。 | 检查参数 | +| 502 | SER_NEW_OBJECT_FAILED | 对象生成失败。 | 一般是new失败,检查资源是否充足 | +| 503 | SER_CREATE_TIMEOUT_THREAD_FAILED | 创建超时处理线程失败。 | 按照通用创建线程失败方式进行排查处理,比如检查是否资源不足,线程数是否达到上限等 | +| 504 | SER_NEW_MESSAGE_DATA_FAILED | 生成消息失败。 | 一般是malloc失败,检查资源是否充足 | +| 505 | SER_NOT_ESTABLISHED | NetChannel未建链。 | 链接未建立,需要先通过connect建链 | +| 506 | SER_STORE_SEQ_DUP | 序列号重复。 | 极端情况偶发错误,尝试重试 | +| 507 | SER_STORE_SEQ_NO_FOUND | 序列号不存在。 | 极端情况偶发错误,尝试重试 | +| 508 | SER_RSP_SIZE_TOO_SMALL | 消息大小不一致。 | 检查传入的数据length是否和实际数据大小一致 | +| 509 | SER_TIMEOUT | 超时。 | 尝试重试或者检查网络状况 | +| 510 | SER_TIMER_NOT_WORK | 超时处理线程开启失败。 | 按照通用创建线程失败方式进行排查处理,比如检查是否资源不足,线程数是否达到上限等 | +| 511 | SER_NOT_ENABLE_RNDV | 开启Rndv失败。 | 未开启Rndv,需要在start时的option中设置开启 | +| 512 | SER_RNDV_FAILED_BY_PEER | 对端使用Rndv失败。 | 需要两端都启用rndv | +| 513 | SER_CHANNEL_ID_DUP | Channel Id重复。 | 尝试重试 | +| 514 | SER_EP_NOT_BROKEN_ALL | NetChannel中所有EP未发生断链。 | 尝试重试 | +| 515 | SER_CHANNEL_NOT_EXIST | NetChannel不存在。 | 使用了不存在的channel | +| 517 | SER_EP_BROKEN_DURING_CONNECTING | NetChannel中所有EP均断链。 | channel中所有的ep都断开,尝试重新建链 | +| 518 | SER_NOT_SUPPORT_SERVER_RECONNECT | 不支持重建链。 | 仅客户端可以重新建链 | +| 519 | SER_STOP | 服务停止。 | 服务已停止 | +| 520 | SER_NULL_INSTANCE | 空指针 | 检查参数是否正确 | +| 521 | SER_UNSUPPORTED | 该功能不支持 | 部分功能仅针对特定场景开放,可联系hcom问题接口人 | +| 522 | SER_INVALID_IP | 无效ip | 检查ip是否正确 | +| 523 | SER_MALLOC_FAILED | malloc失败 | 检查系统资源是否充足 | +| 524 | SER_SPLIT_INVALID_MSG | 通信拆包模式下发生错误 | 查看日志获得更详细的报错,检查包大小是否在范围内,空间是否充足,是否有外部攻击者伪造的包等 | + + + +- HCOM传输层常见错误码说明 + +表3 + +|错误码数 |错误码 |错误码 |推荐处理方式 | +| ------------ | ------------ | ------------ | ------------ | +| 0 | NN_OK | 成功。 | 无 | +| 100 | NN_ERROR | 内部错误。 | 排查日志或者联系HCOM问题接口人处理 | +| 101 | NN_INVALID_IP | 无效IP地址。 | 检查IP地址是否正确 | +| 102 | NN_NEW_OBJECT_FAILED | 创建对象失败。 | 一般为new操作失败,检查系统资源是否充足 | +| 103 | NN_INVALID_PARAM | 参数无效。 | 检查参数是否正确 | +| 104 | NN_TWO_SIDE_MESSAGE_TOO_LARGE | 双边消息size过大。 | 调整双边消息大小,或者调整为单边方式 | +| 105 | NN_INVALID_OPCODE | 无效opCode。 | opCode设置错误 | +| 106 | NN_EP_NOT_ESTABLISHED | EP未建链。 | 链接未建立,需先建立链接 | +| 107 | NN_EP_NOT_INITIALIZED | EP未初始化。 | 链接未建立,需先建立链接 | +| 109 | NN_TIMEOUT | 超时。 | 尝试重试或者检查网络状态 | +| 110 | NN_INVALID_OPERATION | 无效操作。 | 未支持操作,请检查协议类型是否支持相应操作 | +| 111 | NN_MALLOC_FAILED | 获得内存失败。 | 检查系统资源是否充足 | +| 113 | NN_NOT_INITIALIZED | NetDriver未初始化。 | 需要先初始化Driver | + + + +【备注】更详细错误码在HCOM API使用手册中体现 + + +- urma外部接口描述 +详细查看:[URMA对外接口](http://platformdoc.huawei.com/hedex/hdx.do?lib=075356218&v=01%20(2023-04-28)&homepage=resources/hedex-homepage.html&productId=2797) + + +### 关键用例模型 + +#### 关键用例 + +![](./images/4.2.1.png) + +#### 交互场景 + +1. 通算场景使用socket: +- 用户使用socket,通过UB Socket进行通信。 + +2. 通算场景直接使用HCOM API / URMA API: +- 用户直接使用HCOM提供的API进行通信,HCOM底层对接URMA、verbs、socket等。 +- 用户直接使用URMA提供的API进行通信。 + + +## 部署模型 + +### 部署节点及规格定义 + +节点规格: +CPU:1620/1630/1650 +芯片互联:HCCS/UB-C +内存:大于64G + +### 部署模型 + +![部署模型 - HCOM](./images/部署模型-HCOM.png) + +![部署模型 - URMA](./images/部署模型-URMA.png) + +## 运行模型 + +![运行模型-HCOM](./images/运行模型-HCOM.png) + + + + +# Movitvation + +## 架构和关键质量属性目标 + +### 架构目标 + +面向灵衢2.0新一代计算架构,通信子系统作为平台软件,通过在通信协议、通信算法关键优化技术方案充分发挥UB-C组网性能,支撑实现下一代芯片多打一战略目标。 + +对于通信子系统而言,灵衢2.0相比当代计算架构的主要变化为: +- 网络互联全面采用UB协议,组网采用超节点架构,超节点内使用UB-C形成组网(1D FullMesh,2D FullMesh等),超节点间使用TCP/IP或RDMA。 + +此外,考虑生态兼容性和架构归一性,通信子系统的架构演进主要为: +- 分布式通信库需要考虑现有生态的兼容性和易用性,如socket生态。 + +#### UB影响下的架构目标 + +MatrixServer框内通信使用UB-C协议进行通信加速,出框通信使用TCP/IP或者RDMA通信。同时网络形态也分为纯UB生态和UB及TCP/IP共存生态,通信子系统在此次架构设计中需要包含这方面的考虑。 +![](./images/2.1.1.png) + +#### 生态兼容性和易用性下的架构目标 + +灵衢2.0时代,网络发展路径逐步由TCP->RDMA->UB。网络协议和互联技术在不断的更新,但客户的应用场景使用的协议往往存在滞后性。同时,不同的场景对通信库有不同的诉求:有的要求最低的延迟、有的要求高带宽、有的要求易用性、有的要求应用不改动、有的要求无TCP时能工作、有的要求一套代码适配多种底层协议、有的要求在内核态、有的要求在用户态等。由于通信软件栈,向下使能高速的硬件互联,向上为应用提供API接口,同时为满足多个场景多样化的诉求,整个通信软件栈采用多层的设计,因为单一的组件无法满足多样化的需求。 + +通过多层多组件架构也达成,满足以下3个方面的要求: +- 能充分发挥硬件的性能(原生通信库接口); +- 能提供较好的易用性; +- 能提供0修改兼容Legacy应用的能力; + +目前主流的应用还是使用socket(TCP/IP),通信子系统在此次架构设计中亦包含这方面考虑。 + +### 关键架构需求 + +|序号 |SR编号 |需求名称 |IR编码 |IR标题 | +| ------------ | ------------ | ------------ | ------------ | ------------ | +|1 |SR20250429448477 |【业务面】【通信子系统】【功能】HCOM适配UB-C下URMA多路径和单路径选择能力 |IR20240528000345 |【业务面】【通信子系统】【交付通算】【HCOM】HCOM支持灵衢2.0,北向生态统一兼容。UB-C场景下,软件时延<=0.5us,带宽达到URMA带宽的96%。 | +|2 |SR20250429448690 |【业务面】【通信子系统】【功能】HCOM支持灵衢2.0 URMA自举建链 |IR20240528000345 |【业务面】【通信子系统】【交付通算】【HCOM】HCOM支持灵衢2.0,北向生态统一兼容。UB-C场景下,软件时延<=0.5us,带宽达到URMA带宽的96%。 | +|3 |SR20240626245974 |【业务面】【通信子系统】【交付通算】【性能】HCOM在灵衢2.0(UB-C)下,8KB 单并发<5us(**依赖硬件8KB@3us达成**),单并发256K<30us(**依赖硬件256KB@23us达成**),提供>50GB/s数据传输带宽 |IR20240528000345 |【业务面】【通信子系统】【交付通算】【HCOM】HCOM支持灵衢2.0,北向生态统一兼容。UB-C场景下,软件时延<=0.5us,带宽达到URMA带宽的96%。 | +|4 |SR20240619101230 |【业务面】【通信子系统】【交付通算】【性能】HCOM在灵衢2.0网络下,软件栈时延<=0.5us,带宽达到URMA带宽的96% |IR20240528000345 |【业务面】【通信子系统】【交付通算】【HCOM】HCOM支持灵衢2.0,北向生态统一兼容。UB-C场景下,软件时延<=0.5us,带宽达到URMA带宽的96%。 | + + +|JDC RR编号 |JDC标题 | +| ------------ | ------------ | +|2025051360791 |【计算 灵衢 2.0】MXE生成MatrixServer内各节点BondingEID并将其与PrimaryEID、CNA的映射关系下发给URMA | +|2025051360745 |【计算 灵衢 2.0】URMA提供BondingEID、PrimaryEID及CNA映射关系下发接口并本地缓存 | +|2025041844935 |【计算 灵衢 2.0】URMA支持MatrixServer公知jetty建链和通信 | +|2025041844928 |【计算 灵衢2.0】URMA基础通讯,支持MatrixServer内UBC通信 | +|2025052669524 |【计算 灵衢 2.0】URMA提供UB链路流量可观测性工具 | + + +### 假设和约束 + +#### 生命周期约束 + +生命周期与整体BeiMing-LingQu 2.0版本生命周期保持一致。 + + +## 架构原则 + +公司级的通用可信架构设计原则,结合产品上下文等特点,涉及的主要架构原则如下: +- HCOM架构原则 + +|维度 |原则 |原则编号 |来源 |解读 |落地方式 | +| ------------ | ------------ | ------------ | ------------ | ------------ | ------------ | +|可信基础 |服务化/组件化原则 |2.1.1 |ICT可信设计原则V1.2 |根据产品业务诉求,合理的采用服务化、组件化架构,使产品具备灵活、按需组合的能力,以更好地适应为了业务、技术和环境等变化。 |HCOM南向对接URMA,功能实现插件化,与RDMA、TCP等平级,可灵活替换。 | +|可信基础 |分层设计原则 |2.1.2 |ICT可信设计原则V1.2 |系统分为多个层次,每个层次有明确的功能定位,层次之间具有明确的、可信的依赖关系。 |HCOM分为service和transport层,每层提供不同层级的API。 | +|可信基础 |可替换性优先原则 |2.1.4 |ICT可信设计原则V1.2 |优先针对可替换性进行设计,而不是可重用性。随着软件技术的急速发展,在进行架构设计时,产品部件被新技术替代的速度加快,可替代性的重要性在很多情况下远远超过了可重用性。 |HCOM不同通信协议支持插件化,符合可替代性优先原则 | +|可信基础 |最小修改原则 |2.1.5 |ICT可信设计原则V1.2 |业务应用层纵向划分优先原则,通过纵向划分,将大型的域分割为“变更孤岛”。避免业务应用层各部件间的复杂依赖关系。新增特性或问题修改涉及的组件/服务应该内聚,修改范围越少越好。 |基于对象语言类职责功能,合理定义各模块类的成员变量和成员函数,进行功能抽象解耦,满足最小修改原则。 | + + + +- URMA架构原则 +具体见:URMA架构设计原则 + + + +# Detailed Design + +## 关键技术方案设计 + +### UBS COMM支持UB关键技术方案 + +![UBS COMM支持UB关键技术方案](./images/UBS-COMM支持UB关键技术方案.png) + +表1 Socket生态 + +|场景 |性能 |兼容生态 |环境要求 | +| ------------ | ------------ | ------------ | ------------ | +|TCP/IP Socket原生协议 |性能低 |应用无感,生态好 |MatrixServer的每个Host配有一张TCP/IP网卡 | +|UBSocket(Socket转UB)
进程级替换
![](./images/进程级替换.png) |性能高|应用无感,需要启动修改脚本,生态较好| MatrixServer的每个Host配有一张TCP/IP网卡 | +|UBSocket(Socket转UB)
修改一行代码
![修改一行代码](./images/修改一行代码.png) |性能高| 应用修改一行代码,生态一般| MatrixServer的每个Host配有一张TCP/IP网卡 | + + +表2 非Socket生态 + +|方式 |易用性 |性能 |高级特性及DFX |环境要求 | +| ------------ | ------------ | ------------ | ------------ | ------------ | +|对接HCOM |1. 易用性高
(1)控制面:一行代码/一个接口完成建链及协议切换,仅需单向建链,UB场景上层应用仅需感知BondingEID,对jetty等底层概念无感。
(2)数据面:提供类rpc和内存语义两种通信接口,提供同步和异步两种通信方式,上层应用通过注册回调方式处理数据收发结果。
2. 框内框外通信接口统一,框内框外通信协议不一致的时候使用HCOM可做到接口统一 | 服务层:性能是传输层的95%以上
传输层:性能高。性能几乎等同于直接使用URMA接口 | 1. 支持多链接管理,提供KEEP_ALIVE,RECONNECT, BROKEN_ALL三种链接管理方式。
2. 支持根据数据包大小自动择优选择通信方式,比如极小包采用inline,大包采用RNDV(双边转单边)。
3. 支持身份认证和消息加密传输功能。
4. 支持链接主动心跳检测,链接故障主动发现上报。
5. 提供性能Trace工具,辅助开发者快速定位故障位置及辅助性能分析。 | 不依赖TCP/IP网卡 | +|对接URMA | 1. 易用性较低
(1)控制面:完整建链需要100行左右代码,需要双向建链,上层应用需要理解jetty 概念,jetty通信方式,公知jetty使用方法等底层概念。
(2) 数据面:仅提供post_send和post_recv等基础接口,上层应用需要针对实际场景进行适配,同时需要设计通信线程模型,内存模型等。但是使用灵活性相对较高。
2. 框内框外通信如果使用不同通信协议的话,需要对接两套编程接口。 | 性能高 | 无 |不依赖TCP/IP网卡 | + + +### 建链关键技术方案 + +UB通信,应用层建链需要先交换本端和对端的jetty信息,因此就有两种建链通道可选:TCP/IP和公知jetty。HCOM建链统一接口可任意选择这两种方式中的一种(方式见编程样例)。URMA API只提供创建公知jetty等基本接口,需要自行实现公知jetty建链或者通过socket接口实现TCP/IP建链。 + +【建链规格1】链接数限制通信子系统不做约束,同硬件规格(当前硬件出口规格是单节点jetty数上限为64K)。 + +【建链规格2】管控面4000条TP=》247MB,单个jetty JFS_WQEBB_SIZE(64B) * 深度。 + +#### 通过TCP/IP建链 + +![](./images/5.2.1.png) + +#### 通过公知jetty建链 + +![](./images/5.2.2.png) + +##### HCOM公知jetty建链和通信编程样例 + +表1 + +|端类型 | 代码样例 | +| ------------ | ------------ | +|客户端 | void HcomClientDemo()
{
UBSHcomServiceOptions options{};
options.maxSendRecvDataSize = 1024;
UBSHcomService *client = UBSHcomService::Create(UBSHcomServiceProtocol::UBC, "client", options);
client->RegisterRecvHandler(ReceivedRequest);
client->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, UBSHcomChannelBrokenPolicy::BROKEN_ALL);
client->RegisterSendHandler([](const UBSHcomServiceContext &ctx) {return 0;});
client->RegisterOneSideHandler([](const UBSHcomServiceContext &ctx) {return 0;});
service->Start();
UBSHcomChannelPtr channel;
UBSHcomConnectOptions connOpt{};
client->Connect("ubc://" + BondingEID + ":" + std::to_string(JettyID), channel, connOpt);
UBSHcomRequest req(reinterpret_cast(dataAddr), dataSize, 0);
// 同步发送双边消息
channel->Send(req, nullptr);
client->Disconnect(channel);
UBSHcomService::Destroy("client");
} | +|服务端 |// 接收到新建链请求回调函数
int ReceivedRequest(UBSHcomServiceContext &context)
{
// 执行业务逻辑
return 0;
}

int NewChannel(const std::string &ipPort, const UBSHcomChannelPtr &ch, const std::string &payload)
{
// 执行业务逻辑
return 0;
}
void HcomServerDemo()
{
UBSHcomServiceOptions options;
options.maxSendRecvDataSize = 1024;
options.workerGroupMode = ock::hcom::NET_EVENT_POLLING;
UBSHcomService *server = UBSHcomService::Create(UBSHcomServiceProtocol::UBC, "server", options);
uint64_t bondingEid; // 应用本端的bondingEid
uint32_t jettyId; // 应用自己设置公知jettyId
server->RegisterRecvHandler(ReceivedRequest);
server->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, UBSHcomChannelBrokenPolicy::BROKEN_ALL);
server->RegisterSendHandler([](const UBSHcomServiceContext &ctx) { return 0; });
server->RegisterOneSideHandler([](const UBSHcomServiceContext &ctx) { return 0; });
server->Bind("ubc://" + bondingEid + ":" + std::to_string(jettyId), NewChannel);
server->Start();
// 业务逻辑执行完成后清理资源
UBSHcomService::Destroy("server");
}| + + +##### URMA公知jetty建链和通信编程样例 + +表1 + +|端类型 |代码样例 | +| ------------ | ------------ | +|客户端 |void UrmaClientDemo()
{
// URMA资源初始化
urma_init_attr_t attr{};
urma_status_t status = urma_init(&attr);
urma_eid_t eid;
urma_device_t *device = urma_get_device_by_eid(eid, URMA_TRANSPORT_UB);
urma_context_t *context = urma_create_context(device, 0);

// 创建jfc队列
urma_jfc_cfg_t jfcCfg{};
urma_jfc_t *jfc = urma_create_jfc(context, &jfcCfg);

// 创建数据面jetty
urma_jfs_cfg_t jfsCfg{};
urma_jfr_cfg_t jfrCfg{};
urma_jetty_cfg_t jettyCfg{};
jettyCfg.jfs_cfg = jfsCfg;
jettyCfg.jfr_cfg = &jfrCfg;
urma_jfr_t *jfr = urma_create_jfr(context, &jfrCfg);
urma_jetty_t *jetty = urma_create_jetty(context, &jettyCfg);

// 创建公知jetty
urma_jetty_cfg_t publicJettyCfg{};
publicJettyCfg.id = 100; // 自定义公知jetty号
urma_jetty_t *publicJetty = urma_create_jetty(context, &publicJettyCfg);

// 公知jetty建链
urma_rjetty_t remotePublicJetty{};
remotePublicJetty.jetty_id.eid = 1; // 指定对端公知jetty所在的eid
remotePublicJetty.jetty_id.id = 100; // 指定对端公知jetty的jetty_id
urma_token_t tokenValue{};
urma_target_jetty_t *targetPublicJetty = urma_import_jetty(context, &remotePublicJetty, &tokenValue);

// 通过公知jetty发送本端数据面通信jetty信息
urma_jfs_wr_t wr{}; // wr中填充本端jetty信息
wr.tjetty = targetPublicJetty;
urma_jfs_wr_t *badWr;
status = urma_post_jetty_send_wr(publicJetty, &wr, &badWr);

// 创建线程poll jfc队列,或者通过中断的方式poll jfc队列,本部分较复杂,暂不做代码样例实现
// 下面流程假定已经从jfc队列中poll到了对端返回来的对端数据面jetty信息,假定为remoteJetty
// 与对端数据面jetty建链
urma_rjetty_t remoteJetty{};
urma_target_jetty_t *targetJetty = urma_import_jetty(context, &remoteJetty, &tokenValue);

// 通过数据面jetty进行通信
urma_jfs_wr_t wr2{}; // wr2中填充需要发送的数据等信息
wr2.tjetty = targetJetty;
urma_jfs_wr_t *badWr2;
status = urma_post_jetty_send_wr(publicJetty, &wr, &badWr2);

// 业务发送处理完成后清理回收资源
status = urma_unimport_jetty(targetJetty);
status = urma_unbind_jetty(jetty);
status = urma_unimport_jetty(targetPublicJetty);
status = urma_unbind_jetty(publicJetty);
} | +|服务端 |void UrmaServerDemo()
{
// URMA资源初始化
urma_init_attr_t attr{};
urma_status_t status = urma_init(&attr);
urma_eid_t eid;
urma_device_t *device = urma_get_device_by_eid(eid, URMA_TRANSPORT_UB);
urma_context_t *context = urma_create_context(device, 0);

// 创建jfc队列
urma_jfc_cfg_t jfcCfg{};
urma_jfc_t *jfc = urma_create_jfc(context, &jfcCfg);

// 创建数据面jetty
urma_jfs_cfg_t jfsCfg{};
urma_jfr_cfg_t jfrCfg{};
urma_jetty_cfg_t jettyCfg{};
jettyCfg.jfs_cfg = jfsCfg;
jettyCfg.jfr_cfg = &jfrCfg;
urma_jfr_t *jfr = urma_create_jfr(context, &jfrCfg);
urma_jetty_t *jetty = urma_create_jetty(context, &jettyCfg);

// 创建公知jetty
urma_jetty_cfg_t publicJettyCfg{};
publicJettyCfg.id = 100; // 自定义公知jetty号
urma_jetty_t *publicJetty = urma_create_jetty(context, &publicJettyCfg);

// 创建线程poll jfs队列,或者通过中断方式poll jfc队列,本部分较复杂,暂不做样例实现
// 下面流程假定已经从jfc队列中poll到了对端通过公知jetty通道发过来的对端数据面jetty信息,假定为remoteJetty
// 与对端数据面jetty建链
urma_rjetty_t remoteJetty{};
urma_token_t tokenValue{};
urma_target_jetty_t *targetJetty = urma_import_jetty(context, &remoteJetty, &tokenValue);

// 通过数据面jetty通道,将本端数据面jetty信息发送给对端
urma_jfs_wr_t wr{}; // wr中填充需要发送的数据等信息
wr.tjetty = targetJetty;
urma_jfs_wr_t *badWr;
status = urma_post_jetty_send_wr(publicJetty, &wr, &badWr);

// 业务发送处理完成后清理回收资源
status = urma_unimport_jetty(targetJetty);
status = urma_unbind_jetty(jetty);
status = urma_unbind_jetty(publicJetty);
} | + + +### URMA支持多路径通信关键技术方案 + +![1D_FM](./images/1D_FM.png "1D_FM")![2D_FM](./images/2D_FM.png "2D_FM") + + +URMA多路径会使能两节点间直连路径和全部的跨跳路径(最多一跳)。 + +如上图,8节点1D FM组网两个Host之间存在14条路径,16节点2D FullMesh组网同轴两个Host之间存在6条路径。跨跳路径的时延会劣于直连路径,因此,针对不同应用的诉求,URMA提供两种多路径模式:低时延模式和高带宽模式。 + +低时延模式:URMA发送数据包只使用一条路径,优先使用直连路径,达到极致时延目的。(如:大数据、数据库) +高带宽模式:URMA发送数据包会使用全部6条路径,达到极致带宽目的。(如:虚机热迁) + + +- HCOM多路径接口和数据结构定义如下:(通过服务层的SetUbcMode接口设置多路径模式) + +``` +/** + * @brief 设置 UB-C 多路径模式 + * + * @param ubcMode UB-C 多路径模式 + */ +virtual void SetUbcMode(UBSHcomUbcMode ubcMode) = 0; + +enum class UBSHcomUbcMode : int8_t { + Disabled = -1, ///< 禁用多路径能力(默认) + LowLatency = 0, ///< 低时延模式,使用单路径发送 + HighBandwidth = 1, ///< 高带宽模式,使用多条路径发送 +}; +``` + +- URMA多路径接口和数据结构定义如下:(创建jetty时在jetty_cfg中设置多路径模式) + +``` +urma_jetty_t *urma_create_jetty(urma_context_t *ctx, urma_jetty_cfg_t *jetty_cfg); +typedef struct urma_jetty_cfg { + … + int mode; // 0:低时延模式(只使用一条路径);1:高带宽模式(使用全部路径) +} urma_jetty_cfg_t; + +``` + +### TP/CTP选型方案 + +表1 + +|类型 |限制 | +| ------------ | ------------ | +|TP |1. 无拥塞控制
2. 双边数据大小限制最大64KB,数据大小超出4KB的话进程退出需等待28s(TP必须等待28s确保没有新的REQ到达才能销毁)=>RC模式或者不共享TP的RM模式可以通过先销毁TP在销毁TA的方式规避该问题
3. 有TA层的重传机制 | +|CTP(compacted Transport) |1. 拥塞控制只有TA粒度的,一条路径发送拥塞,会导致整体降速
2. 双边数据大小限制最大4KB
3. 仅仅有链路层重传 | + + +- HCOM对外呈现: +HCOM只支持RC模式+TP,该使用方式基本对齐RDMA方式。满足数据库/大数据等使用HCOM业务对于时延迟和性能的要求。 +- URMA对外呈现: +既支持TP又支持CTP,import时设置。 + + +## 逻辑架构 + +### 逻辑模型 + +1. EID和CNA模型 + +![](./images/EID模型.png)![](./images/6.1.1_2.png) + + +EID(Entity ID):是UB Entity在UB Domain内的唯一标识,EID用于标识参与通信的对象,可唯一的标识UB Domain内的主机和设备。在需要访问某个UB Entity时,需要先知道目标EID,这个属于先验知识。EID是一个UB Domain内全局唯一的128bit值。 + +CNA(Clan Network Address),是UB Clan网络层地址。 + +EID和CNA根据其所属主体不同,有分为Bonding EID,Primary EID,Primary CNA及Port CNA,相互关系和说明如下表: + +表1 + +| |主体 |说明 |归属 | +| ------------ | ------------ | ------------ | ------------ | +| Bonding EID | 节点 | 软件层面对上层应用屏蔽Primary EID,对Primary EID做Bonding,Bonding后的EID即为Bonding EID | 软件 | +| Primary EID | IO Die | IO Die的Entity ID | 硬件 | +| Primary CNA | IO Die | IO Die的网络地址 | 硬件 | +| Port CNA | port | IO Die上port的网络地址 | 硬件 | + + +2. 系统架构模型 + +![系统架构模型](./images/6.1.2.png) + +表2 + +|模块 |职责 |形态 | +| ------------ | ------------ | ------------ | +| 分布式应用(大数据、数据库、虚机热迁等) | 1. 获取BondingEID(URMA把Bonding设置到URMA设备上,可以通过命令行查到)
2. 调用HCOM/URMA建链通信接口 | 用户态进程 | +| HCOM | 1. 北向提供统一建链、通信接口,屏蔽底层网络通信协议差异
2. 灵衢2.0网络中,南向对接URMA,使能UB网络协议 | lib库 | +| URMA | 1. 提供UB建链、通信接口
2. 通过UDMA进行数据收发
3. UVS本地缓存Bonding EID、Primary EID及CNA的映射关系 | 1. urma用户态so
2. urma内核态ko
3. uvs用户态so
4. uvs内核态ko | +| UDMA | 1. 海思硬件驱动,提供网络层数据收发能力 | udma.ko | +| MXE | 1. 逻辑资源管理,负责Bonding EID生成
2. MXE通过dlopen uvs.so的方式调用推送接口将Bonding EID、Primary EID和CNA映射关系下发给UVS,并且更新时进行全量推送 | 用户态进程 | + + +3. 通信链接模型 + +![](./images/通信链接模型.png) + +表3 + +|元素 |说明 |与下一层关系说明 | +| ------------ | ------------ | ------------ | +| Process | 应用进程 | 1对n,一个Process中可以创建n个Hcom_instance(Process每调一次Hcom的HcomService::Create接口创建一个Hcom_instance) | +| Hcom_instance | Hcom实例 | 1对n,一个Hcom实例中可以创建n个Hcom_channel(通过instance指针每调一次connect接口创建一个Hcom_channel) | +| Hcom_channel | Hcom封装的逻辑链接 | 1对1~16,一个Hcom_channel中可以创建1~16个Hcom_ep(通过配置选项中ep数量) | +| Hcom_ep | Hcom通信的endpoint,每个ep是一条URMA链接的一个通信结点 | 1对1,一个Hcom_ep对应一个jetty | +| jetty | urma通信实体 | 通过import_jetty来创建urma链接 | + + +#### 架构模式 + +整体通信子系统架构,采用分层架构的架构模式进行设计。 + +通信服务层:北向提供对接不同场景业务的API,其中包含(部分)业界标准或事实标准的接口 + +通信传输层:南向对接多种硬件、多种协议(含UB)。 + +![通信子系统-L0](./images/通信子系统-L0.png) + + +#### 1层-n层逻辑模型 + +##### HCOM逻辑模型 + +![HCOM](./images/HCOM.png) + +#### 接口设计 + +##### HCOM接口设计 + +HCOM接口设计: + +|接口名称 |接口类型 |职责 |备注 | +| ------------ | ------------ | ------------ | ------------ | +| static UBSHcomService* UBSHcomService::Create(UBSHcomServiceProtocol t, const std::string &name, const UBSHcomServiceOptions &opt = {}) | 服务层 | 创建服务对象 | / | +| static int32_t UBSHcomService::Destroy(const std::string &name) | 服务层 | 销毁服务对象 | / | +| int32_t UBSHcomService::Start() | 服务层 | 启动实例 | / | +| int32_t UBSHcomService::Connect(const std::string &serverUrl, UBSHcomChannelPtr &ch, const UBSHcomConnectOptions &opt = {}) | 服务层 | 建立连接 | / | +| int32_t UBSHcomService::Bind(const std::string &listenerUrl, const UBSHcomServiceNewChannelHandler &handler) | 服务层 | 服务端绑定监听的url和端口号 | / | +| int32_t UBSHcomChannel::Send(const UBSHcomRequest &req, const Callback *done)
int32_t UBSHcomChannel::Send(const UBSHcomRequest &req) | 服务层 |1. 向对端异步发送一个双边请求消息,并且不等待响应。
2. 向对端同步发送一个双边请求消息,并且不等待响应。| / | +| int32_t UBSHcomChannel::Get(const UBSHcomOneSideRequest &req, const Callback *done)
int32_t UBSHcomChannel::Get(const UBSHcomOneSideRequest &req) | 服务层 | 1.同步模式下,发送一个读请求给对方。
异步模式下,发送一个读请求给对方。 | / | +| int32_t UBSHcomChannel::Put(const UBSHcomOneSideRequest &req, const Callback *done)
int32_t UBSHcomChannel::Put(const UBSHcomOneSideRequest &req) | 服务层 |1.同步模式下,发送一个写请求给对方。
2.异步模式下,发送一个写请求给对方。 | / | +| static UBSHcomNetDriver *UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol t, const std::string &name, bool startOobSvr) | 传输层 | 创建实例对象 | / | +| static NResult UBSHcomNetDriver::DestroyInstance(const std::string &name) | 传输层 | 销毁实例对象 | / | +| NResult UBSHcomNetDriver::Initialize(const UBSHcomNetDriverOptions &option) | 传输层 | 初始化实例 | / | +| void UBSHcomNetDriver::UnInitialize() | 传输层 | 反初始化实例 | / | +| NResult UBSHcomNetDriver::Start() | 传输层 | 启动实例 | / | +| void UBSHcomNetDriver::Bind(const std::string &url) | 传输层 | 停止实例 | / | +| NResult UBSHcomNetDriver::Connect( const std::string &serverUrl uint16_t oobPort, const std::string &payload, UBSHcomNetEndpointPtr &ep) | 传输层 | 建立连接 | / | +| NResult UBSHcomNetEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request) | 传输层 | 发送双边通信请求 | / | +| NResult UBSHcomNetEndpoint::PostRead(const UBSHcomNetTransRequest &request) | 传输层 | 单边读请求 | / | +| NResult UBSHcomNetEndpoint::PostWrite(const UBSHcomNetTransRequest &request) | 传输层 | 单边写请求 | / | + + +### 行为模型 + +#### 分布式应用使用HCOM行为模型 + +![](./images/分布式应用使用HCOM行为模型.png) + + +#### 分布式应用使用URMA行为模型 + +![](./images/6.2.2.png) + + +### 数据模型 + +#### 架构模式 + +#### 关键数据设计 + +##### HCOM关键数据设计 + +- UBSHcomNetDriverOptions + +``` +/** + * @brief UBSHcomNetDriver options +*/ +struct UBSHcomNetDriverOptions { + char netDeviceIpMask[NN_NO256] {};// ip masks for devices + bool enableTls = false;// enable ssl + UBSHcomNetCipherSuite cipherSuite= AES_GCM_128;// if tls enabled can set cipher suite, client and server should same +/* worker setting */ + bool dontStartWorkers = false;// start worker or not + UBSHcomNetDriverWorkingMode mode= NET_BUSY_POLLING;// worker polling mode, could busy polling or event polling + char workerGroups[NN_NO64] {};// worker groups, for example 1,3,3 + char workerGroupsCpuSet[NN_NO128] {};// worker groups cpu set, for example 1-1,2-5,na +// worker thread priority [-20,20], 20 is the lowest, -20 is the highest, 0 (default) means do not set priority + int workerThreadPriority = 0; +/* connection attribute */ + NetDriverOobType oobType= NET_OOB_TCP;// oob type, tcp or UDS, UDS cannot accept remote connection + UBSHcomNetDriverLBPolicy lbPolicy= NET_ROUND_ROBIN;// select worker load balance policy, default round-robin + uint16_t magic = NN_NO256;// magic number for c/s connect validation + uint8_t version = 0;// program version used by connect validation +/* heart beat attribute */ + uint16_t heartBeatIdleTime = NN_NO60;// heart beat idle time, in seconds + uint16_t heartBeatProbeTimes = NN_NO7;// heart beat probe times + uint16_t heartBeatProbeInterval = NN_NO2;// heart beat probe interval, in seconds +/* options for only tcp protocol */ +// timeout during io (s), it should be [-1, 1024], -1 means do not set, 0 means never timeout during io + int16_t tcpUserTimeout = -1; + bool tcpEnableNoDelay = true;// tcp TCP_NODELAY option, true in default + bool tcpSendZCopy = false;// tcp whether copy request to inner memory, false in default +/* The buffer sizes will be adjusted automatically when these two variables are 0, and the performance would be + * better */ + uint16_t tcpSendBufSize = 0;// tcp connection send buffer size in kernel, by KB + uint16_t tcpReceiveBufSize = 0;// tcp connection send receive buf size in kernel, by KB +/* options for rdma protocol only */ + uint32_t mrSendReceiveSegCount = NN_NO8192;// memory region segment count for two side operation + uint32_t mrSendReceiveSegSize = NN_NO1024;// data size of memory region segment +/* transmit of 256b data performs better when dmSegSize is 290 */ + uint32_t dmSegSize = NN_NO290;// data size of device memory segment + uint32_t dmSegCount = NN_NO400;// segment count of device memory segment + uint16_t completionQueueDepth = NN_NO2048;// completion queue size of rdma + uint16_t maxPostSendCountPerQP = NN_NO64;// max number request could issue + uint16_t prePostReceiveSizePerQP = NN_NO64;// pre post receive of qp + uint16_t pollingBatchSize = NN_NO4;// polling batch size for worker + uint32_t eventPollingTimeout = NN_NO500;// event polling timeout in ms, max value is 2000000ms + uint32_t qpSendQueueSize = NN_NO256;// max send working request of qp for rdma + uint32_t qpReceiveQueueSize = NN_NO256;// max receive working request of qp for rdma + uint16_t oobConnHandleThreadCount = NN_NO2;// server accept connection thread num + uint32_t oobConnHandleQueueCap = NN_NO4096;// server accept connection queue capability + uint8_t slave = 1;// slave 1 or 2 + char oobPortRange[NN_NO16] {};// port range when enable port auto selection +/* verify the common options of each driver */ + NResultValidateCommonOptions(); + std::stringNetDeviceIpMask() const; + std::stringWorkGroups() const; + std::stringWorkerGroupCpus() const; + /** + * @brief Set the ip mask for net devices, example: 192.168.0.1/24 +*/ + bool SetNetDeviceIpMask(const std::string &mask); + /** + * @brief Set worker groups, example: 1,3,4 + * meaning 3 groups for workers: + * group0 has 1 workers + * group1 has 3 workers + * group2 has 4 workers +*/ + bool SetWorkerGroups(const std::string &groups); + /** + * @brief Set worker groups, example: 10-10,11-13,na + * meaning 3 groups for workers: + * group0 bind to cpu 10 + * group1 bind to cpu 11, 12, 13 + * group2 not bind to cpu +*/ + bool SetWorkerGroupsCpuSet(const std::string &value); + std::stringToString() const; + std::stringToStringForSock() const; +} __attribute__((packed)); + +``` + +- UBSHcomNetTransRequest + +``` +/** + * @brief Transfer request +*/ +struct UBSHcomNetTransRequest { + uintptr_t lAddress = 0;// local buffer address + uintptr_t rAddress = 0;// remote buffer address + uint32_t lKey = 0;// local memory region key, for rdma etc. + uint32_t rKey = 0;// remote memory region key, for rdma etc. + void *srcSeg = nullptr; + void *dstSeg = nullptr; + uint32_t size = 0;// buffer size + uint16_t upCtxSize = 0;// upper context size + char upCtxData[NN_NO64] = {};// upper context data + UBSHcomNetTransRequest() = default; + UBSHcomNetTransRequest(void *data, uint32_t dataSize, uint16_t upContextSize) + :lAddress(reinterpret_cast(data)), size(dataSize), upCtxSize(upContextSize) + {} + UBSHcomNetTransRequest(uintptr_t la, uintptr_t ra, uint32_t lk, uint32_t rk, uint32_t s, uint16_t upCtxSi) + :lAddress(la), rAddress(ra), lKey(lk), rKey(rk), size(s), upCtxSize(upCtxSi) + {} +} __attribute__((packed)); + +``` + +- UBSHcomServiceOptions + +``` +struct UBSHcomServiceOptions { + uint32_t maxSendRecvDataSize = 1024; // 发送数据块最大值 + uint16_t workerGroupId = 0; // group id of the worker group, must increment from 0 and be unique + uint16_t workerGroupThreadCount = 1; // worker线程数,如果设置为0的话,不启动worker线程 + UBSHcomWorkerMode workerGroupMode = NET_BUSY_POLLING; // worker线程工作模式,默认busy_polling + int8_t workerThreadPriority = 0; // 线程优先级[-20,19],19优先级最低,-20优先级最高,同nice值 + std::pair workerGroupCpuIdsRange = {UINT32_MAX, UINT32_MAX}; // default not bind +}; + +``` + +- UBSHcomRequest + +``` +struct UBSHcomRequest { + void *address = nullptr; /* pointer of data */ + uint32_t size = 0; /* size of data */ + uint32_t key = 0; + uint16_t opcode = 0; /* operation code of request */ + + UBSHcomRequest() = default; + UBSHcomRequest(void *addr, uint32_t sz, uint16_t op) : address(addr), size(sz), opcode(op) {} +}; + +``` + +#### 静态数据结构模型 + +#### 数据所有权模型 + + +### 逻辑元素清单 + + +## 实现架构 + +### 技术模型 + +#### 技术选型 + +- HCOM:延续原先架构进行演进,增加对接URMA支持UB协议,扩充HCOM对外API的丰富性和兼容性。 + +- URMA/UVS: 延续Scale Out的架构,增加Scale Up的支持。 + +### 代码模型(HCOM) + +#### 代码模型 + +![](./images/代码模型-HCOM.png) + +#### 代码元素清单 + +|逻辑元素(服务/微服务/组件/模块)|逻辑元素编号| 代码元素名称(目录/代码仓链接) |代码元素编号| +| ------------ | ------------ | ------------ | ------------ | +| 组件 | 1 | api | 1 | +| 组件 | 2 | common | 2 | +| 组件 | 3 | service | 3 | +| 组件 | 4 | transport | 4 | +| 组件 | 5 | under_api | 5 | + + +### 构建模型 + +#### 构建模型 + +![构建模型-HCOM](./images/构建模型-HCOM.png) + + +#### 构建元素清单 + +|构建元素(编译目标文件/执行目标文件)|构建元素编号 |构建过程/工具链|对应的代码元素|代码元素编号 | +| ------------ | ------------ | ------------ | ------------ | ------------ | +| 执行目标文件 | 1 | cmake | libhcom.so | 1 | +| 执行目标文件 | 2 | cmake | libhcom_static.a | 2 | +| 执行目标文件 | 3 | cmake | libhcom_adapter.so | 3 | +| 执行目标文件 | 4 | cmake | libhcom_jni.so | 4 | +| 执行目标文件 | 5 | java构建工具 | hcom-sdk.jar | 5 | +| 执行目标文件 | 6 | cmake | hcom.mod | 6 | + + +### 硬件实现模型 + +不涉及 + +### 交付模型 + +#### 交付模型 + +![HCOM交付模型](./images/HCOM交付模型.png) + +![URMA交付模型](./images/URMA交付模型.png) + + +# Design constraints + +## HCOM异常处理和可定位性(补充章节) + +- 异常处理和可靠性 +HCOM异常及其对应的可靠性保障主要发生在三个阶段:初始化&启动,建链,数据面通信,每个阶段的典型场景和可靠性保障如下表: + +|阶段 |功能|故障模式| 可能的故障原因 |故障影响|可靠性措施 | +| ------------ | ------------ | ------------ | ------------ | ------------ | ------------ | +| 初始化&启动 | 线程模型初始化(包括超时处理periodic线程,worker线程,心跳线程) | 死循环 | 未收到needStop信号 | 程序占用大量CPU导致系统响应慢
停止线程无法退出,卡死导致无法收发消息 | 循环通过neestop发送信号退出线程 | +| 初始化&启动 | 线程模型初始化(包括超时处理periodic线程,worker线程,心跳线程) | 退出 | 错误调用stop,让线程收到needStop信号 | 程序异常终止,可能是由于未捕获的异常用户主动终止程序,线程退出,无法收发消息 | 代码流程保证只有在退出时才调用stop | +| 初始化&启动 | 线程模型初始化(包括超时处理periodic线程,worker线程,心跳线程) | 句柄泄漏 | 连续创建过多句柄,或句柄随线程创建导致过多 | 线程退出句柄释放,确认代码保证句柄不会泄露 | 析构、出错、线程退出时候释放句柄 | +| 初始化&启动 | 线程模型初始化(包括超时处理periodic线程,worker线程,心跳线程) | 栈溢出 | 局部变量过大,或任务堆栈设置过小 | 确认代码保证无栈溢出 | 确认代码保证无栈溢出 | +| 初始化&启动 | 线程模型初始化(包括超时处理periodic线程,worker线程,心跳线程) | 无法启动 | 资源不够无法启动、权限不对 | service启动失败,报错返回 | 服务退出 | +| 初始化&启动 | 内存池创建和初始化 | 内存泄漏 | 内存池资源申请了未释放 | 资源泄漏,导致拒绝服务 | 代码保证调用Allocate后再调用Free对应释放 | +| 建链 | 客户端connect | 消息发送失败 | 网络连接问题、目标主机不可达、软件错误、硬件故障等 | 建链消息发送失败,建链失败 | 设置重试次数和间隔进行重试建链 | +| 建链 | 客户端收rsp | 等待应答超时 | 网络延迟过高、目标主机处理缓慢、网络拥塞、目标主机未响应等 | 建链回复时recv卡死,若无法超时退出,进程卡死 | 通过环境变量设置recv的超时时间,超时退出,建链失败 | +| 建链 | 客户端收rsp | 报文内容损坏 | 网络原因导致报完损坏 | 建链回复状态connectStatus被修改,导致建链成功了但是提示建链失败,再尝试建链,会多次尝试建链 | 通过环境变量设置重试间隔和次数,保证能够有效退出。并且代码里限制最大重试次数为10次,最大重试间隔为60s | +| 建链 | 客户端收rsp | 报文丢失 | 网络拥塞导致丢包、硬件故障、路由错误等 | 建链回复时recv卡死,若无法超时退出,进程卡死 | 通过环境变量设置recv的超时时间,超时退出,建链失败 | +| 建链 | 客户端收rsp | 报文超大 | 网络攻击,网络报文超大 | 只收sizeof(ConnectResp)大小的数据,消息解析失败 | 只收sizeof(ConnectResp)大小的数据,消息解析失败 | +| 建链 | 服务端accept | 等待应答超时 | 网络延迟过高、目标主机处理缓慢、网络拥塞、目标主机未响应等 | 建链回复时recv卡死,若无法超时退出,进程卡死 | 通过环境变量设置recv的超时时间,超时退出,建链失败 | +| 建链 | 服务端accept | 报文内容损坏 | 网络原因导致报完损坏 | 建链回复状态connectStatus被修改,导致建链成功了但是提示建链失败,再测去尝试建链 | 异常退出 | +| 建链 | 服务端accept | 报文丢失 | 网络拥塞导致丢包、硬件故障、路由错误等 | 建链回复时recv卡死,若无法超时退出,进程卡死 | 通过环境变量设置recv的超时时间,超时退出,建链失败 | +| 建链 | 服务端accept | 报文乱序 | 不涉及,阻塞式发送,只发一条消息 | | | +| 建链 | 服务端accept | 报文超大 | 网络攻击,网络报文超大 | 报文被分片,增加了复杂度;部分报文可能丢失或乱序 | 只收固定大小的数据,消息解析失败 | +| 数据面通信 | 双边发送消息send | 拥塞窗口大小不合理 | 拥塞窗口设置不合理过大或者过小 | 窗口过大导致拥塞加剧,窗口过小导致数据传输效率低 | 设置时间窗口和数据量显示,当前时间窗口如果超过了发送数据则不再发送。并且时间窗口和数据窗口都可配置 | +| 数据面通信 | 双边发送消息send | 资源申请失败 | 资源空间不够 | 消息发送一次失败,返回SER_NEW_OBJECT_FAILED,重试发送(当前也有重试机制) | 重试发送,在超时的时间窗内会尝试重发SER_NEW_OBJECT_FAILED错误码会重试,重试时间间隔usleeep(100UL) | + +HCOM全量故障模式分析详见:通信子系统-HCOM-SFMEA故障模式&分析表 .xlsx + +- 可定位性 +HCOM提供CLI工具定位分析链路详情,可以定位每条链路的调用次数/成功次数/失败次数/时延最大值/时延最小值/时延平均值/时延分位值等,使用方式详见:[HCOM性能分析工具使用指导手册](https://idp.huawei.com/idp-designer-war/design?op=edit&locate=newMode/EDIT/205482681522/ZH-CN_BOOKMAP_0000002062055353/ZH-CN_TOPIC_0000002117625377/6) + +![](./images/8.1.png) + + +# Adoption strategy +- 当前的应用/模块如何适配到此模块 + +# Related Documentions +其他与此模块相关的设计文档 + +# SIGs/Maintianers +所属与关联的SIG与相关的maintainer \ No newline at end of file diff --git a/doc/UBS-COMM-Contribution-Guide.md b/doc/UBS-COMM-Contribution-Guide.md new file mode 100644 index 0000000000000000000000000000000000000000..ef8ce7328b7d1e458f07221f4856165ebd1e1c0b --- /dev/null +++ b/doc/UBS-COMM-Contribution-Guide.md @@ -0,0 +1 @@ +[参考OpenEuler社区统一contribution guide](https://www.openeuler.openatom.cn/zh/community/contribution/detail) \ No newline at end of file diff --git a/doc/UBS-Comm-Tutorial-Demo.md b/doc/UBS-Comm-Tutorial-Demo.md new file mode 100644 index 0000000000000000000000000000000000000000..f09949fb86e2f183dc33e35b4c46e22c0d38ac70 --- /dev/null +++ b/doc/UBS-Comm-Tutorial-Demo.md @@ -0,0 +1,3039 @@ +| | | | | | +|:--:|:--:|:--:|:--:|:--:| +| | | | | | +| | **UBS CommTutorial Demo** | | | | +| | **文档版本** | **1** | | | +| | **发布日期** | **2025-09-30** | | | +| ![华为网格系统---方案4-032.png](media/image1.png) | | | | | +| | 华为技术有限公司 | | ![附件1-16K](media/image2.png) | | + +[TABLE] + +| 华为技术有限公司 | | +|------------------|---------------------------------------------| +| 地址: | 深圳市龙岗区坂田华为总部办公楼 邮编:518129 | +| 网址: | | +| 客户服务邮箱: | | +| 客户服务电话: | 4008302118 | + +[TABLE] + +# 前言 + +## 概述 + +本文档详细的描述了UBS Comm的使用指南,包括环境配置、安全管理和库文件链接方法等内容。 + +## 读者对象 + +本文档主要适用于升级的操作人员。操作人员必须具备以下经验和技能: + +- 熟悉当前网络的组网和相关网元的版本信息。 + +- 有该设备维护经验,熟悉设备的操作维护方式。 + +## 符号约定 + +在本文中可能出现下列标志,它们所代表的含义如下。 + +[TABLE] + +# 目 录 + +[前言 [iii](#前言)](#前言) + +[1 介绍 [1](#介绍)](#介绍) + +[2 环境配置 [5](#环境配置)](#环境配置) + +[2.1 组网规划 [5](#组网规划)](#组网规划) + +[2.2 环境要求 [5](#环境要求)](#环境要求) + +[2.3 安装使用 [7](#安装使用)](#安装使用) + +[2.3.1 安装MLNX_OFED驱动 [7](#安装ubs-comm)](#安装ubs-comm) + +[2.3.2 配置服务器侧RDMA网卡无损特性 [10](#rdma场景配置服务器侧rdma网卡无损特性)](#rdma场景配置服务器侧rdma网卡无损特性) + +[2.3.3 安装UBS Comm [11](#安装ubs-comm)](#安装ubs-comm) + +[2.3.4 UBC仿真环境 [12](#_Toc256000009)](#_Toc256000009) + +[3 使用指导 [15](#使用指导)](#使用指导) + +[3.1 服务层 [15](#服务层)](#服务层) + +[3.1.1 说明 [15](#说明)](#说明) + +[3.1.2 服务端 [15](#服务端)](#服务端) + +[3.1.3 客户端 [16](#客户端)](#客户端) + +[3.1.4 服务端与客户端启动后 [16](#服务端与客户端启动后)](#服务端与客户端启动后) + +[3.1.5 服务层编程 [17](#服务层编程)](#服务层编程) + +[3.2 传输层 [30](#传输层)](#传输层) + +[3.2.1 说明 [30](#说明-1)](#说明-1) + +[3.2.2 服务端 [30](#服务端-1)](#服务端-1) + +[3.2.3 客户端 [30](#客户端-1)](#客户端-1) + +[3.2.4 服务端和客户端启动后 [31](#服务端和客户端启动后)](#服务端和客户端启动后) + +[3.2.5 传输层编程 [32](#传输层编程)](#传输层编程) + +[3.3 Java使用指导 [52](#_Toc256000023)](#_Toc256000023) + +[3.3.1 Java服务层编程 [52](#_Toc256000024)](#_Toc256000024) + +[4 安全管理 [67](#ZH-CN_TOPIC_0000002363191544)](#ZH-CN_TOPIC_0000002363191544) + +[5 UBS Comm库文件链接方法参考 [69](#ubs-comm库文件链接方法参考)](#ubs-comm库文件链接方法参考) + +[A 公网地址声明 [70](#公网地址声明)](#公网地址声明) + +[B 术语和缩略语 [71](#术语和缩略语)](#术语和缩略语) + +# 介绍 + +1. 概述 + +UBS Comm(UB service communication)是一个适用于高带宽和低延迟网络C/S(Client/Server)架构应用程序的高性能通信框架。 + +UBS Comm提供一组支持各种协议的高级API(Application Programming Interface),并屏蔽了包括RDMA(Remote Direct Memory Access)、TCP(Transmission Control Protocol)、UDS(Unix Domain Socket)、SHM(Shared Memory)、UBC(Unified bus clan)等低级API的复杂性与差异性,同时尽可能发挥硬件能力,以保证其拥有高性能。 + +2. 整体方案 + +UBS Comm主要分为服务层和传输层。其中,服务层([图1-1](#fig42211835370)中的Service所展示的内容)提供了更易用的API,包含Net Service(服务层对象)、Net Channel(消息收发通道)、同步/异步模型、链路重连和限流、IO超时检测、传输加密等功能。传输层([图1 软件架构](#fig42211835370)的Net Driver所展示的内容)也有单独的API,同时提供多个协议(RDMA/TCP/UDS/SHM/UBC)的同步异步通信、心跳、传输加密等功能。 + +1. 软件架构 + +**![image-20251029102007735](C:\Users\y00835993\AppData\Roaming\Typora\typora-user-images\image-20251029102007735.png)** + +3. 特性介绍 + +- 线程模型 + +UBS Comm会创建3种类型的线程:主线程、Worker线程和心跳线程。 + +- 主线程:每个Client或Server会创建一个主线程进行侦听、建链、收发消息等操作。 + +- Worker线程:同时可以配置多个Worker线程,每条链路EP(End Point)会在建链时选择某个Worker线程,每个Worker线程可能对应多个EP(多条链路)。链路的异步收发回调、断链回调等都会由Worker线程进行处理。 + +用户能够使用参数workerGroups配置线程组以及每个组线程的个数,并通过参数workerGroupsCpuSet配置线程绑核。 + +- 心跳线程:心跳线程会定时监测对端状态,以保证能感知对端服务是否还存在。用户可以通过参数heartBeatIdleTime、heartBeatProbeTime、heartBeatProbeInterval来配置心跳检查时长。 + + 1. RDMA模式下,启动心跳线程,对所有链路发送单边写来判断链路状态。 + + 2. TCP模式下,使用TCP协议的keepalive特性,配置TCP_KEEPIDLE/TCP_KEEPINTVL等字段,保证链路状态正常。 + + 1. 线程模型 + +> ![image-20251029102016672](C:\Users\y00835993\AppData\Roaming\Typora\typora-user-images\image-20251029102016672.png) + +- 双向RPC + +UBS Comm提供双向的RPC通信,每个Client和Server都是对等的,都可以启动监听线程等待对方建链,可以由建立Instance时的第三个bool参数startOobSvr来决定是否启动监听线程。Client和Server彼此之间可以相互建链也可以相互收发消息。 + +- RNDV特性 + +RNDV协议(Rendezvous协议)是MPI通信协议中的一类,会在接收端协调缓存来接收信息,通常适用于发送比较大的消息。为了增加易用性,UBS Comm引入Rendezvous协议提供给用户使用。 + +RNDV协议主要采用单边+双边结合的方式完成。使用双边协议传递控制消息以及回复响应,如单边的MR信息、用户控制头、用户处理结果等;使用单边协议进行数据拉取,并通过回调通知业务处理。 + +- 超时机制 + +UBS Comm可以对每个IO进行超时检测,通过获取每个IO的时间戳标记,然后加入到定时器中,检测标记时间和当前时间,判断该IO是否发生超时,从而及时进行业务回调处理。 + +用户可以通过NetServiceOpInfo结构的timeout字段来配置每个IO的超时时间。 + +- 认证加密 + +UBS Comm提供了加密认证的能力,可以选择AES_128_GCM_SHA256、AES_256_GCM_SHA384、AES_128_CCM_SHA256、TLS_CHACHA20_POLY1305_SHA256四种加密算法进行加密,同时可以选择设置TLS版本,当前默认且仅支持TLS 1.3版本。用户只需要把“enableTls”参数设置为“true”,然后配置“cipherSuite”参数,注册三个TLS相关的回调函数(具体可参见“《UBS-Comm-API-Spec.md》”中tls相关章节),提供CA证书、公钥、私钥信息,即可开启加密的流程。 + + + +- 传输口令,密钥,银行账号等敏感数据、敏感个人数据和批量个人数据时,建议开启TLS能力。 + +- 当用户使用UBS Comm时,应该自己做好三面隔离,如果将UBS Comm使用在登录认证场景时,用户需要自己做好管理接口提供接入认证机制。 + +- 当用户使用TLS加密能力时,建议用户做好证书安全管理,参见[证书安全管理](#section1911412125313)。 + +  + +- RDMA协议加速特性Device Memory + +在发送数据量很小的情况下,RDMA协议提供DM(Device Memory)特性来加速传输效率,DM是存在于硬件网卡上的内存,直接使用该内存可以免去将消息拷贝到网卡的时间从而提升性能。在UBS Comm中可以通过配置选项的dmSegCount和dmSegSize来配置,其中dmSegSize决定使用DM特性的消息最大长度,在小于或等于1024bytes时有明显提升,dmSegCount决定预申请多少个dmSegSize长度的内存。在配置过大时由于硬件内存有限会申请失败,但依旧可以正常运行UBS Comm,只是无法使用DM特性。 + +- RDMA协议加速特性inline + +普通的情况下,消息请求中存放的是需要发送消息的地址,网卡需要去地址处拷贝内容。而当发送数据大小在128bytes及以下时,RDMA提供一种比DM更高效的特性inline,inline可以把需要发送的消息直接存放在消息请求中,可以明显节省拷贝用时。 + +- 兼容性检查 + +UBS Comm版本号区分主次版本,如HCOM1.0,HCOM1.1。其中小数点前数字为主版本,小数点后数字为次版本。客户端服务端主版本要相同,但服务端的次版本一定要大于等于客户端的次版本。 + +- 限制客户端的连接数消减DOS攻击风险 + +UBS Comm服务端支持开启建链TLS认证,但认证过程比较耗时;DOS攻击可以通过伪造大量的客户端发送建链报文对UBS Comm服务端进行攻击,迫使UBS Comm服务端忙于执行TLS认证校验,无法响应合法建链请求。 + +支持通过配置项限定某个客户端IP地址最大允许EP建链数,默认值为250,异常IP发来的请求达到阈值后直接报错,不再执行TLS认证校验,并通过日志告警;通过提高恶意建链成本,提升服务端服务韧性。 + +# 环境配置 + +[2.1 组网规划](#组网规划) + +[2.2 环境要求](#环境要求) + +[2.3 安装使用](#安装使用) + +## 组网规划 + +UBS Comm组网可由2台服务器组成,其中: + +- Server用于等待其他节点建链,也可以主动向其他节点建链,并可以使用链路来向对端发送消息。 + +- Client用于主动向其他节点建链,并可以使用链路来向对端发送消息。![image-20251029101933187](C:\Users\y00835993\AppData\Roaming\Typora\typora-user-images\image-20251029101933187.png) + + + +## 环境要求 + +1. 硬件要求 + + +| 服务器名称 | TaiShan服务器 | +|----|----| +| 处理器 | 鲲鹏处理器 | +| 网卡 | Mellanox CX5 (仅使用RDMA通讯协议时必须,使用其他通讯协议不需要) | +| CPU | 通过系统文件“/sys/devices/system/cpu/cpu0/regs/identification/midr_el1”中获取CPU厂商信息判断,当前配套机型鲲鹏处理器型号为0x48。 | + +2. 软件版本 + +| 软件名称 | 软件版本 | +| --------- | ------------------------------------------------------------ | +| OS | l openEuler 20.03 LTS l openEuler 22.03 LTS l openEuler 24.03 LTS l CentOS 7.6 | +| rdma-core | 42.7 | +| GCC | 7.3.0 | +| CCA | VPP V300R024C10SPC001 | + +3. 获取软件安装包 + +| 名称 | 包名 | 发布类型 | 说明 | 获取地址 | +| -------- | ------------------------------------------ | -------- | ----------------------- | ------------------------------------ | +| UBS Comm | l ubs-hcom-2.0.0-1.oe2403sp1.aarch64.rpm | 开源 | UBS Comm软件rpm安装包。 | 华为技术企业网:Link 鲲鹏社区:Link | + +[TABLE] + +1. 校验软件包完整性 + +为了防止软件包在传递过程或存储期间被恶意篡改,获取软件包时需下载对应的数字签名文件用于完整性验证。 + +1. 参见[获取软件安装包](#section3489574613)获取软件包。 + +  + +1. 获取《OpenPGP签名验证指南》。 + +- 运营商客户:请访问 + +- 企业客户:请访问 + + 1. 根据《OpenPGP签名验证指南》进行软件安装包完整性检查。 + + + +- 如果校验失败,请不要使用该软件包,先联系华为技术支持工程师解决。 + +- 使用软件包安装或升级之前,也需要按上述过程先验证软件包的数字签名,确保软件包未被篡改。 + +----结束 + +1. UBS Comm软件版本可查询 + +UBS Comm支持查询软件版本。 + +1. 参考[获取软件安装包](#section3489574613)获取软件包。 + +2. 参考[安装UBS Comm](#ubc场景安装与规格限制)安装UBS Comm + +  + +1. 查询UBS Comm软件版本。 + +rpm -qi ubs-hcom-2.0.0 + +----结束 + +## 编译构建 + +### 拉取三方库 + +执行以下命令自动拉取 + +`git submodule update --init –recursive` + +或以下命令手动拉取 + +``` +yum install libboundscheck + +mkdir 3rdparty && cd 3rdparty && git clone + +cd .. +``` + +### 编译ubs comm源码 + +执行以下命令编译 + +bash ./build.sh + +执行完毕后可以在源码的dist目录中找到BoostKit-hcom_1.0.0_aarch64.tar.gz的压缩包 + +### 高级编译选项 + +Ubs comm支持编译隔离,可通过环境变量控制部分功能是否编译。 + +具体的环境变量请参考源码中README.md与build.sh文件 + +## 安装使用 + +### 安装UBS Comm + +1. 前提条件 + +安装libboundscheck rpm包 + +> 方式一:`yum install libboundscheck` +> +> 方式二:通过gitee下载源码编译 +> +> https://gitee.com/openeuler/libboundscheck +> +> 【下载发行版本】(最新的release版本是 v1.1.16 两年前) +> https://gitee.com/openeuler/libboundscheck/releases/tag/v1.1.16 +> 【编译】 +> `make CC=gcc` +> +> 编译后根目录下 lib目录中,存在libboundscheck.so + +2. 操作步骤(安装rpm软件包) + + 1. `rpm -ivh ubs-hcom-2.0.0-1.oe2403sp1.aarch64.rpm` +2. 安装完成,安装路径可以通过rpm -qpl ubs-hcom-2.0.0-1.oe2403sp1.aarch64.rpm查看 + 3. 若需使用RDMA场景,请参考2.4.1、2.4.2章节配置RDMA驱动环境 +4. 若需使用UBC场景,请参考2.4.3章节配置UBC驱动环境 + +----结束 + + + +hcom_utils.h和hcom_ref.h文件中的函数为hcom内部使用。 + +3. 操作步骤(安装tar.gz软件包) + + 1. tar -zxvf BoostKit-hcom_1.0.0_aarch64.tar.gz + + 2. 动态库和头文件会解压到当前路径,用户可以按自己的需求将文件拷贝到需要的路径 + + 3. 若需使用RDMA场景,请参考2.4.1、2.4.2章节配置RDMA驱动环境 + + 4. 若需使用UBC场景,请参考2.4.3章节配置UBC驱动环境 + +----结束 + + + +hcom_utils.h和hcom_ref.h文件中的函数为hcom内部使用。 + +### RDMA场景安装MLNX_OFED驱动 + +![注意](media/image13.png) + +使用RDMA通信协议时,请在UBS Comm所有通信节点执行本章节操作。未使用RDMA通信协议,则可跳过本章节。 + +1. 安装步骤 + + 1. 执行以下命令,查询服务器操作系统。 + +uname -a + +返回信息如下所示。 + +Linux 4826-node62 5.10.0-182.0.0.95.oe2203sp3.aarch64 + +1. 执行以下命令,查看Mellanox网卡信息。 + +lspci \|grep Mellanox + +返回信息如下所示。 + +81:00.0 Ethernet controller: Mellanox Technologies MT28800 Family \[ConnectX-5 Ex\] +81:00.1 Ethernet controller: Mellanox Technologies MT28800 Family \[ConnectX-5 Ex\] + +2. 获取与操作系统匹配的MLNX_OFED驱动包至本地。 + +地址为。 + +1. 下载页面 + +![](media/image14.png) + +3. 执行以下命令,新建目录并将操作系统镜像文件挂载至新建目录。 + +mkdir -p */mnt/iso* +mount openEuler-20.03-LTS-aarch64-dvd.iso */mnt/iso* + + + +操作系统镜像名称请根据实际情况进行修改。 + +4. 配置操作系统镜像源,此处以配置本地镜像源为例,配置前请做好镜像源配置文件备份。 + +  + +1. 执行以下命令打开镜像源配置文件。 + +vi /etc/yum.repos.d/openEuler.repo + +2. 按“i”进入编辑模式,只保留以下内容。 + +\[OS\] +name=OS +baseurl=file:///mnt/iso +enabled=1 +gpgcheck=0 + +3. 按“Esc”键,输入**:wq!**,按“Enter”保存并退出编辑。 + + 1. 执行以下命令刷新软件包缓存信息。 + +yum makecache + +2. 上传驱动包至服务器并解压。 + +tar -zxvf *MLNX_OFED_LINUX-5.4-3.7.5.0-openeuler22.03-x86_64.tgz* + +3. 进入压缩包解压文件夹目录下,执行以下命令安装驱动。 + +./mlnxofedinstall –force + +- 若提示内核不匹配,则执行以下命令。 + +./mlnxofedinstall --add-kernel-support + +- 若不想进行固件更新,则执行以下命令。 + +./mlnxofedinstall --without-fw-update + + + +- 安装程序将删除所有之前安装的OFED驱动,并重新安装,系统会提示您确认删除旧包。 + +- **./mlnxofedinstall -h**可查询参数配置,请根据实际情况选择参数。 + + 1. 安装完成后,执行以下命令重启服务器。 + +reboot + +2. 执行以下命令,配置MLNX_OFED驱动安装完成后自启动。 + +chkconfig --add openibd +/etc/init.d/openibd start +chkconfig openibd on + +3. 执行以下命令验证MLNX_OFED驱动是否安装成功。 + +- Server节点请执行以下命令: + +ib_send_bw -d mlx5_1 -a + +- Client节点请执行以下命令: + +ib_send_bw -d mlx5_1 -a *\* + +返回信息如下所示,即为安装成功。 + +--------------------------------------------------------------------------------------- +Send BW Test +Dual-port : OFF Device : mlx5_1 +Number of qps : 1 Transport type : IB +Connection type : RC Using SRQ : OFF +PCIe relax order: ON +ibv_wr\* API : ON +TX depth : 128 +CQ Moderation : 100 +Mtu : 4096\[B\] +Link type : Ethernet +GID index : 3 +Max inline data : 0\[B\] +rdma_cm QPs : OFF +Data ex. method : Ethernet +--------------------------------------------------------------------------------------- +local address: LID 0000 QPN 0x19b8 PSN 0xa3aa02 +GID: 00:00:00:00:00:00:00:00:00:00:255:255:10:10:01:62 +remote address: LID 0000 QPN 0x19b9 PSN 0xf3ab0 +GID: 00:00:00:00:00:00:00:00:00:00:255:255:10:10:01:62 +--------------------------------------------------------------------------------------- +\#bytes \#iterations BW peak\[MB/sec\] BW average\[MB/sec\] MsgRate\[Mpps\] +2 1000 10.60 10.11 5.298300 + + + +当对RDMA通信协议有性能调优需求时,请参见[Performance Tuning for Mellanox Adapters](https://enterprise-support.nvidia.com/s/article/performance-tuning-for-mellanox-adapters)。 + +----结束 + +1. 卸载 + + 1. 进入解压包。 + +   + + 1. 执行以下命令,卸载MLNX_OFED驱动。 + +./uninstall.sh + +----结束 + +### RDMA场景配置服务器侧RDMA网卡无损特性 + +RDMA无损配置可以提高网络传输的性能和效率,确保数据传输的可靠性和一致性,同时减少CPU的负担。 + +![注意](media/image13.png) + +未使用RDMA通信协议时,以下操作步骤可以不执行;否则需要在使用UBS Comm的所有通信节点上执行。 + +1. 登录服务器,执行以下命令查询CX5网卡设备net_card信息,以CX5网卡为例。 + +net_card=\$(ibdev2netdev \| grep *mlx5_1* \| awk '{print \$5}') + +1. 使用[步骤1](#li1919910204286)查询出的net_card作为参数,执行以下命令进行CX5网卡配置。 + +cma_roce_tos -d *mlx5_1* -t 106 +mlnx_qos -i \${net_card} --pfc 0,0,0,1,0,0,0,0 --trust dscp +ifconfig \${net_card} mtu 4500 + + + +服务器每次重启后都需要重新执行当前步骤进行配置。 + +2. 执行以下命令配置网卡的CNP中的DSCP字段。 + +echo 48 \>/sys/class/net/\${net_card}/ecn/roce_np/cnp_dscp + +3. 执行以下命令配置网卡的RoCEv2中的DCQCN拥塞控制机制。 + +echo 1 \>/sys/class/net/\${net_card}/ecn/roce_np/enable/3 +echo 1 \>/sys/class/net/\${net_card}/ecn/roce_rp/enable/3 + +----结束 + +### UBC场景安装与规格限制 + +1. 前提条件 + +安装前置的LCNE、MAMI、UDMA、URMA、UBSE等驱动 + +2. 规格限制 + +- 双边发送数据长度小于等于64KB。 + +- 单边读写数据小于等于16MB。 + +- 单边带宽为0.12 MB/S。 + +3. 操作步骤 + +- rpm -ivh ubs-hcom-2.0.0-1.oe2403sp1.aarch64.rpm + +- 安装完成,动态库的安装路径可以通过rpm -qpl ubs-hcom-2.0.0-1.oe2403sp1.aarch64.rpm查看 + +### 风险声明 + +当前公知Jetty存在内存越权访问风险,token_value泄漏和篡改风险,需要部署在可信的环境中。 + +----结束 + +# 使用指导 + +[3.1 服务层](#服务层) + +## 服务层 + +### 说明 + +本章节将通过基础示例来演示如何使用UBS Comm,开发者可以通过学习此指导来快速上手UBS Comm。UBS Comm向开发者提供了传输层和服务层,因此使用指导也将分别提供一个示例代码来演示如何使用传输层和服务层。 + +### 服务端 + +- 使用NetService::Instance创建一个service的对象。 + +``` +UBSHcomServiceOptions options; + +UBSHcomService *service = UBSHcomService::Create(RDMA, "server1", options); +``` + +- 此处创建了一个使用RDMA协议的名为server1的服务端Driver,支持通过option设置基本配置项。 + +- 设置NetServiceOptions选项,使用service对象注册回调函数,并用service的Bind方法设置需要侦听的IP地址和端口。 + +``` +service->RegisterRecvHandler(ReceivedRequest); + +service->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, UBSHcomChannelBrokenPolicy::BROKEN_ALL); + +service->RegisterSendHandler([](const UBSHcomServiceContext &ctx) { return 0; }); + +service->RegisterOneSideHandler([](const UBSHcomServiceContext &ctx) { return 0; }); +service->Bind("uds://" + oobIp + ":" + std::to_string(oobPort), NewChannel); +``` + + + +- ServiceOptions的参数,详情请参见《UBS-COMM-API-Spec》的“ServiceOptions”章节。 + +- 注册回调函数,详情请参见《UBS-COMM-API-Spec》的“RegisterRecvHandler”等章节。 + +- Bind用来设置需要侦听的IP地址和端口以及收到建链请求时的回调。 + + 1. 调用service的Start方法,完成服务端的启动。 + +`service->Start();` + +----结束 + +### 客户端 + +- 使用NetService::Instance创建一个service的对象。 + +``` +UBSHcomServiceOptions options; + +UBSHcomService *service = UBSHcomService::Create(RDMA, "client1", options); +``` + +- 此处创建了一个使用RDMA协议的名为client1的服务端Driver,支持通过option设置基本配置项。 + +- 设置NetServiceOptions选项,使用service对象注册回调函数。 + +``` +service->RegisterRecvHandler(ReceivedRequest); + +service->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, UBSHcomChannelBrokenPolicy::BROKEN_ALL); + +service->RegisterSendHandler([](const UBSHcomServiceContext &ctx) { return 0; }); +``` + +- 调用service的Start方法,完成客户端的启动。 + +`service->Start();` + +----结束 + +### 服务端与客户端启动后 + +1. 当服务端与客户端都完成启动后,客户端的Service可以调用Connect方法来连接服务端。 + +UBSHcomConnectOptions opt; + +UBSHcomChannelPtr channel + +service-\>Connect("tcp://" + oobIp + ":" + std::to_string(oobPort), channel, opt); + + + +- oobIp:需要建链的IP地址。 + +- oobPort:需要建链的Port。 + +- channel:Connect函数的返回值,即为得到的链路的本端,对端的HcomChannelPtr在NewChannel回调函数的第二个参数中获得。 + +- options:设置这条链路的选项。详情请参见《UBS-COMM-API-Spec》的“ConnectOptions”章节。 + + 1. 连接成功后,服务端与客户端都会获得一个HcomChannel对象,服务端与客户端都可以使用该对象来调用各种消息发送接口向对端发送消息。 + +UBSHcomRequest req(reinterpret_cast\(addr), dataSize, 0); +channel-\>Send(req, nullptr) + +详情请参见《UBS-COMM-API-Spec》的“UBSHcomChannel::Send”章节。 + +----结束 + +### 服务层编程 + +此示例仅限帮助开发者具象化理解如何使用UBS Comm,作为实际使用场景的参考,请勿直接复制使用。 + +1. Sever端完整示例代码 + +  + +1. 以下为服务层Server端的完整示例代码,当Server端收到Client端的消息时,会调用初始化时注册的回调函数RequestReceived,可以在回调函数中给Client端回复消息。 + +``` +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Author: + * Description: UBC通信服务端,接收客户端大量请求并响应 + */ +#include +#include +#include +#include +#include +#include + +#include "hcom/hcom_service.h" +#include "hcom/hcom_service_context.h" +#include "hcom/hcom.h" + +using namespace ock::hcom; + +ock::hcom::UBSHcomService *service = nullptr; + +UBSHcomServiceProtocol driverType = ock::hcom::UBSHcomServiceProtocol::UBC; +std::string oobIp = ""; +uint16_t oobPort = 9981; + +std::string ipSeg = "192.168.100.0/24"; +int32_t dataSize = 1024; +uint32_t workerNum = 1; +int16_t asyncWorkerCpuId = -1; +UBSHcomChannelPtr channel = nullptr; + +uint32_t multiRailEnable = 1; +uint32_t multiRailThresh = 8192; +UBSHcomUbcMode mUbcMode = UBSHcomUbcMode::LowLatency; +uint32_t splitSendThreshold = UINT32_MAX; + +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t size = 0; + UBSHcomMemoryKey lKey; +}; + +UBSHcomRegMemoryRegion rndvMr; +uintptr_t rndvAddr = 0; +uint32_t rndvSize = 1048576; +uint32_t rndvThreshold = UINT32_MAX; + +TestRegMrInfo localMrInfo; + +int NewChannel(const std::string &ipPort, const UBSHcomChannelPtr &ch, const std::string &payload) +{ + NN_LOG_INFO("new channel " << ch->GetId() << " call from " << ipPort << " payload: " << payload); + channel = ch; + + UBSHcomTwoSideThreshold threshold{}; + threshold.rndvThreshold = rndvThreshold; + + auto result = channel->SetTwoSideThreshold(threshold); + if (result != 0) { + NN_LOG_ERROR("failed to set two side threshold, result " << result); + return false; + } + return 0; +} + +int CallBackReply(UBSHcomServiceContext &context) +{ + Callback *cb = UBSHcomNewCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + if (NN_UNLIKELY(cb == nullptr)) { + NN_LOG_ERROR("new callback is nullptr"); + return -1; + } + if (context.OpCode() == 0) { + NN_LOG_DEBUG("receive msg, channel id " << context.Channel()->GetId() << ", MessageData " << + reinterpret_cast(context.MessageData()) << " MessageDataLen: " << context.MessageDataLen()); + } else if (context.OpCode() == 1) { + UBSHcomRequest req; + req.address = context.MessageData(); + req.size = context.MessageDataLen(); + + UBSHcomReplyContext replyCtx; + replyCtx.errorCode = 200; + replyCtx.rspCtx = context.RspCtx(); + + if (context.Channel()->Reply(replyCtx, req, cb) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return -1; + } + } else if (context.OpCode() == 2) { + UBSHcomRequest req((void *)&localMrInfo, sizeof(localMrInfo), 1); + UBSHcomReplyContext replyCtx(context.RspCtx(), 200); + if (context.Channel()->Reply(replyCtx, req, cb) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return -1; + } + } + return 0; +} + +int ReceivedRequest(UBSHcomServiceContext &context) +{ + if (context.OpType() == UBSHcomServiceContext::Operation::SER_RNDV) { + uintptr_t contextRsp = context.RspCtx(); + const UBSHcomChannelPtr &rspChannel = context.Channel(); + Callback *newCallback = UBSHcomNewCallback( + [contextRsp, rspChannel](UBSHcomServiceContext &ctx) { + if (NN_UNLIKELY(ctx.Result() != SER_OK)) { + NN_LOG_ERROR("Rndv recv callback failed " << ctx.Result()); + } + + UBSHcomRequest req; + char str[] = "rndv reply!"; + char *ptr = str; + req.address = str; + req.size = strlen(str); + + UBSHcomReplyContext replyCtx; + replyCtx.errorCode = 0; + replyCtx.rspCtx = contextRsp; + Callback *cb = UBSHcomNewCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + if (NN_UNLIKELY(cb == nullptr)) { + NN_LOG_ERROR("new callback is nullptr"); + return; + } + if (rspChannel->Reply(replyCtx, req, cb) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + } + }, + std::placeholders::_1); + if (context.Channel()->Recv(context, rndvAddr, dataSize, newCallback) != 0) { + NN_LOG_ERROR("failed to rndv recv data to server"); + return -1; + } + return 0; + } + + return CallBackReply(context); +} + +bool CreateService() +{ + if (service != nullptr) { + NN_LOG_ERROR("service already created"); + return false; + } + + UBSHcomServiceOptions options; + options.maxSendRecvDataSize = dataSize + 1024; + options.workerGroupMode = ock::hcom::NET_EVENT_POLLING; + if (asyncWorkerCpuId != -1) { + options.workerGroupCpuIdsRange = { asyncWorkerCpuId, asyncWorkerCpuId }; + } + service = UBSHcomService::Create(driverType, "server1", options); + if (service == nullptr) { + NN_LOG_ERROR("failed to create service already created"); + return false; + } + if (driverType != UBC) { + service->SetDeviceIpMask({ ipSeg }); + } + service->RegisterRecvHandler(ReceivedRequest); + service->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, UBSHcomChannelBrokenPolicy::BROKEN_ALL); + service->RegisterSendHandler([](const UBSHcomServiceContext &ctx) { return 0; }); + service->RegisterOneSideHandler([](const UBSHcomServiceContext &ctx) { return 0; }); + service->SetUbcMode(mUbcMode); + + if (driverType == SHM) { + service->Bind("uds://" + oobIp + ":" + std::to_string(oobPort), NewChannel); + } else if (driverType == UBC) { + service->Bind("ubc://" + oobIp + ":" + std::to_string(oobPort), NewChannel); + } else { + service->Bind("tcp://" + oobIp + ":" + std::to_string(oobPort), NewChannel); + } + + UBSHcomHeartBeatOptions hbOptions{}; + hbOptions.heartBeatIdleSec = 2; + service->SetHeartBeatOptions(hbOptions); + + int result = 0; + if ((result = service->Start()) != 0) { + NN_LOG_ERROR("failed to initialize service " << result); + return false; + } + NN_LOG_INFO("service initialized"); + return true; +} + + +bool RegSglMem() +{ + UBSHcomRegMemoryRegion mr; + auto result = service->RegisterMemoryRegion(dataSize, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + localMrInfo.lAddress = mr.GetAddress(); + mr.GetMemoryKey(localMrInfo.lKey); + localMrInfo.size = mr.GetSize(); + + strcpy(reinterpret_cast(localMrInfo.lAddress), "aaaaa server"); + return true; +} + + +bool RegRndvMem() +{ + void *address = memalign(4096, rndvSize); + if (address == nullptr) { + NN_LOG_ERROR("Failed to alloc memory, maybe lack of spare memory in system."); + return false; + } + rndvAddr = reinterpret_cast(address); + auto result = service->RegisterMemoryRegion(rndvAddr, rndvSize, rndvMr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + free(address); + return false; + } + memset(address, 'A', rndvSize); + return true; +} + +void SendMemInfo() +{ + UBSHcomRequest req((void *)&localMrInfo, sizeof(localMrInfo), 1); + + if ((channel->Send(req, nullptr)) != 0) { + NN_LOG_ERROR("failed to send message to data to server"); + return; + } + NN_LOG_INFO("SendMemInfo success"); +} + +void SendRequest() +{ + NN_LOG_INFO("input q means quit, d means dump obj static, c means channel close"); + while (true) { + auto tmpChar = getchar(); + switch (tmpChar) { + case 'q': + service->DestroyMemoryRegion(rndvMr); + free(reinterpret_cast(rndvAddr)); + UBSHcomService::Destroy("server1"); + return; + case 'd': + NetObjStatistic::Dump(); + continue; + case 'c': + continue; + case 's': + SendMemInfo(); + default: + NN_LOG_INFO("input q means quit, d means dump obj static, c means channel close"); + continue; + } + } +} + +void exitFunc() +{ + service = nullptr; +} + +void Run() +{ + if (!CreateService()) { + return; + } + + atexit(exitFunc); + + if (!RegSglMem()) { + return; + } + + if (!RegRndvMem()) { + return; + } + + SendRequest(); +} + +int main(int argc, char *argv[]) +{ + struct option options[] = { + {"driver", required_argument, NULL, 'd'}, + {"ip", required_argument, NULL, 'i'}, + {"port", required_argument, NULL, 'p'}, + {"size", required_argument, NULL, 's'}, + {"worker num", required_argument, NULL, 'w'}, + {"cpuId", required_argument, NULL, 'c'}, + {"multiRail", optional_argument, NULL, 'r'}, + {"multiRailThresh", optional_argument, NULL, 'R'}, + {"RndvThreshold", optional_argument, NULL, 'v'}, + {"ubcMode", optional_argument, NULL, 'u'}, + {"splitSendThreshold", optional_argument, NULL, 'S'}, + {NULL, 0, NULL, 0}, + }; + + const char *usage = "usage\n" + " -d, --driver, driver type, 0 for rdma, 1 for tcp, 3 for shm, 7 for UBC\n" + " -i, --ip, server ip mask, e.g. 10.175.118.1; eid for UBC, e.g. " + "4245:4944:0000:0000:0000:0000:0100:0000\n" + " -p, --port, server port, by default 9981; jetty id for UBC, e.g. 998\n" + " -s, --io size , max data size\n" + " -w, --worker num , worker num\n" + " -c, --cpuId, async worker\n" + " -r, --enableMultiRail, enable multiRail\n" + " -R, --multiRailThresh, multiRail threshhold\n" + " -v, --RndvThreshold, Perf case only supports an RNDV threshold of less than 1048576, actual " + "scenario requires a value less than UINT32_MAX\n" + " -u, --ubcMode, UB-C mode, 0 means LowLatency, other value means HighBandwidth\n" + " -S, --splitSendThreshold, the threshold of split send, UINT32_MAX by default\n"; + + + int ret = 0; + int index = 0; + + std::string str = "d:i:p:s:w:c:r:R:v:u:S:"; + while ((ret = getopt_long(argc, argv, str.c_str(), options, &index)) != -1) { + switch (ret) { + case 'd': + driverType = static_cast((uint16_t)strtoul(optarg, NULL, 0)); + if (driverType > UBC) { + printf("invalid driver type %d", driverType); + return -1; + } + break; + case 'i': + oobIp = optarg; + ipSeg = oobIp + "/24"; + break; + case 'p': + oobPort = (uint16_t)strtoul(optarg, NULL, 0); + break; + case 's': + dataSize = (int32_t)strtoul(optarg, NULL, 0); + break; + case 'w': + workerNum = (int32_t)strtoul(optarg, NULL, 0); + break; + case 'c': + asyncWorkerCpuId = strtoul(optarg, nullptr, 0); + break; + case 'r': + multiRailEnable = (uint32_t)strtoul(optarg, nullptr, 0); + break; + case 'R': + multiRailThresh = (uint32_t)strtoul(optarg, nullptr, 0); + break; + case 'v': + rndvThreshold = (uint32_t)strtoul(optarg, NULL, 0); + break; + case 'u': + mUbcMode = std::stoi(optarg) ? UBSHcomUbcMode::HighBandwidth : UBSHcomUbcMode::LowLatency; + break; + case 'S': + splitSendThreshold = (uint32_t)strtoul(optarg, nullptr, 0); + break; + } + } + + Run(); + return 0; +} +``` + +2. 执行以下命令运行代码,启动Server端。 + +`./*pp_service_server_simple* -d 1 -i 127.0.0.1 -p 9980` + +*pp_service_server_simple*:编译后可执行文件名,请根据实际情况进行修改。 + +- -d:配置Driver类型。 + + 1. 0:RDMA + + 2. 1:TCP + + 3. 2:UDS + + 4. 3:SHM + + 5. 7:UBC + +- -i:Server IP地址。 + +- -p: 监听的端口 + +- 其余参数可以根据实际情况自由配置 + + 1. Client端完整示例代码 + +1. 以下为服务层一个完整的Client端示例,入口为main函数,经过参数解析后,进入Run函数。 + +``` +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Author: + * Description: UBC通信客户端,使用多种不同接口发送大量请求组并计算请求的延迟和吞吐量 + */ +#include +#include +#include +#include +#include +#include "hcom/hcom_service.h" +#include "hcom/hcom_service_context.h" +#include "hcom/hcom.h" + +using namespace ock::hcom; + +UBSHcomService *service = nullptr; +UBSHcomChannelPtr channel = nullptr; + +UBSHcomServiceProtocol driverType = ock::hcom::UBSHcomServiceProtocol::UBC; +std::string oobIp = ""; +uint16_t oobPort = 9981; + +std::string ipSeg = "192.168.100.0/24"; +int32_t pingCount = 1000000; +uint64_t startTime = 0; +uint64_t finishTime = 0; +uint64_t asyncTime = 0; + + +uint64_t mode = 0; +uint64_t periodThreadCnt = 0; +int32_t dataSize = 1024; +int32_t epSize = 1; +int16_t asyncWorkerCpuId = -1; +char *data = nullptr; +char *rcvData = nullptr; +bool start = false; +uint32_t multiRailEnable = 1; +uint32_t multiRailThresh = 8192; +UBSHcomUbcMode mUbcMode = UBSHcomUbcMode::LowLatency; +uint32_t splitSendThreshold = UINT32_MAX; + +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t size = 0; + UBSHcomMemoryKey lKey; +}; + +TestRegMrInfo localMrInfo; +TestRegMrInfo remoteMrInfo; + +UBSHcomRegMemoryRegion rndvMr; +uintptr_t rndvAddr = 0; +uint32_t rndvSize = 1048576; +uint32_t rndvThreshold = UINT32_MAX; + +int ReceivedRequest(UBSHcomServiceContext &context) +{ + if (context.OpCode() == 1) { + memcpy((void *)&remoteMrInfo, context.MessageData(), sizeof(remoteMrInfo)); + NN_LOG_INFO("remoteMrInfo lAddress is " << remoteMrInfo.lAddress << ", lKey: " << remoteMrInfo.lKey.keys[0] << ", size: " << remoteMrInfo.size); + } + return 0; +} + +bool CreateService() +{ + if (service != nullptr) { + NN_LOG_ERROR("service already created"); + return false; + } + + UBSHcomServiceOptions options{}; + options.maxSendRecvDataSize = dataSize + 1024; + options.workerGroupMode = ock::hcom::NET_EVENT_POLLING; + if (asyncWorkerCpuId != -1) { + options.workerGroupCpuIdsRange = {asyncWorkerCpuId, asyncWorkerCpuId}; + } + service = UBSHcomService::Create(driverType, "client1", options); + if (service == nullptr) { + NN_LOG_ERROR("failed to create service already created"); + return false; + } + + service->SetTimeOutDetectionThreadNum(periodThreadCnt); + service->RegisterRecvHandler(ReceivedRequest); + service->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, UBSHcomChannelBrokenPolicy::BROKEN_ALL); + service->RegisterSendHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterOneSideHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->SetUbcMode(mUbcMode); + + if (driverType != UBC) { + service->SetDeviceIpMask({ipSeg}); + } + + int result = 0; + if ((result = service->Start()) != 0) { + NN_LOG_ERROR("failed to start service " << result); + return false; + } + NN_LOG_ERROR("service started"); + + return true; +} + +bool Connect() +{ + if (service == nullptr) { + NN_LOG_ERROR("service is null"); + return false; + } + + int result = 0; + UBSHcomConnectOptions opt; + opt.linkCount = epSize; + + NN_LOG_INFO("connect mode: " << mode); + if (mode == 1) { + opt.mode = UBSHcomClientPollingMode::SELF_POLL; + } + + if (driverType == SHM) { + result = service->Connect("uds://" + oobIp, channel, opt); + } else if (driverType == UBC) { + result = service->Connect("ubc://" + oobIp + ":" + std::to_string(oobPort), channel, opt); + } else { + result = service->Connect("tcp://" + oobIp + ":" + std::to_string(oobPort), channel, opt); + } + + if (result != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + data = static_cast(malloc(dataSize)); + + UBSHcomTwoSideThreshold threshold{}; + threshold.rndvThreshold = rndvThreshold; + threshold.splitThreshold = splitSendThreshold; + + result = channel->SetTwoSideThreshold(threshold); + if (result != 0) { + NN_LOG_ERROR("failed to set two side threshold, result " << result); + return false; + } + + return true; +} + +void SendRequest() +{ + UBSHcomRequest req(reinterpret_cast(rndvAddr), dataSize, 0); + + if ((channel->Send(req, nullptr)) != 0) { + NN_LOG_ERROR("failed to send message to data to server"); + return; + } +} + +bool CallRequest() +{ + UBSHcomRequest req(reinterpret_cast(rndvAddr), dataSize, 1); + UBSHcomResponse rsp(reinterpret_cast(rndvAddr), dataSize); + if ((channel->Call(req, rsp, nullptr)) != 0) { + NN_LOG_ERROR("failed to call message to data to server"); + return false; + } + return true; +} + +bool AsyncCallRequest() +{ + UBSHcomRequest req(reinterpret_cast(rndvAddr), dataSize, 1); + + rcvData = static_cast(malloc(dataSize)); + // char *rcvData = static_cast(malloc(dataSize)); + UBSHcomResponse rsp(rcvData, dataSize); + + int32_t ret = 0; + sem_t sem; + sem_init(&sem, 0, 0); + Callback *callback = UBSHcomNewCallback( + [&sem, &ret, &rsp](UBSHcomServiceContext &context) { + if (NN_UNLIKELY(context.Result() != SER_OK)) { + NN_LOG_ERROR("Channel Async send callback failed " << context.Result() << ", optype: " << context.OpType()); + ret = -1; + sem_post(&sem); + return; + } + memcpy(rsp.address, context.MessageData(), context.MessageDataLen()); + sem_post(&sem); + }, + std::placeholders::_1); + + + if ((channel->Call(req, rsp, callback)) != 0) { + NN_LOG_ERROR("failed to call message to data to server"); + return false; + } + + sem_wait(&sem); + sem_destroy(&sem); + if (ret != 0) { + NN_LOG_ERROR("failed to async call"); + return false; + } + + return true; +} + +void AsyncSendRequest() +{ + UBSHcomRequest req(reinterpret_cast(rndvAddr), dataSize, 0); + + Callback *callback = UBSHcomNewCallback( + [](UBSHcomServiceContext &context) { + if (NN_UNLIKELY(context.Result() != SER_OK)) { + NN_LOG_ERROR("Channel Async send callback failed " << context.Result() << ", optype: " << + context.OpType()); + } + }, + std::placeholders::_1); + + if ((channel->Send(req, callback)) != 0) { + NN_LOG_ERROR("failed to send message to data to server"); + return; + } +} + +void ReadRequest() +{ + UBSHcomOneSideRequest req {}; + req.lAddress = localMrInfo.lAddress; + req.lKey = localMrInfo.lKey; + req.rAddress = remoteMrInfo.lAddress; + req.rKey = remoteMrInfo.lKey; + req.size = dataSize; + + if ((channel->Get(req, nullptr)) != 0) { + NN_LOG_ERROR("failed to read data to server"); + return; + } +} + +void AsyncReadRequest() +{ + UBSHcomOneSideRequest req {}; + req.lAddress = localMrInfo.lAddress; + req.lKey = localMrInfo.lKey; + req.rAddress = remoteMrInfo.lAddress; + req.rKey = remoteMrInfo.lKey; + req.size = dataSize; + + Callback *callback = UBSHcomNewCallback( + [](UBSHcomServiceContext &context) { + if (NN_UNLIKELY(context.Result() != SER_OK || context.OpType() != 4)) { + NN_LOG_ERROR("Channel Async read callback failed " << context.Result() << ", optype: " << context.OpType()); + } else { + NN_LOG_INFO("Channel Async read callback successful " << context.Result()); + } + }, + std::placeholders::_1); + if ((channel->Get(req, callback)) != 0) { + NN_LOG_ERROR("failed to read data to server"); + return; + } + NN_LOG_INFO("read data from server" << std::string((char *)req.lAddress)); +} + +void WriteRequest() +{ + UBSHcomOneSideRequest req {}; + req.lAddress = localMrInfo.lAddress; + req.lKey = localMrInfo.lKey; + req.rAddress = remoteMrInfo.lAddress; + req.rKey = remoteMrInfo.lKey; + req.size = dataSize; + + if ((channel->Put(req, nullptr)) != 0) { + NN_LOG_ERROR("failed to read data to server"); + return; + } +} + +void AsyncWriteRequest() +{ + UBSHcomOneSideRequest req {}; + req.lAddress = localMrInfo.lAddress; + req.lKey = localMrInfo.lKey; + req.rAddress = remoteMrInfo.lAddress; + req.rKey = remoteMrInfo.lKey; + req.size = dataSize; + + Callback *callback = UBSHcomNewCallback( + [](UBSHcomServiceContext &context) { + if (NN_UNLIKELY(context.Result() != SER_OK || context.OpType() != 4)) { + NN_LOG_ERROR("Channel Async write callback failed " << context.Result() << ", optype: " << context.OpType()); + } else { + NN_LOG_INFO("Channel Async write callback successful " << context.Result()); + } + }, + std::placeholders::_1); + if ((channel->Put(req, callback)) != 0) { + NN_LOG_ERROR("failed to read data to server"); + return; + } +} + +bool RegRndvMem() +{ + void *address = memalign(4096, rndvSize); + if (address == nullptr) { + NN_LOG_ERROR("Failed to alloc memory, maybe lack of spare memory in system."); + return false; + } + rndvAddr = reinterpret_cast(address); + auto result = service->RegisterMemoryRegion(rndvAddr, rndvSize, rndvMr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + free(address); + return false; + } + memset(address, 'B', rndvSize); + return true; +} + +int userChar = 0; + +void RunInThread() +{ + while (!start) { + usleep(1); + } + bool ret; + switch (userChar) { + case '0': + for (int32_t i = 0; i < pingCount; i++) { + SendRequest(); + } + printf("SendRequest finish\n"); + break; + case '1': + for (int32_t i = 0; i < pingCount; i++) { + AsyncSendRequest(); + } + break; + case '2': + for (int32_t i = 0; i < pingCount; i++) { + ret = CallRequest(); + if (!ret) { + return; + } + } + break; + case '3': + for (int32_t i = 0; i < pingCount; i++) { + ret = AsyncCallRequest(); + if (!ret) { + return; + } + } + break; + case '4': + for (int32_t i = 0; i < pingCount; i++) { + ReadRequest(); + } + break; + case '5': + for (int32_t i = 0; i < pingCount; i++) { + AsyncReadRequest(); + } + break; + case '6': + for (int32_t i = 0; i < pingCount; i++) { + WriteRequest(); + } + break; + case '7': + for (int32_t i = 0; i < pingCount; i++) { + AsyncWriteRequest(); + } + break; + default: + return; + } +} + +void Test() +{ + NN_LOG_INFO( + "input 0:send, 1:async send, 2:call, 3:async call, 4:read, 5:async read, 6:write, 7:async write "); + int ret; + while (true) { + userChar = getchar(); + startTime = MONOTONIC_TIME_NS(); + + std::thread threads[epSize]; + + start = false; + for (int i = 0; i < epSize; i++) { + threads[i] = std::thread(RunInThread); + } + + NN_LOG_INFO("Wait for finish"); + start = true; + for (auto &t : threads) { + t.join(); + } + + switch (userChar) { + case '0': + printf("\tType sync send\n"); + break; + case '1': + printf("\tType async send\n"); + break; + case '2': + printf("\tType sync call\n"); + break; + case '3': + printf("\tType async call\n"); + break; + case '4': + printf("\tType sync read\n"); + break; + case '5': + printf("\tType sync read\n"); + break; + case '6': + printf("\tType sync write\n"); + break; + case '7': + printf("\tType async write\n"); + break; + case 'q': + service->DestroyMemoryRegion(rndvMr); + free(reinterpret_cast(rndvAddr)); + UBSHcomService::Destroy("client1"); + return; + case 'd': + NetObjStatistic::Dump(); + break; + case 'c': + printf("\tOperate close\n"); + service->Disconnect(channel); + break; + default: + NN_LOG_INFO("input 0:send, 1:async send, 2:call, 3:async call, 4:read, 5:async read, 6:write, 7:sync " + "write "); + continue; + } + + if (userChar == 'd' || userChar == 'c' || userChar == 'r') { + continue; + } + + finishTime = MONOTONIC_TIME_NS(); + printf("\tPerf summary\n"); + printf("\tPingpong times:\t\t%d\n", pingCount); + printf("\tData size:\t\t%d\n", dataSize); + printf("\tEp size:\t\t%d\n", epSize); + printf("\tThread count:\t\t%d\n", epSize); + printf("\tTotal time(us):\t\t%f\n", (finishTime - startTime) / 1000.0); + printf("\tTotal time(ms):\t\t%f\n", (finishTime - startTime) / 1000000.0); + printf("\tTotal time(s):\t\t%f\n", (finishTime - startTime) / 1000000000.0); + printf("\tLatency(us):\t\t%f\n", (finishTime - startTime) / pingCount / 1000.0); + printf("\tAvg ops:\t\t%f pp/s\n", (pingCount * 1000000000.0) / (finishTime - startTime)); + printf("\tTotal ops:\t\t%f pp/s\n", (pingCount * 1000000000.0) / (finishTime - startTime) * epSize); + printf("\tAvg bw:\t\t\t%f MB/s\n", + (pingCount * 1000000000.0) / (finishTime - startTime) * dataSize / 1024 / 1024); + printf("\tTotal bw:\t\t%f MB/s\n", + (pingCount * 1000000000.0) / (finishTime - startTime) * dataSize / 1024 / 1024 * epSize); + + if (userChar == 'a') { + printf("\tAsync call latency(us):\t%f\n", asyncTime / pingCount / 1000.0 / epSize); + asyncTime = 0; + } + } +} + +bool GetRemoteMr() +{ + UBSHcomRequest req(data, sizeof(data), 2); + UBSHcomResponse rsp((void*)&remoteMrInfo, sizeof(remoteMrInfo)); + + if ((channel->Call(req, rsp, nullptr)) != 0) { + NN_LOG_INFO("failed to call message to data to server"); + return false; + } + NN_LOG_INFO("remoteMrInfo lAddress is " << remoteMrInfo.lAddress << ", lKey: " << remoteMrInfo.lKey.keys[0] << ", size: " << remoteMrInfo.size); + return true; +} + +bool RegSglMem() +{ + UBSHcomRegMemoryRegion mr; + auto result = service->RegisterMemoryRegion(dataSize, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + localMrInfo.lAddress = mr.GetAddress(); + mr.GetMemoryKey(localMrInfo.lKey); + localMrInfo.size = dataSize; + + return true; +} + +void exitFunc() +{ + free(data); + free(rcvData); + service = nullptr; + data = nullptr; + rcvData = nullptr; +} + +void Run() +{ + if (!CreateService()) { + return; + } + + atexit(exitFunc); + + if (!RegSglMem()) { + return; + } + + if (!RegRndvMem()) { + return; + } + + if (!Connect()) { + return; + } + + if (!GetRemoteMr()) { + return; + } + + Test(); +} + +int main(int argc, char *argv[]) +{ + struct option options[] = { + {"driver", required_argument, NULL, 'd'}, + {"ip", required_argument, NULL, 'i'}, + {"port", required_argument, NULL, 'p'}, + {"pingpongtimes", required_argument, NULL, 't'}, + {"size", required_argument, NULL, 's'}, + {"epSize", required_argument, NULL, 'e'}, + {"epMode", required_argument, NULL, 'm'}, + {"timeout thread", required_argument, NULL, 'o'}, + {"cpuId", required_argument, NULL, 'c'}, + {"multiRail", optional_argument, NULL, 'r'}, + {"multiRailThresh", optional_argument, NULL, 'R'}, + {"RndvThreshold", optional_argument, NULL, 'v'}, + {"ubcMode", optional_argument, NULL, 'u'}, + {"splitSendThreshold", optional_argument, NULL, 'S'}, + {"rndvThreshold", optional_argument, NULL, 'v'}, + {NULL, 0, NULL, 0}, + }; + + const char *usage = "usage\n" + " -d, --driver, driver type, 0 for rdma, 1 for tcp, 3 for shm, 7 for UBC\n" + " -i, --ip, coord server ip mask, e.g. 10.175.118.1; remote eid for UBC, e.g. " + "4245:4944:0000:0000:0000:0000:0100:0000 \n" + " -p, --port, coord server port, by default 9981; jetty id for UBC, e.g. 998\n" + " -t, --pingpongtimes, ping pong times\n" + " -s, --size, max data size\n" + " -e, --ep size, connect and run ep size\n" + " -m, --ep mode, connect and run ep mode\n" + " -o, --timeout thread, range [1, 4]\n" + " -c, --cpuId, async worker\n" + " -r, --enableMultiRail, enable multiRail\n" + " -R, --multiRailThresh, multiRail threshhold\n" + " -v, --RndvThreshold, Perf case only supports an RNDV threshold of less than 1048576, actual " + "scenario requires a value less than UINT32_MAX\n" + " -u, --ubcMode, UB-C mode, 0 means LowLatency, other value means HighBandwidth\n" + " -S, --splitSendThreshold, the threshold of split send, UINT32_MAX by default\n"; + + + int ret = 0; + int index = 0; + + std::string str = "d:i:p:t:s:e:m:o:c:r:R:v:u:S:"; + while ((ret = getopt_long(argc, argv, str.c_str(), options, &index)) != -1) { + switch (ret) { + case 'd': + driverType = static_cast((uint16_t)strtoul(optarg, NULL, 0)); + if (driverType > UBC) { + printf("invalid driver type %d", driverType); + return -1; + } + break; + case 'i': + oobIp = optarg; + ipSeg = oobIp + "/24"; + break; + case 'p': + oobPort = (uint16_t)strtoul(optarg, NULL, 0); + break; + case 't': + pingCount = (int32_t)strtoul(optarg, NULL, 0); + break; + case 's': + dataSize = (int32_t)strtoul(optarg, NULL, 0); + break; + case 'e': + epSize = (int32_t)strtoul(optarg, NULL, 0); + break; + case 'm': + mode = (uint64_t)strtoul(optarg, NULL, 0); + break; + case 'o': + periodThreadCnt = (uint64_t)strtoul(optarg, NULL, 0); + break; + case 'c': + asyncWorkerCpuId = strtoul(optarg, nullptr, 0); + break; + case 'r': + multiRailEnable = (uint32_t)strtoul(optarg, nullptr, 0); + break; + case 'R': + multiRailThresh = (uint32_t)strtoul(optarg, nullptr, 0); + break; + case 'v': + rndvThreshold = (uint32_t)strtoul(optarg, NULL, 0); + break; + case 'u': + mUbcMode = std::stoi(optarg) ? UBSHcomUbcMode::HighBandwidth : UBSHcomUbcMode::LowLatency; + break; + case 'S': + splitSendThreshold = (uint32_t)strtoul(optarg, nullptr, 0); + break; + } + } + + Run(); + return 0; +} +``` + +2. 使用以下命令运行代码,启动Client端。 + +`./*pp_service_client_simple* -d 1 -i 127.0.0.1 -p 9980` + +*pp_service_client_simple*:编译后可执行文件名,请根据实际情况进行修改。 + +- -d:配置Driver类型。 + + 1. 0:RDMA + + 2. 1:TCP + + 3. 2:UDS + + 4. 3:SHM + + 5. 7:UBC + +- -i:Server IP地址。 + +- -p: 监听的端口 + +- 其余参数可以根据实际情况自由配置 + +## 传输层 + +### 说明 + +本章节将通过基础示例来演示如何使用UBS Comm,开发者可以通过学习此指导来快速上手UBS Comm。UBS Comm向开发者提供了传输层和服务层,因此使用指导也将分别提供一个示例代码来演示如何使用传输层和服务层。 + +### 服务端 + +1. 使用NetDriver::Instance创建一个Driver的对象。 + +`UBSHcomNetDriver \*driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "server1", true);` + +此处创建了一个使用RDMA协议的名为server1的服务端Driver。true代表启动监听线程,可以接受其他Driver对象的建链请求。 + +1. 设置NetDriverOptions选项,使用Driver对象注册回调函数,并用Driver的OobIpAndPort方法设置需要侦听的IP地址和端口。 + +```cpp +UBSHcomNetDriverOptions options {}; + +driver->RegisterNewEPHandler(std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); +driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); +driver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); +driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); +driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + +driver->OobIpAndPort(oobIp, oobPort); +``` + +- NetDriverOptions的参数,详情请参见《UBS-Comm-API-Spec.md》的“UBSHcomNetDriver::Initialize”章节。 + +- 注册回调函数,详情请参见《UBS-Comm-API-Spec.md》的“UBSHcomNetDriver::RegisterTLSCaCallback”章节和“TLSEraseKeypass函数类型”章节。 + +- OobIpAndPort用来设置需要侦听的IP地址和端口。 + + 1. 使用设置好的NetDriverOptions选项作为参数来调用Driver的Initialize方法,然后调用Driver的Start方法,完成服务端的启动。 + +```CPP +driver->Initialize(options); +driver->Start(); +``` + +----结束 + +### 客户端 + +1. 使用NetDriver::Instance创建一个Driver的对象。 + +`UBSHcomNetDriver *driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "client1", false);` + +第三个参数可以为false,因为客户端通常不需要被建链,无需启动监听线程。 + +1. 设置NetDriverOptions选项,使用Driver对象注册回调函数,并用Driver的OobIpAndPort方法设置需要建立连接的IP地址和端口。若不启动监听线程,则RegisterNewEPHandler可以不注册,但其它四个回调函数依旧需要注册。 + +``` +UBSHcomNetDriverOptions options {}; + +driver->RegisterNewEPHandler(std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); +driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); +driver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); +driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); +driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + +driver->OobIpAndPort(oobIp, oobPort); +``` + + + +2. 使用设置好的选项NetDriverOptions作为参数来调用Driver的Initialize方法,然后调用Driver的Start方法,完成客户端的启动。 + +``` +driver->Initialize(options); +driver->Start(); +``` + + + +----结束 + +### 服务端和客户端启动后 + +1. 当服务端和客户端都完成启动后,客户端的Driver可以调用Connect方法来连接服务端。 + +driver-\>Connect("hello world", ep, 0); + +- "hello world":连接时本端发送给对端的一条消息,对端可以在NewEndpoint回调函数的第三个参数中获取该消息。 + +- ep:Connect函数的返回值,即为得到链路的本端,服务端的NetEndpoint可以在NewEndpoint回调函数的第二个参数中获得。 + +- 0:链路类型。 + +  + +- 0:表示异步NetEndpoint。 + +- 1:表示同步NetEndpoint。 + +- 2:代表在RDMA协议中,同步'EventPoll的NetEndpoint。 + + 1. 连接完成后,客户端和服务端都会得到一个NetEndpoint对象,服务端和客户端都可以使用该对象来调用各种消息发送接口向对端发送消息。 + +UBSHcomNetTransRequest req((void \*)(data), sizeof(data), 0); +ep-\>PostSend(1, req); + +此处仅以PostSend为例,更多消息发送接口,请参见《UBS-Comm-API-Spec.md》的“UBSHcomNetEndpoint::PostSend”章节、“UBSHcomNetEndpoint::WaitCompletion”章节“UBSHcomNetEndpoint::PostSendRaw”章节和“UBSHcomNetEndpoint::ReceiveRawSgl”章节。 + +- 1:用户指定的opCode,取值范围0 ~ 1023。 + +- req:需要发送内容的结构体,结构体中的data为发送消息体。 + +----结束 + +### 传输层编程 + +此示例仅限帮助开发者具象化理解如何使用UBS Comm,作为实际使用场景的参考,请勿直接复制使用。 + +1. Sever端示例 + +以下为传输层Server端的完整示例代码。 + +1. 当Server端收到Client端的消息时,会调用初始化时注册的回调函数RequestReceived,可以在回调函数中给Client端回复消息。 + +``` +#include +#include +#include "hcom_service.h" + +using namespace ock::hcom; + +UBSHcomNetDriver *driver = nullptr; +UBSHcomNetEndpointPtr ep = nullptr; +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t lKey = 0; + uint32_t size = 0; +} __attribute__((packed)); +TestRegMrInfo localMrInfo[4]; +TestRegMrInfo remoteMrInfo[4]; + +std::string ipSeg = "192.168.100.0/24"; +std::string oobIp = ""; +uint16_t oobPort = 9980; +int16_t asyncWorkerCpuId = -1; + +UBSHcomNetDriverProtocol driverType = RDMA; +std::string udsName = "SHM_UDS"; +int32_t dataSize = 1024; +int32_t workerMode = 0; +void *data = nullptr; + +int NewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload << " id " << newEP->Id()); + ep = newEP; + return 0; +} + +void EndPointBroken(const UBSHcomNetEndpointPtr &netEp) +{ + NN_LOG_INFO("end point " << netEp->Id() << " is broken"); + if (ep != nullptr && netEp->Id() == ep->Id()) { + ep.Set(nullptr); + } +} + +UBSHcomNetTransSgeIov iovPtr[4]; +int RequestReceived(const UBSHcomNetRequestContext &ctx) +{ + int result = 0; + if (driverType == 1 || driverType == 2) { + if ((ctx.Header().opCode == 0) && (ctx.Header().flags == NTH_TWO_SIDE) && (ctx.Header().immData == 0)) { + goto postSend1; + } else if ((ctx.Header().opCode == 1) && (ctx.Header().flags == NTH_TWO_SIDE) && (ctx.Header().immData == 0)) { + goto postSend2; + } else if ((ctx.Header().seqNo == 1) && (ctx.Header().flags == NTH_TWO_SIDE) && (ctx.Header().immData == 1)) { + goto PostSendRaw; + } else if ((ctx.Header().seqNo == 2) && (ctx.Header().flags == NTH_TWO_SIDE_SGL)) { + goto PostSendRawSgl; + } + } + if (ctx.Header().opCode == 0) { + postSend1: + UBSHcomNetTransRequest rsp((void *)(localMrInfo), sizeof(localMrInfo), 0); + if ((result = ep->PostSend(0, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + return 0; + } else if (ctx.Header().opCode == 1) { + postSend2: + UBSHcomNetTransRequest req(data, dataSize, 0); + if ((result = ep->PostSend(1, req)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + return 0; + }else if (ctx.Header().seqNo == 1) { + PostSendRaw: + UBSHcomNetTransRequest req(data, dataSize, 0); + if ((result = ep->PostSendRaw(req, ctx.Header().seqNo)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + return 0; + } else if (ctx.Header().seqNo == 2) { + PostSendRawSgl: + UBSHcomNetTransSglRequest req(iovPtr, 4, 0); + if ((result = ep->PostSendRawSgl(req, ctx.Header().seqNo)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + return 0; + } + + return 0; +} + +int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + if (ctx.Result() != NN_OK) { + NN_LOG_ERROR("Post send err"); + } + NN_LOG_TRACE_INFO("RequestPosted"); + return 0; +} + +int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + +bool CreateDriver() +{ + if (driver != nullptr) { + NN_LOG_ERROR("driver already created"); + return false; + } + driver = UBSHcomNetDriver::Instance(driverType, "pp_transport_server", true); + if (driver == nullptr) { + NN_LOG_ERROR("failed to create driver already created"); + return false; + } + + UBSHcomNetDriverOptions options{}; + options.mode = static_cast(workerMode); + options.mrSendReceiveSegSize = dataSize * 4 + 32; + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + if (asyncWorkerCpuId != -1) { + std::string str = std::to_string(asyncWorkerCpuId) + "-" + std::to_string(asyncWorkerCpuId); + options.SetWorkerGroupsCpuSet(str); + NN_LOG_INFO("set cpuId " << options.WorkerGroupCpus()); + } + + if (driverType == ock::hcom::SHM) { + options.oobType = NET_OOB_UDS; + UBSHcomNetOobUDSListenerOptions listenOpt; + listenOpt.Name(udsName); + listenOpt.perm = 0; + driver->AddOobUdsOptions(listenOpt); + } + + options.SetNetDeviceIpMask(ipSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + driver->OobIpAndPort(oobIp, oobPort); + + driver->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + int result = 0; + if ((result = driver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("driver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start driver " << result); + return false; + } + NN_LOG_INFO("driver started"); + + return true; +} + +bool RegSglMem() +{ + // write read + for (uint16_t i = 0; i < 4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(dataSize, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + localMrInfo[i].lAddress = mr->GetAddress(); + localMrInfo[i].lKey = mr->GetLKey(); + localMrInfo[i].size = dataSize; + memset(reinterpret_cast(localMrInfo[i].lAddress), 0, dataSize); + } + + // sendsgl + for (uint16_t i = 0; i < 4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(dataSize, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + + iovPtr[i].lAddress = mr->GetAddress(); + iovPtr[i].lKey = mr->GetLKey(); + iovPtr[i].size = dataSize; + memset(reinterpret_cast(iovPtr[i].lAddress), 0, dataSize); + } + + return true; +} + +void SendRequest() +{ + NN_LOG_INFO("input q means quit."); + while (true) { + auto tmpChar = getchar(); + switch (tmpChar) { + case 'q': + return; + default: + NN_LOG_INFO("input q means quit."); + continue; + } + } +} + +void Run() +{ + if (!CreateDriver()) { + return; + } + + if (!RegSglMem()) { + return; + } + + SendRequest(); +} + +int main(int argc, char *argv[]) +{ + struct option options[] = { + {"driver", required_argument, NULL, 'd'}, + {"ip", required_argument, NULL, 'i'}, + {"port", required_argument, NULL, 'p'}, + {"size", required_argument, NULL, 's'}, + {"worker Mode", required_argument, NULL, 'w'}, + {"worker Num", required_argument, NULL, 'n'}, + {"cpuId", required_argument, NULL, 'c'}, + {NULL, 0, NULL, 0}, + }; + + const char *usage = "usage\n" + " -d, --driver, driver type, 0 means rdma, 1 means tcp, 2 means uds, 3 means shm\n" + " -i, --ip, server ip mask, e.g. 10.175.118.1\n" + " -p, --port, server port, by default 9980\n" + " -s, --io size , max data size\n" + " -w, --worker mode, 0 means busy polling, 1 means event polling\n" + " -c, --cpuId, async worker\n"; + + int ret = 0; + int index = 0; + + if (argc != 13) { + printf("invalid param, %s, for example %s -d 0 -i rdma_nic_ip -p 9980 -s 1024 -w 0 -c 5\n", usage, argv[0]); + return -1; + } + + std::string str = "d:i:p:s:w:c:"; + while ((ret = getopt_long(argc, argv, str.c_str(), options, &index)) != -1) { + switch (ret) { + case 'd': + driverType = static_cast((uint16_t)strtoul(optarg, NULL, 0)); + if (driverType > UBC) { + printf("invalid driver type %d", driverType); + return -1; + } + break; + case 'i': + oobIp = optarg; + ipSeg = oobIp + "/24"; + break; + case 'p': + oobPort = (uint16_t)strtoul(optarg, NULL, 0); + break; + case 's': + dataSize = (int32_t)strtoul(optarg, NULL, 0); + break; + case 'w': + workerMode = (int32_t)strtoul(optarg, NULL, 0); + break; + case 'c': + asyncWorkerCpuId = strtoul(optarg, nullptr, 0); + break; + } + } + data = malloc(dataSize); + Run(); + free(data); + return 0; +} +``` + + + +1. 使用以下命令运行代码,启动Server端。 + +`./*pp_server* -i 127.0.0.1 -p 9980 -c -1` + +- *pp_server*:编译后可执行文件名,请根据实际情况进行修改。 + +- -i:Server IP地址。 + +- -p:Server端口号。 + +- -c:Worker绑定CPU,-1表示不绑核。 + +----结束 + +1. Client端示例 + +以下为传输层一个完整的Client端示例。 + +1. 初始化流程和Server端基本一致。入口为main函数,经过参数解析后,进入Run函数。 + +``` +#include +#include +#include "hcom_service.h" +#include "net_monotonic.h" +#include +#include + +using namespace ock::hcom; + +UBSHcomNetDriver *driver = nullptr; +UBSHcomNetEndpointPtr ep = nullptr; +std::string oobIp = ""; +uint16_t oobPort = 9980; +std::string ipSeg = "192.168.100.0/24"; +std::string dumpStr = ""; +std::string udsName = "SHM_UDS"; +UBSHcomNetDriverProtocol driverType = RDMA; +int32_t dataSize = 1024; +int16_t asyncWorkerCpuId = -1; +uint64_t mode = 0; +uint32_t flags = 0; +bool start = false; +uint64_t startTime = 0; +uint64_t finishTime = 0; + +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t lKey = 0; + uint32_t size = 0; +} __attribute__((packed)); +TestRegMrInfo localMrInfo[4]; +TestRegMrInfo remoteMrInfo[4]; +UBSHcomNetTransRequest iov[4]; +int32_t pingCount = 100000; +int32_t pingCount1 = 100000; +int seqNo = 1; +int workerMode = 0; +sem_t sem; +void* data = nullptr; +void printPerf() +{ + finishTime = MONOTONIC_TIME_NS(); + NN_LOG_INFO("Finished " << pingCount1 << " pingpong"<<" ,startTime:"<Id() << " broken"); + if (ep != nullptr) { + ep.Set(nullptr); + } +} + + +void SendRequest() +{ + int result = 0; + UBSHcomNetTransRequest req(data, dataSize, 0); + + if (pingCount-- == 0) { + printPerf(); + sem_post(&sem); + return; + } + if ((result = ep->PostSend(1, req)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return; + } +} + +void SyncSendRequest() +{ + int result = 0; + UBSHcomNetTransRequest req(data, dataSize, 0); + UBSHcomNetResponseContext respCtx{}; + startTime = MONOTONIC_TIME_NS(); + + uint32_t count = 0; + for (int32_t i = 0; i < pingCount; i++) { + count++; + if ((result = ep->PostSend(1, req)) != 0) { + if (result == 314) { + NN_LOG_ERROR("post message to data to server successfully,but fail to post message to client"); + return; + } + NN_LOG_ERROR("failed to post message to data to server"); + break; + } + if ((result = ep->WaitCompletion(2)) != 0) { + NN_LOG_ERROR("failed to get WaitCompletion, result " << result); + break; + } + + if ((result = ep->Receive(2, respCtx)) != 0) { + NN_LOG_ERROR("failed to get response, result " << result); + break; + } + } + printPerf(); + sem_post(&sem); + return; +} + +void SendRawRequest() +{ + int result = 0; + UBSHcomNetTransRequest req(data, dataSize, 0); + + if (pingCount-- == 0) { + printPerf(); + sem_post(&sem); + return; + } + if ((result = ep->PostSendRaw(req, 1)) != 0) { + NN_LOG_INFO("failed to post message to data to server, result " << result); + return; + } +} + +void SyncSendRawRequest() +{ + int result = 0; + UBSHcomNetTransRequest req(data, dataSize, 0); + UBSHcomNetResponseContext respCtx{}; + startTime = MONOTONIC_TIME_NS(); + + for (int32_t i = 0; i < pingCount; i++) { + if ((result = ep->PostSendRaw(req, 1)) != 0) { + if (result == 314) { + NN_LOG_ERROR("post message to data to server successfully,but fail to post message to client"); + return; + } + NN_LOG_ERROR("failed to post message to data to server"); + break; + } + if ((result = ep->WaitCompletion(2)) != 0) { + NN_LOG_ERROR("failed to get WaitCompletion, result " << result); + break; + } + + if ((result = ep->ReceiveRaw(2, respCtx)) != 0) { + NN_LOG_ERROR("failed to get response, result " << result); + break; + } + } + printPerf(); + sem_post(&sem); + return; +} +UBSHcomNetTransSgeIov iovPtr[4]; +void SendRawSglRequest() +{ + int result = 0; + UBSHcomNetTransSglRequest req(iovPtr, 4, 0); + + if (pingCount-- == 0) { + printPerf(); + sem_post(&sem); + return; + } + if ((result = ep->PostSendRawSgl(req, 2)) != 0) { + NN_LOG_INFO("failed to post message to data to server, result " << result); + return; + } +} + +void SyncSendRawSglRequest() +{ + int result = 0; + UBSHcomNetTransSglRequest req(iovPtr, 4, 0); + UBSHcomNetResponseContext respCtx{}; + startTime = MONOTONIC_TIME_NS(); + + for (int32_t i = 0; i < pingCount1; i++) { + if ((result = ep->PostSendRawSgl(req, 2)) != 0) { + if (result == 314) { + NN_LOG_ERROR("post message to data to server successfully,but fail to post message to client"); + return; + } + NN_LOG_ERROR("failed to post message to data to server"); + break; + } + if ((result = ep->WaitCompletion(2)) != 0) { + NN_LOG_ERROR("failed to get WaitCompletion, result " << result); + break; + } + + if ((result = ep->ReceiveRawSgl(respCtx)) != 0) { + NN_LOG_ERROR("failed to get response, result " << result); + break; + } + } + printPerf(); + sem_post(&sem); + return; +} + +void ReadRequest() +{ + for (int32_t i = 0; i < pingCount1; i++) { + if (ep->PostRead(iov[0]) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + } +} + +void SyncReadRequest() +{ + for (int32_t i = 0; i < pingCount1; i++) { + if (ep->PostRead(iov[0]) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + if (ep->WaitCompletion(-1) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + } + printPerf(); + sem_post(&sem); + return; +} + +void ReadSglRequest() +{ + UBSHcomNetTransSgeIov segIov[4]; + for (uint16_t i = 0; i < 4; i++) { + segIov[i].lAddress = iov[i].lAddress; + segIov[i].rAddress = iov[i].rAddress; + segIov[i].lKey = iov[i].lKey; + segIov[i].rKey = iov[i].rKey; + segIov[i].size = iov[i].size; + } + UBSHcomNetTransSglRequest reqRead(segIov, 4, 0); + for (int i = 0; i < pingCount1; ++i) { + if (ep->PostRead(reqRead) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + } +} + +void SyncReadSglRequest() +{ + UBSHcomNetTransSgeIov segIov[4]; + for (uint16_t i = 0; i < 4; i++) { + segIov[i].lAddress = iov[i].lAddress; + segIov[i].rAddress = iov[i].rAddress; + segIov[i].lKey = iov[i].lKey; + segIov[i].rKey = iov[i].rKey; + segIov[i].size = iov[i].size; + } + UBSHcomNetTransSglRequest reqRead(segIov, 4, 0); + for (int32_t i = 0; i < pingCount1; i++) { + if (ep->PostRead(reqRead) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + if (ep->WaitCompletion(-1) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + } + printPerf(); + sem_post(&sem); + return; +} + +void WriteRequest() +{ + for (int32_t i = 0; i < pingCount1; i++) { + if (ep->PostWrite(iov[0]) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + } +} + +void SyncWriteRequest() +{ + for (int32_t i = 0; i < pingCount1; i++) { + if (ep->PostWrite(iov[0]) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + if (ep->WaitCompletion(-1) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + } + printPerf(); + sem_post(&sem); + return; +} + +void WriteSglRequest() +{ + UBSHcomNetTransSgeIov segIov[4]; + for (uint16_t i = 0; i < 4; i++) { + segIov[i].lAddress = iov[i].lAddress; + segIov[i].rAddress = iov[i].rAddress; + segIov[i].lKey = iov[i].lKey; + segIov[i].rKey = iov[i].rKey; + segIov[i].size = iov[i].size; + } + UBSHcomNetTransSglRequest reqRead(segIov, 4, 0); + for (int i = 0; i < pingCount1; ++i) { + if (ep->PostWrite(reqRead) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + } +} + +void SyncWriteSglRequest() +{ + UBSHcomNetTransSgeIov segIov[4]; + for (uint16_t i = 0; i < 4; i++) { + segIov[i].lAddress = iov[i].lAddress; + segIov[i].rAddress = iov[i].rAddress; + segIov[i].lKey = iov[i].lKey; + segIov[i].rKey = iov[i].rKey; + segIov[i].size = iov[i].size; + } + UBSHcomNetTransSglRequest reqRead(segIov, 4, 0); + for (int32_t i = 0; i < pingCount1; i++) { + if (ep->PostWrite(reqRead) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + if (ep->WaitCompletion(-1) != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + } + printPerf(); + sem_post(&sem); + return; +} + +int RequestReceived(const UBSHcomNetRequestContext &ctx) +{ + if (driverType == 1 || driverType == 2) { + if ((ctx.Header().opCode == 0) && (ctx.Header().flags == NTH_TWO_SIDE) && (ctx.Header().immData == 0)) { + goto postSend1; + } else if ((ctx.Header().opCode == 1) && (ctx.Header().flags == NTH_TWO_SIDE) && (ctx.Header().immData == 0)) { + goto postSend2; + } else if ((ctx.Header().seqNo == 1) && (ctx.Header().flags == NTH_TWO_SIDE) && (ctx.Header().immData == 1)) { + goto PostSendRaw; + } else if ((ctx.Header().seqNo == 2) && (ctx.Header().flags == NTH_TWO_SIDE_SGL)) { + goto PostSendRawSgl; + } + } + + if (ctx.Header().opCode == 0) { + postSend1: + memcpy(remoteMrInfo, ctx.Message()->Data(), ctx.Message()->DataLen()); + sem_post(&sem); + return 0; + }else if (ctx.Header().opCode == 1) { + postSend2: + SendRequest(); + return 0; + }else if (ctx.Header().seqNo == 1) { + PostSendRaw: + SendRawRequest(); + return 0; + } else if (ctx.Header().seqNo == 2) { + PostSendRawSgl: + SendRawSglRequest(); + } + return 0; +} + +int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + if (--pingCount == 0) { + printPerf(); + sem_post(&sem); + } + return 0; +} + +void exitFunc() +{ + driver->Stop(); + driver->UnInitialize(); +} + +bool CreateDriver() +{ + if (driver != nullptr) { + NN_LOG_ERROR("driver already created"); + return false; + } + + driver = UBSHcomNetDriver::Instance(driverType, "transport_pp_client", false); + if (driver == nullptr) { + NN_LOG_ERROR("failed to create driver already created"); + return false; + } + + atexit(exitFunc); + UBSHcomNetDriverOptions options{}; + options.mode = static_cast(workerMode); + options.mrSendReceiveSegSize = dataSize * 4 + 32; + options.mrSendReceiveSegCount = 10; + if (mode == 1) { + options.dontStartWorkers = true; + } + if (driverType == SHM) { + options.oobType = NET_OOB_UDS; + } + if (asyncWorkerCpuId != -1) { + std::string str = std::to_string(asyncWorkerCpuId) + "-" + std::to_string(asyncWorkerCpuId); + options.SetWorkerGroupsCpuSet(str); + NN_LOG_INFO(" set cpuId: " << options.WorkerGroupCpus()); + } + options.SetNetDeviceIpMask(ipSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + driver->OobIpAndPort(oobIp, oobPort); + int result = 0; + if ((result = driver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("driver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start driver " << result); + return false; + } + NN_LOG_INFO("driver started"); + sem_init(&sem, 0, 0); + return true; +} + +bool Connect() +{ + if (driver == nullptr) { + NN_LOG_ERROR("driver is null"); + return false; + } + + int result = 0; + if (mode == 1) { + flags = NET_EP_SELF_POLLING; + } + + if (driverType == SHM) { + result = driver->Connect(udsName, 0, "hello server", ep, flags); + } else { + result = driver->Connect(oobIp, oobPort, "hello server", ep, flags); + } + + if (result != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + NN_LOG_INFO("success to connect to server, ep id " << ep->Id()); + return true; +} + +int userChar = 0; +int startTime1=0; +void RunInThread() +{ + while (!start) { + usleep(1); + } + pingCount = pingCount1; + startTime1 = MONOTONIC_TIME_NS(); + switch (userChar) { + case '0': + NN_LOG_INFO("Wait for finish, Type post send:"); + startTime = MONOTONIC_TIME_NS(); + NN_LOG_INFO("******startTime: "<Close(); + break; + default: + NN_LOG_INFO("input 0:send, 1:send raw, 2:send raw sgl, 3:read, 4:read sgl, 5:write " + "6:write sgl, q mean quit, c ep close"); + continue; + } + + if (userChar == 'c') { + continue; + } + } +} + +bool GetRemoteMr() +{ + int result = 0; + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + UBSHcomNetResponseContext respCtx{}; + if ((result = ep->PostSend(0, req)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + return false; + } + if (mode == 1) { + if ((result = ep->WaitCompletion(2)) != 0) { + NN_LOG_ERROR("failed to get WaitCompletion, result " << result); + return false; + } + + if ((result = ep->Receive(2, respCtx)) != 0) { + NN_LOG_ERROR("failed to get response, result " << result); + return false; + } + memcpy(remoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + sem_post(&sem); + } + + return true; +} + +bool RegSglMem() +{ + // write read + for (uint16_t i = 0; i < 4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(dataSize, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + localMrInfo[i].lAddress = mr->GetAddress(); + localMrInfo[i].lKey = mr->GetLKey(); + localMrInfo[i].size = dataSize; + memset(reinterpret_cast(localMrInfo[i].lAddress), 0, dataSize); + } + + // sendsgl + for (uint16_t i = 0; i < 4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(dataSize, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + + iovPtr[i].lAddress = mr->GetAddress(); + iovPtr[i].lKey = mr->GetLKey(); + iovPtr[i].size = dataSize; + memset(reinterpret_cast(iovPtr[i].lAddress), 0, dataSize); + } + + return true; +} + + +void Run() +{ + if (!CreateDriver()) { + return; + } + + if (!Connect()) { + return; + } + if (!RegSglMem()) { + return; + } + if (!GetRemoteMr()) { + return; + } + sem_wait(&sem); + + for (int i = 0; i < 4; ++i) { + iov[i].lAddress = localMrInfo[i].lAddress; + iov[i].rAddress = remoteMrInfo[i].lAddress; + iov[i].lKey = localMrInfo[i].lKey; + iov[i].rKey = remoteMrInfo[i].lKey; + iov[i].size = localMrInfo[i].size; + } + + Test(); +} + +int main(int argc, char *argv[]) +{ + struct option options[] = { + {"driver", required_argument, NULL, 'd'}, + {"ip", required_argument, NULL, 'i'}, + {"port", required_argument, NULL, 'p'}, + {"pingpongtimes", required_argument, NULL, 't'}, + {"size", required_argument, NULL, 's'}, + {"epMode", required_argument, NULL, 'm'}, + {"workerMode", required_argument, NULL, 'w'}, + {"cpuId", required_argument, NULL, 'c'}, + {NULL, 0, NULL, 0}, + }; + + const char *usage = "usage\n" + " -d, --driver, driver type, 0 means rdma, 1 means tcp, 2 means uds, 3 means shm\n" + " -i, --ip, coord server ip mask, e.g. 10.175.118.1\n" + " -p, --port, coord server port, by default 9980\n" + " -t, --pingpongtimes, ping pong times\n" + " -s, --size, max data size\n" + " -m, --ep mode, 0 means worker polling(Async), 1 means self polling(Sync)\n" + " -w, --worker mode, 0 means busy polling, 1 means event polling()\n" + " -c, --cpuId, async worker\n"; + + int ret = 0; + int index = 0; + + if (argc != 17) { + printf("invalid param, %s, for example %s -d 0 -i rdma_nic_ip -p 9980 -t 1000000 -s 1024 -m 0 -w 1 -c 5\n", + usage, argv[0]); + return -1; + } + + std::string str = "d:i:p:t:s:m:w:c:"; + while ((ret = getopt_long(argc, argv, str.c_str(), options, &index)) != -1) { + switch (ret) { + case 'd': + driverType = static_cast((uint16_t)strtoul(optarg, NULL, 0)); + if (driverType > UBC) { + printf("invalid driver type %d", driverType); + return -1; + } + break; + case 'i': + oobIp = optarg; + ipSeg = oobIp + "/24"; + break; + case 'p': + oobPort = (uint16_t)strtoul(optarg, NULL, 0); + break; + case 't': + pingCount = (int32_t)strtoul(optarg, NULL, 0); + break; + case 's': + dataSize = (int32_t)strtoul(optarg, NULL, 0); + break; + case 'm': + mode = (uint64_t)strtoul(optarg, NULL, 0); + break; + case 'w': + workerMode = (uint64_t)strtoul(optarg, NULL, 0); + break; + case 'c': + asyncWorkerCpuId = strtoul(optarg, nullptr, 0); + break; + } + } + data = malloc(dataSize); + pingCount1 = pingCount; + Run(); + free(data); + return 0; +} +``` + +使用以下命令运行代码,启动Client端。 + +`./*pp_client* -i 127.0.0.1 -p 9980 -t 10000 -c -1` + +- *pp_client*:编译后可执行文件名,请根据实际情况进行修改。 + +- -i:Server端IP地址,127.0.0.1。 + +- -p:Server端口号 + +- -t:Pingpong次数。 + +- -c:Worker绑定CPU,-1表示不绑核。 + +----结束 + +- + +# 安全管理 + +1. 推荐环境变量配置 + +环境变量配置,请参见《UBS-COMM-API-Spec》的“环境变量参考”章节。 + +2. 防病毒软件例行检查 + +定期开展对集群和UBS Comm组件的防病毒扫描是十分必要的,防病毒软件例行检查会帮助集群免受病毒、恶意代码、间谍软件以及恶意程序侵害,减少系统瘫痪、信息泄露等安全风险。可以使用业界主流的防病毒软件进行防病毒检查。 + +3. 漏洞修复 + +为保证环境安全,降低被攻击的风险,请开启防火墙,并定期修复以下漏洞。 + +- 操作系统漏洞 + +- rdma-core漏洞 + +- OpenSSL漏洞 + +- 其他相关组件漏洞 + + 1. 证书安全管理 + +- 需使用X509v3格式的证书,并使用安全的证书签名算法。 + +- 证书应设置合理的有效期,允许华为设备预置证书的有效期略长于产品生命周期。 + +- 证书的私钥要使用基于口令的加密机制保存,私钥保护口令应满足复杂度要求并加密保存,同时控制私钥文件和证书文件的访问权限。 + +- 必须验证对端证书的有效性,必须验证项包括对端证书是否由受信根CA签发、是否在有效期内、是否已被吊销。 + +- 使用安全随机数生成密钥对,且必须使用至少2048位,推荐使用3072位的RSA密钥对(第三方CA签发证书、与第三方系统对接、兼容老版本等场景可例外)。 + +- 在使用数字证书进行内层软件完整性保护时,必须防止用于验证软件完整性的根证书被篡改。 + +#### 安全申明 + +对于UB通信建链方式,做出如下安全申明: + +- 推荐使用以太网卡+TCP方式建链,默认开启TLS安全认证 + +- 不依赖以太网卡的公知jetty建链方式当前不支持TLS安全认证,后续通过补丁版本支持IPoverURMA解决安全认证问题。 + + 1. 无属主文件安全加固 + +用户可以执行find / -nouser -nogroup命令,查找容器内或物理机上的无属主文件。根据文件的UID和GID创建相应的用户和用户组,或者修改已有用户的UID、用户组的GID来适配,赋予文件属主,避免无属主文件给系统带来安全隐患。 + +# UBS Comm库文件链接方法参考 + +UBS Comm以头文件和库文件的形式提供给开发者集成和使用,开发者可以根据自己的实际项目需要选择使用动态库或静态库。 + +\# 链接动态库 +`gcc -o \<输出文件名称\> \<被链接的文件\> -L\<动态库路径\> -lhcom -lstdc++ -I\ ` +\# 链接静态库 +`gcc -o \<输出文件名称\> \<被链接的文件\> -L\<静态库路径\> -lhcom_static -lm -lstdc++ -I\` + +静态库在编译期就已经被链接到可执行文件中,无需像动态库一样在运行期加载,故执行效率更高。但静态库会增加可执行文件大小,多个程序同时使用同一静态库时,会造成存储资源浪费。另外库文件更新时,使用动态库场景可以仅更新动态库文件,使用静态库场景必须重新编译应用程序。 + +# 公网地址声明 + +以下表格中列出了当前产品中包含的公网地址,不涉及安全风险。 + +| 网址 | 说明 | +|----|----| +| https://gcc.gnu.org/bugs/ | 该网址为开源软件GCC编译引入,无安全风险。 | +| http://license.coscl.org.cn/MulanPSL2 | 该网址为版权声明license,无安全风险 | + +# 术语和缩略语 + +| 缩略语 | 英文全称 | **说明** | +|----|----|----| +| CQ | Completion Queue | 完成队列。 | +| CRC | Cyclic Redundancy Code | 循环冗余码,一种线性检错码,通过多项式除法的余数来生成奇偶校验位。 | +| CNP | Congestion Notification Packet | 拥塞通知报文。 | +| DM | Device Memory | 设备内存。 | +| DSCP | Differentiated Services Code Point | 区分服务编码点,根据Diff-Serv(Differentiated Service)的QoS分类标准,在每个数据包IP头部的服务类别TOS字节中,利用已使用的6比特和未使用的2比特,通过编码值来区分优先级。DSCP是TOS字节中已使用6比特的标识,是“IP优先”和“服务类型”字段的组合。为了利用只支持“IP优先”的旧路由器,会使用DSCP值,因为DSCP值与“IP优先”字段兼容。每一个DSCP编码值都被映射到一个已定义的PHB(Per-Hop-Behavior)标识码。通过键入DSCP值,终端设备可对流量进行标识。 | +| DCQCN | Data Center Quantized Congestion Notification | 数据中心网络的拥塞控制算法。 | +| EP | Endpoint | 端点,数据信源和数据信宿,是运行在物理链路上的虚拟链路。 | +| MR | Memory Region | 内存区域。 | +| QP | Queue Pair | 队列对。 | +| RDMA | Remote Direct Memory Access | 远程直接存储器访问,从一台计算机的存储器直接访问另一台计算机的存储器的技术。它使得网卡能够直接访问应用存储器,支持零拷贝网络通信。 | +| RPC | Remote Procedure Call | 远程过程调用,是一个计算机通信协议。该协议允许运行于一台计算机的程序调用另一台计算机的子程序,而程序员无需额外地为这个交互作用编程。如果涉及的软件采用面向对象编程,那么远程过程调用亦可称作远程调用或远程方法调用。 | +| SHM | Shared Memory | 共享内存,在计算机硬件中,共享内存通常是指大量的无序访问内存。该内存能够被多处理器电脑系统的多个不同的CPU访问。 | +| TCP | Transmission Control Protocol | 传输控制协议,TCP/IP中的协议,用于将数据信息分解成信息包,使之经过IP协议发送;并对利用IP协议接收来的信息包进行校验并将其重新装配成完整的信息。TCP是面向连接的可靠协议,能够确保信息的无误发送,它与ISO/OSI基准模型中的传输层相对应。 | +| UDS | Unix Domain Socket | UNIX域套接字,是一种在同一台计算机上的进程间通信机制。 | +| UBC | Unified Bus Clan | UBC协议。 | diff --git a/doc/UBS-Comm-Tutorial-UseCase.md b/doc/UBS-Comm-Tutorial-UseCase.md new file mode 100644 index 0000000000000000000000000000000000000000..fb690bcba706020670231a546cdd2bd0ed95ed45 --- /dev/null +++ b/doc/UBS-Comm-Tutorial-UseCase.md @@ -0,0 +1,260 @@ +| | | | | | +|:--:|:--:|:--:|:--:|:--:| +| | | | | | +| | **UBS Comm** **UseCase使用场景描述** | | | | +| | **文档版本** | **1** | | | +| | **发布日期** | **2025-09-30** | | | +| ![](media/image1.png) | | | | | +| | 华为技术有限公司 | | ![](media/image2.png) | | + +[TABLE] + +| 华为技术有限公司 | | +|------------------|---------------------------------------------| +| 地址: | 深圳市龙岗区坂田华为总部办公楼 邮编:518129 | +| 网址: | | +| 客户服务邮箱: | | +| 客户服务电话: | 4008302118 | + +[TABLE] + +# 前言 + +## 概述 + +本文档详细地描述了UBS Comm的整体架构、应用场景、关键特性等信息。 + +## 读者对象 + +本文档主要适用于售前技术支持工程师。 + +## 符号约定 + +在本文中可能出现下列标志,它们所代表的含义如下。 + +[TABLE] + +## + +# 目 录 + +[前言 [iii](#前言)](#前言) + +[1 产品定位和亮点 [1](#产品定位和亮点)](#产品定位和亮点) + +[1.1 产品定位 [1](#产品定位)](#产品定位) + +[1.2 产品亮点 [1](#产品亮点)](#产品亮点) + +[1.2.1 说明 [1](#说明)](#说明) + +[1.2.2 高性能 [1](#高性能)](#高性能) + +[1.2.3 易集成 [1](#易集成)](#易集成) + +[1.2.4 可靠性 [1](#可靠性)](#可靠性) + +[2 典型应用场景 [2](#典型应用场景)](#典型应用场景) + +[2.1 数据库场景 [2](#数据库场景)](#数据库场景) + +[2.1.1 场景介绍 [2](#场景介绍)](#场景介绍) + +[2.2 HPC场景 [3](#hpc场景)](#hpc场景) + +[2.2.1 场景介绍 [3](#场景介绍-1)](#场景介绍-1) + +[2.3 对接X交易所场景 [3](#对接x交易所场景)](#对接x交易所场景) + +[2.3.1 场景介绍 [3](#场景介绍-2)](#场景介绍-2) + +[3 特性和功能 [4](#特性和功能)](#特性和功能) + +[3.1 传输层特性 [4](#传输层特性)](#传输层特性) + +[3.1.1 客户价值 [4](#客户价值)](#客户价值) + +[3.1.2 场景举例 [4](#场景举例)](#场景举例) + +[3.1.3 功能说明 [4](#功能说明)](#功能说明) + +[3.2 服务层特性 [4](#服务层特性)](#服务层特性) + +[3.2.1 客户价值 [4](#客户价值-1)](#客户价值-1) + +[3.2.2 场景举例 [5](#场景举例-1)](#场景举例-1) + +[3.2.3 功能说明 [5](#功能说明-1)](#功能说明-1) + +[A 特性规格清单 [6](#特性规格清单)](#特性规格清单) + +[B 术语 [8](#术语)](#术语) + +# 产品定位和亮点 + +[1.1 产品定位](#产品定位) + +[1.2 产品亮点](#产品亮点) + +## 产品定位 + +UBS Comm(UB service communication)是一个适用于高带宽和低延迟网络C/S(Client/Server)架构应用程序的高性能通信框架。 + +## 产品亮点 + +### 说明 + +UBS Comm旨在提供一组支持各种协议的高级API(Application Programming Interface),并屏蔽了包括RDMA(Remote Direct Memory Access)、TCP(Transmission Control Protocol)、UDS(Unix Domain Socket)、SHM(Shared Memory)、UBC(Unified Bus Clan)等低级API的复杂性与差异性,同时尽可能发挥硬件能力,以保证其拥有高性能。 + +### 高性能 + +UBS Comm提供点对点消息Send/Receive、Read/Write的单双边通信接口,可使用UBC、RDMA协议进行高性能通信。具体场景参见[典型应用场景](#典型应用场景)。 + +### 易集成 + +- 支持多语言(C/C++)API。 + +- 支持多种协议通信(RDMA/TCP/UDS/SHM/UBC)。 + +### 可靠性 + +提供高可靠的通信传输能力,支持故障检测消息重传,包括超时检测、等待、重传。 + +# 典型应用场景 + +[3.1 数据库场景](#数据库场景) + +[3.2 HPC场景](#hpc场景) + +[3.3 对接X交易所场景](#对接x交易所场景) + +## 数据库场景 + +### 场景介绍 + +UWAL的Client和Server对接了UBS Comm使用service层接口完成RDMA和TCP协议通信。场景应用如[图3-1](#fig1571175417276)所示。 + +1. 数据库场景典型应用 + +![img](file:///C:/Users/Y00835~1/AppData/Local/Temp/msohtmlclip1/01/clip_image002.jpg) + +在数据库场景中,openGauss中UWAL模块借助UBS Comm极致数据传输能力,TPC-C tmpC性能提升12.8%。 + +## HPC场景 + +### 场景介绍 + +SDK和Daemon进程的Cache组件使用了UBS Comm的SHM协议,MF组件节点内通信使用的UDS协议,MF节点间通信使用的TCP协议。场景应用如[图3-2](#fig768110310817)所示。 + +1. HPC场景应用 + +![img](file:///C:/Users/Y00835~1/AppData/Local/Temp/msohtmlclip1/01/clip_image002.jpg) + +HPC场景中,IO缓存采用UBS Comm读写效率提升30%。 + +## 对接X交易所场景 + +### 场景介绍 + +X交易系统中,使用UBS Comm的Transport层C++接口层进行RDMA通信。 + +1. X交易所对接 + +![img](file:///C:/Users/Y00835~1/AppData/Local/Temp/msohtmlclip1/01/clip_image002.jpg) + +X交易所对接场景中,基于MLX5网卡,使用RDMA协议通信,实现256B小包单向时延不高于1.5us。 + +# 特性和功能 + +[4.1 传输层特性](#传输层特性) + +[4.2 服务层特性](#服务层特性) + +## 传输层特性 + +### 客户价值 + +提供多种协议(RDMA/TCP/UDS/SHM/UBC)点对点消息Send/Receive双边通信接口、Read/Write单边通信接口。 + +### 场景举例 + +对接X交易所场景。 + +### 功能说明 + +传输层特性功能说明如下: + +- 支持多种协议(RDMA/TCP/UDS/SHM/UBC)。 + +- 点对点消息Send/Receive双边通信,Read/Write单边通信。 + +- 支持多种算法的加密通信。 + +- 支持保活功能。 + +## 服务层特性 + +### 客户价值 + +提供双向服务层API接口,提供流量控制、多语言(C/C++)API、MULTIRAIL(多端口)、RNDV等高级功能。 + +### 场景举例 + +数据库场景。 + +### 功能说明 + +- 支持多种协议(RDMA/TCP/UDS/SHM) + +- 支持点对点消息Send/Receive双边通信,Read/Write单边通信。 + +- 支持多种加密算法的认证和加密通信。 + +- 支持保活功能。 + +- 支持流量控制功能。 + +- 支持多语言(C/C++)API。 + +- 支持MULTIRAIL功能。 + +- 支持RNDV功能。 + +# 特性规格清单 + +2. 特性规格清单 + +| 特性 | 子特性/规格 | 特性/规格描述 | +|--------|----------------|-------------------------------------------------| +| 传输层 | RDMA | 支持配置RDMA通信功能,使用RDMA协议通信。 | +| | TCP | 支持配置TCP通信功能,使用TCP协议通信。 | +| | UDS | 支持配置UDS通信功能,使用UDS协议通信。 | +| | SHM | 支持配置SHM通信功能,使用SHM通信。 | +| | UBC | 支持配置UBC通信功能,使用UBC通信。 | +| | 双边通信 | 支持使用双边通信接口,进行双边通信。 | +| | 单边通信 | 支持使用单边通信接口,进行单边通信。 | +| | 加密认证和通信 | 支持使能加密功能,进行加密认证和通信。 | +| | 保活 | 默认开启保活功能。 | +| 服务层 | RDMA | 支持配置RDMA通信功能,使用RDMA协议通信。 | +| | TCP | 支持配置TCP通信功能,使用TCP协议通信。 | +| | UDS | 支持配置UDS通信功能,使用UDS协议通信。 | +| | SHM | 支持配置SHM通信功能,使用SHM通信。 | +| | UBC | 支持配置UBC通信功能,使用UBC通信。 | +| | 双边通信 | 支持使用双边通信接口,进行双边通信。 | +| | 单边通信 | 支持使用单边通信接口,进行单边通信。 | +| | 加密认证和通信 | 支持使能加密功能,进行加密认证和通信。 | +| | 保活 | 默认开启保活功能。 | +| | RNDV | 支持使能RNDV协议,进行单边+双边结合的方式通信。 | +| | MULTIRAIL | 支持使能MULTIRAIL功能,RDMA多网口带宽聚合通信。 | + +# 术语 + +| 缩略语 | 英文全称 | 中文名称 | +|-----------|-------------------------------|--------------------| +| RDMA | Remote direct memory access | 远端内存直接访问。 | +| TCP | Transmission Control Protocol | 传输控制协议。 | +| UDS | Unix Domain Socket | Unix域套接字。 | +| SHM | Shared Memory | 共享内存。 | +| UB-C | Unified bus clan | UB-C协议 | +| RNDV | Rendezvous | Rendezvous协议。 | +| MULTIRAIL | multi rail | 多网口。 | diff --git "a/doc/release\350\257\264\346\230\216_20200930.md" "b/doc/release\350\257\264\346\230\216_20200930.md" new file mode 100644 index 0000000000000000000000000000000000000000..bbef8f975416257599936e68354a8418890c2afe --- /dev/null +++ "b/doc/release\350\257\264\346\230\216_20200930.md" @@ -0,0 +1,26 @@ +## Release notes: +### HCOM 22.0.0 +Date: 2022/09/30 +Summary: HCOM is an easy to use, high performance library for various hardware including RoCE/IB, Eth, UB etc. This is the first RC release for RDMA protocol, which hides all the complexities of RoCE and IB etc, and also provides high performance. +Major features: +- RDMA two side operation with/without opcode +- RDMA one side operation including memory region registering +- RDMA Endpoint establishing with OOB socket +- Multiple threads support including groups +- Busy polling and event polling support +- Cpu binding support +- Self polling endpoint support +- Dynamically load verbs library +- OOB connection support with both TCP and UDS +- Multiple OOB listeners support +- Load balance support including Round-Robin and Hash policy +- Heartbeat support for RDMA +- Connection version support +- Data crypt for two side operation, OOB uses TLS 1.3 and AES_128_GCM for data crypt over RDMA +- Dynamically load openssl library, 1.1.1f and later version +- Two side TLS verification support +- External log function support +- Both C and C++ API support +- Providing includes files, .a and .so library + +Note: This is a limited release to DCS for internal integration only, not suitable for production and PoC externally. \ No newline at end of file diff --git "a/doc/\344\273\243\347\240\201\346\236\266\346\236\204\350\256\276\350\256\241.md" "b/doc/\344\273\243\347\240\201\346\236\266\346\236\204\350\256\276\350\256\241.md" new file mode 100644 index 0000000000000000000000000000000000000000..cfb73b044fdd1bf7f833aba9dcaca3ecd46bdd8c --- /dev/null +++ "b/doc/\344\273\243\347\240\201\346\236\266\346\236\204\350\256\276\350\256\241.md" @@ -0,0 +1,50 @@ +### code structure + +#### 3 layers: +Layer 1: API layer +``` +file all start with hcom + +include cxx api and c api + +cxx api including +1 net related things +2 obj pool +3 ring buffer +4 lockless ring buffer +5 execution service +6 blocking ring buffer + +c api only include +1 net related thing +2 and ring buffer and blocking queue with spinlock + +``` + +Layer 2: logic layer and adaptive layer +``` +which are plug-in for api layer +and also calling wrapper layer + +net_rdma_* is plug-in for rdma + +net_tcp_* is plug-in for tcp +``` + +Layer 3: Wrapper layer +``` +which is just simple wrapper raw api + +for example: + +rdma related wrappers all start with rdma_ + +rdma_verbs_wrapper.* wrappers all rdma related functions + +rdma_worker.* worker with polling thread + +rdma_mr_pool.* memory region related + +rdma_composed_endpoint.* endpoints for async/sync/semi-sync + +``` \ No newline at end of file diff --git "a/doc/\345\256\211\345\205\250\346\240\241\351\252\214\350\241\214\344\270\272.md" "b/doc/\345\256\211\345\205\250\346\240\241\351\252\214\350\241\214\344\270\272.md" new file mode 100644 index 0000000000000000000000000000000000000000..853510826d295a928a5842f5f3def9e2a0709d55 --- /dev/null +++ "b/doc/\345\256\211\345\205\250\346\240\241\351\252\214\350\241\214\344\270\272.md" @@ -0,0 +1,98 @@ +# HCOM +HCOM is a high performance communication library for C/S applications. +* Easy to use, HCOM provides high level APIs instead of difficult low RDMA etc +* High performance, expose hardware capability as much as possible, and also optimization of domain oriented applications +* Various protocol supports, HCOM hidden the complex of low level API, RDMA/TCP/UDS/Shm/URMA etc + +#### 1 how to clone +``` +1 git clone repo +2 git submodule update --init --recursive +``` + +#### 2 how to build +``` +cd ${TOP_DIR}/test +sh build.sh + +cd ${TOP_DIR} +mkdir build +cd build + +#if gcc version 4.8.5 +sh adapter_script.sh + +cmake -DCMAKE_DEPENDS_USE_COMPILER=false -DCMAKE_BUILD_TYPE=release .. +make -j8 + +# more flags +# -DBUILD_TESTS=off for disable building test +# -DBUILD_EXAMPLE=off for disable building example +# -DBUILD_PERF=off for disable building perf +# -DBUILD_JAVA_SDK=off for disable building java code + +ll + +# you will see server and client binary under current dir, include libhcom.so and libhcom_static.a. +# note: if you want to link libhcom_static.a, make sure libsecurec.a exist in your project. + +more cmake options: +1 CMAKE_BUILD_TYPE: release|debug, release binary or debug binary + default: release + +2 NN_LOG_TRACE_INFO_ENABLED: enable trace log + default: disabled + +3 USE_PROCESS_MONOTONIC: use CPU instruction for fast timestamp, need to change it in CMakelists.txt + default: enabled + +4 ENABLE_OBJ_GLOBAL_STATISTICS: enable object statistic, need to change it in CMakelists.txt + default: enabled + +5 CMAKE_INSTALL_PREFIX, CMAKE built-in options, to set make install target folder + default: system + +``` + + +#### 3 how to execute UT cases +``` +export HCOM_BUILD_TYPE=debug +export HCOM_BUILD_TESTS=on +./build.sh +./build/generate_gtest_report.sh +``` + +#### 4 how to run examples +``` + +``` + +#### 5 how to run perf +``` + +``` + +#### 6 secure info behavior table +单向校验行为表: + +| 场景case | OOB client 注册provider否 | OOB client provider返回有效否 | OOB Server 注册validator否 | OOB Server validator返回无效否 | 内部行为发送header内容 | 用户可见行为 | +|:------:|--------------------------|----------------------------|---------------------------|-----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------| +| 1 | Y | N | 任意组合(4种) |   |   | 1.OOB client终止connect, 返回错误NN_OOB_SEC_PROCESS_ERROR(141) 2.OOB server若有validator,则报错reset by peer, 返回错误码NN_OOB_CONN_RECEIVE_ERROR(126) | +| 2 | Y | Y | N | -(2种) | 1.OOB client发送header(flag, secInfoLen +SEC_VALID_ONE_WAY) + secInfo 2.OOB server接受header + secInfo,不进行校验,直接验证通过并返回respons(OK) 3.OOB client接受response(OK),认证通过 | 1.OOB Server通过, 打印WARNING message | +| 3 | Y | Y | Y | N | 1.OOB client发送header(flag, secInfoLen +SEC_VALID_ONE_WAY) + secInfo 2.OOB server接受header + secInfo,进行校验,校验失败,直接验证通过并返回respons(SEC_VALID_ERROR) (对面被reset by peer了) 3.OOB client接受response(SEC_VALID_ERROR), 认证失败 | 1.OOB Server失败,打印Error日志写明具体validator的错误值 2.OOB Client失败,打印Error日志,Response=SEC_VALID_FAILED(-9) | +| 4 | Y | Y | Y | Y | 1.OOB client发送header(flag, secInfoLen +SEC_VALID_ONE_WAY) + secInfo 2.OOB server接受header + secInfo,校验通过,直接验证通过并返回respons(OK) 3.OOB client接受response(OK),认证通过 | 1.OOB Server通过 2.OOB Client通过 | +| 5 | N | -(2种) | Y | Y | 1.OOB client发送header(0,0,NO_VALID) 2.OOB server收header,校验type失败,返回resp(SEC_VALID_FAILED) 3.OOB Client收的到resp(SEC_VALID_FAILED) | 1.OOB Server失败,打印Error 2.OOB Client失败,打印Error日志,Response=SEC_VALID_ERROR(-9) | +| 6 | N | -(2种) | Y | N | 同上 | 同上 | +| 7 | N | -(2种) | N | -(2种) | 1.OOB client发送header(0,0,NO_VALID) 2.OOB server收header,校验type为NO_VAILD并且没有注册validator,校验通过 | 1.OOB Server通过 | + +双向校验行为表: + +| 场景case | OOB Server 注册 provider否 | OOB Server provider返回有效否 | OOB Client 注册validator否 | OOB Client validator返回无效否 | 内部行为发送header内容 | 用户可见行为 | +|--------|---------------------------|-----------------------------|---------------------------|-----------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------| +| 8 | Y | N | N | - (2种) | 0. 前提条件,client已经注册了provider,否则参考表1(单项场景) 1.OOB client调用provider,获取secType = SEC_VALID_TWO_WAY, 校验没有validator直接打印并返回错误 2.OOB server, reset by peer | OOB client终止connect, 返回错误NN_OOB_SEC_PROCESS_ERROR(141) | +| 9 | Y | N | Y | -(2种) | 1.OOB server provider 返回无效,打印NN_OOB_SEC_PROCESS_ERROR,并返回错误 2.OOB client打印reset by peer,返回错误 | OOB client打印返回错误NN_OOB_SEC_PROCESS_ERROR(141) | +| 10 | Y | Y | N | -(2种) | 0. 前提条件,client已经注册了provider,否则参考表1(单项场景) 1. OOB client调用provider,获取secType = SEC_VALID_TWO_WAY, 校验没有validator直接打印并返回错误 | OOB client终止connect, 返回错误NN_OOB_SEC_PROCESS_ERROR(141) | +| 11 | Y | Y | Y | N | 1.OOB server发送header(flag, secInfoLen +SEC_VALID_TWO_WAY+errCode(0)) + secInfo 2.OOB client接受header,进行validate失败,打印错误并返回 3.OOB server收到reset by peer, 打印错误并返回 | 1.OOB Server失败,打印reset by peer 2.OOB client打印返回错误NN_OOB_SEC_PROCESS_ERROR(141) | +| 12 | Y | Y | Y | Y | 1.OOB server发送header(flag, secInfoLen +SEC_VALID_TWO_WAY+errCode(0)) + secInfo 2.OOB client接受header,进行validate成功,发送Response(OK) 3.OOB server收到response(OK), 校验成功 | 1.OOB Server通过 2.OOB Client通过 | +| 13 | N | -(8种) | | | 1.OOB server 直接返回错误 2.OOB client 端收到reset by peer,打印错误并返回 | 1.OOB Server失败,提示未注册provider, 打印Error 2.OOB client打印返回错误NN_OOB_SEC_PROCESS_ERROR(141) | \ No newline at end of file diff --git "a/doc/\346\234\215\345\212\241\345\261\202\350\256\276\350\256\241\345\217\212\345\256\236\347\216\260FAQ.md" "b/doc/\346\234\215\345\212\241\345\261\202\350\256\276\350\256\241\345\217\212\345\256\236\347\216\260FAQ.md" new file mode 100644 index 0000000000000000000000000000000000000000..c0747891565f9447e440582680bd335f46147517 --- /dev/null +++ "b/doc/\346\234\215\345\212\241\345\261\202\350\256\276\350\256\241\345\217\212\345\256\236\347\216\260FAQ.md" @@ -0,0 +1,94 @@ +### service layer includes +- multiple rails +- RR mode +- QoS +- etc +1 多平面 p0 +- 单性多通道 (比如数据面的通道有n个通道, 时序问题) +- 多性质通道 (控制面消息面 数据消息面) +- 以整体通道channel +- 多卡 (p2) +``` +FAQ1: 时序问题? +- API不保证数据在多通道到上的时序问题,由上层保证;有时序要求的消息,下一命令需要等上一个完成 +- SendAfterWrite() 要特殊考虑 + +FAQ2: 通道选择问题? +- Round-Robin 默认 + +FAQ3: 建链? +- 所有数量成功后才成功 +- 使用过程中部分链路断链后行为, 有3种行为; 由上层配置模式 + - 断开其他的,且通知上层链路断开 + - 重连已断开的,a) 如果重连成功, 打印message, 上层不感知; b) 如果重连失败, 打印error msg, 通知上层, 且关闭其他链接 + - 重连但不成功, 保留剩余的链接; 打印message 且通知上层 + +FAQ4: 多性质多通道? +- 单性质多通道, 为一个channel, 一个channel有多个ep +- 多性质多通道, 为多个channel, 一次创建多个, 返回channel数据; + 如 + NetService::Instance() + NetService::Start()/Stop() + NetService::CreateChannels() + NetChannelPtr chs[2], NetChannelPtr ctrlCh = chs[0]; NetChannelPtr dataCh = chs[1]; + +FAQ5: callback register +- 第1种, channel new/broken, 为service级 +- 第2种, idle callback, 为service级别 +- 第3种, 收到消息 opCode级别, 为service级别; 2种模式, 可以为统一的, 也可以为per opCode, 但二者不能共存 +Notes: 注册函数不可动态修改 + +FAQ6: channel中多ep如何分布在worker上? +- RR策略 +- Hash策略 +- 所有ep在同一个worker上? (P1) + +FAQ7: 不同的channl要不要 workerGroups? +- 创建service和创建channels指定参数 + +``` + +2 ReqResp模式 p0 +- 双向 +- 只有Req +- 同步, CV/sem + self polling +- 异步, 函数级别callback? + +``` +FAQ1: one way +- 暴露一个api + +FAQ2: 异步RR, 什么级别的callback? +- 2种, per call 和 per channel + case1: 如果有per call, 使用per call + case2: 如果没有per call, 有per channel, 使用per channel的 + case3: 都没有报错 + +FAQ3: 同步API有2种模式, +- 非self polling的, CV/SEM, 细节待定 +- self polling的 (p1,部分细节再讨论) + case 1: 多ep per channel + = rdma需要把多个qp放到一个cq + = 需要把多fd放到一个poll + case 2: 1 ep per channel, 直接使用 (p0) + +FAQ4: timeout +``` + +3 QoS: +给上层一个确定的结果,无法恢复了,屏蔽闪断的情况 +超时 p0 +重连 p0 +重发 p0 + +反压 p1, 给上层调用者一个反馈 + +4 IO优先级 (P2) + +5 单/双边操作的内存分配 (p1, transport层) +- 大页注册 +- 较大块内存的不同size的分配 +- 可动态增大 + +6 callback per opcode (p0) \ No newline at end of file diff --git a/hcom.spec b/hcom.spec new file mode 100644 index 0000000000000000000000000000000000000000..9bf0e3658e2e93b4258972a1b9c915c57001abb2 --- /dev/null +++ b/hcom.spec @@ -0,0 +1,97 @@ +# add --with java_compile option, i.e. disable java_compile by default +%bcond_with java_compile + +%global build_type %{?_build_type:%{_build_type}} +# 如果没有提供,则设置默认值 +%if "%{build_type}" == "" + %global build_type release +%endif + +%global with_hcom_perf %{?_with_hcom_perf:%{_with_hcom_perf}} +# 如果没有提供,则设置默认值 +%if "%{with_hcom_perf}" == "" + %global with_hcom_perf 0 +%endif + +%global with_htracer_cli %{?_with_htracer_cli:%{_with_htracer_cli}} +# 如果没有提供,则设置默认值 +%if "%{with_htracer_cli}" == "" + %global with_htracer_cli 0 +%endif + +%if %{undefined rpm_version} + %define rpm_version 2.0.0 +%endif + +%if %{undefined rpm_release} + %define rpm_release B099 +%endif + +%if %{undefined rpm_build_date} + %define rpm_build_date %(date +"%%Y-%%m-%%d-%%H:%%M:%%S") +%endif + +# 根据构建类型决定是否生成 debuginfo 包 +%if "%{build_type}" == "debug" + %global package_suffix OCK-CommunicationSuite_HCOM_Debug + %global _enable_debug_packages 1 +%else + %global package_suffix OCK-CommunicationSuite_HCOM + %global debug_package %{nil} +%endif + +Name: %{package_suffix} +Version : %{rpm_version} +Release : %{rpm_release} +Summary: HCOM +License : Proprietary +Provides : Huawei Technologies Co., Ltd +Source0 : %{package_name}.tar.gz +BuildRoot : %{_buildirootdir}/%{name}_%{version}-build +buildArch : aarch64 x86_64 + +%package debug +Summary: debug info of hcom debug + +%description debug +This package contains debug info of hcom.so +%description +HCOM是一个适用于C/S架构应用程序的高性能通信库 + +%prep +%setup -c -n %{name}_%{version} + +%install +rm -rf %{buildroot} +mkdir -p %{buildroot}/usr/local/lib/hcom +mkdir -p %{buildroot}/usr/local/jars/hcom +mkdir -p %{buildroot}/usr/include/hcom/capi +mkdir -p %{buildroot}/usr/local/bin + +cp %{_builddir}/%{name}_%{version}/%{package_name}/hcom/lib/* %{buildroot}/usr/local/lib/hcom +cp -r %{_builddir}/%{name}_%{version}/%{package_name}/hcom/include/hcom/* %{buildroot}/usr/include/hcom + +%if %{with java_compile} + cp %{_builddir}/%{name}_%{version}/%{package_name}/hcom/jars/* %{buildroot}/usr/local/jars/hcom +%endif + +%if %{with_hcom_perf} + cp -r %{_builddir}/%{name}_%{version}/%{package_name}/hcom/hcom_perf %{buildroot}/usr/local/bin +%endif + +%if %{with_htracer_cli} + cp -r %{_builddir}/%{name}_%{version}/%{package_name}/hcom/bin/htracer_cli %{buildroot}/usr/local/bin +%endif + +%files +%defattr(-,root,root) +%{_prefix}/include/hcom/capi/*.h +%{_prefix}/include/hcom/*.h +%if %{with_hcom_perf} || %{with_htracer_cli} + %{_prefix}/local/bin/* +%endif +%{_prefix}/local/lib/hcom/*.so +%{_prefix}/local/lib/hcom/*.a +%if %{with java_compile} + %{_prefix}/local/jars/hcom/*.jar +%endif \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fbf06802c1b4f2fb29e0b7c6e7a27c7813c46830 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,105 @@ +file(GLOB_RECURSE HCOM *.cpp *.h) +file(GLOB_RECURSE HCOM_JNI_SRCS "${HCOM_SRC_DIR}/src/api/java_sdk/*.cpp" + "${HCOM_SRC_DIR}/src/api/java_sdk/*.h") +list(REMOVE_ITEM HCOM ${HCOM_JNI_SRCS}) + +file(GLOB_RECURSE HCOM_SERVICE_SRCS + "${HCOM_SRC_DIR}/src/service/*" + "${HCOM_SRC_DIR}/src/api/capi/*") +file(GLOB_RECURSE HCOM_SERVICE_V2_SRCS + "${HCOM_SRC_DIR}/src/service_v2/*" + "${HCOM_SRC_DIR}/src/api/capi_v2/*") + +file(GLOB HCOM_RDMA_COMMON "${HCOM_SRC_DIR}/src/transport/rdma/*") +file(GLOB_RECURSE HCOM_RDMA_SRCS + "${HCOM_SRC_DIR}/src/transport/rdma/verbs/*" + "${HCOM_SRC_DIR}/src/under_api/verbs/verbs*") +file(GLOB_RECURSE HCOM_SHM_SRCS "${HCOM_SRC_DIR}/src/transport/shm/*") +file(GLOB_RECURSE HCOM_SOCK_SRCS "${HCOM_SRC_DIR}/src/transport/sock/*") +file(GLOB_RECURSE HCOM_UB_SRCS "${HCOM_SRC_DIR}/src/transport/ub/*") + +if(NOT BUILD_WITH_RDMA) + list(REMOVE_ITEM HCOM ${HCOM_RDMA_COMMON}) +endif() + +if(NOT BUILD_WITH_RDMA) + list(REMOVE_ITEM HCOM ${HCOM_RDMA_SRCS}) +endif() + +if(NOT BUILD_WITH_SHM) + list(REMOVE_ITEM HCOM ${HCOM_SHM_SRCS}) +endif() + +if(NOT BUILD_WITH_SOCK) + list(REMOVE_ITEM HCOM ${HCOM_SOCK_SRCS}) +endif() + +if(NOT BUILD_WITH_UB) + list(REMOVE_ITEM HCOM ${HCOM_UB_SRCS}) +endif () + + +function(split_debug_symbols target) + # 定义符号文件路径(与二进制同目录,添加.debug后缀) + set(debug_file "$.debug") + + add_custom_command(TARGET ${target} POST_BUILD + # 1. 提取完整调试符号到单独文件 + COMMAND ${CMAKE_OBJCOPY} --only-keep-debug "$" "${debug_file}" + # 2. 从主二进制中剥离调试符号但保留最小关联信息 + COMMAND ${CMAKE_OBJCOPY} --strip-debug --strip-unneeded "$" + # 为可执行文件添加符号表文件的链接 + COMMAND ${CMAKE_OBJCOPY} --add-gnu-debuglink=${debug_file} "$" + COMMENT "Splitting debug symbols for ${target} into ${debug_file}" + ) +endfunction() + +add_library(hcom_static_obj STATIC ${HCOM}) +target_link_libraries(hcom_static_obj + -Wl,--start-group + pthread dl rt boundscheck + -Wl,--end-group) + +add_custom_command( + OUTPUT libhcom_static.a + DEPENDS hcom_static_obj + COMMAND mkdir -p tmp_obj && cd tmp_obj + && ar x $ + && cd .. + && ar cur libhcom_static.a tmp_obj/*.o + && rm -rf tmp_obj +) +add_custom_target(hcom_static ALL DEPENDS libhcom_static.a) + +add_library(hcom SHARED ${HCOM}) +target_compile_options(hcom PUBLIC -Wl,-z,noexecstack -Wl,-z,relro -Wl,-z,now) +target_link_libraries(hcom + -Wl,--start-group + pthread dl rt boundscheck + -Wl,--end-group) + +split_debug_symbols(hcom) +file(GLOB INCLUDE_HEADERS hcom*.h) + +file(GLOB C_INCLUDE_HEADERS api/capi_v2/hcom_c.h + api/capi_v2/hcom_service_c.h + api/capi/hcom_cgo_c.h) + +list(APPEND INCLUDE_HEADERS ${HCOM_SRC_DIR}/src/service_v2/api/hcom_service_channel.h) +list(APPEND INCLUDE_HEADERS ${HCOM_SRC_DIR}/src/service_v2/api/hcom_service_def.h) +list(APPEND INCLUDE_HEADERS ${HCOM_SRC_DIR}/src/service_v2/api/hcom_service_context.h) +list(APPEND INCLUDE_HEADERS ${HCOM_SRC_DIR}/src/service_v2/api/hcom_service.h) + +set(TARGET_INSTALL_INCLUDE ${CMAKE_INSTALL_PREFIX}/include/hcom) +set(TARGET_INSTALL_LIB ${CMAKE_INSTALL_PREFIX}/lib) + +install(FILES ${INCLUDE_HEADERS} DESTINATION ${TARGET_INSTALL_INCLUDE} PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ WORLD_READ) +install(FILES ${C_INCLUDE_HEADERS} DESTINATION ${TARGET_INSTALL_INCLUDE}/capi PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ WORLD_READ) +install(FILES ${CMAKE_BINARY_DIR}/src/libhcom_static.a DESTINATION ${TARGET_INSTALL_LIB}/ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE) +install(TARGETS hcom DESTINATION ${TARGET_INSTALL_LIB}/ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE) +install(FILES ${CMAKE_BINARY_DIR}/src/libhcom.so.debug DESTINATION ${TARGET_INSTALL_LIB}/ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE) + +if(${BUILD_JAVA_SDK} MATCHES "ON") + add_subdirectory(api/java_sdk) +endif() +message(STATUS "BUILD_JAVA_SDK: ${BUILD_JAVA_SDK}") diff --git a/src/api/capi_v2/hcom_c.cpp b/src/api/capi_v2/hcom_c.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8436dc8f941b3ee1208a4f867d108a139067209a --- /dev/null +++ b/src/api/capi_v2/hcom_c.cpp @@ -0,0 +1,1296 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "hcom_c.h" + +#include +#include +#include +#include +#include +#include "securec.h" + +#include "hcom_def.h" +#include "hcom_err.h" +#include "hcom_log.h" +#include "hcom_def_inner_c.h" +#include "net_load_balance.h" + +using namespace ock::hcom; + +#define VALIDATE_DRIVER(driver) \ + do { \ + if (NN_UNLIKELY((driver) == 0)) { \ + NN_LOG_ERROR("Invalid param, driver must be correct driver address"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define VALIDATE_DRIVER_NO_RET(driver) \ + do { \ + if (NN_UNLIKELY((driver) == 0)) { \ + NN_LOG_ERROR("Invalid param, driver must be correct driver address"); \ + return; \ + } \ + } while (0) + +#define VALIDATE_MR(mr) \ + do { \ + if (NN_UNLIKELY((mr) == 0)) { \ + NN_LOG_ERROR("Invalid param, mr must be correct mr address"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define VALIDATE_MR_NO_RET(mr) \ + do { \ + if (NN_UNLIKELY((mr) == 0)) { \ + NN_LOG_ERROR("Invalid param, mr must be correct mr address"); \ + return; \ + } \ + } while (0) + +#define VALIDATE_MR_POINT(mr) \ + do { \ + if (NN_UNLIKELY((mr) == 0)) { \ + NN_LOG_ERROR("Invalid param, mr point must be correct mr address"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define VALIDATE_EP(ep) \ + do { \ + if (NN_UNLIKELY((ep) == 0)) { \ + NN_LOG_ERROR("Invalid param, endpoint must be correct address"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define VALIDATE_EP_NO_RET(ep) \ + do { \ + if (NN_UNLIKELY((ep) == 0)) { \ + NN_LOG_ERROR("Invalid param, endpoint must be correct address"); \ + return; \ + } \ + } while (0) + +#define VALIDATE_REQ(req) \ + do { \ + if (NN_UNLIKELY((req) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, req is null"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define VALIDATE_SEQ(seqNo) \ + do { \ + if (NN_UNLIKELY((seqNo) == 0)) { \ + NN_LOG_ERROR("Invalid param, seqNo is 0"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define VALIDATE_NAME_NO_RET(name) \ + do { \ + if (NN_UNLIKELY((name) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, name must be correct address"); \ + return; \ + } \ + } while (0) + +#define VALIDATE_ALLOCATOR(allocator) \ + do { \ + if (NN_UNLIKELY((allocator) == 0)) { \ + NN_LOG_ERROR("Invalid allocator ptr"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define VALIDATE_OFFSET(offset) \ + do { \ + if (NN_UNLIKELY((offset) == 0)) { \ + NN_LOG_ERROR("Invalid offset ptr"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define VALIDATE_SIZE(size) \ + do { \ + if (NN_UNLIKELY((size) == 0)) { \ + NN_LOG_ERROR("Invalid size ptr"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define VALIDATE_NOT_NULL(ptr, errorMsg) \ + do { \ + if (NN_UNLIKELY((ptr) == nullptr)) { \ + NN_LOG_ERROR(errorMsg); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) \ + +#define VALIDATE_DRIVER_GET_OOB_IP_AND_PORT_PARAM(driver, ipArray, portArray, length) \ + do { \ + if (NN_UNLIKELY((driver) == 0)) { \ + NN_LOG_ERROR("Invalid param, driver must be correct driver address"); \ + return false; \ + } \ + if (NN_UNLIKELY((ipArray) == nullptr || (portArray) == nullptr || (length) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, ipArray/portArray/length cann't be nullptr"); \ + return false; \ + } \ + } while (0) \ + +static int ChangeAllocatorType(ubs_hcom_memory_allocator_options *options, UBSHcomNetMemoryAllocatorOptions &out) +{ + if (NN_UNLIKELY(options->address == NN_NO0)) { + NN_LOG_ERROR("Invalid allocator address "); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(options->size == NN_NO0)) { + NN_LOG_ERROR("Invalid allocator memory size " << options->size); + return NN_INVALID_PARAM; + } + + out.size = options->size; + out.address = options->address; + out.minBlockSize = options->minBlockSize != 0 ? options->minBlockSize : NN_NO4096; + out.bucketCount = options->bucketCount != 0 ? options->bucketCount : NN_NO8192; + out.alignedAddress = options->alignedAddress != 0; + out.cacheTierCount = options->cacheTierCount != 0 ? options->cacheTierCount : NN_NO8; + out.cacheBlockCountPerTier = options->cacheBlockCountPerTier != 0 ? options->cacheBlockCountPerTier : NN_NO16; + out.cacheTierPolicy = static_cast(options->cacheTierPolicy); + return SER_OK; +} + +inline UBSHcomNetDriverProtocol ChangeDriverTypeToDriverProto(ubs_hcom_driver_type type) +{ + UBSHcomNetDriverProtocol protocol = UBSHcomNetDriverProtocol::UNKNOWN; + switch (type) { + case C_DRIVER_RDMA: + protocol = UBSHcomNetDriverProtocol::RDMA; + break; + case C_DRIVER_TCP: + protocol = UBSHcomNetDriverProtocol::TCP; + break; + case C_DRIVER_UDS: + protocol = UBSHcomNetDriverProtocol::UDS; + break; + case C_DRIVER_SHM: + protocol = UBSHcomNetDriverProtocol::SHM; + break; + case C_SERVICE_UBC: + protocol = UBSHcomNetDriverProtocol::UBC; + break; + default: + NN_LOG_ERROR("Invalid driver protocol type"); + protocol = UBSHcomNetDriverProtocol::UNKNOWN; + } + return protocol; +} + +int ubs_hcom_mem_allocator_create(ubs_hcom_memory_allocator_type t, ubs_hcom_memory_allocator_options *options, + ubs_hcom_memory_allocator *allocator) +{ + if (NN_UNLIKELY(options == nullptr || allocator == nullptr)) { + NN_LOG_ERROR("Invalid options " << options << " or allocator " << allocator); + return NN_INVALID_PARAM; + } + + auto allocatorType = static_cast(t); + UBSHcomNetMemoryAllocatorOptions allocatorOptions; + UBSHcomNetMemoryAllocatorPtr innerAllocator; + + auto result = ChangeAllocatorType(options, allocatorOptions); + if (NN_UNLIKELY(result != SER_OK)) { + return result; + } + + result = UBSHcomNetMemoryAllocator::Create(allocatorType, allocatorOptions, innerAllocator); + if (NN_UNLIKELY(result != SER_OK)) { + return result; + } + + *allocator = reinterpret_cast(innerAllocator.Get()); + innerAllocator->IncreaseRef(); + return SER_OK; +} + +int ubs_hcom_mem_allocator_destroy(ubs_hcom_memory_allocator allocator) +{ + VALIDATE_ALLOCATOR(allocator); + + auto innerAllocator = reinterpret_cast(allocator); + innerAllocator->DecreaseRef(); + return SER_OK; +} + +int ubs_hcom_mem_allocator_set_mr_key(ubs_hcom_memory_allocator allocator, uint64_t mrKey) +{ + VALIDATE_ALLOCATOR(allocator); + + auto innerAllocator = reinterpret_cast(allocator); + innerAllocator->MrKey(mrKey); + return SER_OK; +} + +int ubs_hcom_mem_allocator_get_offset(ubs_hcom_memory_allocator allocator, uintptr_t address, uintptr_t *offset) +{ + VALIDATE_ALLOCATOR(allocator); + VALIDATE_OFFSET(offset); + + auto innerAllocator = reinterpret_cast(allocator); + *offset = innerAllocator->MemOffset(address); + return SER_OK; +} + +int ubs_hcom_mem_allocator_get_free_size(ubs_hcom_memory_allocator allocator, uintptr_t *size) +{ + VALIDATE_ALLOCATOR(allocator); + VALIDATE_SIZE(size); + + auto innerAllocator = reinterpret_cast(allocator); + *size = innerAllocator->FreeSize(); + return SER_OK; +} + +int ubs_hcom_mem_allocator_allocate(ubs_hcom_memory_allocator allocator, uint64_t size, uintptr_t *address, uint64_t *key) +{ + VALIDATE_ALLOCATOR(allocator); + VALIDATE_NOT_NULL(address, "Invalid out address"); + VALIDATE_NOT_NULL(key, "Invalid key ptr"); + auto innerAllocator = reinterpret_cast(allocator); + *key = innerAllocator->MrKey(); + return innerAllocator->Allocate(size, *address); +} + +int ubs_hcom_mem_allocator_free(ubs_hcom_memory_allocator allocator, uintptr_t address) +{ + VALIDATE_ALLOCATOR(allocator); + + auto innerAllocator = reinterpret_cast(allocator); + return innerAllocator->Free(address); +} + +static HdlMgr g_epHandlerManager; + +void ubs_hcom_set_log_handler(ubs_hcom_log_handler h) +{ + NetLogger::Instance()->SetExternalLogFunction(h); +} + +int ubs_hcom_check_local_support(ubs_hcom_driver_type t, ubs_hcom_device_info *info) +{ + if (NN_UNLIKELY(info == nullptr)) { + NN_LOG_ERROR("Invalid param info"); + return 0; + } + + UBSHcomNetDriverProtocol driverProto = ChangeDriverTypeToDriverProto(t); + if (NN_UNLIKELY(driverProto == UBSHcomNetDriverProtocol::UNKNOWN)) { + NN_LOG_ERROR("Unsupport driver type, type:" << t); + return 0; + } + + UBSHcomNetDriverDeviceInfo deviceInfo; + /* return 1 if support, otherwise return 0 */ + if (UBSHcomNetDriver::LocalSupport(driverProto, deviceInfo)) { + info->maxSge = deviceInfo.maxSge; + return 1; + } + + return 0; +} + +// driver api +int ubs_hcom_driver_create(ubs_hcom_driver_type t, const char *name, uint8_t startOobSvr, ubs_hcom_driver *driver) +{ + if (NN_UNLIKELY((name == nullptr) || (driver == nullptr))) { + NN_LOG_ERROR("Invalid param, name or driver is null"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY((startOobSvr != 0) && (startOobSvr != 1))) { + NN_LOG_ERROR("Invalid param, startOobSvr must be 0 or 1"); + return NN_INVALID_PARAM; + } + + if (strlen(name) > NN_NO64) { + NN_LOG_ERROR("Invalid param, name length is larger than " << NN_NO64); + return NN_INVALID_PARAM; + } + + UBSHcomNetDriverProtocol driverProtocol = ChangeDriverTypeToDriverProto(t); + if (NN_UNLIKELY(driverProtocol == UBSHcomNetDriverProtocol::UNKNOWN)) { + NN_LOG_ERROR("Unsupport driver type, type:" << t); + return NN_INVALID_PARAM; + } + + auto tmpDriver = UBSHcomNetDriver::Instance(driverProtocol, name, startOobSvr == 1); + if (tmpDriver == nullptr) { + return NN_NEW_OBJECT_FAILED; + } + + *driver = reinterpret_cast(tmpDriver); + + return NN_OK; +} + +void ubs_hcom_driver_set_ipport(ubs_hcom_driver driver, const char *ip, uint16_t port) +{ + VALIDATE_DRIVER_NO_RET(driver); + if (ip == nullptr) { + NN_LOG_ERROR("Invalid param, ip is empty"); + return; + } + + reinterpret_cast(driver)->OobIpAndPort(ip, port); +} + +static void ClearIpAndPortArray(char ***ipArray, uint16_t **portArray, int *length) +{ + if (length == nullptr || *length == 0) { + return; + } + + if ((portArray != nullptr) && (*portArray != nullptr)) { + free(*portArray); + *portArray = nullptr; + } + + if ((ipArray == nullptr) || (*ipArray == nullptr)) { + return; + } + + for (int i = 0; i != *length; ++i) { + if (**(ipArray + i) != nullptr) { + free(**(ipArray + i)); + **(ipArray + i) = nullptr; + } + } + free(*ipArray); + *ipArray = nullptr; + *length = 0; + return; +} + +bool ubs_hcom_driver_get_ipport(ubs_hcom_driver driver, char ***ipArray, uint16_t **portArray, int *length) +{ + VALIDATE_DRIVER_GET_OOB_IP_AND_PORT_PARAM(driver, ipArray, portArray, length); + std::vector> res; + if (!reinterpret_cast(driver)->GetOobIpAndPort(res)) { + NN_LOG_ERROR("Invalid param, get oob ip port failed"); + return false; + } + if (res.size() == 0) { + NN_LOG_ERROR("The Working oob ip and port cannot be found"); + return false; + } + // prepare length result + *length = static_cast(res.size()); + // prepare port result + *portArray = static_cast(malloc(res.size() * sizeof(uint16_t))); + if (*portArray == nullptr) { + NN_LOG_ERROR("Failed to malloc portArray"); + ClearIpAndPortArray(ipArray, portArray, length); + return false; + } + for (auto i = 0; i < static_cast(res.size()); ++i) { + **(portArray + i) = res[i].second; + } + // prepare ip result + *ipArray = static_cast(malloc(res.size() * sizeof(char *))); + if (*ipArray == nullptr) { + NN_LOG_ERROR("malloc ipArray failed!"); + ClearIpAndPortArray(ipArray, portArray, length); + return false; + } + bzero(*ipArray, res.size() * sizeof(char *)); + for (int i = 0; i < static_cast(res.size()); ++i) { + auto temp = static_cast(malloc(MAX_IP_LENGTH * sizeof(char))); + if (temp == nullptr) { + NN_LOG_ERROR("malloc ipArray[" << i << "] failed!"); + ClearIpAndPortArray(ipArray, portArray, length); + return false; + } + bzero(temp, MAX_IP_LENGTH * sizeof(char)); + if (memcpy_s(temp, MAX_IP_LENGTH, res[i].first.c_str(), res[i].first.size()) != 0) { + NN_LOG_ERROR("copy ipArray" << i << "] failed!"); + free(temp); + ClearIpAndPortArray(ipArray, portArray, length); + return false; + } + **(ipArray + i) = temp; + } + return true; +} + +void ubs_hcom_driver_set_udsname(ubs_hcom_driver driver, const char *name) +{ + VALIDATE_DRIVER_NO_RET(driver); + VALIDATE_NAME_NO_RET(name); + reinterpret_cast(driver)->OobUdsName(name); +} + +void ubs_hcom_driver_add_uds_opt(ubs_hcom_driver driver, ubs_hcom_driver_uds_listen_opts option) +{ + VALIDATE_DRIVER_NO_RET(driver); + UBSHcomNetOobUDSListenerOptions innerOpt {}; + innerOpt.Name(option.name); + innerOpt.perm = option.perm; + innerOpt.targetWorkerCount = option.targetWorkerCount; + reinterpret_cast(driver)->AddOobUdsOptions(innerOpt); +} + +void ubs_hcom_driver_add_oob_opt(ubs_hcom_driver driver, ubs_hcom_driver_listen_opts options) +{ + VALIDATE_DRIVER_NO_RET(driver); + + UBSHcomNetOobListenerOptions innerOpt {}; + std::string ip = { options.ip, strlen(options.ip) <= sizeof(options.ip) ? strlen(options.ip) : sizeof(options.ip) }; + innerOpt.Set(ip, options.port, options.targetWorkerCount); + reinterpret_cast(driver)->AddOobOptions(innerOpt); +} + +int ubs_hcom_driver_initialize(ubs_hcom_driver driver, ubs_hcom_driver_opts options) +{ + VALIDATE_DRIVER(driver); + + UBSHcomNetDriverOptions driverOps {}; + driverOps.mode = NET_BUSY_POLLING; + if (options.mode == C_EVENT_POLLING) { + driverOps.mode = ock::hcom::NET_EVENT_POLLING; + } + + driverOps.mrSendReceiveSegSize = options.mrSendReceiveSegSize != 0 ? options.mrSendReceiveSegSize : NN_NO1024; + driverOps.mrSendReceiveSegCount = options.mrSendReceiveSegCount != 0 ? options.mrSendReceiveSegCount : NN_NO8192; + driverOps.SetNetDeviceIpMask(options.netDeviceIpMask); + driverOps.SetNetDeviceIpGroup(options.netDeviceIpGroup); + driverOps.completionQueueDepth = options.completionQueueDepth != 0 ? options.completionQueueDepth : NN_NO2048; + driverOps.maxPostSendCountPerQP = options.maxPostSendCountPerQP != 0 ? options.maxPostSendCountPerQP : NN_NO64; + driverOps.prePostReceiveSizePerQP = + options.prePostReceiveSizePerQP != 0 ? options.prePostReceiveSizePerQP : NN_NO64; + driverOps.pollingBatchSize = options.pollingBatchSize != 0 ? options.pollingBatchSize : NN_NO4; + driverOps.qpSendQueueSize = options.qpSendQueueSize != 0 ? options.qpSendQueueSize : NN_NO256; + driverOps.qpReceiveQueueSize = options.qpReceiveQueueSize != 0 ? options.qpReceiveQueueSize : NN_NO256; + driverOps.version = options.version != 0 ? options.version : NN_NO0; + driverOps.dontStartWorkers = (options.dontStartWorkers == 1); + driverOps.tcpSendBufSize = options.tcpSendBufSize != 0 ? NN_NextPower2(options.tcpSendBufSize) : 0; + driverOps.tcpReceiveBufSize = options.tcpReceiveBufSize != 0 ? NN_NextPower2(options.tcpReceiveBufSize) : 0; + + if (NN_UNLIKELY(memcpy_s(driverOps.workerGroups, sizeof(driverOps.workerGroups), options.workerGroups, + sizeof(options.workerGroups)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy worker groups"); + return NN_INVALID_PARAM; + } + if (NN_UNLIKELY(memcpy_s(driverOps.workerGroupsCpuSet, sizeof(driverOps.workerGroupsCpuSet), + options.workerGroupsCpuSet, sizeof(options.workerGroupsCpuSet)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy worker cpu set"); + return NN_INVALID_PARAM; + } + driverOps.workerThreadPriority = options.workerThreadPriority; + driverOps.tcpUserTimeout = options.tcpUserTimeout; + driverOps.tcpEnableNoDelay = options.tcpEnableNoDelay; + driverOps.tcpSendZCopy = options.tcpSendZCopy; + driverOps.heartBeatIdleTime = options.heartBeatIdleTime != 0 ? options.heartBeatIdleTime : NN_NO60; + driverOps.heartBeatProbeTimes = options.heartBeatProbeTimes != 0 ? options.heartBeatProbeTimes : NN_NO7; + driverOps.heartBeatProbeInterval = options.heartBeatProbeInterval != 0 ? options.heartBeatProbeInterval : NN_NO2; + driverOps.enableTls = options.enableTls; + driverOps.tlsVersion = + options.tlsVersion != 0 ? static_cast(options.tlsVersion) : (ock::hcom::TLS_1_3); + driverOps.cipherSuite = static_cast(options.cipherSuite); + driverOps.oobType = NET_OOB_TCP; + driverOps.tcpSendBufSize = options.tcpSendBufSize; + driverOps.tcpReceiveBufSize = options.tcpReceiveBufSize; + driverOps.maxConnectionNum = options.maxConnectionNum != 0 ? options.maxConnectionNum : NN_NO250; + if (options.oobType == C_NET_OOB_UDS) { + driverOps.oobType = ock::hcom::NET_OOB_UDS; + } + if (NN_UNLIKELY(memcpy_s(driverOps.oobPortRange, sizeof(driverOps.oobPortRange), options.oobPortRange, + sizeof(options.oobPortRange)) != 0)) { + NN_LOG_ERROR("Failed to copy oob port range"); + return NN_INVALID_PARAM; + } + return reinterpret_cast(driver)->Initialize(driverOps); +} + +int ubs_hcom_driver_start(ubs_hcom_driver driver) +{ + VALIDATE_DRIVER(driver); + return reinterpret_cast(driver)->Start(); +} + +int ubs_hcom_driver_connect(ubs_hcom_driver driver, const char *payloadData, ubs_hcom_endpoint *ep, uint32_t flags) +{ + return ubs_hcom_driver_connect_with_grpno(driver, payloadData, ep, flags, 0, 0); +} + +int ubs_hcom_driver_connect_with_grpno(ubs_hcom_driver driver, const char *payloadData, ubs_hcom_endpoint *ep, + uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo) +{ + VALIDATE_DRIVER(driver); + VALIDATE_EP(ep); + + std::string payload = payloadData != nullptr ? payloadData : ""; + + UBSHcomNetEndpointPtr realEp; + auto result = reinterpret_cast(driver)->Connect(payload, realEp, flags, serverGrpNo, + clientGrpNo); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + + // increase ref, need to call ubs_hcom_ep_destroy() to decrease ref + realEp->IncreaseRef(); + + *ep = reinterpret_cast(realEp.Get()); + + return NN_OK; +} + +int ubs_hcom_driver_connect_to_ipport(ubs_hcom_driver driver, const char *serverIp, uint16_t serverPort, + const char *payloadData, ubs_hcom_endpoint *ep, uint32_t flags) +{ + return ubs_hcom_driver_connect_to_ipport_with_groupno(driver, serverIp, serverPort, payloadData, ep, flags, 0, 0); +} + +int ubs_hcom_driver_connect_to_ipport_with_groupno(ubs_hcom_driver driver, const char *serverIp, uint16_t serverPort, + const char *payloadData, ubs_hcom_endpoint *ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo) +{ + return ubs_hcom_driver_connect_to_ipport_with_ctx(driver, serverIp, serverPort, payloadData, ep, flags, + serverGrpNo, clientGrpNo, 0); +} + +int ubs_hcom_driver_connect_to_ipport_with_ctx(ubs_hcom_driver driver, const char *serverIp, uint16_t serverPort, + const char *payloadData, ubs_hcom_endpoint *ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, + uint64_t ctx) +{ + if (serverIp == nullptr || serverPort == 0) { + NN_LOG_ERROR("Failed to connect as server ip is null or port " << serverPort << " is invalid"); + return NN_INVALID_PARAM; + } + + VALIDATE_DRIVER(driver); + VALIDATE_EP(ep); + + std::string payload = payloadData != nullptr ? payloadData : ""; + + UBSHcomNetEndpointPtr realEp; + auto result = reinterpret_cast(driver)->Connect(serverIp, serverPort, payload, realEp, flags, + serverGrpNo, clientGrpNo, ctx); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + + // increase ref, need to call ubs_hcom_ep_destroy() to decrease ref + realEp->IncreaseRef(); + + *ep = reinterpret_cast(realEp.Get()); + return NN_OK; +} + +void ubs_hcom_driver_stop(ubs_hcom_driver driver) +{ + VALIDATE_DRIVER_NO_RET(driver); + reinterpret_cast(driver)->Stop(); +} + +void ubs_hcom_driver_uninitialize(ubs_hcom_driver driver) +{ + VALIDATE_DRIVER_NO_RET(driver); + reinterpret_cast(driver)->UnInitialize(); +} + +int ubs_hcom_driver_destroy(ubs_hcom_driver driver) +{ + VALIDATE_DRIVER(driver); + std::string name = reinterpret_cast(driver)->Name(); + return UBSHcomNetDriver::DestroyInstance(name); +} + +uintptr_t ubs_hcom_driver_register_ep_handler(ubs_hcom_driver driver, ubs_hcom_ep_handler_type t, + ubs_hcom_ep_handler h, uint64_t usrCtx) +{ + if (NN_UNLIKELY(driver == 0)) { + NN_LOG_ERROR("Invalid param, driver must be correct driver address"); + return 0; + } + + if (NN_UNLIKELY(h == nullptr)) { + NN_LOG_ERROR("Invalid param, ubs_hcom_ep_handler is null"); + return 0; + } + + auto tmpHandle = new (std::nothrow) EpHdlAdp(t, h, usrCtx); + if (NN_UNLIKELY(tmpHandle == nullptr)) { + NN_LOG_ERROR("Failed to new Endpoint handler adapter, probably out of memory"); + return 0; + } + + if (t == C_EP_NEW) { + reinterpret_cast(driver)->RegisterNewEPHandler(std::bind(&EpHdlAdp::NewEndPoint, tmpHandle, + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + } else if (t == C_EP_BROKEN) { + reinterpret_cast(driver)->RegisterEPBrokenHandler( + std::bind(&EpHdlAdp::EndPointBroken, tmpHandle, std::placeholders::_1)); + } else { + NN_LOG_ERROR("Unreachable"); + delete tmpHandle; + return 0; + } + + g_epHandlerManager.AddHdlAdp(reinterpret_cast(tmpHandle)); + + return reinterpret_cast(tmpHandle); +} + +uintptr_t ubs_hcom_driver_register_op_handler(ubs_hcom_driver driver, ubs_hcom_op_handler_type t, + ubs_hcom_request_handler h, uint64_t usrCtx) +{ + if (NN_UNLIKELY(driver == 0)) { + NN_LOG_ERROR("Invalid param, driver must be correct driver address"); + return 0; + } + + if (NN_UNLIKELY(h == nullptr)) { + NN_LOG_ERROR("Invalid param, ubs_hcom_ep_handler is null"); + return 0; + } + + auto tmpHandle = new (std::nothrow) EpOpHdlAdp(h, usrCtx); + if (NN_UNLIKELY(tmpHandle == nullptr)) { + NN_LOG_ERROR("Failed to new Endpoint handler adapter, probably out of memory"); + return 0; + } + + if (t == C_OP_REQUEST_RECEIVED) { + reinterpret_cast(driver)->RegisterNewReqHandler( + std::bind(&EpOpHdlAdp::Requested, tmpHandle, std::placeholders::_1)); + } else if (t == C_OP_REQUEST_POSTED) { + reinterpret_cast(driver)->RegisterReqPostedHandler( + std::bind(&EpOpHdlAdp::Requested, tmpHandle, std::placeholders::_1)); + } else if (t == C_OP_READWRITE_DONE) { + reinterpret_cast(driver)->RegisterOneSideDoneHandler( + std::bind(&EpOpHdlAdp::Requested, tmpHandle, std::placeholders::_1)); + } else { + NN_LOG_ERROR("Unreachable"); + delete tmpHandle; + return 0; + } + + g_epHandlerManager.AddHdlAdp(reinterpret_cast(tmpHandle)); + + return reinterpret_cast(tmpHandle); +} + +uintptr_t ubs_hcom_driver_register_idle_handler(ubs_hcom_driver driver, ubs_hcom_idle_handler h, uint64_t usrCtx) +{ + if (NN_UNLIKELY(driver == 0)) { + NN_LOG_ERROR("Invalid param, driver must be correct driver address"); + return 0; + } + + if (NN_UNLIKELY(h == nullptr)) { + NN_LOG_ERROR("Invalid param, ubs_hcom_ep_handler is null"); + return 0; + } + + auto tmpHandle = new (std::nothrow) EpIdleHdlAdp(h, usrCtx); + if (NN_UNLIKELY(tmpHandle == nullptr)) { + NN_LOG_ERROR("Failed to new Endpoint handler adapter, probably out of memory"); + return 0; + } + + reinterpret_cast(driver)->RegisterIdleHandler( + std::bind(&EpIdleHdlAdp::Idle, tmpHandle, std::placeholders::_1)); + + g_epHandlerManager.AddHdlAdp(reinterpret_cast(tmpHandle)); + + return reinterpret_cast(tmpHandle); +} + +uintptr_t ubs_hcom_driver_register_secinfo_provider(ubs_hcom_driver driver, ubs_hcom_secinfo_provider provider) +{ + if (NN_UNLIKELY(driver == 0)) { + NN_LOG_ERROR("Invalid param, driver must be correct driver address"); + return 0; + } + + if (NN_UNLIKELY(provider == nullptr)) { + NN_LOG_ERROR("Invalid param, ubs_hcom_secinfo_provider is null"); + return 0; + } + + auto tmpHandle = new (std::nothrow) OOBSecInfoProviderAdp(provider); + if (NN_UNLIKELY(tmpHandle == nullptr)) { + NN_LOG_ERROR("Register ubs_hcom_secinfo_provider failed, probably out of memory"); + return 0; + } + + reinterpret_cast(driver)->RegisterEndpointSecInfoProvider( + std::bind(&OOBSecInfoProviderAdp::CreateSecInfo, tmpHandle, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + + return reinterpret_cast(tmpHandle); +} + +uintptr_t ubs_hcom_driver_register_secinfo_validator(ubs_hcom_driver driver, ubs_hcom_secinfo_validator validator) +{ + if (NN_UNLIKELY(driver == 0)) { + NN_LOG_ERROR("Invalid param, driver must be correct driver address"); + return 0; + } + + if (NN_UNLIKELY(validator == nullptr)) { + NN_LOG_ERROR("Invalid param, ubs_hcom_secinfo_validator is null"); + return 0; + } + + auto tmpHandle = new (std::nothrow) OOBSecInfoValidatorAdp(validator); + if (NN_UNLIKELY(tmpHandle == nullptr)) { + NN_LOG_ERROR("Register ubs_hcom_secinfo_validator failed, probably out of memory"); + return 0; + } + + reinterpret_cast(driver)->RegisterEndpointSecInfoValidator( + std::bind(&OOBSecInfoValidatorAdp::SecInfoValidate, tmpHandle, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + + return reinterpret_cast(tmpHandle); +} + +uintptr_t ubs_hcom_driver_register_tls_cb(ubs_hcom_driver driver, ubs_hcom_tls_get_cert_cb certCb, + ubs_hcom_tls_get_pk_cb priKeyCb, ubs_hcom_tls_get_ca_cb caCb) +{ + if (NN_UNLIKELY(driver == 0)) { + NN_LOG_ERROR("Invalid param, driver must be correct driver address"); + return 0; + } + + if (NN_UNLIKELY(certCb == nullptr) || NN_UNLIKELY(priKeyCb == nullptr || NN_LIKELY(caCb == nullptr))) { + NN_LOG_ERROR("Failed to reg driver tls cb by invalid param or handler"); + return 0; + } + + auto tmpHandle = new (std::nothrow) EpTLSHdlAdp(); + if (NN_UNLIKELY(tmpHandle == nullptr)) { + NN_LOG_ERROR("Failed to new driver tls handler adapter, probably out of memory"); + return 0; + } + + tmpHandle->SetTLSCertCb(certCb); + tmpHandle->SetTLSPrivateKeyCb(priKeyCb); + tmpHandle->SetTLSCaCb(caCb); + reinterpret_cast(driver)->RegisterTLSCertificationCallback( + (std::bind(&EpTLSHdlAdp::UBSHcomTLSCertificationCallback, tmpHandle, std::placeholders::_1, + std::placeholders::_2))); + + reinterpret_cast(driver)->RegisterTLSPrivateKeyCallback( + (std::bind(&EpTLSHdlAdp::UBSHcomTLSPrivateKeyCallback, tmpHandle, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5))); + + reinterpret_cast(driver)->RegisterTLSCaCallback( + (std::bind(&EpTLSHdlAdp::UBSHcomTLSCaCallback, tmpHandle, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5))); + + return reinterpret_cast(tmpHandle); +} + +void ubs_hcom_driver_unregister_ep_handler(ubs_hcom_ep_handler_type t, uintptr_t handle) +{ + g_epHandlerManager.RemoveHdlAdp(handle); +} + +void ubs_hcom_driver_unregister_op_handler(ubs_hcom_op_handler_type t, uintptr_t handle) +{ + g_epHandlerManager.RemoveHdlAdp(handle); +} + +void ubs_hcom_driver_unregister_idle_handler(uintptr_t handle) +{ + g_epHandlerManager.RemoveHdlAdp(handle); +} + +int ubs_hcom_driver_create_memory_region(ubs_hcom_driver driver, uint64_t size, ubs_hcom_memory_region *mr) +{ + VALIDATE_DRIVER(driver); + VALIDATE_MR_POINT(mr); + + auto tmpMr = new (std::nothrow) UBSHcomNetMemoryRegionPtr; + if (tmpMr == nullptr) { + NN_LOG_ERROR("Create memory region malloc memory failed"); + return NN_NEW_OBJECT_FAILED; + } + + auto result = reinterpret_cast(driver)->CreateMemoryRegion(size, *tmpMr); + if (result != NN_OK) { + delete tmpMr; + return result; + } + + *mr = reinterpret_cast(tmpMr); + return NN_OK; +} + +int ubs_hcom_driver_create_assign_memory_region(ubs_hcom_driver driver, uintptr_t address, uint64_t size, + ubs_hcom_memory_region *mr) +{ + VALIDATE_DRIVER(driver); + VALIDATE_MR_POINT(mr); + + auto tmpMr = new (std::nothrow) UBSHcomNetMemoryRegionPtr; + if (tmpMr == nullptr) { + NN_LOG_ERROR("Create memory region malloc memory failed"); + return NN_NEW_OBJECT_FAILED; + } + auto result = reinterpret_cast(driver)->CreateMemoryRegion(address, size, *tmpMr); + if (result != NN_OK) { + delete tmpMr; + return result; + } + + *mr = reinterpret_cast(tmpMr); + return NN_OK; +} + +void ubs_hcom_driver_destroy_memory_region(ubs_hcom_driver driver, ubs_hcom_memory_region mr) +{ + VALIDATE_DRIVER_NO_RET(driver); + VALIDATE_MR_NO_RET(mr); + + auto tmpMr = reinterpret_cast(mr); + reinterpret_cast(driver)->DestroyMemoryRegion(*tmpMr); + delete tmpMr; +} + +int ubs_hcom_driver_get_memory_region_info(ubs_hcom_memory_region mr, ubs_hcom_memory_region_info *info) +{ + VALIDATE_MR(mr); + + if (NN_UNLIKELY(info == nullptr)) { + NN_LOG_ERROR("Param info is empty"); + return NN_PARAM_INVALID; + } + + auto tmpMrPtr = reinterpret_cast(mr); + auto tmpMr = tmpMrPtr->ToChild(); + if (NN_UNLIKELY(tmpMr == nullptr)) { + NN_LOG_ERROR("ToChild failed"); + return NN_ERROR; + } + info->lAddress = tmpMr->GetAddress(); + info->lKey = tmpMr->GetLKey(); + info->size = tmpMr->Size(); + return NN_OK; +} + +void ubs_hcom_ep_set_context(ubs_hcom_endpoint ep, uint64_t ctx) +{ + VALIDATE_EP_NO_RET(ep); + reinterpret_cast(ep)->UpCtx(ctx); +} + +uint64_t ubs_hcom_ep_get_context(ubs_hcom_endpoint ep) +{ + VALIDATE_EP(ep); + return reinterpret_cast(ep)->UpCtx(); +} + +uint16_t ubs_hcom_ep_get_worker_idx(ubs_hcom_endpoint ep) +{ + if (NN_UNLIKELY(ep == 0)) { + NN_LOG_ERROR("Invalid param, endpoint must be correct address"); + return NET_INVALID_WORKER_INDEX; + } + + return reinterpret_cast(ep)->WorkerIndex().idxInGrp; +} + +uint8_t ubs_hcom_ep_get_workergroup_idx(ubs_hcom_endpoint ep) +{ + if (NN_UNLIKELY(ep == 0)) { + NN_LOG_ERROR("Invalid param, endpoint must be correct address"); + return NET_INVALID_WORKER_GROUP_INDEX; + } + + return reinterpret_cast(ep)->WorkerIndex().grpIdx; +} + +uint32_t ubs_hcom_ep_get_listen_port(ubs_hcom_endpoint ep) +{ + if (NN_UNLIKELY(ep == 0)) { + NN_LOG_ERROR("Invalid param, endpoint must be correct address"); + return 0; + } + + return reinterpret_cast(ep)->ListenPort(); +} + +uint8_t ubs_hcom_ep_version(ubs_hcom_endpoint ep) +{ + if (NN_UNLIKELY(ep == 0)) { + NN_LOG_ERROR("Invalid param, endpoint must be correct address"); + return 0; + } + + return reinterpret_cast(ep)->Version(); +} + +void ubs_hcom_ep_set_timeout(ubs_hcom_endpoint ep, int32_t timeout) +{ + VALIDATE_EP_NO_RET(ep); + reinterpret_cast(ep)->DefaultTimeout(timeout); +} + +/* Caller have to make sure iov and src is NOT null */ +static inline int CopySglInfo(UBSHcomNetTransSgeIov *iov, uint16_t iovCnt, ubs_hcom_readwrite_request_sgl *src) +{ + if (NN_UNLIKELY(src == nullptr)) { + NN_LOG_ERROR("Invalid param, src is NULL"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(src->iov == nullptr)) { + NN_LOG_ERROR("Invalid param src iov"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(src->iovCount > iovCnt)) { + NN_LOG_ERROR("Invalid iov count, src iovCount: " << src->iovCount << ", iovCnt: " << iovCnt); + return NN_INVALID_PARAM; + } + + for (uint16_t i = 0; i < src->iovCount; i++) { + iov[i].rAddress = src->iov[i].rAddress; + iov[i].rKey = src->iov[i].rKey; + iov[i].lAddress = src->iov[i].lAddress; + iov[i].lKey = src->iov[i].lKey; + iov[i].size = src->iov[i].size; + } + + return NN_OK; +} + +int ubs_hcom_ep_post_send(ubs_hcom_endpoint ep, uint16_t opcode, ubs_hcom_send_request *req) +{ + VALIDATE_EP(ep); + VALIDATE_REQ(req); + + UBSHcomNetTransRequest transReq(reinterpret_cast(req->data), req->size, req->upCtxSize); + if (NN_UNLIKELY(memcpy_s(transReq.upCtxData, sizeof(transReq.upCtxData), req->upCtxData, sizeof(req->upCtxData)) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy up ctx data"); + return NN_INVALID_PARAM; + } + + return reinterpret_cast(ep)->PostSend(opcode, transReq); +} + +int ubs_hcom_ep_post_send_with_opinfo(ubs_hcom_endpoint ep, uint16_t opcode, ubs_hcom_send_request *req, + ubs_hcom_opinfo *opInfo) +{ + VALIDATE_EP(ep); + VALIDATE_REQ(req); + + if (NN_UNLIKELY(opInfo == nullptr)) { + NN_LOG_ERROR("Invalid param opInfo"); + return NN_INVALID_PARAM; + } + + UBSHcomNetTransRequest transReq(reinterpret_cast(req->data), req->size, req->upCtxSize); + if (NN_UNLIKELY(memcpy_s(transReq.upCtxData, sizeof(transReq.upCtxData), req->upCtxData, sizeof(req->upCtxData)) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy up ctx data"); + return NN_INVALID_PARAM; + } + + UBSHcomNetTransOpInfo innerOpInfo(opInfo->seqNo, opInfo->timeout, opInfo->errorCode, opInfo->flags); + + return reinterpret_cast(ep)->PostSend(opcode, transReq, innerOpInfo); +} + +int ubs_hcom_ep_post_send_raw(ubs_hcom_endpoint ep, ubs_hcom_send_request *req, uint32_t seqNo) +{ + VALIDATE_EP(ep); + VALIDATE_REQ(req); + VALIDATE_SEQ(seqNo); + + UBSHcomNetTransRequest transReq(reinterpret_cast(req->data), req->size, req->upCtxSize); + if (NN_UNLIKELY(memcpy_s(transReq.upCtxData, sizeof(transReq.upCtxData), req->upCtxData, sizeof(req->upCtxData)) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy up ctx data"); + return NN_INVALID_PARAM; + } + + return reinterpret_cast(ep)->PostSendRaw(transReq, seqNo); +} + +int ubs_hcom_ep_post_send_raw_sgl(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request_sgl *req, uint32_t seqNo) +{ + VALIDATE_EP(ep); + VALIDATE_REQ(req); + VALIDATE_SEQ(seqNo); + UBSHcomNetTransSglRequest transReq {}; + UBSHcomNetTransSgeIov iov[C_NET_SGE_MAX_IOV]; + bzero(&transReq, sizeof(UBSHcomNetTransSglRequest)); + + if (CopySglInfo(iov, C_NET_SGE_MAX_IOV, req) != NN_OK) { + return NN_INVALID_PARAM; + } + transReq.iov = iov; + transReq.iovCount = req->iovCount; + transReq.upCtxSize = req->upCtxSize; + if (NN_UNLIKELY(memcpy_s(transReq.upCtxData, sizeof(transReq.upCtxData), req->upCtxData, sizeof(req->upCtxData)) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy up ctx data"); + return NN_INVALID_PARAM; + } + + return reinterpret_cast(ep)->PostSendRawSgl(transReq, seqNo); +} + +int ubs_hcom_ep_post_send_with_seqno(ubs_hcom_endpoint ep, uint16_t opcode, ubs_hcom_send_request *req, + uint32_t replySeqNo) +{ + VALIDATE_EP(ep); + VALIDATE_REQ(req); + + UBSHcomNetTransRequest transReq(reinterpret_cast(req->data), req->size, req->upCtxSize); + if (NN_UNLIKELY(memcpy_s(transReq.upCtxData, sizeof(transReq.upCtxData), req->upCtxData, sizeof(req->upCtxData)) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy up ctx data"); + return NN_INVALID_PARAM; + } + + return reinterpret_cast(ep)->PostSend(opcode, transReq, replySeqNo); +} + +int ubs_hcom_ep_post_read(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request *req) +{ + VALIDATE_EP(ep); + VALIDATE_REQ(req); + + UBSHcomNetTransRequest transReq(req->lMRA, req->rMRA, req->lKey, req->rKey, req->size, req->upCtxSize); + if (NN_UNLIKELY(memcpy_s(transReq.upCtxData, sizeof(transReq.upCtxData), req->upCtxData, sizeof(req->upCtxData)) != + NN_OK)) { + NN_LOG_ERROR("Failed to post read by copy up ctx data err"); + return NN_INVALID_PARAM; + } + + return reinterpret_cast(ep)->PostRead(transReq); +} + +int ubs_hcom_ep_post_read_sgl(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request_sgl *req) +{ + VALIDATE_EP(ep); + VALIDATE_REQ(req); + UBSHcomNetTransSgeIov iov[C_NET_SGE_MAX_IOV]; + if (CopySglInfo(iov, C_NET_SGE_MAX_IOV, req) != NN_OK) { + NN_LOG_ERROR("Failed to post read sgl by copy sgl info err"); + return NN_INVALID_PARAM; + } + + UBSHcomNetTransSglRequest transReq { iov, req->iovCount, req->upCtxSize }; + if (NN_UNLIKELY(memcpy_s(transReq.upCtxData, sizeof(transReq.upCtxData), req->upCtxData, sizeof(req->upCtxData)) != + NN_OK)) { + NN_LOG_ERROR("Failed to post read sgl by copy up ctx data err"); + return NN_INVALID_PARAM; + } + + return reinterpret_cast(ep)->PostRead(transReq); +} + +int ubs_hcom_ep_post_write(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request *req) +{ + VALIDATE_EP(ep); + VALIDATE_REQ(req); + + UBSHcomNetTransRequest transReq(req->lMRA, req->rMRA, req->lKey, req->rKey, req->size, req->upCtxSize); + if (NN_UNLIKELY(memcpy_s(transReq.upCtxData, sizeof(transReq.upCtxData), req->upCtxData, sizeof(req->upCtxData)) != + NN_OK)) { + NN_LOG_ERROR("Failed to post write by copy up ctx data err"); + return NN_INVALID_PARAM; + } + + return reinterpret_cast(ep)->PostWrite(transReq); +} + +int ubs_hcom_ep_post_write_sgl(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request_sgl *req) +{ + VALIDATE_EP(ep); + VALIDATE_REQ(req); + + UBSHcomNetTransSgeIov iov[C_NET_SGE_MAX_IOV]; + if (CopySglInfo(iov, C_NET_SGE_MAX_IOV, req) != NN_OK) { + NN_LOG_ERROR("Failed to post write sgl by copy sgl info err"); + return NN_INVALID_PARAM; + } + + UBSHcomNetTransSglRequest transReq { iov, req->iovCount, req->upCtxSize }; + if (NN_UNLIKELY(memcpy_s(transReq.upCtxData, sizeof(transReq.upCtxData), req->upCtxData, sizeof(req->upCtxData)) != + NN_OK)) { + NN_LOG_ERROR("Failed to post write sgl by copy up ctx data err"); + return NN_INVALID_PARAM; + } + + return reinterpret_cast(ep)->PostWrite(transReq); +} + +int ubs_hcom_ep_wait_completion(ubs_hcom_endpoint ep, int32_t timeout) +{ + VALIDATE_EP(ep); + + return reinterpret_cast(ep)->WaitCompletion(timeout); +} + +int ubs_hcom_ep_receive(ubs_hcom_endpoint ep, int32_t timeout, ubs_hcom_response_context **ctx) +{ + if (ctx == nullptr) { + return NN_INVALID_PARAM; + } + + VALIDATE_EP(ep); + + static thread_local ubs_hcom_response_context context; + UBSHcomNetResponseContext rspContext; + auto result = reinterpret_cast(ep)->Receive(timeout, rspContext); + if (NN_LIKELY(result == NN_OK)) { + context.opCode = rspContext.Header().opCode; + context.seqNo = rspContext.Header().seqNo; + context.msgData = rspContext.Message()->Data(); + context.msgSize = rspContext.Header().dataLength; + + *ctx = &context; + } + + return result; +} + +int ubs_hcom_ep_receive_raw(ubs_hcom_endpoint ep, int32_t timeout, ubs_hcom_response_context **ctx) +{ + if (ctx == nullptr) { + return NN_INVALID_PARAM; + } + + VALIDATE_EP(ep); + + static thread_local ubs_hcom_response_context context; + UBSHcomNetResponseContext rspContext; + auto result = reinterpret_cast(ep)->ReceiveRaw(timeout, rspContext); + if (NN_LIKELY(result == NN_OK)) { + context.opCode = rspContext.Header().opCode; + context.seqNo = rspContext.Header().seqNo; + context.msgData = rspContext.Message()->Data(); + context.msgSize = rspContext.Header().dataLength; + + *ctx = &context; + } + + return result; +} + +int ubs_hcom_ep_receive_raw_sgl(ubs_hcom_endpoint ep, int32_t timeout, ubs_hcom_response_context **ctx) +{ + VALIDATE_EP(ep); + return ubs_hcom_ep_receive_raw(ep, timeout, ctx); +} + +void ubs_hcom_ep_refer(ubs_hcom_endpoint ep) +{ + VALIDATE_EP_NO_RET(ep); + reinterpret_cast(ep)->IncreaseRef(); +} + +void ubs_hcom_ep_close(ubs_hcom_endpoint ep) +{ + VALIDATE_EP_NO_RET(ep); + reinterpret_cast(ep)->Close(); +} + +void ubs_hcom_ep_destroy(ubs_hcom_endpoint ep) +{ + VALIDATE_EP_NO_RET(ep); + reinterpret_cast(ep)->DecreaseRef(); +} + +const char *ubs_hcom_err_str(int16_t errCode) +{ + return ock::hcom::UBSHcomNetErrStr(errCode); +} + +uint64_t ubs_hcom_estimate_encrypt_len(ubs_hcom_endpoint ep, uint64_t rawLen) +{ + VALIDATE_EP(ep); + return reinterpret_cast(ep)->EstimatedEncryptLen(rawLen); +} + +int ubs_hcom_encrypt(ubs_hcom_endpoint ep, const void *rawData, uint64_t rawLen, void *cipher, uint64_t *cipherLen) +{ + VALIDATE_EP(ep); + if (NN_UNLIKELY(cipherLen == nullptr)) { + NN_LOG_ERROR("Failed to encrypt as cipherLen is nullptr"); + return SER_INVALID_PARAM; + } + return reinterpret_cast(ep)->Encrypt(rawData, rawLen, cipher, *cipherLen); +} + +uint64_t ubs_hcom_estimate_decrypt_len(ubs_hcom_endpoint ep, uint64_t cipherLen) +{ + VALIDATE_EP(ep); + return reinterpret_cast(ep)->EstimatedDecryptLen(cipherLen); +} + +int ubs_hcom_decrypt(ubs_hcom_endpoint ep, const void *cipher, uint64_t cipherLen, void *rawData, uint64_t *rawLen) +{ + VALIDATE_EP(ep); + if (NN_UNLIKELY(rawLen == nullptr)) { + NN_LOG_ERROR("Failed to descrypt as rawLen is nullptr"); + return SER_INVALID_PARAM; + } + if (NN_UNLIKELY(rawData == nullptr)) { + NN_LOG_ERROR("Failed to descrypt as rawData is nullptr"); + return SER_INVALID_PARAM; + } + return reinterpret_cast(ep)->Decrypt(cipher, cipherLen, rawData, *rawLen); +} + +int ubs_hcom_send_fds(ubs_hcom_endpoint ep, int fds[], uint32_t len) +{ + VALIDATE_EP(ep); + return reinterpret_cast(ep)->SendFds(fds, len); +} + +int ubs_hcom_receive_fds(ubs_hcom_endpoint ep, int fds[], uint32_t len, int timeoutSec) +{ + VALIDATE_EP(ep); + return reinterpret_cast(ep)->ReceiveFds(fds, len, timeoutSec); +} + +int ubs_hcom_get_remote_uds_info(ubs_hcom_endpoint ep, ubs_hcom_uds_id_info *idInfo) +{ + VALIDATE_EP(ep); + + if (NN_UNLIKELY(idInfo == nullptr)) { + return NN_INVALID_PARAM; + } + + UBSHcomNetUdsIdInfo udsIdInfo {}; + auto result = reinterpret_cast(ep)->GetRemoteUdsIdInfo(udsIdInfo); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + idInfo->pid = udsIdInfo.pid; + idInfo->uid = udsIdInfo.uid; + idInfo->gid = udsIdInfo.gid; + return result; +} diff --git a/src/api/capi_v2/hcom_c.h b/src/api/capi_v2/hcom_c.h new file mode 100644 index 0000000000000000000000000000000000000000..6fc801887377f6942bfd7fc6b96607978a6af5f9 --- /dev/null +++ b/src/api/capi_v2/hcom_c.h @@ -0,0 +1,1160 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_CAPI_V2_HCOM_C_H_ +#define HCOM_CAPI_V2_HCOM_C_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define C_NET_SGE_MAX_IOV 4 +#define MAX_IP_LENGTH 16 +#define NET_C_FLAGS_BIT(i) (1UL << (i)) + +/* + * @brief Driver, which include oob & rdma communication & callback etc + */ +typedef uintptr_t ubs_hcom_driver; + +/* + * @brief Endpoint, represent one RDMA connection to dual-direction communication + * + * two side operation, ubs_hcom_ep_post_send + * read operation from remote, ubs_hcom_ep_post_read + * write operation from remote, ubs_hcom_ep_post_write + */ +typedef uintptr_t ubs_hcom_endpoint; + +/* + * @brief RegMemoryRegion, which region memory in RDMA Nic for write/read operation + */ +typedef uintptr_t ubs_hcom_memory_region; + +typedef enum { + NET_C_EP_SELF_POLLING = NET_C_FLAGS_BIT(0), + NET_C_EP_EVENT_POLLING = NET_C_FLAGS_BIT(1) +} ubs_hcom_polling_mode; + +/* + * @brief Request type, part of ubs_hcom_request_context + */ +typedef enum { + C_SENT = 0, + C_SENT_RAW = 1, + C_SENT_RAW_SGL = 2, + C_RECEIVED = 3, + C_RECEIVED_RAW = 4, + C_WRITTEN = 5, + C_READ = 6, + C_SGL_WRITTEN = 7, + C_SGL_READ = 8, +} ubs_hcom_request_type; + +/* + * @brief Worker polling type + * 1 For RDMA: + * C_BUSY_POLLING, means cpu 100% polling no matter there is request done, better performance but cost dedicated CPU + * C_EVENT_POLLING, waiting on OS kernel for request done + * 2 For TCP/UDS + * only event pooling is supported + */ +typedef enum { + C_BUSY_POLLING = 0, + C_EVENT_POLLING = 1, +} ubs_hcom_driver_working_mode; + +typedef enum { + C_DRIVER_RDMA = 0, + C_DRIVER_TCP = 1, + C_DRIVER_UDS = 2, + C_DRIVER_SHM = 3, + C_DRIVER_UBC = 6, +} ubs_hcom_driver_type; + +/* + * @brief DriverOobType working mode + */ +typedef enum { + C_NET_OOB_TCP = 0, + C_NET_OOB_UDS = 1, +} ubs_hcom_driver_oob_type; + +/* + * @brief Enum for secure type + */ +typedef enum { + C_NET_SEC_DISABLED = 0, + C_NET_SEC_ONE_WAY = 1, + C_NET_SEC_TWO_WAY = 2, +} ubs_hcom_driver_sec_type; + +typedef enum { + C_TLS_1_2 = 771, + C_TLS_1_3 = 772, +} ubs_hcom_driver_tls_version; + +/* + * @brief DriverCipherSuite mode + */ +typedef enum { + C_AES_GCM_128 = 0, + C_AES_GCM_256 = 1, + C_AES_CCM_128 = 2, + C_CHACHA20_POLY1305 = 3, +} ubs_hcom_driver_cipher_suite; + +/* + * @brief Memory allocator cache tier policy + */ +typedef enum { + C_TIER_TIMES = 0, /* tier by times of min-block-size */ + C_TIER_POWER = 1, /* tier by power of min-block-size */ +} ubs_hcom_memory_allocator_cache_tier_policy; + +/* + * @brief Enum for tls callback, set peer cert verify type + */ +typedef enum { + C_VERIFY_BY_NONE = 0, + C_VERIFY_BY_DEFAULT = 1, + C_VERIFY_BY_CUSTOM_FUNC = 2, +} ubs_hcom_peer_cert_verify_type; + +/* + * @brief Type of allocator + */ +typedef enum { + C_DYNAMIC_SIZE = 0, /* allocate dynamic memory size, there is alignment with X KB */ + C_DYNAMIC_SIZE_WITH_CACHE = 1, /* allocator with dynamic memory size, with pre-allocate cache for performance */ +} ubs_hcom_memory_allocator_type; + +/* + * @brief Enum for callback register [new endpoint connected or endpoint broken] + */ +typedef enum { + C_EP_NEW = 0, + C_EP_BROKEN = 1, +} ubs_hcom_ep_handler_type; + +/* + * @brief Enum for callback register [request received, request posted, read/write done] + */ +typedef enum { + C_OP_REQUEST_RECEIVED = 0, + C_OP_REQUEST_POSTED = 1, + C_OP_READWRITE_DONE = 2, +} ubs_hcom_op_handler_type; + +/* + * @brief Two side RDMA operation (i.e. RDMA send/receive) + * + * @param data, pointer of data need to send to peer (the data will be copied to register RDMA memory region + * the data must be less than mrSendReceiveSegSize of driver + * @param size, size of data + */ +typedef struct { + uintptr_t data; // pointer of data to send to peer + uint32_t size; // size of data + uint16_t upCtxSize; // user context size + char upCtxData[64]; // user context +} ubs_hcom_send_request; + +typedef struct { + uint32_t seqNo; // seq no + int16_t timeout; // timeout + int16_t errorCode; // error code + uint8_t flags; // flags +} ubs_hcom_opinfo; + +/* + * @brief Device information for user + */ +typedef struct { + int maxSge; // max iov count in UBSHcomNetTransSglRequest +} ubs_hcom_device_info; + +/* + * @brief Read/write request for one side rdma operation + */ +typedef struct { + uintptr_t lMRA; // local memory region address + uintptr_t rMRA; // remote memory region address + uint64_t lKey; // local memory region key + uint64_t rKey; // remote memory region key + uint32_t size; // data size + uint16_t upCtxSize; // user context size + char upCtxData[64]; // user context +} ubs_hcom_readwrite_request; + +typedef struct { + uintptr_t lAddress; // local memory region address + uintptr_t rAddress; // remote memory region address + uint64_t lKey; // local memory region key + uint64_t rKey; // remote memory region key + uint32_t size; // data size +} __attribute__((packed)) ubs_hcom_readwrite_sge; + +typedef struct { + ubs_hcom_readwrite_sge *iov; // sgl array + uint16_t iovCount; // max count:NUM_4 + uint16_t upCtxSize; // user context size + char upCtxData[16]; // user context +} __attribute__((packed)) ubs_hcom_readwrite_request_sgl; + +/* + * @brief Read/write mr info for one side rdma operation + */ +typedef struct { + uintptr_t lAddress; // local memory region address + uint64_t lKey; // local memory region key + uint32_t size; // data size +} ubs_hcom_memory_region_info; + +/* + * @brief Callback function context, for received, post done, read/write done + */ +typedef struct { + ubs_hcom_request_type type; + uint16_t opCode; // for post send + uint16_t flags; // flags on the header + int16_t timeout; // timeout + int16_t errorCode; // error code + int result; // return 0 successful + void *msgData; // for receive operation or C_OP_REQUEST_RECEIVED callback + uint32_t msgSize; // for receive operation or C_OP_REQUEST_RECEIVED callback + uint32_t seqNo; // for post send raw + ubs_hcom_endpoint ep; + ubs_hcom_send_request originalSend; // for C_OP_REQUEST_POSTED, copy struct information, not original + // originalSend.data is self rdma address, not original input data address + ubs_hcom_readwrite_request originalReq; // for C_OP_READWRITE_DONE, copy struct information, not original + ubs_hcom_readwrite_request_sgl originalSglReq; // for C_OP_READWRITE_DONE, copy struct information, not original +} ubs_hcom_request_context; + +typedef struct { + uint16_t opCode; + uint32_t seqNo; + void *msgData; + uint32_t msgSize; +} ubs_hcom_response_context; + +typedef struct { + uint32_t pid; + uint32_t uid; + uint32_t gid; +} ubs_hcom_uds_id_info; + +/* + * @brief Options for driver initialization + */ +typedef struct { + ubs_hcom_driver_working_mode mode; // polling mode + uint32_t mrSendReceiveSegCount; // segment count of segment for send/receive + uint32_t mrSendReceiveSegSize; // single segment size of send/receive memory region + char netDeviceIpMask[256]; // device ip mask, for multiple net device cases + char netDeviceIpGroup[1024]; // ip group for devices + uint16_t completionQueueDepth; // rdma completion queue size + uint16_t maxPostSendCountPerQP; // max post send count + uint16_t prePostReceiveSizePerQP; // pre post receive size for one qp + uint16_t pollingBatchSize; // polling wc size on at one time + uint32_t qpSendQueueSize; // qp send queue size, by default is 256 + uint32_t qpReceiveQueueSize; // qp receive queue size, by default is 256 + uint16_t dontStartWorkers; // start worker or not, 1 means don't start, 0 means start + char workerGroups[64]; // worker groups, for example 1,3,3 + char workerGroupsCpuSet[128]; // worker groups cpu set, for example 1-16 + // worker thread priority [-20,20], 20 is the lowest, -20 is the highest, 0 (default) means do not set priority + int workerThreadPriority; + uint16_t heartBeatIdleTime; // heart beat idle time, in seconds + uint16_t heartBeatProbeTimes; // heart beat probe times, in seconds + uint16_t heartBeatProbeInterval; // heart beat probe interval, in seconds + // timeout during io (s), it should be [-1, 1024], -1 means do not set, 0 means never timeout during io + int16_t tcpUserTimeout; + bool tcpEnableNoDelay; // tcp TCP_NODELAY option, true in default + bool tcpSendZCopy; // tcp whether copy request to inner memory, false in default + /* The buffer sizes will be adjusted automatically when these two variables are 0, and the performance would be + * better */ + uint16_t tcpSendBufSize; // tcp connection send buffer size in kernel, in KB + uint16_t tcpReceiveBufSize; // tcp connection send receive buf size in kernel, in KB + uint16_t enableTls; // value only in 0 and 1, value 1 means enable ssl and encrypt, 0 on the contrary + ubs_hcom_driver_sec_type secType; // security type + ubs_hcom_driver_tls_version tlsVersion; // tls version, default TLS1.3 (772) + ubs_hcom_driver_cipher_suite cipherSuite; // if tls enabled can set cipher suite, client and server should same + ubs_hcom_driver_oob_type oobType; // oob type, tcp or UDS, UDS cannot accept remote connection + uint8_t version; // program version used by connect validation + uint32_t maxConnectionNum; // max connection number + char oobPortRange[16]; // port range when enable port auto selection +} ubs_hcom_driver_opts; + +/* + * @brief Options for multiple listeners + */ +typedef struct { + char ip[16]; // ip to be listened + uint16_t port; // port to be listened + uint16_t targetWorkerCount; // the count of workers can be dispatched to, for connections from this listener +} ubs_hcom_driver_listen_opts; + +/* + * @brief Oob uds listening information + */ +typedef struct { + char name[96]; // UDS name for listen or file path + uint16_t perm; // if 0 means not use file, otherwise use file and this perm as file perm + uint16_t targetWorkerCount; // the count of target workers, if >= 1, + // the accepted socket will be attached to sub set to workers, 0 means all +} ubs_hcom_driver_uds_listen_opts; + +/* + * @brief Callback function definition + * 1) new endpoint connected from client, only need to register this at sever side + * 2) endpoint is broken, called when RDMA qp detection error or broken + */ +typedef int (*ubs_hcom_ep_handler)(ubs_hcom_endpoint ep, uint64_t usrCtx, const char *payLoad); + +/* + * @brief Callback function definition + * + * it is called when the following cases happen + * 1) post send done + * 2) read done + * 3) write done + * + * Important notes: + * 1) ctx is a thread local static variable, cannot transform to another thread directly + * 2) msgData need to copy to another space properly + * 3) ep can be transferred to another thread for further reply or other stuff + * in this case, need to call ubs_hcom_ep_refer() to increase reference count + * and call ubs_hcom_ep_destroy() after to decrease the reference count + */ +typedef int (*ubs_hcom_request_handler)(ubs_hcom_request_context *ctx, uint64_t usrCtx); + +/* + * @brief Idle callback function, when worker thread idle, this function will be called + * + * @param wkrGrpIdx [in] worker group index in on net driver + * @param idxInGrp [in] worker index in the group + * @param usrCtx [in] user context + */ +typedef void (*ubs_hcom_idle_handler)(uint8_t wkrGrpIdx, uint16_t idxInGrp, uint64_t usrCtx); + +/* + * @brief Sec callback function, when oob connect build, this function will be called to generate auth info. + * if this function not set secure type is C_NET_SEC_NO_VALID and oob will not send secure info + * + * @param ctx [in] ctx from connect param ctx, and will send in auth process + * @param flag [out] flag to sent in auth process + * @param type [out] secure type, value should set in oob client, and should in [C_NET_SEC_ONE_WAY, + * C_NET_SEC_TWO_WAY] + * @param output [out] secure info created + * @param outLen [out] secure info length + * @param needAutoFree [out] secure info need to auto free in hcom or not + */ +typedef int (*ubs_hcom_secinfo_provider)(uint64_t ctx, int64_t *flag, ubs_hcom_driver_sec_type *type, char **output, + uint32_t *outLen, int *needAutoFree); + +/* + * @brief ValidateSecInfo callback function, when oob connect build, this function will be called to validate auth info + * if this function not set oob will not validate secure info + * + * @param flag [in] flag received in auth process + * @param ctx [in] ctx received in auth process + * @param input [in] secure info received + * @param inputLen [in] secure info length + */ +typedef int (*ubs_hcom_secinfo_validator)(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen); + +/* + * @brief keyPass [in] erase function + * @param keyPass [in] the memory address of keyPass + */ +typedef void (*ubs_hcom_tls_keypass_erase)(char *keyPass, int len); + +/* + * @brief The cert verify function + * + * @param x509 [in] the x509 object of CA + * @param crlPath [in] the crl file path + * + * @return -1 for failed, and 1 for success + */ +typedef int (*ubs_hcom_tls_cert_verify)(void *x509, const char *crlPath); + +/* + * @brief Get the certificate file of public key + * + * @param name [out] the name + * @param certPath [out] the path of certificate + */ +typedef int (*ubs_hcom_tls_get_cert_cb)(const char *name, char **certPath); + +/* + * @brief Get private key file's path and length, and get the keyPass + * @param name [out] the name + * @param priKeyPath [out] the path of private key + * @param keyPass [out] the keyPass + * @param erase [out] the erase function + */ +typedef int (*ubs_hcom_tls_get_pk_cb) + (const char *name, char **priKeyPath, char **keyPass, ubs_hcom_tls_keypass_erase *erase); + +/* + * @brief Get the CA and verify + * @param name [out] the name + * @param caPath [out] the path of CA file + * @param crlPath [out] the crl file path + * @param verifyType [out] the type of verify in[VERIFY_BY_NONE,VERIFY_BY_DEFAULT, VERIFY_BY_CUSTOM_FUNC] + * @param verify [out] the verify function, only effect in VERIFY_BY_CUSTOM_FUNC mode + */ +typedef int (*ubs_hcom_tls_get_ca_cb)(const char *name, char **caPath, char **crlPath, + ubs_hcom_peer_cert_verify_type *verifyType, ubs_hcom_tls_cert_verify *verify); + +/* + * @brief External log callback function + * + * @param level [in] level, 0/1/2/3 represent debug/info/warn/error + * @param msg [in] message, log message with name:code-line-number + */ +typedef void (*ubs_hcom_log_handler)(int level, const char *msg); + +/* + * @brief Options for Memory Allocator + */ +typedef struct { + uintptr_t address; /* base address of large range of memory for allocator */ + uint64_t size; /* size of large memory chuck */ + uint32_t minBlockSize; /* min size of block, more than 4 KB is required */ + uint32_t bucketCount; /* default size of hash bucket */ + uint16_t alignedAddress; /* force to align the memory block allocated, 0 means not align + 1 means align */ + uint16_t cacheTierCount; /* for DYNAMIC_SIZE_WITH_CACHE only */ + uint16_t cacheBlockCountPerTier; /* for DYNAMIC_SIZE_WITH_CACHE only */ + ubs_hcom_memory_allocator_cache_tier_policy cacheTierPolicy; /* tier policy */ +} ubs_hcom_memory_allocator_options; + +/* + * @brief memory allocator ptr + */ +typedef uintptr_t ubs_hcom_memory_allocator; + +/* + * @brief Memory allocator create + * + * @param t [in] type of allocator + * @param options [in] options + * @param allocator [out] allocator created + */ +int ubs_hcom_mem_allocator_create(ubs_hcom_memory_allocator_type t, ubs_hcom_memory_allocator_options *options, + ubs_hcom_memory_allocator *allocator); + +/* + * @brief destroy the memory allocator + * + * @param allocator [in] memory allocator + * + * @return 0 if successful + */ +int ubs_hcom_mem_allocator_destroy(ubs_hcom_memory_allocator allocator); + +/* + * @brief Set the memory region key + * @param allocator [in] memory allocator + * + * @return 0 if successful + */ +int ubs_hcom_mem_allocator_set_mr_key(ubs_hcom_memory_allocator allocator, uint64_t mrKey); + +/* + * @brief Get the memory offset based on base address + * + * @param allocator [in] memory allocator + * @param address [in] memory address + * @param offset [out] offset comparing to base address + * + * @return 0 if successful + */ +int ubs_hcom_mem_allocator_get_offset(ubs_hcom_memory_allocator allocator, uintptr_t address, uintptr_t *offset); + +/* + * @brief Get free memory size + * + * @param allocator [in] memory allocator + * + * @return 0 if successful + */ +int ubs_hcom_mem_allocator_get_free_size(ubs_hcom_memory_allocator allocator, uintptr_t *size); + +/* + * @brief Allocate memory area + * + * @param allocator [in] memory allocator + * @param size [in] size of memory of demand + * @param address [out] allocated memory address + * @param key [out] allocated memory key + * + * @return 0 if successful + */ +int ubs_hcom_mem_allocator_allocate(ubs_hcom_memory_allocator allocator, uint64_t size, uintptr_t *address, + uint64_t *key); + +/* + * @brief Free the address allocated by #Allocate function + * + * @param allocator [in] memory allocator + * @param address [in] address to be freed + * + * @return 0 if successful + */ +int ubs_hcom_mem_allocator_free(ubs_hcom_memory_allocator allocator, uintptr_t address); + +/* + * @brief Set external logger function + * + * @param h [in] the log function ptr + */ +void ubs_hcom_set_log_handler(ubs_hcom_log_handler h); + +/* + * @brief Check if local host support certain protocol + * + * @param t [in] driver type + * @param info [out] driver info + * + * @return 1 if supported, 0 if not + */ +int ubs_hcom_check_local_support(ubs_hcom_driver_type t, ubs_hcom_device_info *info); +/* + * @brief Create a driver + * + * @param t [in] type of driver + * @param name [in] the name of driver + * @param startOobSvr [in] 0 or 1, 1 to start Oob server, 0 don't start Oob server + * @param driver [out] created driver address + * + * @return 0, if created successfully + */ +int ubs_hcom_driver_create(ubs_hcom_driver_type t, const char *name, uint8_t startOobSvr, ubs_hcom_driver *driver); + +/* + * @brief Set the out of bound ip and port, for endpoint connection + * + * @param driver [in] the address of driver + * @param ip [in] the ip for listen or connect + * @param port [in] the port for listen or connect + */ +void ubs_hcom_driver_set_ipport(ubs_hcom_driver driver, const char *ip, uint16_t port); + +/* + * @brief Get the out of bound ip and port + * + * @param driver [in] the address of driver + * @param ipArray [out] oob ip list + * @param port [out] oob port list + * @param length [out] the length of ipArray and portArray + */ +bool ubs_hcom_driver_get_ipport(ubs_hcom_driver driver, char ***ipArray, uint16_t **portArray, int *length); + +/* + * @brief Set oob listener of uds type + * + * @param name [in] name of uds listener + * + */ +void ubs_hcom_driver_set_udsname(ubs_hcom_driver driver, const char *name); + +/* + * @brief Add multiple oob uds listeners, if there is only one listener just use OobUdsName + * + * @param option [in] option of uds listener option + * + */ +void ubs_hcom_driver_add_uds_opt(ubs_hcom_driver driver, ubs_hcom_driver_uds_listen_opts option); + +/* + * @brief Add listen option if to enable multiple listener, duplicated ip and port cannot be added + * + * @param driver [in] the address of driver + * @param options [in] the options of the listener + * + */ +void ubs_hcom_driver_add_oob_opt(ubs_hcom_driver driver, ubs_hcom_driver_listen_opts options); + +/* + * @brief Initialize the driver + * + * @param driver [in] the address of driver + * @param options [in] options for initialization + * + * @return 0 if successful + */ +int ubs_hcom_driver_initialize(ubs_hcom_driver driver, ubs_hcom_driver_opts options); + +/* + * @brief Start the driver, start oob accept thread (server only) and RDMA polling thread + * + * @param driver [in] the address of driver + * + * @return 0 if successful + */ +int ubs_hcom_driver_start(ubs_hcom_driver driver); + +/* + * @brief, Connect to another driver (which is server) and new endpoint will be created if successful + * + * There is a retry in it in case of sever is quite busy + * + * @param driver [in] the address of driver + * @param payloadData [in] the payloadData, must be ended with \0, i.e. it is a string + * @param ep [out] the new endpoint created after connect to server + * @param flags [in] flags of ep to be created, NET_C_EP_SELF_POLLING for self polling ep, and + * NET_C_EP_EVENT_POLLING is the self polling mode + * + * @return 0 if successful + */ +int ubs_hcom_driver_connect(ubs_hcom_driver driver, const char *payloadData, ubs_hcom_endpoint *ep, uint32_t flags); + +int ubs_hcom_driver_connect_with_grpno(ubs_hcom_driver driver, const char *payloadData, ubs_hcom_endpoint *ep, + uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo); + +/* + * @brief Connect to another driver (which is server) and new endpoint will be created if successful + * + * There is a retry in it in case of sever is quite busy + * + * @param driver [in] the address of driver + * @param serverIp [in] server ip + * @param serverPort [in] server listen port + * @param payloadData [in] the payloadData, must be ended with \0, i.e. it is a string + * @param ep [out] the new endpoint created after connect to server + * @param flags [in] flags of ep to be created, NET_C_EP_SELF_POLLING for self polling ep, and + * NET_C_EP_EVENT_POLLING is the self polling mode + * + * @return 0 if successful + */ +int ubs_hcom_driver_connect_to_ipport(ubs_hcom_driver driver, const char *serverIp, uint16_t serverPort, + const char *payloadData, ubs_hcom_endpoint *ep, uint32_t flags); + +int ubs_hcom_driver_connect_to_ipport_with_groupno(ubs_hcom_driver driver, const char *serverIp, uint16_t serverPort, + const char *payloadData, ubs_hcom_endpoint *ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo); + +int ubs_hcom_driver_connect_to_ipport_with_ctx(ubs_hcom_driver driver, const char *serverIp, uint16_t serverPort, + const char *payloadData, ubs_hcom_endpoint *ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, + uint64_t ctx); +/* + * @brief Stop the driver + * + * @param driver [in] the address of driver + */ +void ubs_hcom_driver_stop(ubs_hcom_driver driver); + +/* + * @brief Un-initialize the driver + * + * @param driver [in] the address of driver + */ +void ubs_hcom_driver_uninitialize(ubs_hcom_driver driver); + +/* + * @brief Destroy the driver + * + * @param driver [in] the address of driver + * + * @return 0 if destroy successful + */ +int ubs_hcom_driver_destroy(ubs_hcom_driver driver); + +/* + * @brief Register callback function for endpoint + * + * @param driver [in] the address of driver + * @param t [in] handle type, could be C_EP_NEW or C_EP_BROKEN + * @param h [in] callback function address + * + * @return an inner handler address, for un-register in case of memory leak + */ +uintptr_t ubs_hcom_driver_register_ep_handler(ubs_hcom_driver driver, ubs_hcom_ep_handler_type t, + ubs_hcom_ep_handler h, uint64_t usrCtx); + +/* + * @brief Register callback function for endpoint operation + * + * @param driver [in] the address of driver + * @param t [in] handle type, could be C_OP_REQUEST_RECEIVED or C_OP_REQUEST_POSTED or + * C_OP_READWRITE_DONE + * @param h [in] callback function address + * + * @return an inner handler address, for un-register in case of memory leak + */ +uintptr_t ubs_hcom_driver_register_op_handler(ubs_hcom_driver driver, ubs_hcom_op_handler_type t, + ubs_hcom_request_handler h, uint64_t usrCtx); + +/* + * @brief Register callback function for worker idle + * + * @param driver [in] the address of driver + * @param t [in] handler + * @param usrCtx [in] user context, passed to callback function + * + * @return an inner handler address, for un-register in case of memory leak + */ +uintptr_t ubs_hcom_driver_register_idle_handler(ubs_hcom_driver driver, ubs_hcom_idle_handler h, uint64_t usrCtx); + +/* + * @brief Register callback function for create secure info + * + * @param driver [in] the address of driver + * @param provider [in] callback function address + * + * @return an inner handler address, for un-register in case of memory leak + */ +uintptr_t ubs_hcom_driver_register_secinfo_provider(ubs_hcom_driver driver, ubs_hcom_secinfo_provider provider); + +/* + * @brief Register callback function for validate secure info from peer + * + * @param driver [in] the address of driver + * @param validator [in] callback function address + * + * @return an inner handler address, for un-register in case of memory leak + */ +uintptr_t ubs_hcom_driver_register_secinfo_validator(ubs_hcom_driver driver, ubs_hcom_secinfo_validator validator); + +/* + * @brief Register callback function for tls enable + * + * @param driver [in] the address of driver + * @param certCb [in] callback to get cert + * @param priKeyCb [in] callback to get private key + * @param caCb [in] callback to get ca + * + * @return an inner handler address, for un-register in case of memory leak + */ +uintptr_t ubs_hcom_driver_register_tls_cb(ubs_hcom_driver driver, ubs_hcom_tls_get_cert_cb certCb, + ubs_hcom_tls_get_pk_cb priKeyCb, ubs_hcom_tls_get_ca_cb caCb); + +/* + * @brief Un-register callback function for endpoint + * + * @param t [in] handle type, could be C_EP_NEW or C_EP_BROKEN + * @param handle [in] callback function address returned when registered + * + */ +void ubs_hcom_driver_unregister_ep_handler(ubs_hcom_ep_handler_type t, uintptr_t handle); + +/* + * @brief Un-register callback function for endpoint operation + * + * @param t [in] handle type, could be C_OP_REQUEST_RECEIVED or C_OP_REQUEST_POSTED or + * C_OP_READWRITE_DONE + * @param handle [in] callback function address returned when registered + * + */ +void ubs_hcom_driver_unregister_op_handler(ubs_hcom_op_handler_type t, uintptr_t handle); + +/* + * @brief Un-register idle callback + * + * @param handle [in] callback function address returned when registered + */ +void ubs_hcom_driver_unregister_idle_handler(uintptr_t handle); + +/* + * @brief Register a memory region, the memory will be allocated internally + * + * @param driver [in] the address of driver + * @param size [in] size of the memory region + * @param mr [out] memory region registered + * + * @return 0 successful + */ +int ubs_hcom_driver_create_memory_region(ubs_hcom_driver driver, uint64_t size, ubs_hcom_memory_region *mr); + +/* + * @brief Register a memory region, the memory need to be passed in + * + * @param driver [in] the address of driver + * @param address [in] the memory point need to be registered + * @param size [in] size of the memory region + * @param mr [out] memory region registered + * + * @return 0 successful + */ +int ubs_hcom_driver_create_assign_memory_region(ubs_hcom_driver driver, uintptr_t address, uint64_t size, + ubs_hcom_memory_region *mr); + +/* + * @brief Unregister the memory region + * + * @param driver [in] the address of driver + * @param mr [in] memory region registered + * + * @return 0 successful + */ +void ubs_hcom_driver_destroy_memory_region(ubs_hcom_driver driver, ubs_hcom_memory_region mr); + +/* + * @brief Parse the memory region, get info + * + * @param mr [in] memory region registered + * @param info [in] memory region info + * + * @return 0 successful + */ +int ubs_hcom_driver_get_memory_region_info(ubs_hcom_memory_region mr, ubs_hcom_memory_region_info *info); + +/* + * @brief User can set a relative object address to endpoint + * this can be used locally only, not send to peer + * + * @param ep [in] address of ep + * @param ctx [in] context value to set + */ +void ubs_hcom_ep_set_context(ubs_hcom_endpoint ep, uint64_t ctx); + +/* + * @brief Get the relative object address of the endpoint + * + * @param ep [in] address of ep + * + * @return the context set by ubs_hcom_ep_set_context + */ +uint64_t ubs_hcom_ep_get_context(ubs_hcom_endpoint ep); + +#define NET_INVALID_WORKER_INDEX 0xffff +#define NET_INVALID_WORKER_GROUP_INDEX 0xff +/* + * @brief Get worker index from ep, 0xffff is invalid + * + * @param ep [in] address of ep + * + * @return Worker index in the worker group + */ +uint16_t ubs_hcom_ep_get_worker_idx(ubs_hcom_endpoint ep); + +/* + * @brief Get worker group index from ep, 0xff is invalid + * + * @param ep [in] address of ep + * + * @return Group index in the net driver + */ +uint8_t ubs_hcom_ep_get_workergroup_idx(ubs_hcom_endpoint ep); + +/* + * @brief Get ep listen port, 0 is invalid + * + * @param ep [in] address of ep + * + * @return Listen port of the ep accept from + */ +uint32_t ubs_hcom_ep_get_listen_port(ubs_hcom_endpoint ep); + +/* + * @brief Get ep version of peer, the version is transferred when connecting + * + * This could be used for version matching for backward compatibility + * + * @param ep [in] address of ep + * + * @return Version transferred from peer + */ +uint8_t ubs_hcom_ep_version(ubs_hcom_endpoint ep); + +/* + * @brief Set default timeout + * + * 1. timeout = 0: return immediately + * 2. timeout < 0: never timeout, usually set to -1 + * 3. timeout > 0: second precision timeout. + */ +void ubs_hcom_ep_set_timeout(ubs_hcom_endpoint ep, int32_t timeout); + +/* + * @brief Two side RDMA operation, send a data to peer + * + * 1) the callback function 'ubs_hcom_request_handler' will be triggered + * 2) after peer successfully received by RDMA driver, 'ubs_hcom_request_handler' + will be trigger as well, i.e. post done + * + * @param ep [in] the endpoint address + * @param opcode [in] opcode to peer + * @param req [in] request wrappers the data and size + * + * @return 0 for successful + */ +int ubs_hcom_ep_post_send(ubs_hcom_endpoint ep, uint16_t opcode, ubs_hcom_send_request *req); + +/* + * @brief Two side RDMA operation, send a data to peer + * + * 1) the callback function 'ubs_hcom_request_handler' will be triggered + * 2) after peer successfully received by RDMA driver, 'ubs_hcom_request_handler' + will be trigger as well, i.e. post done + * + * @param ep [in] the endpoint address + * @param opcode [in] opcode to peer + * @param req [in] request wrappers the data and size + * @param opInfo [in] opInfo to peer + * + * @return 0 for successful + */ +int ubs_hcom_ep_post_send_with_opinfo(ubs_hcom_endpoint ep, uint16_t opcode, ubs_hcom_send_request *req, + ubs_hcom_opinfo *opInfo); +/* + * @brief Two side RDMA operation, send a data to peer + * + * 1) the callback function 'ubs_hcom_request_handler' will be triggered + * 2) after peer successfully received by RDMA driver, 'ubs_hcom_request_handler' + will be trigger as well, i.e. post done + * + * @param ep [in] the endpoint address + * @param opcode [in] opcode to peer + * @param req [in] request wrappers the data and size + * @param replySeqNo [in] + * + * @return 0 for successful + */ +int ubs_hcom_ep_post_send_with_seqno(ubs_hcom_endpoint ep, uint16_t opcode, ubs_hcom_send_request *req, + uint32_t replySeqNo); + +/* + * @brief Post send a request without opcode and header to peer, peer will be trigger new request callback also + * without opcode and header, this could be used when you have self define header + * + * @param ep [in] the endpoint address + * @param req [in] request information, local address and size is used only, the data is copied, you can + * free it after called + * @param seqNo [in] seq no for peer to reply, must be > 0, peer can get it from context.Header().seqNo, + * for sync client it will be matching request and response + * + * Behavior: + * 1 For RDMA, + * case a) if NET_EP_SELF_POLLING is not set, just issue the send request, not wait for sending request finished + * case b) if NET_EP_SELF_POLLING is set, issue the send request and wait for sending arrived to peer + * + * @return 0 if successful + * + */ +int ubs_hcom_ep_post_send_raw(ubs_hcom_endpoint ep, ubs_hcom_send_request *req, uint32_t seqNo); + +/* + * @brief Post send a request without opcode and header to peer, peer will be trigger new request callback also + * without opcode and header, this could be used when you have self define header + * + * @param request [in] request information, fill with local different MRs, send to the same remote MR by local + * MRs sequence, you can free it after called. rKey/rAddress do not need to assign + * @param seqNo [in] seq no for peer to reply, must be > 0, peer can get it from context.Header().seqNo, + * for sync client it will be matching request and response + * + * Behavior: + * 1 For RDMA, + * case a) if NET_EP_SELF_POLLING is not set, just issue the send request, not wait for sending request finished + * case b) if NET_EP_SELF_POLLING is set, issue the send request and wait for sending arrived to peer + * + * @return 0 if successful + * + */ +int ubs_hcom_ep_post_send_raw_sgl(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request_sgl *req, uint32_t seqNo); + +/* + * @brief Read RDMA operation, read from peer + * + * 1) after peer successfully received by RDMA driver, 'ubs_hcom_request_handler' + will be trigger as well, i.e. read done + * + * @param ep [in] the endpoint address + * @param req [in] request wrappers the data and size + * + * @return 0 for successful + */ +int ubs_hcom_ep_post_read(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request *req); +int ubs_hcom_ep_post_read_sgl(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request_sgl *req); + +/* + * @brief Write RDMA operation, read from peer + * + * 1) after peer successfully received by RDMA driver, 'ubs_hcom_request_handler' + will be trigger as well, i.e. read done + * + * @param ep [in] the endpoint address + * @param req [in] request wrappers the data and size + * + * @return 0 for successful + */ +int ubs_hcom_ep_post_write(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request *req); +int ubs_hcom_ep_post_write_sgl(ubs_hcom_endpoint ep, ubs_hcom_readwrite_request_sgl *req); + +/* + * @brief Wait for send/read/write finish, only for NET_EP_SELF_POLLING is set + * + * @param timeout [in] in second + * 1. timeout = 0: return immediately + * 2. timeout < 0: never timeout, usually set to -1 + * 3. timeout > 0: second precision timeout max is 2000s. + * + * Behavior: + * 1 for send, return when request send to peer + * 2 for read, return when read completion + * 3 for write, return when write completion + * + * @return 0 if successful + * + * NN_TIMEOUT if timeout + */ +int ubs_hcom_ep_wait_completion(ubs_hcom_endpoint ep, int32_t timeout); + +/* + * @brief Get the response for send request reply + * + * @param timeout [in] in second + * 1. timeout = 0: return immediately + * 2. timeout < 0: never timeout, usually set to -1 + * 3. timeout > 0: second precision timeout max is 2000s. + * @param ctx [out] ctx for response message, ctx cannot be freed by caller + * + * @return 0 if successful + */ +int ubs_hcom_ep_receive(ubs_hcom_endpoint ep, int32_t timeout, ubs_hcom_response_context **ctx); + +/* + * @brief Get the response for send request reply + * + * @param timeout [in] in second + * 1. timeout = 0: return immediately + * 2. timeout < 0: never timeout, usually set to -1 + * 3. timeout > 0: second precision timeout max is 2000s. + * @param ctx [out] ctx for response message, ctx cannot be freed by caller + * + * @return 0 if successful + */ +int ubs_hcom_ep_receive_raw(ubs_hcom_endpoint ep, int32_t timeout, ubs_hcom_response_context **ctx); + +/* + * @brief Get the response for send request reply + * + * @param timeout [in] in second + * 1. timeout = 0: return immediately + * 2. timeout < 0: never timeout, usually set to -1 + * 3. timeout > 0: second precision timeout max is 2000s. + * @param ctx [out] ctx for response message, ctx cannot be freed by caller + * + * @return 0 if successful + */ +int ubs_hcom_ep_receive_raw_sgl(ubs_hcom_endpoint ep, int32_t timeout, ubs_hcom_response_context **ctx); + +/* + * @brief Increase the internal reference count, need to call this when forwarding the context to another thread to + * process + * + * @param ep, [in] the endpoint address + */ +void ubs_hcom_ep_refer(ubs_hcom_endpoint ep); + +/* + * @brief Close endpoint, then will async call broken function + */ +void ubs_hcom_ep_close(ubs_hcom_endpoint ep); + +/* + * @brief Destroy the end point + * + * @param ep, [in] the endpoint address + */ +void ubs_hcom_ep_destroy(ubs_hcom_endpoint ep); + +const char *ubs_hcom_err_str(int16_t errCode); + +/* + * @brief Estimated Encrypt length for input raw len + * + * @param ep [in] the endpoint address + * @param rawLen [in] raw length before encrypt + * + * @return the length after encrypt + */ +uint64_t ubs_hcom_estimate_encrypt_len(ubs_hcom_endpoint ep, uint64_t rawLen); + +/* + * @brief Encrypt data + * + * @param ep [in] the endpoint address + * @param rawData [in] raw data before encrypt + * @param rawLen [in] raw data length before encrypt + * @param cipher [out] cipher data after encrypt + * @param cipherLen [out] cipher data length after encrypt + * + * @return 0 if success + */ +int ubs_hcom_encrypt(ubs_hcom_endpoint ep, const void *rawData, uint64_t rawLen, void *cipher, uint64_t *cipherLen); + +/* + * @brief Estimate Decrypt length + * + * @param ep [in] the endpoint address + * @param cipherLen [in] cipher len before decrypt + * + * @return the raw length after decrypt + */ +uint64_t ubs_hcom_estimate_decrypt_len(ubs_hcom_endpoint ep, uint64_t cipherLen); + +/* + * @brief Decrypt data + * + * @param ep [in] the endpoint address + * @param cipher [in] cipher data after encrypt + * @param cipherLen [in] cipher data length after encrypt + * @param rawData [out] raw data before encrypt + * @param rawLen [out] raw data length before encrypt + * + * @return 0 if success + */ +int ubs_hcom_decrypt(ubs_hcom_endpoint ep, const void *cipher, uint64_t cipherLen, void *rawData, uint64_t *rawLen); + +/* + * @brief Send shm fds, only shm protocol support + * + * @param ep [in] the endpoint address + * @param fds [in] fds to send + * @param len [in] fds count to send + * + * @return 0 if success + */ +int ubs_hcom_send_fds(ubs_hcom_endpoint ep, int fds[], uint32_t len); + +/* + * @brief Receive shm fds, only shm protocol support + * + * @param ep [in] the endpoint address + * @param fds [out] fds to be received + * @param len [in] fds count to be received + * @param timeoutSec [in] timeout in second for receive. -1 is never timeout + * + * @return 0 if success + */ +int ubs_hcom_receive_fds(ubs_hcom_endpoint ep, int fds[], uint32_t len, int timeoutSec); + +/* + * @brief Get remote uds ids include pid uid gid, only support in oob server and when oob type is uds + * + * @param ep [in] the endpoint address + * @param idInfo [out] remote uds idInfo + * + * @return 0 if success + */ +int ubs_hcom_get_remote_uds_info(ubs_hcom_endpoint ep, ubs_hcom_uds_id_info *idInfo); +#ifdef __cplusplus +} +#endif + +#endif // HCOM_CAPI_V2_HCOM_C_H_ diff --git a/src/api/capi_v2/hcom_def_inner_c.h b/src/api/capi_v2/hcom_def_inner_c.h new file mode 100644 index 0000000000000000000000000000000000000000..03e38d23ecc0e6e0fc1c45915a1e9191bd32514f --- /dev/null +++ b/src/api/capi_v2/hcom_def_inner_c.h @@ -0,0 +1,506 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_CAPI_V2_HCOM_DEF_INNER_C_H_ +#define HCOM_CAPI_V2_HCOM_DEF_INNER_C_H_ + +#include +#include + +#include "hcom_c.h" +#include "service_v2/api/hcom_service.h" +#include "hcom_service_c.h" +#include "securec.h" + +namespace ock { +namespace hcom { +class EpHdlAdp { +public: + EpHdlAdp(ubs_hcom_ep_handler_type t, ubs_hcom_ep_handler h, uint64_t usrCtx) : mHandlerType(t), + mHandler(h), mUsrCtx(usrCtx) {} + ~EpHdlAdp() + { + mHandler = nullptr; + } + + int NewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) + { + if (NN_UNLIKELY(mHandler == nullptr || newEP.Get() == nullptr || mHandlerType != C_EP_NEW)) { + return NN_INVALID_PARAM; + } + newEP->IncreaseRef(); + return mHandler(reinterpret_cast(newEP.Get()), mUsrCtx, payload.c_str()); + } + + void EndPointBroken(const UBSHcomNetEndpointPtr &ep) + { + if (NN_UNLIKELY(mHandler == nullptr || ep.Get() == nullptr || mHandlerType != C_EP_BROKEN)) { + return; + } + + mHandler(reinterpret_cast(ep.Get()), mUsrCtx, ep->PeerConnectPayload().c_str()); + } + +private: + ubs_hcom_ep_handler_type mHandlerType; + ubs_hcom_ep_handler mHandler = nullptr; + uint64_t mUsrCtx = 0; +}; + +class EpOpHdlAdp { +public: + explicit EpOpHdlAdp(ubs_hcom_request_handler handler, uint64_t usrCtx) : mHandler(handler), mUsrCtx(usrCtx) {} + + ~EpOpHdlAdp() + { + mHandler = nullptr; + } + + static void BuildRequestCommonFiled(const UBSHcomNetRequestContext &ctx, ubs_hcom_request_context &localCtx) + { + (localCtx).result = static_cast((ctx).Result()); + (localCtx).type = static_cast((ctx).OpType()); + (localCtx).opCode = (ctx).Header().opCode; + (localCtx).ep = reinterpret_cast((ctx).EndPoint().Get()); + (localCtx).seqNo = (ctx).Header().seqNo; + (localCtx).flags = (ctx).Header().flags; + (localCtx).errorCode = (ctx).Header().errorCode; + (localCtx).timeout = (ctx).Header().timeout; + if ((ctx).Message() != nullptr) { + (localCtx).msgData = (ctx).Message()->Data(); + (localCtx).msgSize = (ctx).Header().dataLength; + } + } + + int Requested(const UBSHcomNetRequestContext &ctx) + { + if (NN_UNLIKELY(mHandler == nullptr)) { + return NN_INVALID_PARAM; + } + + static thread_local ubs_hcom_request_context localCtx {}; + bzero(&localCtx, sizeof(ubs_hcom_request_context)); + BuildRequestCommonFiled(ctx, localCtx); + + if (ctx.OpType() == UBSHcomNetRequestContext::NN_SENT || + ctx.OpType() == UBSHcomNetRequestContext::NN_SENT_RAW) { + localCtx.originalSend.data = 0; + localCtx.originalSend.size = ctx.OriginalRequest().size; + localCtx.originalSend.upCtxSize = ctx.OriginalRequest().upCtxSize; + if (NN_UNLIKELY(memcpy_s(localCtx.originalSend.upCtxData, sizeof(localCtx.originalSend.upCtxData), + ctx.OriginalRequest().upCtxData, sizeof(ctx.OriginalRequest().upCtxData)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy up ctx data"); + return NN_INVALID_PARAM; + } + } else if (ctx.OpType() == UBSHcomNetRequestContext::NN_WRITTEN || + ctx.OpType() == UBSHcomNetRequestContext::NN_READ) { + localCtx.originalReq.lMRA = ctx.OriginalRequest().lAddress; + localCtx.originalReq.rMRA = ctx.OriginalRequest().rAddress; + localCtx.originalReq.lKey = ctx.OriginalRequest().lKey; + localCtx.originalReq.rKey = ctx.OriginalRequest().rKey; + localCtx.originalReq.size = ctx.OriginalRequest().size; + localCtx.originalReq.upCtxSize = ctx.OriginalRequest().upCtxSize; + if (NN_UNLIKELY(memcpy_s(localCtx.originalReq.upCtxData, sizeof(localCtx.originalReq.upCtxData), + ctx.OriginalRequest().upCtxData, sizeof(ctx.OriginalRequest().upCtxData)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy up ctx data"); + return NN_INVALID_PARAM; + } + } else if (ctx.OpType() == UBSHcomNetRequestContext::NN_SGL_WRITTEN || + ctx.OpType() == UBSHcomNetRequestContext::NN_SGL_READ || + ctx.OpType() == UBSHcomNetRequestContext::NN_SENT_RAW_SGL) { + localCtx.originalSglReq.iov = reinterpret_cast(ctx.OriginalSgeRequest().iov); + localCtx.originalSglReq.iovCount = ctx.OriginalSgeRequest().iovCount; + localCtx.originalSglReq.upCtxSize = ctx.OriginalSgeRequest().upCtxSize; + if (NN_UNLIKELY(memcpy_s(localCtx.originalSglReq.upCtxData, sizeof(localCtx.originalSglReq.upCtxData), + ctx.OriginalSgeRequest().upCtxData, sizeof(ctx.OriginalSgeRequest().upCtxData)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy up ctx data"); + return NN_INVALID_PARAM; + } + } + + return mHandler(&localCtx, mUsrCtx); + } + +private: + ubs_hcom_request_handler mHandler = nullptr; + uint64_t mUsrCtx = 0; +}; + +class OOBSecInfoProviderAdp { +public: + explicit OOBSecInfoProviderAdp(ubs_hcom_secinfo_provider provider) : mProvider(provider) {} + + ~OOBSecInfoProviderAdp() + { + mProvider = nullptr; + } + int CreateSecInfo(uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, uint32_t &outLen, + bool &needAutoFree) + { + if (NN_UNLIKELY(mProvider == nullptr)) { + return -1; + } + + auto driSecType = static_cast(0); + int needFree = 0; + auto ret = mProvider(ctx, &flag, &driSecType, &output, &outLen, &needFree); + if (ret != 0) { + return ret; + } + + if (driSecType == C_NET_SEC_DISABLED) { + type = NET_SEC_DISABLED; + } else if (driSecType == C_NET_SEC_ONE_WAY) { + type = NET_SEC_VALID_ONE_WAY; + } else if (driSecType == C_NET_SEC_TWO_WAY) { + type = NET_SEC_VALID_TWO_WAY; + } + + if (needFree) { + needAutoFree = true; + } + + return ret; + } + +private: + ubs_hcom_secinfo_provider mProvider = nullptr; +}; + +class OOBSecInfoValidatorAdp { +public: + explicit OOBSecInfoValidatorAdp(ubs_hcom_secinfo_validator validator) : mValidator(validator) {} + + ~OOBSecInfoValidatorAdp() + { + mValidator = nullptr; + } + + int SecInfoValidate(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen) + { + if (NN_UNLIKELY(mValidator == nullptr)) { + return -1; + } + return mValidator(ctx, flag, input, inputLen); + } + +private: + ubs_hcom_secinfo_validator mValidator = nullptr; +}; + +class EpIdleHdlAdp { +public: + explicit EpIdleHdlAdp(ubs_hcom_idle_handler handler, uint64_t usrCtx) : mHandler(handler), mUsrCtx(usrCtx) {} + + ~EpIdleHdlAdp() + { + mHandler = nullptr; + } + + void Idle(const UBSHcomNetWorkerIndex &index) + { + if (NN_UNLIKELY(mHandler == nullptr)) { + return; + } + + mHandler(index.grpIdx, index.idxInGrp, mUsrCtx); + } + +private: + ubs_hcom_idle_handler mHandler = nullptr; + uint64_t mUsrCtx = 0; +}; + +class EpTLSHdlAdp { +public: + EpTLSHdlAdp() = default; + + ~EpTLSHdlAdp() = default; + + inline void SetTLSCertCb(ubs_hcom_tls_get_cert_cb h) + { + mGetCert = h; + } + + inline void SetTLSCaCb(ubs_hcom_tls_get_ca_cb h) + { + mGetCA = h; + } + + inline void SetTLSPrivateKeyCb(ubs_hcom_tls_get_pk_cb h) + { + mGetPriKey = h; + } + + bool UBSHcomTLSPrivateKeyCallback(const std::string &name, std::string &path, void *&keyPass, int len, + ::ock::hcom::UBSHcomTLSEraseKeypass &callback) + { + if (NN_UNLIKELY(mGetPriKey == nullptr)) { + return false; + } + + char *privateKeyPath = nullptr; + char *keyPassWd = nullptr; + ubs_hcom_tls_keypass_erase erase; + + mGetPriKey(name.c_str(), &privateKeyPath, &keyPassWd, &erase); + + if (NN_UNLIKELY(privateKeyPath == nullptr) || NN_UNLIKELY(keyPassWd == nullptr) || + NN_UNLIKELY(erase == nullptr)) { + NN_LOG_INFO("Failed to get private key, key pass or erase function from callback"); + return false; + } + + path = privateKeyPath; + keyPass = keyPassWd; + callback = std::bind(&EraseCB, erase, std::placeholders::_1, std::placeholders::_2); + + return true; + } + + bool UBSHcomTLSCertificationCallback(const std::string &name, std::string &path) + { + if (NN_UNLIKELY(mGetCert == nullptr)) { + return false; + } + + char *certPath = nullptr; + mGetCert(name.c_str(), &certPath); + if (NN_UNLIKELY(certPath == nullptr)) { + NN_LOG_INFO("Failed to get cert path from TLS cert callback."); + return false; + } + + path = certPath; + + return true; + } + + static void EraseCB(ubs_hcom_tls_keypass_erase erase, void *pw, int len) + { + erase(reinterpret_cast(pw), len); + } + + bool UBSHcomTLSCaCallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, ::ock::hcom::UBSHcomTLSCertVerifyCallback &callback) + { + if (NN_UNLIKELY(mGetCA == nullptr)) { + return false; + } + + char *caPth = nullptr; + char *crlPth = nullptr; + ubs_hcom_tls_cert_verify verifyCb = nullptr; + ubs_hcom_peer_cert_verify_type verifyType = C_VERIFY_BY_DEFAULT; + + mGetCA(name.c_str(), &caPth, &crlPth, &verifyType, &verifyCb); + + if (NN_UNLIKELY(caPth == nullptr)) { + NN_LOG_INFO("Failed to get CA path from callback"); + return false; + } + + caPath = caPth; + + if (crlPth != nullptr) { + crlPath = crlPth; + } + + callback = verifyCb; + + if (verifyType == C_VERIFY_BY_NONE) { + peerCertVerifyType = VERIFY_BY_NONE; + } else if (verifyType == C_VERIFY_BY_DEFAULT) { + peerCertVerifyType = VERIFY_BY_DEFAULT; + } else if (verifyType == C_VERIFY_BY_CUSTOM_FUNC) { + peerCertVerifyType = VERIFY_BY_CUSTOM_FUNC; + } + return true; + } + +private: + ubs_hcom_tls_get_cert_cb mGetCert = nullptr; + ubs_hcom_tls_get_pk_cb mGetPriKey = nullptr; + ubs_hcom_tls_get_ca_cb mGetCA = nullptr; +}; + +class HdlMgr { +public: + void AddHdlAdp(uintptr_t adp) + { + std::lock_guard guard(mMutex); + auto iterator = mHdlAdpSet.find(adp); + if (iterator != mHdlAdpSet.end()) { + return; + } + + mHdlAdpSet.insert(adp); + } + + template void RemoveHdlAdp(uintptr_t adp) + { + uintptr_t hdlAddr = 0; + { + std::lock_guard guard(mMutex); + auto iterator = mHdlAdpSet.find(adp); + if (iterator == mHdlAdpSet.end()) { + return; + } + + hdlAddr = *iterator; + mHdlAdpSet.erase(iterator); + } + + if (hdlAddr == 0) { + NN_LOG_ERROR("Remove handle not found"); + return; + } + + auto adpPtr = reinterpret_cast(hdlAddr); + delete adpPtr; + } + +private: + std::mutex mMutex; + std::unordered_set mHdlAdpSet; +}; + +class ServiceHdlAdp { +public: + ServiceHdlAdp(ubs_hcom_service_channel_handler_type t, ubs_hcom_service_channel_policy p, + ubs_hcom_service_channel_handler h, uint64_t usrCtx) + : mHandlerType(t), mHandler(h), mUsrCtx(usrCtx) {} + ServiceHdlAdp(ubs_hcom_service_channel_handler_type t, ubs_hcom_service_channel_handler h, uint64_t usrCtx) + : mHandlerType(t), mHandler(h), mUsrCtx(usrCtx) {} + + ~ServiceHdlAdp() + { + mHandler = nullptr; + } + + int NewChannel(const std::string &ipPort, const UBSHcomChannelPtr &newCh, const std::string &payload) + { + if (NN_UNLIKELY(mHandler == nullptr || newCh.Get() == nullptr || mHandlerType != C_CHANNEL_NEW)) { + return NN_INVALID_PARAM; + } + + // increase ref and need call ep_destroy + newCh->IncreaseRef(); + + return mHandler(reinterpret_cast(newCh.Get()), mUsrCtx, payload.c_str()); + } + + void ChannelBroken(const UBSHcomChannelPtr &ch) + { + if (NN_UNLIKELY(mHandler == nullptr || ch.Get() == nullptr || mHandlerType != C_CHANNEL_BROKEN)) { + return; + } + + mHandler(reinterpret_cast(ch.Get()), mUsrCtx, ch->GetPeerConnectPayload().c_str()); + } + +private: + ubs_hcom_service_channel_handler_type mHandlerType; + ubs_hcom_service_channel_handler mHandler = nullptr; + uint64_t mUsrCtx = 0; +}; + +class ServiceIdleHdlAdp { +public: + explicit ServiceIdleHdlAdp(ubs_hcom_idle_handler handler, uint64_t usrCtx) + : mServiceHandler(handler), mUsrCtx(usrCtx) {} + ~ServiceIdleHdlAdp() + { + mServiceHandler = nullptr; + } + void Idle(const UBSHcomNetWorkerIndex &index) + { + if (NN_UNLIKELY(mServiceHandler == nullptr)) { + return; + } + mServiceHandler(index.grpIdx, index.idxInGrp, mUsrCtx); + } + +private: + ubs_hcom_idle_handler mServiceHandler = nullptr; + uint64_t mUsrCtx = 0; +}; + +class ChannelOpHdlAdp { +public: + explicit ChannelOpHdlAdp(ubs_hcom_service_request_handler handler, uint64_t usrCtx) + : mHandler(handler), mUsrCtx(usrCtx) {} + + ~ChannelOpHdlAdp() + { + mHandler = nullptr; + } + int Requested(const UBSHcomServiceContext &ctx) + { + if (NN_UNLIKELY(mHandler == nullptr)) { + return NN_INVALID_PARAM; + } + + return mHandler(reinterpret_cast(&ctx), mUsrCtx); + } + +private: + ubs_hcom_service_request_handler mHandler = nullptr; + uint64_t mUsrCtx = 0; +}; + +template +class ServiceHdlMgr { +public: + // 添加属于特定Service的指针 + void AddHdlAdp(uintptr_t service, uintptr_t adp) + { + std::lock_guard guard(mMutex); + + // 自动为不存在的Service创建条目 + auto& adpSet = mHdlAdpMap[service]; + if (adpSet.find(adp) != adpSet.end()) { + return; // 已存在则不重复添加 + } + adpSet.insert(adp); + } + + void RemoveAll(uintptr_t svc) + { + std::unordered_set adpToDelete; + + { // 临界区开始 + std::lock_guard guard(mMutex); + auto svcIter = mHdlAdpMap.find(svc); + if (svcIter == mHdlAdpMap.end()) { + return; // Service不存在 + } + + // 转移所有权到临时集合 + adpToDelete = std::move(svcIter->second); + mHdlAdpMap.erase(svcIter); // 立即移除Service条目 + } // 临界区结束 + + // 在锁外执行资源释放 + for (auto& adp : adpToDelete) { + if (adp != 0) { // 防御性检查 + delete reinterpret_cast(adp); + } + } + } + +private: + std::mutex mMutex; + std::unordered_map> mHdlAdpMap; +}; +} +} +#endif // HCOM_CAPI_V2_HCOM_DEF_INNER_C_H_ diff --git a/src/api/capi_v2/hcom_service_c.cpp b/src/api/capi_v2/hcom_service_c.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d44b74a1d9d6e80641a26563227f7317db53de1 --- /dev/null +++ b/src/api/capi_v2/hcom_service_c.cpp @@ -0,0 +1,1075 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "hcom_service_c.h" +#include +#include "hcom_def_inner_c.h" +#include "hcom_err.h" +#include "api/hcom_service_def.h" +#include "api/hcom_service.h" +#include "api/hcom_service_channel.h" +#include "api/hcom_service_context.h" +#include "net_common.h" + +using namespace ock::hcom; + +#define VALIDATE_SERVICE(service) \ + if (NN_UNLIKELY((service) == 0)) { \ + NN_LOG_ERROR("Invalid param, service must be correct address"); \ + return SER_INVALID_PARAM; \ + } + +#define VALIDATE_SERVICE_NO_RET(service) \ + if (NN_UNLIKELY((service) == 0)) { \ + NN_LOG_ERROR("Invalid param, service must be correct address"); \ + return; \ + } + +#define VALIDATE_CHANNEL(channel) \ + if (NN_UNLIKELY((channel) == 0)) { \ + NN_LOG_ERROR("Invalid param, channel must be correct address"); \ + return SER_INVALID_PARAM; \ + } + +#define VALIDATE_HANDLER(h) \ + if (NN_UNLIKELY((h) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, handler must be correct address"); \ + return 0; \ + } + +#define VALIDATE_HANDLER_NO_RET(h) \ + if (NN_UNLIKELY((h) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, handler must be correct address"); \ + return; \ + } + +#define VALIDATE_MR(mr) \ + if (NN_UNLIKELY((mr) == 0)) { \ + NN_LOG_ERROR("Invalid param, mr must be correct mr address"); \ + return SER_INVALID_PARAM; \ + } + +#define VALIDATE_INFO(info) \ + if (NN_UNLIKELY((info) == 0)) { \ + NN_LOG_ERROR("Invalid param, info must be correct mr address"); \ + return SER_INVALID_PARAM; \ + } + +#define VALIDATE_MR_POINT(mr) \ + if (NN_UNLIKELY((mr) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, mr pointer must be correct address"); \ + return SER_INVALID_PARAM; \ + } + +#define VALIDATE_MR_ADDRESS(address) \ + if (NN_UNLIKELY((address) == 0)) { \ + NN_LOG_ERROR("Invalid param, mr address must be correct address"); \ + return SER_INVALID_PARAM; \ + } + +#define VALIDATE_MR_SIZE(size) \ + if (NN_UNLIKELY((size) == 0)) { \ + NN_LOG_ERROR("Invalid param, mr size must be correct size"); \ + return SER_INVALID_PARAM; \ + } + +#define VALIDATE_CHANNEL(channel) \ + if (NN_UNLIKELY((channel) == 0)) { \ + NN_LOG_ERROR("Invalid param, channel must be correct address"); \ + return SER_INVALID_PARAM; \ + } + +#define VALIDATE_CHANNEL_NO_RET(channel) \ + if (NN_UNLIKELY((channel) == 0)) { \ + NN_LOG_ERROR("Invalid param, channel must be correct address"); \ + return; \ + } + +#define VALIDATE_MESSAGE(req) \ + if (NN_UNLIKELY((req) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, message must be correct address"); \ + return SER_INVALID_PARAM; \ + } + +#define COPY_ONESIDE_KEY(input, output) \ + for (uint32_t i = 0; i < NN_NO4; i++) { \ + (output).keys[i] = (input).keys[i]; \ + (output).tokens[i] = (input).tokens[i]; \ + } + +#define VALIDATE_CONTEXT(context) \ + if (NN_UNLIKELY((context) == 0)) { \ + NN_LOG_ERROR("Invalid param, context should be correct address"); \ + return SER_INVALID_PARAM; \ + } + +#define VALIDATE_CONTEXT_RETURN_PTR(context) \ + if (NN_UNLIKELY((context) == 0)) { \ + NN_LOG_ERROR("Invalid param, context should be correct address"); \ + return nullptr; \ + } + +#define VALIDATE_CONTEXT_RETURN_ZERO(context) \ + if (NN_UNLIKELY((context) == 0)) { \ + NN_LOG_ERROR("Invalid param, context should be correct address"); \ + return 0; \ + } + +static ServiceHdlMgr g_serviceHandlerManager; +static ServiceHdlMgr g_serviceIdleHandlerManager; +static ServiceHdlMgr g_channelHandlerManager; +static ServiceHdlMgr g_secProVider; +static ServiceHdlMgr g_secValidator; +static ServiceHdlMgr g_TlsHdl; + +static bool IsNumberStr(const std::string &str) +{ + std::regex pattern("^[0-9]+$"); + return std::regex_match(str, pattern); +} + +static bool ConvertCpuIdsRangeStrToPair(const char *cpuIdsStr, std::pair &cpuIdsPair) +{ + std::string cpuRangeStr = std::string(cpuIdsStr); // 1-2 + if (cpuRangeStr.empty()) { + return true; + } + if (NN_UNLIKELY(NetFunc::NN_ValidateName(cpuRangeStr) != NN_OK)) { + NN_LOG_ERROR("Invalid cpu id"); + return false; + } + std::string::size_type pos = cpuRangeStr.find("-"); + if (pos == std::string::npos) { + NN_LOG_ERROR("Invalid workerGroupCpuRange: " << cpuRangeStr); + return false; + } + std::string beginNumStr = cpuRangeStr.substr(0, pos); + std::string endNumStr = cpuRangeStr.substr(pos + 1); + if (NN_UNLIKELY(!IsNumberStr(beginNumStr) || !IsNumberStr(endNumStr))) { + NN_LOG_ERROR("Invalid workerGroupCpuRange: " << cpuRangeStr); + return false; + } + + long beginId = 0; + long endId = 0; + if (!NetFunc::NN_Stol(beginNumStr, beginId) || !NetFunc::NN_Stol(endNumStr, endId)) { + NN_LOG_ERROR("Invalid begin id " << beginNumStr << " or end id: " << endNumStr); + return false; + } + cpuIdsPair = {static_cast(beginId), static_cast(endId)}; + NN_LOG_DEBUG("Convert Cpu ids pair:" << beginId << "," << endId); + return true; +} + +static bool ConvertServiceOptionsToInnerOptions(const ubs_hcom_service_options &options, + UBSHcomServiceOptions &innerOptions) +{ + innerOptions.maxSendRecvDataSize = + options.maxSendRecvDataSize != 0 ? options.maxSendRecvDataSize : NN_NO1024 ; + innerOptions.workerGroupId = options.workerGroupId; + innerOptions.workerGroupThreadCount = + options.workerGroupThreadCount != 0 ? options.workerGroupThreadCount : NN_NO1; + if (options.workerGroupMode == C_SERVICE_BUSY_POLLING) { + innerOptions.workerGroupMode = UBSHcomWorkerMode::NET_BUSY_POLLING; + } else if (options.workerGroupMode == C_SERVICE_EVENT_POLLING) { + innerOptions.workerGroupMode = UBSHcomWorkerMode::NET_EVENT_POLLING; + } + + std::pair cpuIdsRange = {UINT32_MAX, UINT32_MAX}; + if (NN_UNLIKELY(!ConvertCpuIdsRangeStrToPair(options.workerGroupCpuRange, cpuIdsRange))) { + NN_LOG_ERROR("Invalid cpuIdsRange, for example: 1-2 means cpu 1 to cpu 2"); + return false; + } + innerOptions.workerGroupCpuIdsRange = cpuIdsRange; + return true; +} + +static void ConvertServiceConnectOptionsToInnerOptions(const ubs_hcom_service_connect_options &options, + UBSHcomConnectOptions &innerOptions) +{ + innerOptions.clientGroupId = options.clientGroupId; + innerOptions.serverGroupId = options.serverGroupId; + innerOptions.linkCount = options.linkCount; + if (options.mode == C_CLIENT_WORKER_POLL) { + innerOptions.mode = UBSHcomClientPollingMode::WORKER_POLL; + } else if (options.mode == C_CLIENT_SELF_POLL_BUSY) { + innerOptions.mode = UBSHcomClientPollingMode::SELF_POLL_BUSY; + } else if (options.mode == C_CLIENT_SELF_POLL_EVENT) { + innerOptions.mode = UBSHcomClientPollingMode::SELF_POLL_EVENT; + } + if (options.cbType == C_CHANNEL_FUNC_CB) { + innerOptions.cbType = UBSHcomChannelCallBackType::CHANNEL_FUNC_CB; + } else if (options.cbType == C_CHANNEL_GLOBAL_CB) { + innerOptions.cbType = UBSHcomChannelCallBackType::CHANNEL_GLOBAL_CB; + } + innerOptions.payload = NN_CHAR_ARRAY_TO_STRING(options.payLoad); +} + +static void ConvertServiceTypeToInnerServiceProto(ubs_hcom_service_type t, UBSHcomServiceProtocol &proto) +{ + switch (t) { + case ubs_hcom_service_type::C_SERVICE_RDMA: + proto = UBSHcomServiceProtocol::RDMA; + break; + case ubs_hcom_service_type::C_SERVICE_TCP: + proto = UBSHcomServiceProtocol::TCP; + break; + case ubs_hcom_service_type::C_SERVICE_UDS: + proto = UBSHcomServiceProtocol::UDS; + break; + case ubs_hcom_service_type::C_SERVICE_SHM: + proto = UBSHcomServiceProtocol::SHM; + break; + case ubs_hcom_service_type::C_SERVICE_UBC: + proto = UBSHcomServiceProtocol::UBC; + break; + default: + proto = UBSHcomServiceProtocol::UNKNOWN; + break; + } +} + +void ubs_hcom_channel_refer(ubs_hcom_channel channel) +{ + VALIDATE_CHANNEL_NO_RET(channel) + reinterpret_cast(channel)->IncreaseRef(); +} + +void ubs_hcom_channel_derefer(ubs_hcom_channel channel) +{ + VALIDATE_CHANNEL_NO_RET(channel) + reinterpret_cast(channel)->DecreaseRef(); +} + +int ubs_hcom_channel_send(ubs_hcom_channel channel, ubs_hcom_channel_request req, ubs_hcom_channel_callback *cb) +{ + VALIDATE_CHANNEL(channel) + + UBSHcomRequest request(req.address, req.size, req.opcode); + auto innerChannel = reinterpret_cast(channel); + + if (cb == nullptr) { + return innerChannel->Send(request, nullptr); + } + + ubs_hcom_channel_cb_func cbFunc = cb->cb; + void *arg = cb->arg; + Callback *newCallback = UBSHcomNewCallback( + [cbFunc, arg] + (UBSHcomServiceContext &context) { cbFunc(arg, reinterpret_cast(&context)); }, + std::placeholders::_1); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("ubs_hcom_channel_send malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + auto result = innerChannel->Send(request, newCallback); + if (NN_UNLIKELY(result != SER_OK)) { + delete newCallback; + return result; + } + + return SER_OK; +} + +int ubs_hcom_channel_call(ubs_hcom_channel channel, ubs_hcom_channel_request req, ubs_hcom_channel_response *rsp, + ubs_hcom_channel_callback *cb) +{ + VALIDATE_CHANNEL(channel) + VALIDATE_MESSAGE(rsp) + UBSHcomRequest request(req.address, req.size, req.opcode); + UBSHcomResponse response(rsp->address, rsp->size); + auto innerChannel = reinterpret_cast(channel); + + SerResult ret = SER_OK; + if (cb == nullptr) { + ret = innerChannel->Call(request, response, nullptr); + rsp->address = response.address; + rsp->size = response.size; + rsp->errorCode = response.errorCode; + return ret; + } + + ubs_hcom_channel_cb_func cbFunc = cb->cb; + void *arg = cb->arg; + Callback *newCallback = UBSHcomNewCallback( + [cbFunc, arg] + (UBSHcomServiceContext &context) { cbFunc(arg, reinterpret_cast(&context)); }, + std::placeholders::_1); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("ubs_hcom_channel_call malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + auto result = innerChannel->Call(request, response, newCallback); + if (NN_UNLIKELY(result != SER_OK)) { + delete newCallback; + return result; + } + + return SER_OK; +} + +int ubs_hcom_channel_reply(ubs_hcom_channel channel, ubs_hcom_channel_request req, ubs_hcom_channel_reply_context ctx, + ubs_hcom_channel_callback *cb) +{ + VALIDATE_CHANNEL(channel) + + UBSHcomRequest request(req.address, req.size, req.opcode); + UBSHcomReplyContext replyCtx(reinterpret_cast(ctx.rspCtx), ctx.errorCode); + auto innerChannel = reinterpret_cast(channel); + + if (cb == nullptr) { + return innerChannel->Reply(replyCtx, request, nullptr); + } + + ubs_hcom_channel_cb_func cbFunc = cb->cb; + void *arg = cb->arg; + Callback *newCallback = UBSHcomNewCallback( + [cbFunc, arg] + (UBSHcomServiceContext &context) { cbFunc(arg, reinterpret_cast(&context)); }, + std::placeholders::_1); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("ubs_hcom_channel_reply malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + auto result = innerChannel->Reply(replyCtx, request, newCallback); + if (NN_UNLIKELY(result != SER_OK)) { + delete newCallback; + return result; + } + + return SER_OK; +} + +int ubs_hcom_channel_put(ubs_hcom_channel channel, ubs_hcom_oneside_request req, ubs_hcom_channel_callback *cb) +{ + VALIDATE_CHANNEL(channel) + + auto innerChannel = reinterpret_cast(channel); + UBSHcomOneSideRequest oneSideReq {}; + oneSideReq.lAddress = reinterpret_cast(req.lAddress); + COPY_ONESIDE_KEY(req.lKey, oneSideReq.lKey); + oneSideReq.rAddress = reinterpret_cast(req.rAddress); + COPY_ONESIDE_KEY(req.rKey, oneSideReq.rKey); + oneSideReq.size = req.size; + + if (cb == nullptr) { + return innerChannel->Put(oneSideReq, nullptr); + } + + ubs_hcom_channel_cb_func cbFunc = cb->cb; + void *arg = cb->arg; + Callback *newCallback = UBSHcomNewCallback( + [cbFunc, arg] + (UBSHcomServiceContext &context) { cbFunc(arg, reinterpret_cast(&context)); }, + std::placeholders::_1); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("ubs_hcom_channel_put malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + auto result = innerChannel->Put(oneSideReq, newCallback); + if (NN_UNLIKELY(result != SER_OK)) { + delete newCallback; + return result; + } + + return SER_OK; +} + +int ubs_hcom_channel_get(ubs_hcom_channel channel, ubs_hcom_oneside_request req, ubs_hcom_channel_callback *cb) +{ + VALIDATE_CHANNEL(channel) + + auto innerChannel = reinterpret_cast(channel); + UBSHcomOneSideRequest oneSideReq {}; + oneSideReq.lAddress = reinterpret_cast(req.lAddress); + oneSideReq.rAddress = reinterpret_cast(req.rAddress); + COPY_ONESIDE_KEY(req.lKey, oneSideReq.lKey); + COPY_ONESIDE_KEY(req.rKey, oneSideReq.rKey); + oneSideReq.size = req.size; + + if (cb == nullptr) { + return innerChannel->Get(oneSideReq, nullptr); + } + + ubs_hcom_channel_cb_func cbFunc = cb->cb; + void *arg = cb->arg; + Callback *newCallback = UBSHcomNewCallback( + [cbFunc, arg] + (UBSHcomServiceContext &context) { cbFunc(arg, reinterpret_cast(&context)); }, + std::placeholders::_1); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("ubs_hcom_channel_get malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + auto result = innerChannel->Get(oneSideReq, newCallback); + if (NN_UNLIKELY(result != SER_OK)) { + delete newCallback; + return result; + } + + return SER_OK; +} + +int ubs_hcom_channel_recv(ubs_hcom_channel channel, ubs_hcom_service_context ctx, uintptr_t address, uint32_t size, + ubs_hcom_channel_callback *cb) +{ + VALIDATE_CHANNEL(channel) + VALIDATE_CONTEXT(ctx) + VALIDATE_MR_ADDRESS(address) + VALIDATE_MR_SIZE(size) + auto innerChannel = reinterpret_cast(channel); + auto innerContext = reinterpret_cast(ctx); + + if (cb == nullptr) { + return innerChannel->Recv(*innerContext, address, size, nullptr); + } + + ubs_hcom_channel_cb_func cbFunc = cb->cb; + void *arg = cb->arg; + Callback *newCallback = UBSHcomNewCallback( + [cbFunc, arg] + (UBSHcomServiceContext &context) { cbFunc(arg, reinterpret_cast(&context)); }, + std::placeholders::_1); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("ubs_hcom_channel_get malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + auto result = innerChannel->Recv(*innerContext, address, size, newCallback); + if (NN_UNLIKELY(result != SER_OK)) { + return result; + } + return SER_OK; +} + +int ubs_hcom_channel_send_fds(ubs_hcom_channel channel, int fds[], uint32_t len) +{ + VALIDATE_CHANNEL(channel) + return reinterpret_cast(channel)->SendFds(fds, len); +} + +int ubs_hcom_channel_recv_fds(ubs_hcom_channel channel, int fds[], uint32_t len, int32_t timeoutSec) +{ + VALIDATE_CHANNEL(channel) + return reinterpret_cast(channel)->ReceiveFds(fds, len, timeoutSec); +} + +int ubs_hcom_channel_set_flowctl_cfg(ubs_hcom_channel channel, ubs_hcom_flowctl_opts opt) +{ + VALIDATE_CHANNEL(channel) + auto innerChannel = reinterpret_cast(channel); + + UBSHcomFlowCtrlOptions ctl {}; + ctl.intervalTimeMs = opt.intervalTimeMs; + ctl.thresholdByte = opt.thresholdByte; + ctl.flowCtrlLevel = static_cast(opt.flowCtrlLevel); + return innerChannel->SetFlowControlConfig(ctl); +} + +void ubs_hcom_channel_set_timeout(ubs_hcom_channel channel, int16_t oneSideTimeout, int16_t twoSideTimeout) +{ + VALIDATE_CHANNEL_NO_RET(channel) + auto innerChannel = reinterpret_cast(channel); + + innerChannel->SetChannelTimeOut(oneSideTimeout, twoSideTimeout); +} + +int ubs_hcom_channel_set_twoside_threshold(ubs_hcom_channel channel, ubs_hcom_twoside_threshold threshold) +{ + VALIDATE_CHANNEL(channel) + auto innerChannel = reinterpret_cast(channel); + + UBSHcomTwoSideThreshold twoSideThreshold{}; + twoSideThreshold.splitThreshold = threshold.splitThreshold; + twoSideThreshold.rndvThreshold = threshold.rndvThreshold; + return innerChannel->SetTwoSideThreshold(twoSideThreshold); +} + +uint64_t ubs_hcom_channel_get_id(ubs_hcom_channel channel) +{ + VALIDATE_CHANNEL(channel) + auto innerChannel = reinterpret_cast(channel); + + return innerChannel->GetId(); +} + +int ubs_hcom_context_get_rspctx(ubs_hcom_service_context context, ubs_hcom_channel_reply_context *rspCtx) +{ + VALIDATE_CONTEXT(context) + auto innerContext = reinterpret_cast(context); + rspCtx->rspCtx = reinterpret_cast(innerContext->RspCtx()); + return SER_OK; +} + +int ubs_hcom_context_get_channel(ubs_hcom_service_context context, ubs_hcom_channel *channel) +{ + VALIDATE_CONTEXT(context) + if (NN_UNLIKELY(channel == nullptr)) { + NN_LOG_ERROR("Invalid param, channel must be correct address"); + return SER_INVALID_PARAM; + } + + auto innerContext = reinterpret_cast(context); + *channel = reinterpret_cast(innerContext->Channel().Get()); + return SER_OK; +} + +int ubs_hcom_context_get_type(ubs_hcom_service_context context, ubs_hcom_service_context_type *type) +{ + VALIDATE_CONTEXT(context) + if (NN_UNLIKELY(type == nullptr)) { + NN_LOG_ERROR("Invalid param, type must be correct address"); + return SER_INVALID_PARAM; + } + + auto innerContext = reinterpret_cast(context); + *type = static_cast(innerContext->OpType()); + return SER_OK; +} + +int ubs_hcom_context_get_result(ubs_hcom_service_context context, int *result) +{ + VALIDATE_CONTEXT(context) + if (NN_UNLIKELY(result == nullptr)) { + NN_LOG_ERROR("Invalid param, result must be correct address"); + return SER_INVALID_PARAM; + } + + auto innerContext = reinterpret_cast(context); + *result = innerContext->Result(); + return SER_OK; +} + +uint16_t ubs_hcom_context_get_opcode(ubs_hcom_service_context context) +{ + VALIDATE_CONTEXT(context); + auto innerContext = reinterpret_cast(context); + return innerContext->OpCode(); +} + +void *ubs_hcom_context_get_data(ubs_hcom_service_context context) +{ + VALIDATE_CONTEXT_RETURN_PTR(context) + + auto innerContext = reinterpret_cast(context); + return innerContext->MessageData(); +} + +uint32_t ubs_hcom_context_get_datalen(ubs_hcom_service_context context) +{ + VALIDATE_CONTEXT_RETURN_ZERO(context) + + auto innerContext = reinterpret_cast(context); + return innerContext->MessageDataLen(); +} + +int ubs_hcom_service_create(ubs_hcom_service_type t, const char *name, ubs_hcom_service_options options, + ubs_hcom_service *service) +{ + if (NN_UNLIKELY(name == nullptr || service == nullptr)) { + NN_LOG_ERROR("Invalid param, name or service is nullptr"); + return SER_INVALID_PARAM; + } + + if (strlen(name) > NN_NO64) { + NN_LOG_ERROR("Invalid param, name length must be than " << NN_NO64); + return SER_INVALID_PARAM; + } + + UBSHcomServiceOptions innerOptions; + if (NN_UNLIKELY(!ConvertServiceOptionsToInnerOptions(options, innerOptions))) { + NN_LOG_ERROR("Invalid options"); + return SER_INVALID_PARAM; + } + + UBSHcomServiceProtocol proto = UBSHcomServiceProtocol::RDMA; + ConvertServiceTypeToInnerServiceProto(t, proto); + + auto tmpService = UBSHcomService::Create(proto, name, innerOptions); + if (tmpService == nullptr) { + NN_LOG_ERROR("Failed to create UBSHcomService"); + return SER_NEW_OBJECT_FAILED; + } + *service = reinterpret_cast(tmpService); + return SER_OK; +} + +int ubs_hcom_service_bind(ubs_hcom_service service, const char *listenerUrl, ubs_hcom_service_channel_handler h) +{ + VALIDATE_SERVICE(service); + VALIDATE_HANDLER(h); + if (NN_UNLIKELY(listenerUrl == nullptr)) { + NN_LOG_ERROR("Invalid paraim, listenerUrl is null"); + return SER_INVALID_PARAM; + } + + auto tmpH = new (std::nothrow) ServiceHdlAdp( + ubs_hcom_service_channel_handler_type::C_CHANNEL_NEW, h, 0); + if (NN_UNLIKELY(tmpH == nullptr)) { + NN_LOG_ERROR("Failed to new channel handler adaptor, probably out of memory"); + return SER_NEW_OBJECT_FAILED; + } + g_serviceHandlerManager.AddHdlAdp(service, reinterpret_cast(tmpH)); + + return reinterpret_cast(service)->Bind(listenerUrl, std::bind(&ServiceHdlAdp::NewChannel, tmpH, + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); +} + +int ubs_hcom_service_start(ubs_hcom_service service) +{ + VALIDATE_SERVICE(service); + return reinterpret_cast(service)->Start(); +} + +int ubs_hcom_service_destroy(ubs_hcom_service service, const char *name) +{ + VALIDATE_SERVICE(service); + if (NN_UNLIKELY(name == nullptr)) { + NN_LOG_ERROR("Failed to destroy as name is nullptr"); + return SER_INVALID_PARAM; + } + g_serviceHandlerManager.RemoveAll(service); + g_serviceIdleHandlerManager.RemoveAll(service); + g_channelHandlerManager.RemoveAll(service); + g_secProVider.RemoveAll(service); + g_secValidator.RemoveAll(service); + g_TlsHdl.RemoveAll(service); + return reinterpret_cast(service)->Destroy(name); +} + +int ubs_hcom_service_connect(ubs_hcom_service service, const char *serverUrl, ubs_hcom_channel *channel, ubs_hcom_service_connect_options options) +{ + VALIDATE_SERVICE(service); + if (NN_UNLIKELY(serverUrl == nullptr)) { + NN_LOG_ERROR("Failed to connect as serverUrl is nullptr"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(channel == nullptr)) { + NN_LOG_ERROR("Failed to connect as channel is nullptr"); + return SER_INVALID_PARAM; + } + + UBSHcomConnectOptions innerOptions; + ConvertServiceConnectOptionsToInnerOptions(options, innerOptions); + UBSHcomChannelPtr tmpChannel; + auto result = reinterpret_cast(service)->Connect(serverUrl, tmpChannel, innerOptions); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + + // increase ref, need to call Channel_Destroy() to decrease ref + tmpChannel->IncreaseRef(); + + *channel = reinterpret_cast(tmpChannel.Get()); + return SER_OK; +} + +int ubs_hcom_service_disconnect(ubs_hcom_service service, ubs_hcom_channel channel) +{ + VALIDATE_SERVICE(service); + VALIDATE_CHANNEL(channel); + auto innerChannel = reinterpret_cast(channel); + reinterpret_cast(service)->Disconnect(innerChannel); + innerChannel->DecreaseRef(); + return SER_OK; +} + +int ubs_hcom_service_register_memory_region(ubs_hcom_service service, uint64_t size, ubs_hcom_memory_region *mr) +{ + VALIDATE_SERVICE(service); + VALIDATE_MR_POINT(mr); + + auto tmpMr = new (std::nothrow) UBSHcomRegMemoryRegion; + if (tmpMr == nullptr) { + NN_LOG_ERROR("Failed to malloc memory"); + return SER_NEW_OBJECT_FAILED; + } + + auto result = reinterpret_cast(service)->RegisterMemoryRegion(size, *tmpMr); + if (NN_UNLIKELY(result != NN_OK)) { + delete tmpMr; + return result; + } + *mr = reinterpret_cast(tmpMr); + return SER_OK; +} + +int ubs_hcom_service_register_assign_memory_region(ubs_hcom_service service, uintptr_t address, uint64_t size, + ubs_hcom_memory_region *mr) +{ + VALIDATE_SERVICE(service); + VALIDATE_MR_POINT(mr); + + auto tmpMr = new (std::nothrow) UBSHcomRegMemoryRegion; + if (tmpMr == nullptr) { + NN_LOG_ERROR("Failed to malloc memory"); + return SER_NEW_OBJECT_FAILED; + } + + auto result = reinterpret_cast(service)->RegisterMemoryRegion(address, size, *tmpMr); + if (NN_UNLIKELY(result != NN_OK)) { + delete tmpMr; + NN_LOG_ERROR("Failed to register memory"); + return result; + } + *mr = reinterpret_cast(tmpMr); + return SER_OK; +} + +int ubs_hcom_service_get_memory_region_info(ubs_hcom_memory_region mr, ubs_hcom_mr_info *info) +{ + VALIDATE_MR(mr); + VALIDATE_INFO(info); + + auto tmp = reinterpret_cast(mr); + if (NN_UNLIKELY(tmp == nullptr)) { + NN_LOG_ERROR("convert to mr failed"); + return SER_ERROR; + } + info->lAddress = tmp->GetAddress(); + UBSHcomMemoryKey mrKey; + tmp->GetMemoryKey(mrKey); + if (memcpy_s(&(info->lKey), sizeof(ubs_hcom_oneside_key), &mrKey, sizeof(UBSHcomMemoryKey)) != 0) { + NN_LOG_ERROR("copy mrkey failed!"); + return SER_ERROR; + } + + info->size = tmp->GetSize(); + return SER_OK; +} + +int ubs_hcom_service_destroy_memory_region(ubs_hcom_service service, ubs_hcom_memory_region mr) +{ + VALIDATE_SERVICE(service); + VALIDATE_MR(mr); + auto tmpMr = reinterpret_cast(mr); + reinterpret_cast(service)->DestroyMemoryRegion(*tmpMr); + delete tmpMr; + return SER_OK; +} + +void ubs_hcom_service_register_broken_handler(ubs_hcom_service service, ubs_hcom_service_channel_handler h, + ubs_hcom_service_channel_policy policy, uint64_t usrCtx) +{ + VALIDATE_SERVICE_NO_RET(service); + VALIDATE_HANDLER_NO_RET(h); + + auto tmpHdl = new (std::nothrow) ServiceHdlAdp( + ubs_hcom_service_channel_handler_type::C_CHANNEL_BROKEN, h, usrCtx); + if (NN_UNLIKELY(tmpHdl == nullptr)) { + NN_LOG_ERROR("Failed to new channel handler adapter, probably out of memory"); + return; + } + + reinterpret_cast(service)->RegisterChannelBrokenHandler( + std::bind(&ServiceHdlAdp::ChannelBroken, tmpHdl, std::placeholders::_1), + static_cast(policy)); + g_serviceHandlerManager.AddHdlAdp(service, reinterpret_cast(tmpHdl)); + return; +} + +void ubs_hcom_service_register_idle_handler(ubs_hcom_service service, ubs_hcom_service_idle_handler h, uint64_t usrCtx) +{ + VALIDATE_SERVICE_NO_RET(service) + VALIDATE_HANDLER_NO_RET(h) + + auto tmpHdl = new (std::nothrow) ServiceIdleHdlAdp(h, usrCtx); + if (NN_UNLIKELY(tmpHdl == nullptr)) { + NN_LOG_ERROR("Failed to new Endpoint handler adapter, probably out of memory"); + return; + } + + reinterpret_cast(service)->RegisterIdleHandler( + std::bind(&ServiceIdleHdlAdp::Idle, tmpHdl, std::placeholders::_1)); + + g_serviceIdleHandlerManager.AddHdlAdp(service, reinterpret_cast(tmpHdl)); + return; +} + +void ubs_hcom_service_register_handler(ubs_hcom_service service, ubs_hcom_service_handler_type t, + ubs_hcom_service_request_handler h, uint64_t usrCtx) +{ + VALIDATE_SERVICE_NO_RET(service) + VALIDATE_HANDLER_NO_RET(h) + + auto tmpHdl = new (std::nothrow) ChannelOpHdlAdp(h, usrCtx); + if (NN_UNLIKELY(tmpHdl == nullptr)) { + NN_LOG_ERROR("Failed to new Endpoint handler adapter, probably out of memory"); + return; + } + + if (t == C_SERVICE_REQUEST_RECEIVED) { + reinterpret_cast(service)->RegisterRecvHandler( + std::bind(&ChannelOpHdlAdp::Requested, tmpHdl, std::placeholders::_1)); + } else if (t == C_SERVICE_REQUEST_POSTED) { + reinterpret_cast(service)->RegisterSendHandler( + std::bind(&ChannelOpHdlAdp::Requested, tmpHdl, std::placeholders::_1)); + } else if (t == C_SERVICE_READWRITE_DONE) { + reinterpret_cast(service)->RegisterOneSideHandler( + std::bind(&ChannelOpHdlAdp::Requested, tmpHdl, std::placeholders::_1)); + } else { + NN_LOG_ERROR("Unreachable"); + delete tmpHdl; + return; + } + g_channelHandlerManager.AddHdlAdp(service, reinterpret_cast(tmpHdl)); + + return; +} + +void ubs_hcom_service_add_workergroup(ubs_hcom_service service, int8_t priority, uint16_t workerGroupId, + uint32_t threadCount, const char *cpuIdsRange) +{ + VALIDATE_SERVICE_NO_RET(service); + if (NN_UNLIKELY(cpuIdsRange == nullptr)) { + NN_LOG_ERROR("Invalid cpuIdsRange, cpuIdsRange is NULL"); + return; + } + + std::pair cpuIdsPair; + if (NN_UNLIKELY(!ConvertCpuIdsRangeStrToPair(cpuIdsRange, cpuIdsPair))) { + NN_LOG_ERROR("Invalid cpuIdsRange, for example: 1-2 means cpu 1 to cpu 2"); + return; + } + reinterpret_cast(service)->AddWorkerGroup(workerGroupId, threadCount, cpuIdsPair, priority); +} + +void ubs_hcom_service_add_listener(ubs_hcom_service service, const char *url, uint16_t workerCount) +{ + VALIDATE_SERVICE_NO_RET(service); + if (NN_UNLIKELY(url == nullptr)) { + NN_LOG_ERROR("Invalid url as url is nullptr"); + return; + } + reinterpret_cast(service)->AddListener(url, workerCount); +} + +void ubs_hcom_service_set_lbpolicy(ubs_hcom_service service, ubs_hcom_service_lb_policy lbPolicy) +{ + VALIDATE_SERVICE_NO_RET(service); + UBSHcomServiceLBPolicy policy = UBSHcomServiceLBPolicy::NET_ROUND_ROBIN; + if (lbPolicy == SERVICE_HASH_IP_PORT) { + policy = UBSHcomServiceLBPolicy::NET_HASH_IP_PORT; + } + reinterpret_cast(service)->SetConnectLBPolicy(policy); +} + +void ubs_hcom_service_set_tls_opt(ubs_hcom_service service, bool enableTls, ubs_hcom_service_tls_version version, + ubs_hcom_service_cipher_suite cipherSuite, ubs_hcom_tls_get_cert_cb certCb, ubs_hcom_tls_get_pk_cb priKeyCb, + ubs_hcom_tls_get_ca_cb caCb) +{ + VALIDATE_SERVICE_NO_RET(service); + + UBSHcomTlsOptions opt; + opt.enableTls = enableTls; + if (!enableTls) { + reinterpret_cast(service)->SetTlsOptions(opt); + return; + } + if (NN_UNLIKELY(certCb == nullptr) || NN_UNLIKELY(priKeyCb == nullptr) || NN_UNLIKELY(caCb == nullptr)) { + NN_LOG_ERROR("Failed to set tls options as cb is nullptr"); + return; + } + opt.tlsVersion = version == C_SERVICE_TLS_1_2 ? UBSHcomTlsVersion::TLS_1_2 : UBSHcomTlsVersion::TLS_1_3; + if (cipherSuite == C_SERVICE_AES_GCM_128) { + opt.netCipherSuite = UBSHcomNetCipherSuite::AES_GCM_128; + } else if (cipherSuite == C_SERVICE_AES_GCM_256) { + opt.netCipherSuite = UBSHcomNetCipherSuite::AES_GCM_256; + } else if (cipherSuite == C_SERVICE_AES_CCM_128) { + opt.netCipherSuite = UBSHcomNetCipherSuite::AES_CCM_128; + } else if (cipherSuite == C_SERVICE_CHACHA20_POLY1305) { + opt.netCipherSuite = UBSHcomNetCipherSuite::CHACHA20_POLY1305; + } + + auto tmpH = new (std::nothrow) EpTLSHdlAdp(); + if (NN_UNLIKELY(tmpH == nullptr)) { + NN_LOG_ERROR("Failed to new service tls handler adapter, probably out of memory"); + return; + } + tmpH->SetTLSCertCb(certCb); + tmpH->SetTLSPrivateKeyCb(priKeyCb); + tmpH->SetTLSCaCb(caCb); + opt.caCb = (std::bind(&EpTLSHdlAdp::UBSHcomTLSCaCallback, tmpH, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + opt.cfCb = (std::bind(&EpTLSHdlAdp::UBSHcomTLSCertificationCallback, tmpH, std::placeholders::_1, + std::placeholders::_2)); + opt.pkCb = (std::bind(&EpTLSHdlAdp::UBSHcomTLSPrivateKeyCallback, tmpH, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + g_TlsHdl.AddHdlAdp(service, reinterpret_cast(tmpH)); + reinterpret_cast(service)->SetTlsOptions(opt); +} + +void ubs_hcom_service_set_secure_opt(ubs_hcom_service service, ubs_hcom_service_secure_type secType, + ubs_hcom_secinfo_provider provider, ubs_hcom_secinfo_validator validator, uint16_t magic, uint8_t version) +{ + VALIDATE_SERVICE_NO_RET(service); + if (NN_UNLIKELY(provider == nullptr) || NN_UNLIKELY(validator == nullptr)) { + NN_LOG_ERROR("Failed to SetSecureOptions as provider or validator is nullptr"); + return; + } + + UBSHcomConnSecureOptions opt; + if (secType == C_SERVICE_NET_SEC_DISABLED) { + opt.secType = UBSHcomNetDriverSecType::NET_SEC_DISABLED; + } else if (secType == C_SERVICE_NET_SEC_ONE_WAY) { + opt.secType = UBSHcomNetDriverSecType::NET_SEC_VALID_ONE_WAY; + } else if (secType == C_SERVICE_NET_SEC_TWO_WAY) { + opt.secType = UBSHcomNetDriverSecType::NET_SEC_VALID_TWO_WAY; + } + opt.version = version; + opt.magic = magic; + + auto providerTmpH = new (std::nothrow) OOBSecInfoProviderAdp(provider); + if (NN_UNLIKELY(providerTmpH == nullptr)) { + NN_LOG_ERROR("Register Service_SecInfoProvider failed, probably out of memory"); + return; + } + opt.provider = std::bind(&OOBSecInfoProviderAdp::CreateSecInfo, providerTmpH, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, + std::placeholders::_6); + + auto validatorTmpH = new (std::nothrow) OOBSecInfoValidatorAdp(validator); + if (NN_UNLIKELY(validatorTmpH == nullptr)) { + NN_LOG_ERROR("Register Service_SecInfoValidator failed, probably out of memory"); + opt.provider = nullptr; + delete providerTmpH; + return; + } + opt.validator = std::bind(&OOBSecInfoValidatorAdp::SecInfoValidate, validatorTmpH, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); + + g_secProVider.AddHdlAdp(service, reinterpret_cast(providerTmpH)); + g_secValidator.AddHdlAdp(service, reinterpret_cast(validatorTmpH)); + reinterpret_cast(service)->SetConnSecureOpt(opt); +} + +void ubs_hcom_service_set_tcp_usr_timeout(ubs_hcom_service service, uint16_t timeOutSec) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetTcpUserTimeOutSec(timeOutSec); +} + +void ubs_hcom_service_set_tcp_send_zcopy(ubs_hcom_service service, bool tcpSendZCopy) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetTcpSendZCopy(tcpSendZCopy); +} + +void ubs_hcom_service_set_ipmask(ubs_hcom_service service, const char *ipMask) +{ + VALIDATE_SERVICE_NO_RET(service); + if (NN_UNLIKELY(ipMask == nullptr)) { + NN_LOG_ERROR("Failed to set as ipMask is nullptr"); + return; + } + std::vector ipMasks; + NetFunc::NN_SplitStr(ipMask, ",", ipMasks); + reinterpret_cast(service)->SetDeviceIpMask(ipMasks); +} + +void ubs_hcom_service_set_ipgroup(ubs_hcom_service service, const char *ipGroup) +{ + VALIDATE_SERVICE_NO_RET(service); + if (NN_UNLIKELY(ipGroup == nullptr)) { + NN_LOG_ERROR("Failed to set as ipGroup is nullptr"); + return; + } + std::vector ipGroups; + NetFunc::NN_SplitStr(ipGroup, ",", ipGroups); + reinterpret_cast(service)->SetDeviceIpGroups(ipGroups); +} + +void ubs_hcom_service_set_cq_depth(ubs_hcom_service service, uint16_t depth) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetCompletionQueueDepth(depth); +} + +void ubs_hcom_service_set_sq_size(ubs_hcom_service service, uint32_t sqSize) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetSendQueueSize(sqSize); +} + +void ubs_hcom_service_set_rq_size(ubs_hcom_service service, uint32_t rqSize) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetRecvQueueSize(rqSize); +} + +void ubs_hcom_service_set_prepost_size(ubs_hcom_service service, uint32_t prePostSize) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetQueuePrePostSize(prePostSize); +} + +void ubs_hcom_service_set_polling_batchsize(ubs_hcom_service service, uint16_t pollSize) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetPollingBatchSize(pollSize); +} + +void ubs_hcom_service_set_polling_timeoutus(ubs_hcom_service service, uint16_t pollTimeout) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetEventPollingTimeOutUs(pollTimeout); +} + +void ubs_hcom_service_set_timeout_threadnum(ubs_hcom_service service, uint32_t threadNum) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetTimeOutDetectionThreadNum(threadNum); +} + +void ubs_hcom_service_set_max_connection_cnt(ubs_hcom_service service, uint32_t maxConnCount) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetMaxConnectionCount(maxConnCount); +} + +void ubs_hcom_service_set_heartbeat_opt(ubs_hcom_service service, uint16_t idleSec, uint16_t probeTimes, + uint16_t intervalSec) +{ + VALIDATE_SERVICE_NO_RET(service); + UBSHcomHeartBeatOptions opt; + opt.heartBeatIdleSec = idleSec; + opt.heartBeatProbeTimes = probeTimes; + opt.heartBeatProbeIntervalSec = intervalSec; + reinterpret_cast(service)->SetHeartBeatOptions(opt); +} + +void ubs_hcom_service_set_multirail_opt(ubs_hcom_service service, bool enable, uint32_t threshold) +{ + VALIDATE_SERVICE_NO_RET(service); + UBSHcomMultiRailOptions opt; + opt.enable = enable; + opt.threshold = threshold; + reinterpret_cast(service)->SetMultiRailOptions(opt); +} + +void ubs_hcom_service_set_enable_mrcache(ubs_hcom_service service, bool enableMrCache) +{ + VALIDATE_SERVICE_NO_RET(service); + reinterpret_cast(service)->SetEnableMrCache(enableMrCache); +} + +void ubs_hcom_service_set_ubcmode(ubs_hcom_service service, ubs_hcom_service_ubc_mode ubcMode) +{ + VALIDATE_SERVICE_NO_RET(service); + UBSHcomUbcMode tmpUbcMode = UBSHcomUbcMode::LowLatency; + if (ubcMode == C_SERVICE_HIGHBANDWIDTH) { + tmpUbcMode = UBSHcomUbcMode::HighBandwidth; + } + reinterpret_cast(service)->SetUbcMode(tmpUbcMode); +} diff --git a/src/api/capi_v2/hcom_service_c.h b/src/api/capi_v2/hcom_service_c.h new file mode 100644 index 0000000000000000000000000000000000000000..65514ac597c584b0329f146d9c77480e99fd2319 --- /dev/null +++ b/src/api/capi_v2/hcom_service_c.h @@ -0,0 +1,435 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_CAPI_V2_HCOM_SERVICE_C_V2_H_ +#define HCOM_CAPI_V2_HCOM_SERVICE_C_V2_H_ + +#include +#include +#include "hcom_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef uintptr_t ubs_hcom_channel; +/* + * @brief Service context, which used for callback as param + */ +typedef uintptr_t ubs_hcom_service_context; + +/* + * @brief service, which include oob & multi protocols(TCP/RDMA/SHM) workers & callback etc + */ +typedef uintptr_t ubs_hcom_service; + +/* + * @brief Channel, represent multi connections(EPs) of one protocol + * + * two side operation, Hcom_ChannelSend + * read operation from remote, Hcom_ChannelRead + * write operation from remote, Hcom_ChannelWrite + */ +typedef uintptr_t ubs_hcom_channel; + +typedef uintptr_t ubs_hcom_memory_region; + +/* + * @brief Service context, which used for callback as param + */ +typedef uintptr_t ubs_hcom_service_context; + +/* + * Callback function which will be invoked by async use mode + */ +typedef void (*ubs_hcom_channel_cb_func)(void *arg, ubs_hcom_service_context context); +/* + * @brief Callback function definition + * 1) new endpoint connected from client, only need to register this at sever side + * 2) endpoint is broken, called when RDMA qp detection error or broken + */ +typedef int (*ubs_hcom_service_channel_handler)(ubs_hcom_channel channel, uint64_t usrCtx, const char *payLoad); +typedef void (*ubs_hcom_service_idle_handler)(uint8_t wkrGrpIdx, uint16_t idxInGrp, uint64_t usrCtx); +typedef int (*ubs_hcom_service_request_handler)(ubs_hcom_service_context ctx, uint64_t usrCtx); + +/* + * @brief keyPass [in] erase function + * @param keyPass [in] the memory address of keyPass + */ +typedef void (*ubs_hcom_tls_keypass_erase)(char *keyPass, int len); + +/* + * @brief The cert verify function + * + * @param x509 [in] the x509 object of CA + * @param crlPath [in] the crl file path + * + * @return -1 for failed, and 1 for success + */ +typedef int (*ubs_hcom_tls_cert_verify)(void *x509, const char *crlPath); + +/* + * @brief Get the certificate file of public key + * + * @param name [out] the name + * @param certPath [out] the path of certificate + */ +typedef int (*ubs_hcom_tls_get_cert_cb)(const char *name, char **certPath); + +/* + * @brief Get private key file's path and length, and get the keyPass + * @param name [out] the name + * @param priKeyPath [out] the path of private key + * @param keyPass [out] the keyPass + * @param erase [out] the erase function + */ +typedef int (*ubs_hcom_tls_get_pk_cb)( + const char *name, char **priKeyPath, char **keyPass, ubs_hcom_tls_keypass_erase *erase); + +/* + * @brief Get the CA and verify + * @param name [out] the name + * @param caPath [out] the path of CA file + * @param crlPath [out] the crl file path + * @param verifyType [out] the type of verify in[VERIFY_BY_NONE,VERIFY_BY_DEFAULT, VERIFY_BY_CUSTOM_FUNC] + * @param verify [out] the verify function, only effect in VERIFY_BY_CUSTOM_FUNC mode + */ +typedef int (*ubs_hcom_tls_get_ca_cb)(const char *name, char **caPath, char **crlPath, + ubs_hcom_peer_cert_verify_type *verifyType, ubs_hcom_tls_cert_verify *verify); + +/* + * @brief Sec callback function, when oob connect build, this function will be called to generate auth info. + * if this function not set secure type is C_NET_SEC_NO_VALID and oob will not send secure info + * + * @param ctx [in] ctx from connect param ctx, and will send in auth process + * @param flag [out] flag to sent in auth process + * @param type [out] secure type, value should set in oob client, and should in [C_NET_SEC_ONE_WAY, + * C_NET_SEC_TWO_WAY] + * @param output [out] secure info created + * @param outLen [out] secure info length + * @param needAutoFree [out] secure info need to auto free in hcom or not + */ +typedef int (*ubs_hcom_secinfo_provider)(uint64_t ctx, int64_t *flag, ubs_hcom_driver_sec_type *type, char **output, + uint32_t *outLen, int *needAutoFree); + +/* + * @brief ValidateSecInfo callback function, when oob connect build, this function will be called to validate auth info + * if this function not set oob will not validate secure info + * + * @param flag [in] flag received in auth process + * @param ctx [in] ctx received in auth process + * @param input [in] secure info received + * @param inputLen [in] secure info length + */ +typedef int (*ubs_hcom_secinfo_validator)(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen); + +/* + * @brief External log callback function + * + * @param level [in] level, 0/1/2/3 represent debug/info/warn/error + * @param msg [in] message, log message with name:code-line-number + */ +typedef void (*ubs_hcom_log_handler)(int level, const char *msg); + +/* + * @brief Worker polling type + * 1 For RDMA: + * C_BUSY_POLLING, means cpu 100% polling no matter there is request cb, better performance but cost dedicated CPU + * C_EVENT_POLLING, waiting on OS kernel for request cb + * 2 For TCP/UDS + * only event pooling is supported + */ +typedef enum { + C_SERVICE_BUSY_POLLING = 0, + C_SERVICE_EVENT_POLLING = 1, +} ubs_hcom_worker_mode; + +typedef enum { + C_CLIENT_WORKER_POLL = 0, + C_CLIENT_SELF_POLL_BUSY = 1, + C_CLIENT_SELF_POLL_EVENT = 2, +} ubs_hcom_service_polling_mode; + +typedef enum { + C_CHANNEL_FUNC_CB = 0, // use channel function param (const NetCallback *cb) + C_CHANNEL_GLOBAL_CB = 1, // use service RegisterOpHandler +} ubs_hcom_channel_cb_type; + +typedef enum { + HIGH_LEVEL_BLOCK, /* spin-wait by busy loop */ + LOW_LEVEL_BLOCK, /* full sleep */ +} ubs_hcom_channel_flowctl_level; + +typedef enum { + C_SERVICE_RDMA = 0, + C_SERVICE_TCP = 1, + C_SERVICE_UDS = 2, + C_SERVICE_SHM = 3, + C_SERVICE_UBC = 6, +} ubs_hcom_service_type; + +typedef enum { + C_CHANNEL_BROKEN_ALL = 0, /* when one ep broken, all eps broken */ + C_CHANNEL_RECONNECT = 1, /* when one ep broken, try re-connect first. If re-connect fail, broken all eps */ + C_CHANNEL_KEEP_ALIVE = 2, /* when one ep broken, keep left eps alive until all eps broken */ +} ubs_hcom_service_channel_policy; + +/* + * @brief Enum for callback register [new endpoint connected or endpoint broken] + */ +typedef enum { + C_CHANNEL_NEW = 0, + C_CHANNEL_BROKEN = 1, +} ubs_hcom_service_channel_handler_type; + +typedef enum { + C_SERVICE_REQUEST_RECEIVED = 0, + C_SERVICE_REQUEST_POSTED = 1, + C_SERVICE_READWRITE_DONE = 2, +} ubs_hcom_service_handler_type; + +typedef enum { + SERVICE_ROUND_ROBIN = 0, + SERVICE_HASH_IP_PORT = 1, +} ubs_hcom_service_lb_policy; + +typedef enum { + C_SERVICE_TLS_1_2 = 771, + C_SERVICE_TLS_1_3 = 772, +} ubs_hcom_service_tls_version; + +typedef enum { + C_SERVICE_AES_GCM_128 = 0, + C_SERVICE_AES_GCM_256 = 1, + C_SERVICE_AES_CCM_128 = 2, + C_SERVICE_CHACHA20_POLY1305 = 3, +} ubs_hcom_service_cipher_suite; + +typedef enum { + C_SERVICE_NET_SEC_DISABLED = 0, + C_SERVICE_NET_SEC_ONE_WAY = 1, + C_SERVICE_NET_SEC_TWO_WAY = 2, +} ubs_hcom_service_secure_type; + +/* + * @brief Enum for UBC mode + */ +typedef enum { + C_SERVICE_LOWLATENCY = 0, + C_SERVICE_HIGHBANDWIDTH = 1, +} ubs_hcom_service_ubc_mode; + +/* + * @brief Context type, part of ubs_hcom_service_context, sync mode is not aware most of them + */ +typedef enum { + SERVICE_RECEIVED = 0, /* support invoke all functions */ + SERVICE_RECEIVED_RAW = 1, /* support invoke most functions except Service_GetOpInfo() */ + SERVICE_SENT = 2, /* support invoke basic functions except + Service_GetMessage() * 3、Service_GetRspCtx()、 */ + SERVICE_SENT_RAW = 3, /* support invoke basic functions except + Service_GetMessage() * 3、、Service_GetRspCtx()、Service_GetOpInfo() */ + SERVICE_ONE_SIDE = 4, /* support invoke basic functions except + Service_GetMessage() * 3、、Service_GetRspCtx()、Service_GetOpInfo() */ + SERVICE_RNDV = 5, + SERVICE_INVALID_OP_TYPE = 255, +} ubs_hcom_service_context_type; + +typedef struct { + uint32_t maxSendRecvDataSize; + uint16_t workerGroupId; + uint16_t workerGroupThreadCount; + ubs_hcom_worker_mode workerGroupMode; + int8_t workerThreadPriority; + char workerGroupCpuRange[64]; // worker group cpu range, for example 6-10 +} ubs_hcom_service_options; + +typedef struct { + ubs_hcom_channel_cb_func cb; // User callback function + void *arg; // Argument of callback +} ubs_hcom_channel_callback; + + +typedef struct { + uint16_t clientGroupId; // worker group id of client + uint16_t serverGroupId; // worker group id of server + uint8_t linkCount; // actual link count of the channel + ubs_hcom_service_polling_mode mode; + ubs_hcom_channel_cb_type cbType; + char payLoad[512]; +} ubs_hcom_service_connect_options; + +typedef struct { + void *address; /* pointer of data */ + uint32_t size; /* size of data */ + uint16_t opcode; +} ubs_hcom_channel_request; + +typedef struct { + void *address; /* pointer of data */ + uint32_t size; /* size of data */ + int16_t errorCode; /* error code of response */ +} ubs_hcom_channel_response; + +typedef struct { + void *rspCtx; + int16_t errorCode; +} ubs_hcom_channel_reply_context; + +typedef struct { + uint64_t keys[4]; + uint64_t tokens[4]; +} ubs_hcom_oneside_key; + +/* + * @brief Read/write mr info for one side rdma operation + */ +typedef struct { + uintptr_t lAddress; // local memory region address + ubs_hcom_oneside_key lKey; // local memory region key + uint64_t size; // data size +} ubs_hcom_mr_info; + +typedef struct { + void *lAddress; + void *rAddress; + ubs_hcom_oneside_key lKey; + ubs_hcom_oneside_key rKey; + uint32_t size; +} ubs_hcom_oneside_request; + +typedef struct { + uint16_t intervalTimeMs; + uint64_t thresholdByte; + ubs_hcom_channel_flowctl_level flowCtrlLevel; +} ubs_hcom_flowctl_opts; + +typedef struct { + uint32_t splitThreshold; + uint32_t rndvThreshold; +} ubs_hcom_twoside_threshold; + +int ubs_hcom_service_create(ubs_hcom_service_type t, const char *name, ubs_hcom_service_options options, + ubs_hcom_service *service); + +int ubs_hcom_service_bind(ubs_hcom_service service, const char *listenerUrl, ubs_hcom_service_channel_handler h); + +int ubs_hcom_service_start(ubs_hcom_service service); + +int ubs_hcom_service_destroy(ubs_hcom_service service, const char *name); + +int ubs_hcom_service_connect(ubs_hcom_service service, const char *serverUrl, ubs_hcom_channel *channel, + ubs_hcom_service_connect_options options); + +int ubs_hcom_service_disconnect(ubs_hcom_service service, ubs_hcom_channel channel); + +int ubs_hcom_service_register_memory_region(ubs_hcom_service service, uint64_t size, ubs_hcom_memory_region *mr); + +int ubs_hcom_service_get_memory_region_info(ubs_hcom_memory_region mr, ubs_hcom_mr_info *info); + +int ubs_hcom_service_register_assign_memory_region( + ubs_hcom_service service, uintptr_t address, uint64_t size, ubs_hcom_memory_region *mr); + +int ubs_hcom_service_destroy_memory_region(ubs_hcom_service service, ubs_hcom_memory_region mr); + +void ubs_hcom_service_register_broken_handler(ubs_hcom_service service, ubs_hcom_service_channel_handler h, + ubs_hcom_service_channel_policy policy, uint64_t usrCtx); + +void ubs_hcom_service_register_idle_handler(ubs_hcom_service service, ubs_hcom_service_idle_handler h, uint64_t usrCtx); + +void ubs_hcom_service_register_handler(ubs_hcom_service service, ubs_hcom_service_handler_type t, + ubs_hcom_service_request_handler h, uint64_t usrCtx); + +void ubs_hcom_service_add_workergroup(ubs_hcom_service service, int8_t priority, uint16_t workerGroupId, + uint32_t threadCount, const char *cpuIdsRange); + +void ubs_hcom_service_add_listener(ubs_hcom_service service, const char *url, uint16_t workerCount); + +void ubs_hcom_service_set_lbpolicy(ubs_hcom_service service, ubs_hcom_service_lb_policy lbPolicy); + +void ubs_hcom_service_set_tls_opt(ubs_hcom_service service, bool enableTls, ubs_hcom_service_tls_version version, + ubs_hcom_service_cipher_suite cipherSuite, ubs_hcom_tls_get_cert_cb certCb, ubs_hcom_tls_get_pk_cb priKeyCb, + ubs_hcom_tls_get_ca_cb caCb); + +void ubs_hcom_service_set_secure_opt(ubs_hcom_service service, ubs_hcom_service_secure_type secType, + ubs_hcom_secinfo_provider provider, ubs_hcom_secinfo_validator validator, uint16_t magic, uint8_t version); + +void ubs_hcom_service_set_tcp_usr_timeout(ubs_hcom_service service, uint16_t timeOutSec); + +void ubs_hcom_service_set_tcp_send_zcopy(ubs_hcom_service service, bool tcpSendZCopy); + +void ubs_hcom_service_set_ipmask(ubs_hcom_service service, const char *ipMask); + +void ubs_hcom_service_set_ipgroup(ubs_hcom_service service, const char *ipGroup); + +void ubs_hcom_service_set_cq_depth(ubs_hcom_service service, uint16_t depth); + +void ubs_hcom_service_set_sq_size(ubs_hcom_service service, uint32_t sqSize); + +void ubs_hcom_service_set_rq_size(ubs_hcom_service service, uint32_t rqSize); + +void ubs_hcom_service_set_prepost_size(ubs_hcom_service service, uint32_t prePostSize); + +void ubs_hcom_service_set_polling_batchsize(ubs_hcom_service service, uint16_t pollSize); + +void ubs_hcom_service_set_polling_timeoutus(ubs_hcom_service service, uint16_t pollTimeout); + +void ubs_hcom_service_set_timeout_threadnum(ubs_hcom_service service, uint32_t threadNum); + +void ubs_hcom_service_set_max_connection_cnt(ubs_hcom_service service, uint32_t maxConnCount); + +void ubs_hcom_service_set_heartbeat_opt(ubs_hcom_service service, uint16_t idleSec, uint16_t probeTimes, + uint16_t intervalSec); + +void ubs_hcom_service_set_multirail_opt(ubs_hcom_service service, bool enable, uint32_t threshold); + +void ubs_hcom_service_set_ubcmode(ubs_hcom_service service, ubs_hcom_service_ubc_mode ubcMode); + +void ubs_hcom_service_set_enable_mrcache(ubs_hcom_service service, bool enableMrCache); + +void ubs_hcom_channel_refer(ubs_hcom_channel channel); +void ubs_hcom_channel_derefer(ubs_hcom_channel channel); +int ubs_hcom_channel_send(ubs_hcom_channel channel, ubs_hcom_channel_request req, ubs_hcom_channel_callback *cb); +int ubs_hcom_channel_call(ubs_hcom_channel channel, ubs_hcom_channel_request req, ubs_hcom_channel_response *rsp, + ubs_hcom_channel_callback *cb); +int ubs_hcom_channel_reply(ubs_hcom_channel channel, ubs_hcom_channel_request req, ubs_hcom_channel_reply_context ctx, + ubs_hcom_channel_callback *cb); +int ubs_hcom_channel_put(ubs_hcom_channel channel, ubs_hcom_oneside_request req, ubs_hcom_channel_callback *cb); +int ubs_hcom_channel_get(ubs_hcom_channel channel, ubs_hcom_oneside_request req, ubs_hcom_channel_callback *cb); +int ubs_hcom_channel_recv(ubs_hcom_channel channel, ubs_hcom_service_context ctx, uintptr_t address, uint32_t size, + ubs_hcom_channel_callback *cb); +int ubs_hcom_channel_send_fds(ubs_hcom_channel channel, int fds[], uint32_t len); +int ubs_hcom_channel_recv_fds(ubs_hcom_channel channel, int fds[], uint32_t len, int32_t timeoutSec); +int ubs_hcom_channel_set_flowctl_cfg(ubs_hcom_channel channel, ubs_hcom_flowctl_opts opt); +void ubs_hcom_channel_set_timeout(ubs_hcom_channel channel, int16_t oneSideTimeout, int16_t twoSideTimeout); +int ubs_hcom_channel_set_twoside_threshold(ubs_hcom_channel channel, ubs_hcom_twoside_threshold threshold); +uint64_t ubs_hcom_channel_get_id(ubs_hcom_channel channel); + +int ubs_hcom_context_get_rspctx(ubs_hcom_service_context context, ubs_hcom_channel_reply_context *rspCtx); +int ubs_hcom_context_get_channel(ubs_hcom_service_context context, ubs_hcom_channel *channel); +int ubs_hcom_context_get_type(ubs_hcom_service_context context, ubs_hcom_service_context_type *type); +int ubs_hcom_context_get_result(ubs_hcom_service_context context, int *result); +uint16_t ubs_hcom_context_get_opcode(ubs_hcom_service_context context); +void *ubs_hcom_context_get_data(ubs_hcom_service_context context); +uint32_t ubs_hcom_context_get_datalen(ubs_hcom_service_context context); + +/* + * @brief Set external logger function + * + * @param h [in] the log function ptr + */ +void ubs_hcom_set_log_handler(ubs_hcom_log_handler h); + +#ifdef __cplusplus +} +#endif + +#endif // HCOM_HCOM_SERVICE_C_V2_H diff --git a/src/common/code_msg.h b/src/common/code_msg.h new file mode 100644 index 0000000000000000000000000000000000000000..8feda274807446c83a46fc16da92c62b0d19cf1f --- /dev/null +++ b/src/common/code_msg.h @@ -0,0 +1,199 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_CODE_MSG_H +#define HCOM_CODE_MSG_H + +namespace ock { +namespace hcom { +static const char *NNCodeArray[] = { + "net error", + "net invalid ip", + "net new object generate failed", + "net invalid parameter", + "net message is too large for two side", + "net invalid opcode", + "net endpoint is not established", + "net endpoint is not initialized", + "net semaphore initialize failed for block queue", + "net timeout", + "net invalid operation", + "net malloc failed", + "net seqNo is not matched", + "net not initialized", + "net get buffer failed", + "net message timeout", + "net message canceled", + "net message error", + "net connection is refused", + "net connection protocol is mismatched", + "net invalid lKey", + "net endpoint is broken", + "net endpoint is closed", + "net invalid param", + "net oob listen socket error", + "net send error in oob connection", + "net receive error in oob connection", + "net oob connection callback is not set", + "net oob client socket error", + "net oob ssl initialize error", + "net oob ssl write error", + "net oob ssl read error", + "net create epoll failed in heartbeat manager", + "net set socket option failed in heartbeat manager", + "net ip is already existed in heartbeat manager", + "net failed to add ip into heartbeat manager", + "net failed to add ip into heartbeat manager as epoll add failed", + "net failed to remove ip from epoll handle", + "net ip is not found in heartbeat manager", + "net encrypt failed", + "net decrypt failed", + "net oob secure process error", + "net not support to exchange fd", + "net validate header failed", +}; + +static const char *RRCodeArray[] = { + "rdma invalid param", + "rdma memory allocate failed", + "rdma new object generate failed", + "rdma open file failed", + "rdma read file failed", + "rdma device open failed", + "rdma device index overflow", + "rdma device open failed", + "rdma device get interface address failed", + "rdma device interface address is mismatched", + "rdma device get gid failed by address", + "rdma device invalid ip mask", + "rdma memory region register failed", + "rdma completion queue is not initialized", + "rdma completion queue is polling failed", + "rdma completion queue is polling timeout", + "rdma completion queue is polling error result", + "rdma completion queue is polling unmatched opcode", + "rdma completion queue get event failed", + "rdma completion queue notify event failed", + "rdma completion queue is polled failed", + "rdma completion queue get event timeout", + "rdma create queue pair failed", + "rdma queue pair is not initialized", + "rdma queue pair state change failed", + "rdma queue pair post receive failed", + "rdma queue pair post send failed", + "rdma queue pair post read failed", + "rdma queue pair post write failed", + "rdma queue pair receive configuration error", + "rdma queue pair work request of post send is full", + "rdma queue pair one side work request is full", + "rdma queue pair context is full", + "rdma queue pair change error", + "rdma oob listen socket error", + "rdma send error in oob connection", + "rdma receive error in oob connection", + "rdma oob connection callback is not set", + "rdma oob client socket error", + "rdma oob ssl initialize error", + "rdma oob ssl write error", + "rdma oob ssl read error", + "rdma endpoint is no initialized", + "rdma worker is no initialized", + "rdma worker binds cpu failed", + "rdma request handler is not set in worker", + "rdma send request posted handler is not set in worker", + "rdma one side done handler not set in worker", + "rdma worker adds queue pair failed", + "rdma create epoll failed in heartbeat manager", + "rdma set socket option failed in heartbeat manager", + "rdma ip is already existed in heartbeat manager", + "rdma failed to add ip into heartbeat manager", + "rdma failed to add ip into heartbeat manager as epoll add failed", + "rdma failed to remove ip from epoll handle", + "rdma ip is not found in heartbeat manager", +}; + +static const char *ShCodeArray[] = { + "shm error", + "shm invalid parameter", + "shm memory allocate failed", + "shm new object generate failed", + "shm file operation failed", + "shm not initialized", + "shm timeout", + "shm context pool is used up", + "shm channel broken", + "shm create epoll failed for channel keeper", + "shm duplicated channel in channel keeper", + "shm add channel into channel keeper failed", + "shm remove channel from channel keeper failed", + "shm request queue space failed", + "shm send completion callback is failed", + "shm fd queue is full", + "shm peer fd is destroyed", + "shm op context is failed to remove", +}; + +static const char *SCodeArray[] = { + "socket general error", + "socket invalid parameter", + "socket memory allocate failed", + "socket new object generate failed", + "socket listen failed", + "socket create failed", + "socket data size is unmatched", + "socket epoll operation failed", + "socket send failed", + "socket connect failed", + "socket set option failed", + "socket get option failed", + "socket create epoll failed in worker", + "socket retry", + "socket eagain in nonblocking mode", + "socket send queue is full", + "socket context pool is used up", + "socket ssl write is failed", + "socket ssl read is failed", + "socket reset by peer", + "socket ssl read failed", + "socket timeout", +}; + +static const char *SevCodeArray[] = { + "service general error", + "service invalid parameter", + "service new object generate failed", + "service create timeout thread is failed", + "service malloc data memory failed", + "service channel is not established", + "service store seq no duplicated", + "service seq no is not found", + "service response size is small than data length", + "service timeout", + "service failed to start periodic manager", + "service is not configure enable RNDV, failed to start RNDV", + "service RNDV operate failed by peer", + "service store channel id duplicated", + "service reconnect find ep not broken", + "service find channel not exist", + "service reconnect over user set window", + "service connect failed by some ep broken", + "service do not support server invoke reconnect", + "service stop by user", +}; + +static int32_t NNCodeArrayLength = sizeof(NNCodeArray) / sizeof(NNCodeArray[0]); +static int32_t RRCodeArrayLength = sizeof(RRCodeArray) / sizeof(RRCodeArray[0]); +static int32_t ShCodeArrayLength = sizeof(ShCodeArray) / sizeof(ShCodeArray[0]); +static int32_t SCodeArrayLength = sizeof(SCodeArray) / sizeof(SCodeArray[0]); +static int32_t SevCodeArrayLength = sizeof(SevCodeArray) / sizeof(SevCodeArray[0]); +} +} +#endif // HCOM_CODE_MSG_H diff --git a/src/common/hcom_env.h b/src/common/hcom_env.h new file mode 100644 index 0000000000000000000000000000000000000000..1e69e08f7b1a991b1a0368f30b9284b830b01c4a --- /dev/null +++ b/src/common/hcom_env.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_ENV_H +#define OCK_HCOM_ENV_H + +#include "net_common.h" + +namespace ock { +namespace hcom { + +class HcomEnv { +public: + // 双边inline阈值,inline上限和网卡有关(一般256Bytes),超出网卡上限的话,创建QP时会报错,如果设置太大在创建QP时提醒用户 + static inline uint32_t InlineThreshold() + { + static long threshold = [] () { + auto value = NetFunc::NN_GetLongEnv("HCOM_INLINE_THRESHOLD", 0, UINT32_MAX, 0); + NN_LOG_INFO("Inline threshold is: " << value); + return static_cast(value); + }(); + return threshold; + } + + // 双边rndv阈值,默认是UINT32_MAX,用户不设置默认不开启 + static inline uint32_t RndvThreshold() + { + static long threshold = [] () { + auto value = NetFunc::NN_GetLongEnv("HCOM_RNDV_THRESHOLD", 0, UINT32_MAX, UINT32_MAX); + NN_LOG_INFO("Rndv Threshold is: " << value); + return static_cast(value); + }(); + return threshold; + } +}; + +} +} + +#endif \ No newline at end of file diff --git a/src/common/net_addr_size_map.h b/src/common/net_addr_size_map.h new file mode 100644 index 0000000000000000000000000000000000000000..b1f6d82a98bb61e90fbcd20a154062dcf8d39b11 --- /dev/null +++ b/src/common/net_addr_size_map.h @@ -0,0 +1,483 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_NET_ADDR_SIZE_MAP_H +#define HCOM_NET_ADDR_SIZE_MAP_H + +#include "hcom.h" + +namespace ock { +namespace hcom { +constexpr uint64_t gAddressMask = 0xFFFFFFFFFFFF; /* mask for key */ + +/* + * @brief Spin lock entry in bucket + * used for alloc overflowed buckets + */ +struct NetHashLockEntry { + uint64_t lock = 0; + + /* + * @brief Spin lock + */ + void Lock() + { + while (!__sync_bool_compare_and_swap(&lock, 0, NN_NO1)) { + } + } + + /* + * @brief Unlock + */ + void Unlock() + { + __atomic_store_n(&lock, 0, __ATOMIC_SEQ_CST); + } +} __attribute__((packed)); + +/* + * @brief Store the key/value into a linked array with 6 items, + * because 64bytes is one cache line + */ +struct NetHashBucket { + /* + * @brief Make entry with address and times of base size + * first 16bits: timesOfBaseSize + * second 48bits: address + */ + static inline uint64_t MakeEntry(uint64_t address, uint64_t timesOfBaseSize) + { + return (timesOfBaseSize << NN_NO48) | address; + } + + /* + * @brief Get times of base size from entry + * first 16bits: timesOfBaseSize + * second 48bits: address + */ + static inline uint64_t GetSize(uint64_t entry) + { + return entry >> NN_NO48; + } + + uint64_t subBuck[NN_NO6] {}; + NetHashBucket *next = nullptr; + NetHashLockEntry spinLock {}; + + bool Put(uint64_t address, uint64_t timesOfBaseSize) + { + /* + * There are three pre-conditions, as this is used for memory allocator + * 1 it is NOT possible that put and remove the same address at same time + * 2 it is NOT possible that put two same key at same time + * 3 there is no duplicated address + * + * these pre-conditions make logic much simpler, two steps need: + * 1 loop and find an empty place in the bucket + * 2 if no free in bucket expand a new one + */ + + /* don't put them into loop, flat code is faster than loop */ + auto newEntry = MakeEntry(address, timesOfBaseSize); + if (subBuck[NN_NO0] == 0 && __sync_bool_compare_and_swap(&subBuck[NN_NO0], 0, newEntry)) { + return true; + } + + if (subBuck[NN_NO1] == 0 && __sync_bool_compare_and_swap(&subBuck[NN_NO1], 0, newEntry)) { + return true; + } + + if (subBuck[NN_NO2] == 0 && __sync_bool_compare_and_swap(&subBuck[NN_NO2], 0, newEntry)) { + return true; + } + + if (subBuck[NN_NO3] == 0 && __sync_bool_compare_and_swap(&subBuck[NN_NO3], 0, newEntry)) { + return true; + } + + if (subBuck[NN_NO4] == 0 && __sync_bool_compare_and_swap(&subBuck[NN_NO4], 0, newEntry)) { + return true; + } + + if (subBuck[NN_NO5] == 0 && __sync_bool_compare_and_swap(&subBuck[NN_NO5], 0, newEntry)) { + return true; + } + + return false; + } + + /* + * @brief Remove the address from the bucket and get size + */ + bool Remove(uint64_t address, uint32_t ×OfBaseSize) + { + /* + * expand the loop, instead of put them into a for/while loop for performance + */ + uint64_t oldValue = subBuck[NN_NO0]; + if ((oldValue & gAddressMask) == address) { + __sync_bool_compare_and_swap(&subBuck[NN_NO0], oldValue, 0); + timesOfBaseSize = GetSize(oldValue); + return true; + } + + oldValue = subBuck[NN_NO1]; + if ((oldValue & gAddressMask) == address) { + __sync_bool_compare_and_swap(&subBuck[NN_NO1], oldValue, 0); + timesOfBaseSize = GetSize(oldValue); + return true; + } + + oldValue = subBuck[NN_NO2]; + if ((oldValue & gAddressMask) == address) { + __sync_bool_compare_and_swap(&subBuck[NN_NO2], oldValue, 0); + timesOfBaseSize = GetSize(oldValue); + return true; + } + + oldValue = subBuck[NN_NO3]; + if ((oldValue & gAddressMask) == address) { + __sync_bool_compare_and_swap(&subBuck[NN_NO3], oldValue, 0); + timesOfBaseSize = GetSize(oldValue); + return true; + } + + oldValue = subBuck[NN_NO4]; + if ((oldValue & gAddressMask) == address) { + __sync_bool_compare_and_swap(&subBuck[NN_NO4], oldValue, 0); + timesOfBaseSize = GetSize(oldValue); + return true; + } + + oldValue = subBuck[NN_NO5]; + if ((oldValue & gAddressMask) == address) { + __sync_bool_compare_and_swap(&subBuck[NN_NO5], oldValue, 0); + timesOfBaseSize = GetSize(oldValue); + return true; + } + + return false; + } +}; + +/* + * @brief Allocator template, for extend memory allocation for overflowed buckets + */ +class NetHeapAllocator { +public: + void *Allocate(uint32_t size) + { + return calloc(NN_NO1, size); + } + + void Free(void *p) + { + if (NN_LIKELY(p != nullptr)) { + free(p); + p = nullptr; + } + } +}; + +/* + * A high performance lockless hash map to store address and size(i.e. key=address, value=size), + * the unique things are following: + * 1 split one hash bucket array into sub 7 bucket arrays + * 2 store key and value into uint64_t, to minimize the memory occupation and cache miss + * 3 instead of store key/value into linked list, we store key/value into linked array + * 4 using CAS instead of mutex + */ +template class NetAddress2SizeHashMap { +public: + DEFINE_RDMA_REF_COUNT_FUNCTIONS + + NetAddress2SizeHashMap() = default; + ~NetAddress2SizeHashMap() + { + UnInitialize(); + } + + NResult Initialize(uint32_t reserve) + { + /* already initialized */ + if (mOverflowEntryAlloc != nullptr) { + return NN_OK; + } + + /* get proper bucket count */ + uint32_t bucketCount = reserve < NN_NO128 ? NN_NO128 : reserve; + if (bucketCount > gPrimes[NN_NO165]) { + bucketCount = gPrimes[NN_NO165]; + } else { + uint32_t i = 0; + while (i < gPrimesCount - 1 && gPrimes[i] < bucketCount) { + i++; + } + bucketCount = gPrimes[i]; + } + + /* allocate buckets for sub-maps */ + for (uint16_t i = 0; i < gSubMapCount; i++) { + auto tmp = new (std::nothrow) NetHashBucket[bucketCount]; + if (NN_UNLIKELY(tmp == nullptr)) { + for (uint16_t j = i; j < gSubMapCount; j++) { + mSubMaps[j] = nullptr; + } + FreeSubMaps(); + NN_LOG_ERROR("Failed to new hash bucket, probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + /* make physical page and set to zero */ + bzero(tmp, sizeof(NetHashBucket) * bucketCount); + + mSubMaps[i] = tmp; + } + + /* create overflow entry allocator */ + mOverflowEntryAlloc = new (std::nothrow) Alloc(); + if (NN_UNLIKELY(mOverflowEntryAlloc == nullptr)) { + FreeSubMaps(); + NN_LOG_ERROR("Failed to new overflow entry allocator, probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + /* set bucket count */ + mBucketCount = bucketCount; + mCount = 0; + + NN_LOG_INFO("Initialized NetAddress2SizeHashMap with " << gSubMapCount << " sub-maps, each contains " << + mBucketCount << " buckets, count of items " << mCount << ", occupied memory by buckets is " << + (sizeof(NetHashBucket) * bucketCount * gSubMapCount) << " bytes"); + + return NN_OK; + } + + void UnInitialize() + { + if (mOverflowEntryAlloc == nullptr) { + return; + } + /* free overflowed entries firstly */ + FreeOverFlowedEntries(); + + /* free sub map secondly */ + FreeSubMaps(); + + /* free overflow entry at last */ + delete mOverflowEntryAlloc; + mOverflowEntryAlloc = nullptr; + + mBucketCount = 0; + mCount = 0; + } + + /* + * @brief Put address with size in + * + */ + NResult Put(uintptr_t address, uint32_t timesOfBaseSize) + { + if (NN_UNLIKELY(address == 0)) { + return NN_INVALID_PARAM; + } + + /* get bucket */ + auto buck = &(mSubMaps[address % gSubMapCount][address % mBucketCount]); + + /* try 8192 times */ + for (uint16_t i = 0; i < NN_NO8192; i++) { + /* loop all buckets linked */ + while (buck != nullptr) { + /* if there is an entry to put, just break */ + if (buck->Put(address, timesOfBaseSize)) { + /* increase count of items */ + __sync_add_and_fetch(&mCount, 1); + return NN_OK; + } + + /* + * if no next bucket exist, just for break, + * else move to next bucket linked + */ + if (buck->next == nullptr) { + break; + } else { + buck = buck->next; + } + } + + /* + * if not put successfully in existing buckets, allocate a new one + * + * NOTES: just allocate memory, don't access new bucket in the spin lock scope, + * if access new bucket, which could trigger physical memory allocation which + * could trigger page fault, that is quite slow. In this case, spin lock + * could occupy too much CPU + */ + auto &lock = buck->spinLock; + lock.Lock(); + /* if other thread allocated new buck already, unlock and continue */ + if (buck->next != nullptr) { + buck = buck->next; + lock.Unlock(); + continue; + } + + /* firstly entered thread allocate new bucket */ + auto newBuck = static_cast(mOverflowEntryAlloc->Allocate(sizeof(NetHashBucket))); + if (NN_UNLIKELY(newBuck == nullptr)) { + lock.Unlock(); + NN_LOG_ERROR("Failed to alloc new overflowed bucket from allocator"); + return NN_MALLOC_FAILED; + } + + /* link to current buck, set buck to new buck */ + buck->next = newBuck; + buck = newBuck; + + /* unlock */ + lock.Unlock(); + } + + NN_LOG_ERROR("Failed to put key/size with " << NN_NO8192 * NN_NO6 << " times try"); + return NN_ERROR; + } + + /* + * @brief Remove and get size + */ + NResult Remove(uintptr_t address, uint32_t ×OfBaseSize) + { + if (NN_UNLIKELY(address == 0)) { + return NN_INVALID_PARAM; + } + + /* get bucket */ + auto buck = &(mSubMaps[address % gSubMapCount][address % mBucketCount]); + + /* loop all buckets linked */ + while (buck != nullptr) { + if (buck->Remove(address, timesOfBaseSize)) { + __sync_sub_and_fetch(&mCount, 1); + return NN_OK; + } + + buck = buck->next; + } + + NN_LOG_TRACE_INFO("Not found address in address2size map, which should not happen"); + return NN_ERROR; + } + + /* + * @brief Get size of item in hash map + */ + inline uint32_t Size() const + { + return mCount; + } + +private: + void FreeSubMaps() + { + /* free all sub maps */ + for (uint16_t i = 0; i < gSubMapCount; i++) { + auto &tmp = mSubMaps[i]; + if (tmp != nullptr) { + delete[] tmp; + mSubMaps[i] = nullptr; + } + } + } + + void FreeOverFlowedEntries() + { + for (uint16_t i = 0; i < gSubMapCount; i++) { + auto &tmp = mSubMaps[i]; + if (tmp == nullptr) { + continue; + } + + /* free overflow entries in one sub map */ + for (uint32_t buckIndex = 0; buckIndex < mBucketCount; ++buckIndex) { + auto curBuck = mSubMaps[i][buckIndex].next; + NetHashBucket *nextOverflowEntryBuck = nullptr; + + /* exit loop when curBuck is null */ + while (curBuck != nullptr) { + /* assign next overflow buck to tmp variable */ + nextOverflowEntryBuck = curBuck->next; + + /* free this overflow bucket */ + mOverflowEntryAlloc->Free(curBuck); + + /* assign next to current */ + curBuck = nextOverflowEntryBuck; + } + } + } + } + +private: + static constexpr uint16_t gSubMapCount = NN_NO5; /* count of sub map */ + static constexpr uint32_t gPrimesCount = NN_NO256; + +private: + /* make sure the size of this class is 64 bytes, fit into one cache line */ + Alloc *mOverflowEntryAlloc = nullptr; /* allocate overflowed entry in one bucket */ + NetHashBucket *mSubMaps[gSubMapCount] {}; /* sub map */ + uint32_t mBucketCount = 0; /* bucket count of each sub map */ + uint32_t mCount = 0; /* bucket count of each sub map */ + uint32_t mBaseSize = NN_NO4096; /* base size */ + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + const uint32_t gPrimes[gPrimesCount] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, + 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, + 97, 103, 109, 113, 127, 137, 139, 149, 157, 167, + 179, 193, 199, 211, 227, 241, 257, 277, 293, 313, + 337, 359, 383, 409, 439, 467, 503, 541, 577, 619, + 661, 709, 761, 823, 887, 953, 1031, 1109, 1193, 1289, + 1381, 1493, 1613, 1741, 1879, 2029, 2179, 2357, 2549, + 2753, 2971, 3209, 3469, 3739, 4027, 4349, 4703, 5087, + 5503, 5953, 6427, 6949, 7517, 8123, 8783, 9497, 10273, + 11113, 12011, 12983, 14033, 15173, 16411, 17749, 19183, + 20753, 22447, 24281, 26267, 28411, 30727, 33223, 35933, + 38873, 42043, 45481, 49201, 53201, 57557, 62233, 67307, + 72817, 78779, 85229, 92203, 99733, 107897, 116731, 126271, + 136607, 147793, 159871, 172933, 187091, 202409, 218971, 236897, + 256279, 277261, 299951, 324503, 351061, 379787, 410857, 444487, + 480881, 520241, 562841, 608903, 658753, 712697, 771049, 834181, + 902483, 976369, 1056323, 1142821, 1236397, 1337629, 1447153, + 1565659, 1693859, 1832561, 1982627, 2144977, 2320627, 2510653, + 2716249, 2938679, 3179303, 3439651, 3721303, 4026031, 4355707, + 4712381, 5098259, 5515729, 5967347, 6456007, 6984629, 7556579, + 8175383, 8844859, 9569143, 10352717, 11200489, 12117689, + 13109983, 14183539, 15345007, 16601593, 17961079, 19431899, + 21023161, 22744717, 24607243, 26622317, 28802401, 31160981, + 33712729, 36473443, 39460231, 42691603, 46187573, 49969847, + 54061849, 58488943, 63278561, 68460391, 74066549, 80131819, + 86693767, 93793069, 101473717, 109783337, 118773397, 128499677, + 139022417, 150406843, 162723577, 176048909, 190465427, + 206062531, 222936881, 241193053, 260944219, 282312799, + 305431229, 330442829, 357502601, 386778277, 418451333, + 452718089, 489790921, 529899637, 573292817, 620239453, + 671030513, 725980837, 785430967, 849749479, 919334987, + 994618837, 1076067617, 1164186217, 1259520799, 1362662261, + 1474249943, 1594975441, 1725587117, 1866894511, 2019773507, + 2185171673, 2364114217, 2557710269, 2767159799, 2993761039, + 3238918481, 3504151727, 3791104843, 4101556399, 4294967291}; +}; +} +} + +#endif // HCOM_NET_ADDR_SIZE_MAP_H diff --git a/src/common/net_bucket_linked_list.h b/src/common/net_bucket_linked_list.h new file mode 100644 index 0000000000000000000000000000000000000000..0a9a6ae666638b1de542a50b4ac811bcab9c5d31 --- /dev/null +++ b/src/common/net_bucket_linked_list.h @@ -0,0 +1,125 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_LINKED_LIST_H +#define OCK_HCOM_NET_LINKED_LIST_H + +#include "hcom.h" + +namespace ock { +namespace hcom { +/* + * Node info for linked list + */ +struct NetLLNode { + struct NetLLNode *next = nullptr; /* point to next node, which is memory segment */ +}; + +/* + * The meta info for one linked list, + */ +struct NetBucketLinkedListMeta { + NetLLNode *next = nullptr; /* point to the real memory segment */ + NetSpinLock lock {}; /* spin lock for insertion & deletion of memory */ + uint32_t count = 0; /* the count of current linked list */ +}; + +/* + * A thread safe linked list with buckets for multiple threads cases, + * used for MR segment allocation. + * + * This linked list doesn't allocate extract memory for linked node, + * the linked node info stores on the start place of free memory segment. + * The linked node info needs to clean after allocated, since these memory segments + * are allocated to end user possibly. + */ +#define BUCKET_COUNT 64 +class NetBucketLinkedList { +public: + NetBucketLinkedList() = default; + ~NetBucketLinkedList() = default; + + /* + * @brief Push one item to linked list + * + * @param item [in] the address of memory to added to list + */ + inline void PushFront(uintptr_t item) + { + auto *newNode = reinterpret_cast(item); + if (NN_UNLIKELY(newNode == nullptr)) { + return; + } + NetBucketLinkedListMeta *buckets = &mBuckets[__sync_fetch_and_add(&mPushRRIdx, 1) % BUCKET_COUNT]; + buckets->lock.Lock(); + newNode->next = buckets->next; + buckets->next = newNode; + buckets->count++; + buckets->lock.Unlock(); + } + + inline bool Pop(uintptr_t &item) + { + uint16_t leftBucketsCount = BUCKET_COUNT; + do { + NetBucketLinkedListMeta *buckets = &mBuckets[__sync_fetch_and_add(&mPopRRIdx, 1) % BUCKET_COUNT]; + + buckets->lock.Lock(); + if (NN_UNLIKELY(buckets->count == NN_NO0)) { + buckets->lock.Unlock(); + continue; + } + + item = reinterpret_cast(buckets->next); + + buckets->next = buckets->next->next; + buckets->count--; + buckets->lock.Unlock(); + return true; + } while (--leftBucketsCount > 0); + + return false; + } + + inline bool PopN(uintptr_t *&items, uint32_t n) + { + if (NN_UNLIKELY(items == nullptr)) { + return false; + } + + /* traverse every bucket for balance */ + for (uint32_t i = NN_NO0; i < n; i++) { + if (NN_UNLIKELY(!Pop(items[i]))) { + for (uint32_t j = NN_NO0; j < i; j++) { + PushFront(items[j]); + } + return false; + } + } + + return true; + } + + NetBucketLinkedList(const NetBucketLinkedList &) = delete; + NetBucketLinkedList(NetBucketLinkedList &&) = delete; + NetBucketLinkedList &operator = (const NetBucketLinkedList &) = delete; + NetBucketLinkedList &operator = (NetBucketLinkedList &&) = delete; + +private: + /* NOTE: to make sure the size of this class is same with one cache line of CPU */ + uint32_t mPopRRIdx = 0; /* round-robin index for pop */ + uint32_t mPushRRIdx = 0; /* round-robin index for push */ + NetBucketLinkedListMeta mBuckets[BUCKET_COUNT] {}; /* buckets linked list */ +}; +} +} + +#endif // OCK_HCOM_NET_LINKED_LIST_H diff --git a/src/common/net_common.h b/src/common/net_common.h new file mode 100644 index 0000000000000000000000000000000000000000..80d74cdb208482eb9b0982e62dfd1aef92a446fb --- /dev/null +++ b/src/common/net_common.h @@ -0,0 +1,858 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_NET_COMMON_123424434341233_H +#define OCK_NET_COMMON_123424434341233_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "net_crc32.h" +#include "net_trace.h" +#include "net_util.h" +#include "securec.h" + +namespace ock { +namespace hcom { +constexpr int INVALID_FD = -1; +constexpr int16_t MAX_OPCODE = NN_NO1200; +constexpr uint32_t OOB_DEFAULT_LISTEN_PORT = 9980; +constexpr uint32_t OOB_DEFAULT_LISTEN_BACKLOG = 65535; +constexpr uint32_t MR_FIXED_POOL_DEFAULT_SEG_SIZE = 8192; +constexpr uint32_t MR_FIXED_POOL_DEFAULT_SEG_COUNT = 1024; + +#ifndef KERNEL_VERSION +#define KERNEL_VERSION(a, b, c) (((a) << 16) + ((b) << 8) + (c)) +#endif + +/** + * Get struct pointer from member pointer + */ +#define GetStructRoot(_memberPtr, _type, _field) ((_type*)((char*)(_memberPtr) - offsetof(_type, _field))) + +enum class NetProtocol { + NET_TCP, + NET_UDS, + NET_UBC, +}; + +class NetFunc { +public: + static inline uint32_t CalcHeaderCrc32(UBSHcomNetTransHeader *header) + { + static const uint32_t LENGTH = sizeof(UBSHcomNetTransHeader) - sizeof(uint32_t); + return NetCrc32::CalcCrc32(reinterpret_cast(header) + sizeof(uint32_t), LENGTH); + } + + static inline uint32_t CalcHeaderCrc32(UBSHcomNetTransHeader &header) + { + static const uint32_t LENGTH = sizeof(UBSHcomNetTransHeader) - sizeof(uint32_t); + return NetCrc32::CalcCrc32(reinterpret_cast(&header) + sizeof(uint32_t), LENGTH); + } + + static inline bool ValidateHeaderCrc32(UBSHcomNetTransHeader *header) + { + if (NN_UNLIKELY(header == nullptr)) { + NN_LOG_ERROR("Invalid param, header must be correct address"); + return false; + } + return header->headerCrc == CalcHeaderCrc32(header); + } + + static inline bool ValidateHeaderCrc32(UBSHcomNetTransHeader &header) + { + return header.headerCrc == CalcHeaderCrc32(header); + } + + static inline NResult ValidateSeqNo(UBSHcomNetTransHeader &header, uint32_t lastSendSeqNo) + { + if (NN_UNLIKELY(header.seqNo != lastSendSeqNo)) { + NN_LOG_ERROR("Received un-matched seq no " << header.seqNo << ", demand seq no " + << lastSendSeqNo); + return NN_SEQ_NO_NOT_MATCHED; + } + + return NN_OK; + } + + static inline NResult ValidateHeader(UBSHcomNetTransHeader &header) + { + if (header.dataLength == 0 || header.dataLength > NET_SGE_MAX_SIZE) { + NN_LOG_ERROR("Failed to validate header dataLength " << header.dataLength << " received"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(!ValidateHeaderCrc32(header))) { + NN_LOG_ERROR("Failed to validate received header crc " << header.headerCrc); + return NN_VALIDATE_HEADER_CRC_INVALID; + } + + return NN_OK; + } + + static inline NResult ValidateHeaderWithDataSize(UBSHcomNetTransHeader &header, uint32_t dataSize) + { + if (header.dataLength != (dataSize - sizeof(UBSHcomNetTransHeader))) { + NN_LOG_ERROR("Failed to validate received dataLength " << header.dataLength << " in header, with dataSize " + << dataSize); + return NN_INVALID_PARAM; + } + + return ValidateHeader(header); + } + + static inline NResult ValidateHeaderWithSeqNo(UBSHcomNetTransHeader &header, uint32_t dataSize, + uint32_t lastSendSeqNo) + { + NResult ret = ValidateSeqNo(header, lastSendSeqNo); + if (ret != NN_OK) { + return ret; + } + + return ValidateHeaderWithDataSize(header, dataSize); + } + /* + * @brief Safely close fd, if the fd is less than 0, no action + * if fd >= 0, close it and assign to -1 atomically + * + * @param fd [in] fd to be closed + */ + static inline void NN_SafeCloseFd(int &fd) + { + if (NN_UNLIKELY(fd < 0)) { + return; + } + + auto tmpFd = fd; + if (__sync_bool_compare_and_swap(&fd, tmpFd, INVALID_FD)) { + close(tmpFd); + } + } + + static inline uint32_t GetIpByFd(int fd) + { + struct sockaddr_in addressIn {}; + addressIn.sin_addr.s_addr = INVALID_IP; + socklen_t len = sizeof(addressIn); + getsockname(fd, reinterpret_cast(&addressIn), &len); // UDS return INVALID_IP + return addressIn.sin_addr.s_addr; + } + + static inline uint32_t IpStringToUint32(const std::string &ip) + { + struct in_addr addr; + if (inet_pton(AF_INET, ip.c_str(), &addr) == 1) { + return addr.s_addr; + } + NN_LOG_ERROR("Fail to change ip string to uint32_t, ip " << ip); + return 0; + } + + /* + * @brief Round up one number to another + */ + static inline uint64_t NN_RoundUpTo(uint64_t value, uint64_t align) + { + return ((value + align - 1) / align) * align; + } + + /* + * @brief Get next power of N, its index, instead of the final value + */ + static inline uint64_t NN_PowerOfNIndex(uint64_t value, uint64_t align) + { + uint64_t tmp = (value + align - 1) / align; + return tmp < NN_NO2 ? 0 : (NN_NO32 - __builtin_clz(tmp - 1)); + } + + /* NN_SplitStr */ + static void NN_SplitStr(const std::string &str, const std::string &separator, std::vector &result) + { + result.clear(); + std::string::size_type pos1 = 0; + std::string::size_type pos2 = str.find(separator); + + std::string tmpStr; + while (pos2 != std::string::npos) { + tmpStr = str.substr(pos1, pos2 - pos1); + result.emplace_back(tmpStr); + pos1 = pos2 + separator.size(); + pos2 = str.find(separator, pos1); + } + + if (pos1 != str.length()) { + tmpStr = str.substr(pos1); + result.emplace_back(tmpStr); + } + } + + static void NN_VecStrToStr(const std::vector &vec, const std::string &linkStr, std::string &result) + { + result.clear(); + for (const auto &item : vec) { + if (NN_UNLIKELY(result.empty())) { + result = item; + } else { + result += (linkStr + item); + } + } + } + + static bool NN_Stol(const std::string &str, long &value) + { + char *remain = nullptr; + errno = 0; + value = std::strtol(str.c_str(), &remain, 10); // 10 is decimal digits + if (remain == nullptr || strlen(remain) > 0 || ((value == LONG_MAX || value == LONG_MIN) && errno == ERANGE)) { + return false; + } else if (value == 0 && str != "0") { + return false; + } + return true; + } + + static bool NN_Stof(const std::string &str, float &value) + { + constexpr float EPSINON = 0.000001; + char *remain = nullptr; + errno = 0; + value = std::strtof(str.c_str(), &remain); + if (remain == nullptr || strlen(remain) > 0 || + ((value - HUGE_VALF) >= -EPSINON && (value - HUGE_VALF) <= EPSINON && errno == ERANGE)) { + return false; + } else if ((value >= -EPSINON && value <= EPSINON) && (str != "0.0")) { + return false; + } + return true; + } + + static bool NN_CovertIpMask(const std::string &maskString, in_addr_t &ipByMask, in_addr_t &mask) + { + std::vector ipMaskVec; + NN_SplitStr(maskString, "/", ipMaskVec); + if (ipMaskVec.size() != NN_NO2) { + return false; + } + + long maskWidth = 0; + if (!NN_Stol(ipMaskVec[1], maskWidth)) { + return false; + } + + long maskOffset = NN_NO32 - maskWidth; + if (maskOffset < 0 || maskOffset >= NN_NO32) { + return false; + } + mask = static_cast(0xFFFFFFFF >> maskOffset); + + auto tmp = inet_addr(ipMaskVec[0].c_str()); + if (tmp == INADDR_NONE) { + return false; + } + + ipByMask = tmp & mask; + return true; + } + + static bool NN_CovertIpWithoutPort(const std::string &ipPort, uint32_t &ip) + { + std::vector ipPortVec; + NN_SplitStr(ipPort, ":", ipPortVec); + if (ipPortVec.size() != NN_NO2) { + return false; + } + + auto tmp = inet_addr(ipPortVec[0].c_str()); + if (tmp == INADDR_NONE) { + return false; + } + + ip = tmp; + return true; + } + + static NResult NN_ValidateUrl(const std::string &name) + { + if (NN_UNLIKELY(name.length() > NN_NO100 || name.length() < NN_NO1)) { + NN_LOG_WARN("Url length should be in 1-100"); + return NN_INVALID_PARAM; + } + for (char n : name) { + if (NN_UNLIKELY((!std::isalnum(n)) && + (n != '_' && n != '-' && n != '/' && n != '.' && n != ':'))) { + NN_LOG_WARN("Url cannot contain illegal characters, only could contain alphabet, " + "number, -, _, ., :, /"); + return NN_INVALID_PARAM; + } + } + return NN_OK; + } + + static bool NN_ConvertIpAndPort(const std::string &url, std::string &ip, uint16_t &port) + { + if (NN_UNLIKELY(NN_ValidateUrl(url) != NN_OK)) { + NN_LOG_ERROR("Invalid url"); + return false; + } + std::string separator(":"); + std::string::size_type pos = url.find(separator); + if (NN_UNLIKELY(pos == std::string::npos)) { + NN_LOG_ERROR("invalid url: " << url << ", must be like 127.0.0.1:9981"); + return false; + } + + ip = url.substr(0, pos); + port = std::strtoul(url.substr(pos + 1).c_str(), nullptr, NN_NO10); + if (NN_UNLIKELY(port == NN_NO0)) { + NN_LOG_ERROR("Invalid port, url:" << url); + return false; + } + return true; + } + + static bool NN_ConvertEidAndJettyId(const std::string &url, std::string &eid, uint16_t &jettyId) + { + std::vector idVec; + NN_SplitStr(url, ":", idVec); + if (idVec.size() != NN_NO9 || url.length() <= NN_NO40) { + return false; + } + + eid = url.substr(0, NN_NO39); + auto tmpId = std::strtoul(url.substr(NN_NO40, url.length()).c_str(), nullptr, NN_NO10); + if (tmpId < NN_NO4 || tmpId > NN_NO1023) { + NN_LOG_ERROR("Ensure the jetty id in range 4~1023"); + return false; + } + + jettyId = tmpId; + return true; + } + + static bool NN_ConvertNameAndPerm(const std::string &url, std::string &name, uint16_t &perm) + { + if (NN_UNLIKELY(NN_ValidateUrl(url) != NN_OK)) { + NN_LOG_ERROR("Invalid url"); + return false; + } + std::string separator(":"); + std::string::size_type pos = url.find(separator); + if (NN_LIKELY(pos == std::string::npos)) { + name = url; + perm = 0; + return true; + } + name = url.substr(0, pos); + perm = std::strtoul(url.substr(pos + 1).c_str(), nullptr, NN_NO10); + if (NN_UNLIKELY(perm == NN_NO0) || NN_UNLIKELY(perm == UINT16_MAX)) { + NN_LOG_ERROR("Invalid perm, url:" << url); + return false; + } + return true; + } + + // protocal://url + static bool NN_SplitProtoUrl(const std::string &protoUrl, NetProtocol &protocal, std::string &url) + { + std::string separator("://"); + std::string::size_type pos = protoUrl.find(separator); + if (NN_UNLIKELY(pos == std::string::npos)) { + NN_LOG_ERROR("Invalid url, must be like tcp://127.0.0.1:9981 or uds://name or ubc://eid:jettyId"); + return false; + } + + std::string protoStr = protoUrl.substr(0, pos); + if (protoStr == "tcp") { + protocal = NetProtocol::NET_TCP; + } else if (protoStr == "uds") { + protocal = NetProtocol::NET_UDS; + } else if (protoStr == "ubc") { + protocal = NetProtocol::NET_UBC; + } else { + NN_LOG_ERROR("Unsupport url protocal"); + return false; + } + url = protoUrl.substr(pos + separator.size()); + return true; + } + + static bool NN_CheckFilePrefix(std::string fileName) + { + char *envFilePrefixPath = ::getenv("HCOM_FILE_PATH_PREFIX"); + if (NN_UNLIKELY(envFilePrefixPath == nullptr)) { + NN_LOG_ERROR("Check file prefix failed as env HCOM_FILE_PATH_PREFIX is not set"); + return false; + } + std::string filenamePrefix = fileName.substr(0, strlen(envFilePrefixPath)); + if (NN_UNLIKELY(filenamePrefix != envFilePrefixPath)) { + NN_LOG_ERROR("Check file prefix failed as prefix does not match HCOM_FILE_PATH_PREFIX"); + return false; + } + return true; + } + /* + * @brief Parse worker string to vector + * + * @param workerStr [in] string format of workers + * @param workerGroups [out] vector of groups + * + * @return true if workStr is valid and convert successfully + * + * for example, workerStr is 1,3,3 + * output workerGroups will be [1,3,3] vector + * + * validations: + * 1 the total count of group be less than 128 + * 2 if workerStr is empty, [1] vector will be the output + * 3 each element is workerStr must be a digital + * 4 each element is workerStr must be 1 to 128 + */ + static bool NN_ParseWorkersGroups(const std::string &workerStr, std::vector &workerGroups) + { + std::vector extractStrings; + NN_SplitStr(workerStr, ",", extractStrings); + + NN_LOG_TRACE_INFO("worker str '" << workerStr << "', extract vector size " << extractStrings.size()); +#ifdef NN_LOG_TRACE_INFO_ENABLED + for (auto &item : extractStrings) { + NN_LOG_TRACE_INFO("extracted item " << item); + } +#endif + + /* if empty, make it to default, i.e. 1 group with 1 worker */ + if (workerStr.empty() || extractStrings.empty()) { + workerGroups.clear(); + workerGroups.emplace_back(1); + return true; + } else if (extractStrings.size() > NN_NO128) { + NN_LOG_ERROR("Invalid worker group setting '" << workerStr << + "', example '1,3,3' meaning that there are 3 groups, 1 worker in group0, 3 workers in group1 and 3 " + "workers in group2. group size must be 1-128"); + return false; + } + + /* validate worker config */ + long tmpCount = 0; + workerGroups.reserve(extractStrings.size()); + for (auto &item : extractStrings) { + if (NN_Stol(item, tmpCount) && tmpCount > 0 && tmpCount <= NN_NO128) { + workerGroups.emplace_back(tmpCount); + continue; + } + + /* if invalid config group */ + NN_LOG_ERROR("Invalid worker group setting '" << workerStr << + "', example '1,3,3' meaning that there are 3 groups, 1 worker in group0, 3 workers in group1 and 3 " + "workers in group2. worker size in each group must be 1-128"); + return false; + } + + return true; + } + + /* + * @brief Parse cpu binding str to vector + * + * @param workerGroupCpusStr [in] cpu binding setting in string + * @param workerGroupCpus [out] output vector + * + * for example + * - input: na,10-12,13-16 + * - output: [[128,0],[10,3],[13,4]] + * - first is start cpu id + * - second is cpu count + * + * validations: + * 1 the total count of group be less than 128 + * 2 if workerGroupCpusStr is empty, [] vector will be the output + * 3 each element is workerGroupCpusStr na/NA/digital-range + */ + static bool NN_ParseWorkerGroupsCpus(const std::string &workerGroupCpusStr, + std::vector> &workerGroupCpus) + { + std::vector extractStrings; + NN_SplitStr(workerGroupCpusStr, ",", extractStrings); + + NN_LOG_TRACE_INFO("worker str '" << workerGroupCpusStr << "', extract vector size " << extractStrings.size()); +#ifdef NN_LOG_TRACE_INFO_ENABLED + for (auto &item : extractStrings) { + NN_LOG_TRACE_INFO("extracted item " << item); + } +#endif + + /* if empty */ + if (workerGroupCpusStr.empty() || extractStrings.empty()) { + workerGroupCpus.clear(); + return true; + } else if (extractStrings.size() > NN_NO128) { + NN_LOG_ERROR("Invalid cpu id setting '" << workerGroupCpusStr << + "' for worker groups, example '10-10,11-13,na' meaning that 10 for group0, 11/12/13 for group1, no " + "need to group2, each number must be 0-127, total group must less or equal to 128"); + return false; + } + + /* validate */ + long tmpCpuIdStart = 0; + long tmpCpuIdEnd = 0; + std::vector extractedCpuIds; + extractedCpuIds.reserve(NN_NO4); + workerGroupCpus.reserve(extractStrings.size()); + for (auto &item : extractStrings) { + if (item == "na" || item == "NA") { + workerGroupCpus.emplace_back(NN_NO128, 0); + continue; + } + + bool badConf = false; + NN_SplitStr(item, "-", extractedCpuIds); + + /* size un-matched and invalid digital */ + if (extractedCpuIds.size() != NN_NO2) { + badConf = true; + } else if (!NN_Stol(extractedCpuIds[0], tmpCpuIdStart) || !NN_Stol(extractedCpuIds[1], tmpCpuIdEnd)) { + badConf = true; + } else if (tmpCpuIdStart < 0 || tmpCpuIdStart >= NN_NO256 || tmpCpuIdEnd < 0 || tmpCpuIdEnd >= NN_NO256) { + badConf = true; + } else if (tmpCpuIdStart > tmpCpuIdEnd) { + badConf = true; + } + + if (badConf) { + NN_LOG_ERROR("Invalid cpu id setting '" << item << "' in '" << workerGroupCpusStr << + "' for worker groups, example '10-10,11-13,na' meaning that 10 for group0, 11/12/13 for group1, no " + "need to group2, each number must be 0-127, total group must less or equal to 128"); + return false; + } + + /* push the start index and count */ + workerGroupCpus.emplace_back(tmpCpuIdStart, tmpCpuIdEnd - tmpCpuIdStart + 1); + } + + return true; + } + + /* + * @brief Finalize cpu binding setting + * + * @param workerGroups [in] worker groups vector, for example [1,3,3] + * @param workerGroupCpus [in] cpu binding for worker groups, [[10,1], [11,3], [14,3]] + * @param allowDuplicatedCpuIds [in] allow duplicated cpus id, for rdma busy polling is not allowed + * @param flatWorkersCpus [out] flat cpu id for workers + * + * @return true if ok + */ + static bool NN_FinalizeWorkerGroupCpus(const std::vector &workerGroups, + const std::vector> &workerGroupCpus, bool allowDuplicatedCpuIds, + std::vector &flatWorkersCpus) + { + if (workerGroups.empty() || workerGroups.size() < workerGroupCpus.size()) { + NN_LOG_ERROR("Invalid worker groups which is empty or size of worker groups < cpu groups"); + return false; + } + + /* count total workers */ + uint16_t totalWorkers = 0; + for (auto item : workerGroups) { + totalWorkers += item; + } + + /* reserve and set to default -1 */ + flatWorkersCpus.reserve(totalWorkers); + for (uint16_t i = 0; i < totalWorkers; ++i) { + flatWorkersCpus.push_back(-1); + } + + /* match and set cpus */ + uint16_t flatWorkerCpuIndex = 0; + for (uint32_t i = 0; i < workerGroupCpus.size(); ++i) { + auto &cpuPair = workerGroupCpus[i]; + auto workersInGroup = workerGroups[i]; + + /* no need cpu bind */ + if (cpuPair.first == NN_NO128) { + flatWorkerCpuIndex += workersInGroup; + continue; + } + + /* invalid size */ + if (cpuPair.second > workersInGroup || (!allowDuplicatedCpuIds && cpuPair.second != workersInGroup)) { + NN_LOG_ERROR("Invalid cpus group '" << cpuPair.first << ":" << cpuPair.second << "', the count " << + cpuPair.second << " is larger than or not equal to workers number " << workersInGroup << + " of group " << i); + return false; + } + + /* set */ + for (uint16_t j = 0; j < workersInGroup; j++) { + flatWorkersCpus[flatWorkerCpuIndex + j] = static_cast(cpuPair.first + j % cpuPair.second); + } + + /* move the index */ + flatWorkerCpuIndex += workersInGroup; + } + + return true; + } + + /* + * @brief Parse worker group thread priority string to vector + * + * @param threadPriorityStr [in] string format of workers thread priority + * @param threadPriority [out] vector of thread priority groups + * + * @return true if threadPriorityStr is valid and convert successfully + * + * for example, threadPriorityStr is -1,-10,na,9 + * output threadPriority will be [-1,-10,0,9] vector + * + * validations: + * 1 the total count of thread priority be equal worker group count + * 2 if threadPriorityStr is empty, null vector will be the output + * 3 each element is threadPriorityStr must be a digital + * 4 each element is threadPriorityStr must be -20 to 20 + */ + static bool NN_ParseWorkersGroupsThreadPriority(const std::string &threadPriorityStr, + std::vector &threadPriority, int groupNum) + { + std::vector extractStrings; + NN_SplitStr(threadPriorityStr, ",", extractStrings); + + NN_LOG_TRACE_INFO("Worker group thread priority string '" << threadPriorityStr << "', extract vector size " << + threadPriority.size()); +#ifdef NN_LOG_TRACE_INFO_ENABLED + for (auto &item : extractStrings) { + NN_LOG_TRACE_INFO("extracted item " << item); + } +#endif + + /* if empty, make it to default, i.e. 1 group with 1 worker */ + if (threadPriorityStr.empty() || extractStrings.empty()) { + threadPriority.clear(); + return true; + } else if (static_cast(extractStrings.size()) != groupNum) { + NN_LOG_ERROR("Invalid worker group thread priority setting '" << threadPriorityStr << + "'. group size must be equal worker group number " << groupNum); + return false; + } + + /* validate worker config */ + long tmpCount = 0; + threadPriority.reserve(extractStrings.size()); + for (auto &item : extractStrings) { + if (strcmp(item.c_str(), "na") == 0) { + threadPriority.emplace_back(0); + continue; + } + if (NN_Stol(item, tmpCount) && tmpCount < NN_NO20 && tmpCount >= NN_NOF20) { + threadPriority.emplace_back(tmpCount); + continue; + } + + /* if invalid config group */ + NN_LOG_ERROR("Invalid worker group thread priority setting '" << threadPriorityStr << + "', example '1,3,na,10' meaning that there are 4 groups, group0 set thread priority 1 , group1 set " + "thread priority 3,group2 not set thread priority and group3 set thread priority 10" + ". thread priority in each group must be -20~19"); + return false; + } + + return true; + } + + static long NN_GetLongEnv(const char *env, long min, long max, long defaultNum) + { + if (env == nullptr) { + return defaultNum; + } + auto envString = getenv(env); + auto result = defaultNum; + if (envString != nullptr) { + long tmp = 0; + if (NetFunc::NN_Stol(envString, tmp) && tmp >= min && tmp <= max) { + result = tmp; + } + } + return result; + } + + static char *NN_GetStrError(int errNum, char *buf, size_t bufSize) + { +#if defined(_XOPEN_SOURCE) && defined(_POSIX_C_SOURCE) && defined(_GNU_SOURCE) && \ + (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600) && !_GNU_SOURCE + strerror_r(errNum, buf, bufSize - 1); + return buf; +#else + return strerror_r(errNum, buf, bufSize - 1); +#endif + } + + static NResult NN_ValidateName(const std::string &name) + { + if (NN_UNLIKELY(name.length() > NN_NO100 || name.length() < NN_NO1)) { + NN_LOG_WARN("Service or Driver name length should be in 1-100"); + return NN_INVALID_PARAM; + } + for (char n : name) { + if (NN_UNLIKELY((!std::isalnum(n)) && (n != '_' && n != '-'))) { + NN_LOG_WARN("Service or Driver name cannot contain illegal characters, only could contain alphabet, " + "number, -, _"); + return NN_INVALID_PARAM; + } + } + return NN_OK; + } +}; + +class MemoryRegionChecker { +public: + MemoryRegionChecker() = default; + explicit MemoryRegionChecker(bool lockWhenOperates) : mLockWhenOperates(lockWhenOperates) {} + + inline NResult Validate(uint64_t key, uintptr_t address, uint64_t size) + { + if (NN_UNLIKELY(size == 0)) { + NN_LOG_ERROR("size is 0"); + return NN_ERROR; + } + + if (NN_UNLIKELY(mLockWhenOperates)) { + pthread_rwlock_rdlock(&mRwlock); + } + if (NN_UNLIKELY(mRangeCache.count(key) == 0)) { + NN_LOG_ERROR("LKey is Wrong " << key); + if (NN_UNLIKELY(mLockWhenOperates)) { + pthread_rwlock_unlock(&mRwlock); + } + return NN_ERROR; + } else { + auto range = mRangeCache[key]; + if (NN_UNLIKELY(mLockWhenOperates)) { + pthread_rwlock_unlock(&mRwlock); + } + if (address >= range.first && address + size <= range.second) { + return NN_OK; + } + NN_LOG_ERROR("Address does not match lKey, size:" << size); + return NN_ERROR; + } + } + + inline bool Contains(uint64_t key) + { + pthread_rwlock_rdlock(&mRwlock); + if (NN_UNLIKELY(mRangeCache.count(key) == 0)) { + pthread_rwlock_unlock(&mRwlock); + return false; + } else { + pthread_rwlock_unlock(&mRwlock); + return true; + } + } + + inline NResult Register(uint64_t key, uintptr_t address, uint64_t size) + { + pthread_rwlock_wrlock(&mRwlock); + if (NN_UNLIKELY(mRangeCache.count(key) > 0)) { + pthread_rwlock_unlock(&mRwlock); + return NN_ERROR; + } + mRangeCache[key] = {address, address + size}; + pthread_rwlock_unlock(&mRwlock); + return NN_OK; + } + + inline void UnRegister(uint64_t key) + { + pthread_rwlock_wrlock(&mRwlock); + mRangeCache.erase(key); + pthread_rwlock_unlock(&mRwlock); + } + + inline void SetLockWhenOperates(bool shouldLock) + { + mLockWhenOperates = shouldLock; + } + + inline void Reserve(uint32_t size) + { + mRangeCache.reserve(size); + } + +private: + std::unordered_map> mRangeCache; + ::pthread_rwlock_t mRwlock {}; + bool mLockWhenOperates = false; +}; + +inline NResult FilterIp(const std::string &ipMask, std::vector &outIps) +{ + in_addr_t mask = 0; + in_addr_t inputIpByMask = 0; + if (!NetFunc::NN_CovertIpMask(ipMask, inputIpByMask, mask)) { + NN_LOG_ERROR("Ip mask is invalid " << ipMask << + ", should be something like '192.168.2.1/24', 24 means the left 24 bits will be " + "the condition to compare"); + return NN_ERROR; + } + + struct ifaddrs *addresses = nullptr; + if (getifaddrs(&addresses) != 0) { + NN_LOG_ERROR("Failed to get interface addresses"); + return NN_ERROR; + } + + struct ifaddrs *iter = addresses; + while (iter != nullptr) { + if (iter->ifa_addr == nullptr || + iter->ifa_addr->sa_family != AF_INET || + ((reinterpret_cast(iter->ifa_addr))->sin_addr.s_addr & mask) != inputIpByMask) { + iter = iter->ifa_next; + continue; + } + + char ipStr[INET_ADDRSTRLEN] = {0}; + inet_ntop(AF_INET, &((reinterpret_cast(iter->ifa_addr))->sin_addr), ipStr, + INET_ADDRSTRLEN); + outIps.emplace_back(ipStr); + + iter = iter->ifa_next; + } + freeifaddrs(addresses); + return NN_OK; +} + +inline bool ValidateArrayOptions(const char *src, uint32_t srcLen) +{ + if (NN_UNLIKELY(src == nullptr) || NN_UNLIKELY(srcLen <= NN_NO0)) { + return false; + } + for (uint32_t i = 0; i < srcLen; ++i) { + if (src[i] == '\0') { + return true; + } + } + NN_LOG_ERROR("The array length is too long, it must less or equal to " << srcLen); + return false; +} +} +} + +#endif diff --git a/src/common/net_crc32.cpp b/src/common/net_crc32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..997e55b4b16ade8d2c9ccff8efc5536ab11d1a3c --- /dev/null +++ b/src/common/net_crc32.cpp @@ -0,0 +1,293 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_crc32.h" + +namespace ock { +namespace hcom { +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#ifdef USE_HARDWARE_CRC +#ifdef __aarch64__ +static inline uint32_t CalcCrc64Bit(const uint8_t **buffer, uint32_t len, uint32_t crc) +{ + auto tmpBuf = reinterpret_cast(*buffer); + auto crc64 = (uint64_t)crc; + + for (uint32_t index = 0; index < len; index++) { + __asm__ volatile("crc32cx %w[c], %w[c], %x[v]" : [ c ] "+r"(crc64) : [ v ] "r"(*tmpBuf++)); + } + + *buffer = reinterpret_cast(tmpBuf); + return (uint32_t)crc64; +} + +static inline uint32_t CalcCrc32Bit(const uint8_t **buffer, uint32_t len, uint32_t crc) +{ + auto tmpBuf = reinterpret_cast(*buffer); + for (uint32_t index = 0; index < len; index++) { + __asm__ volatile("crc32cw %w[c], %w[c], %w[v]" : [ c ] "+r"(crc) : [ v ] "r"(*tmpBuf++)); + } + + *buffer = reinterpret_cast(tmpBuf); + return crc; +} + +static inline uint32_t CalcCrc16Bit(const uint8_t **buffer, uint32_t len, uint32_t crc) +{ + auto tmpBuf = reinterpret_cast(*buffer); + for (uint32_t index = 0; index < len; index++) { + __asm__ volatile("crc32ch %w[c], %w[c], %w[v]" : [ c ] "+r"(crc) : [ v ] "r"(*tmpBuf++)); + } + + *buffer = reinterpret_cast(tmpBuf); + return crc; +} + +static inline uint32_t CalcCrc8Bit(const uint8_t **buffer, uint32_t len, uint32_t crc) +{ + auto tmpBuf = *buffer; + + for (uint32_t index = 0; index < len; index++) { + __asm__ volatile("crc32cb %w[c], %w[c], %w[v]" : [ c ] "+r"(crc) : [ v ] "r"(*tmpBuf++)); + } + + *buffer = tmpBuf; + return crc; +} + +uint32_t NetCrc32C(const void *buffer, uint32_t length) +{ + uint32_t crc = 0xffffffff; + uint32_t len = length; + + if (NN_UNLIKELY(buffer == nullptr || len == 0)) { + return 0; + } + + auto tmpBuf = reinterpret_cast(buffer); + + // 首地址可能不对齐,考虑到 UBSHcomNetTransHeader 头部较小,计算时仅需 24 字节,所 + // 以采用 crc32c 4B 硬件指令而非 8B 以带来较好的 throughput、较低的 latency. + constexpr uint32_t STEP = 4; + const uint32_t toAlign = static_cast(-(uintptr_t)tmpBuf & (STEP - 1)); + switch (toAlign) { + case NN_NO3: + __asm__ volatile("crc32cb %w[c], %w[c], %w[v]" : [ c ] "+r"(crc) : [ v ] "r"(*tmpBuf++)); + // fallthrough + case NN_NO2: + __asm__ volatile("crc32cb %w[c], %w[c], %w[v]" : [ c ] "+r"(crc) : [ v ] "r"(*tmpBuf++)); + // fallthrough + case NN_NO1: + __asm__ volatile("crc32cb %w[c], %w[c], %w[v]" : [ c ] "+r"(crc) : [ v ] "r"(*tmpBuf++)); + // fallthrough + case NN_NO0: + break; + } + + length -= toAlign; + for (; length >= STEP; length -= STEP, tmpBuf += STEP) { + __asm__ volatile("crc32cw %w[c], %w[c], %w[v]" : [ c ] "+r"(crc) : [ v ] "r"(*(uint32_t*)tmpBuf)); + } + + for (; length; --length, ++tmpBuf) { + __asm__ volatile("crc32cb %w[c], %w[c], %w[v]" : [ c ] "+r"(crc) : [ v ] "r"(*tmpBuf)); + } + return ~crc; +} + +#else /* for x86 */ + +#if defined(__x86_64__) /* for x86_64 */ +static inline uint32_t CalcCrc64Bit(const uint8_t **buffer, uint32_t len, uint32_t crc) +{ + const uint8_t *tmpBuf = *buffer; + auto crc64 = (uint64_t)crc; + + for (uint32_t index = 0; index < len; index++) { + __asm__ volatile("crc32q\t" + "(%1), %0" + : "+r"(crc64) + : "r"(tmpBuf), "m"(*tmpBuf)); + tmpBuf += sizeof(uint64_t); + } + + *buffer = tmpBuf; + return (uint32_t)crc64; +} +#endif + +static inline uint32_t CalcCrc32Bit(const uint8_t **buffer, uint32_t len, uint32_t crc) +{ + const uint8_t *tmpBuf = *buffer; + + for (uint32_t index = 0; index < len; index++) { + __asm__ volatile("crc32l\t" + "(%1), %0" + : "+r"(crc) + : "r"(tmpBuf), "m"(*tmpBuf)); + tmpBuf += sizeof(uint32_t); + } + + *buffer = tmpBuf; + return crc; +} + +static inline uint32_t CalcCrc16Bit(const uint8_t **buffer, uint32_t len, uint32_t crc) +{ + const uint8_t *tmpBuf = *buffer; + + for (uint32_t index = 0; index < len; index++) { + __asm__ volatile("crc32w\t" + "(%1), %0" + : "+r"(crc) + : "r"(tmpBuf), "m"(*tmpBuf)); + tmpBuf += sizeof(uint16_t); + } + + *buffer = tmpBuf; + return crc; +} + +static inline uint32_t CalcCrc8Bit(const uint8_t **buffer, uint32_t len, uint32_t crc) +{ + const uint8_t *tmpBuf = *buffer; + + for (uint32_t index = 0; index < len; index++) { + __asm__ volatile("crc32b\t" + "(%1), %0" + : "+r"(crc) + : "r"(tmpBuf), "m"(*tmpBuf)); + tmpBuf += sizeof(uint8_t); + } + + *buffer = tmpBuf; + return crc; +} + +uint32_t NetCrc32C(const void *buffer, uint32_t length) +{ + uint32_t crc = 0xffffffff; + uint32_t len = length; + + if (NN_UNLIKELY(len == 0 || buffer == nullptr)) { + return 0; + } + + auto tmpBuf = reinterpret_cast(buffer); + + /* The performance of the crc32 command is better than that of the __mm_crc32/__built-in _ia32_crc32 command. */ +#if defined(__x86_64__) + crc = CalcCrc64Bit(&tmpBuf, len / sizeof(uint64_t), crc); + len &= sizeof(uint64_t) - 1; +#endif + crc = CalcCrc32Bit(&tmpBuf, len / sizeof(uint32_t), crc); + len &= sizeof(uint32_t) - 1; + + crc = CalcCrc16Bit(&tmpBuf, len / sizeof(uint16_t), crc); + len &= sizeof(uint16_t) - 1; + + crc = CalcCrc8Bit(&tmpBuf, len / sizeof(uint8_t), crc); + + return ~crc; +} +#endif + +#else // USE_SOFTWARE_CRC +// CRC-32C Poly:1EDC6F41 +static uint32_t gCrcTable[NN_NO256] = { + 0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, + 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, + 0x8AD958CF, 0x78B2DBCC, 0x6BE22838, 0x9989AB3B, + 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, + 0x105EC76F, 0xE235446C, 0xF165B798, 0x030E349B, + 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, + 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57, 0x89D76C54, + 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, + 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A, + 0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, + 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, + 0x6DFE410E, 0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, + 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, + 0xF779DEAE, 0x05125DAD, 0x1642AE59, 0xE4292D5A, + 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, + 0x7DA08661, 0x8FCB0562, 0x9C9BF696, 0x6EF07595, + 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, + 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957, + 0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, + 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, + 0x5125DAD3, 0xA34E59D0, 0xB01EAA24, 0x42752927, + 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, + 0xDBFC821C, 0x2997011F, 0x3AC7F2EB, 0xC8AC71E8, + 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, + 0x61C69362, 0x93AD1061, 0x80FDE395, 0x72966096, + 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, + 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859, + 0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, + 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, + 0xB602C312, 0x44694011, 0x5739B3E5, 0xA55230E6, + 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, + 0x3CDB9BDD, 0xCEB018DE, 0xDDE0EB2A, 0x2F8B6829, + 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, + 0x456CAC67, 0xB7072F64, 0xA457DC90, 0x563C5F93, + 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, + 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C, + 0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, + 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, + 0x1871A4D8, 0xEA1A27DB, 0xF94AD42F, 0x0B21572C, + 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, + 0xA24BB5A6, 0x502036A5, 0x4370C551, 0xB11B4652, + 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, + 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, 0x3BC21E9D, + 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, + 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D, + 0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, + 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, + 0xFF56BD19, 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, + 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, + 0x0417B1DB, 0xF67C32D8, 0xE52CC12C, 0x1747422F, + 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, + 0x8ECEE914, 0x7CA56A17, 0x6FF599E3, 0x9D9E1AE0, + 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, + 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540, + 0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, + 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, + 0xE330A81A, 0x115B2B19, 0x020BD8ED, 0xF0605BEE, + 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, + 0x69E9F0D5, 0x9B8273D6, 0x88D28022, 0x7AB90321, + 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, + 0xF36E6F75, 0x0105EC76, 0x12551F82, 0xE03E9C81, + 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, + 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E, + 0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351L +}; + +uint32_t NetCrc32C(const void *buffer, uint32_t length) +{ + auto p = reinterpret_cast(buffer); + uint32_t crc = 0xFFFFFFFF; + for (uint32_t i = 0; i < length; i++) { + crc = ((crc >> NN_NO8) ^ gCrcTable[((*p) ^ crc) & 0xFF]); + p++; + } + + return ~crc; +} +#endif + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +Crc32Function NetCrc32::gCrc32Func = NetCrc32C; +} +} diff --git a/src/common/net_crc32.h b/src/common/net_crc32.h new file mode 100644 index 0000000000000000000000000000000000000000..435635cb99abfd3ed0b2a41a251cda76de655617 --- /dev/null +++ b/src/common/net_crc32.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_CRC32_H +#define HCOM_CRC32_H + +#include "hcom_def.h" +#include "hcom_log.h" + +namespace ock { +namespace hcom { +/* + * @brief brief calculate crc32 + * + * @param buffer [in] which is to be calculated. + * @param length [in] calculate buff length. + * + * @return crc32 value. + * + */ +using Crc32Function = uint32_t (*)(const void *buffer, uint32_t length); + +class NetCrc32 { +public: + /* + * @brief brief calculate crc32 + * + * @param buffer [in] which is to be calculated. + * @param length [in] calculate buff length. + * + * @return crc32 value. + * + */ + static inline uint32_t CalcCrc32(const void *buffer, uint32_t length) + { + return gCrc32Func(buffer, length); + } + +private: + static Crc32Function gCrc32Func; +}; +} +} +#endif // HCOM_CRC32_H \ No newline at end of file diff --git a/src/common/net_execution_service.cpp b/src/common/net_execution_service.cpp new file mode 100644 index 0000000000000000000000000000000000000000..054e158a2d54bb38df3e190a6fe3df8ae0739c0a --- /dev/null +++ b/src/common/net_execution_service.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include + +#include "net_execution_service.h" + +namespace ock { +namespace hcom { +bool NetExecutorService::Start() +{ + if (mStarted) { + return true; + } + + /* init ring buffer blocking queue */ + auto result = mRunnableQueue.Initialize(); + if (result != 0) { + NN_LOG_ERROR("Failed to initialize queue, result " << result); + return false; + } + + for (uint16_t i = 0; i < mThreadNum; i++) { + auto cpuId = mCpuSetStartIdx < 0 ? -1 : mCpuSetStartIdx + i; + auto *thr = new (std::nothrow) std::thread(&NetExecutorService::RunInThread, this, cpuId); + if (thr == nullptr) { + ForceStop(); + NN_LOG_ERROR("Failed to create executor thread " << i); + return false; + } + + mThreads.push_back(thr); + } + + while (mStartedThreadNum < mThreadNum) { + usleep(1); + } + + mStarted = true; + mStopped = false; + return true; +} + +void NetExecutorService::ForceStop() +{ + for (uint32_t i = 0; i < mThreads.size(); ++i) { + NetRunnablePtr stopTask = new (std::nothrow) NetRunnable(); + if (stopTask == nullptr) { + NN_LOG_ERROR("Failed to new stop task, probably out of memory"); + break; + } + stopTask->Type(NetRunnableType::STOP); + + NetRunnable *tmp = stopTask.Get(); + tmp->IncreaseRef(); + if (!mRunnableQueue.EnqueueFirst(tmp)) { + continue; + } + } + + for (auto &thr : mThreads) { + if (thr != nullptr) { + thr->join(); + } + } + + mRunnableQueue.UnInitialize(); + + while (!mThreads.empty()) { + delete (mThreads.back()); + mThreads.pop_back(); + } + mStopped = true; + mStarted = false; +} + +void NetExecutorService::Stop() +{ + if (!mStarted || mStopped) { + return; + } + ForceStop(); +} + +void NetExecutorService::DoRunnable(bool &flag) +{ + try { + NetRunnable *task = nullptr; + mRunnableQueue.Dequeue(task); + if (task != nullptr) { + /* the ref count of `task` was manually increased when enqueue, and it will be automatically increased again + when assignning to `runnable`, so it should be decreased explicitly after assignment to make the + ref count = 1. */ + NetRunnablePtr runnable = task; + task->DecreaseRef(); + if (runnable->Type() == NetRunnableType::NORMAL) { + runnable->Run(); + } else if (runnable->Type() == NetRunnableType::STOP) { + flag = false; + } else { + NN_LOG_ERROR("Un-reachable path"); + } + } else { + NN_LOG_ERROR("Task is null"); + } + } catch (std::runtime_error &ex) { + NN_LOG_ERROR("Caught error " << ex.what() << " when execute a task, continue"); + } catch (...) { + NN_LOG_ERROR("Caught unknown error when execute a task, continue"); + } +} + +void NetExecutorService::RunInThread(int16_t cpuId) +{ + bool runFlag = true; + uint16_t threadIndex = mStartedThreadNum++; + + auto threadName = mThreadName.empty() ? "executor" : mThreadName; + threadName += std::to_string(threadIndex); + if (cpuId != -1) { + cpu_set_t cpuSet; + CPU_ZERO(&cpuSet); + CPU_SET(cpuId, &cpuSet); + if (pthread_setaffinity_np(pthread_self(), sizeof(cpuSet), &cpuSet) != 0) { + NN_LOG_WARN("Invalid to bind executor thread" << threadName << " << to cpu " << cpuId); + } + } + + pthread_setname_np(pthread_self(), threadName.c_str()); + NN_LOG_INFO("Thread is started for executor service <" << threadName << "> cpuId " << cpuId); + + while (runFlag) { + DoRunnable(runFlag); + } + + NN_LOG_INFO("Thread for executor service <" << threadName << "> cpuId " << cpuId << " exiting"); +} +} +} diff --git a/src/common/net_execution_service.h b/src/common/net_execution_service.h new file mode 100644 index 0000000000000000000000000000000000000000..b42bcd0b80b4cf1bf53bba1f7e89df0c0633163e --- /dev/null +++ b/src/common/net_execution_service.h @@ -0,0 +1,181 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_NET_EXECUTION_SERVICE_H +#define HCOM_NET_EXECUTION_SERVICE_H + +#include + +#include "hcom.h" + +namespace ock { +namespace hcom { +enum NetRunnableType { + NORMAL = 0, + STOP = 1, +}; + +/* + * @brief Base class of runnable task + */ +class NetRunnable { +public: + virtual ~NetRunnable() = default; + + virtual void Run() {} + + DEFINE_RDMA_REF_COUNT_FUNCTIONS +private: + inline void Type(NetRunnableType type) + { + mType = type; + } + + inline NetRunnableType Type() const + { + return mType; + } + +private: + NetRunnableType mType = NetRunnableType::NORMAL; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class NetExecutorService; +}; +using NetRunnablePtr = NetRef; + +constexpr uint32_t ES_MAX_THR_NUM = 256; + +class NetExecutorService; +using NetExecutorServicePtr = NetRef; + + +/* + * @brief Execution service is fixed thread pool to task execution + */ +class NetExecutorService { +public: + /* + * @brief Create an execution service with fixed number of threads + * + * @param threadNum [in] number of threads + * @param queueCapacity [in] capacity of inner queue to store tasks + * + * @return executor ptr if successfully, otherwise return null + */ + static NetExecutorServicePtr Create(uint16_t threadNum, uint32_t queueCapacity = 10000) + { + if (threadNum > ES_MAX_THR_NUM || threadNum == 0) { + NN_LOG_ERROR("The num of thread must 1-" << ES_MAX_THR_NUM); + return nullptr; + } + + return new (std::nothrow) NetExecutorService(threadNum, queueCapacity); + } + +public: + ~NetExecutorService() + { + if (!mStopped) { + Stop(); + } + } + + /* + * @brief Start the execution service, wait for all threads started + * + * @return true if successfully + */ + bool Start(); + + /* + * @brief Stop the execution service, wait for all threads exited + */ + void Stop(); + + /* + * @brief Enqueue a task to thread pool, need to ensure this has been started + * + * The ref count of runnable will be increased and will be decreased after executed + * + * @return true if enqueue successfully, otherwise the queue is full + */ + inline bool Execute(const NetRunnablePtr &runnable) + { + auto tmp = runnable.Get(); + if (NN_UNLIKELY(tmp == nullptr)) { + return false; + } + + tmp->IncreaseRef(); + return mRunnableQueue.Enqueue(tmp); + } + + /* + * @brief Set the thread name prefix + * + * @param name [in] prefix name of execute service working thread + */ + inline void SetThreadName(const std::string &name) + { + mThreadName = name; + } + + /* + * @brief Bind the cpu for working threads + * + * @param idx [in] starting index cpu id to bind to working threads + */ + inline void SetCpuSetStartIndex(int16_t idx) + { + mCpuSetStartIdx = idx; + } + + inline bool IsStart() + { + return mStarted.load(); + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + NetExecutorService(uint16_t threadNum, uint32_t queueCapacity) + : mRunnableQueue(queueCapacity), + mThreadNum(threadNum), + mThreads(0), + mStarted(false), + mStopped(false), + mStartedThreadNum(0) + {} + + void RunInThread(int16_t cpuId); + void DoRunnable(bool &flag); + void ForceStop(); + +private: + NetBlockingQueue mRunnableQueue; + uint16_t mThreadNum = 0; + int16_t mCpuSetStartIdx = -1; + std::vector mThreads; + + std::atomic mStarted; + std::atomic mStopped; + std::atomic mStartedThreadNum; + + std::string mThreadName; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} + +#endif // HCOM_NET_EXECUTION_SERVICE_H diff --git a/src/common/net_linked_list.h b/src/common/net_linked_list.h new file mode 100644 index 0000000000000000000000000000000000000000..1648661765d2745b74db900ef6a260678e694477 --- /dev/null +++ b/src/common/net_linked_list.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_LINKEDLIST_H +#define HCOM_LINKEDLIST_H + +namespace ock { +namespace hcom { +template struct NetLinkedListNode { + T data; + NetLinkedListNode *prev; + NetLinkedListNode *next; + + NetLinkedListNode() + { + data = T(); + ReLinkSelf(); + } + + void ReLinkSelf() + { + prev = this; + next = this; + } + + /* + * @brief Insert node between prev and next + */ + void InsertBetween(NetLinkedListNode *prevNode, NetLinkedListNode *nextNode) + { + if (NN_UNLIKELY(prevNode == nullptr || nextNode == nullptr)) { + NN_LOG_ERROR("Invalid prevNode or nextNode"); + return; + } + nextNode->prev = this; + this->next = nextNode; + this->prev = prevNode; + prevNode->next = this; + } + + /* + * @brief Remove self from linked list + */ + void RemoveSelf() + { + if (next != nullptr) { + next->prev = prev; + } + if (prev != nullptr) { + prev->next = next; + } + } +}; + +template struct NetLinkedList { + NetLinkedListNode head; + + NetLinkedList() = default; + + /* + * @brief Link node to list's tail + */ + void Append(NetLinkedListNode *node) + { + if (NN_UNLIKELY(node == nullptr)) { + NN_LOG_ERROR("Invalid node"); + return; + } + node->InsertBetween(head.prev, &head); + } + + bool IsEmpty() + { + return head.next == &head; + } +}; +} +} +#endif // HCOM_LINKEDLIST_H diff --git a/src/common/net_mem_allocator.cpp b/src/common/net_mem_allocator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc0289ffb979d3d4bba775a2d16628192d78c149 --- /dev/null +++ b/src/common/net_mem_allocator.cpp @@ -0,0 +1,540 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_common.h" +#include "net_mem_allocator.h" + +#define MEM_ALLOCATOR_ATOMIC_INC(x) __sync_add_and_fetch((x), 1) +#define MEM_ALLOCATOR_ATOMIC_DEC(x) __sync_sub_and_fetch((x), 1) + +/* MetaInfo layout || indicates memory node separate, | indicates information + * part separate + * ---- || -----|#c_a(4 bytes canary)|4096(4 bytes length of next memory block) + * || X(allocated memory)-------- || ------ + */ + +#define CAST_TO_LIST_NODE(any) (reinterpret_cast> *>(any)) +#define CAST_TO_LIST(any) (reinterpret_cast> *>(any)) +#define CAST_TO_RBNODE(any) (reinterpret_cast *>(any)) +#define NODE_SIZE sizeof(NetLinkedList>) + +namespace ock { +namespace hcom { +const uint64_t MEM_ALLOCATOR_BASE_SHIFT = NN_NO12; /* 1 << 12 == 4096 */ + +/* + * @brief init member MemoryRegion + */ +void NetMemAllocator::MemoryRegionInitial() +{ + mMemRegionMgr.mRoot.ref = nullptr; + + for (auto &index : mMemRegionMgr.freeCnt) { + index = 0LL; + } + + mMemRegionMgr.totalSize = 0UL; + mMemRegionMgr.freeSize = 0UL; +} + +#define VALIDATE_MEM_AREA(ma) \ + if (NN_UNLIKELY((ma) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, ma must be correct address"); \ + return; \ + } + +#define VALIDATE_NEW_MEM_AREA(newMa) \ + if (NN_UNLIKELY((newMa) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, newMa must be correct address"); \ + return; \ + } + +#define VALIDATE_MEDIA_DESC(mediaDesc) \ + if (NN_UNLIKELY((mediaDesc) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, mediaDesc must be correct address"); \ + return NN_INVALID_PARAM; \ + } + +#define VALIDATE_ADDRESS(address) \ + if (NN_UNLIKELY((address) == nullptr)) { \ + NN_LOG_ERROR("Invalid param, address must be correct address"); \ + return NN_INVALID_PARAM; \ + } +/* + * @brief when recycled memory is adjacent behind some memories block, we should + * merge the two block to larger block for potential large block demand in the + * future, as we said, self-scaling + * + * @param newMa recycled memory block, which should be merged + * @param ma memory block adjacent ahead newMa, which should be merged + * @param newNode + */ +void MemoryRegion::MemoryAreaInsertPre(NetRbNode *newMa, NetRbNode *ma) +{ + VALIDATE_NEW_MEM_AREA(newMa) + VALIDATE_MEM_AREA(ma) + + auto root = &mRoot; + NetRbNode *neighbNode = nullptr; + MemoryAreaRawPtr neighMa; + uint32_t index = 0; + newMa->data.endAddress = ma->data.endAddress; + newMa->data.length += ma->data.length; + neighbNode = ma->Prev(); + /* + * check whether the memory block to be extended forward has a prev node, + * which may also adjacent to recycled memory, as its end address is recycled + * start address, so, the merge process change from new + ma to ma.prev + + * new + ma + */ + if (neighbNode != nullptr) { + neighMa = &neighbNode->data; + if (neighMa->endAddress == newMa->data.startAddress) { + neighMa->endAddress = newMa->data.endAddress; + neighMa->length += newMa->data.length; + CAST_TO_LIST_NODE(&ma->data)->RemoveSelf(); + MEM_ALLOCATOR_ATOMIC_DEC(&freeCnt[ma->data.index]); + CAST_TO_LIST_NODE(neighMa)->RemoveSelf(); + MEM_ALLOCATOR_ATOMIC_DEC(&freeCnt[neighMa->index]); + index = neighMa->length >> MEM_ALLOCATOR_BASE_SHIFT; + index = (index >= FREE_LIST_NUM) ? (FREE_LIST_NUM - 1) : (index - 1); + neighMa->index = index; + CAST_TO_LIST(&freeHead[index])->Append(CAST_TO_LIST_NODE(neighMa)); + MEM_ALLOCATOR_ATOMIC_INC(&freeCnt[index]); + root->Erase(ma); + return; + } + } + + CAST_TO_LIST_NODE(&ma->data)->RemoveSelf(); + MEM_ALLOCATOR_ATOMIC_DEC(&freeCnt[ma->data.index]); + index = newMa->data.length >> MEM_ALLOCATOR_BASE_SHIFT; + index = (index >= FREE_LIST_NUM) ? (FREE_LIST_NUM - 1) : (index - 1); + newMa->data.index = index; + CAST_TO_LIST(&freeHead[index])->Append(CAST_TO_LIST_NODE(newMa)); + MEM_ALLOCATOR_ATOMIC_INC(&freeCnt[index]); + root->Replace(ma, newMa); +} + +/* + * @brief same as MemoryAreaInsertPre + */ +void MemoryRegion::MemoryAreaInsertNext(NetRbNode *newMa, NetRbNode *ma) +{ + VALIDATE_NEW_MEM_AREA(newMa) + VALIDATE_MEM_AREA(ma) + + auto root = &mRoot; + NetRbNode *neighbNode; + MemoryAreaRawPtr neighMa = nullptr; + uint32_t index = 0; + ma->data.endAddress = newMa->data.endAddress; + ma->data.length += newMa->data.length; + neighbNode = ma->Next(); + if (neighbNode != nullptr) { + neighMa = &neighbNode->data; + if (neighMa->startAddress == ma->data.endAddress) { + ma->data.endAddress = neighMa->endAddress; + ma->data.length += neighMa->length; + CAST_TO_LIST_NODE(ma)->RemoveSelf(); + MEM_ALLOCATOR_ATOMIC_DEC(&freeCnt[ma->data.index]); + CAST_TO_LIST_NODE(neighMa)->RemoveSelf(); + MEM_ALLOCATOR_ATOMIC_DEC(&freeCnt[neighMa->index]); + index = ma->data.length >> MEM_ALLOCATOR_BASE_SHIFT; + index = (index >= FREE_LIST_NUM) ? (FREE_LIST_NUM - 1) : (index - 1); + ma->data.index = index; + CAST_TO_LIST(&freeHead[index])->Append(CAST_TO_LIST_NODE(ma)); + MEM_ALLOCATOR_ATOMIC_INC(&freeCnt[index]); + root->Erase(neighbNode); + return; + } + } + + CAST_TO_LIST_NODE(ma)->RemoveSelf(); + + MEM_ALLOCATOR_ATOMIC_DEC(&freeCnt[ma->data.index]); + index = ma->data.length >> MEM_ALLOCATOR_BASE_SHIFT; + index = (index >= FREE_LIST_NUM) ? (FREE_LIST_NUM - 1) : (index - 1); + ma->data.index = index; + CAST_TO_LIST(&freeHead[index])->Append(CAST_TO_LIST_NODE(ma)); + MEM_ALLOCATOR_ATOMIC_INC(&freeCnt[index]); +} + +/* + * @brief expand pool, just a wrapper of MemoryAreaInsert with some initial work + * now it's private,only invoked by MemoryRegionInit, may be opened later for + * dynamic expanding feature + */ +NResult NetMemAllocator::MemoryRegionJoin(MediaDescribe *mediaDesc) +{ + VALIDATE_MEDIA_DESC(mediaDesc) + + int32_t ret = NN_OK; + + { + std::lock_guard lock(mMemRegionLock); + MemoryRegionInitial(); + + uint64_t dataLength = mediaDesc->endAddress - mediaDesc->startAddress; + mMemRegionMgr.totalSize += dataLength; + ret = mMemRegionMgr.MemoryAreaInsert(mediaDesc->startAddress, dataLength); + } + + if (ret != NN_OK) { + NN_LOG_ERROR("Region join failed, new address is not overlapped with the existed"); + return ret; + } + NN_LOG_INFO("Region join succeed"); + return NN_OK; +} + +NResult NetMemAllocator::MemoryRegionInit(MediaDescribe &media) +{ + NResult hr = NN_OK; + hr = MemoryRegionJoin(&media); + if (hr != NN_OK) { + NN_LOG_ERROR("Memory region init failed, ret " << hr); + return hr; + } + + NN_LOG_INFO("Memory region init succeed."); + return NN_OK; +} + +/* + * @brief take memory from pool + * @param length rounded request memory length + * @param deltaLength rounded length - request length + * + * "#c_a" no reserve canary,aligned size minus applied size enough to store + * "#c_r" reserve canary + */ +NResult NetMemAllocator::RegionMalloc(uint64_t &startAddress, uint64_t length, uint64_t deltaLength) +{ + MemoryRegionRawPtr memoryRegion = &mMemRegionMgr; + startAddress = 0; + int32_t ret = NN_OK; + uint64_t lengthWithMeta = length + MA_META_DATA_RESERVE_LEN; + auto needReserve = deltaLength < MA_META_DATA_RESERVE_LEN; + + { + std::lock_guard lock(mMemRegionLock); + + if (!needReserve || NN_UNLIKELY(length == mMemRegionMgr.freeSize && length == mMemRegionMgr.totalSize)) { + lengthWithMeta -= MA_META_DATA_RESERVE_LEN; + } + + ret = memoryRegion->MemoryAreaRemove(&startAddress, lengthWithMeta, mMinBlockSize); + if (NN_UNLIKELY(ret != NN_OK)) { + NN_LOG_WARN("Areas scan invalid, length " << lengthWithMeta << " remain " << memoryRegion->freeSize); + return NN_ERROR; + } + + if (startAddress == mMRAddress) { + mFirstAllocLength = lengthWithMeta; + mFirstReqLength = length - deltaLength; + } else { + if (startAddress < MA_META_DATA_RESERVE_LEN) { + NN_LOG_WARN("startAddress don't have enough space for meta data"); + return NN_ERROR; + } + auto metaBaseAddr = startAddress - MA_META_DATA_RESERVE_LEN; + auto metaCanaryAddress = reinterpret_cast(metaBaseAddr); + auto metaLenAddress = reinterpret_cast(metaBaseAddr + MA_CANARY_LEN); + auto metaReqLenAddress = reinterpret_cast(metaBaseAddr + MA_CANARY_LEN + MA_LEN_LEN); + + if (NN_UNLIKELY(memcpy_s(metaCanaryAddress, MA_CANARY_LEN, needReserve ? "#c_r" : "#c_a", MA_CANARY_LEN) != + NN_OK)) { + NN_LOG_WARN("Invalid operation to memcpy_s in RegionMalloc"); + return NN_ERROR; + } + *metaLenAddress = length / mMinBlockSize; + *metaReqLenAddress = length - deltaLength; + } + } + return NN_OK; +} + +NResult NetMemAllocator::RegionMallocWithMap(uint64_t &startAddress, uint64_t length) +{ + MemoryRegionRawPtr memoryRegion = &mMemRegionMgr; + startAddress = 0; + int32_t ret = NN_OK; + { + std::lock_guard lock(mMemRegionLock); +#ifdef ALLOCATOR_PROTECTION_ENABLED + if (NN_UNLIKELY(!CheckNodes())) { + NN_LOG_ERROR("Allocator corrupted, Allocate failed"); + return NN_ERROR; + } + UnProtectAllMem(); +#endif + ret = memoryRegion->MemoryAreaRemove(&startAddress, length, mMinBlockSize, false); + if (NN_UNLIKELY(ret != NN_OK)) { + NN_LOG_WARN("Areas scan invalid, length " << length << " remain " << memoryRegion->freeSize); +#ifdef ALLOCATOR_PROTECTION_ENABLED + ProtectFreeMem(); +#endif + return NN_ERROR; + } + + mAddrLenMap[startAddress] = length; +#ifdef ALLOCATOR_PROTECTION_ENABLED + ProtectFreeMem(); +#endif + } + return NN_OK; +} + +NResult NetMemAllocator::RegionFreeWithMap(uint64_t startAddress) +{ + std::lock_guard lock(mMemRegionLock); +#ifdef ALLOCATOR_PROTECTION_ENABLED + if (NN_UNLIKELY(!CheckNodes())) { + NN_LOG_ERROR("Allocator corrupted, Allocate failed"); + return NN_ERROR; + } + UnProtectAllMem(); +#endif + if (mAddrLenMap.count(startAddress) == 0) { + NN_LOG_ERROR("Areas scan failed, address not malloc!"); +#ifdef ALLOCATOR_PROTECTION_ENABLED + ProtectFreeMem(); +#endif + return NN_ERROR; + } + auto length = mAddrLenMap[startAddress]; + + auto ret = mMemRegionMgr.MemoryAreaInsert(startAddress, length); + if (ret == NN_OK) { + mAddrLenMap.erase(startAddress); + } +#ifdef ALLOCATOR_PROTECTION_ENABLED + ProtectFreeMem(); +#endif + + return ret; +} + +NResult NetMemAllocator::RegionFree(uint64_t startAddress) +{ + uint64_t length = 0; + int32_t ret = NN_OK; + bool freeingFirstBlock = false; + + { + std::lock_guard lock(mMemRegionLock); + + if (startAddress == mMRAddress) { + if (NN_UNLIKELY(mFirstAllocLength == 0)) { + NN_LOG_ERROR("Address Invalid"); + return NN_ERROR; + } + length = mFirstAllocLength; + freeingFirstBlock = true; + } else { + if (startAddress < MA_META_DATA_RESERVE_LEN) { + NN_LOG_WARN("startAddress don't have enough space for meta data"); + return NN_ERROR; + } + auto metaCanaryAddress = reinterpret_cast(startAddress - MA_META_DATA_RESERVE_LEN); + auto metaLenAddress = reinterpret_cast(startAddress - MA_META_DATA_RESERVE_LEN + MA_CANARY_LEN); + + if (memcmp(metaCanaryAddress, "#c_r", MA_CANARY_LEN) == 0) { + length = mMinBlockSize * (*metaLenAddress) + MA_META_DATA_RESERVE_LEN; + } else if (memcmp(metaCanaryAddress, "#c_a", MA_CANARY_LEN) == 0) { + length = mMinBlockSize * (*metaLenAddress); + } else { + NN_LOG_ERROR("Address Invalid"); + return NN_ERROR; + } + } + + ret = mMemRegionMgr.MemoryAreaInsert(startAddress, length); + if (ret == NN_OK && freeingFirstBlock) { + mFirstAllocLength = 0; + mFirstReqLength = 0; + } + } + + if (NN_UNLIKELY(ret != NN_OK)) { + NN_LOG_ERROR("Areas scan failed, length " << length); + return NN_ERROR; + } + + NN_LOG_TRACE_INFO("Mem free success, length " << length); + return NN_OK; +} + +/* + * @brief take operate,traverse freeHead to find available memory block, + * then remove the related node from freeHead lists and red-black tree + */ +NResult MemoryRegion::MemoryAreaRemove(uint64_t *startAddress, uint64_t length, uint32_t minBlockSize, bool metaStored) +{ + VALIDATE_ADDRESS(startAddress) + + auto root = &mRoot; + NetRbNode *ma = nullptr; + NetRbNode *newMa = nullptr; + uint32_t index = 0; + uint32_t areaIndex = 0; + uint32_t nIndex = 0; + + /* + * calculate start index to traverse free memory list's array + */ + index = length >> MEM_ALLOCATOR_BASE_SHIFT; + index = (index >= FREE_LIST_NUM) ? (FREE_LIST_NUM - 1) : (index - 1); + + /* + * traverse memory array until find a memory block has enough + * size, then take all or part from it + */ + for (areaIndex = index; areaIndex < FREE_LIST_NUM; areaIndex++) { + if (freeHead[areaIndex].IsEmpty()) { + continue; + } + NetLinkedListNode> *areaNode = freeHead[areaIndex].head.next; + while (areaNode != &freeHead[areaIndex].head) { + ma = CAST_TO_RBNODE(areaNode); + + /* + * if current checked memory block has enough size for request length, we + * take three condition in consider: + * 1. equal to length, we just take address and remove the node + * 2. greater than length, but not enough for record a new node, we just + * take address and remove the node + * 3. enough for length and new node data, thus we take address and record + * a new node in spare memory + * + */ + bool firstCond = ma->data.length == length; + if (metaStored) { + firstCond = ma->data.length >= length && ma->data.length < length + NODE_SIZE + minBlockSize; + } + if (firstCond) { + *startAddress = ma->data.startAddress; + + CAST_TO_LIST_NODE(ma)->RemoveSelf(); + MEM_ALLOCATOR_ATOMIC_DEC(&freeCnt[ma->data.index]); + + root->Erase(ma); + if (freeSize < length) { + NN_LOG_ERROR("the length " << length << " is bigger than the size remaining " << freeSize); + return NN_ERROR; + } + freeSize -= length; + return NN_OK; + } else if (ma->data.length > length) { + newMa = CAST_TO_RBNODE(ma->data.startAddress + length); + newMa->data.startAddress = ma->data.startAddress + length; + newMa->data.endAddress = ma->data.endAddress; + newMa->data.length = ma->data.length - length; +#ifdef ALLOCATOR_PROTECTION_ENABLED + newMa->data.Initialize(); +#endif + *startAddress = ma->data.startAddress; + + CAST_TO_LIST_NODE(ma)->RemoveSelf(); + MEM_ALLOCATOR_ATOMIC_DEC(&freeCnt[ma->data.index]); + nIndex = newMa->data.length >> MEM_ALLOCATOR_BASE_SHIFT; + nIndex = (nIndex >= FREE_LIST_NUM) ? (FREE_LIST_NUM - 1) : nIndex; + newMa->data.index = nIndex; + + freeHead[nIndex].Append((CAST_TO_LIST_NODE(newMa))); + MEM_ALLOCATOR_ATOMIC_INC(&freeCnt[nIndex]); + + root->Replace(ma, newMa); + if (freeSize < length) { + NN_LOG_ERROR("the length " << length << " is bigger than the size remaining " << freeSize); + return NN_ERROR; + } + freeSize -= length; + return NN_OK; + } else { + areaNode = areaNode->next; + continue; + } + } + } + + return NN_ERROR; +} + +/* + * @brief putback operate,traverse red-black tree to find the correct position + * to insert memory block, then insert the related node to red-black tree and + * freeHead lists + */ +NResult MemoryRegion::MemoryAreaInsert(uint64_t startAddress, uint64_t length) +{ + auto root = &mRoot; + auto newMa = reinterpret_cast(startAddress); + uint32_t index = 0; +#ifdef ALLOCATOR_PROTECTION_ENABLED + newMa->Initialize(); +#endif + newMa->startAddress = startAddress; + newMa->endAddress = startAddress + length; + newMa->length = length; + + CAST_TO_LIST_NODE(newMa)->ReLinkSelf(); + + NetRbNode **newNode = &(root->ref); + NetRbNode *parentNode = nullptr; + + /* + * traverse red-black tree to find the right place to insert returned memory + * block + */ + MemoryAreaRawPtr ma = nullptr; + + while (*newNode) { + ma = &(*newNode)->data; + parentNode = *newNode; + + if (newMa->endAddress == ma->startAddress) { + MemoryAreaInsertPre(CAST_TO_RBNODE(newMa), CAST_TO_RBNODE(ma)); + freeSize += length; + return NN_OK; + } else if (newMa->endAddress < ma->startAddress) { + newNode = &((*newNode)->left); + } else if (newMa->startAddress == ma->endAddress) { + MemoryAreaInsertNext(CAST_TO_RBNODE(newMa), CAST_TO_RBNODE(ma)); + freeSize += length; + return NN_OK; + } else if (newMa->startAddress > ma->endAddress) { + newNode = &((*newNode)->right); + } else { + NN_LOG_ERROR("Areas overlapped failed"); + return NN_ERROR; + } + } + + /* + * free memory block has been inserted to red-black tree, then we sync + * freeHead, which speed up taking operation + */ + index = newMa->length >> MEM_ALLOCATOR_BASE_SHIFT; + index = (index >= FREE_LIST_NUM) ? (FREE_LIST_NUM - 1) : (index - 1); + newMa->index = index; + freeHead[index].Append(CAST_TO_LIST_NODE(newMa)); + MEM_ALLOCATOR_ATOMIC_INC(&freeCnt[index]); + + CAST_TO_RBNODE(newMa)->Link(parentNode, newNode); + root->Insert(CAST_TO_RBNODE(newMa)); + freeSize += length; + + return NN_OK; +} +} // namespace hcom +} // namespace ock diff --git a/src/common/net_mem_allocator.h b/src/common/net_mem_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..7cd27333814f9c18eae5e8cd35983cebc36cbd53 --- /dev/null +++ b/src/common/net_mem_allocator.h @@ -0,0 +1,397 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef COMMUNICATION_NET_MEM_ALLOCATOR_H +#define COMMUNICATION_NET_MEM_ALLOCATOR_H + +#include +#ifdef ALLOCATOR_PROTECTION_ENABLED +#include +#endif + +#include "hcom.h" +#include "hcom_def.h" +#include "hcom_log.h" +#include "hcom_utils.h" +#include "net_linked_list.h" +#include "net_rb_tree.h" + +#ifdef ALLOCATOR_PROTECTION_ENABLED +#define MA_MAGIC 666666 /* magic number to check whether memory node was corrupted */ +#endif + +#define MEM_ALLOCATOR_ROUND_UP_TO(x, align) \ + (((x) + (align)-1) / ((align)) * ((align))) /* round x to align, etc round(3,4) = 4 */ + +#define MA_LEN_LEN (NN_NO4) +#define MA_CANARY_LEN (NN_NO4) +#define MA_META_DATA_RESERVE_LEN \ + (NN_NO16) /* total reserve length, for memcpy's performance, 16 bytes alignment \ + is best */ + +namespace ock { +namespace hcom { +const uint64_t FREE_LIST_NUM = NN_NO1024; + +class NetMemAllocator; + +struct MediaDescribe { + uint64_t startAddress = 0; + uint64_t endAddress = 0; +}; + +struct MemoryArea { +#ifdef ALLOCATOR_PROTECTION_ENABLED + uint64_t magic = MA_MAGIC; +#endif + uint64_t startAddress = 0; + uint64_t endAddress = 0; + uint64_t length = 0; + uint64_t index = 0; + + MemoryArea() = default; + +#ifdef ALLOCATOR_PROTECTION_ENABLED + /* + * @brief setup memory area's magic number, since all memory areas are + * always cast from dirty memory + */ + inline void Initialize() + { + magic = MA_MAGIC; + } +#endif +}; + +using MemoryAreaRawPtr = MemoryArea *; + +/* + * MemoryRegion is an inner management class for a continuous memory, it offers + * remove and insert functions, which will sync mRoot and freeHead inside + * routine. This class should not be used separately, it's part of + * NetMemAllocator + * + * @property mRoot stores free memory blocks in red-black tree + * @property freeHead is an array of free memory block list,list with greater + * array index contains bigger memory block + * + */ +class MemoryRegion { +public: + NetRbTree mRoot; + NetLinkedList> freeHead[FREE_LIST_NUM]; + int64_t freeCnt[FREE_LIST_NUM] {}; + uint64_t totalSize = 0; + uint64_t freeSize = 0; + +private: + NResult MemoryAreaRemove(uint64_t *startAddress, uint64_t length, uint32_t minBlockSize, bool metaStored = true); + + NResult MemoryAreaInsert(uint64_t startAddress, uint64_t length); + + void MemoryAreaInsertPre(NetRbNode *newMa, NetRbNode *ma); + + void MemoryAreaInsertNext(NetRbNode *newMa, NetRbNode *ma); + + friend NetMemAllocator; +}; + +using MemoryRegionRawPtr = MemoryRegion *; + +using NetMemAllocatorPtr = NetRef; + +/* + * NetMemAllocator is a self-scaling memory pool, supporting take and put memory + * with different size from it within pool capacity. You should bind an already + * allocated continuous memory to NetMemAllocator, this class will not + * alloc/malloc any memory inside + * + * @property mMemRegionLock mutex lock for thread safe + * @property mMemRegionMgr inner memory manager support base operation for take + * and put memory + * @property mAddrMap stores taken memory with address-length pair,which help to + * reduce one length parameter for putback api, also it prevents unmatched + * length with some address when putback + */ +class NetMemAllocator : public UBSHcomNetMemoryAllocator { +public: + inline NResult Initialize(uintptr_t mrAddress, uint64_t mrSize, uint32_t minBlockSize, bool alignAddress) + { + if (NN_UNLIKELY(mInited)) { + if (mrAddress == mMRAddress && mrSize == mMRSize && minBlockSize == mMinBlockSize) { + return NN_OK; + } + NN_LOG_ERROR("Already initialized,can not be initialized again with different parameters"); + return NN_ERROR; + } + + if (mrAddress == 0) { + NN_LOG_ERROR("address can not be null"); + return NN_INVALID_PARAM; + } + + if (!POWER_OF_2(minBlockSize)) { + NN_LOG_ERROR("minBlockSize must be power of 2"); + return NN_INVALID_PARAM; + } + + if (minBlockSize < NN_NO4096 || minBlockSize > NN_NO1024 * NN_NO1024 * NN_NO1024) { + NN_LOG_ERROR("minBlockSize must be at least 4096 byte and not greater than 1 gigabyte"); + return NN_INVALID_PARAM; + } + + if (mrSize < minBlockSize) { + NN_LOG_ERROR("mrSize must be greater than minBlockSize"); + return NN_INVALID_PARAM; + } + + mInited = true; + mAlignAddress = alignAddress; + mMRAddress = mrAddress; + mMRSize = mrSize; + mMinBlockSize = minBlockSize; + + MediaDescribe media = {}; + media.startAddress = mMRAddress; + media.endAddress = mMRAddress + mMRSize; + + if (media.endAddress <= media.startAddress) { + NN_LOG_ERROR("mrSize must be legal"); + return NN_INVALID_PARAM; + } + + auto hr = MemoryRegionInit(media); + if (NN_UNLIKELY(hr != NN_OK)) { + NN_LOG_ERROR("Init mem region mgr failed " << hr); + return hr; + } + + NN_LOG_INFO("Init mem region mgr success, mr size " << mMRSize); + return NN_OK; + } + + inline void Destroy() override + { +#ifdef ALLOCATOR_PROTECTION_ENABLED + UnProtectAllMem(); +#endif + } + + inline uintptr_t MemOffset(uintptr_t address) const override + { + if (address < mMRAddress) { + NN_LOG_ERROR("invalid address in MemOffset"); + } + return address - mMRAddress; + } + + inline uint64_t FreeSize() const override + { + return mMemRegionMgr.freeSize; + } + + inline NResult Allocate(uint64_t size, uintptr_t &mrAddress) override + { + if (NN_UNLIKELY(!mInited)) { + NN_LOG_ERROR("Allocator not initialized, Allocate failed"); + return NN_NOT_INITIALIZED; + } + + uint64_t alignedSize = MEM_ALLOCATOR_ROUND_UP_TO(size, mMinBlockSize); + uint64_t address = 0; + + if (mAlignAddress) { + if (NN_UNLIKELY(NN_OK != RegionMallocWithMap(address, alignedSize))) { + NN_LOG_ERROR("Mem allocate failed"); + return NN_ERROR; + } + } else { + if (NN_UNLIKELY(NN_OK != RegionMalloc(address, alignedSize, alignedSize - size))) { + NN_LOG_ERROR("Mem allocate failed"); + return NN_ERROR; + } + } + + mrAddress = static_cast(address); + + NN_LOG_TRACE_INFO("Mem allocate success, addr size " << size << " alignSize " << + alignedSize); + return NN_OK; + } + + inline NResult Free(uintptr_t mrAddress) override + { + if (NN_UNLIKELY(!mInited)) { + NN_LOG_ERROR("Allocator not initialized, Free failed"); + return NN_NOT_INITIALIZED; + } + + if (NN_UNLIKELY(mrAddress == 0)) { + NN_LOG_WARN("mrAddress is zero, directly back"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY((mrAddress < mMRAddress) || (mrAddress > mMRAddress + mMRSize))) { + NN_LOG_ERROR("Mem free failed, because address is not overlapped"); + return NN_ERROR; + } + + NResult hr = NN_OK; + if (mAlignAddress) { + hr = RegionFreeWithMap(static_cast(mrAddress)); + } else { + hr = RegionFree(static_cast(mrAddress)); + } + + if (NN_UNLIKELY(hr != NN_OK)) { + NN_LOG_ERROR("Mem free failed"); + return hr; + } + + NN_LOG_TRACE_INFO("Mem free success"); + return NN_OK; + } + + inline NResult GetSizeByAddressNoAlign(uint64_t startAddress, uint64_t &length) + { + if (startAddress != mMRAddress) { + if (startAddress < MA_META_DATA_RESERVE_LEN) { + NN_LOG_WARN("address don't have enough space for meta data"); + return NN_ERROR; + } + if (startAddress < mMRAddress || startAddress >= mMRAddress + mMRSize) { + NN_LOG_WARN("address is illegal"); + return NN_ERROR; + } + auto metaReqLenAddress = + reinterpret_cast(startAddress - MA_META_DATA_RESERVE_LEN + MA_CANARY_LEN + MA_LEN_LEN); + length = *metaReqLenAddress; + } else { + if (NN_UNLIKELY(mFirstReqLength == 0)) { + NN_LOG_ERROR("Address Invalid in GetSizeByAddressNoAlign"); + return NN_ERROR; + } + length = mFirstReqLength; + } + + return NN_OK; + } + + inline NResult GetSizeByAddressAlign(uint64_t startAddress, uint64_t &length) + { + std::lock_guard lock(mMemRegionLock); + auto iter = mAddrLenMap.find(startAddress); + if (NN_LIKELY(iter != mAddrLenMap.end())) { + length = iter->second; + return NN_OK; + } + + return NN_ERROR; + } + + inline uint32_t MinBlockSize() const + { + return mMinBlockSize; + } + +protected: + uintptr_t mMRAddress = 0; + uint64_t mMRSize = 0; + +private: + NResult RegionFree(uint64_t startAddress); + + NResult RegionFreeWithMap(uint64_t startAddress); + + NResult RegionMalloc(uint64_t &startAddress, uint64_t length, uint64_t deltaLength); + + NResult RegionMallocWithMap(uint64_t &startAddress, uint64_t length); + + NResult MemoryRegionJoin(MediaDescribe *mediaDesc); + + void MemoryRegionInitial(); + + NResult MemoryRegionInit(MediaDescribe &media); + +#ifdef ALLOCATOR_PROTECTION_ENABLED + /* + * @brief traverse all (free) memory area nodes to confirm no ma node was corrupted, + * this function works when mprotect unavaliable, such as shared memory based allocator, + * where memory user is a different process to allocator owner, thus they have different + * pagetables, mprotect doesn't work + * @return false if corrupted node was found(usually caused by uaf or overwrite), + * otherwise return true, + */ + inline bool CheckNodes() + { + for (auto &i : mMemRegionMgr.freeHead) { + auto cur = &i.head; + do { + if (cur->data.data.magic != MA_MAGIC) { + NN_LOG_ERROR("free memory node " << (uint64_t)cur << + " was corrupted, usually caused by use after free."); + return false; + } + cur = cur->next; + } while (cur != &i.head); + } + + return true; + } + + /* + * @brief call mprotect on all (free) memory area nodes, which let only read operation + * allowed for those memories, once ProtectFreeMem called, if free memories were written, + * SIGSEGV will rise + * + * mprotect() changes the access protections for the calling + * process's memory pages containing any part of the address range + * in the interval [addr, addr+len-1]. addr must be aligned to a + * page boundary. + * + * If the calling process tries to access memory in a manner that + * violates the protections, then the kernel generates a SIGSEGV + * signal for the process. + */ + inline void ProtectFreeMem() + { + for (auto &i : mMemRegionMgr.freeHead) { + auto cur = &i.head; + do { + mprotect(cur, cur->data.data.length, PROT_READ); + cur = cur->next; + } while (cur != &i.head); + } + } + + /* + * @brief remove write protection added by ProtectFreeMem, this is only called before + * internal manipulation on allocator, to improve performance, just change all memory + * managed to read write is safe + */ + inline void UnProtectAllMem() + { + mprotect((void *)mMRAddress, mMRSize, PROT_READ | PROT_WRITE); + } +#endif + + std::mutex mMemRegionLock; + MemoryRegion mMemRegionMgr; + uint64_t mFirstAllocLength = 0; + uint64_t mFirstReqLength = 0; + uint32_t mMinBlockSize = NN_NO4096; + bool mInited = false; + bool mAlignAddress = false; + std::unordered_map mAddrLenMap; +}; +} // namespace hcom +} // namespace ock +#endif // COMMUNICATION_NET_MEM_ALLOCATOR_H diff --git a/src/common/net_mem_allocator_cache.cpp b/src/common/net_mem_allocator_cache.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd3bb66bf6c11da1173c6f0a96fa7dafa9369c36 --- /dev/null +++ b/src/common/net_mem_allocator_cache.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_common.h" +#include "net_mem_allocator_cache.h" + +namespace ock { +namespace hcom { +uint16_t NetCacheTierFuncTimes(uint64_t size, uint64_t baseSize) +{ + /* + * 1 firstly make the times of base size + * 2 secondly get tier index + */ + return NetFunc::NN_RoundUpTo(size, baseSize) / baseSize - 1; +} + +uint16_t NetCacheTierFuncPower(uint64_t size, uint64_t baseSize) +{ + return NetFunc::NN_PowerOfNIndex(size, baseSize); +} + +NResult NetAllocatorCache::Initialize(const UBSHcomNetMemoryAllocatorOptions &options) +{ + if (mTieredBlockHead != nullptr) { + return NN_OK; + } + + mAligned = options.alignedAddress; + + /* validate cache tier count */ + if (options.cacheTierCount == NN_NO0 || options.cacheTierCount > NN_NO8192) { + NN_LOG_ERROR("Invalid cacheTierCount " << options.cacheTierCount << " for allocator cache, which should <= " << + NN_NO8192 << " and != " << NN_NO0); + return NN_INVALID_PARAM; + } + + /* validate cache block count in each tier and assign */ + if (options.cacheBlockCountPerTier < NN_NO4 || options.cacheBlockCountPerTier > NN_NO8192) { + NN_LOG_ERROR("Invalid cacheBlockCountPerTier " << options.cacheBlockCountPerTier << + " for allocate cache, which should between " << NN_NO4 << "~" << NN_NO8192); + return NN_INVALID_PARAM; + } + mBlockCacheCountPerTier = options.cacheBlockCountPerTier; + + /* ref major allocator */ + if (mMajorAllocator == nullptr) { + NN_LOG_ERROR("Failed to allocator cache as major allocator is null"); + return NN_INVALID_PARAM; + } + + mBaseBlockSize = mMajorAllocator->MinBlockSize(); + + /* block tier count */ + mBlockTierCount = options.cacheTierCount; + if (options.cacheTierPolicy == TIER_TIMES) { + mTierChooseFunc = &NetCacheTierFuncTimes; + mMaxCacheBlockSize = mBaseBlockSize * mBlockTierCount; + } else if (options.cacheTierPolicy == TIER_POWER) { + if (mBlockTierCount > NN_NO31) { + NN_LOG_ERROR("Invalid cacheTierCount " << options.cacheTierCount << + " for allocator cache, since the cacheTierPolicy is TIER_POWER, then it should <= " << NN_NO31 << + " and != " << NN_NO0); + return NN_INVALID_PARAM; + } + mTierChooseFunc = &NetCacheTierFuncPower; + uint64_t timesOfPower2 = 1 << (mBlockTierCount - 1); + mMaxCacheBlockSize = mBaseBlockSize * timesOfPower2; + } else { + NN_ASSERT_LOG_RETURN(false, NN_INVALID_PARAM); + } + + /* allocate address to size map */ + if (options.alignedAddress) { + mAddress2SizeMap = new (std::nothrow) NetAddressSizeMap(); + if (NN_UNLIKELY(mAddress2SizeMap == nullptr)) { + NN_LOG_ERROR("Failed to new address to size map for allocator cache"); + return NN_NEW_OBJECT_FAILED; + } + if (NN_UNLIKELY(mAddress2SizeMap->Initialize(options.bucketCount))) { + delete mAddress2SizeMap; + mAddress2SizeMap = nullptr; + NN_LOG_ERROR("Failed to initialize address to size map for allocator cache"); + return NN_NEW_OBJECT_FAILED; + } + } + + /* allocate tier head and get physical memory */ + mTieredBlockHead = new (std::nothrow) NetMemAllocCacheLinkNode[mBlockTierCount]; + if (mTieredBlockHead == nullptr) { + UnInitialize(); + NN_LOG_ERROR("Failed to new tier buckets head for allocator cache"); + return NN_NEW_OBJECT_FAILED; + } + bzero(mTieredBlockHead, sizeof(NetMemAllocCacheLinkNode) * mBlockTierCount); + + /* set tier size for buckets */ + for (uint16_t i = 0; i < mBlockTierCount; i++) { + if (options.cacheTierPolicy == TIER_TIMES) { + mTieredBlockHead[i].blockSizeInKB = (mBaseBlockSize / NN_NO1024) * (i + 1); + } else if (options.cacheTierPolicy == TIER_POWER) { + mTieredBlockHead[i].blockSizeInKB = (mBaseBlockSize / NN_NO1024) * (1 << i); + } + } + + NN_LOG_INFO("Initialized allocator cache, aligned " << mAligned << ", tierCount " << mBlockTierCount << + ", blockCountPerTier " << mBlockCacheCountPerTier << ", minBlockSize " << mBaseBlockSize << + ", maxCacheBlockSize " << mMaxCacheBlockSize << ", tier bucket heads occupied memory " << + (sizeof(NetMemAllocCacheLinkNode) * mBlockTierCount)); + + return NN_OK; +} + +void NetAllocatorCache::UnInitialize() +{ + if (mMajorAllocator != nullptr) { + mMajorAllocator->DecreaseRef(); + mMajorAllocator = nullptr; + } + + if (mAddress2SizeMap != nullptr) { + delete mAddress2SizeMap; + mAddress2SizeMap = nullptr; + } + + if (mTieredBlockHead != nullptr) { + delete[] mTieredBlockHead; + mTieredBlockHead = nullptr; + } +} +} +} \ No newline at end of file diff --git a/src/common/net_mem_allocator_cache.h b/src/common/net_mem_allocator_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..a7a9a1276cfbf3bb1d4078968a6560c191145073 --- /dev/null +++ b/src/common/net_mem_allocator_cache.h @@ -0,0 +1,260 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. +change * Author: bao, lu + */ +#ifndef HCOM_NET_MEM_ALLOCATOR_CACHE_H +#define HCOM_NET_MEM_ALLOCATOR_CACHE_H + +#include "hcom.h" +#include "net_addr_size_map.h" +#include "net_mem_allocator.h" + +namespace ock { +namespace hcom { +/* + * @brief Link node of blocks which are same size of block + * + * NOTE: make sure all memory is aligned + */ +struct NetMemAllocCacheLinkNode { + NetMemAllocCacheLinkNode *next = nullptr; /* link node for block allocated from major allocator */ + uint32_t lock = 0; + uint32_t accessCount = 0; /* access count, only for head */ + uint32_t blockSizeInKB = NN_NO4; /* size of the block in KB, only for head */ + uint16_t currentBlocks = 0; /* current cached count of blocks */ + /* + * @brief Spin lock + */ + void Lock() + { + while (!__sync_bool_compare_and_swap(&lock, 0, NN_NO1)) { + } + } + + /* + * @brief Unlock + */ + void Unlock() + { + __atomic_store_n(&lock, 0, __ATOMIC_SEQ_CST); + } +}; + +/* + * @brief Tier choose interface and functions + */ +using NetCacheTierFunc = uint16_t(uint64_t size, uint64_t baseSize); + +/* + * @brief This is a lockless allocator cache, which allocates several size of block in batch from major allocator + */ +class NetAllocatorCache : public UBSHcomNetMemoryAllocator { +public: + explicit NetAllocatorCache(NetMemAllocator *majorAllocator) : mMajorAllocator(majorAllocator) + { + if (mMajorAllocator != nullptr) { + mMajorAllocator->IncreaseRef(); + } + } + + ~NetAllocatorCache() override + { + UnInitialize(); + } + + NResult Initialize(const UBSHcomNetMemoryAllocatorOptions &options); + void UnInitialize(); + + NResult Allocate(uint64_t size, uintptr_t &address) override + { + NN_ASSERT_LOG_RETURN(mMajorAllocator != nullptr, NN_ERROR) + + /* if the size is larger than max cache block size */ + if (NN_UNLIKELY(size > mMaxCacheBlockSize)) { + return mMajorAllocator->Allocate(size, address); + } + + /* get tier index, increase access count */ + auto tierIndex = mTierChooseFunc(size, mBaseBlockSize); + NN_ASSERT_LOG_RETURN(tierIndex < mBlockTierCount, NN_INVALID_PARAM); + + auto &oneList = mTieredBlockHead[tierIndex]; + oneList.accessCount++; + + /* allocated from tiered block linked list */ + oneList.Lock(); + if (oneList.next != nullptr) { + address = reinterpret_cast(oneList.next); + oneList.next = oneList.next->next; + --oneList.currentBlocks; + mTotalCacheSizeKB -= oneList.blockSizeInKB; + oneList.Unlock(); + /* it is aligned need to record size with address into hashmap */ + if (mAligned) { + uint32_t timesOfBaseSize = oneList.blockSizeInKB * NN_NO1024 / mBaseBlockSize; + mAddress2SizeMap->Put(address, timesOfBaseSize); + } + return NN_OK; + } + + /* if not allocated then allocate from major pool */ + uintptr_t majorAddress = 0; + NetMemAllocCacheLinkNode *newNode = nullptr; + for (uint16_t i = 0; i < mBlockCacheCountPerTier; i++) { + if (mMajorAllocator->Allocate(oneList.blockSizeInKB * NN_NO1024, majorAddress) != NN_OK) { + break; + } + + /* added to one list, if it is first allocate from major + * and remember this as newNode for link next allocated memory block + */ + ++oneList.currentBlocks; + if (newNode == nullptr) { + newNode = reinterpret_cast(majorAddress); + newNode->next = nullptr; + oneList.next = newNode; + } else { + newNode->next = reinterpret_cast(majorAddress); + newNode = newNode->next; + newNode->next = nullptr; + } + + mTotalCacheSizeKB += oneList.blockSizeInKB; + } + + /* allocate from cache again, it happens when allocated from major */ + if (oneList.next != nullptr) { + address = reinterpret_cast(oneList.next); + oneList.next = oneList.next->next; + --oneList.currentBlocks; + mTotalCacheSizeKB -= oneList.blockSizeInKB; + oneList.Unlock(); + /* it is aligned need to remember size with address into hashmap */ + if (mAligned) { + uint32_t timesOfBaseSize = oneList.blockSizeInKB * NN_NO1024 / mBaseBlockSize; + mAddress2SizeMap->Put(address, timesOfBaseSize); + } + return NN_OK; + } + + /* do later, free some bigger from cache */ + + /* unlock */ + oneList.Unlock(); + + return NN_ERROR; + } + + inline NResult Free(uintptr_t address) override + { + NN_ASSERT_LOG_RETURN(mMajorAllocator != nullptr, NN_ERROR); + + uint64_t size = 0; + /* firstly get size */ + if (!mAligned) { + auto result = mMajorAllocator->GetSizeByAddressNoAlign(address, size); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_WARN("Try to free invalid address in allocator cache"); + return result; + } + + if (size > mMaxCacheBlockSize) { + return mMajorAllocator->Free(address); + } + + goto FREE_TO_CACHE; + } else { + /* find from address to size map for cache */ + uint32_t timesOfBaseSize = 0; + if (mAddress2SizeMap->Remove(address, timesOfBaseSize) == NN_OK) { + size = timesOfBaseSize * mBaseBlockSize; + goto FREE_TO_CACHE; + } + + /* if not found, try to find from major allocator */ + auto result = mMajorAllocator->GetSizeByAddressAlign(address, size); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_WARN("Try to free invalid address in allocator cache"); + return result; + } + + /* if found in major allocator, free it */ + return mMajorAllocator->Free(address); + } + + FREE_TO_CACHE: + /* + * attach to linked list for this + * step1: get tiered index + */ + auto tierIndex = mTierChooseFunc(size, mBaseBlockSize); + NN_ASSERT_LOG_RETURN(tierIndex < mBlockTierCount, NN_INVALID_PARAM); + + /* step2: attach */ + auto &oneList = mTieredBlockHead[tierIndex]; + oneList.Lock(); + auto tmp = reinterpret_cast(address); + tmp->next = oneList.next; + oneList.next = tmp; + mTotalCacheSizeKB += oneList.blockSizeInKB; + /* if not two times cached, return */ + if (++oneList.currentBlocks < mBlockCacheCountPerTier * NN_NO2) { + oneList.Unlock(); + return NN_OK; + } + + NetMemAllocCacheLinkNode *returnHead = oneList.next; + for (uint16_t i = 0; i < mBlockCacheCountPerTier - NN_NO2; ++i) { + returnHead = returnHead->next; + } + + /* set last node will not be returned */ + NetMemAllocCacheLinkNode *lastRetained = returnHead->next; + returnHead = lastRetained->next; + lastRetained->next = nullptr; + mTotalCacheSizeKB -= oneList.blockSizeInKB * mBlockCacheCountPerTier; + oneList.currentBlocks -= mBlockCacheCountPerTier; + /* unlock */ + oneList.Unlock(); + + NetMemAllocCacheLinkNode *next = nullptr; + /* free to major allocator */ + while (returnHead != nullptr) { + next = returnHead->next; + (void)mMajorAllocator->Free(reinterpret_cast(returnHead)); + returnHead = next; + } + + return NN_OK; + } + + inline uintptr_t MemOffset(uintptr_t address) const override + { + NN_ASSERT_LOG_RETURN(mMajorAllocator != nullptr, 0) + return mMajorAllocator->MemOffset(address); + } + + uint64_t FreeSize() const override + { + NN_ASSERT_LOG_RETURN(mMajorAllocator != nullptr, 0) + return mTotalCacheSizeKB * NN_NO1024 + mMajorAllocator->FreeSize(); + } + +private: + using NetAddressSizeMap = NetAddress2SizeHashMap; + + NetMemAllocCacheLinkNode *mTieredBlockHead = nullptr; /* tiered buckets */ + NetMemAllocator *mMajorAllocator = nullptr; /* major allocator */ + NetAddressSizeMap *mAddress2SizeMap = nullptr; /* hash map for address to size, for unaligned allocator */ + NetCacheTierFunc *mTierChooseFunc = nullptr; /* tier choose function */ + uint64_t mMaxCacheBlockSize = NN_NO4096; /* max block size, if >= this, allocate from major */ + uint64_t mTotalCacheSizeKB = 0; /* total cache size in KB */ + uint64_t mBaseBlockSize = NN_NO4096; /* min block size */ + uint16_t mBlockCacheCountPerTier = NN_NO16; /* cached block count in each tier */ + uint16_t mBlockTierCount = 0; /* timer count */ + bool mAligned = true; /* address is aligned or not */ +}; +} +} + +#endif // HCOM_NET_MEM_ALLOCATOR_CACHE_H diff --git a/src/common/net_mem_pool_fixed.cpp b/src/common/net_mem_pool_fixed.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e07bc57073f3ba6d74b47fbbec044f54d32a37cd --- /dev/null +++ b/src/common/net_mem_pool_fixed.cpp @@ -0,0 +1,257 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_mem_pool_fixed.h" +#include "net_monotonic.h" + +namespace ock { +namespace hcom { +NetMemPoolFixed::NetMemPoolFixed(const std::string &name, const NetMemPoolFixedOptions &options) + : mOptions(options), mName(name) +{ + OBJ_GC_INCREASE(NetMemPoolFixed); +} + +NResult NetMemPoolFixed::Initialize() +{ + std::lock_guard guard(mMutex); + if (mInited) { + return NN_OK; + } + + /* validate options */ + NResult result = Validate(); + if (result != NN_OK) { + return result; + } + + /* reserve to avoid reallocate memory when expanding vector */ + mSuperBlocks.reserve(NN_NO1024); + + /* expand one super block from os */ + if ((result = ExpandFromOs(false)) != NN_OK) { + DoUnInitialize(); + return result; + } + + mInited = true; + return NN_OK; +} + +void NetMemPoolFixed::DoUnInitialize() +{ + for (auto &iter : mSuperBlocks) { + free(iter.buffer); + iter.buffer = nullptr; + } + + mSuperBlocks.clear(); + mTotalSuperBlkSize = 0; + mFreeCount = 0; +} + +NResult NetMemPoolFixed::Validate() +{ + /* validate super block size, which must between 1 and 256 MB including 256 MB */ + if (mOptions.superBlkSizeMB == 0 || mOptions.superBlkSizeMB > NN_NO256) { + NN_LOG_ERROR("Invalid superBlkSizeMB " << mOptions.superBlkSizeMB << " in mem pool " << mName << + ", which be 1~" << NN_NO256 << ", reset to " << NN_NO4); + return NN_INVALID_PARAM; + } + + /* validate thread cache expand and shrink steps, which between 8 and 256 MB */ + if (mOptions.tcExpandBlkCnt < NN_NO8 || mOptions.tcExpandBlkCnt > NN_NO256) { + NN_LOG_ERROR("Invalid tcExpandBlkCnt " << mOptions.tcExpandBlkCnt << " in mem pool " << mName << + ", which be " << NN_NO8 << "~" << NN_NO256 << ", reset to " << NN_NO128); + return NN_INVALID_PARAM; + } + + /* validate size of min block */ + if (mOptions.minBlkSize < sizeof(NetMemPoolMinBlock)) { + NN_LOG_ERROR("Invalid minBlkSize " << mOptions.minBlkSize << " in mem pool " << mName << + ", which be larger than " << sizeof(NetMemPoolMinBlock)); + return NN_INVALID_PARAM; + } + + /* validate relation of min block size and super block size */ + uint64_t tmp = mOptions.minBlkSize * mOptions.tcExpandBlkCnt * NN_NO16; + tmp = tmp / NN_NO1024 / NN_NO1024; /* in MB */ + if (tmp > mOptions.superBlkSizeMB) { + NN_LOG_ERROR("Invalid minBlkSize " << mOptions.minBlkSize << " in mem pool " << mName); + return NN_INVALID_PARAM; + } + + uint64_t superBlkSize = mOptions.superBlkSizeMB * NN_NO1024 * NN_NO1024; + if (superBlkSize % (mOptions.minBlkSize * mOptions.tcExpandBlkCnt)) { + NN_LOG_ERROR("Invalid minBlkSize " << mOptions.minBlkSize << " or tcExpandBlkCnt " << mOptions.tcExpandBlkCnt << + " in mem pool " << mName << ", super block size is not times of " << + mOptions.minBlkSize * mOptions.tcExpandBlkCnt); + return NN_INVALID_PARAM; + } + + return NN_OK; +} + +NResult NetMemPoolFixed::ExpandFromOs(bool holdFreeListLock) +{ + uint64_t startTime = NetMonotonic::TimeNs(); + /* allocate memory */ + auto superBlkSize = (mTotalSuperBlkSize == 0) ? + (mOptions.superBlkSizeMB * NN_NO1024 * NN_NO1024) : mTotalSuperBlkSize; + auto mem = memalign(NN_NO4096, superBlkSize); + if (mem == nullptr) { + NN_LOG_ERROR("Failed to malloc memory for supper block in mem pool " << mName); + return NN_MALLOC_FAILED; + } + + /* get physical memory */ + bzero(mem, superBlkSize); + + /* insert to super block list */ + NetMemPoolSuperBlock blk(mem, superBlkSize); + mSuperBlocks.emplace_back(blk); + + /* add size and count */ + mTotalSuperBlkSize += superBlkSize; + + /* make free linked list */ + NetMemPoolMinBlock *head = nullptr; + NetMemPoolMinBlock *tail = nullptr; + uint32_t count = 0; + NResult result = NN_OK; + if ((result = MakeFreeList(blk, head, tail, count)) != NN_OK) { + return result; + } + + /* attach free linked list */ + if (holdFreeListLock) { + mTcMutex.Lock(); + } + + if ((result = AttacheToFreeList(head, tail, count)) != NN_OK) { + if (holdFreeListLock) { + mTcMutex.Unlock(); + } + return result; + } + + if (holdFreeListLock) { + mTcMutex.Unlock(); + } + + NN_LOG_INFO("Fixed size memory pool " << mName << " allocated " << mOptions.superBlkSizeMB << + "MB memory from os, total block size " << mTotalSuperBlkSize << " and split to " << count << + " min block with size " << mOptions.minBlkSize << " which took " << + (NetMonotonic::TimeNs() - startTime) / NN_NO1000 << "us, current free min block is " << mFreeCount); + + return NN_OK; +} + +NResult NetMemPoolFixed::TCAlloc(NetMemPoolMinBlock &head) +{ + NN_ASSERT_LOG_RETURN(mInited, NN_NOT_INITIALIZED) + bool flag = true; + do { + /* step 1: allocate from free list */ + mTcMutex.Lock(); + if (mFreeCount > 0) { + head.next = mFreeMinBlkList.next; + mFreeCount -= mFreeMinBlkList.next->count; + mFreeMinBlkList.next = head.next->nextN->next; + head.next->nextN->next = nullptr; + mTcMutex.Unlock(); + flag = false; + return NN_OK; + } + mTcMutex.Unlock(); + + /* step 2: if there is no free in list, allocate from OS */ + { + std::unique_lock locker(mMutex); + /* wait if already allocating from os by another thread */ + mCondForOs.wait(locker, [&]() { return !mAllocatingFromOs; }); + if (mFreeCount > 0) { + continue; + } + mAllocatingFromOs = true; + + NResult result = ExpandFromOs(true); + mAllocatingFromOs = false; + mCondForOs.notify_all(); + if (result != NN_OK) { + flag = false; + return result; + } + } + } while (flag); + + NN_LOG_ERROR("Unreachable code path"); + return NN_ERROR; +} + +std::string NetMemPoolFixed::ToString() +{ + std::ostringstream oss; + oss << "fixed-size-memory-pool [name: " << mName << ", options: [" << mOptions.ToString() << + "], super-block-count: " << mSuperBlocks.size() << ", super-block-size: " << + mTotalSuperBlkSize / NN_NO1024 / NN_NO1024 << "MB, free-min-block-count: " << mFreeCount; + + uint32_t blkIndex = 0; + oss << " super-blocks: ["; + for (auto &iter : mSuperBlocks) { + oss << "superBlk" << blkIndex++ << ":" << iter.size << "," << iter.buffer << " "; + } + + oss << "] min-blocks: ["; + auto iter = mFreeMinBlkList.next; + blkIndex = 0; + while (iter != nullptr) { + if (iter->nextN != nullptr) { + oss << "** " << blkIndex++ << ":" << iter << "," << iter->nextN << "," << iter->count << " "; + } else { + oss << blkIndex++ << ":" << iter << " "; + } + iter = iter->next; + } + oss << "]]"; + + return oss.str(); +} + +/* NetTCacheFixed */ +NetTCacheFixed::NetTCacheFixed(NetMemPoolFixed *sharePool) : mSharedPool(sharePool) +{ + if (NN_UNLIKELY(mSharedPool == nullptr)) { + return; + } + + mSharedPool->IncreaseRef(); + + mFreeSteps = mSharedPool->mOptions.tcExpandBlkCnt; +} + +std::string NetTCacheFixed::ToString() +{ + std::ostringstream oss; + oss << "net-thread-cache: [free-steps: " << mFreeSteps << ", current-free: " << mCurrentFree; + + auto iter = mHead.next; + uint16_t minBlkIndex = 0; + oss << " mini-blocks: ["; + while (iter != nullptr) { + oss << "blk" << minBlkIndex++ << ":" << iter << " "; + iter = iter->next; + } + oss << "]]"; + return oss.str(); +} +} +} \ No newline at end of file diff --git a/src/common/net_mem_pool_fixed.h b/src/common/net_mem_pool_fixed.h new file mode 100644 index 0000000000000000000000000000000000000000..ed77422b69bd80566f627ef6344551f8c02b3152 --- /dev/null +++ b/src/common/net_mem_pool_fixed.h @@ -0,0 +1,411 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_MEM_POOL_H +#define OCK_HCOM_NET_MEM_POOL_H + +#include +#include +#include + +#include "hcom.h" + +namespace ock { +namespace hcom { +/* + * There are two levels of blocks for fixed size of memory pool + * 1 Super block: which allocated from OS + * 2 Mini block: which allocated from thread cache to end users + * + * So: + * a) size of super block is multiple times of size of mini block + * b) size of mini block is min size of allocate unit to user + * + * Super block allocated from OS, node list + */ +struct NetMemPoolSuperBlock { + void *buffer = nullptr; /* point to real buffer */ + uint64_t size = 0; /* memory size of the super block */ + + NetMemPoolSuperBlock(void *buf, uint64_t s) : buffer(buf), size(s) {} +}; + +/* + * Mini block allocated to end user + */ +struct NetMemPoolMinBlock { + NetMemPoolMinBlock *next = nullptr; /* link to next min block */ + NetMemPoolMinBlock *nextN = nullptr; /* link to next N min block */ + uint32_t count = 0; /* current link count */ +}; + +/* + * Options of fixed size memory pool + */ +struct NetMemPoolFixedOptions { + uint16_t superBlkSizeMB = NN_NO4; /* size of each super block, by default 4 MB */ + uint16_t minBlkSize = NN_NO64; /* size of min block by default is 64 bytes */ + uint16_t tcExpandBlkCnt = NN_NO128; /* count of min block to expand from shared pool at one time */ + + std::string ToString() const + { + std::ostringstream oss; + oss << "super-blk-size-mb: " << superBlkSizeMB << ", min-blk-size: " << minBlkSize << + ", thread-cache-blk-count: " << tcExpandBlkCnt; + return oss.str(); + } +} __attribute__((packed)); + +/* + * Mem pool for fixed size objects + * + * NetMemPoolFixed is shared mem pool for all threads + * NetTCacheFixed is thread local cache + * + */ +class NetTCacheFixed; +class NetMemPoolFixed; + +using NetMemPoolFixedPtr = NetRef; + +/* + * Memory pool allocated from OS and shared by thread + */ +class NetMemPoolFixed { +public: + NetMemPoolFixed(const std::string &name, const NetMemPoolFixedOptions &options); + ~NetMemPoolFixed() + { + UnInitialize(); + OBJ_GC_DECREASE(NetMemPoolFixed); + } + + NResult Initialize(); + + void UnInitialize() + { + std::lock_guard guard(mMutex); + if (!mInited) { + return; + } + + DoUnInitialize(); + mInited = false; + } + + /* + * @brief Allocate batch of min block from pool, which called by thread cache + */ + NResult TCAlloc(NetMemPoolMinBlock &head); + + /* + * @brief Free batch of min block to pool, which called by thread cache + */ + NResult TCFree(NetMemPoolMinBlock *head) + { + NN_ASSERT_LOG_RETURN(head != nullptr, NN_NOT_INITIALIZED) + NN_ASSERT_LOG_RETURN(head->nextN != nullptr, NN_NOT_INITIALIZED) + + mTcMutex.Lock(); + (void)AttacheToFreeList(head, head->nextN, head->count); + mTcMutex.Unlock(); + + return NN_OK; + } + + std::string ToString(); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + void DoUnInitialize(); + + NResult Validate(); + + /* + * @brief Expand one super block from OS + */ + NResult ExpandFromOs(bool holdFreeListLock); + + /* + * @brief Make the super block to min block with linked list + */ + NResult MakeFreeList(const NetMemPoolSuperBlock &superBlk, NetMemPoolMinBlock *&head, NetMemPoolMinBlock *&tail, + uint32_t &count) const + { + count = superBlk.size / mOptions.minBlkSize; + auto address = reinterpret_cast(superBlk.buffer); + head = reinterpret_cast(address); + auto iter = head; + + NetMemPoolMinBlock *batchHeader = nullptr; + + /* make free linked list */ + for (uint32_t i = 1; i <= count; ++i) { + auto mod = i % mOptions.tcExpandBlkCnt; + if (mod == 1) { + /* first min block, remember this as header */ + batchHeader = iter; + } else if (mod == 0 && batchHeader != nullptr) { + /* N block, set the nextN of header to this and set count */ + batchHeader->nextN = iter; + batchHeader->count = mOptions.tcExpandBlkCnt; + } + + /* set next */ + address += mOptions.minBlkSize; + iter->next = reinterpret_cast(address); + + /* move to next and skip the last one */ + if (i != count) { + iter = reinterpret_cast(address); + } + } + + tail = iter; + tail->next = nullptr; + tail->nextN = nullptr; + tail->count = 0; + + return NN_OK; + } + + /* + * @brief Attach linked min block to free list + */ + inline NResult AttacheToFreeList(NetMemPoolMinBlock *head, NetMemPoolMinBlock *tail, uint32_t count) + { + NN_ASSERT_LOG_RETURN(head != nullptr, NN_INVALID_PARAM); + NN_ASSERT_LOG_RETURN(tail != nullptr, NN_INVALID_PARAM); + NN_ASSERT_LOG_RETURN(count != 0, NN_INVALID_PARAM); + + tail->next = mFreeMinBlkList.next; + mFreeMinBlkList.next = head; + mFreeCount += count; + return NN_OK; + } + +private: + NetSpinLock mTcMutex; + NetMemPoolMinBlock mFreeMinBlkList {}; + uint64_t mFreeCount = 0; + + NetMemPoolFixedOptions mOptions {}; + std::mutex mMutex; + std::condition_variable mCondForOs; + bool mAllocatingFromOs = false; + std::vector mSuperBlocks; + uint64_t mTotalSuperBlkSize = 0; + + std::string mName; + bool mInited = false; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class NetTCacheFixed; +}; + +/* + * Thread cache for fixed size of memory pool, usually for object + */ +class NetTCacheFixed { +public: + explicit NetTCacheFixed(NetMemPoolFixed *sharePool); + ~NetTCacheFixed() + { + FreeAllToPool(); + + if (mSharedPool != nullptr) { + mSharedPool->DecreaseRef(); + mSharedPool = nullptr; + } + } + + /* + * @brief Allocate one from thread cache, this is not thread safe + */ + template T *Allocate() + { + if (NN_LIKELY(mHead.next != nullptr)) { + /* allocate from head */ + auto tmp = mHead.next; + mHead.next = tmp->next; + /* assign tail to null if it is empty */ + --mCurrentFree; + return reinterpret_cast(tmp); + } + + /* allocate from shared pool */ + NN_ASSERT_LOG_RETURN(mSharedPool != nullptr, nullptr); + + if (mSharedPool->TCAlloc(mHead) != NN_OK) { + return nullptr; + } + + /* set current free */ + mCurrentFree = mHead.next->count - 1; + + NN_LOG_TRACE_INFO(this->ToString()); + + /* move head to next and return first */ + auto tmp = mHead.next; + mHead.next = mHead.next->next; + return reinterpret_cast(tmp); + } + + /* + * @brief Free one to thread cache, this is not thread safe + */ + template void Free(T *value) + { + if (NN_LIKELY(value == nullptr)) { + return; + } + + /* insert into first */ + auto tmp = reinterpret_cast(value); + tmp->next = mHead.next; + mHead.next = tmp; + + ++mCurrentFree; + /* judge is current free count is 2 times larger than free steps + * 1 no:just return + * 2 yes: return many to shared pool, which means return in batch to reduce the cost of mutex in shared pool + */ + if ((mCurrentFree >> 1) < mFreeSteps) { + return; + } + + /* step 1: get first */ + auto head = mHead.next; + + /* step 2: move head forward mFreeSteps and get tail */ + const uint16_t returnCount = mFreeSteps - 1; + for (uint16_t i = 0; i < returnCount; ++i) { + mHead.next = mHead.next->next; + } + head->nextN = mHead.next; + + /* step 3: move head one more */ + mHead.next = mHead.next->next; + + head->nextN->next = nullptr; + head->count = mFreeSteps; + + /* step 4: decrease current */ + mCurrentFree -= mFreeSteps; + + NN_ASSERT_LOG_RETURN_VOID(mSharedPool != nullptr); + + /* step 5: return to share pool */ + mSharedPool->TCFree(head); + } + + std::string ToString(); + +private: + /* + * Free all to pool + */ + void FreeAllToPool() + { + NN_ASSERT_LOG_RETURN_VOID(mSharedPool != nullptr); + + if (mCurrentFree == 0) { + return; + } + + /* step 1: get first */ + auto head = mHead.next; + + /* step 2: move head forward mFreeSteps and get tail */ + const uint16_t returnCount = mCurrentFree - 1; + for (uint16_t i = 0; i < returnCount; ++i) { + mHead.next = mHead.next->next; + } + head->nextN = mHead.next; + + /* step 3: reset */ + mHead.next = nullptr; + + /* step 4: free */ + head->nextN->next = nullptr; + head->count = mCurrentFree; + mSharedPool->TCFree(head); + + NN_LOG_TRACE_INFO("Thread cache for fixed size memory pool is deconstructing, returned " << mCurrentFree << + " to global pool " << mSharedPool->mName); + mCurrentFree = 0; + } + +private: + NetMemPoolMinBlock mHead {}; + NetMemPoolFixed *mSharedPool = nullptr; + uint16_t mCurrentFree = 0; + uint16_t mFreeSteps = 0; + + friend class NetMemPoolFixed; + template friend class KeyedThreadLocalCache; +}; + +/// NetTCacheFixed 通常与 thread_local 一起使用,即使上层传递的 mempool 是不同的,thread_local 对象仅会初始化一次。在同 +/// 一线程下,用户使用 x = Alloc(mempool1) 与 y = Alloc(mempool2), 实际两次分配的对象都会先尝试从 thread_local +/// NetTCacheFixed 的 freelist 中获取,如果没有则从一开始初始化的 mempool1 中分配。而如果后续两者在不同线程内被归还,那 +/// 么就会出现 Free(mempool2, y) 被还至 mempool2 但是它实际归属于 mempool1. +/// +/// 为了解决跨线程归还的问题,每个 mempool 需要各自提供一个 key 来找到对应的 NetTCacheFixed 对象,保证在每次申请/归还内 +/// 存时都使用同一内存池。 +/// +/// \seealso NetServiceCtxStore::GetOrReturn +/// \seealso HcomServiceCtxStore::GetOrReturn +template class KeyedThreadLocalCache { +public: + KeyedThreadLocalCache() = default; + ~KeyedThreadLocalCache() = default; + + template T *Allocate(uint8_t key) + { + if (key > KeyMax) { + return nullptr; + } + + return mTCacheFixeds[key] ? mTCacheFixeds[key]->template Allocate() : nullptr; + } + + template void Free(uint8_t key, T *ctx) + { + if (key > KeyMax) { + return; + } + + if (mTCacheFixeds[key]) { + mTCacheFixeds[key]->template Free(ctx); + } + } + + void UpdateIf(uint8_t key, NetMemPoolFixed *mempool) + { + if (key > KeyMax) { + return; + } + + if (!mTCacheFixeds[key] || mTCacheFixeds[key]->mSharedPool != mempool) { + mTCacheFixeds[key].reset(new (std::nothrow) NetTCacheFixed(mempool)); + } + } + +private: + std::array, KeyMax + 1> mTCacheFixeds; +}; +} // namespace hcom +} // namespace ock + +#endif // OCK_HCOM_NET_MEM_POOL_H diff --git a/src/common/net_monotonic.h b/src/common/net_monotonic.h new file mode 100644 index 0000000000000000000000000000000000000000..e1c19a47e61a86646220bc545cc7577a71861bac --- /dev/null +++ b/src/common/net_monotonic.h @@ -0,0 +1,247 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_MONOTONIC_H +#define OCK_HCOM_NET_MONOTONIC_H + +#include +#ifdef __x86_64__ +#include +#endif + +#include "hcom.h" +#include "net_common.h" +#include "net_util.h" + +namespace ock { +namespace hcom { +constexpr int32_t INIT_FAILURE_RET = NN_NO1; + +class NetMonotonic { +#ifdef USE_PROCESS_MONOTONIC +public: +#ifdef __aarch64__ + /* + * @brief init tick for us + * + */ + template static int32_t InitTickUs() + { + /* get frequ */ + uint64_t tmpFreq = 0; + __asm__ volatile("mrs %0, cntfrq_el0" : "=r"(tmpFreq)); + auto freq = static_cast(tmpFreq); + + /* calculate */ + freq = freq / 1000L / 1000L; + if (freq == 0) { + NN_LOG_ERROR("Failed to get tick as freq is " << freq); + return FAILURE_RET; + } + + return freq; + } + + /* + * @brief Get monotonic time in ns, is not absolution time + */ + static inline uint64_t TimeNs() + { + const static int32_t TICK_PER_US = InitTickUs(); + uint64_t timeValue = 0; + __asm__ volatile("mrs %0, cntvct_el0" : "=r"(timeValue)); + return timeValue * 1000L / TICK_PER_US; + } + + /* + * @brief Get monotonic time in us, is not absolution time + */ + static inline uint64_t TimeUs() + { + const static int32_t TICK_PER_US = InitTickUs(); + uint64_t timeValue = 0; + __asm__ volatile("mrs %0, cntvct_el0" : "=r"(timeValue)); + return timeValue / TICK_PER_US; + } + + /* + * @brief Get monotonic time in ms, is not absolution time + */ + static inline uint64_t TimeMs() + { + const static int32_t TICK_PER_US = InitTickUs(); + uint64_t timeValue = 0; + __asm__ volatile("mrs %0, cntvct_el0" : "=r"(timeValue)); + return timeValue / (TICK_PER_US * 1000L); + } + + /* + * @brief Get monotonic time in sec, is not absolution time + */ + static inline uint64_t TimeSec() + { + const static int32_t TICK_PER_US = InitTickUs(); + uint64_t timeValue = 0; + __asm__ volatile("mrs %0, cntvct_el0" : "=r"(timeValue)); + return timeValue / (TICK_PER_US * 1000000L); + } + +#elif __x86_64__ + template static int32_t InitTickUs() + { + const std::string path = "/proc/cpuinfo"; + const std::string prefix = "model name"; + const std::string gHZ = "GHz"; + + std::ifstream inConfFile(path); + if (!inConfFile) { + NN_LOG_ERROR("Failed to get tick as failed to open " << path); + return FAILURE_RET; + } + + bool found = false; + std::string strLine; + while (getline(inConfFile, strLine)) { + if (strLine.compare(0, prefix.size(), prefix) == 0) { + found = true; + break; + } + } + + if (!found) { + NN_LOG_ERROR("Failed to get tick as failed to find " << prefix); + return FAILURE_RET; + } + + std::vector splitVec; + NetFunc::NN_SplitStr(strLine, " ", splitVec); + if (splitVec.empty()) { + NN_LOG_ERROR("Failed to get tick as failed to get line " << prefix); + return FAILURE_RET; + } + + std::string lastWord = splitVec[splitVec.size() - 1]; + auto index = lastWord.find(gHZ); + if (index == std::string::npos) { + NN_LOG_ERROR("Failed to get tick as failed to get " << gHZ); + return FAILURE_RET; + } + + auto strGhz = lastWord.substr(0, index); + float fhz = 0.0f; + if (!NetFunc::NN_Stof(strGhz, fhz)) { + NN_LOG_ERROR("Failed to get tick as failed to convert " << strGhz << " to float"); + return FAILURE_RET; + } + + NN_LOG_TRACE_INFO("ghz " << strGhz << ", " << fhz << ", " << static_cast(fhz * 1000L)); + + return static_cast(fhz * 1000L); + } + + /* + * @brief Get monotonic time in ns, is not absolution time + */ + static inline uint64_t TimeNs() + { + const static int32_t TICK_PER_US = InitTickUs(); + if (TICK_PER_US == INIT_FAILURE_RET) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (static_cast(ts.tv_sec)) * 1000000000L + ts.tv_nsec; + } + return __rdtsc() * 1000L / TICK_PER_US; + } + + /* + * @brief Get monotonic time in us, is not absolution time + */ + static inline uint64_t TimeUs() + { + const static int32_t TICK_PER_US = InitTickUs(); + if (TICK_PER_US == INIT_FAILURE_RET) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (static_cast(ts.tv_sec)) * 1000000L + ts.tv_nsec / 1000L; + } + return __rdtsc() / TICK_PER_US; + } + + /* + * @brief Get monotonic time in ms, is not absolution time + */ + static inline uint64_t TimeMs() + { + const static int32_t TICK_PER_US = InitTickUs(); + if (TICK_PER_US == INIT_FAILURE_RET) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (static_cast(ts.tv_sec)) * 1000L + ts.tv_nsec / 1000000L; + } + return __rdtsc() / (TICK_PER_US * 1000L); + } + + /* + * @brief Get monotonic time in sec, is not absolution time + */ + static inline uint64_t TimeSec() + { + const static int32_t TICK_PER_US = InitTickUs(); + if (TICK_PER_US == INIT_FAILURE_RET) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (static_cast(ts.tv_sec)) + ts.tv_nsec / 1000000000L; + } + return __rdtsc() / (TICK_PER_US * 1000000L); + } + +#endif /* __x86_64__ || __aarch64__ */ + +#else /* USE_PROCESS_MONOTONIC */ +public: + template static int32_t InitTickUs() + { + return NN_OK; + } + + static inline uint64_t TimeNs() + { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (static_cast(ts.tv_sec)) * 1000000000L + ts.tv_nsec; + } + + static inline uint64_t TimeUs() + { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (static_cast(ts.tv_sec)) * 1000000L + ts.tv_nsec / 1000L; + } + + static inline uint64_t TimeMs() + { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (static_cast(ts.tv_sec)) * 1000L + ts.tv_nsec / 1000000L; + } + + static inline uint64_t TimeSec() + { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (static_cast(ts.tv_sec)) + ts.tv_nsec / 1000000000L; + } +#endif /* USE_PROCESS_MONOTONIC */ +}; +} +} + +#endif // OCK_HCOM_NET_MONOTONIC_H diff --git a/src/common/net_obj_pool.h b/src/common/net_obj_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..c926224059fd5b6c623de7cd40e2b8f3759b8ddf --- /dev/null +++ b/src/common/net_obj_pool.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_COMM_OBJ_POOL_H_23444 +#define OCK_COMM_OBJ_POOL_H_23444 + +#include + +#include "hcom_def.h" +#include "hcom_log.h" +#include "hcom_utils.h" + +namespace ock { +namespace hcom { +template class NetObjPool { +public: + explicit NetObjPool(const std::string &name, uint32_t capacity) : mName(name), mObjRB(capacity) {} + ~NetObjPool() + { + UnInitialize(); + } + + NResult Initialize() + { + std::lock_guard locker(mInitMutex); + if (mObjs != nullptr) { + NN_LOG_INFO("Obj pool already initialized"); + return NN_OK; + } + + mObjs = static_cast(malloc(sizeof(T) * mObjRB.Capacity())); + if (mObjs == nullptr) { + NN_LOG_ERROR("Failed to new objects for pool, probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + auto result = mObjRB.Initialize(); + if (result != NN_OK) { + NN_LOG_ERROR("Failed to initialize ring buffer, result " << result); + free(mObjs); + mObjs = nullptr; + return result; + } + + for (uint32_t i = 0; i < mObjRB.Capacity(); i++) { + mObjRB.PushBack(&(mObjs[i])); + } + + mObjsEnd = &(mObjs[mObjRB.Capacity() - 1]); + return NN_OK; + } + + void UnInitialize() + { + std::lock_guard locker(mInitMutex); + if (mObjs != nullptr) { + free(mObjs); + mObjs = nullptr; + } + + mObjRB.UnInitialize(); + } + + inline bool Dequeue(T *&item) + { + if (NN_LIKELY(mObjRB.PopFront(item))) { + return true; + } + + // new one + NN_LOG_INFO("Create new object from malloc lib for pool " << mName << " as pool is fully"); + item = static_cast(malloc(sizeof(T))); + if (NN_UNLIKELY(item == nullptr)) { + NN_LOG_INFO("Create new object from malloc lib for pool " << mName << ", probably out of memory"); + return false; + } + return true; + } + + inline void Enqueue(T *item) + { + if (NN_LIKELY(item >= mObjs && item <= mObjsEnd)) { + mObjRB.PushFront(item); + } else { + if (NN_UNLIKELY(item != nullptr)) { + free(item); + } + } + } + + std::string ToString() + { + std::ostringstream oss; + oss << "obj pool " << mName << ", capacity " << mObjRB.Capacity() << ", size " << mObjRB.Size() << + ", addresses " << mObjs; + return oss.str(); + } + +private: + std::mutex mInitMutex; + T *mObjs = nullptr; + T *mObjsEnd = nullptr; + std::string mName; + NetRingBuffer mObjRB; +}; +} +} + +#endif // OCK_COMM_OBJ_POOL_H_23444 diff --git a/src/common/net_pgtable.cpp b/src/common/net_pgtable.cpp new file mode 100644 index 0000000000000000000000000000000000000000..87540a65744ac3cc2fe2e01ced431a1c9e8dbd48 --- /dev/null +++ b/src/common/net_pgtable.cpp @@ -0,0 +1,728 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include "hcom_def.h" +#include "hcom_err.h" +#include "hcom_log.h" +#include "securec.h" +#include "net_pgtable.h" + +namespace ock { +namespace hcom { +constexpr uint64_t BitMask(uint32_t i) +{ + return (i >= NN_NO64) ? 0 : (1UL << i); +} + +constexpr uint64_t OrderMask(uint32_t i) +{ + return (i >= NN_NO64) ? ~0 : (BitMask(i) - 1); +} + +constexpr size_t AlignDownPow2(size_t n, size_t alignment) +{ + return n & ~(alignment - 1); +} + +constexpr size_t AlignUpPow2(size_t n, size_t alignment) +{ + return AlignDownPow2(n + alignment - 1, alignment); +} + +constexpr bool IsAddrAligned(size_t addr) +{ + return (addr & (PAGE_ADDR_ALIGN_MIN - 1)) == 0; +} +/** + * @brief Returns the position of the most significant set bit in a uint64_t. + * @param n Input value, must be greater than 0. + * @return The bit position (0-indexed). For example: 1 -> 0, 8 (1000b) -> 3. + */ +static inline unsigned HighestBitPosition(uint64_t n) +{ + return NN_NO63 - static_cast(__builtin_clzll(n)); +} +/** + * @brief Returns the position of the least significant set bit in a uint64_t. + * @param n Input value, must be greater than 0. + * @return The bit position (0-indexed). For example: 8 (1000b) -> 3, 6 (110b) -> 1. + */ +static inline unsigned LowestBitPosition(uint64_t n) +{ + return static_cast(__builtin_ctzll(n)); +} + +static inline bool IsValidPtePointer(void *ptr) +{ + return !(reinterpret_cast(ptr) & (PGT_ENTRY_MIN_ALIGN - 1)); +} + +static inline void AdvanceAddrByOrder(PgtAddress &address, uint32_t order) +{ + if (order >= NN_NO64) { + NN_LOG_ERROR("Failed pgt address advance order is >= 64"); + return; + } + + address += 1uL << order; +} + +PgtDir *PgTable::PgtDirAlloc() +{ + if (mPgdAllocCb == nullptr || mPgdReleaseCb == nullptr) { + NN_LOG_ERROR("Failed to allocate page table directory, as allocate or release callback is null"); + return nullptr; + } + auto pgd = mPgdAllocCb(*this); + if (pgd == nullptr) { + NN_LOG_ERROR("Failed to allocate page table directory, as pgd ptr is null"); + return nullptr; + } + if (!IsValidPtePointer(pgd)) { + NN_LOG_ERROR("Failed to allocate page table directory, as pgd ptr is not align"); + mPgdReleaseCb(*this, pgd); + return nullptr; + } + + if (memset_s(pgd, sizeof(PgtDir), 0, sizeof(PgtDir)) != 0) { + NN_LOG_ERROR("Failed to allocate page table directory, as memset_s pgd failed"); + mPgdReleaseCb(*this, pgd); + return nullptr; + } + + return pgd; +} + +void PgTable::PgtDirRelease(PgtDir *pgd) +{ + if (pgd == nullptr) { + NN_LOG_ERROR("Failed to release page table directory, as dir is null"); + return; + } + if (mPgdReleaseCb == nullptr) { + NN_LOG_ERROR("Failed to release page table directory, as release callback is null"); + return; + } + mPgdReleaseCb(*this, pgd); + pgd = nullptr; +} + +void PgTable::PgtDumpSubtree(uint32_t indent, const PgtEntry &pgtEntry, uint32_t pteIndex, PgtAddress base, + PgtAddress mask, uint32_t shift) +{ + if (pgtEntry.HasFlag(EntryFlags::REGION)) { + auto region = pgtEntry.GetRegion(); + if (region == nullptr) { + return; + } + NN_LOG_DEBUG("indent: " << indent << " pte_index:" << pteIndex << " shift:" << shift << " is region"); + } else if (pgtEntry.HasFlag(EntryFlags::DIR)) { + auto pgd = pgtEntry.GetDir(); + if (pgd == nullptr) { + return; + } + NN_LOG_DEBUG("indent: " << indent << " pte_index:" << pteIndex << " dir count:" + << pgd->count << " shift:" << shift); + shift -= PTE_SHIFT_PER_DIR; + mask |= PTE_INDEX_MASK << shift; + for (uint32_t i = 0; i < PTE_ENTRY_NUM_PER_DIR; ++i) { + PgtDumpSubtree(indent + NN_NO2, pgd->entries[i], i, base | (i << shift), mask, shift); + ++base; + } + } else { + NN_LOG_DEBUG("indent: " << indent << " pte_index:" << pteIndex << " not present"); + } +} + +void PgTable::Dump() +{ + NN_LOG_INFO("pgtable dump, shift:" << mIndexShift << ", count:" << mRegionCount); + PgtDumpSubtree(0, mRootEntry, 0, mVirBaseAddr, mSpaceMask, mIndexShift); +} + +void PgTable::PgTableReset() +{ + mVirBaseAddr = 0; + mSpaceMask = (static_cast(-1)) << PAGE_SHIFT_MIN; + mIndexShift = PAGE_SHIFT_MIN; +} + +void PgTable::PgtEnsureCapacity(uint32_t order, PgtAddress address) +{ + // Ensure the page table is deep enough to support the address order + while (mIndexShift < order) { + if (!PgtExpand()) { + return; + } + } + + if (!mRootEntry.IsPresent()) { + mVirBaseAddr = address & mSpaceMask; + NN_LOG_INFO("pgtable initialize, shift:" << mIndexShift << ", count:" << mRegionCount); + } else { + // Ensure the target address falls within the current pgtable mVirBaseAddr address range + while ((address & mSpaceMask) != mVirBaseAddr) { + if (!PgtExpand()) { + return; + } + } + } +} + +bool PgTable::PgtExpand() +{ + // shift为地址最高位,最大值为[PGT_ADDR_ORDER_MAX - PTE_SHIFT_PER_DIR] + if (mIndexShift > (PGT_ADDR_ORDER_MAX - PTE_SHIFT_PER_DIR)) { + NN_LOG_ERROR("failed to expand pgtable, shift is over max " << PGT_ADDR_ORDER_MAX - PTE_SHIFT_PER_DIR); + return false; + } + + // 如果根节点已存在,将其下沉为子目录 + if (mRootEntry.IsPresent()) { + PgtDir *pgd = PgtDirAlloc(); + if (pgd == nullptr) { + NN_LOG_ERROR("failed to expand pgtable, allocate pgt dir error"); + return false; + } + pgd->entries[(mVirBaseAddr >> mIndexShift) & PTE_INDEX_MASK] = mRootEntry; + pgd->count = 1; + if (!mRootEntry.SetDir(*pgd)) { + PgtDirRelease(pgd); + return false; + } + } + + mIndexShift += PTE_SHIFT_PER_DIR; + mSpaceMask <<= PTE_SHIFT_PER_DIR; // example 0xF0 -> 0xFF0 + mVirBaseAddr &= mSpaceMask; + + NN_LOG_INFO("pgtable expand success, shift:" << mIndexShift << ", count:" << mRegionCount); + return true; +} + +bool PgTable::PgtShrink() +{ + if (!mRootEntry.IsPresent()) { + PgTableReset(); + NN_LOG_INFO("pgtable shrink, shift:" << mIndexShift << ", count:" << mRegionCount); + return false; + } + if (!mRootEntry.HasFlag(EntryFlags::DIR)) { + return false; + } + + auto pgd = mRootEntry.GetDir(); + if (pgd == nullptr || pgd->count != 1) { + return false; + } + + PgtEntry *pgtEntry = nullptr; + uint32_t idx = 0; + + // 当页表某层目录只有一个有效entry时,找到这个entry,并移除这一层 + for (uint32_t i = 0; i < PTE_ENTRY_NUM_PER_DIR; ++i) { + if (pgd->entries[i].IsPresent()) { + pgtEntry = &pgd->entries[i]; + idx = i; + break; // 因为 count == 1,最多只有一个 + } + } + + if (pgtEntry == nullptr) { + NN_LOG_ERROR("pgtable shrink failed, pgd entry is null"); + PgtDirRelease(pgd); + return false; + } + + if (mIndexShift < PTE_SHIFT_PER_DIR) { + NN_LOG_ERROR("pgtable shrink failed, invalid shift:" << mIndexShift); + PgtDirRelease(pgd); + return false; + } + + mIndexShift -= PTE_SHIFT_PER_DIR; + mVirBaseAddr |= static_cast(idx) << mIndexShift; // 将idx的偏移 + mSpaceMask |= PTE_INDEX_MASK << mIndexShift; // 缩小mask不再覆盖最高的PTE_SHIFT_PER_DIR位置 + mRootEntry = *pgtEntry; + NN_LOG_INFO("pgtable shrink, shift:" << mIndexShift << ", count:" << mRegionCount); + PgtDirRelease(pgd); + return true; +} + +static NResult ValidatePage(PgtAddress address, uint32_t order) +{ + // 检查起始地址是否与页大小对齐 + if ((address & ((1uL << order) - 1)) != 0) { + NN_LOG_ERROR("failed to check address, is not align with page order"); + return NN_INVALID_PARAM; + } + // 检查order是否为页表层级结构允许的 必须为[PAGE_SHIFT_MIN + k * PTE_SHIFT_PER_DIR] + // 例如:起始阶 PAGE_SHIFT_MIN=4, 阶差 PTE_SHIFT_PER_DIR=4 → 对齐合法阶为 4 - 8 - 12 - 16 - 20... + if (((order - PAGE_SHIFT_MIN) % PTE_SHIFT_PER_DIR) != 0) { + NN_LOG_ERROR("failed to check order " << order); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +NResult PgTable::PgtCheckEntryDir(PgtEntry &pgtEntry, uint32_t shift, uint32_t order) +{ + if (pgtEntry.HasFlag(EntryFlags::REGION)) { + NN_LOG_ERROR("Failed to insert entry, order is not equal to shift but pgtEntry is region."); + return NN_ERROR; + } + + if (shift < PTE_SHIFT_PER_DIR + order) { + NN_LOG_ERROR("shift is less than PTE_SHIFT_PER_DIR + order"); + return NN_ERROR; + } + return NN_OK; +} + +/** + * @brief Returns the smallest page table level order that can cover the range [start, end). + * If start == 0 and end == 0, it represents the entire address space. + * @param start The start address of the range (inclusive) + * @param end The end address of the range (exclusive) + * @return The page order (order), or -1 on failure + */ +static uint32_t GetNextPageOrder(PgtAddress start, PgtAddress end) +{ + if (!IsAddrAligned(start) || !IsAddrAligned(end)) { + NN_LOG_ERROR("failed to get next page order, start or end address is not aligned"); + return -1; + } + + uint32_t maxOrder = 0; + if ((end == 0) && (start == 0)) { + // entire address space + maxOrder = PGT_ADDR_ORDER_MAX; + } else if (end == start) { + // min page size + maxOrder = PAGE_SHIFT_MIN; + } else { + maxOrder = HighestBitPosition(end - start); + // The lowest set bit in the start address determines its alignment order. + // For example: start = 0x1000 (binary ...0001 0000 0000 0000), LowestBitPosition = 12 → 4KB aligned. + // This means a page larger than 4KB cannot be used, otherwise it would cross an alignment boundary. + if (start) { + maxOrder = std::min(LowestBitPosition(start), maxOrder); + } + } + + if ((maxOrder < PAGE_SHIFT_MIN) || (maxOrder > PGT_ADDR_ORDER_MAX)) { + NN_LOG_ERROR("failed to get next page order, log2Len is invalid"); + return -1; + } + + // aligned down maxOrder to the nearest valid page table level. + uint32_t alignedOrder = ((maxOrder - PAGE_SHIFT_MIN) / PTE_SHIFT_PER_DIR) * PTE_SHIFT_PER_DIR + PAGE_SHIFT_MIN; + NN_LOG_DEBUG("Calculate max order is " << maxOrder << " alignedOrder is " << alignedOrder); + return alignedOrder; +} + +/** + * Insert a variable-size page to the page table. + * + * @param address address to insert + * @param order page size to insert - should be k*PTE_SHIFT for a certain k + * @param region region to insert + */ +NResult PgTable::InsertPage(PgtAddress address, uint32_t order, PgtRegion ®ion) +{ + NN_LOG_DEBUG("begin to insert page, order " << order << " region " << region.key); + + if (ValidatePage(address, order) != NN_OK) { + return NN_INVALID_PARAM; + } + + PgtEnsureCapacity(order, address); + + PgtDir dummyPgd = {}; + PgtDir *currentDir = &dummyPgd; + uint32_t currentShift = mIndexShift; + PgtEntry *pgtEntry = &mRootEntry; + while (order != currentShift) { + if (PgtCheckEntryDir(*pgtEntry, currentShift, order) != NN_OK) { + goto ROLLBACK; + } + + if (!pgtEntry->IsPresent()) { + ++currentDir->count; + auto dir = PgtDirAlloc(); + if (dir == nullptr) { + goto ROLLBACK; + } + if (!pgtEntry->SetDir(*dir)) { + PgtDirRelease(dir); + goto ROLLBACK; + } + } + + currentDir = pgtEntry->GetDir(); + if (currentDir == nullptr) { + goto ROLLBACK; + } + currentShift -= PTE_SHIFT_PER_DIR; + uint32_t index = (address >> currentShift) & PTE_INDEX_MASK; + if (index >= PTE_ENTRY_NUM_PER_DIR) { + goto ROLLBACK; + } + pgtEntry = ¤tDir->entries[index]; + } + + if (pgtEntry->IsPresent() || !pgtEntry->SetRegion(region)) { + NN_LOG_ERROR("Failed to insert entry, entry already exist or not set region flag."); + goto ROLLBACK; + } + + if (currentDir) { + ++currentDir->count; + } + + NN_LOG_DEBUG("insert page success, order " << order << " region " << region.key); + return NN_OK; +ROLLBACK: + while (PgtShrink()) {} + return NN_ERROR; +} + +NResult PgTable::UnlinkRegion(PgtAddress address, uint32_t order, PgtDir &pgd, PgtEntry &pgtEntry, uint32_t shift, + PgtRegion ®ion) +{ + if (pgtEntry.HasFlag(EntryFlags::REGION)) { + if (shift != order) { + return NN_ERROR; + } + if (pgtEntry.GetRegion() != ®ion) { + return NN_ERROR; + } + + --pgd.count; + pgtEntry.Clear(); + return NN_OK; + } else if (pgtEntry.HasFlag(EntryFlags::DIR)) { + auto nextDir = pgtEntry.GetDir(); + if (nextDir == nullptr) { + return NN_ERROR; + } + uint32_t nextShift = shift - PTE_SHIFT_PER_DIR; + uint32_t index = (address >> nextShift) & PTE_INDEX_MASK; + if (index >= PTE_ENTRY_NUM_PER_DIR) { + return NN_ERROR; + } + auto nextPte = &nextDir->entries[index]; + + auto ret = UnlinkRegion(address, order, *nextDir, *nextPte, nextShift, region); + if (ret != NN_OK) { + return ret; + } + + if (nextDir->count == 0) { + pgtEntry.Clear(); + --pgd.count; + if (mPgdReleaseCb != nullptr) { + mPgdReleaseCb(*this, nextDir); + } else { + NN_LOG_WARN("unable to call dir release cb, which is nullptr"); + } + } + return NN_OK; + } + return NN_ERROR; +} + +NResult PgTable::RemovePage(PgtAddress address, uint32_t order, PgtRegion ®ion) +{ + if (ValidatePage(address, order) != NN_OK) { + return NN_INVALID_PARAM; + } + + if ((address & mSpaceMask) != mVirBaseAddr) { + NN_LOG_ERROR("no elem in address, as address mVirBaseAddr is not pgtable base"); + return NN_ERROR; + } + + PgtDir pgd = {}; + auto ret = UnlinkRegion(address, order, pgd, mRootEntry, mIndexShift, region); + if (ret != NN_OK) { + return ret; + } + + while (PgtShrink()) {} + return NN_OK; +} + +NResult PgTable::Insert(PgtRegion ®ion) +{ + NN_LOG_DEBUG("begin to add region " << region.key); + + uint32_t order = 0; + PgtAddress address = region.start; + PgtAddress end = region.end; + if ((address >= end) || !IsAddrAligned(address) || !IsAddrAligned(end)) { + NN_LOG_ERROR("failed to add region maybe region start > end, or address is not 16-byte aligned"); + return NN_INVALID_PARAM; + } + + while (address < end) { + order = GetNextPageOrder(address, end); + if (order < 0 || order >= NN_NO64) { + NN_LOG_ERROR("Failed to add region, get next page order is less than 0 or over 64"); + goto ROLLBACK; + } + if (InsertPage(address, order, region) != NN_OK) { + NN_LOG_ERROR("failed to insert page."); + goto ROLLBACK; + } + + AdvanceAddrByOrder(address, order); + } + ++mRegionCount; + + NN_LOG_INFO("pgtable insert success, shift:" << mIndexShift << ", count:" << mRegionCount); + return NN_OK; + +ROLLBACK: + /* Revert all pages we've inserted by now */ + end = address; + address = region.start; + while (address < end) { + order = GetNextPageOrder(address, end); + RemovePage(address, order, region); + AdvanceAddrByOrder(address, order); + } + return NN_ERROR; +} + +NResult PgTable::Remove(PgtRegion ®ion) +{ + NN_LOG_DEBUG("begin to remove region " << region.key); + + PgtAddress address = region.start; + PgtAddress end = region.end; + if ((address >= end) || !IsAddrAligned(address) || !IsAddrAligned(end)) { + NN_LOG_ERROR("failed to remove region no element with this param."); + return NN_ERROR; + } + + while (address < end) { + uint32_t order = GetNextPageOrder(address, end); + if (order >= NN_NO64) { + NN_LOG_ERROR("Failed pgt table get next page order is >= 64"); + return NN_ERROR; + } + auto ret = RemovePage(address, order, region); + if (ret != NN_OK) { + /* Cannot be partially removed */ + if (address != region.start) { + return NN_ERROR; + } + return ret; + } + + AdvanceAddrByOrder(address, order); + } + + if (mRegionCount > 0) { + --mRegionCount; + } + + NN_LOG_INFO("pgtable remove success, shift:" << mIndexShift << ", count:" << mRegionCount); + return NN_OK; +} + +PgtRegion *PgTable::Lookup(PgtAddress address) const +{ + NN_LOG_DEBUG("begin to lookup pgtable"); + + if ((address & mSpaceMask) != mVirBaseAddr) { + NN_LOG_ERROR("failed to lookup pgtable, as address is not mapped by the page table"); + return nullptr; + } + if (!mRootEntry.IsPresent()) { + NN_LOG_ERROR("failed to lookup pgtable, mRootEntry is nullptr"); + return nullptr; + } + const PgtEntry *currentEntry = &mRootEntry; + uint32_t currentShift = mIndexShift; + // Descend dir level by level until a Region is found + while (true) { + if (currentEntry->HasFlag(EntryFlags::REGION)) { + auto region = currentEntry->GetRegion(); + if (region == nullptr) { + NN_LOG_ERROR("failed to lookup pgtable, as region is null"); + return nullptr; + } + if ((address < region->start) || (address >= region->end)) { + NN_LOG_ERROR("failed to lookup pgtable as address is not in region"); + return nullptr; + } + return region; + } + if (currentEntry->HasFlag(EntryFlags::DIR)) { + auto dir = currentEntry->GetDir(); + if (dir == nullptr) { + NN_LOG_ERROR("failed to lookup pgtable, as dir is null"); + return nullptr; + } + currentShift -= PTE_SHIFT_PER_DIR; + uint32_t index = (address >> currentShift) & PTE_INDEX_MASK; + if (index >= PTE_ENTRY_NUM_PER_DIR) { + NN_LOG_ERROR("failed to lookup entry index is over entry array bound"); + return nullptr; + } + currentEntry = &dir->entries[index]; + continue; + } + NN_LOG_DEBUG("Lookup failed: entry is invalid no REGION or DIR flag"); + return nullptr; + } +} + +void PgTable::SearchSubtree(PgtAddress address, uint32_t order, const PgtEntry &pgtEntry, uint32_t currentShift, + PgtSearchCb cb, void *arg, PgtRegion *&lastRegion) +{ + NN_LOG_DEBUG("Begin to search subtree, order " << order << " currentShift " << currentShift << + " entry is region " << pgtEntry.HasFlag(EntryFlags::REGION)); + if (pgtEntry.HasFlag(EntryFlags::REGION)) { + auto region = pgtEntry.GetRegion(); + if (region == nullptr || lastRegion == region) { + return; + } + if (lastRegion != nullptr && region->start < lastRegion->end) { + NN_LOG_ERROR("Failed to search, as regions is overlap, now region start is less than previous region end"); + return; + } + lastRegion = region; + + // ensure region is not overlaps with address [address, address + 2^order - 1] + if (std::max(region->start, address) > std::min(region->end - 1, address + OrderMask(order))) { + NN_LOG_ERROR("Failed to search, region start end is not overlaps with the address"); + return; + } + if (cb == nullptr) { + NN_LOG_WARN("Unable to call the search cb, as cb is null"); + return; + } + cb(*this, *region, arg); + } else if (pgtEntry.HasFlag(EntryFlags::DIR)) { + auto dir = pgtEntry.GetDir(); + if (dir == nullptr) { + NN_LOG_ERROR("Failed to search, current dir is nullptr."); + return; + } + if (currentShift < PTE_SHIFT_PER_DIR) { + NN_LOG_ERROR("Failed to search, current shift " << currentShift + << " is less than entry mIndexShift per level " << PTE_SHIFT_PER_DIR); + return; + } + + uint32_t nextShift = currentShift - PTE_SHIFT_PER_DIR; + if (order < currentShift) { + // search region is less than current dir span, it can only in dir sub entry + uint32_t index = (address >> nextShift) & PTE_INDEX_MASK; + if (index >= PTE_ENTRY_NUM_PER_DIR) { + NN_LOG_ERROR("Failed to search, index is over dir entry size."); + return; + } + auto nextPte = &dir->entries[index]; + SearchSubtree(address, order, *nextPte, nextShift, cb, arg, lastRegion); + } else { + // search region covers the range of current dir, need to search all entries. + for (const auto& nextPte : dir->entries) { + SearchSubtree(address, order, nextPte, nextShift, cb, arg, lastRegion); + } + } + } +} + +void PgTable::SearchRange(PgtAddress from, PgtAddress to, PgtSearchCb cb, void *arg) +{ + // 确保搜索操作在页对齐的边界上进行 + PgtAddress address = AlignDownPow2(from, PAGE_ADDR_ALIGN_MIN); + PgtAddress end = AlignUpPow2(to, PAGE_ADDR_ALIGN_MIN); + + // 与页表实际管理的范围 [base, base + 2^shift) 进行交集操作,确保搜索不会超出页表的边界 + if (mIndexShift < (sizeof(uint64_t) * NN_NO8)) { + address = std::max(address, mVirBaseAddr); + end = std::min(end, mVirBaseAddr + BitMask(mIndexShift)); + } else { + if (mVirBaseAddr != 0) { + NN_LOG_ERROR("Failed to search range,shift is whole address base should be 0"); + return; + } + } + + PgtRegion *lastRegion = nullptr; + while (address <= to) { + uint32_t order = GetNextPageOrder(address, end); + if ((address & mSpaceMask) == mVirBaseAddr) { + SearchSubtree(address, order, mRootEntry, mIndexShift, cb, arg, lastRegion); + } + + if (order >= PGT_ADDR_ORDER_MAX) { + break; + } + + AdvanceAddrByOrder(address, order); + } +} + +static void PgtCleanupCallback(const PgTable &pgtable, PgtRegion ®ion, void *arg) +{ + if (arg == nullptr) { + NN_LOG_ERROR("Failed to call the page table purge callback, arg is nullptr"); + return; + } + auto *regionVector = static_cast *>(arg); + regionVector->push_back(®ion); + NN_LOG_DEBUG("call the clean cb success, push region to vector region " << region.key); +} + +void PgTable::Cleanup() +{ + NN_LOG_INFO("begin to cleanup pgtable numRegions " << mRegionCount); + if (mRegionCount == 0) { + NN_LOG_INFO("page table is empty, nothing to cleanup."); + return; + } + std::vector cleanRegions; + cleanRegions.reserve(mRegionCount); + + PgtAddress from = mVirBaseAddr; + PgtAddress to = mVirBaseAddr + (BitMask(mIndexShift) & mSpaceMask) - 1; + SearchRange(from, to, PgtCleanupCallback, &cleanRegions); + if (cleanRegions.size() != mRegionCount) { + NN_LOG_ERROR("Found size " << cleanRegions.size() << " regions, expected size" << mRegionCount); + return; + } + + for (auto region : cleanRegions) { + auto ret = Remove(*region); + if (ret != NN_OK) { + NN_LOG_ERROR("failed to remove region during cleanup"); + } + } + + // 最终检查状态 + if (mRootEntry.IsPresent()) { + NN_LOG_WARN("unable to purge pgtable, entry already exist after clean"); + } + if (mIndexShift != PAGE_SHIFT_MIN || mVirBaseAddr != 0 || mRegionCount != 0) { + NN_LOG_WARN("unable to purge pgtable, patable mIndexShift:" << mIndexShift << " base:" + << mVirBaseAddr << " num regions " << mRegionCount); + } +} +} +} diff --git a/src/common/net_pgtable.h b/src/common/net_pgtable.h new file mode 100644 index 0000000000000000000000000000000000000000..37ead97c14e61ce4028595f6bad64eb1df806701 --- /dev/null +++ b/src/common/net_pgtable.h @@ -0,0 +1,366 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_NET_PGTABLE_H_ +#define HCOM_NET_PGTABLE_H_ + +#include "hcom_def.h" +#include "hcom_err.h" +#include "hcom_log.h" + +namespace ock { +namespace hcom { +/* Define the address type */ +using PgtAddress = uintptr_t; + +enum class EntryFlags : PgtAddress { + REGION = 1uL << 0, + DIR = 1uL << 1 +}; + +/* Address alignment requirements */ +constexpr uint32_t PAGE_SHIFT_MIN = 4; +constexpr PgtAddress PAGE_ADDR_ALIGN_MIN = (1uL << PAGE_SHIFT_MIN); +constexpr uint32_t PGT_ADDR_ORDER_MAX = (sizeof(PgtAddress) * 8); // Total number of bits in the PgtAddress type +constexpr PgtAddress PGT_ADDR_MAX = (static_cast(-1)); // maximum addressable space of PgtAddress + +/* Page table entry/directory constants */ +constexpr uint32_t PTE_SHIFT_PER_DIR = 4; +constexpr uint32_t PTE_ENTRY_NUM_PER_DIR = (1uL << (PTE_SHIFT_PER_DIR)); +constexpr PgtAddress PTE_INDEX_MASK = (PTE_ENTRY_NUM_PER_DIR - 1); + +/* Page table pointers constants and flags */ +constexpr PgtAddress PGT_ENTRY_FLAGS_MASK = + (static_cast(EntryFlags::REGION) | static_cast(EntryFlags::DIR)); + +constexpr PgtAddress PGT_ENTRY_PTR_MASK = (~PGT_ENTRY_FLAGS_MASK); +constexpr PgtAddress PGT_ENTRY_MIN_ALIGN = (PGT_ENTRY_FLAGS_MASK + 1); + +constexpr int Log2Static(uint64_t n) +{ + return (n <= 1) ? 0 : 1 + Log2Static(n >> 1); +} + +constexpr bool IsPowerOfTwoOrZero(uint64_t n) +{ + return (n & (n - 1)) == 0; +} + +constexpr bool IsPowerOfTwo(uint64_t n) +{ + return n > 0 && (n & (n - 1)) == 0; +} + +using PgTable = struct NetPgTable; +using PgtDir = struct PgtDir; + +/** + * Memory region in the page table. + * The structure itself, and the pointers in it, must be aligned to 2^PTR_SHIFT. + */ +struct PgtRegion { + PgtAddress start; /* *< Region start address */ + PgtAddress end; /* *< Region end address */ + uint64_t key; + uint64_t token; +}; + +/** + * Page table entry: + * + * +--------------------+---+---+ + * | pointer (MSB) | d | r | + * +--------------------+---+---+ + * | | | | + * 64 2 1 0 + * + */ +class PgtEntry { +public: + PgtEntry() : mValue(0) {} + + ~PgtEntry() + { + mValue = 0; + } + + PgtRegion *GetRegion() const + { + if (!HasFlag(EntryFlags::REGION)) { + NN_LOG_ERROR("Failed to get region, value is not set region flag"); + return nullptr; + } + return reinterpret_cast(mValue & PGT_ENTRY_PTR_MASK); + } + + bool SetRegion(PgtRegion ®ion) + { + if (!CheckPtrValueAlign(®ion)) { + NN_LOG_ERROR("Failed to check region, value is not align"); + return false; + } + + SetPointerAndFlags(®ion, EntryFlags::REGION); + return true; + } + + PgtDir *GetDir() const + { + if (!HasFlag(EntryFlags::DIR)) { + NN_LOG_ERROR("Failed to get directory, value is not set dir flag"); + return nullptr; + } + return reinterpret_cast(mValue & PGT_ENTRY_PTR_MASK); + } + + bool SetDir(PgtDir &dir) + { + if (!CheckPtrValueAlign(&dir)) { + NN_LOG_ERROR("Failed to check dir, value is not align"); + return false; + } + SetPointerAndFlags(&dir, EntryFlags::DIR); + return true; + } + + bool HasFlag(EntryFlags flag) const + { + return (mValue & static_cast(flag)) != 0; + } + + void SetFlag(EntryFlags flag) + { + mValue |= static_cast(flag); + } + + void ClearFlag(EntryFlags flag) + { + mValue &= ~static_cast(flag); + } + + bool IsPresent() const + { + constexpr PgtAddress PRESENT_MASK = + static_cast(EntryFlags::REGION) | static_cast(EntryFlags::DIR); + return (mValue & PRESENT_MASK) != 0; + } + + void Clear() + { + mValue = 0; + } + +private: + bool CheckPtrValueAlign(void *ptr) + { + return !(reinterpret_cast(ptr) & (PGT_ENTRY_MIN_ALIGN - 1)); + } + + void SetPointerAndFlags(void *ptr, EntryFlags flag) + { + mValue = (reinterpret_cast(ptr) & PGT_ENTRY_PTR_MASK) | static_cast(flag); + } + + PgtAddress mValue = 0; +}; + +/** + * Page table directory. + * Each directory contains a fixed number of page table entries (PTEs) and tracks + * the count of valid entries. + */ +struct PgtDir { + PgtEntry entries[PTE_ENTRY_NUM_PER_DIR]; // Array of page table entries + uint32_t count; // Number of valid (present) entries in this directory +}; + +/** + * Callback type: Allocates a page table directory. + * + * This function is responsible for allocating memory for a new PgtDir. + * + * @param pgtable [in] Reference to the page table requesting allocation. + * @return Pointer to the newly allocated PgtDir, or nullptr on failure. + * The returned pointer must be aligned to PGT_ENTRY_ALIGN bytes. + */ +using PgDirAllocCb = PgtDir *(*)(const PgTable &pgtable); + +/** + * Callback type: Releases a page table directory. + * + * Frees memory associated with a previously allocated PgtDir. + * + * @param pgtable [in] Reference to the page table that owns the directory. + * @param pgdir [in] Pointer to the directory to release. May be nullptr. + */ +using PgDirReleaseCb = void (*)(const PgTable &pgtable, PgtDir *pgdir); + +/** + * Callback type: Invoked when a valid memory region is found during traversal. + * + * Used in search or walk operations to process matching regions. + * + * @param pgtable [in] Reference to the current page table. + * @param region [in] The memory region that was found (contains base, size, attrs). + * @param arg [in] User-provided context or data (e.g., accumulator, flag). + */ +using PgtSearchCb = void (*)(const PgTable &pgtable, PgtRegion ®ion, void *arg); + +/** + * The page table data structure organizes non-overlapping memory regions + * using an efficient radix tree, optimized for large and/or naturally aligned regions. + * + * Each page table entry (PTE) can be in one of three states: + * - Points to a memory region (indicated by PGT_PTE_FLAG_REGION) + * - Points to a child directory (indicated by PGT_PTE_FLAG_DIR) + * - Empty (null) (if neither flag is set) + * + * Entries are mutually exclusive: a PTE cannot be both a region and a directory. + * This ensures a clear hierarchical structure and prevents ambiguity during traversal. + */ +class NetPgTable { +public: + /** + * Constructor. + * Initializes the page table with allocation and release callbacks. + * + * @param [in] allocCb Callback for allocating page directories. + * @param [in] releaseCb Callback for releasing page directories. + */ + explicit NetPgTable(PgDirAllocCb allocCb, PgDirReleaseCb releaseCb) + : mRegionCount(0), + mRootEntry {}, + mVirBaseAddr(0), + mSpaceMask(0), + mIndexShift(PAGE_SHIFT_MIN) + { + static_assert(IsPowerOfTwo(PGT_ENTRY_MIN_ALIGN)); + + static_assert(IsPowerOfTwoOrZero(PGT_ADDR_MAX + 1)); + // We must cover all bits of the address up to ADDR_MAX + static_assert(((Log2Static(PGT_ADDR_MAX) + 1 - PAGE_SHIFT_MIN) % PTE_SHIFT_PER_DIR) == 0); + + if (allocCb == nullptr || releaseCb == nullptr) { + throw std::invalid_argument("invalid param, directory allocate or release callback is null"); + } + + mSpaceMask = (static_cast(-1)) << PAGE_SHIFT_MIN; + + mPgdAllocCb = allocCb; + mPgdReleaseCb = releaseCb; + } + + /** + * DeConstructor. + */ + ~NetPgTable() + { + if (mRootEntry.IsPresent()) { + try { + Cleanup(); + } catch (const std::exception& ex) { + NN_LOG_ERROR("NetPgTable DeConstructor caught exception in Cleanup: " << ex.what()); + } + } + mPgdAllocCb = nullptr; + mPgdReleaseCb = nullptr; + } + + /** + * Add a memory region to the page table. + * + * @param [in] region Memory region to insert. The region must remain valid + * and unchanged as long as it's in the page table. + * + * @return NN_OK - region was added. + * NN_INVALID_PARAM - memory region address is invalid (misaligned or empty) + */ + NResult Insert(PgtRegion ®ion); + + /** + * Remove a memory region from the page table. + * + * @param [in] region Memory region to remove. This must be the same pointer passed to Insert. + * @return NN_OK - region was removed. + * NN_INVALID_PARAM - memory region address is invalid (misaligned or empty) + */ + NResult Remove(PgtRegion ®ion); + + /** + * Find a region which contains the given address. + * + * @param [in] address Address to search. + * @return Pointer to the region which contains 'address', or nullptr if not found. + */ + PgtRegion *Lookup(PgtAddress address) const; + + /** + * Search for all regions overlapping with a given address range. + * + * @param [in] from Lower bound of the range. + * @param [in] to Upper bound of the range (inclusive). + * @param [in] cb Callback to be called for every region found. + * The callback must not modify the page table. + * @param [in] arg User-defined argument to the callback. + */ + void SearchRange(PgtAddress from, PgtAddress to, PgtSearchCb cb, void *arg) const; + + /** + * Remove all regions from the page table and call the provided callback for each. + */ + void Cleanup(); + + /** + * Dump page table to log. + */ + void Dump(); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS; + +private: + NResult InsertPage(PgtAddress address, uint32_t order, PgtRegion ®ion); + + NResult RemovePage(PgtAddress address, uint32_t order, PgtRegion ®ion); + NResult UnlinkRegion(PgtAddress address, uint32_t order, PgtDir &pgd, PgtEntry &pte, uint32_t shift, + PgtRegion ®ion); + + void SearchRange(PgtAddress from, PgtAddress to, PgtSearchCb cb, void *arg); + void SearchSubtree(PgtAddress address, uint32_t order, const PgtEntry &pte, uint32_t shift, PgtSearchCb cb, + void *arg, PgtRegion *&lastRegion); + + void PgtEnsureCapacity(uint32_t order, PgtAddress address); + bool PgtExpand(); + bool PgtShrink(); + + NResult PgtCheckEntryDir(PgtEntry &pte, uint32_t shift, uint32_t order); + void PgTableReset(); + + PgtDir *PgtDirAlloc(); + void PgtDirRelease(PgtDir *pgd); + + void PgtDumpSubtree(uint32_t indent, const PgtEntry &pte, uint32_t pteIndex, PgtAddress base, PgtAddress mask, + uint32_t shift); + + PgtEntry mRootEntry {}; + PgtAddress mVirBaseAddr = 0; + PgtAddress mSpaceMask = 0; + uint32_t mIndexShift = 0; + uint32_t mRegionCount = 0; + + PgDirAllocCb mPgdAllocCb = nullptr; + PgDirReleaseCb mPgdReleaseCb = nullptr; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} + +#endif diff --git a/src/common/net_rb_tree.h b/src/common/net_rb_tree.h new file mode 100644 index 0000000000000000000000000000000000000000..5ab301092a43c05a23285be9ac8775d94ee7ba40 --- /dev/null +++ b/src/common/net_rb_tree.h @@ -0,0 +1,597 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef LOCKFREES_RBTREEWRAPPER_H +#define LOCKFREES_RBTREEWRAPPER_H + +namespace ock { +namespace hcom { +enum NetRbColor : uint8_t { + RB_RED, + RB_BLACK +}; + +/* + * @brief red-black tree node + * @tparam T data type stored in tree node + */ +template struct NetRbNode { + T data; + + uint64_t rbParentColor = 0; + NetRbNode *left = nullptr; + NetRbNode *right = nullptr; + + NetRbNode() + { + data = T(); + } + + explicit NetRbNode(T _data) : data(_data) + { + ClearParent(); + } + + inline NetRbNode *GetParent() const + { + return reinterpret_cast *>(rbParentColor & ~0x3); + } + + inline NetRbColor GetColor() const + { + return static_cast(rbParentColor & 1); + } + + inline bool IsBlack() const + { + return GetColor() == RB_BLACK; + } + + inline bool IsRed() const + { + return GetColor() == RB_RED; + } + + inline void SetBlack() + { + rbParentColor |= RB_BLACK; + } + + inline void SetRed() + { + rbParentColor &= ~1; + } + + inline void SetParent(NetRbNode *parent) + { + rbParentColor = (rbParentColor & 3UL) | reinterpret_cast(parent); + } + + inline void SetColor(uint64_t color) + { + rbParentColor = (rbParentColor & ~1) | color; + } + + inline bool IsOrphan() const + { + return GetParent() == this; + } + + inline void ClearParent() + { + SetParent(this); + } + + /* + * @brief Get prev node according to inorder traverse + */ + inline NetRbNode *Prev() + { + if (this->GetParent() == this) { + return nullptr; + } + + auto node = this; + + if (node->left) { + node = node->left; + + while (node->right) { + node = node->right; + } + } + + return node; + } + + /* + * @brief Get successor node according to inorder traverse + */ + inline NetRbNode *Next() + { + NetRbNode *parent = nullptr; + + if (IsOrphan()) { + return nullptr; + } + + auto node = this; + + if (node->right) { + node = node->right; + + while (node->left) { + node = node->left; + } + + return node; + } + + while ((parent = node->GetParent()) && node == parent->right) { + node = parent; + } + + return parent; + } + + /* + * @brief Link current node to parent + */ + inline void Link(NetRbNode *parent, NetRbNode **link) + { + if (NN_UNLIKELY(link == nullptr)) { + return; + } + rbParentColor = reinterpret_cast(parent); + left = nullptr; + right = nullptr; + *link = this; + } +} __attribute__((aligned(sizeof(long)))); + +template struct NetRbTree { + NetRbNode *ref = nullptr; + + inline void RotateLeft(NetRbNode *node) + { + if (NN_UNLIKELY(node == nullptr)) { + return; + } + auto right = node->right; + auto parent = node->GetParent(); + + if ((node->right = right->left)) { + right->left->SetParent(node); + } + + right->left = node; + right->SetParent(parent); + + if (parent) { + if (node == parent->left) { + parent->left = right; + } else { + parent->right = right; + } + } else { + ref = right; + } + node->SetParent(right); + } + + inline void RotateRight(NetRbNode *node) + { + if (NN_UNLIKELY(node == nullptr)) { + return; + } + auto left = node->left; + auto parent = node->GetParent(); + + if ((node->left = left->right)) { + left->right->SetParent(node); + } + + left->right = node; + left->SetParent(parent); + + if (parent) { + if (node == parent->right) { + parent->right = left; + } else { + parent->left = left; + } + } else { + ref = left; + } + + node->SetParent(left); + } + + /* + * @brief subroutine of Insert when the node to insert is on left side of parent + */ + inline bool InsertLeft(NetRbNode *&node, NetRbNode *&parent, NetRbNode *&gparent) + { + if (NN_UNLIKELY(parent == nullptr || gparent == nullptr)) { + return false; + } + NetRbNode *uncle = gparent->right; + + /* if both parent and uncle is red, grandparent must be blacked,we just + * transfer grandparent's black to parent and uncle, then go upwards to + * check grandparent further more */ + if (uncle && uncle->IsRed()) { + uncle->SetBlack(); + parent->SetBlack(); + gparent->SetRed(); + node = gparent; + return true; + } + + /* uncle is black,only recoloring will cause uncle lose 1 bh, which means bh imbalance, + * we try to make node's position to parent is same as parent's position to grandparent + * in current branch, parent is left child,node is right child,we call this LR type, + * we left-rotate parent to exchange node and parent, which result in + * both node and parent is their parents' left child, which called LL type + * LL type(same as RR type in symmetry situation) structure is easier to rebalance */ + if (parent->right == node) { + NetRbNode *tmp = nullptr; + RotateLeft(parent); + tmp = parent; + parent = node; + node = tmp; + } + + /* now the type is LL, we push down the grandparent's black to parent and rotate parent to be + * grandparent, balancing finished */ + parent->SetBlack(); + gparent->SetRed(); + RotateRight(gparent); + return false; + } + + inline bool InsertRight(NetRbNode *&node, NetRbNode *&parent, NetRbNode *&gparent) + { + if (NN_UNLIKELY(parent == nullptr || gparent == nullptr)) { + return false; + } + NetRbNode *uncle = gparent->left; + + if (uncle && uncle->IsRed()) { + uncle->SetBlack(); + parent->SetBlack(); + gparent->SetRed(); + node = gparent; + return true; + } + + if (parent->left == node) { + NetRbNode *tmp = nullptr; + RotateRight(parent); + tmp = parent; + parent = node; + node = tmp; + } + + parent->SetBlack(); + gparent->SetRed(); + RotateLeft(gparent); + return false; + } + + /* + * @brief insert node to rbt,this routine does not include searching process,before calling this routine, + * caller needs to search the correct parent, and call Link to link node to its parent + * + * @param node + */ + inline void Insert(NetRbNode *node) + { + if (NN_UNLIKELY(node == nullptr)) { + return; + } + NetRbNode *parent = nullptr; + NetRbNode *gparent = nullptr; + + /* go upwards until there is no continuous red child & parent, since parent + * is red, grandparent must be black */ + while ((parent = node->GetParent()) && parent->IsRed()) { + gparent = parent->GetParent(); + if (parent == gparent->left) { + if (InsertLeft(node, parent, gparent)) { + continue; + } + } else { + if (InsertRight(node, parent, gparent)) { + continue; + } + } + } + if (ref == nullptr) { + return; + } + ref->SetBlack(); + } + + /* + * @brief subroutine of EraseColor when the node is on left side of parent + */ + inline bool EraseColorLeft(NetRbNode *&node, NetRbNode *&parent) + { + if (NN_UNLIKELY(parent == nullptr)) { + return false; + } + auto other = parent->right; + + if (NN_UNLIKELY(other == nullptr)) { + return false; + } + + /* as sibling is red, we get Rr__ type,which we can convert to Rb__, + * just black sibling, red parent and left rotate parent */ + if (other->IsRed()) { + other->SetBlack(); + parent->SetRed(); + RotateLeft(parent); + other = parent->right; + } + + /* now, the type must be Rb__, then we determine nephew type furtherly, + * no children or only has black children,the final type is Rb,just red sibling and go upwards */ + if ((!other->left || other->left->IsBlack()) && (!other->right || other->right->IsBlack())) { + other->SetRed(); + node = parent; + parent = node->GetParent(); + } else { + /* no red right nephew, type is RbLr,do red sibling,black left nephew and + * left rotate sibling, then update sibling, thus type converted to RbRr */ + if (!other->right || other->right->IsBlack()) { + other->left->SetBlack(); + other->SetRed(); + RotateRight(other); + other = parent->right; + } + + /* type RbRr, we recompense deleted bh 1 to current left path by blacking and rotating parent to + * left side, which may result in lacking bh 1 for right path when parent is already blacked, so + * we also color sibling by parent's color,till now,the visited subtree has same bh and root + * color as before this routine,no need for further upward checking,just break */ + other->SetColor(parent->GetColor()); + parent->SetBlack(); + if (NN_UNLIKELY(!other->right)) { + return false; + } + other->right->SetBlack(); + RotateLeft(parent); + node = ref; + return true; + } + return false; + } + + inline bool EraseColorRight(NetRbNode *&node, NetRbNode *&parent) + { + if (NN_UNLIKELY(parent == nullptr)) { + return false; + } + auto other = parent->left; + + if (NN_UNLIKELY(other == nullptr)) { + return false; + } + + if (other->IsRed()) { + other->SetBlack(); + parent->SetRed(); + RotateRight(parent); + other = parent->left; + } + + if ((!other->left || other->left->IsBlack()) && (!other->right || other->right->IsBlack())) { + other->SetRed(); + node = parent; + parent = node->GetParent(); + } else { + if (!other->left || other->left->IsBlack()) { + other->right->SetBlack(); + other->SetRed(); + RotateLeft(other); + other = parent->left; + } + + other->SetColor(parent->GetColor()); + parent->SetBlack(); + if (NN_UNLIKELY(!other->left)) { + return false; + } + other->left->SetBlack(); + RotateRight(parent); + node = ref; + return true; + } + return false; + } + + /* + * @brief rebalancing red-black tree after deleting black node + * @param node deleted node + * @param parent parent of deleted node + * + * Only take deleting left son in account + * For the rebalancing logic is symmetrically identical to another side + * Def. Only 4 types satify rbt laws and affect rebalancing process + * _1_2_3_4 _1 stands for sibling side,_2 stands for sibling color,_3 stands for nephew side,_4 stands for nephew + * color Notice! absent of _3 and _4 means sibling has no children or every child is black The 4 types are + * Rb/Rr/RbLr/RbRr,Rr(__) can be converted to Rb(__),which is simpler to rebalance, the same as converting RbLr to + * RbRr + */ + inline void EraseColor(NetRbNode *node, NetRbNode *parent) + { + /* rebalancing upwards until node reached root or became red */ + while ((node == nullptr || node->IsBlack()) && node != ref) { + if (NN_UNLIKELY(parent == nullptr)) { + return; + } + if (parent->left == node) { + if (EraseColorLeft(node, parent)) { + break; + } + } else { + if (EraseColorRight(node, parent)) { + break; + } + } + } + + /* In few cases we can get an node to receive the deleted black color, + * then we just black it + * this operation may be invalid when node is already blacked, + * which means we reached root node,also means we complete rebalance, so + * this invalid blacking do no harm */ + if (node) { + node->SetBlack(); + } + } + + /* + * @brief subroutine of Erase when node has two children + */ + inline void EraseWithTwoChildren(NetRbNode *&node, NetRbNode *&parent, NetRbNode *&child, + NetRbColor &color) + { + NetRbNode *old = node; + NetRbNode *left = nullptr; + + node = node->right; + while ((left = node->left) != nullptr) { + node = left; + } + + parent = old->GetParent(); + if (parent) { + if (parent->left == old) { + parent->left = node; + } else { + parent->right = node; + } + } else { + ref = node; + } + + child = node->right; + parent = node->GetParent(); + color = node->GetColor(); + + if (parent == old) { + parent = node; + } else { + if (child) { + child->SetParent(parent); + } + + parent->left = child; + node->right = old->right; + old->right->SetParent(node); + } + + node->rbParentColor = old->rbParentColor; + node->left = old->left; + old->left->SetParent(node); + } + + /* + * @brief delete node from rbt,the process is same as binary search tree deleting, + * except we may do some recoloring and rotation later + * + * @param node the node to delete + */ + inline void Erase(NetRbNode *node) + { + if (NN_UNLIKELY(node == nullptr)) { + return; + } + NetRbNode *child = nullptr; + NetRbNode *parent = nullptr; + NetRbColor color; + + if (!node->left) { + child = node->right; + } else if (!node->right) { + child = node->left; + } else { + EraseWithTwoChildren(node, parent, child, color); + + // fix: coloring label is at bottom of the routine + if (color == RB_BLACK) { + EraseColor(child, parent); + } + return; + } + + parent = node->GetParent(); + color = node->GetColor(); + + if (child) { + child->SetParent(parent); + } + + if (parent) { + if (parent->left == node) { + parent->left = child; + } else { + parent->right = child; + } + } else { + ref = child; + } + + /* only delete black node corrupt rbt,we need do rebalancing */ + if (color == RB_BLACK) { + EraseColor(child, parent); + } + } + + /* + * @param victim the node to be replaced + * @param newNode the node to replace victim + */ + inline void Replace(NetRbNode *victim, NetRbNode *newNode) + { + if (NN_UNLIKELY(victim == nullptr || newNode == nullptr)) { + return; + } + auto parent = victim->GetParent(); + if (parent) { + if (victim == parent->left) { + parent->left = newNode; + } else { + parent->right = newNode; + } + } else { + ref = newNode; + } + + if (victim->left) { + victim->left->SetParent(newNode); + } + + if (victim->right) { + victim->right->SetParent(newNode); + } + + newNode->rbParentColor = victim->rbParentColor; + newNode->left = victim->left; + newNode->right = victim->right; + } +}; +} +} +#endif // LOCKFREES_RBTREEWRAPPER_H diff --git a/src/common/net_security_alg.cpp b/src/common/net_security_alg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a75c29b060a3fb834aba1ad88ac19287c124c33f --- /dev/null +++ b/src/common/net_security_alg.cpp @@ -0,0 +1,235 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "net_security_alg.h" +#include "net_security_rand.h" +#include "openssl_api_wrapper.h" + +namespace ock { +namespace hcom { +const EVP_CIPHER *GetEvpCipherSuite(UBSHcomNetCipherSuite mCipherSuite) +{ + const EVP_CIPHER *cipherSuite; + switch (mCipherSuite) { + case AES_GCM_128: + cipherSuite = HcomSsl::EvpAes128Gcm(); + break; + case AES_GCM_256: + cipherSuite = HcomSsl::EvpAes256Gcm(); + break; + case AES_CCM_128: + cipherSuite = HcomSsl::EvpAes128Ccm(); + break; + case CHACHA20_POLY1305: + cipherSuite = HcomSsl::EvpChacha20Poly1305(); + break; + default: + cipherSuite = HcomSsl::EvpAes128Gcm(); + } + return cipherSuite; +} + +NResult AesGcm128::SetEncryptInfo(EVP_CIPHER_CTX *ctx, const unsigned char *key, unsigned char *cipher) +{ + if (HcomSsl::EvpEncryptInitEx(ctx, GetEvpCipherSuite(mCipherSuite), nullptr, nullptr, nullptr) <= 0) { + NN_LOG_ERROR("EvpEncryptInitEx() failed"); + return NN_ENCRYPT_FAILED; + } + + /* Put IV */ + if (!SecurityRandGenerator::SslRand(cipher + mIVOffset, mIVLen)) { + NN_LOG_ERROR("Generate IV failed"); + return NN_ENCRYPT_FAILED; + } + + if (HcomSsl::EvpCipherCtxCtrl(ctx, HcomSsl::EVP_CTRL_AEAD_SET_IVLEN, mIVLen, nullptr) <= 0) { + NN_LOG_ERROR("EvpCipherCtxCtrl() failed"); + return NN_ENCRYPT_FAILED; + } + + if (mCipherSuite == AES_CCM_128) { + /* if CipherSuite is AES_CCM_128, need set tag */ + if (HcomSsl::EvpCipherCtxCtrl(ctx, HcomSsl::EVP_CTRL_AEAD_SET_TAG, mTagLen, nullptr) <= 0) { + NN_LOG_ERROR("Set TAG failed"); + return NN_ENCRYPT_FAILED; + } + } + + if (HcomSsl::EvpEncryptInitEx(ctx, nullptr, nullptr, key, cipher + mIVOffset) <= 0) { + NN_LOG_ERROR("EvpEncryptInitEx() failed"); + return NN_ENCRYPT_FAILED; + } + return NN_OK; +} + +NResult AesGcm128::EncryptInner(const unsigned char *key, const unsigned char *aad, const unsigned char *rawData, + uint32_t rawLen, unsigned char *cipher, uint32_t &cipherLen) +{ + EVP_CIPHER_CTX *ctx = HcomSsl::EvpCipherCtxNew(); + if (ctx == nullptr) { + NN_LOG_ERROR("EvpCipherCtxNew() alloc memory failed!"); + return NN_ENCRYPT_FAILED; + } + + if (SetEncryptInfo(ctx, key, cipher) != NN_OK) { + HcomSsl::EvpCipherCtxFree(ctx); + return NN_ENCRYPT_FAILED; + } + + int outLen = 0; + if (mCipherSuite == AES_CCM_128) { + /* if CipherSuite is AES_CCM, set plaintext length: only needed if AAD is used */ + /* AES_GCM and CHACHA20_POLY1305 automatically handle the plaintext length through their internal mechanisms */ + if (HcomSsl::EvpEncryptUpdate(ctx, nullptr, &outLen, nullptr, rawLen) <= 0) { + NN_LOG_ERROR("EVP_EncryptUpdate() set plaintext length failed"); + goto ERROR_FREE; + } + } + + if (memcpy_s(cipher + mAADOffset, mAADLen, aad, mAADLen) != 0) { + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + goto ERROR_FREE; + } + + if (HcomSsl::EvpEncryptUpdate(ctx, nullptr, &outLen, cipher + mAADOffset, mAADLen) <= 0) { + NN_LOG_ERROR("EvpEncryptUpdate() AAD failed"); + goto ERROR_FREE; + } + + if (HcomSsl::EvpEncryptUpdate(ctx, cipher + mCipherOffset, &outLen, rawData, rawLen) <= 0) { + NN_LOG_ERROR("EvpEncryptUpdate() raw data failed"); + goto ERROR_FREE; + } + cipherLen = static_cast(outLen); + + if (HcomSsl::EvpEncryptFinalEx(ctx, cipher + mCipherOffset + outLen, &outLen) <= 0) { + NN_LOG_ERROR("EvpEncryptFinalEx() raw data failed"); + goto ERROR_FREE; + } + + /* Final should make outLen to zero */ + if (outLen != 0) { + NN_LOG_ERROR("EvpEncryptFinalEx() raw data failed as out len should be zero"); + goto ERROR_FREE; + } + + /* Add the prefix data in cipher format because cipher is same with plain */ + if (HcomSsl::EvpCipherCtxCtrl(ctx, HcomSsl::EVP_CTRL_AEAD_GET_TAG, mTagLen, cipher + mTagOffset) <= 0) { + NN_LOG_ERROR("Generate TAG failed"); + goto ERROR_FREE; + } + + cipherLen += static_cast(mCipherOffset); + HcomSsl::EvpCipherCtxFree(ctx); + NN_LOG_TRACE_INFO("Encrypt data rawLen :" << rawLen << " cipherLen: " << cipherLen); + return NN_OK; +ERROR_FREE: + HcomSsl::EvpCipherCtxFree(ctx); + return NN_ENCRYPT_FAILED; +} + +NResult AesGcm128::SetDecryptInfo(EVP_CIPHER_CTX *ctx, const unsigned char *key, const unsigned char *cipher) +{ + if (HcomSsl::EvpDecryptInitEx(ctx, GetEvpCipherSuite(mCipherSuite), nullptr, nullptr, nullptr) <= 0) { + NN_LOG_ERROR("EvpDecryptInitEx() failed"); + HcomSsl::EvpCipherCtxFree(ctx); + return NN_ENCRYPT_FAILED; + } + + if (HcomSsl::EvpCipherCtxCtrl(ctx, HcomSsl::EVP_CTRL_AEAD_SET_IVLEN, mIVLen, nullptr) <= 0) { + NN_LOG_ERROR("Set IV length failed"); + HcomSsl::EvpCipherCtxFree(ctx); + return NN_ENCRYPT_FAILED; + } + + if (mCipherSuite == AES_CCM_128) { + /* if CipherSuite is AES_CCM_128, need set tag */ + if (HcomSsl::EvpCipherCtxCtrl(ctx, HcomSsl::EVP_CTRL_AEAD_SET_TAG, mTagLen, + const_cast(cipher + mTagOffset)) <= 0) { + NN_LOG_ERROR("Set TAG failed"); + HcomSsl::EvpCipherCtxFree(ctx); + return NN_ENCRYPT_FAILED; + } + } + + if (HcomSsl::EvpDecryptInitEx(ctx, nullptr, nullptr, key, cipher + mIVOffset) <= 0) { + NN_LOG_ERROR("EvpDecryptInitEx() failed"); + HcomSsl::EvpCipherCtxFree(ctx); + return NN_ENCRYPT_FAILED; + } + return NN_OK; +} + +NResult AesGcm128::DecryptInner(const unsigned char *key, const unsigned char *cipher, uint32_t cipherLen, + unsigned char *rawData, uint32_t &rawLen) +{ + EVP_CIPHER_CTX *ctx = HcomSsl::EvpCipherCtxNew(); + if (ctx == nullptr) { + NN_LOG_ERROR("EvpCipherCtxNew() alloc memory failed!"); + return NN_DECRYPT_FAILED; + } + + if (SetDecryptInfo(ctx, key, cipher) != NN_OK) { + HcomSsl::EvpCipherCtxFree(ctx); + return NN_DECRYPT_FAILED; + } + + int outLen = 0; + if (mCipherSuite == AES_CCM_128) { + /* if CipherSuite is AES_CCM_128, set cipher length: only needed if AAD is used */ + if (HcomSsl::EvpDecryptUpdate(ctx, nullptr, &outLen, nullptr, cipherLen - mCipherOffset) <= 0) { + NN_LOG_ERROR("EvpDecryptUpdate() set cipher length failed"); + HcomSsl::EvpCipherCtxFree(ctx); + return NN_DECRYPT_FAILED; + } + } + + if (HcomSsl::EvpDecryptUpdate(ctx, nullptr, &outLen, cipher + mAADOffset, mAADLen) <= 0) { + NN_LOG_ERROR("EvpDecryptUpdate() AAD failed"); + HcomSsl::EvpCipherCtxFree(ctx); + return NN_DECRYPT_FAILED; + } + + if (HcomSsl::EvpDecryptUpdate(ctx, rawData, &outLen, cipher + mCipherOffset, cipherLen - mCipherOffset) <= 0) { + NN_LOG_ERROR("EvpDecryptUpdate() cipher data failed"); + HcomSsl::EvpCipherCtxFree(ctx); + return NN_DECRYPT_FAILED; + } + rawLen = static_cast(outLen); + + if (mCipherSuite != AES_CCM_128) { + if (HcomSsl::EvpCipherCtxCtrl(ctx, HcomSsl::EVP_CTRL_AEAD_SET_TAG, mTagLen, + const_cast(cipher + mTagOffset)) <= 0) { + NN_LOG_ERROR("Set TAG failed"); + HcomSsl::EvpCipherCtxFree(ctx); + return NN_DECRYPT_FAILED; + } + } + + /* If don't check TAG, the EvpDecryptFinalEx() will always return 0, ignore that error */ + if (HcomSsl::EvpDecryptFinalEx(ctx, rawData, &outLen) <= 0) { + NN_LOG_WARN("EvpDecryptFinalEx() cipher data unfinished"); + } + + /* DecryptFinal should make it to zero */ + if (outLen != 0) { + NN_LOG_WARN("EvpDecryptFinalEx() cipher data failed as outLen is zero"); + HcomSsl::EvpCipherCtxFree(ctx); + return NN_DECRYPT_FAILED; + } + + HcomSsl::EvpCipherCtxFree(ctx); + NN_LOG_TRACE_INFO("Decrypt data rawLen :" << rawLen << " cipherLen: " << cipherLen); + return NN_OK; +} +} +} \ No newline at end of file diff --git a/src/common/net_security_alg.h b/src/common/net_security_alg.h new file mode 100644 index 0000000000000000000000000000000000000000..b84f25210df09c46a1c38ecceb4983496fc9a7a6 --- /dev/null +++ b/src/common/net_security_alg.h @@ -0,0 +1,176 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef OCK_HCOM_SECURITY_ALG_H +#define OCK_HCOM_SECURITY_ALG_H + +#include "net_common.h" +#include "net_security_rand.h" +#include "openssl_api_wrapper.h" + +namespace ock { +namespace hcom { +class AesGcm128 { +public: + /* + * @brief Encrypt @b rawData to @b cipher, default use AES-GCM-128 + * + * @param secret : the secret for encrypt + * @param rawData : the raw data for encrypt + * @param rawLen : the length of raw data + * @param cipher : the cipher data + * @param cipherLen : the length of cipher data + * + * @return true for success + * @note : user can call EstimatedEncryptLen() for @b cipher memory length allocating, + * and should make sure @b cipher is enough for write + */ + bool Encrypt(NetSecrets &secrets, const void *rawData, uint32_t rawLen, void *cipher, uint32_t &cipherLen) + { + if (NN_UNLIKELY(secrets.GetKeySecret() == nullptr) || NN_UNLIKELY(rawData == nullptr) || + NN_UNLIKELY(rawLen == 0) || NN_UNLIKELY(cipher == nullptr)) { + NN_LOG_ERROR("Failed to encrypt as invalid params."); + return false; + } + + int32_t ret; + uint32_t estimateLen = EstimatedEncryptLen(rawLen); + cipherLen = estimateLen; + + // openssl api the type of len is int + if (NN_UNLIKELY(rawLen > INT_MAX)) { + NN_LOG_ERROR("invalid rawLen " << rawLen); + return false; + } + ret = EncryptInner(static_cast(secrets.GetKeySecret()), + static_cast(secrets.GetAADSecret()), static_cast(rawData), + rawLen, static_cast(cipher), cipherLen); + if (NN_UNLIKELY(ret != 0) || NN_UNLIKELY(cipherLen != estimateLen)) { + NN_LOG_ERROR("Failed to encrypt as ret:" << ret << " cipher length:" << cipherLen << + " estimateLen length:" << estimateLen); + return false; + } + return true; + } + + NResult EncryptInner(const unsigned char *key, const unsigned char *aad, const unsigned char *rawData, + uint32_t rawLen, unsigned char *cipher, uint32_t &cipherLen); + + NResult SetEncryptInfo(EVP_CIPHER_CTX *ctx, const unsigned char *key, unsigned char *cipher); + /* + * @brief Decrypt @b cipher to @b raw, default use AES-GCM-128 + * + * @param key : the private-key for decrypt, length is @b mKeyLen + * @param cipher : the cipher data + * @param cipherLen : the length of cipher data + * @param rawData : the raw data for encrypt + * @param rawLen : the length of raw data + * + * @return true for success + * @note : user should make sure @b rawData is enough for write + */ + bool Decrypt(NetSecrets &secrets, const void *cipher, uint32_t cipherLen, void *rawData, uint32_t &rawLen) + { + if (NN_UNLIKELY(secrets.GetKeySecret() == nullptr) || NN_UNLIKELY(cipher == nullptr) || + NN_UNLIKELY(cipherLen == 0) || NN_UNLIKELY(rawData == nullptr)) { + NN_LOG_ERROR("Invalid params"); + return false; + } + + int32_t ret; + uint32_t estimateLen = GetRawLen(cipherLen); + rawLen = estimateLen; + + // openssl api the type of len is int + if (NN_UNLIKELY(cipherLen <= mCipherOffset) || NN_UNLIKELY(cipherLen - mCipherOffset > INT_MAX)) { + NN_LOG_ERROR("invalid cipherLen " << cipherLen); + return false; + } + ret = DecryptInner(static_cast(secrets.GetKeySecret()), + static_cast(cipher), cipherLen, static_cast(rawData), rawLen); + if (NN_UNLIKELY(ret != 0) || NN_UNLIKELY(rawLen != estimateLen)) { + NN_LOG_ERROR("Failed to decrypt as ret:" << ret << " raw length:" << rawLen << + " estimateLen length:" << estimateLen); + return false; + } + return true; + } + + NResult DecryptInner(const unsigned char *key, const unsigned char *cipher, uint32_t cipherLen, + unsigned char *rawData, uint32_t &rawLen); + + NResult SetDecryptInfo(EVP_CIPHER_CTX *ctx, const unsigned char *key, const unsigned char *cipher); + + /* + * @brief Estimated the cipher length, it will be greater than or equal to real cipher length + */ + inline uint32_t EstimatedEncryptLen(uint32_t rawLen) const + { + auto cipherLen = static_cast(mCipherOffset); + if (NN_UNLIKELY(rawLen == 0 || rawLen > UINT32_MAX - cipherLen)) { + NN_LOG_ERROR("Failed to estimate ep encrypt raw length invalid"); + return 0; + } + cipherLen += rawLen; + return cipherLen; + } + + inline uint32_t GetRawLen(uint32_t cipherLen) const + { + if (cipherLen <= static_cast(mCipherOffset)) { + return 0; + } + return cipherLen - static_cast(mCipherOffset); + } + + inline void SetEncryptOptions(UBSHcomNetCipherSuite cipherSuite) + { + mCipherSuite = cipherSuite; + if (cipherSuite == AES_GCM_128 || cipherSuite == AES_CCM_128) { + mKeyLen = NN_NO16; + } else if (cipherSuite == AES_GCM_256 || cipherSuite == CHACHA20_POLY1305) { + mKeyLen = NN_NO32; + } else { + NN_LOG_WARN("Invalid to set encrypt options, because unknown cipher suite, use default one."); + } + } + +private: + UBSHcomNetCipherSuite mCipherSuite = AES_GCM_128; + /* + * cipher data format :|IV||AAD (opt)|TAG(opt)|CIPHER DATA| + * ------------ Bytes :|12 | 16 | 16 | ? | + * ------------------- |-> mIVOffset |-> mCipherOffset + * ----------------------- |-> mAADOffset + * --------------------------------- |-> mTagOffset + * + * default dont use AAD and TAG, so it is same + */ + + /* Key use 32Bytes cipher len 256, 16Bytes cipher len 128 */ + int mKeyLen = NN_NO16; + /* IV use 12Bytes(96bits) */ + const int mIVLen = NN_NO12; + /* AAD use 16Bytes(128bits) */ + const int mAADLen = NN_NO16; + /* Tag use 16Bytes(128bits) */ + const int mTagLen = NN_NO16; + + off_t mIVOffset = 0; + off_t mAADOffset = mIVOffset + mIVLen; + off_t mTagOffset = mAADOffset + mAADLen; + off_t mCipherOffset = mTagOffset + mTagLen; +}; +} +} + +#endif \ No newline at end of file diff --git a/src/common/net_security_rand.h b/src/common/net_security_rand.h new file mode 100644 index 0000000000000000000000000000000000000000..3b9ce1fb79e274c578d540dea241a5fe5d0f4cd6 --- /dev/null +++ b/src/common/net_security_rand.h @@ -0,0 +1,242 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef OCK_HCOM_SECURITY_RAND_H +#define OCK_HCOM_SECURITY_RAND_H + +#include +#include +#include +#include +#include +#include + +#include "securec.h" +#include "openssl_api_wrapper.h" + +namespace ock { +namespace hcom { +/* + * Because OpenSSL will reseed by itself, so there + * we don't reseed it by manual + */ +class SecurityRandGenerator { +public: + /* + * @brief generate @b length random data to @b out + */ + static bool SslRand(void *out, size_t length) + { + int ret; + auto outBuf = static_cast(out); + const uint16_t maxTryCount = NN_NO1000; + uint16_t tryCount = 0; + + if (out == nullptr || length == 0) { + return false; + } + + while (tryCount <= maxTryCount) { + Poll(); + tryCount++; + + ret = HcomSsl::RandPrivBytes(outBuf, length); + if (ret <= 0) { + NN_LOG_TRACE_INFO("Failed to generate secure rand, result " << ret); + continue; + } else { + NN_LOG_TRACE_INFO("Successfully generate secure random num, length: " << length); + return true; + } + } + + NN_LOG_ERROR("Failed to generate secure rand after tried " << tryCount << " times"); + return false; + } + +private: + /* + * @brief RandPoll() to get enough rand + */ + static void Poll() + { + const uint16_t maxPollCount = NN_NO1000; + uint16_t pollCount = 0; + + while (HcomSsl::RandStatus() <= 0 && pollCount < maxPollCount) { + pollCount++; + NN_LOG_TRACE_INFO("Rand start to poll"); + HcomSsl::RandPoll(); + } + } +}; + +class NetSecrets { +public: + NetSecrets() = default; + ~NetSecrets() + { + bzero(mKeySecret, NN_NO32 * sizeof(char)); + bzero(mAADSecret, NN_NO32 * sizeof(char)); + bzero(mIVSecret, NN_NO32 * sizeof(char)); + }; + + inline const void *GetKeySecret() const + { + return mKeySecret; + } + + inline const void *GetAADSecret() const + { + return mAADSecret; + } + + inline const void *GetIVSecret() const + { + return mIVSecret; + } + + bool Init(UBSHcomNetCipherSuite cipherSuite) + { + if (cipherSuite == AES_GCM_128 || cipherSuite == AES_CCM_128) { + mKeySecretLen = NN_NO16; + } else if (cipherSuite == AES_GCM_256 || cipherSuite == CHACHA20_POLY1305) { + mKeySecretLen = NN_NO32; + } else { + NN_LOG_ERROR("Failed to init secret, because unknown cipher suite."); + return false; + } + + return InitSSLRandSecret(); + } + /* + * @brief update the secret, generate new sn and new secert + * + * because now RAND is use static method, so it call SslRand() iternal + * + * @return true for success, false for failed + */ + bool InitSSLRandSecret() + { + mAADSecretLen = NN_NO16; + mIVSecretLen = NN_NO12; + + if (!SecurityRandGenerator::SslRand(mKeySecret, mKeySecretLen)) { + NN_LOG_WARN("Update keySecret failed"); + return false; + } + + if (!SecurityRandGenerator::SslRand(mAADSecret, mAADSecretLen)) { + NN_LOG_WARN("Update mAADSecret failed"); + return false; + } + + if (!SecurityRandGenerator::SslRand(mIVSecret, mIVSecretLen)) { + NN_LOG_WARN("Update mIVSecret failed"); + return false; + } + return true; + } + + /* + * @brief Format:|SN| |KeySecret| |AADSecret| |IVSecret| + * + * Bytes: mSN + * mKeySecretLen + * mAADSecrete + * mIVSecret + */ + inline size_t GetSerializeLen() const + { + return sizeof(uint8_t) + mKeySecretLen + mAADSecretLen + mIVSecretLen; + } + + inline bool Serialize(char *dest, size_t len) const + { + if (NN_UNLIKELY(dest == nullptr) || NN_UNLIKELY(len != GetSerializeLen())) { + NN_LOG_ERROR("Invalid param secret is null or length:" << len << " is not equal to serialized len:" << + GetSerializeLen()); + return false; + } + + if (memcpy_s(dest, sizeof(uint8_t), &mSN, sizeof(uint8_t)) != EOK) { + NN_LOG_ERROR("memcpy_s sn failed."); + return false; + } + + if (memcpy_s(dest + sizeof(uint8_t), mKeySecretLen, mKeySecret, mKeySecretLen) != EOK) { + NN_LOG_ERROR("memcpy_s key failed."); + return false; + } + + if (memcpy_s(dest + sizeof(uint8_t) + mKeySecretLen, mAADSecretLen, mAADSecret, mAADSecretLen) != EOK) { + NN_LOG_ERROR("memcpy_s aad failed."); + return false; + } + + if (memcpy_s(dest + sizeof(uint8_t) + mKeySecretLen + mAADSecretLen, mIVSecretLen, mIVSecret, mIVSecretLen) != + EOK) { + NN_LOG_ERROR("memcpy_s iv failed."); + return false; + } + + return true; + } + + inline bool Deserialize(const char *secret, size_t len) + { + if (NN_UNLIKELY(secret == nullptr) || NN_UNLIKELY(len != GetSerializeLen())) { + NN_LOG_ERROR("Invalid param secret is null or length:" << len << " is not equal to serialized len:" << + GetSerializeLen()); + return false; + } + + if (memcpy_s(mKeySecret, mKeySecretLen, secret + sizeof(uint8_t), mKeySecretLen) != EOK) { + NN_LOG_ERROR("memcpy_s key failed."); + return false; + } + + if (memcpy_s(mAADSecret, mAADSecretLen, secret + sizeof(uint8_t) + mKeySecretLen, mAADSecretLen) != EOK) { + NN_LOG_ERROR("memcpy_s aad failed."); + return false; + } + + if (memcpy_s(mIVSecret, mIVSecretLen, secret + sizeof(uint8_t) + mKeySecretLen + mAADSecretLen, mIVSecretLen) != + EOK) { + NN_LOG_ERROR("memcpy_s iv failed."); + return false; + } + + return true; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + // reserve for secret time out + uint8_t mSN = { 0 }; + + char mKeySecret[NN_NO32] = { 0 } ; + char mAADSecret[NN_NO32] = { 0 } ; + char mIVSecret[NN_NO32] = { 0 } ; + + /* the real secret length */ + size_t mKeySecretLen = NN_NO0; + size_t mAADSecretLen = NN_NO0; + size_t mIVSecretLen = NN_NO0; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} + +#endif diff --git a/src/common/net_trace.cpp b/src/common/net_trace.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9d7e3346ee41d76034655daf9081c64e09f653ef --- /dev/null +++ b/src/common/net_trace.cpp @@ -0,0 +1,244 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_trace.h" +#include "net_common.h" + +namespace ock { +namespace hcom { +static const char *g_idToString[] = { + "SERVICE_INSTANCE", + "SERVICE_START", + "SERVICE_STOP", + "SERVICE_CONNECT_DO", + "SERVICE_RECONNECT_DO", + "SERVICE_RECONNECT_COMFIRM", + "SERVICE_REG_MR", + "SERVICE_REG_MR_WITH_PTR", + "SERVICE_DESTROY_MR", + "SERVICE_OP_HANDLE_RNDV", + "SERVICE_OP_HANDLE_RNDV_SGL", + "SERVICE_OP_HANDLE_RECONNECT", + "SERVICE_CB_REQUEST_RECEIVED", + "SERVICE_CB_REQUEST_POSTED", + "SERVICE_CB_ONESIDE_DONE", + "SERVICE_CB_NEW_CHANNEL", + "SERVICE_CB_BROKEN_CHANNEL", + "SERVICE_THREAD_PERIODIC", + + "CHANNEL_SEND", + "CHANNEL_SEND_RAW", + "CHANNEL_SEND_RAW_SGL", + "CHANNEL_SYNC_CALL", + "CHANNEL_ASYNC_CALL", + "CHANNEL_SYNC_CALL_RAW", + "CHANNEL_ASYNC_CALL_RAW", + "CHANNEL_SYNC_CALL_RAW_SGL", + "CHANNEL_ASYNC_CALL_RAW_SGL", + "CHANNEL_SYNC_RNDV_CALL", + "CHANNEL_ASYNC_RNDV_CALL", + "CHANNEL_SYNC_RNDV_SGL_CALL", + "CHANNEL_ASYNC_RNDV_SGL_CALL", + "CHANNEL_READ", + "CHANNEL_READ_SGL", + "CHANNEL_WRITE", + "CHANNEL_WRITE_SGL", + "CHANNEL_SEND_FD", + "CHANNEL_RECEIVE_FD", + + "RDMA_DRIVER_INIT", + "RDMA_DRIVER_UNINIT", + "RDMA_DRIVER_START", + "RDMA_DRIVER_STOP", + "RDMA_DRIVER_CONNECT_EP", + "RDMA_DRIVER_DESTROY_EP", + "RDMA_THREAD_HEARTBEAT", + "RDMA_THREAD_ASYNC_EVENT", + "RDMA_WORKER_BUSY_POLLING", + "RDMA_WORKER_EVENT_POLLING", + "RDMA_EP_ASYNC_POST_SEND", + "RDMA_EP_ASYNC_POST_SEND_RAW", + "RDMA_EP_ASYNC_POST_SEND_RAW_SGL", + "RDMA_EP_ASYNC_POST_READ", + "RDMA_EP_ASYNC_POST_READ_SGL", + "RDMA_EP_ASYNC_POST_WRITE", + "RDMA_EP_ASYNC_POST_WRITE_SGL", + "RDMA_EP_SYNC_POST_SEND", + "RDMA_EP_SYNC_POST_SEND_RAW", + "RDMA_EP_SYNC_POST_SEND_RAW_SGL", + "RDMA_EP_SYNC_POST_READ", + "RDMA_EP_SYNC_POST_READ_SGL", + "RDMA_EP_SYNC_POST_WRITE", + "RDMA_EP_SYNC_POST_WRITE_SGL", + "RDMA_EP_SYNC_RECEIVE", + "RDMA_EP_SYNC_WAIT_COMPLETION", + + "SOCK_DRIVER_CONNECT", + "SOCK_DRIVER_HANDLE_CONNECT", + "SOCK_DRIVER_INITIALIZE", + "SOCK_DRIVER_START", + "SOCK_DRIVER_CREATE_WORKER_RESOURCE", + "SOCK_DRIVER_CREATE_WORKERS", + "SOCK_DRIVER_CREATE_CLIENT_LB", + "SOCK_DRIVER_CREATE_LISTENERS", + "SOCK_DRIVER_WORKER_START", + "SOCK_DRIVER_START_LISTENERS", + "SOCK_WORKER_EPOLL_WAIT", + "SOCK_WORKER_HANDLE_EVENT", + "SOCK_WORKER_HANDLE_EPOLLIN_EVENT", + "SOCK_WORKER_HANDLE_EPOLL_OUT_EVENT", + "SOCK_WORKER_HANDLE_EPOLL_WRNORM_EVENT", + "SOCK_WORKER_IDLE_HANDLER", + "SOCK_EP_BLOCK_POST_SEND", + "SOCK_EP_ASYNC_POST_SEND", + "SOCK_EP_ASYNC_POST_SEND_RAW", + "SOCK_EP_ASYNC_POST_SEND_RAW_SGL", + "SOCK_EP_ASYNC_POST_READ", + "SOCK_EP_ASYNC_POST_READ_SGL", + "SOCK_EP_ASYNC_POST_WRITE", + "SOCK_EP_ASYNC_POST_WRITE_SGL", + "SOCK_EP_SYNC_POST_SEND", + "SOCK_EP_SYNC_POST_SEND_RAW", + "SOCK_EP_SYNC_POST_SEND_RAW_SGL", + "SOCK_EP_SYNC_POST_READ", + "SOCK_EP_SYNC_POST_READ_SGL", + "SOCK_EP_SYNC_POST_WRITE", + "SOCK_EP_SYNC_POST_WRITE_SGL", + "SOCK_EP_SYNC_RECEIVE", + "SOCK_EP_SYNC_WAIT_COMPLETION", + + "SHM_DRIVER_INIT", + "SHM_DRIVER_UNINIT", + "SHM_DRIVER_START", + "SHM_DRIVER_STOP", + "SHM_DRIVER_CONNECT", + "SHM_DRIVER_CREATE_MEMORY_REGION", + "SHM_DRIVER_DESTORY_MEMORY_REGION", + "SHM_WORKER_BUSY_POLLING", + "SHM_WORKER_EVENT_POLLING", + "SHM_THREAD_CHANNEL_KEEPER", + "SHM_EP_ASYNC_POST_SEND", + "SHM_EP_ASYNC_POST_SEND_RAW", + "SHM_EP_ASYNC_POST_SEND_RAW_SGL", + "SHM_EP_ASYNC_POST_READ", + "SHM_EP_ASYNC_POST_READ_SGL", + "SHM_EP_ASYNC_POST_WRITE", + "SHM_EP_ASYNC_POST_WRITE_SGL", + "SHM_EP_ASYNC_SEND_FDS", + "SHM_EP_ASYNC_RECEIVE_FDS", + "SHM_EP_SYNC_POST_SEND", + "SHM_EP_SYNC_POST_SEND_RAW", + "SHM_EP_SYNC_POST_SEND_RAW_SGL", + "SHM_EP_SYNC_POST_READ", + "SHM_EP_SYNC_POST_READ_SGL", + "SHM_EP_SYNC_POST_WRITE", + "SHM_EP_SYNC_POST_WRITE_SGL", + "SHM_EP_SYNC_WAIT_COMPLETION", + "SHM_EP_SYNC_RECEIVE", + "SHM_EP_SYNC_RECEIVE_RAW", + + "UB_WORKER_BUSY_POLLING", + "UB_WORKER_EVENT_POLLING", + "UB_EP_ASYNC_POST_SEND", + "UB_EP_ASYNC_POST_SEND_RAW", + "UB_EP_ASYNC_POST_SEND_RAW_SGL", + "UB_EP_ASYNC_POST_READ", + "UB_EP_ASYNC_POST_READ_SGL", + "UB_EP_ASYNC_POST_WRITE", + "UB_EP_ASYNC_POST_WRITE_SGL", + "UB_EP_SYNC_POST_SEND", + "UB_EP_SYNC_POST_SEND_RAW", + "UB_EP_SYNC_POST_SEND_RAW_SGL", + "UB_EP_SYNC_POST_READ", + "UB_EP_SYNC_POST_READ_SGL", + "UB_EP_SYNC_POST_WRITE", + "UB_EP_SYNC_POST_WRITE_SGL", + + "OOB_START", + "OOB_STOP", + "OOB_CONN_SEND", + "OOB_CONN_RECEIVE", + "OOB_CONN_SEND_MSG", + "OOB_CONN_RECEIVE_MSG", + "OOB_ACCREPT_SOCKET", + "OOB_CONNECT_SOCKET", + "OOB_EXEC_CONN_TASK", + "OOB_SECINFO_PROVIDER", + "OOB_SECINFO_VALIDATOR", + + "SERVICE_IO_BROKEN_CALLBACK", + "SERVICE_POSTED_OR_DONE_CALLBACK", + "SERVICE_CALL_DONE_CALLBACK", + "SERVICE_RUN_CALLBACK", + "TIMEOUT_RUN_CALLBACK", +}; + +void NetTrace::Initialize() +{ + uint32_t stringArraySize = sizeof(g_idToString) / sizeof(g_idToString[0]); + if (NN_UNLIKELY(stringArraySize != MAX_MODULE_ID_INNER)) { + NN_LOG_WARN("Id to string table size " << stringArraySize << " different from trace size " << + MAX_MODULE_ID_INNER); + } + for (uint32_t traceId = 0; traceId < MAX_MODULE_ID_INNER; traceId++) { + if (traceId < stringArraySize) { + mPointProperty[traceId].name = g_idToString[traceId]; + } + } + + auto envString = getenv("HCOM_TRACE_LEVEL"); + if (envString != nullptr) { + long tmp = 0; + if (NetFunc::NN_Stol(envString, tmp) && tmp >= LEVEL0 && tmp <= LEVEL3) { + mEnableLevel = static_cast(tmp); + } + } else { + NN_LOG_INFO("Default trace level " << mEnableLevel); + } +} + +NetTrace *NetTrace::gTraceInst = nullptr; +std::mutex NetTrace::gTraceLock; + +void NetTrace::Instance() +{ + if (gTraceInst == nullptr) { + std::lock_guard locker(gTraceLock); + if (gTraceInst == nullptr) { + // double check nullptr + gTraceInst = new (std::nothrow) NetTrace(); + if (NN_UNLIKELY(gTraceInst == nullptr)) { + return; + } + gTraceInst->Initialize(); + } + } +} + +bool NetTrace::gEnableHtrace = false; +void NetTrace::HtraceInit(const std::string &name) +{ + HTracerInit(name); + auto envString = getenv("HCOM_ENABLE_TRACE"); + if (envString != nullptr) { + long tmp = 0; + if (NetFunc::NN_Stol(envString, tmp) && tmp >= LEVEL0 && tmp <= LEVEL1) { + gEnableHtrace = static_cast(tmp); + } else { + NN_LOG_WARN("Set env 'HCOM_ENABLE_TRACE' error, the value can be 0 or 1. "); + } + } else { + NN_LOG_INFO("Default diseable trace. "); + } + EnableHtrace(gEnableHtrace); +} +} +} diff --git a/src/common/net_trace.h b/src/common/net_trace.h new file mode 100644 index 0000000000000000000000000000000000000000..81c3b7016d49eaf5919fd3f51526cfed2725fdcc --- /dev/null +++ b/src/common/net_trace.h @@ -0,0 +1,537 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TRACE_H +#define HCOM_TRACE_H + +#include +#include +#include +#include +#include "hcom_def.h" +#include "hcom_utils.h" +#include "trace/htracer.h" + +namespace ock { +namespace hcom { +#define GENERATE_TRACE_ID(level, innerId) ((level) << NN_NO24 | ((innerId)&0xFFFFFF)) + +enum NetTraceLevel : uint8_t { + LEVEL0 = 0, // disable + LEVEL1, // High priority , LEVEL1 enable + LEVEL2, // Middle priority, LEVEL1 + LEVEL2 enable + LEVEL3, // Low priority , LEVEL1 + LEVEL2+LEVEL3 enable +}; + +enum NetTracePointId : uint32_t { + SERVICE_INSTANCE_INNER = 0, + SERVICE_START_INNER, + SERVICE_STOP_INNER, + SERVICE_CONNECT_DO_INNER, + SERVICE_RECONNECT_DO_INNER, + SERVICE_RECONNECT_COMFIRM_INNER, + SERVICE_REG_MR_INNER, + SERVICE_REG_MR_WITH_PTR_INNER, + SERVICE_DESTROY_MR_INNER, + SERVICE_OP_HANDLE_RNDV_INNER, + SERVICE_OP_HANDLE_RNDV_SGL_INNER, + SERVICE_OP_HANDLE_RECONNECT_INNER, + SERVICE_CB_REQUEST_RECEIVED_INNER, + SERVICE_CB_REQUEST_POSTED_INNER, + SERVICE_CB_ONESIDE_DONE_INNER, + SERVICE_CB_NEW_CHANNEL_INNER, + SERVICE_CB_BROKEN_CHANNEL_INNER, + SERVICE_THREAD_PERIODIC_INNER, + + CHANNEL_SEND_INNER, + CHANNEL_SEND_RAW_INNER, + CHANNEL_SEND_RAW_SGL_INNER, + CHANNEL_SYNC_CALL_INNER, + CHANNEL_ASYNC_CALL_INNER, + CHANNEL_SYNC_CALL_RAW_INNER, + CHANNEL_ASYNC_CALL_RAW_INNER, + CHANNEL_SYNC_CALL_RAW_SGL_INNER, + CHANNEL_ASYNC_CALL_RAW_SGL_INNER, + CHANNEL_SYNC_RNDV_CALL_INNER, + CHANNEL_ASYNC_RNDV_CALL_INNER, + CHANNEL_SYNC_RNDV_SGL_CALL_INNER, + CHANNEL_ASYNC_RNDV_SGL_CALL_INNER, + CHANNEL_READ_INNER, + CHANNEL_READ_SGL_INNER, + CHANNEL_WRITE_INNER, + CHANNEL_WRITE_SGL_INNER, + CHANNEL_SEND_FD_INNER, + CHANNEL_RECEIVE_FD_INNER, + + RDMA_DRIVER_INIT_INNER, + RDMA_DRIVER_UNINIT_INNER, + RDMA_DRIVER_START_INNER, + RDMA_DRIVER_STOP_INNER, + RDMA_DRIVER_CONNECT_EP_INNER, + RDMA_DRIVER_DESTROY_EP_INNER, + RDMA_THREAD_HEARTBEAT_INNER, + RDMA_THREAD_ASYNC_EVENT_INNER, + RDMA_WORKER_BUSY_POLLING_INNER, + RDMA_WORKER_EVENT_POLLING_INNER, + RDMA_EP_ASYNC_POST_SEND_INNER, + RDMA_EP_ASYNC_POST_SEND_RAW_INNER, + RDMA_EP_ASYNC_POST_SEND_RAW_SGL_INNER, + RDMA_EP_ASYNC_POST_READ_INNER, + RDMA_EP_ASYNC_POST_READ_SGL_INNER, + RDMA_EP_ASYNC_POST_WRITE_INNER, + RDMA_EP_ASYNC_POST_WRITE_SGL_INNER, + RDMA_EP_SYNC_POST_SEND_INNER, + RDMA_EP_SYNC_POST_SEND_RAW_INNER, + RDMA_EP_SYNC_POST_SEND_RAW_SGL_INNER, + RDMA_EP_SYNC_POST_READ_INNER, + RDMA_EP_SYNC_POST_READ_SGL_INNER, + RDMA_EP_SYNC_POST_WRITE_INNER, + RDMA_EP_SYNC_POST_WRITE_SGL_INNER, + RDMA_EP_SYNC_RECEIVE_INNER, + RDMA_EP_SYNC_WAIT_COMPLETION_INNER, + + SOCK_DRIVER_CONNECT_INNER, + SOCK_DRIVER_HANDLE_CONNECT_INNER, + SOCK_DRIVER_INITIALIZE_INNER, + SOCK_DRIVER_START_INNER, + SOCK_DRIVER_CREATE_WORKER_RESOURCE_INNER, + SOCK_DRIVER_CREATE_WORKERS_INNER, + SOCK_DRIVER_CREATE_CLIENT_LB_INNER, + SOCK_DRIVER_CREATE_LISTENERS_INNER, + SOCK_DRIVER_WORKER_START_INNER, + SOCK_DRIVER_START_LISTENERS_INNER, + SOCK_WORKER_EPOLL_WAIT_INNER, + SOCK_WORKER_HANDLE_EVENT_INNER, + SOCK_WORKER_HANDLE_EPOLLIN_EVENT_INNER, + SOCK_WORKER_HANDLE_EPOLL_OUT_EVENT_INNER, + SOCK_WORKER_HANDLE_EPOLL_WRNORM_EVENT_INNER, + SOCK_WORKER_IDLE_HANDLER_INNER, + SOCK_EP_BLOCK_POST_SEND_INNER, + SOCK_EP_ASYNC_POST_SEND_INNER, + SOCK_EP_ASYNC_POST_SEND_RAW_INNER, + SOCK_EP_ASYNC_POST_SEND_RAW_SGL_INNER, + SOCK_EP_ASYNC_POST_READ_INNER, + SOCK_EP_ASYNC_POST_READ_SGL_INNER, + SOCK_EP_ASYNC_POST_WRITE_INNER, + SOCK_EP_ASYNC_POST_WRITE_SGL_INNER, + SOCK_EP_SYNC_POST_SEND_INNER, + SOCK_EP_SYNC_POST_SEND_RAW_INNER, + SOCK_EP_SYNC_POST_SEND_RAW_SGL_INNER, + SOCK_EP_SYNC_POST_READ_INNER, + SOCK_EP_SYNC_POST_READ_SGL_INNER, + SOCK_EP_SYNC_POST_WRITE_INNER, + SOCK_EP_SYNC_POST_WRITE_SGL_INNER, + SOCK_EP_SYNC_RECEIVE_INNER, + SOCK_EP_SYNC_WAIT_COMPLETION_INNER, + + SHM_DRIVER_INIT_INNER, + SHM_DRIVER_UNINIT_INNER, + SHM_DRIVER_START_INNER, + SHM_DRIVER_STOP_INNER, + SHM_DRIVER_CONNECT_INNER, + SHM_DRIVER_CREATE_MEMORY_REGION_INNER, + SHM_DRIVER_DESTORY_MEMORY_REGION_INNER, + SHM_WORKER_BUSY_POLLING_INNER, + SHM_WORKER_EVENT_POLLING_INNER, + SHM_THREAD_CHANNEL_KEEPER_INNER, + SHM_EP_ASYNC_POST_SEND_INNER, + SHM_EP_ASYNC_POST_SEND_RAW_INNER, + SHM_EP_ASYNC_POST_SEND_RAW_SGL_INNER, + SHM_EP_ASYNC_POST_READ_INNER, + SHM_EP_ASYNC_POST_READ_SGL_INNER, + SHM_EP_ASYNC_POST_WRITE_INNER, + SHM_EP_ASYNC_POST_WRITE_SGL_INNER, + SHM_EP_ASYNC_SEND_FDS_INNER, + SHM_EP_ASYNC_RECEIVE_FDS_INNER, + SHM_EP_SYNC_POST_SEND_INNER, + SHM_EP_SYNC_POST_SEND_RAW_INNER, + SHM_EP_SYNC_POST_SEND_RAW_SGL_INNER, + SHM_EP_SYNC_POST_READ_INNER, + SHM_EP_SYNC_POST_READ_SGL_INNER, + SHM_EP_SYNC_POST_WRITE_INNER, + SHM_EP_SYNC_POST_WRITE_SGL_INNER, + SHM_EP_SYNC_WAIT_COMPLETION_INNER, + SHM_EP_SYNC_RECEIVE_INNER, + SHM_EP_SYNC_RECEIVE_RAW_INNER, + + UB_WORKER_BUSY_POLLING_INNER, + UB_WORKER_EVENT_POLLING_INNER, + UB_EP_ASYNC_POST_SEND_INNER, + UB_EP_ASYNC_POST_SEND_RAW_INNER, + UB_EP_ASYNC_POST_SEND_RAW_SGL_INNER, + UB_EP_ASYNC_POST_READ_INNER, + UB_EP_ASYNC_POST_READ_SGL_INNER, + UB_EP_ASYNC_POST_WRITE_INNER, + UB_EP_ASYNC_POST_WRITE_SGL_INNER, + UB_EP_SYNC_POST_SEND_INNER, + UB_EP_SYNC_POST_SEND_RAW_INNER, + UB_EP_SYNC_POST_SEND_RAW_SGL_INNER, + UB_EP_SYNC_POST_READ_INNER, + UB_EP_SYNC_POST_READ_SGL_INNER, + UB_EP_SYNC_POST_WRITE_INNER, + UB_EP_SYNC_POST_WRITE_SGL_INNER, + + OOB_START_INNER, + OOB_STOP_INNER, + OOB_CONN_SEND_INNER, + OOB_CONN_RECEIVE_INNER, + OOB_CONN_SEND_MSG_INNER, + OOB_CONN_RECEIVE_MSG_INNER, + OOB_ACCREPT_SOCKET_INNER, + OOB_CONNECT_SOCKET_INNER, + OOB_EXEC_CONN_TASK_INNER, + OOB_SECINFO_PROVIDER_INNER, + OOB_SECINFO_VALIDATOR_INNER, + + SERVICE_IO_BROKEN_CALLBACK_INNER, + SERVICE_POSTED_OR_DONE_CALLBACK_INNER, + SERVICE_CALL_DONE_CALLBACK_INNER, + SERVICE_RUN_CALLBACK_INNER, + TIMEOUT_RUN_CALLBACK_INNER, + + MAX_MODULE_ID_INNER, +}; + +enum NetTracePointIdWithLevel : uint32_t { + SERVICE_INSTANCE = TRACE_ID(SERVICE_INSTANCE_INNER, LEVEL1), + SERVICE_START = TRACE_ID(SERVICE_START_INNER, LEVEL1), + SERVICE_STOP = TRACE_ID(SERVICE_STOP_INNER, LEVEL1), + SERVICE_CONNECT_DO = TRACE_ID(SERVICE_CONNECT_DO_INNER, LEVEL1), + SERVICE_RECONNECT_DO = TRACE_ID(SERVICE_RECONNECT_DO_INNER, LEVEL1), + SERVICE_RECONNECT_COMFIRM = TRACE_ID(SERVICE_RECONNECT_COMFIRM_INNER, LEVEL1), + SERVICE_REG_MR = TRACE_ID(SERVICE_REG_MR_INNER, LEVEL1), + SERVICE_REG_MR_WITH_PTR = TRACE_ID(SERVICE_REG_MR_WITH_PTR_INNER, LEVEL1), + SERVICE_DESTROY_MR = TRACE_ID(SERVICE_DESTROY_MR_INNER, LEVEL1), + SERVICE_OP_HANDLE_RNDV = TRACE_ID(SERVICE_OP_HANDLE_RNDV_INNER, LEVEL1), + SERVICE_OP_HANDLE_RNDV_SGL = TRACE_ID(SERVICE_OP_HANDLE_RNDV_SGL_INNER, LEVEL1), + SERVICE_OP_HANDLE_RECONNECT = TRACE_ID(SERVICE_OP_HANDLE_RECONNECT_INNER, LEVEL1), + SERVICE_CB_REQUEST_RECEIVED = TRACE_ID(SERVICE_CB_REQUEST_RECEIVED_INNER, LEVEL1), + SERVICE_CB_REQUEST_POSTED = TRACE_ID(SERVICE_CB_REQUEST_POSTED_INNER, LEVEL1), + SERVICE_CB_ONESIDE_DONE = TRACE_ID(SERVICE_CB_ONESIDE_DONE_INNER, LEVEL1), + SERVICE_CB_NEW_CHANNEL = TRACE_ID(SERVICE_CB_NEW_CHANNEL_INNER, LEVEL1), + SERVICE_CB_BROKEN_CHANNEL = TRACE_ID(SERVICE_CB_BROKEN_CHANNEL_INNER, LEVEL1), + SERVICE_THREAD_PERIODIC = TRACE_ID(SERVICE_THREAD_PERIODIC_INNER, LEVEL1), + + // CHANNEL_SEND = TRACE_ID(CHANNEL_SEND_INNER, LEVEL1), + CHANNEL_SEND = TRACE_ID(CHANNEL_SEND_INNER, LEVEL1), + CHANNEL_SEND_RAW = TRACE_ID(CHANNEL_SEND_RAW_INNER, LEVEL1), + CHANNEL_SEND_RAW_SGL = TRACE_ID(CHANNEL_SEND_RAW_SGL_INNER, LEVEL1), + CHANNEL_SYNC_CALL = TRACE_ID(CHANNEL_SYNC_CALL_INNER, LEVEL1), + CHANNEL_ASYNC_CALL = TRACE_ID(CHANNEL_ASYNC_CALL_INNER, LEVEL1), + CHANNEL_SYNC_CALL_RAW = TRACE_ID(CHANNEL_SYNC_CALL_RAW_INNER, LEVEL1), + CHANNEL_ASYNC_CALL_RAW = TRACE_ID(CHANNEL_ASYNC_CALL_RAW_INNER, LEVEL1), + CHANNEL_SYNC_CALL_RAW_SGL = TRACE_ID(CHANNEL_SYNC_CALL_RAW_SGL_INNER, LEVEL1), + CHANNEL_ASYNC_CALL_RAW_SGL = TRACE_ID(CHANNEL_ASYNC_CALL_RAW_SGL_INNER, LEVEL1), + CHANNEL_SYNC_RNDV_CALL = TRACE_ID(CHANNEL_SYNC_RNDV_CALL_INNER, LEVEL1), + CHANNEL_ASYNC_RNDV_CALL = TRACE_ID(CHANNEL_ASYNC_RNDV_CALL_INNER, LEVEL1), + CHANNEL_SYNC_RNDV_SGL_CALL = TRACE_ID(CHANNEL_SYNC_RNDV_SGL_CALL_INNER, LEVEL1), + CHANNEL_ASYNC_RNDV_SGL_CALL = TRACE_ID(CHANNEL_ASYNC_RNDV_SGL_CALL_INNER, LEVEL1), + CHANNEL_READ = TRACE_ID(CHANNEL_READ_INNER, LEVEL1), + CHANNEL_READ_SGL = TRACE_ID(CHANNEL_READ_SGL_INNER, LEVEL1), + CHANNEL_WRITE = TRACE_ID(CHANNEL_WRITE_INNER, LEVEL1), + CHANNEL_WRITE_SGL = TRACE_ID(CHANNEL_WRITE_SGL_INNER, LEVEL1), + CHANNEL_SEND_FD = TRACE_ID(CHANNEL_SEND_FD_INNER, LEVEL1), + CHANNEL_RECEIVE_FD = TRACE_ID(CHANNEL_RECEIVE_FD_INNER, LEVEL1), + + RDMA_DRIVER_INIT = TRACE_ID(RDMA_DRIVER_INIT_INNER, LEVEL1), + RDMA_DRIVER_UNINIT = TRACE_ID(RDMA_DRIVER_UNINIT_INNER, LEVEL1), + RDMA_DRIVER_START = TRACE_ID(RDMA_DRIVER_START_INNER, LEVEL1), + RDMA_DRIVER_STOP = TRACE_ID(RDMA_DRIVER_STOP_INNER, LEVEL1), + RDMA_DRIVER_CONNECT_EP = TRACE_ID(RDMA_DRIVER_CONNECT_EP_INNER, LEVEL1), + RDMA_DRIVER_DESTROY_EP = TRACE_ID(RDMA_DRIVER_DESTROY_EP_INNER, LEVEL1), + RDMA_THREAD_HEARTBEAT = TRACE_ID(RDMA_THREAD_HEARTBEAT_INNER, LEVEL1), + RDMA_THREAD_ASYNC_EVENT = TRACE_ID(RDMA_THREAD_ASYNC_EVENT_INNER, LEVEL1), + RDMA_WORKER_BUSY_POLLING = TRACE_ID(RDMA_WORKER_BUSY_POLLING_INNER, LEVEL1), + RDMA_WORKER_EVENT_POLLING = TRACE_ID(RDMA_WORKER_EVENT_POLLING_INNER, LEVEL1), + RDMA_EP_ASYNC_POST_SEND = TRACE_ID(RDMA_EP_ASYNC_POST_SEND_INNER, LEVEL1), + RDMA_EP_ASYNC_POST_SEND_RAW = TRACE_ID(RDMA_EP_ASYNC_POST_SEND_RAW_INNER, LEVEL1), + RDMA_EP_ASYNC_POST_SEND_RAW_SGL = TRACE_ID(RDMA_EP_ASYNC_POST_SEND_RAW_SGL_INNER, LEVEL1), + RDMA_EP_ASYNC_POST_READ = TRACE_ID(RDMA_EP_ASYNC_POST_READ_INNER, LEVEL1), + RDMA_EP_ASYNC_POST_READ_SGL = TRACE_ID(RDMA_EP_ASYNC_POST_READ_SGL_INNER, LEVEL1), + RDMA_EP_ASYNC_POST_WRITE = TRACE_ID(RDMA_EP_ASYNC_POST_WRITE_INNER, LEVEL1), + RDMA_EP_ASYNC_POST_WRITE_SGL = TRACE_ID(RDMA_EP_ASYNC_POST_WRITE_SGL_INNER, LEVEL1), + RDMA_EP_SYNC_POST_SEND = TRACE_ID(RDMA_EP_SYNC_POST_SEND_INNER, LEVEL1), + RDMA_EP_SYNC_POST_SEND_RAW = TRACE_ID(RDMA_EP_SYNC_POST_SEND_RAW_INNER, LEVEL1), + RDMA_EP_SYNC_POST_SEND_RAW_SGL = TRACE_ID(RDMA_EP_SYNC_POST_SEND_RAW_SGL_INNER, LEVEL1), + RDMA_EP_SYNC_POST_READ = TRACE_ID(RDMA_EP_SYNC_POST_READ_INNER, LEVEL1), + RDMA_EP_SYNC_POST_READ_SGL = TRACE_ID(RDMA_EP_SYNC_POST_READ_SGL_INNER, LEVEL1), + RDMA_EP_SYNC_POST_WRITE = TRACE_ID(RDMA_EP_SYNC_POST_WRITE_INNER, LEVEL1), + RDMA_EP_SYNC_POST_WRITE_SGL = TRACE_ID(RDMA_EP_SYNC_POST_WRITE_SGL_INNER, LEVEL1), + RDMA_EP_SYNC_RECEIVE = TRACE_ID(RDMA_EP_SYNC_RECEIVE_INNER, LEVEL1), + RDMA_EP_SYNC_WAIT_COMPLETION = TRACE_ID(RDMA_EP_SYNC_WAIT_COMPLETION_INNER, LEVEL1), + + SOCK_DRIVER_CONNECT = TRACE_ID(SOCK_DRIVER_CONNECT_INNER, LEVEL1), + SOCK_DRIVER_HANDLE_CONNECT = TRACE_ID(SOCK_DRIVER_HANDLE_CONNECT_INNER, LEVEL1), + SOCK_DRIVER_INITIALIZE = TRACE_ID(SOCK_DRIVER_INITIALIZE_INNER, LEVEL1), + SOCK_DRIVER_START = TRACE_ID(SOCK_DRIVER_START_INNER, LEVEL1), + SOCK_DRIVER_CREATE_WORKER_RESOURCE = TRACE_ID(SOCK_DRIVER_CREATE_WORKER_RESOURCE_INNER, LEVEL1), + SOCK_DRIVER_CREATE_WORKERS = TRACE_ID(SOCK_DRIVER_CREATE_WORKERS_INNER, LEVEL1), + SOCK_DRIVER_CREATE_CLIENT_LB = TRACE_ID(SOCK_DRIVER_CREATE_CLIENT_LB_INNER, LEVEL1), + SOCK_DRIVER_CREATE_LISTENERS = TRACE_ID(SOCK_DRIVER_CREATE_LISTENERS_INNER, LEVEL1), + SOCK_DRIVER_WORKER_START = TRACE_ID(SOCK_DRIVER_WORKER_START_INNER, LEVEL1), + SOCK_DRIVER_START_LISTENERS = TRACE_ID(SOCK_DRIVER_START_LISTENERS_INNER, LEVEL1), + SOCK_WORKER_EPOLL_WAIT = TRACE_ID(SOCK_WORKER_EPOLL_WAIT_INNER, LEVEL1), + SOCK_WORKER_HANDLE_EVENT = TRACE_ID(SOCK_WORKER_HANDLE_EVENT_INNER, LEVEL1), + SOCK_WORKER_HANDLE_EPOLLIN_EVENT = TRACE_ID(SOCK_WORKER_HANDLE_EPOLLIN_EVENT_INNER, LEVEL1), + SOCK_WORKER_HANDLE_EPOLL_OUT_EVENT = TRACE_ID(SOCK_WORKER_HANDLE_EPOLL_OUT_EVENT_INNER, LEVEL1), + SOCK_WORKER_HANDLE_EPOLL_WRNORM_EVENT = TRACE_ID(SOCK_WORKER_HANDLE_EPOLL_WRNORM_EVENT_INNER, LEVEL1), + SOCK_WORKER_IDLE_HANDLER = TRACE_ID(SOCK_WORKER_IDLE_HANDLER_INNER, LEVEL1), + SOCK_EP_BLOCK_POST_SEND = TRACE_ID(SOCK_EP_BLOCK_POST_SEND_INNER, LEVEL1), + SOCK_EP_ASYNC_POST_SEND = TRACE_ID(SOCK_EP_ASYNC_POST_SEND_INNER, LEVEL1), + SOCK_EP_ASYNC_POST_SEND_RAW = TRACE_ID(SOCK_EP_ASYNC_POST_SEND_RAW_INNER, LEVEL1), + SOCK_EP_ASYNC_POST_SEND_RAW_SGL = TRACE_ID(SOCK_EP_ASYNC_POST_SEND_RAW_SGL_INNER, LEVEL1), + SOCK_EP_ASYNC_POST_READ = TRACE_ID(SOCK_EP_ASYNC_POST_READ_INNER, LEVEL1), + SOCK_EP_ASYNC_POST_READ_SGL = TRACE_ID(SOCK_EP_ASYNC_POST_READ_SGL_INNER, LEVEL1), + SOCK_EP_ASYNC_POST_WRITE = TRACE_ID(SOCK_EP_ASYNC_POST_WRITE_INNER, LEVEL1), + SOCK_EP_ASYNC_POST_WRITE_SGL = TRACE_ID(SOCK_EP_ASYNC_POST_WRITE_SGL_INNER, LEVEL1), + SOCK_EP_SYNC_POST_SEND = TRACE_ID(SOCK_EP_SYNC_POST_SEND_INNER, LEVEL1), + SOCK_EP_SYNC_POST_SEND_RAW = TRACE_ID(SOCK_EP_SYNC_POST_SEND_RAW_INNER, LEVEL1), + SOCK_EP_SYNC_POST_SEND_RAW_SGL = TRACE_ID(SOCK_EP_SYNC_POST_SEND_RAW_SGL_INNER, LEVEL1), + SOCK_EP_SYNC_POST_READ = TRACE_ID(SOCK_EP_SYNC_POST_READ_INNER, LEVEL1), + SOCK_EP_SYNC_POST_READ_SGL = TRACE_ID(SOCK_EP_SYNC_POST_READ_SGL_INNER, LEVEL1), + SOCK_EP_SYNC_POST_WRITE = TRACE_ID(SOCK_EP_SYNC_POST_WRITE_INNER, LEVEL1), + SOCK_EP_SYNC_POST_WRITE_SGL = TRACE_ID(SOCK_EP_SYNC_POST_WRITE_SGL_INNER, LEVEL1), + SOCK_EP_SYNC_RECEIVE = TRACE_ID(SOCK_EP_SYNC_RECEIVE_INNER, LEVEL1), + SOCK_EP_SYNC_WAIT_COMPLETION = TRACE_ID(SOCK_EP_SYNC_WAIT_COMPLETION_INNER, LEVEL1), + + SHM_DRIVER_INIT = TRACE_ID(SHM_DRIVER_INIT_INNER, LEVEL1), + SHM_DRIVER_UNINIT = TRACE_ID(SHM_DRIVER_UNINIT_INNER, LEVEL1), + SHM_DRIVER_START = TRACE_ID(SHM_DRIVER_START_INNER, LEVEL1), + SHM_DRIVER_STOP = TRACE_ID(SHM_DRIVER_STOP_INNER, LEVEL1), + SHM_DRIVER_CONNECT = TRACE_ID(SHM_DRIVER_CONNECT_INNER, LEVEL1), + SHM_DRIVER_CREATE_MEMORY_REGION = TRACE_ID(SHM_DRIVER_CREATE_MEMORY_REGION_INNER, LEVEL1), + SHM_DRIVER_DESTORY_MEMORY_REGION = TRACE_ID(SHM_DRIVER_DESTORY_MEMORY_REGION_INNER, LEVEL1), + SHM_WORKER_BUSY_POLLING = TRACE_ID(SHM_WORKER_BUSY_POLLING_INNER, LEVEL1), + SHM_WORKER_EVENT_POLLING = TRACE_ID(SHM_WORKER_EVENT_POLLING_INNER, LEVEL1), + SHM_THREAD_CHANNEL_KEEPER = TRACE_ID(SHM_THREAD_CHANNEL_KEEPER_INNER, LEVEL1), + SHM_EP_ASYNC_POST_SEND = TRACE_ID(SHM_EP_ASYNC_POST_SEND_INNER, LEVEL1), + SHM_EP_ASYNC_POST_SEND_RAW = TRACE_ID(SHM_EP_ASYNC_POST_SEND_RAW_INNER, LEVEL1), + SHM_EP_ASYNC_POST_SEND_RAW_SGL = TRACE_ID(SHM_EP_ASYNC_POST_SEND_RAW_SGL_INNER, LEVEL1), + SHM_EP_ASYNC_POST_READ = TRACE_ID(SHM_EP_ASYNC_POST_READ_INNER, LEVEL1), + SHM_EP_ASYNC_POST_READ_SGL = TRACE_ID(SHM_EP_ASYNC_POST_READ_SGL_INNER, LEVEL1), + SHM_EP_ASYNC_POST_WRITE = TRACE_ID(SHM_EP_ASYNC_POST_WRITE_INNER, LEVEL1), + SHM_EP_ASYNC_POST_WRITE_SGL = TRACE_ID(SHM_EP_ASYNC_POST_WRITE_SGL_INNER, LEVEL1), + SHM_EP_ASYNC_SEND_FDS = TRACE_ID(SHM_EP_ASYNC_SEND_FDS_INNER, LEVEL1), + SHM_EP_ASYNC_RECEIVE_FDS = TRACE_ID(SHM_EP_ASYNC_RECEIVE_FDS_INNER, LEVEL1), + SHM_EP_SYNC_POST_SEND = TRACE_ID(SHM_EP_SYNC_POST_SEND_INNER, LEVEL1), + SHM_EP_SYNC_POST_SEND_RAW = TRACE_ID(SHM_EP_SYNC_POST_SEND_RAW_INNER, LEVEL1), + SHM_EP_SYNC_POST_SEND_RAW_SGL = TRACE_ID(SHM_EP_SYNC_POST_SEND_RAW_SGL_INNER, LEVEL1), + SHM_EP_SYNC_POST_READ = TRACE_ID(SHM_EP_SYNC_POST_READ_INNER, LEVEL1), + SHM_EP_SYNC_POST_READ_SGL = TRACE_ID(SHM_EP_SYNC_POST_READ_SGL_INNER, LEVEL1), + SHM_EP_SYNC_POST_WRITE = TRACE_ID(SHM_EP_SYNC_POST_WRITE_INNER, LEVEL1), + SHM_EP_SYNC_POST_WRITE_SGL = TRACE_ID(SHM_EP_SYNC_POST_WRITE_SGL_INNER, LEVEL1), + SHM_EP_SYNC_WAIT_COMPLETION = TRACE_ID(SHM_EP_SYNC_WAIT_COMPLETION_INNER, LEVEL1), + SHM_EP_SYNC_RECEIVE = TRACE_ID(SHM_EP_SYNC_RECEIVE_INNER, LEVEL1), + SHM_EP_SYNC_RECEIVE_RAW = TRACE_ID(SHM_EP_SYNC_RECEIVE_RAW_INNER, LEVEL1), + + UB_WORKER_BUSY_POLLING = TRACE_ID(UB_WORKER_BUSY_POLLING_INNER, LEVEL1), + UB_WORKER_EVENT_POLLING = TRACE_ID(UB_WORKER_EVENT_POLLING_INNER, LEVEL1), + UB_EP_ASYNC_POST_SEND = TRACE_ID(UB_EP_ASYNC_POST_SEND_INNER, LEVEL1), + UB_EP_ASYNC_POST_SEND_RAW = TRACE_ID(UB_EP_ASYNC_POST_SEND_RAW_INNER, LEVEL1), + UB_EP_ASYNC_POST_SEND_RAW_SGL = TRACE_ID(UB_EP_ASYNC_POST_SEND_RAW_SGL_INNER, LEVEL1), + UB_EP_ASYNC_POST_READ = TRACE_ID(UB_EP_ASYNC_POST_READ_INNER, LEVEL1), + UB_EP_ASYNC_POST_READ_SGL = TRACE_ID(UB_EP_ASYNC_POST_READ_SGL_INNER, LEVEL1), + UB_EP_ASYNC_POST_WRITE = TRACE_ID(UB_EP_ASYNC_POST_WRITE_INNER, LEVEL1), + UB_EP_ASYNC_POST_WRITE_SGL = TRACE_ID(UB_EP_ASYNC_POST_WRITE_SGL_INNER, LEVEL1), + UB_EP_SYNC_POST_SEND = TRACE_ID(UB_EP_SYNC_POST_SEND_INNER, LEVEL1), + UB_EP_SYNC_POST_SEND_RAW = TRACE_ID(UB_EP_SYNC_POST_SEND_RAW_INNER, LEVEL1), + UB_EP_SYNC_POST_SEND_RAW_SGL = TRACE_ID(UB_EP_SYNC_POST_SEND_RAW_SGL_INNER, LEVEL1), + UB_EP_SYNC_POST_READ = TRACE_ID(UB_EP_SYNC_POST_READ_INNER, LEVEL1), + UB_EP_SYNC_POST_READ_SGL = TRACE_ID(UB_EP_SYNC_POST_READ_SGL_INNER, LEVEL1), + UB_EP_SYNC_POST_WRITE = TRACE_ID(UB_EP_SYNC_POST_WRITE_INNER, LEVEL1), + UB_EP_SYNC_POST_WRITE_SGL = TRACE_ID(UB_EP_SYNC_POST_WRITE_SGL_INNER, LEVEL1), + + OOB_START = TRACE_ID(OOB_START_INNER, LEVEL1), + OOB_STOP = TRACE_ID(OOB_STOP_INNER, LEVEL1), + OOB_CONN_SEND = TRACE_ID(OOB_CONN_SEND_INNER, LEVEL1), + OOB_CONN_RECEIVE = TRACE_ID(OOB_CONN_RECEIVE_INNER, LEVEL1), + OOB_CONN_SEND_MSG = TRACE_ID(OOB_CONN_SEND_MSG_INNER, LEVEL1), + OOB_CONN_RECEIVE_MSG = TRACE_ID(OOB_CONN_RECEIVE_MSG_INNER, LEVEL1), + OOB_ACCREPT_SOCKET = TRACE_ID(OOB_ACCREPT_SOCKET_INNER, LEVEL1), + OOB_CONNECT_SOCKET = TRACE_ID(OOB_CONNECT_SOCKET_INNER, LEVEL1), + OOB_EXEC_CONN_TASK = TRACE_ID(OOB_EXEC_CONN_TASK_INNER, LEVEL1), + OOB_SECINFO_PROVIDER = TRACE_ID(OOB_SECINFO_PROVIDER_INNER, LEVEL1), + OOB_SECINFO_VALIDATOR = TRACE_ID(OOB_SECINFO_VALIDATOR_INNER, LEVEL1), + + SERVICE_IO_BROKEN_CALLBACK = TRACE_ID(SERVICE_IO_BROKEN_CALLBACK_INNER, LEVEL1), + SERVICE_POSTED_OR_DONE_CALLBACK = TRACE_ID(SERVICE_POSTED_OR_DONE_CALLBACK_INNER, LEVEL1), + SERVICE_CALL_DONE_CALLBACK = TRACE_ID(SERVICE_CALL_DONE_CALLBACK_INNER, LEVEL1), + SERVICE_RUN_CALLBACK = TRACE_ID(SERVICE_RUN_CALLBACK_INNER, LEVEL1), + TIMEOUT_RUN_CALLBACK = TRACE_ID(TIMEOUT_RUN_CALLBACK_INNER, LEVEL1), +}; + +struct NetTraceItem { + uint64_t count = 0; + uint64_t success = 0; + uint64_t fail = 0; + uint64_t latency = 0; + uint64_t maxLatency = 0; + uint64_t minLatency = 0; + uint64_t rsv1 = 0; + uint64_t rsv2 = 0; +}; + +constexpr uint32_t BUCKET_MAX_SIZE = 32; +struct NetTracePoint { + NetTraceItem mItem[BUCKET_MAX_SIZE] = {}; +}; + +struct NetTracePointProperty { + std::string name; /* point name */ + uint64_t bucketIdx = 0; /* for different threads generate index by FAA */ +}; + +class NetTrace { +public: + static void HtraceInit(const std::string &name); + + static void Instance(); + + NetTrace() = default; + ~NetTrace() = default; + + void Initialize(); + + __always_inline static void TraceBegin(uint32_t id) + { + if (NN_UNLIKELY(gTraceInst == nullptr)) { + return; + } + + auto level = Level(id); + if (NN_UNLIKELY(level > gTraceInst->mEnableLevel)) { + return; + } + + auto item = gTraceInst->GetTraceItem(InnerId(id)); + __sync_fetch_and_add(&item->count, 1); + } + + __always_inline static void TraceEnd(uint32_t id, int ret) + { + if (NN_UNLIKELY(gTraceInst == nullptr)) { + return; + } + + auto level = Level(id); + if (NN_UNLIKELY(level > gTraceInst->mEnableLevel)) { + return; + } + + auto item = gTraceInst->GetTraceItem(InnerId(id)); + if (NN_LIKELY(ret == 0)) { + __sync_fetch_and_add(&item->success, 1); + } else { + __sync_fetch_and_add(&item->fail, 1); + } + } + + static std::string TraceDump() + { + std::ostringstream ossDump; + ossDump << std::endl; + ossDump << std::setw(NN_NO48) << std::setiosflags(std::ios::left) << "Name"; + ossDump << std::setw(NN_NO32) << std::setiosflags(std::ios::left) << "Begin"; + ossDump << std::setw(NN_NO32) << std::setiosflags(std::ios::left) << "EndGood"; + ossDump << std::setw(NN_NO32) << std::setiosflags(std::ios::left) << "EndBad"; + ossDump << std::endl; + + if (NN_UNLIKELY(gTraceInst == nullptr)) { + return ossDump.str(); + } + + for (uint32_t i = 0; i < MAX_MODULE_ID_INNER; i++) { + struct NetTraceItem recordItem {}; + bzero(&recordItem, sizeof(recordItem)); + for (auto &j : gTraceInst->mPoint[i].mItem) { + recordItem.count += __sync_fetch_and_add(&j.count, 0); + recordItem.success += __sync_fetch_and_add(&j.success, 0); + recordItem.fail += __sync_fetch_and_add(&j.fail, 0); + } + + if (recordItem.count == 0) { + /* ignore empty point */ + continue; + } + + ossDump << std::setw(NN_NO48) << std::setiosflags(std::ios::left) << gTraceInst->mPointProperty[i].name; + ossDump << std::setw(NN_NO32) << std::setiosflags(std::ios::left) << recordItem.count; + ossDump << std::setw(NN_NO32) << std::setiosflags(std::ios::left) << recordItem.success; + ossDump << std::setw(NN_NO32) << std::setiosflags(std::ios::left) << recordItem.fail; + ossDump << std::endl; + } + return ossDump.str(); + } + +private: + __always_inline NetTraceItem *GetTraceItem(uint16_t innerId) + { + static thread_local auto itemIndex = + __sync_fetch_and_add(&mPointProperty[innerId].bucketIdx, 1) % BUCKET_MAX_SIZE; + + return &mPoint[innerId].mItem[itemIndex]; + } + + static inline uint16_t Level(uint32_t id) + { + return (id >> NN_NO24) & 0xFF; + } + + static inline uint16_t InnerId(uint32_t id) + { + return id & 0xFFFFFF; + } + + NetTraceLevel mEnableLevel = LEVEL0; /* level switch for enable or not */ + NetTracePoint mPoint[MAX_MODULE_ID_INNER]; /* trace point */ + NetTracePointProperty mPointProperty[MAX_MODULE_ID_INNER]; /* trace point property include name */ + +private: + static NetTrace *gTraceInst; + static std::mutex gTraceLock; + static bool gEnableHtrace; +}; + +class TraceDefer { +public: + TraceDefer(std::function beginFunc, std::function endFunc) + { + mBeginFunc = beginFunc; + mEndFunc = endFunc; + if (mBeginFunc) { + mBeginFunc(); + } + } + ~TraceDefer() + { + if (mEndFunc) { + mEndFunc(); + } + } + + // Disable copy and assignment + TraceDefer(const TraceDefer &) = delete; + TraceDefer &operator = (const TraceDefer &) = delete; + +private: + std::function mBeginFunc; + std::function mEndFunc; +}; + +#define DO_TRACE(tracePoint, result) \ + uint64_t tpBegin; \ + NetTracePointIdWithLevel level = tracePoint; \ + std::string name = #tracePoint; \ + TraceDefer defer([level, name, &tpBegin]() { TRACE_DELAY_DEFER_BEGIN(level, name.c_str(), tpBegin); }, \ + [level, &result, &tpBegin]() { TRACE_DELAY_DEFER_END(level, (result), tpBegin); }) +} +} +#endif // HCOM_TRACE_H diff --git a/src/common/net_util.cpp b/src/common/net_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0615742512ecb82125eed2e73554a1fed6b00b50 --- /dev/null +++ b/src/common/net_util.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. +// Author: bao + +#include +#include "net_common.h" +#include "net_util.h" + +namespace ock { +namespace hcom { +bool BuffToHexString(void *buff, uint32_t buffSize, std::string &out) +{ + static const std::string hex = "0123456789ABCDEF"; + + if (NN_UNLIKELY(buff == nullptr)) { + NN_LOG_ERROR("Invalid buff ptr for serialize as buff is nullptr"); + return false; + } + + if (NN_UNLIKELY(buffSize > UINT32_MAX / NN_NO2)) { + NN_LOG_ERROR("Invalid buff size as is over half of UINT32_MAX"); + return false; + } + + auto tmpBuff = reinterpret_cast(buff); + out.clear(); + out.reserve(buffSize * NN_NO2); + + for (uint32_t i = 0; i < buffSize; i++) { + // push back high 4 bit + out.push_back(hex[static_cast(tmpBuff[i]) >> NN_NO4]); + // push back low 4 bit + out.push_back(hex[static_cast(tmpBuff[i]) & 0xF]); + } + + return true; +} + +bool HexStringToBuff(const std::string &str, uint32_t buffSize, void *buff) +{ + if (NN_UNLIKELY(buff == nullptr)) { + NN_LOG_ERROR("Invalid buff ptr for serialize as buff is nullptr"); + return false; + } + + if (NN_UNLIKELY(buffSize > UINT32_MAX / NN_NO2) || NN_UNLIKELY(str.size() < buffSize * NN_NO2)) { + NN_LOG_ERROR("Invalid str or buff size is over half of UINT32_MAX"); + return false; + } + + auto tmpBuff = reinterpret_cast(buff); + for (uint32_t i = 0; i < buffSize * NN_NO2; i += NN_NO2) { + std::string byte = str.substr(i, NN_NO2); + char *remain = nullptr; + long value = strtol(byte.c_str(), &remain, NN_NO16); + if (NN_UNLIKELY(remain == nullptr || strlen(remain) > 0 || value > NN_NO255)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get value as " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return false; + } + tmpBuff[i / NN_NO2] = value; + } + + return true; +} + +uint32_t GenerateSecureRandomUint32() +{ + uint32_t rand = 0; + std::ifstream urandom("/dev/urandom", std::ios::in | std::ios::binary); + if (!urandom.is_open()) { + NN_LOG_ERROR("Failed to open urandom"); + } + urandom.read(reinterpret_cast(&rand), sizeof(uint32_t)); + if (!urandom) { + urandom.close(); + NN_LOG_ERROR("Failed to read from urandom"); + } + urandom.close(); + return rand; +} +} // namespace hcom +} // namespace ock diff --git a/src/common/net_util.h b/src/common/net_util.h new file mode 100644 index 0000000000000000000000000000000000000000..6603ce37974af3cbe46f4c4f0da1f9a556e72d22 --- /dev/null +++ b/src/common/net_util.h @@ -0,0 +1,180 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. +// Author: bao + +#ifndef OCK_RDMA_UTIL_1233432457233_H +#define OCK_RDMA_UTIL_1233432457233_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hcom_err.h" +#include "hcom_num_def.h" +#include "hcom_def.h" +#include "hcom_log.h" + +namespace ock { +namespace hcom { +/* NetLocalAutoDecreasePtr */ +template class NetLocalAutoDecreasePtr { +public: + explicit NetLocalAutoDecreasePtr(T *obj) + { + if (obj != nullptr) { + mObj = obj; + mObj->IncreaseRef(); + } + } + + ~NetLocalAutoDecreasePtr() + { + if (mObj != nullptr) { + mObj->DecreaseRef(); + mObj = nullptr; + } + } + + inline T *Get() const + { + return mObj; + } + + /* + * @brief Set inner obj to null, after the inner obj will be free during de-constructor + */ + inline void SetNull() + { + mObj = nullptr; + } + + NetLocalAutoDecreasePtr() = delete; + NetLocalAutoDecreasePtr(const NetLocalAutoDecreasePtr &) = delete; + NetLocalAutoDecreasePtr(NetLocalAutoDecreasePtr &&) = delete; + // operator = + NetLocalAutoDecreasePtr &operator = (T *newObj) = delete; + NetLocalAutoDecreasePtr &operator = (const NetLocalAutoDecreasePtr &other) = delete; + NetLocalAutoDecreasePtr &operator = (NetLocalAutoDecreasePtr &&other) = delete; + +private: + T *mObj = nullptr; +}; + +/* NetLocalAutoFreePtr */ +template class NetLocalAutoFreePtr { +public: + explicit NetLocalAutoFreePtr(T *obj, bool isArray = false) : mObj(obj), mIsArray(isArray) {} + + ~NetLocalAutoFreePtr() + { + if (mObj == nullptr) { + return; + } + + if (mIsArray) { + delete[] mObj; + mObj = nullptr; + } else { + delete mObj; + mObj = nullptr; + } + } + + /* + * @brief Set inner obj to null, after the inner obj will be free during de-constructor + */ + inline void SetNull() + { + mObj = nullptr; + } + + /* + * @brief Get the inner obj ptr + */ + inline T *Get() const + { + return mObj; + } + + NetLocalAutoFreePtr() = delete; + NetLocalAutoFreePtr(const NetLocalAutoFreePtr &) = delete; + NetLocalAutoFreePtr(NetLocalAutoFreePtr &&) = delete; + // operator = + NetLocalAutoFreePtr &operator = (T *newObj) = delete; + NetLocalAutoFreePtr &operator = (const NetLocalAutoFreePtr &other) = delete; + NetLocalAutoFreePtr &operator = (NetLocalAutoFreePtr &&other) = delete; + +private: + T *mObj = nullptr; + bool mIsArray = false; +}; + +/// ScopeExit 主要功能为作用域退出时执行一些动作,常用于清理. +template class ScopeExit { +public: + ScopeExit(F f, bool active) : mHolder(std::move(f), active) + { + } + + ScopeExit(ScopeExit &&rhs) noexcept : mHolder(std::move(rhs.mHolder)) + { + rhs.Deactivate(); + } + + ScopeExit(const ScopeExit &) = delete; + ScopeExit &operator=(const ScopeExit &) = delete; + ScopeExit &operator=(ScopeExit &&) = delete; + + void Deactivate() + { + mHolder.mActive = false; + } + + bool Active() const + { + return mHolder.mActive; + } + + ~ScopeExit() + { + if (Active()) { + mHolder(); + } + } + +private: + struct FuncHolder : public F { + FuncHolder(F f, bool active) : F(std::move(f)), mActive(active) + { + } + + FuncHolder(FuncHolder &&rhs) noexcept : F(static_cast(rhs)), mActive(rhs.mActive) + { + rhs.mActive = false; + } + + FuncHolder& operator=(FuncHolder&&) = delete; + + bool mActive; + }; + + FuncHolder mHolder; +}; + +template auto MakeScopeExit(F f, bool active = true) -> ScopeExit +{ + return ScopeExit(std::move(f), active); +} + +bool BuffToHexString(void *buff, uint32_t buffSize, std::string &out); +bool HexStringToBuff(const std::string &str, uint32_t buffSize, void *buff); +uint32_t GenerateSecureRandomUint32(); +} // namespace hcom +} +#endif // OCK_RDMA_UTIL_1233432457233_H diff --git a/src/common/trace/htracer.cpp b/src/common/trace/htracer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4fad33b5efba6e920de76bc93af5b050b8d1fe57 --- /dev/null +++ b/src/common/trace/htracer.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "trace/htracer.h" +#include "htracer_manager.h" +#include "htracer_service.h" + +namespace ock { +namespace hcom { + +static HTracerService *traceService = nullptr; +#ifdef HTRACER_ENABLED +bool TraceManager::mEnable = true; +#else +bool TraceManager::mEnable = false; +#endif + +bool TraceManager::mLatencyQuantileEnable = false; + +std::string TraceManager::mDumpDir = ""; +std::string TraceManager::mDefaultDir = "/tmp/htrace/log"; +bool TraceManager::mDumpEnable = false; +static bool HtraceEnable(); + +HTRACE_INTF g_htraceIntf = {HtraceEnable, NULL, NULL, NULL, NULL}; +static bool g_htraceInit = false; + +static bool HtraceEnable() +{ + return TraceManager::IsEnable() && g_htraceInit; +} + +static void HtracerRegisterInterface(void) +{ + g_htraceIntf.DelayBegin = &TraceManager::DelayBegin; + g_htraceIntf.AsyncDelayBegin = &TraceManager::AsyncDelayBegin; + g_htraceIntf.DelayEnd = &TraceManager::DelayEnd; + g_htraceIntf.GetCurrentTimeNs = &TraceManager::GetTimeNs; +} + +int32_t HTracerInit(const std::string &serverName) +{ + if (traceService != nullptr) { + return SER_OK; + } + + HtracerRegisterInterface(); + + traceService = new (std::nothrow) HTracerService(); + if (traceService == nullptr) { + NN_LOG_WARN("[HTRACER] failed to malloc traceService"); + return SER_ERROR; + } + traceService->StartUp(serverName); + + auto ins = TraceManager::Instance(); + if (ins == nullptr) { + NN_LOG_WARN("[HTRACER] init trace manager instance failed"); + return SER_ERROR; + } + g_htraceInit = true; + return SER_OK; +} + +void HTracerExit(void) +{ + if (traceService != nullptr) { + traceService->ShutDown(); + delete traceService; + traceService = nullptr; + } + g_htraceInit = false; +} + +void EnableHtrace(bool enableTrace) +{ + TraceManager::SetEnable(enableTrace); +} + +} +} \ No newline at end of file diff --git a/src/common/trace/htracer.h b/src/common/trace/htracer.h new file mode 100644 index 0000000000000000000000000000000000000000..42509ffd21d048deb702a72bffed5bc6e7cdc4a0 --- /dev/null +++ b/src/common/trace/htracer.h @@ -0,0 +1,127 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACER_H +#define HTRACER_H + +#include +#include +#include +namespace ock { +namespace hcom { + +#define INVALID_PORT (0xFFFF) +#define TRACE_ID(SERVICE_ID_, INNER_ID_) ((SERVICE_ID_) << 16 | ((INNER_ID_) & 0xFFFF)) + +using HTRACE_INTF = struct HTRACE_INTF_S { + bool (*IsEnable)(); + void (*DelayBegin)(uint32_t tpId, const char *tpName); + struct timespec (*AsyncDelayBegin)(uint32_t tpId, const char *tpName); + void (*DelayEnd)(uint32_t tpId, const uint64_t diff, int32_t retCode); + uint64_t (*GetCurrentTimeNs)(); +}; + +extern int32_t HTracerInit(const std::string &serverName); +extern void HTracerExit(void); +extern void EnableHtrace(bool); +extern HTRACE_INTF g_htraceIntf; + +#define TRACE_DELAY_BEGIN(TP_ID) \ + uint64_t tpBegin##TP_ID = 0; \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayBegin(TP_ID, #TP_ID); \ + tpBegin##TP_ID = g_htraceIntf.GetCurrentTimeNs(); \ + } + +#define TRACE_DELAY_END(TP_ID, RET_CODE) \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayEnd(TP_ID, g_htraceIntf.GetCurrentTimeNs() - tpBegin##TP_ID, RET_CODE); \ + } + +#define TRACE_DELAY_DEFER_BEGIN(TP_ID, TP_NAME, TP_BEGIN_TIME) \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayBegin(TP_ID, TP_NAME); \ + TP_BEGIN_TIME = g_htraceIntf.GetCurrentTimeNs(); \ + } + +#define TRACE_DELAY_DEFER_END(TP_ID, RET_CODE, TP_BEGIN_TIME) \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayEnd(TP_ID, g_htraceIntf.GetCurrentTimeNs() - (TP_BEGIN_TIME), RET_CODE); \ + } + +#define TRACE_DELAY_BEGIN_ASYNC(TP_ID, BEGINTIME) \ + uint64_t tpBegin##TP_ID = 0; \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayBegin(TP_ID, #TP_ID); \ + tpBegin##TP_ID = BEGINTIME; \ + } + +#define GET_TIME_NS() \ + ({ \ + struct timespec tpDelay = { 0, 0 }; \ + clock_gettime(CLOCK_MONOTONIC, &tpDelay); \ + (uint64_t)(tpDelay.tv_nsec + tpDelay.tv_sec * 1000000000ULL); \ + }) + +// NOTICE: will be deprecated, use TRACE_V2_DELAY_BEGIN +#define ASYNC_TRACE_DELAY_BEGIN(TP_ID) g_htraceIntf.AsyncDelayBegin(TP_ID, #TP_ID) + +// NOTICE: will be deprecated, use TRACE_V2_DELAY_END +#define ASYNC_TRACE_DELAY_END(TP_ID, RET_CODE, STARTTIME) \ + struct timespec tpEnd##TP_ID = { 0, 0 }; \ + bool traceEnabled##TP_ID = g_htraceIntf.IsEnable(); \ + if (traceEnabled##TP_ID) { \ + clock_gettime(CLOCK_MONOTONIC, &tpEnd##TP_ID); \ + long tpDiff##TP_ID; \ + long tpDiffSec##TP_ID = tpEnd##TP_ID.tv_sec - (STARTTIME).tv_sec; \ + if (tpDiffSec##TP_ID == 0) { \ + tpDiff##TP_ID = tpEnd##TP_ID.tv_nsec - (STARTTIME).tv_nsec; \ + } else { \ + tpDiff##TP_ID = tpDiffSec##TP_ID * 1000000000 + tpEnd##TP_ID.tv_nsec - (STARTTIME).tv_nsec; \ + } \ + g_htraceIntf.DelayEnd(TP_ID, tpDiff##TP_ID, RET_CODE); \ + } + +#define TRACE_V2_DELAY_BEGIN(TP_ID, P_U64_TIME_NS) \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayBegin(TP_ID, #TP_ID); \ + (*(P_U64_TIME_NS)) = g_htraceIntf.GetCurrentTimeNs(); \ + } + +#define TRACE_V2_DELAY_END(TP_ID, U64_TIME_NS, RET_CODE) \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayEnd(TP_ID, (g_htraceIntf.GetCurrentTimeNs() - (U64_TIME_NS)), RET_CODE); \ + } + +#define TRACE_CURRENT_TIME_NS g_htraceIntf.GetCurrentTimeNs() + +#define TRACE_RECORD_DELAY(TP_ID, U64_DIFF_TIME_NS, RET_CODE) \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayBegin(TP_ID, #TP_ID); \ + g_htraceIntf.DelayEnd(TP_ID, (U64_DIFF_TIME_NS), RET_CODE); \ + } + +#define TRACE_IOSIZE_BEGIN(TP_ID) \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayBegin(TP_ID, #TP_ID); \ + } + + +#define TRACE_IOSIZE_END(TP_ID, IOSIZE, RET_CODE) \ + if (g_htraceIntf.IsEnable()) { \ + g_htraceIntf.DelayEnd(TP_ID, ((IOSIZE)*1000ULL), RET_CODE); \ + } +} +} +#endif + +// HTRACER_H \ No newline at end of file diff --git a/src/common/trace/htracer_info.h b/src/common/trace/htracer_info.h new file mode 100644 index 0000000000000000000000000000000000000000..d859e2e8b16c528c626efb61c579ad078b13b2ff --- /dev/null +++ b/src/common/trace/htracer_info.h @@ -0,0 +1,202 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACER_INFO_H +#define HTRACER_INFO_H + +#include +#include +#include +#include +#include +#include "htracer_utils.h" +#include "htracer_tdigest.h" +#include + +#define SERVICE_ID(TP_ID_) (((TP_ID_) >> 16) & 0xFFFF) +#define INNER_ID(TP_ID_) ((TP_ID_) & 0xFFFF) +#define INVALID_SERVICE_ID (0xFFFF) +#define MAX_SERVICE_NUM (256) +#define MAX_INNER_ID_NUM (2) + +namespace ock { +namespace hcom { + +class TraceInfo { +public: + __always_inline void DelayBegin(const char *tpName) + { + if (!isSetName) { + std::lock_guard lock(traceLock); + if (!isSetName) { + name = tpName; + isSetName = true; + } + } + + begin++; + } + + __always_inline void DelayEnd(uint64_t diff, int32_t retCode, bool lateQuantileEnable) + { + if (retCode != 0) { + badEnd++; + return; + } + if (diff < min) { + min = diff; + } + if (diff > max) { + max = diff; + } + + if (diff < periodMin) { + periodMin = diff; + } + if (diff > periodMax) { + periodMax = diff; + } + + if (lateQuantileEnable) { + tdigest.Insert(diff); + } + + total += diff; + goodEnd++; + } + + __always_inline void Reset() + { + begin = 0; + goodEnd = 0; + badEnd = 0; + min = UINT64_MAX; + max = 0; + total = 0; + tdigest.Reset(); + + latestBegin = 0; + latestGoodEnd = 0; + latestBadEnd = 0; + latestTotal = 0; + periodMin = UINT64_MAX; + periodMax = 0; + } + + __always_inline void RecordLatest() + { + latestBegin = begin; + latestGoodEnd = goodEnd; + latestBadEnd = badEnd; + latestTotal = total; + + periodMin = UINT64_MAX; + periodMax = 0; + } + + __always_inline const std::string GetName() const + { + return name; + } + + __always_inline void SetName(const std::string &name) + { + this->name = name; + } + + __always_inline uint64_t GetBegin() const + { + return begin; + } + + __always_inline uint64_t GetGoodEnd() const + { + return goodEnd; + } + + __always_inline uint64_t GetBadEnd() const + { + return badEnd; + } + + __always_inline uint64_t GetMin() const + { + return min; + } + + __always_inline uint64_t GetMax() const + { + return max; + } + + __always_inline uint64_t GetTotal() const + { + return total; + } + + __always_inline Tdigest GetTdigest() const + { + return tdigest; + } + + __always_inline bool Valid() const + { + return isSetName; + } + + __always_inline bool ValidPeriod() const + { + return (begin - latestBegin) > 0; + } + + std::string ToString() + { + return HTracerUtils::FormatString(name, begin, goodEnd, badEnd, min, max, total); + } + + std::string ToPeriodString() + { + uint64_t interBegin = begin - latestBegin; + uint64_t interGoodEnd = goodEnd - latestGoodEnd; + uint64_t interBadEnd = badEnd - latestBadEnd; + uint64_t interTotal = total - latestTotal; + uint64_t interMin = periodMin; + uint64_t interMax = periodMax; + RecordLatest(); + return HTracerUtils::FormatString(name, interBegin, interGoodEnd, interBadEnd, interMin, interMax, interTotal); + } + +private: + std::string name = ""; + volatile bool isSetName = false; + std::mutex traceLock; + + std::atomic begin = {0}; + std::atomic goodEnd = {0}; + std::atomic badEnd = {0}; + std::atomic min = {UINT64_MAX}; + std::atomic max = {0}; + std::atomic total = {0}; + Tdigest tdigest = Tdigest(20); + + uint64_t latestBegin = 0; + uint64_t latestGoodEnd = 0; + uint64_t latestBadEnd = 0; + uint64_t latestTotal = 0; + uint64_t periodMin = UINT64_MAX; + uint64_t periodMax = 0; +}; + +} +} + +#endif \ No newline at end of file diff --git a/src/common/trace/htracer_manager.h b/src/common/trace/htracer_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..6b600350edbb88f390f9ea6f80ae8fed4d0268e8 --- /dev/null +++ b/src/common/trace/htracer_manager.h @@ -0,0 +1,211 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACER_MANAGER_H +#define HTRACER_MANAGER_H + +#include +#include +#include +#include +#include +#include "trace/htracer.h" +#include "htracer_info.h" +#include "hcom_err.h" +#include "hcom_log.h" + +namespace ock { +namespace hcom { +class TraceManager { +public: + static __always_inline TraceInfo **Instance() + { + static TraceInfo **tracePoints = CreateInstance(); + return tracePoints; + } + + static __always_inline void DelayBegin(uint32_t tpId, const char *tpName) + { + auto instance = Instance(); + uint16_t serviceId = SERVICE_ID(tpId); + uint16_t innerId = INNER_ID(tpId); + if (serviceId >= MAX_SERVICE_NUM || innerId >= MAX_INNER_ID_NUM) { + return; + } + instance[serviceId][innerId].DelayBegin(tpName); + } + + static __always_inline void DelayEnd(uint32_t tpId, const uint64_t diff, int32_t retCode) + { + auto instance = Instance(); + uint16_t serviceId = SERVICE_ID(tpId); + uint16_t innerId = INNER_ID(tpId); + if (serviceId >= MAX_SERVICE_NUM || innerId >= MAX_INNER_ID_NUM) { + return; + } + if (mDumpEnable) { + DumpTraceSplitInfo(instance[serviceId][innerId].GetName(), diff, retCode); + } + instance[serviceId][innerId].DelayEnd(diff, retCode, mLatencyQuantileEnable); + } + + static __always_inline timespec AsyncDelayBegin(uint32_t tpId, const char *tpName) + { + DelayBegin(tpId, tpName); + struct timespec tpDelay = {0, 0}; + clock_gettime(CLOCK_MONOTONIC, &tpDelay); + return tpDelay; + } + + static __always_inline uint64_t GetTimeNs() + { + struct timespec tpDelay = {0, 0}; + clock_gettime(CLOCK_MONOTONIC, &tpDelay); + return tpDelay.tv_sec * 1000000000ULL + tpDelay.tv_nsec; + } + + static __always_inline void SetEnable(bool enable) + { + mEnable = enable; + } + + static __always_inline bool IsEnable() + { + return mEnable; + } + + static __always_inline void SetLatencyQuantileEnable(bool enable) + { + mLatencyQuantileEnable = enable; + } + + static __always_inline bool IsLatencyQuantileEnable() + { + return mLatencyQuantileEnable; + } + + static __always_inline void SetEnableLog(bool enable, std::string &logPath) + { + if (!logPath.empty() && !mDumpEnable) { + GetLogPath(logPath); + } + if (enable && mDumpDir.empty()) { + int32_t ret = HTracerUtils::CreateDirectory(mDefaultDir); + if (ret != SER_OK) { + NN_LOG_WARN("[htracer], prepare dum dir failed, disable dump feature!"); + } + mDumpDir = mDefaultDir; + } + mDumpEnable = enable; + } + + static __always_inline bool IsEnableLog() + { + return mDumpEnable; + } + +private: + static __always_inline TraceInfo **CreateInstance() + { + TraceInfo **instance = new (std::nothrow) TraceInfo *[MAX_SERVICE_NUM]; + if (instance == nullptr) { + return nullptr; + } else { + auto ret = memset_s(instance, sizeof(TraceInfo *) * MAX_SERVICE_NUM, + 0x0, sizeof(TraceInfo *) * MAX_SERVICE_NUM); + if (ret != 0) { + NN_LOG_WARN("[HTRACER] Failed to memset_s to instance."); + delete[] instance; + instance = nullptr; + return nullptr; + } + } + + int ret = 0; + uint16_t i = 0; + for (i = 0; i < MAX_SERVICE_NUM; ++i) { + instance[i] = new (std::nothrow) TraceInfo[MAX_INNER_ID_NUM]; + if (instance[i] == nullptr) { + ret = -1; + break; + } + } + + if (ret != 0) { + for (uint16_t j = 0; j < i; ++j) { + delete[] instance[j]; + instance[j] = nullptr; + } + delete[] instance; + instance = nullptr; + return nullptr; + } + return instance; + } + + static __always_inline void DumpTraceSplitInfo(std::string tpName, const uint64_t diff, int32_t retCode) + { + std::stringstream ss; + std::string currentTime = HTracerUtils::CurrentTime(); + ss << currentTime << "|" << tpName << "|" << retCode << "|" << diff << "(ns)" << std::endl; + + std::string dumpPath = mDumpDir + "/htrace_" + std::to_string(getpid()) + ".log"; + // 创建文件并设置权限 0640 (rw-r-----) + int fd = open(dumpPath.c_str(), O_WRONLY | O_CREAT | O_APPEND, 0640); + if (fd == -1) { + return; + } + + std::ofstream dump; + dump.open(dumpPath, std::ios::out | std::ios::app); + if (!dump.is_open()) { + close(fd); + return; + } + + dump << ss.str(); + dump.flush(); + dump.close(); + close(fd); + } + + static void GetLogPath(std::string &path) + { + if (!HTracerUtils::CanonicalPath(path)) { + NN_LOG_WARN("[HTRACER] Log directory is invalid, use default path. "); + /* path is error, use old path */ + if (!mDumpDir.empty()) { + return; + } + /* if path error and old path is empty, use default path */ + int32_t ret = HTracerUtils::CreateDirectory(mDefaultDir); + if (ret != SER_OK) { + NN_LOG_WARN("[HTRACER] prepare dump dir failed, disable dump feature!"); + } + mDumpDir = mDefaultDir; + return; + } + mDumpDir = path; + return; + } + +private: + static bool mEnable; + static bool mLatencyQuantileEnable; + static bool mDumpEnable; + static std::string mDumpDir; + static std::string mDefaultDir; +}; +} +} + +#endif \ No newline at end of file diff --git a/src/common/trace/htracer_msg.h b/src/common/trace/htracer_msg.h new file mode 100644 index 0000000000000000000000000000000000000000..a6617447bedf338ad3badd4b464e49a3fa906583 --- /dev/null +++ b/src/common/trace/htracer_msg.h @@ -0,0 +1,284 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACE_MSG_H +#define HTRACE_MSG_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include "rpc_msg.h" +#include "htracer_info.h" +#include "securec.h" +#include "htracer_tdigest.h" +#include "hcom_err.h" +#include "hcom_log.h" + +#define TRACE_INFO_MAX_LEN 63 + +namespace ock { +namespace hcom { + +constexpr uint32_t LOG_PATH_LENGTH = 260; + +enum MessageOpcode { + TRACE_OP_PING = 0, + TRACE_OP_QUERY = 1, + TRACE_OP_ENABLE_TRACE = 2, + TRACE_OP_RESET = 3 +}; + +struct TTraceInfo { + char name[TRACE_INFO_MAX_LEN + 1] = {0}; + uint64_t begin = 0; + uint64_t goodEnd = 0; + uint64_t badEnd = 0; + uint64_t min = UINT64_MAX; + uint64_t max = 0; + uint64_t total = 0; + double latencyQuentile = 0.0; + + explicit TTraceInfo(const char *name) + { + errno_t ret = strncpy_s(this->name, sizeof(this->name), name, + std::min(strlen(name), static_cast(TRACE_INFO_MAX_LEN))); + if (ret != EOK) { + NN_LOG_ERROR("[HTRACER] Failed to strncpy name, err: " << ret); + this->name[0] = '\0'; + } + } + + void operator += (const TTraceInfo &other) + { + begin += other.begin; + goodEnd += other.goodEnd; + badEnd += other.badEnd; + if (min >= other.min) { + min = other.min; + } + if (max <= other.max) { + max = other.max; + } + total += other.total; + } + + TTraceInfo(const TraceInfo &info, double quantile, bool enableTp) + { + errno_t ret = strncpy_s(this->name, sizeof(this->name), info.GetName().c_str(), + std::min(strlen(info.GetName().c_str()), static_cast(TRACE_INFO_MAX_LEN))); + if (ret != EOK) { + NN_LOG_ERROR("[HTRACER] Failed to strncpy name, err: " << ret); + this->name[0] = '\0'; + return; + } + begin = info.GetBegin(); + goodEnd = info.GetGoodEnd(); + badEnd = info.GetBadEnd(); + min = info.GetMin(); + max = info.GetMax(); + total = info.GetTotal(); + // get latency quantile + if (enableTp) { + latencyQuentile = -1.0; + if (quantile > 0 && quantile < NN_NO100) { + Tdigest tdigest = info.GetTdigest(); + tdigest.Merge(); + // "/1000" ns -> us + latencyQuentile = tdigest.Quantile(quantile)/NN_NO1000; + } + } + } + + enum TracePointTimeUnit { + NANO_SECOND, + MICRO_SECOND, + MILLI_SECOND, + SECOND, + TP_TIME_UNIT + }; + + std::string ToString(TracePointTimeUnit unit = MICRO_SECOND) const + { + static uint64_t TIME_UNIT_STEP[TP_TIME_UNIT] = { + 1, + NN_NO1000, + NN_NO1000000, + NN_NO1000000000 + }; + + static std::string TIME_UNIT_NAME[TP_TIME_UNIT] = { + "ns", + "us", + "ms", + "s" + }; + std::string str; + std::ostringstream os(str); + os.flags(std::ios::fixed); + os.precision(NN_NO3); + auto unitStep = TIME_UNIT_STEP[unit]; + auto unitName = TIME_UNIT_NAME[unit]; + os << "[" << std::left << std::setw(NN_NO50) << name << "]" + << "\t" << std::left << std::setw(NN_NO15) << begin << "\t" + << std::left << std::setw(NN_NO15) << goodEnd << "\t" + << std::left << std::setw(NN_NO15) << badEnd << "\t" + << std::left << std::setw(NN_NO15) + << ((begin > goodEnd - badEnd) ? (begin - goodEnd - badEnd) : 0) + << "\t" << std::left << std::setw(NN_NO15) + << (min == UINT64_MAX ? 0 : ((double)min / unitStep)) + << "\t" << std::left << std::setw(NN_NO15) + << (double)max / unitStep << "\t" << std::left << std::setw(NN_NO15) + << (goodEnd == 0 ? 0 : (double)total / goodEnd / unitStep) << "\t" + << std::left << std::setw(NN_NO15) + << (double)total / unitStep << "\t" << std::left << std::setw(NN_NO15) + << (latencyQuentile > 0 ? std::to_string(latencyQuentile) : "OFF"); + return os.str(); + } + + static std::string HeaderString() + { + std::stringstream ss; + ss << "\t[" << std::left << std::setw(NN_NO50) << "TP_NAME" + << "]" + << "\t" << std::left << std::setw(NN_NO15) << "TOTAL" + << "\t" << std::left << std::setw(NN_NO15) << "SUCCESS" + << "\t" << std::left << std::setw(NN_NO15) << "FAILURE" + << "\t" << std::left << std::setw(NN_NO15) << "UNFINISHED" + << "\t" << std::left << std::setw(NN_NO15) << "MIN(us)" + << "\t" << std::left << std::setw(NN_NO15) << "MAX(us)" + << "\t" << std::left << std::setw(NN_NO15) << "AVG(us)" + << "\t" << std::left << std::setw(NN_NO15) << "TOTAL(us)" + << "\t" << std::left << std::setw(NN_NO15) << "TPX(us)"; + return ss.str(); + } +}; + +struct ResetTraceInfoRequest : public MessageHeader { + ResetTraceInfoRequest() : MessageHeader(TRACE_OP_RESET) {} +}; + +struct ResetTraceInfoResponse : public MessageHeader { + ResetTraceInfoResponse() : MessageHeader(TRACE_OP_RESET) {} + + static SerCode BuildMessage(Message &message) + { + uint32_t messageSize = sizeof(ResetTraceInfoResponse); + void *messageData = malloc(messageSize); + if (messageData == nullptr) { + NN_LOG_WARN("[HTRACER] failed to malloc message data, size:" << messageSize); + return SER_ERROR; + } + bzero(messageData, messageSize); + + // fill message header. + auto queryResponse = reinterpret_cast(messageData); + queryResponse->version = VERSION; + queryResponse->magicCode = MAGIC_CODE; + queryResponse->crc = 0; + queryResponse->opcode = TRACE_OP_RESET; + queryResponse->bodySize = 0; + + message.SetMsg(messageData, messageSize); + + return SER_OK; + } +}; + +struct EnableTraceRequest : public MessageHeader { + bool enable = false; + bool enableTp = false; + bool enableLog = false; + char reserved[1] = {0}; + char logPath[LOG_PATH_LENGTH] = {0}; + EnableTraceRequest() : MessageHeader(TRACE_OP_ENABLE_TRACE) {} +}; + +struct EnableTraceResponse : public MessageHeader { + EnableTraceResponse() : MessageHeader(TRACE_OP_ENABLE_TRACE) {} + + static SerCode BuildMessage(Message &message) + { + uint32_t messageSize = sizeof(EnableTraceResponse); + void *messageData = malloc(messageSize); + if (messageData == nullptr) { + NN_LOG_WARN("[HTRACER] failed to malloc message data, size:" << messageSize); + return SER_ERROR; + } + bzero(messageData, messageSize); + + // fill message header. + auto queryResponse = reinterpret_cast(messageData); + queryResponse->version = VERSION; + queryResponse->magicCode = MAGIC_CODE; + queryResponse->crc = 0; + queryResponse->opcode = TRACE_OP_ENABLE_TRACE; + queryResponse->bodySize = 0; + + message.SetMsg(messageData, messageSize); + + return SER_OK; + } +}; + +struct QueryTraceInfoRequest : public MessageHeader { + uint16_t serviceId = INVALID_SERVICE_ID; + double quantile = 0; + QueryTraceInfoRequest() : MessageHeader(TRACE_OP_QUERY) {} +}; + +struct QueryTraceInfoResponse : public MessageHeader { + uint32_t traceInfoNum = 0; + pid_t pid = 0; + TTraceInfo traceInfo[0]; + + QueryTraceInfoResponse() : MessageHeader(TRACE_OP_QUERY) {} + + static SerCode BuildMessage(const std::vector &tTranceInfos, Message &message) + { + uint32_t bodySize = sizeof(uint32_t) + sizeof(pid_t) + sizeof(TTraceInfo) * tTranceInfos.size(); + uint32_t messageSize = sizeof(MessageHeader) + bodySize; + void *messageData = malloc(messageSize); + if (messageData == nullptr) { + NN_LOG_WARN("[HTRACER] failed to malloc message data, size:" << messageSize); + return SER_ERROR; + } + bzero(messageData, messageSize); + + // fill message header. + auto queryResponse = reinterpret_cast(messageData); + queryResponse->version = VERSION; + queryResponse->magicCode = MAGIC_CODE; + queryResponse->crc = 0; + queryResponse->opcode = TRACE_OP_QUERY; + queryResponse->bodySize = bodySize; + queryResponse->pid = getpid(); + queryResponse->traceInfoNum = tTranceInfos.size(); + + // file message body. + int i = 0; + for (const auto &info : tTranceInfos) { + queryResponse->traceInfo[i++] = info; + } + message.SetMsg(messageData, messageSize); + return SER_OK; + } +}; + +} +} + +#endif // HTRACE_MSG_H diff --git a/src/common/trace/htracer_service.cpp b/src/common/trace/htracer_service.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3161fd48c2da4f100fa4f129771d0b044efc6219 --- /dev/null +++ b/src/common/trace/htracer_service.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include +#include +#include +#include "htracer_service.h" +#include "htracer_msg.h" +#include "htracer_service_helper.h" +#include "htracer_utils.h" +#include "htracer_manager.h" + +namespace ock { +namespace hcom { + +// todo 支持命令修改 +#ifdef HTRACER_DUMP_ENABLED +#define TRACE_DUMP_PERIOD 5 +bool HTracerService::mDumpEnable = true; +#else +#define TRACE_DUMP_PERIOD 60 +bool HTracerService::mDumpEnable = false; +#endif +int HTracerService::mDumpPeriod = TRACE_DUMP_PERIOD; +std::string HTracerService::dumpDir = "/tmp/htrace"; + +int32_t HTracerService::StartUp(const std::string &serverName) +{ + if (mRpcServer != nullptr) { + return SER_OK; + } + + std::unique_ptr mRpcServerTmp(new RpcServer()); // compatible with c++11 + mRpcServer = std::move(mRpcServerTmp); + if (mRpcServer == nullptr) { + NN_LOG_WARN("[HTRACER] failed to create rpc server"); + return SER_ERROR; + } + + mRpcServer->RegisterRequestHandler( + std::bind(&HTracerService::HandleRequest, this, std::placeholders::_1, std::placeholders::_2)); + + mRpcServer->RegisterSentResponse( + std::bind(&HTracerService::SentResponse, this, std::placeholders::_1, std::placeholders::_2)); + + if (mRpcServer->Start(serverName) != SER_OK) { + NN_LOG_WARN("[HTRACER] failed to start rpc server"); + return SER_ERROR; + } + + int32_t ret = HTracerUtils::CreateDirectory(dumpDir); + if (ret != SER_OK) { + NN_LOG_WARN("[HTRACER] prepare dump dir failed, disable dump feature!"); + mDumpEnable = false; + } + + mIsRunning = true; + return SER_OK; +} + +void HTracerService::ShutDown() +{ + mIsRunning = false; + if (mRpcServer != nullptr) { + mRpcServer->Stop(); + } +} + +SerCode HTracerService::HandleRequest(const Message &request, Message &response) +{ + auto header = request.GetHeader(); + if (header == nullptr) { + NN_LOG_WARN("[HTRACER] header is nullptr"); + return SER_ERROR; + } + + switch (header->opcode) { + case TRACE_OP_QUERY: { + auto queryRequest = reinterpret_cast(request.GetData()); + auto tTranceInfos = TracerServiceHelper::GetTraceInfos(queryRequest->serviceId, queryRequest->quantile, + TraceManager::IsLatencyQuantileEnable()); + SerCode ret = QueryTraceInfoResponse::BuildMessage(tTranceInfos, response); + if (ret != SER_OK) { + NN_LOG_WARN("[HTRACER] failed to build response message"); + return ret; + } + break; + } + case TRACE_OP_RESET: { + TracerServiceHelper::ResetTraceInfos(); + SerCode ret = ResetTraceInfoResponse::BuildMessage(response); + if (ret != SER_OK) { + NN_LOG_WARN("[HTRACER] failed to build response message"); + return ret; + } + break; + } + case TRACE_OP_ENABLE_TRACE: { + auto enableRequest = reinterpret_cast(request.GetData()); + TraceManager::SetEnable(enableRequest->enable); + TraceManager::SetLatencyQuantileEnable(enableRequest->enableTp); + std::string logPath(enableRequest->logPath); + TraceManager::SetEnableLog(enableRequest->enableLog, logPath); + SerCode ret = EnableTraceResponse::BuildMessage(response); + if (ret != SER_OK) { + NN_LOG_WARN("[HTRACER] failed to build response message"); + return ret; + } + break; + } + default: + break; + } + return SER_OK; +} + +void HTracerService::SentResponse(SerCode result, Message &response) +{ + void *data = response.GetData(); + if (data != nullptr) { + free(data); + data = nullptr; + } +} +} +} \ No newline at end of file diff --git a/src/common/trace/htracer_service.h b/src/common/trace/htracer_service.h new file mode 100644 index 0000000000000000000000000000000000000000..390b6925a6d87efc7bd81ca7c14ee99fd82805b4 --- /dev/null +++ b/src/common/trace/htracer_service.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACE_SERVICE_H +#define HTRACE_SERVICE_H + +#include +#include +#include +#include +#include +#include "rpc_server.h" +#include "hcom_err.h" + +namespace ock { +namespace hcom { + +/*! + * trace_service + * 1. trace by service support + */ +class HTracerService { +public: + int32_t StartUp(const std::string &serverName); + + void ShutDown(); +private: + SerCode HandleRequest(const Message &request, Message &response); + void SentResponse(SerCode result, Message &response); + +private: + std::unique_ptr mRpcServer = nullptr; + std::condition_variable mDumpCond; + std::mutex mDumpLock; + volatile bool mIsRunning = false; + static int mDumpPeriod; + static std::string dumpDir; + static bool mDumpEnable; +}; + +} +} +#endif // HTRACE_SERVICE_H diff --git a/src/common/trace/htracer_service_helper.h b/src/common/trace/htracer_service_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..a50fcdf4473a8a4ccb00da7f5dea09835072eb6e --- /dev/null +++ b/src/common/trace/htracer_service_helper.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACE_HELPER_H +#define HTRACE_HELPER_H + +#include +#include "htracer_manager.h" +#include "htracer_msg.h" + +namespace ock { +namespace hcom { + +class TracerServiceHelper { +public: + static std::vector GetTraceInfos(uint16_t serviceId, double quantile, bool enableTp) + { + std::vector retTraceInfos; + auto traceManager = TraceManager::Instance(); + if (serviceId == INVALID_SERVICE_ID) { + for (int i = 0; i < MAX_SERVICE_NUM; ++i) { + for (int j = 0; j < MAX_INNER_ID_NUM; ++j) { + auto &traceInfo = traceManager[i][j]; + if (traceInfo.Valid()) { + retTraceInfos.emplace_back(TTraceInfo(traceManager[i][j], quantile, enableTp)); + } + } + } + return retTraceInfos; + } + + if (serviceId > MAX_SERVICE_NUM) { + return retTraceInfos; + } + + for (int i = 0; i < MAX_INNER_ID_NUM; ++i) { + auto &traceInfo = traceManager[serviceId][i]; + if (traceInfo.Valid()) { + retTraceInfos.emplace_back(TTraceInfo(traceInfo, quantile, enableTp)); + } + } + return retTraceInfos; + } + + static void ResetTraceInfos() + { + std::vector retTraceInfos; + auto traceManager = TraceManager::Instance(); + for (int i = 0; i < MAX_SERVICE_NUM; ++i) { + for (int j = 0; j < MAX_INNER_ID_NUM; ++j) { + auto &traceInfo = traceManager[i][j]; + if (traceInfo.Valid()) { + traceManager[i][j].Reset(); + } + } + } + } +}; + +} +} + +#endif // HTRACE_HELPER_H diff --git a/src/common/trace/htracer_tdigest.h b/src/common/trace/htracer_tdigest.h new file mode 100644 index 0000000000000000000000000000000000000000..3ab80066bd7047cf6f4c3ef8f6cbae35a4d72054 --- /dev/null +++ b/src/common/trace/htracer_tdigest.h @@ -0,0 +1,365 @@ +/** +* Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. +* t-digest algorithm used to implement tpX(e.g tp99 tp 90) +*/ +#ifndef _HTRACER_3RDPARTY_T_DIGEST_H +#define _HTRACER_3RDPARTY_T_DIGEST_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "hcom_num_def.h" + +namespace ock { +namespace hcom { + +enum class InsertResultCode { + NO_NEED_COMPERSS, + NEED_COMPERSS +}; + +class Centroid { +public: + Centroid(double newMean, uint32_t newWeight) : mean(newMean), weight(newWeight) {} + bool operator < (const Centroid ¢roid) const + { + return this->mean < centroid.mean; + } + bool operator > (const Centroid ¢roid) const + { + return this->mean > centroid.mean; + } + double GetMean() const + { + return mean; + } + uint32_t GetWeight() const + { + return weight; + } + +private: + double mean; + uint32_t weight; +}; + +class CentroidList { +public: + explicit CentroidList(size_t size) + { + centroids.reserve(size); + totalWeight = 0; + } + + InsertResultCode Insert(double mean, uint32_t weight) + { + if (mean < 0 || mean > UINT32_MAX) { + return InsertResultCode::NO_NEED_COMPERSS; + } + Centroid centroid(mean, weight); + centroids.emplace_back(centroid); + totalWeight += weight; + if (centroids.size() < centroids.capacity()) { + return InsertResultCode::NO_NEED_COMPERSS; + } + return InsertResultCode::NEED_COMPERSS; + } + + void Reset() + { + centroids.clear(); + totalWeight = 0; + } + + size_t GetCentroidCount() const + { + return centroids.size(); + } + + size_t GetTotalWeight() const + { + return totalWeight; + } + + std::vector& GetAndSetCentroids() + { + return centroids; + } + +private: + std::vector centroids; + uint32_t totalWeight; +}; + +// k_3 尺度函数 +inline double ComputeNormalizer(double compression, double n) +{ + const uint32_t NN_NO21 = 21; + return compression / (NN_NO4 * std::log(n / compression) + NN_NO21); +} + +inline double QuantileToScale(double q, double normalizer) +{ + const double qMin = 1e-15; + const double qMax = 1 - qMin; + const double qMid = 0.5; + + if (q < qMin) { + return (NN_NO2 * QuantileToScale(qMin, normalizer)); + } else if (q > qMax) { + return (NN_NO2 * QuantileToScale(qMax, normalizer)); + } + + if (q <= qMid) { + return log(NN_NO2 * q) * normalizer; + } else { + return -log(NN_NO2 * (1 - q)) * normalizer; + } +} + +inline double ScaleToQuantile(double k, double normalizer) +{ + if (k <= 0) { + return exp(k / normalizer) / NN_NO2; + } else { + return 1 - exp(-k / normalizer) / NN_NO2; + } +} + +inline double Lerp(double a, double b, double t) noexcept +{ + return a + t * (b - a); +} + +struct CompressionState { + double k1 = 0; + double nextQLimitWeight = 0; + double weightSoFar = 0; + double weightToAdd = 0; + double meanToAdd = 0; + const uint32_t newTotalWeight = 0; + const double normalizer = 0; + + CompressionState(uint32_t totalWeight, double norm) + : newTotalWeight(totalWeight), normalizer(norm) + { + k1 = QuantileToScale(NN_NO0, normalizer); + nextQLimitWeight = newTotalWeight * ScaleToQuantile(k1 + NN_NO1, normalizer); + weightSoFar = 0; + weightToAdd = 0; + meanToAdd = 0; + } + + void InitializeFirstCentroid(const Centroid& first) + { + weightToAdd = first.GetWeight(); + meanToAdd = first.GetMean(); + } + + void UpdateQuantileLimit() + { + double camelBack = static_cast(weightSoFar) / static_cast(newTotalWeight); + k1 = QuantileToScale(camelBack, normalizer); + nextQLimitWeight = newTotalWeight * QuantileToScale(k1 + 1, normalizer); + } +}; + +class Tdigest { +public: + explicit Tdigest(size_t size) + : one(size), two(size), buffer(NN_NO2 * size), active(&one), + minValue(std::numeric_limits::max()), + maxValue(std::numeric_limits::lowest()) {} + + void Insert(double value, uint32_t weight = 1) + { + auto insert_result = buffer.Insert(value, weight); + if (insert_result == InsertResultCode::NEED_COMPERSS) { + Merge(); + } + } + + void Reset() + { + one.Reset(); + two.Reset(); + buffer.Reset(); + active = &one; + minValue = std::numeric_limits::max(); + maxValue = std::numeric_limits::lowest(); + } + + void Merge() + { + std::vector &input = buffer.GetAndSetCentroids(); + if (input.empty() || NULL == active) { + return; + } + // 准备和排序数据 + PrepareAndSortData(input); + // 更新最小值和最大值 + UpdateMinMaxValues(input); + // 压缩合并数据 + auto &inactive = (&one == active) ? two : one; + CompressData(input, inactive); + // 清理和准备下一轮 + CleanUpAndPrepareNextRound(input, inactive); + } + + double Quantile(double p) const + { + if (p < 0 || p > NN_NO100) { + return 0.0; + } + if (nullptr == active) { + return 0.0; + } + if (active->GetAndSetCentroids().empty()) { + return 0.0; + } + if (active->GetAndSetCentroids().size() == 1) { + return (active->GetAndSetCentroids().front().GetMean()); + } + uint32_t index = (p / NN_NO100) * active->GetTotalWeight(); + // 处理边界情况 + if (index < NN_NO1) { + return minValue; + } + if (index > active->GetTotalWeight() - NN_NO1) { + return maxValue; + } + // 处理首位质心特殊插值 + const auto &first = active->GetAndSetCentroids().front(); + if (first.GetWeight() > NN_NO1 && index < (first.GetWeight() / NN_NO2)) { + return (Lerp(minValue, first.GetMean(), + static_cast(index - NN_NO1) / (first.GetWeight() / NN_NO2 - NN_NO1))); + } + + const auto &last = active->GetAndSetCentroids().back(); + if (last.GetWeight() > NN_NO1 && active->GetTotalWeight() - index <= last.GetWeight() / NN_NO2) { + return (maxValue - static_cast(active->GetTotalWeight() - index - NN_NO1) / + (last.GetWeight() / NN_NO2 - NN_NO1) * (maxValue - last.GetMean())); + } + + // 在质心对中查找中位数 + uint32_t currentWeight = first.GetWeight() / 2; + for (size_t i = 0; i < active->GetAndSetCentroids().size() - 1; i++) { + const auto &left = active->GetAndSetCentroids()[i]; + const auto &right = active->GetAndSetCentroids()[i + 1]; + uint32_t segmentWeight = (left.GetWeight() + right.GetWeight()) / 2; + if (currentWeight + segmentWeight > index) { + uint32_t lower = index - currentWeight; + uint32_t upper = currentWeight + segmentWeight - index; + return (left.GetMean() * upper + right.GetMean() *lower) / (lower + upper); + } + currentWeight += segmentWeight; + } + return active->GetAndSetCentroids().back().GetMean(); + } + +private: + void PrepareAndSortData(std::vector &input) + { + if (forward) { + std::sort(input.begin(), input.end(), std::less()); + } else { + std::sort(input.begin(), input.end(), std::greater()); + } + } + void UpdateMinMaxValues(const std::vector &input) + { + if (forward) { + UpdateMinMaxForward(input); + } else { + UpdateMinMaxBackward(input); + } + } + void UpdateMinMaxForward(const std::vector &input) + { + minValue = std::min(minValue, + input.front().GetWeight() == NN_NO1 ? input.front().GetMean() : + std::numeric_limits::max()); + maxValue = std::max(maxValue, + input.back().GetWeight() == NN_NO1 ? input.back().GetMean() : + std::numeric_limits::min()); + } + void UpdateMinMaxBackward(const std::vector &input) + { + minValue = std::min(minValue, + input.back().GetWeight() == NN_NO1 ? input.back().GetMean() : + std::numeric_limits::max()); + maxValue = std::max(maxValue, + input.front().GetWeight() == NN_NO1 ? input.front().GetMean() : + std::numeric_limits::min()); + } + void CompressData(const std::vector &input, CentroidList &inactive) + { + const uint32_t newTotalWeight = buffer.GetTotalWeight() + active->GetTotalWeight(); + const double normalizer = ComputeNormalizer(inactive.GetAndSetCentroids().capacity(), newTotalWeight); + CompressionState state(newTotalWeight, normalizer); + // 初始化第一个质心 + state.InitializeFirstCentroid(input.front()); + // 处理剩余质心 + for (auto it = input.begin() + 1; it != input.end(); ++it) { + ProcessCentroid(*it, state, inactive); + } + // 处理最后一个质心 + if (state.weightToAdd != 0) { + if (std::is_integral::value) { + state.meanToAdd = std::round(state.meanToAdd); + } + inactive.Insert(state.meanToAdd, state.weightToAdd); + } + } + void CleanUpAndPrepareNextRound(std::vector &input, CentroidList &inactive) + { + if (!forward) { + std::sort(inactive.GetAndSetCentroids().begin(), inactive.GetAndSetCentroids().end()); + } + forward = !forward; + buffer.Reset(); + input.assign(inactive.GetAndSetCentroids().begin(), inactive.GetAndSetCentroids().end()); + auto newInactive = active; + active = &inactive; + newInactive->Reset(); + } + void ProcessCentroid(const Centroid ¤t, CompressionState &state, CentroidList &inactive) + { + if ((state.weightSoFar + state.weightToAdd + current.GetWeight()) <= state.nextQLimitWeight) { + // 合并到当前质心 + state.weightToAdd += current.GetWeight(); + state.meanToAdd = state.meanToAdd + + (current.GetMean() - state.meanToAdd) * current.GetWeight() / state.weightToAdd; + } else { + // 开始新的质心 + state.weightSoFar += state.weightToAdd; + state.UpdateQuantileLimit(); + if (std::is_integral::value) { + state.meanToAdd = std::round(state.meanToAdd); + } + inactive.Insert(state.meanToAdd, state.weightToAdd); + state.meanToAdd = current.GetMean(); + state.weightToAdd = current.GetWeight(); + } + } + +private: + CentroidList one; + CentroidList two; + CentroidList buffer; + CentroidList *active; + double minValue; + double maxValue; + bool forward = true; +}; + +} +} + +#endif // _HTRACER_3RDPARTY_T_DIGEST_H \ No newline at end of file diff --git a/src/common/trace/htracer_utils.h b/src/common/trace/htracer_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3eceff9064fb331993f51c26de466376a7bb46a4 --- /dev/null +++ b/src/common/trace/htracer_utils.h @@ -0,0 +1,126 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACE_UTILS_H +#define HTRACE_UTILS_H + +#include "securec.h" +#include "hcom_num_def.h" +#include "htracer_tdigest.h" +#include "net_common.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ock { +namespace hcom { +class HTracerUtils { +public: + static std::string &StrTrim(std::string &str) + { + if (str.empty()) { + return str; + } + str.erase(0, str.find_first_not_of(" ")); + str.erase(str.find_last_not_of(" ") + 1); + return str; + } + + static std::string CurrentTime() + { + time_t rawTime; + time(&rawTime); + auto tmInfo = localtime(&rawTime); + std::stringstream ss; + ss << std::setfill('0') << std::setw(NN_NO4) << std::right << (NN_NO1900 + tmInfo->tm_year) << "-" << + std::setfill('0') << std::setw(NN_NO2) << std::right << (NN_NO1 + tmInfo->tm_mon) << "-" << + std::setfill('0') << std::setw(NN_NO2) << std::right << tmInfo->tm_mday << " " << + std::setfill('0') << std::setw(NN_NO2) << std::right << tmInfo->tm_hour << ":" << + std::setfill('0') << std::setw(NN_NO2) << std::right << tmInfo->tm_min << ":" << + std::setfill('0') << std::setw(NN_NO2) << std::right << tmInfo->tm_sec; + return ss.str(); + } + + static std::string FormatString(std::string &name, uint64_t begin, uint64_t goodEnd, uint64_t badEnd, uint64_t min, + uint64_t max, uint64_t total) + { + std::string str; + std::ostringstream os(str); + os.flags(std::ios::fixed); + os.precision(NN_NO3); + auto unitStep = NN_NO1000; + os << std::left << std::setw(NN_NO50) << name + << "\t" << std::left << std::setw(NN_NO15) << begin + << "\t" << std::left << std::setw(NN_NO15) << goodEnd + << "\t" << std::left << std::setw(NN_NO15) << badEnd + << "\t" << std::left << std::setw(NN_NO15) << ((begin > goodEnd - badEnd) ? (begin - goodEnd - badEnd) : 0) + << "\t" << std::left << std::setw(NN_NO15) << (min == UINT64_MAX ? 0 : ((double)min / unitStep)) + << "\t" << std::left << std::setw(NN_NO15) << (double)max / unitStep + << "\t" << std::left << std::setw(NN_NO15) << (goodEnd == 0 ? 0 : (double)total / goodEnd / unitStep) + << "\t" << std::left << std::setw(NN_NO15) << (double)total / unitStep; + return os.str(); + } + + static int CreateDirectory(const std::string &name) + { + std::vector paths; + NetFunc::NN_SplitStr(name, "/", paths); + int32_t ret = 0; + std::string pathTmp; + for (auto &item : paths) { + if (item.empty()) { + continue; + } + + pathTmp += "/" + item; + if (access(pathTmp.c_str(), F_OK) != 0) { + mode_t old_mask = umask(0); + ret = mkdir(pathTmp.c_str(), S_IRWXU | S_IRGRP | S_IXGRP); + umask(old_mask); + if (ret != 0 && errno != EEXIST) { + break; + } + } + } + return ret; + } + + /* * + * @brief Check whether the path is canonical, and canonical it. + */ + static bool CanonicalPath(std::string &path) + { + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + + /* It will allocate memory to store path */ + char *realPath = realpath(path.c_str(), nullptr); + if (realPath == nullptr) { + return false; + } + + path = realPath; + free(realPath); + realPath = nullptr; + return true; + } +}; +} +} +#endif // HTRACE_UTILS_H diff --git a/src/common/trace/rpc_msg.h b/src/common/trace/rpc_msg.h new file mode 100644 index 0000000000000000000000000000000000000000..bcde1ac96bab2985779fce7067935e76c5336f2a --- /dev/null +++ b/src/common/trace/rpc_msg.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef RPC_MSG_H +#define RPC_MSG_H +#include + +#define VERSION 1 +#define MAGIC_CODE 0xABABABAB +#define INVALID_OPCODE 0xFFFFFFFF + +namespace ock { +namespace hcom { + +struct MessageHeader { + uint32_t version = VERSION; + uint32_t magicCode = MAGIC_CODE; + uint32_t crc = 0; + uint32_t opcode = INVALID_OPCODE; + uint32_t bodySize = 0; + uint32_t reserved = 0; + explicit MessageHeader(uint32_t opcode) : opcode(opcode) {} +}; + +class Message { +public: + Message(void *data, uint32_t dataSize) : mData(data), mSize(dataSize) {} + Message() : Message(nullptr, 0) {} + + void *GetData() const + { + return mData; + } + + void SetMsg(void *data, uint32_t size) + { + mData = data; + mSize = size; + } + + uint32_t GetSize() const + { + return mSize; + } + + const MessageHeader *GetHeader() const + { + if (mData == nullptr) { + return nullptr; + } + return reinterpret_cast(mData); + } + +private: + void *mData = nullptr; + uint32_t mSize = 0; +}; + +class MessageValidator { +public: + static bool Validate(const Message &message) + { + void *messageData = message.GetData(); + uint32_t messageSize = message.GetSize(); + if (messageData == nullptr || messageSize == 0) { + return false; + } + + MessageHeader *header = reinterpret_cast(messageData); + if (header->version != VERSION || + header->magicCode != MAGIC_CODE + || header->bodySize + sizeof(MessageHeader) > messageSize) { + return false; + } + return true; + } +}; + +} +} + +#endif // RPC_MSG_H \ No newline at end of file diff --git a/src/common/trace/rpc_server.cpp b/src/common/trace/rpc_server.cpp new file mode 100644 index 0000000000000000000000000000000000000000..adaab09f1d458f4356d92db6e0cd383077b400a4 --- /dev/null +++ b/src/common/trace/rpc_server.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "rpc_server.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "rpc_msg.h" +#include "securec.h" +#include "hcom_log.h" + +#define MAX_CONNECT_NUM (2) + +namespace ock { +namespace hcom { +SerCode RpcServer::Start(const std::string &serverName) +{ + // create listen socket; + mSockFd = ::socket(AF_UNIX, SOCK_STREAM, 0); + if (mSockFd == -1) { + NN_LOG_WARN("[HTRACER] failed to create sock"); + return SER_ERROR; + } + + std::string abstractSockName(1, '\0'); + abstractSockName += serverName; + + struct sockaddr_un un; + auto ret = memset_s(&un, sizeof(un), 0, sizeof(un)); + if (ret != 0) { + NN_LOG_WARN("[HTRACER] failed to memset_s sockaddr un"); + close(mSockFd); + return SER_ERROR; + } + un.sun_family = AF_UNIX; + ret = memcpy_s(un.sun_path, abstractSockName.length() + 1, abstractSockName.c_str(), abstractSockName.length() + 1); + if (ret != 0) { + NN_LOG_WARN("[HTRACER] failed to memcpy_s to sun_path"); + close(mSockFd); + return SER_ERROR; + } + + if (bind(mSockFd, reinterpret_cast(&un), sizeof(un)) < 0) { + NN_LOG_WARN("[HTRACER] failed to bind socket"); + close(mSockFd); + return SER_ERROR; + } + + if (listen(mSockFd, MAX_CONNECT_NUM) < 0) { + NN_LOG_WARN("[HTRACER] listen failed"); + std::cout<<"failed to listen"< +#include +#include +#include "rpc_msg.h" +#include "hcom_err.h" + +namespace ock { +namespace hcom { + +using RequestHandler = std::function; +using SentResponse = std::function; + +class RpcServer { +public: + RpcServer() {} + + void RegisterRequestHandler(const RequestHandler requestHandler) + { + mRequestHandler = requestHandler; + } + + void RegisterSentResponse(const SentResponse sentResponse) + { + mSentResponse = sentResponse; + } + + SerCode Start(const std::string &serverName); + + void Stop(); + + uint16_t GetPort() + { + return mPort; + } + +private: + RequestHandler mRequestHandler = nullptr; + SentResponse mSentResponse = nullptr; + int32_t mSockFd = -1; + bool mRunning = true; + uint16_t mPort = 0xFFFF; +}; + +} +} +#endif // RPC_SERVER_H diff --git a/src/hcom.cpp b/src/hcom.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a8ff1cfb301a0d6c41e7ac7371feb5663a5df604 --- /dev/null +++ b/src/hcom.cpp @@ -0,0 +1,1614 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "hcom.h" + +#include +#include +#include + +#include "hcom_def.h" +#include "hcom_log.h" +#include "securec.h" +#include "code_msg.h" +#include "common/net_common.h" +#include "net_mem_allocator.h" +#include "net_mem_allocator_cache.h" +#include "net_oob.h" +#include "net_oob_ssl.h" +#include "hcom_obj_statistics.h" +#include "net_rdma_driver_oob.h" +#include "net_trace.h" +#include "net_sock_driver_oob.h" +#include "net_shm_driver_oob.h" +#include "verbs_api_wrapper.h" +#include "trace/htracer.h" + +#ifdef UB_BUILD_ENABLED +#include "net_ub_driver_oob.h" +#include "obmm_api_wrapper.h" +#endif + +namespace ock { +namespace hcom { +constexpr const int KP_ID = 0x48; +namespace { +// SIGPIPE will be triggered when sending data to a closed socket +struct HcomInit { + HcomInit() noexcept + { + std::signal(SIGPIPE, SIG_IGN); + } +} g_hcomInitializer; +} // namespace + +const UBSHcomNetWorkerIndex &UBSHcomNetEndpoint::WorkerIndex() const +{ + return mWorkerIndex; +} + +bool UBSHcomNetEndpoint::IsEstablished() +{ + return mState.Compare(NEP_ESTABLISHED); +} + +const std::string &UBSHcomNetEndpoint::PeerConnectPayload() const +{ + return mPayload; +} + +uint32_t UBSHcomNetEndpoint::LocalIp() const +{ + return mLocalIp; +} + +uint16_t UBSHcomNetEndpoint::ListenPort() const +{ + return mListenPort; +} + +uint8_t UBSHcomNetEndpoint::Version() const +{ + return mVersion; +} + +void UBSHcomNetEndpoint::DefaultTimeout(int32_t timeout) +{ + if (NN_UNLIKELY(timeout > static_cast(NN_NO65536))) { + NN_LOG_WARN("Invalid operation to set timeout, the time is less than 65536."); + return; + } + mDefaultTimeout = timeout; +} + +void UBSHcomNetEndpoint::StoreConnInfo(uint32_t localIp, uint16_t listenPort, uint8_t version, + const std::string &payload) +{ + mLocalIp = localIp; + mListenPort = listenPort; + mVersion = version; + mPayload = payload; +} + +void UBSHcomNetEndpoint::Payload(const std::string &payload) +{ + mPayload = payload; +} + +void UBSHcomNetEndpoint::RemoteUdsIdInfo(uint32_t pid, uint32_t uid, uint32_t gid) +{ + mRemoteUdsIdInfo = UBSHcomNetUdsIdInfo(pid, uid, gid); +} + +NResult UBSHcomNetMemoryAllocator::Create(UBSHcomNetMemoryAllocatorType t, + const UBSHcomNetMemoryAllocatorOptions &options, UBSHcomNetMemoryAllocatorPtr &out) +{ + if (t == DYNAMIC_SIZE) { + NetLocalAutoDecreasePtr alloc(new (std::nothrow) NetMemAllocator()); + if (alloc.Get() == nullptr) { + NN_LOG_ERROR("Failed to new memory allocator obj with type '" << + UBSHcomNetMemoryAllocatorTypeToString(t) << "'"); + return NN_NEW_OBJECT_FAILED; + } + + auto ret = alloc.Get()->Initialize(options.address, options.size, options.minBlockSize, options.alignedAddress); + if (ret != NN_OK) { + NN_LOG_ERROR("Failed to initialize allocator obj with type '" << + UBSHcomNetMemoryAllocatorTypeToString(t) << "'"); + return NN_ERROR; + } + + out.Set(alloc.Get()); + + return NN_OK; + } else if (t == DYNAMIC_SIZE_WITH_CACHE) { + NetLocalAutoDecreasePtr alloc(new (std::nothrow) NetMemAllocator()); + if (alloc.Get() == nullptr) { + NN_LOG_ERROR("Failed to new memory allocator with type '" << + UBSHcomNetMemoryAllocatorTypeToString(t) << "'"); + return NN_NEW_OBJECT_FAILED; + } + + auto ret = alloc.Get()->Initialize(options.address, options.size, options.minBlockSize, options.alignedAddress); + if (ret != NN_OK) { + NN_LOG_ERROR("Failed to initialize allocator with type '" << + UBSHcomNetMemoryAllocatorTypeToString(t) << "'"); + return NN_ERROR; + } + + NetLocalAutoDecreasePtr cache(new (std::nothrow) NetAllocatorCache(alloc.Get())); + if (cache.Get() == nullptr) { + NN_LOG_ERROR("Failed to new memory allocator cache with type '" << + UBSHcomNetMemoryAllocatorTypeToString(t) << "'"); + return NN_NEW_OBJECT_FAILED; + } + + ret = cache.Get()->Initialize(options); + if (ret != NN_OK) { + NN_LOG_ERROR("Failed to initialize allocator cache with type '" << + UBSHcomNetMemoryAllocatorTypeToString(t) << + "'"); + return NN_ERROR; + } + + out.Set(cache.Get()); + + return NN_OK; + } + + NN_LOG_ERROR("Invalid net memory allocator type " << t); + return NN_ERROR; +} + +bool UBSHcomNetOobListenerOptions::SetEid(const std::string &eid, uint16_t id, uint16_t twc) +{ + port = id; + targetWorkerCount = twc; + return HexStringToBuff(eid, NN_NO16, ip); +} + +bool UBSHcomNetOobListenerOptions::Set(const std::string &pIp, uint16_t pp, uint16_t twc) +{ + if (NN_UNLIKELY(Ip(pIp) == NN_ERROR)) { + return false; + } + port = pp; + targetWorkerCount = twc; + return true; +} + +bool UBSHcomNetOobListenerOptions::Set(const std::string &pIp, uint16_t pp) +{ + return Set(pIp, pp, UINT16_MAX); +} + +bool UBSHcomNetOobListenerOptions::Set(uint16_t pp, uint16_t twc) +{ + port = pp; + targetWorkerCount = twc; + return true; +} + +NResult UBSHcomNetOobListenerOptions::Ip(const std::string &value) +{ + if (NN_LIKELY(UBSHcomNetCloneStringToArray(ip, sizeof(ip), value))) { + return NN_OK; + } + + return NN_ERROR; +} + +std::string UBSHcomNetOobListenerOptions::Ip() const +{ + return NN_CHAR_ARRAY_TO_STRING(ip); +} + +bool UBSHcomNetOobUDSListenerOptions::Set(const std::string &pName, uint16_t twc) +{ + if (NN_UNLIKELY(!Name(pName))) { + return false; + } + targetWorkerCount = twc; + return true; +} + +bool UBSHcomNetOobUDSListenerOptions::Name(const std::string &value) +{ + NN_SET_CHAR_ARRAY_FROM_STRING(name, value); +} + +std::string UBSHcomNetOobUDSListenerOptions::Name() const +{ + return NN_CHAR_ARRAY_TO_STRING(name); +} + +uint32_t UBSHcomNetDriver::gMaxListenPort = NN_NO16; +uint8_t UBSHcomNetDriver::gDriverIndex = 0; +std::mutex UBSHcomNetDriver::gDriverMapMutex; +std::map UBSHcomNetDriver::gDriverMap; +int32_t UBSHcomNetDriver::gOSMaxFdCount = -1; + +NResult UBSHcomNetDriver::ValidateKunpeng() +{ + std::ifstream file; + file.open("/sys/devices/system/cpu/cpu0/regs/identification/midr_el1"); + if (!file) { + NN_LOG_ERROR("Failed to new driver, sys file cannot be open"); + return NN_ERROR; + } + std::string line; + getline(file, line); + int machineID = 0; + try { + machineID = std::stoi(line, nullptr, NN_NO16) >> NN_NO24; + } catch (...) { + NN_LOG_ERROR("Failed to new driver, as stoi failed"); + } + file.close(); + if (machineID != KP_ID) { + NN_LOG_ERROR("Failed to new driver, CPU company id is invalid"); + return NN_ERROR; + } + + return NN_OK; +} + +UBSHcomNetDriver *UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol t, const std::string &name, bool startOobSvr) +{ +#ifdef ENABLE_ARM_KP + if (NN_UNLIKELY(ValidateKunpeng() != NN_OK)) { + return nullptr; + } +#endif + if (NN_UNLIKELY(NetFunc::NN_ValidateName(name) != NN_OK)) { + return nullptr; + } + + UBSHcomNetDriver *driver = nullptr; + + auto envString = getenv("HCOM_ENABLE_TRACE"); + long level = 0; + if (envString != nullptr && NetFunc::NN_Stol(envString, level) && level > LEVEL0) { + NetTrace::Instance(); + NetTrace::HtraceInit(name); + } + + std::lock_guard locker(gDriverMapMutex); + auto iter = gDriverMap.find(name); + if (iter != gDriverMap.end()) { + NN_LOG_WARN("Driver named " << name << " is already existed, the existed one will be returned"); + return iter->second; + } + + switch (t) { + case UBSHcomNetDriverProtocol::RDMA: +#ifdef RDMA_BUILD_ENABLED + if (HcomIbv::Load() != 0) { + NN_LOG_ERROR("Failed to load verbs API"); + return nullptr; + } + + driver = new (std::nothrow) NetDriverRDMAWithOob(name, startOobSvr, t); + break; +#else + NN_LOG_ERROR("Failed to new driver, RDMA not enabled"); + return nullptr; +#endif + case UBSHcomNetDriverProtocol::UBC: +#ifdef UB_BUILD_ENABLED + if (HcomUrma::Load() != 0) { + NN_LOG_ERROR("Failed to load urma API"); + return nullptr; + } + driver = new (std::nothrow) NetDriverUBWithOob(name, startOobSvr, t); + break; +#else + NN_LOG_ERROR("Failed to new driver, UB not enabled"); + return nullptr; +#endif + + case UBSHcomNetDriverProtocol::TCP: +#ifdef SOCK_BUILD_ENABLED + driver = new (std::nothrow) NetDriverSockWithOOB(name, startOobSvr, t, SockType::SOCK_TCP); + break; +#else + NN_LOG_ERROR("Failed to new driver, TCP not enabled"); + return nullptr; +#endif + case UBSHcomNetDriverProtocol::UDS: +#ifdef SOCK_BUILD_ENABLED + driver = new (std::nothrow) NetDriverSockWithOOB(name, startOobSvr, t, SockType::SOCK_UDS); + break; +#else + NN_LOG_ERROR("Failed to new driver, UDS not enabled"); + return nullptr; +#endif + case UBSHcomNetDriverProtocol::SHM: +#ifdef SHM_BUILD_ENABLED + driver = new (std::nothrow) NetDriverShmWithOOB(name, startOobSvr, t); + break; +#else + NN_LOG_ERROR("Failed to new driver, SHM not enabled"); + return nullptr; +#endif + default: + NN_LOG_ERROR("Failed to new driver " << name << " for " << UBSHcomNetDriverProtocolToString(t) << + ", not implemented yet"); + break; + } + + if (driver != nullptr) { + driver->IncreaseRef(); + driver->mIndex = gDriverIndex++; + std::tie(iter, std::ignore) = gDriverMap.emplace(name, driver); + } else { + NN_LOG_ERROR("Failed to new driver " << name << " for " << UBSHcomNetDriverProtocolToString(t) << + ", probably out of memory"); + return nullptr; + } + +#ifdef HCOM_COMMIT_ID + NN_LOG_INFO("hcom build commit: " << HCOM_COMMIT_ID); +#endif + +#ifdef HCOM_COMPONENT_VERSION + NN_LOG_INFO("Hcom version :" << HCOM_COMPONENT_VERSION); + std::string ComponentVersion = HCOM_COMPONENT_VERSION; + std::vector versions; + NetFunc::NN_SplitStr(ComponentVersion, ".", versions); + + if (versions.size() < NN_NO2) { + NN_LOG_ERROR("parsing version failed!"); + gDriverMap.erase(iter); + delete driver; + driver = nullptr; + return nullptr; + } + + long version; + if (NetFunc::NN_Stol(versions[0], version)) { + driver->mMajorVersion = version; + } + + if (NetFunc::NN_Stol(versions[1], version)) { + driver->mMinorVersion = version; + } +#endif + + gOSMaxFdCount = static_cast(sysconf(_SC_OPEN_MAX)); + if (NN_UNLIKELY(gOSMaxFdCount == -1)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Unable to get limit of open files, errno: " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + } else { + NN_LOG_INFO("Limit of open files is " << gOSMaxFdCount << ", please check if it is big enough"); + } + + return driver; +} + +NResult UBSHcomNetDriver::DestroyInstance(const std::string &name) +{ + if (NN_UNLIKELY(NetFunc::NN_ValidateName(name) != NN_OK)) { + return NN_ERROR; + } + UBSHcomNetDriver *driver = nullptr; + { + std::lock_guard locker(gDriverMapMutex); + auto iter = gDriverMap.find(name); + if (NN_UNLIKELY((iter == gDriverMap.end()) || (iter->second == nullptr))) { + NN_LOG_ERROR("Failed to destroy driver, because " << name << "driver was not found or does not exist"); + return NN_ERROR; + } + + driver = iter->second; + if (NN_UNLIKELY(driver->IsInited() || driver->IsStarted())) { + NN_LOG_ERROR("Please stop or unInitialize the driver " << name << + " first, the current driver status cannot be destroyed"); + return NN_ERROR; + } + gDriverMap.erase(iter); + } + driver->DecreaseRef(); + HTracerExit(); + return NN_OK; +} + +bool UBSHcomNetDriver::LocalSupport(UBSHcomNetDriverProtocol t, UBSHcomNetDriverDeviceInfo &deviceInfo) +{ + UBSHcomNetDriverDeviceInfo tmpInfo {}; +#ifdef RDMA_BUILD_ENABLED + std::vector enabledDevice; + uint16_t devCount = 0; +#endif + std::lock_guard locker(gDriverMapMutex); + switch (t) { + case UBSHcomNetDriverProtocol::RDMA: +#ifdef RDMA_BUILD_ENABLED + if (HcomIbv::Load() != 0) { + NN_LOG_WARN("Unable to load verbs API, therefore cannot run RDMA app"); + return false; + } + + if (RDMADeviceHelper::GetDeviceCount(devCount, enabledDevice) != NN_OK || enabledDevice.empty()) { + NN_LOG_WARN("Unable to get RDMA devices or no active device found, therefore cannot run RDMA app"); + return false; + } + + for (auto &iter : enabledDevice) { + tmpInfo.maxSge = iter.deviceInfo.maxSge < tmpInfo.maxSge ? iter.deviceInfo.maxSge : tmpInfo.maxSge; + } + NN_LOG_TRACE_INFO("device count " << devCount << ", active devices count " << enabledDevice.size()); + + return true; +#else + NN_LOG_WARN("Unable to get RDMA devices or no active device found, rdma compilation not enabled"); + return false; +#endif + + case UBSHcomNetDriverProtocol::TCP: + case UBSHcomNetDriverProtocol::UDS: + case UBSHcomNetDriverProtocol::SHM: + return true; + default: + NN_LOG_WARN("Un-supported protocol"); + break; + } + + deviceInfo = tmpInfo; + return false; +} + +bool UBSHcomNetDriver::MultiRailGetDevCount(UBSHcomNetDriverProtocol t, std::string ipMask, uint16_t &enableDevCount, + std::string ipGroup) +{ +#if defined(RDMA_BUILD_ENABLED) || defined(UB_BUILD_ENABLED) + uint16_t devCount = 0; + std::vector enableIps; +#endif + std::lock_guard locker(gDriverMapMutex); + switch (t) { + case UBSHcomNetDriverProtocol::RDMA: +#ifdef RDMA_BUILD_ENABLED + if (HcomIbv::Load() != 0) { + NN_LOG_WARN("Unable to load verbs API, therefore cannot run RDMA app"); + return false; + } + + if (RDMADeviceHelper::GetEnableDeviceCount(ipMask, devCount, enableIps, ipGroup) != NN_OK || + devCount == 0) { + NN_LOG_WARN("Unable to get RDMA devices or no active device found, therefore cannot run RDMA app"); + return false; + } + enableDevCount = devCount; + + return true; +#else + NN_LOG_WARN("Unable to get RDMA devices or no active device found, rdma compilation not enabled"); + return false; +#endif + + case UBSHcomNetDriverProtocol::TCP: + case UBSHcomNetDriverProtocol::UDS: + case UBSHcomNetDriverProtocol::SHM: + return true; + case UBSHcomNetDriverProtocol::UBC: +#ifdef UB_BUILD_ENABLED + if (HcomUrma::Load() != 0) { + NN_LOG_WARN("Failed to load verbs API, unable to run RDMA app"); + return false; + } + + if (UBDeviceHelper::GetEnableDeviceCount(ipMask, devCount, enableIps, ipGroup) != UB_OK || devCount == 0) { + NN_LOG_WARN("Failed to get URMA devices or no active device found, unable to run URMA app"); + return false; + } + enableDevCount = devCount; + return true; +#endif + NN_LOG_WARN("Failed to get URMA devices or no active device found, URMA compilation not enabled"); + return false; + + default: + NN_LOG_WARN("Un-supported protocol"); + break; + } + + return false; +} + +/* + * @brief Create listeners, must be called after workers created and need to set new conn handler * + */ +NResult UBSHcomNetDriver::CreateListeners(bool enableMultiRail) +{ + if (enableMultiRail) { + return CreateServerLB(); + } + if (mOptions.oobType != NET_OOB_UDS && mOptions.oobType != NET_OOB_TCP) { + NN_LOG_ERROR("Un-supported oob type " << mOptions.oobType << " is set in driver " << mName); + return NN_INVALID_PARAM; + } else if (mOptions.oobType == NET_OOB_UDS) { + return CreateUdsListeners(); + } + + if (mOobListenOptions.empty()) { + NN_LOG_ERROR("No listen info is set for oob type " << UBSHcomNetDriverOobTypeToString(mOptions.oobType) << + " in driver " << mName); + return NN_INVALID_PARAM; + } + + uint16_t oobIndex = 0; + for (auto &lOpt : mOobListenOptions) { + NetOOBServerPtr oobServer = nullptr; + /* create oob server */ + if (mEnableTls) { + auto oobSSLServer = new (std::nothrow) OOBSSLServer(mOptions.oobType, lOpt.Ip(), lOpt.port, + mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(oobSSLServer != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLServer->SetTlsOptions(mOptions.cipherSuite, mOptions.tlsVersion); + oobSSLServer->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + oobServer = oobSSLServer; + } else { + oobServer = new (std::nothrow) OOBTCPServer(mOptions.oobType, lOpt.Ip(), lOpt.port); + NN_ASSERT_LOG_RETURN(oobServer.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + if (lOpt.port == 0) { + if (oobServer->EnableAutoPortSelection(mPortRange.first, mPortRange.second)) { + return NN_INVALID_PARAM; + } + } + + NN_LOG_TRACE_INFO(lOpt.second.Ip()); + + oobServer->Index({ mIndex, oobIndex++ }); + oobServer->SetMaxConntionNum(mOptions.maxConnectionNum); + + /* create load balancer for each oob server */ + auto twc = lOpt.targetWorkerCount == 0 ? UINT16_MAX : lOpt.targetWorkerCount; + NetWorkerLBPtr lb = new (std::nothrow) NetWorkerLB(mName, mOptions.lbPolicy, twc); + if (NN_UNLIKELY(lb == nullptr)) { + NN_LOG_ERROR("Failed to new oob load balancer in driver " << mName); + return NN_NEW_OBJECT_FAILED; + } + + /* attach lb to oob server in case of leak */ + oobServer->SetWorkerLb(lb.Get()); + + /* add worker groups to lb */ + if (NN_UNLIKELY(lb->AddWorkerGroups(mWorkerGroups) != NN_OK)) { + NN_LOG_ERROR("Failed to added worker groups into load balancer in driver " << mName); + return NN_NEW_OBJECT_FAILED; + } + + oobServer->IncreaseRef(); + mOobServers.emplace_back(oobServer.Get()); + } + + if (mOobListenOptions.size() != mOobServers.size()) { + NN_LOG_ERROR("Created oob server count " << mOobServers.size() << " is not equal to listener options size " << + mOobListenOptions.size() << " in driver " << mName); + return NN_ERROR; + } + + return NN_OK; +} + +NResult UBSHcomNetDriver::CreateUdsListeners() +{ + if (mOobUdsListenOptions.empty()) { + NN_LOG_ERROR("No listen info is set in driver " << mName); + return NN_INVALID_PARAM; + } + + uint16_t oobIndex = 0; + for (auto &lOpt : mOobUdsListenOptions) { + NetOOBServerPtr oobServer = nullptr; + /* create oob server */ + if (mEnableTls) { + auto oobSSLServer = new (std::nothrow) OOBSSLServer(mOptions.oobType, lOpt.second.Name(), lOpt.second.perm, + lOpt.second.isCheck, mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(oobSSLServer != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLServer->SetTlsOptions(mOptions.cipherSuite, mOptions.tlsVersion); + oobSSLServer->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + oobServer = oobSSLServer; + } else { + oobServer = new (std::nothrow) + OOBTCPServer(mOptions.oobType, lOpt.second.Name(), lOpt.second.perm, lOpt.second.isCheck); + NN_ASSERT_LOG_RETURN(oobServer.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + NN_LOG_TRACE_INFO(lOpt.second.Name()); + + oobServer->Index({ mIndex, oobIndex++ }); + oobServer->SetMaxConntionNum(mOptions.maxConnectionNum); + + /* create load balancer ptr for each oob server */ + auto twc = lOpt.second.targetWorkerCount == 0 ? UINT16_MAX : lOpt.second.targetWorkerCount; + NetWorkerLBPtr lb = new (std::nothrow) NetWorkerLB(mName, mOptions.lbPolicy, twc); + if (NN_UNLIKELY(lb == nullptr)) { + NN_LOG_ERROR("Failed to new oob load balancer in uds driver " << mName); + return NN_NEW_OBJECT_FAILED; + } + + /* attach lb to oob server in case of leak */ + oobServer->SetWorkerLb(lb.Get()); + + /* add worker groups to lb */ + if (NN_UNLIKELY(lb->AddWorkerGroups(mWorkerGroups) != NN_OK)) { + NN_LOG_ERROR("Failed to added worker groups into load balancer in uds driver " << mName); + return NN_NEW_OBJECT_FAILED; + } + + oobServer->IncreaseRef(); + mOobServers.emplace_back(oobServer.Get()); + } + + if (mOobUdsListenOptions.size() != mOobServers.size()) { + NN_LOG_ERROR("Created oob server count " << mOobServers.size() << " is not equal to listener options size " << + mOobUdsListenOptions.size() << " in uds driver " << mName); + return NN_ERROR; + } + + return NN_OK; +} + +NResult UBSHcomNetDriver::CreateServerLB() +{ + /* create load balancer for each oob server */ + NetWorkerLBPtr lb = new (std::nothrow) NetWorkerLB(mName, mOptions.lbPolicy, UINT16_MAX); + if (NN_UNLIKELY(lb == nullptr)) { + NN_LOG_ERROR("Failed to new oob load balancer in driver " << mName); + return NN_NEW_OBJECT_FAILED; + } + + /* add worker groups to lb */ + if (NN_UNLIKELY(lb->AddWorkerGroups(mWorkerGroups) != NN_OK)) { + NN_LOG_ERROR("Failed to added worker groups into load balancer in driver " << mName); + return NN_NEW_OBJECT_FAILED; + } + + lb->IncreaseRef(); + mServerLb = lb.Get(); + + return NN_OK; +} + +NResult UBSHcomNetDriver::StartListeners() +{ + NResult result = NN_OK; + for (uint64_t i = 0; i < mOobServers.size(); i++) { + if (NN_UNLIKELY(mOobServers[i] == nullptr)) { + NN_LOG_WARN("index " << i << "of oobServer is null"); + continue; + } + if ((result = mOobServers[i]->Start()) != NN_OK) { + for (uint64_t j = 0; j < i; j++) { + mOobServers[j]->Stop(); + } + return result; + } + } + + // get auto selected listen port + for (uint64_t i = 0; i < mOobListenOptions.size(); i++) { + if (mOobListenOptions[i].port == 0) { + uint16_t port = 0; + // for tcp oob, mOobServers.size() must be equal to mOobListenOptions.size() + if (mOobServers[i]->GetListenPort(port) == NN_OK) { + mOobListenOptions[i].port = port; + } else { + NN_LOG_WARN("Invalid to get real listen port for " << mOobListenOptions[i].Ip() << ":" << + mOobListenOptions[i].port); + } + } + } + + return NN_OK; +} + +NResult UBSHcomNetDriver::StopListeners(bool clear) +{ + for (auto &item : mOobServers) { + item->Stop(); + if (clear) { + item->DecreaseRef(); + } + } + + if (clear) { + mOobServers.clear(); + } + + return NN_OK; +} + +NResult UBSHcomNetDriver::CreateClientLB() +{ + NResult result = NN_OK; + NetWorkerLBPtr lb = new (std::nothrow) NetWorkerLB(mName, mOptions.lbPolicy, UINT16_MAX); + if (NN_UNLIKELY(lb.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new lb object in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + if (NN_UNLIKELY((result = lb->AddWorkerGroups(mWorkerGroups)) != NN_OK)) { + NN_LOG_ERROR("Failed to add worker into load balancer result " << result << " in driver " << mName); + return result; + } + + lb->IncreaseRef(); + mClientLb = lb.Get(); + return NN_OK; +} + +void UBSHcomNetDriver::DestroyClientLB() +{ + if (mClientLb != nullptr) { + mClientLb->DecreaseRef(); + mClientLb = nullptr; + } +} + +void UBSHcomNetDriver::DumpObjectStatistics() +{ + NetObjStatistic::Dump(); +} + +void UBSHcomNetDriver::OobIpAndPort(const std::string &ip, uint16_t port) +{ + if (mStartOobSvr) { + if (inet_addr(ip.c_str()) == 0) { + NN_LOG_ERROR("SetOobIpAndPort failed, ip addr is 0.0.0.0"); + return; + } + + UBSHcomNetOobListenerOptions opt{}; + if (NN_UNLIKELY(!opt.Set(ip, port, UINT16_MAX))) { + NN_LOG_ERROR("set UBSHcomNetOobListenerOptions failed"); + return; + } + AddOobOptions(opt); + return; + } + + mOobIp = ip; + mOobPort = port; +} + +void UBSHcomNetDriver::OobEidAndJettyId(const std::string &eid, uint16_t id) +{ + std::string s; + std::remove_copy(eid.begin(), eid.end(), std::back_inserter(s), ':'); + if (s.length() != NN_NO32) { + NN_LOG_ERROR("Ensure the eid is of 128b size after erasing the colon sign"); + return; + } + if (id < NN_NO2 || id > NN_NO1023) { + NN_LOG_ERROR("Ensure the jetty id in range 2~1023"); + return; + } + if (mStartOobSvr) { + UBSHcomNetOobListenerOptions opt{}; + if (NN_UNLIKELY(!opt.SetEid(s, id, UINT16_MAX))) { + NN_LOG_ERROR("set UBSHcomNetOobListenerOptions failed"); + return; + } + AddOobOptions(opt); + return; + } + + mOobIp = eid; + mOobPort = id; +} + +bool UBSHcomNetDriver::GetOobIpAndPort(std::vector> &result) +{ + if (!mStartOobSvr) { + NN_LOG_ERROR("GetOobIpAndPort failed, it is not server"); + return false; + } + + if (!mStarted) { + NN_LOG_ERROR("GetOobIpAndPort failed, net driver is not started"); + return false; + } + + result.clear(); + for (const auto& item : mOobListenOptions) { + result.emplace_back(item.Ip(), item.port); + } + return true; +} + +NResult UBSHcomNetDriver::ValidateAndParseOobPortRange(const char* oobPortRange) +{ + if (oobPortRange == nullptr || oobPortRange[0] == '\0') { + return NN_OK; + } + + std::vector portStr; + std::string strPortRange(oobPortRange); + NetFunc::NN_SplitStr(oobPortRange, "-", portStr); + + const int portSize = 2; + if (portStr.size() != portSize) { + NN_LOG_ERROR("oobPortRange is invalid, oobPortRange consists of two numbers connected by '-'"); + return NN_ERROR; + } + + long lowerLimit = 0; + if (!NetFunc::NN_Stol(portStr[0], lowerLimit)) { + NN_LOG_ERROR("parse lower limit of oobPortRange(" << portStr[0] << ") failed"); + return NN_ERROR; + } + if (lowerLimit < NN_NO1024 || lowerLimit > NN_NO65535) { + NN_LOG_ERROR("lower limit of oobPortRange invalid, port number must be in the range 1024-65535"); + return NN_ERROR; + } + + long upperLimit = 0; + if (!NetFunc::NN_Stol(portStr[1], upperLimit)) { + NN_LOG_ERROR("parse upper limit of oobPortRange(" << portStr[1] << ") failed"); + return NN_ERROR; + } + if (upperLimit < NN_NO1024 || upperLimit > NN_NO65535) { + NN_LOG_ERROR("upper limit of oobPortRange invalid, port number must be in the range 1024-65535"); + return NN_ERROR; + } + + if (lowerLimit > upperLimit) { + NN_LOG_ERROR("lower limit of oobPortRange is bigger than the upper limit"); + return NN_ERROR; + } + + mPortRange.first = static_cast(lowerLimit); + mPortRange.second = static_cast(upperLimit); + return NN_OK; +} + +NResult UBSHcomNetDriver::ParseUrl(const std::string &url, NetDriverOobType &type, std::string &ip, uint16_t &port) +{ + NetProtocol protocal; + std::string urlSuffix; + if (NN_UNLIKELY(!NetFunc::NN_SplitProtoUrl(url, protocal, urlSuffix))) { + NN_LOG_ERROR("Invalid url: "<< url <<" should be like tcp://127.0.0.1:9981 or uds://name or ubc://eid:jettyId"); + return NN_PARAM_INVALID; + } + + if (protocal == NetProtocol::NET_UBC) { + type = NetDriverOobType::NET_OOB_UB; + + if (NN_UNLIKELY(!NetFunc::NN_ConvertEidAndJettyId(urlSuffix, ip, port))) { + NN_LOG_ERROR("Invalid url: " << url << " should be like 1111:1111:0000:0000:0000:0000:4444:0000:888"); + return NN_PARAM_INVALID; + } + return SER_OK; + } + + if (protocal == NetProtocol::NET_UDS) { + type = NetDriverOobType::NET_OOB_UDS; + ip = urlSuffix; + return SER_OK; + } + + type = NetDriverOobType::NET_OOB_TCP; + if (NN_UNLIKELY(!NetFunc::NN_ConvertIpAndPort(urlSuffix, ip, port))) { + NN_LOG_ERROR("Invalid url: " << url <<" should be like 127.0.0.1:9981"); + return NN_PARAM_INVALID; + } + + return SER_OK; +} + +void UBSHcomNetDriver::AddOobOptions(const UBSHcomNetOobListenerOptions &option) +{ + { + std::lock_guard guard(mInitMutex); + if (NN_UNLIKELY(mOobListenOptions.size() >= gMaxListenPort)) { + NN_LOG_ERROR("Only " << gMaxListenPort << " listeners is allowed in driver"); + return; + } + + // The same port number cannot be used for two identical IP addresses + // The same port number can be used for two different IP addresses + for (const auto& opt : mOobListenOptions) { + if (opt.Ip() == option.Ip() && opt.port == option.port && opt.port != 0) { + NN_LOG_WARN("Duplicated listen '" << option.Ip() << ":" << option.port << "' adding to driver " << + mName << ", ignored"); + return; + } + } + + mOobListenOptions.emplace_back(option); + } +} + +void UBSHcomNetDriver::OobUdsName(const std::string &name) +{ + if (name.length() >= sizeof(UBSHcomNetOobUDSListenerOptions::name)) { + NN_LOG_ERROR("Uds name is too long for driver " << mName); + return; + } + + if (mStartOobSvr) { + UBSHcomNetOobUDSListenerOptions opt{}; + if (NN_UNLIKELY(!opt.Set(name, UINT16_MAX))) { + NN_LOG_ERROR("set UBSHcomNetOobUDSListenerOptions failed"); + return; + } + AddOobUdsOptions(opt); + return; + } + + mUdsName = name; +} + +void UBSHcomNetDriver::AddOobUdsOptions(const UBSHcomNetOobUDSListenerOptions &option) +{ + std::lock_guard guard(mInitMutex); + if (NN_UNLIKELY(mOobUdsListenOptions.size() >= gMaxListenPort)) { + NN_LOG_ERROR("Only " << gMaxListenPort << " listeners is allowed in driver"); + return; + } + if (NN_UNLIKELY(NetFunc::NN_ValidateUrl(option.Name()) != NN_OK)) { + NN_LOG_ERROR("Invalid uds name"); + return; + } + + auto iter = mOobUdsListenOptions.find(option.Name()); + if (NN_UNLIKELY(iter != mOobUdsListenOptions.end())) { + NN_LOG_WARN("Duplicated listen name '" << option.Name() << "' adding to driver " << mName << ", ignored"); + return; + } + + mOobUdsListenOptions[option.Name()] = option; +} + +NResult UBSHcomNetDriver::ValidateHandlesCheck() +{ + if (mReceivedRequestHandler == nullptr) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as receivedRequestHandler is null"); + return NN_INVALID_PARAM; + } + + if (mRequestPostedHandler == nullptr) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as requestPostedHandler is null"); + return NN_INVALID_PARAM; + } + + if (mOneSideDoneHandler == nullptr) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as oneSideDoneHandler is null"); + return NN_INVALID_PARAM; + } + // SHM self polling mode not register ep handler + if (mProtocol != SHM && mEndPointBrokenHandler == nullptr) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as endPointBrokenHandler is null"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +NResult UBSHcomNetDriver::ValidateOptionsOobType() +{ + if (mProtocol != UBC && mOptions.oobType == NET_OOB_UB) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", only the UBC protocol can be set NET_OOB_UB."); + return NN_INVALID_PARAM; + } + if (mOptions.oobType == NET_OOB_UB && mOptions.enableTls) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as oobType NET_OOB_UB does not support enableTls."); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +void UBSHcomNetDriver::RegisterNewEPHandler(const UBSHcomNetDriverNewEndPointHandler &handler) +{ + mNewEndPointHandler = handler; +} +void UBSHcomNetDriver::RegisterEPBrokenHandler(const UBSHcomNetDriverEndpointBrokenHandler &handler) +{ + mEndPointBrokenHandler = handler; +} + +void UBSHcomNetDriver::RegisterNewReqHandler(const UBSHcomNetDriverReceivedHandler &handler) +{ + mReceivedRequestHandler = handler; +} + +void UBSHcomNetDriver::RegisterReqPostedHandler(const UBSHcomNetDriverSentHandler &handler) +{ + mRequestPostedHandler = handler; +} + +void UBSHcomNetDriver::RegisterOneSideDoneHandler(const UBSHcomNetDriverOneSideDoneHandler &handler) +{ + mOneSideDoneHandler = handler; +} + +void UBSHcomNetDriver::RegisterIdleHandler(const UBSHcomNetDriverIdleHandler &handler) +{ + mIdleHandler = handler; +} + +void UBSHcomNetDriver::RegisterTLSCaCallback(const UBSHcomTLSCaCallback &cb) +{ + mTlsCaCallback = cb; +} + +void UBSHcomNetDriver::RegisterTLSCertificationCallback(const UBSHcomTLSCertificationCallback &cb) +{ + mTlsCertCB = cb; +} + +void UBSHcomNetDriver::RegisterTLSPrivateKeyCallback(const UBSHcomTLSPrivateKeyCallback &cb) +{ + mTlsPrivateKeyCB = cb; +} + +void UBSHcomNetDriver::RegisterEndpointSecInfoProvider(const UBSHcomNetDriverEndpointSecInfoProvider &provider) +{ + mSecInfoProvider = provider; +} + +void UBSHcomNetDriver::RegisterEndpointSecInfoValidator(const UBSHcomNetDriverEndpointSecInfoValidator &validator) +{ + mSecInfoValidator = validator; +} + +void UBSHcomNetDriver::RegisterPskUseSessionCb(const UBSHcomPskUseSessionCb &cb) +{ + mPskUseSessionCb = cb; +} + +void UBSHcomNetDriver::RegisterPskFindSessionCb(const UBSHcomPskFindSessionCb &cb) +{ + mPskFindSessionCb = cb; +} + +constexpr int16_t ERROR_CODE_100 = 100; +constexpr int16_t ERROR_CODE_200 = 200; +constexpr int16_t ERROR_CODE_300 = 300; +constexpr int16_t ERROR_CODE_400 = 400; +constexpr int16_t ERROR_CODE_500 = 500; +constexpr int16_t ERROR_CODE_600 = 600; + +const char *UBSHcomNetErrStr(int16_t errCode) +{ + if (errCode == 0) { + return "OK"; + } + int32_t index = 0; + if (errCode >= ERROR_CODE_100 && errCode < ERROR_CODE_200) { + index = errCode - ERROR_CODE_100; + if (index < NNCodeArrayLength) { + return NNCodeArray[index]; + } else { + return "ILLEGAL_CODE"; + } + } + + if (errCode >= ERROR_CODE_200 && errCode < ERROR_CODE_300) { + index = errCode - ERROR_CODE_200; + if (index < RRCodeArrayLength) { + return RRCodeArray[index]; + } else { + return "ILLEGAL_CODE"; + } + } + + if (errCode >= ERROR_CODE_300 && errCode < ERROR_CODE_400) { + index = errCode - ERROR_CODE_300; + if (index < ShCodeArrayLength) { + return ShCodeArray[index]; + } else { + return "ILLEGAL_CODE"; + } + } + + if (errCode >= ERROR_CODE_400 && errCode < ERROR_CODE_500) { + index = errCode - ERROR_CODE_400; + if (index < SCodeArrayLength) { + return SCodeArray[index]; + } else { + return "ILLEGAL_CODE"; + } + } + + if (errCode >= ERROR_CODE_500 && errCode < ERROR_CODE_600) { + index = errCode - ERROR_CODE_500; + if (index < SevCodeArrayLength) { + return SevCodeArray[index]; + } else { + return "ILLEGAL_CODE"; + } + } + + return "ILLEGAL_CODE"; +} + +std::string &UBSHcomNEPStateToString(UBSHcomNetEndPointState v) +{ + static std::string nepStateString[NEP_BUFF] = {"new", "established", "broken"}; + static std::string unknown = "UNKNOWN EP STATE"; + if (v != NEP_NEW && v != NEP_ESTABLISHED && v != NEP_BROKEN) { + return unknown; + } + return nepStateString[v]; +} + +std::string &UBSHcomRequestStatusToString(UBSHcomNetRequestStatus status) +{ + static std::string requestStatus[NN_NO5] = {"Called", "In HCOM", "In URMA", "Polled", "Success"}; + static std::string invalid = "INVALID STATUS"; + if (status > UBSHcomNetRequestStatus::SUCCESS) { + return invalid; + } + int value = static_cast(status); + return requestStatus[value]; +} + +void SetTraceIdInner(const std::string &traceId) +{ +#ifdef UB_BUILD_ENABLED + if (HcomUrma::IsLoaded()) { + HcomUrma::LogSetThreadTag(traceId.c_str()); + return; + } +#endif + NN_LOG_WARN("failed to set trace id, urma api is not loaded"); +} + +std::string &UBSHcomNetMemoryAllocatorTypeToString(UBSHcomNetMemoryAllocatorType v) +{ + static std::string allocatorType[NN_NO2] = {"Dynamic size allocator", "Dynamic size allocator with cache"}; + static std::string unknown = "UNKNOWN ALLOCATOR TYPE"; + if (v != DYNAMIC_SIZE && v != DYNAMIC_SIZE_WITH_CACHE) { + return unknown; + } + return allocatorType[v]; +} + +std::string UBSHcomNetMemoryAllocatorOptions::ToString() const +{ + std::ostringstream oss; + oss << "address " << address << ", size " << size << ", minBlockSize " << minBlockSize << ", alignedAddress " << + alignedAddress << ", cacheTierCount " << cacheTierCount << ", cacheBlockCountPerTier " << + cacheBlockCountPerTier << ", cacheTierPolicy " << cacheTierPolicy; + return oss.str(); +} + +std::string &UBSHcomNetDriverOobTypeToString(NetDriverOobType v) +{ + static std::string oobType[NN_NO3] = {"Tcp", "UDS", "URMA"}; + static std::string unknown = "UNKNOWN OOB TYPE"; + if (v != NET_OOB_TCP && v != NET_OOB_UDS && v != NET_OOB_UB) { + return unknown; + } + return oobType[v]; +} + +std::string &UBSHcomNetDriverSecTypeToString(UBSHcomNetDriverSecType v) +{ + static std::string secType[NN_NO3] = {"SecNoValid", "SecValidOneWay", "SecValidTwoWay", }; + static std::string unknown = "UNKNOWN SEC TYPE"; + if (v != NET_SEC_VALID_ONE_WAY && v != NET_SEC_VALID_TWO_WAY) { + return unknown; + } + return secType[v]; +} + +std::string &UBSHcomNetDriverLBPolicyToString(UBSHcomNetDriverLBPolicy v) +{ + static std::string driverLB[NN_NO2] = {"RR", "Hash", }; + static std::string unknown = "UNKNOWN POLICY"; + if (v != NET_ROUND_ROBIN && v != NET_HASH_IP_PORT) { + return unknown; + } + return driverLB[v]; +} + +std::string &UBSHcomNetDriverProtocolToString(UBSHcomNetDriverProtocol v) +{ + static std::string driverProtocol[NN_NO6] = {"RDMA", "TCP", "UDS", "SHM", "UBC", + "UNKNOWN PROTOCOL"}; + static std::string unknown = "UNKNOWN PROTOCOL"; + if (v >= NN_NO6) { + return unknown; + } + return driverProtocol[v]; +} + +bool UBSHcomNetCloneStringToArray(char *dest, size_t destMax, const std::string &src) +{ + if (NN_LIKELY(src.length() < destMax)) { + int ret = strcpy_s(dest, destMax, src.c_str()); + if (NN_UNLIKELY(ret != EOK)) { + NN_LOG_ERROR("copy string failed, ret " << ret); + return false; + } + return true; + } + + NN_LOG_ERROR("Invalid src length " << src.length() + NN_NO1 << " clone to dest length" << destMax); + return false; +} + +NResult ValidateWorkerOptions(UBSHcomNetDriverWorkingMode mode, char *workerGroups, char *workerGroupsCpuSet, + UBSHcomNetDriverLBPolicy lbPolicy, int workerThreadPriority) +{ + /* validate param related to poll mode for RDMA, Sock and SHM */ + if (NN_UNLIKELY(mode != NET_BUSY_POLLING && mode != NET_EVENT_POLLING)) { + NN_LOG_ERROR("Option 'mode' is invalid, " << mode << + " is set in driver, valid value is NET_BUSY_POLLING(0) or NET_EVENT_POLLING(1)"); + return NN_INVALID_PARAM; + } + + /* validate params related to worker group for RDMA, Sock and SHM */ + if (NN_UNLIKELY(!ValidateArrayOptions(workerGroups, NN_NO64))) { + NN_LOG_ERROR("Option 'workerGroups' is invalid, the Array max length is 64."); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(!ValidateArrayOptions(workerGroupsCpuSet, NN_NO128))) { + NN_LOG_ERROR("Option 'workerGroupsCpuSet' is invalid, the Array max length is 128."); + return NN_INVALID_PARAM; + } + + /* validate param related to load balance policy for RDMA, Sock and SHM */ + if (NN_UNLIKELY(lbPolicy != NET_ROUND_ROBIN && lbPolicy != NET_HASH_IP_PORT)) { + NN_LOG_ERROR("Option 'oobType' is invalid, " << lbPolicy << + " is set in driver, valid value is NET_ROUND_ROBIN(0) or NET_HASH_IP_PORT(1)"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(workerThreadPriority > static_cast(NN_NO20) || + workerThreadPriority < -static_cast(NN_NO20))) { + NN_LOG_ERROR("Option 'workerThreadPriority' is invalid, it should be set from -20 to 20 closed, 0 means do not " + "set priority"); + return NN_INVALID_PARAM; + } + + return NN_OK; +} + +NResult ValidateOobOptions(NetDriverOobType oobType) +{ + /* validate param related to net driver oobType for RDMA, Sock and SHM */ + if (NN_UNLIKELY(oobType > NET_OOB_UB)) { + NN_LOG_ERROR("Option 'oobType' is invalid, " << oobType << + " is set in driver, valid value is NET_OOB_TCP(0) or NET_OOB_UDS(1) or NET_OOB_UB(2)"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +NResult ValidateHeartbeatOptions(uint16_t heartBeatIdleTime, uint16_t heartBeatProbeTimes, + uint16_t heartBeatProbeInterval) +{ + if (NN_UNLIKELY(heartBeatIdleTime == 0 || heartBeatIdleTime > NN_NO10000)) { + NN_LOG_ERROR("Option 'heartBeatIdleTime' is invalid, " << heartBeatIdleTime << + " is set in driver, the valid value range is 1s ~ 10000s"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(heartBeatProbeTimes == 0 || heartBeatProbeTimes > NN_NO1024)) { + NN_LOG_ERROR("Option 'heartBeatProbeTime' is invalid, " << heartBeatProbeTimes << + " is set in driver, the valid value range is 1s ~ 1024s"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(heartBeatProbeInterval > NN_NO1024)) { + NN_LOG_ERROR("Option 'heartBeatProbeInterval' is invalid, " << heartBeatProbeInterval << + " is set in driver, the valid value range is 1s ~ 1024s"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +NResult ValidateQueueOptions(uint32_t qpSendQueueSize, uint32_t qpReceiveQueueSize, uint16_t completionQueueDepth) +{ + /* validate params related to send queue and receive queue size for RDMA and Sock */ + if (NN_UNLIKELY(qpSendQueueSize < NN_NO16 || qpSendQueueSize > NN_NO65535)) { + NN_LOG_ERROR("Option 'qpSendQueueSize' is invalid, " << qpSendQueueSize << + " is set in driver, the valid value range is 16 ~ 65535"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(qpReceiveQueueSize < NN_NO16 || qpReceiveQueueSize > NN_NO65535)) { + NN_LOG_ERROR("Option 'qpReceiveQueueSize' is invalid " << qpReceiveQueueSize << + " is set in driver, the valid value range is 16 ~ 65535"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(completionQueueDepth == NN_NO0 || completionQueueDepth > NN_NO8192)) { + NN_LOG_ERROR("Option 'completionQueueDepth' is invalid " << completionQueueDepth << + " is set in driver, the valid value range is 1 ~ 8192"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +NResult ValidatePollingOptions(uint16_t pollingBatchSize, uint32_t eventPollingTimeout) +{ + /* validate params related to poll for RDMA, Sock and SHM */ + if (NN_UNLIKELY(pollingBatchSize == 0 || pollingBatchSize > NN_NO1024)) { + NN_LOG_ERROR("Option 'pollingBatchSize' is invalid, " << pollingBatchSize << + " is set in driver, the valid value range is 1 ~ 1024"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(eventPollingTimeout == 0 || eventPollingTimeout > NN_NO2000000)) { + NN_LOG_ERROR("Option 'eventPollingTimeout' is invalid, " << eventPollingTimeout << + " is set in driver, the valid value range is 1ms ~ 2000000ms"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +NResult ValidateSegOptions(uint32_t mrSendReceiveSegSize, uint32_t mrSendReceiveSegCount) +{ + if (mrSendReceiveSegSize < NN_NO1 || mrSendReceiveSegSize > NET_SGE_MAX_SIZE) { + NN_LOG_ERROR("Option 'mrSendReceiveSegSize' is invalid, " << mrSendReceiveSegSize << + " is set in driver, the valid value range is 1 byte ~ 524288000 byte"); + return NN_INVALID_PARAM; + } + + if (mrSendReceiveSegCount < NN_NO1 || mrSendReceiveSegCount > NN_NO65535) { + NN_LOG_ERROR("Option 'mrSendReceiveSegCount' is invalid, " << mrSendReceiveSegCount << + " is set in driver, the valid value range is 1 ~ 65535"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +NResult ValidateCipherOptions(bool enableTls, UBSHcomTlsVersion tlsVersion, UBSHcomNetCipherSuite cipherSuite) +{ + if (!enableTls) { + return NN_OK; + } + + if ((cipherSuite < AES_GCM_128) || (cipherSuite > CHACHA20_POLY1305)) { + NN_LOG_ERROR("Option 'cipherSuite' is invalid, " << cipherSuite << + " is set in driver, the valid value range is AES_GCM_128:" << AES_GCM_128 << " and CHACHA20_POLY1305:" << + CHACHA20_POLY1305); + return NN_INVALID_PARAM; + } + + if ((tlsVersion != TLS_1_3)) { + NN_LOG_ERROR("Currently only supports TLS 1.3 version"); + return NN_INVALID_PARAM; + } + + return NN_OK; +} + +NResult ValidateMaxConnectionOptions(uint32_t maxConnectionNum) +{ + if (maxConnectionNum == NN_NO0) { + NN_LOG_ERROR("Option 'maxConnectionNum' is invalid, " << maxConnectionNum << + " is set in driver, the valid value range is > 0"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +NResult UBSHcomNetDriverOptions::ValidateCommonOptions() +{ + /* validate params related to heart beat for RDMA, Sock and SHM */ + if (NN_UNLIKELY(ValidateWorkerOptions(mode, workerGroups, workerGroupsCpuSet, lbPolicy, workerThreadPriority) != + NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidatePollingOptions(pollingBatchSize, eventPollingTimeout) != NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidateOobOptions(oobType) != NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidateHeartbeatOptions(heartBeatIdleTime, heartBeatProbeTimes, heartBeatProbeInterval) != + NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidateQueueOptions(qpSendQueueSize, qpReceiveQueueSize, completionQueueDepth) != NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidateSegOptions(mrSendReceiveSegSize, mrSendReceiveSegCount) != NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidateCipherOptions(enableTls, tlsVersion, cipherSuite) != NN_OK)) { + return NN_INVALID_PARAM; + } + + if (!POWER_OF_2(tcpSendBufSize)) { + tcpSendBufSize = NN_NextPower2(tcpSendBufSize); + } + + if (!POWER_OF_2(tcpReceiveBufSize)) { + tcpReceiveBufSize = NN_NextPower2(tcpReceiveBufSize); + } + + if (!POWER_OF_2(qpSendQueueSize)) { + qpSendQueueSize = NN_NextPower2(qpSendQueueSize); + } + + if (!POWER_OF_2(qpReceiveQueueSize)) { + qpReceiveQueueSize = NN_NextPower2(qpReceiveQueueSize); + } + + if (NN_UNLIKELY(ValidateMaxConnectionOptions(maxConnectionNum) != NN_OK)) { + return NN_INVALID_PARAM; + } + + return NN_OK; +} + +std::string UBSHcomNetDriverOptions::NetDeviceIpMask() const +{ + return NN_CHAR_ARRAY_TO_STRING(netDeviceIpMask); +} + +std::string UBSHcomNetDriverOptions::NetDeviceIpGroup() const +{ + return NN_CHAR_ARRAY_TO_STRING(netDeviceIpGroup); +} + +std::string UBSHcomNetDriverOptions::WorkGroups() const +{ + return NN_CHAR_ARRAY_TO_STRING(workerGroups); +} + +std::string UBSHcomNetDriverOptions::WorkerGroupCpus() const +{ + return NN_CHAR_ARRAY_TO_STRING(workerGroupsCpuSet); +} + +std::string UBSHcomNetDriverOptions::WorkerGroupThreadPriority() const +{ + return NN_CHAR_ARRAY_TO_STRING(workerGroupsThreadPriority); +} + +bool UBSHcomNetDriverOptions::SetNetDeviceIpMask(const std::string &mask) +{ + NN_SET_CHAR_ARRAY_FROM_STRING(netDeviceIpMask, mask); +} + +bool UBSHcomNetDriverOptions::SetNetDeviceIpMask(const std::vector &mask) +{ + std::string ipMasksStr; + NetFunc::NN_VecStrToStr(mask, ",", ipMasksStr); + NN_SET_CHAR_ARRAY_FROM_STRING(netDeviceIpMask, ipMasksStr); +} + +bool UBSHcomNetDriverOptions::SetNetDeviceEid(const std::string &eid) +{ + std::string s; + std::remove_copy(eid.begin(), eid.end(), std::back_inserter(s), ':'); + if (s.length() != NN_NO32) { + NN_LOG_ERROR("Ensure the eid is of 128b size after erasing the colon sign"); + return false; + } + + return HexStringToBuff(s, NN_NO16, netDeviceEid); +} + +bool UBSHcomNetDriverOptions::SetNetDeviceIpGroup(const std::string &ipGroup) +{ + NN_SET_CHAR_ARRAY_FROM_STRING(netDeviceIpGroup, ipGroup); +} + +bool UBSHcomNetDriverOptions::SetNetDeviceIpGroup(const std::vector &ipGroup) +{ + std::string ipGroupStr; + NetFunc::NN_VecStrToStr(ipGroup, ";", ipGroupStr); + NN_SET_CHAR_ARRAY_FROM_STRING(netDeviceIpGroup, ipGroupStr); +} + +bool UBSHcomNetDriverOptions::SetWorkerGroups(const std::string &groups) +{ + NN_SET_CHAR_ARRAY_FROM_STRING(workerGroups, groups); +} + +bool UBSHcomNetDriverOptions::SetWorkerGroupsCpuSet(const std::string &value) +{ + NN_SET_CHAR_ARRAY_FROM_STRING(workerGroupsCpuSet, value); +} + +bool UBSHcomNetDriverOptions::SetWorkerGroupThreadPriority(const std::string &value) +{ + NN_SET_CHAR_ARRAY_FROM_STRING(workerGroupsThreadPriority, value); +} + +std::string UBSHcomNetDriverOptions::ToString() const +{ + std::ostringstream oss; + oss << "UBSHcomNetDriverOptions mode: " << static_cast(mode) << ", send/receive-mr-seg-count: " << + mrSendReceiveSegCount << ", send/receive-mr-seg-size: " << mrSendReceiveSegSize << ", device-mask: " << + NetDeviceIpMask() << ", cq-size " << completionQueueDepth << ", max-post-send: " << maxPostSendCountPerQP << + ", pre-post-receive-count: " << prePostReceiveSizePerQP << ", polling-batch-size: " << pollingBatchSize << + ", qp-send-queue-size: " << qpSendQueueSize << ", qp-receive-queue-size: " << qpReceiveQueueSize << + ", worker-groups: " << WorkGroups() << ", worker-groups-cpu-set: " << WorkerGroupCpus() << + ", start-workers: " << dontStartWorkers << ", tls-enabled: " << enableTls << ", oob-type: " << + UBSHcomNetDriverOobTypeToString(oobType) << ", lb-policy: " << UBSHcomNetDriverLBPolicyToString(lbPolicy); + return oss.str(); +} + +std::string UBSHcomNetDriverOptions::ToStringForSock() const +{ + std::ostringstream oss; + oss << "UBSHcomNetDriverOptions mode: " << static_cast(mode) << ", send/receive-mr-seg-count: " << + mrSendReceiveSegCount << ", send/receive-mr-seg-size: " << mrSendReceiveSegSize << ", device-mask: " << + NetDeviceIpMask() << ", cq-size " << completionQueueDepth << ", max-post-send: " << maxPostSendCountPerQP << + ", pre-post-receive-count: " << prePostReceiveSizePerQP << ", polling-batch-size: " << pollingBatchSize << + ", qp-send-queue-size: " << qpSendQueueSize << ", qp-receive-queue-size: " << qpReceiveQueueSize << + ", worker-groups: " << WorkGroups() << ", worker-groups-cpu-set: " << WorkerGroupCpus() << + ", start-workers: " << dontStartWorkers << ", tls-enabled: " << enableTls << ", oob-type: " << + UBSHcomNetDriverOobTypeToString(oobType) << ", lb-policy: " << UBSHcomNetDriverLBPolicyToString(lbPolicy) << + ", tcp-keepalive-idle-time: " << heartBeatIdleTime << " seconds, tcp-keepalive-probe-times: " << + heartBeatProbeTimes << ", tcp-keepalive-probe-interval: " << heartBeatProbeInterval << + " seconds, tcp-send-buffer-size: " << tcpSendBufSize << ", tcp-receive-buffer-size: " << tcpReceiveBufSize; + return oss.str(); +} + +void UnParseWorkerGroups(const std::vector &workerGroups, std::string &strRes) +{ + strRes.clear(); + for (const auto &workerGroup : workerGroups) { + if (NN_UNLIKELY(strRes.empty())) { + strRes += std::to_string(workerGroup.threadCount); + } else { + strRes += ("," + std::to_string(workerGroup.threadCount)); + } + } +} + +void UnParseWorkerGroupsCpus(const std::vector &workerGroups, std::string &strRes) +{ + strRes.clear(); + for (const auto &workerGroup : workerGroups) { + std::string item = "na"; + if (NN_UNLIKELY(workerGroup.cpuIdsRange.first != UINT32_MAX)) { + item = std::to_string(workerGroup.cpuIdsRange.first) + "-" + + std::to_string(workerGroup.cpuIdsRange.second); + } + if (NN_UNLIKELY(strRes.empty())) { + strRes += item; + } else { + strRes += ("," + item); + } + } +} + +bool UBSHcomNetDriverOptions::SetWorkerGroupsInfo(const std::vector &workerGroupInfos) +{ + if (NN_UNLIKELY(workerGroupInfos.empty())) { + NN_LOG_ERROR("SetWorkerGroupsInfo failed, workerGroups is empty"); + return false; + } + workerThreadPriority = workerGroupInfos[0].threadPriority; + std::string wGsStr; + std::string wGsCpuSetStr; + UnParseWorkerGroups(workerGroupInfos, wGsStr); + UnParseWorkerGroupsCpus(workerGroupInfos, wGsCpuSetStr); + NN_SET_CHAR_ARRAY_FROM_STRING_VOID(workerGroups, wGsStr); + NN_SET_CHAR_ARRAY_FROM_STRING_VOID(workerGroupsCpuSet, wGsCpuSetStr); + return true; +} + +} +} diff --git a/src/hcom.h b/src/hcom.h new file mode 100644 index 0000000000000000000000000000000000000000..dd79d9d21725678bededefef1ed2dbaee3d74a40 --- /dev/null +++ b/src/hcom.h @@ -0,0 +1,2389 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_CPP_H_34562 +#define OCK_HCOM_CPP_H_34562 + +#include +#include +#include +#include +#include +#include +#include + +#include "hcom_err.h" +#include "hcom_ref.h" +#include "hcom_def.h" +#include "hcom_log.h" +#include "hcom_utils.h" +#include "hcom_obj_statistics.h" + +namespace ock { +namespace hcom { +/* flag */ +#define NET_EP_SELF_POLLING NET_FLAGS_BIT(0) +#define NET_EP_EVENT_POLLING NET_FLAGS_BIT(1) + +class UBSHcomNetEndpoint; +class UBSHcomNetMessage; +class UBSHcomNetRequestContext; +class UBSHcomNetResponseContext; +class UBSHcomNetMemoryRegion; +class UBSHcomNetMemoryAllocator; +class UBSHcomNetDriver; +class NetWorkerLB; +class NetSecrets; +class OOBTCPServer; +class NetServiceGlobalObject; +class OOBTCPConnection; + +using UBSHcomNetEndpointPtr = NetRef; +using UBSHcomNetRequestContextPtr = NetRef; +using UBSHcomNetMemoryRegionPtr = NetRef; +using UBSHcomNetMemoryAllocatorPtr = NetRef; +using NetOOBServer = OOBTCPServer; + +using NetLogger = UBSHcomNetOutLogger; + +/* ****************************************************************************************** */ +enum UBSHcomNetEndPointState { + NEP_NEW = 0, + NEP_ESTABLISHED = 1, + NEP_BROKEN = 2, + + NEP_BUFF +}; + +std::string &UBSHcomNEPStateToString(UBSHcomNetEndPointState v); + +const char *UBSHcomNetErrStr(int16_t errCode); + +bool UBSHcomNetCloneStringToArray(char *dest, size_t destMax, const std::string &src); + +#define NN_SET_CHAR_ARRAY_FROM_STRING(CHAR_ARRAY, VALUE) \ + do { \ + return UBSHcomNetCloneStringToArray(CHAR_ARRAY, sizeof(CHAR_ARRAY), VALUE); \ + } while (0) + +#define NN_SET_CHAR_ARRAY_FROM_STRING_VOID(CHAR_ARRAY, VALUE) \ + do { \ + UBSHcomNetCloneStringToArray(CHAR_ARRAY, sizeof(CHAR_ARRAY), VALUE); \ + } while (0) + +#define NN_CHAR_ARRAY_TO_STRING(CHAR_ARRAY) \ + { \ + CHAR_ARRAY, strlen(CHAR_ARRAY) <= sizeof(CHAR_ARRAY) ? \ + strlen(CHAR_ARRAY) : \ + sizeof(CHAR_ARRAY) \ + } + +enum class UBSHcomNetRequestStatus { + CALLED = 0, + IN_HCOM, + IN_URMA, + POLLED, + SUCCESS +}; + +std::string &UBSHcomRequestStatusToString(UBSHcomNetRequestStatus status); + +void SetTraceIdInner(const std::string &traceId); + +/// 传输层请求 +/// 它有以下几种典型用法: +/// - 双边 bcopy, 上层应用只需要填充 `lAddress`, `size`和 `upCtxData`. 然后调用 hcom 的 +/// `PostSend()` 接口将 `[lAddress, lAddress + size)` 区间拷贝到一块注册过的内存上,随 +/// 后自动调整 `lAddress`, `lkey` 和 `size` (有额外头部) +/// - 单边 RDMA, 需要填充 `lAddress`, `rAddress`, `lkey`, `rkey`, `size` 和`upCtxData`. 随 +/// 后调用 `PostWrite()` 时会直接使用这些参数,所以要求 `lAddress`, `rAddress` 都提前注册 +/// 好了。 +/// - 单边 UBC 场景与 RDMA 保持一致。 +struct UBSHcomNetTransRequest { + uintptr_t lAddress = 0; ///< 本地读取地址 + uintptr_t rAddress = 0; ///< 远端写入地址 + uint64_t lKey = 0; ///< 本地 lkey, 适用于 RDMA/UB/TCP/SHM + uint64_t rKey = 0; ///< 远端 rkey, 适用于 RDMA/UB/TCP/SHM + void *srcSeg = nullptr; ///< 仅适用于 UB + void *dstSeg = nullptr; ///< 仅适用于 UB + uint32_t size = 0; ///< 写入字节数 + uint16_t upCtxSize = 0; ///< 上层 ctx 大小。默认为 0 代表 upCtxData 无效 + char upCtxData[NN_NO64] = {}; ///< 可用于存储上层 ctx + + UBSHcomNetTransRequest() = default; + + UBSHcomNetTransRequest(void *data, uint32_t dataSize, uint16_t upContextSize) + : lAddress(reinterpret_cast(data)), + size(dataSize), + upCtxSize(upContextSize) + { + } + + UBSHcomNetTransRequest(uintptr_t la, uintptr_t ra, uint64_t lk, uint64_t rk, + uint32_t s, uint16_t upCtxSi) + : lAddress(la), + rAddress(ra), + lKey(lk), + rKey(rk), + size(s), + upCtxSize(upCtxSi) + { + } + + UBSHcomNetTransRequest(uintptr_t la, uintptr_t ra, uint64_t lk, uint64_t rk, + uint32_t s, uint16_t upCtxSi, void *sSeg, void *dSeg) + : lAddress(la), + rAddress(ra), + lKey(lk), + rKey(rk), + size(s), + upCtxSize(upCtxSi) + { + // avoid cleancode check + srcSeg = sSeg; + dstSeg = dSeg; + } +} __attribute__((packed)); + +struct UBSHcomNetTransSglRequest { + UBSHcomNetTransSgeIov *iov = nullptr; // array + uint16_t iovCount = 0; // max count: NET_SGE_MAX_IOV + uint16_t upCtxSize = 0; // upper context size + char upCtxData[NN_NO16] = {}; // upper context data + + UBSHcomNetTransSglRequest() = default; + + UBSHcomNetTransSglRequest(UBSHcomNetTransSgeIov *iovPtr, uint16_t cnt, uint16_t upCtxSi) + : iov(iovPtr), iovCount(cnt), upCtxSize(upCtxSi) + {} +} __attribute__((packed)); + +struct UBSHcomNetTransOpInfo { + uint32_t seqNo = 0; // seq no + int16_t timeout = 0; // timeout + int16_t errorCode = 0; // error code + uint8_t flags = 0; // flags in user case + + UBSHcomNetTransOpInfo() = default; + + UBSHcomNetTransOpInfo(uint32_t seqNo, int16_t timeout, int16_t errorCode, uint8_t flags) + : seqNo(seqNo), timeout(timeout), errorCode(errorCode), flags(flags) + {} + UBSHcomNetTransOpInfo(uint32_t seqNo, int16_t timeout) : seqNo(seqNo), timeout(timeout) {} +} __attribute__((packed)); + +struct UBSHcomNetUdsIdInfo { + uint32_t pid = 0; // process id + uint32_t uid = 0; // user id + uint32_t gid = 0; // group id + + UBSHcomNetUdsIdInfo() = default; + + UBSHcomNetUdsIdInfo(uint32_t pid, uint32_t uid, uint32_t gid) + : pid(pid), uid(uid), gid(gid){}; +} __attribute__((packed)); + +union UBSHcomEpOptions { + struct { + bool tcpBlockingIo; + bool cbByWorkerInBlocking; + int32_t sendTimeout; // send timeout in blocking mode in second + }; + + void Set(bool tcpBI, bool cb, int32_t st) + { + tcpBlockingIo = tcpBI; + cbByWorkerInBlocking = cb; + sendTimeout = st; + } + + UBSHcomEpOptions() + { + tcpBlockingIo = false; + cbByWorkerInBlocking = false; + sendTimeout = -1; + } +}; + +/** + * @brief Cipher suite ids + */ +enum UBSHcomNetCipherSuite { + AES_GCM_128 = 0, + AES_GCM_256 = 1, + AES_CCM_128 = 2, + CHACHA20_POLY1305 = 3, +}; + +enum UBSHcomTlsVersion : uint32_t { + TLS_1_2 = NN_NO771, + TLS_1_3 = NN_NO772, +}; + +/** + * @brief Endpoint for data transfer, representing a connection + */ +class UBSHcomNetEndpoint { +public: + virtual ~UBSHcomNetEndpoint() + { + OBJ_GC_DECREASE(UBSHcomNetEndpoint); + } + + /** + * @brief Only support TCP now, and TCP is nonblocking in default, only could be set from nonblocking to blocking + * if set as blocking, there might occur function problems in some conditions. + */ + virtual NResult SetEpOption(UBSHcomEpOptions &epOptions) = 0; + + /** + * @brief Get using count in sending queue + */ + virtual uint32_t GetSendQueueCount() = 0; + + /** + * @brief Get the id of the endpoint + */ + inline uint64_t Id() const + { + return mId; + } + + /** + * @brief Get the worker index of the endpoint + */ + const UBSHcomNetWorkerIndex &WorkerIndex() const; + + /** + * @brief Check if ep is in established state + */ + bool IsEstablished(); + + /** + * @brief Set the upper context, which could be used store user data pointer and read it when handler called + */ + void UpCtx(uint64_t ctx); + + /** + * @brief get the upper context + */ + uint64_t UpCtx() const; + + /** + * @brief Get the payload + */ + const std::string &PeerConnectPayload() const; + + /** + * @brief Get the local ip + */ + uint32_t LocalIp() const; + + /** + * @brief Get the listen port + */ + uint16_t ListenPort() const; + + /** + * @brief Get the driver version + */ + uint8_t Version() const; + + /** + * @brief Get state, don't change it which could leading to uncertain behavior + */ + inline UBSHcomNetAtomicState &State() + { + return mState; + } + + /** + * @brief Get the peer ip and port of oob tcp connection, which used to identify where peer comes from + */ + virtual const std::string &PeerIpAndPort() = 0; + + virtual const std::string &UdsName() = 0; + + /** + * @brief Post send a request with opcode and header to peer, peer will be trigger new request callback also with + * opcode and header + * + * @param opCode [in] operation code, 0~1023 + * @param request [in] request information, local address and size is used only, the data is copied, you can + * free it after called + * @param seqNo [in] seq number for peer to reply, must be > 0, peer can get it from context.Header().seqNo; + * if it is 0, an auto increased number is generated, for sync client it will be matching request and response + * + * Behavior: + * 1 For RDMA, + * case a) if NET_EP_SELF_POLLING is not set, just issue the send request, not wait for sending request finished + * case b) if NET_EP_SELF_POLLING is set, issue the send request and wait for sending arrived to peer + * + * @return 0 if successful + * + */ + virtual NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + uint32_t seqNo) = 0; + + virtual NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) = 0; + + /** + * @brief Post send a request with opcode and header to peer, peer will be trigger new request callback also with + * opcode and header + * + * @param opCode [in] operation code, 0~1023 + * @param request [in] request information, local address and size is used only, the data is copied, you can + * free it after called + * + * Behavior: + * 1 For RDMA, + * case a) if NET_EP_SELF_POLLING is not set, just issue the send request, not wait for sending request finished + * case b) if NET_EP_SELF_POLLING is set, issue the send request and wait for sending arrived to peer + * + * @return 0 if successful + * + */ + inline NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request) + { + return PostSend(opCode, request, 0); + } + + /** + * @brief Post send a request without opcode and header to peer, peer will be trigger new request callback also + * without opcode and header, this could be used when you have self define header + * + * @param request [in] request information, local address and size is used only, the data is copied, you can + * free it after called + * @param seqNo [in] seq no for peer to reply, must be > 0, peer can get it from context.Header().seqNo, + * for sync client it will be matching request and response + * + * Behavior: + * 1 For RDMA, + * case a) if NET_EP_SELF_POLLING is not set, just issue the send request, not wait for sending request finished + * case b) if NET_EP_SELF_POLLING is set, issue the send request and wait for sending arrived to peer + * + * @return 0 if successful + * + */ + virtual NResult + PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) = 0; + + /** + * @brief Post send a request without opcode and header to peer, peer will be trigger new request callback also + * without opcode and header, this could be used when you have self define header + * + * @param request [in] request information, fill with local different MRs, send to the same remote MR by local + * MRs sequence, you can free it after called. rKey/rAddress do not need to assign + * @param seqNo [in] seq no for peer to reply, must be > 0, peer can get it from context.Header().seqNo, + * for sync client it will be matching request and response + * + * Behavior: + * 1 For RDMA, + * case a) if NET_EP_SELF_POLLING is not set, just issue the send request, not wait for sending request finished + * case b) if NET_EP_SELF_POLLING is set, issue the send request and wait for sending arrived to peer + * + * @return 0 if successful + * + */ + virtual NResult + PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) = 0; + + /** + * @brief Post a single side read request to peer, no callback at peer will be triggered + * + * @param request [in] request information, including 5 important variables, local/remote address/key and size + * also an upper context for user context, which could store 16 bytes + * + * Behavior: + * just issue the read request, not wait for reading request finished + * + * @return 0 if successful + * + */ + virtual NResult PostRead(const UBSHcomNetTransRequest &request) = 0; + + virtual NResult PostRead(const UBSHcomNetTransSglRequest &request) = 0; + + /** + * @brief Post a single side write request to peer, no callback at peer will be triggered + * + * @param request [in] request information, including 5 important variables, local/remote address/key and size + * also an upper context for user context, which could store 16 bytes + * + * Behavior: + * just issue the write request, not wait for writing request finished + * + * @return 0 if successful + * + */ + virtual NResult PostWrite(const UBSHcomNetTransRequest &request) = 0; + + virtual NResult PostWrite(const UBSHcomNetTransSglRequest &request) = 0; + + /** + * @brief Set default timeout + * + * 1. timeout = 0: return immediately + * 2. timeout < 0: never timeout, usually set to -1 + * 3. timeout > 0: second precision timeout. + */ + void DefaultTimeout(int32_t timeout); + + /** + * @brief Wait for send/read/write finish, only for NET_EP_SELF_POLLING is set + * + * @param timeout [in] in second + * 1. timeout = 0: return immediately + * 2. timeout < 0: never timeout, usually set to -1 + * 3. timeout > 0: second precision timeout max is 2000s. + * + * Behavior: + * 1 for send, return when request send to peer + * 2 for read, return when read completion + * 3 for write, return when write completion + * + * @return 0 if successful + * + * NN_TIMEOUT if timeout + * + */ + virtual NResult WaitCompletion(int32_t timeout) = 0; + + /** + * @brief Wait for send/read/write finish, only for NET_EP_SELF_POLLING is set + * + * Behavior: + * 1 for send, return when request send to peer + * 2 for read, return when read completion + * 3 for write, return when write completion + * + * Default timeout will be used + * + * @return 0 if successful + * + * NN_TIMEOUT if timeout + * + */ + inline NResult WaitCompletion() + { + return WaitCompletion(mDefaultTimeout); + } + + /** + * @brief Get the response for send request reply + * + * @param timeout [in] in second + * 1. timeout = 0: return immediately + * 2. timeout < 0: never timeout, usually set to -1 + * 3. timeout > 0: second precision timeout max is 2000s. + * @param ctx [out] ctx for response message + * + * @return 0 if successful + */ + virtual NResult Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) = 0; + + /** + * @brief Get the response for send request reply + * Default timeout will be used + * + * @param ctx [out] ctx for response message + * + * @return 0 if successful + */ + inline NResult Receive(UBSHcomNetResponseContext &ctx) + { + return Receive(mDefaultTimeout, ctx); + } + + /** + * @brief Get the response for send request reply, without header and opCode etc + * + * @param timeout [in] in second + * 1. timeout = 0: return immediately + * 2. timeout < 0: never timeout, usually set to -1 + * 3. timeout > 0: second precision timeout max is 2000s. + * @param ctx [out] ctx for response message, + * + * @return 0 if successful + */ + virtual NResult ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) = 0; + + /** + * @brief Get the response for send request reply, without header and opCode etc + * Default timeout will be used + * + * @param ctx [out] ctx for response message + * + * @return 0 if successful + */ + inline NResult ReceiveRaw(UBSHcomNetResponseContext &ctx) + { + return ReceiveRaw(mDefaultTimeout, ctx); + } + + /** + * @brief Get the response for send request reply, without header and opCode etc + * Default timeout will be used + * + * @param ctx [out] ctx for response message + * + * @return 0 if successful + */ + inline NResult ReceiveRawSgl(UBSHcomNetResponseContext &ctx) + { + return ReceiveRaw(mDefaultTimeout, ctx); + } + + /** + * @brief Estimated Encrypt length for input raw len + * + * @param rawLen [in] raw length before encrypt + * + * @return the length after encrypt + */ + virtual uint64_t EstimatedEncryptLen(uint64_t rawLen) + { + return 0; + } + + /** + * @brief Encrypt data + * + * @param rawData [in] raw data before encrypt + * @param rawLen [in] raw data length before encrypt + * @param cipher [out] cipher data after encrypt + * @param cipherLen [out] cipher data length after encrypt + * + * @return 0 if success + */ + virtual NResult Encrypt(const void *rawData, uint64_t rawLen, void *cipher, + uint64_t &cipherLen) + { + return 0; + } + + /** + * @brief Estimate Decrypt length + * + * @param cipherLen [in] cipher len before decrypt + * + * @return the raw length after decrypt + */ + virtual uint64_t EstimatedDecryptLen(uint64_t cipherLen) + { + return 0; + } + + /** + * @brief Decrypt data + * + * @param cipher [in] cipher data after encrypt + * @param cipherLen [in] cipher data length after encrypt + * @param rawData [out] raw data before encrypt + * @param rawLen [out] raw data length before encrypt + * + * @return 0 if success + */ + virtual NResult Decrypt(const void *cipher, uint64_t cipherLen, + void *rawData, uint64_t &rawLen) + { + return 0; + } + + /** + * @brief Send shm fds, only shm protocol support + * + * @param fds [in] fds to send + * @param len [in] fds count to send + * + * @return 0 if success + */ + virtual NResult SendFds(int fds[], uint32_t len) + { + return NN_EXCHANGE_FD_NOT_SUPPORT; + } + + /** + * @brief Receive shm fds, only shm protocol support + * + * @param fds [out] fds to be received + * @param len [in] fds count to be received + * @param timeoutSec [in] timeout in second for receive. -1 is never timeout + * + * @return 0 if success + */ + virtual NResult ReceiveFds(int fds[], uint32_t len, int32_t timeoutSec) + { + return NN_EXCHANGE_FD_NOT_SUPPORT; + } + + /** + * @brief Get remote uds ids include pid uid gid, only support in oob server and when oob type is uds + * + * @param idInfo [out] remote uds ids + */ + virtual NResult GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &idInfo) + { + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + /** + * @brief Get ip and port of peer + */ + virtual bool GetPeerIpPort(std::string &ip, uint16_t &port) = 0; + + /** + * @brief Close endpoint, then will async call broken function + */ + virtual void Close() {} + + inline uint8_t GetDevIndex() const + { + return mDevIndex; + } + + inline uint8_t GetPeerDevIndex() const + { + return mPeerDevIndex; + } + + inline uint8_t GetBandWidth() const + { + return mBandWidth; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +protected: + explicit UBSHcomNetEndpoint(uint64_t id, const UBSHcomNetWorkerIndex &workerWholeIndex) + : mId(id) + { + OBJ_GC_INCREASE(UBSHcomNetEndpoint); + mWorkerIndex = workerWholeIndex; + } + + inline uint32_t NextSeq() + { + return __sync_fetch_and_add(&mSeqIndex, 1); + } + + /** + * To later, change this to private and using friend to access this + */ + inline std::atomic_bool &EPBrokenProcessed() + { + return mEPBrokenProcessed; + } + + bool IsNeedSendHb() const; + + virtual NResult PostSendSglInline(uint16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo) + { + return PostSend(opCode, request, opInfo); + } + +protected: + uint64_t mUpCtx = 0; + uint32_t mSeqIndex = 1; + uint32_t mSegSize = 0; + /// mAllowedSize 通常表示为除 UBSHcomNetTransHeader 外可允许发送消息的大小。但是有 + /// 时候服务层可能会有 ExtHeader 需要发送,需注意。通常它的检查是在 + /// POST_SEND_VALIDATION 中,当涉及到 ExtHeader 时可直接减去 extHeaderSize. + /// \see NetAsyncEndpoint::PostSend + /// \see NetSyncEndpoint::PostSend + /// \see NetUBAsyncEndpoint::PostSend + /// \see NetUBSyncEndpoint::PostSend + uint32_t mAllowedSize = 0; + int32_t mDefaultTimeout = -1; + + UBSHcomNetWorkerIndex mWorkerIndex{}; + UBSHcomNetAtomicState mState{NEP_NEW}; + + bool mIsNeedSendHb = false; + std::atomic_bool mEPBrokenProcessed{false}; + + uint64_t mId = 0; + UBSHcomNetUdsIdInfo mRemoteUdsIdInfo{}; + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class NetHeartbeat; + + // 服务层拆包专用,上层用户在调用时应当保证 extHeaderType != RAW + virtual NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo, + const UBSHcomExtHeaderType extHeaderType, const void *extHeader, uint32_t extHeaderSize) + { + NN_LOG_WARN("PostSend with header unimplemented yet!!!"); + return NN_ERROR; + } + + friend class NetChannel; + friend class HcomChannelImp; + +private: + /** + * @brief Set the connect info + */ + void StoreConnInfo(uint32_t localIp, uint16_t listenPort, uint8_t version, + const std::string &payload); + + /** + * @brief Set the payload + */ + void Payload(const std::string &payload); + + /** + * @brief Set remote uds id info + */ + void RemoteUdsIdInfo(uint32_t pid, uint32_t uid, uint32_t gid); + + uint32_t mLocalIp = INVALID_IP; + uint16_t mListenPort = 0; + uint8_t mVersion = 0; + std::string mPayload; + uint8_t mDevIndex = 0; + uint8_t mPeerDevIndex = 0; + uint8_t mBandWidth = 0; + + friend class NetDriverRDMAWithOob; + friend class NetDriverSockWithOOB; + friend class NetDriverShmWithOOB; +#ifdef UB_BUILD_ENABLED + friend class NetDriverUBWithOob; +#endif +}; + +inline void UBSHcomNetEndpoint::UpCtx(uint64_t ctx) +{ + mUpCtx = ctx; +} + +inline uint64_t UBSHcomNetEndpoint::UpCtx() const +{ + return mUpCtx; +} + +inline bool UBSHcomNetEndpoint::IsNeedSendHb() const +{ + return mIsNeedSendHb; +} + +/* ****************************************************************************************** */ +class UBSHcomNetMessage { +public: + inline uint32_t DataLen() const + { + return mDataLen; + } + + inline void *Data() const + { + return mBuf; + } + + uint32_t GetBufLen() const + { + return mBufLen; + } + +protected: + UBSHcomNetMessage() + { + OBJ_GC_INCREASE(UBSHcomNetMessage); + } + + ~UBSHcomNetMessage() + { + if (mBuf != nullptr) { + free(mBuf); + mBuf = nullptr; + } + + OBJ_GC_DECREASE(UBSHcomNetMessage); + } + + inline bool AllocateIfNeed(uint32_t newSize) + { + if (NN_UNLIKELY(newSize == NN_NO0)) { + NN_LOG_ERROR("Invalid msg size " << newSize << ", alloc failed"); + return false; + } + if (newSize > mBufLen) { + if (mBuf != nullptr) { + free(mBuf); + } + + if ((mBuf = malloc(newSize)) != nullptr) { + mBufLen = newSize; + return true; + } + mBuf = nullptr; + mBufLen = NN_NO0; + return false; + } + + return true; + } + + inline void SetBuf(void *buf, uint32_t len) + { + mBuf = buf; + mBufLen = len; + } + + UBSHcomNetMessage(const UBSHcomNetMessage &) = delete; + UBSHcomNetMessage(UBSHcomNetMessage &&) = delete; + UBSHcomNetMessage &operator=(const UBSHcomNetMessage &) = delete; + UBSHcomNetMessage &operator=(UBSHcomNetMessage &&) = delete; + +private: + uint32_t mBufLen = 0; + uint32_t mDataLen = 0; + void *mBuf = nullptr; + + friend class NetAsyncEndpoint; + friend class NetSyncEndpoint; + friend class NetAsyncEndpointSock; + friend class NetSyncEndpointSock; + friend class NetSyncEndpointShm; + friend class NetDriverSockWithOOB; + friend class NetDriverRDMAWithOob; + friend class NetDriverShmWithOOB; + friend class NetAsyncEndpointShm; + friend class NetServiceDefaultImp; + +#ifdef UB_BUILD_ENABLED + friend class NetUBAsyncEndpoint; + friend class NetUBSyncEndpoint; + friend class NetDriverUBWithOob; +#endif +}; + +/* ****************************************************************************************** */ +/** + * @brief UBSHcomNetRequestContext + */ +class UBSHcomNetRequestContext { +public: + enum NN_OpType : uint8_t { + NN_SENT = 0, + NN_SENT_RAW = 1, + NN_SENT_RAW_SGL = 2, + NN_RECEIVED = 3, + NN_RECEIVED_RAW = 4, + NN_WRITTEN = 5, + NN_READ = 6, + NN_SGL_WRITTEN = 7, + NN_SGL_READ = 8, + NN_RNDV = 9, + NN_SENT_SGL_INLINE = 10, + + NN_INVALID_OP_TYPE = 255, + }; + + /** + * @brief Get the endpoint of context + */ + const UBSHcomNetEndpointPtr &EndPoint() const; + + /** + * @brief Get result of all operation + */ + NResult Result() const; + + /** + * @brief Get header of two side operation + */ + const UBSHcomNetTransHeader &Header() const; + + /** + * @brief Get the message received + */ + UBSHcomNetMessage *Message() const; + + /** + * @brief Get the operation type, send/receive/read/write + */ + NN_OpType OpType() const; + + /** + * @brief Get the original request + */ + const UBSHcomNetTransRequest &OriginalRequest() const; + + /** + * @brief Get the original sgl request + */ + const UBSHcomNetTransSglRequest &OriginalSgeRequest() const; + + // the passed context cannot be copy directly need to use SafeClone() + // if needing to transfer to thread to process in async + static bool + SafeClone(const UBSHcomNetRequestContext &old, const UBSHcomNetRequestContextPtr &newOne) + { + if (NN_UNLIKELY(newOne.Get() == nullptr)) { + return false; + } + + newOne->mEp = old.mEp; + newOne->mHeader = old.mHeader; + newOne->mOpType = old.mOpType; + return true; + } + + UBSHcomNetRequestContext() : mMessage(nullptr) + { + OBJ_GC_INCREASE(UBSHcomNetRequestContext); + } + + ~UBSHcomNetRequestContext() + { + OBJ_GC_DECREASE(UBSHcomNetRequestContext); + } + + UBSHcomNetRequestContext(const UBSHcomNetRequestContext &) = delete; + UBSHcomNetRequestContext(UBSHcomNetRequestContext &&) = delete; + UBSHcomNetRequestContext &operator=(const UBSHcomNetRequestContext &) = delete; + UBSHcomNetRequestContext &operator=(UBSHcomNetRequestContext &&) = delete; + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + UBSHcomNetEndpointPtr mEp = nullptr; + NResult mResult = NN_OK; + UBSHcomNetTransHeader mHeader{}; + NN_OpType mOpType = NN_RECEIVED; + UBSHcomNetMessage *mMessage = nullptr; + UBSHcomNetTransRequest mOriginalReq{}; // copy information, not original address + + UBSHcomNetTransSgeIov iov[NET_SGE_MAX_IOV]; + UBSHcomNetTransSglRequest + mOriginalSglReq{}; // copy information, not original address + + UBSHcomExtHeaderType extHeaderType = UBSHcomExtHeaderType::RAW; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class NetAsyncEndpoint; + friend class NetSyncEndpoint; + friend class NetAsyncEndpointSock; + friend class NetSyncEndpointSock; + friend class NetSyncEndpointShm; + friend class NetDriverSockWithOOB; + friend class NetDriverRDMAWithOob; + friend class NetDriverShmWithOOB; + friend class NetServiceGlobalObject; + friend class NetServiceDefaultImp; + friend class HcomServiceImp; + +#ifdef UB_BUILD_ENABLED + friend class NetUBAsyncEndpoint; + friend class NetUBSyncEndpoint; + friend class NetDriverUBWithOob; +#endif +}; + +inline const UBSHcomNetEndpointPtr &UBSHcomNetRequestContext::EndPoint() const +{ + return mEp; +} + +inline NResult UBSHcomNetRequestContext::Result() const +{ + return mResult; +} + +inline const UBSHcomNetTransHeader &UBSHcomNetRequestContext::Header() const +{ + return mHeader; +} + +inline UBSHcomNetMessage *UBSHcomNetRequestContext::Message() const +{ + return mMessage; +} + +inline UBSHcomNetRequestContext::NN_OpType UBSHcomNetRequestContext::OpType() const +{ + return mOpType; +} + +inline const UBSHcomNetTransRequest &UBSHcomNetRequestContext::OriginalRequest() const +{ + return mOriginalReq; +} + +inline const UBSHcomNetTransSglRequest &UBSHcomNetRequestContext::OriginalSgeRequest() const +{ + return mOriginalSglReq; +} + +/** + * @brief Response context for sync call + */ +class UBSHcomNetResponseContext { +public: + /** + * @brief Get header of response + */ + const UBSHcomNetTransHeader &Header() const; + + UBSHcomNetMessage *Message() const; + + UBSHcomNetResponseContext() : mMessage(nullptr) + { + OBJ_GC_INCREASE(UBSHcomNetResponseContext); + } + + ~UBSHcomNetResponseContext() + { + OBJ_GC_DECREASE(UBSHcomNetResponseContext); + } + + UBSHcomNetResponseContext(const UBSHcomNetRequestContext &) = delete; + UBSHcomNetResponseContext(UBSHcomNetRequestContext &&) = delete; + UBSHcomNetResponseContext &operator=(const UBSHcomNetRequestContext &) = delete; + UBSHcomNetResponseContext &operator=(UBSHcomNetRequestContext &&) = delete; + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + UBSHcomNetTransHeader mHeader{}; + UBSHcomNetMessage *mMessage = nullptr; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class NetAsyncEndpoint; + friend class NetSyncEndpoint; + friend class NetAsyncEndpointSock; + friend class NetSyncEndpointSock; + friend class NetSyncEndpointShm; + friend class NetDriverSockWithOOB; + friend class NetDriverRDMAWithOob; + +#ifdef UB_BUILD_ENABLED + friend class NetUBAsyncEndpoint; + friend class NetUBSyncEndpoint; + friend class NetDriverUBWithOob; +#endif +}; + +inline const UBSHcomNetTransHeader &UBSHcomNetResponseContext::Header() const +{ + return mHeader; +} + +inline UBSHcomNetMessage *UBSHcomNetResponseContext::Message() const +{ + return mMessage; +} + +/* ****************************************************************************************** */ +/** + * @brief Memory region for one side operation + */ +class UBSHcomNetMemoryRegion { +public: + /** + * @brief Initialize memory region, lkey can be got after + * + * Behavior + * 1) RDMA, physical memory will be allocated and registered to hardware, will be pinned + * 2) TCP/UDS, physical memory will be allocated + * + * @return 0 successful + */ + virtual NResult Initialize() = 0; + + /** + * @brief Get local key + */ + inline uint64_t GetLKey() const + { + return mLKey; + } + + /** + * @brief Get address + */ + inline uintptr_t GetAddress() const + { + return mBuf; + } + + /** + * @brief Get size of memory size + */ + inline uint64_t Size() const + { + return mSize; + } + + virtual void *GetMemorySeg() = 0; + + virtual void GetVa(uint64_t &va, uint64_t &va_len, uint32_t &token_id) = 0; + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +protected: + UBSHcomNetMemoryRegion(const std::string &name, bool extMem, uintptr_t buf, + uint64_t size) + : mName(name), mExternalMemory(extMem), mSize(size), mBuf(buf) + {} + + /** + * @brief UnInitialize + */ + virtual void UnInitialize() = 0; + + virtual ~UBSHcomNetMemoryRegion() = default; + +protected: + std::string mName; + bool mExternalMemory = false; + uint64_t mSize = 0; + + uintptr_t mBuf = 0; + bool mGetBufWithMapping = false; + uint64_t mLKey = 0; + uintptr_t mPgRegion = 0; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class NetDriverRDMA; + friend class NetDriverSockWithOOB; + friend class NetDriverShmWithOOB; +#ifdef UB_BUILD_ENABLED + friend class NetDriverUB; + friend class UBJetty; +#endif + friend class HcomServiceImp; +}; + +/** + * @brief Type of allocator + */ +enum UBSHcomNetMemoryAllocatorType { + DYNAMIC_SIZE = + 0, /* allocate dynamic memory size, there is alignment with X KB */ + DYNAMIC_SIZE_WITH_CACHE = + 1, /* allocator with dynamic memory size, with pre-allocate cache for performance */ +}; + +/** + * @brief Covert UBSHcomNetMemoryAllocatorType to string + * + * @param v [in] value to type to be converted + * + * @return string coverted + */ +std::string &UBSHcomNetMemoryAllocatorTypeToString(UBSHcomNetMemoryAllocatorType v); + +/** + * @brief Allocator cache tier policy + */ +enum UBSHcomNetMemoryAllocatorCacheTierPolicy : int16_t { + TIER_TIMES = 0, /* tier by times of min-block-size */ + TIER_POWER = 1, /* tier by power of min-block-size */ +}; + +/** + * @brief Allocator options + */ +struct UBSHcomNetMemoryAllocatorOptions { + uintptr_t address = 0; /* base address of large range of memory for allocator */ + uint64_t size = 0; /* size of large memory chuck */ + uint32_t minBlockSize = 0; /* min size of block can be allocated from allocator */ + uint32_t bucketCount = NN_NO8192; /* default size of hash bucket */ + bool alignedAddress = false; /* force to align the memory block allocated */ + uint16_t cacheTierCount = NN_NO8; /* for DYNAMIC_SIZE_WITH_CACHE only */ + uint16_t cacheBlockCountPerTier = NN_NO16; /* for DYNAMIC_SIZE_WITH_CACHE only */ + UBSHcomNetMemoryAllocatorCacheTierPolicy cacheTierPolicy = TIER_TIMES; /* tier policy */ + + std::string ToString() const; +}; + +/** + * @brief Allocator to alloc memory area from a large mount of memory. + * + * For example, we have RDMA memory region, which already registered to NIC, + * and we need to reuse memory on this region, so we need to alloc sub part + * of memory from the large memory region, use it and return it. + */ +class UBSHcomNetMemoryAllocator { +public: + /** + * @brief Create a memory allocator + * + * @param t [in] type of allocator + * @param options [in] options + * @param allocator [out] allocator created + */ + static NResult Create(UBSHcomNetMemoryAllocatorType t, + const UBSHcomNetMemoryAllocatorOptions &options, + UBSHcomNetMemoryAllocatorPtr &out); + +public: + virtual ~UBSHcomNetMemoryAllocator() = default; + + /** + * @brief Get the memory region key + * + * @return key + */ + uint64_t MrKey() const; + + /** + * @brief Set the memory region key + */ + void MrKey(uint64_t mrKey); + + void *GetTargetSeg() const; + + void SetTargetSeg(void *targetSeg); + + /** + * @brief Get the memory offset based on base address + * + * @param address [in] memory address + * + * @return offset comparing to base address + */ + virtual uintptr_t MemOffset(uintptr_t address) const = 0; + + /** + * @brief Get free memory size + * + * @return Free memory size + */ + virtual uint64_t FreeSize() const = 0; + + /** + * @brief Allocate memory area + * + * @param size [in] size of memory of demand + * @param outAddress [out] allocated memory address + * + * @return 0 if successful + */ + virtual NResult Allocate(uint64_t size, uintptr_t &outAddress) = 0; + + /** + * @brief Free the address allocated by #Allocate function + * + * @param address [in] address to be freed + * + * @param 0 if successful + */ + virtual NResult Free(uintptr_t address) = 0; + + /** + * @brief function should be called before managed memory freeing + * + * Remove memory protection if enabled(cmake -DBUILD_WITH_ALLOCATOR_PROTECTION=ON), + * should be called before freeing the memory passed in, otherwise sigsegv will raise by free(), + * It's suggested to be called even if you are not using memory protection currently, + * in case you may miss this once you turn memory protection on in the future. + */ + virtual void Destroy(){}; + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + uint64_t mMrKey = 0; + void *mTargetSeg = nullptr; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; + +inline uint64_t UBSHcomNetMemoryAllocator::MrKey() const +{ + return mMrKey; +} + +inline void UBSHcomNetMemoryAllocator::MrKey(uint64_t mrKey) +{ + mMrKey = mrKey; +} + +inline void *UBSHcomNetMemoryAllocator::GetTargetSeg() const +{ + return mTargetSeg; +} + +inline void UBSHcomNetMemoryAllocator::SetTargetSeg(void *targetSeg) +{ + mTargetSeg = targetSeg; +} + +/** + * @brief Oob listening information for multiple listen port + */ +struct UBSHcomNetOobListenerOptions { + char ip[NN_NO16]{}; /* listening ip */ + uint16_t port = 9980; /* listening port */ + uint16_t targetWorkerCount = + UINT16_MAX; /* the count of target workers, if >= 1, the + accepted socket will be attached to sub set to workers, 0 means all */ + + /** + * @brief Set ip/port/targetWorkerCount + * + * @param pIp [in] ip to set + * @param pp [in] port to set + * @param twc [in] target worker count to set + */ + bool Set(const std::string &pIp, uint16_t pp, uint16_t twc); + + /** + * @brief Set ip/port/targetWorkerCount + * + * @param eid [in] public jetty eid to set + * @param id [in] public jetty id to set + * @param twc [in] target worker count to set + */ + bool SetEid(const std::string &eid, uint16_t id, uint16_t twc); + + /** + * @brief Set ip/port, targetWorkerCount will be set to uint16_max + * + * @param pIp [in] ip to set + * @param pp [in] port to set + */ + bool Set(const std::string &pIp, uint16_t pp); + + /** + * @brief Set port/targetWorkerCount + * + * @param pp [in] port to set + * @param twc [in] target worker count to set + */ + bool Set(uint16_t pp, uint16_t twc); + + /** + * @brief Set the listen ip + * + * @param value [in] the ip to set + * + * @return 0 if successful, otherwise it could be the length of value is too large + */ + NResult Ip(const std::string &value); + + /** + * @brief Get ip of listening + */ + std::string Ip() const; +} __attribute__((packed)); + +/** + * @brief Oob listening information for multiple listen file + */ +struct UBSHcomNetOobUDSListenerOptions { + char name[NN_NO96] {}; /* UDS name for listen or file path */ + uint16_t perm = 0600; /* if 0 means not use file, otherwise use file and this perm as file perm, max is 0600 */ + uint16_t targetWorkerCount = UINT16_MAX; /* the count of target workers, if >= 1, the + accepted socket will be attached to sub set to workers, 0 means all */ + bool isCheck = true; /* whether to verify the permission on the UDS file */ + + /** + * @brief Set name/targetWorkerCount + * + * @param name [in] name or file path to set + * @param twc [in] target worker count to set + */ + bool Set(const std::string &pName, uint16_t twc); + + /** + * @brief Set name for uds oob + * + * @param value [in] the name or file path to set, less than 32 + * + * @return 0 if successful, otherwise it could be the length of value is too large + */ + bool Name(const std::string &value); + + /** + * @brief Get name or file path of listening + */ + std::string Name() const; +} __attribute__((packed)); + +/* ****************************************************************************************** */ +using UBSHcomNetDriverNewEndPointHandler = + std::function; +using UBSHcomNetDriverEndpointBrokenHandler = std::function; + +// the passed context cannot be copy directly need to use SafeClone() +// if needing to transfer to thread to process in async +using UBSHcomNetDriverReceivedHandler = std::function; +using UBSHcomNetDriverSentHandler = std::function; +using UBSHcomNetDriverOneSideDoneHandler = std::function; +using UBSHcomNetDriverIdleHandler = std::function; + +/** + * @brief During establish TLS connection, we can verify peer cert. There are three types of behaviors: + * a) don't verify peer certification + * b) verify peer certification by what hcom provided + * c) verify peer certification using caller's + */ +enum UBSHcomPeerCertVerifyType : uint8_t { + VERIFY_BY_NONE = 0, /* don't verify peer certification */ + VERIFY_BY_DEFAULT = 1, /* verify peer certification by what hcom provided, crl check and cert check */ + VERIFY_BY_CUSTOM_FUNC = 2, /* verify peer certification using caller's */ +}; + +/** + * @brief Callback function to erase key pass after used it, for huawei's security policy + * that "don't store plaintext in memory" + * + * @param void* [in] the address where store key passwd + * @param int [in] the length key passwd + */ +using UBSHcomTLSEraseKeypass = std::function; + +/** + * @brief Callback function to get certification path + * + * @param name [in] a name for logging + * @param path [out] cert file path + */ +using UBSHcomTLSCertificationCallback = std::function; + +/** + * @brief Callback function to get TLS private key and related things, when establishing a connection + * + * @param name [in] a name for logging + * @param path [out] path of cert file + * @param password [out] key passwd of private key + * @param length [out] length of key passwd + * @param erase [out] callback function to erase key passwd in memory, which is called just after key + * passwd is used + */ +using UBSHcomTLSPrivateKeyCallback = std::function; + +/** + * @brief Customize callback function of verify cert, which is used in UBSHcomTLSCaCallback + */ +using UBSHcomTLSCertVerifyCallback = std::function; + +/** + * @brief Callback function of certification check + * + * @param name [in] a name for logging + * @param capath [out] path of ca files, could be multiple files + * @param crlPath [out] path of crl file + * @param verifyPeerCert [out] cert verification type, none | default_by_hcom | customized, if customized, cb need + * to be specified + * @param cb [out] callback function of customized function + */ +using UBSHcomTLSCaCallback = std::function; + +/** + * @brief UBSHcomNetDriver secure mode + */ +enum UBSHcomNetDriverSecType : uint8_t { + NET_SEC_DISABLED = 0, + NET_SEC_VALID_ONE_WAY = 1, + NET_SEC_VALID_TWO_WAY = 2, +}; + +/** + * @brief Sec callback function, when oob connect build, this function will be called to generate auth info. + * if this function not set secure type is C_NET_SEC_NO_VALID and oob will not send secure info + * + * @param ctx [in] ctx from connect param ctx, and will send in auth process + * @param flag [out] flag to send in auth process + * @param type [out] secure type, value should set in oob client, and should in [C_NET_SEC_ONE_WAY, + * C_NET_SEC_TWO_WAY] + * @param output [out] secure info created + * @param outLen [out] secure info length + * @param needAutoFree [out] secure info need to auto free in hcom or not + */ +using UBSHcomNetDriverEndpointSecInfoProvider = std::function; + +/** + * @brief ValidateSecInfo callback function, when oob connect build, this function will be called to validate auth info + * if this function not set oob will not validate secure info + * + * @param ctx [in] ctx received in auth process + * @param flag [in] flag received in auth process + * @param input [in] secure info received + * @param inputLen [in] secure info length + */ +using UBSHcomNetDriverEndpointSecInfoValidator = + std::function; + +/** + * @brief Callback function of PSK check, set for client + * + * @param ssl [in] SSL connection pointer + * @param md [in] digest algorithm + * @param id [out] the identity that the client gives to server uses to find the psk + * @param idlen [out] the id length + * @param sess [out] SSL session + * + * @return int 1 on success or 0 on failure + */ +using UBSHcomPskUseSessionCb = + std::function; + +/** + * @brief Callback function of PSK check, set for server + * + * @param ssl [in] SSL connection pointer + * @param identity [in] Client's identity (provided by the client) + * @param identity_len [in] Length of the client's identity + * @param sess [out] SSL session + * + * @return int 1 on success or 0 on failure + */ +using UBSHcomPskFindSessionCb = + std::function; + +std::string &UBSHcomNetDriverSecTypeToString(UBSHcomNetDriverSecType v); + +/** + * @brief UBSHcomNetDriver working mode + */ +enum NetDriverOobType : uint8_t { + NET_OOB_TCP = 0, + NET_OOB_UDS = 1, + NET_OOB_UB = 2, +}; + +std::string &UBSHcomNetDriverOobTypeToString(NetDriverOobType v); + +/** + * @brief UBSHcomNetDriver working mode + */ +enum UBSHcomNetDriverWorkingMode : uint8_t { + NET_BUSY_POLLING = 0, + NET_EVENT_POLLING = 1, +}; + +/** + * @brief UBSHcomNetDriver load balance policy + */ +enum UBSHcomNetDriverLBPolicy : uint8_t { + NET_ROUND_ROBIN = 0, + NET_HASH_IP_PORT = 1, +}; + +std::string &UBSHcomNetDriverLBPolicyToString(UBSHcomNetDriverLBPolicy v); + +/// UB-C 专用: UB-C 具有多路径能力,发送时使用多条路径可以增大带宽,对于带宽要求 +/// 不高、时延敏感型业务又提供单路径直连模式。 +enum class UBSHcomUbcMode : int8_t { + LowLatency = 0, ///< 低时延模式,使用单路径发送 + HighBandwidth = 1, ///< 高带宽模式,使用多条路径发送 +}; + +struct UBSHcomWorkerGroupInfo { + int8_t threadPriority = 0; // [-20, 19], 19 is the lowest, -20 is the highest + uint16_t threadCount = 1; // total number of threads in the worker group + uint16_t groupId = 0; // group id of the worker group + std::pair cpuIdsRange; // worker groups cpu ids range +}; + +/** + * @brief UBSHcomNetDriver options + */ +struct UBSHcomNetDriverOptions { + union { + char netDeviceIpMask[NN_NO256]{}; // IP 掩码。非 UBC 多路径场景通过此掩码可查找得到实际设备的 IP + uint8_t netDeviceEid[NN_NO16]; // UB EID (128b). 多路径聚合设备为非 IP 设备,需用户显式指定 + } __attribute__((packed)); + char netDeviceIpGroup[NN_NO1024]{}; // ip group for devices + bool enableTls = true; // enable ssl + UBSHcomNetDriverSecType secType = NET_SEC_DISABLED; // security type + UBSHcomTlsVersion tlsVersion = TLS_1_3; // tls version, default TLS1.3 (772) + UBSHcomNetCipherSuite cipherSuite = + AES_GCM_128; // if tls enabled can set cipher suite, client and server should same + /* worker setting */ + bool dontStartWorkers = false; // start worker or not + UBSHcomNetDriverWorkingMode mode = + NET_BUSY_POLLING; // worker polling mode, could busy polling or event polling + char workerGroups[NN_NO64]{}; // worker groups, for example 1,3,3 + char workerGroupsCpuSet + [NN_NO128]{}; // worker groups cpu set, for example 1-1,2-5,na + char workerGroupsThreadPriority[NN_NO64] {}; // worker groups thread priority, for example -10,na,9 + // worker thread priority [-20,19], 19 is the lowest, -20 is the highest, 0 (default) means do not set priority + int workerThreadPriority = 0; + /* connection attribute */ + NetDriverOobType oobType = + NET_OOB_TCP; // oob type, tcp or UDS, UDS cannot accept remote connection + UBSHcomNetDriverLBPolicy lbPolicy = + NET_ROUND_ROBIN; // select worker load balance policy, default round-robin + uint16_t magic = NN_NO256; // magic number for c/s connect validation + uint8_t version = 0; // program version used by connect validation + /* heart beat attribute */ + uint16_t heartBeatIdleTime = NN_NO60; // heart beat idle time, in seconds + uint16_t heartBeatProbeTimes = NN_NO7; // heart beat probe times + uint16_t heartBeatProbeInterval = + NN_NO2; // heart beat probe interval, in seconds + /* options for only tcp protocol */ + // timeout during io (s), it should be [-1, 1024], -1 means do not set, 0 means never timeout during io + int16_t tcpUserTimeout = -1; + bool tcpEnableNoDelay = true; // tcp TCP_NODELAY option, true in default + bool tcpSendZCopy = + false; // tcp whether copy request to inner memory, false in default + /* The buffer sizes will be adjusted automatically when these two variables are 0, and the performance would be + * better */ + uint16_t tcpSendBufSize = + 0; // tcp connection send buffer size in kernel, by KB + uint16_t tcpReceiveBufSize = + 0; // tcp connection send receive buf size in kernel, by KB + /* options for rdma protocol only */ + uint32_t mrSendReceiveSegCount = + NN_NO8192; // memory region segment count for two side operation + uint32_t mrSendReceiveSegSize = + NN_NO1024; // data size of memory region segment + /* transmit of 256b data performs better when dmSegSize is 290 */ + uint32_t dmSegSize = NN_NO290; // data size of device memory segment + uint32_t dmSegCount = NN_NO400; // segment count of device memory segment + uint16_t completionQueueDepth = NN_NO2048; // completion queue size of rdma + uint16_t maxPostSendCountPerQP = NN_NO64; // max number request could issue + uint16_t prePostReceiveSizePerQP = NN_NO64; // pre post receive of qp + uint16_t pollingBatchSize = NN_NO4; // polling batch size for worker + uint32_t eventPollingTimeout = + NN_NO500; // event polling timeout in ms, max value is 2000000ms + uint32_t qpSendQueueSize = + NN_NO256; // max send working request of qp for rdma + uint32_t qpReceiveQueueSize = + NN_NO256; // max receive working request of qp for rdma + uint16_t oobConnHandleThreadCount = + NN_NO2; // server accept connection thread num + uint32_t oobConnHandleQueueCap = + NN_NO4096; // server accept connection queue capability + uint32_t maxConnectionNum = NN_NO250; // max connection number + bool enableMultiRail = false; // enable multi rail + uint8_t slave = 1; // slave 1 or 2 + + char oobPortRange[NN_NO16]{}; // port range when enable port auto selection + + UBSHcomUbcMode ubcMode = UBSHcomUbcMode::LowLatency; + + /* verify the common options of each driver */ + NResult ValidateCommonOptions(); + + std::string NetDeviceIpMask() const; + + std::string NetDeviceIpGroup() const; + + std::string WorkGroups() const; + + std::string WorkerGroupCpus() const; + + std::string WorkerGroupThreadPriority() const; + + /// 设置设备 IP 掩码以辅助查找得到真实通信设备 IP,与 `SetNetDeviceEid()` 冲突, + /// 不可同时使用。当前仅支持 IPv4, 格式如下 `192.168.0.1/24`. + bool SetNetDeviceIpMask(const std::string &mask); + + /// 设置 UB 设备 EID 以辅助查找得到真实通信设备, 与 `SetNetDeviceIpMask()` 冲突,不可 + /// 同时使用。EID 格式类似 IPv6, 为如下格式 `0000:0000:0000:0000:0000:xxxx:0x0x:0x0x`. + /// 要求在去除冒号后为 16 字节,不可省略每个数字前的前导 0. 通常用户只需要复制 + /// `urma_admin show` 的输出即可。 + bool SetNetDeviceEid(const std::string &eid); + + /** + * @brief Set the ip mask for net devices + * + * @param mask Each element in the mask vector represent an ipmask. e.g. mask = {192.168.0.1/24. 192.168.1.1/24} + * @return true set success + * @return false set failed + */ + bool SetNetDeviceIpMask(const std::vector &mask); + + /** + * @brief Set the ip group for net devices, example: 192.168.0.1;192.168.0.2 + */ + bool SetNetDeviceIpGroup(const std::string &ipGroup); + + /** + * @brief Set the ip group for net devices + * + * @param ipGroup Each element in the ipGroup represent an ip. e.g. ipGroup = {192.168.0.1;192.168.0.2} + * @return true set success + * @return false set failed + */ + bool SetNetDeviceIpGroup(const std::vector &ipGroup); + + /** + * @brief Set worker groups, example: 1,3,4 + * meaning 3 groups for workers: + * group0 has 1 workers + * group1 has 3 workers + * group2 has 4 workers + */ + bool SetWorkerGroups(const std::string &groups); + + /** + * @brief Set worker groups, example: 10-10,11-13,na + * meaning 3 groups for workers: + * group0 bind to cpu 10 + * group1 bind to cpu 11, 12, 13 + * group2 not bind to cpu + */ + bool SetWorkerGroupsCpuSet(const std::string &value); + + /** + * @brief Set worker groups thread priority, example: 10,na,15 + * meaning 3 groups for workers: + * group0 thread priority 10 + * group1 not set thread priority + * group2 thread priority 15 + */ + bool SetWorkerGroupThreadPriority(const std::string &value); + + void SetUbcMode(UBSHcomUbcMode m) + { + ubcMode = m; + } + /** + * @brief Set the Worker Groups Info by UBSHcomWorkerGroupInfo vector + * + * @param workerGroups vector of UBSHcomWorkerGroupInfo, each element represent a worker group config + * @return true set success + * @return false set fail + */ + bool SetWorkerGroupsInfo(const std::vector &workerGroupInfos); + + std::string ToString() const; + + std::string ToStringForSock() const; +} __attribute__((packed)); + +/** + * @brief The protocol of driver + */ +enum UBSHcomNetDriverProtocol { + RDMA = 0, + TCP = 1, + UDS = 2, + SHM = 3, + UBC = 7, + + UNKNOWN = 255, +}; + +/** + * @brief Protocol to string + */ +std::string &UBSHcomNetDriverProtocolToString(UBSHcomNetDriverProtocol v); + +/** + * @brief UBSHcomNetDriver + */ +class UBSHcomNetDriver { +public: + /** + * @brief Get a driver instance by name + * + * @param t [in] protocol of this driver + * @param name [in] name of driver to be created + * @param startOobSvr [in] start oob server or not + * + * @return Driver instance is OK, otherwise return nullptr + */ + static UBSHcomNetDriver *Instance(UBSHcomNetDriverProtocol t, const std::string &name, bool startOobSvr); + + /** + * @brief Destroy driver instance by name + * + * @param name [in] name of driver to be created + * + * @return Destroy driver instance is OK, otherwise return error + */ + static NResult DestroyInstance(const std::string &name); + + /** + * @brief Check if local host support certain protocol + * + * @param t [in] protocol + * @param t [out] device info + * + * @return true is support + */ + static bool LocalSupport(UBSHcomNetDriverProtocol t, UBSHcomNetDriverDeviceInfo &deviceInfo); + + static bool MultiRailGetDevCount(UBSHcomNetDriverProtocol t, std::string ipMask, uint16_t &enableDevCount, + std::string ipGroup); + +public: + virtual ~UBSHcomNetDriver() + { + OBJ_GC_DECREASE(UBSHcomNetDriver); + } + + /** + * @brief Initialize the net driver + * + * @param option [in] option for initialize + * + * @return 0 if successful + */ + virtual NResult Initialize(const UBSHcomNetDriverOptions &option) = 0; + + /** + * @brief UnInitialize the net driver + */ + virtual void UnInitialize() = 0; + + /** + * @brief Start the net driver + * + * @return 0 if successful + */ + virtual NResult Start() = 0; + + /** + * @brief Stop the net driver + */ + virtual void Stop() = 0; + + /** + * @brief Register a memory region, the memory will be allocated internally + * + * @param size [in] size of the memory region + * @param mr [out] memory region registered + * + * @return 0 successful + */ + virtual NResult + CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) = 0; + + /** + * @brief Register a memory region, the memory need to be passed in + * + * @param address [in] the memory point need to be registered + * @param size [in] size of the memory region + * @param mr [out] memory region registered + * + * @return 0 successful + */ + virtual NResult CreateMemoryRegion(uintptr_t address, uint64_t size, + UBSHcomNetMemoryRegionPtr &mr) = 0; + + virtual NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, + unsigned long memid) = 0; + + /** + * @brief Unregister the memory region + * + * @param mr [in] memory region registered + * + * @return 0 successful + */ + virtual void DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) = 0; + + /** + * @brief Connect to server with driver's oob ip or uds name + * + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * @param flags [in] flags + * @param serverGrpNo [in] indicates which client worker group to connect + * @param clientGrpNo [in] indicates which server worker group to connect to + * + * @return 0 successful + */ + virtual NResult Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, + uint32_t flags, uint8_t serverGrpNo, + uint8_t clientGrpNo) = 0; + + /** + * @brief Connect to server with driver's oob ip or uds name + * + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * @param serverGrpNo [in] indicates which client worker group to connect + * @param clientGrpNo [in] indicates which server worker group to connect to + * + * @return 0 successful + */ + virtual NResult Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, + uint8_t serverGrpNo, uint8_t clientGrpNo) + { + return Connect(payload, ep, 0, serverGrpNo, clientGrpNo); + } + + /** + * @brief Connect to server with driver's oob ip or uds name + * + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * @param flags [in] flags + * + * @return 0 successful + */ + virtual NResult + Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags) + { + return Connect(payload, ep, flags, 0, 0); + } + + /** + * @brief Connect to server with driver's oob ip or uds name + * + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * + * @return 0 successful + */ + virtual NResult Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep) + { + return Connect(payload, ep, 0, 0, 0); + } + + /** + * @brief Connect to server + * + * @param oobIpOrName [in] oob ip or name to connect, set ip for tcp and name for uds + * @param oobPort [in] only need to set when tcp oob + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * @param flags [in] flags + * @param serverGrpNo [in] indicates which client worker group to connect + * @param clientGrpNo [in] indicates which server worker group to connect to + * + * @return 0 successful + */ + NResult Connect(const std::string &oobIpOrName, uint16_t oobPort, + const std::string &payload, UBSHcomNetEndpointPtr &ep, + uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo) + { + return Connect(oobIpOrName, oobPort, payload, ep, flags, serverGrpNo, + clientGrpNo, 0); + }; + + /** + * @brief Connect to server + * + * @param serverUrl [in] oob url, e.g. tcp://127.0.0.1:9981 or uds://udsName + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * @param flags [in] flags + * @param serverGrpNo [in] indicates which client worker group to connect + * @param clientGrpNo [in] indicates which server worker group to connect to + * @param ctx [in] ctx in upstream + * + * @return 0 successful + */ + virtual NResult Connect(const std::string &serverUrl, const std::string &payload, + UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) = 0; + + /** + * @brief Connect to server + * + * @param oobIpOrName [in] oob ip or name to connect, set ip for tcp and name for uds + * @param oobPort [in] only need to set when tcp oob + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * @param flags [in] flags + * @param serverGrpNo [in] indicates which client worker group to connect + * @param clientGrpNo [in] indicates which server worker group to connect to + * @param ctx [in] ctx in upstream + * + * @return 0 successful + */ + virtual NResult Connect(const std::string &oobIpOrName, uint16_t oobPort, + const std::string &payload, UBSHcomNetEndpointPtr &ep, + uint32_t flags, uint8_t serverGrpNo, + uint8_t clientGrpNo, uint64_t ctx) = 0; + + /** + * @brief Connect to server + * + * @param oobIpOrName [in] oob ip or name to connect, set ip for tcp and name for uds + * @param oobPort [in] only need to set when tcp oob + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * @param serverGrpNo [in] indicates which client worker group to connect + * @param clientGrpNo [in] indicates which server worker group to connect to + * + * @return 0 successful + */ + inline NResult Connect(const std::string &oobIpOrName, uint16_t oobPort, + const std::string &payload, UBSHcomNetEndpointPtr &ep, + uint8_t serverGrpNo, uint8_t clientGrpNo) + { + return Connect(oobIpOrName, oobPort, payload, ep, 0, serverGrpNo, + clientGrpNo); + } + + /** + * @brief Connect to server + * + * @param oobIpOrName [in] oob ip or name to connect, set ip for tcp and name for uds + * @param oobPort [in] only need to set when tcp oob + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * @param flags [in] flags + * + * @return 0 successful + */ + virtual NResult Connect(const std::string &oobIpOrName, uint16_t oobPort, + const std::string &payload, UBSHcomNetEndpointPtr &ep, + uint32_t flags) + { + return Connect(oobIpOrName, oobPort, payload, ep, flags, 0, 0); + } + + /** + * @brief Connect to server + * + * @param oobIpOrName [in] oob ip or name to connect, set ip for tcp and name for uds + * @param oobPort [in] only need to set when tcp oob + * @param payload [in] payload transferred to peer, could be got EP Connected callback at server + * @param ep [out] connected end point + * + * @return 0 successful + */ + virtual NResult Connect(const std::string &oobIpOrName, uint16_t oobPort, + const std::string &payload, UBSHcomNetEndpointPtr &ep) + { + return Connect(oobIpOrName, oobPort, payload, ep, 0, 0, 0); + } + + virtual NResult MultiRailNewConnection(OOBTCPConnection &conn) = 0; + + /** + * @brief Destroy the endpoint + * + * @param ep [in] the end point to destroy + */ + virtual void DestroyEndpoint(UBSHcomNetEndpointPtr &ep) = 0; + + /** + * @brief Set out of bound ip and port + * + * @param ip [in] ip address + * @param port [out] port + */ + void OobIpAndPort(const std::string &ip, uint16_t port); + + /** + * @brief Set out of bound eid and jetty id used in public jetty + * + * @param eid [in] public jetty eid + * @param id [in] public jetty id + */ + void OobEidAndJettyId(const std::string &eid, uint16_t id); + + /** + * @brief Get out of bound ip and port + */ + bool GetOobIpAndPort(std::vector> &result); + + /** + * @brief Add multiple oob listeners, if there is only one listener just use OobIpAndPort + * + * @param option [in] listen options + */ + void AddOobOptions(const UBSHcomNetOobListenerOptions &option); + + /** + * @brief Set oob listener of uds type + * + * @param name [in] name of uds listener + * + */ + void OobUdsName(const std::string &name); + + /** + * @brief Add multiple oob uds listeners, if there is only one listener just use OobUdsName + * + * @param option [in] option of uds listener option + * + */ + void AddOobUdsOptions(const UBSHcomNetOobUDSListenerOptions &option); + + /** + * @brief Register callback for new end point connected from client, only need to register at server side + * + * @param handler [in] handler function + */ + void RegisterNewEPHandler(const UBSHcomNetDriverNewEndPointHandler &handler); + + /** + * @brief Register callback for end point broken + * + * @param handler [in] handler function + */ + void RegisterEPBrokenHandler(const UBSHcomNetDriverEndpointBrokenHandler &handler); + + /** + * @brief Register callback for new request from peer + * + * @param handler [in] handler function + */ + void RegisterNewReqHandler(const UBSHcomNetDriverReceivedHandler &handler); + + /** + * @brief Register callback for request posted to peer (send/read/write etc) + * + * @param handler [in] handler function + */ + void RegisterReqPostedHandler(const UBSHcomNetDriverSentHandler &handler); + + /** + * @brief Register callback for one side operation done + * + * @param handler [in] handler function + */ + void RegisterOneSideDoneHandler(const UBSHcomNetDriverOneSideDoneHandler &handler); + + /** + * @brief Register callback for idle + * + * @param handler [in] handler function + */ + void RegisterIdleHandler(const UBSHcomNetDriverIdleHandler &handler); + + /** + * @brief Register callback for idle + * + * @param handler [in] handler function + */ + void RegisterTLSCaCallback(const UBSHcomTLSCaCallback &cb); + + /** + * @brief Register callback for idle + * + * @param handler [in] handler function + */ + void RegisterTLSCertificationCallback(const UBSHcomTLSCertificationCallback &cb); + + /** + * @brief Register callback for idle + * + * @param handler [in] handler function + */ + void RegisterTLSPrivateKeyCallback(const UBSHcomTLSPrivateKeyCallback &cb); + + /** + * @brief Register callback for create secure info + * + * @param provider [in] provider function + */ + void RegisterEndpointSecInfoProvider( + const UBSHcomNetDriverEndpointSecInfoProvider &provider); + + /** + * @brief Register callback for validate secure info from peer + * + * @param validator [in] validator function + */ + void RegisterEndpointSecInfoValidator( + const UBSHcomNetDriverEndpointSecInfoValidator &validator); + + /** + * @brief Register psk callback for client + * + * @param cb [in] psk use session callback + */ + void RegisterPskUseSessionCb(const UBSHcomPskUseSessionCb &cb); + + /** + * @brief Register psk callback for server + * + * @param cb [in] psk find session callback + */ + void RegisterPskFindSessionCb(const UBSHcomPskFindSessionCb &cb); + + /** + * @brief Get the name of driver + */ + const std::string &Name() const; + + uint8_t GetId() const; + + /** + * @brief Get the protocol of driver + */ + UBSHcomNetDriverProtocol Protocol() const; + + /** + * @brief Get the result indicates whether driver is stopped + */ + bool IsStarted() const; + + /** + * @brief Get the result indicates whether driver is inited + */ + bool IsInited() const; + + virtual void *MapAndRegVaForUB(unsigned long memid, uint64_t &va) = 0; + + virtual NResult UnmapVaForUB(uint64_t &va) = 0; + + static void DumpObjectStatistics(); + + void SetPeerDevId(uint8_t index); + + uint8_t GetPeerDevId() const; + + inline void SetDeviceId(uint8_t index) + { + mDevIndex = index; + } + + inline uint8_t GetDeviceId() const + { + return mDevIndex; + } + + uint8_t GetBandWidth() const; + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +protected: + UBSHcomNetDriver(const std::string &name, bool startOobSvr, + UBSHcomNetDriverProtocol protocol) + : mName(name), mStartOobSvr(startOobSvr), mProtocol(protocol) + { + OBJ_GC_INCREASE(UBSHcomNetDriver); + } + +protected: + NResult CreateListeners(bool enableMultiRail = false); + NResult CreateUdsListeners(); + NResult CreateServerLB(); + NResult StartListeners(); + NResult StopListeners(bool clear = true); + + NResult CreateClientLB(); + void DestroyClientLB(); + + NResult ValidateAndParseOobPortRange(const char *oobPortRange); + // tcp://127.0.0.1:9981 or uds://name + NResult ParseUrl(const std::string &url, NetDriverOobType &type, std::string &ip, uint16_t &port); + + static NResult ValidateKunpeng(); + + NResult ValidateHandlesCheck(); + + NResult ValidateOptionsOobType(); + +protected: + std::mutex mInitMutex; + + bool mStarted = false; + UBSHcomNetDriverOptions mOptions; + + std::string mOobIp; + uint16_t mOobPort = 0; + std::string mUdsName; + uint8_t mIndex = 0; + uint8_t mPeerDevIndex = 0; + uint16_t mDevIndex = 0; + uint8_t mBandWidth = 0; + std::pair mPortRange{0, 0}; + + // hot used variables for start + std::string mName; + bool mStartOobSvr = true; + UBSHcomNetDriverProtocol mProtocol = UBSHcomNetDriverProtocol::RDMA; + bool mEnableTls = true; + uint32_t mMajorVersion = NN_NO1; + uint32_t mMinorVersion = 0; + std::atomic_bool mInited{false}; + + UBSHcomNetDriverReceivedHandler mReceivedRequestHandler = nullptr; + UBSHcomNetDriverSentHandler mRequestPostedHandler = nullptr; + UBSHcomNetDriverOneSideDoneHandler mOneSideDoneHandler = nullptr; + + UBSHcomNetDriverIdleHandler mIdleHandler = nullptr; + + UBSHcomNetDriverNewEndPointHandler mNewEndPointHandler = nullptr; + UBSHcomNetDriverEndpointBrokenHandler mEndPointBrokenHandler = nullptr; + + std::mutex mEndPointsMutex; + std::unordered_map mEndPoints; + + NetWorkerLB *mClientLb = nullptr; + NetWorkerLB *mServerLb = nullptr; + + std::vector mOobServers; + std::vector> mWorkerGroups; + + UBSHcomTLSPrivateKeyCallback mTlsPrivateKeyCB = nullptr; + UBSHcomTLSCertificationCallback mTlsCertCB = nullptr; + UBSHcomTLSCaCallback mTlsCaCallback = nullptr; + + UBSHcomNetDriverEndpointSecInfoProvider mSecInfoProvider = nullptr; + UBSHcomNetDriverEndpointSecInfoValidator mSecInfoValidator = nullptr; + + UBSHcomPskFindSessionCb mPskFindSessionCb = nullptr; + UBSHcomPskUseSessionCb mPskUseSessionCb = nullptr; + + std::vector mOobListenOptions; + std::unordered_map + mOobUdsListenOptions; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + +private: + static uint32_t gMaxListenPort; + static uint8_t gDriverIndex; + static std::mutex gDriverMapMutex; + static std::map gDriverMap; + static int32_t gOSMaxFdCount; // number of file descriptors that can be opened by each user process + friend class NetHeartbeat; +}; + +inline const std::string &UBSHcomNetDriver::Name() const +{ + return mName; +} + +inline uint8_t UBSHcomNetDriver::GetId() const +{ + return mIndex; +} + +inline UBSHcomNetDriverProtocol UBSHcomNetDriver::Protocol() const +{ + return mProtocol; +} + +inline bool UBSHcomNetDriver::IsStarted() const +{ + return mStarted; +} + +inline bool UBSHcomNetDriver::IsInited() const +{ + return mInited; +} + +inline void UBSHcomNetDriver::SetPeerDevId(uint8_t index) +{ + mPeerDevIndex = index; +} + +inline uint8_t UBSHcomNetDriver::GetPeerDevId() const +{ + return mPeerDevIndex; +} + +inline uint8_t UBSHcomNetDriver::GetBandWidth() const +{ + return mBandWidth; +} +} // namespace hcom +} // namespace ock + +#endif // OCK_HCOM_CPP_H_34562 diff --git a/src/hcom_def.h b/src/hcom_def.h new file mode 100644 index 0000000000000000000000000000000000000000..c9abc7395b24eacbfe53c622240b3c3af447d636 --- /dev/null +++ b/src/hcom_def.h @@ -0,0 +1,279 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_CPP_UTIL_H_34523 +#define OCK_HCOM_CPP_UTIL_H_34523 + +#include +#include +#include + +#include "hcom_num_def.h" + +namespace ock { +namespace hcom { +#define NET_FLAGS_BIT(i) (1UL << (i)) + +#define SLAVE1_PHYSICAL_ADDRESS 0x286f00000000 // hard-coded PA on slave1 +#define SLAVE2_PHYSICAL_ADDRESS 0x686f00000000 // hard-coded PA on slave2 +#define OBMM_SIZE 1 << 27 // 128M + +constexpr const uint32_t NET_SGE_MAX_SIZE = 524288000; +constexpr const uint32_t NET_STR_ERROR_BUF_SIZE = 128; +constexpr const uint32_t NET_SGE_MAX_IOV = 4; + +// enum num should less than 128 +enum NET_FLAGS { + NTH_TWO_SIDE = 0, + NTH_TWO_SIDE_SGL = 1, + NTH_REPLY_REQUIRED = 2, + NTH_READ = 3, + NTH_READ_ACK = 4, + NTH_WRITE = 5, + NTH_WRITE_ACK = 6, + NTH_READ_SGL = 7, + NTH_READ_SGL_ACK = 8, + NTH_WRITE_SGL = 9, + NTH_WRITE_SGL_ACK = 10 +}; + +/* opcode specifically reserved for some operations */ +enum NetPrivateOpCode { + HB_SEND_OP = 1024, + HB_RECV_OP = 1025, + MR_INFO_OP = 1026, +}; + +#if __GNUC__ == 4 && __GNUC_MINOR__ == 8 && __GNUC_PATCHLEVEL__ == 5 +template T exchangeHcom(T &obj, U &&new_value) +{ + T old_value = std::move(obj); + obj = std::forward(new_value); + return old_value; +} +#endif + +template class UBSHcomNetAtomicState { +public: + explicit UBSHcomNetAtomicState(T state) : mState(state) {} + + ~UBSHcomNetAtomicState() = default; + + inline void Set(T newState) + { + __sync_lock_test_and_set(&mState, newState); + } + + inline bool CAS(T oldState, T newState) + { + return __sync_bool_compare_and_swap(&mState, oldState, newState); + } + + inline bool Compare(T state) const + { + return mState == state; + } + + inline T Get() const + { + return mState; + } + + UBSHcomNetAtomicState() = default; + + UBSHcomNetAtomicState(const UBSHcomNetAtomicState &) = delete; + + UBSHcomNetAtomicState(UBSHcomNetAtomicState &&) = delete; + + UBSHcomNetAtomicState &operator = (const UBSHcomNetAtomicState &) = delete; + + UBSHcomNetAtomicState &operator = (UBSHcomNetAtomicState &&) = delete; + +private: + T mState; +}; + +/** + * @brief const variables + */ +const std::string CONST_EMPTY_STRING; + +using NResult = int32_t; + +constexpr uint16_t INVALID_WORKER_INDEX = 0xFFFF; +constexpr uint8_t INVALID_WORKER_GROUP_INDEX = 0xFF; +constexpr uint32_t INVALID_IP = 0xFFFFFFFF; + +struct UBSHcomNetTransSgeIov { + uintptr_t lAddress = 0; + uintptr_t rAddress = 0; + uint64_t lKey = 0; + uint64_t rKey = 0; + uint32_t size = 0; + unsigned long memid = 0; // indicate obmm memory used by urma in rndv + void *srcSeg; // ptr to description of src mem seg for urma + void *dstSeg; // ptr to description of dst mem seg for urma + + UBSHcomNetTransSgeIov() = default; + UBSHcomNetTransSgeIov(uintptr_t lAddress, uintptr_t rAddress, uint64_t lKey, uint64_t rKey, uint32_t size) + : lAddress(lAddress), rAddress(rAddress), lKey(lKey), rKey(rKey), size(size), + srcSeg(nullptr), dstSeg(nullptr) {} + UBSHcomNetTransSgeIov(uintptr_t lAddr, uint64_t lK, uint32_t s) + : lAddress(lAddr), lKey(lK), size(s), srcSeg(nullptr), dstSeg(nullptr) {} +} __attribute__((packed)); + +struct UBSHcomNetTransDataIov { + uintptr_t address = 0; + uint64_t key = 0; + uint32_t size = 0; + UBSHcomNetTransDataIov() = default; + UBSHcomNetTransDataIov(uintptr_t address, uint64_t key, uint32_t size) + : address(address), key(key), size(size) {} +}; + +/// 本类型主要用于描述存在于传输层 payload 中可能存在的额外头部,以指导服务层如 +/// 何处理 payload. +enum class UBSHcomExtHeaderType : uint32_t { + RAW = 0, ///< 裸 payload + FRAGMENT, ///< SplitSend 专用的分片头,对应下方的 UBSHcomFragmentHeader +}; + +struct UBSHcomNetTransHeader { + uint32_t headerCrc = 0; // header crc code + int16_t opCode = 0; // user define op code, it can be 0~1023 + uint16_t flags = 0; // flags on the header, the upper 8 bits are reserved for the user + uint32_t seqNo = 0; // seq no + int16_t timeout = 0; // timeout from client + int16_t errorCode = 0; // error code for response + uint32_t dataLength = 0; // body length + uint32_t immData = 0; // immData + UBSHcomExtHeaderType extHeaderType = UBSHcomExtHeaderType::RAW; // 传输层 payload 中是否存在服务层的头部 + + UBSHcomNetTransHeader() = default; + + inline void Invalid() + { + opCode = NN_NO1024; + seqNo = NN_NO0; + errorCode = NN_NO0; + dataLength = NN_NO0; + } +} __attribute__((packed)); + +/// 服务层分片头所属消息 ID +struct UBSHcomFragmentMessageId { + uint64_t epId; ///< endpoint ID,用来区分不同连接发送的消息 + uint64_t seqNo; ///< endpoint 自増 seqNo,用来区分同一连接发送的不同消息 + + bool operator==(const UBSHcomFragmentMessageId &rhs) const + { + return epId == rhs.epId && seqNo == rhs.seqNo; + } + + bool operator<(const UBSHcomFragmentMessageId &rhs) const + { + return epId != rhs.epId ? epId < rhs.epId : seqNo < rhs.seqNo; + } + + friend std::ostream &operator<<(std::ostream &os, const UBSHcomFragmentMessageId &rhs) + { + return os << "(" << rhs.epId << " " << rhs.seqNo << ")"; + } +}; + +/// 服务层分片头,当服务层一次性发送数据大小超过单个 MemoryRegion 容量时,需要记 +/// 录单个 fragment 信息以供接收端恢复. +struct UBSHcomFragmentHeader { + UBSHcomFragmentMessageId msgId; ///< 分片所属原消息 ID + uint32_t totalLength; ///< 原消息总大小,接收端在收到分片时会首先分配 totalLength 大小的内存 + uint32_t offset; ///< 分片在原消息中的偏移 +} __attribute__((packed)); + +/** + * @brief Worker index + */ +union UBSHcomNetWorkerIndex { + struct { + uint32_t idxInGrp : 16; /* index in one worker group */ + uint32_t grpIdx : 8; /* index of the group in the net driver */ + uint32_t driverIdx : 8; /* index of the net driver in one process */ + }; + uint32_t wholeIdx = 0; + + void Set(uint32_t idx, uint32_t gIdx, uint16_t dIdx) + { + idxInGrp = idx; + grpIdx = gIdx; + driverIdx = dIdx; + } + + std::string ToString() const + { + std::ostringstream oss; + oss << driverIdx << "-" << grpIdx << "-" << idxInGrp; + return oss.str(); + } +}; + +/** + * @brief Device information for user + */ +struct UBSHcomNetDriverDeviceInfo { + int maxSge = NN_NO4; // max iov count in UBSHcomNetTransSglRequest +} __attribute__((packed)); + +/** + * @brief macros + */ +// macro for reference count +#ifndef DEFINE_RDMA_REF_COUNT_FUNCTIONS +#define DEFINE_RDMA_REF_COUNT_FUNCTIONS \ +public: \ + inline void IncreaseRef() \ + { \ + __sync_fetch_and_add(&mRefCount, 1); \ + } \ + \ + inline void DecreaseRef() \ + { \ + int32_t tmp = __sync_sub_and_fetch(&mRefCount, 1); \ + if (tmp == 0) { \ + delete this; \ + } \ + } \ + \ + inline int32_t GetRef() \ + { \ + return __sync_sub_and_fetch(&mRefCount, 0); \ + } +#endif + +#ifndef DEFINE_RDMA_REF_COUNT_VARIABLE +#define DEFINE_RDMA_REF_COUNT_VARIABLE \ +private: \ + int32_t mRefCount = 0 +#endif + +// macro for gcc optimization for prediction of if/else +#ifndef NN_LIKELY +#define NN_LIKELY(x) (__builtin_expect(!!(x), 1) != 0) +#endif + +#ifndef NN_UNLIKELY +#define NN_UNLIKELY(x) (__builtin_expect(!!(x), 0) != 0) +#endif + +// macro for flag +#define NN_FLAGS_BIT(i) (1UL << (i)) +} +} + +#endif // OCK_HCOM_CPP_UTIL_H_34523 diff --git a/src/hcom_err.h b/src/hcom_err.h new file mode 100644 index 0000000000000000000000000000000000000000..6815421f81ecbacf8a01fa415773fa4f1718915e --- /dev/null +++ b/src/hcom_err.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_CPP_ERR_H +#define OCK_HCOM_CPP_ERR_H + +namespace ock { +namespace hcom { + +enum NNCode { + NN_OK = 0, + NN_ERROR = 100, + NN_INVALID_IP = 101, + NN_NEW_OBJECT_FAILED = 102, + NN_INVALID_PARAM = 103, + NN_TWO_SIDE_MESSAGE_TOO_LARGE = 104, + NN_INVALID_OPCODE = 105, + NN_EP_NOT_ESTABLISHED = 106, + NN_EP_NOT_INITIALIZED = 107, + NN_BLOCK_QUEUE_SEM_INIT_FAILED = 108, + NN_TIMEOUT = 109, + NN_INVALID_OPERATION = 110, + NN_MALLOC_FAILED = 111, + NN_SEQ_NO_NOT_MATCHED = 112, + NN_NOT_INITIALIZED = 113, + NN_GET_BUFF_FAILED = 114, + NN_MSG_TIMEOUT = 115, + NN_MSG_CANCELED = 116, + NN_MSG_ERROR = 117, + NN_CONNECT_REFUSED = 118, + NN_CONNECT_PROTOCOL_MISMATCH = 119, + NN_INVALID_LKEY = 120, + NN_EP_BROKEN = 121, + NN_EP_CLOSE = 122, + NN_PARAM_INVALID = 123, + NN_OOB_LISTEN_SOCKET_ERROR = 124, + NN_OOB_CONN_SEND_ERROR = 125, + NN_OOB_CONN_RECEIVE_ERROR = 126, + NN_OOB_CONN_CB_NOT_SET = 127, + NN_OOB_CLIENT_SOCKET_ERROR = 128, + NN_OOB_SSL_INIT_ERROR = 129, + NN_OOB_SSL_WRITE_ERROR = 130, + NN_OOB_SSL_READ_ERROR = 131, + NN_HEARTBEAT_CREATE_EPOLL_FAILED = 132, + NN_HEARTBEAT_SET_SOCKET_OPT_FAILED = 133, + NN_HEARTBEAT_IP_ALREADY_EXISTED = 134, + NN_HEARTBEAT_IP_ADD_FAILED = 135, + NN_HEARTBEAT_IP_ADD_EPOLL_FAILED = 136, + NN_HEARTBEAT_IP_REMOVE_EPOLL_FAILED = 137, + NN_HEARTBEAT_IP_NO_FOUND = 138, + NN_ENCRYPT_FAILED = 139, + NN_DECRYPT_FAILED = 140, + NN_OOB_SEC_PROCESS_ERROR = 141, + NN_EXCHANGE_FD_NOT_SUPPORT = 142, + NN_VALIDATE_HEADER_CRC_INVALID = 143, + NN_UDS_ID_INFO_NOT_SUPPORT = 144, + NN_GET_UDS_ID_INFO_FAILED = 145, + NN_VERSION_CHECK_FAILED = 146, + NN_URMA_ACCESS_ABRT = 147, + NN_URMA_ACK_TIMEOUT = 148, +}; + +enum SerCode { + SER_OK = 0, + SER_ERROR = 500, + SER_INVALID_PARAM = 501, + SER_NEW_OBJECT_FAILED = 502, + SER_CREATE_TIMEOUT_THREAD_FAILED = 503, + SER_NEW_MESSAGE_DATA_FAILED = 504, + SER_NOT_ESTABLISHED = 505, + SER_STORE_SEQ_DUP = 506, + SER_STORE_SEQ_NO_FOUND = 507, + SER_RSP_SIZE_TOO_SMALL = 508, + SER_TIMEOUT = 509, + SER_TIMER_NOT_WORK = 510, + SER_NOT_ENABLE_RNDV = 511, + SER_RNDV_FAILED_BY_PEER = 512, + SER_CHANNEL_ID_DUP = 513, + SER_EP_NOT_BROKEN_ALL = 514, + SER_CHANNEL_NOT_EXIST = 515, + SER_CHANNEL_RECONNECT_OVER_WINDOW = 516, + SER_EP_BROKEN_DURING_CONNECTING = 517, + SER_NOT_SUPPORT_SERVER_RECONNECT = 518, + SER_STOP = 519, + SER_NULL_INSTANCE = 520, + SER_UNSUPPORTED = 521, + SER_INVALID_IP = 522, + SER_MALLOC_FAILED = 523, + SER_SPLIT_INVALID_MSG = 524, +}; +} +} + +#endif // OCK_HCOM_CPP_ERR_H diff --git a/src/hcom_log.cpp b/src/hcom_log.cpp new file mode 100644 index 0000000000000000000000000000000000000000..765a434b8f9e38978ba3b802f0abd158356f71c8 --- /dev/null +++ b/src/hcom_log.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "hcom_log.h" + +namespace ock { +namespace hcom { +UBSHcomNetOutLogger *UBSHcomNetOutLogger::gLogger = nullptr; +std::mutex UBSHcomNetOutLogger::gMutex; +int UBSHcomNetOutLogger::logLevel = NN_NO1; +} +} \ No newline at end of file diff --git a/src/hcom_log.h b/src/hcom_log.h new file mode 100644 index 0000000000000000000000000000000000000000..daa0efefa13f9bd7c834f3523ef73bbabe1b4dfb --- /dev/null +++ b/src/hcom_log.h @@ -0,0 +1,191 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_COMM_LOG_12456845341233_H +#define OCK_COMM_LOG_12456845341233_H + +#include +#include +#include +#include +#include +#include +#include "hcom_def.h" + +namespace ock { +namespace hcom { +using ExternalLog = void (*)(int level, const char *msg); + +class UBSHcomNetOutLogger { +public: + static UBSHcomNetOutLogger *Instance() + { + if (NN_UNLIKELY(gLogger == nullptr)) { + std::lock_guard lock(gMutex); + if (gLogger == nullptr) { + gLogger = new (std::nothrow) UBSHcomNetOutLogger(); + if (gLogger == nullptr) { + std::cout << "Failed to new UBSHcomNetOutLogger, probably out of memory" << std::endl; + } + SetLogLevel(); + } + } + + return gLogger; + } + + static void SetLogLevel() + { + /* set one of 0,1,2,3 */ + char *envSize = ::getenv("HCOM_SET_LOG_LEVEL"); + if (envSize != nullptr) { + long value = 0; + if (!SetStrStol(envSize, value)) { + std::cout << "Invalid setting 'HCOM_SET_LOG_LEVEL', should set one of 0,1,2,3 " << std::endl; + return; + } + logLevel = value; + } + } + + inline static void SetLogLevel(int level) + { + if (level >= static_cast(NN_NO0) && level <= static_cast(NN_NO3)) { + logLevel = level; + } + } + + static bool SetStrStol(const std::string &str, long &value) + { + char *remain = nullptr; + errno = 0; + value = std::strtol(str.c_str(), &remain, 10); // 10 is decimal digits + if (remain == nullptr || strlen(remain) > 0 || value < NN_NO0 || value > NN_NO3 || errno == ERANGE) { + return false; + } else if (value == 0 && str != "0") { + return false; + } + + return true; + } + + inline void SetExternalLogFunction(ExternalLog func) + { + mLogFunc = func; + } + + static inline void Print(int level, const char *msg) + { + // See NN_LOG_DEBUG, NN_LOG_INFO, NN_LOG_WARN and NN_LOG_ERROR + const char *levelStr[] = {"DEBUG", "INFO", "WARN", "ERROR"}; + + struct timeval tv{}; + char strTime[24]; + + int ret = gettimeofday(&tv, nullptr); + if (ret != 0) { + std::cout << "Fail to get the current system time, " << ret << ".\n"; + } + time_t timeStamp = tv.tv_sec; + struct tm localTime{}; + struct tm *resultTime = localtime_r(&timeStamp, &localTime); + if ((resultTime != nullptr) && + (strftime(strTime, sizeof strTime, "%Y-%m-%d %H:%M:%S.", resultTime) != NN_NO0)) { + std::cout << strTime << tv.tv_usec << " " << levelStr[level] << " " << msg << '\n'; + } else { + std::cout << "Invalid time trace " << tv.tv_usec << " " << levelStr[level] << " " << msg << '\n'; + } + } + + inline void Log(int level, const std::ostringstream &oss) const + { + if (NN_LIKELY(mLogFunc != nullptr)) { + mLogFunc(level, oss.str().c_str()); + } else { + Print(level, oss.str().c_str()); + } + } + + UBSHcomNetOutLogger(const UBSHcomNetOutLogger &) = delete; + UBSHcomNetOutLogger &operator = (const UBSHcomNetOutLogger &) = delete; + UBSHcomNetOutLogger(UBSHcomNetOutLogger &&) = delete; + UBSHcomNetOutLogger &operator = (UBSHcomNetOutLogger &&) = delete; + + ~UBSHcomNetOutLogger() + { + mLogFunc = nullptr; + } + + inline int GetLogLevel() const + { + return logLevel; + } + +private: + UBSHcomNetOutLogger() = default; + +private: + static UBSHcomNetOutLogger *gLogger; + static std::mutex gMutex; + static int logLevel; + + ExternalLog mLogFunc = nullptr; +}; + +// macro for log +#ifndef NN_LOG_FILENAME +#define NN_LOG_FILENAME (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__) +#endif + +#define NN_LOG(level, args) \ + do { \ + if ((level) >= (UBSHcomNetOutLogger::Instance()->GetLogLevel())) { \ + std::ostringstream oss; \ + oss << "[HCOM " << NN_LOG_FILENAME << ":" << __LINE__ << "] " << args; \ + UBSHcomNetOutLogger::Instance()->Log(level, oss); \ + } \ + } while (0) + +#define NN_LOG_PRINT(level, args) \ + do { \ + if ((level) >= (UBSHcomNetOutLogger::Instance()->GetLogLevel())) { \ + std::ostringstream oss; \ + oss << "[HCOM " << NN_LOG_FILENAME << ":" << __LINE__ << "] " << (args); \ + UBSHcomNetOutLogger::Instance()->Print(level, oss.str().c_str()); \ + } \ + } while (0) + +#define NN_LOG_DEBUG(args) NN_LOG(0, args) +#define NN_LOG_INFO(args) NN_LOG(1, args) +#define NN_LOG_WARN(args) NN_LOG(2, args) +#define NN_LOG_ERROR(args) NN_LOG(3, args) + +#define NN_ASSERT_LOG_RETURN(args, RET) \ + if (NN_UNLIKELY(!(args))) { \ + NN_LOG_ERROR("Assert " << #args); \ + return RET; \ + } + +#define NN_ASSERT_LOG_RETURN_VOID(args) \ + if (NN_UNLIKELY(!(args))) { \ + NN_LOG_ERROR("Assert " << #args); \ + return; \ + } + +#ifdef NN_LOG_TRACE_INFO_ENABLED +#define NN_LOG_TRACE_INFO(args) NN_LOG_INFO(args) +#else +#define NN_LOG_TRACE_INFO(x) +#endif +} +} + +#endif // OCK_COMM_LOG_12456845341233_H diff --git a/src/hcom_num_def.h b/src/hcom_num_def.h new file mode 100644 index 0000000000000000000000000000000000000000..acb00c13f02f61c5cc27743dc7a95ce91e8a5d35 --- /dev/null +++ b/src/hcom_num_def.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_CPP_NUM_DEF_H +#define OCK_HCOM_CPP_NUM_DEF_H + +#include +#include + +namespace ock { +namespace hcom { + +constexpr int32_t NN_NOF1 = -1; +constexpr int32_t NN_NOF20 = -20; +constexpr uint32_t NN_NO0 = 0; +constexpr uint32_t NN_NO1 = 1; +constexpr uint32_t NN_NO2 = 2; +constexpr uint32_t NN_NO3 = 3; +constexpr uint32_t NN_NO4 = 4; +constexpr uint32_t NN_NO5 = 5; +constexpr uint32_t NN_NO6 = 6; +constexpr uint32_t NN_NO7 = 7; +constexpr uint32_t NN_NO8 = 8; +constexpr uint32_t NN_NO9 = 9; +constexpr uint32_t NN_NO10 = 10; +constexpr uint32_t NN_NO12 = 12; +constexpr uint32_t NN_NO14 = 14; +constexpr uint32_t NN_NO15 = 15; +constexpr uint32_t NN_NO16 = 16; +constexpr uint32_t NN_NO17 = 17; +constexpr uint8_t NN_NO19 = 19; +constexpr uint32_t NN_NO20 = 20; +constexpr uint32_t NN_NO24 = 24; +constexpr uint32_t NN_NO26 = 26; +constexpr uint32_t NN_NO28 = 28; +constexpr uint32_t NN_NO29 = 29; +constexpr uint32_t NN_NO31 = 31; +constexpr uint32_t NN_NO32 = 32; +constexpr uint32_t NN_NO39 = 39; +constexpr uint32_t NN_NO40 = 40; +constexpr uint32_t NN_NO48 = 48; +constexpr uint32_t NN_NO50 = 50; +constexpr uint32_t NN_NO56 = 56; +constexpr uint32_t NN_NO58 = 58; +constexpr uint32_t NN_NO60 = 60; +constexpr uint32_t NN_NO63 = 63; +constexpr uint32_t NN_NO64 = 64; +constexpr uint32_t NN_NO90 = 90; +constexpr uint32_t NN_NO96 = 96; +constexpr uint32_t NN_NO100 = 100; +constexpr uint32_t NN_NO108 = 108; +constexpr uint32_t NN_NO120 = 120; +constexpr uint32_t NN_NO123 = 123; +constexpr uint32_t NN_NO124 = 124; +constexpr uint32_t NN_NO128 = 128; +constexpr uint32_t NN_NO165 = 165; +constexpr uint32_t NN_NO180 = 180; +constexpr uint32_t NN_NO200 = 200; +constexpr uint32_t NN_NO250 = 250; +constexpr uint32_t NN_NO256 = 256; +constexpr uint32_t NN_NO255 = 255; +constexpr uint32_t NN_NO260 = 260; +constexpr uint32_t NN_NO290 = 290; +constexpr uint32_t NN_NO300 = 300; +constexpr uint32_t NN_NO400 = 400; +constexpr uint32_t NN_NO500 = 500; +constexpr uint32_t NN_NO0600 = 0600; +constexpr uint32_t NN_NO512 = 512; +constexpr uint32_t NN_NO770 = 770; +constexpr uint32_t NN_NO771 = 771; +constexpr uint32_t NN_NO772 = 772; +constexpr uint32_t NN_NO1000 = 1000; +constexpr uint32_t NN_NO1023 = 1023; +constexpr uint32_t NN_NO1024 = 1024; +constexpr uint32_t NN_NO1200 = 1200; +constexpr uint32_t NN_NO1900 = 1900; +constexpr uint32_t NN_NO2048 = 2048; +constexpr uint32_t NN_NO2000 = 2000; +constexpr uint32_t NN_NO4096 = 4096; +constexpr uint16_t NN_NO7200 = 7200; +constexpr uint16_t NN_NO8192 = 8192; +constexpr uint32_t NN_NO10000 = 10000; +constexpr uint32_t NN_NO32768 = 32768; +constexpr uint16_t NN_NO65535 = 65535; +constexpr uint32_t NN_NO65536 = 65536; +constexpr uint32_t NN_NO100000 = 100000; +constexpr uint32_t NN_NO262144 = 262144; +constexpr uint32_t NN_NO500000 = 500000; +constexpr uint32_t NN_NO1000000 = 1000000; +constexpr uint32_t NN_NO1048576 = 1048576; +constexpr uint32_t NN_NO2000000 = 2000000; +constexpr uint32_t NN_NO2097152 = 2097152; +constexpr uint32_t NN_NO8388608 = 8388608; +constexpr uint32_t NN_NO16777216 = 16777216; +constexpr uint32_t NN_NO1073741824 = 1073741824; +constexpr uint32_t NN_NO2147483646 = 2147483646; +constexpr uint32_t NN_NO1000000000 = 1000000000; +constexpr uint64_t NN_NO107374182400 = 107374182400; +constexpr uint64_t NN_NO1099511627776 = 1099511627776; // 1TB + +} +} + +#endif // OCK_HCOM_CPP_NUM_DEF_H diff --git a/src/hcom_obj_statistics.cpp b/src/hcom_obj_statistics.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7abab26b4ee55f215810324e82e72f6877e00dd --- /dev/null +++ b/src/hcom_obj_statistics.cpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "hcom.h" +#include "hcom_obj_statistics.h" + +namespace ock { +namespace hcom { +DEFINE_OBJ_GC(NetService); +DEFINE_OBJ_GC(UBSHcomService); +DEFINE_OBJ_GC(HcomServiceImp); +DEFINE_OBJ_GC(NetServiceDefaultImp); +DEFINE_OBJ_GC(NetChannel); +DEFINE_OBJ_GC(UBSHcomChannel); +DEFINE_OBJ_GC(HcomChannelImp); +DEFINE_OBJ_GC(NetPeriodicManager); +DEFINE_OBJ_GC(HcomPeriodicManager); +DEFINE_OBJ_GC(NetMemPoolFixed); +DEFINE_OBJ_GC(NetServiceCtxStore); +DEFINE_OBJ_GC(HcomServiceCtxStore); +DEFINE_OBJ_GC(NetServiceTimer); +DEFINE_OBJ_GC(HcomServiceTimer); + +DEFINE_OBJ_GC(UBSHcomNetDriver); +DEFINE_OBJ_GC(UBSHcomNetEndpoint); +DEFINE_OBJ_GC(UBSHcomNetMessage); +DEFINE_OBJ_GC(UBSHcomNetRequestContext); +DEFINE_OBJ_GC(UBSHcomNetResponseContext); + +DEFINE_OBJ_GC(RDMAWorker); +DEFINE_OBJ_GC(RDMAEndpoint); +DEFINE_OBJ_GC(RDMACq); +DEFINE_OBJ_GC(RDMAContext); +DEFINE_OBJ_GC(RDMAQp); +DEFINE_OBJ_GC(RDMAAsyncEndPoint); +DEFINE_OBJ_GC(RDMASyncEndpoint); +DEFINE_OBJ_GC(RDMAMemoryRegion); +DEFINE_OBJ_GC(RDMAMemoryRegionFixedBuffer); +DEFINE_OBJ_GC(NetDriverRDMA); +DEFINE_OBJ_GC(NetDriverRDMAWithOob); +DEFINE_OBJ_GC(NetAsyncEndpoint); +DEFINE_OBJ_GC(NetSyncEndpoint); + +#ifdef UB_BUILD_ENABLED +DEFINE_OBJ_GC(UBContext); +DEFINE_OBJ_GC(UBWorker); +DEFINE_OBJ_GC(NetDriverUB); +DEFINE_OBJ_GC(NetDriverUBWithOob); +DEFINE_OBJ_GC(NetUBAsyncEndpoint); +DEFINE_OBJ_GC(NetUBSyncEndpoint); +DEFINE_OBJ_GC(UBJfc); +DEFINE_OBJ_GC(UBJetty); +DEFINE_OBJ_GC(UBPublicJetty); +DEFINE_OBJ_GC(UBMemoryRegion); +DEFINE_OBJ_GC(UBMemoryRegionFixedBuffer); +#endif + +DEFINE_OBJ_GC(NetDriverSockWithOOB); +DEFINE_OBJ_GC(NetAsyncEndpointSock); +DEFINE_OBJ_GC(NetSyncEndpointSock); +DEFINE_OBJ_GC(SockWorker); +DEFINE_OBJ_GC(SockBuff); +DEFINE_OBJ_GC(Sock); + +DEFINE_OBJ_GC(NetDriverShmWithOOB); +DEFINE_OBJ_GC(NetAsyncEndpointShm); +DEFINE_OBJ_GC(NetSyncEndpointShm); +DEFINE_OBJ_GC(ShmChannel); +DEFINE_OBJ_GC(ShmChannelKeeper); +DEFINE_OBJ_GC(ShmDataChannel); +DEFINE_OBJ_GC(ShmHandle); +DEFINE_OBJ_GC(ShmMemoryRegion); +DEFINE_OBJ_GC(ShmQueue); +DEFINE_OBJ_GC(ShmWorker); +DEFINE_OBJ_GC(ShmSyncEndpoint); + +void NetObjStatistic::Dump() +{ + std::ostringstream ossDump; + ossDump << "Object global count:\n"; +#ifdef ENABLE_OBJ_GLOBAL_STATISTICS + OBJ_GC_DUMP(NetService); + OBJ_GC_DUMP(NetServiceDefaultImp); + OBJ_GC_DUMP(NetChannel); + OBJ_GC_DUMP(NetPeriodicManager); + OBJ_GC_DUMP(NetMemPoolFixed); + OBJ_GC_DUMP(NetServiceCtxStore); + OBJ_GC_DUMP(NetServiceTimer); + + OBJ_GC_DUMP(UBSHcomNetDriver); + OBJ_GC_DUMP(UBSHcomNetEndpoint); + OBJ_GC_DUMP(UBSHcomNetMessage); + OBJ_GC_DUMP(UBSHcomNetRequestContext); + OBJ_GC_DUMP(UBSHcomNetResponseContext); + + OBJ_GC_DUMP(RDMAWorker); + OBJ_GC_DUMP(RDMAEndpoint); + OBJ_GC_DUMP(RDMACq); + OBJ_GC_DUMP(RDMAContext); + OBJ_GC_DUMP(RDMAQp); + OBJ_GC_DUMP(RDMAAsyncEndPoint); + OBJ_GC_DUMP(RDMASyncEndpoint); + OBJ_GC_DUMP(RDMAMemoryRegion); + OBJ_GC_DUMP(RDMAMemoryRegionFixedBuffer); + OBJ_GC_DUMP(NetDriverRDMA); + OBJ_GC_DUMP(NetDriverRDMAWithOob); + OBJ_GC_DUMP(NetAsyncEndpoint); + OBJ_GC_DUMP(NetSyncEndpoint); + +#ifdef UB_BUILD_ENABLED + OBJ_GC_DUMP(UBContext); + OBJ_GC_DUMP(UBWorker); + OBJ_GC_DUMP(NetDriverUB); + OBJ_GC_DUMP(NetDriverUBWithOob); + OBJ_GC_DUMP(NetUBAsyncEndpoint); + OBJ_GC_DUMP(NetUBSyncEndpoint); + OBJ_GC_DUMP(UBJfc); + OBJ_GC_DUMP(UBJetty); + OBJ_GC_DUMP(UBPublicJetty); + OBJ_GC_DUMP(UBMemoryRegion); + OBJ_GC_DUMP(UBMemoryRegionFixedBuffer); +#endif + + OBJ_GC_DUMP(NetDriverSockWithOOB); + OBJ_GC_DUMP(NetAsyncEndpointSock); + OBJ_GC_DUMP(NetSyncEndpointSock); + OBJ_GC_DUMP(SockWorker); + OBJ_GC_DUMP(SockBuff); + OBJ_GC_DUMP(Sock); + + OBJ_GC_DUMP(NetDriverShmWithOOB); + OBJ_GC_DUMP(NetAsyncEndpointShm); + OBJ_GC_DUMP(NetSyncEndpointShm); + OBJ_GC_DUMP(ShmChannel); + OBJ_GC_DUMP(ShmChannelKeeper); + OBJ_GC_DUMP(ShmDataChannel); + OBJ_GC_DUMP(ShmHandle); + OBJ_GC_DUMP(ShmMemoryRegion); + OBJ_GC_DUMP(ShmQueue); + OBJ_GC_DUMP(ShmWorker); + OBJ_GC_DUMP(ShmSyncEndpoint); +#else + ossDump << "\tDisabled"; +#endif + NN_LOG_INFO(ossDump.str()); +} +} +} diff --git a/src/hcom_obj_statistics.h b/src/hcom_obj_statistics.h new file mode 100644 index 0000000000000000000000000000000000000000..4d0243d615c00f1e9449e511ce1ddc5c177f8d44 --- /dev/null +++ b/src/hcom_obj_statistics.h @@ -0,0 +1,116 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_OBJ_STATISTICS_H +#define OCK_HCOM_NET_OBJ_STATISTICS_H + +// macro object statistic +#ifdef ENABLE_OBJ_GLOBAL_STATISTICS +/* declare variable in NetObjStatistic */ +#define DECLARE_OBJ_GC(CLASSNAME) static int32_t GC##CLASSNAME +/* initialize the variable define in NetObjStatistic */ +#define DEFINE_OBJ_GC(CLASSNAME) int32_t NetObjStatistic::GC##CLASSNAME = 0 +/* increase the object count, which should be added into constructor */ +#define OBJ_GC_INCREASE(CLASSNAME) __sync_fetch_and_add(&NetObjStatistic::GC##CLASSNAME, 1) +/* decrease the object count, which should be added into destructor */ +#define OBJ_GC_DECREASE(CLASSNAME) __sync_sub_and_fetch(&NetObjStatistic::GC##CLASSNAME, 1) +/* dump object count, which should be in dump function */ +#define OBJ_GC_DUMP(CLASSNAME) ossDump << "\t" << #CLASSNAME << ": " << GC##CLASSNAME << "\n" +#else +#define DECLARE_OBJ_GC(CLASSNAME) +#define DEFINE_OBJ_GC(CLASSNAME) +#define OBJ_GC_INCREASE(CLASSNAME) +#define OBJ_GC_DECREASE(CLASSNAME) +#define OBJ_GC_DUMP(CLASSNAME) +#endif + +namespace ock { +namespace hcom { +class NetObjStatistic { +public: + DECLARE_OBJ_GC(NetService); + DECLARE_OBJ_GC(UBSHcomService); + DECLARE_OBJ_GC(HcomServiceImp); + DECLARE_OBJ_GC(NetServiceDefaultImp); + DECLARE_OBJ_GC(NetServiceMultiRailImp); + DECLARE_OBJ_GC(ServiceNetDriverManager); + DECLARE_OBJ_GC(ServiceDriverManagerOob); + DECLARE_OBJ_GC(NetChannel); + DECLARE_OBJ_GC(UBSHcomChannel); + DECLARE_OBJ_GC(HcomChannelImp); + DECLARE_OBJ_GC(MultiRailNetChannel); + DECLARE_OBJ_GC(NetPeriodicManager); + DECLARE_OBJ_GC(HcomPeriodicManager); + DECLARE_OBJ_GC(NetMemPoolFixed); + DECLARE_OBJ_GC(NetServiceCtxStore); + DECLARE_OBJ_GC(HcomServiceCtxStore); + DECLARE_OBJ_GC(NetServiceTimer); + DECLARE_OBJ_GC(HcomServiceTimer); + + DECLARE_OBJ_GC(UBSHcomNetDriver); + DECLARE_OBJ_GC(UBSHcomNetEndpoint); + DECLARE_OBJ_GC(UBSHcomNetMessage); + DECLARE_OBJ_GC(UBSHcomNetRequestContext); + DECLARE_OBJ_GC(UBSHcomNetResponseContext); + + DECLARE_OBJ_GC(RDMAWorker); + DECLARE_OBJ_GC(RDMAEndpoint); + DECLARE_OBJ_GC(RDMACq); + DECLARE_OBJ_GC(RDMAContext); + DECLARE_OBJ_GC(RDMAQp); + DECLARE_OBJ_GC(RDMAAsyncEndPoint); + DECLARE_OBJ_GC(RDMASyncEndpoint); + DECLARE_OBJ_GC(RDMAMemoryRegion); + DECLARE_OBJ_GC(RDMAMemoryRegionFixedBuffer); + DECLARE_OBJ_GC(NetDriverRDMA); + DECLARE_OBJ_GC(NetDriverRDMAWithOob); + DECLARE_OBJ_GC(NetAsyncEndpoint); + DECLARE_OBJ_GC(NetSyncEndpoint); + +#ifdef UB_BUILD_ENABLED + DECLARE_OBJ_GC(UBContext); + DECLARE_OBJ_GC(UBWorker); + DECLARE_OBJ_GC(NetDriverUB); + DECLARE_OBJ_GC(NetDriverUBWithOob); + DECLARE_OBJ_GC(NetUBAsyncEndpoint); + DECLARE_OBJ_GC(NetUBSyncEndpoint); + DECLARE_OBJ_GC(UBJfc); + DECLARE_OBJ_GC(UBJetty); + DECLARE_OBJ_GC(UBPublicJetty); + DECLARE_OBJ_GC(UBMemoryRegion); + DECLARE_OBJ_GC(UBMemoryRegionFixedBuffer); +#endif + + DECLARE_OBJ_GC(NetDriverSockWithOOB); + DECLARE_OBJ_GC(NetAsyncEndpointSock); + DECLARE_OBJ_GC(NetSyncEndpointSock); + DECLARE_OBJ_GC(SockWorker); + DECLARE_OBJ_GC(SockBuff); + DECLARE_OBJ_GC(Sock); + + DECLARE_OBJ_GC(NetDriverShmWithOOB); + DECLARE_OBJ_GC(NetAsyncEndpointShm); + DECLARE_OBJ_GC(NetSyncEndpointShm); + DECLARE_OBJ_GC(ShmChannel); + DECLARE_OBJ_GC(ShmChannelKeeper); + DECLARE_OBJ_GC(ShmDataChannel); + DECLARE_OBJ_GC(ShmHandle); + DECLARE_OBJ_GC(ShmMemoryRegion); + DECLARE_OBJ_GC(ShmQueue); + DECLARE_OBJ_GC(ShmWorker); + DECLARE_OBJ_GC(ShmSyncEndpoint); + + static void Dump(); +}; +} +} + +#endif // OCK_HCOM_NET_OBJ_STATISTICS_H diff --git a/src/hcom_ref.h b/src/hcom_ref.h new file mode 100644 index 0000000000000000000000000000000000000000..05dd83b5bb0d72c2b00da2f724ab5441b91b86e6 --- /dev/null +++ b/src/hcom_ref.h @@ -0,0 +1,162 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_CPP_REF_H +#define OCK_HCOM_CPP_REF_H + +#include + +namespace ock { +namespace hcom { + +/** + * @brief Smart pointer object + */ +template class NetRef { +public: + // constructor + NetRef() noexcept = default; + + // fix: can't be explicit + NetRef(T *newObj) noexcept + { + // if new obj is not null, increase reference count and assign to mObj + // else nothing need to do as mObj is nullptr by default + if (newObj != nullptr) { + newObj->IncreaseRef(); + mObj = newObj; + } + } + + NetRef(const NetRef &other) noexcept + { + // if other's obj is not null, increase reference count and assign to mObj + // else nothing need to do as mObj is nullptr by default + if (other.mObj != nullptr) { + other.mObj->IncreaseRef(); + mObj = other.mObj; + } + } + +#if __GNUC__ == 4 && __GNUC_MINOR__ == 8 && __GNUC_PATCHLEVEL__ == 5 + NetRef(NetRef &&other) noexcept : mObj(exchangeHcom(other.mObj, nullptr)) +#else + NetRef(NetRef &&other) noexcept : mObj(std::__exchange(other.mObj, nullptr)) +#endif + { + // move constructor + // since this mObj is null, just exchange + } + + // de-constructor + ~NetRef() + { + if (mObj != nullptr) { + mObj->DecreaseRef(); + } + } + + // operator = + inline NetRef &operator = (T *newObj) + { + this->Set(newObj); + return *this; + } + + inline NetRef &operator = (const NetRef &other) + { + if (this != &other) { + this->Set(other.mObj); + } + return *this; + } + + NetRef &operator = (NetRef &&other) noexcept + { + if (this != &other) { + auto tmp = mObj; +#if __GNUC__ == 4 && __GNUC_MINOR__ == 8 && __GNUC_PATCHLEVEL__ == 5 + mObj = exchangeHcom(other.mObj, nullptr); +#else + mObj = std::__exchange(other.mObj, nullptr); +#endif + if (tmp != nullptr) { + tmp->DecreaseRef(); + } + } + return *this; + } + + // equal operator + inline bool operator == (const NetRef &other) const + { + return mObj == other.mObj; + } + + inline bool operator == (T *other) const + { + return mObj == other; + } + + inline bool operator != (const NetRef &other) const + { + return mObj != other.mObj; + } + + inline bool operator != (T *other) const + { + return mObj != other; + } + + // get operator and set + inline T *operator->() const + { + return mObj; + } + + inline T *Get() const + { + return mObj; + } + + inline void Set(T *newObj) + { + if (newObj == mObj) { + return; + } + + if (newObj != nullptr) { + newObj->IncreaseRef(); + } + + if (mObj != nullptr) { + mObj->DecreaseRef(); + } + + mObj = newObj; + } + + template C *ToChild() + { + if (mObj != nullptr) { + return dynamic_cast(mObj); + } + return nullptr; + } + +private: + T *mObj = nullptr; +}; + +} +} + +#endif // OCK_HCOM_CPP_REF_H diff --git a/src/hcom_split.cpp b/src/hcom_split.cpp new file mode 100644 index 0000000000000000000000000000000000000000..570e6a7114814667499b4dabf36785e47e412f32 --- /dev/null +++ b/src/hcom_split.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "hcom_split.h" +#include "securec.h" + +namespace ock { +namespace hcom { + +/// SyncCallWithSelfPoll 拼包专用。 +/// - 在拼包时会尝试接收一次,如果发现首包为 RAW 包则直接返回并设置出参 data、 +/// dataLen,此时出参 acc 无实际作用。 +/// - 如果发现需要拼包,则会在出参 acc 中分配内存,当拼包完成时设置出参 data、 +/// dataLen 表明实际收到的数据、数据长度。在拼包阶段发生任意错误都会提前返回。 +/// \seealso SpliceMessage +SerResult SyncSpliceMessage(UBSHcomNetResponseContext &ctx, UBSHcomNetEndpoint *ep, int32_t timeout, + std::string &acc, void *&data, uint32_t &dataLen) +{ + SerResult result = ep->Receive(timeout, ctx); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync call receive failed " << result << " ep id " << ep->Id()); + return result; + } + + // 是否为拆包的一部分 + switch (ctx.Header().extHeaderType) { + case UBSHcomExtHeaderType::RAW: + data = ctx.Message()->Data(); + dataLen = ctx.Message()->DataLen(); + return result; + + case UBSHcomExtHeaderType::FRAGMENT: + break; + } + + while (true) { + const uintptr_t msgAddr = reinterpret_cast(ctx.Message()->Data()); + const uint32_t msgSize = ctx.Message()->DataLen(); + if (msgSize < sizeof(UBSHcomFragmentHeader)) { + NN_LOG_ERROR("SyncSpliceMessage: message size is invalid!"); + return SER_ERROR; + } + + const UBSHcomFragmentHeader *serviceHeader = reinterpret_cast(msgAddr); + const void *payload = reinterpret_cast(msgAddr + sizeof(UBSHcomFragmentHeader)); + const uint64_t payloadLen = msgSize - sizeof(UBSHcomFragmentHeader); + const auto msgId = serviceHeader->msgId; + const uint32_t totalLength = serviceHeader->totalLength; + const uint32_t offset = serviceHeader->offset; + + NN_LOG_DEBUG("SyncSpliceMessage: id " << msgId << ", totalLength " << totalLength << ", offset " << offset + << ", size " << payloadLen); + + // 避免因数据在网络中被篡改而造成高内存占用 + if (totalLength >= SERVICE_MAX_TOTAL_LENGTH) { + NN_LOG_ERROR("SyncSpliceMessage: totalLength (" << totalLength << ") is larger than the maximum (" + << SERVICE_MAX_TOTAL_LENGTH << ")"); + return SER_SPLIT_INVALID_MSG; + } + + // 首包分配足够大的内存 + if (offset == 0) { + acc.resize(totalLength); + } + + // | msg1 | ... | last | + // | msg2 | ... | ... | last | + // + // 可能 msg1 的尾部消息丢失同时 msg2 的消息头部也丢失;或者是 msg1 的消 + // 息头部丢失导致 acc 未分配足够空间。 + if (NN_UNLIKELY(offset > acc.size())) { + NN_LOG_ERROR("SyncSpliceMessage: the fragment is from another msg, or the first fragment is lost. offset = " + << offset << ", totalLength = " << acc.size()); + return SER_SPLIT_INVALID_MSG; + } + + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(reinterpret_cast(acc.data()) + offset), + acc.size() - offset, payload, payloadLen) != EOK)) { + NN_LOG_ERROR("SyncSpliceMessage: the payload is too large."); + return SER_SPLIT_INVALID_MSG; + } + + // 拼包完成,如果数据在网络层面发现异常, SyncEp 可能会一直阻塞. + if (offset + payloadLen == totalLength) { + NN_LOG_DEBUG("SyncSpliceMessage: complete! id " << msgId); + + data = const_cast(acc.data()); + dataLen = acc.size(); + break; + } + + result = ep->Receive(timeout, ctx); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync call receive failed " << result << " ep id " << ep->Id()); + return result; + } + + // 在大包还未拼成、接收小包过程中出现 RAW 包,可能是尾部小包丢失。只能将 + // 此 RAW 包丢弃,用户状态机可能会发生错误。 + switch (ctx.Header().extHeaderType) { + case UBSHcomExtHeaderType::RAW: + NN_LOG_ERROR( + "SyncSpliceMessage: a RAW type msg is received during SpliceMessage, it will be discarded."); + return SER_ERROR; + + case UBSHcomExtHeaderType::FRAGMENT: + break; + } + } + + return SER_OK; +} + +} // namespace hcom +} // namespace ock diff --git a/src/hcom_split.h b/src/hcom_split.h new file mode 100644 index 0000000000000000000000000000000000000000..1900b92f10cc5a12bc595c62d314c0fa8c31b730 --- /dev/null +++ b/src/hcom_split.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_CPP_HCOM_SPLIT_H +#define OCK_HCOM_CPP_HCOM_SPLIT_H + +#include "hcom.h" + +namespace ock { +namespace hcom { +using SerResult = int; + +/// SplitSend 专用: 2G 为最大可发送包大小 +const uint32_t SERVICE_MAX_TOTAL_LENGTH = 2U * 1024 * 1024 * 1024; + +SerResult SyncSpliceMessage(UBSHcomNetResponseContext &ctx, UBSHcomNetEndpoint *ep, int32_t timeout, + std::string &acc, void *&data, uint32_t &dataLen); + +enum class SpliceMessageResultType { + OK, + ERROR, + INDETERMINATE, +}; + +} // namespace hcom +} // namespace ock + +#endif // OCK_HCOM_CPP_HCOM_SPLIT_H diff --git a/src/hcom_utils.cpp b/src/hcom_utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e509cb668005a4bcc25fc38c00b27005e9ffb0a --- /dev/null +++ b/src/hcom_utils.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include + +#include "hcom_utils.h" + +namespace ock { +namespace hcom { + +NetSpinLock NetUuid::gLock; +uint32_t NetUuid::gSeqNo = 0; + +uint64_t NetUuid::GenerateUuid(const std::string& ip) +{ + struct Uuid { + union { + struct { + uint64_t ip : 8; + uint64_t pid : 20; + uint64_t tid : 20; + uint64_t seqNo : 16; + }; + uint64_t value; + }; + }; + + if (ip.empty()) { + return GenerateUuid(); + } + + // 仅使用最低位ip地址,如"xx1.xx2.xx3.xx4"中的"xx4" + uint32_t ipNum = inet_addr(ip.c_str()); + ipNum >>= 0x18; + + int32_t pid = getpid(); + auto tid = pthread_self(); + + struct Uuid res; + res.ip = ipNum; + res.tid = tid; + res.pid = static_cast(pid); + + gLock.Lock(); + gSeqNo += 1; + res.seqNo = gSeqNo; + gLock.Unlock(); + + return res.value; +} + +} +} diff --git a/src/hcom_utils.h b/src/hcom_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b3c43ab7a6df2c15ae34589409698b3d03393911 --- /dev/null +++ b/src/hcom_utils.h @@ -0,0 +1,549 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_UTIL_H_54434 +#define OCK_HCOM_NET_UTIL_H_54434 + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hcom_err.h" +#include "hcom_def.h" + +namespace ock { +namespace hcom { +inline timespec MONOTONIC_TIME() +{ + struct timespec now { + 0, 0 + }; + clock_gettime(CLOCK_MONOTONIC, &now); + return now; +} + +inline uint64_t MONOTONIC_TIME_INTERVAL_NS(const timespec &start, const timespec &end) +{ + return (end.tv_sec - start.tv_sec) * NN_NO1000000000 + + (end.tv_nsec - start.tv_nsec); +} + +inline uint64_t MONOTONIC_TIME_INTERVAL_US(const timespec &start, const timespec &end) +{ + return (end.tv_sec - start.tv_sec) * NN_NO1000000 + + (end.tv_nsec - start.tv_nsec) / NN_NO1000; +} + +inline uint64_t MONOTONIC_TIME_INTERVAL_SEC(const timespec &start, const timespec &end) +{ + return (end.tv_sec - start.tv_sec) + + (end.tv_nsec - start.tv_nsec) / NN_NO1000000000; +} + +inline uint64_t MONOTONIC_TIME_NS() +{ + struct timespec now { + 0, 0 + }; + clock_gettime(CLOCK_MONOTONIC, &now); + return now.tv_nsec + now.tv_sec * NN_NO1000000000; +} + +inline uint64_t MONOTONIC_TIME_SECOND() +{ + return MONOTONIC_TIME_NS() / NN_NO1000000000; +} + +inline int32_t TimeSecToMs(const int32_t &timeInSec) +{ + if (NN_UNLIKELY(timeInSec < 0)) { + return -1; + } + if (NN_UNLIKELY(timeInSec > static_cast(NN_NO2000))) { + return NN_NO2000 * NN_NO1000; + } + return timeInSec * static_cast(NN_NO1000); +} + +/** + * @brief Check whether the path is canonical, and canonical it. + */ +inline bool CanonicalPath(std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + + /* It will allocate memory to store path */ + char *realPath = realpath(path.c_str(), nullptr); + if (realPath == nullptr) { + return false; + } + + path = realPath; + free(realPath); + realPath = nullptr; + return true; +} +/* ****************************************************************************************** */ +class NetReadWriteLock { +public: + NetReadWriteLock() + { + pthread_rwlock_init(&mLock, nullptr); + } + ~NetReadWriteLock() + { + pthread_rwlock_destroy(&mLock); + } + + NetReadWriteLock(const NetReadWriteLock &) = delete; + NetReadWriteLock &operator=(const NetReadWriteLock &) = delete; + NetReadWriteLock(NetReadWriteLock &&) = delete; + NetReadWriteLock &operator=(NetReadWriteLock &&) = delete; + + inline void LockRead() + { + pthread_rwlock_rdlock(&mLock); + } + + inline void LockWrite() + { + pthread_rwlock_wrlock(&mLock); + } + + inline void UnLock() + { + pthread_rwlock_unlock(&mLock); + } + +private: + pthread_rwlock_t mLock{}; +}; + +/* ****************************************************************************************** */ +class NetSpinLock { +public: + NetSpinLock() = default; + ~NetSpinLock() = default; + + NetSpinLock(const NetSpinLock &) = delete; + NetSpinLock &operator=(const NetSpinLock &) = delete; + NetSpinLock(NetSpinLock &&) = delete; + NetSpinLock &operator=(NetSpinLock &&) = delete; + + inline bool TryLock() + { + return mFlag.test_and_set(std::memory_order_acquire); + } + + inline void Lock() + { + while (mFlag.test_and_set(std::memory_order_acquire)) { + } + } + + inline void Unlock() + { + mFlag.clear(std::memory_order_release); + } + +private: + std::atomic_flag mFlag = ATOMIC_FLAG_INIT; +}; + +/* ****************************************************************************************** */ + +template class NetRingBuffer { +public: + explicit NetRingBuffer(uint32_t capacity) : mCapacity(capacity) + { + } + + ~NetRingBuffer() + { + UnInitialize(); + } + + inline uint32_t Capacity() const + { + return mCapacity; + } + + NResult Initialize() + { + if (mCapacity == 0) { + return NN_INVALID_PARAM; + } + + if (mRingBuf != nullptr) { + return NN_OK; + } + + mRingBuf = new (std::nothrow) T[mCapacity]; + if (NN_UNLIKELY(mRingBuf == nullptr)) { + return NN_NEW_OBJECT_FAILED; + } + mCount = 0; + mHead = 0; + mTail = 0; + + return NN_OK; + } + + inline void UnInitialize() + { + if (mRingBuf == nullptr) { + return; + } + + delete[] mRingBuf; + mRingBuf = nullptr; + } + + inline bool PushBack(const T &item) + { + mLock.Lock(); + if (mCapacity <= mCount) { + mLock.Unlock(); + return false; + } + + // mRinBuf will not be null after init, this func is performance-sensitive, there is no need to check null + mRingBuf[mTail] = item; + if (mTail != mCapacity - 1) { + ++mTail; + } else { + mTail = 0; + } + ++mCount; + mLock.Unlock(); + return true; + } + + inline bool InterruptablePushBack(const T &item, bool &isInterrupted) + { + mLock.Lock(); + if (mCapacity <= mCount) { + mLock.Unlock(); + return false; + } + + if (mInterrupt) { + isInterrupted = true; + mLock.Unlock(); + return false; + } + + // mRinBuf will not be null after init, this func is performance-sensitive, there is no need to check null + mRingBuf[mTail] = item; + if (mTail != mCapacity - 1) { + ++mTail; + } else { + mTail = 0; + } + ++mCount; + mLock.Unlock(); + return true; + } + + inline bool PushFront(const T &item) + { + mLock.Lock(); + if (mCapacity <= mCount) { + mLock.Unlock(); + return false; + } + + // move to tail + if (mHead == 0) { + mHead = mCapacity - 1; + } else { + mHead--; + } + + // mRinBuf will not be null after init, this func is performance-sensitive, there is no need to check null + mRingBuf[mHead] = item; + ++mCount; + + mLock.Unlock(); + return true; + } + + inline bool PopFront(T &item) + { + mLock.Lock(); + if (mCount == 0) { + mLock.Unlock(); + return false; + } + + // mRinBuf will not be null after init, this func is performance-sensitive, there is no need to check null + item = mRingBuf[mHead]; + if (mHead != mCapacity - 1) { + ++mHead; + } else { + mHead = 0; + } + --mCount; + mLock.Unlock(); + return true; + } + + inline bool GetFront(T &item) + { + mLock.Lock(); + if (mCount == 0) { + mLock.Unlock(); + return false; + } + item = mRingBuf[mHead]; + mLock.Unlock(); + return true; + } + + inline bool PopFrontN(T *items, uint32_t n) + { + mLock.Lock(); + if (mCount < n) { + mLock.Unlock(); + return false; + } + + // mRinBuf will not be null after init, this func is performance-sensitive, there is no need to check null + for (uint32_t i = 0; i < n; ++i) { + items[i] = mRingBuf[mHead]; + if (mHead != mCapacity - 1) { + ++mHead; + } else { + mHead = 0; + } + } + + mCount -= n; + + mLock.Unlock(); + return true; + } + + inline bool IsFull() + { + mLock.Lock(); + auto full = mCount >= mCapacity; + mLock.Unlock(); + return full; + } + + inline uint32_t Size() + { + mLock.Lock(); + auto temp = mCount; + mLock.Unlock(); + return temp; + } + + inline void Interrupt() + { + mLock.Lock(); + mInterrupt = true; + mLock.Unlock(); + } + + NetRingBuffer(const NetRingBuffer &) = delete; + NetRingBuffer(NetRingBuffer &&) = delete; + NetRingBuffer &operator=(const NetRingBuffer &) = delete; + NetRingBuffer &operator=(NetRingBuffer &&) = delete; + +private: + T *mRingBuf = nullptr; + NetSpinLock mLock; + uint32_t mCapacity = 0; + uint32_t mCount = 0; + uint32_t mHead = 0; + uint32_t mTail = 0; + bool mInterrupt = false; +}; + +template class NetBlockingQueue { +public: + explicit NetBlockingQueue(uint32_t capacity) : mRingBuffer(capacity) + { + } + ~NetBlockingQueue() + { + UnInitialize(); + } + + inline NResult Initialize() + { + if (sem_init(&mSem, 0, 0) != 0) { + return NN_BLOCK_QUEUE_SEM_INIT_FAILED; + } + + return mRingBuffer.Initialize(); + } + + inline void UnInitialize() + { + mRingBuffer.UnInitialize(); + sem_destroy(&mSem); + } + + inline bool Enqueue(T &item) + { + auto result = mRingBuffer.PushBack(item); + if (result) { + sem_post(&mSem); + } + return result; + } + + inline bool EnqueueFirst(T &item) + { + auto result = mRingBuffer.PushFront(item); + if (result) { + sem_post(&mSem); + } + return result; + } + + /* tip Dequeue and Interrupt cannot be used at the same time */ + inline bool Dequeue(T &item) + { + while (true) { + auto result = mRingBuffer.PopFront(item); + if (!result) { + sem_wait(&mSem); + } else { + // result always true + return result; + } + } + } + + inline bool InterruptableEnqueue(const T &item, bool &isInterrupted) + { + auto result = mRingBuffer.InterruptablePushBack(item, isInterrupted); + if (result) { + sem_post(&mSem); + } + return result; + } + + /* tip Dequeue and InterruptableDequeue cannot be + * used at the same time */ + inline bool InterruptableDequeue(T &item, bool &isInterrupt) + { + isInterrupt = false; + while (true) { + auto result = mRingBuffer.PopFront(item); + if (!result) { + sem_wait(&mSem); + if (NN_UNLIKELY(mInterrupt)) { + isInterrupt = true; + mInterrupt.store(false); + return false; + } + } else { + // result always true + return result; + } + } + } + + inline uint32_t Size() + { + return mRingBuffer.Size(); + } + + /* tip Interrupt only be used for InterruptableDequeue */ + inline void Interrupt() + { + mInterrupt.store(true); + sem_post(&mSem); + mRingBuffer.Interrupt(); + } + +private: + NetRingBuffer mRingBuffer; + sem_t mSem{}; + std::atomic mInterrupt{false}; +}; + +/* ****************************************************************************************** */ +class NetUuid { +public: + static inline uint64_t GenerateUuid() + { + // 高32位:时间戳(ns级) + uint64_t timestamp = static_cast(std::chrono::system_clock::now().time_since_epoch().count()); + + gLock.Lock(); + uint32_t seqNo = gSeqNo++; + gLock.Unlock(); + + return (timestamp << NN_NO32) | seqNo; + } + + static uint64_t GenerateUuid(const std::string& ip); +private: + static uint32_t gSeqNo; + static NetSpinLock gLock; +}; + +/* ****************************************************************************************** */ +// const variables +constexpr uint32_t PAGE_ALIGN_H = NN_NO4096; + +// defines +#define POWER_OF_2(x) ((((x) - 1) & (x)) == 0) + +#define H_LIKELY(e) (__builtin_expect(!!(e), 1) != 0) +#define H_UNLIKELY(e) (__builtin_expect(!!(e), 0) != 0) + +#define H_CAS(ptr, o, n) \ + __atomic_compare_exchange_n(ptr, &(o), n, 0, __ATOMIC_RELEASE, \ + __ATOMIC_RELAXED) +#define H_WMB() __atomic_thread_fence(__ATOMIC_RELEASE) +#define H_RMB() __atomic_thread_fence(__ATOMIC_ACQUIRE) +#define H_MB() __atomic_thread_fence(__ATOMIC_SEQ_CST) + +#define H_ATOMIC_LOAD(n) __atomic_load_n(&(n), __ATOMIC_RELAXED) +#define H_ATOMIC_FAA(n, num) __atomic_fetch_add(&(n), (num), __ATOMIC_RELAXED) +#define H_ATOMIC_STORE(n, num) __atomic_store_n(&(n), (num), __ATOMIC_RELAXED) + +inline void H_Pause() +{ +#ifdef __x86_64__ + asm volatile("pause" ::: "memory"); +#elif defined(__aarch64__) + asm volatile("yield" ::: "memory"); +#endif +} + +inline uint32_t NN_NextPower2(uint32_t value) +{ + if (value < NN_NO2) { + return NN_NO2; + } + return 1UL << (NN_NO32 - __builtin_clz(value - 1)); +} + +} // namespace hcom +} // namespace ock + +#endif // OCK_HCOM_NET_UTIL_H_54434 diff --git a/src/service_v2/api/hcom_service.h b/src/service_v2/api/hcom_service.h new file mode 100644 index 0000000000000000000000000000000000000000..41b5e9ea0a8cd4f4612475f3322dfd37c3030024 --- /dev/null +++ b/src/service_v2/api/hcom_service.h @@ -0,0 +1,322 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_API_HCOM_SERVICE_H_ +#define HCOM_API_HCOM_SERVICE_H_ + +#include +#include +#include +#include "hcom.h" +#include "hcom_service_channel.h" +#include "hcom_obj_statistics.h" +#include "hcom_def.h" +#include "hcom_ref.h" + +namespace ock { +namespace hcom { + +using UBSHcomChannelPtr = NetRef; + +using UBSHcomServiceNewChannelHandler = + std::function; +using UBSHcomServiceChannelBrokenHandler = std::function; +using UBSHcomServiceRecvHandler = std::function; +using UBSHcomServiceSendHandler = std::function; +using UBSHcomServiceOneSideDoneHandler = std::function; +using UBSHcomServiceIdleHandler = UBSHcomNetDriverIdleHandler; +using UBSHcomServiceProtocol = UBSHcomNetDriverProtocol; +using UBSHcomServiceLBPolicy = UBSHcomNetDriverLBPolicy; + +class UBSHcomService { +public: + /** + * @brief service创建 + * + * @param t service对应底层通信协议 + * @param name service名称 + * @param opt service创建需要的配置项 + * @return UBSHcomService* 返回创建好的service + */ + static UBSHcomService* Create(UBSHcomServiceProtocol t, const std::string &name, const UBSHcomServiceOptions &opt = {}); + + /** + * @brief 销毁service + * + * @param name 要销毁的实例名称 + * @return int32_t 成功:0;失败:错误码 + */ + static int32_t Destroy(const std::string &name); + + /** + * @brief 绑定监听url,指定监听的类型及url,客户端可以不调用Bind。 + * + * @param listenerUrl 监听url,对于tcp来说:tcp://127.0.0.1:9981 + * 对于uds来说:uds://file:perm(如果有:perm则使用真实文件,perm格式如:0600,没有则使用抽象文件) + * 对于ubc来说:ubc://eid:jettyId + * @param handler 收到建链请求后的回调函数 + * @return int32_t 成功:0;失败:错误码 + */ + virtual int32_t Bind(const std::string &listenerUrl, const UBSHcomServiceNewChannelHandler &handler) = 0; + + /** + * @brief 开启服务,如果调用过Bind,则同时开启监听,否则不进行监听 + * + * @return int32_t 成功:0;失败:错误码 + */ + virtual int32_t Start() = 0; + + /** + * @brief 建立链接 + * + * @param serverUrl 建连服务端url,对于tcp来说:tcp://127.0.0.1:9981 + * 对于uds来说:uds://file(文件名/抽象命名空间) + * 对于ubc来说:ubc://eid:jettyId + * @param ch 出参,建链成功返回的channel + * @param opt 建链配置项 + * @return int32_t 成功:0;失败:错误码 + */ + virtual int32_t Connect(const std::string &serverUrl, UBSHcomChannelPtr &ch, const UBSHcomConnectOptions &opt = {}) = 0; + + /** + * @brief 断开链接 + * + * @param ch 要断开的channel + */ + virtual void Disconnect(const UBSHcomChannelPtr &ch) = 0; + + /** + * @brief 注册memory region,内存会在内部进行分配 + * + * @param size memory region的大小 + * @param mr 注册好的memoryRegion + * @return int32_t 成功:0;失败:错误码 + */ + virtual int32_t RegisterMemoryRegion(uint64_t size, UBSHcomRegMemoryRegion &mr) = 0; + + /** + * @brief 注册memory region,分配的内存需要传入进来 + * + * @param address 需要被注册为MR的内存起始地址 + * @param size memory region的大小 + * @param mr 注册好的memoryRegion + * @return int32_t 成功:0;失败:错误码 + */ + virtual int32_t RegisterMemoryRegion(uintptr_t address, uint64_t size, UBSHcomRegMemoryRegion &mr) = 0; + + /** + * @brief memory region取消注册 + * + * @param mr 取消的mr + */ + virtual void DestroyMemoryRegion(UBSHcomRegMemoryRegion &mr) = 0; + + /** + * @brief 设置RegisterMemoryRegion是否将mr信息放入pgTable管理 + * 若用户需要使用RNDV,则需要设置为true + * + * @param enableMrCache true表示放入pgTable,false表示不放入;默认是false。 + */ + virtual void SetEnableMrCache(bool enableMrCache) = 0; + + /** + * @brief 注册断链回调 + * + * @param handler 断链回调函数 + * @param policy 断链回调策略 + */ + virtual void RegisterChannelBrokenHandler(const UBSHcomServiceChannelBrokenHandler &handler, + const UBSHcomChannelBrokenPolicy policy) = 0; + + /** + * @brief 注册pollCq、epoll_wait超时等回调 + * + * @param handler 回调函数 + */ + virtual void RegisterIdleHandler(const UBSHcomServiceIdleHandler &handler) = 0; + + /** + * @brief 注册接收receive操作回调 + * + * @param rcvHandler 回调函数 + */ + virtual void RegisterRecvHandler(const UBSHcomServiceRecvHandler &recvHandler) = 0; + + /** + * @brief 注册发送send操作回调 + * + * @param sentHandler 回调函数 + */ + virtual void RegisterSendHandler(const UBSHcomServiceSendHandler &sendHandler) = 0; + + /** + * @brief 注册单边操作回调 + * + * @param oneSideDoneHandler 回调函数 + */ + virtual void RegisterOneSideHandler(const UBSHcomServiceOneSideDoneHandler &oneSideDoneHandler) = 0; + + // 高级配置选项及特性配置选项 + + /** + * @brief 增加workerGroup + * + * @param workerGroupId workerGroup的id + * @param threadCount 该workerGroup的线程数 + * @param cpuIdsRange 该workerGroup绑定的cpuId范围 + * @param priority 同线程nice值,范围[-20,19],-20优先级最高,19优先级最低 + * @param multirailIdx 该workerGroup绑定的rail + */ + virtual void AddWorkerGroup(uint16_t workerGroupId, uint32_t threadCount, + const std::pair &cpuIdsRange, int8_t priority = 0, uint16_t multirailIdx = 0) = 0; + + /** + * @brief 增加监听器,支持监听多个url + * + * @param url 监听url,tcp协议:tcp://127.0.0.1:9981;uds协议:uds://file(文件名/抽象命名空间) + * @param workerCount 该listener监听到链接请求后,会从对应的workerGroup中选择workerCount个线程按照lbPolicy的策略去选择线程绑定到asyncEp上 + */ + virtual void AddListener(const std::string &url, uint16_t workerCount = UINT16_MAX) = 0; + + /** + * @brief 设置建链负载均衡策略,主动/被动建链时需要选择一个worker线程去完成,lbPolicy则代表选择worker线程的策略 + * + * @param lbPolicy NET_ROUND_ROBIN:轮询,NET_HASH_IP_PORT:根据ip和port做hash + */ + virtual void SetConnectLBPolicy(UBSHcomServiceLBPolicy lbPolicy) = 0; + + /** + * @brief TLS相关配置项,如果不配置的话默认不开启 + * + * @param opt + */ + virtual void SetTlsOptions(const UBSHcomTlsOptions &opt) = 0; + + virtual void SetConnSecureOpt(const UBSHcomConnSecureOptions &opt) = 0; + + /** + * @brief 设置TCP_USER_TIMEOUT套接字选项,tcp超时时间,[0, 1024],0表示永不超时 + * + * @param timeOutSec + */ + virtual void SetTcpUserTimeOutSec(uint16_t timeOutSec) = 0; + + /** + * @brief 设置TCP发送是否要做内存拷贝(hcom内部内存) + * + * @param tcpSendZCopy 是否要做数据拷贝 + */ + virtual void SetTcpSendZCopy(bool tcpSendZCopy) = 0; + + /** + * @brief 设置设备ipMask,用于rdma/ub,根据ipMask获取该网段的GID和UBEId + * + * @param ipMasks 用于过滤的ipMask集合 + */ + virtual void SetDeviceIpMask(const std::vector &ipMasks) = 0; + + /** + * @brief 设置设备的ipGroup,如果明确制定了ipGroup,则直接使用对应的设备 + * + * @param ipGroups ipGroups集合 + */ + virtual void SetDeviceIpGroups(const std::vector &ipGroups) = 0; + + /** + * @brief 设置cq队列的深度 + * + * @param depth cq队列深度 + */ + virtual void SetCompletionQueueDepth(uint16_t depth) = 0; + + /** + * @brief 设置SQ队列的大小,默认256 + * + * @param sqSize 队列大小 + */ + virtual void SetSendQueueSize(uint32_t sqSize) = 0; + + /** + * @brief 设置RQ队列的大小,默认256 + * + * @param rqSize 队列大小 + */ + virtual void SetRecvQueueSize(uint32_t rqSize) = 0; + + /** + * @brief 设置提前下发wr的数量,不设置的话默认64 + * @param prePostSize 预先下发的wr数量 + */ + virtual void SetQueuePrePostSize(uint32_t prePostSize) = 0; + + /** + * @brief 设置批量polling的大小,默认是4 + * + * @param pollSize 每批大小 + */ + virtual void SetPollingBatchSize(uint16_t pollSize) = 0; + + /** + * @brief 设置polling的超时时间,单位us,默认500 + * + * @param pollTimeout 超时时间 + */ + virtual void SetEventPollingTimeOutUs(uint16_t pollTimeout) = 0; + + /** + * @brief 设置周期任务处理线程数,主要用在内部异步检查超时等场景,不设置的话默认1个线程 + * + * @param threadNum 线程数 + */ + virtual void SetTimeOutDetectionThreadNum(uint32_t threadNum) = 0; + + /** + * @brief 设置最大连接数,不设置的话默认250 + * + * @param maxConnCount 最大连接数 + */ + virtual void SetMaxConnectionCount(uint32_t maxConnCount) = 0; + + /** + * @brief 设置心跳选项 + * + * @param opt 心跳设置选项 + * @return int32_t + */ + virtual void SetHeartBeatOptions(const UBSHcomHeartBeatOptions &opt) = 0; + + /** + * @brief Set the Multi Rail Options object + * + * @param opt multi rail option + */ + virtual void SetMultiRailOptions(const UBSHcomMultiRailOptions &opt) = 0; + + /** + * @brief 设置 UB-C 多路径模式 + * + * @param ubcMode UB-C 多路径模式 + */ + virtual void SetUbcMode(UBSHcomUbcMode ubcMode) = 0; + + /** + * @brief 设置发送数据块最大数量 + * + * @param maxSendRecvDataCount 发送数据块最大数量 + */ + virtual void SetMaxSendRecvDataCount(uint32_t maxSendRecvDataCount) = 0; + + virtual ~UBSHcomService() {} + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + virtual int32_t DoDestroy(const std::string &name) = 0; + +private: + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; + +} +} +#endif // HCOM_SERVICE_H \ No newline at end of file diff --git a/src/service_v2/api/hcom_service_channel.h b/src/service_v2/api/hcom_service_channel.h new file mode 100644 index 0000000000000000000000000000000000000000..c9f9cdb447c82e7648c1c3bb26536b087a71e30d --- /dev/null +++ b/src/service_v2/api/hcom_service_channel.h @@ -0,0 +1,279 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_API_HCOM_CHANNEL_H_ +#define HCOM_API_HCOM_CHANNEL_H_ + +#include +#include +#include "hcom.h" +#include "hcom_def.h" +#include "hcom_service_def.h" +#include "hcom_split.h" + +namespace ock { +namespace hcom { +class UBSHcomServiceContext; +class UBSHcomChannel; +class HcomServiceCtxStore; + +using UBSHcomEndpointPtr = NetRef; +using UBSHcomChannelPtr = NetRef; +using UBSHcomServiceChannelBrokenHandler = std::function; + +enum UBSHcomChannelState : uint16_t { + CH_NEW, + CH_ESTABLISHED, + CH_CLOSE, + CH_DESTROY, +}; + +class Callback { +public: + Callback() = default; + virtual ~Callback() = default; + virtual void Run(UBSHcomServiceContext &context) = 0; + virtual void SetTime(uint64_t time) = 0; + virtual uint64_t GetTime() = 0; +}; + +/** + * @brief 内部使用,请使用NewCallback生成回调用 + * + * @param ClosureFunction + */ +template class InnerClosureCallback : public Callback { +public: + explicit InnerClosureCallback(ClosureFunction &&function, bool deleteSelf) + : mFunction(std::move(function)), mDeleteSelf(deleteSelf) {} + + ~InnerClosureCallback() override = default; + + void Run(UBSHcomServiceContext &context) override + { + bool doDeleteSelf = false; + if (mDeleteSelf) { + mDeleteSelf = false; + doDeleteSelf = true; + } + mFunction(context); + if (doDeleteSelf) { + delete this; + } + } + +private: + uint64_t GetTime() override + { + return mStartTime; + } + + void SetTime(uint64_t time) override + { + mStartTime = time; + } + +private: + ClosureFunction mFunction = nullptr; + bool mDeleteSelf = true; + uint64_t mStartTime = 0; +}; + +/** + * @brief Generate a self-deleting Callback object. + * + * @param Args + * @param args + * @return Callback* + * @note At present, asynchronous operation is not a hot spot. In order to simplify + * coding, std::bind is used to implement closure. If the cost of std::bind + * is found to be high, then optimize it. + */ +template Callback *UBSHcomNewCallback(Args... args) +{ + auto closure = std::bind(args...); + return new (std::nothrow) InnerClosureCallback(std::move(closure), true); +} + +class UBSHcomChannel { +public: + /** + * @brief 发送双边消息,不需要回复 + * + * @param req 发送双边消息请求 + * @param done nullptr:同步发送;非nullptr:异步发送,发送完成后回调函数 + * @return int32_t 0:成功;非0:失败错误码 + */ + virtual int32_t Send(const UBSHcomRequest &req, const Callback *done) = 0; + int32_t Send(const UBSHcomRequest &req); + + /** + * @brief 发送双边消息,需要回复 + * + * @param req 发送双边消息请求 + * @param rsp 出参,发送双边消息请求后对端回复 + * @param done nullptr:同步发送;非nullptr:异步发送,发送完成后回调函数 + * @return int32_t 0:成功;非0:失败错误码 + */ + virtual int32_t Call(const UBSHcomRequest &req, UBSHcomResponse &rsp, const Callback *done) = 0; + int32_t Call(const UBSHcomRequest &req, UBSHcomResponse &rsp); + + /** + * @brief 回复双边消息,接收端配合Call使用 + * + * @param ctx 回复上下文 + * @param req 回复数据 + * @param done nullptr:同步发送;非nullptr:异步发送,发送完成后回调函数 + * @return int32_t 0:成功;非0:失败错误码 + */ + virtual int32_t Reply(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req, const Callback *done) = 0; + int32_t Reply(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req); + + /** + * @brief 发送单边写请求 + * + * @param req 单边写请求 + * @param done nullptr:同步发送单边请求;非nullptr:异步发送单边请求,发送完成后回调函数 + * @return int32_t 0:成功;非0:失败错误码 + */ + virtual int32_t Put(const UBSHcomOneSideRequest &req, const Callback *done) = 0; + int32_t Put(const UBSHcomOneSideRequest &req); + + /** + * @brief 发送单边读请求 + * + * @param req 单边读请求 + * @param done nullptr:同步发送单边读请求;非nullptr:异步发送单边读请求,发送完成后回调函数 + * @return int32_t 0:成功;非0:失败错误码 + */ + virtual int32_t Get(const UBSHcomOneSideRequest &req, const Callback *done) = 0; + int32_t Get(const UBSHcomOneSideRequest &req); + + /** + * @brief 只接收RNDV请求时使用,且RNDV请求接收后必须reply + * + * @param context: 接收到的service Context + * @param address: recv的数据地址 + * @param size: recv的数据长度 + * @param done nullptr: 同步接受数据需要切换线程使用;非nullptr:异步收到数据请求,接收完成后执行回调函数 + * @return int32_t 0:成功;非0:失败错误码 + */ + virtual int32_t Recv(const UBSHcomServiceContext &context, uintptr_t address, uint32_t size, + const Callback *done = nullptr) = 0; + + /** + * @brief 流控设置 + * + * @param opt 流控设置选项 + * @return int32_t 0:成功;非0:失败错误码 + */ + virtual int32_t SetFlowControlConfig(const UBSHcomFlowCtrlOptions &opt) = 0; + + /** + * @brief 超时设置 + * + * @param oneSideTimeout 单边请求超时时间 + * @param twoSideTimeout 双边请求超时时间 + */ + virtual void SetChannelTimeOut(int16_t oneSideTimeout, int16_t twoSideTimeout) = 0; + + /** + * @brief 设置双边操作阈值 + * + * @param threshold 双边操作阈值 + * @return int32_t 0:成功;非0:失败错误码 + */ + virtual int32_t SetTwoSideThreshold(const UBSHcomTwoSideThreshold &threshold) = 0; + + /** + * @brief 设置trace id + * + * @param traceId trace id + */ + virtual void SetTraceId(const std::string &traceId) = 0; + + virtual uint64_t GetId() = 0; + virtual std::string GetPeerConnectPayload() = 0; + virtual int32_t GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &idInfo) = 0; + virtual int32_t SendFds(int fds[], uint32_t len) = 0; + virtual int32_t ReceiveFds(int fds[], uint32_t len, int32_t timeoutSec) = 0; + + virtual ~UBSHcomChannel() {} + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +protected: + virtual auto SpliceMessage(const UBSHcomNetRequestContext &ctx, bool isResp) + -> std::tuple = 0; + + uint32_t mUserSplitSendThreshold = UINT32_MAX; // 用户 payload 拆包阈值,已去除额外头部大小 +private: + virtual SerResult Initialize(std::vector &ep, uintptr_t ctxMemPool, uintptr_t periodicMgr, + uintptr_t pgTable) = 0; + virtual void UnInitialize() = 0; + virtual std::string ToString() = 0; + + virtual void SetUuid(const std::string &uuid) = 0; + virtual void SetPayload(const std::string &payLoad) = 0; + virtual void SetBrokenInfo(UBSHcomChannelBrokenPolicy policy, const UBSHcomServiceChannelBrokenHandler &broken) = 0; + virtual void SetEpBroken(uint32_t index) = 0; + virtual void SetChannelState(UBSHcomChannelState state) = 0; + virtual void SetMultiRail(bool multiRail, uint32_t threshold) = 0; + virtual void SetDriverNum(uint16_t driverNum) = 0; + virtual void SetTotalBandWidth(uint32_t bandWidth) = 0; + virtual void SetEnableMrCache(bool enableMrCache) = 0; + + virtual bool AllEpBroken() = 0; + virtual bool NeedProcessBroken() = 0; + virtual void ProcessIoInBroken() = 0; + virtual void InvokeChannelBrokenCb(UBSHcomChannelPtr &channel) = 0; + + virtual std::string GetUuid() = 0; + virtual uintptr_t GetTimerList() = 0; + virtual uint32_t GetLocalIp() = 0; + virtual uint16_t GetDelayEraseTime() = 0; + virtual HcomServiceCtxStore *GetCtxStore() = 0; + virtual UBSHcomChannelCallBackType GetCallBackType() = 0; + +private: + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class HcomServiceImp; + friend class HcomServiceTimer; + friend class HcomPeriodicManager; +}; + +inline int32_t UBSHcomChannel::Send(const UBSHcomRequest &req) +{ + return this->Send(req, nullptr); +} + +inline int32_t UBSHcomChannel::Call(const UBSHcomRequest &req, UBSHcomResponse &rsp) +{ + return this->Call(req, rsp, nullptr); +} + +inline int32_t UBSHcomChannel::Reply(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req) +{ + return this->Reply(ctx, req, nullptr); +} + +inline int32_t UBSHcomChannel::Put(const UBSHcomOneSideRequest &req) +{ + return this->Put(req, nullptr); +} + +inline int32_t UBSHcomChannel::Get(const UBSHcomOneSideRequest &req) +{ + return this->Get(req, nullptr); +} +} +} +#endif // HCOM_API_HCOM_CHANNEL_H_ \ No newline at end of file diff --git a/src/service_v2/api/hcom_service_context.h b/src/service_v2/api/hcom_service_context.h new file mode 100644 index 0000000000000000000000000000000000000000..1edde38df2e80f6dc9d14647887d7c9bccc33840 --- /dev/null +++ b/src/service_v2/api/hcom_service_context.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_API_HCOM_CONTEXT_H_ +#define HCOM_API_HCOM_CONTEXT_H_ +#include +#include "hcom_service_def.h" +#include "hcom_service_channel.h" + +namespace ock { +namespace hcom { + +using UBSHcomChannelPtr = NetRef; +using UBSHcomRequestContext = UBSHcomNetRequestContext; + +constexpr uint16_t EPIDX_BIT_NUM = 32; + +enum class Operation : uint8_t { + SER_RECEIVED = 0, /* support invoke all functions */ + SER_RECEIVED_RAW = 1, /* support invoke most functions except OpInfo() */ + SER_SENT = 2, /* support invoke basic functions except + Message()、MessageData()、MessageDataLen()、RspCtx()、all ReplySend*() */ + SER_SENT_RAW = 3, /* support invoke basic functions except + Message()、MessageData()、MessageDataLen()、RspCtx()、OpInfo()、all ReplySend*() */ + SER_ONE_SIDE = 4, /* support invoke basic functions except + Message()、MessageData()、MessageDataLen()、RspCtx()、OpInfo()、all ReplySend*() */ + SER_RNDV = 5, /* support invoke all functions */ + SER_RNDV_SGL = 6, /* support invoke all functions */ + SER_MULTIRAIL_RNDV_RAW = 7, /* support invoke all functions */ + SER_INVALID_OP_TYPE = 255, /* support invoke all functions */ +}; + +/* * + * @brief Context of request received, operationInfo/message/channel can be got from it, + * and reply message with it + */ +class UBSHcomServiceContext { +public: + enum Operation : uint8_t { + SER_RECEIVED = 0, /* support invoke all functions */ + SER_RECEIVED_RAW = 1, /* support invoke most functions except OpInfo() */ + SER_SENT = 2, /* support invoke basic functions except + Message()、MessageData()、MessageDataLen()、RspCtx()、all ReplySend*() */ + SER_SENT_RAW = 3, /* support invoke basic functions except + Message()、MessageData()、MessageDataLen()、RspCtx()、OpInfo()、all ReplySend*() */ + SER_ONE_SIDE = 4, /* support invoke basic functions except + Message()、MessageData()、MessageDataLen()、RspCtx()、OpInfo()、all ReplySend*() */ + SER_RNDV = 5, + SER_INVALID_OP_TYPE = 255, + }; + + /** + * @brief Get result of the operation + */ + SerResult Result() const; + + /** + * @brief Get the channel ptr + */ + const UBSHcomChannelPtr &Channel() const; + + /** + * @brief Get the operation type + * @return SER_INVALID_OP_TYPE if failed + */ + Operation OpType() const; + + /** + * @brief Get response context for send rsp message in other thread + * note: only support SER_RECEIVED/SER_RECEIVED_RAW invoke + * @return ture if success, false if failed + */ + uintptr_t RspCtx() const; + + /** + * @brief Get op code by user input + * note: only support SER_SENT/SER_RECEIVED invoke + * @return 0~999 if success, others if failed + */ + uint16_t OpCode() const; + + int32_t ErrorCode() const; + + /** + * @brief Get the message data received which valid in callback lifetime + * note1: only support SER_RECEIVED/SER_RECEIVED_RAW invoke + * note2: if user want to use message in other thread, need to copy message.data by self or invoke clone() + * @return valid address if success, nullptr if failed + */ + void *MessageData() const; + + /** + * @brief Get the message data received which valid in callback lifetime + * note1: only support SER_RECEIVED/SER_RECEIVED_RAW invoke + * @return valid length if success, 0 if failed + */ + uint32_t MessageDataLen() const; + + /** + * @brief clone service context + * @param copyData : true means malloc and copy receive data. + * @return SER_OK if success, others if failed + */ + static SerResult Clone(UBSHcomServiceContext &newOne, + const UBSHcomServiceContext &oldOne, + bool copyData = true); + + /** + * @brief check current context timeout or not + * @return true if timeout, false if not timeout + */ + bool IsTimeout() const; + + void Invalidate(); + + ~UBSHcomServiceContext() + { + Invalidate(); + } + + UBSHcomServiceContext() = default; + +private: + UBSHcomServiceContext(const UBSHcomRequestContext &ctx, UBSHcomChannel *ch); + + enum DataType : uint8_t { + OUTER_DATA = 0, /* assign by UBSHcomRequestContext.Message()->Data() */ + MEM_POOL_DATA = 1, /* alloc from channel mem pool */ + + INVALID_DATA = 255, + }; + + SerResult CopyData(void *data, uint32_t dataLen); + + UBSHcomChannelPtr mCh; /* channel ptr */ + uint64_t mTimeoutTraceMs = 0; /* record timeout time trace, 0 means never timeout */ + void *mData = nullptr; /* received/received raw message data address */ + uint32_t mDataLen = 0; /* received/received raw message data len */ + int32_t mErrorCode; + int32_t mResult = 0; /* context result */ + uint32_t mEpIdxInCh = 0; /* for response */ + uint32_t mSeqNo = 0; /* for response */ + uint32_t mReadCount = 0; + uint16_t mOpCode; + UBSHcomRequestContext::NN_OpType mOpType = + UBSHcomRequestContext::NN_INVALID_OP_TYPE; /* operate original type */ + DataType mDataType = DataType::INVALID_DATA; /* type of mData */ + // 64B cache rsv 12 Bytes + + uint8_t rsv[12]; + + friend class HcomServiceImp; + friend class UBSHcomChannel; + friend class HcomServiceGlobalObject; + friend class HcomPeriodicManager; + friend class HcomChannelImp; +}; + +inline SerResult UBSHcomServiceContext::Result() const +{ + return mResult; +} + +inline const UBSHcomChannelPtr &UBSHcomServiceContext::Channel() const +{ + return mCh; +} + +inline UBSHcomServiceContext::Operation UBSHcomServiceContext::OpType() const +{ + switch (mOpType) { + case UBSHcomRequestContext::NN_SENT: + return Operation::SER_SENT; + case UBSHcomRequestContext::NN_SENT_RAW: + case UBSHcomRequestContext::NN_SENT_RAW_SGL: + return Operation::SER_SENT_RAW; + case UBSHcomRequestContext::NN_RECEIVED: + return Operation::SER_RECEIVED; + case UBSHcomRequestContext::NN_RECEIVED_RAW: + return Operation::SER_RECEIVED_RAW; + case UBSHcomRequestContext::NN_WRITTEN: + case UBSHcomRequestContext::NN_SGL_WRITTEN: + case UBSHcomRequestContext::NN_READ: + case UBSHcomRequestContext::NN_SGL_READ: + return Operation::SER_ONE_SIDE; + case UBSHcomRequestContext::NN_RNDV: + return Operation::SER_RNDV; + default: + return Operation::SER_INVALID_OP_TYPE; + } +} + +inline uintptr_t UBSHcomServiceContext::RspCtx() const +{ + /* 8byte: low 4 byte for seq no, high 4 byte for epIndex */ + uintptr_t rsp = mEpIdxInCh; + rsp = (rsp << EPIDX_BIT_NUM) | mSeqNo; + return rsp; +} + +inline int32_t UBSHcomServiceContext::ErrorCode() const +{ + return mErrorCode; +} + +inline uint16_t UBSHcomServiceContext::OpCode() const +{ + return mOpCode; +} + +inline void *UBSHcomServiceContext::MessageData() const +{ + return mData; +} + +inline uint32_t UBSHcomServiceContext::MessageDataLen() const +{ + return mDataLen; +} + +inline void UBSHcomServiceContext::Invalidate() +{ + if (mDataType == MEM_POOL_DATA && mData != nullptr) { + free(mData); + mData = nullptr; + mDataType = DataType::INVALID_DATA; + } + mCh.Set(nullptr); +} + +} +} +#endif // HCOM_CONTEXT_H \ No newline at end of file diff --git a/src/service_v2/api/hcom_service_def.h b/src/service_v2/api/hcom_service_def.h new file mode 100644 index 0000000000000000000000000000000000000000..df21c7aef6fe5e4f3db2740232d41da059a9692c --- /dev/null +++ b/src/service_v2/api/hcom_service_def.h @@ -0,0 +1,223 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_API_HCOM_CLASSES_H_ +#define HCOM_API_HCOM_CLASSES_H_ + +#include +#include +#include +#include "hcom.h" +#include "hcom_def.h" +#include "hcom_num_def.h" + +namespace ock { +namespace hcom { + +constexpr uint16_t MAX_MULTI_RAIL_NUM = 4; + +using SerResult = int; +using UBSHcomDriverSecInfoProvider = std::function; +using UBSHcomDriverSecInfoValidator = std::function; +using UBSHcomWorkerMode = UBSHcomNetDriverWorkingMode; +using UBSHcomChTypeIndex = uint16_t; +using UBSHcomCipherSuite = UBSHcomNetCipherSuite; +using UBSHcomMemoryRegionPtr = UBSHcomNetMemoryRegionPtr; + +enum class UBSHcomChannelBrokenPolicy : uint8_t { + BROKEN_ALL, /* when one ep broken, all eps broken */ + RECONNECT, /* when one ep broken, try re-connect first. If re-connect fail, broken all eps */ + KEEP_ALIVE, /* when one ep broken, keep left eps alive until all eps broken */ +}; + +enum class UBSHcomClientPollingMode : uint8_t { + WORKER_POLL = 0, + SELF_POLL_BUSY = 1, + SELF_POLL_EVENT = 2, + UNKNOWN = 255, +}; + +enum class UBSHcomChannelCallBackType : uint8_t { + CHANNEL_FUNC_CB, + CHANNEL_GLOBAL_CB, +}; + +enum class UBSHcomOobType : uint8_t { + TCP, + UDS, +}; + +enum class UBSHcomSecType : uint8_t { + NET_SEC_DISABLED, + NET_SEC_VALID_ONE_WAY, + NET_SEC_VALID_TWO_WAY, +}; +struct UBSHcomRequest { + void *address = nullptr; /* pointer of data */ + uint32_t size = 0; /* size of data */ + uint64_t key = 0; + uint16_t opcode = 0; /* operation code of request */ + + UBSHcomRequest() = default; + UBSHcomRequest(void *addr, uint32_t sz, uint16_t op) : address(addr), size(sz), opcode(op) {} +}; + +struct UBSHcomResponse { + void *address = nullptr; /* pointer of data */ + uint32_t size = 0; /* size of data */ + int16_t errorCode = 0; /* error code of response */ + + UBSHcomResponse() = default; + UBSHcomResponse(void *addr, uint32_t sz) : address(addr), size(sz) {} +}; + +struct UBSHcomSglRequest { + UBSHcomRequest *iov = nullptr; + uint16_t iovCount = 0; +}; + +struct UBSHcomMemoryKey { + uint64_t keys[4]; + uint64_t tokens[4]; +}; + +struct UBSHcomOneSideRequest { + uintptr_t lAddress = 0; + uintptr_t rAddress = 0; + UBSHcomMemoryKey lKey; + UBSHcomMemoryKey rKey; + uint32_t size = 0; +}; + +struct UBSHcomOneSideSglRequest { + UBSHcomOneSideRequest *iov = nullptr; + uint16_t iovCount = 0; +}; + +struct UBSHcomReplyContext { + uintptr_t rspCtx = 0; + int16_t errorCode = 0; + UBSHcomReplyContext() = default; + UBSHcomReplyContext(uintptr_t ctx, int16_t errCode) : rspCtx(ctx), errorCode(errCode) {} +}; + +struct UBSHcomIov { + void *address = nullptr; + uint32_t size = 0; +}; + +struct UBSHcomServiceOptions { + uint32_t maxSendRecvDataSize = 1024; // 发送数据块最大值 + uint16_t workerGroupId = 0; // group id of the worker group, must increment from 0 and be unique + uint16_t workerGroupThreadCount = 1; // worker线程数,如果设置为0的话,不启动worker线程 + UBSHcomWorkerMode workerGroupMode = NET_BUSY_POLLING; // worker线程工作模式,默认busy_polling + int8_t workerThreadPriority = 0; // 线程优先级[-20,19],19优先级最低,-20优先级最高,同nice值 + std::pair workerGroupCpuIdsRange = {UINT32_MAX, UINT32_MAX}; // default not bind +}; + +struct UBSHcomConnectOptions { + uint16_t clientGroupId = 0; // worker group id of client + uint16_t serverGroupId = 0; // worker group id of server + uint8_t linkCount = 1; // actual link count of the channel + UBSHcomClientPollingMode mode = UBSHcomClientPollingMode::WORKER_POLL; + UBSHcomChannelCallBackType cbType = UBSHcomChannelCallBackType::CHANNEL_FUNC_CB; + std::string payload; +}; + +struct UBSHcomMultiRailOptions { + uint32_t threshold = 8192; // threshold of multirail + bool enable = true; // multi switch +}; + +struct UBSHcomTlsOptions { + UBSHcomTLSCaCallback caCb = nullptr; + UBSHcomTLSCertificationCallback cfCb = nullptr; + UBSHcomTLSPrivateKeyCallback pkCb = nullptr; + UBSHcomPskUseSessionCb pskUseCb = nullptr; + UBSHcomPskFindSessionCb pskFindCb = nullptr; + UBSHcomTlsVersion tlsVersion = UBSHcomTlsVersion::TLS_1_3; + UBSHcomCipherSuite netCipherSuite = UBSHcomCipherSuite::AES_GCM_128; + bool enableTls = true; +}; + +struct UBSHcomConnSecureOptions { + UBSHcomDriverSecInfoProvider provider = nullptr; + UBSHcomDriverSecInfoValidator validator = nullptr; + uint16_t magic = 256; + uint8_t version = 0; + UBSHcomNetDriverSecType secType = UBSHcomNetDriverSecType::NET_SEC_DISABLED; +}; + +struct UBSHcomHeartBeatOptions { + uint16_t heartBeatIdleSec = 60; // 发送心跳保活消息间隔时间 + uint16_t heartBeatProbeTimes = 7; // 发送心跳探测失败/没收到回复重试次数,超了认为连接已经断开 + uint16_t heartBeatProbeIntervalSec = 2; // 发送心跳后再次发送时间 +}; + +enum class UBSHcomFlowCtrlLevel : uint8_t { + HIGH_LEVEL_BLOCK, /* spin-wait by busy loop */ + LOW_LEVEL_BLOCK, /* full sleep */ +}; + +struct UBSHcomFlowCtrlOptions { + uint16_t intervalTimeMs = 0; + uint64_t thresholdByte = 0; + UBSHcomFlowCtrlLevel flowCtrlLevel = UBSHcomFlowCtrlLevel::LOW_LEVEL_BLOCK; +}; + +struct UBSHcomTwoSideThreshold { + uint32_t splitThreshold = UINT32_MAX; // UBC 专用。此值表示拆包发送的阈值,也可以当做拆包发送时每个小包的 + // 最大长度(含额外头部). 一般将其配置成小于等于 SegSize 的值,可配置范围 + // 为 [128, maxSendRecvDataSize]. 特别的配置成 UINT32_MAX 会禁用拆包功能。 + uint32_t rndvThreshold = UINT32_MAX; // rndv阈值,请求长度大于等于该值,则启用RNDV。 +}; + +class UBSHcomRegMemoryRegion { +public: + inline void GetMemoryKey(UBSHcomMemoryKey &mrKey) + { + for (uint32_t i = 0; i < mHcomMrs.size(); i++) { + if (i >= MAX_MULTI_RAIL_NUM) { + break; + } + mrKey.keys[i] = mHcomMrs[i]->GetLKey(); + mrKey.tokens[i] = reinterpret_cast(mHcomMrs[i]->GetMemorySeg()); + } + } + + inline uintptr_t GetAddress() + { + if (mHcomMrs.empty() || mHcomMrs[0] == nullptr) { + return 0; + } + return mHcomMrs[0]->GetAddress(); + } + + inline uint64_t GetSize() + { + if (mHcomMrs.empty() || mHcomMrs[0] == nullptr) { + return 0; + } + return mHcomMrs[0]->Size(); + } + + inline std::vector& GetHcomMrs() + { + return mHcomMrs; + } + +private: + std::vector mHcomMrs; +}; +} +} +#endif // HCOM_API_HCOM_CLASSES_H_ \ No newline at end of file diff --git a/src/service_v2/net_param_validator.h b/src/service_v2/net_param_validator.h new file mode 100644 index 0000000000000000000000000000000000000000..c05002197644d25facd0cfdd13d1db8d3c7345f2 --- /dev/null +++ b/src/service_v2/net_param_validator.h @@ -0,0 +1,164 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_PARAM_VAL_H +#define HCOM_PARAM_VAL_H + +#include +#include +#include +#include "hcom.h" +#include "hcom_def.h" +#include "hcom_log.h" +#include "hcom_num_def.h" +#include "hcom_service.h" +#include "net_common.h" +#include "service_common.h" + +namespace ock { +namespace hcom { + +#define VALIDATE_PARAM_RET(funcName, ...) \ + do { \ + if (NN_UNLIKELY(!funcName##Check(__VA_ARGS__))) { \ + NN_LOG_ERROR("Invalid parameter!"); \ + return SER_INVALID_PARAM; \ + } \ + } while (0) \ + +#define VALIDATE_PARAM(funcName, ...) (funcName##Check(__VA_ARGS__)) + +inline bool BindCheck(const std::string &url, const UBSHcomServiceNewChannelHandler &handler) +{ + if (NN_UNLIKELY(url.empty())) { + NN_LOG_ERROR("Invalid url: " << url); + return false; + } + if (NN_UNLIKELY(handler == nullptr)) { + NN_LOG_ERROR("UBSHcomServiceNewChannelHandler is nullptr"); + return false; + } + return true; +} + +inline bool TlsOptionsCheck(const UBSHcomTlsOptions &opt) +{ + if (NN_UNLIKELY(!opt.enableTls)) { + return true; + } + + if (NN_UNLIKELY(opt.caCb == nullptr)) { + NN_LOG_ERROR("UBSHcomTLSCaCallback of UBSHcomTlsOptions is nullptr"); + return false; + } + + if (NN_UNLIKELY(opt.cfCb == nullptr)) { + NN_LOG_ERROR("UBSHcomTLSCertificationCallback of UBSHcomTlsOptions is nullptr"); + return false; + } + + if (NN_UNLIKELY(opt.pkCb == nullptr)) { + NN_LOG_ERROR("UBSHcomTLSPrivateKeyCallback of UBSHcomTlsOptions is nullptr"); + return false; + } + + return true; +} + +inline bool SerConnInfoCheck(const SerConnInfo &connInfo) +{ + if (NN_UNLIKELY(connInfo.totalLinkCount == NN_NO0 || connInfo.totalLinkCount > NN_NO64)) { + NN_LOG_ERROR("Invalid total link count " << static_cast(connInfo.totalLinkCount) << + " for connect, make sure range in [1, 64]"); + return false; + } + + if (NN_UNLIKELY(connInfo.options.linkCount == NN_NO0 || connInfo.options.linkCount > NN_NO16)) { + NN_LOG_ERROR("Invalid link count " << connInfo.options.linkCount << + " for connect, make sure range in [1, 16]"); + return false; + } + + if (NN_UNLIKELY(connInfo.index >= connInfo.totalLinkCount)) { + NN_LOG_ERROR("Invalid conn index " << connInfo.index << ", total ep size " << connInfo.totalLinkCount << + " for connecting"); + return false; + } + + return true; +} + +inline bool ConnectOptionsCheck(const UBSHcomConnectOptions &opt) +{ + if (NN_UNLIKELY(opt.linkCount == NN_NO0 || opt.linkCount > NN_NO16)) { + NN_LOG_ERROR("Invalid link count " << static_cast(opt.linkCount) << + " for connect, make sure range in [1, 16]"); + return false; + } + if (NN_UNLIKELY(opt.mode != UBSHcomClientPollingMode::WORKER_POLL && + opt.mode != UBSHcomClientPollingMode::SELF_POLL_BUSY && + opt.mode != UBSHcomClientPollingMode::SELF_POLL_EVENT)) { + NN_LOG_ERROR("Invalid polling mode " << static_cast(opt.mode)); + return false; + } + return true; +} + +inline bool RequestCheck(const UBSHcomRequest &req) +{ + if (NN_UNLIKELY(req.address == nullptr)) { + NN_LOG_ERROR("Invalid request as address of request is nullptr"); + return false; + } + if (NN_UNLIKELY(req.size == NN_NO0)) { + NN_LOG_ERROR("Invalid request as size of request is zeri"); + return false; + } + return true; +} + +inline bool OneSideRequestCheck(const UBSHcomOneSideRequest &req) +{ + if (NN_UNLIKELY(req.size == NN_NO0)) { + NN_LOG_ERROR("NetServiceRequest.size is invalid"); + return false; + } + if (NN_UNLIKELY(req.lAddress == 0)) { + NN_LOG_ERROR("NetServiceRequest.lAddress is invalid"); + return false; + } + return true; +} + +inline bool ReplyCheck(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req, bool selfPoll) +{ + if (NN_UNLIKELY(selfPoll)) { + NN_LOG_ERROR("Self poll is not support reply"); + return false; + } + + if (NN_UNLIKELY(ctx.rspCtx == 0)) { + NN_LOG_ERROR("Invalid reply param as rspCtx is 0"); + return false; + } + if (NN_UNLIKELY(req.address == nullptr)) { + NN_LOG_ERROR("Invalid reply param as address of req is nullptr"); + return false; + } + if (NN_UNLIKELY(req.size <= 0)) { + NN_LOG_ERROR("Invalid reply param as size of req is negative"); + return false; + } + return true; +} +} +} +#endif // HCOM_PARAM_VAL_H \ No newline at end of file diff --git a/src/service_v2/service.cpp b/src/service_v2/service.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d3cde58f4d738a5bcb6a84cfc71c45623f9e426 --- /dev/null +++ b/src/service_v2/service.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "hcom_service.h" + +#include +#include +#include +#include + +#include "hcom.h" +#include "hcom_log.h" +#include "hcom_def.h" +#include "hcom_err.h" +#include "hcom_service_def.h" +#include "service_imp.h" +#include "net_param_validator.h" + +namespace ock { +namespace hcom { + +using namespace ock::hcom; + +static std::map g_serviceMap; +static std::mutex g_mutex; + +static inline bool HcomServiceCreateCheck(const UBSHcomServiceOptions &opt) +{ + if (NN_UNLIKELY(opt.maxSendRecvDataSize == 0)) { + NN_LOG_ERROR("Invalid maxSendDataSize: " << opt.maxSendRecvDataSize); + return false; + } + if (NN_UNLIKELY(opt.workerGroupMode != NET_BUSY_POLLING + && opt.workerGroupMode != NET_EVENT_POLLING)) { + NN_LOG_ERROR("Invalid workerGroupMode: " << static_cast(opt.workerGroupMode)); + return false; + } + if (NN_UNLIKELY(opt.workerThreadPriority < NN_NOF20 || opt.workerThreadPriority > NN_NO19)) { + NN_LOG_ERROR("Invalid workerThreadPriority: " << opt.workerThreadPriority << ", must be [-20, 19]"); + return false; + } + return true; +} + +UBSHcomService *UBSHcomService::Create(UBSHcomServiceProtocol t, const std::string &name, + const UBSHcomServiceOptions &opt) +{ + if (NN_UNLIKELY(!HcomServiceCreateCheck(opt))) { + NN_LOG_ERROR("invalid options for service create"); + return nullptr; + } + + if (name.length() > NN_NO64) { + NN_LOG_ERROR("Invalid param, name length must be less than " << NN_NO64); + return nullptr; + } + std::lock_guard locker(g_mutex); + auto iter = g_serviceMap.find(name); + if (iter != g_serviceMap.end()) { + return iter->second; + } + + UBSHcomService *service = new (std::nothrow) HcomServiceImp(t, name, opt); + if (service == nullptr) { + NN_LOG_ERROR("failed to create netServiceImp for service"); + return nullptr; + } + + SerResult result = HcomServiceGlobalObject::Initialize(); + if (NN_UNLIKELY(result != SER_OK)) { + delete service; + service = nullptr; + NN_LOG_ERROR("Failed to create serviceNetServiceGlobalObject initialize "); + return nullptr; + } + + g_serviceMap.emplace(name, service); + service->IncreaseRef(); + return service; +} + +int32_t UBSHcomService::Destroy(const std::string &name) +{ + std::lock_guard locker(g_mutex); + auto iter = g_serviceMap.find(name); + if (NN_UNLIKELY(iter == g_serviceMap.end())) { + NN_LOG_ERROR("Failed to destroy service, because service is not found or does not exist"); + return SER_ERROR; + } + + UBSHcomService *service = iter->second; + if (service == nullptr) { + NN_LOG_ERROR("Failed to destroy service, because service empty"); + return SER_ERROR; + } + int32_t res = service->DoDestroy(name); + if (NN_UNLIKELY(res != SER_OK)) { + NN_LOG_ERROR("Failed to destroy service, DoDestroy failed"); + return res; + } + HcomServiceGlobalObject::UnInitialize(); + g_serviceMap.erase(iter); + service->DecreaseRef(); + return SER_OK; +} + +} +} \ No newline at end of file diff --git a/src/service_v2/service_callback.h b/src/service_v2/service_callback.h new file mode 100644 index 0000000000000000000000000000000000000000..b2ba1f13b14eddc1501aa2735011c5d06d3c5c8d --- /dev/null +++ b/src/service_v2/service_callback.h @@ -0,0 +1,347 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_SERVICE_V2_SERVICE_CALLBACK_H_ +#define HCOM_SERVICE_V2_SERVICE_CALLBACK_H_ + +#include "api/hcom_service_channel.h" +#include "service_ctx_store.h" +#include "service_common.h" +#include "net_monotonic.h" + +namespace ock { +namespace hcom { +class SerTimerListHeader; + +enum class HcomAsyncCBState : uint8_t { + CBS_INIT = 0, + CBS_FINISHED = 1, + CBS_TIMEOUT = 2, +}; + +enum class HcomAsyncCBType : uint8_t { + CBS_IO = 0, + CBS_CHANNEL_BROKEN = 1, +}; + +class HcomServiceTimer { +public: + UBSHcomChannel *mChannel = nullptr; /* used for build UBSHcomServiceContext */ + HcomServiceCtxStore *mCtxStore = nullptr; /* manager memory and seqNo */ + uint64_t mTimeout = 0; /* absolute timeout compare to current system time */ + uintptr_t mCallback = 0; /* callback obj address */ + uint32_t mSeqNo = 0; /* seq no for find query map */ + HcomAsyncCBType mType = HcomAsyncCBType::CBS_IO; /* callback type */ + HcomAsyncCBState mState = HcomAsyncCBState::CBS_INIT; /* atomic status to handle the trace condition between timeout + * handle thread and polling thread */ +public: + inline uint32_t SeqNo() const + { + return mSeqNo; + } + + inline void SeqNo(uint32_t seqNo) + { + mSeqNo = seqNo; + } + + inline HcomAsyncCBState State() const + { + return mState; + } + + inline uint64_t Timeout() const + { + return mTimeout; + } + + inline uintptr_t Callback() const + { + return mCallback; + } + + inline void TimeoutDump() const + { + if (mType == HcomAsyncCBType::CBS_IO) { + if (mChannel == nullptr) { + NN_LOG_WARN("IO timeout, seq no " << mSeqNo); + } else { + NN_LOG_WARN("IO timeout, seq no " << mSeqNo << " in channel id " << mChannel->GetId()); + } + } + } + + inline void EraseSeqNo() const + { + NN_ASSERT_LOG_RETURN_VOID(mCtxStore != nullptr); + + class HcomServiceTimer *timer = nullptr; + if (NN_UNLIKELY(mCtxStore->GetSeqNoAndRemove(mSeqNo, timer) != SER_OK)) { + HcomSeqNo dumpSeq(mSeqNo); + NN_LOG_ERROR("Failed to erase " << dumpSeq.ToString()); + return; + } + + if (NN_UNLIKELY(timer != this)) { + HcomSeqNo dumpSeq(mSeqNo); + NN_LOG_ERROR(dumpSeq.ToString() << " erase wrong timer"); + return; + } + } + + inline bool EraseSeqNoWithRet() const + { + NN_ASSERT_LOG_RETURN(mCtxStore != nullptr, false); + + class HcomServiceTimer *timer = nullptr; + if (NN_UNLIKELY(mCtxStore->GetSeqNoAndRemove(mSeqNo, timer) != SER_OK)) { + return false; + } + + /* first time: before CAS, flat buff = valid address(this); after CAS, flat buff = 0, timer = valid address + second time: before CAS, flat buff = 0; after CAS, flat buff = 0, timer = 0 */ + if (NN_UNLIKELY(timer != this)) { + return false; + } + + return true; + } + + inline bool IsFinished() const + { + return mState == HcomAsyncCBState::CBS_FINISHED; + } + + /* + * @brief Mark the CB wrapper to finished, which should be called by polling thread or user caller thread + * + * @return true if mark the state from init to FINISHED state + * otherwise it is timeout + */ + inline void MarkFinished() + { + mState = HcomAsyncCBState::CBS_FINISHED; + } + + inline void RunCallBack(UBSHcomServiceContext &ctx) + { + if (mCallback != 0) { + auto callback = reinterpret_cast(mCallback); + mCallback = 0; + callback->Run(ctx); + } + } + + inline void DeleteCallBack() + { + if (mCallback != 0) { + auto callback = reinterpret_cast(mCallback); + mCallback = 0; + delete callback; + } + } + + /* + * @brief Mark the CB wrapper to timeout, which should be called by timeout thread + * + * @return true if mark the state from init to FINISHED state + * otherwise it is timeout + */ + inline void MarkTimeout() + { + mState = HcomAsyncCBState::CBS_TIMEOUT; + } + + bool IsTimeOut() const + { + // if mTimeout is 0, this timer will never timeout + if (mTimeout == 0) { + return false; + } + if (NetMonotonic::TimeSec() > mTimeout) { + return true; + } + return false; + } + + HcomServiceTimer(UBSHcomChannel *ch, HcomServiceCtxStore *ctxStore, int32_t t, uintptr_t cb, HcomAsyncCBType type) + : mChannel(ch), mCtxStore(ctxStore), mCallback(cb), mType(type) + { + // if t < 0, it means never timeout, so leave mTimeout as 0 + if (t >= 0) { + mTimeout = NetMonotonic::TimeSec() + static_cast(t); + } + + if (mChannel != nullptr) { + mChannel->IncreaseRef(); + } + + OBJ_GC_INCREASE(HcomServiceTimer); + } + + HcomServiceTimer() + { + OBJ_GC_INCREASE(HcomServiceTimer); + } + + ~HcomServiceTimer() {} + +public: + inline void IncreaseRef() + { + __sync_fetch_and_add(&mRefCount, 1); + } + + inline void DecreaseRef() + { + int32_t tmpCnt = __sync_sub_and_fetch(&mRefCount, 1); + if (tmpCnt == 0) { + if (mChannel != nullptr) { + mChannel->DecreaseRef(); + } + + if (mCtxStore != nullptr) { + mCtxStore->Return(this); + } + + OBJ_GC_DECREASE(HcomServiceTimer); + } + } + + inline int32_t GetRef() + { + return __sync_sub_and_fetch(&mRefCount, 0); + } + + friend class SerTimerListHeader; + +private: + int32_t mRefCount = 0; + class HcomServiceTimer *mPrev = nullptr; + class HcomServiceTimer *mNext = nullptr; +}; + +class SerTimerListHeader { +public: + SerTimerListHeader() = default; + + /* + * @brief add timer ctx in linked list + * @note increase ref + */ + inline void AddTimerCtx(HcomServiceTimer *timer) + { + if (NN_LIKELY(timer != nullptr)) { + // bi-direction linked list, 4 step to insert to head + timer->mPrev = &mTimerCtx; + mLock.Lock(); + // head -><- first -><- second -><- third -> nullptr + // insert into the head place + timer->mNext = mTimerCtx.mNext; + if (mTimerCtx.mNext != nullptr) { + mTimerCtx.mNext->mPrev = timer; + } + mTimerCtx.mNext = timer; + ++mCtxCount; + mLock.Unlock(); + timer->IncreaseRef(); + } + } + + /* + * @brief remove timer ctx in linked list + * @note if remove success, decrease ref + */ + inline void RemoveTimerCtx(HcomServiceTimer *timer) + { + if (NN_LIKELY(timer != nullptr)) { + // bi-direction linked list, 4 step to remove one + mLock.Lock(); + + // repeat remove + if (timer->mPrev == nullptr) { + mLock.Unlock(); + return; + } + + // head-><- first -><- second -><- third -> nullptr + timer->mPrev->mNext = timer->mNext; + + if (timer->mNext != nullptr) { + timer->mNext->mPrev = timer->mPrev; + } + --mCtxCount; + + timer->mPrev = nullptr; + timer->mNext = nullptr; + mLock.Unlock(); + timer->DecreaseRef(); + } + } + + /* + * @brief get timer ctx in linked list + * @note outside need decrease ref + */ + inline void GetTimerCtx(std::vector &remainCtx) + { + HcomServiceTimer *timer = nullptr; + HcomServiceTimer *next = nullptr; + remainCtx.clear(); + remainCtx.reserve(mCtxCount); + + mLock.Lock(); + // head -> first -><- second -><- third -> nullptr + timer = mTimerCtx.mNext; + mTimerCtx.mNext = nullptr; + mCtxCount = 0; + + while (timer != nullptr) { + next = timer->mNext; + timer->mNext = nullptr; + timer->mPrev = nullptr; + remainCtx.emplace_back(timer); + timer = next; + } + mLock.Unlock(); + } + + inline uint32_t GetCtxCount() + { + mLock.Lock(); + auto tmpCnt = mCtxCount; + mLock.Unlock(); + return tmpCnt; + } + +public: + HcomServiceTimer mTimerCtx {}; + NetSpinLock mLock; + uint32_t mCtxCount = 0; +}; + +class HcomServiceTimerCompare { +public: + bool operator () (HcomServiceTimer *&a, HcomServiceTimer *&b) const + { + if (a->Timeout() > b->Timeout()) { + return true; + } else if (a->Timeout() == b->Timeout()) { + return a->SeqNo() > b->SeqNo(); + } else { + return false; + } + } +}; + +} +} +#endif // HCOM_SERVICE_V2_SERVICE_CALLBACK_H_ \ No newline at end of file diff --git a/src/service_v2/service_channel_imp.cpp b/src/service_v2/service_channel_imp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0bcfacac0a9e30b1d551652ef19605a31718000a --- /dev/null +++ b/src/service_v2/service_channel_imp.cpp @@ -0,0 +1,2359 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "service_channel_imp.h" + +#include + +#include "hcom_log.h" +#include "hcom_err.h" +#include "hcom_num_def.h" +#include "hcom_service_channel.h" +#include "net_param_validator.h" + +namespace ock { +namespace hcom { + +constexpr uint16_t RECON_DELAY_ERASE_TIME = 60; +constexpr uint16_t DEFAULT_DELAY_ERASE_TIME = 1; + + +SerResult HcomChannelImp::Initialize(std::vector &ep, uintptr_t ctxMemPool, + uintptr_t periodicMgr, uintptr_t pgTable) +{ + std::lock_guard locker(mMgrMutex); + if (!mChState.Compare(UBSHcomChannelState::CH_NEW)) { + return SER_OK; + } + + if (NN_UNLIKELY(ep.size() == 0) || NN_UNLIKELY(ep.size() > NN_NO64)) { + NN_LOG_ERROR("Invalid ep vector, size is " << ep.size() << " should in [1-64]."); + return SER_INVALID_PARAM; + } + + auto header = new (std::nothrow) SerTimerListHeader; + if (header == nullptr) { + NN_LOG_ERROR("Failed to create timer list header"); + return SER_NEW_OBJECT_FAILED; + } + mTimerList = reinterpret_cast(header); + + auto ctxMemPoolPtr = reinterpret_cast(ctxMemPool); + if (NN_UNLIKELY(ctxMemPoolPtr == nullptr)) { + NN_LOG_ERROR("Invalid ctx store ptr " << ctxMemPool); + ForceUnInitialize(); + return SER_INVALID_PARAM; + } + ctxMemPoolPtr->IncreaseRef(); + mCtxMemPool = ctxMemPool; + + HcomServiceCtxStore *ctxStore = new (std::nothrow) HcomServiceCtxStore(NN_NO2097152, ctxMemPoolPtr, mProtocol); + if (NN_UNLIKELY(ctxStore == nullptr)) { + NN_LOG_ERROR("Create ctx store failed"); + ForceUnInitialize(); + return SER_NEW_OBJECT_FAILED; + } + + SerResult ret = ctxStore->Initialize(); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Init ctx store failed " << ret); + delete ctxStore; + ctxStore = nullptr; + ForceUnInitialize(); + return SER_NEW_OBJECT_FAILED; + } + ctxStore->IncreaseRef(); + mCtxStore = ctxStore; + + auto periodicMgrPtr = reinterpret_cast(periodicMgr); + if (NN_UNLIKELY(periodicMgrPtr == nullptr)) { + NN_LOG_ERROR("Invalid periodic mgr ptr " << periodicMgr); + ForceUnInitialize(); + return SER_INVALID_PARAM; + } + periodicMgrPtr->IncreaseRef(); + mPeriodicMgr = periodicMgr; + + auto NetPgTablePtr = reinterpret_cast(pgTable); + if (NN_UNLIKELY(NetPgTablePtr == nullptr)) { + NN_LOG_ERROR("Invalid pgTable ptr is null"); + ForceUnInitialize(); + return SER_INVALID_PARAM; + } + NetPgTablePtr->IncreaseRef(); + mPgtable = pgTable; + + ret = InitializeEp(ep); + if (NN_UNLIKELY(ret != SER_OK)) { + ForceUnInitialize(); + return ret; + } + + CheckAndUpdateThreshold(); + mChState.Set(UBSHcomChannelState::CH_ESTABLISHED); + return SER_OK; +} + +void HcomChannelImp::CheckAndUpdateThreshold() +{ + auto rndvThreshold = HcomEnv::RndvThreshold(); + // 环境变量 HCOM_ENABLE_SPLIT_SEND 仅能够为 0 或者 1, 其他情况都认定为 0. + const long enabled = NetFunc::NN_GetLongEnv("HCOM_ENABLE_SPLIT_SEND", 0, 1, 0); + if (!enabled) { + if (!mEnableMrCache) { + NN_LOG_WARN("Unable to set rndv threshold because mEnableMrCache is false, SplitSend threshold " << + mUserSplitSendThreshold << ", Rndv Threshold is: " << mRndvThreshold); + return; + } + mRndvThreshold = rndvThreshold; + NN_LOG_INFO("SplitSend (UBC only) enabled with threshold " << UINT32_MAX + << ", Rndv Threshold is: " << mRndvThreshold); + return; + } + + if (rndvThreshold < NN_NO65536) { + NN_LOG_WARN("The threshold of split send cannot be greater than the threshold of rndv! Split send threshold: " + << NN_NO65536 << " Rndv threshold: " << rndvThreshold); + return; + } + + mUserSplitSendThreshold = NN_NO65536 - sizeof(UBSHcomNetTransHeader) - sizeof(UBSHcomFragmentHeader); + if (!mEnableMrCache) { + NN_LOG_WARN("Unable to set rndv threshold because mEnableMrCache is false "); + } else { + mRndvThreshold = rndvThreshold; + } + NN_LOG_INFO("SplitSend (UBC only) enabled with threshold " << NN_NO65536 << ", Rndv Threshold is: " << + mRndvThreshold); +} + +SerResult HcomChannelImp::InitializeEp(std::vector &ep) +{ + if (ep.empty() || ep[0] == nullptr) { + NN_LOG_WARN("try to initialize empty ep"); + return SER_OK; + } + + mLocalIp = ep[0]->LocalIp(); + mEpInfo = new (std::nothrow) EpInfo; + if (mEpInfo == nullptr) { + NN_LOG_ERROR("Create ep info failed"); + return SER_NEW_OBJECT_FAILED; + } + + mEpInfo->epSize = ep.size(); + for (int i = 0; i < mEpInfo->epSize; i++) { + mEpInfo->epArr[i] = ep.at(i).Get(); + mEpInfo->epArr[i]->IncreaseRef(); + + ServiceEpState state = SER_EP_ESTABLISHED; + if (mOptions.selfPoll) { + state = SER_EP_ESTABLISHED_UNOCCUPIED; + } + mEpInfo->epState[i].Set(state); + } + SetEpUpCtx(); + + if (NN_UNLIKELY(!AllEpEstablished())) { + UnSetEpUpCtx(); + NN_LOG_ERROR("Failed to check ep state, some of them are broken during connecting, channel id " << mOptions.id); + return SER_EP_BROKEN_DURING_CONNECTING; + } + + return SER_OK; +} + +void HcomChannelImp::UnInitialize() +{ + std::lock_guard locker(mMgrMutex); + if (mChState.Compare(UBSHcomChannelState::CH_DESTROY)) { + return; + } + + if (NN_UNLIKELY(mEpInfo == nullptr)) { + mChState.Set(CH_CLOSE); + return; + } + + for (uint16_t idx = 0; idx < mEpInfo->epSize; idx++) { + if (mEpInfo->epArr[idx]->State().Compare(NEP_ESTABLISHED)) { + mEpInfo->epArr[idx]->Close(); + } + } + + mChState.Set(CH_CLOSE); +} + +void HcomChannelImp::ForceUnInitialize() +{ + if (mOptions.rateLimit != 0) { + auto rateLimit = reinterpret_cast(mOptions.rateLimit); + delete rateLimit; + mOptions.rateLimit = 0; + } + + if (NN_LIKELY(mCtxStore != nullptr)) { + mCtxStore->DecreaseRef(); + mCtxStore = nullptr; + } + + auto ctxMemPool = reinterpret_cast(mCtxMemPool); + if (NN_LIKELY(ctxMemPool != nullptr)) { + ctxMemPool->DecreaseRef(); + mCtxMemPool = 0; + } + + auto periodicMgrPtr = reinterpret_cast(mPeriodicMgr); + if (NN_LIKELY(periodicMgrPtr != nullptr)) { + periodicMgrPtr->DecreaseRef(); + mPeriodicMgr = 0; + } + + NetPgTable *pgTable = reinterpret_cast(mPgtable); + if (NN_LIKELY(pgTable != nullptr)) { + pgTable->DecreaseRef(); + mPgtable = 0; + } + + auto timeHeader = reinterpret_cast(mTimerList); + if (NN_LIKELY(timeHeader != nullptr)) { + delete timeHeader; + mTimerList = 0; + } + + if (mEpInfo != nullptr) { + // unset up ctx first, avoid race condition for ep broken during connecting. + UnSetEpUpCtx(); + for (uint16_t i = 0; i < mEpInfo->epSize; i++) { + if (mEpInfo->epArr[i] != nullptr) { + mEpInfo->epArr[i]->DecreaseRef(); + } + } + delete mEpInfo; + mEpInfo = nullptr; + } + mChState.Set(UBSHcomChannelState::CH_DESTROY); +} + +std::string HcomChannelImp::ToString() +{ + std::ostringstream oss; + oss << "Connect channel id " << mOptions.id; + + if (mEpInfo == nullptr) { + oss << " error, mEpInfo is nullptr"; + return oss.str(); + } + + oss << " with " << mEpInfo->epSize << " eps :["; + for (uint16_t i = 0; i < mEpInfo->epSize; i++) { + if (mEpInfo->epArr[i] == nullptr) { + continue; + } + oss << mEpInfo->epArr[i]->Id(); + if (i != (mEpInfo->epSize - 1)) { + oss << ", "; + } + } + oss << "]"; + return oss.str(); +} + +constexpr uint32_t VALID_SEQ_NO = 0xFFFFFFFF; /* low 32 bit */ +inline void MarkOpCodeBySeqNo(uint32_t &seqNo, uintptr_t rspCtx) +{ + HcomSeqNo netSeqNo(seqNo); + netSeqNo.isResp = rspCtx == 0 ? 0 : 1; + seqNo = netSeqNo.wholeSeq; +} + +inline void MarkOpCodeBySeqNo(uint32_t &seqNo, uintptr_t rspCtx, bool originalSeqNo) +{ + if (rspCtx == 0) { + HcomSeqNo netSeqNo(seqNo); + netSeqNo.isResp = 0; + seqNo = netSeqNo.wholeSeq; + } else if (originalSeqNo) { + seqNo = rspCtx & VALID_SEQ_NO; + } else { + HcomSeqNo netSeqNo(rspCtx & VALID_SEQ_NO); + netSeqNo.isResp = 1; + seqNo = netSeqNo.wholeSeq; + } +} + +inline SerResult HcomChannelImp::AcquireSelfPollEp(UBSHcomNetEndpoint *&ep, uint32_t &index, int16_t timeout, + uint16_t dvrIdx) +{ + if (NN_UNLIKELY(!mChState.Compare(UBSHcomChannelState::CH_ESTABLISHED))) { + NN_LOG_ERROR("Channel state is not established " << static_cast(mChState.Get())); + return SER_NOT_ESTABLISHED; + } + NN_ASSERT_LOG_RETURN(mEpInfo != nullptr, NN_ERROR) + uint64_t startTimeSecond = NetMonotonic::TimeSec(); + uint64_t endTimeSecond = 0; + if (timeout > 0) { + endTimeSecond = startTimeSecond + timeout; + } else { + endTimeSecond = startTimeSecond + NN_NO8; + } + + index = __sync_fetch_and_add(&mEpChoosingIdx[dvrIdx], 1) % (mEpInfo->epSize / mDriverNum) + + dvrIdx * (mEpInfo->epSize / mDriverNum); + uint32_t count = 0; + while (!mEpInfo->epState[index].CAS(SER_EP_ESTABLISHED_UNOCCUPIED, SER_EP_ESTABLISHED_OCCUPIED)) { + index = __sync_fetch_and_add(&mEpChoosingIdx[dvrIdx], 1) % (mEpInfo->epSize / mDriverNum) + + dvrIdx * (mEpInfo->epSize / mDriverNum); + if ((++count % (mEpInfo->epSize / mDriverNum)) == 0) { + if (NN_UNLIKELY(!mChState.Compare(UBSHcomChannelState::CH_ESTABLISHED))) { + NN_LOG_ERROR("Channel is not established " << static_cast(mChState.Get())); + return SER_NOT_ESTABLISHED; + } + if (NetMonotonic::TimeSec() > endTimeSecond) { + NN_LOG_ERROR("Acquire self poll ep timeout for " << endTimeSecond - startTimeSecond << + " seconds, maybe all endpoints broken / users too much / remote side not response"); + return SER_TIMEOUT; + } + } + } + + ep = mEpInfo->epArr[index]; + if (NN_UNLIKELY(ep == nullptr)) { + NN_LOG_ERROR("Channel Id " << mOptions.id << " ep invalid"); + return SER_NOT_ESTABLISHED; + } + return SER_OK; +} + +inline void HcomChannelImp::ReleaseSelfPollEp(uint32_t index) +{ + if (NN_UNLIKELY(index >= mEpInfo->epSize)) { + NN_LOG_ERROR("Invalid index to release self poll ep in channel " + << mOptions.id); + return; + } + + if (!mEpInfo->epState[index].CAS(SER_EP_ESTABLISHED_OCCUPIED, + SER_EP_ESTABLISHED_UNOCCUPIED)) { + NN_LOG_ERROR("Channel id " << mOptions.id + << " failed to release self poll ep, state " + << mEpInfo->epState[index].Get()); + } +} + +inline SerResult HcomChannelImp::NextWorkerPollEp(UBSHcomNetEndpoint *&ep, uint16_t dvrIdx) +{ + if (NN_UNLIKELY(!mChState.Compare(UBSHcomChannelState::CH_ESTABLISHED))) { + NN_LOG_ERROR("Channel state not established " << static_cast(mChState.Get())); + return SER_NOT_ESTABLISHED; + } + + uint16_t tmpIndex = __sync_fetch_and_add(&mEpChoosingIdx[dvrIdx], 1) % (mEpInfo->epSize / mDriverNum) + + dvrIdx * (mEpInfo->epSize / mDriverNum); + uint16_t count = 0; + + while (mEpInfo->epState[tmpIndex].Compare(SER_EP_BROKEN) && + count < (mEpInfo->epSize / mDriverNum)) { + tmpIndex = (tmpIndex + 1) % (mEpInfo->epSize / mDriverNum) + dvrIdx * (mEpInfo->epSize / mDriverNum); + count++; + } + + if (NN_UNLIKELY(count > mEpInfo->epSize)) { + NN_LOG_ERROR("Channel Id " << mOptions.id << " all ep broken"); + return SER_NOT_ESTABLISHED; + } + + ep = mEpInfo->epArr[tmpIndex]; + if (NN_UNLIKELY(ep == nullptr)) { + NN_LOG_ERROR("Channel Id " << mOptions.id << " ep invalid"); + return SER_NOT_ESTABLISHED; + } + return SER_OK; +} + +inline SerResult HcomChannelImp::ResponseWorkerPollEp(uintptr_t rspCtx, UBSHcomNetEndpoint *&ep) +{ + if (NN_UNLIKELY(!mChState.Compare(CH_ESTABLISHED))) { + NN_LOG_ERROR("Channel state not established " << mChState.Get()); + return SER_NOT_ESTABLISHED; + } + + uint32_t epIndex = rspCtx >> 32; + if (NN_UNLIKELY(epIndex >= mEpInfo->epSize)) { + NN_LOG_ERROR("Invalid ep index " << epIndex << " over ep size " + << mEpInfo->epSize); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(mEpInfo->epState[epIndex].Compare(SER_EP_BROKEN))) { + NN_LOG_ERROR("Ep broken of channel id " + << mOptions.id << " , select response ep fail"); + return SER_NOT_ESTABLISHED; + } + + ep = mEpInfo->epArr[epIndex]; + if (NN_UNLIKELY(ep == nullptr)) { + NN_LOG_ERROR("Channel Id " << mOptions.id << " ep invalid"); + return SER_NOT_ESTABLISHED; + } + return SER_OK; +} + +SerResult HcomChannelImp::PrepareTimerContext(Callback *cb, int16_t timeout, TimerCtx &context) +{ + auto timerPtr = mCtxStore->GetCtxObj(); + if (NN_UNLIKELY(timerPtr == nullptr)) { + NN_LOG_ERROR("Failed to get context object from memory pool."); + return SER_NEW_OBJECT_FAILED; + } + + context.timer = new (timerPtr)HcomServiceTimer(this, mCtxStore, + timeout, reinterpret_cast(cb), HcomAsyncCBType::CBS_IO); + NResult ret = mCtxStore->PutAndGetSeqNo(context.timer, context.seqNo); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Failed to generate seqNo by context store pool."); + mCtxStore->Return(timerPtr); + return SER_NEW_OBJECT_FAILED; + } + + context.timer->IncreaseRef(); + // timer seqNo is invalid, here need update by EmplaceContext() build seqNo. + context.timer->SeqNo(context.seqNo); + + HcomPeriodicManagerPtr periodicMgrPtr = reinterpret_cast(mPeriodicMgr); + ret = periodicMgrPtr->AddTimer(context.timer); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Failed to add timer in for timeout control."); + context.timer->EraseSeqNo(); + mCtxStore->Return(timerPtr); + return ret; + } + context.timer->IncreaseRef(); + return SER_OK; +} + +void HcomChannelImp::DestroyTimerContext(TimerCtx &context) +{ + // 主动清理 TimerContext,当且仅当发送失败时才会被调用。此时仅标记它为 finished,由于它已经被放 + // 入了超时队列当中,后面由超时线程自动将其删除、回收。如果很不巧,在发生超时时此线程未被调度到, + // 那么定时器会被标记为超时,清理完全由超时线程处理。 + // + // `DeleteCallBack()` 必须要被保护起来,否则可能会发生超时线程先被调度到,之后运行定时器关联的 + // callback 的同时将 callback 删除的极限情况。这时就可能会出现运行时错误了。 + if (NN_LIKELY(context.timer->EraseSeqNoWithRet())) { + context.timer->DeleteCallBack(); + context.timer->MarkFinished(); + context.timer->DecreaseRef(); + } +} + +int32_t HcomChannelImp::Send(const UBSHcomRequest &req, const Callback *done) +{ + NN_LOG_DEBUG("[Request Send] ------ API = HcomChannelImp::Send" << ", channel id = " << mOptions.id << + ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::CALLED)); + VALIDATE_PARAM_RET(Request, req); + SerResult result = SER_OK; + uint64_t timestamp = mOptions.twoSideTimeout < 0 ? UINT64_MAX : mOptions.twoSideTimeout + NetMonotonic::TimeSec(); + do { + result = FlowControl(req.size, mOptions.twoSideTimeout, timestamp); + if (NN_UNLIKELY(SER_OK != result)) { + return result; + } + + NetTrace::TraceBegin(CHANNEL_SEND); + result = SendInner(req, done); + NetTrace::TraceEnd(CHANNEL_SEND, result); + if (NN_LIKELY(result == SER_OK)) { + return SER_OK; + } else if (result == SER_NEW_OBJECT_FAILED) { // do later::add retry result code + usleep(100UL); + continue; + } else { + break; + } + } while (NetMonotonic::TimeSec() < timestamp); + + NN_LOG_WARN("Failed to Send, error code: " << result); + return result; +} + +SerResult HcomChannelImp::SendInner(const UBSHcomRequest &req, const Callback *done) +{ + if (done == nullptr) { + return SyncSendInner(req); + } + return AsyncSendInner(req, done); +} + +static void SyncSendCbForWorkerPoll(UBSHcomServiceContext &context, HcomServiceSelfSyncParam *syncParam) +{ + if (NN_UNLIKELY(syncParam == nullptr)) { + NN_LOG_ERROR("Failed to call SyncCallback syncParam is null"); + return; + } + if (NN_UNLIKELY(context.Result() != SER_OK)) { + NN_LOG_ERROR("Channel sync send inner callback failed " << context.Result()); + } + syncParam->Result(context.Result()); + syncParam->Signal(); +} + +SerResult HcomChannelImp::SyncSendInner(const UBSHcomRequest &req) +{ + if (mOptions.selfPoll) { + return SyncSendWithSelfPoll(req); + } + + UBSHcomNetEndpoint *ep = nullptr; + SerResult result = NextWorkerPollEp(ep); + if (NN_UNLIKELY(SER_OK != result)) { + return result; + } + + const uint32_t fragmentNum = EstimateFragmentNum(req.size, true); + if (fragmentNum > 1) { + return SyncSendSplitWithWorkerPoll(ep, req, fragmentNum); + } + + HcomServiceSelfSyncParam syncParam {}; + Callback *callback = UBSHcomNewCallback(SyncSendCbForWorkerPoll, std::placeholders::_1, &syncParam); + if (NN_UNLIKELY(callback == nullptr)) { + NN_LOG_ERROR("Sync send callback is nullptr"); + return SER_NEW_OBJECT_FAILED; + } + + TimerCtx timerContext {}; + result = PrepareTimerContext(callback, mOptions.twoSideTimeout, timerContext); + if (result != SER_OK) { + delete callback; + return result; + } + + UBSHcomNetTransRequest transReq(req.address, req.size, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, timerContext.seqNo); + uint32_t userSeqNo = timerContext.seqNo; + MarkOpCodeBySeqNo(userSeqNo, NN_NO0, mRespOriginalSeqNo); + UBSHcomNetTransOpInfo transOp(userSeqNo, mOptions.twoSideTimeout); + if (NN_LIKELY(transReq.size >= mRndvThreshold)) { + result = RndvInner(ep, req, transOp, false); + } else { + NN_LOG_DEBUG("[Request Send] ------ channel id=" << mOptions.id << ", ep id=" << ep->Id() << ", seqNo=" << + transOp.seqNo << ", status=" << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM)); + result = ep->PostSend(req.opcode, transReq, transOp); + } + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync send failed " << result << " ep id " << ep->Id()); + DestroyTimerContext(timerContext); + return result; + } + + syncParam.Wait(); + return syncParam.Result(); +} + +SerResult HcomChannelImp::SyncSendSplitWithWorkerPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, + uint32_t fragmentNum) +{ + UBSHcomFragmentHeader extHeader; + extHeader.msgId = {ep->Id(), ep->NextSeq()}; + extHeader.totalLength = req.size; + extHeader.offset = 0; + + HcomServiceSelfSyncParam syncParam{}; + for (uint32_t segIndex = 0; segIndex < fragmentNum; ++segIndex) { + const uint32_t segOffset = segIndex * mUserSplitSendThreshold; + const uint64_t segSize = std::min(mUserSplitSendThreshold, req.size - segOffset); + const uintptr_t segAddr = reinterpret_cast(req.address) + segOffset; + extHeader.offset = segOffset; + + Callback *newCallback = UBSHcomNewCallback( + [segIndex, fragmentNum, &syncParam](UBSHcomServiceContext &context) { + if (NN_UNLIKELY(context.Result() != SER_OK)) { + syncParam.Result(context.Result()); + NN_LOG_ERROR("Channel sync send inner callback failed " << context.Result() << " when sending [" + << (segIndex + 1) << "/" << fragmentNum + << "]"); + } + + if (segIndex == fragmentNum - 1) { + syncParam.Signal(); + } + }, + std::placeholders::_1); + + if (NN_UNLIKELY(!newCallback)) { + NN_LOG_ERROR("Sync send malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + TimerCtx timerContext{}; + SerResult result = PrepareTimerContext(newCallback, mOptions.twoSideTimeout, timerContext); + if (result != SER_OK) { + delete newCallback; + return result; + } + + UBSHcomNetTransRequest transReq(reinterpret_cast(segAddr), segSize, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, timerContext.seqNo); + uint32_t userSeqNo = timerContext.seqNo; + MarkOpCodeBySeqNo(userSeqNo, NN_NO0, mRespOriginalSeqNo); + UBSHcomNetTransOpInfo transOp(userSeqNo, mOptions.twoSideTimeout); + NN_LOG_DEBUG("SyncSendSplitWithWorkerPoll fragment [" + << (segIndex + 1) << "/" << fragmentNum << "] begin; ep id=" << ep->Id() + << ", seqNo=" << transOp.seqNo << ", status=" + << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM) + << " req.opcode: " << req.opcode); + result = ep->PostSend(req.opcode, transReq, transOp, UBSHcomExtHeaderType::FRAGMENT, &extHeader, + sizeof(extHeader)); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync send failed " << result << " ep id " << ep->Id()); + DestroyTimerContext(timerContext); + return result; + } + + NN_LOG_DEBUG("SyncSendSplitWithWorkerPoll fragment [" << (segIndex + 1) << "/" << fragmentNum << "] end"); + } + + syncParam.Wait(); + return syncParam.Result(); +} + +auto HcomChannelImp::SpliceMessage(const UBSHcomNetRequestContext &ctx, bool isResp) + -> std::tuple +{ + const uintptr_t msgAddr = reinterpret_cast(ctx.Message()->Data()); + const uint32_t msgSize = ctx.Message()->DataLen(); + if (msgSize < sizeof(UBSHcomFragmentHeader)) { + NN_LOG_ERROR("SpliceMessage: message size is invalid, actual: " << msgSize << ", wanted at least " + << sizeof(UBSHcomFragmentHeader)); + return std::make_tuple(SpliceMessageResultType::ERROR, SER_SPLIT_INVALID_MSG, ""); + } + + // 需要拼包情况下必须包含 UBSHcomFragmentHeader 头部,此时的内存布局为: + // | UBSHcomFragmentHeader | payload | + const UBSHcomFragmentHeader *serviceHeader = reinterpret_cast(msgAddr); + const void *payload = reinterpret_cast(msgAddr + sizeof(UBSHcomFragmentHeader)); + const uint64_t payloadLen = msgSize - sizeof(UBSHcomFragmentHeader); + const UBSHcomFragmentMessageId msgId = serviceHeader->msgId; + const uint32_t totalLength = serviceHeader->totalLength; + const uint32_t offset = serviceHeader->offset; + + NN_LOG_DEBUG("SpliceMessage: msgId " << msgId << ", totalLength " << totalLength + << ", offset " << offset << ", size " << payloadLen); + + // 避免因数据在网络中被篡改而造成高内存占用 + if (totalLength >= SERVICE_MAX_TOTAL_LENGTH) { + NN_LOG_ERROR("SpliceMessage: totalLength (" << totalLength << ") is larger than the maximum (" + << SERVICE_MAX_TOTAL_LENGTH << ")"); + return std::make_tuple(SpliceMessageResultType::ERROR, SER_SPLIT_INVALID_MSG, ""); + } + + std::shared_ptr> incompleteMsg; + auto iter = mMsgReceived.end(); + + if (offset == 0) { + // 如果在短时间内 msgId 出现重复,且之前的消息还未超时仍旧存在,那么就会 + // 失败。需要修正一下 msgId 的生成算法。 + bool isInserted = false; + { + std::lock_guard lock(mMsgReceivedMutex); + std::tie(iter, isInserted) = mMsgReceived.emplace(msgId, + std::make_shared>()); + if (NN_LIKELY(isInserted)) { + incompleteMsg = iter->second; + } else { + NN_LOG_WARN("SpliceMessage: duplicate id " << msgId << ", nothing to do."); + } + } + + if (isInserted) { + // 为防止分片无限堆积,如果在有限时间内无法完成拼包则将分片全部丢弃。 + Callback *cb = UBSHcomNewCallback( + [iter](UBSHcomServiceContext &context, NetRef ch) { + // 超时由单独的超时线程处理,可能会并发地进行 iter 删除(此处)与对 iter->second + // 的复制(Worker 线程接收到另一个分片)。为了避免 Worker 线程访问无效内存,使用 + // std::shared_ptr 增加引用计数延长接收 buffer 的生命周期。 + NN_LOG_WARN("SpliceMessage: Timed-out. message can't be spliced in time."); + std::lock_guard lock(ch->mMsgReceivedMutex); + ch->mMsgReceived.erase(iter); + }, + std::placeholders::_1, NetRef{this}); + if (!cb) { + NN_LOG_ERROR("SpliceMessage malloc callback failed"); + return std::make_tuple(SpliceMessageResultType::ERROR, SER_NEW_OBJECT_FAILED, ""); + } + + // 由于默认 NetServiceOpInfo.timeout 初始化为 -1, 用户在创建 OpInfo 时可能会忘记修改,最终会导致定 + // 时器永不超时,与预期不符;同时 NetServiceOpInfo 的另一个构造函数的 timeout 参数默认值为 0, 只要 + // 定时器线程被 OS 调度并处理定时器就会立即超时。这两种情况都会导致此处创建的定时器不起作用。 + if (ctx.Header().timeout <= 0) { + NN_LOG_WARN("SpliceMessage: the timer will not work correctly! Check NetServiceOpInfo.timeout field, " + "current value: " + << ctx.Header().timeout); + } + + TimerCtx context{}; + auto result = PrepareTimerContext(cb, ctx.Header().timeout, context); + if (result != SER_OK) { + NN_LOG_ERROR("Prepare timer context failed when creating timer for SpliceMessage"); + delete cb; + return std::make_tuple(SpliceMessageResultType::ERROR, result, ""); + } + + // 在拼包完成后,通过 seqNo 索引到对应定时器 + incompleteMsg->first = context.seqNo; + // 首包,分配足够大的内存 + incompleteMsg->second.resize(totalLength); + } + } else { + std::lock_guard lock(mMsgReceivedMutex); + + iter = mMsgReceived.find(msgId); + if (NN_LIKELY(iter != mMsgReceived.end())) { + incompleteMsg = iter->second; + } else { + NN_LOG_WARN("SpliceMessage: the first fragment is lost/timed-out? msgId " << msgId); + } + } + + if (!incompleteMsg) { + return std::make_tuple(SpliceMessageResultType::ERROR, SER_ERROR, ""); + } + + auto *pmsg = &incompleteMsg->second; + + // 极小概率出现 msg2 的首包丢失,同时又正好 msgId 相同: + // | msg1 first | last | + // | msg2 first | ... | last | + if (NN_UNLIKELY(offset > pmsg->size())) { + NN_LOG_ERROR("SpliceMessage: the fragment is from another msg."); + return std::make_tuple(SpliceMessageResultType::ERROR, SER_SPLIT_INVALID_MSG, ""); + } + + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(reinterpret_cast(pmsg->data()) + offset), + pmsg->size() - offset, payload, payloadLen) != EOK)) { + NN_LOG_ERROR("SpliceMessage: the payload is too large."); + return std::make_tuple(SpliceMessageResultType::ERROR, SER_SPLIT_INVALID_MSG, ""); + } + + // 最后一个包,拼包结束. 如果 totalLength 在网络传输中被修改,可能永远无法取 + // 等,依赖超时机制将这种异常 Msg 清除。 + if (offset + payloadLen == totalLength) { + NN_LOG_DEBUG("SpliceMessage: complete! id " << msgId); + + // 由 std::shared_ptr 保证,pmsg 一定有效 + std::string msg = std::move(*pmsg); + + HcomServiceTimer* timer = nullptr; + if (NN_UNLIKELY(mCtxStore->GetSeqNoAndRemove(incompleteMsg->first, timer) == SER_OK)) { + timer->MarkFinished(); + timer->DeleteCallBack(); + timer->DecreaseRef(); + + std::lock_guard lock(mMsgReceivedMutex); + mMsgReceived.erase(iter); + } + + return std::make_tuple(SpliceMessageResultType::OK, SER_OK, std::move(msg)); + } + return std::make_tuple(SpliceMessageResultType::INDETERMINATE, SER_OK, ""); +} + +SerResult HcomChannelImp::AsyncSendSplitWithWorkerPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, + uint32_t fragmentNum, const Callback *done) +{ + UBSHcomFragmentHeader extHeader; + extHeader.msgId = {ep->Id(), ep->NextSeq()}; + extHeader.totalLength = req.size; + extHeader.offset = 0; + + for (uint32_t segIndex = 0; segIndex < fragmentNum; ++segIndex) { + const uint32_t segOffset = segIndex * mUserSplitSendThreshold; + const uint64_t segSize = std::min(mUserSplitSendThreshold, req.size - segOffset); + const uintptr_t segAddr = reinterpret_cast(req.address) + segOffset; + + extHeader.offset = segOffset; + + UBSHcomNetTransRequest transReq(reinterpret_cast(segAddr), segSize, sizeof(SerTransContext)); + + // 用户回调函数 done 只在最后被调用一次。由于每次都创建了一个新的 + // callback,需要在 PostSend 失败时将 callback 删除。 + Callback *callback = UBSHcomNewCallback( + [segIndex, fragmentNum, done](UBSHcomServiceContext &context) { + NN_LOG_DEBUG("Run CB [" << (segIndex + 1) << "/" << fragmentNum << "], result " << context.Result()); + if (segIndex == fragmentNum - 1) { + const_cast(done)->Run(context); + } + }, + std::placeholders::_1); + if (!callback) { + NN_LOG_ERROR("Async send malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + TimerCtx context{}; + auto result = PrepareTimerContext(callback, mOptions.twoSideTimeout, context); + if (result != SER_OK) { + NN_LOG_ERROR("Prepare timer context failed when sending [" << (segIndex + 1) << "/" << fragmentNum << "]"); + delete callback; + return result; + } + + SetServiceTransCtx(transReq.upCtxData, context.seqNo); + uint32_t newSeqNo = context.seqNo; + MarkOpCodeBySeqNo(newSeqNo, 0); + + UBSHcomNetTransOpInfo transOp(newSeqNo, mOptions.twoSideTimeout); + NN_LOG_DEBUG("AsyncSendSplitWithWorkerPoll fragment [" + << (segIndex + 1) << "/" << fragmentNum << "] begin; ep id=" << ep->Id() + << ", seqNo=" << transOp.seqNo << ", status=" + << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM)); + result = ep->PostSend(req.opcode, transReq, transOp, UBSHcomExtHeaderType::FRAGMENT, &extHeader, + sizeof(extHeader)); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("AsyncSendSplitWithWorkerPoll Send fragment [" << (segIndex + 1) << "/" << fragmentNum + << "] failed"); + + DestroyTimerContext(context); + return result; + } + NN_LOG_DEBUG("AsyncSendSplitWithWorkerPoll fragment [" << (segIndex + 1) << "/" << fragmentNum << "] end"); + } + + return SER_OK; +} + +SerResult HcomChannelImp::AsyncSendInner(const UBSHcomRequest &req, const Callback *done) +{ + if (mOptions.selfPoll) { + NN_LOG_ERROR("Failed to invoke async send with self poll, not support"); + return SER_INVALID_PARAM; + } + + UBSHcomNetEndpoint *ep = nullptr; + SerResult result = NextWorkerPollEp(ep); + if (NN_UNLIKELY(SER_OK != result)) { + return result; + } + + const uint32_t fragmentNum = EstimateFragmentNum(req.size, true); + if (fragmentNum > 1) { + return AsyncSendSplitWithWorkerPoll(ep, req, fragmentNum, done); + } + + UBSHcomNetTransRequest transReq(req.address, req.size, sizeof(SerTransContext)); + uint32_t newSeqNo = 0; + TimerCtx context {}; + result = PrepareTimerContext(const_cast(done), mOptions.twoSideTimeout, context); + if (result != SER_OK) { + return result; + } + SetServiceTransCtx(transReq.upCtxData, context.seqNo); + + // if rspCtx is valid, seqNo is changed now by mark + newSeqNo = context.seqNo; + MarkOpCodeBySeqNo(newSeqNo, 0); + UBSHcomNetTransOpInfo transOp(newSeqNo, mOptions.twoSideTimeout); + if (NN_LIKELY(transReq.size >= mRndvThreshold)) { + result = RndvInner(ep, req, transOp, false); + } else { + NN_LOG_DEBUG("[Request Send] ------ channel id=" << mOptions.id << ", ep id=" << ep->Id() << ", seqNo=" << + transOp.seqNo << ", status=" << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM)); + result = ep->PostSend(req.opcode, transReq, transOp); + } + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel async send failed " << result << " ep id " << ep->Id()); + DestroyTimerContext(context); + return result; + } + return result; +} + +SerResult HcomChannelImp::SyncSendSplitWithSelfPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, + uint32_t fragmentNum, uint32_t index) +{ + UBSHcomFragmentHeader extHeader; + extHeader.msgId = {ep->Id(), ep->NextSeq()}; + extHeader.totalLength = req.size; + extHeader.offset = 0; + + for (uint32_t segIndex = 0; segIndex < fragmentNum; ++segIndex) { + const uint32_t segOffset = segIndex * mUserSplitSendThreshold; + const uint64_t segSize = std::min(mUserSplitSendThreshold, req.size - segOffset); + const uintptr_t segAddr = reinterpret_cast(req.address) + segOffset; + extHeader.offset = segOffset; + + UBSHcomNetTransRequest transReq(reinterpret_cast(segAddr), segSize, 0); + UBSHcomNetTransOpInfo transOp(SelfPollNextSeqNo(), mOptions.twoSideTimeout); + NN_LOG_DEBUG("SyncSendSplitWithSelfPoll fragment [" + << (segIndex + 1) << "/" << fragmentNum << "] begin; ep id=" << ep->Id() + << ", seqNo=" << transOp.seqNo << ", status=" + << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM)); + auto result = ep->PostSend(req.opcode, transReq, transOp, UBSHcomExtHeaderType::FRAGMENT, &extHeader, + sizeof(extHeader)); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel SyncSendSplitWithSelfPoll failed " << result << " ep id " << ep->Id() << ", [" + << (segIndex + 1) << "/" << fragmentNum << "]"); + ReleaseSelfPollEp(index); + return result; + } + NN_LOG_DEBUG("SyncSendSplitWithSelfPoll fragment [" << (segIndex + 1) << "/" << fragmentNum << "] end"); + + int32_t timeout = (mOptions.twoSideTimeout == 0 ? -1 : static_cast(mOptions.twoSideTimeout)); + result = ep->WaitCompletion(timeout); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync send split wait complete failed " << result << " ep id " << ep->Id()); + ReleaseSelfPollEp(index); + return result; + } + } + + ReleaseSelfPollEp(index); + NN_LOG_DEBUG("[Request Send] ------ ep id=" << ep->Id() << ", multiple seq no, status=" + << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::SUCCESS)); + return SER_OK; +} + +SerResult HcomChannelImp::SyncSendWithSelfPoll(const UBSHcomRequest &req) +{ + UBSHcomNetEndpoint *ep = nullptr; + uint32_t index = 0; + auto result = AcquireSelfPollEp(ep, index, mOptions.twoSideTimeout); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync send acquire ep failed " << result << " channel id " << mOptions.id); + return result; + } + + const uint32_t fragmentNum = EstimateFragmentNum(req.size); + if (fragmentNum > 1) { + return SyncSendSplitWithSelfPoll(ep, req, fragmentNum, index); + } + + UBSHcomNetTransRequest transReq(req.address, req.size, 0); + UBSHcomNetTransOpInfo transOp(SelfPollNextSeqNo(), mOptions.twoSideTimeout); + NN_LOG_DEBUG("[Request Send] ------ channel id=" << mOptions.id << ", ep id=" << ep->Id() << ", seqNo=" << + transOp.seqNo << ", status=" << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM)); + result = ep->PostSend(req.opcode, transReq, transOp); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync send failed " << result << " ep id " << ep->Id()); + ReleaseSelfPollEp(index); + return result; + } + + /* timeout = 0 will poll cq empty in self polling */ + int32_t timeout = (mOptions.twoSideTimeout == 0 ? -1 : static_cast(mOptions.twoSideTimeout)); + result = ep->WaitCompletion(timeout); + ReleaseSelfPollEp(index); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync send wait complete failed " << result << " ep id " << ep->Id()); + return result; + } + + return SER_OK; +} + +int32_t HcomChannelImp::Call(const UBSHcomRequest &req, UBSHcomResponse &rsp, const Callback *done) +{ + NN_LOG_DEBUG("[Request Send] ------ API = HcomChannelImp::Call" << ", channel id = " << mOptions.id << + ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::CALLED)); + VALIDATE_PARAM(Request, req); + SerResult result = SER_OK; + uint64_t timestamp = mOptions.twoSideTimeout < 0 ? UINT64_MAX : mOptions.twoSideTimeout + NetMonotonic::TimeSec(); + do { + result = FlowControl(req.size, mOptions.twoSideTimeout, timestamp); + if (NN_UNLIKELY(result != SER_OK)) { + return result; + } + result = CallInner(req, rsp, done); + if (NN_LIKELY(result == SER_OK)) { + return SER_OK; + } else if (result == SER_NEW_OBJECT_FAILED) { + usleep(100UL); + continue; + } else { + break; + } + } while (NetMonotonic::TimeSec() < timestamp); + + NN_LOG_ERROR("Failed to sync call " << result); + return result; +} + +SerResult HcomChannelImp::CallInner(const UBSHcomRequest &req, UBSHcomResponse &rsp, const Callback *done) +{ + if (done == nullptr) { + return SyncCallInner(req, rsp); + } + return AsyncCallInner(req, done); +} + +NResult HcomChannelImp::SendFds(int fds[], uint32_t len) +{ + NN_ASSERT_LOG_RETURN(mEpInfo != nullptr, SER_ERROR) + NN_ASSERT_LOG_RETURN(mEpInfo->epArr[0] != nullptr, SER_ERROR) + return mEpInfo->epArr[0]->SendFds(fds, len); +} + +NResult HcomChannelImp::ReceiveFds(int fds[], uint32_t len, int32_t timeoutSec) +{ + NN_ASSERT_LOG_RETURN(mEpInfo != nullptr, SER_ERROR) + NN_ASSERT_LOG_RETURN(mEpInfo->epArr[0] != nullptr, SER_ERROR) + return mEpInfo->epArr[0]->ReceiveFds(fds, len, timeoutSec); +} + +static void SyncCallCbForWorkerPoll(UBSHcomServiceContext &context, UBSHcomResponse *rsp, + HcomServiceSelfSyncParam *syncParam) +{ + if (NN_UNLIKELY(rsp == nullptr || syncParam == nullptr)) { + NN_LOG_ERROR("Failed to call SyncCallback as rspOpInfo, rsp or syncParam is null"); + return; + } + HcomServiceMessage message(context.MessageData(), context.MessageDataLen()); + syncParam->Result(SER_OK); + + do { + rsp->errorCode = context.ErrorCode(); + if (NN_UNLIKELY(context.Result() != SER_OK)) { + NN_LOG_ERROR("Sync call result " << context.Result() << " error"); + syncParam->Result(context.Result()); + break; + } + + if (rsp->address != nullptr) { + if (NN_UNLIKELY(message.size > rsp->size)) { + NN_LOG_ERROR("Sync call check user prepare size " << rsp->size << " less than receive size " << + message.size); + syncParam->Result(SER_RSP_SIZE_TOO_SMALL); + break; + } + if (NN_UNLIKELY(memcpy_s(rsp->address, rsp->size, message.data, message.size) != SER_OK)) { + NN_LOG_ERROR("Sync call failed to copy data"); + syncParam->Result(SER_INVALID_PARAM); + break; + } + } else { + rsp->address = malloc(message.size); + if (rsp->address == nullptr) { + NN_LOG_ERROR("Sync call malloc data size " << message.size << " failed"); + syncParam->Result(SER_NEW_MESSAGE_DATA_FAILED); + break; + } + if (NN_UNLIKELY(memcpy_s(rsp->address, message.size, message.data, message.size) != SER_OK)) { + free(rsp->address); + rsp->address = nullptr; + NN_LOG_ERROR("Sync call failed to copy data"); + syncParam->Result(SER_INVALID_PARAM); + break; + } + } + rsp->size = message.size; + } while (false); + + syncParam->Signal(); +} + +SerResult HcomChannelImp::SyncCallInner(const UBSHcomRequest &req, UBSHcomResponse &rsp, uint32_t timeOut) +{ + if (mOptions.selfPoll) { + return SyncCallWithSelfPoll(req, rsp); + } + + UBSHcomNetEndpoint *ep = nullptr; + auto result = NextWorkerPollEp(ep); + if (NN_UNLIKELY(result != SER_OK)) { + return result; + } + + const uint32_t fragmentNum = EstimateFragmentNum(req.size, true); + if (fragmentNum > 1) { + return SyncCallSplitWithWorkerPoll(ep, req, fragmentNum, rsp); + } + + /* worker poll mode */ + HcomServiceSelfSyncParam syncParam {}; + Callback *newCallback = UBSHcomNewCallback(SyncCallCbForWorkerPoll, std::placeholders::_1, &rsp, &syncParam); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("Sync call malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + TimerCtx context {}; + result = PrepareTimerContext(newCallback, timeOut == 0 ? mOptions.twoSideTimeout : timeOut, context); + if (result != SER_OK) { + delete newCallback; + return result; + } + + UBSHcomNetTransRequest transReq(req.address, req.size, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, context.seqNo, false); + + MarkOpCodeBySeqNo(context.seqNo, 0); + UBSHcomNetTransOpInfo transOp(context.seqNo, mOptions.twoSideTimeout); + if (NN_LIKELY(transReq.size >= mRndvThreshold)) { + result = RndvInner(ep, req, transOp, true); + } else { + result = ep->PostSend(req.opcode, transReq, transOp); + } + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync call send failed " << result << " ep id " << ep->Id()); + DestroyTimerContext(context); + return result; + } + + syncParam.Wait(); + return syncParam.Result(); +} + +SerResult HcomChannelImp::SyncCallSplitWithWorkerPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, + uint32_t fragmentNum, UBSHcomResponse &rsp) +{ + UBSHcomFragmentHeader extHeader; + extHeader.msgId = {ep->Id(), ep->NextSeq()}; + extHeader.totalLength = req.size; + extHeader.offset = 0; + + HcomServiceSelfSyncParam syncParam{}; + for (uint32_t segIndex = 0; segIndex < fragmentNum; ++segIndex) { + const uint32_t segOffset = segIndex * mUserSplitSendThreshold; + const uint64_t segSize = std::min(mUserSplitSendThreshold, req.size - segOffset); + const uintptr_t segAddr = reinterpret_cast(req.address) + segOffset; + extHeader.offset = segOffset; + + Callback *newCallback = UBSHcomNewCallback(SyncCallCbForWorkerPoll, std::placeholders::_1, &rsp, &syncParam); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("Sync call split malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + Callback *cb = UBSHcomNewCallback( + [segIndex, fragmentNum, newCallback](UBSHcomServiceContext &context) { + NN_LOG_DEBUG("Run CB [" << (segIndex + 1) << "/" << fragmentNum << "], result " << context.Result()); + if (segIndex == fragmentNum - 1) { + const_cast(newCallback)->Run(context); + } + }, + std::placeholders::_1); + if (!cb) { + NN_LOG_ERROR("Sync call split malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + TimerCtx context{}; + auto result = PrepareTimerContext(cb, mOptions.twoSideTimeout, context); + if (result != SER_OK) { + NN_LOG_ERROR("Prepare timer context failed when sending [" << (segIndex + 1) << "/" << fragmentNum << "]"); + delete cb; + return result; + } + + UBSHcomNetTransRequest transReq(reinterpret_cast(segAddr), segSize, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, context.seqNo, segIndex != fragmentNum - 1); + + uint32_t newSeqNo = context.seqNo; + MarkOpCodeBySeqNo(newSeqNo, 0); + UBSHcomNetTransOpInfo transOp(newSeqNo, mOptions.twoSideTimeout); + NN_LOG_DEBUG("SyncCallSplitWithWorkerPoll fragment [" + << (segIndex + 1) << "/" << fragmentNum << "] begin;ep id=" << ep->Id() + << ", seqNo=" << transOp.seqNo << " ,opCode = " << req.opcode + << ", status=" << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM)); + result = ep->PostSend(req.opcode, transReq, transOp, UBSHcomExtHeaderType::FRAGMENT, &extHeader, + sizeof(extHeader)); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("SyncCallSplitWithWorkerPoll Send fragment [" << (segIndex + 1) << "/" << fragmentNum + << "] failed"); + DestroyTimerContext(context); + return result; + } + NN_LOG_DEBUG("SyncCallSplitWithWorkerPoll fragment [" << (segIndex + 1) << "/" << fragmentNum << "] end"); + } + + syncParam.Wait(); + return syncParam.Result(); +} + +SerResult HcomChannelImp::RndvInner(UBSHcomNetEndpoint *ep, const UBSHcomRequest &req, + UBSHcomNetTransOpInfo &transOp, bool isCall) +{ + SerResult result = SER_OK; + PgTable *pgTable = reinterpret_cast(mPgtable); + // pgTable 根据地址查询start addr和end addr + PgtAddress add = reinterpret_cast(req.address); + PgtAddress reqEndAdd = reinterpret_cast(req.address) + req.size - NN_NO1; + + PgtRegion *pgtRegion = pgTable->Lookup(add); + if (pgtRegion == nullptr || !(pgtRegion->start <= add && reqEndAdd < pgtRegion->end)) { + NN_LOG_WARN("Unable to lookUp address in pgTable or req address is out of range, so not use rndv send "); + UBSHcomNetTransRequest transReq(req.address, req.size, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, transOp.seqNo, !isCall); + result = ep->PostSend(req.opcode, transReq, transOp); + } else { + UBSHcomRequest newReq{}; + newReq.address = req.address; + newReq.size = req.size; + newReq.opcode = req.opcode; + // 根据起始地址查找lKey + newReq.key = pgtRegion->key; + + HcomServiceRndvMessage rndvMessage(mConnectTimestamp.GetRemoteTimestamp(mOptions.twoSideTimeout), req); + UBSHcomNetTransRequest transReq(const_cast(reinterpret_cast(&rndvMessage)), + sizeof(HcomServiceRndvMessage), sizeof(SerTransContext)); + // RNDV请求 对端必须reply 不区分send和Call + SetServiceTransCtx(transReq.upCtxData, transOp.seqNo, false); + result = ep->PostSend(ServiceV2PrivateOpcode::RNDV_CALL_OP_V2, transReq, transOp); + } + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync call Rndv send failed " << result << " ep id " << ep->Id()); + } + return result; +} + +static SerResult SyncCallCbForSelfPoll(UBSHcomNetResponseContext &ctx, UBSHcomResponse &rsp) +{ + void *data = ctx.Message()->Data(); + uint32_t dataLength = ctx.Message()->DataLen(); + + UBSHcomNetTransHeader header = ctx.Header(); + rsp.errorCode = header.errorCode; + + if (rsp.address != nullptr) { + if (dataLength <= rsp.size) { + if (NN_UNLIKELY(memcpy_s(rsp.address, rsp.size, data, dataLength) != SER_OK)) { + NN_LOG_ERROR("Failed to copy data"); + return SER_INVALID_PARAM; + } + } else { + NN_LOG_ERROR("Sync call self poll check user prepare size " << rsp.size << " less than receive size " << + dataLength); + return SER_RSP_SIZE_TOO_SMALL; + } + } else { + rsp.address = malloc(dataLength); + if (rsp.address == nullptr) { + NN_LOG_ERROR("Sync call self poll malloc data size " << dataLength << " failed"); + return SER_NEW_MESSAGE_DATA_FAILED; + } + if (NN_UNLIKELY(memcpy_s(rsp.address, dataLength, data, dataLength) != SER_OK)) { + free(rsp.address); + rsp.address = nullptr; + NN_LOG_ERROR("Failed to sync callback by copy data err"); + return SER_INVALID_PARAM; + } + } + rsp.size = dataLength; + return SER_OK; +} + +SerResult HcomChannelImp::SyncCallWithSelfPoll(const UBSHcomRequest &req, UBSHcomResponse &rsp) +{ + UBSHcomNetEndpoint *ep = nullptr; + uint32_t index = 0; + auto ret = AcquireSelfPollEp(ep, index, mOptions.twoSideTimeout); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Channel sync call acquire ep failed " << ret << " channel id " << mOptions.id); + return ret; + } + + const uint32_t fragmentNum = EstimateFragmentNum(req.size); + if (fragmentNum > 1) { + return SyncCallSplitWithSelfPoll(ep, req, fragmentNum, index, rsp); + } + + UBSHcomNetTransRequest transReq(req.address, req.size, 0); + UBSHcomNetTransOpInfo transOp(SelfPollNextSeqNo(), mOptions.twoSideTimeout); + ret = ep->PostSend(req.opcode, transReq, transOp); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Channel sync call failed " << ret << " ep id " << ep->Id()); + ReleaseSelfPollEp(index); + return ret; + } + + /* timeout = 0 will poll cq empty in self polling */ + int32_t timeout = (mOptions.twoSideTimeout == 0 ? -1 : static_cast(mOptions.twoSideTimeout)); + ret = ep->WaitCompletion(timeout); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Channel sync call wait complete failed " << ret << " ep id " << ep->Id()); + ReleaseSelfPollEp(index); + return ret; + } + + UBSHcomNetResponseContext ctx; + ret = ep->Receive(timeout, ctx); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Channel sync call receive failed " << ret << " ep id " << ep->Id()); + ReleaseSelfPollEp(index); + return ret; + } + + ret = SyncCallCbForSelfPoll(ctx, rsp); + ReleaseSelfPollEp(index); + if (NN_UNLIKELY(ret != SER_OK)) { + return ret; + } + + return SER_OK; +} + +static SerResult SyncCallbackWithSelfPoll(void *data, uint32_t dataLen, const UBSHcomNetTransHeader &header, + UBSHcomResponse &rsp) +{ + rsp.errorCode = header.errorCode; + if (rsp.address != nullptr) { + if (dataLen <= rsp.size) { + if (NN_UNLIKELY(memcpy_s(rsp.address, rsp.size, data, dataLen) != SER_OK)) { + NN_LOG_ERROR("Failed to copy data"); + return SER_INVALID_PARAM; + } + } else { + NN_LOG_ERROR("Sync call self poll check user prepare size " << rsp.size << " less than receive size " << + dataLen); + return SER_RSP_SIZE_TOO_SMALL; + } + } else { + rsp.address = malloc(dataLen); + if (rsp.address == nullptr) { + NN_LOG_ERROR("Sync call self poll malloc data size " << dataLen << " failed"); + return SER_NEW_MESSAGE_DATA_FAILED; + } + if (NN_UNLIKELY(memcpy_s(rsp.address, dataLen, data, dataLen) != SER_OK)) { + free(rsp.address); + rsp.address = nullptr; + NN_LOG_ERROR("Failed to sync callback by copy data err"); + return SER_INVALID_PARAM; + } + } + rsp.size = dataLen; + return SER_OK; +} + +SerResult HcomChannelImp::SyncCallSplitWithSelfPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, + uint32_t fragmentNum, uint32_t index, UBSHcomResponse &rsp) +{ + UBSHcomFragmentHeader extHeader; + extHeader.msgId = {ep->Id(), ep->NextSeq()}; + extHeader.totalLength = req.size; + extHeader.offset = 0; + + const int32_t timeout = (mOptions.twoSideTimeout == 0 ? -1 : static_cast(mOptions.twoSideTimeout)); + for (uint32_t segIndex = 0; segIndex < fragmentNum; ++segIndex) { + const uint32_t segOffset = segIndex * mUserSplitSendThreshold; + const uint64_t segSize = std::min(mUserSplitSendThreshold, req.size - segOffset); + const uintptr_t segAddr = reinterpret_cast(req.address) + segOffset; + extHeader.offset = segOffset; + + UBSHcomNetTransRequest msg(reinterpret_cast(segAddr), segSize, 0); + UBSHcomNetTransOpInfo transOp(SelfPollNextSeqNo(), mOptions.twoSideTimeout); + NN_LOG_DEBUG("SyncCallSplitWithSelfPoll fragment [" + << (segIndex + 1) << "/" << fragmentNum << "] begin;ep id=" << ep->Id() + << ", seqNo=" << transOp.seqNo << ", opCode = " << req.opcode + << ", status=" << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM)); + auto result = ep->PostSend(req.opcode, msg, transOp, UBSHcomExtHeaderType::FRAGMENT, &extHeader, + sizeof(extHeader)); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel SyncCallSplitWithSelfPoll failed " << result << " ep id " << ep->Id() << ", [" + << (segIndex + 1) << "/" << fragmentNum << "]"); + return result; + } + NN_LOG_DEBUG("SyncCallSplitWithSelfPoll fragment [" << (segIndex + 1) << "/" << fragmentNum << "] end"); + + result = ep->WaitCompletion(timeout); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync call split wait complete failed " << result << " ep id " << ep->Id()); + return result; + } + } + + // 同步方式拼包 + std::string acc; + void *data; + uint32_t dataLen; + UBSHcomNetResponseContext ctx; + auto result = SyncSpliceMessage(ctx, ep, timeout, acc, data, dataLen); + if (NN_UNLIKELY(result != SER_OK)) { + return result; + } + + // 对端回复的每一个小包,它们的 Header() 都是相同的. + result = SyncCallbackWithSelfPoll(data, dataLen, ctx.Header(), rsp); + ReleaseSelfPollEp(index); + if (NN_UNLIKELY(result != SER_OK)) { + return result; + } + + return SER_OK; +} + +SerResult HcomChannelImp::AsyncCallInner(const UBSHcomRequest &req, const Callback *done) +{ + if (mOptions.selfPoll) { + NN_LOG_ERROR("Failed to invoke async call with self poll, not support"); + return SER_INVALID_PARAM; + } + + UBSHcomNetEndpoint *ep = nullptr; + auto result = NextWorkerPollEp(ep); + if (NN_UNLIKELY(result != SER_OK)) { + return result; + } + + const uint32_t fragmentNum = EstimateFragmentNum(req.size, true); + if (fragmentNum > 1) { + return AsyncCallSplitWithWorkerPoll(ep, req, fragmentNum, done); + } + + TimerCtx context {}; + result = PrepareTimerContext(const_cast(done), mOptions.twoSideTimeout, context); + if (result != SER_OK) { + return result; + } + + UBSHcomNetTransRequest transReq(req.address, req.size, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, context.seqNo, false); + + MarkOpCodeBySeqNo(context.seqNo, 0); + UBSHcomNetTransOpInfo transOp(context.seqNo, mOptions.twoSideTimeout); + if (NN_LIKELY(transReq.size >= mRndvThreshold)) { + result = RndvInner(ep, req, transOp, true); + } else { + result = ep->PostSend(req.opcode, transReq, transOp); + } + + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel async call send failed " << result << " ep id " << ep->Id()); + DestroyTimerContext(context); + return result; + } + return SER_OK; +} + +SerResult HcomChannelImp::AsyncCallSplitWithWorkerPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, + uint32_t fragmentNum, const Callback *done) +{ + UBSHcomFragmentHeader extHeader; + extHeader.msgId = {ep->Id(), ep->NextSeq()}; + extHeader.totalLength = req.size; + extHeader.offset = 0; + + for (uint32_t segIndex = 0; segIndex < fragmentNum; ++segIndex) { + const uint32_t segOffset = segIndex * mUserSplitSendThreshold; + const uint64_t segSize = std::min(mUserSplitSendThreshold, req.size - segOffset); + const uintptr_t segAddr = reinterpret_cast(req.address) + segOffset; + extHeader.offset = segOffset; + + Callback *cb = UBSHcomNewCallback( + [segIndex, fragmentNum, done](UBSHcomServiceContext &context) { + NN_LOG_DEBUG("Run CB [" << (segIndex + 1) << "/" << fragmentNum << "], result " + << context.Result()); + if (segIndex == fragmentNum - 1) { + const_cast(done)->Run(context); + } + }, + std::placeholders::_1); + if (!cb) { + NN_LOG_ERROR("AsyncCallInner malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + TimerCtx context{}; + auto result = PrepareTimerContext(cb, mOptions.twoSideTimeout, context); + if (result != SER_OK) { + NN_LOG_ERROR("Prepare timer context failed when sending [" << (segIndex + 1) << "/" << fragmentNum << "]"); + delete cb; + return result; + } + + UBSHcomNetTransRequest transReq(reinterpret_cast(segAddr), segSize, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, context.seqNo, segIndex != fragmentNum - 1); + + uint32_t newSeqNo = context.seqNo; + MarkOpCodeBySeqNo(newSeqNo, 0); + UBSHcomNetTransOpInfo transOp(newSeqNo, mOptions.twoSideTimeout); + NN_LOG_DEBUG("AsyncCallSplitWithWorkerPoll fragment [" + << (segIndex + 1) << "/" << fragmentNum << "] begin;ep id=" << ep->Id() + << ", seqNo=" << transOp.seqNo << " ,opCode = " << req.opcode + << ", status=" << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM)); + result = ep->PostSend(req.opcode, transReq, transOp, UBSHcomExtHeaderType::FRAGMENT, &extHeader, + sizeof(extHeader)); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("AsyncCallSplitWithWorkerPoll Send fragment [" << (segIndex + 1) << "/" << fragmentNum + << "] failed"); + DestroyTimerContext(context); + return result; + } + NN_LOG_DEBUG("AsyncCallSplitWithWorkerPoll fragment [" << (segIndex + 1) << "/" << fragmentNum << "] end"); + } + + return SER_OK; +} + +int32_t HcomChannelImp::Reply(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req, const Callback *done) +{ + NN_LOG_DEBUG("[Request Send] ------ API = HcomChannelImp::Reply" << ", channel id = " << mOptions.id << + ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::CALLED)); + VALIDATE_PARAM(Reply, ctx, req, mOptions.selfPoll); + SerResult ret = SER_OK; + uint64_t timestamp = mOptions.twoSideTimeout < 0 ? UINT64_MAX : mOptions.twoSideTimeout + NetMonotonic::TimeSec(); + do { + ret = FlowControl(req.size, mOptions.twoSideTimeout, timestamp); + if (NN_UNLIKELY(ret != SER_OK)) { + return ret; + } + ret = ReplyInner(ctx, req, done); + if (NN_LIKELY(ret == SER_OK)) { + return SER_OK; + } else if (ret == SER_NEW_OBJECT_FAILED) { // do later::add retry result code + usleep(100UL); + continue; + } else { + break; + } + } while (NetMonotonic::TimeSec() < timestamp); + NN_LOG_WARN("Failed to reply, error code: " << ret); + return ret; +} + +SerResult HcomChannelImp::ReplyInner(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req, const Callback *done) +{ + if (done == nullptr) { + return SyncReplyInner(ctx, req); + } + return AsyncReplyInner(ctx, req, done); +} + +SerResult HcomChannelImp::SyncReplyInner(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req) +{ + SerResult res = SER_OK; + UBSHcomNetEndpoint *ep = nullptr; + res = ResponseWorkerPollEp(ctx.rspCtx, ep); + if (NN_UNLIKELY(res != SER_OK)) { + NN_LOG_ERROR("Failed to select ep " << res); + return res; + } + + const uint32_t fragmentNum = EstimateFragmentNum(req.size); + if (fragmentNum > 1) { + return SyncReplySplitWithWorkerPoll(ctx, ep, req, fragmentNum); + } + + HcomServiceSelfSyncParam syncParam {}; + Callback *newCallback = UBSHcomNewCallback( + [&syncParam](UBSHcomServiceContext &context) { + if (NN_UNLIKELY(context.Result() != SER_OK)) { + NN_LOG_WARN("Channel sync reply inner callback failed " << context.Result()); + } + syncParam.Result(context.Result()); + syncParam.Signal(); + }, + std::placeholders::_1); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("Sync send callback is nullptr"); + return SER_NEW_OBJECT_FAILED; + } + + TimerCtx context {}; + auto result = PrepareTimerContext(newCallback, mOptions.twoSideTimeout, context); + if (result != SER_OK) { + delete newCallback; + return result; + } + + UBSHcomNetTransRequest transReq(req.address, req.size, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, context.seqNo); + + uint32_t userSeqNo = context.seqNo; + MarkOpCodeBySeqNo(userSeqNo, ctx.rspCtx, mRespOriginalSeqNo); + UBSHcomNetTransOpInfo transOp(userSeqNo, mOptions.twoSideTimeout, ctx.errorCode, 0); + result = ep->PostSend(req.opcode, transReq, transOp); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync send failed " << result << " ep id " << ep->Id()); + DestroyTimerContext(context); + return result; + } + + syncParam.Wait(); + return syncParam.Result(); +} + +SerResult HcomChannelImp::SyncReplySplitWithWorkerPoll(const UBSHcomReplyContext &ctx, UBSHcomNetEndpoint *&ep, + const UBSHcomRequest &req, uint32_t fragmentNum) +{ + UBSHcomFragmentHeader extHeader; + extHeader.msgId = {ep->Id(), ep->NextSeq()}; + extHeader.totalLength = req.size; + extHeader.offset = 0; + + HcomServiceSelfSyncParam syncParam{}; + for (uint32_t segIndex = 0; segIndex < fragmentNum; ++segIndex) { + const uint32_t segOffset = segIndex * mUserSplitSendThreshold; + const uint64_t segSize = std::min(mUserSplitSendThreshold, req.size - segOffset); + const uintptr_t segAddr = reinterpret_cast(req.address) + segOffset; + extHeader.offset = segOffset; + + Callback *newCallback = UBSHcomNewCallback( + [segIndex, fragmentNum, &syncParam](UBSHcomServiceContext &context) { + if (NN_UNLIKELY(context.Result() != SER_OK)) { + syncParam.Result(context.Result()); + NN_LOG_ERROR("Channel sync reply inner callback failed " << context.Result() << " when sending [" + << (segIndex + 1) << "/" << fragmentNum + << "]"); + } + + if (segIndex == fragmentNum - 1) { + syncParam.Signal(); + } + }, + std::placeholders::_1); + + if (NN_UNLIKELY(!newCallback)) { + NN_LOG_ERROR("Sync reply malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + TimerCtx context{}; + SerResult result = PrepareTimerContext(newCallback, mOptions.twoSideTimeout, context); + if (result != SER_OK) { + delete newCallback; + return result; + } + + UBSHcomNetTransRequest transReq(reinterpret_cast(segAddr), segSize, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, context.seqNo); + uint32_t userSeqNo = context.seqNo; + MarkOpCodeBySeqNo(userSeqNo, ctx.rspCtx, mRespOriginalSeqNo); + UBSHcomNetTransOpInfo transOp(userSeqNo, mOptions.twoSideTimeout, ctx.errorCode, 0); + NN_LOG_DEBUG("SyncReplySplitWithWorkerPoll fragment [" + << (segIndex + 1) << "/" << fragmentNum << "] begin; ep id=" << ep->Id() + << ", seqNo=" << transOp.seqNo << ", status=" + << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM) + << " req.opcode: " << req.opcode); + result = ep->PostSend(req.opcode, transReq, transOp, UBSHcomExtHeaderType::FRAGMENT, &extHeader, + sizeof(extHeader)); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync reply failed " << result << " ep id " << ep->Id()); + DestroyTimerContext(context); + return result; + } + + NN_LOG_DEBUG("SyncReplySplitWithWorkerPoll fragment [" << (segIndex + 1) << "/" << fragmentNum << "] end"); + } + + syncParam.Wait(); + return syncParam.Result(); +} + +SerResult HcomChannelImp::AsyncReplyInner(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req, + const Callback *done) +{ + SerResult res = SER_OK; + UBSHcomNetEndpoint *ep = nullptr; + res = ResponseWorkerPollEp(ctx.rspCtx, ep); + if (NN_UNLIKELY(res != SER_OK)) { + NN_LOG_ERROR("Failed to select ep " << res); + return res; + } + + const uint32_t fragmentNum = EstimateFragmentNum(req.size); + if (fragmentNum > 1) { + return AsyncReplySplitWithWorkerPoll(ctx, ep, req, fragmentNum, done); + } + + UBSHcomNetTransRequest transReq(req.address, req.size, sizeof(SerTransContext)); + uint32_t newSeqNo = 0; + SetServiceTransCtx(transReq.upCtxData, const_cast(done)); + MarkOpCodeBySeqNo(newSeqNo, ctx.rspCtx, mRespOriginalSeqNo); + UBSHcomNetTransOpInfo transOp(newSeqNo, mOptions.twoSideTimeout, ctx.errorCode, 0); + return ep->PostSend(req.opcode, transReq, transOp); +} + +SerResult HcomChannelImp::AsyncReplySplitWithWorkerPoll(const UBSHcomReplyContext &ctx, UBSHcomNetEndpoint *&ep, + const UBSHcomRequest &req, uint32_t fragmentNum, const Callback *done) +{ + UBSHcomFragmentHeader extHeader; + extHeader.msgId = {ep->Id(), ep->NextSeq()}; + extHeader.totalLength = req.size; + extHeader.offset = 0; + + for (uint32_t segIndex = 0; segIndex < fragmentNum; ++segIndex) { + const uint32_t segOffset = segIndex * mUserSplitSendThreshold; + const uint64_t segSize = std::min(mUserSplitSendThreshold, req.size - segOffset); + const uintptr_t segAddr = reinterpret_cast(req.address) + segOffset; + extHeader.offset = segOffset; + + Callback *cb = UBSHcomNewCallback( + [segIndex, fragmentNum, done](UBSHcomServiceContext &context) { + NN_LOG_DEBUG("Run CB [" << (segIndex + 1) << "/" << fragmentNum << "], result " << context.Result()); + if (segIndex == fragmentNum - 1) { + const_cast(done)->Run(context); + } + }, + std::placeholders::_1); + if (!cb) { + NN_LOG_ERROR("Async send malloc callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + TimerCtx context{}; + auto result = PrepareTimerContext(cb, mOptions.twoSideTimeout, context); + if (result != SER_OK) { + NN_LOG_ERROR("Prepare timer context failed when sending [" << (segIndex + 1) << "/" << fragmentNum << "]"); + delete cb; + return result; + } + + UBSHcomNetTransRequest transReq(reinterpret_cast(segAddr), segSize, sizeof(SerTransContext)); + SetServiceTransCtx(transReq.upCtxData, context.seqNo); + uint32_t newSeqNo = 0; + MarkOpCodeBySeqNo(newSeqNo, ctx.rspCtx, mRespOriginalSeqNo); + UBSHcomNetTransOpInfo transOp(newSeqNo, mOptions.twoSideTimeout, ctx.errorCode, 0); + NN_LOG_DEBUG("AsyncReplySplitWithWorkerPoll fragment [" + << (segIndex + 1) << "/" << fragmentNum << "] begin; ep id=" << ep->Id() + << ", seqNo=" << transOp.seqNo << ", status=" + << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_HCOM)); + result = ep->PostSend(req.opcode, transReq, transOp, UBSHcomExtHeaderType::FRAGMENT, &extHeader, + sizeof(extHeader)); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("AsyncReplySplitWithWorkerPoll Send fragment [" << (segIndex + 1) << "/" << fragmentNum + << "] failed"); + + DestroyTimerContext(context); + return result; + } + NN_LOG_DEBUG("AsyncReplySplitWithWorkerPoll fragment [" << (segIndex + 1) << "/" << fragmentNum << "] end"); + } + + return SER_OK; +} + +SerResult HcomChannelImp::OneSideSyncWithSelfPoll(const UBSHcomOneSideRequest &request, bool isWrite) +{ + SerResult result = SER_OK; + uint32_t size = request.size; + uint32_t offset = 0; + uint32_t remain = request.size; + uint16_t multiNum = (mOptions.enableMultiRail && request.size > mOptions.multiRailThresh) ? mDriverNum : 1; + if (mOptions.enableMultiRail && request.size > mOptions.multiRailThresh) { + NN_LOG_INFO("Multirail not supported in oneside sync with self poll, using single rail."); + } + for (uint32_t i = 0; i < multiNum; i++) { + UBSHcomNetEndpoint *ep = nullptr; + uint32_t index = 0; + result = AcquireSelfPollEp(ep, index, mOptions.oneSideTimeout, i); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync read acquire ep failed " << result << " channel id " << mOptions.id << + " in rail " << i); + return result; + } + + CalculateOffsetAndSize(request, ep, remain, offset, size); + UBSHcomNetTransRequest req(request.lAddress + offset, request.rAddress + offset, + request.lKey.keys[ep->GetDevIndex()], request.rKey.keys[ep->GetPeerDevIndex()], size, 0); + req.srcSeg = reinterpret_cast(request.lKey.tokens[ep->GetDevIndex()]); + if (isWrite) { + result = ep->PostWrite(req); + } else { + result = ep->PostRead(req); + } + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync read failed " << result << " ep id " << ep->Id() << " in rail " << i); + ReleaseSelfPollEp(index); + return result; + } + /* The PostRead operation uses a thread-local variable to record the RDMA context for the current thread. + Thus, the next PostRead operation must be performed after executing WaitCompletion. */ + result = ep->WaitCompletion(mOptions.oneSideTimeout == 0 ? -1 : mOptions.oneSideTimeout); + ReleaseSelfPollEp(index); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync read wait complete failed " << result << " ep id " << ep->Id()); + return result; + } + } + return SER_OK; +} + +SerResult HcomChannelImp::PrepareCallback(HcomServiceSelfSyncParam& syncParam, TimerCtx &syncContext) +{ + Callback *newCallback = UBSHcomNewCallback([&syncParam](UBSHcomServiceContext &context) { + if (NN_UNLIKELY(context.Result() != SER_OK)) { + NN_LOG_ERROR("Prepare callback failed " << context.Result()); + } + syncParam.Result(context.Result()); + syncParam.Signal(); + }, + std::placeholders::_1); + if (NN_UNLIKELY(newCallback == nullptr)) { + NN_LOG_ERROR("Sync read callback is nullptr"); + return SER_NEW_OBJECT_FAILED; + } + + SerResult result = PrepareTimerContext(newCallback, mOptions.oneSideTimeout, syncContext); + if (result != SER_OK) { + delete newCallback; + return result; + } + return SER_OK; +} + + +SerResult HcomChannelImp::OneSideSyncWithWorkerPoll(const UBSHcomOneSideRequest &request, bool isWrite) +{ + SerResult ret = SER_OK; + uint32_t size = request.size; + uint32_t offset = 0; + uint32_t remain = request.size; + uint16_t multiNum = (mOptions.enableMultiRail && request.size > mOptions.multiRailThresh) ? mDriverNum : 1; + std::vector paramVec(multiNum, HcomServiceSelfSyncParam()); + uint32_t idx = 0; + NN_LOG_DEBUG("Multirail enabled: " << (multiNum != 1) << ", rail num: " << multiNum); + do { + UBSHcomNetEndpoint *ep = nullptr; + auto result = NextWorkerPollEp(ep, idx); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("NextWorkerPollEp failed, result:" << result <<", idx: "<< idx); + ret = result; + break; + } + TimerCtx syncContext {}; + result = PrepareCallback(paramVec[idx], syncContext); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("PrepareCallback failed, result:" << result << ", idx:" << idx); + ret = result; + break; + } + + CalculateOffsetAndSize(request, ep, remain, offset, size); + UBSHcomNetTransRequest req(request.lAddress + offset, request.rAddress + offset, + request.lKey.keys[ep->GetDevIndex()], request.rKey.keys[ep->GetPeerDevIndex()], size, + sizeof(SerTransContext)); + req.srcSeg = reinterpret_cast(request.lKey.tokens[ep->GetDevIndex()]); + SetServiceTransCtx(req.upCtxData, syncContext.seqNo); + + if (isWrite) { + result = ep->PostWrite(req); + } else { + result = ep->PostRead(req); + } + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel sync oneside failed " << result << " ep id " << ep->Id() << " in rail " << idx); + DestroyTimerContext(syncContext); + NN_LOG_ERROR("ep Post failed: " << result << ", idx:" << idx); + ret = result; + break; + } + ++idx; + } while (idx < multiNum); + + SerResult cbRet = SER_OK; + for (uint32_t i = 0; i < idx; i++) { + paramVec[i].Wait(); + if (NN_UNLIKELY((paramVec[i]).Result() != NN_OK)) { + cbRet = (paramVec[i]).Result(); + } + } + + // if ret is not ok, can not return before sem_wait because callback need paramVec, it can not free + if (cbRet != SER_OK || ret != SER_OK) { + NN_LOG_ERROR("callback or multi error, cbRet:" << cbRet << ",ret:" << ret); + return cbRet != SER_OK ? cbRet : ret; + } + return SER_OK; +} + +Callback *HcomChannelImp::GetAsyncCB(uint16_t multiNum, const Callback *done) +{ + if (multiNum > NN_NO1) { + Callback *newCallback = new (std::nothrow) AsyncClosureCallback(const_cast(done), multiNum); + if (newCallback == nullptr) { + NN_LOG_ERROR("Failed to create new callback"); + return nullptr; + } + return newCallback; + } else { + return const_cast(done); + } +} + +void HcomChannelImp::ProcessRemainCallback(Callback *cb, uint32_t remainNums) +{ + if (NN_UNLIKELY(cb == nullptr)) { + return; + } + + UBSHcomServiceContext context{}; + context.mCh.Set(nullptr); + context.mResult = SER_ERROR; + context.mEpIdxInCh = 0; + context.mSeqNo = 0; + context.mDataType = UBSHcomServiceContext::INVALID_DATA; + context.mDataLen = 0; + context.mData = nullptr; + context.mOpType = UBSHcomRequestContext::NN_INVALID_OP_TYPE; + context.mOpCode = NN_NO1024; + + for (uint32_t i = 0; i < remainNums; i++) { + cb->Run(context); + } +} + +SerResult HcomChannelImp::OneSideAsyncWithWorkerPoll(const UBSHcomOneSideRequest &request, const Callback *done, + bool isWrite) +{ + uint32_t size = request.size; + uint32_t offset = 0; + uint32_t remain = request.size; + uint16_t multiNum = (mOptions.enableMultiRail && request.size > mOptions.multiRailThresh) ? mDriverNum : 1; + + Callback *cb = GetAsyncCB(multiNum, done); + if (NN_UNLIKELY(cb == nullptr)) { + NN_LOG_ERROR("Get OneSideCB failed "); + return SER_NEW_OBJECT_FAILED; + } + + NN_LOG_DEBUG("Multirail enabled: " << (multiNum != 1) << ", rail num: " << multiNum); + for (uint32_t i = 0; i < multiNum; i++) { + UBSHcomNetEndpoint *ep = nullptr; + auto result = NextWorkerPollEp(ep, i); + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Get Ep failed " << result); + ProcessRemainCallback(cb, multiNum - i); + return result; + } + + TimerCtx readContext {}; + result = PrepareTimerContext(cb, mOptions.oneSideTimeout, readContext); + if (result != SER_OK) { + NN_LOG_ERROR("PrepareTimerContext failed " << result << " in rail " << i); + ProcessRemainCallback(cb, multiNum - i); + return result; + } + + CalculateOffsetAndSize(request, ep, remain, offset, size); + UBSHcomNetTransRequest req(request.lAddress + offset, request.rAddress + offset, + request.lKey.keys[ep->GetDevIndex()], request.rKey.keys[ep->GetPeerDevIndex()], size, + sizeof(SerTransContext)); + req.srcSeg = reinterpret_cast(request.lKey.tokens[ep->GetDevIndex()]); + SetServiceTransCtx(req.upCtxData, readContext.seqNo); + + if (isWrite) { + result = ep->PostWrite(req); + } else { + result = ep->PostRead(req); + } + if (NN_UNLIKELY(result != SER_OK)) { + NN_LOG_ERROR("Channel async read failed " << result << " ep id " << ep->Id() << " in rail " << i); + DestroyTimerContext(readContext); + return result; + } + } + return SER_OK; +} + +SerResult HcomChannelImp::OneSideInner(const UBSHcomOneSideRequest &request, const Callback *done, bool isWrite) +{ + if (mOptions.selfPoll) { + if (done == nullptr) { + return OneSideSyncWithSelfPoll(request, isWrite); + } else { + NN_LOG_ERROR("Failed to invoke async one side op with self poll, not supported"); + return SER_INVALID_PARAM; + } + } else { + if (done == nullptr) { + return OneSideSyncWithWorkerPoll(request, isWrite); + } else { + return OneSideAsyncWithWorkerPoll(request, done, isWrite); + } + } + return SER_INVALID_PARAM; +} + +int32_t HcomChannelImp::Put(const UBSHcomOneSideRequest &req, const Callback *done) +{ + NN_LOG_DEBUG("[Request Send] ------ API = HcomChannelImp::Put" << ", channel id = " << mOptions.id << + ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::CALLED)); + VALIDATE_PARAM(OneSideRequest, req); + SerResult ret = SER_OK; + uint64_t timestamp = mOptions.oneSideTimeout < 0 ? UINT64_MAX : mOptions.oneSideTimeout + NetMonotonic::TimeSec(); + do { + ret = FlowControl(req.size, mOptions.oneSideTimeout, timestamp); + if (NN_UNLIKELY(ret != SER_OK)) { + return ret; + } + + NetTrace::TraceBegin(CHANNEL_WRITE); + ret = OneSideInner(req, done, true); + NetTrace::TraceEnd(CHANNEL_WRITE, ret); + if (NN_LIKELY(ret == SER_OK)) { + return SER_OK; + } else if (ret == SER_NEW_OBJECT_FAILED) { // do later::add retry result code + usleep(100UL); + continue; + } else { + break; + } + } while (NetMonotonic::TimeSec() < timestamp); + + NN_LOG_ERROR("Failed to write " << ret); + return ret; +} + +int32_t HcomChannelImp::Get(const UBSHcomOneSideRequest &req, const Callback *done) +{ + NN_LOG_DEBUG("[Request Send] ------ API = HcomChannelImp::Get" << ", channel id = " << mOptions.id << + ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::CALLED)); + VALIDATE_PARAM(OneSideRequest, req); + SerResult ret = SER_OK; + uint64_t timestamp = mOptions.oneSideTimeout < 0 ? UINT64_MAX : mOptions.oneSideTimeout + NetMonotonic::TimeSec(); + do { + ret = FlowControl(req.size, mOptions.oneSideTimeout, timestamp); + if (NN_UNLIKELY(ret != SER_OK)) { + return ret; + } + + NetTrace::TraceBegin(CHANNEL_READ); + ret = OneSideInner(req, done, false); + NetTrace::TraceEnd(CHANNEL_READ, ret); + if (NN_LIKELY(ret == SER_OK)) { + return SER_OK; + } else if (ret == SER_NEW_OBJECT_FAILED) { // do later::add retry result code + usleep(100UL); + continue; + } else { + break; + } + } while (NetMonotonic::TimeSec() < timestamp); + + NN_LOG_ERROR("Failed to read " << ret); + return ret; +} + +int32_t HcomChannelImp::Recv(const UBSHcomServiceContext &context, uintptr_t address, uint32_t size, + const Callback *done) +{ + if (context.mDataLen != sizeof(HcomServiceRndvMessage)) { + NN_LOG_ERROR(" Received RNDV data size is incorrect, actual size " << context.mDataLen << ", expected size " << + sizeof(UBSHcomRequest)); + return SER_ERROR; + } + HcomServiceRndvMessage *rndvMessage = static_cast(context.mData); + if (rndvMessage == nullptr || rndvMessage->request.size != size) { + NN_LOG_ERROR(" Fail to get Request data or Request size " << size << " and processing size " << + (rndvMessage == nullptr ? 0 : rndvMessage->request.size) << " mismatch"); + return SER_ERROR; + } + + if (rndvMessage->IsTimeout()) { + NN_LOG_ERROR(" Fail to recv request data due to timeout"); + return SER_TIMEOUT; + } + + // 在pgTable上查询address 是否被注册 + PgTable *pgTable = reinterpret_cast(mPgtable); + uintptr_t endAddr = address + size - NN_NO1; + PgtRegion *pgtRegion = pgTable->Lookup(address); + if (pgtRegion == nullptr || !(pgtRegion->start <= address && pgtRegion->end > endAddr)) { + NN_LOG_ERROR(" Fail to lookUp address in pgTable or req address is out of range"); + return SER_ERROR; + } + + UBSHcomOneSideRequest oneSideRequest{}; + oneSideRequest.lAddress = address; + oneSideRequest.lKey.keys[0] = pgtRegion->key; + oneSideRequest.lKey.tokens[0] = pgtRegion->token; + oneSideRequest.rAddress = reinterpret_cast(rndvMessage->request.address); + oneSideRequest.rKey.keys[0] = rndvMessage->request.key; + oneSideRequest.size = size; + SerResult ret = Get(oneSideRequest, done); + if (ret != SER_OK) { + NN_LOG_ERROR("Fail to rndv read data " << ret); + return ret; + } + return SER_OK; +} + +SerResult HcomChannelImp::FlowControl(uint64_t size, int16_t timeout, uint64_t timestamp) +{ + if (mOptions.rateLimit == 0) { + return SER_OK; + } + + auto rateLimiter = reinterpret_cast(mOptions.rateLimit); + uint64_t timeoutSecond = timeout > 0 ? timestamp : NetMonotonic::TimeSec() + NN_NO10; + while (true) { + while (rateLimiter->AcquireQuota(size)) { + uint64_t newByte = rateLimiter->windowPassedByte + size; + uint64_t oldByte = rateLimiter->windowPassedByte; + if (__sync_bool_compare_and_swap(&rateLimiter->windowPassedByte, oldByte, newByte)) { + NN_LOG_TRACE_INFO("Success passed flow control size " << size << ", tid " << pthread_self()); + return SER_OK; + } + } + + if (NN_UNLIKELY(rateLimiter->InvalidateSize(size))) { + NN_LOG_ERROR("Failed to flow control by user size " << size << " over configure thresholdByte " << + rateLimiter->thresholdByte); + return SER_INVALID_PARAM; + } + + NN_LOG_TRACE_INFO("Wait start flow control size " << size << ", tid " << pthread_self()); + rateLimiter->WaitUntilNextWindow(); + NN_LOG_TRACE_INFO("Wait finish flow control size " << size << ", tid " << pthread_self()); + rateLimiter->BuildNextWindow(); + + if (NN_UNLIKELY(NetMonotonic::TimeSec() > timeoutSecond)) { + NN_LOG_ERROR("Flow control timeout, channel id " << mOptions.id << " size " << size); + return SER_TIMEOUT; + } + } + + return SER_OK; +} + +int32_t HcomChannelImp::SetFlowControlConfig(const UBSHcomFlowCtrlOptions &opt) +{ + std::lock_guard locker(mMgrMutex); + if (!mChState.Compare(UBSHcomChannelState::CH_ESTABLISHED)) { + NN_LOG_ERROR("Config flow control failed, as channel state invalid " << static_cast(mChState.Get())); + return SER_NOT_ESTABLISHED; + } + + auto rateLimit = reinterpret_cast(mOptions.rateLimit); + if (mOptions.rateLimit == 0) { + rateLimit = new (std::nothrow) RateLimiter; + if (NN_UNLIKELY(rateLimit == nullptr)) { + NN_LOG_ERROR("Failed to create rate limiter"); + return SER_INVALID_PARAM; + } + + rateLimit->level = opt.flowCtrlLevel; + rateLimit->intervalTimeMs = opt.intervalTimeMs; + rateLimit->thresholdByte = opt.thresholdByte; + rateLimit->windowEndTimeMs = NetMonotonic::TimeMs() + rateLimit->intervalTimeMs; + mOptions.rateLimit = reinterpret_cast(rateLimit); + return SER_OK; + } + + /* require:support repeat config */ + rateLimit->level = opt.flowCtrlLevel; + rateLimit->intervalTimeMs = opt.intervalTimeMs; + rateLimit->thresholdByte = opt.thresholdByte; + + return SER_OK; +} + +void HcomChannelImp::SetChannelTimeOut(int16_t oneSideTimeout, int16_t twoSideTimeout) +{ + if (oneSideTimeout < -1 || twoSideTimeout < -1) { + NN_LOG_WARN("Timeout range must be greater than or equal to -1, default value is -1"); + return; + } + mOptions.oneSideTimeout = oneSideTimeout; + mOptions.twoSideTimeout = twoSideTimeout; +} + +void HcomChannelImp::SetEpUpCtx() +{ + for (uint16_t i = 0; i < mEpInfo->epSize; i++) { + Ep2ChanUpCtx ctx(1, reinterpret_cast(this), i); + mEpInfo->epArr[i]->UpCtx(ctx.wholeUpCtx); + } +} + +void HcomChannelImp::UnSetEpUpCtx() +{ + for (uint16_t i = 0; i < mEpInfo->epSize; i++) { + mEpInfo->epArr[i]->UpCtx(0); + } +} + +bool HcomChannelImp::AllEpEstablished() +{ + for (uint16_t i = 0; i < mEpInfo->epSize; i++) { + if (mEpInfo->epState[i].Compare(SER_EP_BROKEN) || + mEpInfo->epArr[i]->State().Compare(NEP_BROKEN)) { + return false; + } + } + + return true; +} + +void HcomChannelImp::SetUuid(const std::string &uuid) +{ + mUuid = uuid; +} + +void HcomChannelImp::SetPayload(const std::string &payload) +{ + mPayload = payload; +} + +void HcomChannelImp::SetBrokenInfo(UBSHcomChannelBrokenPolicy policy, const UBSHcomServiceChannelBrokenHandler &broken) +{ + mOptions.brokenPolicy = policy; + mOptions.brokenHandler = broken; +} + +void HcomChannelImp::SetEpBroken(uint32_t index) +{ + if (mEpInfo == nullptr || index >= mEpInfo->epSize) { + return; + } + mEpInfo->epState[index].Set(SER_EP_BROKEN); +} + +void HcomChannelImp::SetChannelState(UBSHcomChannelState state) +{ + mChState.Set(state); +} + +bool HcomChannelImp::AllEpBroken() +{ + for (uint16_t i = 0; i < mEpInfo->epSize; i++) { + if (!mEpInfo->epState[i].Compare(SER_EP_BROKEN) || !mEpInfo->epArr[i]->State().Compare(NEP_BROKEN)) { + return false; + } + } + return true; +} + +bool HcomChannelImp::NeedProcessBroken() +{ + bool process = false; + if (NN_UNLIKELY(!mBrokenProcessed.compare_exchange_strong(process, true))) { + return false; + } + return true; +} + +void HcomChannelImp::ProcessIoInBroken() +{ + auto header = reinterpret_cast(mTimerList); + std::vector remainCtx; + + header->GetTimerCtx(remainCtx); + if (!remainCtx.empty()) { + NN_LOG_INFO("Channel id " << mOptions.id << " process io broken, size " << remainCtx.size()); + PROCESS_IO(remainCtx); + } + + /* try again to handle new add io during process */ + header->GetTimerCtx(remainCtx); + if (!remainCtx.empty()) { + NN_LOG_INFO("Channel id " << mOptions.id << " process io broken, size " << remainCtx.size()); + PROCESS_IO(remainCtx); + } +} + +void HcomChannelImp::InvokeChannelBrokenCb(UBSHcomChannelPtr &channel) +{ + if (mOptions.brokenHandler == nullptr) { + NN_LOG_WARN("Empty ChannelBrokenCb"); + return; + } + mOptions.brokenHandler(channel); +} + +uint64_t HcomChannelImp::GetId() +{ + return mOptions.id; +} +std::string HcomChannelImp::GetUuid() +{ + return mUuid; +} +uintptr_t HcomChannelImp::GetTimerList() +{ + return mTimerList; +} +uint32_t HcomChannelImp::GetLocalIp() +{ + return mLocalIp; +} +std::string HcomChannelImp::GetPeerConnectPayload() +{ + return mPayload; +} +uint16_t HcomChannelImp::GetDelayEraseTime() +{ + if (mOptions.brokenPolicy == UBSHcomChannelBrokenPolicy::RECONNECT) { + return RECON_DELAY_ERASE_TIME; + } else { + return DEFAULT_DELAY_ERASE_TIME; + } +} +HcomServiceCtxStore *HcomChannelImp::GetCtxStore() +{ + return mCtxStore; +} +UBSHcomChannelCallBackType HcomChannelImp::GetCallBackType() +{ + return mOptions.cbType; +} + +SerResult HcomChannelImp::GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &idInfo) +{ + NN_ASSERT_LOG_RETURN(mEpInfo != nullptr, SER_ERROR) + NN_ASSERT_LOG_RETURN(mEpInfo->epArr[0] != nullptr, SER_ERROR) + return mEpInfo->epArr[0]->GetRemoteUdsIdInfo(idInfo); +} + +int32_t HcomChannelImp::SetTwoSideThreshold(const UBSHcomTwoSideThreshold &threshold) +{ + if (threshold.splitThreshold == UINT32_MAX) { + mUserSplitSendThreshold = UINT32_MAX; + // 如果mEnableMrCache为false且设置rndv阈值生效的情况(splitThreshold不涉及),给用户返回报错,让用户先将mEnableMrCache设置为true + if ((!mEnableMrCache) && (threshold.rndvThreshold != UINT32_MAX)) { + NN_LOG_ERROR("Fail to set Threshold, because need set enableMrCache true first "); + return SER_INVALID_PARAM; + } + mRndvThreshold = threshold.rndvThreshold; + NN_LOG_INFO("SplitSend (UBC only) enabled with threshold " << mUserSplitSendThreshold + << ", Rndv Threshold is: " << mRndvThreshold); + return SER_OK; + } + + if (threshold.splitThreshold < NN_NO128) { + NN_LOG_ERROR("The split threshold (" << threshold.splitThreshold + << ") is less than 128, SplitSend may not work properly"); + return SER_INVALID_PARAM; + } + + if (threshold.splitThreshold > mMaxSendRecvDataSize) { + NN_LOG_ERROR("The split threshold (" << threshold.splitThreshold << ") is larger than SegSize (" + << mMaxSendRecvDataSize << "), SplitSend will fail to post request"); + return SER_INVALID_PARAM; + } + + if (threshold.splitThreshold > threshold.rndvThreshold) { + NN_LOG_ERROR("The threshold of split send cannot be greater than the threshold of rndv! Split send threshold: " + << threshold.splitThreshold << " Rndv threshold: " << threshold.rndvThreshold); + return SER_INVALID_PARAM; + } + + // 如果mEnableMrCache为false且设置rndv阈值生效的情况(splitThreshold不涉及),给用户返回报错,让用户先将mEnableMrCache设置为true + if ((!mEnableMrCache) && (threshold.rndvThreshold != UINT32_MAX)) { + NN_LOG_ERROR("Fail to set Threshold, because need set enableMrCache true first "); + return SER_INVALID_PARAM; + } + + // 拆包阈值只有在小于rndv阈值时才有效 + if (threshold.splitThreshold < threshold.rndvThreshold) { + mUserSplitSendThreshold = + threshold.splitThreshold - sizeof(UBSHcomNetTransHeader) - sizeof(UBSHcomFragmentHeader); + } + + mRndvThreshold = threshold.rndvThreshold; + + NN_LOG_INFO("SplitSend (UBC only) enabled with threshold " << threshold.splitThreshold + << ", Rndv Threshold is: " << mRndvThreshold); + return SER_OK; +} +} +} diff --git a/src/service_v2/service_channel_imp.h b/src/service_v2/service_channel_imp.h new file mode 100644 index 0000000000000000000000000000000000000000..cbd7e24c6883ecd8009d1279717a75d7f58730be --- /dev/null +++ b/src/service_v2/service_channel_imp.h @@ -0,0 +1,315 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_SERVICE_V2_HCOM_CHANNEL_IMP_H_ +#define HCOM_SERVICE_V2_HCOM_CHANNEL_IMP_H_ + +#include +#include +#include +#include "hcom_def.h" +#include "hcom_service_def.h" +#include "hcom_service_channel.h" +#include "hcom_obj_statistics.h" +#include "service_imp.h" +#include "service_common.h" +#include "service_callback.h" +#include "hcom_env.h" + +namespace ock { +namespace hcom { + +struct HcomChannelImpOptions { + uint64_t id = 0; + uintptr_t rateLimit = 0; + UBSHcomServiceChannelBrokenHandler brokenHandler = nullptr; + uint32_t multiRailThresh = 8192; + int16_t oneSideTimeout = 30; + int16_t twoSideTimeout = 30; + UBSHcomChannelCallBackType cbType = UBSHcomChannelCallBackType::CHANNEL_FUNC_CB; + UBSHcomChannelBrokenPolicy brokenPolicy = UBSHcomChannelBrokenPolicy::BROKEN_ALL; + bool enableMultiRail = false; + bool selfPoll = false; +}; + +enum ServiceEpState : uint16_t { + SER_EP_ESTABLISHED = 0, + SER_EP_BROKEN = 1, + SER_EP_ESTABLISHED_OCCUPIED = 2, + SER_EP_ESTABLISHED_UNOCCUPIED = 3, +}; + +struct EpInfo { + UBSHcomNetAtomicState epState[CHANNEL_EP_MAX_NUM]{}; /* state of eps */ + UBSHcomNetEndpoint *epArr[CHANNEL_EP_MAX_NUM]{}; /* endpoints for data transfer */ + uint16_t epSize = 0; + EpInfo() = default; +}; + +#define PROCESS_IO(remainCtx) \ + do { \ + UBSHcomServiceContext brokenCtx{}; \ + HcomServiceGlobalObject::BuildBrokenCtx(brokenCtx); \ + for (auto ctx : (remainCtx)) { \ + if (ctx->EraseSeqNoWithRet()) { \ + ctx->TimeoutDump(); \ + ctx->MarkFinished(); \ + brokenCtx.mCh = ctx->mChannel; \ + ctx->RunCallBack(brokenCtx); \ + brokenCtx.mCh.Set(nullptr); \ + ctx->DecreaseRef(); \ + } \ + /* decrease linked list ref */ \ + ctx->DecreaseRef(); \ + } \ + } while (0) + +struct TimerCtx { + uint32_t seqNo = 0; + HcomServiceTimer *timer = nullptr; + + TimerCtx() = default; +}; + +class HcomChannelImp : public UBSHcomChannel { +public: + int32_t Send(const UBSHcomRequest &req, const Callback *done = nullptr) override; + int32_t Call(const UBSHcomRequest &req, UBSHcomResponse &rsp, const Callback *done = nullptr) override; + int32_t Reply(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req, const Callback *done = nullptr) override; + int32_t Put(const UBSHcomOneSideRequest &req, const Callback *done = nullptr) override; + int32_t Get(const UBSHcomOneSideRequest &req, const Callback *done = nullptr) override; + int32_t SendFds(int fds[], uint32_t len) override; + int32_t ReceiveFds(int fds[], uint32_t len, int32_t timeoutSec) override; + int32_t Recv(const UBSHcomServiceContext &context, uintptr_t address, uint32_t size, + const Callback *done = nullptr) override; + + int32_t SetFlowControlConfig(const UBSHcomFlowCtrlOptions &opt) override; + void SetChannelTimeOut(int16_t oneSideTimeout, int16_t twoSideTimeout) override; + int32_t GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &idInfo) override; + int32_t SetTwoSideThreshold(const UBSHcomTwoSideThreshold &threshold) override; + + inline void SetTraceId(const std::string &traceId) override + { + SetTraceIdInner(traceId); + } + +protected: + /// 当发送端发送大包时,一个消息会被分为多个 fragment 发送,接收端在识别后需要将 fragment 拼接。 + /// 当返回 SpliceMessageResultType::ERROR 时,同时返回的 SerResult 是实际的错误码, + /// std::string 无效;当返回 SpliceMessageResultType::OK 时,同时返回的 SerResult 必 + /// 定为 SER_OK,同时 std::string 为拼完后的完整消息;当返回 + /// SpliceMessageResultType::INDETERMINATE 时,同时返回的 SerResult 必定为 SER_OK, + /// std::string 无效。 + auto SpliceMessage(const UBSHcomNetRequestContext &ctx, bool isResp) + -> std::tuple override; + + std::mutex mMsgReceivedMutex; + std::map>> mMsgReceived; + +private: + HcomChannelImp(uint64_t id, bool selfPoll, InnerConnectOptions &opt, + UBSHcomServiceProtocol protocol = UBSHcomServiceProtocol::UNKNOWN, + uint32_t maxSendRecvDataSize = 1024) + : mProtocol(protocol), mMaxSendRecvDataSize(maxSendRecvDataSize) + { + mOptions.id = id; + mOptions.selfPoll = selfPoll; + mOptions.cbType = opt.cbType; + if (opt.mode == UBSHcomClientPollingMode::SELF_POLL_BUSY) { + mRespOriginalSeqNo = true; + } + mChState.Set(UBSHcomChannelState::CH_NEW); + OBJ_GC_INCREASE(HcomChannelImp); + } + + ~HcomChannelImp() override + { + UnInitialize(); + ForceUnInitialize(); + OBJ_GC_DECREASE(HcomChannelImp); + } + + SerResult Initialize(std::vector &ep, uintptr_t ctxMemPool, uintptr_t periodicMgr, + uintptr_t pgTable) override; + void UnInitialize() override; + void ForceUnInitialize(); + std::string ToString() override; + + SerResult InitializeEp(std::vector &ep); + SerResult SendInner(const UBSHcomRequest &req, const Callback *done); + SerResult SyncSendInner(const UBSHcomRequest &req); + SerResult AsyncSendInner(const UBSHcomRequest &req, const Callback *done); + SerResult SyncSendWithSelfPoll(const UBSHcomRequest &req); + SerResult CallInner(const UBSHcomRequest &req, UBSHcomResponse &rsp, const Callback *done); + SerResult SyncCallInner(const UBSHcomRequest &req, UBSHcomResponse &rsp, uint32_t timeOut = NN_NO0); + SerResult RndvInner(UBSHcomNetEndpoint *ep, const UBSHcomRequest &req, UBSHcomNetTransOpInfo &transOp, bool isCall); + SerResult AsyncCallInner(const UBSHcomRequest &req, const Callback *done); + SerResult ReplyInner(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req, const Callback *done); + SerResult SyncReplyInner(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req); + SerResult AsyncReplyInner(const UBSHcomReplyContext &ctx, const UBSHcomRequest &req, const Callback *done); + SerResult SyncCallWithSelfPoll(const UBSHcomRequest &req, UBSHcomResponse &rsp); + SerResult FlowControl(uint64_t size, int16_t timeout, uint64_t timestamp); + SerResult SyncSendSplitWithWorkerPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, uint32_t fragmentNum); + SerResult SyncSendSplitWithSelfPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, uint32_t fragmentNum, + uint32_t index); + SerResult AsyncSendSplitWithWorkerPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, uint32_t fragmentNum, + const Callback *done); + SerResult AsyncReplySplitWithWorkerPoll(const UBSHcomReplyContext &ctx, UBSHcomNetEndpoint *&ep, + const UBSHcomRequest &req, uint32_t fragmentNum, const Callback *done); + SerResult SyncReplySplitWithWorkerPoll(const UBSHcomReplyContext &ctx, UBSHcomNetEndpoint *&ep, + const UBSHcomRequest &req, uint32_t fragmentNum); + + SerResult SyncCallSplitWithWorkerPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, uint32_t fragmentNum, + UBSHcomResponse &rsp); + SerResult AsyncCallSplitWithWorkerPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, uint32_t fragmentNum, + const Callback *done); + SerResult SyncCallSplitWithSelfPoll(UBSHcomNetEndpoint *&ep, const UBSHcomRequest &req, uint32_t fragmentNum, + uint32_t index, UBSHcomResponse &rsp); + + void SetUuid(const std::string &uuid) override; + void SetPayload(const std::string &payload) override; + void SetBrokenInfo(UBSHcomChannelBrokenPolicy policy, const UBSHcomServiceChannelBrokenHandler &broken) override; + void SetEpBroken(uint32_t index) override; + void SetChannelState(UBSHcomChannelState state) override; + void SetEpUpCtx(); + bool AllEpEstablished(); + void UnSetEpUpCtx(); + inline void SetMultiRail(bool multiRail, uint32_t threshold) override + { + mOptions.enableMultiRail = multiRail; + mOptions.multiRailThresh = threshold; + } + inline void SetDriverNum(uint16_t driverNum) override + { + mDriverNum = driverNum; + } + inline void SetTotalBandWidth(uint32_t bandWidth) override + { + mTotalBandWidth = bandWidth; + } + + inline void SetEnableMrCache(bool enableMrCache) override + { + mEnableMrCache = enableMrCache; + } + + bool AllEpBroken() override; + bool NeedProcessBroken() override; + void ProcessIoInBroken() override; + void InvokeChannelBrokenCb(UBSHcomChannelPtr &channel) override; + + uint64_t GetId() override; + std::string GetUuid() override; + uintptr_t GetTimerList() override; + uint32_t GetLocalIp() override; + std::string GetPeerConnectPayload() override; + uint16_t GetDelayEraseTime() override; + HcomServiceCtxStore *GetCtxStore() override; + UBSHcomChannelCallBackType GetCallBackType() override; + + SerResult AcquireSelfPollEp(UBSHcomNetEndpoint *&ep, uint32_t &index, int16_t timeout, uint16_t dvrIdx = 0); + void ReleaseSelfPollEp(uint32_t index); + SerResult NextWorkerPollEp(UBSHcomNetEndpoint *&ep, uint16_t dvrIdx = 0); + SerResult ResponseWorkerPollEp(uintptr_t rspCtx, UBSHcomNetEndpoint *&ep); + + SerResult PrepareTimerContext(Callback *cb, int16_t timeout, TimerCtx &context); + void DestroyTimerContext(TimerCtx &context); + + Callback *GetAsyncCB(uint16_t multiNum, const Callback *done); + SerResult OneSideInner(const UBSHcomOneSideRequest &request, const Callback *done, bool isWrite); + SerResult OneSideSyncWithSelfPoll(const UBSHcomOneSideRequest &request, bool isWrite); + SerResult OneSideSyncWithWorkerPoll(const UBSHcomOneSideRequest &request, bool isWrite); + SerResult OneSideAsyncWithWorkerPoll(const UBSHcomOneSideRequest &request, const Callback *done, bool isWrite); + SerResult PrepareCallback(HcomServiceSelfSyncParam& syncParam, TimerCtx &syncContext); + inline void CalculateOffsetAndSize(const UBSHcomOneSideRequest &request, UBSHcomNetEndpoint *ep, + uint32_t &remain, uint32_t &offset, uint32_t &size) + { + if (mOptions.enableMultiRail && mDriverNum > 1 && request.size > mOptions.multiRailThresh) { + offset = request.size - remain; + uint32_t transferSize = + static_cast(ceilf(request.size * (ep->GetBandWidth() / static_cast(mTotalBandWidth)))); + size = (transferSize > remain) ? remain : transferSize; + remain -= size; + } + } + + inline uint32_t SelfPollNextSeqNo() + { + /* reserve 1 bit for mark send/rsp */ + uint32_t tmpSeqNo = __sync_fetch_and_add(&mSelfPollSeqNo, 1); + /* In order to make sure the netSeqNo.wholeSeq is not zero, and since only the lower 24 bits will be assigned to + netSeqNo.realSeq, tmpSeqNo need to be ensured that lower 24 bits are not zero. */ + if (NN_UNLIKELY((tmpSeqNo & 0x00FFFFFF) == 0)) { + tmpSeqNo = __sync_fetch_and_add(&mSelfPollSeqNo, 1); + } + + /* sell poll just set realSeq */ + HcomSeqNo netSeqNo(0); + netSeqNo.realSeq = tmpSeqNo; + return netSeqNo.wholeSeq; + } + + void ProcessRemainCallback(Callback *cb, uint32_t remainNums); + + /// 估算 fragment 个数,如果为单个 fragment 则不拆包,service 层无额外头部。否则 + /// 添加 UBSHcomFragmentHeader 头部以指示共被分成多少块 fragment、总大小为多少。 + /// 如果满足以下任意条件,则始终使用一个 fragment 来发送: + /// 1、split send 特性未启用; + /// 2、protocol 非 UBC; + /// 3、同时设置了rndv阈值,并且发送的数据size大于等于rndv阈值; + /// 调用前应当保证 size 不为 0. + /// \see SplitSendInner + inline uint32_t EstimateFragmentNum(uint32_t size, bool withRndv = false) + { + if (mUserSplitSendThreshold == UINT32_MAX || (withRndv && size >= mRndvThreshold)) { + return 1; + } + + return (mProtocol != UBSHcomServiceProtocol::UBC) ? + 1 : + (static_cast(size) + mUserSplitSendThreshold - 1) / mUserSplitSendThreshold; + } + + void CheckAndUpdateThreshold(); + +private: + HcomChannelImpOptions mOptions; + EpInfo *mEpInfo = nullptr; + HcomServiceCtxStore *mCtxStore = nullptr; + uintptr_t mCtxMemPool = 0; + uintptr_t mPeriodicMgr = 0; /* timeout periodic manager */ + uintptr_t mPgtable = 0; + uint32_t mRndvThreshold = UINT32_MAX; + uint32_t mSelfPollSeqNo = 1; /* for self polling simplified usage */ + bool mRespOriginalSeqNo = false; + + uintptr_t mTimerList = 0; + uint32_t mLocalIp = 0; + uint16_t mDriverNum = 1; + uint32_t mTotalBandWidth = 0; + uint16_t mEpChoosingIdx[4] = {0}; /* index for choosing to which ep to transfer */ + std::string mUuid; + std::atomic_bool mBrokenProcessed{false}; + std::mutex mMgrMutex; + UBSHcomNetAtomicState mChState; // channel state + std::string mPayload; + HcomConnectTimestamp mConnectTimestamp {}; + + UBSHcomServiceProtocol mProtocol = UBSHcomServiceProtocol::UNKNOWN; + uint32_t mMaxSendRecvDataSize = 1024; + bool mEnableMrCache = false; // mr into pgTable for management + + friend class HcomServiceImp; +}; + +} +} +#endif // HCOM_SERVICE_V2_HCOM_CHANNEL_IMP_H_ \ No newline at end of file diff --git a/src/service_v2/service_common.cpp b/src/service_v2/service_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fad90b6fd919a19db7def06a707302fb9ab70f7e --- /dev/null +++ b/src/service_v2/service_common.cpp @@ -0,0 +1,213 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "service_common.h" + +#include +#include "securec.h" +#include "hcom_def.h" +#include "net_common.h" +#include "hcom_service_def.h" +namespace ock { +namespace hcom { + +Callback* HcomServiceGlobalObject::gEmptyCallback = nullptr; +bool HcomServiceGlobalObject::gInited = false; + +SerResult SerConnInfo::Serialize(SerConnInfo &connInfo, const std::string &payload, std::string &out) +{ + std::string strInfo; + if (!connInfo.ToString(strInfo)) { + NN_LOG_ERROR("Failed to generate connect info to string"); + return SER_ERROR; + } + + out.clear(); + out.reserve(strInfo.size() + payload.size()); + out.append(strInfo); + out.append(payload); + return SER_OK; +} + +SerResult SerConnInfo::Deserialize(const std::string &payload, SerConnInfo &connInfo, std::string &userPayload) +{ + if (NN_UNLIKELY(!HexStringToBuff(payload, sizeof(SerConnInfo), &connInfo))) { + NN_LOG_ERROR("Failed to parse connection info"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(!connInfo.Validate())) { + NN_LOG_ERROR("Failed to validate connection info"); + return SER_INVALID_PARAM; + } + + uint32_t connInfoStrSize = sizeof(SerConnInfo) * 2; + userPayload.clear(); + userPayload.append(payload.substr(connInfoStrSize, payload.size() - connInfoStrSize)); + + return NN_OK; +} + +SerResult HcomServiceGlobalObject::Initialize() +{ + if (gInited) { + return SER_OK; + } + + gEmptyCallback = NewPermanentCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + if (NN_UNLIKELY(gEmptyCallback == nullptr)) { + NN_LOG_ERROR("Build empty callback failed"); + return SER_NEW_OBJECT_FAILED; + } + + gInited = true; + return SER_OK; +} + +void HcomServiceGlobalObject::UnInitialize() +{ + if (!gInited) { + return; + } + if (gEmptyCallback != nullptr) { + delete gEmptyCallback; + gEmptyCallback = nullptr; + } + gInited = false; +} + +void HcomServiceGlobalObject::BuildTimeOutCtx(UBSHcomServiceContext &ctx) +{ + ctx.mCh.Set(nullptr); + ctx.mResult = SER_TIMEOUT; + ctx.mEpIdxInCh = 0; + ctx.mSeqNo = 0; + ctx.mDataType = UBSHcomServiceContext::INVALID_DATA; + ctx.mDataLen = 0; + ctx.mData = nullptr; + ctx.mOpType = UBSHcomRequestContext::NN_INVALID_OP_TYPE; + ctx.mOpCode = NN_NO1024; +} + +void HcomServiceGlobalObject::BuildBrokenCtx(UBSHcomServiceContext &ctx) +{ + ctx.mCh.Set(nullptr); + ctx.mResult = SER_NOT_ESTABLISHED; + ctx.mEpIdxInCh = 0; + ctx.mSeqNo = 0; + ctx.mDataType = UBSHcomServiceContext::INVALID_DATA; + ctx.mDataLen = 0; + ctx.mData = nullptr; + ctx.mOpType = UBSHcomRequestContext::NN_INVALID_OP_TYPE; + ctx.mOpCode = NN_NO1024; +} + +bool HcomConnectingEpInfo::AllEPBroken(uint16_t index) +{ + std::lock_guard lockerEp(mLock); + + if (NN_UNLIKELY(index >= mEpVector.size()) || NN_UNLIKELY(index >= CHANNEL_EP_MAX_NUM)) { + NN_LOG_ERROR("Invalid ep index " << index << ", ep size is " << mEpVector.size()); + return false; + } + + mEpState[index].Set(NEP_BROKEN); + + for (uint64_t i = 0; i < mEpVector.size(); i++) { + if (NN_UNLIKELY(mEpVector[i] == nullptr)) { + continue; + } + if (!mEpState[i].Compare(NEP_BROKEN)) { + NN_LOG_WARN("Failed to check all ep state broken, ep id " << mEpVector[i]->Id()); + return false; + } + } + + auto ret = mConnState.CAS(ConnectingEpState::NEW_EP, ConnectingEpState::EP_BROKEN); + if (NN_UNLIKELY(!ret)) { + NN_LOG_ERROR("Failed to validate ep state by generate channel, state " << + static_cast(mConnState.Get())); + } + + return ret; +} + +bool HcomConnectingEpInfo::Compare(const SerConnInfo &info) const +{ + if (NN_UNLIKELY(mConnInfo.version != info.version)) { + NN_LOG_ERROR("New connect version " << info.version << " is different from stored version " << info.version); + return false; + } + + if (NN_UNLIKELY(mConnInfo.channelId != info.channelId)) { + NN_LOG_ERROR("New connect channelId " << info.channelId << " is different from stored channelId " << + info.channelId); + return false; + } + + if (NN_UNLIKELY(mConnInfo.policy != info.policy)) { + NN_LOG_ERROR("New connect policy " << static_cast(mConnInfo.policy) + << " different from stored policy " << static_cast(info.policy)); + return false; + } + + if (NN_UNLIKELY(info.index != mEpVector.size())) { + NN_LOG_ERROR("Failed to validate sequence, connect index " << info.index << " , already ep size " << + mEpVector.size()); + return false; + } + + if (NN_UNLIKELY(mConnInfo.options.linkCount != info.options.linkCount)) { + NN_LOG_ERROR("New connect linkCount " << mConnInfo.options.linkCount <<" is different from stored connect " + << info.options.linkCount); + return false; + } + + if (NN_UNLIKELY(mConnInfo.options.cbType != info.options.cbType)) { + NN_LOG_ERROR("New connect cbType " << static_cast(mConnInfo.options.cbType) << + "is different from stored connect " << static_cast(info.options.cbType)); + return false; + } + + if (NN_UNLIKELY(mConnInfo.options.clientGroupId != info.options.clientGroupId)) { + NN_LOG_ERROR("New connect clientGroupId " << static_cast(mConnInfo.options.clientGroupId) + << "is different from stored connect " << static_cast(info.options.clientGroupId)); + return false; + } + + if (NN_UNLIKELY(mConnInfo.options.serverGroupId != info.options.serverGroupId)) { + NN_LOG_ERROR("New connect serverGroupId " << static_cast(mConnInfo.options.serverGroupId) + << "is different from stored connect " << static_cast(info.options.serverGroupId)); + return false; + } + return true; +} + +uint64_t HcomConnectTimestamp::GetRemoteTimestamp(int16_t timeOutSecond) const +{ + if (timeOutSecond <= 0) { + return 0; + } + // V2 rndv用户在recv的时候就判断超时,所以需要用deltaTimUs/2 去计算remote的超时时间 + uint64_t remoteCurTime = NetMonotonic::TimeUs() - localTimeUs + remoteTimeUs - (deltaTimeUs / NN_NO2); + return remoteCurTime + timeOutSecond * NN_NO1000000; +} + +bool HcomServiceRndvMessage::IsTimeout() const +{ + if (timestamp == 0) { + return false; + } + return NetMonotonic::TimeUs() >= timestamp; +} +} +} \ No newline at end of file diff --git a/src/service_v2/service_common.h b/src/service_v2/service_common.h new file mode 100644 index 0000000000000000000000000000000000000000..05cb8aa269511120b72381d5098ad199b304ea22 --- /dev/null +++ b/src/service_v2/service_common.h @@ -0,0 +1,563 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_SERVICE_V2_SERVICE_COMMON_H_ +#define HCOM_SERVICE_V2_SERVICE_COMMON_H_ +#include +#include "securec.h" + +#include "hcom_service_channel.h" +#include "hcom_service_def.h" +#include "hcom_service_context.h" +#include "net_common.h" +#include "net_crc32.h" +#include "net_trace.h" +#include "net_monotonic.h" + +namespace ock { +namespace hcom { +constexpr uint32_t CHANNEL_EP_MAX_NUM = 64; + +enum ServiceV2PrivateOpcode : uint16_t { + RNDV_CALL_OP_V2 = 1001, + EXCHANGE_TIMESTAMP_OP = 1002, +}; + +struct SerTransContext { + uint32_t seqNo = 0; + bool invokeCallback = true; // call() message is no need to invoke callback. + + Callback *callback = nullptr; /* record for response message quick handle */ + SerTransContext() = default; +} __attribute__((packed)); + +inline void SetServiceTransCtx(char *ctxData, uint32_t seqNo) +{ + auto ctx = reinterpret_cast(ctxData); + ctx->seqNo = seqNo; + ctx->invokeCallback = true; +} + +inline void SetServiceTransCtx(char *ctxData, Callback *callback) +{ + auto ctx = reinterpret_cast(ctxData); + ctx->callback = callback; + ctx->invokeCallback = true; +} + +inline void SetServiceTransCtx(char *ctxData, uint32_t seqNo, bool invokeCallback) +{ + auto ctx = reinterpret_cast(ctxData); + ctx->seqNo = seqNo; + ctx->invokeCallback = invokeCallback; +} + +inline Callback *GetServiceTransCb(char *ctxData) +{ + auto ctx = reinterpret_cast(ctxData); + return ctx->callback; +} + +inline uint32_t GetServiceTransSeqNo(char *ctxData) +{ + auto context = reinterpret_cast(ctxData); + return context->seqNo; +} + +inline bool GetServiceTransNeedPostedCall(char *ctxData) +{ + auto context = reinterpret_cast(ctxData); + return context->invokeCallback; +} + +/* if result is OK, there is no need to invoke callback in Call() method, waiting for remote rsp */ +inline bool IsNeedInvokeCallback(const UBSHcomRequestContext &ctx) +{ + if (ctx.Result() != NN_OK) { + return true; + } + + if (ctx.OpType() == UBSHcomRequestContext::NN_SENT || + ctx.OpType() == UBSHcomRequestContext::NN_SENT_RAW || + ctx.OpType() == UBSHcomRequestContext::NN_WRITTEN || + ctx.OpType() == UBSHcomRequestContext::NN_READ) { + return GetServiceTransNeedPostedCall(const_cast(ctx.OriginalRequest().upCtxData)); + } else if (ctx.OpType() == UBSHcomRequestContext::NN_SENT_RAW_SGL || + ctx.OpType() == UBSHcomRequestContext::NN_SGL_WRITTEN || + ctx.OpType() == UBSHcomRequestContext::NN_SGL_READ) { + return GetServiceTransNeedPostedCall(const_cast(ctx.OriginalSgeRequest().upCtxData)); + } else { + NN_LOG_ERROR("Invalid op type " << ctx.OpType() << " for request posted"); + return false; + } +} + +/** + * @brief Generate a permanent callback object. + * + * @param Args + * @param args + * @return Callback* + * @note see @ref NewCallback. + */ +template Callback *NewPermanentCallback(Args... args) +{ + auto closure = std::bind(args...); + return new (std::nothrow) InnerClosureCallback(std::move(closure), false); +} + +class AsyncClosureCallback : public Callback { +public: + explicit AsyncClosureCallback(Callback *function, uint16_t mFinishCnt) + : mFunction(function), mTotalTime(mFinishCnt) + {} + + void Run(UBSHcomServiceContext &context) override + { + __sync_fetch_and_add(&mRunTime, 1); + if (mRunTime < mTotalTime) { + return; + } + if (mFunction != nullptr) { + mFunction->Run(context); + } + delete this; + } + +private: + uint64_t GetTime() override + { + return mStartTime; + } + + void SetTime(uint64_t time) override + { + mStartTime = time; + } + +private: + Callback *mFunction = nullptr; + uint16_t mRunTime = 0; + uint16_t mTotalTime = 0; + uint64_t mStartTime = 0; +}; + +enum class ConnectingEpState { + NEW_EP, + NEW_CHANNEL, + EP_BROKEN, +}; + +struct InnerConnectOptions { + uint16_t clientGroupId; + uint16_t serverGroupId; + uint8_t linkCount; + UBSHcomClientPollingMode mode; + UBSHcomChannelCallBackType cbType; +}; + +struct SerConnInfo { + uint32_t crc = 0; + uint32_t version = 0; + uint64_t channelId = 0; + uint64_t multiRailId = 0; + uint16_t index = 0; + uint16_t driverIndex = 0; + uint16_t driverSize = 0; + uint16_t totalLinkCount = 0; + UBSHcomChannelBrokenPolicy policy = UBSHcomChannelBrokenPolicy::BROKEN_ALL; + InnerConnectOptions options; + + SerConnInfo() = default; + SerConnInfo(uint32_t v, uint64_t id, UBSHcomChannelBrokenPolicy p, const UBSHcomConnectOptions &opt) + : version(v), channelId(id), policy(p) + { + options.clientGroupId = opt.clientGroupId; + options.serverGroupId = opt.serverGroupId; + options.linkCount = opt.linkCount; + totalLinkCount = opt.linkCount; + options.mode = opt.mode; + options.cbType = opt.cbType; + } + SerConnInfo(uint32_t v, uint64_t id, uint16_t driverSize, UBSHcomChannelBrokenPolicy p, + const UBSHcomConnectOptions &opt) + : version(v), + channelId(id), + driverSize(driverSize), + policy(p) + { + options.clientGroupId = opt.clientGroupId; + options.serverGroupId = opt.serverGroupId; + options.linkCount = opt.linkCount; + totalLinkCount = opt.linkCount; + options.mode = opt.mode; + options.cbType = opt.cbType; + multiRailId = id; + } + + inline void SetCrc32() + { + auto crcAddress = reinterpret_cast(this) + sizeof(uint32_t); + crc = NetCrc32::CalcCrc32(crcAddress, sizeof(SerConnInfo) - sizeof(uint32_t)); + } + + inline bool Validate() + { + auto crcAddress = reinterpret_cast(this) + sizeof(uint32_t); + uint32_t newCrc = NetCrc32::CalcCrc32(crcAddress, sizeof(SerConnInfo) - sizeof(uint32_t)); + + return crc == newCrc; + } + + inline bool ToString(std::string &out) + { + return BuffToHexString(this, sizeof(SerConnInfo), out); + } + + static SerResult Deserialize(const std::string &payload, SerConnInfo &connInfo, std::string &userPayLoad); + static SerResult Serialize(SerConnInfo &connInfo, const std::string &payload, std::string &out); +}; + +class ConnectingSecInfo { +public: + int64_t flag = 0; + uint32_t secContentLen = 0; + char *secContent = nullptr; + UBSHcomNetDriverSecType type = NET_SEC_VALID_ONE_WAY; + bool needAutoFree = false; + bool firstCallProvider = true; + bool firstCallValidator = true; + + ConnectingSecInfo() = default; + + void Initialize(int64_t flg, UBSHcomNetDriverSecType secType, char *output, uint32_t len, bool autoFree) + { + flag = flg; + type = secType; + secContent = output; + secContentLen = len; + needAutoFree = autoFree; + firstCallProvider = false; + } +}; + +class HcomConnectingEpInfo { +public: + HcomConnectingEpInfo() = default; + HcomConnectingEpInfo(std::string &id, const UBSHcomNetEndpointPtr &ep, SerConnInfo &info) + : mConnInfo(info), mUuid(id) + { + std::lock_guard lockerEp(mLock); + mEpState[0].Set(NEP_ESTABLISHED); + mEpVector.emplace_back(ep.Get()); + mConnState.Set(ConnectingEpState::NEW_EP); + } + + inline bool AddEp(const UBSHcomNetEndpointPtr &ep) + { + std::lock_guard lockerEp(mLock); + + if (NN_UNLIKELY(mEpVector.size() >= NN_NO64)) { + NN_LOG_ERROR("Ep vector is full, ep size now is " << mEpVector.size()); + return false; + } + + mEpState[mEpVector.size()].Set(NEP_ESTABLISHED); + mEpVector.emplace_back(ep.Get()); + return mConnState.CAS(ConnectingEpState::NEW_EP, ConnectingEpState::NEW_EP); + } + + bool AllEPBroken(uint16_t index); + + bool Compare(const SerConnInfo &info) const; + + std::mutex mLock; + SerConnInfo mConnInfo {}; + UBSHcomNetAtomicState mConnState {}; + UBSHcomNetAtomicState mEpState[CHANNEL_EP_MAX_NUM] {}; + std::string mUuid {}; + std::vector mEpVector {}; + +public: + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; + +class HcomServiceGlobalObject { +public: + static SerResult Initialize(); + static void UnInitialize(); + static void BuildTimeOutCtx(UBSHcomServiceContext &ctx); + static void BuildBrokenCtx(UBSHcomServiceContext &ctx); + +public: + static Callback *gEmptyCallback; + static bool gInited; +}; + +/** + * Endpoint to channel upCtx, store ch pointer and endpoint index into uint64_t + */ +union Ep2ChanUpCtx { + struct { + uint64_t connected : 1; /* flag for connecting or connected, store different type ptr */ + uint64_t epIdx : 5; /* endpoint index, range [0, 31] */ + uint64_t ptr : 58; /* pointer to connecting mgr or net channel */ + }; + uint64_t wholeUpCtx = 0; /* whole */ + + Ep2ChanUpCtx() = default; + explicit Ep2ChanUpCtx(uint64_t w) : wholeUpCtx(w) {} + Ep2ChanUpCtx(uint64_t connect, uint64_t p, uint64_t i) : connected(connect), epIdx(i), ptr(p) {} + + inline UBSHcomChannel *Channel() const + { + if (NN_UNLIKELY(connected != 1)) { + NN_LOG_ERROR("Failed to get channel by not connected"); + return nullptr; + } + return reinterpret_cast(ptr); + } + + inline uint32_t EpIdx() const + { + return static_cast(epIdx); + } + + inline uint64_t Ptr() const + { + return ptr; + } + + std::string ToString() const + { + std::ostringstream oss; + oss << "chPtr " << ptr << ", epIdx " << epIdx << ", whole " + << wholeUpCtx; + return oss.str(); + } +}; + +struct RateLimiter { + bool triggering = false; /* trigger next window flag */ + UBSHcomFlowCtrlLevel level = UBSHcomFlowCtrlLevel::LOW_LEVEL_BLOCK; /* wait level */ + uint16_t intervalTimeMs = 0; /* user config interval time ms, range in [1, 1000] */ + uint64_t thresholdByte = 0; /* user config threshold byte */ + + uint64_t windowEndTimeMs = 0; /* in interval time window, end time trace */ + uint64_t windowPassedByte = 0; /* in interval time window, passed byte */ + + std::mutex nextWindowMutex; /* mutex for build next window information */ + std::condition_variable nextWindowCond; /* condition variable for build next window information */ + + RateLimiter() = default; + + inline bool InvalidateSize(uint32_t size) const + { + return size > thresholdByte; + } + + inline bool AcquireQuota(uint32_t size) const + { + if (size > UINT64_MAX - windowPassedByte) { + return false; + } + return (windowPassedByte + size) <= thresholdByte; + } + + /* wait until next window */ + inline void WaitUntilNextWindow() const + { + uint64_t currentTime = NetMonotonic::TimeMs(); + uint64_t endTime = windowEndTimeMs; + if (currentTime >= endTime) { + return; + } + + if (level == UBSHcomFlowCtrlLevel::HIGH_LEVEL_BLOCK) { + while (NetMonotonic::TimeMs() < endTime) { + } + } else { + usleep((endTime - currentTime) * NN_NO1000); + } + } + + inline void BuildNextWindow() + { + std::unique_lock locker(nextWindowMutex); + nextWindowCond.wait(locker, [&]() { + return !triggering; + }); + + if (NetMonotonic::TimeMs() < windowEndTimeMs) { + return; + } + + triggering = true; + windowEndTimeMs = NetMonotonic::TimeMs() + intervalTimeMs; + triggering = false; + windowPassedByte = 0; + + nextWindowCond.notify_all(); + } +}; + +struct HcomServiceSelfSyncParam { + sem_t sem {}; + int result = NN_OK; + + HcomServiceSelfSyncParam() + { + sem_init(&sem, 0, 0); + } + + ~HcomServiceSelfSyncParam() + { + sem_destroy(&sem); + } + + inline void Wait() + { + int result; + do { + result = sem_wait(&sem); + } while (result == -1 && errno == EINTR); + if (result != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Sem wait failed with result " << result << ", errno " << errno << ", reason " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + } + } + + inline void Signal() + { + auto result = sem_post(&sem); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Sem post failed " << result); + } + } + + int inline Result() const + { + return result; + } + + inline void Result(int ret) + { + result = ret; + } +}; + +struct HcomServiceMessage { + void *data = nullptr; /* pointer of data */ + uint32_t size = 0; /* size of data */ + bool transferOwner = false; /* reserved, transfer data ownership to hcom, it will be freed after transferred */ + + HcomServiceMessage() = default; + + /** + * @brief Constructor + * @param d [in] pointer of data to be sent or received + * @param s [in] size of data to be sent or received + */ + HcomServiceMessage(void *d, uint32_t s) : data(d), size(s) {} +} __attribute__((packed)); + +/** + * @brief Seq number + */ +union HcomSeqNo { + struct { + /* low address */ + uint32_t realSeq : 24; /* real seq no */ + uint32_t version : 6; /* request version */ + uint32_t fromFlat : 1; /* allocated from flat or hash map */ + uint32_t isResp : 1; /* request or reply, 0 for request, 1 for reply */ + /* high address */ + }; + uint32_t wholeSeq = 0; + + explicit HcomSeqNo(uint32_t whole) : wholeSeq(whole) + { + } + + inline void SetValue(uint32_t flat, uint32_t ver, uint32_t seq) + { + fromFlat = flat; + version = ver; + realSeq = seq; + } + + std::string ToString() const + { + std::ostringstream oss; + oss << "HcomSeqNo info=[wholeSeq: " << wholeSeq << ", isResp: " << isResp + << ", fromFlat: " << fromFlat << ", version: " << version + << ", realSeq: " << realSeq << "]"; + return oss.str(); + } + + inline bool IsResp() const + { + return isResp == 1; + } +}; + +struct SerUuid { + uint32_t ip = 0; + uint64_t channelId = 0; + + SerUuid() = default; + SerUuid(uint32_t ipAddress, uint64_t id) : ip(ipAddress), channelId(id) {} + + inline bool ToString(std::string &out) + { + return BuffToHexString(this, sizeof(SerUuid), out); + } +} __attribute__((packed)); + +struct HcomConnectTimestamp { + uint64_t localTimeUs = 0; /* local time trace when connecting */ + uint64_t remoteTimeUs = 0; /* remote time trace when connecting */ + uint64_t deltaTimeUs = 0; /* delta time for exchange info */ + + HcomConnectTimestamp() = default; + HcomConnectTimestamp(uint64_t lTime, uint64_t rTime, uint64_t dTime) + : localTimeUs(lTime), remoteTimeUs(rTime), deltaTimeUs(dTime) + {} + + uint64_t GetRemoteTimestamp(int16_t timeOutSecond) const; +}; + +struct HcomExchangeTimestamp { + uint64_t timestamp = 0; + uint64_t deltaTimeStamp = 0; + + HcomExchangeTimestamp() = default; +}; + +struct HcomServiceRndvMessage { + uint64_t timestamp = 0; + UBSHcomRequest request{}; + + HcomServiceRndvMessage() = default; + HcomServiceRndvMessage(uint64_t ts, const UBSHcomRequest &req) + : timestamp(ts), request(req) + {} + + bool IsTimeout() const; +}; + +} +} +#endif // HCOM_SERVICE_V2_SERVICE_COMMON_H_ \ No newline at end of file diff --git a/src/service_v2/service_context.cpp b/src/service_v2/service_context.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b316befc25d613e49de2849edc50325abce7510 --- /dev/null +++ b/src/service_v2/service_context.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "hcom_service_context.h" + +#include "net_monotonic.h" +#include "service_common.h" + +namespace ock { +namespace hcom { + +constexpr uint16_t S_TO_MS_MULTIPLIER = 1000; + +UBSHcomServiceContext::UBSHcomServiceContext(const UBSHcomRequestContext &ctx, UBSHcomChannel *ch) : mCh(ch) +{ + mOpType = ctx.OpType(); + mResult = ctx.Result(); + Ep2ChanUpCtx epCtx(ctx.EndPoint()->UpCtx()); + mEpIdxInCh = epCtx.EpIdx(); + mSeqNo = ctx.Header().seqNo; + + if (ctx.Message() != nullptr) { + mDataType = OUTER_DATA; + mDataLen = ctx.Message()->DataLen(); + mData = ctx.Message()->Data(); + } else { + mDataType = INVALID_DATA; + mDataLen = 0; + mData = nullptr; + } + + mTimeoutTraceMs = 0; + if (ctx.Header().timeout > 0) { + uint64_t timoutMs = static_cast(ctx.Header().timeout) * S_TO_MS_MULTIPLIER; + // NetMonotonic::TimeMs() + timoutMs overflow needs system to run for more than 500 million years + mTimeoutTraceMs = NetMonotonic::TimeMs() + timoutMs; + } + mOpCode = ctx.Header().opCode; + mErrorCode = ctx.Header().errorCode; +} + +} +} \ No newline at end of file diff --git a/src/service_v2/service_ctx_store.h b/src/service_v2/service_ctx_store.h new file mode 100644 index 0000000000000000000000000000000000000000..6b98f164bf9ab46d724aa730813e49c078ccf614 --- /dev/null +++ b/src/service_v2/service_ctx_store.h @@ -0,0 +1,335 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_SERVICE_V2_SERVICE_CTX_STORE_H_ +#define HCOM_SERVICE_V2_SERVICE_CTX_STORE_H_ + +#include "hcom_def.h" +#include "hcom_ref.h" +#include "service_common.h" +#include "common/net_mem_pool_fixed.h" + +namespace ock { +namespace hcom { + +constexpr int32_t MIN_FLAT_CAPACITY = 128; +constexpr int32_t MAX_FLAT_CAPACITY = 16 * 1024 * 1024; +constexpr int32_t HASH_BUCKET_SIZE = 1024; +constexpr int32_t VERSION_SHIFT = 58; +constexpr int32_t BITS_PER_INT = 32; + +class HcomServiceCtxStore { +public: + HcomServiceCtxStore(uint32_t flatCapacity, const NetMemPoolFixedPtr &ctxPool, UBSHcomNetDriverProtocol protocol) + : mFlatCapacity(flatCapacity), mCtxMemPool(ctxPool), mProtocol(protocol) + { + OBJ_GC_INCREASE(HcomServiceCtxStore); + } + + ~HcomServiceCtxStore() + { + UnInitialize(); + OBJ_GC_DECREASE(HcomServiceCtxStore); + } + + /* + * @brief Initialize the ctx store + * + * @return 0 return if successful + */ + NResult Initialize() + { + if (mCtxMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to initialize as mem pool for service context store is null"); + return SER_INVALID_PARAM; + } + + /* validate the capacity */ + if (mFlatCapacity < MIN_FLAT_CAPACITY) { + mFlatCapacity = MIN_FLAT_CAPACITY; + } else if (mFlatCapacity > MAX_FLAT_CAPACITY) { + mFlatCapacity = MAX_FLAT_CAPACITY; /* each bucket is an uint64_t, 128MB is occupied */ + } + + /* get aligned capacity */ + mFlatCapacity = 1 << (BITS_PER_INT - __builtin_clz(mFlatCapacity) - 1); + /* get seqNo mask */ + mSeqNoMask = mFlatCapacity - 1; + /* get version shift for move right */ + mVersionShift = __builtin_popcount(mSeqNoMask); + /* get version and seqNo mask, as version occupied 6 bits */ + mSeqNoAndVersionMask = (1 << (mVersionShift + VERSION_BIT_WIDTH)) - 1; + + mFlatCtxBucks = new (std::nothrow) uint64_t[mFlatCapacity]; + if (mFlatCtxBucks == nullptr) { + NN_LOG_ERROR("Failed to new service flat context buckets, probably out of memory"); + return SER_NEW_OBJECT_FAILED; + } + + /* make physical memory allocated and set them to 0 */ + bzero(mFlatCtxBucks, sizeof(uint64_t) * mFlatCapacity); + + /* reserved hash bucket for unordered map */ + for (auto &i : mHashCtxMap) { + i.reserve(HASH_BUCKET_SIZE); + } + + NN_LOG_INFO("Initialized context store, flatten capacity " + << mFlatCapacity << ", versionAndSeqMask " << mSeqNoAndVersionMask << ", seqNoMask " << mSeqNoMask + << ", seqNoAndVersionIndex " << mSeqNoAndVersionIndex); + + return SER_OK; + } + + void UnInitialize() + { + if (mFlatCtxBucks != nullptr) { + delete[] mFlatCtxBucks; + mFlatCtxBucks = nullptr; + } + } + + /* + * @brief Create a seq no, and store it + * + * @param ctx [in] ctx ptr to store + * @param output [out] seqNo created + * + * @return SER_OK if successful + * SER_INVALID_PARAM if param is invalid + * SER_STORE_SEQ_DUP if seq is duplicated in map + */ + template + NResult PutAndGetSeqNo(T *ctx, uint32_t &output) + { + if (NN_UNLIKELY(ctx == nullptr)) { + return SER_INVALID_PARAM; + } + + auto value = reinterpret_cast(ctx); + /* pre-defined variables because of goto */ + HcomSeqNo sn(0); + uint32_t mapIndex = 0; + + /* + * Try to get empty flat bucket 3 times, + * if got emtpy bucket, store it in that flat bucket, + * if not got, store it into hash map + * + * Note: don't do this in a loop (i.e. while), expanded code has better performance than loop + * + * step1: first time to get free flat bucket according to index + */ + + /* get the seqNo with increasing and mask. If the seqNo is 0, increase again */ + auto newSeqAndVersion = __sync_fetch_and_add(&mSeqNoAndVersionIndex, 1); + if (NN_UNLIKELY(newSeqAndVersion & mSeqNoMask) == 0) { + newSeqAndVersion = __sync_fetch_and_add(&mSeqNoAndVersionIndex, 1); + } + + /* get seqNo and version, and mixed value with version and ctx ptr for CAS */ + auto seqNo = newSeqAndVersion & mSeqNoMask; + uint64_t version = (newSeqAndVersion >> mVersionShift) & VERSION_MASK; + value = (version << VERSION_SHIFT) | value; + if (__sync_bool_compare_and_swap(&mFlatCtxBucks[seqNo], 0, value)) { + goto STORE_IN_FLAT; + } + + /* + * step2: second time to get free flat bucket according to index. + */ + newSeqAndVersion = __sync_fetch_and_add(&mSeqNoAndVersionIndex, 1); + if (NN_UNLIKELY(newSeqAndVersion & mSeqNoMask) == 0) { + newSeqAndVersion = __sync_fetch_and_add(&mSeqNoAndVersionIndex, 1); + } + seqNo = newSeqAndVersion & mSeqNoMask; + version = (newSeqAndVersion >> mVersionShift) & VERSION_MASK; + value = (version << VERSION_SHIFT) | value; + if (__sync_bool_compare_and_swap(&mFlatCtxBucks[seqNo], 0, value)) { + goto STORE_IN_FLAT; + } + + /* + * step3: third time to get free flat bucket according to index. + */ + newSeqAndVersion = __sync_fetch_and_add(&mSeqNoAndVersionIndex, 1); + if (NN_UNLIKELY(newSeqAndVersion & mSeqNoMask) == 0) { + newSeqAndVersion = __sync_fetch_and_add(&mSeqNoAndVersionIndex, 1); + } + seqNo = newSeqAndVersion & mSeqNoMask; + version = (newSeqAndVersion >> mVersionShift) & VERSION_MASK; + value = (version << VERSION_SHIFT) | value; + if (__sync_bool_compare_and_swap(&mFlatCtxBucks[seqNo], 0, value)) { + goto STORE_IN_FLAT; + } + + /* step 4: tried 3 times no luck to get an empty bucket, store in hash map. */ + mapIndex = seqNo % HASH_COUNT; + sn.SetValue(0, static_cast(version), seqNo); + output = sn.wholeSeq; + { + std::lock_guard guard(mHashCtxMutex[mapIndex]); + return mHashCtxMap[mapIndex].emplace(sn.wholeSeq, value).second ? SER_OK : SER_STORE_SEQ_DUP; + } + + /* if occupied one flat bucket within 3 times try. */ + STORE_IN_FLAT: + sn.SetValue(1, static_cast(version), seqNo); + output = sn.wholeSeq; + return SER_OK; + } + + /* + * @brief Get the pointer of ctx with seqNo and clean it + * + * @param seqNo [in] seqNo, which whole got from response and timer + * @param out [out] ctx ptr + * + * @return SER_OK if successful + * SER_INVALID_PARAM if param is invalid + * SER_STORE_SEQ_NO_FOUND if seq is not existed, probably removed already + * + */ + template + NResult GetSeqNoAndRemove(uint32_t seqNo, T *&out) + { + HcomSeqNo no(0); + no.wholeSeq = seqNo; + + if (NN_LIKELY(no.fromFlat == 1)) { + /* create the old pointer and */ + if (NN_UNLIKELY(no.realSeq >= mFlatCapacity)) { + return SER_STORE_SEQ_NO_FOUND; + } + uint64_t value = mFlatCtxBucks[no.realSeq] & PTR_MASK; + uint64_t tmpVersion = no.version; + + /* if timeout thread already get seq no, next time will + 1、CAS OK, but get value is 0 + 2、CAS ERR by version++ */ + // 因为ptr是从内存池拿出来的,所以重复的可能性很大,需要加个version验证一下 + if (__sync_bool_compare_and_swap(&mFlatCtxBucks[no.realSeq], (tmpVersion << VERSION_SHIFT) | value, 0)) { + if (NN_UNLIKELY(value == 0)) { + return SER_STORE_SEQ_NO_FOUND; + } + + out = reinterpret_cast(value); + return SER_OK; + } + + return SER_STORE_SEQ_NO_FOUND; + } + + uint32_t mapIndex = no.realSeq % HASH_COUNT; + no.isResp = 0; + { + std::lock_guard guard(mHashCtxMutex[mapIndex]); + auto iter = mHashCtxMap[mapIndex].find(no.wholeSeq); + if (NN_LIKELY(iter != mHashCtxMap[mapIndex].end())) { + out = reinterpret_cast(iter->second & PTR_MASK); + mHashCtxMap[mapIndex].erase(iter); + return SER_OK; + } + } + + return SER_STORE_SEQ_NO_FOUND; + } + + inline void RemoveSeqNo(uint32_t seqNo) + { + uintptr_t *outPtr = nullptr; + if (NN_UNLIKELY(GetSeqNoAndRemove(seqNo, outPtr) != SER_OK)) { + HcomSeqNo dumpSeq(seqNo); + NN_LOG_ERROR("Failed to remove ctx with seqNo " << dumpSeq.ToString() << "as not found"); + return; + } + } + + /* + * @brief Get ctx obj from mem pool + * + * @return ptr of obj if successful + * nullptr if failure + */ + template + inline T *GetCtxObj() + { + return GetOrReturn(nullptr); + } + + /* + * @brief Return ctx obj to mem pool + * + * @param obj [in] ptr of obj get from pool + */ + template + inline void Return(T *obj) + { + /* no need to check obj is nullptr, because is checked in inner function */ + (void)GetOrReturn(obj, false); + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + /* alloc/free in the same function to make sure use the same thread_local variable */ + template + inline T *GetOrReturn(T *returnCtx, bool get = true) + { + static thread_local KeyedThreadLocalCache threadCache; + // 有 2 种场景需要更新: + // - 第一次运行,初始值为 `nullptr`, 需要更新成当前在用的内存池 + // - 主线程不退出,开始时先启动了 Service1, 主线程中进行 Send 会使用 Service1 的内存池;而后 Service1 退出、内存 + // 池回收,主线程中的 `thread_local` cache 仍保存的是 Service1 的内存池地址。在新启动 Service2 后,如果在主线 + // 程中进行 Send 会更新 `thread_local` cache 指向的内存池。此时原有 Service1 的内存池才会真正被归还至 OS. + // + // 注意:上层应当**禁止同时创建同种协议的 2 个不同 Service 实例**,否则此处仍旧会出现 Service2 引用 Service1 内 + // 存池中的地址。 + threadCache.UpdateIf(mProtocol, mCtxMemPool.Get()); + + if (get) { + return threadCache.Allocate(mProtocol); + } else { + threadCache.Free(mProtocol, returnCtx); + return nullptr; + } + } + +private: + static constexpr uint32_t VERSION_MASK = 0x3F; /* mask to reverse version */ + static constexpr uint32_t VERSION_BIT_WIDTH = 6; /* mask to reverse version */ + static constexpr uint32_t HASH_COUNT = 4; /* hash map count */ + static constexpr uint64_t PTR_MASK = 0x03FFFFFFFFFFFFFF; /* ptr mask */ + +private: + /* Note: + * 1 make sure those frequently accessed variables are at first place + * 2 make sure those variables are aligned + * 3 make sure total size of those variables are less than the size of 1 cache line + */ + uint32_t mSeqNoAndVersionIndex = 1; /* atomic increase seqNo and version */ + uint32_t mSeqNoAndVersionMask = 0; /* mask to reverse the seqNo and version */ + uint32_t mSeqNoMask = 0; /* mask to reverse the seqNo */ + uint32_t mVersionShift = 0; /* move right shift num to get version */ + uint32_t mFlatCapacity = 8192; /* flat array capacity */ + uint64_t *mFlatCtxBucks = nullptr; /* actually array to store the ptr */ + NetMemPoolFixedPtr mCtxMemPool = nullptr; /* memory pool of context */ + + std::mutex mHashCtxMutex[HASH_COUNT]; /* mutex to guard unordered_map */ + std::unordered_map mHashCtxMap[HASH_COUNT]; /* unordered_map to store un-flat */ + UBSHcomNetDriverProtocol mProtocol = UBSHcomNetDriverProtocol::UNKNOWN; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} // namespace hcom +} // namespace ock + +#endif // HCOM_SERVICE_V2_SERVICE_CTX_STORE_H_ diff --git a/src/service_v2/service_imp.cpp b/src/service_v2/service_imp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a651518a78e50dfa149cdab366039d1153f5e6e2 --- /dev/null +++ b/src/service_v2/service_imp.cpp @@ -0,0 +1,1821 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "service_imp.h" + +#include +#include +#include +#include "securec.h" + +#include "hcom_def.h" +#include "hcom_err.h" +#include "hcom_log.h" +#include "hcom_num_def.h" +#include "api/hcom_service_def.h" +#include "api/hcom_service.h" + +#include "net_common.h" +#include "net_load_balance.h" +#include "net_mem_pool_fixed.h" +#include "net_param_validator.h" + +#include "service_common.h" +#include "service_callback.h" +#include "service_periodic_manager.h" +#include "service_channel_imp.h" + +namespace ock { +namespace hcom { + +constexpr uint16_t MAX_ENABLE_DEVCOUNT = 4; +constexpr uint16_t MAX_TIME_OUT_DETECT_THREAD_NUM = 4; +constexpr uint16_t MAX_USER_OPCODE = 1000; +constexpr uint16_t MAX_SYS_OPCODE = 1024; + +int32_t HcomServiceImp::Bind(const std::string &listenerUrl, const UBSHcomServiceNewChannelHandler &handler) +{ + VALIDATE_PARAM_RET(Bind, listenerUrl, handler); + mOptions.chNewHandler = handler; + NetProtocol protocal; + std::string url; + if (NN_UNLIKELY(!NetFunc::NN_SplitProtoUrl(listenerUrl, protocal, url))) { + NN_LOG_ERROR("Invalid url, should be like tcp://127.0.0.1:9981 or uds://name or ubc://eid:jettyId"); + return SER_INVALID_PARAM; + } + + mOptions.startOobSvr = true; + if (NetProtocol::NET_TCP == protocal) { + mOptions.oobType = NET_OOB_TCP; + return AddTcpOobListener(url); + } else if (NetProtocol::NET_UDS == protocal) { + mOptions.oobType = NET_OOB_UDS; + return AddUdsOobListener(url); + } else if (NetProtocol::NET_UBC == protocal) { + mOptions.oobType = NET_OOB_UB; + std::string eid; + uint16_t jettyId = 0; + if (NN_UNLIKELY(!NetFunc::NN_ConvertEidAndJettyId(url, eid, jettyId))) { + NN_LOG_ERROR("Invalid url: " << url << " should be like 1111:1111:0000:0000:0000:0000:4444:0000:888"); + return NN_PARAM_INVALID; + } + + mOptions.eid = eid; + mOptions.jettyId = jettyId; + + return SER_OK; + } + + NN_LOG_ERROR("Invalid protocal, only support tcp and uds and ubc, url should be like tcp://127.0.0.1:9981 or " + "uds://name or ubc://eid:jettyId"); + return SER_INVALID_PARAM; +} + +SerResult HcomServiceImp::AddTcpOobListener(const std::string &url, uint16_t workerCount) +{ + std::string ip; + uint16_t port; + if (NN_UNLIKELY(!NetFunc::NN_ConvertIpAndPort(url, ip, port))) { + NN_LOG_ERROR("Invalid url, should be like 127.0.0.1:9981"); + return SER_INVALID_PARAM; + } + + UBSHcomNetOobListenerOptions option; + if (NN_UNLIKELY(!option.Set(ip, port, workerCount))) { + NN_LOG_ERROR("Oob Tcp listener set failed"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(mOptions.oobOption.find(url) != mOptions.oobOption.end())) { + NN_LOG_WARN("Duplicated listen ip/port adding to driver Manager " << + mOptions.name << ", ignored"); + return SER_INVALID_PARAM; + } + + mOptions.oobOption[url] = option; + return SER_OK; +} + +SerResult HcomServiceImp::AddUdsOobListener(const std::string &url, uint16_t workerCount) +{ + std::string name; + uint16_t perm = 0; + if (NN_UNLIKELY(!NetFunc::NN_ConvertNameAndPerm(url, name, perm))) { + NN_LOG_ERROR("Convert url to name and perm failed"); + return SER_INVALID_PARAM; + } + + UBSHcomNetOobUDSListenerOptions option; + option.perm = perm; + if (NN_UNLIKELY(!option.Set(name, workerCount))) { + NN_LOG_ERROR("Oob Uds listener set failed"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(mOptions.udsOobOption.find(name) != mOptions.udsOobOption.end())) { + NN_LOG_WARN("Duplicated listen url adding to driver " << mOptions.name << + ", ignored"); + return SER_INVALID_PARAM; + } + + mOptions.udsOobOption[name] = option; + return SER_OK; +} + +int32_t HcomServiceImp::Start() +{ + std::lock_guard locker(mStartMutex); + int32_t result = SER_OK; + if (mStarted) { + return SER_OK; + } + + if (NN_UNLIKELY((result = ValidateServiceOption()) != SER_OK)) { + NN_LOG_ERROR("Invalid service info, res:" << result); + return result; + } + + if (NN_UNLIKELY((result = CreateResource()) != SER_OK)) { + NN_LOG_ERROR("CreateResource failed, res:" << result); + return result; + } + if (NN_UNLIKELY((result = InitDriver()) != SER_OK)) { + NN_LOG_ERROR("Driver start failed, res:" << result); + return result; + } + + if (NN_UNLIKELY((result = StartDriver()) != SER_OK)) { + NN_LOG_ERROR("Driver start failed, res:" << result); + return result; + } + + NetPgTable *pgtable = new (std::nothrow) NetPgTable(pgdAlloc, pgdFree); + if (NN_UNLIKELY(pgtable == nullptr)) { + NN_LOG_ERROR("Fail to create pgTable "); + return SER_ERROR; + } + mPgtable = pgtable; + + if (NN_LIKELY(mOptions.protocol != UBSHcomNetDriverProtocol::SHM + && mOptions.protocol != UBSHcomNetDriverProtocol::UDS && !mOptions.ipMasks.empty())) { + mOobIp = GetFilteredDeviceIP(mOptions.ipMasks[0]); + } + + mStarted = true; + return result; +} + +SerResult HcomServiceImp::DoInitDriver() +{ + SerResult res = SER_OK; + UBSHcomNetDriverOptions driverOpt; + ConvertHcomSerImpOptsToHcomDriOpts(mOptions, driverOpt); + RegisterDriverCb(); + uint16_t driverIdx = 0; + for (auto &driver : mDriverPtrs) { + if (driverIdx >= mOptions.workerGroupInfos.size()) { + driverOpt.SetWorkerGroupsInfo(mOptions.workerGroupInfos[0]); + } else { + driverOpt.SetWorkerGroupsInfo(mOptions.workerGroupInfos[driverIdx]); + ++driverIdx; + } + driver->RegisterNewEPHandler(std::bind(&HcomServiceImp::ServiceHandleNewEndPoint, this, + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&HcomServiceImp::ServiceEndPointBroken, this, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&HcomServiceImp::ServiceRequestReceived, this, std::placeholders::_1)); + driver->RegisterReqPostedHandler(std::bind(&HcomServiceImp::ServiceRequestPosted, this, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&HcomServiceImp::ServiceOneSideDone, this, std::placeholders::_1)); + + if (mOptions.connSecOption.provider != nullptr) { + driver->RegisterEndpointSecInfoProvider(std::bind(&HcomServiceImp::ServiceSecInfoProvider, this, + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, + std::placeholders::_5, std::placeholders::_6)); + } + + if (mOptions.connSecOption.validator != nullptr) { + driver->RegisterEndpointSecInfoValidator(std::bind(&HcomServiceImp::ServiceSecInfoValidator, this, + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); + } + + res = driver->Initialize(driverOpt); + if (NN_UNLIKELY(res != SER_OK)) { + ForceStop(); + return res; + } + } + + if (mOptions.startOobSvr && mOptions.enableMultiRail && mOptions.protocol == UBSHcomServiceProtocol::RDMA) { + res = CreateOobListeners(driverOpt); + if (NN_UNLIKELY(res != SER_OK)) { + ForceStop(); + return res; + } + + /* set cb for listeners */ + for (auto &oobServer : mOobServers) { + oobServer->SetNewConnCB(std::bind(&HcomServiceImp::NewConnectionCB, this, std::placeholders::_1)); + oobServer->SetNewConnCbThreadNum(driverOpt.oobConnHandleThreadCount); + oobServer->SetNewConnCbQueueCap(driverOpt.oobConnHandleQueueCap); + } + } + return SER_OK; +} + +SerResult HcomServiceImp::CreateOobUdsListeners(const UBSHcomNetDriverOptions &driverOpt) +{ + if (mOptions.udsOobOption.empty()) { + NN_LOG_ERROR("No listen info is set in driver " << mOptions.name); + return SER_INVALID_PARAM; + } + + if (mOptions.udsOobOption.size() > NN_NO65535) { + NN_LOG_ERROR("udsOobOption size is over 65535 in driver " << mOptions.name); + return SER_INVALID_PARAM; + } + + uint16_t oobIndex = 0; + for (auto &lOpt : mOptions.udsOobOption) { + NetOOBServerPtr oobServer = nullptr; + /* create oob server */ + if (driverOpt.enableTls) { // to check + auto oobSSLServer = new (std::nothrow) OOBSSLServer(driverOpt.oobType, lOpt.second.Name(), + lOpt.second.perm, mOptions.tlsOption.pkCb, mOptions.tlsOption.cfCb, mOptions.tlsOption.caCb); + NN_ASSERT_LOG_RETURN(oobSSLServer != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLServer->SetTlsOptions(mOptions.tlsOption.netCipherSuite, mOptions.tlsOption.tlsVersion); + oobSSLServer->SetPSKCallback(mOptions.tlsOption.pskFindCb, mOptions.tlsOption.pskUseCb); + oobServer = oobSSLServer; + } else { + oobServer = new (std::nothrow) + OOBTCPServer(driverOpt.oobType, lOpt.second.Name(), lOpt.second.perm, lOpt.second.isCheck); + NN_ASSERT_LOG_RETURN(oobServer.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + oobServer->Index({ 0, oobIndex++ }); + oobServer->SetMaxConntionNum(driverOpt.maxConnectionNum); + oobServer->SetMultiRail(driverOpt.enableMultiRail); + oobServer->IncreaseRef(); + mOobServers.emplace_back(oobServer.Get()); + } + + if (mOptions.udsOobOption.size() != mOobServers.size()) { + NN_LOG_ERROR("Created oob server count " << mOobServers.size() << " is not equal to listener options size " << + mOptions.udsOobOption.size() << " in uds driver " << mOptions.name); + return SER_ERROR; + } + + return SER_OK; +} + +SerResult HcomServiceImp::CreateOobListeners(const UBSHcomNetDriverOptions &driverOpt) +{ + if (driverOpt.oobType != NET_OOB_UDS && driverOpt.oobType != NET_OOB_TCP) { + NN_LOG_ERROR("Un-supported oob type " << driverOpt.oobType << " is set in driver Manager " << + mOptions.name); + return SER_INVALID_PARAM; + } else if (driverOpt.oobType == NET_OOB_UDS) { + return CreateOobUdsListeners(driverOpt); + } + + if (mOptions.oobOption.empty()) { + NN_LOG_ERROR("No listen info is set for oob type " << UBSHcomNetDriverOobTypeToString(driverOpt.oobType) << + " in driver " << mOptions.name); + return SER_INVALID_PARAM; + } + + if (mOptions.oobOption.size() > NN_NO65535) { + NN_LOG_ERROR("OobOption size is over 65535 in driver " << mOptions.name); + return SER_INVALID_PARAM; + } + + uint16_t oobIndex = 0; + for (auto &lOpt : mOptions.oobOption) { + NetOOBServerPtr oobServer = nullptr; + if (driverOpt.enableTls) { + auto oobSSLServer = new (std::nothrow) OOBSSLServer(driverOpt.oobType, lOpt.second.Ip(), + lOpt.second.port, mOptions.tlsOption.pkCb, mOptions.tlsOption.cfCb, mOptions.tlsOption.caCb); + NN_ASSERT_LOG_RETURN(oobSSLServer != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLServer->SetTlsOptions(mOptions.tlsOption.netCipherSuite, mOptions.tlsOption.tlsVersion); + oobSSLServer->SetPSKCallback(mOptions.tlsOption.pskFindCb, mOptions.tlsOption.pskUseCb); + oobServer = oobSSLServer; + } else { + oobServer = new (std::nothrow) OOBTCPServer(driverOpt.oobType, lOpt.second.Ip(), lOpt.second.port); + NN_ASSERT_LOG_RETURN(oobServer.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + NN_LOG_TRACE_INFO(lOpt.second.Ip()); + oobServer->Index({ 0, oobIndex++ }); + oobServer->SetMaxConntionNum(driverOpt.maxConnectionNum); + oobServer->SetMultiRail(driverOpt.enableMultiRail); + oobServer->IncreaseRef(); + mOobServers.emplace_back(oobServer.Get()); + } + + if (mOptions.oobOption.size() != mOobServers.size()) { + NN_LOG_ERROR("Created oob server count " << mOobServers.size() << " is not equal to listener options size " << + mOptions.oobOption.size() << " in driver " << mOptions.name); + return SER_ERROR; + } + + return SER_OK; +} + +SerResult HcomServiceImp::InitDriver() +{ + if (mOptions.enableMultiRail && mOptions.protocol == UBSHcomServiceProtocol::RDMA) { + if (CreateMultiRailDriver() != SER_OK) { + NN_LOG_ERROR("failed to create driver for service " << mOptions.name); + return SER_ERROR; + } + return DoInitDriver(); + } + + UBSHcomNetDriver *driver = UBSHcomNetDriver::Instance(mOptions.protocol, mOptions.name, mOptions.startOobSvr); + if (driver == nullptr) { + NN_LOG_ERROR("failed to create driver for service " << mOptions.name); + return SER_ERROR; + } + mDriverPtrs.emplace_back(driver); + + if (mOptions.startOobSvr) { + for (auto &option : mOptions.oobOption) { + driver->AddOobOptions(option.second); + } + for (auto &option : mOptions.udsOobOption) { + driver->AddOobUdsOptions(option.second); + } + + if (mOptions.oobType == NET_OOB_UB) { + driver->OobEidAndJettyId(mOptions.eid, mOptions.jettyId); + } + } + return DoInitDriver(); +} + + +SerResult HcomServiceImp::CreateMultiRailDriver() +{ + uint16_t enableDevCount = 0; + std::string ipMasksStr; + NetFunc::NN_VecStrToStr(mOptions.ipMasks, ",", ipMasksStr); + std::string ipGroupsStr; + NetFunc::NN_VecStrToStr(mOptions.ipGroups, ";", ipGroupsStr); + + if (NN_UNLIKELY(!UBSHcomNetDriver::MultiRailGetDevCount(mOptions.protocol, ipMasksStr, enableDevCount, + ipGroupsStr))) { + NN_LOG_ERROR("Failed to new multi rail service, because not get active RDMA devices. "); + return SER_ERROR; + } + + if (NN_UNLIKELY((enableDevCount == 0) || (enableDevCount > MAX_ENABLE_DEVCOUNT))) { + NN_LOG_ERROR("The number of available devices is " << enableDevCount << ", only 1~" << MAX_ENABLE_DEVCOUNT << + " driver is allowed in MultiRail Service."); + return SER_ERROR; + } + mDriverPtrs.reserve(enableDevCount); + for (uint16_t i = 0; i < enableDevCount; i++) { + UBSHcomNetDriver *driver = UBSHcomNetDriver::Instance(mOptions.protocol, + mOptions.name + "_" + std::to_string(i), mOptions.startOobSvr); + if (NN_UNLIKELY(driver == nullptr)) { + NN_LOG_WARN("Failed to new driver in devIndex " << i << "for " << RDMA); + continue; + } + driver->SetDeviceId(i); + NN_LOG_INFO("create driver " << driver->Name()); + mDriverPtrs.emplace_back(driver); + driver->IncreaseRef(); + } + return SER_OK; +} + +SerResult HcomServiceImp::StartDriver() +{ + SerResult result = SER_OK; + for (auto &driver : mDriverPtrs) { + result = driver->Start(); + if (NN_UNLIKELY(result != SER_OK)) { + ForceStop(); + return result; + } + } + + if (!mOptions.startOobSvr) { + return SER_OK; + } + + for (uint32_t i = 0; i < mOobServers.size(); i++) { + if (NN_UNLIKELY(mOobServers[i] == nullptr)) { + for (uint32_t j = 0; j < i; j++) { + mOobServers[j]->Stop(); + } + return result; + } + if ((result = mOobServers[i]->Start()) != SER_OK) { + for (uint32_t j = 0; j < i; j++) { + mOobServers[j]->Stop(); + } + return result; + } + } + return SER_OK; +} + +void HcomServiceImp::ForceStop() +{ + for (auto &server : mOobServers) { + server->Stop(); + server->DecreaseRef(); + } + mOobServers.clear(); + + for (auto &driver : mDriverPtrs) { + driver->Stop(); + driver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(driver->Name()); + } + mDriverPtrs.clear(); + + if (mPeriodicMgr.Get() != nullptr) { + mPeriodicMgr->Stop(); + mPeriodicMgr.Set(nullptr); + } + + if (mContextMemPool.Get() != nullptr) { + mContextMemPool.Set(nullptr); + } + + if (mPgtable.Get() != nullptr) { + mPgtable->Cleanup(); + mPgtable.Set(nullptr); + } + mStarted = false; +} + +void HcomServiceImp::RegisterDriverCb() +{ + if (NN_UNLIKELY(mOptions.tlsOption.enableTls)) { + for (auto &driver : mDriverPtrs) { + driver->RegisterTLSCaCallback(mOptions.tlsOption.caCb); + driver->RegisterTLSCertificationCallback(mOptions.tlsOption.cfCb); + driver->RegisterTLSPrivateKeyCallback(mOptions.tlsOption.pkCb); + driver->RegisterPskFindSessionCb(mOptions.tlsOption.pskFindCb); + driver->RegisterPskUseSessionCb(mOptions.tlsOption.pskUseCb); + } + } +} + +SerResult HcomServiceImp::ValidateServiceOption() +{ + if (NN_UNLIKELY(mOptions.timeOutDetectThreadNum == 0 + || mOptions.timeOutDetectThreadNum > MAX_TIME_OUT_DETECT_THREAD_NUM)) { + NN_LOG_ERROR("Invalid time out detect thread num " << mOptions.timeOutDetectThreadNum << ", must range [1, 4]"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(mOptions.chBrokenHandler == nullptr)) { + NN_LOG_ERROR("Invoke RegisterChannelBrokenHandler to register callback first"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(mOptions.recvHandler == nullptr)) { + NN_LOG_ERROR("Invoke RegisterRecvHandler to register callback first"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(mOptions.sendHandler == nullptr)) { + NN_LOG_ERROR("Invoke RegisterSendHandler to register callback first"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(mOptions.oneSideDoneHandler == nullptr)) { + NN_LOG_ERROR("Invoke RegisterOneSideHandler to register callback first"); + return SER_INVALID_PARAM; + } + + return SER_OK; +} + +SerResult HcomServiceImp::CreateResource() +{ + SerResult res = SER_OK; + if ((res = CreatePeriodicMgr()) != SER_OK) { + NN_LOG_ERROR("CreatePeriodicMgr failed"); + return res; + } + + if ((res = CreateCtxMemPool()) != SER_OK) { + NN_LOG_ERROR("CreateCtxStore failed"); + return res; + } + return res; +} + +SerResult HcomServiceImp::CreatePeriodicMgr() +{ + HcomPeriodicManagerPtr periodicMgr + = new (std::nothrow) HcomPeriodicManager(mOptions.timeOutDetectThreadNum, mOptions.name); + if (NN_UNLIKELY(periodicMgr.Get() == nullptr)) { + NN_LOG_ERROR("Create periodic manager failed"); + return SER_NEW_OBJECT_FAILED; + } + if (NN_UNLIKELY(periodicMgr->Start() != SER_OK)) { + NN_LOG_ERROR("Start periodic manager failed"); + return SER_TIMER_NOT_WORK; + } + mPeriodicMgr = periodicMgr; + return SER_OK; +} + +SerResult HcomServiceImp::CreateCtxMemPool() +{ + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = NN_NO64; + if (mOptions.enableRndv) { + options.minBlkSize = NN_NO64 * NN_NO4; + } + options.tcExpandBlkCnt = NN_NO256; + NetMemPoolFixedPtr contextMemPool = + new (std::nothrow) NetMemPoolFixed("ServiceContextTimer-" + mOptions.name, options); + if (NN_UNLIKELY(contextMemPool.Get() == nullptr)) { + NN_LOG_ERROR("Create mem pool failed"); + return SER_NEW_OBJECT_FAILED; + } + + auto ret = contextMemPool->Initialize(); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Init mem pool failed"); + return SER_NEW_OBJECT_FAILED; + } + + mContextMemPool = contextMemPool; + return SER_OK; +} + +int32_t HcomServiceImp::DoDestroy(const std::string &name) +{ + if (mStarted) { + ForceStop(); + } + return SER_OK; +} + +int32_t HcomServiceImp::Connect(const std::string &serverUrl, UBSHcomChannelPtr &ch, const UBSHcomConnectOptions &opt) +{ + if (!mStarted) { + NN_LOG_ERROR("Failed to validate state as service is not started"); + return SER_STOP; + } + + VALIDATE_PARAM_RET(ConnectOptions, opt); + SerResult res = SER_OK; + + UBSHcomChannelPtr tmpChannel; + const uint32_t version = 0; + SerConnInfo connInfo(version, NetUuid::GenerateUuid(mOobIp), mDriverPtrs.size(), + mOptions.chBrokenPolicy, opt); + + res = DoConnect(serverUrl, connInfo, opt.payload, tmpChannel); + if (NN_UNLIKELY(res != SER_OK)) { + NN_LOG_ERROR("Failed to DoConnect, result: " << res); + return res; + } + + res = ExchangeTimestamp(tmpChannel.Get()); + if (NN_UNLIKELY(res != SER_OK)) { + NN_LOG_ERROR("Failed to exchange timestamp in service connect"); + Disconnect(tmpChannel); + return res; + } + + std::string uuid; + if (NN_UNLIKELY(GenerateUuid(tmpChannel->GetLocalIp(), tmpChannel->GetId(), uuid) != SER_OK)) { + res = SER_INVALID_PARAM; + NN_LOG_ERROR("Failed to Generate uuid"); + Disconnect(tmpChannel); + return res; + } + + tmpChannel->SetUuid(uuid); + tmpChannel->SetBrokenInfo(connInfo.policy, mOptions.chBrokenHandler); + if (NN_UNLIKELY(EmplaceChannelUuid(tmpChannel) != SER_OK)) { + res = SER_CHANNEL_ID_DUP; + NN_LOG_ERROR("Failed to Emplace uuid"); + Disconnect(tmpChannel); + return res; + } + tmpChannel->SetMultiRail(mOptions.enableMultiRail, mOptions.multiRailThresh); + tmpChannel->SetDriverNum(mDriverPtrs.size()); + tmpChannel->SetPayload(opt.payload); + ch = tmpChannel; + return res; +} + +SerResult HcomServiceImp::DoConnectInner(const std::string &serverUrl, SerConnInfo &opt, const std::string &payLoad, + std::vector &epVector, uint32_t &totalBandWidth) +{ + SerResult res = SER_OK; + opt.totalLinkCount = opt.options.linkCount * mDriverPtrs.size(); + for (int j = 0; j < static_cast(mDriverPtrs.size()); ++j) { + opt.driverIndex = mDriverPtrs[j]->GetDeviceId(); + for (uint8_t i = 0; i < opt.options.linkCount; i++) { + opt.index = i + j * opt.options.linkCount; + opt.SetCrc32(); + std::string serializeConnInfo; + if (SerConnInfo::Serialize(opt, payLoad, serializeConnInfo) != SER_OK) { + NN_LOG_ERROR("Failed to serializable payload for connect"); + return SER_INVALID_PARAM; + } + + UBSHcomNetEndpointPtr ep; + auto result = mDriverPtrs[j]->Connect(serverUrl, serializeConnInfo, ep, + static_cast(opt.options.mode), opt.options.serverGroupId, opt.options.clientGroupId, + opt.channelId); + if (NN_LIKELY(result == SER_OK)) { + epVector.emplace_back(ep); + continue; + } + // 失败处理 + for (auto &iter : epVector) { + iter->Close(); + } + { + std::lock_guard lockerEp(mNewEpMutex); + mSecInfoMap.erase(opt.channelId); + NN_LOG_ERROR("Failed to connect " << result); + return result; + } + } + totalBandWidth += mDriverPtrs[j]->GetBandWidth(); + } + return res; +} + +SerResult HcomServiceImp::DoConnect(const std::string &serverUrl, SerConnInfo &opt, const std::string &payLoad, + UBSHcomChannelPtr &channel) +{ + SerResult res = SER_OK; + std::vector epVector; + epVector.reserve(opt.options.linkCount * mDriverPtrs.size()); + uint32_t totalBandWidth = 0; + res = DoConnectInner(serverUrl, opt, payLoad, epVector, totalBandWidth); + if (NN_UNLIKELY(res != SER_OK)) { + NN_LOG_ERROR("Failed to connect , as " << res); + return res; + } + + bool selfPoll = (opt.options.mode == UBSHcomClientPollingMode::SELF_POLL_BUSY || + opt.options.mode == UBSHcomClientPollingMode::SELF_POLL_EVENT); + + UBSHcomChannelPtr tmpChannel = new (std::nothrow) + HcomChannelImp(opt.channelId, selfPoll, opt.options, Protocol(), mOptions.maxSendRecvDataSize); + if (NN_UNLIKELY(tmpChannel == nullptr)) { + NN_LOG_ERROR("Failed to new channel obj"); + for (auto &iter : epVector) { + iter->Close(); + } + std::lock_guard lockerEp(mNewEpMutex); + mSecInfoMap.erase(opt.channelId); + return SER_NEW_OBJECT_FAILED; + } + + tmpChannel->SetEnableMrCache(mEnableMrCache); + + if (NN_UNLIKELY(tmpChannel->Initialize(epVector, reinterpret_cast(mContextMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())))) { + for (auto &iter : epVector) { + if (iter != nullptr) { + iter->Close(); + } + } + std::lock_guard lockerEp(mNewEpMutex); + mSecInfoMap.erase(opt.channelId); + return SER_NEW_OBJECT_FAILED; + } + tmpChannel->SetTotalBandWidth(totalBandWidth); + NN_LOG_INFO(tmpChannel->ToString()); + channel = tmpChannel; + std::lock_guard lockerEp(mNewEpMutex); + mSecInfoMap.erase(opt.channelId); + return SER_OK; +} + +void HcomServiceImp::DoChooseDriver(uint8_t devInex, uint8_t bandWidth, + int8_t &selectDevIndex, uint8_t &selectBandWidth, UBSHcomNetDriver *&driver) +{ + bool isUsed = false; + for (auto it = mDriverPair.begin(); it != mDriverPair.end(); ++it) { + if (it->second == devInex) { + // The peer driver has established, which is not the first time + isUsed = true; + selectDevIndex = it->first; + break; + } + } + if (isUsed) { + for (auto driverPtr : mDriverPtrs) { + if (driverPtr->GetDeviceId() == selectDevIndex) { + driver = driverPtr.Get(); + } + } + return; + } + // 1. find driver + for (uint16_t i = 0; i < static_cast(mDriverPtrs.size()); ++i) { + selectBandWidth = mDriverPtrs[i]->GetBandWidth(); + selectDevIndex = mDriverPtrs[i]->GetDeviceId(); + + // 1.1 find the driver maximum bandwidth + if (bandWidth > selectBandWidth) { + continue; + } + // 1.2 Used or Not + bool isFound = false; + for (int j = 0; j < static_cast(mUseId.size()); ++j) { + if (mUseId[j] == selectDevIndex) { + isFound = true; + break; + } + } + + if (isFound) { + // 1.2.1 Already used + continue; + } + // 1.2.2 find + mUseId.emplace_back(selectDevIndex); + mDriverPair.emplace(selectDevIndex, devInex); + break; + } + + // 2. if no found, random select a driver. + auto innerIdx = __sync_fetch_and_add(&mDriverIndex, 1) % mDriverPtrs.size(); + driver = mDriverPtrs[innerIdx].Get(); + selectDevIndex = driver->GetDeviceId(); + mDriverPair.emplace(selectDevIndex, devInex); +} + +SerResult HcomServiceImp::ChooseDriver(OOBTCPConnection &conn, UBSHcomNetDriver *&driver) +{ + ConnectHeader header{}; + void *receiveBuf = &header; + auto result = conn.Receive(receiveBuf, sizeof(ConnectHeader)); + if (result != 0) { + NN_LOG_ERROR("Failed to receive specified device info , result " << result); + return result; + } + uint8_t bandWidth = header.bandWidth; + uint8_t devInex = header.devIndex; + uint8_t selectBandWidth = 0; + int8_t selectDevIndex = -1; + DoChooseDriver(devInex, bandWidth, selectDevIndex, selectBandWidth, driver); + + if (NN_UNLIKELY(driver == nullptr)) { + NN_LOG_ERROR("Failed to select driver when peer connect. "); + return SER_ERROR; + } + selectBandWidth = driver->GetBandWidth(); + driver->SetPeerDevId(devInex); + ConnectHeader driverHeader; + SetDriverConnHeader(driverHeader, selectBandWidth, static_cast(selectDevIndex)); + result = conn.Send(&driverHeader, sizeof(ConnectHeader)); + if (result != 0) { + NN_LOG_ERROR("Send driver info to client failed " << driver->Name() << ", result " << result); + } + return result; +} + +SerResult HcomServiceImp::NewConnectionCB(OOBTCPConnection &conn) +{ + // choose driver + UBSHcomNetDriver *driver = nullptr; + auto result = ChooseDriver(conn, driver); + if (NN_UNLIKELY(result != SER_OK)) { + return result; + } + + return driver->MultiRailNewConnection(conn); +} + +void HcomServiceImp::Disconnect(const UBSHcomChannelPtr &ch) +{ + if (ch.Get() != nullptr) { + ch->UnInitialize(); + } +} +// create one mr for each driver +int32_t HcomServiceImp::RegisterMemoryRegion(uint64_t size, UBSHcomRegMemoryRegion &mr) +{ + if (mDriverPtrs.size() == 0) { + NN_LOG_ERROR("RegisterMemoryRegion failed, as driverPtr not created"); + return NN_ERROR; + } + int32_t res = 0; + auto &netMrs = mr.GetHcomMrs(); + uint32_t driverSize = mDriverPtrs.size(); + netMrs.reserve(driverSize); + uint32_t i = 0; + for (; i < driverSize; i++) { + auto driverPtr = mDriverPtrs[i].Get(); + if (driverPtr == nullptr) { + NN_LOG_ERROR("CreateMemoryRegion failed because driverPtr empty"); + break; + } + UBSHcomMemoryRegionPtr netMr; + res = driverPtr->CreateMemoryRegion(size, netMr); + if (res != 0) { + NN_LOG_ERROR("CreateMemoryRegion failed, res:" << res); + break; + } + if (mEnableMrCache) { + res = InsertPgTable(netMr); + if (res != SER_OK) { + break; + } + } + netMrs.emplace_back(netMr); + } + + if (i < driverSize) { + DestroyNetMrs(netMrs, 0, i); + } + + return res; +} + +int32_t HcomServiceImp::RegisterMemoryRegion(uintptr_t address, uint64_t size, UBSHcomRegMemoryRegion &mr) +{ + if (mDriverPtrs.size() == 0) { + NN_LOG_ERROR("RegisterMemoryRegion failed, as driver not created"); + return NN_ERROR; + } + int32_t res = 0; + auto &netMrs = mr.GetHcomMrs(); + uint32_t driverSize = mDriverPtrs.size(); + netMrs.reserve(driverSize); + uint32_t i = 0; + for (; i < driverSize; i++) { + auto driver = mDriverPtrs[i].Get(); + if (driver == nullptr) { + NN_LOG_ERROR("CreateMemoryRegion failed because driver empty"); + break; + } + UBSHcomMemoryRegionPtr netMr; + res = driver->CreateMemoryRegion(address, size, netMr); + if (res != 0) { + NN_LOG_ERROR("CreateMemoryRegion failed, res:" << res); + break; + } + if (mEnableMrCache) { + res = InsertPgTable(netMr); + if (res != SER_OK) { + break; + } + } + netMrs.emplace_back(netMr); + } + + if (i < driverSize) { + DestroyNetMrs(netMrs, 0, i); + } + + return res; +} + +SerResult HcomServiceImp::InsertPgTable(UBSHcomNetMemoryRegionPtr &mr) +{ + SerResult res = SER_OK; + PgtRegion *pgtRegion = new (std::nothrow) PgtRegion(); + if (pgtRegion == nullptr) { + res = NN_ERROR; + NN_LOG_ERROR("Fail to new PgtRegion, res:" << res); + return res; + } + // pgtRegion [start,end) 使用pgTable首地址和(尾地址+1)需要16字节对齐,若使用UBC协议则UBC硬件限制需要支持4096字节对齐 + pgtRegion->start = mr->GetAddress(); + pgtRegion->end = mr->GetAddress() + mr->Size(); + pgtRegion->key = mr->GetLKey(); + pgtRegion->token = reinterpret_cast(mr->GetMemorySeg()); + res = mPgtable->Insert(*pgtRegion); + if (res != NN_OK) { + NN_LOG_ERROR("CreateMemoryRegion insert pgTable fail, res:" << res); + delete pgtRegion; + return res; + } + mr->mPgRegion = reinterpret_cast(pgtRegion); + return res; +} + +void HcomServiceImp::DestroyMemoryRegion(UBSHcomRegMemoryRegion &mr) +{ + auto &netMrs = mr.GetHcomMrs(); + if (NN_UNLIKELY(netMrs.empty())) { + NN_LOG_WARN("No need to destroy as UBSHcomMemoryRegionPtr is empty"); + return; + } + + if (NN_UNLIKELY(netMrs.size() != mDriverPtrs.size())) { + NN_LOG_WARN("Size of UBSHcomMemoryRegionPtr is not equal to dirvers, mr size:" << netMrs.size() << + ", driver size:" << mDriverPtrs.size()); + return; + } + + DestroyNetMrs(netMrs, 0, netMrs.size()); +} + +void HcomServiceImp::SetEnableMrCache(bool enableMrCache) +{ + mEnableMrCache = enableMrCache; +} + +void HcomServiceImp::DestroyNetMrs(std::vector &netMrs, uint32_t start, uint32_t end) +{ + for (uint32_t i = start; i < end; i++) { + uintptr_t delPgRegion = netMrs[i]->mPgRegion; + mDriverPtrs[i]->DestroyMemoryRegion(netMrs[i]); + if (!mEnableMrCache) { + continue; + } + PgtRegion *pgtRegion = reinterpret_cast(delPgRegion); + if (pgtRegion != nullptr) { + SerResult res = mPgtable->Remove(*pgtRegion); + if (res != 0) { + NN_LOG_WARN("Unable to Remove PgTable in destroyMemoryRegion, res:" << res); + } + delete pgtRegion; + netMrs[i]->mPgRegion = 0; + } + } + netMrs.clear(); +} + +void HcomServiceImp::RegisterChannelBrokenHandler(const UBSHcomServiceChannelBrokenHandler &handler, + const UBSHcomChannelBrokenPolicy policy) +{ + mOptions.chBrokenHandler = handler; + mOptions.chBrokenPolicy = policy; +} + +void HcomServiceImp::RegisterIdleHandler(const UBSHcomServiceIdleHandler &handler) +{ + mOptions.idleHandler = handler; +} + +void HcomServiceImp::RegisterRecvHandler(const UBSHcomServiceRecvHandler &recvHandler) +{ + mOptions.recvHandler = recvHandler; +} + +void HcomServiceImp::RegisterSendHandler(const UBSHcomServiceSendHandler &sendHandler) +{ + mOptions.sendHandler = sendHandler; +} + +void HcomServiceImp::RegisterOneSideHandler(const UBSHcomServiceOneSideDoneHandler &oneSideDoneHandler) +{ + mOptions.oneSideDoneHandler = oneSideDoneHandler; +} + +void HcomServiceImp::AddWorkerGroup(uint16_t workerGroupId, uint32_t threadCount, + const std::pair &cpuIdsRange, int8_t priority, uint16_t multirailIdx) +{ + if (multirailIdx >= MAX_MULTI_RAIL_NUM) { + NN_LOG_ERROR("Invalid multirailIdx, should be in range [0, 3]"); + return; + } + UBSHcomWorkerGroupInfo groupInfo; + groupInfo.threadPriority = priority; + groupInfo.groupId = workerGroupId; + groupInfo.threadCount = threadCount; + groupInfo.cpuIdsRange = cpuIdsRange; + { + std::lock_guard locker(mOptionsMutex); + if (mOptions.workerGroupInfos.size() <= multirailIdx) { + mOptions.workerGroupInfos.resize(multirailIdx + 1); + } + mOptions.workerGroupInfos[multirailIdx].emplace_back(groupInfo); + } +} + +void HcomServiceImp::AddListener(const std::string &url, uint16_t workerCount) +{ + NetProtocol protocal; + std::string urlSuffix; + if (NN_UNLIKELY(!NetFunc::NN_SplitProtoUrl(url, protocal, urlSuffix))) { + NN_LOG_ERROR("Invalid url, should be like tcp://127.0.0.1:9981 or uds://name"); + return; + } + + if (NetProtocol::NET_TCP == protocal) { + AddTcpOobListener(urlSuffix, workerCount); + } else if (NetProtocol::NET_UDS == protocal) { + AddUdsOobListener(urlSuffix, workerCount); + } +} + +void HcomServiceImp::SetConnectLBPolicy(UBSHcomServiceLBPolicy lbPolicy) +{ + mOptions.lbPolicy = lbPolicy; +} + +void HcomServiceImp::SetTlsOptions(const UBSHcomTlsOptions &opt) +{ + mOptions.tlsOption = opt; +} + +void HcomServiceImp::SetConnSecureOpt(const UBSHcomConnSecureOptions &opt) +{ + mOptions.connSecOption = opt; +} + +void HcomServiceImp::SetTcpUserTimeOutSec(uint16_t timeOutSec) +{ + mOptions.tcpTimeOutSec = timeOutSec; +} + +void HcomServiceImp::SetTcpSendZCopy(bool tcpSendZCopy) +{ + mOptions.tcpSendZCopy = tcpSendZCopy; +} + +void HcomServiceImp::SetDeviceIpMask(const std::vector &ipMasks) +{ + mOptions.ipMasks = ipMasks; +} + +void HcomServiceImp::SetDeviceIpGroups(const std::vector &ipGroups) +{ + mOptions.ipGroups = ipGroups; +} + +void HcomServiceImp::SetCompletionQueueDepth(uint16_t depth) +{ + mOptions.completionQueueDepth = depth; +} + +void HcomServiceImp::SetSendQueueSize(uint32_t sqSize) +{ + mOptions.qpSendQueueSize = sqSize; +} + +void HcomServiceImp::SetRecvQueueSize(uint32_t rqSize) +{ + mOptions.qpRecvQueueSize = rqSize; +} + +void HcomServiceImp::SetQueuePrePostSize(uint32_t prePostSize) +{ + mOptions.qpPrePostSize = prePostSize; +} + +void HcomServiceImp::SetPollingBatchSize(uint16_t pollSize) +{ + mOptions.pollingBatchSize = pollSize; +} + +void HcomServiceImp::SetEventPollingTimeOutUs(uint16_t pollTimeout) +{ + mOptions.eventPollingTimeOutUs = pollTimeout; +} + +void HcomServiceImp::SetTimeOutDetectionThreadNum(uint32_t threadNum) +{ + mOptions.timeOutDetectThreadNum = threadNum; +} + +void HcomServiceImp::SetMaxConnectionCount(uint32_t maxConnCount) +{ + mOptions.maxConnCount = maxConnCount; +} + +void HcomServiceImp::SetHeartBeatOptions(const UBSHcomHeartBeatOptions &opt) +{ + mOptions.heartBeatOption = opt; +} + +void HcomServiceImp::SetMultiRailOptions(const UBSHcomMultiRailOptions &opt) +{ + mOptions.enableMultiRail = opt.enable; + mOptions.multiRailThresh = opt.threshold; +} + +void HcomServiceImp::SetUbcMode(UBSHcomUbcMode ubcMode) +{ + mOptions.ubcMode = ubcMode; +} + +void HcomServiceImp::SetMaxSendRecvDataCount(uint32_t maxSendRecvDataCount) +{ + mOptions.maxSendRecvDataCount = maxSendRecvDataCount; +} + +SerResult HcomServiceImp::GenerateUuid(const std::string &ipInfo, uint64_t channelId, std::string &uuid) +{ + uint32_t ip = 0; + SerResult ret = GetIpAddressByIpPort(ipInfo, ip); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Failed to get ip address " << ipInfo << ", channel id " << channelId); + return SER_INVALID_PARAM; + } + + SerUuid tmpUuid(ip, channelId); + + if (NN_UNLIKELY(!tmpUuid.ToString(uuid))) { + NN_LOG_ERROR("Failed to generate uuid"); + return SER_ERROR; + } + NN_LOG_TRACE_INFO("###### uuid " << uuid << ", ip port " << ipInfo << ", ip " << ip << ", channel id " << + channelId); + return SER_OK; +} + +SerResult HcomServiceImp::GenerateUuid(uint32_t ip, uint64_t channelId, std::string &uuid) +{ + SerUuid tmpUuid(ip, channelId); + if (NN_UNLIKELY(!tmpUuid.ToString(uuid))) { + NN_LOG_ERROR("Failed to generate uuid"); + return SER_ERROR; + } + NN_LOG_TRACE_INFO("###### uuid " << uuid << ", ip " << ip << ", channel id " << channelId); + return SER_OK; +} + +SerResult HcomServiceImp::EmplaceNewEndpoint(const UBSHcomNetEndpointPtr &newEp, ConnectingEpInfoPtr &epInfo, + SerConnInfo &connInfo, std::string &uuid) +{ + std::lock_guard lockerEp(mNewEpMutex); + auto iter = mNewEpMap.find(uuid); + if (iter == mNewEpMap.end()) { + if (NN_UNLIKELY(!VALIDATE_PARAM(SerConnInfo, connInfo))) { + NN_LOG_ERROR("UBSHcomService Failed to verify connection info"); + return SER_INVALID_PARAM; + } + + epInfo = new (std::nothrow) HcomConnectingEpInfo(uuid, newEp, connInfo); + if (NN_UNLIKELY(epInfo == nullptr)) { + NN_LOG_ERROR("UBSHcomService Failed to new ep info"); + return SER_NEW_OBJECT_FAILED; + } + mNewEpMap.emplace(uuid, epInfo); + } else { + epInfo = iter->second; + if (NN_UNLIKELY(epInfo == nullptr)) { + NN_LOG_ERROR("NetService Failed as epInfo empty"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(!epInfo->Compare(connInfo))) { + NN_LOG_ERROR("UBSHcomService Failed to validate connect info"); + return SER_INVALID_PARAM; + } + + if (NN_UNLIKELY(!epInfo->AddEp(newEp))) { + NN_LOG_ERROR("UBSHcomService Failed to add ep by broken"); + return SER_EP_BROKEN_DURING_CONNECTING; + } + } + return SER_OK; +} + +int32_t HcomServiceImp::ServiceHandleNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEp, + const std::string &payload) +{ + if (NN_UNLIKELY(newEp == nullptr)) { + NN_LOG_ERROR("Invalid newEp, newEp is nullptr"); + return SER_INVALID_PARAM; + } + + SerConnInfo connInfo; + std::string userPayLoad; + if (NN_UNLIKELY(SerConnInfo::Deserialize(payload, connInfo, userPayLoad) != SER_OK)) { + NN_LOG_ERROR("Failed to call ServiceHandlerNewEndPoint as deserialize conn info failed"); + return SER_INVALID_PARAM; + } + + std::string uuid; + if (NN_UNLIKELY(GenerateUuid(ipPort, connInfo.channelId, uuid) != SER_OK)) { + NN_LOG_ERROR("Failed to generate uuid"); + return SER_INVALID_PARAM; + } + + ConnectingEpInfoPtr epInfo = nullptr; + if (NN_UNLIKELY(EmplaceNewEndpoint(newEp, epInfo, connInfo, uuid) != SER_OK)) { + NN_LOG_ERROR("Failed to emplace new ep"); + return SER_INVALID_PARAM; + } + + Ep2ChanUpCtx ctx(0, reinterpret_cast(epInfo.Get()), connInfo.index); + newEp->UpCtx(ctx.wholeUpCtx); + + if (epInfo->mEpVector.size() < connInfo.totalLinkCount) { + // not last one + return SER_OK; + } + // last one + if (NN_UNLIKELY(!epInfo->mConnState.CAS(ConnectingEpState::NEW_EP, ConnectingEpState::NEW_CHANNEL))) { + NN_LOG_ERROR("Failed to validate ep state, maybe some eps has broken"); + return SER_EP_BROKEN_DURING_CONNECTING; + } + auto result = ServiceNewChannel(ipPort, connInfo, userPayLoad, epInfo->mEpVector); + + std::lock_guard lockerEp(mNewEpMutex); + mNewEpMap.erase(uuid); + mSecInfoMap.erase(connInfo.channelId); + return result; +} + +int32_t HcomServiceImp::ServiceNewChannel(const std::string &ipPort, SerConnInfo &connInfo, + const std::string &userPayLoad, std::vector &ep) +{ + SerResult res = SER_OK; + UBSHcomChannelPtr channel = new (std::nothrow) + HcomChannelImp(connInfo.channelId, false, connInfo.options, Protocol(), mOptions.maxSendRecvDataSize); + if (NN_UNLIKELY(channel == nullptr)) { + NN_LOG_ERROR("Failed to new channel obj"); + return SER_NEW_OBJECT_FAILED; + } + channel->SetEnableMrCache(mEnableMrCache); + if (NN_UNLIKELY(channel->Initialize(ep, reinterpret_cast(mContextMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())))) { + NN_LOG_ERROR("Failed to initialize channel"); + return SER_NEW_OBJECT_FAILED; + } + + NN_LOG_INFO(channel->ToString()); + std::string uuid; + if (NN_UNLIKELY(GenerateUuid(ipPort, channel->GetId(), uuid) != SER_OK)) { + channel->UnInitialize(); + return SER_INVALID_PARAM; + } + + channel->SetUuid(uuid); + if (NN_UNLIKELY(EmplaceChannelUuid(channel) != SER_OK)) { + NN_LOG_ERROR("Failed to emplace uuid "); + channel->UnInitialize(); + return SER_CHANNEL_ID_DUP; + } + + channel->SetBrokenInfo(static_cast(connInfo.policy), mOptions.chBrokenHandler); + if (NN_UNLIKELY(mOptions.chNewHandler == nullptr)) { + NN_LOG_ERROR("Failed to invoke user cb as handler is nullptr"); + EraseChannel(reinterpret_cast(channel.Get())); + channel->UnInitialize(); + return SER_INVALID_PARAM; + } + + channel->SetPayload(userPayLoad); + res = mOptions.chNewHandler(ipPort, channel, userPayLoad); + if (NN_UNLIKELY(res != SER_OK)) { + NN_LOG_ERROR("Failed to invoke user cb " << res); + EraseChannel(reinterpret_cast(channel.Get())); + channel->UnInitialize(); + return res; + } + + return res; +} + +SerResult HcomServiceImp::DelayEraseChannel(UBSHcomChannelPtr &ch, uint16_t delayTime) +{ + auto chPtr = reinterpret_cast(ch.Get()); + Callback *newCallback = UBSHcomNewCallback(&HcomServiceImp::EraseChannel, this, chPtr); + if (newCallback == nullptr) { + NN_LOG_ERROR("Failed to new callback obj."); + return SER_NEW_OBJECT_FAILED; + } + HcomServiceCtxStore *ctxStore = ch->GetCtxStore(); + if (ctxStore == nullptr) { + NN_LOG_ERROR("Failed to get ctx store."); + delete newCallback; + return SER_NEW_OBJECT_FAILED; + } + + auto timerPtr = ctxStore->GetCtxObj(); + if (NN_UNLIKELY(timerPtr == nullptr)) { + NN_LOG_ERROR("Failed to get context object from memory pool."); + delete newCallback; + return SER_NEW_OBJECT_FAILED; + } + + auto timer = new (timerPtr)HcomServiceTimer(ch.Get(), ctxStore, delayTime, reinterpret_cast(newCallback), + HcomAsyncCBType::CBS_CHANNEL_BROKEN); + uint32_t seqNo = 0; + auto ret = ctxStore->PutAndGetSeqNo(timer, seqNo); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Failed to generate seqNo by context store pool."); + ctxStore->Return(timerPtr); + delete newCallback; + return SER_NEW_OBJECT_FAILED; + } + + timer->IncreaseRef(); + timer->SeqNo(seqNo); + + ret = mPeriodicMgr->AddTimer(timer); + if (NN_UNLIKELY(ret != SER_OK)) { + NN_LOG_ERROR("Failed to add timer in for timeout control."); + timer->EraseSeqNo(); + ctxStore->Return(timerPtr); + delete newCallback; + return ret; + } + timer->IncreaseRef(); + return SER_OK; +} + +void HcomServiceImp::EraseChannel(uintptr_t chPtr) +{ + UBSHcomChannelPtr channel = reinterpret_cast(chPtr); + std::lock_guard lockerChannel(mChannelMutex); + mChannelMap.erase(channel->GetUuid()); +} + +void HcomServiceImp::ServiceEndPointBroken(const UBSHcomNetEndpointPtr &netEp) +{ + if (NN_UNLIKELY(netEp == nullptr)) { + NN_LOG_ERROR("Failed to call ServiceEndPointBroken as netEp is null"); + return; + } + + Ep2ChanUpCtx ctx(netEp->UpCtx()); + if (NN_UNLIKELY(ctx.wholeUpCtx == 0)) { + NN_LOG_ERROR("Up ctx is nullptr, maybe some errors occurs during connecting"); + return; + } + + if (ctx.connected == 0) { + ConnectingEpInfoPtr epInfo = reinterpret_cast(ctx.Ptr()); + if (NN_UNLIKELY(!epInfo->AllEPBroken(ctx.EpIdx()))) { + return; + } + + std::lock_guard lockerEp(mNewEpMutex); + mNewEpMap.erase(epInfo->mUuid); + mSecInfoMap.erase(epInfo->mConnInfo.channelId); + return; + } + + // channel already generate + UBSHcomChannelPtr channel = ctx.Channel(); + if (NN_UNLIKELY(channel == nullptr)) { + NN_LOG_ERROR("Up ctx channel is nullptr, maybe some errors occurs during connecting"); + return; + } + + channel->SetEpBroken(ctx.EpIdx()); + if (!channel->AllEpBroken()) { + NN_LOG_INFO("channel is not all broken"); + return; + } + + if (!channel->NeedProcessBroken()) { + return; + } + + channel->SetChannelState(UBSHcomChannelState::CH_CLOSE); + usleep(NN_NO100); + channel->ProcessIoInBroken(); + channel->InvokeChannelBrokenCb(channel); + + uint16_t delayEraseTime = channel->GetDelayEraseTime(); + // default: try delay erase channel + if (NN_UNLIKELY(DelayEraseChannel(channel, delayEraseTime) == SER_OK)) { + return; + } else { + NN_LOG_WARN("Failed to delay erase channel, now direct erase channel id " << channel->GetId()); + EraseChannel(reinterpret_cast(channel.Get())); + } +} + +int32_t HcomServiceImp::ServiceRequestReceived(const UBSHcomRequestContext &ctx) +{ + Ep2ChanUpCtx epCtx(ctx.EndPoint()->UpCtx()); + auto ch = epCtx.Channel(); + if (NN_UNLIKELY(ch == nullptr)) { + NN_LOG_ERROR("UBSHcomService Up context invalid, maybe broken then handle, ep Id " << ctx.EndPoint()->Id()); + return SER_ERROR; + } + HcomSeqNo netSeqNo(ctx.Header().seqNo); + bool isResp = netSeqNo.IsResp(); + UBSHcomServiceContext context(ctx, ch); + + // 如果服务层消息存在头部信息... + std::string msg; + if (ctx.extHeaderType == UBSHcomExtHeaderType::RAW) { + // 无服务层扩展头 + } else if (ctx.extHeaderType == UBSHcomExtHeaderType::FRAGMENT) { + int error = 0; + SpliceMessageResultType result = SpliceMessageResultType::INDETERMINATE; + + std::tie(result, error, msg) = ch->SpliceMessage(ctx, isResp); + switch (result) { + case SpliceMessageResultType::OK: + context.mData = &msg[0]; + context.mDataLen = msg.size(); + break; + + case SpliceMessageResultType::INDETERMINATE: + case SpliceMessageResultType::ERROR: + return error; + } + } + + if (!isResp) { + if (context.OpCode() == EXCHANGE_TIMESTAMP_OP) { + return ServiceExchangeTimeStampHandle(context); + } + if (context.OpCode() == RNDV_CALL_OP_V2) { + ServicePrivateOpHandle(context); + } + if (context.OpCode() < MAX_USER_OPCODE || ctx.OpType() == UBSHcomRequestContext::NN_RECEIVED_RAW) { + auto &userHandler = mOptions.recvHandler; + int ret = SER_OK; + NetTrace::TraceBegin(SERVICE_CB_REQUEST_RECEIVED); + ret = userHandler(context); + NetTrace::TraceEnd(SERVICE_CB_REQUEST_RECEIVED, ret); + return ret; + } else { + NN_LOG_ERROR("UBSHcomService Invalid op code " << context.OpCode() << ", ignore message"); + return SER_ERROR; + } + } else { + uintptr_t *tmp = nullptr; + auto ctxStorePtr = ch->GetCtxStore(); + if (NN_UNLIKELY(ctxStorePtr->GetSeqNoAndRemove(ctx.Header().seqNo, tmp) != SER_OK)) { + HcomSeqNo dumpSeq(ctx.Header().seqNo); + NN_LOG_ERROR("UBSHcomService Channel " << ch->GetId() << " fetch " << dumpSeq.ToString() << " context failed"); + return SER_ERROR; + } + + auto timer = reinterpret_cast(tmp); + timer->RunCallBack(context); + timer->MarkFinished(); + timer->DecreaseRef(); + return SER_OK; + } +} + +int32_t HcomServiceImp::ServicePrivateOpHandle(UBSHcomServiceContext &ctx) +{ + // 将context opType设置成rndv context, 回调中用户根据opType判断是否是rndv消息 + ctx.mOpType = UBSHcomRequestContext::NN_OpType::NN_RNDV; + HcomServiceRndvMessage *rndvMessage = static_cast(ctx.mData); + if (rndvMessage == nullptr) { + NN_LOG_ERROR("Failed to get data in service privateOpHandle "); + return SER_ERROR; + } + // opCode 设置成发送端的请求配置 + ctx.mOpCode = rndvMessage->request.opcode; + return SER_OK; +} + +bool HcomServiceImp::RunRequestCallback(UBSHcomChannel *channel, const UBSHcomRequestContext &ctx, + UBSHcomServiceContext &context) +{ + char *upCtx = nullptr; + if (ctx.OpType() == UBSHcomRequestContext::NN_SENT || ctx.OpType() == UBSHcomRequestContext::NN_SENT_RAW || + ctx.OpType() == UBSHcomRequestContext::NN_READ|| ctx.OpType() == UBSHcomRequestContext::NN_WRITTEN) { + upCtx = const_cast(ctx.OriginalRequest().upCtxData); + } else if (ctx.OpType() == UBSHcomRequestContext::NN_SENT_RAW_SGL || + ctx.OpType() == UBSHcomRequestContext::NN_SGL_WRITTEN || ctx.OpType() == UBSHcomRequestContext::NN_SGL_READ) { + upCtx = const_cast(ctx.OriginalSgeRequest().upCtxData); + } else { + NN_LOG_ERROR("Invalid op type " << ctx.OpType() << " for request posted"); + return false; + } + + /* try to get callback from ctx, usually is response message type */ + Callback *done = GetServiceTransCb(upCtx); + if (done != nullptr) { + done->Run(context); + return true; + } + + uint32_t seqNo = GetServiceTransSeqNo(upCtx); + uintptr_t *tmp = nullptr; + auto ctxStorePtr = channel->GetCtxStore(); + if (NN_UNLIKELY(ctxStorePtr->GetSeqNoAndRemove(seqNo, tmp) != SER_OK)) { + HcomSeqNo dumpSeq(seqNo); + NN_LOG_ERROR("Channel " << channel->GetId() << " fetch " << dumpSeq.ToString() << " context failed"); + return false; + } + + auto timer = reinterpret_cast(tmp); + timer->RunCallBack(context); + timer->MarkFinished(); + timer->DecreaseRef(); + return true; +} + +int32_t HcomServiceImp::ServiceRequestPosted(const UBSHcomRequestContext &ctx) +{ + Ep2ChanUpCtx epCtx(ctx.EndPoint()->UpCtx()); + auto ch = epCtx.Channel(); + if (NN_UNLIKELY(ch == nullptr)) { + NN_LOG_ERROR("Up context invalid, maybe broken then handle, ep Id " << ctx.EndPoint()->Id() << " result " + << ctx.Result()); + return SER_ERROR; + } + UBSHcomServiceContext context(ctx, ch); + + if (ch->GetCallBackType() == UBSHcomChannelCallBackType::CHANNEL_FUNC_CB) { + if (!IsNeedInvokeCallback(ctx)) { + return SER_OK; + } + + NetTrace::TraceBegin(SERVICE_CB_REQUEST_POSTED); + if (NN_UNLIKELY(!RunRequestCallback(ch, ctx, context))) { + NN_LOG_ERROR("Failed to get user callback for call request posted cb"); + NetTrace::TraceEnd(SERVICE_CB_REQUEST_POSTED, SER_ERROR); + return SER_ERROR; + } + + NetTrace::TraceEnd(SERVICE_CB_REQUEST_POSTED, SER_OK); + return SER_OK; + } else if (ch->GetCallBackType() == UBSHcomChannelCallBackType::CHANNEL_GLOBAL_CB) { + if (mOptions.sendHandler == nullptr) { + NN_LOG_ERROR("global callback channel is nullptr"); + return SER_ERROR; + } + return mOptions.sendHandler(context); + } else { + NN_LOG_ERROR("Invalid callback type " << static_cast(ch->GetCallBackType()) << + " for call request posted cb"); + return SER_ERROR; + } +} + +int32_t HcomServiceImp::ServiceOneSideDone(const UBSHcomRequestContext &ctx) +{ + Ep2ChanUpCtx epCtx(ctx.EndPoint()->UpCtx()); + auto ch = epCtx.Channel(); + if (NN_UNLIKELY(ch == nullptr)) { + NN_LOG_ERROR("Default imp up context invalid, maybe broken then handle, ep Id " << ctx.EndPoint()->Id() + << " result " << ctx.Result()); + return SER_ERROR; + } + + UBSHcomServiceContext context(ctx, ch); + + if (ch->GetCallBackType() == UBSHcomChannelCallBackType::CHANNEL_FUNC_CB) { + NetTrace::TraceBegin(SERVICE_CB_ONESIDE_DONE); + if (NN_UNLIKELY(!RunRequestCallback(ch, ctx, context))) { + NN_LOG_ERROR("Default imp failed to get user callback for call one side done cb"); + NetTrace::TraceEnd(SERVICE_CB_ONESIDE_DONE, SER_ERROR); + return SER_ERROR; + } + + NetTrace::TraceEnd(SERVICE_CB_ONESIDE_DONE, SER_OK); + return SER_OK; + } else if (ch->GetCallBackType() == UBSHcomChannelCallBackType::CHANNEL_GLOBAL_CB) { + UBSHcomServiceOneSideDoneHandler &handler = mOptions.oneSideDoneHandler; + if (handler == nullptr) { + NN_LOG_ERROR("handle is null"); + return SER_ERROR; + } + return handler(context); + } else { + NN_LOG_ERROR("Default imp invalid callback type " << static_cast(ch->GetCallBackType()) << + " for call one side done cb"); + return SER_ERROR; + } +} + +int32_t HcomServiceImp::ServiceSecInfoProvider(uint64_t chId, int64_t &flag, UBSHcomNetDriverSecType &type, + char *&output, uint32_t &outLen, bool &needAutoFree) +{ + bool infoExist = false; + ConnectingSecInfo info {}; + { + std::lock_guard lockerEp(mNewEpMutex); + auto iter = mSecInfoMap.find(chId); + if ((infoExist = (iter != mSecInfoMap.end()))) { + info = iter->second; + } + } + // not first call provider + if (!info.firstCallProvider) { + flag = info.flag; + type = info.type; + output = info.secContent; + outLen = info.secContentLen; + needAutoFree = info.needAutoFree; + return 0; + } + if (NN_UNLIKELY(mOptions.connSecOption.provider == nullptr)) { + NN_LOG_ERROR("Failed to provide secInfo as handler is nullptr"); + return SER_ERROR; + } + // first call provider + auto result = mOptions.connSecOption.provider(chId, flag, type, output, outLen, needAutoFree); + info.Initialize(flag, type, output, outLen, needAutoFree); + + std::lock_guard lockerEp(mNewEpMutex); + if (!infoExist) { + // case1: one-way or two-way case client first call provider, secInfo is not in map + mSecInfoMap.emplace(chId, info); + return result; + } + // case2: two-way case server first call provider, secInfo has already added to map when first call validator + mSecInfoMap[chId] = info; + return result; +} + +int32_t HcomServiceImp::ServiceSecInfoValidator(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen) +{ + ConnectingSecInfo info {}; + bool infoExist = false; + { + std::lock_guard lockerEp(mNewEpMutex); + auto iter = mSecInfoMap.find(ctx); + if ((infoExist = (iter != mSecInfoMap.end()))) { + info = iter->second; + } + } + + if (!info.firstCallValidator) { + return 0; + } + if (NN_UNLIKELY(mOptions.connSecOption.validator == nullptr)) { + NN_LOG_ERROR("Failed to validate secInfo as handler is nullptr"); + return SER_ERROR; + } + // first call validator + auto result = mOptions.connSecOption.validator(ctx, flag, input, inputLen); + info.firstCallValidator = false; + + std::lock_guard lockerEp(mNewEpMutex); + if (!infoExist) { + // case1: one-way two-way case server first call validator, and add secInfo to map + mSecInfoMap.emplace(ctx, info); + return result; + } + // case2: two-way case client first call validator, secInfo has already added to map when first call provider + mSecInfoMap[ctx] = info; + return result; +} + +std::string HcomServiceImp::GetFilteredDeviceIP(const std::string& ipMask) +{ + std::string res; + std::vector filterVec; + NetFunc::NN_SplitStr(ipMask, ",", filterVec); + if (filterVec.empty()) { + NN_LOG_WARN("Invalid ip mask " << ipMask); + return res; + } + + std::vector filteredIp; + for (auto &mask : filterVec) { + FilterIp(mask, filteredIp); + } + + if (filteredIp.empty()) { + NN_LOG_WARN("No matched ip found with " << ipMask); + return res; + } + + res = filteredIp[0]; + return res; +} + +void HcomServiceImp::ConvertHcomSerImpOptsToHcomDriOpts(const HcomServiceImpOptions &serviceOpt, + ock::hcom::UBSHcomNetDriverOptions &driverOpt) +{ + driverOpt.SetNetDeviceIpMask(serviceOpt.ipMasks); + driverOpt.SetNetDeviceIpGroup(serviceOpt.ipGroups); + driverOpt.enableTls = serviceOpt.tlsOption.enableTls; + driverOpt.secType = serviceOpt.connSecOption.secType; + driverOpt.cipherSuite = serviceOpt.tlsOption.netCipherSuite; + driverOpt.tlsVersion = serviceOpt.tlsOption.tlsVersion; + driverOpt.dontStartWorkers = serviceOpt.workerGroupInfos[0].empty(); + driverOpt.mode = serviceOpt.workerGroupMode; + driverOpt.oobType = serviceOpt.oobType; + driverOpt.lbPolicy = serviceOpt.lbPolicy; + driverOpt.magic = serviceOpt.connSecOption.magic; + driverOpt.version = serviceOpt.connSecOption.version; + driverOpt.heartBeatIdleTime = serviceOpt.heartBeatOption.heartBeatIdleSec; + driverOpt.heartBeatProbeTimes = serviceOpt.heartBeatOption.heartBeatProbeTimes; + driverOpt.heartBeatProbeInterval = serviceOpt.heartBeatOption.heartBeatProbeIntervalSec; + driverOpt.tcpUserTimeout = serviceOpt.tcpTimeOutSec; + driverOpt.tcpSendZCopy = serviceOpt.tcpSendZCopy; + + driverOpt.mrSendReceiveSegSize = serviceOpt.maxSendRecvDataSize; + driverOpt.completionQueueDepth = serviceOpt.completionQueueDepth; + driverOpt.pollingBatchSize = serviceOpt.pollingBatchSize; + driverOpt.eventPollingTimeout = serviceOpt.eventPollingTimeOutUs; + driverOpt.qpSendQueueSize = serviceOpt.qpSendQueueSize; + driverOpt.qpReceiveQueueSize = serviceOpt.qpRecvQueueSize; + driverOpt.prePostReceiveSizePerQP = serviceOpt.qpPrePostSize; + driverOpt.maxConnectionNum = serviceOpt.maxConnCount; + driverOpt.enableMultiRail = serviceOpt.enableMultiRail; + driverOpt.mrSendReceiveSegCount = serviceOpt.maxSendRecvDataCount; + driverOpt.ubcMode = serviceOpt.ubcMode; +} + +SerResult HcomServiceImp::ExchangeTimestamp(UBSHcomChannel *channel) +{ + HcomChannelImp *ch = dynamic_cast(channel); + + if (ch == nullptr) { + NN_LOG_ERROR("Failed to exchange timestamp, ch is null "); + return SER_ERROR; + } + + HcomExchangeTimestamp reqTimestamp {}; + UBSHcomRequest req(&reqTimestamp, sizeof(reqTimestamp), EXCHANGE_TIMESTAMP_OP); + HcomExchangeTimestamp rspTimestamp {}; + UBSHcomResponse rsp(&rspTimestamp, sizeof(rspTimestamp)); + + reqTimestamp.deltaTimeStamp = NN_NO100; + // deltaTimeStamp:预估的网络RTT,初始为100us,测算方式如下: + // 首次不符合预期快速更新:delta time = call时间 * 1.2 + // 后续不符合预期指数退避:delta time = delta time * 2 + uint32_t i = 0; + for (; i <= NN_NO16; i++) { + reqTimestamp.timestamp = NetMonotonic::TimeUs(); + auto result = ch->SyncCallInner(req, rsp, NN_NO64); + if (result == SER_OK) { + uint64_t coastTime = NetMonotonic::TimeUs() - reqTimestamp.timestamp; + if (reqTimestamp.deltaTimeStamp > coastTime) { + break; + } + if (i == 0) { + // 首次测算RTT失败,快速更新delta time + reqTimestamp.deltaTimeStamp = coastTime + coastTime / NN_NO5; + } else { + reqTimestamp.deltaTimeStamp *= NN_NO2; + } + // if sync call operation which spend time more than delta time, try next delta time + NN_LOG_TRACE_INFO("Delta time " << reqTimestamp.deltaTimeStamp << ", coast time " << coastTime); + continue; + } else { + NN_LOG_ERROR("Failed to exchange timestamp " << result); + return result; + } + } + + if (NN_UNLIKELY(i > NN_NO16)) { + NN_LOG_ERROR("Failed to exchange timestamp"); + return SER_TIMEOUT; + } + + ch->mConnectTimestamp.localTimeUs = reqTimestamp.timestamp; + ch->mConnectTimestamp.remoteTimeUs = rspTimestamp.timestamp; + ch->mConnectTimestamp.deltaTimeUs = reqTimestamp.deltaTimeStamp; + NN_LOG_INFO("Exchange timestamp success, ch id " << ch->GetId() << ", local " << reqTimestamp.timestamp << + "us, remote " << rspTimestamp.timestamp << "us, delta " << reqTimestamp.deltaTimeStamp << "us"); + return SER_OK; +} + +int HcomServiceImp::ServiceExchangeTimeStampHandle(UBSHcomServiceContext &ctx) +{ + if (NN_UNLIKELY(ctx.Result() != SER_OK)) { + NN_LOG_ERROR("Exchange timestamp failed " << ctx.Result()); + return ctx.Result(); + } + + if (NN_UNLIKELY(ctx.MessageDataLen() != sizeof(HcomExchangeTimestamp))) { + NN_LOG_ERROR("Exchange timestamp receive invalid message "); + return SER_INVALID_PARAM; + } + + auto timestamp = reinterpret_cast(ctx.MessageData()); + if (NN_UNLIKELY(timestamp->deltaTimeStamp == NN_NO0)) { + NN_LOG_ERROR("Exchange timestamp receive invalid delta " << timestamp->deltaTimeStamp); + return SER_INVALID_PARAM; + } + + if (ctx.Channel().Get() == nullptr) { + NN_LOG_ERROR("Exchange timestamp receive invalid channel "); + return SER_INVALID_PARAM; + } + + HcomChannelImp *ch = dynamic_cast(ctx.Channel().Get()); + if (ch == nullptr) { + NN_LOG_ERROR("Fail to dynamic_cast channel "); + return SER_ERROR; + } + + ch->mConnectTimestamp.localTimeUs = NetMonotonic::TimeUs(); + ch->mConnectTimestamp.remoteTimeUs = timestamp->timestamp; + ch->mConnectTimestamp.deltaTimeUs = timestamp->deltaTimeStamp; + + NN_LOG_INFO("Exchange timestamp success, ch id " << ch->GetId() << ", local " << + ch->mConnectTimestamp.localTimeUs << "us, remote " << ch->mConnectTimestamp.remoteTimeUs << "us, delta " << + ch->mConnectTimestamp.deltaTimeUs << "us"); + + timestamp->timestamp = ch->mConnectTimestamp.localTimeUs; + UBSHcomRequest req(timestamp, sizeof(HcomExchangeTimestamp), EXCHANGE_TIMESTAMP_OP); + UBSHcomReplyContext replyCtx(ctx.RspCtx(), NN_NO0); + return ctx.Channel()->Reply(replyCtx, req, HcomServiceGlobalObject::gEmptyCallback); +} +} +} + diff --git a/src/service_v2/service_imp.h b/src/service_v2/service_imp.h new file mode 100644 index 0000000000000000000000000000000000000000..46a4283fe028f6148f1a1253bff0aa72b0d1fe08 --- /dev/null +++ b/src/service_v2/service_imp.h @@ -0,0 +1,501 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_SERVICE_V2_HCOM_SERVICE_IMP_H_ +#define HCOM_SERVICE_V2_HCOM_SERVICE_IMP_H_ + +#include +#include +#include + +#include "hcom_def.h" +#include "hcom_log.h" +#include "api/hcom_service_def.h" +#include "api/hcom_service.h" +#include "api/hcom_service_channel.h" +#include "hcom_obj_statistics.h" +#include "service_periodic_manager.h" +#include "net_common.h" +#include "net_load_balance.h" +#include "net_oob.h" +#include "net_oob_ssl.h" +#include "net_pgtable.h" + +namespace ock { +namespace hcom { + +using NetDriverPtr = NetRef; +using HcomPeriodicManagerPtr = NetRef; +using HcomServiceCtxStorePtr = NetRef; +using ConnectingEpInfoPtr = NetRef; +using NetPgTablePtr = NetRef; + +struct HcomServiceImpOptions { + UBSHcomTlsOptions tlsOption; + UBSHcomConnSecureOptions connSecOption; + UBSHcomHeartBeatOptions heartBeatOption; + UBSHcomServiceIdleHandler idleHandler = nullptr; + UBSHcomServiceRecvHandler recvHandler = nullptr; + UBSHcomServiceSendHandler sendHandler = nullptr; + UBSHcomServiceOneSideDoneHandler oneSideDoneHandler = nullptr; + UBSHcomServiceNewChannelHandler chNewHandler = nullptr; + UBSHcomServiceChannelBrokenHandler chBrokenHandler = nullptr; + uint32_t qpSendQueueSize = 256; + uint32_t qpRecvQueueSize = 256; + uint32_t qpPrePostSize = 64; + uint32_t maxSendRecvDataSize = 1024; + uint32_t maxSendRecvDataCount = 8192; + uint32_t maxConnCount = 250; + uint32_t multiRailThresh = 8192; + uint32_t timeOutDetectThreadNum = 1; + uint16_t pollingBatchSize = 4; + uint16_t eventPollingTimeOutUs = 500; + uint16_t completionQueueDepth = 2048; + uint16_t tcpTimeOutSec = -1; + uint16_t jettyId = 0; + UBSHcomServiceProtocol protocol; + std::string name; + std::string eid; + bool enableRndv = false; + bool tcpSendZCopy = false; + bool startOobSvr = false; + bool enableMultiRail = false; + NetDriverOobType oobType = NET_OOB_TCP; + UBSHcomServiceLBPolicy lbPolicy = NET_ROUND_ROBIN; + UBSHcomWorkerMode workerGroupMode = NET_BUSY_POLLING; + UBSHcomChannelBrokenPolicy chBrokenPolicy = UBSHcomChannelBrokenPolicy::BROKEN_ALL; + UBSHcomUbcMode ubcMode = UBSHcomUbcMode::LowLatency; + std::vector> workerGroupInfos; + std::vector ipMasks; + std::vector ipGroups; + std::unordered_map oobOption; + std::unordered_map udsOobOption; +}; + +class HcomServiceImp : public UBSHcomService { +public: + HcomServiceImp(UBSHcomServiceProtocol t, const std::string &name, const UBSHcomServiceOptions &opt) + { + mOptions.protocol = t; + mOptions.name = name; + mOptions.maxSendRecvDataSize = opt.maxSendRecvDataSize; + mOptions.workerGroupMode = opt.workerGroupMode; + if (NN_LIKELY(opt.workerGroupThreadCount != 0)) { + UBSHcomWorkerGroupInfo groupInfo; + groupInfo.threadPriority = opt.workerThreadPriority; + groupInfo.threadCount = opt.workerGroupThreadCount; + groupInfo.groupId = opt.workerGroupId; + groupInfo.cpuIdsRange = opt.workerGroupCpuIdsRange; + std::vector workerInfoVec {}; + workerInfoVec.emplace_back(groupInfo); + mOptions.workerGroupInfos.emplace_back(workerInfoVec); // UBSHcomService::Instance lock + } + OBJ_GC_INCREASE(HcomServiceImp); + } + + ~HcomServiceImp() override + { + OBJ_GC_DECREASE(HcomServiceImp); + } + + /** + * @brief 绑定监听url,指定监听的类型及url,客户端可以不调用Bind。 + * + * @param listenerUrl 监听url,对于tcp来说:tcp://127.0.0.1:9981 + * 对于uds来说:uds://file:perm(如果有:perm则使用真实文件,perm格式如:0600,没有则使用抽象文件) + * 对于ubc来说:ubc://eid:jettyId + * @param handler 收到建链请求后的回调函数 + * @return int32_t 成功:0;失败:错误码 + */ + int32_t Bind(const std::string &listenerUrl, const UBSHcomServiceNewChannelHandler &handler) override; + + /** + * @brief 开启服务,如果调用过Bind,则同时开启监听,否则不进行监听 + * + * @return int32_t 成功:0;失败:错误码 + */ + int32_t Start() override; + + /** + * @brief 建立链接 + * + * @param serverUrl 建连服务端url,对于tcp来说:tcp://127.0.0.1:9981 + * 对于uds来说:uds://file(文件名/抽象命名空间) + * 对于ubc来说:ubc://eid:jettyId + * @param ch 出参,建链成功返回的channel + * @return int32_t 成功:0;失败:错误码 + */ + int32_t Connect(const std::string &serverUrl, UBSHcomChannelPtr &ch, const UBSHcomConnectOptions &opt) override; + + /** + * @brief 断开链接 + * + * @param ch 要断开的channel + */ + void Disconnect(const UBSHcomChannelPtr &ch) override; + + /** + * @brief 注册memory region,内存会在内部进行分配 + * + * @param size memory region的大小 + * @param mr 注册好的memoryRegion + * @return int32_t 成功:0;失败:错误码 + */ + int32_t RegisterMemoryRegion(uint64_t size, UBSHcomRegMemoryRegion &mr) override; + + /** + * @brief 注册memory region,分配的内存需要传入进来 + * + * @param address 需要被注册为MR的内存起始地址 + * @param size memory region的大小 + * @param mr 注册好的memoryRegion + * @return int32_t 成功:0;失败:错误码 + */ + int32_t RegisterMemoryRegion(uintptr_t address, uint64_t size, UBSHcomRegMemoryRegion &mr) override; + + /** + * @brief memory region取消注册 + * + * @param mr 取消的mr + */ + void DestroyMemoryRegion(UBSHcomRegMemoryRegion &mr) override; + + /** + * @brief 设置RegisterMemoryRegion是否将mr放入pgTable管理 + * 若用户需要使用RNDV,则需要设置为true + * + * @param enableMrCache true表示放入pgTable,false表示不放入;默认是false。 + */ + void SetEnableMrCache(bool enableMrCache) override; + + /** + * @brief 注册断链回调 + * + * @param handler 断链回调函数 + * @param policy 断链回调策略 + */ + void RegisterChannelBrokenHandler(const UBSHcomServiceChannelBrokenHandler &handler, + const UBSHcomChannelBrokenPolicy policy) override; + + /** + * @brief 注册pollCq、epoll_wait超时等回调 + * + * @param handler 回调函数 + */ + void RegisterIdleHandler(const UBSHcomServiceIdleHandler &handler) override; + + /** + * @brief 注册接收receive操作回调 + * + * @param rcvHandler 回调函数 + */ + void RegisterRecvHandler(const UBSHcomServiceRecvHandler &recvHandler) override; + + /** + * @brief 注册发送send操作回调 + * + * @param sentHandler 回调函数 + */ + void RegisterSendHandler(const UBSHcomServiceSendHandler &sendHandler) override; + + /** + * @brief 注册单边操作回调 + * + * @param channelTypeIdx 允许为不同channel设置不同回调,channelTypeIdx对应channel类型的下标 + * @param oneSideDoneHandler 回调函数 + */ + void RegisterOneSideHandler(const UBSHcomServiceOneSideDoneHandler &oneSideDoneHandler) override; + + // 高级配置选项及特性配置选项 + + /** + * @brief 增加workerGroup + * + * @param workerGroupId workerGroup的id + * @param threadCount 该workerGroup的线程数 + * @param cpuIdsRange 该workerGroup绑定的cpuId范围 + * @param priority 同线程nice值,范围[-20,19],-20优先级最高,19优先级最低 + * @param multirailIdx 该workerGroup绑定的rail + */ + void AddWorkerGroup(uint16_t workerGroupId, uint32_t threadCount, + const std::pair &cpuIdsRange, int8_t priority = 0, uint16_t multirailIdx = 0) override; + + /** + * @brief 增加监听器,支持监听多个url + * + * @param url 监听url,tcp协议:tcp://127.0.0.1:9981;uds协议:uds://file(文件名/抽象命名空间) + * @param workerCount 监听到链接请求后,会从对应的workerGroup中选择workerCount个线程按照lbPolicy的策略去选择线程绑定到ep + * @return int32_t + */ + void AddListener(const std::string &url, uint16_t workerCount = UINT16_MAX) override; + + /** + * @brief 设置建链负载均衡策略,主动/被动建链时需要选择一个worker线程去完成,lbPolicy则代表选择worker线程的策略 + * + * @param lbPolicy NET_ROUND_ROBIN:轮询,NET_HASH_IP_PORT:根据ip和port做hash + */ + void SetConnectLBPolicy(UBSHcomServiceLBPolicy lbPolicy) override; + + /** + * @brief TLS相关配置项,如果不配置的话默认不开启 + * + * @param opt + */ + void SetTlsOptions(const UBSHcomTlsOptions &opt) override; + + void SetConnSecureOpt(const UBSHcomConnSecureOptions &opt) override; + + /** + * @brief 设置TCP_USER_TIMEOUT套接字选项,tcp超时时间,[0, 1024],0表示永不超时 + * + * @param timeOutSec + */ + void SetTcpUserTimeOutSec(uint16_t timeOutSec) override; + + /** + * @brief 设置TCP发送是否要做内存拷贝(hcom内部内存) + * + * @param tcpSendZCopy 是否要做数据拷贝 + */ + void SetTcpSendZCopy(bool tcpSendZCopy) override; + + /** + * @brief 设置设备ipMask,用于rdma/ub,根据ipMask获取该网段的GID和UBEId + * + * @param ipMasks 用于过滤的ipMask集合 + */ + void SetDeviceIpMask(const std::vector &ipMasks) override; + + /** + * @brief 设置设备的ipGroup,如果明确制定了ipGroup,则直接使用对应的设备 + * + * @param ipGroups ipGroups集合 + */ + void SetDeviceIpGroups(const std::vector &ipGroups) override; + + /** + * @brief 设置cq队列的深度 + * + * @param depth cq队列深度 + */ + void SetCompletionQueueDepth(uint16_t depth) override; + + /** + * @brief 设置SQ队列的大小,默认256 + * + * @param sqSize 队列大小 + */ + void SetSendQueueSize(uint32_t sqSize) override; + + /** + * @brief 设置RQ队列的大小,默认256 + * + * @param rqSize 队列大小 + */ + void SetRecvQueueSize(uint32_t rqSize) override; + + /** + * @brief 设置提前下发wr的数量,不设置的话默认64 + * + * @param prePostSize 预先下发的wr数量 + */ + void SetQueuePrePostSize(uint32_t prePostSize) override; + + /** + * @brief 设置批量polling的大小,默认是4 + * + * @param pollSize 每批大小 + */ + void SetPollingBatchSize(uint16_t pollSize) override; + + /** + * @brief 设置polling的超时时间,单位us,默认500 + * + * @param pollTimeout 超时时间 + */ + void SetEventPollingTimeOutUs(uint16_t pollTimeout) override; + + /** + * @brief 设置周期任务处理线程数,主要用在内部异步检查超时等场景,不设置的话默认1个线程 + * + * @param threadNum 线程数 + */ + void SetTimeOutDetectionThreadNum(uint32_t threadNum) override; + + /** + * @brief 设置最大连接数,不设置的话默认250 + * + * @param maxConnCount 最大连接数 + */ + void SetMaxConnectionCount(uint32_t maxConnCount) override; + + /** + * @brief 设置心跳选项 + * + * @param opt 心跳设置选项 + * @return int32_t + */ + void SetHeartBeatOptions(const UBSHcomHeartBeatOptions &opt) override; + + /** + * @brief Set the Multi Rail Options object + * + * @param opt multi rail option + */ + void SetMultiRailOptions(const UBSHcomMultiRailOptions &opt) override; + + /** + * @brief 设置 UB-C 多路径模式 + * + * @param ubcMode UB-C 多路径模式 + */ + void SetUbcMode(UBSHcomUbcMode ubcMode) override; + + /** + * @brief 设置发送数据块最大数量 + * + * @param maxSendRecvDataCount 发送数据块最大数量 + */ + void SetMaxSendRecvDataCount(uint32_t maxSendRecvDataCount) override; + +private: + SerResult ValidateServiceOption(); + SerResult CreateResource(); + SerResult InitDriver(); + SerResult DoInitDriver(); + SerResult CreateMultiRailDriver(); + SerResult CreateOobUdsListeners(const UBSHcomNetDriverOptions &driverOpt); + SerResult CreateOobListeners(const UBSHcomNetDriverOptions &driverOpt); + + SerResult StartDriver(); + SerResult CreatePeriodicMgr(); + SerResult CreateCtxMemPool(); + SerResult DoConnect(const std::string &serverUrl, SerConnInfo &opt, const std::string &payLoad, + UBSHcomChannelPtr &tmpChannel); + SerResult DoConnectInner(const std::string &serverUrl, SerConnInfo &opt, const std::string &payLoad, + std::vector &epVector, uint32_t &totalBandWidth); + SerResult ChooseDriver(OOBTCPConnection &conn, UBSHcomNetDriver *&driver); + void DoChooseDriver(uint8_t devInex, uint8_t bandWidth, + int8_t &selectDevIndex, uint8_t &selectBandWidth, UBSHcomNetDriver *&driver); + + void ConvertHcomSerImpOptsToHcomDriOpts(const HcomServiceImpOptions &serviceOpt, + UBSHcomNetDriverOptions &driverOpt); + void RegisterDriverCb(); + bool RunRequestCallback(UBSHcomChannel *channel, const UBSHcomRequestContext &ctx, UBSHcomServiceContext &context); + + SerResult DoDestroy(const std::string &name) override; + void ForceStop(); + SerResult AddTcpOobListener(const std::string &url, uint16_t workerCount = UINT16_MAX); + SerResult AddUdsOobListener(const std::string &url, uint16_t workerCount = UINT16_MAX); + + SerResult DelayEraseChannel(UBSHcomChannelPtr &ch, uint16_t delayTime); + void EraseChannel(uintptr_t chPtr); + + SerResult GenerateUuid(const std::string &ipInfo, uint64_t channelId, std::string &uuid); + SerResult GenerateUuid(uint32_t ip, uint64_t channelId, std::string &uuid); + int32_t ServiceNewChannel(const std::string &ipPort, SerConnInfo &connInfo, const std::string &userPayLoad, + std::vector &ep); + int32_t ServiceHandleNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEp, + const std::string &payload); + SerResult EmplaceNewEndpoint(const UBSHcomNetEndpointPtr &newEp, ConnectingEpInfoPtr &epInfo, + SerConnInfo &connInfo, std::string &uuid); + void ServiceEndPointBroken(const UBSHcomNetEndpointPtr &netEp); + int32_t ServiceRequestReceived(const UBSHcomRequestContext &ctx); + int32_t ServiceRequestPosted(const UBSHcomRequestContext &ctx); + int32_t ServiceOneSideDone(const UBSHcomRequestContext &ctx); + int32_t ServiceSecInfoProvider(uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, + uint32_t &outLen, bool &needAutoFree); + int32_t ServiceSecInfoValidator(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen); + SerResult ExchangeTimestamp(UBSHcomChannel *channel); + int ServiceExchangeTimeStampHandle(UBSHcomServiceContext &ctx); + std::string GetFilteredDeviceIP(const std::string& ipMask); + /** + * @brief MultiRail模式下注册的Connection事件处理函数 + * + * @param conn + * @return int32_t + */ + SerResult NewConnectionCB(OOBTCPConnection &conn); + + inline UBSHcomServiceProtocol Protocol() const + { + return mOptions.protocol; + } + + inline SerResult GetIpAddressByIpPort(const std::string &oobIpPort, uint32_t &ipAddress) const + { + if (Protocol() == SHM) { + ipAddress = 0xffffffff; + } else { + if (NN_UNLIKELY(!NetFunc::NN_CovertIpWithoutPort(oobIpPort, ipAddress))) { + NN_LOG_ERROR("Default imp Failed to covert ip by " << oobIpPort); + return SER_INVALID_PARAM; + } + } + return SER_OK; + } + + inline SerResult EmplaceChannelUuid(UBSHcomChannelPtr &channel) + { + std::lock_guard lockerChannel(mChannelMutex); + auto ret = mChannelMap.emplace(channel.Get()->GetUuid(), channel); + if (NN_UNLIKELY(!ret.second)) { + NN_LOG_ERROR("Failed to emplace channel " << channel.Get()->GetId() << ", already exist"); + return SER_ERROR; + } + return SER_OK; + } + + int32_t ServicePrivateOpHandle(UBSHcomServiceContext &ctx); + + static PgtDir *pgdAlloc(const PgTable &pgtable) + { + return new PgtDir; + } + + static void pgdFree(const PgTable &pgtable, PgtDir *pgdir) + { + delete pgdir; + } + + void DestroyNetMrs(std::vector &netMrs, uint32_t start, uint32_t end); + + SerResult InsertPgTable(UBSHcomNetMemoryRegionPtr &mr); + +private: + HcomServiceImpOptions mOptions; + std::string mOobIp; + std::vector mDriverPtrs; + HcomPeriodicManagerPtr mPeriodicMgr = nullptr; + NetMemPoolFixedPtr mContextMemPool = nullptr; + bool mStarted = false; + + std::mutex mStartMutex; + std::mutex mOptionsMutex; + std::mutex mNewEpMutex; + std::mutex mChannelMutex; + + std::map mNewEpMap; // temporary storage eps until create channel + std::map mSecInfoMap; // temporary storage secInfo + std::map mChannelMap; + std::vector mOobServers; // oob server need to be configed when enable multirail + std::map mDriverPair; /* local driver Index and remote driver Index map have been connected + to each other */ + std::vector mUseId; /* local driver Index has been used */ + uint32_t mDriverIndex = 0; + NetPgTablePtr mPgtable = nullptr; + bool mEnableMrCache = false; // mr into pgTable for management +}; + +} +} +#endif // HCOM_SERVICE_V2_HCOM_SERVICE_IMP_H_ \ No newline at end of file diff --git a/src/service_v2/service_periodic_manager.cpp b/src/service_v2/service_periodic_manager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7c88e98300a1f12b8b34987983f12d7e38cca2d0 --- /dev/null +++ b/src/service_v2/service_periodic_manager.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "hcom_service_context.h" +#include "net_trace.h" +#include "service_common.h" +#include "service_periodic_manager.h" + +namespace ock { +namespace hcom { + +SerResult HcomPeriodicManager::Start() +{ + std::lock_guard guard(mMutex); + if (mStarted) { + return SER_OK; + } + + if (mThreadCount > M_MAX_THREAD_NUM) { + NN_LOG_ERROR("Invalid thread count " << mThreadCount); + return SER_INVALID_PARAM; + } + + mNeedStop = false; + /* create periodicManager threads */ + for (uint16_t i = 0; i < mThreadCount; i++) { + std::thread tmpThread(&HcomPeriodicManager::RunInThread, this, i); + if (!tmpThread.native_handle()) { + StopInner(); + return SER_CREATE_TIMEOUT_THREAD_FAILED; + } + + /* set thread name */ + if (pthread_setname_np(tmpThread.native_handle(), ("HcomPerMgr" + std::to_string(i)).c_str()) != 0) { + NN_LOG_WARN("Unable to set thread name of periodic manager"); + } + mWorkingThreads[i] = std::move(tmpThread); + } + + while (mStartedWorkingThreads.load() != mThreadCount) { + usleep(NN_NO10); + } + + mStarted = true; + return SER_OK; +} + +void HcomPeriodicManager::Stop() +{ + std::lock_guard guard(mMutex); + if (!mStarted) { + return; + } + + StopInner(); + mStarted = false; +} + +void HcomPeriodicManager::StopInner() +{ + mNeedStop = true; + for (uint16_t i = 0; i < mThreadCount; i++) { + if (mWorkingThreads[i].joinable()) { + mWorkingThreads[i].join(); + } + + ProcessCleanUp(i); + } +} + +void HcomPeriodicManager::ProcessCleanUp(uint16_t tId) +{ + if (NN_UNLIKELY(tId >= M_MAX_THREAD_NUM)) { + NN_LOG_WARN("tId is invalid"); + return; + } + UBSHcomServiceContext timeoutCtx{}; + HcomServiceGlobalObject::BuildTimeOutCtx(timeoutCtx); + timeoutCtx.mResult = SER_STOP; + for (uint32_t i = 0; i < M_MAX_BATCH_NUM; i++) { + auto currentQueue = &(mQueue[tId].queue[i]); + std::lock_guard guard(mQueue[tId].lock[i]); + while (!currentQueue->empty()) { + NN_LOG_TRACE_INFO("Process clean up seq no " << currentQueue->top()->SeqNo() << " timeout " << + currentQueue->top()->mTimeout << ", current time " << NetMonotonic::TimeSec()); + if (currentQueue->top()->EraseSeqNoWithRet()) { + currentQueue->top()->TimeoutDump(); + currentQueue->top()->MarkTimeout(); + auto callback = reinterpret_cast(currentQueue->top()->Callback()); + timeoutCtx.mCh = currentQueue->top()->mChannel; + callback->Run(timeoutCtx); + currentQueue->top()->DecreaseRef(); + } + RemoveLinkedList(currentQueue->top()); + currentQueue->top()->DecreaseRef(); + currentQueue->pop(); + timeoutCtx.mCh.Set(nullptr); + } + } +} + +void HcomPeriodicManager::ProcessTimeOut(uint16_t tId) +{ + if (tId >= M_MAX_THREAD_NUM) { + NN_LOG_WARN("tId is invalid"); + return; + } + mHandleQueue[tId].clear(); + for (int32_t i = M_MAX_BATCH_NUM - 1; i >= 0; i--) { + auto currentQueue = &(mQueue[tId].queue[i]); + std::lock_guard guard(mQueue[tId].lock[i]); + while (!currentQueue->empty()) { + NN_LOG_TRACE_INFO("Process time out seq no " << currentQueue->top()->SeqNo() << " timeout " << + currentQueue->top()->mTimeout << ", current time " << NetMonotonic::TimeSec()); + if (currentQueue->top()->IsFinished() || currentQueue->top()->IsTimeOut()) { + mHandleQueue[tId].emplace_back(currentQueue->top()); + currentQueue->pop(); + continue; + } + + break; + } + } + + UBSHcomServiceContext timeoutCtx{}; + HcomServiceGlobalObject::BuildTimeOutCtx(timeoutCtx); + for (auto &i : mHandleQueue[tId]) { + if (i->EraseSeqNoWithRet()) { + i->TimeoutDump(); + i->MarkTimeout(); + auto callback = reinterpret_cast(i->Callback()); + timeoutCtx.mCh = i->mChannel; + callback->Run(timeoutCtx); + i->DecreaseRef(); + } + RemoveLinkedList(i); /* if remove success, decrease linked list ref auto */ + i->DecreaseRef(); /* decrease periodic thread ref */ + timeoutCtx.mCh.Set(nullptr); + } +} + +void HcomPeriodicManager::RunInThread(int16_t tId) +{ + mHandleQueue[tId].reserve(NN_NO8192); + mStartedWorkingThreads.fetch_add(1); + + if (tId >= mThreadCount) { + NN_LOG_ERROR("Invalid tId " << tId << " to run PeriodicManager"); + return; + } + + int eFd = epoll_create(1); + if (eFd < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("HcomPeriodic manager failed to create epoll by " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return; + } + + NN_LOG_INFO("PeriodicManager for timeout [name: " << mName << ", index: " << tId << "] working thread start"); + while (!mNeedStop) { + auto startTime = NetMonotonic::TimeMs(); + ProcessTimeOut(tId); + auto duration = NetMonotonic::TimeMs() - startTime; + + struct epoll_event ev {}; + int waitTimeMs = 0; // wait for 500ms + if (duration >= gMaxTimeout) { + continue; + } else { + waitTimeMs = static_cast(gMaxTimeout - duration); + } + + epoll_wait(eFd, &ev, 1, waitTimeMs); + } + + NetFunc::NN_SafeCloseFd(eFd); + NN_LOG_INFO("PeriodicManager for timeout [name: " << mName << ", index: " << tId << "] working thread exit"); +} +} +} \ No newline at end of file diff --git a/src/service_v2/service_periodic_manager.h b/src/service_v2/service_periodic_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..fff1d1768cbe324b4bfc1320ea9bcce73b107c29 --- /dev/null +++ b/src/service_v2/service_periodic_manager.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_SERVICE_V2_SERVICE_PERIODIC_MANAGER_H_ +#define HCOM_SERVICE_V2_SERVICE_PERIODIC_MANAGER_H_ +#include +#include +#include + +#include "service_common.h" +#include "hcom_utils.h" +#include "service_callback.h" + +namespace ock { +namespace hcom { + +const int M_MAX_THREAD_NUM = 4; +const int M_MAX_BATCH_NUM = 16; + +class HcomPeriodicManager { +public: + HcomPeriodicManager(uint16_t threadCount, const std::string &name) : mThreadCount(threadCount), mName(name) + { + OBJ_GC_INCREASE(HcomPeriodicManager); + } + + ~HcomPeriodicManager() + { + Stop(); + OBJ_GC_DECREASE(HcomPeriodicManager); + } + +#define VALIDATE(timer) \ + do { \ + if (NN_UNLIKELY((timer) == nullptr)) { \ + NN_LOG_ERROR("Failed to add timeout, because timer is null"); \ + return SER_INVALID_PARAM; \ + } \ + \ + if (NN_UNLIKELY(mNeedStop)) { \ + NN_LOG_ERROR("Failed to add timeout seq no " << (timer)->SeqNo() << " because stop service"); \ + return SER_STOP; \ + } \ + \ + if (NN_UNLIKELY((timer)->SeqNo() == 0 || (timer)->Callback() == 0)) { \ + NN_LOG_ERROR("Add timeout invalid seq no " << (timer)->SeqNo() << " or callback " << (timer)->Callback()); \ + return SER_INVALID_PARAM; \ + } \ + } while (false) + + /* + * @brief Add the cb for timeout with seqNo + */ + inline SerResult AddTimer(HcomServiceTimer *&timer) + { + VALIDATE(timer); + uint32_t tId = timer->SeqNo() % mThreadCount; + uint32_t index = mQueue[tId].NextIndex(); + + AddLinkedList(timer); + std::lock_guard guard(mQueue[tId].lock[index]); + mQueue[tId].queue[index].push(timer); + return SER_OK; + } + + SerResult Start(); + void Stop(); + +private: + inline void AddLinkedList(HcomServiceTimer *timer) + { + if (NN_UNLIKELY(timer == nullptr || timer->mType != HcomAsyncCBType::CBS_IO)) { + return; + } + + if (NN_UNLIKELY(timer->mChannel == nullptr || timer->mChannel->GetTimerList() == 0)) { + return; + } + + auto header = reinterpret_cast(timer->mChannel->GetTimerList()); + header->AddTimerCtx(timer); + } + + inline void RemoveLinkedList(HcomServiceTimer *timer) + { + if (NN_UNLIKELY(timer == nullptr || timer->mType != HcomAsyncCBType::CBS_IO)) { + return; + } + + if (NN_UNLIKELY(timer->mChannel == nullptr || timer->mChannel->GetTimerList() == 0)) { + return; + } + + auto header = reinterpret_cast(timer->mChannel->GetTimerList()); + header->RemoveTimerCtx(timer); + } + void StopInner(); + void RunInThread(int16_t tId); + void ProcessTimeOut(uint16_t tId); + void ProcessCleanUp(uint16_t tId); + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + static constexpr uint64_t gMaxTimeout = 500L; + + struct QueueManager { + std::mutex lock[M_MAX_BATCH_NUM]; + std::priority_queue, HcomServiceTimerCompare> + queue[M_MAX_BATCH_NUM]; + uint32_t nextIndex = 0; + QueueManager() = default; + + inline uint32_t NextIndex() + { + return __sync_fetch_and_add(&nextIndex, 1) % M_MAX_BATCH_NUM; + } + }; + +private: + QueueManager mQueue[M_MAX_THREAD_NUM]; + std::vector mHandleQueue[M_MAX_THREAD_NUM]; + + std::thread mWorkingThreads[M_MAX_THREAD_NUM]; + std::atomic mStartedWorkingThreads = { 0 }; + uint16_t mThreadCount = 1; + + std::mutex mMutex; + bool mStarted = false; + bool mNeedStop = true; + + std::string mName; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} +#endif // HCOM_SERVICE_V2_SERVICE_PERIODIC_MANAGER_H_ \ No newline at end of file diff --git a/src/transport/net_ctx_info_pool.h b/src/transport/net_ctx_info_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..8c0278758d4a1a20dbff969af88f51638b2968e2 --- /dev/null +++ b/src/transport/net_ctx_info_pool.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef COMMUNICATION_NET_CTX_INFO_POOL_H +#define COMMUNICATION_NET_CTX_INFO_POOL_H + +#include "hcom_def.h" +#include "net_mem_pool_fixed.h" + +namespace ock { +namespace hcom { +template class OpContextInfoPool { +public: + inline NResult Initialize(const NetMemPoolFixedPtr &opCtxMemPool) + { + mOpCtxMemPool = opCtxMemPool; + return NN_OK; + } + + inline NResult Initialize(const NetMemPoolFixedPtr &opCtxMemPool, const UBSHcomNetDriverProtocol t) + { + mOpCtxMemPool = opCtxMemPool; + mProtocol = t; + return NN_OK; + } + + inline NResult UnInitialize() + { + mOpCtxMemPool.Set(nullptr); + return NN_OK; + } + + inline T *Get() + { + return GetOrReturn(nullptr); + } + + inline void Return(T *info) + { + (void)GetOrReturn(info, false); + } + +private: + /* alloc/free in the same function to make sure use the same thread_local variable */ + inline T *GetOrReturn(T *returnCtx, bool get = true) + { + if (mProtocol == UBSHcomNetDriverProtocol::UDS) { + static thread_local NetTCacheFixed udsThreadCache(mOpCtxMemPool.Get()); + if (get) { + return udsThreadCache.Allocate(); + } else { + udsThreadCache.Free(returnCtx); + return nullptr; + } + } + + static thread_local NetTCacheFixed threadCache(mOpCtxMemPool.Get()); + if (get) { + return threadCache.Allocate(); + } else { + threadCache.Free(returnCtx); + return nullptr; + } + } + + NetMemPoolFixedPtr mOpCtxMemPool; + UBSHcomNetDriverProtocol mProtocol; +}; +} +} + +#endif // COMMUNICATION_NET_CTX_INFO_POOL_H \ No newline at end of file diff --git a/src/transport/net_delay_release_timer.cpp b/src/transport/net_delay_release_timer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba7742de5795e4d561ccff23c6b68e003d31cc49 --- /dev/null +++ b/src/transport/net_delay_release_timer.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "net_delay_release_timer.h" + +namespace ock { +namespace hcom { +constexpr uint32_t EPOLL_WAIT_TIMEOUT = NN_NO1000 * NN_NO10; // 10 second +NResult NetDelayReleaseTimer::Start() +{ + std::lock_guard guard(mMutex); + if (mStarted) { + return NN_OK; + } + + std::thread delayReleaseThread(&NetDelayReleaseTimer::RunDelayReleaseThread, this); + mThread = std::move(delayReleaseThread); + std::string treadName = "DelayRelease" + std::to_string(mDriverIndex); + if (pthread_setname_np(mThread.native_handle(), treadName.c_str()) != 0) { + NN_LOG_WARN("Invalid to set name of NetDelayReleaseTimer working thread to " << treadName); + } + + while (!mThreadStarted.load()) { + usleep(NN_NO10); + } + + mStarted = true; + return NN_OK; +} + +void NetDelayReleaseTimer::Stop() +{ + std::lock_guard guard(mMutex); + if (!mStarted) { + NN_LOG_WARN("NetDelayReleaseTimer " << mName << " has not been started"); + return; + } + + StopInner(); + + mStarted = false; +} + +void NetDelayReleaseTimer::StopInner() +{ + mNeedStop = true; + + if (mThread.native_handle()) { + mThread.join(); + } +} + +void NetDelayReleaseTimer::RunDelayReleaseThread() +{ + mThreadStarted.store(true); + NN_LOG_INFO("NetDelayReleaseTimer " << mName << " working thread started"); + + int eFd = epoll_create(1); + if (eFd < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("NetDelayReleaseTimer thread failed to create epoll by " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return; + } + + while (!mNeedStop) { + auto startTime = NetMonotonic::TimeMs(); + DequeueDelayRelease(); + auto spendTime = NetMonotonic::TimeMs() - startTime; + + struct epoll_event ev {}; + int waitTimeMs = 0; // wait for 1000ms + if (spendTime >= EPOLL_WAIT_TIMEOUT) { + continue; + } else { + waitTimeMs = static_cast(EPOLL_WAIT_TIMEOUT - spendTime); + } + + epoll_wait(eFd, &ev, NN_NO1, waitTimeMs); + } + + NetFunc::NN_SafeCloseFd(eFd); + NN_LOG_INFO("NetDelayReleaseTimer " << mName << " working thread exiting"); +} + +void NetDelayReleaseTimer::DequeueDelayRelease() +{ + std::lock_guard gard(mDelayReleaseMutex); + while (!mDelayReleaseQueue.empty()) { + auto epRes = mDelayReleaseQueue.front(); + if (epRes.IsTimeOut()) { + if (NN_UNLIKELY(epRes.mEp != nullptr)) { + NN_LOG_DEBUG("Destroy Ep " << epRes.mEp->Id() << ", delayed release time has come"); + epRes.mEp.Set(nullptr); + } + mDelayReleaseQueue.pop(); + continue; + } + // if the first one is not timeout ,others is not timeout too + break; + } +} + +void NetDelayReleaseTimer::EnqueueDelayRelease(UBSHcomNetEndpointPtr &ep) +{ + std::lock_guard gard(mDelayReleaseMutex); + auto epRes = NetDelayReleaseResource(ep, NN_NO20); + mDelayReleaseQueue.push(epRes); +} +} +} \ No newline at end of file diff --git a/src/transport/net_delay_release_timer.h b/src/transport/net_delay_release_timer.h new file mode 100644 index 0000000000000000000000000000000000000000..1a7443a14cc876dfde886d913b8ebdd0f42f9b79 --- /dev/null +++ b/src/transport/net_delay_release_timer.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_NET_DELAY_RELEASE_TIMER_H +#define HCOM_NET_DELAY_RELEASE_TIMER_H + +#include + +#include "hcom.h" +#include "hcom_def.h" +#include "net_monotonic.h" + +namespace ock { +namespace hcom { +/* + * Delay Release of queue struct + */ +class NetDelayReleaseResource { +public: + NetDelayReleaseResource(UBSHcomNetEndpointPtr &ep, uint64_t delayTimeSec) + { + mEp = ep; + mTimeout = NetMonotonic::TimeSec() + delayTimeSec; + } + + ~NetDelayReleaseResource() = default; + + bool IsTimeOut() const + { + if (NetMonotonic::TimeSec() > mTimeout) { + return true; + } + return false; + } + +public: + UBSHcomNetEndpointPtr mEp = nullptr; /* manager ep time out */ + uint64_t mTimeout = 0; /* absolute timeout compare to current system time */ +}; + +/* + * Delay Release Timer + */ +class NetDelayReleaseTimer { +public: + NetDelayReleaseTimer(const std::string &name, uint16_t driverIndex) + : mDriverIndex(driverIndex), mName(name + std::to_string(driverIndex)) {}; + + ~NetDelayReleaseTimer() = default; + + NResult Start(); + void Stop(); + + void EnqueueDelayRelease(UBSHcomNetEndpointPtr &ep); + +private: + void StopInner(); + void RunDelayReleaseThread(); + void DequeueDelayRelease(); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS +private: + // hot used variables for start + std::queue mDelayReleaseQueue; + std::mutex mDelayReleaseMutex; + bool mNeedStop = false; + + uint16_t mDriverIndex = 0; + std::string mName; + std::mutex mMutex; + bool mStarted = false; + std::thread mThread; + std::atomic_bool mThreadStarted { false }; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} +#endif // HCOM_NET_DELAY_RELEASE_TIMER_H \ No newline at end of file diff --git a/src/transport/net_endpoint_impl.cpp b/src/transport/net_endpoint_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e974615cd22b5da4bda92e3a368eb3ce59e06b3 --- /dev/null +++ b/src/transport/net_endpoint_impl.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "transport/net_endpoint_impl.h" + +namespace ock { +namespace hcom { + +uint64_t NetEndpointImpl::EstimatedEncryptLen(uint64_t rawLen) +{ + if (NN_UNLIKELY(rawLen == 0)) { + NN_LOG_ERROR("Failed to estimate encrypt length as input length is 0"); + return 0; + } + + if (NN_UNLIKELY(!mIsNeedEncrypt)) { + NN_LOG_ERROR("Failed to estimate encrypt length as options of encrypt is not enabled"); + return 0; + } + + return mAes.EstimatedEncryptLen(rawLen); +} + +NResult NetEndpointImpl::Encrypt(const void *rawData, uint64_t rawLen, void *cipher, uint64_t &cipherLen) +{ + if (NN_UNLIKELY(!mIsNeedEncrypt) || NN_UNLIKELY(rawLen > UINT32_MAX) || NN_UNLIKELY(cipherLen > UINT32_MAX)) { + NN_LOG_ERROR("Failed to encrypt, options of encrypt is not enabled or len over uint32_max"); + return NN_ERROR; + } + + if (NN_UNLIKELY(!mAes.Encrypt(mSecrets, rawData, rawLen, cipher, reinterpret_cast(cipherLen)))) { + return NN_ERROR; + } + return NN_OK; +} + +uint64_t NetEndpointImpl::EstimatedDecryptLen(uint64_t verbsCipherLen) +{ + if (NN_UNLIKELY(!mIsNeedEncrypt)) { + NN_LOG_ERROR("Failed to estimate decrypt length as options of encrypt is not enabled"); + return 0; + } + + return mAes.GetRawLen(verbsCipherLen); +} + +NResult NetEndpointImpl::Decrypt(const void *cipher, uint64_t cipherLen, void *rawData, uint64_t &rawLen) +{ + if (NN_UNLIKELY(!mIsNeedEncrypt) || NN_UNLIKELY(rawLen > UINT32_MAX) || NN_UNLIKELY(cipherLen > UINT32_MAX)) { + NN_LOG_ERROR("Failed to decrypt, options of decrypt not enabled or len over uint32_max"); + return NN_ERROR; + } + + if (NN_UNLIKELY(!mAes.Decrypt(mSecrets, cipher, cipherLen, rawData, reinterpret_cast(rawLen)))) { + return NN_ERROR; + } + return NN_OK; +} + +} +} diff --git a/src/transport/net_endpoint_impl.h b/src/transport/net_endpoint_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..8129b880100b0f033c1d7521ae0810856c12a5a7 --- /dev/null +++ b/src/transport/net_endpoint_impl.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_NET_ENDPOINT_IMPL_H +#define HCOM_NET_ENDPOINT_IMPL_H + +#include "hcom.h" +#include "net_security_alg.h" + +namespace ock { +namespace hcom { + +class NetEndpointImpl : public UBSHcomNetEndpoint { +public: + uint64_t EstimatedEncryptLen(uint64_t rawLen) override; + NResult Encrypt(const void *rawData, uint64_t rawLen, void *cipher, uint64_t &cipherLen) override; + uint64_t EstimatedDecryptLen(uint64_t verbsCipherLen) override; + NResult Decrypt(const void *cipher, uint64_t cipherLen, void *rawData, uint64_t &rawLen) override; + +public: + inline void EnableEncrypt(UBSHcomNetDriverOptions options) + { + mIsNeedEncrypt = true; + mAes.SetEncryptOptions(options.cipherSuite); + } + + inline void SetSecrets(NetSecrets &verbsSecrets) + { + mSecrets = verbsSecrets; + } + +protected: + NetEndpointImpl(uint64_t id, const UBSHcomNetWorkerIndex &workerWholeIndex) + : UBSHcomNetEndpoint(id, workerWholeIndex) {} + +protected: + bool mIsNeedEncrypt = false; + AesGcm128 mAes; + NetSecrets mSecrets; +}; + +} +} + +#endif // HCOM_NET_ENDPOINT_IMPL_H diff --git a/src/transport/net_heartbeat.cpp b/src/transport/net_heartbeat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6713b9c7f445c7fbb3b8d2bd7556b27646c27468 --- /dev/null +++ b/src/transport/net_heartbeat.cpp @@ -0,0 +1,340 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "net_heartbeat.h" +#ifdef RDMA_BUILD_ENABLED +#include "net_rdma_async_endpoint.h" +#endif + +#ifdef UB_BUILD_ENABLED +#include "net_ub_endpoint.h" +#include "ub_worker.h" +#endif + +namespace ock { +namespace hcom { +NetHeartbeat::NetHeartbeat(UBSHcomNetDriver *driver, uint16_t heartBeatIdleTime, uint16_t heartBeatProbeInterval) + : mDriver(driver), + mHeartBeatIdleTime(heartBeatIdleTime), + mHeartBeatProbeInterval(heartBeatProbeInterval * NN_NO1000000) +{ + if (mDriver != nullptr) { + mDriver->IncreaseRef(); + } + + if (mHeartBeatProbeInterval == 0) { + // If the user sets mHeartBeatProbeInterval to 0, change it to 5000 instead(5ms). + mHeartBeatProbeInterval = 5000; + } +} +NetHeartbeat::~NetHeartbeat() +{ + if (mDriver != nullptr) { + mDriver->DecreaseRef(); + mDriver = nullptr; + } +} + +NResult NetHeartbeat::Start() +{ + if (mDriver == nullptr) { + NN_LOG_ERROR("Failed to start because driver is null"); + return NN_INVALID_PARAM; + } + NResult result = NN_OK; + if ((result = mDriver->CreateMemoryRegion(NN_NO64 * NN_NO1024, mHBLocalOpMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for local HB in driver " << mDriver->Name() << ", result " << result); + return result; + } + + if ((result = mDriver->CreateMemoryRegion(NN_NO64 * NN_NO1024, mHBRemoteOpMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for remote HB in driver " << mDriver->Name() << ", result " << result); + mDriver->DestroyMemoryRegion(mHBLocalOpMr); + return result; + } + + mNeedStopHb = false; + std::thread tmpThread(&NetHeartbeat::RunInHbThread, this); + mHbThread = std::move(tmpThread); + + while (!mHBStarted.load()) { + usleep(NN_NO10); + } + return NN_OK; +} + +void NetHeartbeat::Stop() +{ + mNeedStopHb = true; + if (mHbThread.native_handle()) { + mHbThread.join(); + } + + if (mDriver == nullptr) { + return; + } + + if (mHBLocalOpMr != nullptr) { + mDriver->DestroyMemoryRegion(mHBLocalOpMr); + mHBLocalOpMr.Set(nullptr); + } + + if (mHBRemoteOpMr != nullptr) { + mDriver->DestroyMemoryRegion(mHBRemoteOpMr); + mHBRemoteOpMr.Set(nullptr); + } +} + +void NetHeartbeat::RunInHbThread() +{ + mHBStarted.store(true); + NN_LOG_INFO("Heartbeat thread for driver " << mDriver->Name() << ", HCOMHb" << std::to_string(mDriver->GetId()) << + " started, idle time " << mHeartBeatIdleTime); + + /* set thread name */ + pthread_setname_np(pthread_self(), ("HCOMHb" + std::to_string(mDriver->GetId())).c_str()); + + mTarSec = NetMonotonic::TimeSec() + mHeartBeatIdleTime; + while (!mNeedStopHb) { + mCurrentSec = NetMonotonic::TimeSec(); + while (mCurrentSec > mTarSec) { + mTarSec = mCurrentSec + mHeartBeatProbeInterval / NN_NO1000000; + DetectHbState(); + } + usleep(mHeartBeatProbeInterval); + } + NN_LOG_INFO("Heartbeat thread for driver " << mDriver->Name() << ", HCOMHb" << std::to_string(mDriver->GetId()) << + " exiting"); + mHBStarted.store(false); +} + +void NetHeartbeat::DetectHbState() +{ + if (mHBLocalOpMr.Get() == nullptr) { + NN_LOG_ERROR("Failed to heart beat detection as related memory region is null in driver " << mDriver->Name()); + return; + } + + UBSHcomNetTransRequest request = {}; + request.lAddress = GetNextLocalOpHBAddress(); + request.lKey = GetLocalOpHBKey(); + request.size = GetLocalOpHBMrSize(); + + static thread_local std::unordered_map endPointsCopy; + endPointsCopy.reserve(NN_NO8192); + endPointsCopy.clear(); + { + std::lock_guard locker(mDriver->mEndPointsMutex); + for (auto &endPoint : mDriver->mEndPoints) { + auto ep = endPoint.second.Get(); + if (ep != nullptr && ep->IsNeedSendHb()) { + endPointsCopy.emplace(endPoint.first, endPoint.second); + } + } + } + + for (auto &endPoint : endPointsCopy) { + DetectSingleEpHbState(request, endPoint.second.Get()); + } + + endPointsCopy.clear(); +} + +NResult NetHeartbeat::SendTwoSideHeartBeat(UBSHcomNetEndpoint *endPoint) +{ + NResult result = NN_OK; + char data; + UBSHcomNetTransRequest req((void *)(&data), sizeof(data), 0); + if (NN_UNLIKELY((result = endPoint->PostSend(HB_SEND_OP, req, 0)) != NN_OK)) { + NN_LOG_ERROR("Endpoint " << endPoint->mId << " failed to post send request, result " << result); + return result; + } + return NN_OK; +} + +template +NResult NetHeartbeat::SendHeartBeat(T *ep, T1 *driver, UBSHcomNetTransRequest &request, T2 opType) +{ + if (NN_UNLIKELY(!ep->mState.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Endpoint " << ep->mId << " is not established, state is " << + UBSHcomNEPStateToString(ep->mState.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + if (NN_UNLIKELY(ep->GetQp() == nullptr)) { + NN_LOG_ERROR("Endpoint " << ep->mId << " invalid endpoint"); + return NN_ERROR; + } +#ifdef UB_BUILD_ENABLED + if (driver->Protocol() == UBSHcomNetDriverProtocol::UBC) { + auto ubcEp = dynamic_cast(ep); + if (NN_UNLIKELY(ubcEp == nullptr)) { + NN_LOG_ERROR("Invalid operation to dynamic cast"); + return NN_ERROR; + } + auto jetty = ubcEp->GetQp(); + if (jetty->mHBLocalMr == nullptr) { + NN_LOG_WARN("Endpoint " << ep->mId << " HB mr freed already"); + return NN_ERROR; + } + request.lAddress = jetty->GetNextLocalHBAddress(); + request.lKey = jetty->GetLocalHBKey(); + request.srcSeg = jetty->mHBLocalMr->GetMemorySeg(); + } +#endif + request.rAddress = ep->mRemoteHbAddress; + request.upCtxSize = 0; + request.rKey = ep->mRemoteHbKey; + + if (driver->ValidateMemoryRegion(request.lKey, request.lAddress, request.size) != NN_OK) { + NN_LOG_ERROR("Endpoint " << ep->mId << " Invalid MemoryRegion or lkey"); + return NN_INVALID_LKEY; + } + auto worker = ep->GetWorker(); + if (NN_UNLIKELY(worker == nullptr)) { + NN_LOG_ERROR("Endpoint " << ep->mId << " failed to get worker from group in PostWrite "); + return NN_ERROR; + } + + NResult result = NN_OK; + if (NN_UNLIKELY((result = worker->PostWrite(ep->GetQp(), request, opType)) != NN_OK)) { + NN_LOG_ERROR("Endpoint " << ep->mId << " failed to post write request, result " << result); + return result; + } + return result; +} + +template +void NetHeartbeat::DetectSingleEpHbState(T *ep, T1 *driver, UBSHcomNetTransRequest &request, T2 opType) +{ + if (NN_UNLIKELY(ep == nullptr || driver == nullptr)) { + NN_LOG_WARN("Invalid operation to dynamic cast"); + return; + } + + NResult result = NN_OK; + /* check if reach ep target hb time */ + if (!ep->checkTargetHbTime(mCurrentSec)) { + return; + } + if (ep->HbCheckStateNormal()) { + result = SendHeartBeat(ep, driver, request, opType); + if (result == NN_OK) { + return; + } + NN_LOG_WARN("Detect Ep id " << ep->Id() << " cannot send Hb, result " << result); + } + if (ep->HbBrokenEp()) { + /* delay handle broken ep to prevent race condition with work polling cq thread */ + NN_LOG_WARN("Detect Ep id " << ep->Id() << " Hb state abnormal, call broken handle"); + driver->ProcessEpError(reinterpret_cast(ep)); + } else { + /* set hb broken ep when detected first time */ + NN_LOG_WARN("Detect Ep id " << ep->Id() << " Hb state abnormal, set qp err and wait next probe to handle"); + ep->State().Set(NEP_BROKEN); + ep->SetHbBrokenEp(); + if (NN_UNLIKELY(ep->GetQp() == nullptr)) { + NN_LOG_ERROR("Endpoint " << ep->Id() << " failed to get qp"); + return; + } + ep->GetQp()->Stop(); + } +} + +template +void NetHeartbeat::DetectSingleEpHbState(NetUBAsyncEndpoint *ep, NetDriverUBWithOob *driver, + UBSHcomNetTransRequest &request, T opType) +{ + if (NN_UNLIKELY(ep == nullptr || driver == nullptr)) { + NN_LOG_WARN("Invalid operation to dynamic cast"); + return; + } + + /* check if reach ep target hb time */ + if (!ep->checkTargetHbTime(mCurrentSec)) { + return; + } + + // EP 上的心跳包可能会因对端机器重启而产生超时事件。心跳事件任何非SUCCESS的状态码都将认为心跳异常从而断链。另外需 + // 要注意的是,心跳没有在一个周期内完成也判定为异常,这种情况可能发生于EP上存在大量的用户数据包,心跳包位于用户数 + // 据包后。因在处理用户数据包时就耗费大量时间,在轮到心跳包处理时可能已过了一个心跳周期。 + // \see NetDriverUBWithOob::OneSideDoneCB + if (ep->HbCheckStateNormal()) { + NResult result = SendHeartBeat(ep, driver, request, opType); + if (result != NN_OK) { + NN_LOG_WARN("Detect Ep id " << ep->Id() << " cannot send Hb, result " << result); + driver->ProcessEpError(reinterpret_cast(ep)); + } + } else { + driver->ProcessEpError(reinterpret_cast(ep)); + } +} + +template void NetHeartbeat::DetectSingleEpHbState(T *ep, T1 *driver) +{ + if (NN_UNLIKELY(ep == nullptr || driver == nullptr)) { + NN_LOG_WARN("Invalid operation to dynamic cast"); + return; + } + + NResult result = NN_OK; + if (ep->HbCheckStateNormal()) { + result = SendTwoSideHeartBeat(ep); + if (result == NN_OK) { + return; + } + NN_LOG_WARN("Detect Ep id " << ep->Id() << " cannot send Hb, result " << result); + } + if (ep->HbBrokenEp()) { + /* delay handle broken ep to prevent race condition with work polling cq thread */ + NN_LOG_WARN("Detect Ep id " << ep->Id() << " Hb state abnormal, call broken handle"); + driver->ProcessEpError(reinterpret_cast(ep)); + } else { + /* set hb broken ep when detected first time */ + NN_LOG_WARN("Detect Ep id " << ep->Id() << " Hb state abnormal, set qp err and wait next probe to handle"); + if (ep->State().Compare(NEP_ESTABLISHED)) { + ep->State().Set(NEP_BROKEN); + } + ep->SetHbBrokenEp(); + // free resources in ProcessEpError + } +} + +void NetHeartbeat::DetectSingleEpHbState(UBSHcomNetTransRequest &request, UBSHcomNetEndpoint *endPoint) +{ + if (NN_UNLIKELY(endPoint == nullptr || mDriver == nullptr)) { + NN_LOG_ERROR("Endpoint or driver is null"); + return; + } + + switch (mDriver->Protocol()) { +#ifdef RDMA_BUILD_ENABLED + case UBSHcomNetDriverProtocol::RDMA: + return DetectSingleEpHbState(dynamic_cast(endPoint), + dynamic_cast(mDriver), request, RDMAOpContextInfo::HB_WRITE); +#endif + +#ifdef UB_BUILD_ENABLED + case UBSHcomNetDriverProtocol::UBC: + return DetectSingleEpHbState(dynamic_cast(endPoint), + dynamic_cast(mDriver), request, UBOpContextInfo::HB_WRITE); +#endif + + default: + NN_LOG_ERROR("Invalid protocol " << UBSHcomNetDriverProtocolToString(mDriver->Protocol()) << + " to send heartbeat"); + return; + } + return; +} +} +} diff --git a/src/transport/net_heartbeat.h b/src/transport/net_heartbeat.h new file mode 100644 index 0000000000000000000000000000000000000000..54dc7de4600e570d6967291ce9480990311c25d2 --- /dev/null +++ b/src/transport/net_heartbeat.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_NET_HEARTBEAT_H +#define OCK_NET_HEARTBEAT_H + +#include "hcom.h" +#include "net_monotonic.h" + +#include + +namespace ock { +namespace hcom { +class NetUBAsyncEndpoint; +class NetDriverUBWithOob; + +class NetHeartbeat { +public: + NetHeartbeat(UBSHcomNetDriver *driver, uint16_t heartBeatIdleTime, uint16_t heartBeatProbeInterval); + ~NetHeartbeat(); + NResult Start(); + void Stop(); + + template + void GetRemoteHbInfo(T &info) + { + uint64_t nextOffset = __sync_fetch_and_add(&mRemoteNextOffset, NN_NO4) % mHBRemoteOpMr->Size(); + info.hbAddress = mHBRemoteOpMr->GetAddress() + nextOffset; + info.hbKey = mHBRemoteOpMr->GetLKey(); + info.hbMrSize = NN_NO4; + } + + uint16_t GetHbIdleTime() + { + return mHeartBeatIdleTime; + } + +private: + void RunInHbThread(); + void DetectHbState(); + + template + void DetectSingleEpHbState(T *ep, T1 *driver, UBSHcomNetTransRequest &request, T2 opType); + + /// UBC 专用 + template + void DetectSingleEpHbState(NetUBAsyncEndpoint *ep, NetDriverUBWithOob *driver, + UBSHcomNetTransRequest &request, T opType); + + /// 使用双边心跳,目前hshmem专用 + template + void DetectSingleEpHbState(T *ep, T1 *driver); + + void DetectSingleEpHbState(UBSHcomNetTransRequest &request, UBSHcomNetEndpoint *endPoint); + + template + NResult SendHeartBeat(T *ep, T1 *driver, UBSHcomNetTransRequest &request, T2 opType); + NResult SendTwoSideHeartBeat(UBSHcomNetEndpoint *endPoint); + + inline uintptr_t GetNextLocalOpHBAddress() + { + uint64_t nextOffset = __sync_fetch_and_add(&mLocalNextOffset, NN_NO4) % mHBLocalOpMr->Size(); + return mHBLocalOpMr->GetAddress() + nextOffset; + } + + inline uint64_t GetLocalOpHBKey() const + { + return mHBLocalOpMr->GetLKey(); + } + + static inline uint64_t GetLocalOpHBMrSize() + { + return NN_NO4; + } + + UBSHcomNetDriver *mDriver = nullptr; + + bool mNeedStopHb = false; + std::thread mHbThread; + std::atomic mHBStarted { false }; + UBSHcomNetMemoryRegionPtr mHBLocalOpMr; + uint64_t mLocalNextOffset = 0; + UBSHcomNetMemoryRegionPtr mHBRemoteOpMr; + uint64_t mRemoteNextOffset = 0; + uint16_t mHeartBeatIdleTime = NN_NO60; + uint32_t mHeartBeatProbeInterval = NN_NO2000000; // 2s + uint64_t mTarSec = 0; + uint64_t mCurrentSec = 0; +}; +} +} + +#endif // OCK_NET_HEARTBEAT_H diff --git a/src/transport/net_load_balance.h b/src/transport/net_load_balance.h new file mode 100644 index 0000000000000000000000000000000000000000..4ed8493a38b7909213b7d738544e785a824cf00a --- /dev/null +++ b/src/transport/net_load_balance.h @@ -0,0 +1,172 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_LOAD_BALANCE_H +#define OCK_HCOM_NET_LOAD_BALANCE_H + +#include "hcom.h" + +namespace ock { +namespace hcom { +struct NetWorkerGroupLbInfo { + uint16_t wrkCntInGrp = 0; /* number of worker in this group */ + uint16_t wrkOffsetInAll = 0; /* the offset of first workers in total flat workers, as all workers in stored flat */ + uint16_t grpRRIdx = 0; /* index counter for round-robin in one group */ + uint16_t wrkCntLimited = 0; /* only limited number of worker can be chosen */ + std::vector wrkLimited; /* limited workers, offset in this group */ +}; + +class NetWorkerLB { +public: + NetWorkerLB(const std::string &name, UBSHcomNetDriverLBPolicy policy, uint16_t wrkLimitedCnt) + : mName(name), mPolicy(policy), mWorkerLimitedCnt(wrkLimitedCnt) + {} + + ~NetWorkerLB() = default; + + /* + * @brief Add worker groups + */ + inline NResult AddWorkerGroups(const std::vector> &groups) + { + NN_ASSERT_LOG_RETURN(!groups.empty(), NN_INVALID_PARAM); + time_t currentTime = time(nullptr); + if (currentTime == -1) { + NN_LOG_ERROR("Failed to get current time when adding worker groups"); + return NN_ERROR; + } + /* this srand value is not used for security related thing */ + srand(static_cast(currentTime)); + for (const auto &item : groups) { + if (NN_LIKELY(item.second == 0)) { + return NN_INVALID_PARAM; + } + + AddWorkerGroup(item.first, item.second); + } + + return NN_OK; + } + + /* + * @brief Choose a worker + * + * @param grpIdx [in] group index transferred from client + * @param peerIpPort [in] ip and port new connection + * @param flatWrkIdx [out] index of the worker in the flatted workers + * + * @return true if chosen + */ + inline bool ChooseWorker(uint16_t grpIdx, const std::string &peerIpPort, uint16_t &flatWrkIdx) + { + const uint16_t groupCount = mWrkGroups.size(); + if (NN_UNLIKELY(grpIdx >= groupCount)) { + NN_LOG_ERROR("Invalid group no " << grpIdx << " from client " << peerIpPort << " in lb " << mName); + return false; + } + + /* if worker count is not equal to limited worker count */ + if (NN_UNLIKELY(mWrkGroups[grpIdx].wrkCntLimited != mWrkGroups[grpIdx].wrkCntInGrp)) { + return ChooseWorkerLimited(grpIdx, peerIpPort, flatWrkIdx); + } + + if (mPolicy == NET_ROUND_ROBIN) { + auto innerIdx = __sync_fetch_and_add(&(mWrkGroups[grpIdx].grpRRIdx), 1) % mWrkGroups[grpIdx].wrkCntInGrp; + flatWrkIdx = mWrkGroups[grpIdx].wrkOffsetInAll + innerIdx; + return true; + } else if (mPolicy == NET_HASH_IP_PORT) { + auto innerIdx = std::hash {}(peerIpPort) % mWrkGroups[grpIdx].wrkCntInGrp; + flatWrkIdx = mWrkGroups[grpIdx].wrkOffsetInAll + innerIdx; + return true; + } + + NN_LOG_ERROR("Un-supported load balance policy"); + return false; + } + + std::string ToString() const + { + std::ostringstream oss; + oss << "name: " << mName << ", policy: " << UBSHcomNetDriverLBPolicyToString(mPolicy) << + ", choose-able-count: " << (mWorkerLimitedCnt == UINT16_MAX ? "all" : std::to_string(mWorkerLimitedCnt)) << + ", worker-groups: " << mWrkGroups.size(); + return oss.str(); + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + inline bool ChooseWorkerLimited(uint16_t grpIdx, const std::string &peerIpPort, uint16_t &flatWrkIdx) + { + const uint16_t groupCount = mWrkGroups.size(); + if (NN_UNLIKELY(grpIdx >= groupCount)) { + NN_LOG_ERROR("Invalid group Idx"); + return false; + } + + if (mPolicy == NET_ROUND_ROBIN) { + auto innerIdx = __sync_fetch_and_add(&(mWrkGroups[grpIdx].grpRRIdx), 1) % mWrkGroups[grpIdx].wrkCntLimited; + flatWrkIdx = mWrkGroups[grpIdx].wrkOffsetInAll + mWrkGroups[grpIdx].wrkLimited[innerIdx]; + return true; + } else if (mPolicy == NET_HASH_IP_PORT) { + auto innerIdx = std::hash {}(peerIpPort) % mWrkGroups[grpIdx].wrkCntLimited; + flatWrkIdx = mWrkGroups[grpIdx].wrkOffsetInAll + mWrkGroups[grpIdx].wrkLimited[innerIdx]; + return true; + } + + return false; + } + + /* + * @brief Added one worker group + * + * @param offsetWorker [in] offset of the first worker's offset in flat workers + * @param wrkCntInGrp [in] the worker count in this group + * + */ + inline void AddWorkerGroup(uint16_t offsetWorker, uint16_t wrkCntInGrp) + { + NetWorkerGroupLbInfo info {}; + info.wrkCntInGrp = wrkCntInGrp; + info.wrkOffsetInAll = offsetWorker; + info.grpRRIdx = 0; + info.wrkCntLimited = wrkCntInGrp; + + /* + * worker number in group is large than limited number of worker + * generate random offset + */ + if (wrkCntInGrp > mWorkerLimitedCnt) { + info.wrkCntLimited = mWorkerLimitedCnt; + /* this rand value is not used for security related thing */ + auto randIndex = rand(); + for (uint16_t i = 0; i < mWorkerLimitedCnt; i++) { + info.wrkLimited.emplace_back((randIndex + i) % wrkCntInGrp); + } + } + + mWrkGroups.push_back(info); + } + +private: + std::string mName; + UBSHcomNetDriverLBPolicy mPolicy = NET_ROUND_ROBIN; /* policy */ + uint16_t mWorkerLimitedCnt = 0; /* means this can be less than total workers in one group */ + std::vector mWrkGroups; /* worker group info */ + std::vector mEpCntPerWorkers; /* for even distributed policy */ + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +using NetWorkerLBPtr = NetRef; +} +} + +#endif // OCK_HCOM_NET_LOAD_BALANCE_H diff --git a/src/transport/net_memory_region.cpp b/src/transport/net_memory_region.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8080dc851f33744e18b2fe6b5d0d276299ceefaa --- /dev/null +++ b/src/transport/net_memory_region.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_memory_region.h" + +namespace ock { +namespace hcom { +std::atomic NormalMemoryRegion::KEY_ID(0); +std::atomic NormalMemoryRegion::LOCAL_KEY_INDEX(0); + +NResult NormalMemoryRegion::Create(const std::string &name, uint64_t size, NormalMemoryRegion *&buf) +{ + if (NN_UNLIKELY(size == 0)) { + NN_LOG_ERROR("Failed to create normal memory region as size is zero"); + return NN_INVALID_PARAM; + } + + auto tmpBuf = new (std::nothrow) NormalMemoryRegion(name, false, 0, size); + if ((NN_UNLIKELY(tmpBuf == nullptr))) { + NN_LOG_ERROR("Failed to create normal memory region"); + return NN_NEW_OBJECT_FAILED; + } + + buf = tmpBuf; + + return NN_OK; +} + +NResult NormalMemoryRegion::Create(const std::string &name, uintptr_t address, uint64_t size, NormalMemoryRegion *&buf) +{ + if (NN_UNLIKELY(address == 0 || size == 0)) { + NN_LOG_ERROR("Failed to create normal memory region as address or size is zero"); + return NN_INVALID_PARAM; + } + + auto tmpBuf = new (std::nothrow) NormalMemoryRegion(name, true, address, size); + if ((NN_UNLIKELY(tmpBuf == nullptr))) { + NN_LOG_ERROR("Failed to create normal memory region"); + return NN_NEW_OBJECT_FAILED; + } + + buf = tmpBuf; + + return NN_OK; +} + +NResult NormalMemoryRegion::Initialize() +{ + std::lock_guard guard(mMutex); + if (mInited) { + return NN_OK; + } + + if (mExternalMemory) { + if ((mBuf == 0 || mSize == 0)) { + NN_LOG_ERROR("Invalid external memory address or size for normal memory region " << mName); + return NN_INVALID_PARAM; + } + + mLKey = LOCAL_KEY_INDEX.fetch_add(1); + mInited = true; + + /* don't do bzero to external memory, because this may clean user's data */ + return NN_OK; + } + + /* allocate memory */ + auto tmpBuf = memalign(NN_NO4096, mSize); + if (tmpBuf == nullptr) { + NN_LOG_ERROR("Failed to allocate memory for normal memory region " << mName << " with size " << mSize); + return NN_MALLOC_FAILED; + } + + bzero(tmpBuf, mSize); + mBuf = reinterpret_cast(tmpBuf); + mLKey = LOCAL_KEY_INDEX.fetch_add(1); + mInited = true; + return NN_OK; +} + +void NormalMemoryRegion::UnInitialize() +{ + std::lock_guard guard(mMutex); + if (!mInited) { + return; + } + + if (!mExternalMemory) { + free(reinterpret_cast(mBuf)); + mBuf = 0; + } + mInited = false; +} + +/* NormalMemoryRegionFixedBuffer */ +NResult NormalMemoryRegionFixedBuffer::Create(const std::string &name, uint32_t singleSegSize, uint32_t segCount, + NormalMemoryRegionFixedBuffer *&buf) +{ + auto tmp = new (std::nothrow) NormalMemoryRegionFixedBuffer(name, singleSegSize, segCount); + if (tmp == nullptr) { + return NN_NEW_OBJECT_FAILED; + } + + buf = tmp; + return NN_OK; +} + + +NResult NormalMemoryRegionFixedBuffer::Initialize() +{ + NResult result = NN_OK; + if ((result = NormalMemoryRegion::Initialize()) != NN_OK) { + return result; + } + + /* init unAllocated container */ + if ((result = mUnAllocated.Initialize()) != NN_OK) { + NN_LOG_ERROR("Failed to initialize un-allocated ring buffer in NormalMemoryRegionFixedBuffer " << mName); + return result; + } + + /* init un-allocated */ + uintptr_t address = mBuf; + for (uint32_t i = 0; i < mSegCount; i++) { + mUnAllocated.PushBack(address); + address += mSingleSegSize; + } + + return NN_OK; +} + +void NormalMemoryRegionFixedBuffer::UnInitialize() +{ + mUnAllocated.UnInitialize(); + NormalMemoryRegion::UnInitialize(); +} +} +} \ No newline at end of file diff --git a/src/transport/net_memory_region.h b/src/transport/net_memory_region.h new file mode 100644 index 0000000000000000000000000000000000000000..2c2a34f8a7951e97ca6468f93e64ef84b9e8ac7c --- /dev/null +++ b/src/transport/net_memory_region.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_MEMORY_REGION_H_23234 +#define OCK_HCOM_NET_MEMORY_REGION_H_23234 + +#include + +#include "hcom.h" + +namespace ock { +namespace hcom { +class NormalMemoryRegion : public UBSHcomNetMemoryRegion { +public: + static NResult Create(const std::string &name, uint64_t size, NormalMemoryRegion *&buf); + static NResult Create(const std::string &name, uintptr_t address, uint64_t size, NormalMemoryRegion *&buf); + +public: + NormalMemoryRegion(const std::string &name, bool extMem, uintptr_t extMemAddress, uint64_t size) + : UBSHcomNetMemoryRegion(name, extMem, extMemAddress, size) + {} + + NResult Initialize() override; + void UnInitialize() override; + + void *GetMemorySeg() override + { + return nullptr; + } + + void GetVa(uint64_t &va, uint64_t &va_len, uint32_t &token_id) override + { + return; + } + +private: + std::mutex mMutex; + bool mInited = false; + + static std::atomic KEY_ID; + static std::atomic LOCAL_KEY_INDEX; +}; + +/* ***************************************************************************************************** */ +class NormalMemoryRegionFixedBuffer : public NormalMemoryRegion { +public: + static NResult Create(const std::string &name, uint32_t singleSegSize, uint32_t segCount, + NormalMemoryRegionFixedBuffer *&buf); + +public: + NormalMemoryRegionFixedBuffer(const std::string &name, uint32_t singleSegSize, uint32_t segCount) + : NormalMemoryRegion(name, false, 0, static_cast(singleSegSize) * static_cast(segCount)), + mSingleSegSize(singleSegSize), + mSegCount(segCount), + mUnAllocated(segCount) + {} + + ~NormalMemoryRegionFixedBuffer() override + { + UnInitialize(); + } + + NResult Initialize() override; + void UnInitialize() override; + + inline uint32_t GetFreeBufferCount() + { + return mUnAllocated.Size(); + } + + inline bool GetFreeBuffer(uintptr_t &item) + { + return mUnAllocated.PopFront(item); + } + + inline bool GetFreeBufferN(uintptr_t *items, uint32_t n) + { + if (NN_UNLIKELY(items == nullptr)) { + return false; + } + return mUnAllocated.PopFrontN(items, n); + } + + inline bool ReturnBuffer(uintptr_t value) + { + return mUnAllocated.PushFront(value); + } + + std::string ToString() + { + std::ostringstream oss; + oss << "NormalMemoryRegionFixedBuffer info: mBuf " << mBuf << ", mSingleSegSize " << mSingleSegSize << + ", mSegCount " << mSegCount << ", unAllocatedSize " << mUnAllocated.Size() << ", total buf size " << mSize; + return oss.str(); + } + + inline uint32_t GetSingleSegSize() const + { + return mSingleSegSize; + } + +private: + uint32_t mSingleSegSize = 0; + uint32_t mSegCount = 0; + + // uintptr_p store the start address of each mr segment + NetRingBuffer mUnAllocated; +}; +} +} + +#endif // OCK_HCOM_NET_MEMORY_REGION_H_23234 diff --git a/src/transport/net_oob.cpp b/src/transport/net_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbbbb113f8391f273f351fe6c4126788d3a6f6af --- /dev/null +++ b/src/transport/net_oob.cpp @@ -0,0 +1,939 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hcom_def.h" +#include "hcom_log.h" +#include "net_oob.h" + +namespace ock { +namespace hcom { +NResult OOBTCPServer::EnableAutoPortSelection(uint16_t minPort, uint16_t maxPort) +{ + if (mStarted) { + NN_LOG_ERROR("Failed to enable auto port selection! oob server already start."); + return NN_ERROR; + } + + if (mOobType != NET_OOB_TCP) { + NN_LOG_ERROR("Failed to enable auto port selection! OOB_TYPE is not TCP."); + return NN_ERROR; + } + + if (minPort == 0 || maxPort == 0) { + NN_LOG_ERROR("Failed to enable auto port selection!, port range is invalid!"); + return NN_ERROR; + } + + if (minPort < NN_NO1024) { + NN_LOG_ERROR("Failed to enable auto port selection! minPort is less than 1024."); + return NN_ERROR; + } + + if (maxPort < NN_NO1024) { + NN_LOG_ERROR("Failed to enable auto port selection! maxPort is less than 1024."); + return NN_ERROR; + } + + if (minPort > maxPort) { + NN_LOG_ERROR("Failed to enable auto port selection! minPort is bigger than maxPort."); + return NN_ERROR; + } + + if (mListenPort != 0) { + NN_LOG_WARN("oobPort will be selected automatically!"); + } + + mMinListenPort = minPort; + mMaxListenPort = maxPort; + mListenPort = mMinListenPort; + mIsAutoPortSelectionEnabled = true; + return NN_OK; +} + +NResult OOBTCPServer::GetListenPort(uint16_t &port) +{ + if (!mStarted) { + NN_LOG_ERROR("Failed to get listen port, oob server is not start"); + return NN_ERROR; + } + + port = mListenPort; + return NN_OK; +} + +NResult OOBTCPServer::GetListenIp(std::string &ip) +{ + if (!mStarted) { + NN_LOG_ERROR("Failed to get listen ip, oob server is not start"); + return NN_ERROR; + } + + ip = mListenIP; + return NN_OK; +} + +NResult OOBTCPServer::GetUdsName(std::string &udsName) +{ + if (!mStarted) { + NN_LOG_ERROR("Failed to get uds name, oob server is not start"); + return NN_ERROR; + } + + if (mOobType != NET_OOB_UDS) { + NN_LOG_ERROR("Failed to get uds name, oob server is not uds"); + return NN_ERROR; + } + + udsName = mUdsName; + return NN_OK; +} + +NResult OOBTCPServer::BindAndListenCommon(int socketFD) +{ + struct sockaddr_in addr {}; + bzero(&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = inet_addr(mListenIP.c_str()); + addr.sin_port = htons(mListenPort); + auto ret = ::bind(socketFD, reinterpret_cast(&addr), sizeof(addr)); + if (NN_UNLIKELY(ret < 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to bind on " << mListenIP << ":" << mListenPort << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + NetFunc::NN_SafeCloseFd(socketFD); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + + // listen + if (NN_UNLIKELY(::listen(socketFD, OOB_DEFAULT_LISTEN_BACKLOG) < 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to listen on " << mListenIP << ":" << mListenPort << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + NetFunc::NN_SafeCloseFd(socketFD); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + return NN_OK; +} + +NResult OOBTCPServer::BindAndListenAuto(int &socketFD) +{ + struct sockaddr_in addr {}; + bool isBindAndListenSuccess = false; + // mListenPort is set to mMinListenPort in EnableAutoPortSelection() + auto tmpPort = mListenPort; + while (tmpPort <= mMaxListenPort) { + bzero(&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = inet_addr(mListenIP.c_str()); + addr.sin_port = htons(tmpPort); + auto ret = ::bind(socketFD, reinterpret_cast(&addr), sizeof(addr)); + if (NN_UNLIKELY(ret < 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_DEBUG("Try to bind on " << mListenIP << ":" << tmpPort << " failed, error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + ++tmpPort; + continue; + } + + ret = ::listen(socketFD, OOB_DEFAULT_LISTEN_BACKLOG); + if (NN_LIKELY(ret == 0)) { + isBindAndListenSuccess = true; + break; + } + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_DEBUG("Try to listen on " << mListenIP << ":" << tmpPort << " failed, error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + // bind success but listen failed, reuse socketFD will case invalid argument error(22) + NetFunc::NN_SafeCloseFd(socketFD); + ret = CreateAndConfigSocket(socketFD); + if (NN_UNLIKELY(ret != NN_OK)) { + NN_LOG_ERROR("Recreate socket fd failed"); + return ret; + } + ++tmpPort; + } + + if (!isBindAndListenSuccess) { + NN_LOG_ERROR("Failed to bind and listen on port range [" << mMinListenPort << ", " << mMaxListenPort << "]."); + NetFunc::NN_SafeCloseFd(socketFD); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + mListenPort = tmpPort; + return NN_OK; +} + +NResult OOBTCPServer::CreateAndConfigSocket(int &socketFD) +{ + auto tmpFD = ::socket(AF_INET, SOCK_STREAM, 0); + if (tmpFD < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create listen socket, error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << + ", please check if running of fd limit"); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + /* set no-blocking */ + int value = 1; + if (NN_UNLIKELY((value = fcntl(tmpFD, F_GETFL, 0)) == -1)) { + NetFunc::NN_SafeCloseFd(tmpFD); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get control value for sock " << mIndex.oobSvrIdx << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + + if (NN_UNLIKELY((value = fcntl(tmpFD, F_SETFL, uint32_t(value) | O_NONBLOCK)) == -1)) { + NetFunc::NN_SafeCloseFd(tmpFD); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set control value for sock " << mIndex.oobSvrIdx << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + + // set option + int flags = 1; + int ret = ::setsockopt(tmpFD, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&flags), sizeof(flags)); + if (NN_UNLIKELY(ret < 0)) { + NetFunc::NN_SafeCloseFd(tmpFD); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set option, error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + socketFD = tmpFD; + return NN_OK; +} + +NResult OOBTCPServer::CreateAndStartSocket() +{ + int socketFD = 0; + int ret = NN_OK; + + ret = CreateAndConfigSocket(socketFD); + if (NN_UNLIKELY(ret != NN_OK)) { + return ret; + } + + if (mIsAutoPortSelectionEnabled) { + ret = BindAndListenAuto(socketFD); + } else { + ret = BindAndListenCommon(socketFD); + } + + if (NN_LIKELY(ret == NN_OK)) { + mListenFD = socketFD; + } + return ret; +} + +NResult OOBTCPServer::Start() +{ + if (mStarted) { + return NN_OK; + } + + // check new connection cb + if (mNewConnectionHandler == nullptr) { + NN_LOG_ERROR("Failed to start oob server as new connection callback is not set"); + return NN_OOB_CONN_CB_NOT_SET; + } + + // check lb + if ((!enableMultiRail) && (mWorkerLb == nullptr)) { + NN_LOG_ERROR("Failed to start oob server as load balancer is not set"); + return NN_INVALID_PARAM; + } + + if (mOobType == NET_OOB_UDS) { + return StartForUds(); + } + + if (mOobType != NET_OOB_TCP || mListenIP.empty() || mListenPort < NN_NO1024) { + NN_LOG_ERROR("Failed to start oob server as invalid type or listen ip " << mListenIP << " or port " << + mListenPort << ", port range is 1024 ~ 65535)"); + return NN_INVALID_PARAM; + } + + auto ret = CreateAndStartSocket(); + if (NN_UNLIKELY(ret != NN_OK)) { + NN_LOG_ERROR("Failed to create and start oob tcp socket"); + return ret; + } + + // start oob connection cb thread + mEs = NetExecutorService::Create(mNewConnCbThreadNum, mNewConnCbQueueCap); + if (NN_UNLIKELY(mEs == nullptr)) { + NetFunc::NN_SafeCloseFd(mListenFD); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create oob connection cb thread in oob server, as " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_ERROR; + } + mEs->SetThreadName("OOBTcpConnHdl"); + if (NN_UNLIKELY(!mEs->Start())) { + NetFunc::NN_SafeCloseFd(mListenFD); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to start oob connection cb thread in oob server, as " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_ERROR; + } + + mThreadStarted.store(false); + + // start oob accept thread + std::thread tmpThread(&OOBTCPServer::RunInThread, this); + mAcceptThread = std::move(tmpThread); + std::string thrName = "OOBTcpSvr" + mIndex.ToString(); + if (pthread_setname_np(mAcceptThread.native_handle(), thrName.c_str()) != 0) { + NN_LOG_WARN("Invalid to set thread name of oob tcp server"); + } + + while (!mThreadStarted.load()) { + usleep(NN_NO128); + } + + mStarted = true; + return NN_OK; +} + +NResult OOBTCPServer::Stop() +{ + if (!mStarted) { + return NN_OK; + } + + mNeedStop = true; + + if (mAcceptThread.joinable()) { + mAcceptThread.join(); + } + + if (mOobType == NET_OOB_UDS && mUdsPerm != 0) { + if (!CanonicalPath(mUdsName)) { + NN_LOG_ERROR("Uds oob file path is invalid"); + return NN_INVALID_PARAM; + } + + if (!NetFunc::NN_CheckFilePrefix(mUdsName)) { + NN_LOG_ERROR("Uds oob file path is invalid as prefix invalid"); + return NN_INVALID_PARAM; + } + unlink(mUdsName.c_str()); + } + + NetFunc::NN_SafeCloseFd(mListenFD); + + mStarted = false; + return NN_OK; +} + +NResult OOBTCPServer::AssignUdsAddress(sockaddr_un &address, socklen_t &addressLen) +{ + if (mUdsPerm == 0) { + address.sun_path[0] = '\0'; /* use abstract namespace */ + if (strcpy_s(address.sun_path + 1, sizeof(address.sun_path) - 1, mUdsName.c_str()) != EOK) { + NN_LOG_ERROR("strcpy_s uds name error."); + return NN_ERROR; + } + addressLen = sizeof(address.sun_family) + 1 + mUdsName.length(); + } else { + size_t index = mUdsName.find_last_of('/'); + if (NN_UNLIKELY(index == std::string::npos)) { + NN_LOG_ERROR("Uds oob file path is invalid"); + return NN_INVALID_PARAM; + } + + std::string udsFilePrefix = mUdsName.substr(0, index + 1); + std::string udsFileName = mUdsName.substr(index + 1, mUdsName.length()); + + if (!NetFunc::NN_CheckFilePrefix(udsFilePrefix)) { + NN_LOG_ERROR("Uds oob file path is invalid as prefix invalid"); + return NN_INVALID_PARAM; + } + if (!CanonicalPath(udsFilePrefix)) { + NN_LOG_ERROR("Uds oob file path is invalid"); + return NN_INVALID_PARAM; + } + + mUdsName = udsFilePrefix + "/" + udsFileName; + + if (::access(mUdsName.c_str(), 0) == 0) { + if (unlink(mUdsName.c_str()) == -1) { + NN_LOG_ERROR("Failed to unlink uds oob file"); + return NN_INVALID_PARAM; + } + } + + int result = 0; + if ((result = strcpy_s(address.sun_path, sizeof(address.sun_path), mUdsName.c_str())) != EOK) { + NN_LOG_ERROR("strcpy_s uds name error. result :" << result); + return NN_ERROR; + } + addressLen = sizeof(address); + } + return NN_OK; +} + +NResult OOBTCPServer::StartForUds() +{ + if (mUdsName.empty()) { + NN_LOG_ERROR("Failed to start oob server as invalid UDS file path"); + return NN_INVALID_PARAM; + } + + if (mUdsName[0] == '/' && mUdsPerm == 0) { + NN_LOG_ERROR( + "Failed to start oob server as invalid UDS file path, first char cannot be '/' for abstract namespace"); + return NN_INVALID_PARAM; + } + + struct sockaddr_un address {}; + socklen_t addressLen = 0; + NN_ASSERT_LOG_RETURN(sizeof(address.sun_path) - 1 > mUdsName.length(), NN_INVALID_PARAM); + bzero(&address, sizeof(address)); + address.sun_family = AF_UNIX; + + auto result = AssignUdsAddress(address, addressLen); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + + auto listenFd = ::socket(AF_UNIX, SOCK_STREAM, 0); + if (listenFd < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create listen socket, error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << + ", please check if fd is out of limit"); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + + if (::bind(listenFd, reinterpret_cast(&address), addressLen) < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to bind uds, error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + NetFunc::NN_SafeCloseFd(listenFd); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + + /* To support communication between different users in two containers */ + if (NN_UNLIKELY(mCheckUdsPerm && (mUdsPerm != NN_NO0600) && (mUdsPerm != NN_NO0))) { + NN_LOG_WARN("File permission is incorrect, The file permission must be set to 0600."); + mUdsPerm = NN_NO0600; + } + + chmod(mUdsName.c_str(), mUdsPerm); + + if (::listen(listenFd, NN_NO1024) < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to listen uds, error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + NetFunc::NN_SafeCloseFd(listenFd); + return NN_OOB_LISTEN_SOCKET_ERROR; + } + + mListenFD = listenFd; + mThreadStarted.store(false); + + // start oob accept thread + std::thread tmpThread(&OOBTCPServer::RunInThread, this); + mAcceptThread = std::move(tmpThread); + std::string thrName = "OOBUdsSvr" + mIndex.ToString(); + if (pthread_setname_np(mAcceptThread.native_handle(), thrName.c_str()) != 0) { + NN_LOG_WARN("Invalid to set thread name of oob uds server"); + } + + while (!mThreadStarted.load()) { + usleep(NN_NO128); + } + + // start oob connection cb thread + mEs = NetExecutorService::Create(mNewConnCbThreadNum, mNewConnCbQueueCap); + if (NN_UNLIKELY(mEs == nullptr)) { + return NN_ERROR; + } + mEs->SetThreadName("OOBUdsConnHdl"); + if (NN_UNLIKELY(!mEs->Start())) { + return NN_ERROR; + } + + mStarted = true; + return NN_OK; +} + +void OOBTCPServer::DealConnectInThread(int fd, struct sockaddr_in addressIn) +{ + ConnectResp resp = ConnectResp::OK; + + char ipStr[INET_ADDRSTRLEN] = {0}; + auto newConnTask = new (std::nothrow) ConnectCbTask(mNewConnectionHandler, fd, mWorkerLb); + if (NN_UNLIKELY(newConnTask == nullptr)) { + resp = ConnectResp::CONN_ACCEPT_NEW_TASK_FAIL; + } else { + if (inet_ntop(AF_INET, &(addressIn.sin_addr), ipStr, INET_ADDRSTRLEN) == nullptr) { + NN_LOG_ERROR("Failed to convert ip number to string"); + delete newConnTask; + resp = SERVER_INTERNAL_ERROR; + } else { + newConnTask->SetIpPort(std::string(ipStr), ntohs(addressIn.sin_port), mListenPort); + if (mOobType == NET_OOB_UDS) { + newConnTask->SetUdsName(mUdsName); + } + if (NN_UNLIKELY(!mEs->Execute(newConnTask))) { + delete newConnTask; + resp = ConnectResp::CONN_ACCEPT_QUEUE_FULL; + NN_LOG_WARN("Failed to execute task may be queue is full, please retry it"); + } + } + } + + if (resp != ConnectResp::OK) { + // if accept success but execute task failed, should notify client connect fail and client will retry + if (::send(fd, &resp, sizeof(ConnectResp), 0) <= 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to send connect resp to peer on oob @ " << std::string(ipStr) << ":" << + ntohs(addressIn.sin_port) << ", as errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + } + } +} + +void OOBTCPServer::RunInThread() +{ + if (mOobType == NET_OOB_TCP) { + NN_LOG_INFO("OOB server accept thread for " << mListenIP << ":" << mListenPort << " started, load balancer " << + (mWorkerLb == nullptr ? "null" : mWorkerLb->ToString())); + } else if (mOobType == NET_OOB_UDS) { + NN_LOG_TRACE_INFO("OOB server accept thread for " << mUdsName << " started, load balancer " << + (mWorkerLb == nullptr ? "null" : mWorkerLb->ToString())); + } else { + NN_LOG_ERROR("Un-reachable path"); + } + + mThreadStarted.store(true); + + struct sockaddr_in addressIn {}; + socklen_t len = sizeof(addressIn); + + int flags = 1; + + auto maxRecvTimeout = NetFunc::NN_GetLongEnv("HCOM_CONNECTION_RECV_TIMEOUT_SEC", NN_NO1, NN_NO7200, NN_NO0); + auto maxSendTimeout = NetFunc::NN_GetLongEnv("HCOM_CONNECTION_SEND_TIMEOUT_SEC", NN_NO1, NN_NO7200, NN_NO0); + + while (NN_UNLIKELY(mEs == nullptr || !mEs->IsStart())) { + usleep(NN_NO100); + } + while (true) { + try { + if (NN_UNLIKELY(mNeedStop)) { + NN_LOG_INFO("Got stop signal, stop listening"); + break; + } + + struct pollfd pollEventFd = {}; + pollEventFd.fd = mListenFD; + pollEventFd.events = POLLIN; + pollEventFd.revents = 0; + + int rc = poll(&pollEventFd, 1, NN_NO500); + if (rc < 0 && errno != EINTR) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Get poll event failed , errno " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + break; + } + + if (rc == 0) { + continue; + } + + bzero(&addressIn, sizeof(struct sockaddr_in)); + auto fd = ::accept(mListenFD, reinterpret_cast(&addressIn), &len); + if (fd < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Invalid to accept on new socket with " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << ", ignore and continue"); + continue; + } + + // set no delay + setsockopt(fd, SOL_TCP, TCP_NODELAY, reinterpret_cast(&flags), sizeof(flags)); + + /* set recv or send timeout */ + if (maxRecvTimeout != NN_NO0) { + struct timeval recvTimeout = { maxRecvTimeout, 0 }; + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recvTimeout, sizeof(timeval)); + } + if (maxSendTimeout != NN_NO0) { + struct timeval sendTimeout = { maxSendTimeout, 0 }; + setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &sendTimeout, sizeof(timeval)); + } + + DealConnectInThread(fd, addressIn); + } catch (std::exception &ex) { + NN_LOG_WARN("Got exception in OOBTCPServer::RunInThread, exception " << ex.what() << + ", ignore and continue"); + } catch (...) { + NN_LOG_WARN("Got unknown exception in OOBTCPServer::RunInThread, ignore and continue"); + } + } + + NN_LOG_INFO("Working thread for OOBTCPServer at " << mListenIP << ":" << mListenPort << " exiting"); +} + + +/* OOBTCPConnection */ +OOBTCPConnection::~OOBTCPConnection() +{ + NetFunc::NN_SafeCloseFd(mFD); +} + +NResult OOBTCPConnection::Send(void *buf, uint32_t size) const +{ + if (NN_UNLIKELY(buf == nullptr)) { + NN_LOG_ERROR("Failed to send as buf is nullptr"); + return NN_PARAM_INVALID; + } + + const unsigned char *p = static_cast(buf); + while (size > 0) { + const ssize_t result = ::send(mFD, p, size, 0); + if (result == -1) { + if (errno == EINTR) { + continue; + } else { + // Since mFD is blocking, EAGAIN/EWOULDBLOCK won't be there. + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR( + "Failed to send data to peer on oob @ " + << mIpAndPort << ", as errno:" << errno << " error:" + << NetFunc::NN_GetStrError(errno, errBuf, + NET_STR_ERROR_BUF_SIZE)); + return NN_OOB_CONN_SEND_ERROR; + } + } else if (result == 0) { + NN_LOG_ERROR("Failed to send data to peer on oob @ " + << mIpAndPort << ", reset by peer"); + return NN_OOB_CONN_SEND_ERROR; + } + + p += result; + size -= static_cast(result); + } + + return NN_OK; +} + +NResult OOBTCPConnection::Receive(void *buf, uint32_t size) const +{ + if (NN_UNLIKELY(buf == nullptr)) { + NN_LOG_ERROR("Failed to recv as buf is nullptr"); + return NN_PARAM_INVALID; + } + + unsigned char *p = static_cast(buf); + while (size > 0) { + const ssize_t result = ::recv(mFD, p, size, 0); + if (result == -1) { + if (errno == EINTR) { + continue; + } else { + // Since mFD is blocking, EAGAIN/EWOULDBLOCK won't be there. + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR( + "Failed to receive data from peer on oob @ " + << mIpAndPort << ", as errno:" << errno << " error:" + << NetFunc::NN_GetStrError(errno, errBuf, + NET_STR_ERROR_BUF_SIZE)); + return NN_OOB_CONN_RECEIVE_ERROR; + } + } else if (result == 0) { + NN_LOG_ERROR("Failed to receive data from peer on oob @ " << mIpAndPort << ", peer fd closed"); + return NN_OOB_CONN_RECEIVE_ERROR; + } + + p += result; + size -= static_cast(result); + } + + return NN_OK; +} + +NResult OOBTCPConnection::SendMsg(msghdr msg, uint32_t size) const +{ + auto result = ::sendmsg(mFD, &msg, 0); + if (NN_LIKELY(result == size)) { + return NN_OK; + } else if (result <= 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to send msg to peer " << mIpAndPort << " result:" << result << ", as " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_ERROR; + } else { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed send msg to pee, the size is un-matched required size " << sizeof(msg) << ", send size " << + result << ", or connection error, errno " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_ERROR; + } + return NN_OK; +} + + +NResult OOBTCPConnection::ReceiveMsg(msghdr msg, uint32_t size) const +{ + auto result = ::recvmsg(mFD, &msg, 0); + if (NN_LIKELY(result == size)) { + return NN_OK; + } else if (result <= 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to receive msg from peer on oob" << mIpAndPort << ", as " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_ERROR; + } else { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to receive data from peer, the size is un-matched required size " << sizeof(msg) << + ", recv size " << result << ", or connection error, errno " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_ERROR; + } + return NN_OK; +} + +/* OOBTCPClient */ +NResult OOBTCPClient::Connect(const std::string &ip, uint32_t port, OOBTCPConnection *&conn) +{ + int fd = -1; + auto result = ConnectWithFd(ip, port, fd); + if (result != NN_OK) { + return result; + } + + conn = new (std::nothrow) OOBTCPConnection(fd); + if (NN_UNLIKELY(conn == nullptr)) { + NN_LOG_ERROR("Failed to new oob connection, probably out of memory"); + NetFunc::NN_SafeCloseFd(fd); + return NN_NEW_OBJECT_FAILED; + } + + conn->ListenPort(port); + return NN_OK; +} + +NResult OOBTCPClient::ConnectWithFd(const std::string &ip, uint32_t port, int &fd) +{ + auto tmpFD = ::socket(AF_INET, SOCK_STREAM, 0); + if (tmpFD < 0) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create listen socket, errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE) << ", please check if fd is out of limit"); + return NN_OOB_CLIENT_SOCKET_ERROR; + } + + int flags = 1; + setsockopt(tmpFD, SOL_TCP, TCP_NODELAY, reinterpret_cast(&flags), sizeof(flags)); + int synCnt = 1; /* Set connect() retry time for quick connect */ + setsockopt(tmpFD, IPPROTO_TCP, TCP_SYNCNT, &synCnt, sizeof(synCnt)); + + auto ipAddr = inet_addr(ip.c_str()); + if (ipAddr == INADDR_NONE) { + NN_LOG_ERROR("Failed to connect because ip is error. "); + NetFunc::NN_SafeCloseFd(tmpFD); + return NN_INVALID_IP; + } + + struct sockaddr_in addr {}; + bzero(&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = inet_addr(ip.c_str()); + addr.sin_port = htons(port); + + uint32_t timesRetried = 0; + long maxConnRetryTimes = NN_NO5; + long maxConnRetryInterval = NN_NO20; + ConfigureSocketTimeouts(tmpFD, maxConnRetryTimes, maxConnRetryInterval); + + ssize_t result = -1; + ConnectState state = ConnectState::DISCONNECTED; + ConnectResp connectStatus = ConnectResp::OK; + while (timesRetried < maxConnRetryTimes) { + switch (state) { + case ConnectState::DISCONNECTED: + NN_LOG_INFO("Trying to connect to " << ip << ":" << port); + + // 指数回退, nop, 2s, 4s, 8s, ... + if (timesRetried != 0) { + sleep((1 << timesRetried) > maxConnRetryInterval ? maxConnRetryInterval : (1 << timesRetried)); + } + + if (::connect(tmpFD, reinterpret_cast(&addr), sizeof(addr)) == 0) { + state = ConnectState::CONNECTED; + continue; + } else { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Trying to connect to " + << ip << ":" << port << " errno:" << errno + << " error:" << NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE) + << " retry times:" << timesRetried); + } + break; + + case ConnectState::CONNECTED: + result = ::recv(tmpFD, &connectStatus, sizeof(ConnectResp), 0); + if (result <= 0 || connectStatus != ConnectResp::OK) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to receive connection status from peer on oob, as result:" + << result << " errno:" << errno + << " error:" << NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE) + << " connTaskStatus:" << connectStatus); + } else { + fd = tmpFD; + NN_LOG_INFO("Connect to " << ip << ":" << port << " successfully"); + return NN_OK; + } + break; + } + + timesRetried++; + } + + NetFunc::NN_SafeCloseFd(tmpFD); + NN_LOG_ERROR("Failed to connect to " << ip << ":" << port << " after tried " << timesRetried << " times"); + return NN_OOB_CLIENT_SOCKET_ERROR; +} + +NResult OOBTCPClient::Connect(const std::string &udsName, OOBTCPConnection *&conn) +{ + int fd = -1; + auto result = ConnectWithFd(udsName, fd); + if (result != NN_OK) { + return result; + } + + conn = new (std::nothrow) OOBTCPConnection(fd); + if (NN_UNLIKELY(conn == nullptr)) { + NN_LOG_ERROR("Failed to new oob connection, probably out of memory"); + NetFunc::NN_SafeCloseFd(fd); + return NN_NEW_OBJECT_FAILED; + } + + conn->mIsUds = true; + + return NN_OK; +} + +NResult OOBTCPClient::ConnectWithFd(const std::string &filename, int &fd) +{ + if (filename.empty()) { + NN_LOG_ERROR("Invalid name or file path to connect for uds, which is empty"); + return NN_OOB_CLIENT_SOCKET_ERROR; + } + + struct sockaddr_un address {}; + socklen_t addressLen = 0; + NN_ASSERT_LOG_RETURN(sizeof(address.sun_path) - 1 > filename.length(), NN_INVALID_PARAM); + + bzero(&address, sizeof(address)); + address.sun_family = AF_UNIX; + + bool abstractNs = (filename[0] != '/'); + if (abstractNs) { + address.sun_path[0] = '\0'; /* use abstract namespace */ + if (strcpy_s(address.sun_path + 1, sizeof(address.sun_path) - 1, filename.c_str()) != EOK) { + NN_LOG_ERROR("strcpy_s filename error."); + return NN_ERROR; + } + addressLen = sizeof(address.sun_family) + 1 + filename.length(); + } else { + if (!CanonicalPath(const_cast(filename))) { + NN_LOG_ERROR("Uds oob file path is invalid"); + return NN_INVALID_PARAM; + } + + if (strcpy_s(address.sun_path, sizeof(address.sun_path), filename.c_str()) != EOK) { + NN_LOG_ERROR("strcpy_s filename error."); + return NN_ERROR; + } + addressLen = sizeof(address); + } + + auto tmpFD = ::socket(AF_UNIX, SOCK_STREAM, 0); + if (tmpFD < 0) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create listen socket, errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE) << ", please check if fd is out of limit"); + return NN_OOB_CLIENT_SOCKET_ERROR; + } + int synCnt = 1; /* Set connect() retry time for quick connect */ + setsockopt(tmpFD, IPPROTO_TCP, TCP_SYNCNT, &synCnt, sizeof(synCnt)); + + uint32_t timesRetried = 0; + long maxConnRetryTimes = NN_NO5; + long maxConnRetryInterval = NN_NO20; + ConfigureSocketTimeouts(tmpFD, maxConnRetryTimes, maxConnRetryInterval); + + while (timesRetried < maxConnRetryTimes) { + NN_LOG_INFO("Trying to connect to " << filename); + if (::connect(tmpFD, reinterpret_cast(&address), addressLen) == 0) { + ConnectResp connectStatus = ConnectResp::OK; + ssize_t result = ::recv(tmpFD, &connectStatus, sizeof(ConnectResp), 0); + if (result <= 0 || connectStatus != ConnectResp::OK) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to receive connection status from peer on oob, as result:" << result << + " errno:" << errno << " error:" << NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE) << + " connTaskStatus:" << connectStatus); + } else { + fd = tmpFD; + NN_LOG_INFO("Connect to " << filename << " successfully"); + return NN_OK; + } + } + + if (errno == EINTR) { + continue; + } + + sleep(1 << timesRetried > maxConnRetryInterval ? maxConnRetryInterval : 1 << timesRetried); + timesRetried++; + + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Trying to connect to " << filename << " errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE) << " retry times:" << timesRetried); + } + + NetFunc::NN_SafeCloseFd(tmpFD); + NN_LOG_ERROR("Failed to connect to " << filename << " after tried " << timesRetried << " times"); + return NN_OOB_CLIENT_SOCKET_ERROR; +} + +void OOBTCPClient::ConfigureSocketTimeouts(int &tmpFD, long &maxConnRetryTimes, long &maxConnRetryInterval) +{ + maxConnRetryTimes = NetFunc::NN_GetLongEnv("HCOM_CONNECTION_RETRY_TIMES", NN_NO1, NN_NO10, NN_NO5); + maxConnRetryInterval = NetFunc::NN_GetLongEnv("HCOM_CONNECTION_RETRY_INTERVAL_SEC", NN_NO1, NN_NO60, NN_NO20); + auto maxRecvTimeout = NetFunc::NN_GetLongEnv("HCOM_CONNECTION_RECV_TIMEOUT_SEC", NN_NO1, NN_NO7200, NN_NO0); + auto maxSendTimeout = NetFunc::NN_GetLongEnv("HCOM_CONNECTION_SEND_TIMEOUT_SEC", NN_NO1, NN_NO7200, NN_NO0); + if (maxRecvTimeout != NN_NO0) { + struct timeval recvTimeout = { maxRecvTimeout, 0 }; + setsockopt(tmpFD, SOL_SOCKET, SO_RCVTIMEO, &recvTimeout, sizeof(timeval)); + } + if (maxSendTimeout != NN_NO0) { + struct timeval sendTimeout = { maxSendTimeout, 0 }; + setsockopt(tmpFD, SOL_SOCKET, SO_SNDTIMEO, &sendTimeout, sizeof(timeval)); + } +} +} +} diff --git a/src/transport/net_oob.h b/src/transport/net_oob.h new file mode 100644 index 0000000000000000000000000000000000000000..31d7e3b2159a1c6d8822507745a0db61606f3093 --- /dev/null +++ b/src/transport/net_oob.h @@ -0,0 +1,566 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_OOB_1233432457233_H +#define OCK_HCOM_OOB_1233432457233_H + +#include +#include +#include +#include +#include +#include + +#include "hcom.h" +#include "hcom_def.h" +#include "net_common.h" +#include "net_execution_service.h" +#include "net_load_balance.h" +#include "net_monotonic.h" +#include "net_util.h" +#include "rdma_verbs_wrapper_qp.h" +#include "securec.h" + +namespace ock { +namespace hcom { +constexpr uint64_t MAX_CB_TIME_US = NN_NO1000000; // 1s +constexpr uint32_t DEFAULT_CONN_THREAD_NUM = NN_NO2; +constexpr uint32_t DEFAULT_CONN_THREAD_QUEUE_CAP = NN_NO4096; +union ConnectHeader { + struct { + uint64_t magic : 16; + uint64_t version : 8; + uint64_t groupIndex : 8; + uint64_t protocol : 8; + uint64_t bandWidth : 8; + uint64_t devIndex : 8; + uint64_t majorVersion : 8; + uint64_t minorVersion : 8; + uint64_t tlsVersion : 16; + uint64_t reserve : 40; + }; + uint64_t wholeHeader[2] = {0}; +}; + +inline void SetConnHeader(ConnectHeader &h, uint32_t magic, uint32_t version, uint32_t groupIndex, uint32_t protocol, + uint32_t majorVersion, uint32_t minorVersion, uint32_t tlsVersion) +{ + h.magic = magic; + h.version = version; + h.groupIndex = groupIndex; + h.protocol = protocol; + h.majorVersion = majorVersion; + h.minorVersion = minorVersion; + h.tlsVersion = tlsVersion; +} + + +inline void SetDriverConnHeader(ConnectHeader &h, uint8_t bandWidth, uint8_t devIndex) +{ + h.bandWidth = bandWidth; + h.devIndex = devIndex; +} + +enum class ConnectState : int8_t { + DISCONNECTED, + CONNECTED, +}; + +/* + * @brief oob connection response + * + * =0 means no error + * >1 means no error and use this protocol for further processing + * <0 means error + */ +enum ConnectResp : int16_t { + OK_PROTOCOL_TCP = 2, /* tell client using tcp socket to connect real worker */ + OK_PROTOCOL_UDS = 1, /* tell client using uds to connect real worker */ + OK = 0, + MAGIC_MISMATCH = -1, + VERSION_MISMATCH = -2, + WORKER_GRPNO_MISMATCH = -3, + WORKER_NOT_STARTED = -4, + PROTOCOL_MISMATCH = -5, + SERVER_INTERNAL_ERROR = -6, + CONN_ACCEPT_NEW_TASK_FAIL = -7, + CONN_ACCEPT_QUEUE_FULL = -8, + SEC_VALID_FAILED = -9, + TLS_VERSION_MISMATCH = -10, +}; + +struct ConnRespWithUId { + ConnectResp connResp = OK; + uint64_t epId = 0; + + ConnRespWithUId() = default; + + ConnRespWithUId(ConnectResp resp, uint64_t uid) : connResp(resp), epId(uid) {} + + std::string ToString() const + { + std::ostringstream oss; + oss << "connResp = " << std::to_string(connResp) << ", epId = " << epId; + return oss.str(); + } +} __attribute__((packed)); + +struct ConnSecHeader { + int64_t flag = 0; + uint64_t ctx = 0; + uint32_t secInfoLen = 0; + uint8_t type = 0; + + ConnSecHeader() = default; + ConnSecHeader(int64_t flag, uint64_t ctx, uint32_t len, uint8_t type) + : flag(flag), ctx(ctx), secInfoLen(len), type(type){}; +}; + +struct OOBServerIndex { + uint8_t driverIdx = 0; + uint16_t oobSvrIdx = 0; + + OOBServerIndex() = default; + + OOBServerIndex(uint8_t driverIndex, uint16_t oobIndex) : driverIdx(driverIndex), oobSvrIdx(oobIndex) {} + + std::string ToString() const + { + std::ostringstream oss; + oss << std::to_string(driverIdx) << "-" << oobSvrIdx; + return oss.str(); + } +}; + +class OOBTCPConnection; + +class OOBTCPServer { +public: + using NewConnectionHandler = std::function; + + OOBTCPServer(const std::string &ip, uint16_t port) : OOBTCPServer(NET_OOB_TCP, ip, port) {} + + OOBTCPServer(NetDriverOobType t, const std::string &ipOrName, uint16_t portOrPerm) : mOobType(t) + { + if (mOobType == NET_OOB_TCP) { + mListenIP = ipOrName; + mListenPort = portOrPerm; + } else if (mOobType == NET_OOB_UDS) { + mUdsName = ipOrName; + mUdsPerm = portOrPerm; + } + } + + OOBTCPServer(NetDriverOobType t, const std::string &ipOrName, uint16_t portOrPerm, bool isCheck) : mOobType(t) + { + if (mOobType == NET_OOB_TCP) { + mListenIP = ipOrName; + mListenPort = portOrPerm; + } else if (mOobType == NET_OOB_UDS) { + mUdsName = ipOrName; + mUdsPerm = portOrPerm; + mCheckUdsPerm = isCheck; + } + } + + virtual ~OOBTCPServer() + { + (void)Stop(); + + if (mEs != nullptr) { + mEs->Stop(); + } + + if (mWorkerLb != nullptr) { + mWorkerLb->DecreaseRef(); + mWorkerLb = nullptr; + } + } + + inline void SetNewConnCB(const NewConnectionHandler &handler) + { + mNewConnectionHandler = handler; + } + + inline void SetWorkerLb(NetWorkerLB *lb) + { + if (lb != nullptr) { + mWorkerLb = lb; + mWorkerLb->IncreaseRef(); + } + } + + inline void SetMultiRail(bool flags) + { + enableMultiRail = flags; + } + + inline void SetNewConnCbThreadNum(uint16_t threadNum) + { + mNewConnCbThreadNum = threadNum; + } + + inline void SetNewConnCbQueueCap(uint32_t queueCap) + { + mNewConnCbQueueCap = queueCap; + } + + NResult EnableAutoPortSelection(uint16_t minPort, uint16_t maxPort); + NResult GetListenPort(uint16_t &port); + NResult GetListenIp(std::string &ip); + NResult GetUdsName(std::string &udsName); + + NResult Start(); + NResult Stop(); + + inline NetDriverOobType OobType() const + { + return mOobType; + } + + inline void Index(const OOBServerIndex &value) + { + mIndex = value; + } + + inline NResult CompareEpNum(const std::string &ip) + { + auto iter = mIpEpNumberMap.find(ip); + if (iter == mIpEpNumberMap.end()) { + return NN_OK; + } + + if (iter->second >= mMaxConnectionNum) { + return NN_ERROR; + } + + return NN_OK; + } + + inline void AddEpNum(const std::string &ip) + { + std::lock_guard guard(mEpNumMutex); + auto iter = mIpEpNumberMap.find(ip); + if (iter == mIpEpNumberMap.end()) { + mIpEpNumberMap[ip] = 1; + } else { + mIpEpNumberMap[ip] = mIpEpNumberMap[ip] + 1; + } + } + + inline void DelEpNum(const std::string &ip) + { + std::lock_guard guard(mEpNumMutex); + auto iter = mIpEpNumberMap.find(ip); + if (iter != mIpEpNumberMap.end()) { + mIpEpNumberMap[ip] = mIpEpNumberMap[ip] - 1; + + if (iter->second == 0) { + mIpEpNumberMap.erase(ip); + } + } + } + + inline void SetMaxConntionNum(uint32_t maxConnectionNum) + { + mMaxConnectionNum = maxConnectionNum; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +protected: + virtual void RunInThread(); + + NResult StartForUds(); + + virtual void DealConnectInThread(int fd, struct sockaddr_in addressIn); + +protected: + NetDriverOobType mOobType = NET_OOB_TCP; /* listen type TCP or UDS */ + std::string mListenIP; /* listen ip for tcp listener */ + uint16_t mListenPort = OOB_DEFAULT_LISTEN_PORT; /* listen port for tcp listener */ + bool mIsAutoPortSelectionEnabled = false; /* whether auto port selection is enabled or not, for tcp only */ + uint16_t mMinListenPort = 0; /* min port number when enable auto port selection */ + uint16_t mMaxListenPort = 0; /* max port number when enable auto port selection */ + std::string mUdsName; /* listen name of UDS listener */ + uint16_t mUdsPerm = 0; /* perm of uds file, if 0 means don't use file */ + bool mCheckUdsPerm = true; /* whether to verify the permission on the UDS file */ + + OOBServerIndex mIndex {}; + std::thread mAcceptThread; + bool mStarted = false; + std::atomic mThreadStarted { false }; + volatile bool mNeedStop = false; + int mListenFD = -1; + + NewConnectionHandler mNewConnectionHandler = nullptr; + NetWorkerLB *mWorkerLb = nullptr; + NetExecutorServicePtr mEs; + uint16_t mNewConnCbThreadNum = DEFAULT_CONN_THREAD_NUM; + uint32_t mNewConnCbQueueCap = DEFAULT_CONN_THREAD_QUEUE_CAP; + uint32_t mMaxConnectionNum = NN_NO250; + std::mutex mEpNumMutex; + std::map mIpEpNumberMap; + bool enableMultiRail = false; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + +private: + NResult AssignUdsAddress(sockaddr_un &address, socklen_t &addressLen); + NResult CreateAndStartSocket(); + NResult CreateAndConfigSocket(int &socketFD); + NResult BindAndListenCommon(int socketFD); + NResult BindAndListenAuto(int &socketFD); +}; +using NetOOBServerPtr = NetRef; + +class OOBTCPClient { +public: + OOBTCPClient(const std::string &ip, uint32_t port) : OOBTCPClient(NET_OOB_TCP, ip, port) {} + + OOBTCPClient(NetDriverOobType t, const std::string &ipOrName, uint32_t port) : mOobType(t) + { + if (mOobType == NET_OOB_TCP) { + mServerIP = ipOrName; + mServerPort = port; + } else if (mOobType == NET_OOB_UDS) { + mServerUdsName = ipOrName; + } + } + + virtual ~OOBTCPClient() = default; + + virtual inline NResult Connect(OOBTCPConnection *&conn) + { + if (mOobType == NET_OOB_TCP) { + return Connect(mServerIP, mServerPort, conn); + } else if (mOobType == NET_OOB_UDS) { + return Connect(mServerUdsName, conn); + } + + return NN_ERROR; + } + + inline const std::string& GetServerIp() const + { + return mServerIP; + } + + inline uint32_t GetServerPort() const + { + return mServerPort; + } + + inline const std::string& GetServerUdsName() const + { + return mServerUdsName; + } + + inline NetDriverOobType GetOobType() const + { + return mOobType; + } + + /* + * @brief for tcp + */ + virtual NResult Connect(const std::string &ip, uint32_t port, OOBTCPConnection *&conn); + static NResult ConnectWithFd(const std::string &ip, uint32_t port, int &fd); + + /* + * @brief for uds + */ + virtual NResult Connect(const std::string &udsName, OOBTCPConnection *&); + static NResult ConnectWithFd(const std::string &filename, int &fd); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS +protected: + NetDriverOobType mOobType = NET_OOB_TCP; + std::string mServerIP; + uint32_t mServerPort = OOB_DEFAULT_LISTEN_PORT; + std::string mServerUdsName; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +private: + static void ConfigureSocketTimeouts(int &tmpFD, long &maxConnRetryTimes, long &maxConnRetryInterval); +}; +using OOBTCPClientPtr = NetRef; + +class OOBTCPConnection { +public: + explicit OOBTCPConnection(int fd) : mFD(fd) {} + virtual ~OOBTCPConnection(); + + virtual NResult Send(void *buf, uint32_t size) const; + virtual NResult Receive(void *buf, uint32_t size) const; + + NResult SendMsg(msghdr msg, uint32_t size) const; + NResult ReceiveMsg(msghdr msg, uint32_t size) const; + + inline void SetIpAndPort(const std::string &ip, uint32_t port) + { + mIpAndPort = ip + ":" + std::to_string(port); + } + + inline const std::string &GetIpAndPort() const + { + return mIpAndPort; + } + + inline void ListenPort(uint32_t port) + { + mListenPort = port; + } + + inline uint32_t ListenPort() const + { + return mListenPort; + } + + inline void SetUdsName(std::string udsName) + { + mUdsName = udsName; + } + + inline const std::string &GetUdsName() const + { + return mUdsName; + } + + inline void LoadBalancer(const NetWorkerLBPtr &lb) + { + mLb = lb; + } + + inline const NetWorkerLBPtr &LoadBalancer() const + { + return mLb; + } + + inline bool IsUDS() const + { + return mIsUds; + } + + /* + * @brief transfer this oob tcp connection to real connection + */ + inline int TransferFd() + { + auto tmp = mFD; + mFD = -1; + return tmp; + } + + inline int GetFd() const + { + return mFD; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +protected: + int mFD = -1; + uint32_t mListenPort = 0; + std::string mIpAndPort; + NetWorkerLBPtr mLb = nullptr; + bool mIsUds = false; + std::string mUdsName; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class OOBTCPClient; + friend class OOBTCPServer; + friend class OOBSSLServer; + friend class OOBSSLClient; + friend class ConnectCbTask; + friend class TlsConnectCbTask; +}; + +class ConnectCbTask : public NetRunnable { +public: + using NewConnectionHandler = std::function; + + ConnectCbTask(const NewConnectionHandler &cb, int fd, const NetWorkerLBPtr &workerLb) + : mNewConnectionHandler(cb), mFd(fd), mWorkerLb(workerLb) + {} + + void SetIpPort(const std::string &clientIp, uint32_t clientPort, uint32_t serverPort) + { + mClientIP = clientIp; + mClientPort = clientPort; + mListenPort = serverPort; + } + + void SetUdsName(const std::string &udsName) + { + mUdsName = udsName; + } + + ~ConnectCbTask() override + { + NetFunc::NN_SafeCloseFd(mFd); + } + + void Run() override + { + ConnectResp resp = ConnectResp::OK; + if (::send(mFd, &resp, sizeof(ConnectResp), 0) <= 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to send connect status to peer on oob @ " << mClientIP << ":" << mClientPort << + ", as " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return; + } + + // ConnectCbTask holds and is responsible for closing fd. + // At the end of the execution, OOBTCPConnection returns fd to ConnectCbTask. + OOBTCPConnection conn(mFd); + conn.SetIpAndPort(mClientIP, mClientPort); + conn.ListenPort(mListenPort); + conn.LoadBalancer(mWorkerLb); + conn.SetUdsName(mUdsName); + + if (NN_UNLIKELY(mNewConnectionHandler == nullptr)) { + NN_LOG_ERROR("Failed to handshake and exchange address as new connection handler is null"); + return; + } + + auto startConnCb = NetMonotonic::TimeUs(); + auto result = mNewConnectionHandler(conn); + if (result != 0) { + mFd = conn.TransferFd(); + NN_LOG_ERROR("Failed to handshake and exchange address with client " << conn.GetIpAndPort() << ",result:" << + result << " continue to accept future connection"); + return; + } + auto endConnCb = NetMonotonic::TimeUs(); + auto cbTime = endConnCb - startConnCb; + if (NN_UNLIKELY(cbTime > MAX_CB_TIME_US)) { + NN_LOG_WARN("Call new Connection Cb time is too long: " << cbTime << " us."); + } + /* the socket could be transfer to real connection when type is socket */ + mFd = conn.TransferFd(); + } + +protected: + NewConnectionHandler mNewConnectionHandler = nullptr; /* new connection handler */ + int mFd = -1; /* new oob connection file descriptor */ + std::string mClientIP; /* ip of connector */ + uint32_t mClientPort = 0; /* port of connector */ + uint32_t mListenPort = 0; /* listener port */ + std::string mUdsName; + NetWorkerLBPtr mWorkerLb = nullptr; /* load balancer of worker */ +}; + +} +} + +#endif // OCK_HCOM_OOB_1233432457233_H diff --git a/src/transport/net_oob_openssl.cpp b/src/transport/net_oob_openssl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7688025fa5cd0a2b9cf8cc9f7594fad2c1c533fa --- /dev/null +++ b/src/transport/net_oob_openssl.cpp @@ -0,0 +1,386 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openssl_api_wrapper.h" +#include "net_oob_ssl.h" +#include "net_oob_openssl.h" + +namespace ock { +namespace hcom { +#define OOB_SSL_LAYER_CHECK_RET(_condition, _msg) \ + do { \ + if (_condition) { \ + NN_LOG_ERROR(_msg); \ + return NN_OOB_SSL_INIT_ERROR; \ + } \ + } while (0) + +#define OOB_SSL_LAYER_CHECK_RET_ERASE_RET(_cond, _msg) \ + do { \ + if (_cond) { \ + NN_LOG_ERROR(_msg); \ + if (erase) { \ + erase(keyPass, passLen); \ + } \ + return NN_OOB_SSL_INIT_ERROR; \ + } \ + } while (0) + +UBSHcomPskFindSessionCb OOBOpenSSLConnection::mOpenSslPskFindSessionCb = nullptr; +UBSHcomPskUseSessionCb OOBOpenSSLConnection::mOpenSslPskUseSessionCb = nullptr; + +OOBOpenSSLConnection::~OOBOpenSSLConnection() +{ + if (mSsl != nullptr) { + HcomSsl::SslShutdown(mSsl); + HcomSsl::SslFree(mSsl); + mSsl = nullptr; + } + if (mSslCtx != nullptr) { + HcomSsl::SslCtxFree(mSslCtx); + mSslCtx = nullptr; + } +} + +SSL *OOBOpenSSLConnection::TransferSsl() +{ + auto tmp = mSsl; + mSsl = nullptr; + return tmp; +} + +/* OOBOpenSSLConnection */ +NResult OOBOpenSSLConnection::Send(void *buf, uint32_t size) const +{ + if (NN_UNLIKELY(buf == nullptr) || NN_UNLIKELY(size == 0) || NN_UNLIKELY(mSsl == nullptr) || + NN_UNLIKELY(size > INT_MAX)) { + NN_LOG_ERROR("Invalid param for TLS send"); + return NN_PARAM_INVALID; + } + + int len = static_cast(size); + while (len > 0) { + int ret = HcomSsl::SslWrite(mSsl, reinterpret_cast(reinterpret_cast(buf) + size - len), len); + if (ret <= 0) { + int sslErrCode = HcomSsl::SslGetError(mSsl, ret); + NN_LOG_ERROR("Failed to write data to TLS channel, ret: " << ret << ", errno: " << sslErrCode << + " write Len: " << size); + return NN_OOB_SSL_WRITE_ERROR; + } + len -= ret; + } + return NN_OK; +} + +NResult OOBOpenSSLConnection::Receive(void *buf, uint32_t size) const +{ + if (NN_UNLIKELY(buf == nullptr) || NN_UNLIKELY(size == 0) || NN_UNLIKELY(mSsl == nullptr) || + NN_UNLIKELY(size > INT_MAX)) { + NN_LOG_ERROR("Invalid param for TLS receive"); + return NN_PARAM_INVALID; + } + + int len = static_cast(size); + while (len > 0) { + int ret = HcomSsl::SslRead(mSsl, reinterpret_cast(reinterpret_cast(buf) + size - len), len); + if (ret <= 0) { + int sslErrCode = HcomSsl::SslGetError(mSsl, ret); + NN_LOG_ERROR("Failed to read data from TLS channel, ret: " << ret << ", errno: " << sslErrCode << + ", read Len: " << len); + return NN_OOB_SSL_READ_ERROR; + } + len -= ret; + } + return NN_OK; +} + +int OOBOpenSSLConnection::DefaultSslCertVerify(X509_STORE_CTX *x509ctx, const char *arg) +{ + auto crlPath = arg; + const int checkSuccess = 1; + const int checkFailed = -1; + + if (crlPath != nullptr && strlen(crlPath) != 0) { + X509_CRL *crl = LoadCertRevokeListFile(crlPath); + if (crl == nullptr) { + NN_LOG_ERROR("Failed to load cert revocation list"); + return checkFailed; + } + X509_STORE *x509Store = HcomSsl::X509StoreCtxGet0Store(x509ctx); + HcomSsl::X509StoreCtxSetFlags(x509ctx, (unsigned long)HcomSsl::X509_V_FLAG_CRL_CHECK); + auto result = HcomSsl::X509StoreAddCrl(x509Store, crl); + if (result != NN_NO1) { + NN_LOG_INFO("Store add crl failed ret:" << result); + HcomSsl::X509CrlFree(crl); + return checkFailed; + } + HcomSsl::X509CrlFree(crl); + } + + auto verifyResult = HcomSsl::X509VerifyCert(x509ctx); + if (verifyResult != NN_NO1) { + NN_LOG_INFO("Verify failed in callback" + << " error: " << HcomSsl::X509VerifyCertErrorString(HcomSsl::X509StoreCtxGetError(x509ctx))); + return checkFailed; + } + return checkSuccess; +} + +int OOBOpenSSLConnection::CaCallbackWrapper(X509_STORE_CTX *x509ctx, void *arg) +{ + if (x509ctx == nullptr || arg == nullptr) { + return 0; + } + const int checkSuccess = 1; + const int checkFailed = -1; + + auto conn = reinterpret_cast(arg); + int ret = -1; + if (conn->mPeerCertVerifyType == VERIFY_BY_CUSTOM_FUNC) { + if (conn->mCertVerifyCallback == nullptr) { + NN_LOG_ERROR("Cert verification failed for cert verify in callback is null."); + return checkFailed; + } + ret = conn->mCertVerifyCallback(x509ctx, conn->mCrlPath.c_str()); + } else { + ret = conn->DefaultSslCertVerify(x509ctx, conn->mCrlPath.c_str()); + } + if (ret < 0) { + NN_LOG_ERROR("Cert verification failed, please check or set valid cert."); + return checkFailed; + } + return checkSuccess; +} + +int OOBOpenSSLConnection::PskFindCallbackWrapper(SSL *ssl, const unsigned char *identity, size_t identity_len, + SSL_SESSION **sess) +{ + return mOpenSslPskFindSessionCb(ssl, identity, identity_len, reinterpret_cast(sess)); +} + +int OOBOpenSSLConnection::PskUseCallbackWrapper(SSL *ssl, const EVP_MD *md, const unsigned char **id, size_t *idlen, + SSL_SESSION **sess) +{ + return mOpenSslPskUseSessionCb(ssl, md, id, idlen, reinterpret_cast(sess)); +} + +NResult OOBOpenSSLConnection::SetPSKCallback(bool isServer) +{ + if (isServer) { + if (mPskFindSessionCb == nullptr) { + NN_LOG_WARN("Callback for psk find session is not set at server"); + return NN_OK; + } + mOpenSslPskFindSessionCb = mPskFindSessionCb; + HcomSsl::SslCtxSetPskFindSessionCallback(mSslCtx, &PskFindCallbackWrapper); + } else { + if (mPskUseSessionCb == nullptr) { + NN_LOG_WARN("Callback for psk use session is not set at client"); + return NN_OK; + } + mOpenSslPskUseSessionCb = mPskUseSessionCb; + HcomSsl::SslCtxSetPskUseSessionCallback(mSslCtx, &PskUseCallbackWrapper); + } + return NN_OK; +} + +NResult OOBOpenSSLConnection::InitSSL(bool isServer) +{ + /* SSL_library_init() */ + auto ret = HcomSsl::OpensslInitSsl(0, nullptr); + OOB_SSL_LAYER_CHECK_RET((ret <= 0), "Failed to load openssl library"); + + /* SSL_load_error_strings() */ + ret = HcomSsl::OpensslInitSsl(HcomSsl::OPENSSL_INIT_LOAD_SSL_STRINGS | HcomSsl::OPENSSL_INIT_LOAD_CRYPTO_STRINGS, + nullptr); + OOB_SSL_LAYER_CHECK_RET((ret <= 0), "Failed to initialize openssl library"); + + if (isServer) { + OOB_SSL_LAYER_CHECK_RET((mCertCallback == nullptr || mKeyCallback == nullptr) && mPskFindSessionCb == nullptr, + "Both callback for cert and callback for find psk is not set at server"); + mSslCtx = HcomSsl::SslCtxNew(HcomSsl::TlsServerMethod()); + } else { + OOB_SSL_LAYER_CHECK_RET(mCaCallback == nullptr && mPskUseSessionCb == nullptr, + "Both callback for cert and callback for find psk is not set at client"); + mSslCtx = HcomSsl::SslCtxNew(HcomSsl::TlsClientMethod()); + } + OOB_SSL_LAYER_CHECK_RET(mSslCtx == nullptr, "SslCtxNew() failed"); + + HcomSsl::SslCtxCtrl(mSslCtx, HcomSsl::SSL_CTRL_SET_MAX_PROTO_VERSION, GetTLSVersion(), nullptr); + if (GetTLSVersion() == TLS_1_2) { + ret = HcomSsl::SslCtxSetOption(mSslCtx, HcomSsl::SSL_NO_TLS1_2_RENEGOTIATION); + OOB_SSL_LAYER_CHECK_RET(ret <= 0, "Failed to set renegotiation"); + } + + switch (GetCipherSuite()) { + case AES_GCM_128: + ret = HcomSsl::SslCtxSetCipherSuites(mSslCtx, "TLS_AES_128_GCM_SHA256"); + break; + case AES_GCM_256: + ret = HcomSsl::SslCtxSetCipherSuites(mSslCtx, "TLS_AES_256_GCM_SHA384"); + break; + case AES_CCM_128: + ret = HcomSsl::SslCtxSetCipherSuites(mSslCtx, "TLS_AES_128_CCM_SHA256"); + break; + case CHACHA20_POLY1305: + ret = HcomSsl::SslCtxSetCipherSuites(mSslCtx, "TLS_CHACHA20_POLY1305_SHA256"); + break; + default: + ret = NN_NO0; + } + OOB_SSL_LAYER_CHECK_RET(ret <= 0, "Failed to set cipher suites to TLS context"); + + bool success = (CommLoad(isServer) == NN_OK); + if (isServer) { + OOB_SSL_LAYER_CHECK_RET(!success, "Failed to initialize TLS context for encryption at server"); + NN_LOG_INFO("SSL Server accept one SSL client [" << HcomSsl::SslGetVersion(mSsl) << "]"); + } else { + OOB_SSL_LAYER_CHECK_RET(!success, "Failed to initialize TLS context for encryption at client"); + } + return NN_OK; +} + +NResult OOBOpenSSLConnection::VerifyCA(bool isServer) +{ + std::string caPath; + bool result = mCaCallback(mIpAndPort, caPath, mCrlPath, mPeerCertVerifyType, mCertVerifyCallback); + OOB_SSL_LAYER_CHECK_RET(!result, "Failed to get CA cert, UBSHcomTLSCaCallback return false"); + + if (mPeerCertVerifyType == VERIFY_BY_NONE && !isServer) { + HcomSsl::SslCtxSetVerify(mSslCtx, HcomSsl::SSL_VERIFY_NONE, nullptr); + } else { + HcomSsl::SslCtxSetVerify(mSslCtx, HcomSsl::SSL_VERIFY_PEER | HcomSsl::SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + HcomSsl::SslCtxSetCertVerifyCallback(mSslCtx, &this->CaCallbackWrapper, this); + } + OOB_SSL_LAYER_CHECK_RET(caPath.empty(), "Failed to get valid CA cert path via callback, it is empty."); + + std::vector caFileList; + NetFunc::NN_SplitStr(caPath, ":", caFileList); + + for (auto &caFile : caFileList) { + OOB_SSL_LAYER_CHECK_RET(!CanonicalPath(caFile), "Failed to get valid CA cert path via callback"); + auto ret = HcomSsl::SslCtxLoadVerifyLocations(mSslCtx, caFile.c_str(), nullptr); + OOB_SSL_LAYER_CHECK_RET(ret <= 0, "TLS load verify file : " << caFile << "failed"); + } + return NN_OK; +} + +NResult OOBOpenSSLConnection::CommLoad(bool isServer) +{ + /* Check the peer's CA */ + if (mCaCallback != nullptr) { + auto result = VerifyCA(isServer); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + } + + /* Set private key and check */ + int passLen = 0; + void *keyPass = nullptr; + UBSHcomTLSEraseKeypass erase = nullptr; + if (mCertCallback != nullptr && mKeyCallback != nullptr) { + std::string certPath; + bool result = mCertCallback(mIpAndPort, certPath); + OOB_SSL_LAYER_CHECK_RET(!result, "TLS callback get CERT path failed"); + OOB_SSL_LAYER_CHECK_RET(!CanonicalPath(certPath), "get invalid cert path"); + + std::string keyPath; + result = mKeyCallback(mIpAndPort, keyPath, keyPass, passLen, erase); + OOB_SSL_LAYER_CHECK_RET(!result, "TLS callback get private-key path failed"); + OOB_SSL_LAYER_CHECK_RET_ERASE_RET(!CanonicalPath(keyPath), "get invalid keyPath"); + /* load cert chain */ + auto ret = HcomSsl::SslCtxUseCertificateChainFile(mSslCtx, certPath.c_str()); + OOB_SSL_LAYER_CHECK_RET_ERASE_RET(ret <= 0, "TLS use certification file chain failed"); + HcomSsl::SslCtxSetDefaultPasswdCbUserdata(mSslCtx, keyPass); + /* load private key */ + ret = HcomSsl::SslCtxUsePrivateKeyFile(mSslCtx, keyPath.c_str(), HcomSsl::SSL_FILETYPE_PEM); + OOB_SSL_LAYER_CHECK_RET_ERASE_RET(ret <= 0, "TLS use private-key file failed"); + /* check private key */ + ret = HcomSsl::SslCtxCheckPrivateKey(mSslCtx); + OOB_SSL_LAYER_CHECK_RET_ERASE_RET(ret <= 0, "TLS check private-key failed"); + } + + /* set psk callback */ + if (mPskFindSessionCb != nullptr || mPskUseSessionCb != nullptr) { + auto ret = SetPSKCallback(isServer); + OOB_SSL_LAYER_CHECK_RET_ERASE_RET(ret != NN_OK, "Failed to set psk callback"); + } + + mSsl = HcomSsl::SslNew(mSslCtx); + OOB_SSL_LAYER_CHECK_RET_ERASE_RET(mSsl == nullptr, "Failed to new TLS, probably out of memory"); + + auto ret = HcomSsl::SslSetFd(mSsl, mFD); + OOB_SSL_LAYER_CHECK_RET_ERASE_RET(ret <= 0, "Failed to set fd to TLS, result " << ret); + + /* Server will accept and Client will connect */ + ret = isServer ? HcomSsl::SslAccept(mSsl) : HcomSsl::SslConnect(mSsl); + if (isServer) { + OOB_SSL_LAYER_CHECK_RET_ERASE_RET(ret <= 0, + "TLS Failed to accept new TLS connection, result " << ret << " failed"); + } else { + OOB_SSL_LAYER_CHECK_RET_ERASE_RET(ret <= 0, "TLS Failed to connect to TLS server, result " << ret << " failed"); + } + + if (erase != nullptr) { + erase(keyPass, passLen); + } + return NN_OK; +} + +X509_CRL *OOBOpenSSLConnection::LoadCertRevokeListFile(const char *crlFile) +{ + // check whether file is exist + char *realCrlPath = realpath(crlFile, nullptr); + if (realCrlPath == nullptr) { + return nullptr; + } + + // load crl file + BIO *in = HcomSsl::BioNew(HcomSsl::BioSFile()); + if (in == nullptr) { + free(realCrlPath); + realCrlPath = nullptr; + return nullptr; + } + + int result = HcomSsl::BioCtrl(in, HcomSsl::BIO_C_SET_FILENAME, HcomSsl::BIO_CLOSE | HcomSsl::BIO_FP_READ, + const_cast(realCrlPath)); + if (result <= 0) { + (void)HcomSsl::BioFree(in); + free(realCrlPath); + return nullptr; + } + + X509_CRL *crl = HcomSsl::PemReadBioX509Crl(in, nullptr, nullptr, nullptr); + if (crl == nullptr) { + (void)HcomSsl::BioFree(in); + free(realCrlPath); + realCrlPath = nullptr; + return nullptr; + } + + (void)HcomSsl::BioFree(in); + free(realCrlPath); + realCrlPath = nullptr; + return crl; +} +} +} diff --git a/src/transport/net_oob_openssl.h b/src/transport/net_oob_openssl.h new file mode 100644 index 0000000000000000000000000000000000000000..f7da4c13f9ae15591fc60e5557b7e9fb5b660ee6 --- /dev/null +++ b/src/transport/net_oob_openssl.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_OOB_OPENSSL_12334324233_H +#define OCK_HCOM_OOB_OPENSSL_12334324233_H + +#include +#include + +#include "net_oob_ssl.h" +#include "net_security_rand.h" +#include "net_util.h" +#include "openssl_api_wrapper.h" +#include "rdma_verbs_wrapper_qp.h" + +namespace ock { +namespace hcom { +class OOBOpenSSLConnection : public OOBSSLConnection { +public: + explicit OOBOpenSSLConnection(int fd) : OOBSSLConnection(fd) {} + ~OOBOpenSSLConnection() override; + + NResult Send(void *buf, uint32_t size) const override; + NResult Receive(void *buf, uint32_t size) const override; + + NResult InitSSL(bool server) override; + + SSL *TransferSsl(); + +private: + NResult CommLoad(bool server) override; + NResult VerifyCA(bool server) override; + + static int CaCallbackWrapper(X509_STORE_CTX *ctx, void *arg); + static X509_CRL *LoadCertRevokeListFile(const char *crlFile); + static int DefaultSslCertVerify(X509_STORE_CTX *x509ctx, const char *arg); + static int PskFindCallbackWrapper(SSL *ssl, const unsigned char *identity, size_t identity_len, SSL_SESSION **sess); + static int PskUseCallbackWrapper(SSL *ssl, const EVP_MD *md, const unsigned char **id, size_t *idlen, + SSL_SESSION **sess); + static UBSHcomPskFindSessionCb mOpenSslPskFindSessionCb; + static UBSHcomPskUseSessionCb mOpenSslPskUseSessionCb; + NResult SetPSKCallback(bool isServer); + + SSL *mSsl = nullptr; + SSL_CTX *mSslCtx = nullptr; +}; +} +} + +#endif // OCK_HCOM_OOB_OPENSSL_12334324233_H diff --git a/src/transport/net_oob_secure.cpp b/src/transport/net_oob_secure.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1efa73210d35ec2c72cb82c46cf97b324cbd8e2e --- /dev/null +++ b/src/transport/net_oob_secure.cpp @@ -0,0 +1,452 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_oob_secure.h" + +namespace ock { +namespace hcom { +NResult OOBSecureProcess::SecProcessCompareEpNum(uint32_t localIpAddr, uint32_t listenPort, + const std::string &mIpAndPort, const std::vector &oobServers) +{ + struct sockaddr_in addr {}; + bzero(&addr, sizeof(addr)); + addr.sin_addr.s_addr = localIpAddr; + char ipStr[INET_ADDRSTRLEN] = {0}; + if (inet_ntop(AF_INET, &(addr.sin_addr), ipStr, INET_ADDRSTRLEN) == nullptr) { + NN_LOG_ERROR("Failed to convert ip number to string"); + return NN_INVALID_IP; + } + std::string localIP(ipStr); + std::string ip; + uint32_t result; + uint16_t port; + + for (auto &oobServer : oobServers) { + result = static_cast(oobServer->GetListenIp(ip)); + result |= static_cast(oobServer->GetListenPort(port)); + if (result != NN_OK) { + continue; + } + if (ip == localIP || port == listenPort) { + size_t pos = mIpAndPort.find(':'); + std::string remoteIp = mIpAndPort.substr(0, pos); + return oobServer->CompareEpNum(remoteIp); + } + } + + return NN_OK; +} + +NResult OOBSecureProcess::SecProcessCompareEpNum(const std::string &localUdsName, const std::string &mIpAndPort, + const std::vector &oobServers) +{ + std::string udsName; + int result; + + for (auto &oobServer : oobServers) { + result = oobServer->GetUdsName(udsName); + if (result != NN_OK) { + continue; + } + if (udsName == localUdsName) { + size_t pos = mIpAndPort.find(':'); + std::string remoteIp = mIpAndPort.substr(0, pos); + return oobServer->CompareEpNum(remoteIp); + } + } + + return NN_OK; +} + +void OOBSecureProcess::SecProcessAddEpNum(uint32_t localIpAddr, uint32_t listenPort, const std::string &mIpAndPort, + const std::vector &oobServers) +{ + struct sockaddr_in addr {}; + bzero(&addr, sizeof(addr)); + addr.sin_addr.s_addr = localIpAddr; + char ipStr[INET_ADDRSTRLEN] = {0}; + if (inet_ntop(AF_INET, &(addr.sin_addr), ipStr, INET_ADDRSTRLEN) == nullptr) { + NN_LOG_ERROR("Failed to convert ip number to string"); + return; + } + std::string localIP(ipStr); + std::string ip; + uint32_t result; + uint16_t port; + + for (auto &oobServer : oobServers) { + result = static_cast(oobServer->GetListenIp(ip)); + result |= static_cast(oobServer->GetListenPort(port)); + if (result != NN_OK) { + continue; + } + if (ip == localIP || port == listenPort) { + size_t pos = mIpAndPort.find(':'); + std::string remoteIp = mIpAndPort.substr(0, pos); + oobServer->AddEpNum(remoteIp); + break; + } + } +} + +void OOBSecureProcess::SecProcessAddEpNum(const std::string &localUdsName, const std::string &mIpAndPort, + const std::vector &oobServers) +{ + std::string udsName; + int result; + + for (auto &oobServer : oobServers) { + result = oobServer->GetUdsName(udsName); + if (result != NN_OK) { + continue; + } + if (udsName == localUdsName) { + size_t pos = mIpAndPort.find(':'); + std::string remoteIp = mIpAndPort.substr(0, pos); + oobServer->AddEpNum(remoteIp); + break; + } + } +} + +void OOBSecureProcess::SecProcessDelEpNum(uint32_t localIpAddr, uint32_t listenPort, const std::string &mIpAndPort, + const std::vector &oobServers) +{ + struct sockaddr_in addr {}; + bzero(&addr, sizeof(addr)); + addr.sin_addr.s_addr = localIpAddr; + char ipStr[INET_ADDRSTRLEN] = {0}; + if (inet_ntop(AF_INET, &(addr.sin_addr), ipStr, INET_ADDRSTRLEN) == nullptr) { + NN_LOG_ERROR("Failed to convert ip number to string"); + return; + } + std::string localIP(ipStr); + std::string ip; + uint16_t port; + uint32_t result; + + for (auto &oobServer : oobServers) { + result = static_cast(oobServer->GetListenIp(ip)); + result |= static_cast(oobServer->GetListenPort(port)); + if (result != NN_OK) { + continue; + } + if (ip == localIP || port == listenPort) { + size_t pos = mIpAndPort.find(':'); + std::string remoteIp = mIpAndPort.substr(0, pos); + oobServer->DelEpNum(remoteIp); + break; + } + } +} + +void OOBSecureProcess::SecProcessDelEpNum(const std::string &localUdsName, const std::string &mIpAndPort, + const std::vector &oobServers) +{ + std::string udsName; + int result; + + for (auto &oobServer : oobServers) { + result = oobServer->GetUdsName(udsName); + if (result != NN_OK) { + continue; + } + if (udsName == localUdsName) { + size_t pos = mIpAndPort.find(':'); + std::string remoteIp = mIpAndPort.substr(0, pos); + oobServer->DelEpNum(remoteIp); + break; + } + } +} + +NResult OOBSecureProcess::SecProcessInOOBServer(const UBSHcomNetDriverEndpointSecInfoProvider &secInfoProvider, + const UBSHcomNetDriverEndpointSecInfoValidator &secInfoValidator, OOBTCPConnection &conn, + const std::string &driverName, UBSHcomNetDriverSecType sType) +{ + int result = 0; + auto secType = static_cast(0); + ConnectResp resp = ConnectResp::OK; + + uint64_t ctx = 0; + // validate secure info in oob client + if (NN_UNLIKELY(ValidateSecInfo(secInfoValidator, conn, driverName, secType, ctx, sType) != NN_OK)) { + resp = ConnectResp::SEC_VALID_FAILED; + if (NN_UNLIKELY((result = conn.Send(&resp, sizeof(ConnectResp))) != NN_OK)) { + NN_LOG_ERROR("Failed to send secure validate result to " << conn.GetIpAndPort() << " in driver " << + driverName); + return NN_OOB_SEC_PROCESS_ERROR; + } + return NN_OOB_SEC_PROCESS_ERROR; + } + + // send validate result to oob client + if (NN_UNLIKELY((result = conn.Send(&resp, sizeof(ConnectResp))) != NN_OK)) { + NN_LOG_ERROR("Failed to send secure validate result to " << conn.GetIpAndPort() << " in driver " << driverName); + return NN_OOB_SEC_PROCESS_ERROR; + } + + // two-way case server need to send secure info to client + // and no need to receive resp, client validate ok will return directly + if (secType == NET_SEC_VALID_TWO_WAY) { + if (NN_UNLIKELY(SendSecInfo(secInfoProvider, secInfoValidator, &conn, driverName, secType, ctx) != NN_OK)) { + return NN_OOB_SEC_PROCESS_ERROR; + } + } + + return NN_OK; +} + +NResult OOBSecureProcess::SecProcessInOOBClient(const UBSHcomNetDriverEndpointSecInfoProvider &secInfoProvider, + const UBSHcomNetDriverEndpointSecInfoValidator &secInfoValidator, OOBTCPConnection *conn, + const std::string &driverName, uint64_t ctx, UBSHcomNetDriverSecType sType) +{ + // create and send secure info + int result = 0; + auto secType = static_cast(0); + if (sType == NET_SEC_DISABLED) { + // send header no valid to server (case 5) + ConnSecHeader header(0, 0, 0, secType); + if (NN_UNLIKELY((result = conn->Send(&header, sizeof(ConnSecHeader))) != NN_OK)) { + NN_LOG_ERROR("Failed to send conn secure header to oob server " << conn->GetIpAndPort() << " in driver " << + driverName); + return NN_OOB_SEC_PROCESS_ERROR; + } + } + + // create and send secure info to oob server + if (sType != NET_SEC_DISABLED) { + if (NN_UNLIKELY(SendSecInfo(secInfoProvider, secInfoValidator, conn, driverName, secType, ctx) != NN_OK)) { + return NN_OOB_SEC_PROCESS_ERROR; + } + NN_LOG_TRACE_INFO("Secure info send to peer oob " << conn->GetIpAndPort() << " successfully, in driver " << + driverName); + } + + // receive oob server validate result + ConnectResp resp = {}; + void *tmpRsp = &resp; + if (NN_UNLIKELY((result = conn->Receive(tmpRsp, sizeof(ConnectResp))) != NN_OK)) { + return result; + } + if (resp != ConnectResp::OK) { + NN_LOG_ERROR("Received failed response:" << resp << " for validate secure info from " << conn->GetIpAndPort() << + " in driver " << driverName); + return NN_OOB_SEC_PROCESS_ERROR; + } + + // two-way case oob client should validate secure info from oob server + // oob client validate result response no need to send to oob server, server will not receive + if (secType == NET_SEC_VALID_TWO_WAY) { + uint64_t ctxBack = 0; + if (NN_UNLIKELY(ValidateSecInfo(secInfoValidator, *conn, driverName, secType, ctxBack, sType) != NN_OK)) { + return NN_OOB_SEC_PROCESS_ERROR; + } + } + + NN_LOG_TRACE_INFO("The verification is successful"); + return NN_OK; +} + +NResult OOBSecureProcess::SendSecInfo(const UBSHcomNetDriverEndpointSecInfoProvider &secInfoProvider, + const UBSHcomNetDriverEndpointSecInfoValidator &secInfoValidator, OOBTCPConnection *conn, + const std::string &driverName, UBSHcomNetDriverSecType &secType, uint64_t ctx) +{ + int result = 0; + // two-way case server provider not set return error (case 13) + if (NN_UNLIKELY(secInfoProvider == nullptr)) { + NN_LOG_ERROR("Failed to send secure info as secure info provider is null and secure type is " << + UBSHcomNetDriverSecTypeToString(secType) << " in driver " << driverName); + return NN_OOB_SEC_PROCESS_ERROR; + } + + if (NN_UNLIKELY(conn == nullptr)) { + NN_LOG_ERROR("Failed to send secure info as conn is null"); + return NN_OOB_SEC_PROCESS_ERROR; + } + + char *output = nullptr; + uint32_t outLen = 0; + int64_t flag = 0; + secType = static_cast(0); + bool needAutoFree = false; + result = secInfoProvider(ctx, flag, secType, output, outLen, needAutoFree); + if (NN_UNLIKELY(outLen > NN_NO2147483646)) { + NN_LOG_ERROR("The outLen value cannot be greater than 2147483646 in driver " << driverName); + return NN_OOB_SEC_PROCESS_ERROR; + } + // client provider registered but call provider failed, return error (case 1) + // or server provider registered but call provider failed, return error (case 9) + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to create secure info in driver " << driverName << " as do provider callback result is:" << + result); + return NN_OOB_SEC_PROCESS_ERROR; + } + + NetLocalAutoFreePtr secInfoAutoFree(output, true); + if (!needAutoFree) { + secInfoAutoFree.SetNull(); + } + + if (secType != NET_SEC_VALID_ONE_WAY && secType != NET_SEC_VALID_TWO_WAY) { + NN_LOG_ERROR("Failed to create secure info in driver " << driverName << ", as secure type:" << + UBSHcomNetDriverSecTypeToString(secType) << " in provider is invalid"); + return NN_OOB_SEC_PROCESS_ERROR; + } + + // two-way case client should register validator (case 8/10) + if (secType == NET_SEC_VALID_TWO_WAY && secInfoValidator == nullptr) { + NN_LOG_ERROR("Failed to create secure info in driver " << driverName << ", as secure type is:" << + UBSHcomNetDriverSecTypeToString(secType) << " but validator callback not set"); + return NN_OOB_SEC_PROCESS_ERROR; + } + + NN_LOG_TRACE_INFO("Secure info should send to server:" << output << " len:" << outLen << " flag:" << flag << + " ctx:" << ctx << " sec type:" << UBSHcomNetDriverSecTypeToString(secType)); + + ConnSecHeader header(flag, ctx, outLen, secType); + if (NN_UNLIKELY((result = conn->Send(&header, sizeof(ConnSecHeader))) != NN_OK)) { + NN_LOG_ERROR("Failed to send conn secure header to oob server " << conn->GetIpAndPort() << " in driver " << + driverName); + return NN_OOB_SEC_PROCESS_ERROR; + } + + if (NN_UNLIKELY((result = conn->Send(output, outLen)) != NN_OK)) { + NN_LOG_ERROR("Failed to send conn secure info to oob server " << conn->GetIpAndPort() << " in driver " << + driverName); + return NN_OOB_SEC_PROCESS_ERROR; + } + + return NN_OK; +} + +NResult OOBSecureProcess::ValidateSecInfo(const UBSHcomNetDriverEndpointSecInfoValidator &secInfoValidator, + OOBTCPConnection &conn, const std::string &driverName, UBSHcomNetDriverSecType &secType, uint64_t &ctx, + UBSHcomNetDriverSecType sType) +{ + int result = 0; + ConnSecHeader header {}; + void *headerBuf = &header; + if (NN_UNLIKELY((result = conn.Receive(headerBuf, sizeof(ConnSecHeader))) != 0)) { + NN_LOG_ERROR("Failed to read secure header from " << conn.GetIpAndPort() << " in driver " << driverName << + ", result " << result); + return NN_OOB_SEC_PROCESS_ERROR; + } + + ctx = header.ctx; + if (header.type > NET_SEC_VALID_TWO_WAY) { + NN_LOG_ERROR("Failed to validate header as secure type is invalid"); + return NN_OOB_SEC_PROCESS_ERROR; + } + + // oob client not register provider will send type 0 (case 5/6/7) + if (header.type == 0) { + // oob server not register validator, validate success (case 7) + if (sType == NET_SEC_DISABLED) { + return NN_OK; + } + + // oob server register validator but client not register provider (case 5/6) + NN_LOG_ERROR("Failed to validate header as secure type is 0, oob " << conn.GetIpAndPort() << + " may not set provider"); + return NN_OOB_SEC_PROCESS_ERROR; + } + + secType = static_cast(header.type); + + NN_LOG_TRACE_INFO("Secure header flag:" << header.flag << " ctx:" << header.ctx << " len:" << header.secInfoLen << + " sec type:" << UBSHcomNetDriverSecTypeToString(secType)); + if (NN_UNLIKELY(header.secInfoLen > NN_NO2147483646)) { + NN_LOG_ERROR("Receive secInfoLen greater than 2147483646 in " << driverName); + return NN_OOB_SEC_PROCESS_ERROR; + } + char *secInfo = new (std::nothrow) char[header.secInfoLen + NN_NO1]; + if (NN_UNLIKELY(secInfo == nullptr)) { + NN_LOG_ERROR("Failed to new buffer for sec info from peer, probably out of memory"); + return NN_OOB_SEC_PROCESS_ERROR; + } + NetLocalAutoFreePtr secInfoAutoFree(secInfo, true); + void *secBuf = static_cast(secInfo); + if (NN_UNLIKELY((result = conn.Receive(secBuf, header.secInfoLen)) != 0)) { + NN_LOG_ERROR("Failed to read secure info from " << conn.GetIpAndPort() << " in driver " << driverName << + ", result " << result); + return NN_OOB_SEC_PROCESS_ERROR; + } + secInfo[header.secInfoLen] = '\0'; + + // client provider registered but server validator not registered, validate pass (case 2) + int validateResult = 0; + if (NN_UNLIKELY(secInfoValidator == nullptr)) { + NN_LOG_WARN("Validator is null and secure type is:" << UBSHcomNetDriverSecTypeToString(secType) << + " in driver " << driverName << " , skip secure info validate"); + return NN_OK; + } + + validateResult = secInfoValidator(header.ctx, header.flag, secInfo, header.secInfoLen); + // client provider and server validator registered, but server validator validate failed (case 3) + // or two-way case server provider and client validator registered, client validator but validate failed (case 11) + if (validateResult != 0) { + NN_LOG_ERROR("Failed to validate secure info received from " << conn.GetIpAndPort() << " in driver " << + driverName << ", validate result is:" << validateResult); + return NN_OOB_SEC_PROCESS_ERROR; + } + // client provider and server validator registered and validate success, pass (case 4) + // or two-way case server provider and client validator registered and validate success, pass (case 12) + NN_LOG_INFO("Validate secure info from peer oob " << conn.GetIpAndPort() << " successfully, in driver " << + driverName); + + return NN_OK; +} + +NResult OOBSecureProcess::SecCheckConnectionHeader(const ConnectHeader &header, const UBSHcomNetDriverOptions &option, + const bool &enableTls, const UBSHcomNetDriverProtocol &protocol, const uint32_t &majorVersion, + const uint32_t &minorVersion, ConnRespWithUId &respWithUId) +{ + if (header.magic != option.magic) { + NN_LOG_ERROR("Failed to match magic number from client, connection refused header.magic"); + respWithUId.connResp = MAGIC_MISMATCH; + return NN_ERROR; + } + + if (header.protocol != protocol) { + NN_LOG_ERROR("Failed to match protocol " << protocol << " from client " << header.protocol << + ", connection refused"); + respWithUId.connResp = PROTOCOL_MISMATCH; + return NN_ERROR; + } + + if (header.majorVersion != majorVersion) { + NN_LOG_ERROR("Failed to match majorVersion " << majorVersion << " from client " << + header.majorVersion << ", connection refused"); + respWithUId.connResp = VERSION_MISMATCH; + return VERSION_MISMATCH; + } + + if (header.minorVersion > minorVersion) { + NN_LOG_ERROR("Failed to match minorVersion " << minorVersion << " from client " << + header.minorVersion << ", connection refused"); + respWithUId.connResp = VERSION_MISMATCH; + return VERSION_MISMATCH; + } + + if (enableTls) { + if (header.tlsVersion < TLS_1_2 || header.tlsVersion > TLS_1_3) { + NN_LOG_ERROR("Failed to match tls version from client " << header.tlsVersion << + ", connection refused"); + respWithUId.connResp = TLS_VERSION_MISMATCH; + return NN_ERROR; + } + } + + return NN_OK; +} +} +} \ No newline at end of file diff --git a/src/transport/net_oob_secure.h b/src/transport/net_oob_secure.h new file mode 100644 index 0000000000000000000000000000000000000000..3842a1f7516af898692ea461a06ea8287769b32e --- /dev/null +++ b/src/transport/net_oob_secure.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_OOB_SECURE_PROCESS_H +#define OCK_HCOM_OOB_SECURE_PROCESS_H + +#include "net_oob.h" + +namespace ock { +namespace hcom { + +class OOBSecureProcess { +public: + /* + * There are 32 cases for sec info process, please refer readme of this project + */ + static NResult SecProcessInOOBClient(const UBSHcomNetDriverEndpointSecInfoProvider &secInfoProvider, + const UBSHcomNetDriverEndpointSecInfoValidator &secInfoValidator, OOBTCPConnection *conn, + const std::string &driverName, uint64_t ctx, UBSHcomNetDriverSecType secType); + + static NResult SecProcessInOOBServer(const UBSHcomNetDriverEndpointSecInfoProvider &secInfoProvider, + const UBSHcomNetDriverEndpointSecInfoValidator &secInfoValidator, OOBTCPConnection &conn, + const std::string &driverName, UBSHcomNetDriverSecType sType); + + static NResult SecProcessCompareEpNum(uint32_t localIpAddr, uint32_t listenPort, const std::string &mIpAndPort, + const std::vector &oobServers); + + static void SecProcessAddEpNum(uint32_t localIpAddr, uint32_t listenPort, const std::string &mIpAndPort, + const std::vector &oobServers); + + static void SecProcessDelEpNum(uint32_t localIpAddr, uint32_t listenPort, const std::string &mIpAndPort, + const std::vector &oobServers); + + static NResult SecProcessCompareEpNum(const std::string &localUdsName, const std::string &mIpAndPort, + const std::vector &oobServers); + + static void SecProcessAddEpNum(const std::string &localUdsName, const std::string &mIpAndPort, + const std::vector &oobServers); + + static void SecProcessDelEpNum(const std::string &localUdsName, const std::string &mIpAndPort, + const std::vector &oobServers); + + static NResult SecCheckConnectionHeader(const ConnectHeader &header, const UBSHcomNetDriverOptions &option, + const bool &enableTls, const UBSHcomNetDriverProtocol &protocol, const uint32_t &majorVersion, + const uint32_t &minorVersion, ConnRespWithUId &respWithUId); + +private: + /* + * Send sec info to peer via oob connection + * step1: call sec info provider to create sec info + * step2: send header to peer, always send, no matter sec info validate is enabled or not + * step3: send sec info to peer + * + * In 1 way authentication case: only oob client calls this + * In 2 ways authentications case: both oob client and oob sever calls this + */ + static NResult SendSecInfo(const UBSHcomNetDriverEndpointSecInfoProvider &secInfoProvider, + const UBSHcomNetDriverEndpointSecInfoValidator &secInfoValidator, OOBTCPConnection *conn, + const std::string &driverName, UBSHcomNetDriverSecType &secType, uint64_t ctx); + + /* + * Validate sec info from peer via oob connection + * step1: receive head from peer, always receive, no matter sec info validate is enabled or not + * step2: receive sec info + * step3: call sec info validator to validate sec info + * + * In 1 way authentication case: only oob server calls this + * In 2 ways authentications case: both oob server and oob client calls this + */ + static NResult ValidateSecInfo(const UBSHcomNetDriverEndpointSecInfoValidator &secInfoValidator, + OOBTCPConnection &conn, const std::string &driverName, UBSHcomNetDriverSecType &secType, + uint64_t &ctx, UBSHcomNetDriverSecType sType); +}; + +} +} + +#endif \ No newline at end of file diff --git a/src/transport/net_oob_ssl.cpp b/src/transport/net_oob_ssl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..902ae3862f77c0887d26a63295e3c3642eb79a54 --- /dev/null +++ b/src/transport/net_oob_ssl.cpp @@ -0,0 +1,379 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "net_oob.h" +#include "net_oob_openssl.h" +#include "openssl_api_wrapper.h" +#include "net_oob_ssl.h" + +namespace ock { +namespace hcom { +void OOBSSLServer::DealConnectInThread(int fd, struct sockaddr_in addressIn) +{ + ConnectResp resp = ConnectResp::OK; + char ipStr[INET_ADDRSTRLEN] = {0}; + if (inet_ntop(AF_INET, &addressIn.sin_addr, ipStr, INET_ADDRSTRLEN) == nullptr) { + NN_LOG_ERROR("Failed to convert ip number to string"); + resp = SERVER_INTERNAL_ERROR; + } + auto tlsConnectCbTask = new (std::nothrow) TlsConnectCbTask(mNewConnectionHandler, fd, mWorkerLb); + if (NN_UNLIKELY(tlsConnectCbTask == nullptr)) { + resp = ConnectResp::CONN_ACCEPT_NEW_TASK_FAIL; + } else { + tlsConnectCbTask->SetIpPort(std::string(ipStr), ntohs(addressIn.sin_port), mListenPort); + tlsConnectCbTask->SetTlsCb(mTlsCertCb, mTlsPrivateKeyCb, mTlsCaCallback); + tlsConnectCbTask->SetTlsOptions(mCipherSuite, mTlsVersion); + tlsConnectCbTask->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + if (mOobType == NET_OOB_UDS) { + tlsConnectCbTask->SetUdsName(mUdsName); + } + if (NN_UNLIKELY(!mEs->Execute(tlsConnectCbTask))) { + delete tlsConnectCbTask; + resp = ConnectResp::CONN_ACCEPT_QUEUE_FULL; + NN_LOG_WARN("Invalid to execute task may be queue is full please retry it"); + } + } + + if (resp != ConnectResp::OK) { + // if accept success but execute task failed, should notify client connect fail and client will retry + if (::send(fd, &resp, sizeof(ConnectResp), 0) <= 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to send connect status to peer on oob @ " << ipStr << ":" << + ntohs(addressIn.sin_port) << ", as " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + } + } +} + +void OOBSSLServer::RunInThread() +{ + if (mOobType == NET_OOB_TCP) { + NN_LOG_INFO("OOB ssl server accept thread for " << mListenIP << ":" << mListenPort << + " started, load balancer " << (mWorkerLb == nullptr ? "null" : mWorkerLb->ToString())); + } else if (mOobType == NET_OOB_UDS) { + NN_LOG_TRACE_INFO("OOB ssl server accept thread for " << mUdsName << " started, load balancer " << + (mWorkerLb == nullptr ? "null" : mWorkerLb->ToString())); + } else { + NN_LOG_ERROR("Un-reachable"); + } + + mThreadStarted.store(true); + struct sockaddr_in addressIn {}; + socklen_t len = sizeof(addressIn); + + int flags = 1; + + auto maxRecvTimeout = NetFunc::NN_GetLongEnv("HCOM_CONNECTION_RECV_TIMEOUT_SEC", NN_NO1, NN_NO7200, NN_NO0); + auto maxSendTimeout = NetFunc::NN_GetLongEnv("HCOM_CONNECTION_SEND_TIMEOUT_SEC", NN_NO1, NN_NO7200, NN_NO0); + + while (NN_UNLIKELY(mEs == nullptr || !mEs->IsStart())) { + usleep(NN_NO100); + } + while (true) { + try { + if (NN_UNLIKELY(mNeedStop)) { + NN_LOG_INFO("Got stop signal, stop listening in oob ssl server"); + break; + } + + struct pollfd pollEventFd = {}; + pollEventFd.fd = mListenFD; + pollEventFd.events = POLLIN; + pollEventFd.revents = 0; + + int rc = poll(&pollEventFd, 1, NN_NO500); + if (rc < 0 && errno != EINTR) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Get poll event failed in oob ssl server, errno " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + break; + } + + if (rc == 0) { + continue; + } + + bzero(&addressIn, sizeof(struct sockaddr_in)); + auto fd = ::accept(mListenFD, reinterpret_cast(&addressIn), &len); + if (fd < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Invalid to accept in oob ssl server on new socket with " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << ", ignore and continue"); + continue; + } + + /* set no delay */ + setsockopt(fd, SOL_TCP, TCP_NODELAY, reinterpret_cast(&flags), sizeof(flags)); + + /* set recv or send timeout */ + if (maxRecvTimeout != NN_NO0) { + struct timeval recvTimeout = { maxRecvTimeout, 0 }; + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recvTimeout, sizeof(timeval)); + } + if (maxSendTimeout != NN_NO0) { + struct timeval sendTimeout = { maxSendTimeout, 0 }; + setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &sendTimeout, sizeof(timeval)); + } + + DealConnectInThread(fd, addressIn); + } catch (std::exception &ex) { + NN_LOG_WARN("Got exception in OOBSSLServer::RunInThread, exception " << ex.what() << + ", ignore and continue"); + } catch (...) { + NN_LOG_WARN("Got unknown error in OOBSSLServer::RunInThread, ignore and continue"); + } + } + + NN_LOG_INFO("Working thread for OOBSSLServer exiting"); +} + + +/* OOBSSLConnection */ +OOBSSLConnection::~OOBSSLConnection() +{ + NetFunc::NN_SafeCloseFd(mFD); +} + +NResult OOBSSLConnection::SendSecret() +{ + if (NN_UNLIKELY(!mSecret.Init(mCipherSuite))) { + NN_LOG_ERROR("Failed to init secret"); + return NN_ERROR; + } + + size_t len = mSecret.GetSerializeLen(); + char *serializedData = new (std::nothrow) char[len + NN_NO1]; + if (NN_UNLIKELY(serializedData == nullptr)) { + NN_LOG_ERROR("Failed to new a serializedData array, probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + serializedData[len] = '\0'; + NetLocalAutoFreePtr autoFreeData(serializedData, true); + + bool ret = mSecret.Serialize(serializedData, len); + if (!ret) { + NN_LOG_ERROR("Failed to serialize TLS exchange info"); + return NN_OOB_SSL_INIT_ERROR; + } + + NN_LOG_TRACE_INFO("Server update the secrets Len: " << len); + + auto result = Send(serializedData, len); + if (result != NN_OK) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to send info for TLS peer, error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return result; + } + + return NN_OK; +} + +NResult OOBSSLConnection::RecvSecret() +{ + if (NN_UNLIKELY(!mSecret.Init(mCipherSuite))) { + NN_LOG_ERROR("Failed to init secret"); + return NN_ERROR; + } + + size_t len = mSecret.GetSerializeLen(); + void *serializedData = malloc(len); + if (serializedData == nullptr) { + return NN_MALLOC_FAILED; + } + + auto result = Receive(serializedData, len); + if (result != NN_OK) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to receive info for TLS from peer, error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + free(serializedData); + serializedData = nullptr; + return result; + } + + bool ret = mSecret.Deserialize(static_cast(serializedData), len); + if (!ret) { + NN_LOG_ERROR("Failed to deserialize TLS exchange info"); + free(serializedData); + serializedData = nullptr; + return NN_OOB_SSL_INIT_ERROR; + } + + NN_LOG_TRACE_INFO("Client update the secrets."); + free(serializedData); + serializedData = nullptr; + return NN_OK; +} + +NResult OOBSSLConnection::SSLClientRecvHandler(int tmpFD) +{ + if (tmpFD <= 0) { + return NN_ERROR; + } + + return RecvSecret(); +} + +/* OOBSSLClient */ +NResult OOBSSLClient::Connect(const std::string &ip, uint32_t port, OOBTCPConnection *&conn) +{ + int fd = -1; + auto result = ConnectWithFd(ip, port, fd); + if (result != NN_OK) { + return result; + } + + mOobConn = new (std::nothrow) OOBOpenSSLConnection(fd); + if (NN_UNLIKELY(mOobConn == nullptr)) { + NN_LOG_ERROR("Failed to new oob connection, probably out of memory"); + NetFunc::NN_SafeCloseFd(fd); + return NN_NEW_OBJECT_FAILED; + } + + mOobConn->SetTlsOptions(mCipherSuite, mTlsVersion); + mOobConn->SetTLSCallback(mTlsCertCb, mTlsPrivateKeyCb, mTlsCaCallback); + mOobConn->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + + if (mOobConn->InitSSL(false) != NN_OK) { + delete mOobConn; + mOobConn = nullptr; + return NN_OOB_CLIENT_SOCKET_ERROR; + } + + if (mOobConn->SSLClientRecvHandler(fd) != NN_OK) { + NN_LOG_ERROR("Failed to receive secret from server to TLS"); + delete mOobConn; + mOobConn = nullptr; + return NN_OOB_CLIENT_SOCKET_ERROR; + } + + mOobConn->ListenPort(port); + conn = mOobConn; + + return NN_OK; +} + +NResult OOBSSLClient::Connect(const std::string &udsName, OOBTCPConnection *&conn) +{ + NN_LOG_INFO("SSL CONNECT"); + int fd = -1; + auto result = ConnectWithFd(udsName, fd); + if (result != NN_OK) { + return result; + } + + mOobConn = new (std::nothrow) OOBOpenSSLConnection(fd); + if (NN_UNLIKELY(mOobConn == nullptr)) { + NN_LOG_ERROR("Failed to new oob uds connection, probably out of memory"); + NetFunc::NN_SafeCloseFd(fd); + return NN_NEW_OBJECT_FAILED; + } + + mOobConn->SetTLSCallback(mTlsCertCb, mTlsPrivateKeyCb, mTlsCaCallback); + mOobConn->SetTlsOptions(mCipherSuite, mTlsVersion); + mOobConn->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + + if (mOobConn->InitSSL(false) != NN_OK) { + delete mOobConn; + mOobConn = nullptr; + return NN_OOB_CLIENT_SOCKET_ERROR; + } + + if (mOobConn->SSLClientRecvHandler(fd) != NN_OK) { + NN_LOG_ERROR("Failed to receive secret from uds server to TLS"); + delete mOobConn; + mOobConn = nullptr; + return NN_OOB_CLIENT_SOCKET_ERROR; + } + + conn = mOobConn; + conn->mIsUds = true; + + return NN_OK; +} + +void TlsConnectCbTask::Run() +{ + ConnectResp resp = ConnectResp::OK; + if (::send(mFd, &resp, sizeof(ConnectResp), 0) <= 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to send connect status to peer on oob @ " << mClientIP << ":" << mClientIP << ", as " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return; + } + + OOBSSLConnection *conn = nullptr; + conn = new (std::nothrow) OOBOpenSSLConnection(mFd); + if (NN_UNLIKELY(conn == nullptr)) { + NN_LOG_ERROR("Failed to new connection"); + return; + } + + conn->SetIpAndPort(mClientIP, mClientPort); + conn->ListenPort(mListenPort); + conn->LoadBalancer(mWorkerLb); + conn->SetTLSCallback(mTlsCertCb, mTlsPrivateKeyCb, mTlsCaCallback); + conn->SetTlsOptions(mCipherSuite, mTlsVersion); + conn->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + conn->SetUdsName(mUdsName); + + if (NN_UNLIKELY(mNewConnectionHandler == nullptr)) { + NN_LOG_ERROR("Failed to handshake and exchange address as new connection handler is null"); + delete conn; + conn = nullptr; + return; + } + + if (conn->InitSSL(true) != NN_OK) { + NN_LOG_ERROR("Failed to initialize TLS context for new connection from " << conn->GetIpAndPort()); + delete conn; + conn = nullptr; + return; + } + + /* Update the secret first */ + if (conn->SendSecret() != NN_OK) { + NN_LOG_ERROR("Failed to send TLS info to send new connection from " << conn->GetIpAndPort()); + delete conn; + conn = nullptr; + return; + } + + auto startConnCb = NetMonotonic::TimeUs(); + if (mNewConnectionHandler(*conn) != 0) { + NN_LOG_ERROR("Failed to handshake and exchange address with client " << conn->GetIpAndPort() << + ", continue to accept future connection"); + mFd = conn->TransferFd(); + delete conn; + conn = nullptr; + return; + } + auto endConnCb = NetMonotonic::TimeUs(); + auto cbTime = endConnCb - startConnCb; + if (NN_UNLIKELY(cbTime > MAX_CB_TIME_US)) { + NN_LOG_WARN("Call new Connection Cb time is too long: " << cbTime << " us."); + } + /* the socket could be transfer to real connection when type is socket */ + mFd = conn->TransferFd(); + delete conn; + conn = nullptr; +} +} +} \ No newline at end of file diff --git a/src/transport/net_oob_ssl.h b/src/transport/net_oob_ssl.h new file mode 100644 index 0000000000000000000000000000000000000000..49e15941944c9a8cdd9bbbaaf69cfeb74db3d3f6 --- /dev/null +++ b/src/transport/net_oob_ssl.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_OOB_SSL_12334324233_H +#define OCK_HCOM_OOB_SSL_12334324233_H + +#include +#include + +#include "net_oob.h" +#include "net_security_rand.h" + +namespace ock { +namespace hcom { +class OOBSSLConnection : public OOBTCPConnection { +public: + explicit OOBSSLConnection(int fd) : OOBTCPConnection(fd) {} + + ~OOBSSLConnection() override; + + /* + * @brief Initialize the SSL lib, and build the TLS, openssl make sure + * it can be call multi-times, but really do once. + * @return true for success + */ + virtual NResult InitSSL(bool server) {}; + + /* + * @brief server send the secret. + * @return true for success + */ + NResult SendSecret(); + + /* + * @brief client recv and update secrets. + * @return true for success + */ + NResult RecvSecret(); + + NResult SSLClientRecvHandler(int fd); + + void SetTLSCallback(const UBSHcomTLSCertificationCallback &certCB, const UBSHcomTLSPrivateKeyCallback &keyCB, + const UBSHcomTLSCaCallback &caCB) + { + mCertCallback = certCB; + mKeyCallback = keyCB; + mCaCallback = caCB; + } + + void SetPSKCallback(const UBSHcomPskFindSessionCb &pskFindSessionCb, const UBSHcomPskUseSessionCb &pskUseSessionCb) + { + mPskFindSessionCb = pskFindSessionCb; + mPskUseSessionCb = pskUseSessionCb; + } + + NetSecrets &Secret() + { + return mSecret; + } + + inline UBSHcomNetCipherSuite GetCipherSuite() const + { + return mCipherSuite; + } + + inline void SetTlsOptions(UBSHcomNetCipherSuite cipherSuite, UBSHcomTlsVersion tlsVersion) + { + mCipherSuite = cipherSuite; + mTlsVersion = tlsVersion; + } + + inline uint32_t GetTLSVersion() const + { + return mTlsVersion; + } + +protected: + /* Server and Client build the TLS */ + virtual NResult CommLoad(bool server) {}; + virtual NResult VerifyCA(bool server) {}; + + NetSecrets mSecret; + UBSHcomTLSCertificationCallback mCertCallback = nullptr; + UBSHcomTLSPrivateKeyCallback mKeyCallback = nullptr; + UBSHcomTLSCertVerifyCallback mCertVerifyCallback = nullptr; + UBSHcomTLSCaCallback mCaCallback = nullptr; + std::string mCrlPath; + UBSHcomPeerCertVerifyType mPeerCertVerifyType = VERIFY_BY_DEFAULT; + UBSHcomNetCipherSuite mCipherSuite = AES_GCM_128; + UBSHcomTlsVersion mTlsVersion = TLS_1_3; + + UBSHcomPskFindSessionCb mPskFindSessionCb = nullptr; + UBSHcomPskUseSessionCb mPskUseSessionCb = nullptr; +}; + +class OOBSSLServer : public OOBTCPServer { +public: + OOBSSLServer(NetDriverOobType t, const std::string &ipOrName, uint16_t portOrPerm, + UBSHcomTLSPrivateKeyCallback &keyCB, UBSHcomTLSCertificationCallback &certCB, UBSHcomTLSCaCallback &caCB) + : OOBTCPServer(t, ipOrName, portOrPerm), mTlsPrivateKeyCb(keyCB), mTlsCertCb(certCB), mTlsCaCallback(caCB) + {} + + OOBSSLServer(NetDriverOobType t, const std::string &ipOrName, uint16_t portOrPerm, bool isCheck, + UBSHcomTLSPrivateKeyCallback &keyCB, UBSHcomTLSCertificationCallback &certCB, UBSHcomTLSCaCallback &caCB) + : OOBTCPServer(t, ipOrName, portOrPerm, isCheck), + mTlsPrivateKeyCb(keyCB), + mTlsCertCb(certCB), + mTlsCaCallback(caCB) + {} + + ~OOBSSLServer() override = default; + + void RunInThread() override; + + void DealConnectInThread(int fd, struct sockaddr_in addressIn) override; + + inline UBSHcomNetCipherSuite GetCipherSuite() const + { + return mCipherSuite; + } + + inline void SetTlsOptions(UBSHcomNetCipherSuite cipherSuite, UBSHcomTlsVersion tlsVersion) + { + mCipherSuite = cipherSuite; + mTlsVersion = tlsVersion; + } + + inline void SetPSKCallback(const UBSHcomPskFindSessionCb &pskFindSessionCb, + const UBSHcomPskUseSessionCb &pskUseSessionCb) + { + mPskFindSessionCb = pskFindSessionCb; + mPskUseSessionCb = pskUseSessionCb; + } + + inline uint32_t GetTLSVersion() + { + return mTlsVersion; + } + +private: + UBSHcomTLSPrivateKeyCallback mTlsPrivateKeyCb = nullptr; + UBSHcomTLSCertificationCallback mTlsCertCb = nullptr; + UBSHcomTLSCaCallback mTlsCaCallback = nullptr; + UBSHcomNetCipherSuite mCipherSuite = AES_GCM_128; + UBSHcomTlsVersion mTlsVersion = TLS_1_3; + + UBSHcomPskFindSessionCb mPskFindSessionCb = nullptr; + UBSHcomPskUseSessionCb mPskUseSessionCb = nullptr; +}; + +class OOBSSLClient : public OOBTCPClient { +public: + OOBSSLClient(NetDriverOobType t, std::string serverIpOrName, uint16_t serverPort, + UBSHcomTLSPrivateKeyCallback &keyCB, UBSHcomTLSCertificationCallback &certCB, UBSHcomTLSCaCallback &caCB) + : OOBTCPClient(t, serverIpOrName, serverPort), mTlsCaCallback(caCB), mTlsCertCb(certCB), mTlsPrivateKeyCb(keyCB) + {} + + ~OOBSSLClient() override = default; + + /* for tcp */ + inline NResult Connect(OOBTCPConnection *&conn) override + { + if (mOobType == NET_OOB_TCP) { + return Connect(mServerIP, mServerPort, conn); + } else if (mOobType == NET_OOB_UDS) { + return Connect(mServerUdsName, conn); + } + + return NN_ERROR; + } + + NResult Connect(const std::string &ip, uint32_t port, OOBTCPConnection *&conn) override; + + NResult Connect(const std::string &udsName, OOBTCPConnection *&conn) override; + + inline void SetTlsOptions(UBSHcomNetDriverOptions options) + { + mCipherSuite = options.cipherSuite; + mTlsVersion = options.tlsVersion; + } + + inline void SetPSKCallback(const UBSHcomPskFindSessionCb &pskFindSessionCb, + const UBSHcomPskUseSessionCb &pskUseSessionCb) + { + mPskFindSessionCb = pskFindSessionCb; + mPskUseSessionCb = pskUseSessionCb; + } + +private: + UBSHcomNetCipherSuite mCipherSuite = AES_GCM_128; + UBSHcomTlsVersion mTlsVersion = TLS_1_3; + UBSHcomTLSCaCallback mTlsCaCallback = nullptr; + UBSHcomTLSCertificationCallback mTlsCertCb = nullptr; + UBSHcomTLSPrivateKeyCallback mTlsPrivateKeyCb = nullptr; + + UBSHcomPskFindSessionCb mPskFindSessionCb = nullptr; + UBSHcomPskUseSessionCb mPskUseSessionCb = nullptr; + OOBSSLConnection *mOobConn {}; +}; + +class TlsConnectCbTask : public ConnectCbTask { +public: + using NewConnectionHandler = std::function; + + TlsConnectCbTask(NewConnectionHandler cb, int fd, NetWorkerLBPtr workerLb) : ConnectCbTask(cb, fd, workerLb) {} + + ~TlsConnectCbTask() override + { + NetFunc::NN_SafeCloseFd(mFd); + } + + void SetTlsCb(UBSHcomTLSCertificationCallback certCb, UBSHcomTLSPrivateKeyCallback privateKeyCb, + UBSHcomTLSCaCallback caCb) + { + mTlsCertCb = certCb; + mTlsPrivateKeyCb = privateKeyCb; + mTlsCaCallback = caCb; + }; + + void SetTlsOptions(UBSHcomNetCipherSuite cipherSuite, UBSHcomTlsVersion tlsVersion) + { + mCipherSuite = cipherSuite; + mTlsVersion = tlsVersion; + } + + void SetPSKCallback(const UBSHcomPskFindSessionCb &pskFindSessionCb, const UBSHcomPskUseSessionCb &pskUseSessionCb) + { + mPskFindSessionCb = pskFindSessionCb; + mPskUseSessionCb = pskUseSessionCb; + } + + void Run() override; + +private: + UBSHcomTLSPrivateKeyCallback mTlsPrivateKeyCb = nullptr; + UBSHcomTLSCertificationCallback mTlsCertCb = nullptr; + UBSHcomTLSCaCallback mTlsCaCallback = nullptr; + UBSHcomNetCipherSuite mCipherSuite = AES_GCM_128; + UBSHcomTlsVersion mTlsVersion = TLS_1_3; + + UBSHcomPskFindSessionCb mPskFindSessionCb = nullptr; + UBSHcomPskUseSessionCb mPskUseSessionCb = nullptr; +}; +} +} + +#endif // OCK_HCOM_OOB_SSL_12334324233_H diff --git a/src/transport/rdma/rdma_common.h b/src/transport/rdma/rdma_common.h new file mode 100644 index 0000000000000000000000000000000000000000..2af5a2adaa643753ca8fa60c468cc866a4080038 --- /dev/null +++ b/src/transport/rdma/rdma_common.h @@ -0,0 +1,306 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_RDMA_COMMON_1234234341233_H +#define OCK_RDMA_COMMON_1234234341233_H +#ifdef RDMA_BUILD_ENABLED + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hcom_def.h" +#include "hcom_log.h" +#include "net_common.h" +#include "net_obj_pool.h" +#include "verbs_api_wrapper.h" + +namespace ock { +namespace hcom { +/* + * return type + */ +using RResult = int; + +enum RRCode { + RR_OK = 0, + RR_PARAM_INVALID = 200, + RR_MEMORY_ALLOCATE_FAILED = 201, + RR_NEW_OBJECT_FAILED = 202, + RR_OPEN_FILE_FAILED = 203, + RR_READ_FILE_FAILED = 204, + RR_DEVICE_FAILED_OPEN = 205, + RR_DEVICE_INDEX_OVERFLOW = 206, + RR_DEVICE_OPEN_FAILED = 207, + RR_DEVICE_FAILED_GET_IF_ADDRESS = 208, + RR_DEVICE_NO_IF_MATCHED = 209, + RR_DEVICE_NO_IF_TO_GID_MATCHED = 210, + RR_DEVICE_INVALID_IP_MASK = 211, + RR_MR_REG_FAILED = 212, + RR_CQ_NOT_INITIALIZED = 213, + RR_CQ_POLLING_FAILED = 214, + RR_CQ_POLLING_TIMEOUT = 215, + RR_CQ_POLLING_ERROR_RESULT = 216, + RR_CQ_POLLING_UNMATCHED_OPCODE = 217, + RR_CQ_EVENT_GET_FAILED = 218, + RR_CQ_EVENT_NOTIFY_FAILED = 219, + RR_CQ_WC_WRONG = 220, + RR_CQ_EVENT_GET_TIMOUT = 221, + RR_QP_CREATE_FAILED = 222, + RR_QP_NOT_INITIALIZED = 223, + RR_QP_CHANGE_STATE_FAILED = 224, + RR_QP_POST_RECEIVE_FAILED = 225, + RR_QP_POST_SEND_FAILED = 226, + RR_QP_POST_READ_FAILED = 227, + RR_QP_POST_WRITE_FAILED = 228, + RR_QP_RECEIVE_CONFIG_ERR = 229, + RR_QP_POST_SEND_WR_FULL = 230, + RR_QP_ONE_SIDE_WR_FULL = 231, + RR_QP_CTX_FULL = 232, + RR_QP_CHANGE_ERR = 233, + RR_OOB_LISTEN_SOCKET_ERROR = 234, + RR_OOB_CONN_SEND_ERROR = 235, + RR_OOB_CONN_RECEIVE_ERROR = 236, + RR_OOB_CONN_CB_NOT_SET = 237, + RR_OOB_CLIENT_SOCKET_ERROR = 238, + RR_OOB_SSL_INIT_ERROR = 239, + RR_OOB_SSL_WRITE_ERROR = 240, + RR_OOB_SSL_READ_ERROR = 241, + RR_EP_NOT_INITIALIZED = 242, + RR_WORKER_NOT_INITIALIZED = 243, + RR_WORKER_BIND_CPU_FAILED = 244, + RR_WORKER_REQUEST_HANDLER_NOT_SET = 245, + RR_WORKER_SEND_POSTED_HANDLER_NOT_SET = 246, + RR_WORKER_ONE_SIDE_DONE_HANDLER_NOT_SET = 247, + RR_WORKER_FAILED_ADD_QP = 248, + RR_HEARTBEAT_CREATE_EPOLL_FAILED = 249, + RR_HEARTBEAT_SET_SOCKET_OPT_FAILED = 250, + RR_HEARTBEAT_IP_ALREADY_EXISTED = 251, + RR_HEARTBEAT_IP_ADD_FAILED = 252, + RR_HEARTBEAT_IP_ADD_EPOLL_FAILED = 253, + RR_HEARTBEAT_IP_REMOVE_EPOLL_FAILED = 254, + RR_HEARTBEAT_IP_NO_FOUND = 255, + RR_WORKER_START_ERROR = 256, +}; + +// constant variable +constexpr uint32_t QP_MAX_SEND_WR = 256; +constexpr uint32_t QP_MAX_RECV_WR = 256; +constexpr uint32_t QP_MIN_RNR_TIMER = 12; +constexpr uint32_t QP_TIMEOUT = 14; +constexpr uint32_t QP_RETRY_COUNT = 7; +constexpr uint32_t QP_RNR_RETRY = 7; +constexpr uint32_t CQ_COUNT = 1024; + +const std::string RDMA_EMPTY_STRING; + +/* + * class forward declaration + */ + +class RDMAMemoryRegionFixedBuffer; + +// verbs wrappers +class RDMADeviceHelper; +class RDMAContext; +class RDMAQp; +class RDMACq; +class RDMAMemoryRegion; + +// logic part +class RDMAWorker; + +// oob for qp setup +class OOBTCPConnection; +class OOBTCPServer; +class OOBTCPClient; + +// the size of RDMAOpContextInfo is 64 bytes which fit to single CPU cache line +struct RDMAOpContextInfo { + enum OpType : uint8_t { + SEND = 0, + SEND_RAW = 1, + SEND_RAW_SGL = 2, + RECEIVE = 3, + RECEIVE_RAW = 4, + WRITE = 5, + READ = 6, + SGL_WRITE = 7, + SGL_READ = 8, + HB_WRITE = 9, + SEND_SGL_INLINE = 10, + }; + + enum OpResultType : uint8_t { + SUCCESS = 0, + ERR_TIMEOUT = 1, + ERR_CANCELED = 2, + ERR_IO_ERROR = 3, + ERR_EP_BROKEN = 4, + ERR_EP_CLOSE = 5, + + INVALID_MAGIC = 0xFF, + }; + + enum MrType : uint8_t { + MR = 2 + }; + + RDMAQp *qp = nullptr; /* pointer to qp */ + struct RDMAOpContextInfo *prev = nullptr; /* link to prev context */ + struct RDMAOpContextInfo *next = nullptr; /* link to next context */ + + union { + uintptr_t whole = 0; + struct { + /* low address */ + /* address of the buffer, the uintptr_t has 64 bits, only the low 48 bits would be used for address */ + uintptr_t mrMemAddr : 56; + /* high address */ + MrType mrType : 8; + }; + } __attribute__((packed)); + uint32_t lKey = 0; /* local key */ + uint32_t dataSize = 0; /* actual data size */ + uint32_t qpNum = 0; /* qp ID */ + OpType opType = RECEIVE; /* op type */ + OpResultType opResultType = OpResultType::SUCCESS; /* op result */ + uint16_t upCtxSize = 0; /* up context size stored in upCtx[] */ + char upCtx[NN_NO16] = {}; /* 16 bytes for upper context */ + + static inline OpResultType OpResult(struct ibv_wc &result) + { + // any status except success indicating the qp is abnormal + switch (result.status) { + case IBV_WC_SUCCESS: + return OpResultType::SUCCESS; + case IBV_WC_RETRY_EXC_ERR: + case IBV_WC_RNR_RETRY_EXC_ERR: + return OpResultType::ERR_TIMEOUT; + case IBV_WC_WR_FLUSH_ERR: + return OpResultType::ERR_CANCELED; + default: + return OpResultType::ERR_IO_ERROR; + } + } + + static inline NResult GetNResult(OpResultType opResult) + { + switch (opResult) { + case OpResultType::SUCCESS: + return NN_OK; + case OpResultType::ERR_TIMEOUT: + return NN_MSG_TIMEOUT; + case OpResultType::ERR_CANCELED: + return NN_MSG_CANCELED; + case OpResultType::ERR_EP_BROKEN: + return NN_EP_BROKEN; + case OpResultType::ERR_EP_CLOSE: + return NN_EP_CLOSE; + default: + return NN_MSG_ERROR; + } + } +} __attribute__((packed)); + +struct RDMASglContextInfo { + RDMAQp *qp = nullptr; // the qp pointer which posted from + UBSHcomNetTransSgeIov iov[NET_SGE_MAX_IOV] = {}; + NResult result = NN_OK; + uint32_t reserve1 = 0; + uint16_t refCount = 0; // equal to iovCount + uint16_t iovCount = 0; // max count:NN_NO16 + uint16_t upCtxSize = 0; + uint16_t reserve2 = 0; + char upCtx[NN_NO16] = {}; // 16 bytes for upper context +} __attribute__((packed)); + +struct RDMASgeCtxInfo { + RDMASglContextInfo *ctx = nullptr; + uint16_t idx = 0; + + RDMASgeCtxInfo() = default; + explicit RDMASgeCtxInfo(RDMASglContextInfo *sglCtx) : ctx(sglCtx) {} +} __attribute__((packed)); + +enum RDMAPollingMode : uint8_t { + BUSY_POLLING = 0, + EVENT_POLLING = 1, +}; + +struct QpOptions { + uint32_t maxSendWr = QP_MAX_SEND_WR; + uint32_t maxReceiveWr = QP_MAX_RECV_WR; + uint32_t mrSegSize = NN_NO1024; + uint32_t mrSegCount = NN_NO64; + + QpOptions() = default; + + QpOptions(uint32_t maxSendWrNum, uint32_t maxReceiveWrNum, uint32_t segSize, uint32_t segCount) + : maxSendWr(maxSendWrNum), maxReceiveWr(maxReceiveWrNum), mrSegSize(segSize), mrSegCount(segCount) + {} +} __attribute__((packed)); + +inline RResult ReadRoCEVersionFromFile(const std::string &deviceName, uint32_t portNumber, uint32_t gid, + std::string &version) +{ + std::ostringstream oSStream; + char filePath[PATH_MAX] = {0}; + oSStream << "/sys/class/infiniband/" << deviceName.c_str() << "/ports/" << portNumber << "/gid_attrs/types/" << gid; + if (oSStream.str().length() > IBV_SYSFS_PATH_MAX || nullptr == realpath(oSStream.str().c_str(), filePath)) { + NN_LOG_ERROR("The file name is incorrect"); + return RR_PARAM_INVALID; + } + + char fileContent[64] = {0}; + int fd = open(filePath, O_RDONLY); + if (fd < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to open file " << oSStream.str() << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_OPEN_FILE_FAILED; + } + + auto len = read(fd, fileContent, 15); + if (len < 0) { + close(fd); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to read content file " << oSStream.str() << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_READ_FILE_FAILED; + } + + if (len > 1 && fileContent[len - 1] == '\n') { + version = std::string(fileContent, len - 1); + } else { + version = std::string(fileContent, len); + } + + close(fd); + return RR_OK; +} +} +} +#endif +#endif // OCK_RDMA_COMMON_1234234341233_H diff --git a/src/transport/rdma/rdma_heartbeat.cpp b/src/transport/rdma/rdma_heartbeat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae6b78189d573d7ec3f105fe720dd73de3cf098c --- /dev/null +++ b/src/transport/rdma/rdma_heartbeat.cpp @@ -0,0 +1,311 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include + +#include "rdma_heartbeat.h" +namespace ock { +namespace hcom { +constexpr uint32_t MAX_EPOLL_SIZE = 4096 * 4; // 4096 hosts, 4 card per host +constexpr uint32_t MAX_EPOLL_WAIT_EVENTS = 16; +constexpr uint32_t EPOLL_WAIT_TIMEOUT = 1000; // 1 second + +RIPDeviceHeartbeatManager::RIPDeviceHeartbeatManager(const std::string &name) : mName(name) {} + +NResult RIPDeviceHeartbeatManager::Initialize() +{ + if (mEpollHandle > 0) { + return NN_OK; + } + + if (mConnBrokenCheckHandler == nullptr) { + NN_LOG_ERROR("ConnBrokenCheckHandler is not set in RIPDeviceHeartbeatManager " << mName); + return NN_PARAM_INVALID; + } + + if (mConnBrokenPostHandler == nullptr) { + NN_LOG_ERROR("ConnBrokenPostHandler is not set in RIPDeviceHeartbeatManager " << mName); + return NN_PARAM_INVALID; + } + + int epollHandle = epoll_create(MAX_EPOLL_SIZE); + if (epollHandle < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create epoll in RIPDeviceHeartbeatManager " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_HEARTBEAT_CREATE_EPOLL_FAILED; + } + + mEpollHandle = epollHandle; + mStarted.store(false); + return NN_OK; +} + +void RIPDeviceHeartbeatManager::UnInitialize() +{ + if (mEpollHandle == -1) { + return; + } + + { + std::lock_guard guard(mMutex); + mIpFdMap.clear(); + mFdIpMap.clear(); + } + close(mEpollHandle); + mEpollHandle = -1; +} + +NResult RIPDeviceHeartbeatManager::Start() +{ + std::lock_guard guard(mMutex); + if (mStarted.load()) { + NN_LOG_INFO("RIPDeviceHeartbeatManager " << mName << " already started"); + return NN_OK; + } + + std::thread tmpThread(&RIPDeviceHeartbeatManager::RunInThread, this); + mWorkingThread = std::move(tmpThread); + std::string threadName = "IpHeartbeat"; + if (pthread_setname_np(mWorkingThread.native_handle(), threadName.c_str()) != 0) { + NN_LOG_WARN("Failed to set name of RIPDeviceHeartbeatManager working thread"); + } + + while (!mStarted.load()) { + usleep(NN_NO10); + } + + return NN_OK; +} + +void RIPDeviceHeartbeatManager::Stop() +{ + mNeedStop = true; + if (mWorkingThread.native_handle()) { + mWorkingThread.join(); + } +} + +void RIPDeviceHeartbeatManager::RunInThread() +{ + mStarted.store(true); + NN_LOG_INFO("RIPDeviceHeartbeatManager " << mName << " working thread started"); + struct epoll_event ev[MAX_EPOLL_WAIT_EVENTS]; + while (!mNeedStop) { + try { + // do epoll wait + int count = epoll_wait(mEpollHandle, ev, MAX_EPOLL_WAIT_EVENTS, EPOLL_WAIT_TIMEOUT); + if (count <= 0) { + continue; + } + + HandleEpollEvent(count, ev); + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Got runtime error in RIPDeviceHeartbeatManager::RunInThread '" << ex.what() << + "', ignore and continue"); + } catch (...) { + NN_LOG_WARN("Got unknown error in RIPDeviceHeartbeatManager::RunInThread, ignore and continue"); + } + } + NN_LOG_INFO("RIPDeviceHeartbeatManager " << mName << " working thread exiting"); +} + +void RIPDeviceHeartbeatManager::HandleEpollEvent(uint32_t eventCount, struct epoll_event *events) +{ + if (events == nullptr) { + return; + } + + std::unordered_set fds; + fds.reserve(eventCount); + for (uint32_t i = 0; i < eventCount; i++) { + if (!(events[i].events & EPOLLIN)) { + continue; + } + + try { + if (!mConnBrokenCheckHandler(events[i].data.fd)) { + fds.emplace(static_cast(events[i].data.fd)); + } + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Got runtime error in mConnBrokenCheckHandler " << ex.what() << ", ignored"); + } catch (...) { + NN_LOG_WARN("Got unknown error in mConnBrokenCheckHandler , ignored"); + } + } + + if (fds.empty()) { + return; + } + + // remove related fd and ip from maps + for (auto item : fds) { + RemoveByFD(item); + } + + // call post handler function + for (auto item : fds) { + try { + mConnBrokenPostHandler(item); + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Got runtime error in mConnBrokenPostHandler " << ex.what() << ", ignored"); + } catch (...) { + NN_LOG_WARN("Got unknown error in mConnBrokenPostHandler , ignored"); + } + } +} + +NResult RIPDeviceHeartbeatManager::AddNewIP(const std::string &ip, int fd) +{ + if (fd < 0) { + NN_LOG_ERROR("Failed to add new IP and fd to RIPDeviceHeartbeatManager " << mName << " as fd is invalid"); + return NN_PARAM_INVALID; + } + + // set keep alive params + int value = 1; + RKeepaliveConfig tmpConfig = mKeepaliveConfig; + size_t optSize = sizeof(tmpConfig.probeTimes); + if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &value, sizeof(value)) < 0 || + setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &tmpConfig.idleTime, optSize) < 0 || + setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &tmpConfig.probeInterval, optSize) < 0 || + setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &tmpConfig.probeTimes, optSize) < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set keepalive option for " << ip << "-" << fd << " in RIPDeviceHeartbeatManager " << + mName << ", errno:" << errno << " error:" << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_HEARTBEAT_SET_SOCKET_OPT_FAILED; + } + + { + std::lock_guard guard(mMutex); + auto iter = mIpFdMap.find(ip); + if (iter != mIpFdMap.end()) { + NN_LOG_ERROR("Failed to add " << ip << " into RIPDeviceHeartbeatManager " << mName << + " as already existed, remove it firstly."); + return NN_HEARTBEAT_IP_ALREADY_EXISTED; + } + + struct epoll_event ev {}; + ev.events = EPOLLIN; + ev.data.fd = fd; + if (epoll_ctl(mEpollHandle, EPOLL_CTL_ADD, fd, &ev) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to add " << ip << " into RIPDeviceHeartbeatManager " << mName << + " as epoll add failed, error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_HEARTBEAT_IP_ADD_EPOLL_FAILED; + } + + if (!mIpFdMap.emplace(ip, fd).second || !mFdIpMap.emplace(fd, ip).second) { + NN_LOG_ERROR("Failed to add " << ip << " into RIPDeviceHeartbeatManager " << mName); + return NN_HEARTBEAT_IP_ADD_FAILED; + } + } + + return NN_OK; +} + +NResult RIPDeviceHeartbeatManager::GetFdByIP(const std::string &ip, int &fd) +{ + std::lock_guard guard(mMutex); + auto iter = mIpFdMap.find(ip); + if (iter == mIpFdMap.end()) { + NN_LOG_ERROR("No ip " << ip << " found from RIPDeviceHeartbeatManager " << mName); + return NN_HEARTBEAT_IP_NO_FOUND; + } + + fd = iter->second; + return NN_OK; +} + +NResult RIPDeviceHeartbeatManager::RemoveIP(const std::string &ip) +{ + int fd = -1; + { + std::lock_guard guard(mMutex); + auto iter = mIpFdMap.find(ip); + if (iter == mIpFdMap.end()) { + NN_LOG_ERROR("No ip " << ip << " found from RIPDeviceHeartbeatManager " << mName); + return NN_HEARTBEAT_IP_NO_FOUND; + } + + fd = iter->second; + mIpFdMap.erase(iter); + mFdIpMap.erase(fd); + } + + if (epoll_ctl(mEpollHandle, EPOLL_CTL_DEL, fd, nullptr) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to delete from epoll handle for " << ip << "-" << fd << " in RIPDeviceHeartbeatManager " << + mName << ", error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_HEARTBEAT_IP_REMOVE_EPOLL_FAILED; + } + return NN_OK; +} + +NResult RIPDeviceHeartbeatManager::RemoveByFD(int fd) +{ + std::string ip; + { + std::lock_guard guard(mMutex); + auto iter = mFdIpMap.find(fd); + if (iter == mFdIpMap.end()) { + NN_LOG_ERROR("No fd " << fd << " found from RIPDeviceHeartbeatManager " << mName); + return NN_HEARTBEAT_IP_NO_FOUND; + } + + ip = iter->second; + mFdIpMap.erase(iter); + mIpFdMap.erase(ip); + } + + if (epoll_ctl(mEpollHandle, EPOLL_CTL_DEL, fd, nullptr) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to delete from epoll handle for " << ip << "-" << fd << " in RIPDeviceHeartbeatManager " << + mName << ", error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_HEARTBEAT_IP_REMOVE_EPOLL_FAILED; + } + + return NN_OK; +} + +bool RIPDeviceHeartbeatManager::DefaultConnBrokenCheckCB(int fd) +{ + char data[1]; + auto result = recv(fd, data, 1, MSG_DONTWAIT); + if (result < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // connection is still ok + return true; + } + + // connection is wrong + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_INFO("DefaultConnBrokenCheckCB connection is wrong, fd " << fd << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return false; + } else if (result == 0) { + NN_LOG_INFO("DefaultConnBrokenCheckCB connection is broken, fd " << fd); + return false; // connection really broken + } else { + return true; + } +} + +void RIPDeviceHeartbeatManager::DefaultConnBrokenPostCB(int fd) +{ + NN_LOG_INFO("DefaultConnBrokenPostCB close fd"); + close(fd); +} +} +} \ No newline at end of file diff --git a/src/transport/rdma/rdma_heartbeat.h b/src/transport/rdma/rdma_heartbeat.h new file mode 100644 index 0000000000000000000000000000000000000000..2464603e07b8fae865b5c52e4efb2e92c9278bc1 --- /dev/null +++ b/src/transport/rdma/rdma_heartbeat.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_RDMA_HEARTBEAT_1245609845341233_H +#define OCK_RDMA_HEARTBEAT_1245609845341233_H + +#include +#include +#include "common/net_common.h" + +namespace ock { +namespace hcom { +using RKeepaliveConfig = struct RKeepaliveConfigStruct { + uint32_t idleTime = 5; // idle 5 seconds to start to probe + uint32_t probeTimes = 7; // probe times + uint32_t probeInterval = 2; // probe interval +} __attribute__((packed)); + +using RIPConnBrokenCheckHandler = std::function; +using RIPConnBrokenPostHandler = std::function; + +class RIPDeviceHeartbeatManager { +public: + explicit RIPDeviceHeartbeatManager(const std::string &name); + ~RIPDeviceHeartbeatManager() + { + UnInitialize(); + } + + inline void SetKeepaliveConfig(uint32_t idleTime, uint32_t probeTimes, uint32_t probeInterval) + { + mKeepaliveConfig.idleTime = idleTime; + mKeepaliveConfig.probeTimes = probeTimes; + mKeepaliveConfig.probeInterval = probeInterval; + } + + inline void SetConnBrokenCheckHandler(const RIPConnBrokenCheckHandler &value) + { + mConnBrokenCheckHandler = value; + } + + inline void SetConnBrokenPostHandler(const RIPConnBrokenPostHandler &value) + { + mConnBrokenPostHandler = value; + } + + NResult Initialize(); + + void UnInitialize(); + + NResult Start(); + void Stop(); + + NResult AddNewIP(const std::string &ip, int fd); + NResult GetFdByIP(const std::string &ip, int &fd); + NResult RemoveIP(const std::string &ip); + NResult RemoveByFD(int fd); + + static bool DefaultConnBrokenCheckCB(int fd); + static void DefaultConnBrokenPostCB(int fd); + +private: + void RunInThread(); + void HandleEpollEvent(uint32_t eventCount, struct epoll_event *events); + +private: + std::string mName; + std::map mIpFdMap; + std::map mFdIpMap; + std::mutex mMutex; + RKeepaliveConfig mKeepaliveConfig; + RIPConnBrokenCheckHandler mConnBrokenCheckHandler; + RIPConnBrokenPostHandler mConnBrokenPostHandler; + + int mEpollHandle = -1; + + std::thread mWorkingThread; + std::atomic mStarted; + bool mNeedStop = false; +}; +} +} + +#endif // OCK_RDMA_HEARTBEAT_1245609845341233_H diff --git a/src/transport/rdma/rdma_mr_dm_buf.cpp b/src/transport/rdma/rdma_mr_dm_buf.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80ca5a0783397474334a3ff97a8a9cf6bd47694a --- /dev/null +++ b/src/transport/rdma/rdma_mr_dm_buf.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#ifdef RDMA_CX5_BUILD_ENABLED + +#include "rdma_mr_dm_buf.h" + +namespace ock { +namespace hcom { + +RResult RDMAMemoryRegionDmBuffer::Create(const std::string &name, RDMAContext *ctx, uint32_t singleSegSize, + uint32_t segCount, RDMAMemoryRegionDmBuffer *&buf) +{ + auto tmp = new (std::nothrow) RDMAMemoryRegionDmBuffer(name, ctx, singleSegSize, segCount); + if (tmp == nullptr) { + NN_LOG_ERROR("Failed to create rdma mr dm buffer"); + return RR_NEW_OBJECT_FAILED; + } + buf = tmp; + return RR_OK; +} + +RResult RDMAMemoryRegionDmBuffer::Initialize() +{ + RResult result = RR_OK; + if ((result = RDMAMemoryRegion::InitializeForDm()) != RR_OK) { + return result; + } + // init un-allocated + mBuf = reinterpret_cast(memalign(PAGE_ALIGN_H, sizeof(RDMAMemoryRegionDMMgr) * mSegCount)); + if (mBuf == 0) { + NN_LOG_ERROR("Failed to allocate memory for RDMAMemoryRegionDmBuffer " << mName); + return RR_MEMORY_ALLOCATE_FAILED; + } + uintptr_t address = mBuf; + for (uint32_t i = 0; i < mSegCount; i++) { + auto tmpDm = reinterpret_cast(address); + tmpDm->offset = mSingleSegSize * i; + mLinkList.PushFront(address); + address += sizeof(RDMAMemoryRegionDMMgr); + } + + return RR_OK; +} + +void RDMAMemoryRegionDmBuffer::UnInitialize() +{ + RDMAMemoryRegion::UnInitializeForDm(); +} +} +} +#endif +#endif \ No newline at end of file diff --git a/src/transport/rdma/rdma_mr_dm_buf.h b/src/transport/rdma/rdma_mr_dm_buf.h new file mode 100644 index 0000000000000000000000000000000000000000..b72d7d2f349202794a011fefcf5bd42f6f91b57c --- /dev/null +++ b/src/transport/rdma/rdma_mr_dm_buf.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_RDMA_MR_DM_BUF_H +#define HCOM_RDMA_MR_DM_BUF_H +#ifdef RDMA_BUILD_ENABLED +#ifdef RDMA_CX5_BUILD_ENABLED + +#include "rdma_mr_pool.h" + +namespace ock { +namespace hcom { + +struct RDMAMemoryRegionDMMgr { + uintptr_t next; + uint64_t offset; +}; + +class RDMAMemoryRegionDmBuffer : public RDMAMemoryRegion { +public: + static RResult Create(const std::string &name, RDMAContext *ctx, uint32_t singleSegSize, uint32_t segCount, + RDMAMemoryRegionDmBuffer *&buf); + +public: + RDMAMemoryRegionDmBuffer(const std::string &name, RDMAContext *ctx, uint32_t singleSegSize, uint32_t segCount) + : RDMAMemoryRegion(name, ctx, static_cast(singleSegSize) * static_cast(segCount)), + mSingleSegSize(singleSegSize), + mSegCount(segCount) + { + OBJ_GC_INCREASE(RDMAMemoryRegionDmBuffer); + } + + ~RDMAMemoryRegionDmBuffer() override + { + UnInitialize(); + OBJ_GC_DECREASE(RDMAMemoryRegionDmBuffer); + } + + RResult Initialize() override; + + inline bool GetFreeBuffer(uintptr_t &item) + { + return mLinkList.Pop(item); + } + + inline bool ReturnBuffer(uintptr_t value) + { + mLinkList.PushFront(value); + return true; + } + +protected: + void UnInitialize() override; + +private: + uint32_t mSingleSegSize = MR_DM_BUFFER_DEFAULT_SEG_SIZE; + uint32_t mSegCount = MR_DM_BUFFER_DEFAULT_SEG_COUNT; + + // uintptr_p store the start address of each mr segment + NetBucketLinkedList mLinkList; +}; +} +} +#endif +#endif +#endif // HCOM_RDMA_MR_DM_BUF_H \ No newline at end of file diff --git a/src/transport/rdma/rdma_mr_fixed_buf.cpp b/src/transport/rdma/rdma_mr_fixed_buf.cpp new file mode 100644 index 0000000000000000000000000000000000000000..893160c36fb01c584d99fcfc4c49a0f339b98e4c --- /dev/null +++ b/src/transport/rdma/rdma_mr_fixed_buf.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED + +#include "rdma_mr_fixed_buf.h" + +namespace ock { +namespace hcom { + +RResult RDMAMemoryRegionFixedBuffer::Create(const std::string &name, RDMAContext *ctx, uint32_t singleSegSize, + uint32_t segCount, RDMAMemoryRegionFixedBuffer *&buf) +{ + auto tmp = new (std::nothrow) RDMAMemoryRegionFixedBuffer(name, ctx, singleSegSize, segCount); + if (tmp == nullptr) { + NN_LOG_ERROR("Failed to create rdma mr fixed buffer"); + return RR_NEW_OBJECT_FAILED; + } + buf = tmp; + return RR_OK; +} + +RResult RDMAMemoryRegionFixedBuffer::Initialize() +{ + RResult result = RR_OK; + if ((result = RDMAMemoryRegion::Initialize()) != RR_OK) { + NN_LOG_ERROR("Failed to initialize rdma mr res = " << result); + return result; + } + + // init un-allocated + uintptr_t address = mBuf; + for (uint32_t i = 0; i < mSegCount; i++) { + mLinkList.PushFront(address); + address += mSingleSegSize; + } + + return RR_OK; +} + +void RDMAMemoryRegionFixedBuffer::UnInitialize() +{ + RDMAMemoryRegion::UnInitialize(); +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/rdma/rdma_mr_fixed_buf.h b/src/transport/rdma/rdma_mr_fixed_buf.h new file mode 100644 index 0000000000000000000000000000000000000000..16e6dc7899afaebdb567e6ab9cdc6306ecca9055 --- /dev/null +++ b/src/transport/rdma/rdma_mr_fixed_buf.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_RDMA_MR_FIXED_BUF_H +#define HCOM_RDMA_MR_FIXED_BUF_H +#ifdef RDMA_BUILD_ENABLED + +#include "rdma_mr_pool.h" + +namespace ock { +namespace hcom { + +class RDMAMemoryRegionFixedBuffer : public RDMAMemoryRegion { +public: + static RResult Create(const std::string &name, RDMAContext *ctx, uint32_t singleSegSize, uint32_t segCount, + RDMAMemoryRegionFixedBuffer *&buf); + +public: + RDMAMemoryRegionFixedBuffer(const std::string &name, RDMAContext *ctx, uint32_t singleSegSize, uint32_t segCount) + : RDMAMemoryRegion(name, ctx, static_cast(singleSegSize) * static_cast(segCount)), + mSingleSegSize(singleSegSize), + mSegCount(segCount) + { + OBJ_GC_INCREASE(RDMAMemoryRegionFixedBuffer); + } + + ~RDMAMemoryRegionFixedBuffer() override + { + UnInitialize(); + OBJ_GC_DECREASE(RDMAMemoryRegionFixedBuffer); + } + + RResult Initialize() override; + + inline bool GetFreeBuffer(uintptr_t &item) + { + return mLinkList.Pop(item); + } + + inline bool GetFreeBufferN(uintptr_t *&items, uint32_t n) + { + return mLinkList.PopN(items, n); + } + + inline bool ReturnBuffer(uintptr_t value) + { + mLinkList.PushFront(value); + return true; + } + + std::string ToString() + { + std::ostringstream oss; + oss << "buf-address " << mBuf << ", mSingleSegSize " << mSingleSegSize << ", mSegCount " << mSegCount << + ", total buf size " << mSize; + if (mMemReg != nullptr) { + oss << ", mrLKey " << mMemReg->lkey << ", mrRKey " << mMemReg->rkey << ", mrSize " << mMemReg->length; + } + return oss.str(); + } + + inline uint32_t GetSingleSegSize() const + { + return mSingleSegSize; + } + +protected: + void UnInitialize() override; + +private: + uint32_t mSingleSegSize = MR_FIXED_POOL_DEFAULT_SEG_SIZE; + uint32_t mSegCount = MR_FIXED_POOL_DEFAULT_SEG_COUNT; + + // uintptr_p store the start address of each mr segment + NetBucketLinkedList mLinkList; +}; +} +} +#endif +#endif // HCOM_RDMA_MR_FIXED_BUF_H \ No newline at end of file diff --git a/src/transport/rdma/rdma_mr_pool.cpp b/src/transport/rdma/rdma_mr_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c53cc7b03284c9e44ab825ab32fed01f6c02400 --- /dev/null +++ b/src/transport/rdma/rdma_mr_pool.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include + +#include "verbs_api_wrapper.h" +#include "rdma_mr_pool.h" + +namespace ock { +namespace hcom { +uint64_t RDMAMemoryRegion::gPageSize = sysconf(_SC_PAGESIZE); + +RResult RDMAMemoryRegion::Create(const std::string &name, RDMAContext *ctx, uint64_t size, RDMAMemoryRegion *&buf) +{ + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Create rdma mr param invalid"); + return RR_PARAM_INVALID; + } + + auto tmpBuf = new (std::nothrow) RDMAMemoryRegion(name, ctx, size); + if ((NN_UNLIKELY(tmpBuf == nullptr))) { + NN_LOG_ERROR("Failed to create rdma mr"); + return RR_NEW_OBJECT_FAILED; + } + + buf = tmpBuf; + + return RR_OK; +} + +RResult RDMAMemoryRegion::Create(const std::string &name, RDMAContext *ctx, uintptr_t address, uint64_t size, + RDMAMemoryRegion *&buf) +{ + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Create rdma mr param invalid"); + return RR_PARAM_INVALID; + } + + auto tmpBuf = new (std::nothrow) RDMAMemoryRegion(name, ctx, address, size); + if ((NN_UNLIKELY(tmpBuf == nullptr))) { + NN_LOG_ERROR("Failed to create rdma mr"); + return RR_NEW_OBJECT_FAILED; + } + + buf = tmpBuf; + + return RR_OK; +} + +RResult RDMAMemoryRegion::Initialize() +{ + if (mMemReg != nullptr) { + return RR_OK; + } + + if (mRDMAContext == nullptr || mRDMAContext->mProtectDomain == nullptr) { + NN_LOG_ERROR("Failed to initialize RDMAMemoryRegion as rdma context or pd is null"); + return RR_PARAM_INVALID; + } + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + ibv_mr *tmpMR = nullptr; + if (mExternalMemory) { + // the memory is allocated externally + // register mr directly + auto tmpBuf = reinterpret_cast(mBuf); + tmpMR = HcomIbv::RegMr(mRDMAContext->mProtectDomain, tmpBuf, mSize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE); + if (tmpMR == nullptr) { + NN_LOG_ERROR("Failed to register external memory for RDMAMemoryRegion " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << ", buffer " << tmpBuf); + return RR_MR_REG_FAILED; + } + } else { + // allocate memory + if (gPageSize <= 0) { + NN_LOG_ERROR("Failed to get system page size, page size: " << gPageSize); + return RR_PARAM_INVALID; + } + auto tmpBuf = memalign(gPageSize, mSize); + if (tmpBuf == nullptr) { + NN_LOG_ERROR("Failed to allocate memory for RDMAMemoryRegion " << mName << " with size " << mSize); + return RR_MEMORY_ALLOCATE_FAILED; + } + + // register memory region to card + tmpMR = HcomIbv::RegMr(mRDMAContext->mProtectDomain, tmpBuf, mSize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE); + if (tmpMR == nullptr) { + free(tmpBuf); + tmpBuf = nullptr; + NN_LOG_ERROR("Failed to register memory for RDMAMemoryRegion " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_MR_REG_FAILED; + } + + mBuf = reinterpret_cast(tmpBuf); + } + + mMemReg = tmpMR; + mLKey = mMemReg->lkey; + + return RR_OK; +} + +void RDMAMemoryRegion::UnInitialize() +{ + if (mMemReg == nullptr) { + return; + } + + HcomIbv::DeregMr(mMemReg); + if (!mExternalMemory) { + if (mBuf != 0) { + free(reinterpret_cast(mBuf)); + } + } + mRDMAContext->DecreaseRef(); + + mMemReg = nullptr; + mBuf = 0; + mRDMAContext = nullptr; +} + +} +} +#endif \ No newline at end of file diff --git a/src/transport/rdma/rdma_mr_pool.h b/src/transport/rdma/rdma_mr_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..48a771df95821c1386be08ffa2a65adca7084b06 --- /dev/null +++ b/src/transport/rdma/rdma_mr_pool.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_RDMA_MR_12342341232433_H +#define OCK_RDMA_MR_12342341232433_H +#ifdef RDMA_BUILD_ENABLED + +#include "hcom.h" + +#include "net_bucket_linked_list.h" +#include "rdma_verbs_wrapper_ctx.h" +#include "net_util.h" + +namespace ock { +namespace hcom { +class RDMAMemoryRegion : public UBSHcomNetMemoryRegion { +public: + static RResult Create(const std::string &name, RDMAContext *ctx, uint64_t size, RDMAMemoryRegion *&buf); + static RResult Create(const std::string &name, RDMAContext *ctx, uintptr_t address, uint64_t size, + RDMAMemoryRegion *&buf); + + RDMAMemoryRegion() = delete; + RDMAMemoryRegion(const RDMAMemoryRegion &other) = delete; + RDMAMemoryRegion(RDMAMemoryRegion &&other) = delete; + RDMAMemoryRegion &operator = (const RDMAMemoryRegion &) = delete; + RDMAMemoryRegion &operator = (RDMAMemoryRegion &&) = delete; + + ~RDMAMemoryRegion() override + { + OBJ_GC_DECREASE(RDMAMemoryRegion); + } + + RResult Initialize() override; + void UnInitialize() override; + + void *GetMemorySeg() override + { + return nullptr; + } + + void GetVa(uint64_t &va, uint64_t &va_len, uint32_t &token_id) override + { + return; + } + +public: + RDMAContext *mRDMAContext = nullptr; + +protected: + RDMAMemoryRegion(const std::string &name, RDMAContext *ctx, uint64_t size) + : UBSHcomNetMemoryRegion(name, false, 0, size), mRDMAContext(ctx) + { + // increase the reference count of context + if (ctx != nullptr) { + ctx->IncreaseRef(); + } + + OBJ_GC_INCREASE(RDMAMemoryRegion); + } + + RDMAMemoryRegion(const std::string &name, RDMAContext *ctx, uintptr_t address, uint64_t size) + : UBSHcomNetMemoryRegion(name, true, address, size), mRDMAContext(ctx) + { + // increase the reference count of context + if (ctx != nullptr) { + ctx->IncreaseRef(); + } + + OBJ_GC_INCREASE(RDMAMemoryRegion); + } + +protected: + ibv_mr *mMemReg = nullptr; + + static uint64_t gPageSize; +}; +} +} +#endif +#endif // _OCK_RDMA_MR_12342341232433_H diff --git a/src/transport/rdma/verbs/net_rdma_async_endpoint.cpp b/src/transport/rdma/verbs/net_rdma_async_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8e5780262e7dc5ccf5411b44db3e0c4eca8b2e10 --- /dev/null +++ b/src/transport/rdma/verbs/net_rdma_async_endpoint.cpp @@ -0,0 +1,591 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include "net_common.h" +#include "net_rdma_driver_oob.h" +#include "net_security_rand.h" +#include "rdma_validation.h" +#include "net_rdma_async_endpoint.h" + +namespace ock { +namespace hcom { +NetAsyncEndpoint::NetAsyncEndpoint(uint64_t id, RDMAAsyncEndPoint *ep, NetDriverRDMAWithOob *driver, + const UBSHcomNetWorkerIndex &workerIndex) + : NetEndpointImpl(id, workerIndex), mEp(ep), mDriver(driver) +{ + if (mDriver != nullptr) { + mDriver->IncreaseRef(); + } + + if (mEp != nullptr) { + mEp->IncreaseRef(); + } + + if (mEp != nullptr && mDriver != nullptr) { + mSegSize = mDriver->mOptions.mrSendReceiveSegSize < mEp->Qp()->PostSendMaxSize() ? + mDriver->mOptions.mrSendReceiveSegSize : + mEp->Qp()->PostSendMaxSize(); + mAllowedSize = mSegSize - sizeof(UBSHcomNetTransHeader); + } + + mIsNeedSendHb = true; + if (mDriver != nullptr) { + mHeartBeatIdleTime = mDriver->GetHbIdleTime(); + UpdateTargetHbTime(); + } + + OBJ_GC_INCREASE(NetAsyncEndpoint); +} + +NetAsyncEndpoint::~NetAsyncEndpoint() +{ + if (mEp != nullptr) { + mEp->DecreaseRef(); + mEp = nullptr; + } + + if (mDriver != nullptr) { + mDriver->DecreaseRef(); + mDriver = nullptr; + } + + OBJ_GC_DECREASE(NetAsyncEndpoint); +} + +uint32_t NetAsyncEndpoint::GetSendQueueCount() +{ + return mEp->Qp()->GetSendQueueSize(); +} + +NResult NetAsyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, mDriver, opCode, request, mAllowedSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post send as validate fail"); + return result; + } + + // get mr from pool + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to async post send with seqNo as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = seqNO == 0 ? NextSeq() : seqNO; + header->flags = NTH_TWO_SIDE; + + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("RDMA Failed to async post send with seq no as encryption failure"); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + header->dataLength = request.size; + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("RDMA Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + + // change lAddress to mrAddress and set lKey + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + + UBSHcomNetTransRequest rdmaReq = request; + rdmaReq.lAddress = mrBufAddress; + rdmaReq.lKey = mDriver->mDriverSendMR->GetLKey(); + rdmaReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + + auto sendFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_ASYNC_POST_SEND); + do { + result = worker->PostSend(mEp->Qp(), rdmaReq); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + sendFlag = false; + } while (sendFlag); + + NN_LOG_ERROR("Failed to async post send with seq no, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND, result); + return result; +} + +NResult NetAsyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo, const UBSHcomExtHeaderType extHeaderType, const void *extHeader, + uint32_t extHeaderSize) +{ + if (NN_UNLIKELY(extHeaderType == UBSHcomExtHeaderType::RAW)) { + NN_LOG_ERROR("Shouldn't use RAW type when extHeader is given."); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(!extHeader)) { + NN_LOG_ERROR("The ExtHeader is invalid."); + return NN_INVALID_PARAM; + } + + // 保证 extHeaderSize + request.size <= mAllowedSize. + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, mDriver, opCode, request, mAllowedSize - extHeaderSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post send as validate fail"); + return result; + } + + // get mr from pool + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("RDMA failed to async post send with op info as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header->timeout = opInfo.timeout; + header->errorCode = opInfo.errorCode; + header->dataLength = request.size + extHeaderSize; + header->extHeaderType = extHeaderType; + + if (mIsNeedEncrypt) { + NN_LOG_WARN("postsend encrypt is not supported now!"); + } + + // 拷贝上层指定的 header,此时将要发送的结构为 + // | UBSHcomNetTransHeader | extHeader | request body | + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader), extHeader, + extHeaderSize) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + + // 拷贝消息主体 + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader) + extHeaderSize), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader) - extHeaderSize, + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + + // change lAddress to mrAddress and set lKey + UBSHcomNetTransRequest rdmaReq = request; + rdmaReq.lAddress = mrBufAddress; + rdmaReq.lKey = mDriver->mDriverSendMR->GetLKey(); + rdmaReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + + auto sendOpFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_ASYNC_POST_SEND); + do { + result = worker->PostSend(mEp->Qp(), rdmaReq); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND, result); + return RR_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + sendOpFlag = false; + } while (sendOpFlag); + + NN_LOG_ERROR("Failed to async post send with op info, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND, result); + return result; +} + +NResult NetAsyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) +{ + NResult res = NN_OK; + if (NN_UNLIKELY((res = PostSendValidation(mState, mId, mDriver, opCode, request, mAllowedSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post send as validate fail"); + return res; + } + + // get mr from pool + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to async post send with opInfo as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header->timeout = opInfo.timeout; + header->errorCode = opInfo.errorCode; + + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("Mlx5 Failed to async post send with op info as encryption failure"); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + header->dataLength = request.size; + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + } + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + + // change lAddress to mrAddress and set lKey + UBSHcomNetTransRequest rdmaReq = request; + rdmaReq.lAddress = mrBufAddress; + rdmaReq.lKey = mDriver->mDriverSendMR->GetLKey(); + rdmaReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + + auto sendOpFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_ASYNC_POST_SEND); + do { + res = worker->PostSend(mEp->Qp(), rdmaReq); + if (res == RR_OK) { + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND, res); + return NN_OK; + } else if (NeedRetry(res) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + sendOpFlag = false; + } while (sendOpFlag); + + NN_LOG_ERROR("Failed to async post send with op info, result " << res); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND, res); + return res; +} + +NResult NetAsyncEndpoint::PostSendSglInline( + uint16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo) +{ + // 需要加密必定会涉及到内存拷贝,仍然走非inline方式 + if (mIsNeedEncrypt) { + return PostSend(opCode, request, opInfo); + } + + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, mDriver, opCode, request, mAllowedSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post send as validate fail"); + return result; + } + + UBSHcomNetTransHeader header; + header.opCode = opCode; + header.seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header.flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header.timeout = opInfo.timeout; + header.errorCode = opInfo.errorCode; + header.dataLength = request.size; + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + bool flag = true; + uint64_t finishTime = GetFinishTime(); + do { + result = worker->PostSendSglInline(mEp->Qp(), header, request); + if (result == RR_OK) { + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + return result; +} + +NResult NetAsyncEndpoint::PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendRawValidation(mState, mId, mDriver, seqNo, request, mSegSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post send raw as validate fail"); + return result; + } + + /* get mr from pool */ + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to post message as failed to get mr buffer from pool from driver " << mDriver->Name()); + return NN_GET_BUFF_FAILED; + } + + size_t msgSize = 0; + if (!mIsNeedEncrypt) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress), mDriver->mDriverSendMR->GetSingleSegSize(), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + msgSize = request.size; + } else { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(mrBufAddress), cipherLen)) { + NN_LOG_ERROR("Failed send message as encryption failure in rdma"); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + msgSize = cipherLen; + } + + UBSHcomNetTransRequest rdmaReq = request; + rdmaReq.lAddress = mrBufAddress; + rdmaReq.lKey = mDriver->mDriverSendMR->GetLKey(); + rdmaReq.size = msgSize; + + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + auto sendRawAsyncFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_ASYNC_POST_SEND_RAW); + do { + result = worker->PostSend(mEp->Qp(), rdmaReq, seqNo); + if (NN_LIKELY(result == RR_OK)) { + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND_RAW, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(NN_NO128); + continue; + } + // no retry result or timeout = 0 + sendRawAsyncFlag = false; + } while (sendRawAsyncFlag); + + NN_LOG_ERROR("Failed to post send raw request, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND_RAW, result); + return result; +} + +NResult NetAsyncEndpoint::PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) +{ + size_t size = 0; + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendSglValidation(mState, mId, mDriver, seqNo, request, mSegSize, size, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post send raw sgl as validate fail"); + return result; + } + + UBSHcomNetTransRequest tlsReq {}; + uintptr_t mrBufAddress = 0; + if (mIsNeedEncrypt) { + if (NN_UNLIKELY(EncryptRawSgl(tlsReq, mrBufAddress, size, mAes, mDriver, request, mSecrets) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post send raw sgl as encrypt fail"); + return NN_ENCRYPT_FAILED; + } + } + + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_ASYNC_POST_SEND_RAW_SGL); + do { + result = worker->PostSendSgl(mEp->Qp(), request, tlsReq, seqNo, mIsNeedEncrypt); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND_RAW_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep眠 + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + if (mIsNeedEncrypt) { + (void)mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + } + + NN_LOG_ERROR("RDMA Failed to post send raw sgl request, result " << result); + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_SEND_RAW_SGL, result); + return result; +} + +NResult NetAsyncEndpoint::PostRead(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteValidation(mState, mId, mDriver, request)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post read as validate fail"); + return result; + } + + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + auto asyncReadFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_ASYNC_POST_READ); + do { + result = worker->PostRead(mEp->Qp(), request); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_READ, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + asyncReadFlag = false; + } while (asyncReadFlag); + + NN_LOG_ERROR("Failed to post read request, result " << result); + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_READ, result); + return result; +} + +NResult NetAsyncEndpoint::PostRead(const UBSHcomNetTransSglRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteSglValidation(mState, mId, mDriver, request)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post read sgl as validate fail"); + return result; + } + + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_ASYNC_POST_READ_SGL); + do { + result = worker->PostOneSideSgl(mEp->Qp(), request); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_READ_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post read sgl request, result " << result); + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_READ_SGL, result); + return result; +} + +NResult NetAsyncEndpoint::PostWrite(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteValidation(mState, mId, mDriver, request)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post write as validate fail"); + return result; + } + + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + + auto asyncWriteFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_ASYNC_POST_WRITE); + do { + result = worker->PostWrite(mEp->Qp(), request); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_WRITE, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + asyncWriteFlag = false; + } while (asyncWriteFlag); + + NN_LOG_ERROR("Failed to post write request, result " << result); + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_WRITE, result); + return result; +} + +NResult NetAsyncEndpoint::PostWrite(const UBSHcomNetTransSglRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteSglValidation(mState, mId, mDriver, request)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to async post write sgl as validate fail"); + return result; + } + + auto worker = reinterpret_cast(mEp->Qp()->UpContext1()); + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_ASYNC_POST_WRITE_SGL); + do { + result = worker->PostOneSideSgl(mEp->Qp(), request, false); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_WRITE_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post write sgl request, result " << result); + TRACE_DELAY_END(RDMA_EP_ASYNC_POST_WRITE_SGL, result); + return result; +} + +void NetAsyncEndpoint::UpdateTargetHbTime() +{ + mTargetHbTime = NetMonotonic::TimeSec() + mHeartBeatIdleTime; +} +} +} +#endif diff --git a/src/transport/rdma/verbs/net_rdma_async_endpoint.h b/src/transport/rdma/verbs/net_rdma_async_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..ab91eda863b65051c5f1b99709591ed3ba0a9a4e --- /dev/null +++ b/src/transport/rdma/verbs/net_rdma_async_endpoint.h @@ -0,0 +1,268 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_ASYNC_ENDPOINT_RDMA_H +#define OCK_HCOM_NET_ASYNC_ENDPOINT_RDMA_H +#ifdef RDMA_BUILD_ENABLED + +#include "hcom.h" +#include "transport/net_endpoint_impl.h" +#include "rdma_composed_endpoint.h" +#include "net_monotonic.h" +#include "net_rdma_driver_oob.h" +#include "net_security_alg.h" +#include "hcom_utils.h" + +namespace ock { +namespace hcom { +class NetAsyncEndpoint : public NetEndpointImpl { +public: + NetAsyncEndpoint(uint64_t id, RDMAAsyncEndPoint *ep, NetDriverRDMAWithOob *driver, + const UBSHcomNetWorkerIndex &workerIndex); + ~NetAsyncEndpoint() override; + + NResult SetEpOption(UBSHcomEpOptions &epOptions) override + { + NN_LOG_WARN("Empty function for now"); + return NN_OK; + } + + const std::string &PeerIpAndPort() override + { + if (mEp != nullptr) { + return mEp->PeerIpAndPort(); + } + + return CONST_EMPTY_STRING; + } + + uint32_t GetSendQueueCount() override; + + const std::string &UdsName() override + { + NN_LOG_WARN("Empty function for now"); + return CONST_EMPTY_STRING; + } + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + + NResult PostSendSglInline(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + + /* + * @brief raw data to peer without opcode + */ + NResult PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + NResult PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) override; + + NResult PostRead(const UBSHcomNetTransSglRequest &request) override; + NResult PostRead(const UBSHcomNetTransRequest &request) override; + NResult PostWrite(const UBSHcomNetTransRequest &request) override; + NResult PostWrite(const UBSHcomNetTransSglRequest &request) override; + void UpdateTargetHbTime(); + + bool checkTargetHbTime(uint64_t currTime) + { + if (mTargetHbTime < currTime) { + mTargetHbTime = currTime + mHeartBeatIdleTime; + return true; + } + return false; + } + + NResult WaitCompletion(int32_t timeout) override + { + NN_LOG_WARN("Invalid operation, wait completion is not supported by NetAsyncEndpoint"); + return NN_INVALID_OPERATION; + } + + NResult Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) override + { + NN_LOG_WARN("Invalid operation, wait completion is not supported by NetAsyncEndpoint"); + return NN_INVALID_OPERATION; + } + + NResult ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) override + { + NN_LOG_WARN("Invalid operation, wait completion is not supported by NetAsyncEndpoint"); + return NN_INVALID_OPERATION; + } + + inline void HbRecordCount() + { + __sync_add_and_fetch(&mHbCount, 1); + } + + inline bool HbCheckStateNormal() + { + if (mHbCount > mHbLastCount) { + mHbLastCount = mHbCount; + return true; + } + + return false; + } + + inline void SetRemoteHbInfo(uintptr_t address, uint32_t key, uint64_t size) + { + mRemoteHbAddress = address; + mRemoteHbKey = key; + mHbMrSize = size; + } + + inline void SetHbBrokenEp() + { + mHbBrokenEp = true; + } + + inline bool HbBrokenEp() const + { + return mHbBrokenEp; + } + + inline RDMAAsyncEndPoint *GetRdmaEp() + { + return mEp; + } + + NResult GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &verbsIdInfo) override + { + if (!mState.Compare(NEP_ESTABLISHED)) { + NN_LOG_ERROR("[RDMA AsyncEp] EP is not established"); + return NN_EP_NOT_ESTABLISHED; + } + + if (!mDriver->mStartOobSvr) { + NN_LOG_ERROR("[RDMA AsyncEp] oob server is not start"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + if (mDriver->mOptions.oobType != NET_OOB_UDS) { + NN_LOG_ERROR("[RDMA AsyncEp] oob type is not uds"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + verbsIdInfo = mRemoteUdsIdInfo; + return NN_OK; + } + + bool GetPeerIpPort(std::string &ip, uint16_t &port) override + { + if (NN_UNLIKELY(mEp == nullptr)) { + return false; + } + + auto ipAndPort = mEp->PeerIpAndPort(); + if (NN_UNLIKELY(ipAndPort.empty())) { + NN_LOG_ERROR("[RDMA AsyncEp] ip and port of peer is empty"); + return false; + } + + std::vector ipPortVec; + NetFunc::NN_SplitStr(ipAndPort, ":", ipPortVec); + if (NN_UNLIKELY(ipPortVec.size() != NN_NO2)) { + NN_LOG_ERROR("[RDMA AsyncEp] ip and port of peer is invalid"); + return false; + } + + try { + port = std::stoi(ipPortVec[1]); + } catch (...) { + NN_LOG_ERROR("[RDMA AsyncEp] port of peer is invalid"); + return false; + } + if (port == 0) { + NN_LOG_ERROR("[RDMA AsyncEp] oob type is uds, does not have peer ip and port msg"); + return false; + } + ip = ipPortVec[0]; + + return true; + } + + void Close() override + { + NN_LOG_INFO("Close ep id " << mId << " by user"); + auto qp = GetRdmaEp()->Qp(); + qp->Stop(); + } + +protected: + // extHeader 可以认为是服务层头部,它可以是 UBSHcomFragmentHeader 也可以是通用的服 + // 务器头部(暂未实现)。 + // Q:服务层的头部也可以看做是 request 的一部分,为什么不把它放进 request 中? + // A:request 是用户直接传递进来的,而 extHeader 可能是服务层自己生成的,它 + // 们两个内存不连续,强行令它们归一在 request 中需要另外额外的 memcpy. 比较 + // 好的方式是通过 iov 的方式发送。 + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo, + const UBSHcomExtHeaderType extHeaderType, const void *extHeader, uint32_t extHeaderSize) override; + +private: + uint64_t inline GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + + bool inline NeedRetry(RResult &result) + { + if (!State().Compare(NEP_ESTABLISHED)) { + result = NN_EP_NOT_ESTABLISHED; + return false; + } + + if (result == RR_QP_POST_SEND_WR_FULL || result == RR_QP_ONE_SIDE_WR_FULL || result == RR_QP_CTX_FULL) { + return true; + } + + return false; + } + + inline RDMAQp *GetQp() const + { + if (NN_UNLIKELY(mEp == nullptr)) { + return nullptr; + } + return mEp->Qp(); + } + + inline RDMAWorker *GetWorker() const + { + return reinterpret_cast(mEp->Qp()->UpContext1()); + } + + RDMAAsyncEndPoint *mEp = nullptr; + NetDriverRDMAWithOob *mDriver = nullptr; + + bool mHbBrokenEp = false; + uint64_t mHbCount = 1; + uint64_t mHbLastCount = 0; + uintptr_t mRemoteHbAddress = 0; + uint32_t mRemoteHbKey = 0; + uint64_t mHbMrSize = 0; + uint64_t mTargetHbTime = 0; + uint16_t mHeartBeatIdleTime = NN_NO60; + + friend class NetDriverRDMAWithOob; + friend class NetHeartbeat; +}; +} +} + +#endif +#endif // OCK_HCOM_NET_ASYNC_ENDPOINT_RDMA_H diff --git a/src/transport/rdma/verbs/net_rdma_driver.cpp b/src/transport/rdma/verbs/net_rdma_driver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..64da83fd24e95ed74242cee0b7149774fe731a2c --- /dev/null +++ b/src/transport/rdma/verbs/net_rdma_driver.cpp @@ -0,0 +1,621 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "hcom_def.h" +#ifdef RDMA_BUILD_ENABLED +#include "net_rdma_driver.h" +#include "net_rdma_sync_endpoint.h" +#include "net_rdma_async_endpoint.h" +#include "openssl_api_wrapper.h" +#include "rdma_common.h" +#include "rdma_mr_dm_buf.h" +#include "rdma_mr_fixed_buf.h" + +namespace ock { +namespace hcom { +NResult NetDriverRDMA::Initialize(const UBSHcomNetDriverOptions &option) +{ + std::lock_guard lock(mInitMutex); + if (mInited) { + return NN_OK; + } + + mOptions = option; + + if (NN_UNLIKELY(UBSHcomNetOutLogger::Instance() == nullptr)) { + return NN_NOT_INITIALIZED; + } + + NResult verbsRes = NN_OK; + if (NN_UNLIKELY((verbsRes = mOptions.ValidateCommonOptions()) != NN_OK)) { + return verbsRes; + } + + if (NN_UNLIKELY((verbsRes = ValidateOptions()) != NN_OK)) { + return verbsRes; + } + + NN_LOG_INFO("RDMA driver try to initialize with " << mOptions.ToString()); + + if (option.enableTls) { + if (HcomSsl::Load() != 0) { + NN_LOG_ERROR("Failed to load openssl API"); + return NN_NOT_INITIALIZED; + } + } + mEnableTls = option.enableTls; + mHeartBeatIdleTime = mOptions.heartBeatIdleTime; + mHeartBeatProbeInterval = mOptions.heartBeatProbeInterval; + + // create context and initialize + if (((verbsRes = CreateContext()) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to create ctx"); + UnInitializeInner(); + return verbsRes; + } + + if (((verbsRes = mContext->Initialize()) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to initialize ctx"); + UnInitializeInner(); + return verbsRes; + } + + if ((verbsRes = CreateWorkerResource()) != NN_OK) { + NN_LOG_ERROR("RDMA failed to create worker resource"); + UnInitializeInner(); + return verbsRes; + } + + if ((verbsRes = CreateWorkers()) != NN_OK) { + NN_LOG_ERROR("RDMA failed to create workers"); + UnInitializeInner(); + return verbsRes; + } + + /* create lb for client */ + if ((verbsRes = CreateClientLB()) != NN_OK) { + NN_LOG_ERROR("RDMA failed to create client lb"); + UnInitializeInner(); + return verbsRes; + } + + if ((verbsRes = DoInitialize()) != NN_OK) { + NN_LOG_ERROR("RDMA failed to do Initialize"); + UnInitializeInner(); + return verbsRes; + } + + mMrChecker.Reserve(NN_NO128); + mMrChecker.SetLockWhenOperates(false); + + mInited = true; + return NN_OK; +} + +NResult NetDriverRDMA::ValidateOptions() +{ + /* validate param related to device IpMask for RDMA and Sock */ + if (NN_UNLIKELY(!ValidateArrayOptions(mOptions.netDeviceIpMask, NN_NO256))) { + NN_LOG_ERROR("Option 'netDeviceIpMask' is invalid, " << mOptions.netDeviceIpMask << + " is set in driver,the Array max length is 256."); + return NN_INVALID_PARAM; + } + + if (mOptions.prePostReceiveSizePerQP == 0) { + NN_LOG_ERROR("Invalid option prePostReceiveSizePerQP " << mOptions.prePostReceiveSizePerQP << + ", should not be zero"); + return NN_INVALID_PARAM; + } + + if (mOptions.prePostReceiveSizePerQP > NN_NO1024) { + NN_LOG_WARN("Invalid option prePostReceiveSizePerQP " << mOptions.prePostReceiveSizePerQP << + ", should be <= " << NN_NO1024 << ", set to " << NN_NO1024); + mOptions.prePostReceiveSizePerQP = NN_NO1024; + } + + if (mOptions.maxPostSendCountPerQP == 0) { + NN_LOG_ERROR("Invalid option maxPostSendCountPerQP " << mOptions.maxPostSendCountPerQP << + ", should not be zero"); + return NN_INVALID_PARAM; + } + + if (mOptions.maxPostSendCountPerQP > NN_NO1024) { + NN_LOG_WARN("Invalid option maxPostSendCountPerQP " << mOptions.maxPostSendCountPerQP << ", should be <= " << + NN_NO1024 << ", set to " << NN_NO1024); + mOptions.maxPostSendCountPerQP = NN_NO1024; + } + + if (mOptions.maxPostSendCountPerQP > mOptions.prePostReceiveSizePerQP) { + NN_LOG_WARN("Invalid option maxPostSendCountPerQP " << mOptions.maxPostSendCountPerQP << + ", more than prePostReceiveSizePerQP " << mOptions.prePostReceiveSizePerQP << " , change to equal"); + mOptions.maxPostSendCountPerQP = mOptions.prePostReceiveSizePerQP; + } + + if (NN_UNLIKELY(ValidateAndParseOobPortRange(mOptions.oobPortRange) != NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidateOptionsOobType() != NN_OK)) { + return NN_INVALID_PARAM; + } + + return NN_OK; +} + +NResult NetDriverRDMA::CreateContext() +{ + if (mContext != nullptr) { + return NN_OK; + } + int result = 0; + if (mOptions.enableMultiRail) { + uint16_t enableCount = 0; + std::vector enableIps; + result = RDMADeviceHelper::GetEnableDeviceCount(mOptions.NetDeviceIpMask(), enableCount, enableIps, + mOptions.NetDeviceIpGroup()); + if (result != NN_OK) { + return result; + } + mMatchIp = enableIps[mDevIndex]; + } else { + // filter ip by mask + std::vector matchIps; + if ((result = MatchIpByMask(matchIps)) != 0) { + return result; + } + // init RoCE devices + if ((result = RDMADeviceHelper::Initialize()) != 0) { + NN_LOG_ERROR("Failed to init devices"); + return result; + } + + NN_LOG_INFO(RDMADeviceHelper::DeviceInfo()); + + // choose the first matched ip + mMatchIp = matchIps[0]; + } + RDMAGId tmpGid {}; + if ((result = RDMADeviceHelper::GetDeviceByIp(mMatchIp, tmpGid)) != 0) { + RDMADeviceHelper::UnInitialize(); + NN_LOG_ERROR("Failed to get device by ip"); + return result; + } + + NN_LOG_DEBUG("gid found devIndex " << tmpGid.devIndex << ", gidIndex " << tmpGid.gid << ", RoCEVersion " << + RDMADeviceHelper::RoCEVersionToStr(tmpGid.RoCEVersion)); + mBandWidth = tmpGid.bandWidth; + // create context + if ((result = RDMAContext::Create(mName, false, tmpGid, mContext)) != 0) { + RDMADeviceHelper::UnInitialize(); + NN_LOG_ERROR("Failed to new ctx, result " << result); + return result; + } + + NN_ASSERT_LOG_RETURN(mContext != nullptr, NN_ERROR); + mContext->IncreaseRef(); + + return NN_OK; +} + +NResult NetDriverRDMA::MatchIpByMask(std::vector &matchIps) +{ + std::vector filters; + NetFunc::NN_SplitStr(mOptions.NetDeviceIpMask(), ",", filters); + if (filters.empty()) { + NN_LOG_ERROR("Invalid ip mask '" << mOptions.netDeviceIpMask << "' by set, example '192.168.0.0/24'"); + return NN_INVALID_IP; + } + + for (auto &mask : filters) { + FilterIp(mask, matchIps); + } + + if (matchIps.empty()) { + NN_LOG_ERROR("No matched ip found with '" << mOptions.netDeviceIpMask << "', example '192.168.0.0/24'"); + return NN_INVALID_IP; + } + return NN_OK; +} + +NResult NetDriverRDMA::CreateSendMr() +{ + int result = 0; + // create mr pool for send/receive and initialize + if ((result = RDMAMemoryRegionFixedBuffer::Create(mName, mContext, mOptions.mrSendReceiveSegSize, + mOptions.mrSendReceiveSegCount, mDriverSendMR)) != 0) { + NN_LOG_ERROR("Failed to create mr for send/receive in NetDriverRDMA " << mName << ", result " << result); + return result; + } + mDriverSendMR->IncreaseRef(); + if ((result = mDriverSendMR->Initialize()) != 0) { + NN_LOG_ERROR("Failed to initialize mr for send/receive in NetDriverRDMA " << mName << ", result " << result); + return result; + } + + return NN_OK; +} + +NResult NetDriverRDMA::CreateOpCtxMemPool() +{ + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = NN_NextPower2(sizeof(RDMAOpContextInfo)); + options.tcExpandBlkCnt = NN_NO64; + mOpCtxMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mOpCtxMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for rdma op context pool " << mName << ", probably out of memory"); + return NN_INVALID_PARAM; + } + + auto result = mOpCtxMemPool->Initialize(); + if (result != NN_OK) { + mOpCtxMemPool.Set(nullptr); + NN_LOG_ERROR("Failed to initialize memory pool for rdma op context pool " << mName << + ", probably out of memory"); + return result; + } + + return NN_OK; +} + +NResult NetDriverRDMA::CreateSglCtxMemPool() +{ + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = NN_NextPower2(sizeof(RDMASglContextInfo)); + options.tcExpandBlkCnt = NN_NO64; + mSglCtxMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mSglCtxMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for rdma sgl op context in driver " << mName << + ", probably out of memory"); + return NN_INVALID_PARAM; + } + + auto result = mSglCtxMemPool->Initialize(); + if (result != NN_OK) { + mSglCtxMemPool.Set(nullptr); + NN_LOG_ERROR("Failed to initialize memory pool for rdma sgl op context in driver " << mName << + ", probably out of memory"); + return result; + } + + return NN_OK; +} + +NResult NetDriverRDMA::CreateWorkerResource() +{ + auto result = CreateSendMr(); + if (result != NN_OK) { + NN_LOG_ERROR("RDMA falied to create send mr"); + return result; + } + + result = CreateOpCtxMemPool(); + if (result != NN_OK) { + NN_LOG_ERROR("RDMA failed to create op ctx memory pool"); + return result; + } + + result = CreateSglCtxMemPool(); + if (result != NN_OK) { + NN_LOG_ERROR("RDMA failed to create Sgl ctx memory pool"); + return result; + } + + return NN_OK; +} + +void NetDriverRDMA::ClearWorkers() +{ + mWorkerGroups.clear(); + for (auto worker : mWorkers) { + worker->DecreaseRef(); + } + mWorkers.clear(); +} + +void NetDriverRDMA::DestroyEndpoint(UBSHcomNetEndpointPtr &ep) +{ + if (ep.Get() == nullptr) { + NN_LOG_WARN("The verbs ep is null already."); + return; + } + + NN_LOG_INFO("Verbs Destroy endpoint id " << ep->Id()); + mEndPointsMutex.lock(); + auto result = mEndPoints.erase(ep->Id()); + mEndPointsMutex.unlock(); + + if (result == 0) { + NN_LOG_WARN("Verbs unable to destroy endpoint as ep " << ep->Id() << " doesn't exist, maybe cleaned already"); + return; + } + + ep.Set(nullptr); +} + +NResult NetDriverRDMA::CreateWorkers() +{ + NResult result = NN_OK; + + std::vector workerGroups; + std::vector flatWorkerCpus; + std::vector workerThreadPriority; + std::vector> workerGroupCpus; + + /* parse */ + if (!(NetFunc::NN_ParseWorkersGroups(mOptions.WorkGroups(), workerGroups)) || + !(NetFunc::NN_ParseWorkerGroupsCpus(mOptions.WorkerGroupCpus(), workerGroupCpus)) || + !(NetFunc::NN_FinalizeWorkerGroupCpus(workerGroups, workerGroupCpus, mOptions.mode != NET_BUSY_POLLING, + flatWorkerCpus)) || + !(NetFunc::NN_ParseWorkersGroupsThreadPriority(mOptions.WorkerGroupThreadPriority(), workerThreadPriority, + workerGroups.size()))) { + NN_LOG_ERROR("Failed to parse worker or cpu groups"); + return NN_INVALID_PARAM; + } + + RDMAWorkerOptions options; + options.SetValue(mOptions); + if ((mOptions.workerThreadPriority != 0) && (!workerThreadPriority.empty())) { + NN_LOG_WARN("Driver options 'workerThreadPriority' and 'workerGroupsThreadPriority' set all, preferential use " + "'workerGroupsThreadPriority'."); + } + + /* create workers */ + mWorkers.reserve(flatWorkerCpus.size()); + uint32_t groupIndex = 0; + UBSHcomNetWorkerIndex workerIndex{}; + uint16_t totalWorkerIndex = 0; + for (auto item : workerGroups) { + NN_LOG_TRACE_INFO("Add worker " << groupIndex << ", item " << item); + /* The left of mWorkerGroups is the index of each group's first worker in the mWorkers */ + mWorkerGroups.emplace_back(totalWorkerIndex, item); + for (uint16_t i = 0; i < item; ++i) { + options.cpuId = flatWorkerCpus.at(totalWorkerIndex++); + if (!workerThreadPriority.empty()) { + options.threadPriority = workerThreadPriority[groupIndex]; + } + RDMAWorker *worker = nullptr; + if (NN_UNLIKELY( + (result = RDMAWorker::Create(mName, mContext, options, mOpCtxMemPool, mSglCtxMemPool, worker)) != 0)) { + return result; + } + + workerIndex.Set(i, groupIndex, mIndex); + worker->SetIndex(workerIndex); + + if (NN_UNLIKELY((result = worker->Initialize()) != NN_OK)) { + delete worker; + NN_LOG_ERROR("Failed to initialize rdma worker in driver " << mName << ", result " << result); + return NN_NEW_OBJECT_FAILED; + } + + worker->IncreaseRef(); + mWorkers.push_back(worker); + } + ++groupIndex; + } + + std::ostringstream groupInfo; + groupInfo << "Worker group info : "; + for (auto item : mWorkerGroups) { + groupInfo << "[" << item.first << " : " << item.second << "] "; + } + NN_LOG_TRACE_INFO(groupInfo.str()); + return NN_OK; +} + +void NetDriverRDMA::UnInitialize() +{ + std::lock_guard locker(mInitMutex); + if (!mInited) { + return; + } + if (mStarted) { + NN_LOG_WARN("Invalid to UnInitialize driver " << mName << " which is not stopped"); + return; + } + + DoUnInitialize(); + + UnInitializeInner(); + mInited = false; +} + +void NetDriverRDMA::UnInitializeInner() +{ + if (mContext != nullptr) { + mContext->DecreaseRef(); + mContext = nullptr; + } + + if (mDriverSendMR != nullptr) { + mDriverSendMR->DecreaseRef(); + mDriverSendMR = nullptr; + } + + if (mOpCtxMemPool != nullptr) { + mOpCtxMemPool.Set(nullptr); + } + + DestroyClientLB(); + ClearWorkers(); +} + +NResult NetDriverRDMA::Start() +{ + std::lock_guard locker(mInitMutex); + if (mStarted) { + return NN_OK; + } + + if (!mInited) { + NN_LOG_ERROR("Failed to start NetDriverRDMA " << mName << ", as isn't initialized"); + return NN_ERROR; + } + + NResult result = NN_OK; + if (!mOptions.dontStartWorkers) { + if (NN_UNLIKELY(result = ValidateHandlesCheck()) != NN_OK) { + return result; + } + for (uint64_t i = 0; i < mWorkers.size(); i++) { + if (NN_LIKELY((result = mWorkers[i]->Start()) == NN_OK)) { + continue; + } + NN_LOG_ERROR("Failed to start RDMA driver " << mName << " as failed to start worker"); + for (uint64_t j = 0; j < i; j++) { + mWorkers[j]->Stop(); + } + return result; + } + } else { + NN_LOG_INFO("Workers in driver " << mName << " will not be started as option dontStartWorkers is true"); + } + + if (NN_UNLIKELY(result = DoStart()) != NN_OK) { + NN_LOG_ERROR("Failed to do start NetDriverRDMA " << mName << ", result " << result); + for (auto worker : mWorkers) { + worker->Stop(); + } + return result; + } + mStarted = true; + return NN_OK; +} + +void NetDriverRDMA::Stop() +{ + std::lock_guard locker(mInitMutex); + if (!mStarted) { + return; + } + + DoStop(); + + for (auto worker : mWorkers) { + worker->Stop(); + } + + mStarted = false; +} + +NResult NetDriverRDMA::CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + if (NN_UNLIKELY(size == 0 || size > NN_NO107374182400)) { + NN_LOG_ERROR("Failed to create mem region as size is 0 or greater than 100 GB"); + return NN_INVALID_PARAM; + } + + if (!mInited) { + NN_LOG_ERROR("Failed to create Memory region in NetDriverRDMA " << mName << ", as not initialized"); + return NN_EP_NOT_INITIALIZED; + } + + RDMAMemoryRegion *tmp = nullptr; + auto result = RDMAMemoryRegion::Create(mName, mContext, size, tmp); + if (NN_UNLIKELY(result != RR_OK)) { + NN_LOG_ERROR("Failed to create Memory region in NetDriverRDMA " << mName << ", probably out of memory"); + return result; + } + + if ((result = tmp->Initialize()) != RR_OK) { + delete tmp; + return result; + } + + if ((result = mMrChecker.Register(tmp->GetLKey(), tmp->GetAddress(), size)) != NN_OK) { + NN_LOG_ERROR("Failed to add rdma memory region to range checker in driver" << mName << " for duplicate keys"); + delete tmp; + return result; + } + + mr.Set(static_cast(tmp)); + + return NN_OK; +} + +NResult NetDriverRDMA::CreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + if (NN_UNLIKELY(size == 0 || size > NN_NO1099511627776)) { + NN_LOG_ERROR("RDMA Failed to create mem region as size is 0 or greater than 1 TB"); + return NN_INVALID_PARAM; + } + + if (!mInited) { + NN_LOG_ERROR("Failed to create Memory region with ptr in NetDriverRDMA " << mName << ", as not initialized"); + return NN_EP_NOT_INITIALIZED; + } + + if (address == 0) { + NN_LOG_ERROR("Failed to create Memory region with ptr in NetDriverRDMA " << mName << ", as address is 0"); + return NN_INVALID_PARAM; + } + + RDMAMemoryRegion *tmp = nullptr; + auto result = RDMAMemoryRegion::Create(mName, mContext, address, size, tmp); + if (NN_UNLIKELY(result != RR_OK)) { + NN_LOG_ERROR("Failed to create Memory region with ptr in NetDriverRDMA " << mName << + ", probably out of memory"); + return result; + } + + if ((result = tmp->Initialize()) != RR_OK) { + delete tmp; + return result; + } + + if ((result = mMrChecker.Register(tmp->GetLKey(), tmp->GetAddress(), size)) != NN_OK) { + NN_LOG_ERROR("Failed to add memory region with ptr to range checker in driver" << mName << + " for duplicate keys"); + delete tmp; + return result; + } + + mr.Set(static_cast(tmp)); + + return NN_OK; +} + +NResult NetDriverRDMA::CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, unsigned long memid) +{ + NN_LOG_ERROR("operation is not supported in rdma"); + return NN_ERROR; +} + +void NetDriverRDMA::DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) +{ + if (mr.Get() == nullptr) { + NN_LOG_WARN("Try to destroy null memory region in rdma driver " << mName); + return; + } + if (!mMrChecker.Contains(mr->GetLKey())) { + NN_LOG_WARN("Try to destroy unowned memory region in rdma driver " << mName); + return; + } + mMrChecker.UnRegister(mr->GetLKey()); + mr->UnInitialize(); +} + +void *NetDriverRDMA::MapAndRegVaForUB(unsigned long memid, uint64_t &va) +{ + NN_LOG_ERROR("operation is not supported in rdma"); + return nullptr; +} + +NResult NetDriverRDMA::UnmapVaForUB(uint64_t &va) +{ + NN_LOG_ERROR("operation is not supported in rdma"); + return NN_ERROR; +} +} // namespace hcom +} +#endif \ No newline at end of file diff --git a/src/transport/rdma/verbs/net_rdma_driver.h b/src/transport/rdma/verbs/net_rdma_driver.h new file mode 100644 index 0000000000000000000000000000000000000000..2b5b4ce62a3e52d10db1fc1b5fd1727725ed9238 --- /dev/null +++ b/src/transport/rdma/verbs/net_rdma_driver.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_NET_DRIVER_RDMA_123423434341233_H +#define OCK_NET_DRIVER_RDMA_123423434341233_H +#ifdef RDMA_BUILD_ENABLED + +#include +#include + +#include "hcom.h" +#include "net_common.h" +#include "rdma_worker.h" + +namespace ock { +namespace hcom { +class NetDriverRDMA : public UBSHcomNetDriver { +public: + NetDriverRDMA(const std::string &name, bool isServer, UBSHcomNetDriverProtocol protocol) + : UBSHcomNetDriver(name, isServer, protocol) + { + OBJ_GC_INCREASE(NetDriverRDMA); + } + + ~NetDriverRDMA() override + { + OBJ_GC_DECREASE(NetDriverRDMA); + } + + NResult Initialize(const UBSHcomNetDriverOptions &option) override; + + void UnInitialize() override; + + NResult Start() override; + void Stop() override; + + NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) override; + NResult CreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) override; + NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, unsigned long memid) override; + void DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) override; + + inline NResult ValidateMemoryRegion(uint64_t lKey, uintptr_t address, uint64_t size) + { + return mMrChecker.Validate(lKey, address, size); + } + + void DestroyEndpoint(UBSHcomNetEndpointPtr &ep) override; + + void *MapAndRegVaForUB(unsigned long memid, uint64_t &va) override; + + NResult UnmapVaForUB(uint64_t &va) override; + + inline RDMAMemoryRegionFixedBuffer *GetDriverSendMr() const + { + return mDriverSendMR; + } + +protected: + NResult ValidateOptions(); + NResult CreateContext(); + NResult CreateWorkers(); + void ClearWorkers(); + void UnInitializeInner(); + virtual NResult DoInitialize() + { + return NN_OK; + } + + virtual void DoUnInitialize() {} + + virtual NResult DoStart() + { + return NN_OK; + } + + virtual void DoStop() {} + +protected: + std::string mMatchIp; + RDMAContext *mContext = nullptr; + std::vector mWorkers; + RDMAMemoryRegionFixedBuffer *mDriverSendMR = nullptr; + MemoryRegionChecker mMrChecker; + uint32_t mHeartBeatIdleTime = NN_NO8; + uint32_t mHeartBeatProbeInterval = NN_NO1; + +private: + NResult CreateSendMr(); + NResult CreateOpCtxMemPool(); + NResult CreateSglCtxMemPool(); + NResult CreateWorkerResource(); + NResult MatchIpByMask(std::vector &matchIps); + NetMemPoolFixedPtr mOpCtxMemPool = nullptr; + NetMemPoolFixedPtr mSglCtxMemPool = nullptr; +}; +} +} + +#endif +#endif // _OCK_NET_DRIVER_RDMA_123423434341233_H diff --git a/src/transport/rdma/verbs/net_rdma_driver_oob.cpp b/src/transport/rdma/verbs/net_rdma_driver_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..34da3de866c08768bba0f1fb983d077de9e237a1 --- /dev/null +++ b/src/transport/rdma/verbs/net_rdma_driver_oob.cpp @@ -0,0 +1,1869 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include +#include + +#include "net_monotonic.h" +#include "net_oob_ssl.h" +#include "net_rdma_sync_endpoint.h" +#include "net_rdma_async_endpoint.h" +#include "rdma_mr_dm_buf.h" +#include "rdma_mr_fixed_buf.h" +#include "net_rdma_driver_oob.h" +#include "net_oob_secure.h" + +namespace ock { +namespace hcom { +constexpr uint64_t MAX_OP_TIME_US = NN_NO500000; // 500 ms + +NResult NetDriverRDMAWithOob::DoInitialize() +{ + if (mWorkers.empty()) { + NN_LOG_ERROR("Failed to do initialize in Driver " << mName << ", as mWorkers is empty"); + } + + for (auto worker : mWorkers) { + worker->RegisterPostedHandler(std::bind(&NetDriverRDMAWithOob::SendFinished, this, std::placeholders::_1)); + worker->RegisterNewRequestHandler(std::bind(&NetDriverRDMAWithOob::NewRequest, this, std::placeholders::_1)); + worker->RegisterOneSideDoneHandler(std::bind(&NetDriverRDMAWithOob::OneSideDone, this, std::placeholders::_1)); + if (mIdleHandler != nullptr) { + worker->RegisterIdleHandler(mIdleHandler); + } + } + + NResult result = NN_OK; + // create oob + if (mStartOobSvr) { + if ((result = CreateListeners(mOptions.enableMultiRail)) != NN_OK) { + NN_LOG_ERROR("RDMA failed to create listeners"); + return result; + } + } + + mEndPoints.reserve(NN_NO1024); + + return NN_OK; +} + +void NetDriverRDMAWithOob::DoUnInitialize() +{ + if (mStarted) { + NN_LOG_WARN("Invalid to UnInitialize driver " << mName << " which is not stopped"); + return; + } + + if (!mOobServers.empty()) { + mOobServers.clear(); + } +} + +NResult NetDriverRDMAWithOob::DoStart() +{ + if (mStartOobSvr) { + if (mNewEndPointHandler == nullptr) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as newEndPointerHandler is null"); + return NN_INVALID_PARAM; + } + + if (!mOptions.enableMultiRail) { + /* set cb for listeners */ + for (auto &oobServer : mOobServers) { + oobServer->SetNewConnCB(std::bind(&NetDriverRDMAWithOob::NewConnectionCB, this, std::placeholders::_1)); + oobServer->SetNewConnCbThreadNum(mOptions.oobConnHandleThreadCount); + oobServer->SetNewConnCbQueueCap(mOptions.oobConnHandleQueueCap); + } + + NResult result = StartListeners(); + if (result != NN_OK) { + NN_LOG_ERROR("RDMA failed to start listeners"); + return result; + } + } + } + + mHeartBeat = new (std::nothrow) NetHeartbeat(this, mOptions.heartBeatIdleTime, mOptions.heartBeatProbeInterval); + if (mHeartBeat == nullptr) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as new heartbeat failed"); + StopListeners(); + return NN_ERROR; + } + + NResult result = mHeartBeat->Start(); + if (result != NN_OK) { + delete mHeartBeat; + mHeartBeat = nullptr; + StopListeners(); + return result; + } + + mNeedStopEvent = false; + std::thread tmpEventThread(&NetDriverRDMAWithOob::RunInRdmaEventThread, this); + mRdmaEventThread = std::move(tmpEventThread); + + while (!mEventStarted.load()) { + usleep(NN_NO10); + } + + return NN_OK; +} + +void NetDriverRDMAWithOob::DoStop() +{ + if (mHeartBeat != nullptr) { + mHeartBeat->Stop(); + delete mHeartBeat; + mHeartBeat = nullptr; + } + + mNeedStopEvent = true; + if (mRdmaEventThread.native_handle()) { + mRdmaEventThread.join(); + } + + StopListeners(); +} + +void NetDriverRDMAWithOob::DestroyEpByPortNum(int portNum) +{ + static thread_local std::vector endPointsCopy; + endPointsCopy.reserve(NN_NO8192); + endPointsCopy.clear(); + { + std::lock_guard locker(mEndPointsMutex); + for (auto iter = mEndPoints.begin(); iter != mEndPoints.end();) { + auto asyncEp = iter->second.ToChild(); + if (asyncEp != nullptr && asyncEp->GetRdmaEp()->Qp()->PortNum() == portNum) { + endPointsCopy.emplace_back(iter->second); + iter = mEndPoints.erase(iter); + } else { + ++iter; + } + } + } + + for (auto &endPoint : endPointsCopy) { + NN_LOG_WARN("Detect port down event, handle Ep id " << endPoint->Id() << " of driver " << mName); + ProcessEpError(reinterpret_cast(endPoint.Get())); + } + + NN_LOG_INFO("Destroyed all endpoints count " << endPointsCopy.size() << " by port down of driver " << mName); + endPointsCopy.clear(); +} + +void NetDriverRDMAWithOob::HandlePortDown(int portNum) +{ + for (auto worker : mWorkers) { + if (worker->PortNum() == portNum) { + worker->Stop(); + } + } + + DestroyEpByPortNum(portNum); +} + +void NetDriverRDMAWithOob::HandlePortActive(int portNum) +{ + for (auto worker : mWorkers) { + if (worker->PortNum() == portNum) { + worker->Start(); + } + } +} + +void NetDriverRDMAWithOob::DestroyEpInWorker(RDMAWorker *worker) +{ + static thread_local std::vector endPointsCopy; + endPointsCopy.reserve(NN_NO8192); + endPointsCopy.clear(); + { + std::lock_guard locker(mEndPointsMutex); + for (auto iter = mEndPoints.begin(); iter != mEndPoints.end();) { + auto asyncEp = iter->second.ToChild(); + if (asyncEp != nullptr && asyncEp->mEp->mWorker == worker) { + endPointsCopy.emplace_back(iter->second); + iter = mEndPoints.erase(iter); + } else { + ++iter; + } + } + } + + for (auto &endPoint : endPointsCopy) { + NN_LOG_WARN("Detect CQ incorrect event, handle Ep id " << endPoint->Id() << " of driver " << mName); + ProcessEpError(reinterpret_cast(endPoint.Get())); + } + + NN_LOG_INFO("Destroyed all endpoints count " << endPointsCopy.size() << " in RDMA worker " << + worker->DetailName() << " of driver " << mName); + endPointsCopy.clear(); +} + +void NetDriverRDMAWithOob::HandleCqEvent(struct ibv_async_event *event) +{ + /* when sync mode connecting, there is no worker */ + if (event->element.cq->cq_context == nullptr) { + NN_LOG_ERROR("CQ error for CQ of driver " << mName); + } else { + auto worker = reinterpret_cast(event->element.cq->cq_context); + NN_LOG_ERROR("CQ error for CQ with handle " << event->element.cq << " in RDMA worker " << + worker->DetailName() << " of driver " << mName); + if (worker->Stop() != NN_OK) { + NN_LOG_ERROR("Handle Cq event stop error in RDMA worker " << worker->DetailName() << " of driver " << + mName); + return; + } + + DestroyEpInWorker(worker); + if (worker->ReInitializeCQ() != NN_OK) { + NN_LOG_ERROR("Handle Cq event ReInitializeCQ error in RDMA worker " << worker->DetailName() << + " of driver " << mName); + return; + } + if (worker->Start() != NN_OK) { + NN_LOG_ERROR("Handle Cq event start error in RDMA worker " << worker->DetailName() << " of driver " << + mName); + return; + } + } +} + +static inline std::string QpDetailInfo(void *qpContext) +{ + auto qp = reinterpret_cast(qpContext); + std::ostringstream oss; + oss << "[Qp name:" << qp->Name() << ", id:" << qp->Id() << "]"; + return oss.str(); +} + +void NetDriverRDMAWithOob::HandleAsyncEvent(struct ibv_async_event *event) +{ + switch (event->event_type) { + /* QP events */ + case IBV_EVENT_QP_FATAL: + NN_LOG_ERROR("QP fatal event for " << QpDetailInfo(event->element.qp->qp_context) << " of driver " << + mName); + break; + case IBV_EVENT_QP_REQ_ERR: + NN_LOG_ERROR("QP Requester error for " << QpDetailInfo(event->element.qp->qp_context) << " of driver " << + mName); + break; + case IBV_EVENT_QP_ACCESS_ERR: + NN_LOG_ERROR("QP access error event for " << QpDetailInfo(event->element.qp->qp_context) << " of driver " << + mName); + break; + case IBV_EVENT_COMM_EST: + NN_LOG_ERROR("QP communication established event for " << QpDetailInfo(event->element.qp->qp_context) << + " of driver " << mName); + break; + case IBV_EVENT_SQ_DRAINED: + NN_LOG_ERROR("QP Send Queue drained event for " << QpDetailInfo(event->element.qp->qp_context) << + " of driver " << mName); + break; + case IBV_EVENT_PATH_MIG: + NN_LOG_ERROR("QP Path migration loaded event for " << QpDetailInfo(event->element.qp->qp_context) << + " of driver " << mName); + break; + case IBV_EVENT_PATH_MIG_ERR: + NN_LOG_ERROR("QP Path migration error event for " << QpDetailInfo(event->element.qp->qp_context) << + " of driver " << mName); + break; + case IBV_EVENT_QP_LAST_WQE_REACHED: + NN_LOG_ERROR("QP last WQE reached event for " << QpDetailInfo(event->element.qp->qp_context) << + " of driver " << mName); + break; + + /* CQ events */ + case IBV_EVENT_CQ_ERR: + HandleCqEvent(event); + break; + + /* SRQ events */ + case IBV_EVENT_SRQ_ERR: + NN_LOG_ERROR("SRQ error for SRQ of driver " << mName); + break; + case IBV_EVENT_SRQ_LIMIT_REACHED: + NN_LOG_ERROR("SRQ limit reached event for SRQ of driver " << mName); + break; + + /* Port events */ + case IBV_EVENT_PORT_ACTIVE: + NN_LOG_ERROR("Port active event for port number " << event->element.port_num << " of driver " << mName); + HandlePortActive(event->element.port_num); + break; + case IBV_EVENT_PORT_ERR: + NN_LOG_ERROR("Port error event for port number " << event->element.port_num << " of driver " << mName); + /* case1: The physical link is disconnected */ + /* case2: ifconfig down + 1) QP can work normal before the event happened in CX5 + 2) QP report err in 182x */ + HandlePortDown(event->element.port_num); + break; + case IBV_EVENT_LID_CHANGE: + NN_LOG_ERROR("LID change event for port number " << event->element.port_num << " of driver " << mName); + break; + case IBV_EVENT_PKEY_CHANGE: + NN_LOG_ERROR("P_Key table change event for port number " << event->element.port_num << " of driver " << + mName); + break; + case IBV_EVENT_GID_CHANGE: + NN_LOG_ERROR("GID table change event for port number " << event->element.port_num << " of driver " << + mName); + mContext->UpdateGid(mMatchIp); + break; + case IBV_EVENT_SM_CHANGE: + NN_LOG_ERROR("SM change event for port number " << event->element.port_num << " of driver " << mName); + break; + case IBV_EVENT_CLIENT_REREGISTER: + NN_LOG_ERROR("Client reregister event for port number " << event->element.port_num << " of driver " << + mName); + break; + + /* RDMA device events */ + case IBV_EVENT_DEVICE_FATAL: + NN_LOG_ERROR("Fatal error event for device of driver " << mName); + break; + + default: + NN_LOG_ERROR("Unknown event " << event->event_type << " of driver " << mName); + } +} + +void NetDriverRDMAWithOob::RunInRdmaEventThread() +{ + mEventStarted.store(true); + NN_LOG_INFO("Rdma event monitor thread for driver " << mName << " started"); + + /* set thread name */ + pthread_setname_np(pthread_self(), ("RDMAEvent" + std::to_string(mIndex)).c_str()); + + /* set nonblock */ + int flags = fcntl(mContext->Context()->async_fd, F_GETFL); + int ret = fcntl(mContext->Context()->async_fd, F_SETFL, (static_cast(flags)) | O_NONBLOCK); + if (ret < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to change event fd of RDMA context for driver " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return; + } + + struct ibv_async_event event {}; + while (!mNeedStopEvent) { + struct pollfd fd {}; + int timeoutMs = NN_NO100; + fd.fd = mContext->Context()->async_fd; + fd.events = POLLIN; + fd.revents = 0; + do { + ret = poll(&fd, 1, timeoutMs); + if (ret > 0) { + break; + } else if (ret < 0 && errno != EINTR) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to poll event fd of RDMA context for driver " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + break; + } + // rc == 0 + } while (!mNeedStopEvent); + + ret = HcomIbv::GetAsyncEvent(mContext->Context(), &event); + if (ret < 0) { + /* nothing happen when nonblock mode */ + continue; + } + + /* ack the event, otherwise destroy cq will block */ + HcomIbv::AckAsyncEvent(&event); + + /* when fatal event happened, need stop worker first, then call ep broken to prevent race condition + with poll cq thread */ + HandleAsyncEvent(&event); + } + NN_LOG_INFO("Rdma event monitor thread for driver " << mName << " exiting"); + mEventStarted.store(false); +} + +NResult NetDriverRDMAWithOob::MultiRailNewConnection(OOBTCPConnection &conn) +{ + return NewConnectionCB(conn); +} + +int NetDriverRDMAWithOob::NewConnectionCB(OOBTCPConnection &conn) +{ + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBServer(mSecInfoProvider, mSecInfoValidator, conn, mName, + mOptions.secType)) != NN_OK) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + uint32_t ip = NetFunc::GetIpByFd(conn.GetFd()); + if (NN_UNLIKELY(OOBSecureProcess::SecProcessCompareEpNum(ip, conn.ListenPort(), conn.GetIpAndPort(), + mOobServers)) != NN_OK) { + NN_LOG_ERROR("Rdma connection num exceeds maximum"); + return NN_OOB_SEC_PROCESS_ERROR; + } + + int result = 0; + + // receive server worker grpno + auto startRecvWG = NetMonotonic::TimeUs(); + ConnectHeader header {}; + void *grpnobuf = &header; + if ((result = conn.Receive(grpnobuf, sizeof(ConnectHeader))) != 0) { + NN_LOG_ERROR("Failed to receive specified server worker grpno from client " << mName << ", result " << result); + return NN_ERROR; + } + + ConnRespWithUId respWithUId{ OK, 0 }; + result = OOBSecureProcess::SecCheckConnectionHeader(header, mOptions, mEnableTls, Protocol(), mMajorVersion, + mMinorVersion, respWithUId); + if (result != NN_OK) { + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + + auto endRecvWG = NetMonotonic::TimeUs(); + auto recvWGtime = endRecvWG - startRecvWG; + if (NN_UNLIKELY(recvWGtime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Receive group num time is too long :" << recvWGtime << " us."); + } + + /* choose worker */ + NetWorkerLBPtr lb = nullptr; + if (mOptions.enableMultiRail) { + lb = mServerLb; + } else { + lb = conn.LoadBalancer(); + } + + NN_ASSERT_LOG_RETURN(lb.Get() != nullptr, NN_ERROR) + uint16_t workerIndex = 0; + if (NN_UNLIKELY(!lb->ChooseWorker(header.groupIndex, conn.GetIpAndPort(), workerIndex)) || + workerIndex >= mWorkers.size()) { + NN_LOG_ERROR("Failed to find worker fit grpno " << header.groupIndex << " in " << mName << " , result " << + result); + ConnRespWithUId respWithUId { WORKER_GRPNO_MISMATCH, 0 }; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + + NN_LOG_TRACE_INFO("Worker " << workerIndex << " is chosen in driver " << mName); + auto worker = mWorkers[workerIndex]; + NN_ASSERT_LOG_RETURN(worker != nullptr, NN_ERROR); + + if (!worker->IsWorkStarted()) { + NN_LOG_ERROR("Failed to connect worker group no " << header.groupIndex << " in " << mName); + ConnRespWithUId respWithUId { WORKER_NOT_STARTED, 0 }; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + + // create qp + auto startCreateQp = NetMonotonic::TimeUs(); + NN_LOG_TRACE_INFO("create and initialize qp"); + RDMAAsyncEndPoint *rep = nullptr; + + if ((result = RDMAAsyncEndPoint::Create(mName, worker, rep)) != 0) { + NN_LOG_ERROR("Failed to create ep for new connection in Driver " << mName << " , result " << result); + ConnRespWithUId respWithUId { SERVER_INTERNAL_ERROR, 0 }; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + NetLocalAutoDecreasePtr repAutoDecPtr(rep); + if ((result = rep->Initialize()) != 0) { + NN_LOG_ERROR("Failed to initialize ep for new connection in Driver " << mName << " , result " << result); + ConnRespWithUId respWithUId { SERVER_INTERNAL_ERROR, 0 }; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + + auto id = NetUuid::GenerateUuid(); + NN_LOG_TRACE_INFO("new ep id will be set as " << id << " in driver " << mName); + + respWithUId.connResp = OK; + respWithUId.epId = id; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + auto endCreateQp = NetMonotonic::TimeUs(); + auto createQpTime = endCreateQp - startCreateQp; + if (NN_UNLIKELY(createQpTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Create qp time is too long :" << createQpTime << " us."); + } + + // exchange info + NN_LOG_TRACE_INFO("get and send exchange info of ep"); + auto startExchInfo = NetMonotonic::TimeUs(); + auto prePostCount = mOptions.prePostReceiveSizePerQP; + RDMAQpExchangeInfo info {}; + if (mHeartBeat != nullptr) { + mHeartBeat->GetRemoteHbInfo(info); + } + info.maxSendWr = mOptions.qpSendQueueSize; + info.maxReceiveWr = mOptions.qpReceiveQueueSize; + info.receiveSegSize = mOptions.mrSendReceiveSegSize; + info.receiveSegCount = mOptions.prePostReceiveSizePerQP; + if (((result = rep->GetExchangeInfo(info)) != 0)) { + NN_LOG_ERROR("Failed to get ep exchange info in Driver " << mName << ", result " << result); + return NN_ERROR; + } + if (((result = conn.Send(&info, sizeof(RDMAQpExchangeInfo))) != 0)) { + NN_LOG_ERROR("Failed to get or send ep exchange info in Driver " << mName << ", result " << result); + return NN_ERROR; + } + NN_LOG_TRACE_INFO("Send exchange info success in Server " << mName); + NN_LOG_TRACE_INFO("local ep ex info lid " << info.lid << ", qpn " << info.qpn << ", gid interface " << + info.gid.global.interface_id); + void *tmp = static_cast(&info); + if ((result = conn.Receive(tmp, sizeof(RDMAQpExchangeInfo))) != 0) { + NN_LOG_ERROR("Failed to receive ep exchange info in Driver " << mName << ", result " << result); + return NN_ERROR; + } + NN_LOG_TRACE_INFO("Recv exchange info success in Server " << mName); + + // receive payload length + uint32_t payloadLen = 0; + auto tmpPayloadLen = reinterpret_cast(&payloadLen); + if ((result = conn.Receive(tmpPayloadLen, sizeof(uint32_t))) != 0) { + NN_LOG_ERROR("Failed to receive connection payload length in Driver " << mName << ", result " << result); + return NN_ERROR; + } + + if (payloadLen == 0 || payloadLen > NN_NO1024) { + NN_LOG_ERROR("Invalid payload length " << payloadLen << ", it should be 1 ~ 1024"); + return NN_ERROR; + } + + // receive payload + std::string payload; + if (payloadLen > 0) { + auto payChars = new (std::nothrow) char[payloadLen + NN_NO1]; + if (payChars == nullptr) { + NN_LOG_ERROR("Failed to new payload char array in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + NetLocalAutoFreePtr autoFreePayChars(payChars, true); + + void *tmpChars = static_cast(payChars); + if ((result = conn.Receive(tmpChars, payloadLen)) != 0) { + NN_LOG_ERROR("Failed to receive connection payload in Driver " << mName << ", result " << result); + return NN_ERROR; + } + + payChars[payloadLen] = '\0'; + payload = std::string(payChars, payloadLen); + } + + NN_LOG_TRACE_INFO("remote qp ex info lid " << info.lid << ", qpn " << info.qpn << ", gid interface " << + info.gid.global.interface_id << ", pre-post-receive-count " << info.receiveSegCount); + if ((result = rep->ChangeToReady(info)) != 0) { + NN_LOG_ERROR("Failed to change ep to ready in Driver " << mName << ", result " << result); + return result; + } + + auto *mrSegs = new (std::nothrow) uintptr_t[prePostCount]; + if (mrSegs == nullptr) { + NN_LOG_ERROR("Failed to create mr address array in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + NetLocalAutoFreePtr segAutoDelete(mrSegs, true); + + if (!rep->GetFreeBufferN(mrSegs, prePostCount)) { + NN_LOG_ERROR("Failed to get free mr from pool, mr is not enough"); + return NN_MALLOC_FAILED; + } + + uint16_t i = 0; + for (; i < prePostCount; i++) { + if ((result = worker->PostReceive(rep->Qp(), mrSegs[i], mOptions.mrSendReceiveSegSize, rep->GetLKey())) != 0) { + break; + } + } + + for (; i < prePostCount; i++) { + rep->ReturnBuffer(mrSegs[i]); + } + + rep->PeerIpAndPort(conn.GetIpAndPort()); + auto endExchInfo = NetMonotonic::TimeUs(); + auto exchInfoTime = endExchInfo - startExchInfo; + if (NN_UNLIKELY(exchInfoTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Exchange info time too long :" << exchInfoTime << " us."); + } + + // create endpoint + auto startCreateEp = NetMonotonic::TimeUs(); + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(id, rep, this, worker->Index()); + if (ep.Get() == nullptr) { + NN_LOG_ERROR("Failed to create UBSHcomNetEndpoint in Driver " << mName << ", probably out of memory"); + // do later, remove mr for prepost + return NN_NEW_OBJECT_FAILED; + } + + if (mOptions.oobType == NET_OOB_UDS) { + struct ucred remoteIds {}; + socklen_t len = sizeof(struct ucred); + if (NN_UNLIKELY(getsockopt(conn.GetFd(), SOL_SOCKET, SO_PEERCRED, &remoteIds, &len) != 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get uds ids in driver " << mName << " errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_GET_UDS_ID_INFO_FAILED; + } + ep->RemoteUdsIdInfo(remoteIds.pid, remoteIds.uid, remoteIds.gid); + } + + ep->StoreConnInfo(NetFunc::GetIpByFd(conn.GetFd()), conn.ListenPort(), header.version, payload); + auto asyncEp = ep.ToChild(); + if (NN_UNLIKELY(asyncEp == nullptr)) { + NN_LOG_ERROR("To child Failed"); + return NN_ERROR; + } + + asyncEp->SetRemoteHbInfo(info.hbAddress, info.hbKey, info.hbMrSize); + if (mEnableTls) { + auto oobSslConn = dynamic_cast(&conn); + if (NN_UNLIKELY(oobSslConn == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + asyncEp->EnableEncrypt(mOptions); + asyncEp->SetSecrets(oobSslConn->Secret()); + } + rep->Qp()->UpContext(reinterpret_cast(ep.Get())); + ep->mDevIndex = mDevIndex; + ep->mPeerDevIndex = mPeerDevIndex; + ep->mBandWidth = mBandWidth; + ep->State().Set(NEP_ESTABLISHED); + result = mNewEndPointHandler(conn.GetIpAndPort(), ep, payload); + if (NN_UNLIKELY(result != RR_OK)) { + NN_LOG_ERROR("Called new end point handler failed, result " << result); + // do later, remove mr for prepost + return NN_ERROR; + } + + int8_t ready = 1; + if ((result = conn.Send(&ready, sizeof(int8_t))) != RR_OK) { + NN_LOG_ERROR("Failed to send ready signal to client, result " << result); + // do later, remove mr for prepost + return NN_ERROR; + } + + { + std::lock_guard locker(mEndPointsMutex); + auto ret = mEndPoints.emplace(ep->Id(), ep); + if (!ret.second) { + NN_LOG_ERROR("Failed to emplace ep, ep Id " << ep->Id()); + return NN_ERROR; + } + } + reinterpret_cast(ep.Get())->GetRdmaEp()->Qp()->UpId(ep->Id()); + auto endCreateEp = NetMonotonic::TimeUs(); + auto createEpTime = endCreateEp - startCreateEp; + if (NN_UNLIKELY(createEpTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Create endpoint time too long :" << createEpTime << " us."); + } + + OOBSecureProcess::SecProcessAddEpNum(ip, conn.ListenPort(), conn.GetIpAndPort(), mOobServers); + + NN_LOG_INFO("New connection from " << conn.GetIpAndPort() << " established, async ep id " << ep->Id() << + " worker info " << worker->DetailName()); + return NN_OK; +} + +NResult NetDriverRDMAWithOob::Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, + uint8_t serverGrpNo, uint8_t clientGrpNo) +{ + if (mOptions.oobType == NET_OOB_TCP) { + return Connect(mOobIp, mOobPort, payload, ep, flags, serverGrpNo, clientGrpNo); + } else if (mOptions.oobType == NET_OOB_UDS) { + return Connect(mUdsName, 0, payload, ep, flags, serverGrpNo, clientGrpNo); + } + return NN_ERROR; +} + +NResult NetDriverRDMAWithOob::Connect(const std::string &serverUrl, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + if (NN_UNLIKELY(!mInited.load())) { + NN_LOG_ERROR("Verbs Driver " << mName << " is not initialized"); + return NN_NOT_INITIALIZED; + } + + if (NN_UNLIKELY(!mStarted)) { + NN_LOG_ERROR("Verbs Failed to connect on driver " << mName << " as it is not started"); + return NN_ERROR; + } + + if (payload.size() > NN_NO1024) { + NN_LOG_ERROR("Verbs Failed to connect server via payload size " << payload.size() << " over limit"); + return NN_INVALID_PARAM; + } + if (NN_UNLIKELY(NetFunc::NN_ValidateUrl(serverUrl) != NN_OK)) { + NN_LOG_ERROR("Invalid url"); + return NN_PARAM_INVALID; + } + + NetDriverOobType type; + std::string ip; + uint16_t port = 0; + if (NN_UNLIKELY(ParseUrl(serverUrl, type, ip, port) != NN_OK)) { + NN_LOG_WARN("Invalid url, url:" << serverUrl); + return NN_INVALID_PARAM; + } + + OOBTCPClientPtr client; + if (mEnableTls) { + auto oobSSLClient = new (std::nothrow) + OOBSSLClient(type, ip, port, mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(oobSSLClient != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLClient->SetTlsOptions(mOptions); + oobSSLClient->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + client = oobSSLClient; + } else { + client = new (std::nothrow) OOBTCPClient(mOptions.oobType, ip, port); + NN_ASSERT_LOG_RETURN(client.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + /* all kind of drivers can connect to peer to get an ep */ + if ((flags & NET_EP_SELF_POLLING) || (flags & NET_EP_EVENT_POLLING)) { + return ConnectSyncEp(client, payload, outEp, flags, serverGrpNo, ctx); + } + + return Connect(client, payload, outEp, serverGrpNo, clientGrpNo, ctx); +} + +NResult NetDriverRDMAWithOob::Connect(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + if (NN_UNLIKELY(!mInited.load())) { + NN_LOG_ERROR("Verbs Driver " << mName << " is not initialized"); + return NN_NOT_INITIALIZED; + } + + if (NN_UNLIKELY(!mStarted)) { + NN_LOG_ERROR("Verbs Failed to connect on driver " << mName << " as it is not started"); + return NN_ERROR; + } + + if (payload.size() > NN_NO1024) { + NN_LOG_ERROR("Verbs Failed to connect server via payload size " << payload.size() << " over limit"); + return NN_INVALID_PARAM; + } + + OOBTCPClientPtr client; + if (mEnableTls) { + auto oobSSLClient = new (std::nothrow) + OOBSSLClient(mOptions.oobType, oobIp, oobPort, mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(oobSSLClient != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLClient->SetTlsOptions(mOptions); + oobSSLClient->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + client = oobSSLClient; + } else { + client = new (std::nothrow) OOBTCPClient(mOptions.oobType, oobIp, oobPort); + NN_ASSERT_LOG_RETURN(client.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + /* all kind of drivers can connect to peer to get an ep */ + if ((flags & NET_EP_SELF_POLLING) || (flags & NET_EP_EVENT_POLLING)) { + return ConnectSyncEp(client, payload, outEp, flags, serverGrpNo, ctx); + } + + return Connect(client, payload, outEp, serverGrpNo, clientGrpNo, ctx); +} + +NResult NetDriverRDMAWithOob::ConnectSyncEp(const OOBTCPClientPtr &client, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo, uint64_t ctx) +{ + /* try to connect to oob server */ + OOBTCPConnection *conn = nullptr; + NResult result = NN_OK; + if (NN_UNLIKELY((result = client->Connect(conn)) != 0)) { + NN_LOG_ERROR("Verbs Failed to connect server via oob,result " << result); + return result; + } + + NetLocalAutoDecreasePtr autoDecPtr(conn); + if (client->GetOobType() == NET_OOB_TCP) { + conn->SetIpAndPort(client->GetServerIp(), client->GetServerPort()); + } else { + conn->SetIpAndPort(client->GetServerUdsName(), 0); + } + + if (mOptions.enableMultiRail) { + ConnectHeader driverHeader; + SetDriverConnHeader(driverHeader, mBandWidth, mDevIndex); + if (NN_UNLIKELY((result = conn->Send(&driverHeader, sizeof(ConnectHeader))) != 0)) { + NN_LOG_ERROR("Failed to send driver info " << mName << ", Result " << result); + return result; + } + + ConnectHeader header{}; + void *grpnobuf = static_cast(&header); + result = conn->Receive(grpnobuf, sizeof(ConnectHeader)); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to receive specified device info for server, result " << result); + return result; + } + + if (header.devIndex >= NN_NO4) { + NN_LOG_ERROR("Invalid devIndex " << header.devIndex << " in header, which should be in 0 ~ 3"); + return NN_ERROR; + } + mPeerDevIndex = header.devIndex; + } + + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBClient(mSecInfoProvider, mSecInfoValidator, conn, mName, ctx, + mOptions.secType))) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + RDMAPollingMode pollMode = ((flags & NET_EP_EVENT_POLLING)) ? EVENT_POLLING : BUSY_POLLING; + + uint16_t prePostCount = mOptions.prePostReceiveSizePerQP; + + // create + RDMASyncEndpoint *rep = nullptr; + QpOptions qpOptions(mOptions.qpSendQueueSize, mOptions.qpReceiveQueueSize, mOptions.mrSendReceiveSegSize, + mOptions.prePostReceiveSizePerQP); + if (NN_UNLIKELY((result = + RDMASyncEndpoint::Create(mName, mContext, pollMode, prePostCount + NN_NO4, qpOptions, rep)) != 0)) { + NN_LOG_ERROR("Failed to create sync ep for new connection in Driver " << mName << " , result " << result); + return result; + } + + NetLocalAutoDecreasePtr repAutoDecPtr(rep); + if (NN_UNLIKELY((result = rep->Initialize()) != 0)) { + NN_LOG_ERROR("Failed to initialize ep for new connection in Driver " << mName << " , result " << result); + return result; + } + + /* send connection header */ + ConnectHeader header; + SetConnHeader(header, mOptions.magic, mOptions.version, serverGrpNo, Protocol(), mMajorVersion, + mMinorVersion, mOptions.tlsVersion); + + if (NN_UNLIKELY((result = conn->Send(&header, sizeof(ConnectHeader))) != 0)) { + NN_LOG_ERROR("Failed to send server worker grpno in Driver " << mName << ", result " << result); + return result; + } + + /* receive connect response and peer ep id */ + ConnRespWithUId respWithUId {}; + void *ackBuf = static_cast(&respWithUId); + if (NN_UNLIKELY((result = conn->Receive(ackBuf, sizeof(ConnRespWithUId))) != 0)) { + NN_LOG_ERROR("Failed receive ServerAck in Driver " << mName << ", result " << result); + return result; + } + + /* connect response */ + auto serverAck = respWithUId.connResp; + switch (serverAck) { + case MAGIC_MISMATCH: + NN_LOG_ERROR("Verbs Failed to pass server magic validation " << mName << ", result " << serverAck); + return NN_CONNECT_REFUSED; + case WORKER_GRPNO_MISMATCH: + case WORKER_NOT_STARTED: + NN_LOG_ERROR("Verbs Failed to choose worker or not started " << mName << ", result " << serverAck); + return NN_CONNECT_REFUSED; + case PROTOCOL_MISMATCH: + NN_LOG_ERROR("Verbs Failed to pass server protocol validation " << mName << ", result " << serverAck); + return NN_CONNECT_PROTOCOL_MISMATCH; + case SERVER_INTERNAL_ERROR: + NN_LOG_ERROR("Verbs Server error happened, connection refused " << mName << ", result " << serverAck); + return NN_ERROR; + case VERSION_MISMATCH: + NN_LOG_ERROR("Verbs Failed to pass server version validation " << mName << ", result " << serverAck); + return NN_CONNECT_REFUSED; + case TLS_VERSION_MISMATCH: + NN_LOG_ERROR("Verbs Failed to pass server tls version validation " << mName << ", result " << serverAck); + return NN_CONNECT_REFUSED; + case OK: + break; + default: + NN_LOG_ERROR("Verbs Server error happened, connection refused " << mName << ", result " << serverAck); + return NN_ERROR; + } + + /* peer ep id */ + auto id = respWithUId.epId; + NN_LOG_TRACE_INFO("new ep id will be set as " << id << " in driver " << mName); + + // exchange info + NN_LOG_TRACE_INFO("get and send exchange info of ep"); + RDMAQpExchangeInfo info {}; + if (mHeartBeat != nullptr) { + mHeartBeat->GetRemoteHbInfo(info); + } + info.maxSendWr = mOptions.qpSendQueueSize; + info.maxReceiveWr = mOptions.qpReceiveQueueSize; + info.receiveSegSize = mOptions.mrSendReceiveSegSize; + info.receiveSegCount = mOptions.prePostReceiveSizePerQP; + if (NN_UNLIKELY((result = rep->GetExchangeInfo(info)) != 0)) { + NN_LOG_ERROR("Failed to get ep exchange info in Driver " << mName << ", result " << result); + return result; + } + if (NN_UNLIKELY((result = conn->Send(&info, sizeof(RDMAQpExchangeInfo))) != 0)) { + NN_LOG_ERROR("Failed to send ep exchange info in Driver " << mName << ", result " << result); + return result; + } + + // send payload len + uint32_t payloadLen = payload.length(); + if (NN_UNLIKELY((result = conn->Send(&payloadLen, sizeof(uint32_t))) != 0)) { + NN_LOG_ERROR("Failed to send connection payload length in Driver " << mName << ", result " << result); + return result; + } + + // send payload + if (payloadLen > 0) { + auto payloadPtr = reinterpret_cast(const_cast(payload.c_str())); + if ((result = conn->Send(payloadPtr, payloadLen)) != 0) { + NN_LOG_ERROR("Failed to send connection payload in Driver " << mName << ", result " << result); + return result; + } + } + + void *tmp = static_cast(&info); + if (NN_UNLIKELY((result = conn->Receive(tmp, sizeof(RDMAQpExchangeInfo))) != 0)) { + NN_LOG_ERROR("Failed to receive ep exchange info in Driver " << mName << ", result " << result); + return result; + } + + NN_LOG_TRACE_INFO("remote qp ex info lid " << info.lid << ", qpn " << info.qpn << ", gid interface " << + info.gid.global.interface_id << ", pre-post-receive-count " << info.receiveSegCount); + if (NN_UNLIKELY((result = rep->ChangeToReady(info)) != 0)) { + NN_LOG_ERROR("Failed to change ep to ready in Driver " << mName << ", result " << result); + return result; + } + + rep->PeerIpAndPort(conn->GetIpAndPort()); + + auto *mrSegs = new (std::nothrow) uintptr_t[prePostCount]; + if (mrSegs == nullptr) { + NN_LOG_ERROR("Failed to create mr address array in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + NetLocalAutoFreePtr segAutoDelete(mrSegs, true); + + if (NN_UNLIKELY(!rep->GetFreeBufferN(mrSegs, prePostCount))) { + NN_LOG_ERROR("Failed to get free mr from pool, result " << result); + return NN_ERROR; + } + + uint16_t i = 0; + for (; i < prePostCount; i++) { + if ((result = rep->PostReceive(mrSegs[i], mOptions.mrSendReceiveSegSize, rep->GetLKey())) != 0) { + // do later if failure, qp should broken at this time + break; + } + } + + for (; i < prePostCount; i++) { + rep->ReturnBuffer(mrSegs[i]); + } + + // create endpoint + static UBSHcomNetWorkerIndex workerIndex; + workerIndex.driverIdx = mIndex; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetSyncEndpoint(id, rep, this, workerIndex); + if (NN_UNLIKELY(ep.Get() == nullptr)) { + NN_LOG_ERROR("Failed to create UBSHcomNetEndpoint in Driver " << mName << ", probably out of memory"); + // do later: handle pre post-ed mr + return NN_NEW_OBJECT_FAILED; + } + if (mEnableTls) { + auto asyncEp = ep.ToChild(); + auto oobSslConn = dynamic_cast(conn); + if (NN_UNLIKELY(asyncEp == nullptr || oobSslConn == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + asyncEp->EnableEncrypt(mOptions); + asyncEp->SetSecrets(oobSslConn->Secret()); + } + ep->StoreConnInfo(NetFunc::GetIpByFd(conn->GetFd()), conn->ListenPort(), header.version, payload); + + // receive server ready signal + int8_t ready = -1; + tmp = static_cast(&ready); + result = conn->Receive(tmp, sizeof(int8_t)); + if (result != 0 || ready != 1) { + NN_LOG_ERROR("Failed to connect to server as server not respond or return not ready, result " << result); + // do later: handle pre post-ed mr + return NN_ERROR; + } + + rep->Qp()->UpContext(reinterpret_cast(ep.Get())); + { + std::lock_guard locker(mEndPointsMutex); + mEndPoints.emplace(id, ep); + } + + ep->State().Set(NEP_ESTABLISHED); + + NN_LOG_INFO("New connect to tcp:" << client->GetServerIp() << ":" << client->GetServerPort() <<" or uds: " << + client->GetServerUdsName() << " established, sync ep id " << ep->Id()); + + ep->mDevIndex = mDevIndex; + ep->mPeerDevIndex = mPeerDevIndex; + ep->mBandWidth = mBandWidth; + outEp = ep; + reinterpret_cast(ep.Get())->GetRdmaEp()->Qp()->UpId(ep->Id()); + return NN_OK; +} + +NResult NetDriverRDMAWithOob::Connect(const OOBTCPClientPtr &client, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + /* try to connect to oob server */ + OOBTCPConnection *conn = nullptr; + NResult result = NN_OK; + if ((result = client->Connect(conn)) != 0) { + NN_LOG_ERROR("Verbs Failed to connect server via oob, Result " << result); + return result; + } + + NetLocalAutoDecreasePtr autoDecPtr(conn); + if (client->GetOobType() == NET_OOB_TCP) { + conn->SetIpAndPort(client->GetServerIp(), client->GetServerPort()); + } else { + conn->SetIpAndPort(client->GetServerUdsName(), 0); + } + + if (mOptions.enableMultiRail) { + ConnectHeader driverHeader; + SetDriverConnHeader(driverHeader, mBandWidth, mDevIndex); + if ((result = conn->Send(&driverHeader, sizeof(ConnectHeader))) != 0) { + NN_LOG_ERROR("Failed to send driver info " << mName << ", Result " << result); + return result; + } + + ConnectHeader header{}; + void *grpnobuf = &header; + result = conn->Receive(grpnobuf, sizeof(ConnectHeader)); + if (result != 0) { + NN_LOG_ERROR("Failed to receive specified device info for server, Result " << result); + return result; + } + + if (header.devIndex >= NN_NO4) { + NN_LOG_ERROR("Invalid devIndex " << header.devIndex << " in header, which should be in 0 ~ 3"); + return NN_ERROR; + } + mPeerDevIndex = header.devIndex; + } + + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBClient(mSecInfoProvider, mSecInfoValidator, conn, mName, ctx, + mOptions.secType))) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + /* send connection header & grpNo */ + auto startSendGrpNo = NetMonotonic::TimeUs(); + ConnectHeader header; + SetConnHeader(header, mOptions.magic, mOptions.version, serverGrpNo, Protocol(), mMajorVersion, + mMinorVersion, mOptions.tlsVersion); + if ((result = conn->Send(&header, sizeof(ConnectHeader))) != 0) { + NN_LOG_ERROR("Verbs Failed to send server worker grpno in Driver " << mName << ", result " << result); + return result; + } + + /* receive connect response and peer ep id */ + ConnRespWithUId respWithUId {}; + void *ackBuf = static_cast(&respWithUId); + if ((result = conn->Receive(ackBuf, sizeof(ConnRespWithUId))) != 0) { + NN_LOG_ERROR("Verbs Failed receive ServerAck in Driver " << mName << ", result " << result); + return result; + } + + /* connect response */ + auto serverAck = respWithUId.connResp; + switch (serverAck) { + case MAGIC_MISMATCH: + NN_LOG_ERROR("Verbs Failed to pass server magic validation " << mName << ", Result " << serverAck); + return NN_CONNECT_REFUSED; + case WORKER_GRPNO_MISMATCH: + case WORKER_NOT_STARTED: + NN_LOG_ERROR("Verbs Failed to choose worker or not started " << mName << ", Result " << serverAck); + return NN_CONNECT_REFUSED; + case PROTOCOL_MISMATCH: + NN_LOG_ERROR("Verbs Failed to pass server protocol validation " << mName << ", Result " << serverAck); + return NN_CONNECT_PROTOCOL_MISMATCH; + case SERVER_INTERNAL_ERROR: + NN_LOG_ERROR("Verbs Server error happened, connection refused " << mName << ", Result " << serverAck); + return NN_ERROR; + case VERSION_MISMATCH: + NN_LOG_ERROR("Verbs Failed to pass server version validation " << mName << ", Result " << serverAck); + return NN_CONNECT_REFUSED; + case TLS_VERSION_MISMATCH: + NN_LOG_ERROR("Verbs Failed to pass server tls version validation " << mName << ", Result " << serverAck); + return NN_CONNECT_REFUSED; + case OK: + break; + default: + NN_LOG_ERROR("Verbs Server error happened, connection refused " << mName << ", Result " << serverAck); + return NN_ERROR; + } + + auto endSendGrpNo = NetMonotonic::TimeUs(); + auto sendGrpNoTime = endSendGrpNo - startSendGrpNo; + if (NN_UNLIKELY(sendGrpNoTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Verbs Send groupNo time too long: " << sendGrpNoTime << " us."); + } + + /* peer ep id */ + auto id = respWithUId.epId; + NN_LOG_TRACE_INFO("new ep id will be set as " << id << " in driver " << mName); + + /* create rdma ep */ + RDMAAsyncEndPoint *rep = nullptr; + uint16_t workerIndex = 0; + if (NN_UNLIKELY(!mClientLb->ChooseWorker(clientGrpNo, std::to_string(id), workerIndex)) || + workerIndex >= mWorkers.size()) { + NN_LOG_ERROR("Failed to choose worker during connect in driver " << mName); + return NN_ERROR; + } + + NN_ASSERT_LOG_RETURN(workerIndex < mWorkers.size(), NN_ERROR) + auto *worker = mWorkers[workerIndex]; + + if (!worker->IsWorkStarted()) { + NN_LOG_ERROR("Failed to connect worker group no " << clientGrpNo << " in " << mName); + return NN_ERROR; + } + + if ((result = RDMAAsyncEndPoint::Create(mName, worker, rep)) != 0) { + NN_LOG_ERROR("Failed to create ep for new connection in Driver " << mName << " , result " << result); + return result; + } + + NetLocalAutoDecreasePtr repAutoDecPtr(rep); + if ((result = rep->Initialize()) != 0) { + NN_LOG_ERROR("Failed to initialize ep for new connection in Driver " << mName << " , result " << result); + return result; + } + + /* fill and send exchange info */ + auto startExchInfo = NetMonotonic::TimeUs(); + NN_LOG_TRACE_INFO("get and send exchange info of ep"); + RDMAQpExchangeInfo info {}; + if (mHeartBeat != nullptr) { + mHeartBeat->GetRemoteHbInfo(info); + } + info.maxSendWr = mOptions.qpSendQueueSize; + info.maxReceiveWr = mOptions.qpReceiveQueueSize; + info.receiveSegCount = mOptions.prePostReceiveSizePerQP; + info.receiveSegSize = mOptions.mrSendReceiveSegSize; + + if (((result = rep->GetExchangeInfo(info)) != 0)) { + NN_LOG_ERROR("Failed to get ep exchange info in Driver " << mName << ", result " << result); + return result; + } + + if (((result = conn->Send(&info, sizeof(RDMAQpExchangeInfo))) != 0)) { + NN_LOG_ERROR("Failed to send ep exchange info in Driver " << mName << ", result " << result); + return result; + } + + auto prePostCount = mOptions.prePostReceiveSizePerQP; + // send payload len + uint32_t payloadLen = payload.length(); + if ((result = conn->Send(&payloadLen, sizeof(uint32_t))) != 0) { + NN_LOG_ERROR("Failed to send connection payload length in Driver " << mName << ", result " << result); + return result; + } + + // send payload + if (payloadLen > 0) { + auto payloadPtr = reinterpret_cast(const_cast(payload.c_str())); + if ((result = conn->Send(payloadPtr, payloadLen)) != 0) { + NN_LOG_ERROR("Failed to send connection payload in Driver " << mName << ", Result " << result); + return result; + } + } + + /* receive exchange info */ + void *tmp = static_cast(&info); + if ((result = conn->Receive(tmp, sizeof(RDMAQpExchangeInfo))) != 0) { + NN_LOG_ERROR("Failed to receive ep exchange info in Driver " << mName << ", result " << result << + ". check your header"); + return result; + } + + /* change to ready */ + NN_LOG_TRACE_INFO("remote qp ex info lid " << info.lid << ", qpn " << info.qpn << ", gid interface " << + info.gid.global.interface_id << ", pre-post-receive-count " << info.receiveSegCount); + if ((result = rep->ChangeToReady(info)) != 0) { + NN_LOG_ERROR("Verbs Failed to change ep to ready in Driver " << mName << ", result " << result); + return result; + } + + rep->PeerIpAndPort(conn->GetIpAndPort()); + + auto *mrSegs = new (std::nothrow) uintptr_t[prePostCount]; + if (mrSegs == nullptr) { + NN_LOG_ERROR("Verbs Failed to create mr address array in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + NetLocalAutoFreePtr segAutoDelete(mrSegs, true); + + if (!rep->GetFreeBufferN(mrSegs, prePostCount)) { + NN_LOG_ERROR("Failed to get free mr from pool, result " << result); + return NN_ERROR; + } + + uint16_t i = 0; + for (; i < prePostCount; i++) { + if ((result = worker->PostReceive(rep->Qp(), mrSegs[i], mOptions.mrSendReceiveSegSize, rep->GetLKey())) != 0) { + break; + } + } + + for (; i < prePostCount; i++) { + rep->ReturnBuffer(mrSegs[i]); + } + auto endExchInfo = NetMonotonic::TimeUs(); + auto exchInfoTime = endExchInfo - startExchInfo; + if (NN_UNLIKELY(exchInfoTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Exchange Info time too long: " << exchInfoTime << " us."); + } + + // create endpoint + auto startCreateEp = NetMonotonic::TimeUs(); + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(id, rep, this, worker->Index()); + if (ep.Get() == nullptr) { + NN_LOG_ERROR("Failed to create UBSHcomNetEndpoint in Driver " << mName << ", probably out of memory"); + // do later: handle pre post-ed mr + return NN_NEW_OBJECT_FAILED; + } + + auto asyncEp = ep.ToChild(); + if (NN_UNLIKELY(asyncEp == nullptr)) { + NN_LOG_ERROR("To Child Failed"); + return NN_ERROR; + } + + if (mEnableTls) { + auto oobSslConn = dynamic_cast(conn); + if (NN_UNLIKELY(oobSslConn == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + asyncEp->EnableEncrypt(mOptions); + asyncEp->SetSecrets(oobSslConn->Secret()); + } + + rep->Qp()->UpContext(reinterpret_cast(ep.Get())); + ep->StoreConnInfo(NetFunc::GetIpByFd(conn->GetFd()), conn->ListenPort(), header.version, payload); + asyncEp->SetRemoteHbInfo(info.hbAddress, info.hbKey, info.hbMrSize); + + // receive server ready signal + int8_t ready = -1; + tmp = static_cast(&ready); + result = conn->Receive(tmp, sizeof(int8_t)); + if (result != 0 || ready != 1) { + NN_LOG_ERROR("Failed to connect to server as server not responses or return not ready, result " << result); + // do later: handle pre post-ed mr + return NN_ERROR; + } + + ep->State().Set(NEP_ESTABLISHED); + { + std::lock_guard locker(mEndPointsMutex); + auto ret = mEndPoints.emplace(ep->Id(), ep); + if (!ret.second) { + NN_LOG_ERROR("Failed to emplace ep, ep Id " << ep->Id()); + return NN_ERROR; + } + } + + NN_LOG_INFO("New connect to tcp:" << client->GetServerIp() << ":" << client->GetServerPort() <<" or uds: " << + client->GetServerUdsName() << " established, async ep id " << ep->Id() << " worker info " << + worker->DetailName()); + ep->mDevIndex = mDevIndex; + ep->mPeerDevIndex = mPeerDevIndex; + ep->mBandWidth = mBandWidth; + outEp = ep; + reinterpret_cast(ep.Get())->GetRdmaEp()->Qp()->UpId(ep->Id()); + auto endCreateEp = NetMonotonic::TimeUs(); + auto createEpTime = endCreateEp - startCreateEp; + if (NN_UNLIKELY(createEpTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Create endpoint time too long: " << createEpTime << " us."); + } + return NN_OK; +} + +void NetDriverRDMAWithOob::ProcessErrorNewRequest(RDMAOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->qp == nullptr || ctx->qp->UpContext1() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker is null of RequestReceived in Driver " << mName << ""); + return; + } + + if (ctx->opType == RDMAOpContextInfo::RECEIVE) { + ctx->qp->ReturnBuffer(ctx->mrMemAddr); + auto worker = reinterpret_cast(ctx->qp->UpContext1()); + worker->ReturnOpContextInfo(ctx); + // not receive remote data, do not call user callback + } else { + NN_LOG_WARN("Unreachable path"); + } +} + +NResult NetDriverRDMAWithOob::SendRequestFinishedCB(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + RDMAWorker *worker) +{ + int result = 0; + if (ctx->opType == RDMAOpContextInfo::SEND) { + if (NN_UNLIKELY(memcpy_s(&(netCtx.mHeader), sizeof(UBSHcomNetTransHeader), + reinterpret_cast(ctx->mrMemAddr), sizeof(UBSHcomNetTransHeader)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } else { + netCtx.mHeader.Invalid(); + } + netCtx.mResult = RDMAOpContextInfo::GetNResult(ctx->opResultType); + netCtx.mEp.Set(reinterpret_cast(ctx->qp->UpContext())); + netCtx.mMessage = nullptr; + netCtx.mOpType = ctx->opType == RDMAOpContextInfo::SEND ? UBSHcomNetRequestContext::NN_SENT : + UBSHcomNetRequestContext::NN_SENT_RAW; + netCtx.mOriginalReq = {}; + // if PostSend implement with one side memory, the lAddress should be valued with ctx->mrMemAddr. + netCtx.mOriginalReq.lAddress = 0; + netCtx.mOriginalReq.size = ctx->dataSize; + netCtx.mOriginalReq.upCtxSize = ctx->upCtxSize; + + if (netCtx.mOriginalReq.upCtxSize > 0 && + netCtx.mOriginalReq.upCtxSize <= sizeof(RDMASendReadWriteRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalReq.upCtxData, NN_NO16, ctx->upCtx, ctx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + + if (NN_UNLIKELY(!mDriverSendMR->ReturnBuffer(ctx->mrMemAddr))) { + NN_LOG_ERROR("Failed to return mr segment back in Driver " << mName); + } + // return context to worker, and ctx is set null, not usable anymore + worker->ReturnOpContextInfo(ctx); + // call to callback + if (NN_UNLIKELY((result = mRequestPostedHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call requestPostedHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + netCtx.mEp.Set(nullptr); + return NN_OK; +} + +NResult NetDriverRDMAWithOob::SendRawSglFinishedCB(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + RDMAWorker *worker) +{ + int result = 0; + auto sgeCtx = reinterpret_cast(ctx->upCtx); + auto sglCtx = sgeCtx->ctx; + result = RDMAOpContextInfo::GetNResult(ctx->opResultType); + // set context + netCtx.mEp.Set(reinterpret_cast(ctx->qp->UpContext())); + netCtx.mResult = sglCtx->result < result ? result : sglCtx->result; + netCtx.mOpType = UBSHcomNetRequestContext::NN_SENT_RAW_SGL; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + if (NN_UNLIKELY(memcpy_s(netCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, sglCtx->iov, + sizeof(UBSHcomNetTransSgeIov) * sglCtx->iovCount) != NN_OK)) { + NN_LOG_ERROR("Failed to copy request to sglCtx"); + return NN_INVALID_PARAM; + } + netCtx.mOriginalSglReq.iov = netCtx.iov; + netCtx.mOriginalSglReq.iovCount = sglCtx->iovCount; + netCtx.mOriginalSglReq.upCtxSize = sglCtx->upCtxSize; + if (netCtx.mOriginalSglReq.upCtxSize > 0 && + netCtx.mOriginalSglReq.upCtxSize <= sizeof(UBSHcomNetTransSglRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalSglReq.upCtxData, NN_NO16, sglCtx->upCtx, sglCtx->upCtxSize) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy request to sglCtx"); + return NN_INVALID_PARAM; + } + } + worker->ReturnSglContextInfo(sglCtx); + // called to callback + if (NN_UNLIKELY((result = mRequestPostedHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call requestPostedHandler in Driver " << mName << " return non-zero for sgl type " << + ctx->opType << " done"); + } + netCtx.mEp.Set(nullptr); + + // buffer should return when encrypt + if (mEnableTls) { + (void)mDriverSendMR->ReturnBuffer(ctx->mrMemAddr); + } + + worker->ReturnOpContextInfo(ctx); + return NN_OK; +} + +NResult NetDriverRDMAWithOob::SendSglInlineFinishedCB(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + RDMAWorker *worker) +{ + int result = 0; + + netCtx.mResult = RDMAOpContextInfo::GetNResult(ctx->opResultType); + netCtx.mEp.Set(reinterpret_cast(ctx->qp->UpContext())); + netCtx.mMessage = nullptr; + netCtx.mOpType = UBSHcomNetRequestContext::NN_SENT_SGL_INLINE; + netCtx.mHeader.Invalid(); + netCtx.mOriginalReq = {}; + netCtx.mOriginalReq.lAddress = ctx->mrMemAddr; + netCtx.mOriginalReq.size = ctx->dataSize; + netCtx.mOriginalReq.upCtxSize = ctx->upCtxSize; + + if (netCtx.mOriginalReq.upCtxSize > 0 && + netCtx.mOriginalReq.upCtxSize <= sizeof(RDMASendReadWriteRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalReq.upCtxData, NN_NO16, ctx->upCtx, ctx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + // return context to worker, and ctx is set null, not usable anymore + worker->ReturnOpContextInfo(ctx); + // call to callback + if (NN_UNLIKELY((result = mRequestPostedHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call requestPostedHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + netCtx.mEp.Set(nullptr); + return NN_OK; +} + +int NetDriverRDMAWithOob::SendFinishedCB(RDMAOpContextInfo *ctx) +{ + static thread_local UBSHcomNetRequestContext netCtx {}; + ctx->qp->ReturnPostSendWr(); + auto worker = reinterpret_cast(ctx->qp->UpContext1()); + + if (ctx->opType == RDMAOpContextInfo::SEND || ctx->opType == RDMAOpContextInfo::SEND_RAW) { + return SendRequestFinishedCB(ctx, netCtx, worker); + } else if (ctx->opType == RDMAOpContextInfo::SEND_RAW_SGL) { + return SendRawSglFinishedCB(ctx, netCtx, worker); + } else if (ctx->opType == RDMAOpContextInfo::SEND_SGL_INLINE) { + return SendSglInlineFinishedCB(ctx, netCtx, worker); + } else { + NN_LOG_WARN("Unreachable path"); + } + + return NN_OK; +} + +void NetDriverRDMAWithOob::ProcessErrorSendFinished(RDMAOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->qp == nullptr || ctx->qp->UpContext1() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker is null of RequestReceived in Driver " << mName << ""); + return; + } + + SendFinishedCB(ctx); +} + +int NetDriverRDMAWithOob::OneSideDoneCB(RDMAOpContextInfo *ctx) +{ + int result = 0; + static thread_local UBSHcomNetRequestContext netCtx {}; + auto worker = reinterpret_cast(ctx->qp->UpContext1()); + ctx->qp->ReturnOneSideWr(); + if (ctx->opType == RDMAOpContextInfo::WRITE || ctx->opType == RDMAOpContextInfo::READ) { + // set context + netCtx.mResult = RDMAOpContextInfo::GetNResult(ctx->opResultType); + netCtx.mEp.Set(reinterpret_cast(ctx->qp->UpContext())); + netCtx.mOpType = + ctx->opType == RDMAOpContextInfo::WRITE ? UBSHcomNetRequestContext::NN_WRITTEN : + UBSHcomNetRequestContext::NN_READ; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + netCtx.mOriginalReq.lAddress = ctx->mrMemAddr; + netCtx.mOriginalReq.lKey = ctx->lKey; + netCtx.mOriginalReq.size = ctx->dataSize; + netCtx.mOriginalReq.upCtxSize = ctx->upCtxSize; + + if (netCtx.mOriginalReq.upCtxSize > 0 && + netCtx.mOriginalReq.upCtxSize <= sizeof(RDMASendReadWriteRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalReq.upCtxData, NN_NO16, ctx->upCtx, ctx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + + // return context to worker and ctx is not usable anymore + worker->ReturnOpContextInfo(ctx); + + // called to callback + if (NN_UNLIKELY((result = mOneSideDoneHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call oneSideDoneHandler in Driver " << mName << " done"); + } + netCtx.mEp.Set(nullptr); + } else if (ctx->opType == RDMAOpContextInfo::SGL_WRITE || ctx->opType == RDMAOpContextInfo::SGL_READ) { + auto sgeCtx = reinterpret_cast(ctx->upCtx); + auto sglCtx = sgeCtx->ctx; + result = RDMAOpContextInfo::GetNResult(ctx->opResultType); + sglCtx->result = sglCtx->result < result ? result : sglCtx->result; + auto refCount = __sync_add_and_fetch(&(sglCtx->refCount), 1); + if (refCount != sglCtx->iovCount) { + worker->ReturnOpContextInfo(ctx); + return NN_OK; + } + // set context + netCtx.mEp.Set(reinterpret_cast(ctx->qp->UpContext())); + netCtx.mResult = sglCtx->result; + netCtx.mOpType = ctx->opType == RDMAOpContextInfo::SGL_WRITE ? UBSHcomNetRequestContext::NN_SGL_WRITTEN : + UBSHcomNetRequestContext::NN_SGL_READ; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + if (NN_UNLIKELY(memcpy_s(netCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, sglCtx->iov, + sizeof(UBSHcomNetTransSgeIov) * sglCtx->iovCount) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + netCtx.mOriginalSglReq.iov = netCtx.iov; + netCtx.mOriginalSglReq.iovCount = sglCtx->iovCount; + netCtx.mOriginalSglReq.upCtxSize = sglCtx->upCtxSize; + if (netCtx.mOriginalSglReq.upCtxSize > 0 && + netCtx.mOriginalSglReq.upCtxSize <= sizeof(UBSHcomNetTransSglRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalSglReq.upCtxData, NN_NO16, sglCtx->upCtx, sglCtx->upCtxSize) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + worker->ReturnSglContextInfo(sglCtx); + // called to callback + if (NN_UNLIKELY((result = mOneSideDoneHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call oneSideDoneHandler in Driver " << mName << " return non-zero for sgl type " << + ctx->opType << " done"); + } + netCtx.mEp.Set(nullptr); + worker->ReturnOpContextInfo(ctx); + } else if (ctx->opType == RDMAOpContextInfo::HB_WRITE) { + auto ep = reinterpret_cast(ctx->qp->UpContext()); + if (ctx->opResultType == RDMAOpContextInfo::SUCCESS) { + ep->HbRecordCount(); + } + + worker->ReturnOpContextInfo(ctx); + } else { + NN_LOG_WARN("Unreachable path"); + } + + return NN_OK; +} + +void NetDriverRDMAWithOob::ProcessErrorOneSideDone(RDMAOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->qp == nullptr || ctx->qp->UpContext1() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker is null of RequestReceived in Driver " << mName << ""); + return; + } + + OneSideDoneCB(ctx); +} + +void NetDriverRDMAWithOob::ProcessEpError(uintptr_t ep) +{ + auto epPtr = reinterpret_cast(ep); + + bool process = false; + if (NN_UNLIKELY(!epPtr->EPBrokenProcessed().compare_exchange_strong(process, true))) { + NN_LOG_WARN("Ep id " << epPtr->Id() << " broken handled by other thread"); + return; + } + + if (epPtr->State().Compare(NEP_ESTABLISHED)) { + epPtr->State().Set(NEP_BROKEN); + } + + auto qp = epPtr->GetRdmaEp()->Qp(); + qp->Stop(); + + RDMAOpContextInfo *remainingOpCtx = nullptr; + RDMAOpContextInfo *nextOpCtx = nullptr; + qp->GetCtxPosted(remainingOpCtx); + while (remainingOpCtx != nullptr) { + ProcessErrorContext(nextOpCtx, remainingOpCtx, epPtr); + } + + // when ep set broken, there maybe some new context add + while (qp->GetPostedCount() != 0) { + NN_LOG_INFO("Process remain op ctx, qp " << qp->Name()); + qp->GetCtxPosted(remainingOpCtx); + while (remainingOpCtx != nullptr) { + ProcessErrorContext(nextOpCtx, remainingOpCtx, epPtr); + } + } + + NN_LOG_WARN("Handle Ep state " << UBSHcomNEPStateToString(epPtr->State().Get()) << ", Ep id " << epPtr->Id() << + " , try call Ep broken handle"); + UBSHcomNetEndpointPtr netEp = reinterpret_cast(epPtr); + OOBSecureProcess::SecProcessDelEpNum(epPtr->LocalIp(), epPtr->ListenPort(), epPtr->PeerIpAndPort(), + mOobServers); + mEndPointBrokenHandler(netEp); + DestroyEndpoint(netEp); +} + +void NetDriverRDMAWithOob::ProcessQPError(RDMAOpContextInfo *ctx) +{ + if (NN_UNLIKELY(!ValidateRequestContext(ctx))) { + return; + } + + // get ep + auto epPtr = reinterpret_cast(ctx->qp->UpContext()); + ProcessEpError(reinterpret_cast(epPtr)); +} + +int NetDriverRDMAWithOob::NewRequest(RDMAOpContextInfo *ctx) +{ + if (NN_UNLIKELY(!ValidateRequestContext(ctx))) { + return NN_ERROR; + } + + if (NN_UNLIKELY(ctx->opResultType != RDMAOpContextInfo::SUCCESS)) { + ProcessQPError(ctx); + return NN_OK; + } + + static thread_local UBSHcomNetRequestContext netCtx {}; + static thread_local UBSHcomNetMessage msg; + auto worker = reinterpret_cast(ctx->qp->UpContext1()); + uint32_t immData = *reinterpret_cast(ctx->upCtx); + + if (ctx->opType == RDMAOpContextInfo::RECEIVE && immData == 0) { + return NewReceivedRequest(ctx, netCtx, msg, worker); + } else if (ctx->opType == RDMAOpContextInfo::RECEIVE && immData != 0) { + return NewReceivedRawRequest(ctx, netCtx, msg, worker, immData); + } else { + NN_LOG_WARN("Unreachable path"); + } + + return NN_OK; +} + +NResult NetDriverRDMAWithOob::NewReceivedRawRequest(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + UBSHcomNetMessage &msg, RDMAWorker *worker, uint32_t immData) const +{ /* for raw message */ + bool messageReady = true; + + auto qpUpContext = ctx->qp->UpContext(); + UBSHcomNetEndpointPtr ep = reinterpret_cast(qpUpContext); + auto asyncEp = ep.ToChild(); + if (NN_UNLIKELY(asyncEp == nullptr)) { + NN_LOG_ERROR("ToChild failed"); + return NN_ERROR; + } + auto tmpDataAddress = reinterpret_cast(ctx->mrMemAddr); + if (asyncEp->mIsNeedEncrypt) { + size_t decryptRawLen = asyncEp->mAes.GetRawLen(ctx->dataSize); + messageReady = msg.AllocateIfNeed(decryptRawLen); + if (NN_LIKELY(messageReady)) { + uint32_t decryptLen = 0; + if (!asyncEp->mAes.Decrypt(asyncEp->mSecrets, tmpDataAddress, ctx->dataSize, msg.mBuf, decryptLen) || + decryptLen != decryptRawLen) { + NN_LOG_ERROR("Failed to decrypt data"); + (void)worker->RePostReceive(ctx); + return NN_DECRYPT_FAILED; + } + msg.mDataLen = decryptRawLen; + } + } else { + messageReady = msg.AllocateIfNeed(ctx->dataSize); + if (NN_LIKELY(messageReady)) { + if (NN_UNLIKELY(memcpy_s(msg.mBuf, msg.GetBufLen(), tmpDataAddress, ctx->dataSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + msg.mDataLen = ctx->dataSize; + } + } + + int result = 0; + + // after repost the ctx cannot be used anymore + if (NN_UNLIKELY((result = worker->RePostReceive(ctx)) != 0)) { + NN_LOG_ERROR("Failed to repost receive in Driver " << mName << ", result " << result); + } + + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Failed to build UBSHcomNetRequestContext or message in Driver " << mName << + ", receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << msg.mDataLen << + "] will be dropped"); + return NN_OK; + } + + netCtx.mEp.Set(reinterpret_cast(qpUpContext)); + netCtx.mMessage = &msg; + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED_RAW; + netCtx.mOriginalReq = {}; + netCtx.mHeader.Invalid(); + netCtx.mHeader.dataLength = msg.mDataLen; + netCtx.mHeader.seqNo = immData; + + // call to callback + if (NN_UNLIKELY((result = mReceivedRequestHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call receivedRequestHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + + netCtx.mEp.Set(nullptr); + + return NN_OK; +} + +NResult NetDriverRDMAWithOob::NewReceivedRequest(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + UBSHcomNetMessage &msg, RDMAWorker *worker) const +{ + bool messageReady = true; + auto *tmpHeader = reinterpret_cast(ctx->mrMemAddr); + auto qpUpContext = ctx->qp->UpContext(); + auto tmpDataAddress = reinterpret_cast(ctx->mrMemAddr + sizeof(UBSHcomNetTransHeader)); + + UBSHcomNetEndpointPtr ep = reinterpret_cast(qpUpContext); + auto asyncEp = ep.ToChild(); + if (NN_UNLIKELY(asyncEp == nullptr)) { + NN_LOG_ERROR("ToChild failed"); + return NN_ERROR; + } + + auto rst = NetFunc::ValidateHeaderWithDataSize(*tmpHeader, ctx->dataSize); + if (NN_UNLIKELY(rst != NN_OK)) { + worker->RePostReceive(ctx); + return rst; + } + + // 非加密场景可以免拷贝 + if (!asyncEp->mIsNeedEncrypt) { + return NewReceivedRequestWithoutCopy(ctx, netCtx, msg, worker, tmpDataAddress, tmpHeader); + } + + uint32_t decryptRawLen = asyncEp->mAes.GetRawLen(tmpHeader->dataLength); + messageReady = msg.AllocateIfNeed(decryptRawLen); + if (NN_LIKELY(messageReady)) { + uint32_t decryptLen = 0; + if (!asyncEp->mAes.Decrypt(asyncEp->mSecrets, tmpDataAddress, tmpHeader->dataLength, msg.mBuf, + decryptLen)) { + NN_LOG_ERROR("Verbs Failed to decrypt data"); + (void)worker->RePostReceive(ctx); + return NN_DECRYPT_FAILED; + } + if (NN_UNLIKELY(memcpy_s(&(netCtx.mHeader), sizeof(UBSHcomNetTransHeader), tmpHeader, + sizeof(UBSHcomNetTransHeader)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + msg.mDataLen = decryptRawLen; + } + + int result = 0; + if (NN_UNLIKELY((result = worker->RePostReceive(ctx)) != 0)) { + NN_LOG_ERROR("Verbs Failed to repost receive in Driver " << mName << ", result " << result); + } + + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Verbs Failed to build UBSHcomNetRequestContext or message in Driver " << mName << + ", receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << msg.mDataLen << + "] will be dropped"); + return NN_OK; + } + + netCtx.mEp.Set(reinterpret_cast(qpUpContext)); + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED; + netCtx.mMessage = &msg; + netCtx.mOriginalReq = {}; + netCtx.mHeader.dataLength = msg.mDataLen; + netCtx.extHeaderType = tmpHeader->extHeaderType; // 指导服务层处理 + + // call to callback + if (NN_UNLIKELY((result = mReceivedRequestHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Verbs Call receivedRequestHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + + netCtx.mEp.Set(nullptr); + return NN_OK; +} + +NResult NetDriverRDMAWithOob::NewReceivedRequestWithoutCopy(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + UBSHcomNetMessage &msg, RDMAWorker *worker, void *dataAddress, UBSHcomNetTransHeader *header) const +{ + if (NN_UNLIKELY(memcpy_s(&(netCtx.mHeader), sizeof(UBSHcomNetTransHeader), header, sizeof(UBSHcomNetTransHeader)) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + msg.SetBuf(dataAddress, header->dataLength); + msg.mDataLen = header->dataLength; + + netCtx.mEp.Set(reinterpret_cast(ctx->qp->UpContext())); + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED; + netCtx.mMessage = &msg; + netCtx.mOriginalReq = {}; + netCtx.mHeader.dataLength = msg.mDataLen; + netCtx.extHeaderType = header->extHeaderType; // 指导服务层处理 + int result = 0; + // call to callback + if (NN_UNLIKELY((result = mReceivedRequestHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Verbs Call receivedRequestHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + msg.SetBuf(nullptr, 0); + netCtx.mMessage = nullptr; + netCtx.mEp.Set(nullptr); + + if (NN_UNLIKELY((result = worker->RePostReceive(ctx)) != 0)) { + NN_LOG_ERROR("Verbs Failed to repost receive in Driver " << mName << ", result " << result); + } + + return NN_OK; +} + +int NetDriverRDMAWithOob::SendFinished(RDMAOpContextInfo *ctx) +{ + if (NN_UNLIKELY(!ValidateRequestContext(ctx))) { + return NN_ERROR; + } + + if (NN_UNLIKELY(ctx->opResultType != RDMAOpContextInfo::SUCCESS)) { + ProcessQPError(ctx); + return NN_OK; + } + + return SendFinishedCB(ctx); +} + +int NetDriverRDMAWithOob::OneSideDone(RDMAOpContextInfo *ctx) +{ + if (NN_UNLIKELY(!ValidateRequestContext(ctx))) { + return NN_ERROR; + } + + if (NN_UNLIKELY(ctx->opResultType != RDMAOpContextInfo::SUCCESS)) { + ProcessQPError(ctx); + return NN_OK; + } + + return OneSideDoneCB(ctx); +} +} +} +#endif diff --git a/src/transport/rdma/verbs/net_rdma_driver_oob.h b/src/transport/rdma/verbs/net_rdma_driver_oob.h new file mode 100644 index 0000000000000000000000000000000000000000..f3682559511420bf22a1107a4f53d9f246a1eebb --- /dev/null +++ b/src/transport/rdma/verbs/net_rdma_driver_oob.h @@ -0,0 +1,161 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_NET_CLIENT_SERVER_RDMA_1234244441233_H +#define OCK_NET_CLIENT_SERVER_RDMA_1234244441233_H +#ifdef RDMA_BUILD_ENABLED +#include + +#include "hcom.h" + +#include "net_oob.h" +#include "net_rdma_driver.h" +#include "net_util.h" +#include "net_heartbeat.h" +#include "rdma_common.h" +#include "rdma_mr_dm_buf.h" +#include "rdma_mr_fixed_buf.h" + +namespace ock { +namespace hcom { +/* **************************************************************************************** */ +class NetDriverRDMAWithOob : public NetDriverRDMA { +public: + NetDriverRDMAWithOob(const std::string &name, bool startOobSvr, UBSHcomNetDriverProtocol protocol) + : NetDriverRDMA(name, startOobSvr, protocol) + { + OBJ_GC_INCREASE(NetDriverRDMAWithOob); + } + + ~NetDriverRDMAWithOob() override + { + OBJ_GC_DECREASE(NetDriverRDMAWithOob); + } + + NResult Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo = 0, + uint8_t clientGrpNo = 0) override; + NResult Connect(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo = 0, uint8_t clientGrpNo = 0, + uint64_t ctx = 0) override; + NResult Connect(const std::string &serverUrl, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, + uint8_t serverGrpNo = 0, uint8_t clientGrpNo = 0, uint64_t ctx = 0) override; + + NResult MultiRailNewConnection(OOBTCPConnection &conn); + uint16_t GetHbIdleTime() + { + if (mHeartBeat == nullptr) { + NN_LOG_ERROR("mHeartBeat is nullpttr"); + return 0; + } + return mHeartBeat->GetHbIdleTime(); + } + +protected: + int NewConnectionCB(OOBTCPConnection &conn); + int SendFinished(RDMAOpContextInfo *ctx); + int NewRequest(RDMAOpContextInfo *ctx); + int OneSideDone(RDMAOpContextInfo *ctx); + + NResult DoInitialize() override; + void DoUnInitialize() override; + + NResult DoStart() override; + void DoStop() override; + + int OneSideDoneCB(RDMAOpContextInfo *ctx); + int SendFinishedCB(RDMAOpContextInfo *ctx); + + void ProcessEpError(uintptr_t ep); + void ProcessQPError(RDMAOpContextInfo *ctx); + void ProcessErrorNewRequest(RDMAOpContextInfo *ctx); + void ProcessErrorOneSideDone(RDMAOpContextInfo *ctx); + void ProcessErrorSendFinished(RDMAOpContextInfo *ctx); + +private: + friend class NetAsyncEndpoint; + friend class NetSyncEndpoint; + + NResult NewReceivedRequest(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, UBSHcomNetMessage &msg, + RDMAWorker *worker) const; + + NResult NewReceivedRawRequest(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, UBSHcomNetMessage &msg, + RDMAWorker *worker, uint32_t immData) const; + + NResult NewReceivedRequestWithoutCopy(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + UBSHcomNetMessage &msg, RDMAWorker *worker, void *dataAddress, UBSHcomNetTransHeader *header) const; + + NResult SendRequestFinishedCB(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, RDMAWorker *worker); + NResult SendRawSglFinishedCB(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, RDMAWorker *worker); + NResult SendSglInlineFinishedCB(RDMAOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, RDMAWorker *worker); + + NResult Connect(const OOBTCPClientPtr &client, const std::string &payload, UBSHcomNetEndpointPtr &outEp, + uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx); + NResult ConnectSyncEp(const OOBTCPClientPtr &client, const std::string &payload, UBSHcomNetEndpointPtr &outEp, + uint32_t flags, uint8_t serverGrpNo, uint64_t ctx); + + void DestroyEpInWorker(RDMAWorker *worker); + void DestroyEpByPortNum(int portNum); + void HandleCqEvent(struct ibv_async_event *event); + void HandlePortDown(int portNum); + void HandlePortActive(int portNum); + void HandleAsyncEvent(struct ibv_async_event *event); + void RunInRdmaEventThread(); + inline bool ValidateRequestContext(RDMAOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr || ctx->qp == nullptr || ctx->qp->UpContext1() == 0 || + ctx->qp->UpContext() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker is null of RequestReceived in Driver " << mName << ""); + return false; + } + return true; + } + __always_inline void ProcessErrorContext(RDMAOpContextInfo *&nextOpCtx, RDMAOpContextInfo *&remainingOpCtx, + UBSHcomNetEndpoint *epPtr) + { + nextOpCtx = remainingOpCtx->next; + if (remainingOpCtx->opResultType != RDMAOpContextInfo::INVALID_MAGIC) { + remainingOpCtx->opResultType = epPtr->State().Compare(NEP_BROKEN) ? RDMAOpContextInfo::ERR_EP_BROKEN : + RDMAOpContextInfo::ERR_EP_CLOSE; + switch (remainingOpCtx->opType) { + case (RDMAOpContextInfo::OpType::SEND): + case (RDMAOpContextInfo::OpType::SEND_RAW): + case (RDMAOpContextInfo::OpType::SEND_RAW_SGL): + ProcessErrorSendFinished(remainingOpCtx); + break; + case (RDMAOpContextInfo::OpType::RECEIVE): + ProcessErrorNewRequest(remainingOpCtx); + break; + case (RDMAOpContextInfo::OpType::WRITE): + case (RDMAOpContextInfo::OpType::SGL_WRITE): + case (RDMAOpContextInfo::OpType::HB_WRITE): + case (RDMAOpContextInfo::OpType::READ): + case (RDMAOpContextInfo::OpType::SGL_READ): + ProcessErrorOneSideDone(remainingOpCtx); + break; + default: + NN_LOG_ERROR("Poll cq invalid OpType " << remainingOpCtx->opType); + break; + } + } + remainingOpCtx->qpNum = 0xffffffff; + remainingOpCtx->opResultType = RDMAOpContextInfo::INVALID_MAGIC; + remainingOpCtx = nextOpCtx; + } + bool mNeedStopEvent = false; + std::thread mRdmaEventThread; + std::atomic mEventStarted { false }; + NetHeartbeat *mHeartBeat = nullptr; + friend class NetHeartbeat; +}; +} +} +#endif +#endif // _OCK_NET_CLIENT_SERVER_RDMA_1234244441233_H \ No newline at end of file diff --git a/src/transport/rdma/verbs/net_rdma_sync_endpoint.cpp b/src/transport/rdma/verbs/net_rdma_sync_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d6ccf179d416c60bc4e1fee00fc716d67b3e390 --- /dev/null +++ b/src/transport/rdma/verbs/net_rdma_sync_endpoint.cpp @@ -0,0 +1,795 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include "net_common.h" +#include "net_rdma_driver_oob.h" +#include "net_security_rand.h" +#include "rdma_validation.h" +#include "net_rdma_sync_endpoint.h" + +namespace ock { +namespace hcom { +NetSyncEndpoint::NetSyncEndpoint(uint64_t id, RDMASyncEndpoint *ep, NetDriverRDMAWithOob *driver, + const UBSHcomNetWorkerIndex &workerIndex) + : NetEndpointImpl(id, workerIndex), mEp(ep), mDriver(driver) +{ + if (mEp != nullptr) { + mEp->IncreaseRef(); + } + + if (mDriver != nullptr) { + mDriver->IncreaseRef(); + } + + if (mEp != nullptr && mDriver != nullptr) { + mSegSize = mDriver->mOptions.mrSendReceiveSegSize < mEp->Qp()->PostSendMaxSize() ? + mDriver->mOptions.mrSendReceiveSegSize : + mEp->Qp()->PostSendMaxSize(); + mAllowedSize = mSegSize - sizeof(UBSHcomNetTransHeader); + } + + /* set worker index and group index to 0xFFFF */ + mWorkerIndex.idxInGrp = INVALID_WORKER_INDEX; + mWorkerIndex.grpIdx = INVALID_WORKER_GROUP_INDEX; + + OBJ_GC_INCREASE(NetSyncEndpoint); +} + +NetSyncEndpoint::~NetSyncEndpoint() +{ + if (mEp != nullptr) { + mEp->DecreaseRef(); + mEp = nullptr; + } + + if (mDriver != nullptr) { + mDriver->DecreaseRef(); + mDriver = nullptr; + } + + OBJ_GC_DECREASE(NetSyncEndpoint); +} + +NResult NetSyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, mDriver, opCode, request, mAllowedSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post send as validate fail"); + return result; + } + + // get mr from pool + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Verbs Failed to sync post send with seq no as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + // copy message + auto *verbsHeader = reinterpret_cast(mrBufAddress); + bzero(verbsHeader, sizeof(UBSHcomNetTransHeader)); + verbsHeader->seqNo = seqNO == 0 ? NextSeq() : seqNO; + verbsHeader->opCode = opCode; + verbsHeader->flags = NTH_TWO_SIDE; + verbsHeader->dataLength = request.size; + + mLastSendSeqNo = verbsHeader->seqNo; + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("RDMA Failed to sync post send with seq no as encryption failure"); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + verbsHeader->dataLength = cipherLen; + } else { + // copy message + verbsHeader->dataLength = request.size; + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("RDMA Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + } + + /* finally fill header crc */ + verbsHeader->headerCrc = NetFunc::CalcHeaderCrc32(verbsHeader); + mDemandPollingOpType = RDMAOpContextInfo::SEND; + + // post request + // change lAddress to mrAddress and set lKey + UBSHcomNetTransRequest rdmaReq = request; + rdmaReq.lAddress = mrBufAddress; + rdmaReq.lKey = mDriver->mDriverSendMR->GetLKey(); + rdmaReq.size = sizeof(UBSHcomNetTransHeader) + verbsHeader->dataLength; + + auto syncSendFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_SYNC_POST_SEND); + do { + result = mEp->PostSend(rdmaReq); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + syncSendFlag = false; + } while (syncSendFlag); + + NN_LOG_ERROR("Failed to sync post send with seqNo, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetSyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, mDriver, opCode, request, mAllowedSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post send as validate fail"); + return result; + } + + // get mr from pool + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("RDMA Failed to sync post send with op info as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + // copy message + auto *verbsHeader = reinterpret_cast(mrBufAddress); + bzero(verbsHeader, sizeof(UBSHcomNetTransHeader)); + verbsHeader->opCode = opCode; + verbsHeader->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + verbsHeader->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint16_t)NTH_TWO_SIDE; + verbsHeader->timeout = opInfo.timeout; + verbsHeader->errorCode = opInfo.errorCode; + verbsHeader->dataLength = request.size; + + mLastSendSeqNo = verbsHeader->seqNo; + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("RDMA Failed to sync post send with op info as encryption failure"); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + verbsHeader->dataLength = cipherLen; + } else { + // copy message + verbsHeader->dataLength = request.size; + + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + } + + /* finally fill verbsHeader crc */ + verbsHeader->headerCrc = NetFunc::CalcHeaderCrc32(verbsHeader); + mDemandPollingOpType = RDMAOpContextInfo::SEND; + + // post request + // change lAddress to mrAddress and set lKey + UBSHcomNetTransRequest rdmaReq = request; + rdmaReq.lAddress = mrBufAddress; + rdmaReq.lKey = mDriver->mDriverSendMR->GetLKey(); + rdmaReq.size = sizeof(UBSHcomNetTransHeader) + verbsHeader->dataLength; + + auto syncSendOpFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_SYNC_POST_SEND); + do { + result = mEp->PostSend(rdmaReq); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + syncSendOpFlag = false; + } while (syncSendOpFlag); + + NN_LOG_ERROR("Failed to sync post send with op info, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetSyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo, const UBSHcomExtHeaderType extHeaderType, const void *extHeader, + uint32_t extHeaderSize) +{ + if (NN_UNLIKELY(extHeaderType == UBSHcomExtHeaderType::RAW)) { + NN_LOG_ERROR("You shouldn't use RAW type when extHeader is given in sync ep"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(!extHeader)) { + NN_LOG_ERROR("The ExtHeader is invalid."); + return NN_INVALID_PARAM; + } + + // 保证 extHeaderSize + request.size <= mAllowedSize. + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, mDriver, opCode, request, mAllowedSize - extHeaderSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post send as validate fail"); + return result; + } + + // get mr from pool + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to sync post send with op info as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header->timeout = opInfo.timeout; + header->errorCode = opInfo.errorCode; + header->dataLength = request.size + extHeaderSize; + header->extHeaderType = extHeaderType; + + mLastSendSeqNo = header->seqNo; + if (mIsNeedEncrypt) { + NN_LOG_WARN("postsent encrypt is not supported now!"); + } + + // 拷贝上层指定的 header,此时将要发送的结构为 + // | UBSHcomNetTransHeader | extHeader | request body | + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader), extHeader, + extHeaderSize) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + + // 拷贝消息主体 + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader) + extHeaderSize), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader) - extHeaderSize, + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + mDemandPollingOpType = RDMAOpContextInfo::SEND; + + // change lAddress to mrAddress and set lKey + UBSHcomNetTransRequest rdmaReq = request; + rdmaReq.lAddress = mrBufAddress; + rdmaReq.lKey = mDriver->mDriverSendMR->GetLKey(); + rdmaReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + + auto syncSendOpFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_SYNC_POST_SEND); + do { + result = mEp->PostSend(rdmaReq); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + syncSendOpFlag = false; + } while (syncSendOpFlag); + + NN_LOG_ERROR("Failed to sync post send with op info, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetSyncEndpoint::PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendRawValidation(mState, mId, mDriver, seqNo, request, mSegSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post send raw as validate fail"); + return result; + } + + /* get mr from pool */ + uintptr_t mrBufAddress = 0; + size_t msgSize = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to post message as failed to get mr buffer from pool from driver " << mDriver->Name()); + return NN_GET_BUFF_FAILED; + } + + if (!mIsNeedEncrypt) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress), mDriver->mDriverSendMR->GetSingleSegSize(), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to memcpy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + msgSize = request.size; + } else { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(mrBufAddress), cipherLen)) { + NN_LOG_ERROR("Failed send raw message as encryption failure"); + return NN_ENCRYPT_FAILED; + } + msgSize = cipherLen; + } + + UBSHcomNetTransRequest rdmaReq = request; + rdmaReq.lAddress = mrBufAddress; + rdmaReq.lKey = mDriver->mDriverSendMR->GetLKey(); + rdmaReq.size = msgSize; + + /* still use send */ + mDemandPollingOpType = RDMAOpContextInfo::SEND; + + mLastSendSeqNo = seqNo; + + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_SYNC_POST_SEND_RAW); + do { + result = mEp->PostSend(rdmaReq, seqNo); + if (NN_LIKELY(result == RR_OK)) { + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND_RAW, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post send request, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND_RAW, result); + return result; +} + +NResult NetSyncEndpoint::PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) +{ + size_t size = 0; + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendSglValidation(mState, mId, mDriver, seqNo, request, mSegSize, size, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post send raw sgl as validate fail"); + return result; + } + + UBSHcomNetTransRequest tlsReq {}; + uintptr_t mrBufAddress = 0; + if (mIsNeedEncrypt) { + if (NN_UNLIKELY(EncryptRawSgl(tlsReq, mrBufAddress, size, mAes, mDriver, request, mSecrets) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post send raw sgl as encrypt fail"); + return NN_ENCRYPT_FAILED; + } + } + + mDemandPollingOpType = RDMAOpContextInfo::SEND_RAW_SGL; + mLastSendSeqNo = seqNo; + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_SYNC_POST_SEND_RAW_SGL); + do { + result = mEp->PostSendSgl(request, tlsReq, seqNo, mIsNeedEncrypt); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND_RAW_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + if (mIsNeedEncrypt) { + (void)mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + } + NN_LOG_ERROR("Failed to post send raw sgl request, result " << result); + TRACE_DELAY_END(RDMA_EP_SYNC_POST_SEND_RAW_SGL, result); + return result; +} + +NResult NetSyncEndpoint::PostRead(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteValidation(mState, mId, mDriver, request)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post read as validate fail"); + return result; + } + + mDemandPollingOpType = RDMAOpContextInfo::READ; + auto readFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_SYNC_POST_READ); + do { + result = mEp->PostRead(request); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_SYNC_POST_READ, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + readFlag = false; + } while (readFlag); + + NN_LOG_ERROR("Failed to post read request, result " << result); + TRACE_DELAY_END(RDMA_EP_SYNC_POST_READ, result); + return result; +} + +NResult NetSyncEndpoint::PostRead(const UBSHcomNetTransSglRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteSglValidation(mState, mId, mDriver, request)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post read sgl as validate fail"); + return result; + } + + mDemandPollingOpType = RDMAOpContextInfo::SGL_READ; + auto readSglFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_SYNC_POST_READ_SGL); + do { + result = mEp->PostOneSideSgl(request, true); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_SYNC_POST_READ_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + readSglFlag = false; + } while (readSglFlag); + + NN_LOG_ERROR("Failed to post read sgl request, result " << result); + TRACE_DELAY_END(RDMA_EP_SYNC_POST_READ_SGL, result); + return result; +} + +NResult NetSyncEndpoint::PostWrite(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteValidation(mState, mId, mDriver, request)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post write as validate fail"); + return result; + } + + mDemandPollingOpType = RDMAOpContextInfo::WRITE; + auto writeFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_SYNC_POST_WRITE); + do { + result = mEp->PostWrite(request); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_SYNC_POST_WRITE, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + writeFlag = false; + } while (writeFlag); + + NN_LOG_ERROR("Failed to post write request, result " << result); + TRACE_DELAY_END(RDMA_EP_SYNC_POST_WRITE, result); + return result; +} + +NResult NetSyncEndpoint::PostWrite(const UBSHcomNetTransSglRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteSglValidation(mState, mId, mDriver, request)) != NN_OK)) { + NN_LOG_ERROR("RDMA failed to sync post write sgl as validate fail"); + return result; + } + + mDemandPollingOpType = RDMAOpContextInfo::SGL_WRITE; + auto writeSglFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(RDMA_EP_SYNC_POST_WRITE_SGL); + do { + result = mEp->PostOneSideSgl(request, false); + if (result == RR_OK) { + TRACE_DELAY_END(RDMA_EP_SYNC_POST_WRITE_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + writeSglFlag = false; + } while (writeSglFlag); + + NN_LOG_ERROR("Failed to post write sgl request, result " << result); + TRACE_DELAY_END(RDMA_EP_SYNC_POST_WRITE_SGL, result); + return result; +} + +NResult NetSyncEndpoint::WaitCompletion(int32_t timeout) +{ + NN_LOG_TRACE_INFO("wait completion mDemandPollingOpType " << mDemandPollingOpType); + RDMAOpContextInfo *opCtx = nullptr; + NResult result = NN_OK; + uint32_t immData = 0; + +POLL_CQ: + if (NN_UNLIKELY(result = mEp->PollingCompletion(opCtx, timeout, immData)) != NN_OK) { + // do later + return result; + } + + if (NN_UNLIKELY(opCtx->opType != mDemandPollingOpType)) { + // repost if receive opType + if (opCtx->opType == RDMAOpContextInfo::RECEIVE) { + if (mDelayHandleReceiveCtx == nullptr) { + mDelayHandleReceiveCtx = opCtx; + goto POLL_CQ; + } else { + NN_LOG_ERROR("Receive operation type has double received, prev context is not process"); + } + } + NN_LOG_WARN("Got un-demand operation type " << opCtx->opType << ", ignored by ep id " << mId); + } + + opCtx->qp->DecreaseRef(); + if (opCtx->opType == RDMAOpContextInfo::SEND) { + (void)mDriver->mDriverSendMR->ReturnBuffer(opCtx->mrMemAddr); + } + + if (mIsNeedEncrypt && opCtx->opType == RDMAOpContextInfo::SEND_RAW_SGL) { + // buffer should return when encrypt + (void)mDriver->mDriverSendMR->ReturnBuffer(opCtx->mrMemAddr); + } + + if (opCtx->opType == RDMAOpContextInfo::SGL_WRITE || opCtx->opType == RDMAOpContextInfo::SGL_READ) { + auto sgeCtx = reinterpret_cast(opCtx->upCtx); + auto sglCtx = sgeCtx->ctx; + result = RDMAOpContextInfo::GetNResult(opCtx->opResultType); + sglCtx->result = sglCtx->result < result ? sglCtx->result : result; + auto refCount = __sync_add_and_fetch(&(sglCtx->refCount), 1); + if (sglCtx->iovCount == refCount) { + return sglCtx->result; + } + goto POLL_CQ; + } + + return NN_OK; +} + +NResult NetSyncEndpoint::Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) +{ + NResult result = NN_OK; + RDMAOpContextInfo *opCtx = nullptr; + uint32_t immData = 0; + + mDemandPollingOpType = RDMAOpContextInfo::RECEIVE; + NN_LOG_TRACE_INFO("Verbs receive mDemandPollingOpType " << mDemandPollingOpType); + if (NN_UNLIKELY(mDelayHandleReceiveCtx != nullptr)) { + opCtx = mDelayHandleReceiveCtx; + mDelayHandleReceiveCtx = nullptr; + } else if (NN_UNLIKELY(result = mEp->PollingCompletion(opCtx, timeout, immData)) != NN_OK) { + // do later + return result; + } + size_t realDataSize = 0; + do { + if (NN_UNLIKELY(opCtx->opType != mDemandPollingOpType)) { + NN_LOG_ERROR("Verbs Got un-demand operation type " << opCtx->opType << ", ignored"); + result = NN_ERROR; + break; + } + + auto *tmpHeader = reinterpret_cast(opCtx->mrMemAddr); + // 可能会收到多个小包,小包的 SeqNo 在对端回复时由定时器机制生成,与 + // SyncEp 本地记录的 SeqNo 不一致,所以不再检验 SeqNo. + result = NetFunc::ValidateHeaderWithDataSize(*tmpHeader, opCtx->dataSize); + if (NN_UNLIKELY(result != NN_OK)) { + break; + } + + realDataSize = tmpHeader->dataLength; + if (mIsNeedEncrypt) { + realDataSize = mAes.GetRawLen(tmpHeader->dataLength); + } + auto msgReady = mRespMessage.AllocateIfNeed(realDataSize); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Verbs Failed to allocate memory for response size " << realDataSize << + ", probably out of memory"); + result = NN_MALLOC_FAILED; + break; + } + + if (NN_UNLIKELY(memcpy_s(&(mRespCtx.mHeader), sizeof(UBSHcomNetTransHeader), tmpHeader, + sizeof(UBSHcomNetTransHeader)) != NN_OK)) { + NN_LOG_WARN("Invalid operation to memcpy_s in Receive"); + return NN_ERROR; + } + auto tmpDataAddress = reinterpret_cast(opCtx->mrMemAddr + sizeof(UBSHcomNetTransHeader)); + + if (mIsNeedEncrypt) { + uint32_t decryptLen = 0; + if (!mAes.Decrypt(mSecrets, tmpDataAddress, tmpHeader->dataLength, mRespMessage.mBuf, decryptLen)) { + NN_LOG_ERROR("Verbs Failed to decrypt data"); + result = NN_DECRYPT_FAILED; + break; + } + mRespMessage.mDataLen = decryptLen; + } else { + if (NN_UNLIKELY(memcpy_s(mRespMessage.mBuf, mRespMessage.GetBufLen(), tmpDataAddress, + tmpHeader->dataLength) != NN_OK)) { + NN_LOG_ERROR("Failed to copy tmpDataAddress to mRespMessage"); + return NN_INVALID_PARAM; + } + mRespMessage.mDataLen = tmpHeader->dataLength; + } + } while (false); + + auto receiveFlag = true; + uint64_t finishTime = GetFinishTime(); + RResult rePostResult = RR_OK; + uintptr_t mrMemAddr = opCtx->mrMemAddr; + do { + rePostResult = mEp->RePostReceive(opCtx); + if (rePostResult == RR_OK) { + break; + } + if (NeedRetry(rePostResult) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry rePostResult or timeout = 0 + receiveFlag = false; + } while (receiveFlag); + + if (NN_UNLIKELY(rePostResult != RR_OK)) { + NN_LOG_ERROR("Failed to repost receive, result " << rePostResult); + mEp->ReturnBuffer(mrMemAddr); + return rePostResult; + } + + if (NN_LIKELY(result == NN_OK)) { + mRespMessage.mDataLen = realDataSize; + mRespCtx.mHeader.dataLength = realDataSize; + mRespCtx.mMessage = &mRespMessage; + ctx.mHeader = mRespCtx.mHeader; + ctx.mMessage = mRespCtx.mMessage; + } + + return result; +} + +NResult NetSyncEndpoint::ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) +{ + RDMAOpContextInfo *opCtx = nullptr; + NResult verbsResult = NN_OK; + uint32_t immData = 0; + + mDemandPollingOpType = RDMAOpContextInfo::RECEIVE; + + NN_LOG_TRACE_INFO("receive mDemandPollingOpType " << mDemandPollingOpType); + if (NN_UNLIKELY(mDelayHandleReceiveCtx != nullptr)) { + opCtx = mDelayHandleReceiveCtx; + mDelayHandleReceiveCtx = nullptr; + } else if (NN_UNLIKELY(verbsResult = mEp->PollingCompletion(opCtx, timeout, immData)) != NN_OK) { + // do later + return verbsResult; + } + + do { + if (NN_UNLIKELY(opCtx->opType != mDemandPollingOpType)) { + NN_LOG_ERROR("Got un-demand operation type " << opCtx->opType << " in ReceiveRaw, ignored"); + verbsResult = NN_ERROR; + break; + } + + if (NN_UNLIKELY(immData != mLastSendSeqNo)) { + NN_LOG_ERROR("Received un-matched seq no " << immData << ", demand seq no " << mLastSendSeqNo); + verbsResult = NN_SEQ_NO_NOT_MATCHED; + break; + } + + auto dataSize = opCtx->dataSize; + auto msgReady = mRespMessage.AllocateIfNeed(dataSize); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size " << opCtx->dataSize << + ", probably out of memory"); + verbsResult = NN_MALLOC_FAILED; + break; + } + + auto tmpDataAddress = reinterpret_cast(opCtx->mrMemAddr); + if (mIsNeedEncrypt) { + uint32_t decryptLen = 0; + if (!mAes.Decrypt(mSecrets, tmpDataAddress, dataSize, mRespMessage.mBuf, decryptLen)) { + NN_LOG_ERROR("Failed to decrypt data"); + verbsResult = NN_DECRYPT_FAILED; + break; + } + mRespMessage.mDataLen = decryptLen; + } else { + if (NN_UNLIKELY(memcpy_s(mRespMessage.mBuf, mRespMessage.GetBufLen(), tmpDataAddress, dataSize) != NN_OK)) { + NN_LOG_ERROR("Failed to tmpDataAddress req to mRespMessage"); + return NN_INVALID_PARAM; + } + mRespMessage.mDataLen = dataSize; + } + } while (false); + + RResult rePostResult = RR_OK; + auto receiveRawFlag = true; + uint64_t finishTime = GetFinishTime(); + uintptr_t mrMemAddr = opCtx->mrMemAddr; + do { + rePostResult = mEp->RePostReceive(opCtx); + if (rePostResult == RR_OK) { + break; + } else if (NeedRetry(rePostResult) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry verbsResult or timeout = 0 + receiveRawFlag = false; + } while (receiveRawFlag); + + if (NN_UNLIKELY(rePostResult != RR_OK)) { + NN_LOG_ERROR("Failed to repost receive raw, result " << rePostResult); + mEp->ReturnBuffer(mrMemAddr); + return rePostResult; + } + + if (NN_LIKELY(verbsResult == NN_OK)) { + mRespCtx.mMessage = &mRespMessage; + ctx.mHeader = {}; + ctx.mHeader.opCode = -1; + ctx.mHeader.seqNo = immData; + ctx.mMessage = mRespCtx.mMessage; + } + + return verbsResult; +} +} +} +#endif diff --git a/src/transport/rdma/verbs/net_rdma_sync_endpoint.h b/src/transport/rdma/verbs/net_rdma_sync_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..6c5daffbe1e34a176b43a388861b4ab55e4e2a43 --- /dev/null +++ b/src/transport/rdma/verbs/net_rdma_sync_endpoint.h @@ -0,0 +1,186 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_SYNC_ENDPOINT_RDMA_H +#define OCK_HCOM_NET_SYNC_ENDPOINT_RDMA_H +#ifdef RDMA_BUILD_ENABLED + +#include "hcom.h" +#include "transport/net_endpoint_impl.h" +#include "rdma_composed_endpoint.h" +#include "net_monotonic.h" +#include "net_rdma_driver_oob.h" +#include "net_security_alg.h" +#include "hcom_utils.h" + +namespace ock { +namespace hcom { +class NetSyncEndpoint : public NetEndpointImpl { +public: + NetSyncEndpoint(uint64_t id, RDMASyncEndpoint *ep, NetDriverRDMAWithOob *driver, + const UBSHcomNetWorkerIndex &workerIndex); + ~NetSyncEndpoint() override; + + NResult SetEpOption(UBSHcomEpOptions &epOptions) override + { + NN_LOG_WARN("[RDMA SyncEp] Empty function for now"); + return NN_OK; + } + + uint32_t GetSendQueueCount() override + { + NN_LOG_WARN("[RDMA SyncEp] Empty function for now"); + return 0; + } + + inline void PollingMode(RDMAPollingMode m) + { + mPollingMode = m; + } + + const std::string &PeerIpAndPort() override + { + if (mEp != nullptr) { + return mEp->PeerIpAndPort(); + } + + return CONST_EMPTY_STRING; + } + + const std::string &UdsName() override + { + NN_LOG_WARN("[RDMA SyncEp] Empty function for now"); + return CONST_EMPTY_STRING; + } + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + + NResult PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + NResult PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo = 0) override; + + NResult WaitCompletion(int32_t timeout) override; + NResult PostRead(const UBSHcomNetTransRequest &request) override; + NResult PostRead(const UBSHcomNetTransSglRequest &request) override; + NResult PostWrite(const UBSHcomNetTransRequest &request) override; + NResult PostWrite(const UBSHcomNetTransSglRequest &request) override; + + NResult Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) override; + NResult ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) override; + + inline RDMASyncEndpoint *GetRdmaEp() + { + return mEp; + } + + NResult GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &verbsIdInfo) override + { + if (!mState.Compare(NEP_ESTABLISHED)) { + NN_LOG_ERROR("[RDMA SyncEp] EP is not established"); + return NN_EP_NOT_ESTABLISHED; + } + + if (!mDriver->mStartOobSvr) { + NN_LOG_ERROR("[RDMA SyncEp] oob server is not start"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + if (mDriver->mOptions.oobType != NET_OOB_UDS) { + NN_LOG_ERROR("[RDMA SyncEp] oob type is not uds"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + verbsIdInfo = mRemoteUdsIdInfo; + return NN_OK; + } + + bool GetPeerIpPort(std::string &ip, uint16_t &port) override + { + if (NN_UNLIKELY(mEp == nullptr)) { + return false; + } + + auto ipPort = mEp->PeerIpAndPort(); + if (NN_UNLIKELY(ipPort.empty())) { + NN_LOG_ERROR("[RDMA] ip and port of peer is empty"); + return false; + } + + std::vector ipPortVec; + NetFunc::NN_SplitStr(ipPort, ":", ipPortVec); + if (NN_UNLIKELY(ipPortVec.size() != NN_NO2)) { + NN_LOG_ERROR("[RDMA] ip and port of peer is invalid"); + return false; + } + + try { + port = std::stoi(ipPortVec[1]); + } catch (...) { + NN_LOG_ERROR("[RDMA] port of peer is invalid"); + return false; + } + if (port == 0) { + NN_LOG_ERROR("[RDMA] oob type is uds, does not have peer ip and port msg"); + return false; + } + ip = ipPortVec[0]; + + return true; + } + + void Close() override + { + auto qp = GetRdmaEp()->Qp(); + qp->Stop(); + } + +protected: + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo, + const UBSHcomExtHeaderType extHeaderType, const void *extHeader, uint32_t extHeaderSize) override; + +private: + static inline bool NeedRetry(RResult result) + { + if (result == RR_QP_POST_SEND_WR_FULL || result == RR_QP_ONE_SIDE_WR_FULL || result == RR_QP_CTX_FULL) { + return true; + } + + return false; + } + + inline uint64_t GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + + RDMASyncEndpoint *mEp = nullptr; + NetDriverRDMAWithOob *mDriver = nullptr; + RDMAPollingMode mPollingMode = RDMAPollingMode::BUSY_POLLING; + uint32_t mLastSendSeqNo = 0; + RDMAOpContextInfo::OpType mDemandPollingOpType = RDMAOpContextInfo::SEND; + UBSHcomNetResponseContext mRespCtx; + UBSHcomNetMessage mRespMessage; + RDMAOpContextInfo *mDelayHandleReceiveCtx = nullptr; + + friend class NetDriverRDMAWithOob; +}; +} +} + +#endif +#endif // OCK_HCOM_NET_SYNC_ENDPOINT_RDMA_H diff --git a/src/transport/rdma/verbs/rdma_composed_endpoint.cpp b/src/transport/rdma/verbs/rdma_composed_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a6a59ae15ce6b02a3ae8778446bfb188413c9a2c --- /dev/null +++ b/src/transport/rdma/verbs/rdma_composed_endpoint.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include "rdma_composed_endpoint.h" +namespace ock { +namespace hcom { +RResult RDMAAsyncEndPoint::Create(const std::string &name, RDMAWorker *worker, RDMAAsyncEndPoint *&ep) +{ + if (worker == nullptr || name.empty()) { + return RR_PARAM_INVALID; + } + + RDMAQp *tmpQP = nullptr; + RResult result = worker->CreateQP(tmpQP); + if (result != RR_OK) { + return result; + } + + auto tmpEP = new (std::nothrow) RDMAAsyncEndPoint(name, worker, tmpQP); + if (tmpEP == nullptr) { + delete tmpQP; + NN_LOG_ERROR("Failed to create RDMAAsyncEndPoint, probably out of memory"); + return RR_NEW_OBJECT_FAILED; + } + + tmpQP->Name(name); + ep = tmpEP; + return RR_OK; +} + +RResult RDMASyncEndpoint::Create(const std::string &name, RDMAContext *ctx, RDMAPollingMode pollMode, + uint32_t rdmaOpCtxPoolSize, const QpOptions &options, RDMASyncEndpoint *&ep) +{ + if (ctx == nullptr || name.empty()) { + return RR_PARAM_INVALID; + } + + auto tmpCQ = new (std::nothrow) RDMACq(name, ctx, pollMode == EVENT_POLLING); + if (tmpCQ == nullptr) { + NN_LOG_ERROR("Failed to create RDMACq, probably out of memory"); + return RR_NEW_OBJECT_FAILED; + } + + auto tmpQP = new (std::nothrow) RDMAQp(name, RDMAQp::NewId(), ctx, tmpCQ, options); + if (tmpQP == nullptr) { + NN_LOG_ERROR("Failed to create RDMAQp, probably out of memory"); + delete tmpCQ; + return RR_NEW_OBJECT_FAILED; + } + + auto tmpEP = new (std::nothrow) RDMASyncEndpoint(name, ctx, pollMode, tmpCQ, tmpQP, rdmaOpCtxPoolSize); + if (tmpEP == nullptr) { + NN_LOG_ERROR("Failed to create RDMASyncClientEndPoint, probably out of memory"); + delete tmpCQ; + delete tmpQP; + return RR_NEW_OBJECT_FAILED; + } + + ep = tmpEP; + return RR_OK; +} +} +} +#endif diff --git a/src/transport/rdma/verbs/rdma_composed_endpoint.h b/src/transport/rdma/verbs/rdma_composed_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..3c5c909c4d069f52c1ea12b1c0b1f76871807d83 --- /dev/null +++ b/src/transport/rdma/verbs/rdma_composed_endpoint.h @@ -0,0 +1,642 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_RDMA_COMPOSED_ENDPOINT_12342437333_H +#define OCK_RDMA_COMPOSED_ENDPOINT_12342437333_H +#ifdef RDMA_BUILD_ENABLED + +#include "net_common.h" +#include "net_monotonic.h" +#include "rdma_common.h" +#include "rdma_verbs_wrapper_qp.h" +#include "rdma_worker.h" + +namespace ock { +namespace hcom { +/* *********************************************************************************** */ +class RDMAEndpoint { +public: + RDMAEndpoint(const std::string &name, RDMAQp *qp) : mName(name), mQP(qp) + { + if (mQP != nullptr) { + mQP->IncreaseRef(); + } + OBJ_GC_INCREASE(RDMAEndpoint); + } + + virtual ~RDMAEndpoint() + { + if (mQP != nullptr) { + mQP->DecreaseRef(); + mQP = nullptr; + } + OBJ_GC_DECREASE(RDMAEndpoint); + } + + virtual RResult Initialize() = 0; + virtual void UnInitialize() = 0; + + /* + * @brief, get the name of the ep + */ + inline const std::string &Name() const + { + return mName; + } + + /* + * @brief Get qp exchange info + * + * @param info [out] the exchange into + * + * @return 0 is successful + */ + inline RResult GetExchangeInfo(RDMAQpExchangeInfo &info) + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_QP_NOT_INITIALIZED; + } + + return mQP->GetExchangeInfo(info); + } + + /* + * @brief Change the QP to RTR & RTS + * + * @param info [in] the exchange from peer + * + * @return 0 is successful + */ + inline RResult ChangeToReady(RDMAQpExchangeInfo &info) + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_EP_NOT_INITIALIZED; + } + + return mQP->ChangeToReady(info); + } + + /* + * @brief Get peer ip and port + * + * @return peer ip and port + */ + inline const std::string &PeerIpAndPort() + { + if (NN_UNLIKELY(mQP != nullptr)) { + return mQP->PeerIpAndPort(); + } + return CONST_EMPTY_STRING; + } + + /* + * @brief Set peer ip and port + * + * @param value [in] ip and port + */ + inline void PeerIpAndPort(const std::string &value) + { + if (NN_UNLIKELY(mQP != nullptr)) { + mQP->PeerIpAndPort(value); + } + } + + /* + * @brief Get the qp object + */ + inline RDMAQp *Qp() const + { + return mQP; + } + + inline bool GetFreeBuffer(uintptr_t &item) + { + return mQP->GetFreeBuff(item); + } + + inline bool GetFreeBufferN(uintptr_t *&items, uint32_t n) + { + return mQP->GetFreeBufferN(items, n); + } + + inline bool ReturnBuffer(uintptr_t item) + { + return mQP->ReturnBuffer(item); + } + + inline uint32_t GetLKey() const + { + return mQP->GetLKey(); + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +protected: + std::string mName; + RDMAQp *mQP = nullptr; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; + +/* *********************************************************************************** */ +/* + * @brief, both send cq and receive cq are in worker + */ +class RDMAAsyncEndPoint : public RDMAEndpoint { +public: + static RResult Create(const std::string &name, RDMAWorker *worker, RDMAAsyncEndPoint *&ep); + +public: + RDMAAsyncEndPoint(const std::string &name, RDMAWorker *worker, RDMAQp *qp) : RDMAEndpoint(name, qp), mWorker(worker) + { + if (mWorker != nullptr) { + mWorker->IncreaseRef(); + } + + OBJ_GC_INCREASE(RDMAAsyncEndPoint); + } + + ~RDMAAsyncEndPoint() override + { + if (mWorker != nullptr) { + mWorker->DecreaseRef(); + mWorker = nullptr; + } + + OBJ_GC_DECREASE(RDMAAsyncEndPoint); + } + + RResult Initialize() override + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_EP_NOT_INITIALIZED; + } + + RResult result = RR_OK; + // initialize QP + if ((result = mQP->Initialize()) != RR_OK) { + return result; + } + + return result; + } + + void UnInitialize() override {} + + RDMAAsyncEndPoint() = delete; + RDMAAsyncEndPoint(const RDMAAsyncEndPoint &) = delete; + RDMAAsyncEndPoint &operator = (const RDMAAsyncEndPoint &) = delete; + RDMAAsyncEndPoint(RDMAAsyncEndPoint &&) = delete; + RDMAAsyncEndPoint &operator = (RDMAAsyncEndPoint &&) = delete; + +private: + RDMAWorker *mWorker = nullptr; + + friend class NetDriverRDMAWithOob; +}; + +/* *********************************************************************************** */ +/* + * @brief, both send cq and receive cq in its + */ +class RDMASyncEndpoint : public RDMAEndpoint { +public: + static RResult Create(const std::string &name, RDMAContext *ctx, RDMAPollingMode pollMode, + uint32_t rdmaOpCtxPoolSize, const QpOptions &options, RDMASyncEndpoint *&ep); + +public: + RDMASyncEndpoint(const std::string &name, RDMAContext *ctx, RDMAPollingMode pollMode, RDMACq *cq, RDMAQp *qp, + uint32_t rdmaOpCtxPoolSize) + : RDMAEndpoint(name, qp), mContext(ctx), mPollingMode(pollMode), mCq(cq), mCtxPool(name, rdmaOpCtxPoolSize) + { + if (mContext != nullptr) { + mContext->IncreaseRef(); + } + + if (mCq != nullptr) { + mCq->IncreaseRef(); + } + + OBJ_GC_INCREASE(RDMASyncEndpoint); + } + + ~RDMASyncEndpoint() override + { + if (mContext != nullptr) { + mContext->DecreaseRef(); + mContext = nullptr; + } + + if (mCq != nullptr) { + mCq->DecreaseRef(); + mCq = nullptr; + } + + OBJ_GC_DECREASE(RDMASyncEndpoint); + } + + RResult Initialize() override + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_EP_NOT_INITIALIZED; + } + + if (NN_UNLIKELY(mCq == nullptr)) { + return RR_EP_NOT_INITIALIZED; + } + + RResult result = RR_OK; + // initialize cq + if ((result = mCq->Initialize()) != RR_OK) { + return result; + } + + if ((result = mQP->Initialize()) != RR_OK) { + return result; + } + + if ((result = mCtxPool.Initialize()) != RR_OK) { + return result; + } + + return result; + } + + void UnInitialize() override + { + if (mQP != nullptr) { + mQP->UnInitialize(); + } + + if (mCq != nullptr) { + mCq->UnInitialize(); + } + } + + inline RResult PostReceive(uintptr_t bufAddress, uint32_t bufSize, uint32_t localKey) + { + if (NN_UNLIKELY(mQP == nullptr)) { + NN_LOG_ERROR("Failed to PostReceive with RDMASyncEndpoint " << mName << " as qp is null"); + return RR_PARAM_INVALID; + } + + RDMAOpContextInfo *ctx = nullptr; + if (NN_UNLIKELY(!mCtxPool.Dequeue(ctx))) { + NN_LOG_ERROR("Failed to PostReceive with RDMASyncEndpoint " << mName << " as no ctx left"); + return RR_PARAM_INVALID; + } + + ctx->qp = mQP; + ctx->mrMemAddr = bufAddress; + ctx->dataSize = bufSize; + ctx->qpNum = mQP->QpNum(); + ctx->lKey = localKey; + ctx->opType = RDMAOpContextInfo::RECEIVE; + ctx->opResultType = RDMAOpContextInfo::SUCCESS; + mQP->IncreaseRef(); + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + mQP->AddOpCtxInfo(ctx); + + auto result = mQP->PostReceive(bufAddress, bufSize, localKey, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + mQP->DecreaseRef(); + mQP->RemoveOpCtxInfo(ctx); + mCtxPool.Enqueue(ctx); + } + + // ctx could not be used if post successfully + return result; + } + + inline RResult RePostReceive(RDMAOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr || ctx->qp == nullptr)) { + NN_LOG_ERROR("Failed to RePostReceive with RDMASyncEndpoint " << mName << " as ctx or its qp is null"); + return RR_PARAM_INVALID; + } + + auto result = + ctx->qp->PostReceive(ctx->mrMemAddr, mQP->PostRegMrSize(), ctx->lKey, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + ctx->qp->DecreaseRef(); + mQP->RemoveOpCtxInfo(ctx); + mCtxPool.Enqueue(ctx); + } + + // ctx could not be used if post successfully + return result; + } + + inline RResult PostSend(const RDMASendReadWriteRequest &req, uint32_t immData = 0) + { + if (NN_UNLIKELY(mQP == nullptr)) { + NN_LOG_ERROR("Failed to PostSend with RDMASyncEndpoint " << mName << " as qp is null"); + return RR_PARAM_INVALID; + } + + static thread_local RDMAOpContextInfo ctx {}; + ctx.qp = mQP; + ctx.mrMemAddr = req.lAddress; + ctx.dataSize = req.size; + ctx.qpNum = mQP->QpNum(); + // Prevent integer truncation, safely converts uint64_t to uint32_t + if (NN_UNLIKELY(req.lKey > UINT32_MAX)) { + NN_LOG_ERROR("Failed to PostSend with RDMASyncEndpoint as lKey is larger than uint32max, lkey" << req.lKey); + return RR_PARAM_INVALID; + } + ctx.lKey = static_cast(req.lKey); + ctx.opType = RDMAOpContextInfo::SEND; + ctx.opResultType = RDMAOpContextInfo::SUCCESS; + ctx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(ctx.upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy req to ctx"); + return RR_PARAM_INVALID; + } + } + mQP->IncreaseRef(); + + auto result = mQP->PostSend(req.lAddress, req.size, static_cast(req.lKey), + reinterpret_cast(&ctx), immData); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + mQP->DecreaseRef(); + } + + // ctx could not be used if post successfully + return result; + } + + RResult PostSendSgl(const RDMASendSglRWRequest &req, const RDMASendReadWriteRequest &tlsReq, uint32_t immData, + bool isEncrypted = false) + { + if (NN_UNLIKELY(mQP == nullptr)) { + NN_LOG_ERROR("Failed to PostSendSgl with RDMAWorker " << mName << " as qp is null"); + return RR_PARAM_INVALID; + } + + static thread_local RDMASglContextInfo sglCtx; + sglCtx.qp = mQP; + sglCtx.result = RR_OK; + if (NN_UNLIKELY(memcpy_s(sglCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, req.iov, + sizeof(UBSHcomNetTransSgeIov) * req.iovCount) != RR_OK)) { + NN_LOG_ERROR("Failed to copy request to sglCtx"); + return RR_PARAM_INVALID; + } + sglCtx.iovCount = req.iovCount; + sglCtx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx.upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy request to sglCtx"); + return RR_PARAM_INVALID; + } + } + + static thread_local RDMAOpContextInfo ctx; + + // if not encrypt reqTls lAddress\size\lKey is 0 + ctx.mrMemAddr = tlsReq.lAddress; + ctx.dataSize = tlsReq.size; + // Prevent integer truncation, safely converts uint64_t to uint32_t + if (NN_UNLIKELY(tlsReq.lKey > UINT32_MAX)) { + NN_LOG_ERROR("Failed to PostSendSgl with RDMASyncEp as lKey is larger than uint32max, lkey" << tlsReq.lKey); + return RR_PARAM_INVALID; + } + ctx.lKey = static_cast(tlsReq.lKey); + ctx.qp = mQP; + ctx.qpNum = mQP->QpNum(); + ctx.opType = RDMAOpContextInfo::SEND_RAW_SGL; + ctx.opResultType = RDMAOpContextInfo::SUCCESS; + ctx.upCtxSize = static_cast(sizeof(RDMASgeCtxInfo)); + auto upCtx = reinterpret_cast(&ctx.upCtx); + upCtx->ctx = &sglCtx; + mQP->IncreaseRef(); + + RResult result = RR_OK; + if (isEncrypted) { + result = + mQP->PostSend(tlsReq.lAddress, tlsReq.size, static_cast(tlsReq.lKey), + reinterpret_cast(&ctx), immData); + } else { + result = mQP->PostSendSgl(req.iov, req.iovCount, reinterpret_cast(&ctx), immData); + } + + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + mQP->DecreaseRef(); + } + + return result; + } + + inline RResult PostRead(const RDMASendReadWriteRequest &req) + { + if (NN_UNLIKELY(mQP == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with RDMASyncEndpoint " << mName << " as qp is null"); + return RR_PARAM_INVALID; + } + + static thread_local RDMAOpContextInfo ctx {}; + ctx.qp = mQP; + ctx.mrMemAddr = req.lAddress; + ctx.dataSize = req.size; + ctx.qpNum = mQP->QpNum(); + ctx.lKey = req.lKey; + ctx.opType = RDMAOpContextInfo::READ; + ctx.opResultType = RDMAOpContextInfo::SUCCESS; + ctx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(ctx.upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy request to sglCtx"); + return RR_PARAM_INVALID; + } + } + mQP->IncreaseRef(); + + auto result = + mQP->PostRead(req.lAddress, req.lKey, req.rAddress, req.rKey, req.size, reinterpret_cast(&ctx)); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + mQP->DecreaseRef(); + } + + // ctx could not be used if post successfully + return result; + } + + inline RResult PostWrite(const RDMASendReadWriteRequest &req) + { + if (NN_UNLIKELY(mQP == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with RDMASyncEndpoint " << mName << " as qp is null"); + return RR_PARAM_INVALID; + } + + static thread_local RDMAOpContextInfo ctx {}; + ctx.qp = mQP; + ctx.mrMemAddr = req.lAddress; + ctx.dataSize = req.size; + ctx.qpNum = mQP->QpNum(); + ctx.lKey = req.lKey; + ctx.opType = RDMAOpContextInfo::WRITE; + ctx.opResultType = RDMAOpContextInfo::SUCCESS; + ctx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(ctx.upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return RR_PARAM_INVALID; + } + } + mQP->IncreaseRef(); + + auto result = + mQP->PostWrite(req.lAddress, req.lKey, req.rAddress, req.rKey, req.size, reinterpret_cast(&ctx)); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + mQP->DecreaseRef(); + } + + // ctx could not be used if post successfully + return result; + } + + RResult PostOneSideSgl(const RDMASendSglRWRequest &req, bool isRead = true) + { + if (NN_UNLIKELY(mQP == nullptr)) { + NN_LOG_ERROR("Failed to oneSide operation with RDMAWorker " << mName << " as qp is null"); + return RR_PARAM_INVALID; + } + + static thread_local RDMASglContextInfo sglCtx; + sglCtx.result = RR_OK; + sglCtx.qp = mQP; + if (NN_UNLIKELY(memcpy_s(sglCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, req.iov, + sizeof(UBSHcomNetTransSgeIov) * req.iovCount) != RR_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return RR_PARAM_INVALID; + } + sglCtx.iovCount = req.iovCount; + sglCtx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx.upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return RR_PARAM_INVALID; + } + } + sglCtx.refCount = 0; + RDMASgeCtxInfo sgeInfo(&sglCtx); + uint64_t ctxArr[NET_SGE_MAX_IOV]; + RResult result = CreateOneSideCtx(sgeInfo, req.iov, req.iovCount, ctxArr, isRead); + if (result != RR_OK) { + NN_LOG_ERROR("Failed to create one side ctx."); + return result; + } + + result = mQP->PostOneSideSgl(req.iov, req.iovCount, ctxArr, isRead); + if (NN_UNLIKELY(result != RR_OK)) { + for (int i = 0; i < req.iovCount; ++i) { + mQP->DecreaseRef(); + } + } + + return result; + } + + RResult CreateOneSideCtx(RDMASgeCtxInfo &sgeInfo, UBSHcomNetTransSgeIov *iov, uint32_t iovCount, + uint64_t (&ctxArr)[NET_SGE_MAX_IOV], bool isRead) + { + if (iov == nullptr || iovCount == NN_NO0 || iovCount > NN_NO4 || ctxArr == nullptr) { + NN_LOG_ERROR("Failed to create oneSide operation ctx because param invalid"); + return RR_PARAM_INVALID; + } + static thread_local RDMAOpContextInfo ctx[NN_NO4] = {}; + for (uint32_t i = 0; i < iovCount; ++i) { + ctx[i].qp = mQP; + ctx[i].mrMemAddr = iov[i].lAddress; + ctx[i].dataSize = iov[i].size; + ctx[i].qpNum = mQP->QpNum(); + ctx[i].lKey = iov[i].lKey; + ctx[i].opType = isRead ? RDMAOpContextInfo::SGL_READ : RDMAOpContextInfo::SGL_WRITE; + ctx[i].opResultType = RDMAOpContextInfo::SUCCESS; + ctx[i].upCtxSize = static_cast(sizeof(RDMASgeCtxInfo)); + auto upCtx = static_cast((void *)&(ctx[i].upCtx)); + upCtx->ctx = sgeInfo.ctx; + upCtx->idx = i; + mQP->IncreaseRef(); + + ctxArr[i] = reinterpret_cast(&ctx[i]); + } + return RR_OK; + } + + inline RResult PollingCompletion(RDMAOpContextInfo *&ctx, int32_t timeout, uint32_t &immData) + { + if (NN_UNLIKELY(mCq == nullptr)) { + NN_LOG_ERROR("Failed to polling completion with RDMASyncEndpoint " << mName << " as cq is null"); + return RR_EP_NOT_INITIALIZED; + } + + int32_t timeoutInMs = TimeSecToMs(timeout); + ibv_wc wc {}; + int pollCount = 1; + RResult result = RR_OK; + if (mPollingMode == BUSY_POLLING) { + auto start = NetMonotonic::TimeMs(); + int64_t pollTime = 0; + do { + pollCount = 1; + result = mCq->ProgressV(&wc, pollCount); + + pollTime = (int64_t)(NetMonotonic::TimeMs() - start); + if (pollCount == 0 && timeoutInMs >= 0 && pollTime > timeoutInMs) { + return RR_CQ_EVENT_GET_TIMOUT; + } + } while (result == RR_OK && pollCount == 0); + } else if (mPollingMode == EVENT_POLLING) { + result = mCq->EventProgressV(&wc, pollCount, timeoutInMs); + } + + if (NN_UNLIKELY(result != RR_OK)) { + return result; + } + + auto *contextInfo = reinterpret_cast(wc.wr_id); + if (contextInfo == nullptr) { + NN_LOG_ERROR("Failed to polling completion with RDMASyncEndpoint " << mName << " as contextInfo is null"); + return RR_CQ_WC_WRONG; + } + contextInfo->dataSize = wc.byte_len; + contextInfo->opResultType = RDMAOpContextInfo::OpResult(wc); + ctx = contextInfo; + if (NN_UNLIKELY(wc.status != IBV_WC_SUCCESS)) { + NN_LOG_ERROR("Poll cq failed in RDMASyncEndpoint " << mName << ", wcStatus " << wc.status << ", opType " << + contextInfo->opType); + return RR_CQ_WC_WRONG; + } + immData = wc.imm_data; + + return RR_OK; + } + +private: + RDMAContext *mContext = nullptr; + RDMAPollingMode mPollingMode = RDMAPollingMode::EVENT_POLLING; + RDMACq *mCq = nullptr; + NetObjPool mCtxPool; +}; +} +} + +#endif +#endif // OCK_RDMA_COMPOSED_ENDPOINT_12342437333_H \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_device_helper.cpp b/src/transport/rdma/verbs/rdma_device_helper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..034d4e2aebc3880633170359f05d12015b6fd558 --- /dev/null +++ b/src/transport/rdma/verbs/rdma_device_helper.cpp @@ -0,0 +1,379 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + + +#ifdef RDMA_BUILD_ENABLED + +#include "rdma_device_helper.h" + +namespace ock { +namespace hcom { + +static const char* RDMARoCEVersionStrTable[] = { + "Unknown", + "IB/RoCE v1", + "RoCE v1.5", + "RoCE v2", +}; + +bool RDMADeviceHelper::G_Inited = false; +std::unordered_map RDMADeviceHelper::G_RDMADevMap; +std::unordered_map> RDMADeviceHelper::G_RDMADevGidTable; +std::mutex RDMADeviceHelper::G_Mutex; +uint32_t RDMADeviceHelper::PORT_NUMBER = 1; + +RResult RDMADeviceHelper::Initialize() +{ + if (!G_Inited) { + std::lock_guard guard(G_Mutex); + if (!G_Inited) { + // double check + return DoInitialize(); + } + } + + return RR_OK; +} + +void RDMADeviceHelper::UnInitialize() +{ + std::lock_guard guard(G_Mutex); + G_RDMADevMap.clear(); + G_RDMADevGidTable.clear(); + G_Inited = false; +} + +RResult RDMADeviceHelper::DoInitialize() +{ + auto ret = DoUpdate(); + if (NN_UNLIKELY(ret != RR_OK)) { + return ret; + } + + G_Inited = true; + return RR_OK; +} + +RResult RDMADeviceHelper::DoUpdate() +{ + HcomIbv::ForkInit(); + G_RDMADevMap.clear(); + G_RDMADevGidTable.clear(); + + struct ibv_device **devList = nullptr; + int devCount = 0; + devList = HcomIbv::GetDevList(&devCount); + + NN_LOG_TRACE_INFO("RDMA Device count:" << devCount); + if (devList == nullptr) { + NN_LOG_ERROR("Failed to call get ibv device list, errno " << errno); + return RR_DEVICE_FAILED_OPEN; + } + auto guard = MakeScopeExit([&devList]() { HcomIbv::FreeDevList(devList); }); + G_RDMADevMap.reserve(devCount); + G_RDMADevGidTable.reserve(devCount); + + struct ibv_port_attr portAttr {}; + for (int i = 0; i < devCount; i++) { + if (devList[i] == nullptr) { // should not happen + NN_LOG_WARN("RDMA Device " << i << " is null"); + continue; + } + + RDMADeviceSimpleInfo info; + info.devIndex = i; + if (NN_UNLIKELY(strcpy_s(info.devName, IBV_SYSFS_NAME_MAX, reinterpret_cast(devList[i]->name)) != + RR_OK)) { + NN_LOG_ERROR("Failed to copy devName in initializing device"); + return RR_PARAM_INVALID; + } + NN_LOG_TRACE_INFO("RDMA Device " << i << " name " << devList[i]->name); + std::vector gidVec; + gidVec.reserve(NN_NO16); + + auto ctx = HcomIbv::OpenDevice(devList[i]); + if (ctx != nullptr && HcomIbv::QueryPort(ctx, PORT_NUMBER, &portAttr) == 0) { + info.active = (portAttr.state == IBV_PORT_ACTIVE); + GetGidVec(ctx, info.devName, i, portAttr.active_speed, portAttr.gid_tbl_len, gidVec); + } + + struct ibv_device_attr attr {}; + if (info.active && HcomIbv::QueryDevice(ctx, &attr) == 0) { + info.deviceInfo.maxSge = attr.max_sge; + } + + G_RDMADevMap.emplace(i, info); + G_RDMADevGidTable.emplace(info.devName, gidVec); + if (ctx != nullptr) { + HcomIbv::CloseDev(ctx); + } + } + return RR_OK; +} + +RResult RDMADeviceHelper::Update() +{ + std::lock_guard guard(G_Mutex); + return DoUpdate(); +} + +void RDMADeviceHelper::GetGidVec(ibv_context *context, const std::string &devName, uint16_t devIndex, uint8_t bandWidth, + uint32_t gidTableLen, std::vector &outGidVec) +{ + if (context == nullptr) { + return; + } + + union ibv_gid tmpIbvGid {}; + std::string RoCEVersion; + RDMAGId gid {}; + for (uint32_t i = 0; i < gidTableLen; i++) { + if (HcomIbv::QueryGid(context, PORT_NUMBER, i, &tmpIbvGid) != 0) { + continue; + } + + if (tmpIbvGid.global.interface_id == 0) { + continue; + } + + if (ReadRoCEVersionFromFile(devName, PORT_NUMBER, i, RoCEVersion) != RR_OK) { + continue; + } + + gid.devIndex = devIndex; + gid.gid = i; + gid.ibvGid = tmpIbvGid; + gid.RoCEVersion = StrToRoCEVersion(RoCEVersion); + gid.bandWidth = bandWidth; + outGidVec.push_back(gid); + } +} + +RResult RDMADeviceHelper::GetDeviceCount(uint16_t &deviceCount, std::vector &enabledDevices) +{ + RResult result = RR_OK; + if ((result = Initialize()) != RR_OK) { + return result; + } + + { + std::lock_guard guard(G_Mutex); + deviceCount = G_RDMADevMap.size(); + for (auto &item : G_RDMADevMap) { + if (item.second.active) { + enabledDevices.push_back(item.second); + } + } + } + + return RR_OK; +} + +RResult RDMADeviceHelper::GetEnableDeviceCount(std::string ipMask, uint16_t &enableDevCount, + std::vector &enableIps, std::string ipGroup) +{ + /* ipMask and ipGroup may be null */ + if (ipMask.size() > NN_NO256 || ipGroup.size() > NN_NO1024) { + NN_LOG_ERROR("[RDMA] ip mask size cannot exceed 256, ip group size cannot exceed 1024. "); + return NN_INVALID_IP; + } + RResult result = RR_OK; + std::vector matchIps; + // filter ip by mask + NetFunc::NN_SplitStr(ipGroup, ";", matchIps); + if (matchIps.empty()) { + std::vector filters; + NetFunc::NN_SplitStr(ipMask, ",", filters); + if (filters.empty()) { + NN_LOG_ERROR("[RDMA] Invalid ip mask '" << ipMask << "' by set, example '192.168.0.0/24'"); + return NN_INVALID_IP; + } + for (auto &mask : filters) { + result = FilterIp(mask, matchIps); + } + if (matchIps.empty()) { + NN_LOG_ERROR("[RDMA] No matched ip found with ipGroup or ipMask."); + return NN_INVALID_IP; + } + } + // init RoCE devices + if ((result = Initialize()) != 0) { + NN_LOG_ERROR("[RDMA] Failed to init devices"); + return result; + } + + NN_LOG_INFO(DeviceInfo()); + + uint16_t enableCount = 0; + std::vector findIps; + // choose the matched ip and port active + for (uint16_t i = 0; i < static_cast(matchIps.size()); ++i) { + RDMAGId tmpGid{}; + if ((GetDeviceByIp(matchIps[i], tmpGid)) != 0) { + NN_LOG_WARN("[RDMA] Unable to get device by ip " << matchIps[i]); + continue; + } + // active or not + if (G_RDMADevMap[tmpGid.devIndex].active) { + enableCount++; + findIps.emplace_back(matchIps[i]); + } + NN_LOG_DEBUG("gid found devIndex " << tmpGid.devIndex << ", gidIndex " << tmpGid.gid << ", RoCEVersion " << + RoCEVersionToStr(tmpGid.RoCEVersion)); + } + enableDevCount = enableCount; + enableIps = findIps; + return result; +} + +RResult RDMADeviceHelper::GetDeviceByIp(const std::string &ip, RDMAGId &gid) +{ + RResult result = RR_OK; + struct sockaddr_in address {}; + if ((result = GetIfAddressByIp(ip, address)) != RR_OK) { + return result; + } + + return GetDeviceByAddress(ip, address, gid); +} + +RResult RDMADeviceHelper::GetIfAddressByIp(const std::string &ip, struct sockaddr_in &address) +{ + struct ifaddrs *addresses = nullptr; + if (getifaddrs(&addresses) != 0) { + NN_LOG_ERROR("Failed to get interface addresses"); + return RR_DEVICE_FAILED_GET_IF_ADDRESS; + } + + char ipStr[INET_ADDRSTRLEN] = {0}; + bool found = false; + + struct ifaddrs *iter = addresses; + while (iter != nullptr) { + if (iter->ifa_addr != nullptr && iter->ifa_addr->sa_family == AF_INET) { + inet_ntop(AF_INET, &((reinterpret_cast(iter->ifa_addr))->sin_addr), ipStr, + INET_ADDRSTRLEN); + if (ip == std::string(ipStr)) { + address = *(reinterpret_cast(iter->ifa_addr)); + found = true; + break; + } + } + iter = iter->ifa_next; + } + freeifaddrs(addresses); + + if (!found) { + NN_LOG_ERROR("Failed to get interface address for ip " << ip); + return RR_DEVICE_NO_IF_MATCHED; + } + + return RR_OK; +} + + +RResult RDMADeviceHelper::GetDeviceByAddress(const std::string &ip, struct sockaddr_in &address, RDMAGId &gid) +{ + RResult result = RR_OK; + if ((result = Initialize()) != RR_OK) { + return result; + } + + RDMAGId tmpGid {}; + bool found = false; + + std::lock_guard lock(G_Mutex); + for (auto &item : G_RDMADevGidTable) { + for (auto &gItem : item.second) { + auto devI6Address = reinterpret_cast(gItem.ibvGid.raw); + auto targetAddress = address.sin_addr.s_addr; + + auto judge1 = ((devI6Address->s6_addr32[NN_NO0] | devI6Address->s6_addr32[NN_NO1]) | + (devI6Address->s6_addr32[NN_NO2] ^ htonl(0x0000ffff))) == 0UL; + /* IPv4 encoded multicast addresses */ + auto judge2 = devI6Address->s6_addr32[NN_NO0] == htonl(0xff0e0000) && + ((devI6Address->s6_addr32[NN_NO1] | (devI6Address->s6_addr32[NN_NO2] ^ htonl(0x0000ffff))) == 0UL); + if (!((judge1 || judge2) && devI6Address->s6_addr32[NN_NO3] == targetAddress)) { + // doesn't match + continue; + } + + // match + if (!found) { // first found + tmpGid = gItem; + found = true; + } else if (gItem.RoCEVersion > tmpGid.RoCEVersion) { + // found new one then compare the version, higher version is better + tmpGid = gItem; + } + } + } + + if (!found) { + NN_LOG_ERROR("Failed to get proper gid by address for ip " << ip); + return RR_DEVICE_NO_IF_TO_GID_MATCHED; + } + + gid = tmpGid; + return RR_OK; +} + +RDMARoCEVersion RDMADeviceHelper::StrToRoCEVersion(const std::string &value) +{ + if (value == "IB/RoCE v1") { + return RoCE_V1; + } else if (value == "RoCE v2") { + return RoCE_V2; + } + + // rare case + if (value.length() > 1 && value.at(value.length() - 1) == '5') { + return RoCE_V15; + } + + return RoCE_UNKNOWN; +} + +const char *RDMADeviceHelper::RoCEVersionToStr(RDMARoCEVersion v) +{ + return RDMARoCEVersionStrTable[v]; +} + +std::string RDMADeviceHelper::DeviceInfo() +{ + std::ostringstream oss; + std::lock_guard guard(G_Mutex); + if (!G_Inited) { + oss << "RDMADeviceHelper has not been initialized"; + return oss.str(); + } + + // dump device info + oss << "RDMADeviceHelper device info, devices: count " << G_RDMADevMap.size() << ", "; + for (auto &item : G_RDMADevMap) { + oss << "[" << item.second.devIndex << "," << item.second.devName << "," << item.second.active << "] "; + } + + oss << ", gidTable: count " << G_RDMADevGidTable.size() << ", "; + for (auto &item : G_RDMADevGidTable) { + oss << "[deviceName " << item.first << ", "; + for (auto &gid : item.second) { + oss << "[" << gid.devIndex << "," << gid.gid << "," << RoCEVersionToStr(gid.RoCEVersion) << "] "; + } + oss << "] "; + } + + return oss.str(); +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_device_helper.h b/src/transport/rdma/verbs/rdma_device_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..be1825f0ea4032f1d79747686da4369538bd3b30 --- /dev/null +++ b/src/transport/rdma/verbs/rdma_device_helper.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_RDMA_DEVICE_HELPER_H +#define HCOM_RDMA_DEVICE_HELPER_H +#ifdef RDMA_BUILD_ENABLED + +#include "rdma_common.h" + +namespace ock { +namespace hcom { + +struct RDMADeviceSimpleInfo { + uint16_t devIndex = 0; + char devName[IBV_SYSFS_NAME_MAX] {}; + bool active = false; + UBSHcomNetDriverDeviceInfo deviceInfo; +}; + +enum RDMARoCEVersion { + RoCE_UNKNOWN = 0, + RoCE_V1 = 1, + RoCE_V15 = 2, + RoCE_V2 = 3, +}; + +struct RDMAGId { + uint16_t devIndex = 0; + uint16_t gid = 0; + union ibv_gid ibvGid {}; + RDMARoCEVersion RoCEVersion = RDMARoCEVersion::RoCE_UNKNOWN; + uint8_t bandWidth = 0; +} __attribute__((packed)); + +class RDMADeviceHelper { +public: + /* + * @brief, loop all device, and gid table + */ + static RResult Initialize(); + static void UnInitialize(); + static RResult Update(); + + static RResult GetDeviceCount(uint16_t &deviceCount, std::vector &enabledDevices); + + static RResult GetDeviceByIp(const std::string &ip, RDMAGId &gid); + + static const char *RoCEVersionToStr(RDMARoCEVersion v); + static RDMARoCEVersion StrToRoCEVersion(const std::string &value); + + static std::string DeviceInfo(); + + static RResult GetEnableDeviceCount(std::string ipMask, uint16_t &enableDevCount, + std::vector &enableIps, std::string ipGroup); + +private: + static RResult DoInitialize(); + static RResult DoUpdate(); + static void GetGidVec(ibv_context *context, const std::string &devName, uint16_t devIndex, uint8_t bandWidth, + uint32_t gidTableLen, std::vector &outGidVec); + + static RResult GetIfAddressByIp(const std::string &ip, struct sockaddr_in &address); + static RResult GetDeviceByAddress(const std::string &ip, struct sockaddr_in &address, RDMAGId &gid); + +private: + static std::unordered_map G_RDMADevMap; + static std::unordered_map> G_RDMADevGidTable; + static std::mutex G_Mutex; + static bool G_Inited; + + static uint32_t PORT_NUMBER; +}; +} +} +#endif +#endif // HCOM_RDMA_DEVICE_HELPER_H \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_validation.h b/src/transport/rdma/verbs/rdma_validation.h new file mode 100644 index 0000000000000000000000000000000000000000..ce49f0a357d435b028bdcd9f1cd1426367f5a4ff --- /dev/null +++ b/src/transport/rdma/verbs/rdma_validation.h @@ -0,0 +1,259 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_RDMA_VALIDATION_H +#define OCK_HCOM_NET_RDMA_VALIDATION_H +#ifdef RDMA_BUILD_ENABLED + +#include "hcom.h" + +#include "rdma_composed_endpoint.h" +#include "net_monotonic.h" +#include "net_rdma_driver_oob.h" +#include "net_security_alg.h" +#include "hcom_utils.h" + +namespace ock { +namespace hcom { +static __always_inline NResult StateValidate(UBSHcomNetAtomicState &state, uint64_t id, + NetDriverRDMAWithOob *driver) +{ + if (NN_UNLIKELY(!state.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Endpoint " << id << " is not established, state is " << UBSHcomNEPStateToString(state.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + if (NN_UNLIKELY(!driver->IsStarted())) { + NN_LOG_ERROR("Verbs Failed to validate state as driver " << driver << " is not started"); + return NN_ERROR; + } + return NN_OK; +} + +static __always_inline NResult LocalRequestValidate(const UBSHcomNetTransRequest &request) +{ + if (NN_UNLIKELY(request.lAddress == 0 || request.size == 0)) { + NN_LOG_ERROR("Failed to validate request as source data is null or size is zero"); + return NN_PARAM_INVALID; + } + if (NN_UNLIKELY(request.upCtxSize > sizeof(RDMAOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to validate request as up ctx size invalid " << request.upCtxSize); + return NN_PARAM_INVALID; + } + return NN_OK; +} + +static __always_inline NResult SizeValidate(const UBSHcomNetTransRequest &request, uint32_t allowedSize, + bool mIsNeedEncrypt, AesGcm128 mAes) +{ + size_t compareSize = request.size; + if (mIsNeedEncrypt) { + compareSize = mAes.EstimatedEncryptLen(request.size); + } + if (NN_UNLIKELY(compareSize > allowedSize)) { + NN_LOG_ERROR("Failed to post message as message size " << request.size << + " is too large, use one side post"); + return NN_TWO_SIDE_MESSAGE_TOO_LARGE; + } + return NN_OK; +} + +static __always_inline NResult PostSendValidation(UBSHcomNetAtomicState &state, uint64_t id, + NetDriverRDMAWithOob *driver, uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t allowedSize, + bool mIsNeedEncrypt, AesGcm128 mAes) +{ + NResult result = NN_OK; + if (NN_UNLIKELY(result = StateValidate(state, id, driver)) != NN_OK) { + return result; + } + if (NN_UNLIKELY(result = LocalRequestValidate(request)) != NN_OK) { + return result; + } + if (NN_UNLIKELY(result = SizeValidate(request, allowedSize, mIsNeedEncrypt, mAes)) != NN_OK) { + NN_LOG_INFO("res: " << result); + return result; + } + if (NN_UNLIKELY(opCode >= MAX_OPCODE)) { + NN_LOG_ERROR("Failed to post message as opcode is invalid, which should with the range 0~" << (MAX_OPCODE - 1)); + return NN_INVALID_OPCODE; + } + return NN_OK; +} + +static __always_inline NResult PostSendRawValidation(UBSHcomNetAtomicState &state, uint64_t id, + NetDriverRDMAWithOob *driver, uint32_t seqNo, const UBSHcomNetTransRequest &request, uint32_t allowedSize, + bool mIsNeedEncrypt, AesGcm128 mAes) +{ + NResult result = NN_OK; + if (NN_UNLIKELY(result = StateValidate(state, id, driver)) != NN_OK) { + return result; + } + if (NN_UNLIKELY(result = LocalRequestValidate(request)) != NN_OK) { + return result; + } + if (NN_UNLIKELY(result = SizeValidate(request, allowedSize, mIsNeedEncrypt, mAes)) != NN_OK) { + return result; + } + if (NN_UNLIKELY(seqNo == 0)) { + NN_LOG_ERROR("Verbs Failed to post raw message as seqNo must > 0"); + return NN_PARAM_INVALID; + } + return NN_OK; +} + +static __always_inline NResult ReadWriteValidation(UBSHcomNetAtomicState &state, uint64_t id, + NetDriverRDMAWithOob *driver, const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY(result = StateValidate(state, id, driver)) != NN_OK) { + return result; + } + if (NN_UNLIKELY(result = LocalRequestValidate(request)) != NN_OK) { + return result; + } + if (NN_UNLIKELY(request.size > NET_SGE_MAX_SIZE)) { + NN_LOG_ERROR("Failed to validate request size " << request.size << " as over limit"); + return NN_PARAM_INVALID; + } + if (NN_UNLIKELY(request.rAddress == 0)) { + NN_LOG_ERROR("Failed to validate request as remote data is null"); + return NN_PARAM_INVALID; + } + if (NN_OK != driver->ValidateMemoryRegion(request.lKey, request.lAddress, request.size)) { + NN_LOG_ERROR("Invalid MemoryRegion or local key"); + return NN_INVALID_LKEY; + } + return NN_OK; +} + +static __always_inline NResult SglValidation(const UBSHcomNetTransSglRequest &request, size_t &totalSize, + NetDriverRDMAWithOob *driver) +{ + if (NN_UNLIKELY(request.iov == nullptr || request.iovCount > NET_SGE_MAX_IOV || request.iovCount == 0)) { + NN_LOG_ERROR("Invalid iov ptr:" << request.iov << " or iov cnt:" << request.iovCount); + return NN_PARAM_INVALID; + } + if (NN_UNLIKELY(request.upCtxSize > sizeof(RDMAOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to validate request as up ctx size invalid " << request.upCtxSize); + return NN_PARAM_INVALID; + } + for (uint16_t i = 0; i < request.iovCount; ++i) { + if (NN_UNLIKELY(request.iov[i].size > NET_SGE_MAX_SIZE)) { + NN_LOG_ERROR("Failed to validate request size " << request.iov[i].size << " as over limit"); + return NN_PARAM_INVALID; + } + + if (NN_OK != driver->ValidateMemoryRegion(request.iov[i].lKey, request.iov[i].lAddress, + request.iov[i].size)) { + NN_LOG_ERROR("Failed to validate as invalid MemoryRegion or lKey in iov"); + return NN_INVALID_LKEY; + } + totalSize += request.iov[i].size; + } + return NN_OK; +} + +static __always_inline NResult ReadWriteSglValidation(UBSHcomNetAtomicState &state, + uint64_t id, NetDriverRDMAWithOob *driver, const UBSHcomNetTransSglRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY(result = StateValidate(state, id, driver)) != NN_OK) { + return result; + } + size_t tmpTotalSize = 0; + if (NN_UNLIKELY(result = SglValidation(request, tmpTotalSize, driver)) != NN_OK) { + return result; + } + for (uint16_t i = 0; i < request.iovCount; ++i) { + if (NN_UNLIKELY(request.iov[i].rAddress == NN_NO0)) { + NN_LOG_ERROR("Failed to validate request as remote data is null, index " << i); + return NN_PARAM_INVALID; + } + } + return NN_OK; +} + +static __always_inline NResult PostSendSglValidation(UBSHcomNetAtomicState &state, uint64_t id, + NetDriverRDMAWithOob *driver, uint32_t seqNo, const UBSHcomNetTransSglRequest &request, uint32_t allowedSize, + size_t &totalSize, bool mIsNeedEncrypt, AesGcm128 mAes) +{ + NResult ret = NN_OK; + if (NN_UNLIKELY(ret = StateValidate(state, id, driver)) != NN_OK) { + return ret; + } + if (NN_UNLIKELY(seqNo == 0)) { + NN_LOG_ERROR("Failed to post raw message as seqNo must > 0"); + return NN_PARAM_INVALID; + } + if (NN_UNLIKELY(ret = SglValidation(request, totalSize, driver)) != NN_OK) { + return ret; + } + size_t compareSize = totalSize; + if (mIsNeedEncrypt) { + compareSize = mAes.EstimatedEncryptLen(totalSize); + } + if (NN_UNLIKELY(compareSize > allowedSize)) { + NN_LOG_ERROR("Failed to post send raw sgl as message size " << compareSize << + " is too large, use one side post"); + return NN_TWO_SIDE_MESSAGE_TOO_LARGE; + } + return NN_OK; +} + +static __always_inline NResult EncryptRawSgl(UBSHcomNetTransRequest &tlsReq, uintptr_t &mrBufAddress, size_t &size, + AesGcm128 mAes, NetDriverRDMAWithOob *driver, const UBSHcomNetTransSglRequest &request, NetSecrets &mSecrets) +{ + uintptr_t tmpBuffer = 0; + if (NN_UNLIKELY(!driver->GetDriverSendMr()->GetFreeBuffer(tmpBuffer))) { + NN_LOG_ERROR("Failed to post message as failed to get tmp mr buffer from pool from driver " << driver->Name()); + return NN_GET_BUFF_FAILED; + } + + uint32_t iovOffset = 0; + for (uint16_t i = 0; i < request.iovCount; i++) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(tmpBuffer + iovOffset), + driver->GetDriverSendMr()->GetSingleSegSize() - iovOffset, + reinterpret_cast(request.iov[i].lAddress), request.iov[i].size) != NN_OK)) { + (void)driver->GetDriverSendMr()->ReturnBuffer(tmpBuffer); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + iovOffset += request.iov[i].size; + } + + if (NN_UNLIKELY(!driver->GetDriverSendMr()->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to post message as failed to get mr buffer from pool from driver " << driver->Name()); + (void)driver->GetDriverSendMr()->ReturnBuffer(tmpBuffer); + return NN_GET_BUFF_FAILED; + } + + uint32_t cipherLen = 0; + if (!(mAes).Encrypt(mSecrets, reinterpret_cast(tmpBuffer), size, reinterpret_cast(mrBufAddress), + cipherLen)) { + NN_LOG_ERROR("Failed to post send message as encryption failure"); + (void)driver->GetDriverSendMr()->ReturnBuffer(tmpBuffer); + (void)driver->GetDriverSendMr()->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + + tlsReq.lAddress = mrBufAddress; + tlsReq.lKey = driver->GetDriverSendMr()->GetLKey(); + tlsReq.size = cipherLen; + size = cipherLen; + + (void)driver->GetDriverSendMr()->ReturnBuffer(tmpBuffer); + return NN_OK; +} +} +} + +#endif +#endif // OCK_HCOM_NET_RDMA_VALIDATION_H \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_verbs_wrapper_cq.cpp b/src/transport/rdma/verbs/rdma_verbs_wrapper_cq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f97fd1cc2e1c6dfc3b949bb247d15672613b95f --- /dev/null +++ b/src/transport/rdma/verbs/rdma_verbs_wrapper_cq.cpp @@ -0,0 +1,224 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include +#include +#include +#include +#include +#include + +#include "net_monotonic.h" +#include "rdma_verbs_wrapper_cq.h" + +namespace ock { +namespace hcom { + +RResult RDMACq::CreatePollingCq() +{ + auto tmpCQ = HcomIbv::CreateCq(mRDMAContext->mContext, static_cast(mCQCount), reinterpret_cast(mWork), + nullptr, 0); + if (tmpCQ == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create completion queue for RDMACq " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_NEW_OBJECT_FAILED; + } + + mCompletionQueue = tmpCQ; + return RR_OK; +} + +RResult RDMACq::CreateEventCq() +{ + ibv_comp_channel *tmpCC = HcomIbv::CreateCompChannel(mRDMAContext->mContext); + if (tmpCC == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create completion channel for RDMACq " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_NEW_OBJECT_FAILED; + } + + auto tmpCQ = HcomIbv::CreateCq(mRDMAContext->mContext, static_cast(mCQCount), reinterpret_cast(mWork), + tmpCC, 0); + if (tmpCQ == nullptr) { + HcomIbv::DestroyCompChannel(tmpCC); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create completion queue for RDMACq " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_NEW_OBJECT_FAILED; + } + + if (ibv_req_notify_cq(tmpCQ, 0) != 0) { + HcomIbv::DestroyCompChannel(tmpCC); + HcomIbv::DestroyCq(tmpCQ); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create completion queue for RDMACq " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_NEW_OBJECT_FAILED; + } + + int flags = fcntl(tmpCC->fd, F_GETFL); + if (fcntl(tmpCC->fd, F_SETFL, static_cast(flags) | O_NONBLOCK) < 0) { + HcomIbv::DestroyCompChannel(tmpCC); + HcomIbv::DestroyCq(tmpCQ); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set no blocking for RDMACq " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_NEW_OBJECT_FAILED; + } + + mCompletionChannel = tmpCC; + mCompletionQueue = tmpCQ; + return RR_OK; +} + +RResult RDMACq::Initialize() +{ + NN_LOG_TRACE_INFO("RDMACq::Initialize"); + if (mCompletionQueue != nullptr) { + return RR_OK; + } + + if (mRDMAContext == nullptr || mRDMAContext->mContext == nullptr) { + NN_LOG_ERROR("Failed to initialize RDMACq as rdma context is null"); + return RR_PARAM_INVALID; + } + + if (mCreateCompletionChannel) { + return CreateEventCq(); + } else { + return CreatePollingCq(); + } + + NN_LOG_TRACE_INFO("RDMACq::Initialized"); + return RR_OK; +} + +RResult RDMACq::UnInitialize() +{ + if (mCompletionQueue == nullptr) { + return RR_OK; + } + + HcomIbv::DestroyCq(mCompletionQueue); + mCompletionQueue = nullptr; + + if (mCompletionChannel != nullptr) { + HcomIbv::DestroyCompChannel(mCompletionChannel); + mCompletionChannel = nullptr; + } + + if (mRDMAContext != nullptr) { + mRDMAContext->DecreaseRef(); + mRDMAContext = nullptr; + } + return RR_OK; +} + +RResult RDMACq::ProgressV(struct ibv_wc *wc, int &countInOut) +{ + if (NN_UNLIKELY(mCompletionQueue == nullptr || wc == nullptr)) { + return RR_CQ_NOT_INITIALIZED; + } + + uint16_t times = 0; + + while (true) { + auto n = ibv_poll_cq(mCompletionQueue, countInOut, wc); + if (NN_UNLIKELY(n < 0)) { + NN_LOG_ERROR("Poll cq failed in RDMACq " << mName << ", errno " << errno); + return RR_CQ_POLLING_FAILED; + } + if (n == 0) { + times++; + if (times < NN_NO10) { + continue; + } + } + + countInOut = n; + break; + } + + return RR_OK; +} + +RResult RDMACq::EventProgressV(struct ibv_wc *wc, int &countInOut, int32_t timeoutInMs) +{ + if (NN_UNLIKELY(mCompletionQueue == nullptr || mCompletionChannel == nullptr || wc == nullptr)) { + return RR_CQ_NOT_INITIALIZED; + } + +POLL_CQ: + auto n = ibv_poll_cq(mCompletionQueue, countInOut, wc); + if (n < 0) { + NN_LOG_ERROR("Poll cq failed in RDMACq " << mName << ", errno " << errno); + return RR_CQ_POLLING_FAILED; + } else if (n > 0) { + countInOut = n; + return RR_OK; + } + + struct pollfd pollEventFd = {}; + pollEventFd.fd = mCompletionChannel->fd; + pollEventFd.events = POLLIN; + pollEventFd.revents = 0; + int rc = 0; + + auto startTime = NetMonotonic::TimeMs(); + int64_t pollTime = 0; + while (true) { + rc = poll(&pollEventFd, 1, timeoutInMs); + if (rc > 0) { + break; + } + + auto endTime = NetMonotonic::TimeMs(); + pollTime = static_cast(endTime - startTime); + + if (timeoutInMs >= 0 && pollTime > timeoutInMs) { + return RR_CQ_EVENT_GET_TIMOUT; + } + + if (rc < 0 && errno != EINTR) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Get poll event failed in RDMA Cq " << mName << ", errno " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_CQ_EVENT_GET_FAILED; + } + } + + // wait request if n == 0 + void *cqContext = nullptr; + struct ibv_cq *cqEvent = nullptr; + + // Wait for the completion event + if (HcomIbv::GetCqEvent(mCompletionChannel, &cqEvent, &cqContext)) { + NN_LOG_ERROR("Get cq event failed in RDMACq " << mName << ", errno " << errno); + return RR_CQ_EVENT_GET_FAILED; + } + + // Ack the event + HcomIbv::AckCqEvents(cqEvent, 1); + + // Request notification upon the next completion event + if (cqEvent != nullptr && ibv_req_notify_cq(cqEvent, 0) != 0) { + NN_LOG_ERROR("Notify cq event failed in RDMACq " << mName << ", errno " << errno); + return RR_CQ_EVENT_NOTIFY_FAILED; + } + + goto POLL_CQ; +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_verbs_wrapper_cq.h b/src/transport/rdma/verbs/rdma_verbs_wrapper_cq.h new file mode 100644 index 0000000000000000000000000000000000000000..de2ae6b9ba2e5fa6cdfdd980ff7b0772245594e2 --- /dev/null +++ b/src/transport/rdma/verbs/rdma_verbs_wrapper_cq.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_RDMA_VERBS_WRAPPER_CQ_H +#define HCOM_RDMA_VERBS_WRAPPER_CQ_H +#ifdef RDMA_BUILD_ENABLED + +#include "rdma_verbs_wrapper_ctx.h" + +namespace ock { +namespace hcom { + +struct RDMACqPollResult { + uint64_t context = 0; + uint32_t dataSize = 0; + enum ibv_wc_status status = IBV_WC_SUCCESS; +}; + +class RDMACq { +public: + RDMACq(const std::string &name, RDMAContext *ctx, bool createCompletionChannel = false, uintptr_t work = 0) + : mName(name), mCreateCompletionChannel(createCompletionChannel), mWork(work), mRDMAContext(ctx) + { + if (mRDMAContext != nullptr) { + mRDMAContext->IncreaseRef(); + } + + OBJ_GC_INCREASE(RDMACq); + } + + ~RDMACq() + { + UnInitialize(); + OBJ_GC_DECREASE(RDMACq); + } + + inline void SetCQCount(uint32_t value) + { + mCQCount = (value < NN_NO1024) ? NN_NO1024 : value; + } + + inline uint32_t GetCQCount() + { + return mCQCount; + } + + RResult Initialize(); + RResult UnInitialize(); + + RResult ProgressV(struct ibv_wc *wc, int &countInOut); + RResult EventProgressV(struct ibv_wc *wc, int &countInOut, int32_t timeoutInMs = NN_NO500); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + RResult CreatePollingCq(); + RResult CreateEventCq(); + std::string mName; + uint32_t mCQCount = CQ_COUNT; + bool mCreateCompletionChannel = false; + uintptr_t mWork = 0; + RDMAContext *mRDMAContext = nullptr; + ibv_cq *mCompletionQueue = nullptr; + ibv_comp_channel *mCompletionChannel = nullptr; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class RDMAQp; +}; +} +} +#endif +#endif // HCOM_RDMA_VERBS_WRAPPER_CQ_H \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_verbs_wrapper_ctx.cpp b/src/transport/rdma/verbs/rdma_verbs_wrapper_ctx.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffa59ada6791b9e19044dfc854dbfdab0e64dd1f --- /dev/null +++ b/src/transport/rdma/verbs/rdma_verbs_wrapper_ctx.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include "rdma_verbs_wrapper_ctx.h" + +namespace ock { +namespace hcom { + +RResult RDMAContext::Create(const std::string &name, bool useDevX, const RDMAGId &gid, RDMAContext *&ctx) +{ + auto tmpCtx = new (std::nothrow) RDMAContext(name, useDevX, gid); + if (tmpCtx == nullptr) { + return RR_NEW_OBJECT_FAILED; + } + + ctx = tmpCtx; + return RR_OK; +} + +RResult RDMAContext::Initialize() +{ + if (mContext != nullptr) { + NN_LOG_INFO("RDMAContext " << mName << " already initialized"); + return RR_OK; + } + + HcomIbv::ForkInit(); + + struct ibv_device **devList = nullptr; + int devCount = 0; + devList = HcomIbv::GetDevList(&devCount); + if (devList == nullptr) { + NN_LOG_ERROR("Failed to call get ibv device list for RDMAContext " << mName << ", errno " << errno); + return RR_DEVICE_FAILED_OPEN; + } + + if (mDevIndex >= devCount) { + NN_LOG_ERROR("Invalid device index is set for RDMAContext " << mName); + HcomIbv::FreeDevList(devList); + + return RR_DEVICE_INDEX_OVERFLOW; + } + + ibv_context *tmpCtx = nullptr; + if ((tmpCtx = HcomIbv::OpenDevice(devList[mDevIndex])) == nullptr) { + NN_LOG_ERROR("Invalid device index is set for RDMAContext " << mName << ", errno " << errno); + HcomIbv::FreeDevList(devList); + return RR_DEVICE_OPEN_FAILED; + } + + struct ibv_device_attr attr {}; + if (HcomIbv::QueryDevice(tmpCtx, &attr) != 0) { + NN_LOG_ERROR("Failed to query device info"); + HcomIbv::CloseDev(tmpCtx); + HcomIbv::FreeDevList(devList); + return RR_DEVICE_OPEN_FAILED; + } + + mMaxSge = attr.max_sge < mMaxSge ? attr.max_sge : mMaxSge; + NN_LOG_INFO("Device info: fw_ver " << attr.fw_ver << " ,max_qp " << attr.max_qp << " ,max_qp_wr " << + attr.max_qp_wr << " ,max_sge " << attr.max_sge << " ,adapter max_cqe " << mMaxSge << " ,max_cq " << + attr.max_cq << " ,max_cqe " << attr.max_cqe); + + ibv_pd *tmpPD = nullptr; + if ((tmpPD = HcomIbv::AllocPd(tmpCtx)) == nullptr) { + NN_LOG_ERROR("Invalid device index is set for RDMAContext " << mName << ", errno " << errno); + HcomIbv::CloseDev(tmpCtx); + HcomIbv::FreeDevList(devList); + return RR_DEVICE_OPEN_FAILED; + } + + if (HcomIbv::QueryPort(tmpCtx, mPortNumber, &mPortAttr) != 0 || mPortAttr.state != IBV_PORT_ACTIVE) { + NN_LOG_ERROR("Failed to query port for RDMAContext " << mName << ", errno " << errno << + " or port state invalid " << mPortAttr.state); + HcomIbv::CloseDev(tmpCtx); + HcomIbv::FreeDevList(devList); + HcomIbv::DeallocPd(tmpPD); + return RR_DEVICE_OPEN_FAILED; + } + + HcomIbv::FreeDevList(devList); + + mProtectDomain = tmpPD; + mContext = tmpCtx; + return RR_OK; +} + +void RDMAContext::UpdateGid(const std::string &matchIp) +{ + auto ret = RDMADeviceHelper::Update(); + if (NN_UNLIKELY(ret != RR_OK)) { + return; + } + + RDMAGId tmpGid {}; + if ((RDMADeviceHelper::GetDeviceByIp(matchIp, tmpGid)) != 0) { + NN_LOG_ERROR("Failed to get device by ip " << matchIp); + return; + } + + NN_LOG_INFO("gid found devIndex " << tmpGid.devIndex << ", gidIndex " << tmpGid.gid << ", RoCEVersion " << + RDMADeviceHelper::RoCEVersionToStr(tmpGid.RoCEVersion)); + mBestGid = tmpGid; +} + +RResult RDMAContext::UnInitialize() +{ + if (mContext == nullptr) { + return RR_OK; + } + + HcomIbv::DeallocPd(mProtectDomain); + HcomIbv::CloseDev(mContext); + mProtectDomain = nullptr; + mContext = nullptr; + + return RR_OK; +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_verbs_wrapper_ctx.h b/src/transport/rdma/verbs/rdma_verbs_wrapper_ctx.h new file mode 100644 index 0000000000000000000000000000000000000000..aaa04233d485485894bac54c830f322c78076eba --- /dev/null +++ b/src/transport/rdma/verbs/rdma_verbs_wrapper_ctx.h @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_RDMA_VERBS_WRAPPER_CTX_H +#define HCOM_RDMA_VERBS_WRAPPER_CTX_H +#ifdef RDMA_BUILD_ENABLED + +#include "hcom_utils.h" +#include "hcom_obj_statistics.h" +#include "rdma_common.h" +#include "rdma_device_helper.h" + +namespace ock { +namespace hcom { + +inline uint8_t GetTrafficClass() +{ + uint8_t tc = 106; + auto env = getenv("HCOM_QP_TRAFFIC_CLASS"); + if (env == nullptr) { + return tc; + } + + long tmp = 0; + NetFunc::NN_Stol(env, tmp); + if (tmp >= 0 && tmp < NN_NO256) { + return tmp; + } + + return tc; +} + +class RDMAContext { +public: + static RResult Create(const std::string &name, bool useDevX, const RDMAGId &gid, RDMAContext *&ctx); + +public: + RDMAContext(const std::string &name, bool useDevX, const RDMAGId &gid) + : mName(name), mDevIndex(gid.devIndex), mBestGid(gid), mUseDevX(useDevX) + { + OBJ_GC_INCREASE(RDMAContext); + } + + ~RDMAContext() + { + UnInitialize(); + OBJ_GC_DECREASE(RDMAContext); + } + + RResult Initialize(); + RResult UnInitialize(); + + void UpdateGid(const std::string &matchIp); + + RDMAContext() = delete; + RDMAContext(const RDMAContext &) = delete; + RDMAContext &operator = (const RDMAContext &) = delete; + RDMAContext(RDMAContext &&) = delete; + RDMAContext &operator = (RDMAContext &&) = delete; + + std::string ToString() + { + std::ostringstream oss; + oss << "RDMAContext info: mName " << mName << ", use DevX " << mUseDevX << ", mContext " << mContext << + ", mProtectDomain " << mProtectDomain << ", mDevIndex " << mDevIndex << ", mPortAttr " << + HcomIbv::PortStateStr(mPortAttr.state) << "|" << mPortAttr.lid << "|" << mPortAttr.max_mtu << + ", mBestGid " << mBestGid.devIndex << "|" << mBestGid.gid << "|" << mBestGid.ibvGid.global.interface_id << + "|" << mBestGid.RoCEVersion; + return oss.str(); + } + + inline ibv_context *Context() + { + return mContext; + } + +#ifdef RDMA_CX5_BUILD_ENABLED + inline ibv_dm *DeviceMemory() + { + return mDeviceMemory; + } +#endif + + inline bool UseDevX() const + { + return mUseDevX; + } + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + std::string mName; + ibv_context *mContext = nullptr; + ibv_pd *mProtectDomain = nullptr; +#ifdef RDMA_CX5_BUILD_ENABLED + ibv_dm *mDeviceMemory = nullptr; +#endif + struct ibv_port_attr mPortAttr {}; + uint8_t mPortNumber = 1; + uint16_t mDevIndex = 0; + int mMaxSge = NN_NO16; + RDMAGId mBestGid {}; + bool mUseDevX = false; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend RDMAQp; + friend RDMACq; + friend RDMAMemoryRegion; + friend RDMAWorker; + +#ifdef RDMA_CX5_BUILD_ENABLED + friend RDMAMlx5Qp; + friend RDMAMlx5Cq; + friend RDMAMlx5Worker; +#endif +}; +} +} +#endif +#endif // HCOM_RDMA_VERBS_WRAPPER_CTX_H \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_verbs_wrapper_qp.cpp b/src/transport/rdma/verbs/rdma_verbs_wrapper_qp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b098ed43661b8fb09141d05f7871ed65c1afb93 --- /dev/null +++ b/src/transport/rdma/verbs/rdma_verbs_wrapper_qp.cpp @@ -0,0 +1,305 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include +#include +#include +#include +#include +#include + +#include "rdma_verbs_wrapper_qp.h" + +namespace ock { +namespace hcom { + +uint32_t RDMAQp::G_INDEX = 1; + +RResult RDMAQp::CreateIbvQp() +{ + NN_LOG_TRACE_INFO("RDMAQp::Initialize"); + if (mRDMAContext == nullptr || mRDMAContext->mContext == nullptr || mSendCQ == nullptr || + mSendCQ->mCompletionQueue == nullptr || mRecvCQ == nullptr || mRecvCQ->mCompletionQueue == nullptr) { + return RR_PARAM_INVALID; + } + + mCtxPosted.next = nullptr; + mCtxPosted.prev = nullptr; + + struct ibv_qp_init_attr initAttr {}; + bzero(&initAttr, sizeof(ibv_qp_init_attr)); + initAttr.qp_context = this; + initAttr.send_cq = mSendCQ->mCompletionQueue; + initAttr.recv_cq = mRecvCQ->mCompletionQueue; + initAttr.qp_type = IBV_QPT_RC; + mQpOptions.maxSendWr = (mQpOptions.maxSendWr < QP_MAX_SEND_WR) ? QP_MAX_SEND_WR : mQpOptions.maxSendWr; + mQpOptions.maxReceiveWr = (mQpOptions.maxReceiveWr < QP_MAX_RECV_WR) ? QP_MAX_RECV_WR : mQpOptions.maxReceiveWr; + // NN_NO8 is the window size Preventing mPostSendMaxWr exceeds ibv max_send_wr caused by mPostSendMaxWr increasing + // during waiting for mPostSendMaxWr,which will cause ibv_post_send return error 12 + initAttr.cap.max_send_wr = mQpOptions.maxSendWr + NN_NO8; + initAttr.cap.max_recv_wr = mQpOptions.maxReceiveWr + NN_NO8; + initAttr.cap.max_recv_sge = static_cast(mRDMAContext->mMaxSge); + initAttr.cap.max_send_sge = static_cast(mRDMAContext->mMaxSge); + initAttr.cap.max_inline_data = HcomEnv::InlineThreshold(); + + auto tmpQP = HcomIbv::CreateQp(mRDMAContext->mProtectDomain, &initAttr); + if (tmpQP == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create ibv qp RDMAQp " << mName << ", errno " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_QP_CREATE_FAILED; + } + + mQP = tmpQP; + NN_LOG_TRACE_INFO("RDMAQp::Initialized"); + return RR_OK; +} + +RResult RDMAQp::CreateQpMr() +{ + NResult result = NN_OK; + // create mr pool for send/receive and initialize + if ((result = RDMAMemoryRegionFixedBuffer::Create(mName, mRDMAContext, mQpOptions.mrSegSize, mQpOptions.mrSegCount, + mQpMr)) != 0) { + NN_LOG_ERROR("Failed to create mr for send/receive in qp " << mName << ", result " << result); + return result; + } + if ((result = mQpMr->Initialize()) != 0) { + NN_LOG_ERROR("Failed to initialize mr for send/receive in qp " << mName << ", result " << result); + return result; + } + mQpMr->IncreaseRef(); + return RR_OK; +} + +bool RDMAQp::GetFreeBuff(uintptr_t &item) +{ + return mQpMr->GetFreeBuffer(item); +} + +bool RDMAQp::GetFreeBufferN(uintptr_t *&items, uint32_t n) +{ + return mQpMr->GetFreeBufferN(items, n); +} + +bool RDMAQp::ReturnBuffer(uintptr_t value) +{ + return mQpMr->ReturnBuffer(value); +} + +uint32_t RDMAQp::GetLKey() +{ + return static_cast(mQpMr->GetLKey()); +} + +RResult RDMAQp::Initialize() +{ + auto result = CreateIbvQp(); + if (result != RR_OK) { + NN_LOG_ERROR("RDMA failed to create ibv qp"); + return result; + } + + result = CreateQpMr(); + if (result != RR_OK) { + NN_LOG_ERROR("RDMA failed to create qp mr"); + HcomIbv::DestroyQp(mQP); + mQP = nullptr; + return result; + } + + return RR_OK; +} + +RResult RDMAQp::UnInitialize() +{ + if (mQP != nullptr) { + HcomIbv::DestroyQp(mQP); + mQP = nullptr; + } + + Stop(); + + if (mQpMr != nullptr) { + mQpMr->DecreaseRef(); + mQpMr = nullptr; + } + + if (mSendCQ != nullptr) { + mSendCQ->DecreaseRef(); + } + + if (mRecvCQ != nullptr && mRecvCQ != mSendCQ) { + mRecvCQ->DecreaseRef(); + } + mSendCQ = nullptr; + mRecvCQ = nullptr; + + if (mRDMAContext != nullptr) { + mRDMAContext->DecreaseRef(); + mRDMAContext = nullptr; + } + + return RR_OK; +} + +RResult RDMAQp::ChangeToInit(struct ibv_qp_attr &attr) +{ + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = 0; + attr.port_num = mRDMAContext->mPortNumber; + attr.qp_access_flags = + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC | IBV_ACCESS_REMOTE_WRITE; + + if (HcomIbv::ModifyQp(mQP, &attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to change qp " << mName << " state to INIT modify failed, errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_QP_CHANGE_STATE_FAILED; + } + + return RR_OK; +} + +RResult RDMAQp::ChangeToReceive(RDMAQpExchangeInfo &exInfo, struct ibv_qp_attr &attr) +{ + RResult ret = 0; + static uint8_t tc = GetTrafficClass(); + + attr.qp_state = IBV_QPS_RTR; + // path_mtu should be smaller than the network mtu + attr.path_mtu = IBV_MTU_1024; + attr.dest_qp_num = exInfo.qpn; + attr.rq_psn = 0; + attr.max_dest_rd_atomic = 1; + attr.min_rnr_timer = QP_MIN_RNR_TIMER; + attr.ah_attr.is_global = 0; + attr.ah_attr.dlid = exInfo.lid; + attr.ah_attr.sl = 0; + attr.ah_attr.src_path_bits = 0; + attr.ah_attr.port_num = 1; + if (exInfo.gid.global.interface_id) { + attr.ah_attr.is_global = 1; + attr.ah_attr.grh.hop_limit = 1; + attr.ah_attr.grh.dgid = exInfo.gid; + attr.ah_attr.grh.sgid_index = mRDMAContext->mBestGid.gid; + attr.ah_attr.grh.traffic_class = tc; + } + + if ((ret = HcomIbv::ModifyQp(mQP, &attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER)) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to change qp " << mName << " state to READY-TO-RECEIVE modify failed result " << ret << + ", errno:" << errno << " error:" << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_QP_CHANGE_STATE_FAILED; + } + + return RR_OK; +} + +RResult RDMAQp::ChangeToSend(struct ibv_qp_attr &attr) +{ + attr.qp_state = IBV_QPS_RTS; + attr.timeout = QP_TIMEOUT; // 2^14 * 4.096 us = 67108.86 us + attr.retry_cnt = QP_RETRY_COUNT; + attr.rnr_retry = QP_RNR_RETRY; // do later + attr.sq_psn = 0; + attr.max_rd_atomic = 1; + + if (HcomIbv::ModifyQp(mQP, &attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + NN_LOG_ERROR("Failed to change qp " << mName << " state to READY-TO-SEND modify failed, errno " << errno); + return RR_QP_CHANGE_STATE_FAILED; + } + + return RR_OK; +} + +RResult RDMAQp::SetMaxSendWrConfig(RDMAQpExchangeInfo &exInfo) +{ + NN_LOG_TRACE_INFO("Remote qpId " << mId << " info: send wr " << exInfo.maxSendWr << ", receive wr " << + exInfo.maxReceiveWr << ", receive seg size " << exInfo.receiveSegSize << ", receive seg count " << + exInfo.receiveSegCount); + NN_LOG_TRACE_INFO("Local qpId " << mId << " info: send wr " << mQpOptions.maxSendWr << ", receive wr " << + mQpOptions.maxReceiveWr << ", receive seg size " << mQpOptions.mrSegSize << ", receive seg count " << + mQpOptions.mrSegCount); + + int32_t maxWr = std::min(mQpOptions.maxSendWr, exInfo.maxReceiveWr); + int32_t maxPostSendWr = std::min(mQpOptions.maxSendWr, exInfo.receiveSegCount); + if (maxWr < maxPostSendWr) { + NN_LOG_ERROR("Qp " << mId << " max wr " << maxWr << " is less than max post send wr" << maxPostSendWr); + return RR_QP_RECEIVE_CONFIG_ERR; + } + // one side operation do not consume remote receive queue element + mOneSideMaxWr = maxWr - maxPostSendWr; + mOneSideRef = mOneSideMaxWr; + mPostSendMaxWr = maxPostSendWr; + mPostSendRef = mPostSendMaxWr; + mPostSendMaxSize = exInfo.receiveSegSize; + NN_LOG_TRACE_INFO("Qp id " << mId << " one side max wr " << mOneSideMaxWr << ", post send max wr " << + mPostSendMaxWr << ", post send max size " << mPostSendMaxSize); + return RR_OK; +} + +RResult RDMAQp::ChangeToReady(RDMAQpExchangeInfo &exInfo) +{ + if (NN_UNLIKELY(mQP == nullptr)) { + NN_LOG_ERROR("Failed to change qp " << mName << " state to READY-TO-SEND as qp is not created."); + return RR_QP_CHANGE_STATE_FAILED; + } + + RResult ret = 0; + ret = SetMaxSendWrConfig(exInfo); + if (ret != RR_OK) { + return ret; + } + + struct ibv_qp_attr attr {}; + ret = ChangeToInit(attr); + if (ret != RR_OK) { + return ret; + } + + ret = ChangeToReceive(exInfo, attr); + if (ret != RR_OK) { + return ret; + } + + ret = ChangeToSend(attr); + if (ret != RR_OK) { + return ret; + } + + NN_LOG_INFO("RDMA qp " << mId << " attr send queue size " << mQpOptions.maxSendWr << ", receive queue size " << + mQpOptions.maxReceiveWr << ", tc " << std::to_string(attr.ah_attr.grh.traffic_class) << ", gid-n-n " << + (exInfo.gid.global.interface_id != 0)); + + isStarted = true; + return RR_OK; +} + +RResult RDMAQp::GetExchangeInfo(RDMAQpExchangeInfo &exInfo) +{ + if (mQP == nullptr || mRDMAContext == nullptr) { + return RR_QP_NOT_INITIALIZED; + } + + exInfo.qpn = mQP->qp_num; + exInfo.lid = mRDMAContext->mPortAttr.lid; + exInfo.gid = mRDMAContext->mBestGid.ibvGid; + return RR_OK; +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_verbs_wrapper_qp.h b/src/transport/rdma/verbs/rdma_verbs_wrapper_qp.h new file mode 100644 index 0000000000000000000000000000000000000000..b0797e68236fd0e7645f324efcd33dc87d223062 --- /dev/null +++ b/src/transport/rdma/verbs/rdma_verbs_wrapper_qp.h @@ -0,0 +1,591 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_RDMA_VERBS_WRAPPER_QP_H +#define HCOM_RDMA_VERBS_WRAPPER_QP_H +#ifdef RDMA_BUILD_ENABLED +#include +#include + +#include "hcom_env.h" +#include "rdma_mr_fixed_buf.h" +#include "rdma_verbs_wrapper_cq.h" + +namespace ock { +namespace hcom { + +struct RDMAQpExchangeInfo { + uint32_t lid = 0; + uint32_t qpn = 0; + union ibv_gid gid {}; + uintptr_t hbAddress = 0; + uint32_t hbKey = 0; + uint64_t hbMrSize = 0; + uint32_t maxSendWr = QP_MAX_SEND_WR; + uint32_t maxReceiveWr = QP_MAX_RECV_WR; + uint32_t receiveSegSize = NN_NO1024; + uint32_t receiveSegCount = NN_NO64; +} __attribute__((packed)); + +class RDMAQp { +public: + RDMAQp(const std::string &name, uint32_t id, RDMAContext *ctx, RDMACq *cq, QpOptions qpOptions = {}) + : mName(name), mId(id), mRDMAContext(ctx), mSendCQ(cq), mRecvCQ(cq), mQpOptions(qpOptions) + { + if (mRDMAContext != nullptr) { + mRDMAContext->IncreaseRef(); + } + + if (mSendCQ != nullptr) { + mSendCQ->IncreaseRef(); + } + OBJ_GC_INCREASE(RDMAQp); + } + + RDMAQp(uint32_t id, RDMAContext *ctx, RDMACq *sendCq, RDMACq *receiveCq, QpOptions qpOptions) + : mId(id), mRDMAContext(ctx), mSendCQ(sendCq), mRecvCQ(receiveCq), mQpOptions(qpOptions) + { + if (mRDMAContext != nullptr) { + mRDMAContext->IncreaseRef(); + } + + if (mSendCQ != nullptr) { + mSendCQ->IncreaseRef(); + } + + if (mRecvCQ != nullptr && mRecvCQ != mSendCQ) { + mRecvCQ->IncreaseRef(); + } + OBJ_GC_INCREASE(RDMAQp); + } + + virtual ~RDMAQp() + { + UnInitialize(); + OBJ_GC_DECREASE(RDMAQp); + } + + RResult CreateIbvQp(); + RResult CreateQpMr(); + + /* call ibv_create_qp to create real QP */ + RResult Initialize(); + RResult UnInitialize(); + + /* + exchange information needs to be transformed by other channel (e.g. tcp connection) + 1 firstly do the initialization + 2 got qp exchange info from peer + 3 call this function to change qp state to ready state (INIT & RTS & RTR) + */ + RResult ChangeToReady(RDMAQpExchangeInfo &exInfo); + + /* after qp initialized, retrieve the qp qp_num for exchange */ + RResult GetExchangeInfo(RDMAQpExchangeInfo &exInfo); + + inline RResult PostReceive(uintptr_t bufAddr, uint32_t bufSize, uint32_t localKey, uint64_t context) + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_QP_NOT_INITIALIZED; + } + + struct ibv_recv_wr *badWR; + struct ibv_sge list { + bufAddr, bufSize, localKey + }; + + struct ibv_recv_wr wr {}; + wr.wr_id = context; + wr.sg_list = &list; + wr.num_sge = 1; + wr.next = nullptr; + + auto result = ibv_post_recv(mQP, &wr, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to post receive request to qp " << mName << ", result " << result); + return RR_QP_POST_RECEIVE_FAILED; + } + return RR_OK; + } + + inline RResult PostSend(uintptr_t bufAddr, uint32_t bufSize, uint32_t localKey, uint64_t context, + uint32_t immData = 0) + { + NN_LOG_TRACE_INFO("Post send addr " << bufAddr << ", size " << bufSize << ", lkey " << localKey << + ", context " << context); + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_QP_NOT_INITIALIZED; + } + + struct ibv_send_wr *badWR; + struct ibv_sge list { + bufAddr, bufSize, localKey + }; + + struct ibv_send_wr wr {}; + wr.sg_list = &list; + wr.wr_id = context; + wr.num_sge = 1; + wr.next = nullptr; + wr.send_flags = IBV_SEND_SIGNALED; + wr.opcode = IBV_WR_SEND_WITH_IMM; + /* + * case 1: immData == 0, send header then user's data + * case 2: immData != 0, send user's data only + */ + wr.imm_data = immData; + + auto result = ibv_post_send(mQP, &wr, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to post send request to qp " << mName << ", result " << result); + return RR_QP_POST_SEND_FAILED; + } + + return RR_OK; + } + + inline RResult PostSendSglInline(UBSHcomNetTransDataIov *iov, uint32_t iovCount, uint64_t context, + uint32_t immData = 0) + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_QP_NOT_INITIALIZED; + } + + struct ibv_send_wr *badWR; + struct ibv_sge list[NN_NO4] = {}; + for (uint32_t i = 0; i < iovCount; i++) { + list[i].addr = iov[i].address; + list[i].length = iov[i].size; + list[i].lkey = static_cast(iov[i].key); + } + + struct ibv_send_wr wr {}; + wr.wr_id = context; + wr.sg_list = list; + wr.num_sge = static_cast(iovCount); + wr.next = nullptr; + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.send_flags = IBV_SEND_INLINE | IBV_SEND_SIGNALED; + /* + * case 1: immData == 0, send header then user's data + * case 2: immData != 0, send user's data only + */ + wr.imm_data = immData; + + auto result = ibv_post_send(mQP, &wr, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to post send request to qp " << mName << ", result " << result); + return RR_QP_POST_SEND_FAILED; + } + + return RR_OK; + } + + inline RResult PostSendSgl(UBSHcomNetTransSgeIov *iov, uint32_t iovCount, uint64_t context, uint32_t immData = 0) + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_QP_NOT_INITIALIZED; + } + + struct ibv_send_wr *badWR; + struct ibv_sge list[NN_NO4] = {}; + for (uint32_t i = 0; i < iovCount; i++) { + list[i].addr = iov[i].lAddress; + list[i].length = iov[i].size; + list[i].lkey = static_cast(iov[i].lKey); + } + + struct ibv_send_wr wr {}; + wr.wr_id = context; + wr.sg_list = list; + wr.num_sge = static_cast(iovCount); + wr.next = nullptr; + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.send_flags = IBV_SEND_SIGNALED; + /* + * case 1: immData == 0, send header then user's data + * case 2: immData != 0, send user's data only + */ + wr.imm_data = immData; + + auto result = ibv_post_send(mQP, &wr, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to post send request to qp " << mName << ", result " << result); + return RR_QP_POST_SEND_FAILED; + } + + return RR_OK; + } + + inline RResult PostOneSideSgl(UBSHcomNetTransSgeIov *iov, uint32_t iovCount, + uint64_t (&context)[NET_SGE_MAX_IOV], bool isRead) + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_QP_NOT_INITIALIZED; + } + + struct ibv_send_wr *badWR; + struct ibv_send_wr wrList[NET_SGE_MAX_IOV] = {}; + struct ibv_sge list[NN_NO4] = {}; + for (uint32_t i = 0; i < iovCount; i++) { + list[i].addr = iov[i].lAddress; + list[i].length = iov[i].size; + list[i].lkey = static_cast(iov[i].lKey); + + auto &wr = wrList[i]; + wr.wr_id = context[i]; + wr.num_sge = 1; + wr.sg_list = &list[i]; + wr.send_flags = IBV_SEND_SIGNALED; + wr.opcode = isRead ? IBV_WR_RDMA_READ : IBV_WR_RDMA_WRITE; + wr.imm_data = 0; + wr.next = (i + 1 == iovCount) ? nullptr : &wrList[i + 1]; + wr.wr.rdma.remote_addr = iov[i].rAddress; + wr.wr.rdma.rkey = static_cast(iov[i].rKey); + } + + auto result = ibv_post_send(mQP, wrList, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to post oneSide request to qp " << mName << ", result " << result); + return isRead ? RR_QP_POST_READ_FAILED : RR_QP_POST_WRITE_FAILED; + } + + return RR_OK; + } + + inline RResult PostRead(uintptr_t bufAddr, uint32_t localKey, uintptr_t remoteBufAddr, uint32_t remoteKey, + uint32_t bufSize, uint64_t context) + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_QP_NOT_INITIALIZED; + } + + struct ibv_send_wr *badWR; + struct ibv_sge list { + bufAddr, bufSize, localKey + }; + + struct ibv_send_wr wr {}; + wr.sg_list = &list; + wr.num_sge = 1; + wr.next = nullptr; + wr.wr_id = context; + wr.opcode = IBV_WR_RDMA_READ; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.rdma.remote_addr = remoteBufAddr; + wr.wr.rdma.rkey = remoteKey; + + auto result = ibv_post_send(mQP, &wr, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to post read request to qp " << mName << ", result " << result); + return RR_QP_POST_READ_FAILED; + } + + return RR_OK; + } + + inline RResult PostWrite(uintptr_t bufAddr, uint32_t localKey, uintptr_t remoteBufAddr, uint32_t remoteKey, + uint32_t bufSize, uint64_t context) + { + if (NN_UNLIKELY(mQP == nullptr)) { + return RR_QP_NOT_INITIALIZED; + } + + struct ibv_send_wr *badWR; + struct ibv_sge list { + bufAddr, bufSize, localKey + }; + + struct ibv_send_wr wr {}; + wr.wr_id = context; + wr.sg_list = &list; + wr.num_sge = 1; + wr.next = nullptr; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.rdma.remote_addr = remoteBufAddr; + wr.wr.rdma.rkey = remoteKey; + + auto result = ibv_post_send(mQP, &wr, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to post write request to qp " << mName << ", result " << result); + return RR_QP_POST_WRITE_FAILED; + } + + return RR_OK; + } + + inline uint32_t Id() const + { + return mId; + } + + inline void UpId(uint64_t id) + { + mUpId = id; + } + + inline uint64_t UpId() const + { + return mUpId; + } + + inline const std::string &Name() const + { + return mName; + } + + inline void Name(const std::string &value) + { + mName = value; + } + + inline const std::string &PeerIpAndPort() const + { + return mPeerIpPort; + } + + inline void PeerIpAndPort(const std::string &value) + { + mPeerIpPort = value; + } + + inline uint32_t PostSendMaxSize() const + { + return mPostSendMaxSize; + } + + inline uint8_t PortNum() const + { + return mRDMAContext->mPortNumber; + } + + inline void UpContext(uintptr_t ctx) + { + mUpContext = ctx; + } + + inline uintptr_t UpContext() const + { + return mUpContext; + } + + inline void UpContext1(uintptr_t ctx) + { + mUpContext1 = ctx; + } + + inline uintptr_t UpContext1() const + { + return mUpContext1; + } + + bool GetFreeBuff(uintptr_t &item); + bool ReturnBuffer(uintptr_t value); + bool GetFreeBufferN(uintptr_t *&items, uint32_t n); + uint32_t GetLKey(); + + inline void AddOpCtxInfo(RDMAOpContextInfo *verbsCtxInfo) + { + if (NN_LIKELY(verbsCtxInfo != nullptr)) { + // bi-direction linked list, 4 step to insert to head + verbsCtxInfo->prev = &mCtxPosted; + mLock.Lock(); + // head -><- first -><- second -><- third -> nullptr + // insert into the head place + verbsCtxInfo->next = mCtxPosted.next; + if (mCtxPosted.next != nullptr) { + mCtxPosted.next->prev = verbsCtxInfo; + } + mCtxPosted.next = verbsCtxInfo; + ++mCtxPostedCount; + mLock.Unlock(); + } + } + + inline void RemoveOpCtxInfo(RDMAOpContextInfo *ctxInfo) + { + if (NN_LIKELY(ctxInfo != nullptr)) { + // bi-direction linked list, 4 step to remove one + mLock.Lock(); + + // repeat remove + if (ctxInfo->prev == nullptr) { + mLock.Unlock(); + return; + } + + // head-><- first -><- second -><- third -> nullptr + ctxInfo->prev->next = ctxInfo->next; + if (ctxInfo->next != nullptr) { + ctxInfo->next->prev = ctxInfo->prev; + } + --mCtxPostedCount; + + ctxInfo->prev = nullptr; + ctxInfo->next = nullptr; + mLock.Unlock(); + } + } + + // need to call this when qp broken, to get these contexts to return mrs + inline void GetCtxPosted(RDMAOpContextInfo *&remaining) + { + mLock.Lock(); + // head -> first -><- second -><- third -> nullptr + remaining = mCtxPosted.next; + mCtxPosted.next = nullptr; + mCtxPostedCount = 0; + mLock.Unlock(); + } + + /// 获取所有提交至 QP 队列中的任务个数,总数为 PostReceive + PostSend 族函数 + /// 的和。因为 RDMA 有 prePostReceive 机制,所以它的值一般会大于等于 + /// prePostReceiveSizePerQP 的值。 + /// \see prePostReceiveSizePerQP + inline uint32_t GetPostedCount() + { + mLock.Lock(); + auto tmp = mCtxPostedCount; + mLock.Unlock(); + return tmp; + } + + /// 获取 QP 发送队列的长度。 + inline uint32_t GetSendQueueSize() + { + int32_t ref = __sync_fetch_and_add(&mPostSendRef, 0); + ref = std::max(0, std::min(ref, mPostSendMaxWr)); + return static_cast(mPostSendMaxWr - ref); + } + + inline bool GetPostSendWr(uint32_t times = NN_NO8, uint32_t sleepUs = NN_NO64) + { + while (times-- > 0) { + if (NN_LIKELY(__sync_sub_and_fetch(&mPostSendRef, 1) >= 0)) { + return true; + } + __sync_add_and_fetch(&mPostSendRef, 1); + usleep(sleepUs); + } + return false; + } + + inline void ReturnPostSendWr() + { + int32_t ref = __sync_add_and_fetch(&mPostSendRef, 1); + if (ref > mPostSendMaxWr) { + NN_LOG_WARN("[RDMA] Posted send requests " << ref << " over capacity " << mPostSendMaxWr); + } + } + + inline bool GetOneSideWr(uint32_t times = NN_NO8, uint32_t sleepUs = NN_NO64) + { + while (times-- > 0) { + if (NN_LIKELY(__sync_sub_and_fetch(&mOneSideRef, 1) >= 0)) { + return true; + } + __sync_add_and_fetch(&mOneSideRef, 1); + usleep(sleepUs); + } + return false; + } + + inline void ReturnOneSideWr() + { + int32_t ref = __sync_add_and_fetch(&mOneSideRef, 1); + if (ref > mOneSideMaxWr) { + NN_LOG_WARN("[RDMA] Posted one side requests " << ref << " over capacity " << mOneSideMaxWr); + } + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +public: + static uint32_t NewId() + { + return __sync_fetch_and_add(&G_INDEX, 1); + } + + inline uint32_t QpNum() const + { + if (NN_UNLIKELY(mQP == nullptr)) { + return 0xffffffff; + } + + return mQP->qp_num; + } + + inline uint32_t PostRegMrSize() const + { + return mQpOptions.mrSegSize; + } + + inline RResult Stop() + { + if (!isStarted || mQP == nullptr) { + return RR_OK; + } + + struct ibv_qp_attr attr = {}; + attr.qp_state = IBV_QPS_ERR; + auto result = HcomIbv::ModifyQp(mQP, &attr, IBV_QP_STATE); + if (result != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to modify QP state to ERR " << result << ", as " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return RR_QP_CHANGE_ERR; + } + + isStarted = false; + return RR_OK; + } + +private: + RResult ChangeToInit(struct ibv_qp_attr &attr); + RResult ChangeToReceive(RDMAQpExchangeInfo &exInfo, struct ibv_qp_attr &attr); + RResult ChangeToSend(struct ibv_qp_attr &attr); + RResult SetMaxSendWrConfig(RDMAQpExchangeInfo &exInfo); + + std::string mName; + std::string mPeerIpPort; + uint32_t mId = 0; + uint64_t mUpId = 0; + bool isStarted = false; + + RDMAContext *mRDMAContext = nullptr; + RDMACq *mSendCQ = nullptr; + RDMACq *mRecvCQ = nullptr; + QpOptions mQpOptions {}; + ibv_qp *mQP = nullptr; + uintptr_t mUpContext = 0; + uintptr_t mUpContext1 = 0; + NetSpinLock mLock; + RDMAOpContextInfo mCtxPosted {}; + uint32_t mCtxPostedCount { 0 }; + RDMAMemoryRegionFixedBuffer *mQpMr = nullptr; + + int32_t mOneSideMaxWr = QP_MAX_SEND_WR - NN_NO64; + int32_t mOneSideRef = QP_MAX_SEND_WR - NN_NO64; + int32_t mPostSendMaxWr = NN_NO64; + uint32_t mPostSendMaxSize = NN_NO1024; + int32_t mPostSendRef = NN_NO64; + DEFINE_RDMA_REF_COUNT_VARIABLE; + + static uint32_t G_INDEX; + + friend class RDMAWorker; +}; +} +} +#endif +#endif // HCOM_RDMA_VERBS_WRAPPER_QP_H diff --git a/src/transport/rdma/verbs/rdma_worker.h b/src/transport/rdma/verbs/rdma_worker.h new file mode 100644 index 0000000000000000000000000000000000000000..f98e1e9d645e8bc23383f296b5cea628ea29b3a6 --- /dev/null +++ b/src/transport/rdma/verbs/rdma_worker.h @@ -0,0 +1,275 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_RDMA_WORKER_1234341456433_H +#define OCK_RDMA_WORKER_1234341456433_H +#ifdef RDMA_BUILD_ENABLED + +#include +#include +#include +#include +#include +#include + +#include "hcom.h" +#include "net_ctx_info_pool.h" +#include "net_mem_pool_fixed.h" +#include "rdma_verbs_wrapper_qp.h" + +namespace ock { +namespace hcom { +using RDMANewReqHandler = std::function; +using RDMAPostedHandler = std::function; +using RDMAOneSideDoneHandler = std::function; + +// when there is no request from cq, call this +using RDMAIdleHandler = UBSHcomNetDriverIdleHandler; + +using RDMASendSglInlineHeader = UBSHcomNetTransHeader; +using RDMASendReadWriteRequest = UBSHcomNetTransRequest; +using RDMASendSglRWRequest = UBSHcomNetTransSglRequest; + +using RDMAOpContextInfoPool = OpContextInfoPool; +using RDMASglContextInfoPool = OpContextInfoPool; + +enum RDMAWorkerType : uint8_t { + SENDER = 0, + RECEIVER = 1, + SENDER_RECEIVER = 2, +}; + +std::string &WorkerTypeToString(RDMAWorkerType tp); +std::string &PollingModeToString(RDMAPollingMode m); + +using RDMAWorkerOptions = struct RDMAWorkerOptionsStruct { + RDMAWorkerType workerType = RDMAWorkerType::RECEIVER; + RDMAPollingMode workerMode = RDMAPollingMode::BUSY_POLLING; + uint16_t completionQueueDepth = NN_NO2048; + uint16_t maxPostSendCountPerQP = NN_NO64; + uint16_t prePostReceiveSizePerQP = NN_NO64; + uint16_t pollingBatchSize = NN_NO4; + int16_t cpuId = -1; + uint32_t qpSendQueueSize = NN_NO256; + uint32_t qpReceiveQueueSize = NN_NO256; + uint32_t qpMrSegSize = NN_NO1024; + uint32_t qpMrSegCount = NN_NO64; + uint32_t eventPollingTimeout = NN_NO500; + bool dontStartWorkers = false; + /* worker thread priority [-20,20], 20 is the lowest, -20 is the highest, 0 (default) means do not set priority */ + int threadPriority = 0; + + std::string ToString() const + { + std::ostringstream oss; + oss << "options type: " << WorkerTypeToString(workerType) << ", mode: " << PollingModeToString(workerMode) << + ", cq size: " << completionQueueDepth << ", max post send: " << maxPostSendCountPerQP << + ", pre-post receive size: " << prePostReceiveSizePerQP << ", poll batch size " << pollingBatchSize << + ", cpu id: " << cpuId << ", qp send queue: " << qpSendQueueSize << ", qp receive queue: " << + qpReceiveQueueSize << ", dontStartWorkers: " << dontStartWorkers; + return oss.str(); + } + + void SetValue(const UBSHcomNetDriverOptions& opt) + { + workerType = RDMAWorkerType::SENDER_RECEIVER; + completionQueueDepth = opt.completionQueueDepth; + maxPostSendCountPerQP = opt.maxPostSendCountPerQP; + prePostReceiveSizePerQP = opt.prePostReceiveSizePerQP; + pollingBatchSize = opt.pollingBatchSize; + if (opt.mode == NET_EVENT_POLLING) { + workerMode = RDMAPollingMode::EVENT_POLLING; + } else if (opt.mode == NET_BUSY_POLLING) { + workerMode = RDMAPollingMode::BUSY_POLLING; + } + qpSendQueueSize = opt.qpSendQueueSize; + qpReceiveQueueSize = opt.qpReceiveQueueSize; + qpMrSegSize = opt.mrSendReceiveSegSize; + qpMrSegCount = opt.prePostReceiveSizePerQP; + eventPollingTimeout = opt.eventPollingTimeout; + dontStartWorkers = opt.dontStartWorkers; + threadPriority = opt.workerThreadPriority; + } +}; + +class RDMAWorker { +public: + RDMAWorker(const std::string &name, RDMAContext *ctx, const RDMAWorkerOptions &options, + const NetMemPoolFixedPtr &memPool, const NetMemPoolFixedPtr &sglMemPool); + + virtual ~RDMAWorker() + { + UnInitialize(); + OBJ_GC_DECREASE(RDMAWorker); + } + + RResult Initialize(); + RResult UnInitialize(); + RResult ReInitializeCQ(); + + RResult Start(); + RResult Stop(); + + inline bool IsWorkStarted(uint32_t timeOutSecond = NN_NO8) + { + uint64_t count = static_cast(timeOutSecond) * NN_NO1000000 / NN_NO100; + while (--count > 0 && !mProgressThreadStarted.load()) { + usleep(NN_NO100); + } + + if (count > 0) { + return true; + } else { + return false; + } + } + + inline const UBSHcomNetWorkerIndex &Index() const + { + return mIndex; + } + + inline void SetIndex(const UBSHcomNetWorkerIndex &value) + { + mIndex = value; + } + + RResult CreateQP(RDMAQp *&qp); + + RResult PostReceive(RDMAQp *qp, uintptr_t bufAddress, uint32_t bufSize, uint32_t localKey); + RResult PostSend(RDMAQp *qp, const RDMASendReadWriteRequest &req, uint32_t immData = 0); + RResult PostSendSglInline( + RDMAQp *qp, const RDMASendSglInlineHeader &header, const RDMASendReadWriteRequest &req, uint32_t immData = 0); + + RResult PostSendSgl(RDMAQp *qp, const RDMASendSglRWRequest &req, const RDMASendReadWriteRequest &tlsReq, + uint32_t immData = 0, bool isEncrypted = false); + RResult PostRead(RDMAQp *qp, const RDMASendReadWriteRequest &req); + RResult PostOneSideSgl(RDMAQp *qp, const RDMASendSglRWRequest &req, bool isRead = true); + RResult PostWrite(RDMAQp *qp, const RDMASendReadWriteRequest &req, + RDMAOpContextInfo::OpType type = RDMAOpContextInfo::WRITE); + RResult CreateOneSideCtx(RDMASgeCtxInfo &sgeInfo, UBSHcomNetTransSgeIov *iov, uint32_t iovCount, + uint64_t (&ctxArr)[NET_SGE_MAX_IOV], bool isRead); + RResult RePostReceive(RDMAOpContextInfo *ctx); + + inline RDMAOpContextInfo *GetOpContextInfo() + { + return mOpCtxInfoPool.Get(); + } + + inline void ReturnOpContextInfo(RDMAOpContextInfo *&ctx) + { + if (NN_LIKELY(ctx != nullptr)) { + if (NN_LIKELY(ctx->qp != nullptr)) { + ctx->qp->DecreaseRef(); + } + mOpCtxInfoPool.Return(ctx); + ctx = nullptr; + } + } + + inline void RegisterNewRequestHandler(const RDMANewReqHandler &handler) + { + mNewRequestHandler = handler; + } + + inline void ReturnSglContextInfo(RDMASglContextInfo *&ctx) + { + if (NN_LIKELY(ctx != nullptr)) { + mSglCtxInfoPool.Return(ctx); + ctx = nullptr; + } + } + + inline void RegisterPostedHandler(const RDMAPostedHandler &handler) + { + mSendPostedHandler = handler; + } + + inline void RegisterIdleHandler(const RDMAIdleHandler &handler) + { + mIdleHandler = handler; + } + + inline void RegisterOneSideDoneHandler(const RDMAOneSideDoneHandler &handler) + { + mOneSideDoneHandler = handler; + } + + inline const std::string &Name() const + { + return mName; + } + + inline uint8_t PortNum() const + { + return mRDMAContext->mPortNumber; + } + + std::string DetailName() const + { + std::ostringstream oss; + oss << "[name: " << mName << ", index: " << mIndex.ToString() << "]"; + return oss.str(); + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS +public: + static RResult Create(const std::string &name, RDMAContext *ctx, const RDMAWorkerOptions &options, + NetMemPoolFixedPtr memPool, NetMemPoolFixedPtr sglMemPool, RDMAWorker *&outWorker); + +protected: + void RunInThread(); + void DoWithBusyPolling(); + void DoWithCQEventPolling(); + +protected: + std::string mName; + UBSHcomNetWorkerIndex mIndex {}; + RDMAContext *mRDMAContext = nullptr; + RDMACq *mRDMACq = nullptr; + NetMemPoolFixedPtr mOpCtxMemPool = nullptr; + NetMemPoolFixedPtr mSglCtxMemPool = nullptr; + bool mInited = false; + + RDMAWorkerOptions mOptions {}; + + // variable for thread + std::thread mProgressThread; + std::atomic_bool mProgressThreadStarted; + std::atomic_bool mThreadStop; + int16_t mProgressCpuId = -1; + bool mNeedStop = false; + + RDMAOpContextInfoPool mOpCtxInfoPool; + RDMASglContextInfoPool mSglCtxInfoPool; + + // request process related + RDMANewReqHandler mNewRequestHandler = nullptr; + + // send request posted process related + RDMAPostedHandler mSendPostedHandler = nullptr; + + // one side done related + RDMAOneSideDoneHandler mOneSideDoneHandler = nullptr; + + // no request will this + RDMAIdleHandler mIdleHandler = nullptr; + + int mProgressBatchSize = NN_NO4; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class RDMAQp; +}; +} +} +#endif +#endif // OCK_RDMA_WORKER_1234341456433_H \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_worker_core.cpp b/src/transport/rdma/verbs/rdma_worker_core.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d02ef9938823982310a31164dde8716cff51d466 --- /dev/null +++ b/src/transport/rdma/verbs/rdma_worker_core.cpp @@ -0,0 +1,451 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "rdma_worker.h" +#include "net_rdma_async_endpoint.h" + +namespace ock { +namespace hcom { +std::string &WorkerTypeToString(RDMAWorkerType tp) +{ + static std::string workerTypeString[3] = {"sender", "receiver", "sender&receiver"}; + static std::string unknownWorkerType = "unknown worker type"; + if (tp != SENDER && tp != RECEIVER && tp != SENDER_RECEIVER) { + return unknownWorkerType; + } + return workerTypeString[tp]; +} + +std::string &PollingModeToString(RDMAPollingMode m) +{ + static std::string workerModeString[2] = {"busy_polling", "cq_event_polling"}; + static std::string unknownWorkerMode = "unknown worker mode"; + if (m != BUSY_POLLING && m != EVENT_POLLING) { + return unknownWorkerMode; + } + return workerModeString[m]; +} + +RDMAWorker::RDMAWorker(const std::string &name, RDMAContext *ctx, const RDMAWorkerOptions &options, + const NetMemPoolFixedPtr &memPool, const NetMemPoolFixedPtr &sglMemPool) + : mName(name), + mRDMAContext(ctx), + mOpCtxMemPool(memPool), + mSglCtxMemPool(sglMemPool), + mOptions(options), + mProgressThreadStarted(false) +{ + if (mRDMAContext != nullptr) { + mRDMAContext->IncreaseRef(); + } + mThreadStop.store(false); + mProgressCpuId = options.cpuId; + mProgressBatchSize = options.pollingBatchSize; + OBJ_GC_INCREASE(RDMAWorker); +} + +RResult RDMAWorker::Initialize() +{ + if (mInited) { + return RR_OK; + } + + if (mRDMAContext == nullptr || mRDMAContext->mContext == nullptr) { + NN_LOG_ERROR("RDMA Context is null, probably not initialized"); + return RR_PARAM_INVALID; + } + + // create and init CQ + auto tmpCQ = new (std::nothrow) + RDMACq(DetailName(), mRDMAContext, mOptions.workerMode == EVENT_POLLING, reinterpret_cast(this)); + if (tmpCQ == nullptr) { + NN_LOG_ERROR("Verbs Failed to new RDMACq in RDMAWorker " << DetailName() << ", probably out of memory"); + return RR_NEW_OBJECT_FAILED; + } + + tmpCQ->SetCQCount(mOptions.completionQueueDepth); + + RResult result = RR_OK; + if ((result = tmpCQ->Initialize()) != RR_OK) { + NN_LOG_ERROR("Verbs Failed to initialize RDMACq in RDMAWorker " << DetailName() << ", result " << result); + delete tmpCQ; + return result; + } + + if ((result = mOpCtxInfoPool.Initialize(mOpCtxMemPool)) != RR_OK) { + NN_LOG_ERROR("Verbs Failed to initialize operation context info pool in RDMAWorker " << DetailName()); + delete tmpCQ; + return result; + } + + if ((result = mSglCtxInfoPool.Initialize(mSglCtxMemPool)) != RR_OK) { + NN_LOG_ERROR("Verbs Failed to initialize sgl context info pool in RDMAWorker " << DetailName()); + delete tmpCQ; + return result; + } + + mRDMACq = tmpCQ; + mRDMACq->IncreaseRef(); + mInited = true; + + return RR_OK; +} + +RResult RDMAWorker::UnInitialize() +{ + if (!mInited) { + return RR_OK; + } + + if (mRDMACq != nullptr) { + mRDMACq->DecreaseRef(); + mRDMACq = nullptr; + } + + if (mRDMAContext != nullptr) { + mRDMAContext->DecreaseRef(); + mRDMAContext = nullptr; + } + + if (mOpCtxMemPool != nullptr) { + mOpCtxMemPool.Set(nullptr); + } + + mOpCtxInfoPool.UnInitialize(); + + mInited = false; + return RR_OK; +} + +RResult RDMAWorker::ReInitializeCQ() +{ + if (!mInited) { + return RR_OK; + } + + if (mRDMACq != nullptr) { + mRDMACq->DecreaseRef(); + mRDMACq = nullptr; + } + + // create and init CQ + auto tmpCQ = new (std::nothrow) + RDMACq(DetailName(), mRDMAContext, mOptions.workerMode == EVENT_POLLING, reinterpret_cast(this)); + if (tmpCQ == nullptr) { + NN_LOG_ERROR("Failed to new RDMACq in RDMAWorker " << DetailName() << ", probably out of memory"); + return RR_NEW_OBJECT_FAILED; + } + + tmpCQ->SetCQCount(mOptions.completionQueueDepth); + + RResult result = RR_OK; + if ((result = tmpCQ->Initialize()) != RR_OK) { + delete tmpCQ; + tmpCQ = nullptr; + NN_LOG_ERROR("Failed to initialize RDMACq in RDMAWorker " << DetailName() << ", result " << result); + return result; + } + + mRDMACq = tmpCQ; + mRDMACq->IncreaseRef(); + + return RR_OK; +} + +RResult RDMAWorker::Start() +{ + if (!mInited) { + NN_LOG_ERROR("Failed to start RDMAWorker " << DetailName() << " as not initialized"); + return RR_WORKER_NOT_INITIALIZED; + } + + if (mThreadStop.load()) { + NN_LOG_ERROR("Failed to start RDMAWorker " << DetailName() << "worker thread not stop"); + return RR_WORKER_START_ERROR; + } + + if (mOptions.dontStartWorkers) { + NN_LOG_INFO("Do not start workers " << DetailName()); + return RR_OK; + } + + if ((mOptions.workerType == RECEIVER || mOptions.workerType == SENDER_RECEIVER) && mNewRequestHandler == nullptr) { + NN_LOG_ERROR("New request handler is not registered yet in RDMAWorker " << DetailName()); + return RR_WORKER_REQUEST_HANDLER_NOT_SET; + } + + if ((mOptions.workerType == SENDER || mOptions.workerType == SENDER_RECEIVER) && mSendPostedHandler == nullptr) { + NN_LOG_ERROR("Send request posted handler is not registered yet in RDMAWorker " << DetailName()); + return RR_WORKER_SEND_POSTED_HANDLER_NOT_SET; + } + + if (mOneSideDoneHandler == nullptr) { + NN_LOG_WARN("One side done handler is not registered yet in RDMAWorker " << DetailName()); + } + + mNeedStop = false; + std::thread tmpThread(&RDMAWorker::RunInThread, this); + mProgressThread = std::move(tmpThread); + std::string threadName = "RDMAWkr" + mIndex.ToString(); + if (pthread_setname_np(mProgressThread.native_handle(), threadName.c_str()) != 0) { + NN_LOG_WARN("Unable to set name of RDMAWorker progress thread"); + } + + if (mProgressCpuId != -1) { + cpu_set_t cpuSet; + CPU_ZERO(&cpuSet); + CPU_SET(mProgressCpuId, &cpuSet); + if (pthread_setaffinity_np(mProgressThread.native_handle(), sizeof(cpuSet), &cpuSet) != 0) { + NN_LOG_WARN("Unable to bind RDMAWorker" << mIndex.ToString() << " << to cpu " << mProgressCpuId); + } + } + + while (!mProgressThreadStarted.load()) { + usleep(NN_NO10); + } + mThreadStop.store(true); + return RR_OK; +} + +RResult RDMAWorker::Stop() +{ + if (!mThreadStop.load()) { + return RR_OK; + } + mNeedStop = true; + if (mProgressThread.native_handle()) { + mProgressThread.join(); + } + mThreadStop.store(false); + return RR_OK; +} + +#define BUSY_POLLING() \ + if (NN_UNLIKELY(mRDMACq->ProgressV(wc, pollCount) != RR_OK)) { \ + /* timeout return 0, count = 0, will invoke PROCESS_POLLING_RESULT() idle */ \ + /* do later */ \ + continue; \ + } + +#define CQ_EVENT_POLLING() \ + if (NN_UNLIKELY(mRDMACq->EventProgressV(wc, pollCount, pollTimeOut) != RR_OK)) { \ + /* timeout need invoke idle */ \ + if (mIdleHandler != nullptr) { \ + mIdleHandler(mIndex); \ + } \ + /* do later */ \ + continue; \ + } + +#define PROCESS_POLLING_RESULT(pollCount, contextInfo, lastBrokenQp) \ + do { \ + for (int i = 0; i < (pollCount); i++) { \ + (contextInfo) = reinterpret_cast(wc[i].wr_id); \ + if ((contextInfo)->qpNum != wc[i].qp_num || \ + (contextInfo)->opResultType == RDMAOpContextInfo::INVALID_MAGIC) { \ + continue; \ + } \ + (contextInfo)->opResultType = RDMAOpContextInfo::OpResult(wc[i]); \ + if (NN_LIKELY(wc[i].status == IBV_WC_SUCCESS)) { \ + /* detach the context */ \ + (contextInfo)->qp->RemoveOpCtxInfo(contextInfo); \ + } else { \ + if ((contextInfo)->opType == RDMAOpContextInfo::HB_WRITE) { \ + (lastBrokenQp) = (contextInfo)->qp; \ + NN_LOG_INFO("HB poll cq receive wcStatus " << wc[i].status << ", maybe remote ep " << \ + (contextInfo)->qp->UpId() << " closed"); \ + } else if (((contextInfo)->qp->isStarted) && ((lastBrokenQp) != (contextInfo)->qp)) { \ + (lastBrokenQp) = (contextInfo)->qp; \ + NN_LOG_ERROR("Poll cq failed in RDMAWorker " << DetailName() << ", wcStatus " << wc[i].status << \ + ", opType " << (contextInfo)->opType << ", ep id " << (contextInfo)->qp->UpId()); \ + } else if (((contextInfo)->qp->isStarted) && (lastErrorWcStatus != wc[i].status)) { \ + lastErrorWcStatus = wc[i].status; \ + NN_LOG_ERROR("Poll cq failed in RDMAWorker " << DetailName() << ", wc Status " << wc[i].status << \ + ", opType " << (contextInfo)->opType << ", ep id " << (contextInfo)->qp->UpId()); \ + } \ + } \ + \ + auto asyncEp = reinterpret_cast(contextInfo->qp->UpContext()); \ + asyncEp->UpdateTargetHbTime(); \ + switch ((contextInfo)->opType) { \ + case (RDMAOpContextInfo::OpType::SEND): \ + case (RDMAOpContextInfo::OpType::SEND_RAW): \ + case (RDMAOpContextInfo::OpType::SEND_RAW_SGL): \ + case (RDMAOpContextInfo::OpType::SEND_SGL_INLINE): \ + mSendPostedHandler(contextInfo); \ + break; \ + case (RDMAOpContextInfo::OpType::RECEIVE): \ + /* NOTE, up context is store imm data */ \ + (contextInfo)->dataSize = wc[i].byte_len; \ + *((int32_t *)(void *)&((contextInfo)->upCtx)) = wc[i].imm_data; \ + mNewRequestHandler(contextInfo); \ + break; \ + case (RDMAOpContextInfo::OpType::WRITE): \ + case (RDMAOpContextInfo::OpType::SGL_WRITE): \ + case (RDMAOpContextInfo::OpType::HB_WRITE): \ + case (RDMAOpContextInfo::OpType::READ): \ + case (RDMAOpContextInfo::OpType::SGL_READ): \ + mOneSideDoneHandler(contextInfo); \ + break; \ + default: \ + NN_LOG_ERROR("Poll cq invalid OpType " << (contextInfo)->opType); \ + break; \ + } \ + } \ + \ + /* if there is no coming request, call up idle function */ \ + if (mIdleHandler != nullptr && (pollCount) == 0) { \ + mIdleHandler(mIndex); \ + } \ + } while (0) + + +void RDMAWorker::DoWithBusyPolling() +{ + // allocate wc vector + auto *wc = static_cast(calloc(mProgressBatchSize, sizeof(struct ibv_wc))); + if (wc == nullptr) { + NN_LOG_ERROR("Failed to allocate wc in RDMAWorker " << DetailName() << ", thread exiting"); + return; + } + + int pollCount = 0; + RDMAOpContextInfo *contextInfo = nullptr; + + RDMAQp *lastBrokenQp = nullptr; + enum ibv_wc_status lastErrorWcStatus = IBV_WC_SUCCESS; + + while (!mNeedStop) { + try { + pollCount = mProgressBatchSize; + BUSY_POLLING() + TRACE_DELAY_BEGIN(RDMA_WORKER_EVENT_POLLING); + PROCESS_POLLING_RESULT(pollCount, contextInfo, lastBrokenQp); + TRACE_DELAY_END(RDMA_WORKER_EVENT_POLLING, 0); + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Verbs Got runtime incorrect signal in RDMAWorker::RunInThread '" << ex.what() << + "', ignore and continue"); + } catch (...) { + NN_LOG_WARN("Verbs Got unknown signal in RDMAWorker::RunInThread, ignore and continue"); + } + } + + free(wc); + wc = nullptr; +} + +void RDMAWorker::DoWithCQEventPolling() +{ + // allocate wc vector + auto *wc = static_cast(calloc(mProgressBatchSize, sizeof(struct ibv_wc))); + if (wc == nullptr) { + NN_LOG_ERROR("Failed to allocate wc in RDMAWorker " << DetailName() << ", thread exiting"); + return; + } + + int pollCount = 0; + uint32_t pollTimeOut = 0; + RDMAOpContextInfo *contextInfo = nullptr; + + RDMAQp *lastBrokenQp = nullptr; + enum ibv_wc_status lastErrorWcStatus = IBV_WC_SUCCESS; + + while (!mNeedStop) { + try { + pollTimeOut = mOptions.eventPollingTimeout; + pollCount = mProgressBatchSize; + CQ_EVENT_POLLING() + TRACE_DELAY_BEGIN(RDMA_WORKER_EVENT_POLLING); + PROCESS_POLLING_RESULT(pollCount, contextInfo, lastBrokenQp); + TRACE_DELAY_END(RDMA_WORKER_EVENT_POLLING, 0); + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Got runtime incorrect signal in RDMAWorker::RunInThread '" << ex.what() << + "', ignore and continue"); + } catch (...) { + NN_LOG_WARN("Got unknown signal in RDMAWorker::RunInThread, ignore and continue"); + } + } + + free(wc); + wc = nullptr; +} + +void RDMAWorker::RunInThread() +{ + if (mOptions.threadPriority != 0) { + if (NN_UNLIKELY(setpriority(PRIO_PROCESS, 0, mOptions.threadPriority) != 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Unable to set worker thread priority in rdma worker " << mName << ", errno:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + } + } + + mProgressThreadStarted.store(true); + NN_LOG_INFO("RDMAWorker " << DetailName() << ", cpuId: " << mProgressCpuId << ", cq count: " << + ((mRDMACq != nullptr) ? mRDMACq->GetCQCount() : 0) << ", polling batch size: " << mProgressBatchSize << + ", more " << mOptions.ToString() << "] working thread started"); + + if (mOptions.workerMode == BUSY_POLLING) { + DoWithBusyPolling(); + } else if (mOptions.workerMode == EVENT_POLLING) { + DoWithCQEventPolling(); + } else { + NN_LOG_ERROR("Un-reachable"); + } + + NN_LOG_INFO("RDMAWorker " << DetailName() << " working thread exiting"); +} + +RResult RDMAWorker::Create(const std::string &name, RDMAContext *ctx, const RDMAWorkerOptions &options, + NetMemPoolFixedPtr memPool, NetMemPoolFixedPtr sglMemPool, RDMAWorker *&outWorker) +{ + if (ctx == nullptr || name.empty()) { + NN_LOG_ERROR("Create worker param invalid"); + return RR_PARAM_INVALID; + } + + auto tmp = new (std::nothrow) RDMAWorker(name, ctx, options, std::move(memPool), std::move(sglMemPool)); + if (tmp == nullptr) { + NN_LOG_ERROR("Failed to create RDMAWorker, probably out of memory"); + return RR_NEW_OBJECT_FAILED; + } + + outWorker = tmp; + return RR_OK; +} + +RResult RDMAWorker::CreateQP(RDMAQp *&qp) +{ + if (NN_UNLIKELY(!mInited)) { + NN_LOG_ERROR("Failed to create qp with RDMAWorker " << DetailName() << " as not initialized"); + return RR_WORKER_NOT_INITIALIZED; + } + + QpOptions qpOptions(mOptions.qpSendQueueSize, mOptions.qpReceiveQueueSize, mOptions.qpMrSegSize, + mOptions.qpMrSegCount); + qp = new (std::nothrow) RDMAQp(DetailName(), RDMAQp::NewId(), mRDMAContext, mRDMACq, qpOptions); + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to create qp with RDMAWorker " << DetailName() << ", probably out of memory"); + return RR_NEW_OBJECT_FAILED; + } + + qp->UpContext1(reinterpret_cast(this)); + return RR_OK; +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/rdma/verbs/rdma_worker_io.cpp b/src/transport/rdma/verbs/rdma_worker_io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..940d264c5821dc8c79e28ec2b7cf456d7ea03ff4 --- /dev/null +++ b/src/transport/rdma/verbs/rdma_worker_io.cpp @@ -0,0 +1,505 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "rdma_worker.h" +#include "net_rdma_async_endpoint.h" + +namespace ock { +namespace hcom { + +RResult RDMAWorker::PostReceive(RDMAQp *qp, uintptr_t bufAddress, uint32_t bufSize, uint32_t localKey) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to PostReceive with RDMAWorker " << DetailName() << " as qp is null"); + return RR_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostReceive with RDMAWorker " << DetailName() << " as no ctx left"); + return RR_QP_CTX_FULL; + } + + /* set to all 0 */ + bzero(ctx, sizeof(RDMAOpContextInfo)); + ctx->qp = qp; + ctx->mrMemAddr = bufAddress; + ctx->dataSize = bufSize; + ctx->qpNum = qp->QpNum(); + ctx->lKey = localKey; + ctx->opType = RDMAOpContextInfo::RECEIVE; + ctx->opResultType = RDMAOpContextInfo::SUCCESS; + qp->IncreaseRef(); + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + auto res = qp->PostReceive(bufAddress, bufSize, localKey, reinterpret_cast(ctx)); + if (NN_UNLIKELY(res != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return res; +} + +RResult RDMAWorker::RePostReceive(RDMAOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->qp == nullptr)) { + NN_LOG_ERROR("Failed to RePostReceive with RDMAWorker " << DetailName() << " as ctx or its qp is null"); + return RR_PARAM_INVALID; + } + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + ctx->qp->AddOpCtxInfo(ctx); + + auto result = + ctx->qp->PostReceive(ctx->mrMemAddr, mOptions.qpMrSegSize, ctx->lKey, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + ctx->qp->DecreaseRef(); + ctx->qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +RResult RDMAWorker::PostSend(RDMAQp *qp, const RDMASendReadWriteRequest &req, uint32_t immData) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Verbs Failed to PostSend with RDMAWorker " << DetailName() << " as qp is null"); + return RR_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Verbs Failed to PostSend with RDMAWorker " << DetailName() << " as no reqInfo left"); + return RR_QP_CTX_FULL; + } + + if (NN_UNLIKELY(!qp->GetPostSendWr())) { + NN_LOG_ERROR("Verbs Failed to PostSend with RDMAWorker " << DetailName() << " as no post send wr left"); + mOpCtxInfoPool.Return(ctx); + return RR_QP_POST_SEND_WR_FULL; + } + ctx->qp = qp; + ctx->mrMemAddr = req.lAddress; + ctx->dataSize = req.size; + ctx->qpNum = qp->QpNum(); + // Prevent integer truncation, safely converts uint64_t to uint32_t + if (NN_UNLIKELY(req.lKey > UINT32_MAX)) { + NN_LOG_ERROR("Failed to PostSend with RDMAWorker as lKey is larger than uint32max, lkey" << req.lKey); + return RR_PARAM_INVALID; + } + ctx->lKey = static_cast(req.lKey); + ctx->opType = immData == 0 ? RDMAOpContextInfo::SEND : RDMAOpContextInfo::SEND_RAW; + ctx->opResultType = RDMAOpContextInfo::SUCCESS; + ctx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0 && NN_UNLIKELY(memcpy_s(ctx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy req to ctx"); + return RR_PARAM_INVALID; + } + qp->IncreaseRef(); + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + auto result = qp->PostSend(req.lAddress, req.size, static_cast(req.lKey), + reinterpret_cast(ctx), immData); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnPostSendWr(); + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +RResult RDMAWorker::PostSendSglInline( + RDMAQp *qp, const RDMASendSglInlineHeader &header, const RDMASendReadWriteRequest &req, uint32_t immData) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("RDMA Failed to PostSendSgl with RDMAWorker " << DetailName() << " as qp is null"); + return RR_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("RDMA Failed to PostSendSgl with RDMAWorker " << DetailName() << " as no reqInfo left"); + return RR_QP_CTX_FULL; + } + + if (NN_UNLIKELY(!qp->GetPostSendWr())) { + NN_LOG_ERROR("RDMA Failed to PostSendSgl with RDMAWorker " << DetailName() << " as no post send wr left"); + mOpCtxInfoPool.Return(ctx); + return RR_QP_POST_SEND_WR_FULL; + } + ctx->qp = qp; + ctx->mrMemAddr = req.lAddress; + ctx->dataSize = req.size; + ctx->qpNum = qp->QpNum(); + ctx->lKey = req.lKey; + ctx->opType = RDMAOpContextInfo::SEND_SGL_INLINE; + ctx->opResultType = RDMAOpContextInfo::SUCCESS; + ctx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0 && NN_UNLIKELY(memcpy_s(ctx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy request to ctx"); + return RR_PARAM_INVALID; + } + qp->IncreaseRef(); + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + UBSHcomNetTransDataIov netTransDataIov[NN_NO2]; + netTransDataIov[NN_NO0].address = reinterpret_cast(&header); + netTransDataIov[NN_NO0].size = sizeof(RDMASendSglInlineHeader); + netTransDataIov[NN_NO1].address = req.lAddress; + netTransDataIov[NN_NO1].size = req.size; + + auto result = qp->PostSendSglInline( + netTransDataIov, NN_NO2, reinterpret_cast(ctx), immData); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnPostSendWr(); + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +RResult RDMAWorker::PostSendSgl(RDMAQp *qp, const RDMASendSglRWRequest &req, const RDMASendReadWriteRequest &tlsReq, + uint32_t immData, bool isEncrypted) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with RDMAWorker " << DetailName() << " as qp is null"); + return RR_PARAM_INVALID; + } + + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with RDMAWorker " << DetailName() << " as no ctx left"); + return RR_PARAM_INVALID; + } + + sglCtx->qp = qp; + sglCtx->result = RR_OK; + if (NN_UNLIKELY(memcpy_s(sglCtx->iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, req.iov, + sizeof(UBSHcomNetTransSgeIov) * req.iovCount) != RR_OK)) { + NN_LOG_ERROR("Failed to copy request to sglCtx"); + return RR_PARAM_INVALID; + } + sglCtx->iovCount = req.iovCount; + sglCtx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy request to sglCtx"); + return RR_PARAM_INVALID; + } + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostSend with RDMAWorker " << DetailName() << " as no reqInfo left"); + return RR_QP_CTX_FULL; + } + + if (NN_UNLIKELY(!qp->GetPostSendWr())) { + NN_LOG_ERROR("Failed to PostSend with RDMAWorker " << DetailName() << " as no post send wr left"); + mOpCtxInfoPool.Return(ctx); + return RR_QP_POST_SEND_WR_FULL; + } + ctx->qp = qp; + + // if not encrypt reqTls lAddress\size\lKey is 0 + ctx->mrMemAddr = tlsReq.lAddress; + ctx->dataSize = tlsReq.size; + // Prevent integer truncation, safely converts uint64_t to uint32_t + if (NN_UNLIKELY(tlsReq.lKey > UINT32_MAX)) { + NN_LOG_ERROR("Failed to PostSendSgl with RDMAWorker as lKey is larger than uint32max, lkey" << tlsReq.lKey); + return RR_PARAM_INVALID; + } + ctx->lKey = static_cast(tlsReq.lKey); + ctx->qpNum = qp->QpNum(); + ctx->opType = RDMAOpContextInfo::SEND_RAW_SGL; + ctx->opResultType = RDMAOpContextInfo::SUCCESS; + ctx->upCtxSize = static_cast(sizeof(RDMASgeCtxInfo)); + auto upCtx = static_cast((void *)&(ctx->upCtx)); + upCtx->ctx = sglCtx; + qp->IncreaseRef(); + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + RResult result = RR_OK; + if (isEncrypted != 0) { + result = qp->PostSend(tlsReq.lAddress, tlsReq.size, static_cast(tlsReq.lKey), + reinterpret_cast(ctx), immData); + } else { + result = qp->PostSendSgl(req.iov, req.iovCount, reinterpret_cast(ctx), immData); + } + + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnPostSendWr(); + qp->RemoveOpCtxInfo(ctx); + qp->DecreaseRef(); + mOpCtxInfoPool.Return(ctx); + mSglCtxInfoPool.Return(sglCtx); + } + + return result; +} + +RResult RDMAWorker::PostRead(RDMAQp *qp, const RDMASendReadWriteRequest &req) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with RDMAWorker " << DetailName() << " as qp is null"); + return RR_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with RDMAWorker " << DetailName() << " as no reqInfo left"); + return RR_QP_CTX_FULL; + } + + if (NN_UNLIKELY(!qp->GetOneSideWr())) { + NN_LOG_ERROR("Failed to PostSend with RDMAWorker " << DetailName() << " as no one side wr left"); + mOpCtxInfoPool.Return(ctx); + return RR_QP_ONE_SIDE_WR_FULL; + } + ctx->mrMemAddr = req.lAddress; + ctx->qp = qp; + ctx->dataSize = req.size; + ctx->qpNum = qp->QpNum(); + // Prevent integer truncation, safely converts uint64_t to uint32_t + if (NN_UNLIKELY(req.lKey > UINT32_MAX || req.rKey > UINT32_MAX)) { + NN_LOG_ERROR("Failed to PostRead with RDMAWorker as Key is larger than uint32max, lkey" << + req.lKey << " rKey " << req.rKey); + return RR_PARAM_INVALID; + } + ctx->lKey = static_cast(req.lKey); + ctx->opType = RDMAOpContextInfo::READ; + ctx->opResultType = RDMAOpContextInfo::SUCCESS; + ctx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(ctx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return RR_PARAM_INVALID; + } + } + qp->IncreaseRef(); + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + auto result = qp->PostRead(req.lAddress, static_cast(req.lKey), req.rAddress, + static_cast(req.rKey), req.size, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnOneSideWr(); + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +RResult RDMAWorker::PostOneSideSgl(RDMAQp *qp, const RDMASendSglRWRequest &req, bool isRead) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to oneSide operation with RDMAWorker " << DetailName() << " as qp is null"); + return RR_PARAM_INVALID; + } + + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to oneSide operation with RDMAWorker " << DetailName() << " as no ctx left"); + return RR_PARAM_INVALID; + } + + sglCtx->result = RR_OK; + sglCtx->qp = qp; + if (NN_UNLIKELY(memcpy_s(sglCtx->iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, req.iov, + sizeof(UBSHcomNetTransSgeIov) * req.iovCount) != RR_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + mSglCtxInfoPool.Return(sglCtx); + return RR_PARAM_INVALID; + } + sglCtx->upCtxSize = req.upCtxSize; + sglCtx->iovCount = req.iovCount; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + mSglCtxInfoPool.Return(sglCtx); + return RR_PARAM_INVALID; + } + } + sglCtx->refCount = 0; + RDMASgeCtxInfo sgeInfo(sglCtx); + uint64_t ctxArr[NET_SGE_MAX_IOV]; + RResult result = CreateOneSideCtx(sgeInfo, req.iov, req.iovCount, ctxArr, isRead); + if (result != RR_OK) { + NN_LOG_ERROR("Failed to create one side ctx."); + mSglCtxInfoPool.Return(sglCtx); + return result; + } + + result = qp->PostOneSideSgl(req.iov, req.iovCount, ctxArr, isRead); + if (NN_UNLIKELY(result != RR_OK)) { + for (int i = 0; i < req.iovCount; ++i) { + qp->ReturnOneSideWr(); + qp->RemoveOpCtxInfo(reinterpret_cast(ctxArr[i])); + qp->DecreaseRef(); + mOpCtxInfoPool.Return(reinterpret_cast(ctxArr[i])); + } + mSglCtxInfoPool.Return(sglCtx); + } + + return result; +} + +RResult RDMAWorker::CreateOneSideCtx(RDMASgeCtxInfo &sgeInfo, UBSHcomNetTransSgeIov *iov, uint32_t iovCount, + uint64_t (&ctxArr)[NET_SGE_MAX_IOV], bool isRead) +{ + if (iov == nullptr || iovCount == NN_NO0 || iovCount > NN_NO4 || ctxArr == nullptr) { + NN_LOG_ERROR("Failed to create oneSide operation ctx because param invalid"); + return RR_PARAM_INVALID; + } + for (uint32_t i = 0; i < iovCount; ++i) { + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Verbs failed to oneSide operation with RDMAWorker " << DetailName() << " as no ctx left"); + for (uint32_t j = 0; j < i; ++j) { + sgeInfo.ctx->qp->ReturnOneSideWr(); + sgeInfo.ctx->qp->RemoveOpCtxInfo(reinterpret_cast(ctxArr[j])); + sgeInfo.ctx->qp->DecreaseRef(); + mOpCtxInfoPool.Return(reinterpret_cast(ctxArr[j])); + } + return RR_QP_CTX_FULL; + } + + if (NN_UNLIKELY(!sgeInfo.ctx->qp->GetOneSideWr())) { + NN_LOG_ERROR("Verbs failed to oneSide operation with RDMAWorker " << DetailName() << + " as no one side wr left"); + mOpCtxInfoPool.Return(ctx); + for (uint32_t j = 0; j < i; ++j) { + sgeInfo.ctx->qp->ReturnOneSideWr(); + sgeInfo.ctx->qp->RemoveOpCtxInfo(reinterpret_cast(ctxArr[j])); + sgeInfo.ctx->qp->DecreaseRef(); + mOpCtxInfoPool.Return(reinterpret_cast(ctxArr[j])); + } + return RR_QP_ONE_SIDE_WR_FULL; + } + ctx->qp = sgeInfo.ctx->qp; + ctx->mrMemAddr = iov[i].lAddress; + ctx->dataSize = iov[i].size; + ctx->qpNum = sgeInfo.ctx->qp->QpNum(); + ctx->lKey = static_cast(iov[i].lKey); + ctx->opType = isRead ? RDMAOpContextInfo::SGL_READ : RDMAOpContextInfo::SGL_WRITE; + ctx->opResultType = RDMAOpContextInfo::SUCCESS; + ctx->upCtxSize = static_cast(sizeof(RDMASgeCtxInfo)); + auto upCtx = static_cast((void *)&(ctx->upCtx)); + upCtx->ctx = sgeInfo.ctx; + upCtx->idx = i; + + sgeInfo.ctx->qp->IncreaseRef(); + sgeInfo.ctx->qp->AddOpCtxInfo(ctx); + ctxArr[i] = reinterpret_cast(ctx); + } + return RR_OK; +} + +RResult RDMAWorker::PostWrite(RDMAQp *qp, const RDMASendReadWriteRequest &req, RDMAOpContextInfo::OpType type) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with RDMAWorker " << DetailName() << " as qp is null"); + return RR_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with RDMAWorker " << DetailName() << " as no ctx left"); + return RR_QP_CTX_FULL; + } + if (NN_UNLIKELY(!qp->GetOneSideWr())) { + NN_LOG_ERROR("Failed to PostWrite with RDMAWorker " << DetailName() << " as no one side wr left"); + mOpCtxInfoPool.Return(ctx); + return RR_QP_ONE_SIDE_WR_FULL; + } + ctx->qp = qp; + ctx->mrMemAddr = req.lAddress; + ctx->dataSize = req.size; + ctx->qpNum = qp->QpNum(); + // Prevent integer truncation, safely converts uint64_t to uint32_t + if (NN_UNLIKELY(req.lKey > UINT32_MAX || req.rKey > UINT32_MAX)) { + NN_LOG_ERROR("Failed to PostWrite with RDMAWorker as Key is larger than uint32max, lkey" << + req.lKey << " rKey " << req.rKey); + return RR_PARAM_INVALID; + } + ctx->lKey = static_cast(req.lKey); + ctx->opType = type; + ctx->upCtxSize = req.upCtxSize; + ctx->opResultType = RDMAOpContextInfo::SUCCESS; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(ctx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != RR_OK)) { + NN_LOG_ERROR("Failed to copy request to ctx"); + return RR_PARAM_INVALID; + } + } + qp->IncreaseRef(); + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + auto result = qp->PostWrite(req.lAddress, static_cast(req.lKey), req.rAddress, + static_cast(req.rKey), req.size, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != RR_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnOneSideWr(); + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/shm/net_shm_async_endpoint.cpp b/src/transport/shm/net_shm_async_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..044e46a63746b806049188b2a82cd1295cb96c90 --- /dev/null +++ b/src/transport/shm/net_shm_async_endpoint.cpp @@ -0,0 +1,516 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "shm_validation.h" +#include "hcom_log.h" +#include "net_shm_async_endpoint.h" + +namespace ock { +namespace hcom { +NetAsyncEndpointShm::NetAsyncEndpointShm(uint64_t id, ShmChannel *ch, ShmWorker *worker, NetDriverShmWithOOB *driver, + const UBSHcomNetWorkerIndex &workerIndex, ShmMRHandleMap &handleMap) + : NetEndpointImpl(id, workerIndex), + mShmCh(ch), + mWorker(worker), + mDriver(driver), + mrHandleMap(handleMap) +{ + if (mShmCh != nullptr) { + mShmCh->IncreaseRef(); + } + + if (mWorker != nullptr) { + mWorker->IncreaseRef(); + } + + if (mDriver != nullptr) { + mDriver->IncreaseRef(); + } + + if (mDriver != nullptr && mShmCh != nullptr) { + mSegSize = mDriver->GetOptions().mrSendReceiveSegSize; + mAllowedSize = mSegSize - sizeof(UBSHcomNetTransHeader); + } + + OBJ_GC_INCREASE(NetAsyncEndpointShm); +} + +NetAsyncEndpointShm::~NetAsyncEndpointShm() +{ + if (mShmCh != nullptr) { + mShmCh->DecreaseRef(); + } + + if (mWorker != nullptr) { + mWorker->DecreaseRef(); + } + + if (mDriver != nullptr) { + mDriver->DecreaseRef(); + } + + OBJ_GC_DECREASE(NetAsyncEndpointShm); +} + +NResult NetAsyncEndpointShm::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, opCode, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post send as validate fail"); + return result; + } + + if (NN_UNLIKELY((result = PostSendValidationMaxSize(request, mAllowedSize, mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post send as validate size fail"); + return result; + } + + /* get free buffer from channel */ + uintptr_t address = 0; + uint64_t offset = 0; + result = mShmCh->DCGetFreeBuck(address, offset, NN_NO100, mDefaultTimeout); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Shm Failed to get free buck from shm channel " << mShmCh->Id() << ", result " << result); + return result; + } + + /* copy header */ + auto *header = reinterpret_cast(address); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->seqNo = seqNO == 0 ? NextSeq() : seqNO; + header->opCode = opCode; + header->flags = NTH_TWO_SIDE; + + /* copy message */ + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(address + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("Shm Failed to post send message as encryption failure"); + (void)mShmCh->DCMarkBuckFree(address); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + header->dataLength = request.size; + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(address + sizeof(UBSHcomNetTransHeader)), + mShmCh->GetSendDCBuckSize() - sizeof(UBSHcomNetTransHeader), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + (void)mShmCh->DCMarkBuckFree(address); + NN_LOG_ERROR("Failed to copy the request to address"); + return NN_INVALID_PARAM; + } + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + + UBSHcomNetTransRequest innerReq = request; + innerReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + innerReq.lAddress = address; + + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_ASYNC_POST_SEND); + do { + result = mWorker->PostSend(mShmCh, innerReq, offset, 0, mDefaultTimeout); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_ASYNC_POST_SEND, result); + return NN_OK; + } else if (NetMonotonic::TimeNs() < finishTime && NeedRetry(result) && mDefaultTimeout != 0) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + if (result == SH_SEND_COMPLETION_CALLBACK_FAILURE) { + NN_LOG_WARN("Post send successfully but unable to enqueue sent callback request, result " << result); + return result; + } + + /* mark data buck free if failed to send */ + (void)mShmCh->DCMarkBuckFree(address); + TRACE_DELAY_END(SHM_EP_ASYNC_POST_SEND, result); + NN_LOG_ERROR("Failed to post send request, result " << result); + return result; +} + +NResult NetAsyncEndpointShm::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, opCode, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post send as validation fail"); + return result; + } + + if (NN_UNLIKELY((result = PostSendValidationMaxSize(request, mAllowedSize, mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post send as validate size fail"); + return result; + } + + /* get free buffer from channel */ + uint64_t offset = 0; + uintptr_t address = 0; + result = mShmCh->DCGetFreeBuck(address, offset, NN_NO100, mDefaultTimeout); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Shm Failed to get free buck from shm channel " << mShmCh->Id() << "," << "result " << result); + return result; + } + + /* copy header */ + auto *header = reinterpret_cast(address); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header->timeout = opInfo.timeout; + header->errorCode = opInfo.errorCode; + header->opCode = opCode; + header->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + + /* copy message */ + if (mIsNeedEncrypt) { + uint32_t cipherLen = mAes.EstimatedEncryptLen(request.size); + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(address + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("Failed to post send message as encryption failure"); + (void)mShmCh->DCMarkBuckFree(address); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + header->dataLength = request.size; + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(address + sizeof(UBSHcomNetTransHeader)), + mShmCh->GetSendDCBuckSize() - sizeof(UBSHcomNetTransHeader), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + (void)mShmCh->DCMarkBuckFree(address); + NN_LOG_ERROR("Failed to copy request to address"); + return NN_INVALID_PARAM; + } + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + + UBSHcomNetTransRequest innerReq = request; + innerReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + innerReq.lAddress = address; + + // if result is timeout, need to retry + bool flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SHM_EP_ASYNC_POST_SEND); + do { + result = mWorker->PostSend(mShmCh, innerReq, offset, 0, mDefaultTimeout); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_ASYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + if (result == SH_SEND_COMPLETION_CALLBACK_FAILURE) { + NN_LOG_ERROR("Post send request successfully, failed to send completion callback to owner result " << result); + return result; + } + NN_LOG_ERROR("Failed to post send request, result is " << result); + (void)mShmCh->DCMarkBuckFree(address); + TRACE_DELAY_END(SHM_EP_ASYNC_POST_SEND, result); + return result; +} + +NResult NetAsyncEndpointShm::PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNO) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendRawValidation(mState, mId, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post send raw as validate fail"); + return result; + } + + if (NN_UNLIKELY((result = PostSendValidationMaxSize(request, mSegSize, mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post send raw as validate size fail"); + return result; + } + + /* get free buffer from channel */ + uintptr_t address = 0; + uint64_t offset = 0; + result = mShmCh->DCGetFreeBuck(address, offset, NN_NO100, mDefaultTimeout); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Shm Failed to get free buck from shm channel " << mShmCh->Id() << ", result " << result); + return result; + } + + UBSHcomNetTransRequest innerReq = request; + innerReq.lAddress = address; + + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(address), cipherLen)) { + NN_LOG_ERROR("Shm Failed to post send message as encryption failure"); + (void)mShmCh->DCMarkBuckFree(address); + return NN_ENCRYPT_FAILED; + } + innerReq.size = cipherLen; + } else { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(address), mShmCh->GetSendDCBuckSize(), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + (void)mShmCh->DCMarkBuckFree(address); + NN_LOG_ERROR("Shm Failed to copy the request to address"); + return NN_INVALID_PARAM; + } + innerReq.size = request.size; + } + + // if result is timeout, need to retry + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_ASYNC_POST_SEND_RAW); + do { + result = mWorker->PostSend(mShmCh, innerReq, offset, seqNO, mDefaultTimeout); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_ASYNC_POST_SEND_RAW, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + if (result == SH_SEND_COMPLETION_CALLBACK_FAILURE) { + NN_LOG_ERROR("Post send request successfully, failed to send completion callback to owner,result: " << result); + return result; + } + NN_LOG_ERROR("Failed to post send request, result " << result); + (void)mShmCh->DCMarkBuckFree(address); + TRACE_DELAY_END(SHM_EP_ASYNC_POST_SEND_RAW, result); + return result; +} + +NResult NetAsyncEndpointShm::PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendSglValidation(mState, mId, mDriver, seqNo, request, mSegSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post send raw sgl as validate fail"); + return result; + } + + /* get free buffer from channel */ + uintptr_t address = 0; + uint64_t offset = 0; + result = mShmCh->DCGetFreeBuck(address, offset, NN_NO100, mDefaultTimeout); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Shm Failed to get free buck from shm channel " << mShmCh->Id() << ", result " << result); + return result; + } + + uint32_t dataLen = 0; + uint32_t iovOffset = 0; + + UBSHcomNetTransRequest innerReq = {}; + innerReq.lAddress = address; + + if (mIsNeedEncrypt) { + for (uint16_t i = 0; i < request.iovCount; i++) { + dataLen += request.iov[i].size; + } + + UBSHcomNetMessage tmpMsg {}; + bool messageReady = tmpMsg.AllocateIfNeed(dataLen); + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Shm Failed to allocate net msg buffer failed"); + (void)mShmCh->DCMarkBuckFree(address); + return NN_MALLOC_FAILED; + } + for (uint16_t i = 0; i < request.iovCount; i++) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(reinterpret_cast(tmpMsg.mBuf) + iovOffset), + tmpMsg.GetBufLen() - iovOffset, reinterpret_cast(request.iov[i].lAddress), + request.iov[i].size) != NN_OK)) { + NN_LOG_ERROR("Failed to copy request to tmpMsg"); + mShmCh->DCMarkBuckFree(address); + return NN_INVALID_PARAM; + } + iovOffset += request.iov[i].size; + } + + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, tmpMsg.mBuf, dataLen, reinterpret_cast(address), cipherLen)) { + NN_LOG_ERROR("Shm Failed to post send message as encryption failure"); + (void)mShmCh->DCMarkBuckFree(address); + return NN_ENCRYPT_FAILED; + } + + innerReq.size = cipherLen; + } else { + for (uint16_t i = 0; i < request.iovCount; i++) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(address + iovOffset), + mShmCh->GetSendDCBuckSize() - iovOffset, reinterpret_cast(request.iov[i].lAddress), + request.iov[i].size) != NN_OK)) { + (void)mShmCh->DCMarkBuckFree(address); + NN_LOG_ERROR("Failed to copy request to address"); + return NN_INVALID_PARAM; + } + dataLen += request.iov[i].size; + iovOffset += request.iov[i].size; + } + innerReq.size = dataLen; + } + + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_ASYNC_POST_SEND_RAW_SGL); + do { + result = mWorker->PostSendRawSgl(mShmCh, innerReq, request, offset, seqNo, mDefaultTimeout); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_ASYNC_POST_SEND_RAW_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + if (result == SH_SEND_COMPLETION_CALLBACK_FAILURE) { + NN_LOG_ERROR("Post send request successfully, failed to send completion callback to owner,result " << result); + return result; + } + NN_LOG_ERROR("Failed to post send request, result " << result); + (void)mShmCh->DCMarkBuckFree(address); + TRACE_DELAY_END(SHM_EP_ASYNC_POST_SEND_RAW_SGL, result); + return result; +} + +NResult NetAsyncEndpointShm::PostRead(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteValidation(mState, mId, mDriver, mShmCh, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post read as validate fail"); + return result; + } + + uint64_t finishTime = GetFinishTime(); + + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_ASYNC_POST_READ); + do { + result = mWorker->PostRead(mShmCh, request, mrHandleMap); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_ASYNC_POST_READ, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + TRACE_DELAY_END(SHM_EP_ASYNC_POST_READ, result); + return result; +} + +NResult NetAsyncEndpointShm::PostRead(const UBSHcomNetTransSglRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostReadWriteSglValidation(mState, mId, mDriver, mShmCh, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post read sgl as validate fail"); + return result; + } + + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_ASYNC_POST_READ_SGL); + do { + result = mWorker->PostReadSgl(mShmCh, request, mrHandleMap); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_ASYNC_POST_READ_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + TRACE_DELAY_END(SHM_EP_ASYNC_POST_READ_SGL, result); + return result; +} + +NResult NetAsyncEndpointShm::PostWrite(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteValidation(mState, mId, mDriver, mShmCh, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post write as validate fail"); + return result; + } + + uint64_t finishTime = GetFinishTime(); + + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_ASYNC_POST_WRITE); + do { + result = mWorker->PostWrite(mShmCh, request, mrHandleMap); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_ASYNC_POST_WRITE, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + TRACE_DELAY_END(SHM_EP_ASYNC_POST_WRITE, result); + return result; +} + +NResult NetAsyncEndpointShm::PostWrite(const UBSHcomNetTransSglRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostReadWriteSglValidation(mState, mId, mDriver, mShmCh, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to async post write sgl as validate fail"); + return result; + } + + uint64_t finishTime = GetFinishTime(); + + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_ASYNC_POST_WRITE_SGL); + do { + result = mWorker->PostWriteSgl(mShmCh, request, mrHandleMap); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_ASYNC_POST_WRITE_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + TRACE_DELAY_END(SHM_EP_ASYNC_POST_WRITE_SGL, result); + return result; +} +} +} \ No newline at end of file diff --git a/src/transport/shm/net_shm_async_endpoint.h b/src/transport/shm/net_shm_async_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..f2a3e7ab285c22e466df22365035459f2a70a403 --- /dev/null +++ b/src/transport/shm/net_shm_async_endpoint.h @@ -0,0 +1,217 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_SHM_ASYNC_ENDPOINT_H +#define OCK_HCOM_NET_SHM_ASYNC_ENDPOINT_H + +#include "hcom.h" +#include "transport/net_endpoint_impl.h" +#include "hcom_utils.h" +#include "net_common.h" +#include "net_monotonic.h" +#include "net_security_alg.h" +#include "net_shm_common.h" +#include "net_shm_driver_oob.h" +#include "shm_composed_endpoint.h" +#include "shm_handle_fds.h" + +namespace ock { +namespace hcom { +class NetAsyncEndpointShm : public NetEndpointImpl { +public: + NetAsyncEndpointShm(uint64_t id, ShmChannel *ch, ShmWorker *worker, NetDriverShmWithOOB *driver, + const UBSHcomNetWorkerIndex &workerIndex, ShmMRHandleMap &handleMap); + ~NetAsyncEndpointShm() override; + + NResult SetEpOption(UBSHcomEpOptions &epOptions) override + { + NN_LOG_WARN("[SHM AsyncEp] Empty function for now"); + return NN_OK; + } + + uint32_t GetSendQueueCount() override + { + NN_LOG_WARN("[SHM AsyncEp] Empty function for now"); + return 0; + } + + const std::string &UdsName() override + { + if (NN_LIKELY(mShmCh != nullptr)) { + return mShmCh->UdsName(); + } + + return CONST_EMPTY_STRING; + } + + inline const std::string &PeerIpAndPort() override + { + if (NN_LIKELY(mShmCh != nullptr)) { + return mShmCh->PeerIpPort(); + } + + return CONST_EMPTY_STRING; + } + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + + NResult PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) override; + + NResult PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + + NResult PostRead(const UBSHcomNetTransSglRequest &request) override; + + NResult PostRead(const UBSHcomNetTransRequest &request) override; + + NResult PostWrite(const UBSHcomNetTransSglRequest &request) override; + + NResult PostWrite(const UBSHcomNetTransRequest &request) override; + + NResult WaitCompletion(int32_t timeout) override + { + NN_LOG_WARN("Shm Invalid operation, wait completion is not supported by NetAsyncEndpoint"); + return NN_INVALID_OPERATION; + } + + NResult Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) override + { + NN_LOG_WARN("Shm Invalid operation, wait completion is not supported by NetAsyncEndpoint"); + return NN_INVALID_OPERATION; + } + + NResult ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) override + { + NN_LOG_WARN("Shm Invalid operation, wait completion is not supported by NetAsyncEndpoint"); + return NN_INVALID_OPERATION; + } + + NResult SendFds(int fds[], uint32_t len) override + { + if (NN_UNLIKELY(len < NN_NO1 || len > NN_NO4)) { + NN_LOG_ERROR("Shm Failed to send fds in shm async ep as length should more than 0 and less than 4."); + return NN_PARAM_INVALID; + } + + if (NN_UNLIKELY(!mState.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Shm Failed to send fds in shm async ep as endpoint " << mId << + " is not established, state is " << UBSHcomNEPStateToString(mState.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + int innerFds[NN_NO4] = {0}; + for (uint32_t i = 0; i < len; i++) { + innerFds[i] = fds[i]; + if (fds[i] <= 0) { + NN_LOG_ERROR("Shm Failed to send fds in shm async ep, as invalid fds index:" << i); + return NN_INVALID_PARAM; + } + } + + std::lock_guard guard(mShmCh->mFdMutex); + ShmChKeeperMsgHeader header {}; + header.msgType = ShmChKeeperMsgType::EXCHANGE_USER_FD; + header.dataSize = len; + if (NN_UNLIKELY(::send(mShmCh->UdsFD(), &header, sizeof(ShmChKeeperMsgHeader), MSG_NOSIGNAL) <= 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Shm Failed to send header info of exchange external fd to peer, error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_ERROR; + } + + return ShmHandleFds::SendMsgFds(mShmCh->UdsFD(), innerFds, NN_NO4); + } + + NResult ReceiveFds(int fds[], uint32_t len, int32_t timeoutSec) override + { + if (NN_UNLIKELY(len < NN_NO1 || len > NN_NO4)) { + NN_LOG_ERROR("Shm Failed to receive fds in shm async ep as length should more than 0 and less than 4."); + return NN_PARAM_INVALID; + } + + if (NN_UNLIKELY(!mState.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Shm Failed to receive fds in shm async ep as endpoint " << mId << + " is not established, state is " << UBSHcomNEPStateToString(mState.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + return mShmCh->RemoveUserFds(fds, len, timeoutSec); + } + + NResult GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &idInfo) override + { + if (!mState.Compare(NEP_ESTABLISHED)) { + NN_LOG_ERROR("[SHM AsyncEp] EP is not established"); + return NN_EP_NOT_ESTABLISHED; + } + + if (!mDriver->mStartOobSvr) { + NN_LOG_ERROR("[SHM AsyncEp] oob server is not start"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + idInfo = mRemoteUdsIdInfo; + return NN_OK; + } + + void Close() override + { + if (NN_UNLIKELY(mShmCh != nullptr)) { + mShmCh->Close(); + } + } + + bool GetPeerIpPort(std::string &ip, uint16_t &port) override + { + NN_LOG_WARN("Shm Invalid operation for shm, shm does not have ip and port"); + return false; + } + +private: + bool inline NeedRetry(HResult &result) + { + if (NN_UNLIKELY(!State().Compare(NEP_ESTABLISHED))) { + result = NN_EP_NOT_ESTABLISHED; + return false; + } + + if (result == SH_OP_CTX_FULL || result == SH_RETRY_FULL) { + return true; + } + + return false; + } + + uint64_t inline GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + + ShmChannel *mShmCh = nullptr; + ShmWorker *mWorker = nullptr; + NetDriverShmWithOOB *mDriver = nullptr; + uint32_t mAllowedSize = 0; + ShmMRHandleMap &mrHandleMap; + + friend class NetDriverShmWithOOB; +}; +} +} + +#endif // OCK_HCOM_NET_SHM_ASYNC_ENDPOINT_H diff --git a/src/transport/shm/net_shm_common.h b/src/transport/shm/net_shm_common.h new file mode 100644 index 0000000000000000000000000000000000000000..cabd81544274f6986bb15c467032cd0cc0dd66f6 --- /dev/null +++ b/src/transport/shm/net_shm_common.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_NET_SHM_COMMON_H +#define HCOM_NET_SHM_COMMON_H + +#include "hcom.h" +#include "net_common.h" +#include "net_memory_region.h" +#include "net_mem_pool_fixed.h" +#include "net_oob.h" +#include "net_oob_ssl.h" +#include "shm_common.h" +#include "shm_channel.h" +#include "shm_worker.h" + +namespace ock { +namespace hcom { +class NetAsyncEndpointShm; +class NetSyncEndpointShm; + +class NetDriverShmWithOOB; + +} +} + +#endif // HCOM_NET_SHM_COMMON_H diff --git a/src/transport/shm/net_shm_driver_oob.cpp b/src/transport/shm/net_shm_driver_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dc6e664b2036560eaf6a67696904ba9b6cd5c467 --- /dev/null +++ b/src/transport/shm/net_shm_driver_oob.cpp @@ -0,0 +1,1630 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "hcom_def.h" +#include "hcom_log.h" +#include "net_shm_sync_endpoint.h" +#include "net_shm_async_endpoint.h" +#include "net_oob_secure.h" +#include "net_oob_ssl.h" +#include "shm_composed_endpoint.h" +#include "shm_validation.h" +#include "shm_handle_fds.h" +#include "net_shm_driver_oob.h" + +namespace ock { +namespace hcom { +constexpr const char *CHANNEL_KEEPER_NAME = "channel_keeper"; +constexpr const char *DELAY_RELEASE_TIMER_NAME = "delay_release_timer"; +constexpr const char *SHM_FILE_DIR_PATH = "/dev/shm"; +constexpr const char *SHM_FILE_PREFIX = "hcom-"; + +NResult NetDriverShmWithOOB::Initialize(const UBSHcomNetDriverOptions &option) +{ + std::lock_guard guard(mInitMutex); + if (mInited) { + return NN_OK; + } + + mOptions = option; + + NResult shmRes = NN_OK; + if (NN_UNLIKELY((shmRes = mOptions.ValidateCommonOptions()) != NN_OK)) { + return shmRes; + } + + if (NN_UNLIKELY(ValidateOptions() != NN_OK)) { + return shmRes; + } + + if (NN_UNLIKELY(UBSHcomNetOutLogger::Instance() == nullptr)) { + return NN_NOT_INITIALIZED; + } + +#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 0) + std::thread clearThread(&NetDriverShmWithOOB::ClearShmLeftFile, this); + mClearThread = std::move(clearThread); + std::string treadName = "clearShmFile" + std::to_string(mIndex); + if (pthread_setname_np(mClearThread.native_handle(), treadName.c_str()) != 0) { + NN_LOG_WARN("Unable to set name of NetDriverShmWithOOB clearThread working thread to " << treadName); + } + + while (!mClearThreadStarted.load()) { + usleep(NN_NO10); + } +#endif + + if (option.enableTls) { + if (HcomSsl::Load() != 0) { + NN_LOG_ERROR("Failed to load openssl API"); + return NN_NOT_INITIALIZED; + } + } + mEnableTls = option.enableTls; + NN_LOG_INFO("Try to initialize driver '" << mName << "' with " << mOptions.ToString()); + + if ((shmRes = CreateWorkerResource()) != NN_OK) { + NN_LOG_ERROR("Shm failed to create worker resource"); + UnInitializeInner(); + return shmRes; + } + + /* create workers */ + if ((shmRes = CreateWorkers()) != NN_OK) { + NN_LOG_ERROR("Shm failed to create workers"); + UnInitializeInner(); + return shmRes; + } + + /* create lb for client */ + if ((shmRes = CreateClientLB()) != NN_OK) { + NN_LOG_ERROR("Shm failed to create client lb"); + UnInitializeInner(); + return shmRes; + } + + /* create oob */ + if (mStartOobSvr) { + if ((shmRes = CreateListeners()) != NN_OK) { + NN_LOG_ERROR("Shm failed to create listeners"); + UnInitializeInner(); + return shmRes; + } + } + + auto channelKeeper = new (std::nothrow) ShmChannelKeeper(CHANNEL_KEEPER_NAME, mIndex); + if (NN_UNLIKELY(channelKeeper == nullptr)) { + NN_LOG_ERROR("Failed to create shm channel keeper in Driver " << mName); + UnInitializeInner(); + return NN_ERROR; + } + mChannelKeeper.Set(channelKeeper); + + auto delayReleaseTimer = new (std::nothrow) NetDelayReleaseTimer(DELAY_RELEASE_TIMER_NAME, mIndex); + if (NN_UNLIKELY(delayReleaseTimer == nullptr)) { + NN_LOG_ERROR("Failed to create shm channel delayReleaseTimer in Driver " << mName); + UnInitializeInner(); + return NN_ERROR; + } + mDelayReleaseTimer.Set(delayReleaseTimer); + + mInited = true; + return NN_OK; +} + +void NetDriverShmWithOOB::UnInitialize() +{ + std::lock_guard guard(mInitMutex); + if (!mInited) { + return; + } + + if (mStarted) { + NN_LOG_WARN("Unable to unInitialize shm driver " << mName << " which is not stopped"); + return; + } + + UnInitializeInner(); + mInited = false; +} + +NResult NetDriverShmWithOOB::ValidateOptions() +{ + if (NN_UNLIKELY(ValidateAndParseOobPortRange(mOptions.oobPortRange) != NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidateOptionsOobType() != NN_OK)) { + return NN_INVALID_PARAM; + } + + return NN_OK; // do later +} + +void NetDriverShmWithOOB::UnInitializeInner() +{ + ClearWorkers(); + DestroyClientLB(); + + if (mChannelKeeper != nullptr) { + mChannelKeeper.Set(nullptr); + } + + if (mDelayReleaseTimer != nullptr) { + mDelayReleaseTimer.Set(nullptr); + } + + if (!mOobServers.empty()) { + mOobServers.clear(); + } + + if (mClearThread.native_handle()) { + mClearThread.join(); + } + + mOpCtxMemPool = nullptr; + mOpCompMemPool = nullptr; + mSglCompMemPool = nullptr; + + std::lock_guard guard(mEndPointsMutex); + if (!mEndPoints.empty()) { + mEndPoints.clear(); + } +} + +NResult NetDriverShmWithOOB::CreateWorkerResource() +{ + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = NN_NextPower2(sizeof(ShmOpCompInfo)); + options.tcExpandBlkCnt = NN_NO64; + mOpCompMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mOpCompMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for op completion info pool in driver " << mName << + ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + auto result = mOpCompMemPool->Initialize(); + if (result != NN_OK) { + NN_LOG_ERROR("Failed to initialize memory pool for op completion info in driver " << mName << ", result " << + result); + return result; + } + + mOpCtxMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mOpCtxMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for op ctx info pool in driver " << mName << + ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + result = mOpCtxMemPool->Initialize(); + if (result != NN_OK) { + NN_LOG_ERROR("Failed to initialize memory pool for op ctx info in driver " << mName << ", result " << result); + return result; + } + + options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = NN_NO512; // the sgl context is 448, not power of 2, set to the closest num 512 + options.tcExpandBlkCnt = NN_NO64; + mSglCompMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mSglCompMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for sgl op context in driver " << mName << + ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + result = mSglCompMemPool->Initialize(); + if (result != NN_OK) { + NN_LOG_ERROR("Failed to initialize memory pool for sgl op context in driver " << mName << ", result " << + result); + return result; + } + + return NN_OK; +} + +NResult NetDriverShmWithOOB::CreateWorkers() +{ + NResult result = NN_OK; + + std::vector workerGroups; + std::vector> workerGroupCpus; + std::vector flatWorkerCpus; + std::vector workerThreadPriority; + + /* parse */ + if (!(NetFunc::NN_ParseWorkersGroups(mOptions.WorkGroups(), workerGroups)) || + !(NetFunc::NN_ParseWorkerGroupsCpus(mOptions.WorkerGroupCpus(), workerGroupCpus)) || + !(NetFunc::NN_FinalizeWorkerGroupCpus(workerGroups, workerGroupCpus, true, flatWorkerCpus)) || + !(NetFunc::NN_ParseWorkersGroupsThreadPriority(mOptions.WorkerGroupThreadPriority(), + workerThreadPriority, workerGroups.size()))) { + NN_LOG_ERROR("[SHM] Failed to parse worker or cpu groups"); + return NN_INVALID_PARAM; + } + + ShmWorkerOptions options; + options.mode = mOptions.mode == NET_EVENT_POLLING ? SHM_EVENT_POLLING : SHM_BUSY_POLLING; + options.eventQueueLength = mOptions.completionQueueDepth; + options.pollingTimeoutMs = mOptions.eventPollingTimeout; + options.pollingBatchSize = mOptions.pollingBatchSize; + options.threadPriority = mOptions.workerThreadPriority; + if ((mOptions.workerThreadPriority != 0) && (!workerThreadPriority.empty())) { + NN_LOG_WARN("Driver options 'workerThreadPriority' and 'workerGroupsThreadPriority' set all, preferential use " + "'workerGroupsThreadPriority'."); + } + + /* create workers */ + mWorkers.reserve(flatWorkerCpus.size()); + uint32_t groupIndex = 0; + UBSHcomNetWorkerIndex workerIndex {}; + uint16_t totalWorkerIndex = 0; + for (auto item : workerGroups) { + /* The left of mWorkerGroups is the index of each group's first worker in the mWorkers */ + mWorkerGroups.emplace_back(totalWorkerIndex, item); + for (uint32_t i = 0; i < item; ++i) { + options.cpuId = flatWorkerCpus.at(totalWorkerIndex++); + if (!workerThreadPriority.empty()) { + options.threadPriority = workerThreadPriority[groupIndex]; + } + workerIndex.Set(i, groupIndex, mIndex); + auto *worker = new (std::nothrow) + ShmWorker(mName, workerIndex, options, mOpCompMemPool, mOpCtxMemPool, mSglCompMemPool); + if (NN_UNLIKELY(worker == nullptr)) { + NN_LOG_ERROR("Failed to create shm worker in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + if (NN_UNLIKELY((result = worker->Initialize()) != NN_OK)) { + delete worker; + NN_LOG_ERROR("Failed to initialize shm worker in driver " << mName << ", result " << result); + return result; + } + + worker->IncreaseRef(); + mWorkers.push_back(worker); + } + ++groupIndex; + } + + return NN_OK; +} + +void NetDriverShmWithOOB::ClearWorkers() +{ + mWorkerGroups.clear(); + for (auto worker : mWorkers) { + worker->DecreaseRef(); + } + mWorkers.clear(); +} + +NResult NetDriverShmWithOOB::Start() +{ + std::lock_guard guard(mInitMutex); + if (!mInited) { + NN_LOG_ERROR("Failed to start driver " << mName << " as it is not initialized"); + return NN_ERROR; + } + + if (NN_UNLIKELY(mChannelKeeper == nullptr)) { + NN_LOG_ERROR("Failed to start driver " << mName << " as mChannelKeeper is null"); + return NN_ERROR; + } + + if (NN_UNLIKELY(mDelayReleaseTimer == nullptr)) { + NN_LOG_ERROR("Failed to start driver " << mName << " as mDelayReleaseTimer is null"); + return NN_ERROR; + } + + NResult result = NN_OK; + if (mOptions.dontStartWorkers) { + // self polling should register channel keeper and start + mChannelKeeper->RegisterMsgHandler( + std::bind(&NetDriverShmWithOOB::HandleChanelKeeperMsg, this, std::placeholders::_1, std::placeholders::_2)); + if ((result = mChannelKeeper->Start()) != NN_OK) { + return result; + } + mStarted = true; + } + + if (mStarted) { + return NN_OK; + } + if (NN_UNLIKELY(result = ValidateHandlesCheck()) != NN_OK) { + ClearWorkers(); + return result; + } + for (auto &item : mWorkers) { + if (NN_UNLIKELY(item == nullptr)) { + NN_LOG_ERROR("[SHM] Failed to start worker " << mName << " as it is null"); + ClearWorkers(); + return result; + } + + item->RegisterNewReqHandler( + std::bind(&NetDriverShmWithOOB::HandleNewRequest, this, std::placeholders::_1, std::placeholders::_2)); + item->RegisterReqPostedHandler(std::bind(&NetDriverShmWithOOB::HandleReqPosted, this, std::placeholders::_1)); + item->RegisterOneSideHandler(std::bind(&NetDriverShmWithOOB::OneSideDone, this, std::placeholders::_1)); + if (mIdleHandler != nullptr) { + item->RegisterIdleHandler(mIdleHandler); + } + + if ((result = item->Start()) != NN_OK) { + NN_LOG_ERROR("Failed to start worker " << item->Name() << " in driver " << mName << ", result " << result); + ClearWorkers(); + return result; + } + } + + if (mStartOobSvr) { + if (mNewEndPointHandler == nullptr) { + NN_LOG_ERROR("SHM failed to do start in Driver " << mName << ", as newEndPointerHandler is null"); + return NN_INVALID_PARAM; + } + for (auto &oobServer : mOobServers) { + oobServer->SetNewConnCB(std::bind(&NetDriverShmWithOOB::HandleNewOobConn, this, std::placeholders::_1)); + } + + /* start oob server */ + if ((result = StartListeners()) != NN_OK) { + ClearWorkers(); + return result; + } + } + + mChannelKeeper->RegisterMsgHandler( + std::bind(&NetDriverShmWithOOB::HandleChanelKeeperMsg, this, std::placeholders::_1, std::placeholders::_2)); + if ((result = mChannelKeeper->Start()) != NN_OK) { + ClearWorkers(); + return result; + } + + if ((result = mDelayReleaseTimer->Start()) != NN_OK) { + ClearWorkers(); + mChannelKeeper->Stop(); + return result; + } + + mStarted = true; + return NN_OK; +} + +void NetDriverShmWithOOB::ClearShmLeftFile() +{ + mClearThreadStarted.store(true); + NN_LOG_INFO("NetDriverShmWithOOB clearThread " << mName << " working thread started"); + DIR *dir = nullptr; + struct dirent *ent = nullptr; + // do later consider open dir/file and delete dir/file secure problems + if (NN_UNLIKELY((dir = opendir(SHM_FILE_DIR_PATH)) == nullptr)) { + NN_LOG_TRACE_INFO("Failed to open directory SHM_FILE_DIR_PATH"); + return; + } + + while ((ent = readdir(dir)) != nullptr) { + if (strncmp(ent->d_name, SHM_FILE_PREFIX, strlen(SHM_FILE_PREFIX)) != 0) { + continue; + } + + auto tmpFd = shm_open(ent->d_name, O_CREAT | O_RDWR, NN_NO400); + if (NN_UNLIKELY(tmpFd < 0)) { + continue; + } + if (flock(tmpFd, LOCK_EX | LOCK_NB) == 0) { + if (NN_UNLIKELY(shm_unlink(ent->d_name) != 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_TRACE_INFO("Failed to remove file:" << ent->d_name << " error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + (void)buf; + NetFunc::NN_SafeCloseFd(tmpFd); + continue; + } + NN_LOG_TRACE_INFO("Success to delete shm file:" << ent->d_name << + " which is not used now, may be left last time"); + } + NetFunc::NN_SafeCloseFd(tmpFd); + } + closedir(dir); + NN_LOG_INFO("NetDriverShmWithOOB clearThread " << mName << " working thread exit"); +} + +void NetDriverShmWithOOB::Stop() +{ + std::lock_guard locker(mInitMutex); + if (!mStarted) { + return; + } + + StopInner(); + + mStarted = false; +} + +void NetDriverShmWithOOB::StopInner() +{ + for (auto worker : mWorkers) { + worker->Stop(); + } + + if (NN_LIKELY(mChannelKeeper != nullptr)) { + mChannelKeeper->Stop(); + } + + if (NN_LIKELY(mDelayReleaseTimer != nullptr)) { + mDelayReleaseTimer->Stop(); + } + + StopListeners(true); +} + +NResult NetDriverShmWithOOB::HandleNewRequest(ShmOpContextInfo &ctx, uint32_t immData) +{ + NN_ASSERT_LOG_RETURN(ctx.channel != nullptr, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx.channel->UpContext() != 0, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx.dataAddress != 0, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx.dataSize <= NET_SGE_MAX_SIZE, NN_ERROR) + NResult result = NN_OK; + + if (NN_UNLIKELY(ctx.channel->State().Compare(CH_BROKEN))) { + NN_LOG_WARN("Got invalid ctx in new request handler, as channel " << ctx.channel->Id() << " is broken drop it"); + return result; + } + + static thread_local UBSHcomNetRequestContext netCtx {}; + static thread_local UBSHcomNetMessage netMsg {}; + if (ctx.opType == ShmOpContextInfo::ShmOpType::SH_RECEIVE && immData == 0) { + /* get header */ + netCtx.mEp.Set(reinterpret_cast(ctx.channel->UpContext())); + auto asyncEp = netCtx.mEp.ToChild(); + if (NN_UNLIKELY(asyncEp == nullptr)) { + NN_LOG_ERROR("dynamic cast failed"); + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + return NN_PARAM_INVALID; + } + auto header = reinterpret_cast(ctx.dataAddress); + + result = NetFunc::ValidateHeaderWithDataSize(*header, ctx.dataSize); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Failed to validate received header param, ep " << asyncEp->Id()); + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + return result; + } + + netCtx.mHeader = *header; + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED; + netCtx.mMessage = &netMsg; + + size_t realDataSize = 0; + if (asyncEp->mIsNeedEncrypt) { + const void *cipherData = reinterpret_cast(ctx.dataAddress + sizeof(UBSHcomNetTransHeader)); + auto aesLen = header->dataLength; + realDataSize = asyncEp->mAes.GetRawLen(aesLen); + uint32_t decryptLen = 0; + bool messageReady = netMsg.AllocateIfNeed(realDataSize); + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size " << realDataSize << + ", probably out of memory"); + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + return NN_MALLOC_FAILED; + } + + if (!asyncEp->mAes.Decrypt(asyncEp->mSecrets, cipherData, aesLen, netMsg.mBuf, decryptLen)) { + NN_LOG_ERROR("Failed to decrypt data"); + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + return NN_DECRYPT_FAILED; + } + VALIDATE_DECRYPT_LENGTH(decryptLen, realDataSize, ctx) + } else { + realDataSize = ctx.dataSize - sizeof(UBSHcomNetTransHeader); + bool messageReady = netMsg.AllocateIfNeed(realDataSize); + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Failed to allocate net msg buffer failed"); + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + return NN_MALLOC_FAILED; + } + if (NN_UNLIKELY(memcpy_s(netMsg.mBuf, netMsg.GetBufLen(), reinterpret_cast(ctx.dataAddress + + sizeof(UBSHcomNetTransHeader)), realDataSize) != NN_OK)) { + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + NN_LOG_ERROR("Failed to copy ctx to netMsg"); + return NN_INVALID_PARAM; + } + } + + netMsg.mDataLen = realDataSize; + netCtx.mHeader.dataLength = realDataSize; + + /* call upper handler */ + result = mReceivedRequestHandler(netCtx); + + /* mark buck free */ + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + netCtx.mEp.Set(nullptr); + + return result; + } else if (ctx.opType == ShmOpContextInfo::ShmOpType::SH_RECEIVE && immData != 0) { + netCtx.mEp.Set(reinterpret_cast(ctx.channel->UpContext())); + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED_RAW; + netCtx.mMessage = &netMsg; + + auto asyncEp = netCtx.mEp.ToChild(); + if (NN_UNLIKELY(asyncEp == nullptr)) { + NN_LOG_ERROR("dynamic cast failed"); + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + return NN_PARAM_INVALID; + } + if (asyncEp->mIsNeedEncrypt) { + const void *cipherData = reinterpret_cast(ctx.dataAddress); + auto aesLen = ctx.dataSize; + size_t decryptRawLen = asyncEp->mAes.GetRawLen(aesLen); + uint32_t decryptLen = 0; + bool messageReady = netMsg.AllocateIfNeed(decryptRawLen); + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Failed to allocate net msg buffer failed"); + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + return NN_MALLOC_FAILED; + } + + if (!asyncEp->mAes.Decrypt(asyncEp->mSecrets, cipherData, aesLen, netMsg.mBuf, decryptLen)) { + NN_LOG_ERROR("Failed to decrypt data"); + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + return NN_DECRYPT_FAILED; + } + NN_ASSERT_LOG_RETURN(decryptLen == decryptRawLen, NN_DECRYPT_FAILED) + netMsg.mDataLen = decryptRawLen; + } else { + netMsg.mDataLen = ctx.dataSize; + bool messageReady = netMsg.AllocateIfNeed(ctx.dataSize); + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Failed to allocate net msg buffer failed"); + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + return NN_MALLOC_FAILED; + } + if (NN_UNLIKELY(memcpy_s(netMsg.mBuf, netMsg.GetBufLen(), reinterpret_cast(ctx.dataAddress), + ctx.dataSize) != NN_OK)) { + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + NN_LOG_ERROR("Failed to copy dataAddress to netMsg"); + return NN_INVALID_PARAM; + } + } + + netCtx.mHeader.Invalid(); + netCtx.mHeader.dataLength = netMsg.mDataLen; + netCtx.mHeader.seqNo = immData; + /* call upper handler */ + result = mReceivedRequestHandler(netCtx); + + /* mark buck free */ + ctx.channel->DCMarkPeerBuckFree(ctx.dataAddress); + netCtx.mEp.Set(nullptr); + + return result; + } else { + NN_LOG_WARN("Un-reachable path"); + return NN_OK; + } +} + +NResult NetDriverShmWithOOB::HandleReqPosted(ShmOpCompInfo &ctx) +{ + NN_ASSERT_LOG_RETURN(ctx.channel != nullptr, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx.channel->UpContext() != 0, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx.channel->UpContext1() != 0, NN_ERROR) + NResult result = NN_OK; + + static thread_local UBSHcomNetRequestContext netCtx {}; + + netCtx.mResult = NN_OK; + netCtx.mEp.Set(reinterpret_cast(ctx.channel->UpContext())); + netCtx.mOriginalReq = ctx.request; + netCtx.mMessage = nullptr; + + auto shmWorker = reinterpret_cast(ctx.channel->UpContext1()); + auto sgeCtx = reinterpret_cast(ctx.upCtx); + auto sglCtx = sgeCtx->ctx; + + switch (ctx.opType) { + case ShmOpContextInfo::SH_SEND: + netCtx.mHeader = ctx.header; + netCtx.mOpType = UBSHcomNetRequestContext::NN_SENT; + shmWorker->ReturnOpCompInfo(&ctx); + break; + case ShmOpContextInfo::SH_SEND_RAW: + netCtx.mHeader.Invalid(); + netCtx.mOpType = UBSHcomNetRequestContext::NN_SENT_RAW; + shmWorker->ReturnOpCompInfo(&ctx); + break; + case ShmOpContextInfo::SH_SEND_RAW_SGL: + netCtx.mHeader.Invalid(); + netCtx.mOpType = UBSHcomNetRequestContext::NN_SENT_RAW_SGL; + if (sglCtx->iovCount <= NET_SGE_MAX_IOV) { + if (NN_UNLIKELY(memcpy_s(netCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, sglCtx->iov, + sizeof(UBSHcomNetTransSgeIov) * sglCtx->iovCount) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + netCtx.mOriginalSglReq.iov = netCtx.iov; + netCtx.mOriginalSglReq.iovCount = sglCtx->iovCount; + netCtx.mOriginalSglReq.upCtxSize = sglCtx->upCtxSize; + if (netCtx.mOriginalSglReq.upCtxSize > 0 && + netCtx.mOriginalSglReq.upCtxSize <= sizeof(UBSHcomNetTransSglRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalSglReq.upCtxData, NN_NO16, sglCtx->upCtx, sglCtx->upCtxSize) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy request to sglCtx"); + return NN_INVALID_PARAM; + } + } + shmWorker->ReturnOpCompInfo(&ctx); + shmWorker->ReturnSglContextInfo(sglCtx); + break; + default: + NN_LOG_WARN("Un-reachable path"); + break; + } + + /* call upper handler */ + if (NN_UNLIKELY((result = mRequestPostedHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call requestPostedHandler in Driver " << mName << " return non-zero for type " << ctx.opType << + " done"); + } + netCtx.mEp.Set(nullptr); + return result; +} + +NResult NetDriverShmWithOOB::OneSideDone(ShmOpContextInfo *ctxIn) +{ + NN_ASSERT_LOG_RETURN(ctxIn != nullptr, NN_ERROR) + ShmOpContextInfo ctx = *ctxIn; + if (NN_UNLIKELY(ctx.channel == nullptr || ctx.channel->UpContext1() == 0 || ctx.channel->UpContext() == 0)) { + NN_LOG_ERROR("Ctx or channel is null of OneSideDone in Driver " << mName); + return NN_ERROR; + } + + int result = 0; + auto worker = reinterpret_cast(ctx.channel->UpContext1()); + static thread_local UBSHcomNetRequestContext netCtx {}; + + if (ctx.opType == ShmOpContextInfo::SH_WRITE || ctx.opType == ShmOpContextInfo::SH_READ) { + // set context + netCtx.mResult = ShmOpContextInfo::GetNResult(ctx.errType); + netCtx.mEp.Set(reinterpret_cast(ctx.channel->UpContext())); + netCtx.mOpType = + ctx.opType == ShmOpContextInfo::SH_WRITE ? UBSHcomNetRequestContext::NN_WRITTEN : + UBSHcomNetRequestContext::NN_READ; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + netCtx.mOriginalReq.lAddress = ctx.mrMemAddr; + netCtx.mOriginalReq.lKey = ctx.lKey; + netCtx.mOriginalReq.size = ctx.dataSize; + netCtx.mOriginalReq.upCtxSize = ctx.upCtxSize; + + if (ctx.upCtxSize > 0 && ctx.upCtxSize <= sizeof(UBSHcomNetTransRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalReq.upCtxData, NN_NO16, ctx.upCtx, ctx.upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + + // called to callback + if (NN_UNLIKELY((result = mOneSideDoneHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call oneSideDoneHandler in Driver " << mName << " return non-zero for type " << ctx.opType << + " done"); + } + worker->ReturnOpContextInfo(ctxIn); + netCtx.mEp.Set(nullptr); + } else if (ctx.opType == ShmOpContextInfo::SH_SGL_WRITE || ctx.opType == ShmOpContextInfo::SH_SGL_READ) { + auto upCtx = reinterpret_cast(ctx.upCtx); + auto sglCtx = upCtx->ctx; + + netCtx.mResult = ShmOpContextInfo::GetNResult(ctx.errType); + netCtx.mEp.Set(reinterpret_cast(ctx.channel->UpContext())); + netCtx.mOpType = ctx.opType == ShmOpContextInfo::SH_SGL_WRITE ? UBSHcomNetRequestContext::NN_SGL_WRITTEN : + UBSHcomNetRequestContext::NN_SGL_READ; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + if (sglCtx->iovCount <= NET_SGE_MAX_IOV) { + if (NN_UNLIKELY(memcpy_s(netCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, sglCtx->iov, + sizeof(UBSHcomNetTransSgeIov) * sglCtx->iovCount) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + netCtx.mOriginalSglReq.iov = netCtx.iov; + netCtx.mOriginalSglReq.iovCount = sglCtx->iovCount; + netCtx.mOriginalSglReq.upCtxSize = sglCtx->upCtxSize; + if (netCtx.mOriginalSglReq.upCtxSize > 0 && + netCtx.mOriginalSglReq.upCtxSize <= sizeof(UBSHcomNetTransSglRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalSglReq.upCtxData, NN_NO16, sglCtx->upCtx, sglCtx->upCtxSize) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + + // called to callback + if (NN_UNLIKELY((result = mOneSideDoneHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call oneSideDoneHandler in Driver " << mName << " return non-zero for type " << ctx.opType << + " done"); + } + worker->ReturnOpContextInfo(ctxIn); + worker->ReturnSglContextInfo(sglCtx); + netCtx.mEp.Set(nullptr); + } else { + NN_LOG_WARN("Unreachable path"); + } + + return result; +} + +inline void NetDriverShmWithOOB::HandleChanelKeeperMsg(const ShmChKeeperMsgHeader &header, + const ShmChannelPtr &channelPtr) +{ + if (NN_UNLIKELY(channelPtr == nullptr)) { + return; + } + if (header.msgType == ShmChKeeperMsgType::RESET_BY_PEER) { + channelPtr->Close(); + if (NN_UNLIKELY(!channelPtr->State().CAS(CH_NEW, CH_BROKEN))) { + NN_LOG_ERROR("Channel id " << channelPtr->Id() << " failed set state " << CH_BROKEN); + } + ProcessEpError(channelPtr); + } else if (header.msgType == ShmChKeeperMsgType::GET_MR_FD) { + HandleKeeperMsgGetMrFd(header, channelPtr); + } +} + +void NetDriverShmWithOOB::HandleKeeperMsgGetMrFd(const ShmChKeeperMsgHeader &header, const ShmChannelPtr &channelPtr) +{ + uint32_t lKey = -1; + if (NN_UNLIKELY(header.dataSize != sizeof(uint32_t))) { + NN_LOG_ERROR("Failed to receive lkey from peer, as dataSize in header is invalid"); + return; + } + ssize_t result = ::recv(channelPtr->UdsFD(), &lKey, header.dataSize, 0); + if (NN_UNLIKELY(result <= 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to receive data from peer as errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return; + } + + ShmHandlePtr shmHandle = ShmMRHandleMap::GetInstance().GetFromLocalMap(lKey); + if (shmHandle.Get() == nullptr) { + NN_LOG_ERROR("Get shmHandle from local map failed"); + return; + } + int mrFd = shmHandle->Fd(); + if (NN_UNLIKELY(mrFd <= 0)) { + NN_LOG_ERROR("Get Fd from local map failed"); + return; + } + + std::lock_guard guard(channelPtr->mFdMutex); + ShmChKeeperMsgHeader exchangeHeader{}; + exchangeHeader.msgType = ShmChKeeperMsgType::SEND_MR_FD; + exchangeHeader.dataSize = sizeof(int); + if (NN_UNLIKELY(::send(channelPtr->UdsFD(), &exchangeHeader, sizeof(ShmChKeeperMsgHeader), MSG_NOSIGNAL) <= 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to send header info of exchanging user fd to peer, errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return; + } + + int fds[NN_NO4] = {0}; + fds[0] = mrFd; + if (NN_UNLIKELY(ShmHandleFds::SendMsgFds(channelPtr->UdsFD(), fds, NN_NO4) != NN_OK)) { + NN_LOG_ERROR("Failed to send mr to peer"); + return; + } +} + +void NetDriverShmWithOOB::ProcessEpError(const ShmChannelPtr &channelPtr) +{ + UBSHcomNetEndpointPtr epPtr = reinterpret_cast(channelPtr->UpContext()); + if (NN_UNLIKELY(epPtr == nullptr)) { + return; + } + + bool process = false; + if (NN_UNLIKELY(!epPtr->EPBrokenProcessed().compare_exchange_strong(process, true))) { + NN_LOG_WARN("Ep id " << epPtr->Id() << " broken handled by other thread"); + return; + } + + if (epPtr->State().Compare(NEP_ESTABLISHED)) { + epPtr->State().Set(NEP_BROKEN); + } + + // two side remaining + ShmOpCompInfo *remainingCompCtx = nullptr; + ShmOpCompInfo *nextCompCtx = nullptr; + channelPtr->GetCompPosted(remainingCompCtx); + while (remainingCompCtx != nullptr) { + nextCompCtx = remainingCompCtx->next; + remainingCompCtx->errType = ShmOpContextInfo::ShmErrorType::SH_RESET_BY_PEER; + (void)HandleReqPosted(*remainingCompCtx); + remainingCompCtx = nextCompCtx; + } + + // one side remaining + ShmOpContextInfo *remainingOpCtx = nullptr; + ShmOpContextInfo *nextOpCtx = nullptr; + channelPtr->GetCtxPosted(remainingOpCtx); + while (remainingOpCtx != nullptr) { + nextOpCtx = remainingOpCtx->next; + remainingOpCtx->errType = ShmOpContextInfo::ShmErrorType::SH_RESET_BY_PEER; + (void)OneSideDone(remainingOpCtx); + remainingOpCtx = nextOpCtx; + } + + NN_LOG_WARN("Handle Ep state " << UBSHcomNEPStateToString(epPtr->State().Get()) << ", Ep id " << epPtr->Id() << + " , try call Ep broken handle"); + + OOBSecureProcess::SecProcessDelEpNum(epPtr->UdsName(), epPtr->PeerIpAndPort(), + mOobServers); + + if (mEndPointBrokenHandler != nullptr) { + // self polling mode not register ep handler + mEndPointBrokenHandler(epPtr); + } + DestroyEndpoint(epPtr); +} + +NResult NetDriverShmWithOOB::ConnectSyncEp(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint8_t serverGrpNo, uint64_t ctx) +{ + NResult result = NN_OK; + auto eventQueueLength = mOptions.completionQueueDepth; + ShmPollingMode pollMode = (mOptions.mode == NET_EVENT_POLLING) ? SHM_EVENT_POLLING : SHM_BUSY_POLLING; + + ShmSyncEndpointPtr shmEp; + if ((result = ShmSyncEndpoint::Create(mName, eventQueueLength, pollMode, shmEp)) != 0) { + NN_LOG_ERROR("Failed to create sync ep for new connection in Driver " << mName << " , result " << result); + return result; + } + + OOBTCPClientPtr clt; + if (mEnableTls) { + auto oobSSLClt = new OOBSSLClient(NET_OOB_UDS, oobIp, oobPort, + mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(oobSSLClt != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLClt->SetTlsOptions(mOptions); + oobSSLClt->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + clt = oobSSLClt; + } else { + clt = new OOBTCPClient(NET_OOB_UDS, oobIp, oobPort); + NN_ASSERT_LOG_RETURN(clt.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + /* try to connect to oob server */ + OOBTCPConnection *conn = nullptr; + if ((result = clt->Connect(conn)) != 0) { + NN_LOG_ERROR("Shm Failed to connect server via oob, result" << " " << result); + return result; + } + + const auto &peerIpPort = conn->GetIpAndPort(); + NetLocalAutoDecreasePtr autoDecPtr(conn); + conn->SetIpAndPort(oobIp, oobPort); + + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBClient(mSecInfoProvider, mSecInfoValidator, conn, mName, ctx, + mOptions.secType))) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + /* send connection header */ + ConnectHeader header {}; + SetConnHeader(header, mOptions.magic, mOptions.version, serverGrpNo, Protocol(), mMajorVersion, + mMinorVersion, mOptions.tlsVersion); + if (NN_UNLIKELY((result = conn->Send(&header, sizeof(ConnectHeader))) != NN_OK)) { + NN_LOG_ERROR("Failed to send conn header to oob server " << oobIp << ":" << oobPort << " in Driver " << mName); + return NN_ERROR; + } + + /* receive connect response and peer ep id */ + ConnRespWithUId rspWithUid {}; + void *tmpBuf = &rspWithUid; + if (NN_UNLIKELY((result = conn->Receive(tmpBuf, sizeof(ConnRespWithUId))) != NN_OK)) { + return result; + } + + /* connect response */ + auto resp = rspWithUid.connResp; + if (NN_UNLIKELY(resp != OK)) { + NN_LOG_ERROR("Shm Failed to pass server validation in driver " << mName << ", result " << resp); + return NN_CONNECT_REFUSED; + } + + /* peer ep id */ + auto newId = rspWithUid.epId; + NN_LOG_TRACE_INFO("new ep id will be set as " << newId << " in driver " << mName); + + /* create shm and init channel */ + ShmChannelPtr ch; + result = ShmChannel::CreateAndInit(mName, newId, mOptions.mrSendReceiveSegSize, mOptions.qpSendQueueSize, ch); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + + /* fill exchange info */ + ShmConnExchangeInfo exInfo {}; + NN_ASSERT_LOG_RETURN(shmEp->FillQueueExchangeInfo(exInfo), NN_ERROR) + NN_ASSERT_LOG_RETURN(ch->FillExchangeInfo(exInfo), NN_ERROR) + exInfo.payLoadSize = payload.length(); + + /* send exchange info */ + if (NN_UNLIKELY((result = SendExchangeInfo(*conn, exInfo)) != NN_OK)) { + NN_LOG_ERROR("Shm Failed to send channel exchange info to oob server " << oobIp << ":" << oobPort << + " in driver " << mName); + return NN_ERROR; + } + + /* send payload */ + if (NN_UNLIKELY((result = conn->Send(const_cast(payload.c_str()), payload.length())) != NN_OK)) { + NN_LOG_ERROR("Shm Failed to send payload to peer at " << peerIpPort << " in driver " << mName); + return result; + } + + /* receive exchange info */ + NN_LOG_TRACE_INFO("Shm Try to receive exchange info from peer, " << sizeof(ShmConnExchangeInfo)); + if (NN_UNLIKELY((result = ReceiveExchangeInfo(*conn, exInfo)) != NN_OK)) { + return result; + } + + /* change to ready */ + if ((result = ch->ChangeToReady(exInfo)) != NN_OK) { + return result; + } + + /* receive ready signal */ + int8_t ready = -1; + tmpBuf = static_cast(&ready); + result = conn->Receive(tmpBuf, sizeof(int8_t)); + if (result != NN_OK || ready != 1) { + NN_LOG_ERROR("Failed to receive ready from " << peerIpPort << " in Driver " << mName << ", Result " << result); + return result; + } + + /* create ep */ + const UBSHcomNetWorkerIndex netWorkerIndex {}; + UBSHcomNetEndpointPtr newEp = new (std::nothrow) + NetSyncEndpointShm(ch->Id(), ch.Get(), this, netWorkerIndex, shmEp.Get(), ShmMRHandleMap::GetInstance()); + if (NN_UNLIKELY(newEp.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new async shm ep in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + if (mEnableTls) { + auto chiEp = newEp.ToChild(); + auto tmp = dynamic_cast(conn); + if (NN_UNLIKELY(chiEp == nullptr || tmp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + chiEp->EnableEncrypt(mOptions); + chiEp->SetSecrets(tmp->Secret()); + } + + /* 1 transfer fd, 2 set upCtx, 3 set payload, 4 add ep into map, 5 set state */ + ch->UdsFD(conn->TransferFd()); + ch->UpContext(reinterpret_cast(newEp.Get())); + newEp->StoreConnInfo(NetFunc::GetIpByFd(ch->UdsFD()), conn->ListenPort(), header.version, payload); + AddEp(newEp); + newEp->State().Set(NEP_ESTABLISHED); + + outEp.Set(newEp.Get()); + if (mChannelKeeper == nullptr) { + NN_LOG_INFO("New connection failed as mChannelKeeper is null"); + return NN_ERROR; + } + if ((result = mChannelKeeper->AddShmChannel(ch)) != NN_OK) { + NN_LOG_ERROR("Adding Shm Channel failed, result: " << result); + return result; + } + + NN_LOG_INFO("New connection to " << oobIp << ":" << oobPort << " established, sync ep id " << outEp->Id()); + return result; +} + +#define VALIDATE_DRIVER_INIT() \ + if (NN_UNLIKELY(!mInited.load())) { \ + NN_LOG_ERROR("Driver " << mName << " is not initialized"); \ + return NN_NOT_INITIALIZED; \ + } + +#define VALIDATE_PAYLOAD(payloadSize) \ + if ((payloadSize) > NN_NO1024) { \ + NN_LOG_ERROR("Failed to connect server via payload size " << (payloadSize) << " over limit"); \ + return NN_INVALID_PARAM; \ + } + +#define VALIDATE_OOBTYPE() \ + if (mOptions.oobType == NET_OOB_TCP) { \ + NN_LOG_WARN("The current oobType is not supported"); \ + return NN_INVALID_PARAM; \ + } + +NResult NetDriverShmWithOOB::CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + if (NN_UNLIKELY(size == 0 || size > NN_NO107374182400)) { + NN_LOG_ERROR("Failed to create mem region as size is 0 or greater than 100 GB"); + return NN_INVALID_PARAM; + } + + if (!mInited) { + NN_LOG_ERROR("Failed to create Memory region in NetDriverShm " << mName << ", as not initialized"); + return NN_EP_NOT_INITIALIZED; + } + + ShmMemoryRegion *tmp = nullptr; + auto result = ShmMemoryRegion::Create(mName, size, tmp); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Failed to create Memory region in NetDriverShm " << mName << ", probably out of memory"); + return result; + } + + if ((result = tmp->Initialize()) != NN_OK) { + delete tmp; + return result; + } + + if ((result = mMrChecker.Register(tmp->GetLKey(), tmp->GetAddress(), size)) != NN_OK) { + NN_LOG_ERROR("Failed to add memory region to range checker in driver" << mName << " for duplicate keys"); + delete tmp; + return result; + } + + // Prevent integer truncation, safely converts uint64_t to uint32_t + if (NN_UNLIKELY(tmp->mLKey > UINT32_MAX)) { + NN_LOG_ERROR("Failed to create Memory region in NetDriverShm as lKey is larger than uint32max, lkey" << + tmp->mLKey); + delete tmp; + return NN_INVALID_PARAM; + } + + ShmMRHandleMap::GetInstance().AddToLocalMap(static_cast(tmp->mLKey), tmp->GetMrHandle()); + + mr.Set(static_cast(tmp)); + + return NN_OK; +} +NResult NetDriverShmWithOOB::CreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + NN_LOG_WARN("Invalid operation, create memoryRegion is not supported by NetDriverShmWithOOB"); + return NN_INVALID_OPERATION; +} + +NResult NetDriverShmWithOOB::CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, unsigned long memid) +{ + NN_LOG_ERROR("operation is not supported in shm"); + return NN_ERROR; +} + +NResult NetDriverShmWithOOB::Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, + uint8_t serverGrpNo, uint8_t clientGrpNo) +{ + if (mOptions.oobType == NET_OOB_TCP) { + NN_LOG_WARN("The current oobType is not supported"); + return NN_INVALID_PARAM; + } else if (mOptions.oobType == NET_OOB_UDS) { + return Connect(mUdsName, 0, payload, ep, flags, serverGrpNo, clientGrpNo, 0); + } + return NN_ERROR; +} + +NResult NetDriverShmWithOOB::SendExchangeInfo(OOBTCPConnection &conn, ShmConnExchangeInfo &exInfo) +{ + // create iov for general exchange message + struct iovec iov = { + .iov_base = &exInfo, + .iov_len = sizeof(ShmConnExchangeInfo) + }; + // fds, event queue fd and share mem fd + int fds[NN_NO2]; + fds[0] = exInfo.queueFd; + fds[1] = exInfo.channelFd; + char buf[CMSG_SPACE(sizeof(fds))]; + bzero(buf, sizeof(buf)); + + struct msghdr msg {}; + msg.msg_iov = &iov; + msg.msg_iovlen = NN_NO1; + msg.msg_control = buf; + msg.msg_controllen = sizeof(buf); + + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + if (NN_UNLIKELY(cmsg == nullptr)) { + NN_LOG_ERROR("CMSG_FIRSTHDR get empty msg"); + return NN_ERROR; + } + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(sizeof(fds)); + + if (NN_UNLIKELY(memcpy_s((char *)CMSG_DATA(cmsg), sizeof(fds), fds, sizeof(fds)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy fds to cmsg"); + return NN_INVALID_PARAM; + } + + return conn.SendMsg(msg, sizeof(ShmConnExchangeInfo)); +} + +NResult NetDriverShmWithOOB::ReceiveExchangeInfo(OOBTCPConnection &conn, ShmConnExchangeInfo &exInfo) +{ + // create iov for general exchange message + struct iovec iov = { + .iov_base = &exInfo, + .iov_len = sizeof(ShmConnExchangeInfo) + }; + + // fds, event queue fd and share mem fd + int fds[NN_NO2]; + fds[0] = -1; + fds[1] = -1; + char buf[CMSG_SPACE(sizeof(fds))]; + bzero(buf, sizeof(buf)); + + struct msghdr msg {}; + msg.msg_iov = &iov; + msg.msg_iovlen = NN_NO1; + msg.msg_control = buf; + msg.msg_controllen = sizeof(buf); + + auto result = conn.ReceiveMsg(msg, sizeof(ShmConnExchangeInfo)); + if (result != NN_OK) { + return result; + } + + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + if (NN_UNLIKELY(cmsg == nullptr)) { + NN_LOG_ERROR("CMSG_FIRSTHDR get empty msg"); + return NN_ERROR; + } + if (NN_UNLIKELY(memcpy_s(fds, sizeof(fds), (char *)CMSG_DATA(cmsg), sizeof(fds)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy cmsg to fds"); + return NN_INVALID_PARAM; + } + exInfo.queueFd = fds[0]; + exInfo.channelFd = fds[1]; + + return NN_OK; +} + +NResult NetDriverShmWithOOB::HandleNewOobConn(OOBTCPConnection &conn) +{ + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBServer(mSecInfoProvider, mSecInfoValidator, conn, mName, + mOptions.secType)) != NN_OK) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + if (NN_UNLIKELY(OOBSecureProcess::SecProcessCompareEpNum(conn.GetUdsName(), conn.GetIpAndPort(), + mOobServers)) != NN_OK) { + NN_LOG_ERROR("Shm connection num exceeds maximum"); + return NN_OOB_SEC_PROCESS_ERROR; + } + + NResult result = NN_OK; + const auto &peerIpPort = conn.GetIpAndPort(); + /* receive header and verify */ + ConnectHeader header {}; + void *tmpBuf = &header; + if (NN_UNLIKELY((result = conn.Receive(tmpBuf, sizeof(ConnectHeader))) != 0)) { + NN_LOG_ERROR("OOB from " << peerIpPort << " dropped as read data or invalid data in driver " << mName << + ", result " << result); + return result; + } + + ConnRespWithUId respWithUId{ OK, 0 }; + result = OOBSecureProcess::SecCheckConnectionHeader(header, mOptions, mEnableTls, Protocol(), mMajorVersion, + mMinorVersion, respWithUId); + if (result != NN_OK) { + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + + uint64_t newId = NetUuid::GenerateUuid(); + NN_LOG_TRACE_INFO("new ep id will be set as " << newId << " in driver " << mName); + + respWithUId.connResp = OK; + respWithUId.epId = newId; + if (NN_UNLIKELY((result = conn.Send(&respWithUId, sizeof(ConnRespWithUId))) != NN_OK)) { + NN_LOG_ERROR("Failed to send resp to " << peerIpPort << " in driver " << mName << ", result " << result); + return result; + } + + ShmConnExchangeInfo peerExInfo {}; /* fill exchange info */ + if (NN_UNLIKELY((result = ReceiveExchangeInfo(conn, peerExInfo)) != NN_OK)) { + NN_LOG_ERROR("Failed to read ex from " << peerIpPort << " in driver " << mName << ", result " << result); + return result; + } + + if (NN_UNLIKELY(peerExInfo.payLoadSize > NN_NO1024)) { + NN_LOG_ERROR("OOB from " << peerIpPort << " dropped as payload is too big in driver " << mName); + return NN_INVALID_PARAM; + } + + /* choose worker */ + uint16_t workerIndex = 0; + if (NN_UNLIKELY(!mClientLb->ChooseWorker(header.groupIndex, std::to_string(newId), workerIndex)) || + workerIndex >= mWorkers.size()) { + NN_LOG_ERROR("OOB from " << peerIpPort << " dropped as invalid group index in driver " << mName); + return NN_ERROR; + } + + NN_LOG_TRACE_INFO("Worker " << workerIndex << " is chosen in driver " << mName); + + auto worker = mWorkers[workerIndex]; + NN_ASSERT_LOG_RETURN(worker != nullptr, NN_ERROR) + + /* create shm and init channel */ + ShmChannelPtr ch; + result = ShmChannel::CreateAndInit(mName, newId, mOptions.mrSendReceiveSegSize, mOptions.qpSendQueueSize, ch); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("OOB from " << peerIpPort << " dropped as create channel failure in driver " << mName); + return result; + } + + /* fill exchange info */ + ShmConnExchangeInfo exInfo {}; + NN_ASSERT_LOG_RETURN(worker->FillQueueExchangeInfo(exInfo), NN_ERROR) + NN_ASSERT_LOG_RETURN(ch->FillExchangeInfo(exInfo), NN_ERROR) + + /* send exchange info */ + if (NN_UNLIKELY((result = SendExchangeInfo(conn, exInfo)) != NN_OK)) { + NN_LOG_ERROR("Failed to send ex to OOB from " << peerIpPort << " in driver " << mName << ", result " << result); + return result; + } + + if (NN_UNLIKELY((result = ch->ChangeToReady(peerExInfo)) != NN_OK)) { + NN_LOG_ERROR("OOB from " << peerIpPort << " dropped as failed to change channel to ready in driver " << mName); + return result; + } + + /* receive payload if needed */ + char payChars[NN_NO1024 + NN_NO1] {}; + if (peerExInfo.payLoadSize != 0) { + tmpBuf = &payChars; + if (NN_UNLIKELY((result = conn.Receive(tmpBuf, peerExInfo.payLoadSize)) != 0)) { + NN_LOG_ERROR("Failed to read payload from " << peerIpPort << " in driver " << mName << ", result " << + result); + return result; + } + } + payChars[NN_NO1024] = '\0'; + + /* create ep */ + UBSHcomNetEndpointPtr newEp = new (std::nothrow) + NetAsyncEndpointShm(ch->Id(), ch.Get(), worker, this, worker->Index(), ShmMRHandleMap::GetInstance()); + if (NN_UNLIKELY(newEp.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new async shm ep in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + struct ucred remoteIds {}; + socklen_t len = sizeof(struct ucred); + if (NN_UNLIKELY(getsockopt(conn.GetFd(), SOL_SOCKET, SO_PEERCRED, &remoteIds, &len) != 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get uds ids in driver " << mName << " errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return NN_GET_UDS_ID_INFO_FAILED; + } + newEp->RemoteUdsIdInfo(remoteIds.pid, remoteIds.uid, remoteIds.gid); + + if (mEnableTls) { + auto childEp = newEp.ToChild(); + auto tmp = dynamic_cast(&conn); + if (NN_UNLIKELY(childEp == nullptr || tmp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + childEp->EnableEncrypt(mOptions); + childEp->SetSecrets(tmp->Secret()); + } + + std::string payload = std::string(payChars, peerExInfo.payLoadSize); + /* call user handler new endpoint handler */ + if (NN_UNLIKELY((result = mNewEndPointHandler(peerIpPort, newEp, payload)) != NN_OK)) { + NN_LOG_ERROR("Calling new endpoint handler failed in driver " << mName << ", result " << result); + return result; + } + + ch->UpContext1(reinterpret_cast(worker)); + ch->UpContext(reinterpret_cast(newEp.Get())); + ch->UdsFD(conn.GetFd()); + ch->PeerIpAndPort(conn.GetIpAndPort()); + ch->UdsName(conn.GetUdsName()); + newEp->StoreConnInfo(NetFunc::GetIpByFd(ch->UdsFD()), conn.ListenPort(), header.version, payload); + newEp->State().Set(NEP_ESTABLISHED); + + /* send ready signal to oob */ + int8_t ready = 1; + if (NN_UNLIKELY((result = conn.Send(&ready, sizeof(int8_t))) != NN_OK)) { + NN_LOG_ERROR("Failed to send ready to " << peerIpPort << " in driver " << mName << ", result " << result); + return NN_ERROR; + } + + /* 1 transfer fd, 2 add ep into map */ + conn.TransferFd(); + AddEp(newEp); + NN_ASSERT_LOG_RETURN(mChannelKeeper != nullptr, NN_ERROR); + if ((result = mChannelKeeper->AddShmChannel(ch)) != NN_OK) { + NN_LOG_ERROR("Adding Shm Channel failed, result: " << result); + return result; + } + + OOBSecureProcess::SecProcessAddEpNum(conn.GetUdsName(), conn.GetIpAndPort(), mOobServers); + NN_LOG_INFO("New connection from " << peerIpPort << " established, async ep id " << newEp->Id() << + " worker info " << worker->Name()); + + return NN_OK; +} + +NResult NetDriverShmWithOOB::Connect(const std::string &serverUrl, const std::string &payload, + UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + VALIDATE_DRIVER_INIT() + VALIDATE_PAYLOAD(payload.size()) + + NetDriverOobType type; + std::string ip; + uint16_t port = 0; + if (NN_UNLIKELY(NetFunc::NN_ValidateUrl(serverUrl) != NN_OK)) { + NN_LOG_ERROR("Invalid url"); + return NN_PARAM_INVALID; + } + if (NN_UNLIKELY(ParseUrl(serverUrl, type, ip, port) != NN_OK)) { + NN_LOG_WARN("Invalid url, url:" << serverUrl); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(NetDriverOobType::NET_OOB_UDS != type)) { + NN_LOG_WARN("The current oobType is not supported, url:" << serverUrl); + return NN_INVALID_PARAM; + } + return Connect(ip, port, payload, ep, flags, serverGrpNo, clientGrpNo, ctx); +} + +NResult NetDriverShmWithOOB::Connect(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + VALIDATE_DRIVER_INIT() + VALIDATE_PAYLOAD(payload.size()) + + if (NN_UNLIKELY(!mStarted)) { + NN_LOG_ERROR("Failed to connect on driver " << mName << " as it is not started"); + return NN_ERROR; + } + + if (flags & NET_EP_SELF_POLLING) { + return ConnectSyncEp(oobIp, oobPort, payload, ep, serverGrpNo, ctx); + } + + if (NN_UNLIKELY(clientGrpNo >= mWorkerGroups.size())) { + NN_LOG_ERROR("Invalid clientGrpNo " << clientGrpNo << " as it is large than existed groups"); + return NN_ERROR; + } + + NResult result = NN_OK; + OOBTCPClientPtr client; + if (mEnableTls) { + auto oobSSLClient = new OOBSSLClient(mOptions.oobType, oobIp, oobPort, + mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(oobSSLClient != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLClient->SetTlsOptions(mOptions); + oobSSLClient->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + client = oobSSLClient; + } else { + client = new OOBTCPClient(NET_OOB_UDS, oobIp, oobPort); + NN_ASSERT_LOG_RETURN(client.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + /* try to connect to oob server */ + OOBTCPConnection *conn = nullptr; + if ((result = client->Connect(conn)) != 0) { + NN_LOG_ERROR("Failed to connect server via oob, result " << result); + return result; + } + + const auto &peerIpPort = conn->GetIpAndPort(); + NetLocalAutoDecreasePtr autoDecPtr(conn); + conn->SetIpAndPort(oobIp, oobPort); + + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBClient(mSecInfoProvider, mSecInfoValidator, conn, mName, ctx, + mOptions.secType))) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + /* send connection header */ + ConnectHeader header {}; + SetConnHeader(header, mOptions.magic, mOptions.version, serverGrpNo, Protocol(), mMajorVersion, + mMinorVersion, mOptions.tlsVersion); + if (NN_UNLIKELY((result = conn->Send(&header, sizeof(ConnectHeader))) != NN_OK)) { + NN_LOG_ERROR("Failed to send conn header to oob server " << oobIp << ":" << oobPort << " in driver " << mName); + return NN_ERROR; + } + + /* receive connect response and peer ep id */ + ConnRespWithUId respWithUId {}; + void *tmpBuf = &respWithUId; + if (NN_UNLIKELY((result = conn->Receive(tmpBuf, sizeof(ConnRespWithUId))) != NN_OK)) { + return result; + } + + /* connect response */ + auto resp = respWithUId.connResp; + if (NN_UNLIKELY(resp != OK)) { + NN_LOG_ERROR("Shm Failed to pass server validation in driver " << mName << ", result " << resp); + return NN_CONNECT_REFUSED; + } + + /* peer ep id */ + auto newId = respWithUId.epId; + NN_LOG_TRACE_INFO("new ep id will be set as " << newId << " in driver " << mName); + + /* create shm and init channel */ + ShmChannelPtr ch; + result = ShmChannel::CreateAndInit(mName, newId, mOptions.mrSendReceiveSegSize, mOptions.qpSendQueueSize, ch); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + + /* choose worker */ + uint16_t workerIndex = 0; + if (NN_UNLIKELY(!mClientLb->ChooseWorker(clientGrpNo, std::to_string(newId), workerIndex)) || + workerIndex >= mWorkers.size()) { + NN_LOG_ERROR("Failed to choose worker during connect in driver " << mName); + return NN_ERROR; + } + + NN_LOG_TRACE_INFO("Worker " << workerIndex << " is chosen in driver " << mName); + + auto worker = mWorkers[workerIndex]; + NN_ASSERT_LOG_RETURN(worker != nullptr, NN_ERROR) + + /* fill exchange info */ + ShmConnExchangeInfo exInfo {}; + NN_ASSERT_LOG_RETURN(worker->FillQueueExchangeInfo(exInfo), NN_ERROR) + NN_ASSERT_LOG_RETURN(ch->FillExchangeInfo(exInfo), NN_ERROR) + exInfo.payLoadSize = payload.length(); + + /* send exchange info */ + if (NN_UNLIKELY((result = SendExchangeInfo(*conn, exInfo)) != NN_OK)) { + NN_LOG_ERROR("Failed to send channel exchange info to oob server " << oobIp << ":" << oobPort << + " in driver " << mName); + return NN_ERROR; + } + + /* send payload */ + if (NN_UNLIKELY((result = conn->Send(const_cast(payload.c_str()), payload.length())) != NN_OK)) { + NN_LOG_ERROR("Failed to send payload to peer at " << peerIpPort << " in driver " << mName); + return result; + } + + /* receive exchange info */ + NN_LOG_INFO("Try to receive exchange info from peer, " << sizeof(ShmConnExchangeInfo)); + if (NN_UNLIKELY((result = ReceiveExchangeInfo(*conn, exInfo)) != NN_OK)) { + return result; + } + + /* change to ready */ + if ((result = ch->ChangeToReady(exInfo)) != NN_OK) { + return result; + } + + /* receive ready signal */ + int8_t ready = -1; + tmpBuf = static_cast(&ready); + result = conn->Receive(tmpBuf, sizeof(int8_t)); + if (result != NN_OK || ready != 1) { + NN_LOG_ERROR("Failed to receive ready from " << peerIpPort << " in driver " << mName << ", result " << result); + return result; + } + + /* create ep */ + UBSHcomNetEndpointPtr newEp = new (std::nothrow) + NetAsyncEndpointShm(ch->Id(), ch.Get(), worker, this, worker->Index(), ShmMRHandleMap::GetInstance()); + if (NN_UNLIKELY(newEp.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new async shm ep in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + if (mEnableTls) { + auto childEp = newEp.ToChild(); + auto tmp = dynamic_cast(conn); + if (NN_UNLIKELY(childEp == nullptr || tmp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + childEp->EnableEncrypt(mOptions); + childEp->SetSecrets(tmp->Secret()); + } + + /* 1 transfer fd, 2 set upCtx, 3 set payload, 4 add ep into map, 5 set state */ + ch->UdsFD(conn->TransferFd()); + ch->UpContext1(reinterpret_cast(worker)); + ch->UpContext(reinterpret_cast(newEp.Get())); + newEp->StoreConnInfo(NetFunc::GetIpByFd(ch->UdsFD()), conn->ListenPort(), header.version, payload); + AddEp(newEp); + newEp->State().Set(NEP_ESTABLISHED); + + ep.Set(newEp.Get()); + + NN_ASSERT_LOG_RETURN(mChannelKeeper != nullptr, NN_ERROR); + if ((result = mChannelKeeper->AddShmChannel(ch)) != NN_OK) { + NN_LOG_ERROR("Adding Shm Channel failed, result: " << result); + return result; + } + + NN_LOG_INFO("New connection to " << oobIp << ":" << oobPort << " established, async ep id " << ep->Id() << + " worker info " << worker->Name()); + return NN_OK; +} + +NResult NetDriverShmWithOOB::MultiRailNewConnection(OOBTCPConnection &conn) +{ + NN_LOG_ERROR("Invalid operation, SHM is not supported by MultiRail"); + return NN_ERROR; +} + +void NetDriverShmWithOOB::DestroyEndpoint(UBSHcomNetEndpointPtr &ep) +{ + if (NN_UNLIKELY(ep.Get() == nullptr)) { + NN_LOG_WARN("The shm ep is null already."); + return; + } + + NN_LOG_INFO("Destroy endpoint id " << ep->Id()); + if (NN_LIKELY(mDelayReleaseTimer != nullptr)) { + mDelayReleaseTimer->EnqueueDelayRelease(ep); + } + + auto result = Remove(ep->Id()); + if (result == 0) { + NN_LOG_WARN("Unable to destroy shm endpoint as ep " << ep->Id() << " doesn't exist, maybe cleaned already"); + return; + } + + ep.Set(nullptr); +} + +void NetDriverShmWithOOB::DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) +{ + if (NN_UNLIKELY(mr.Get() == nullptr)) { + NN_LOG_WARN("Try to destroy null memory region in shm driver " << mName); + return; + } + if (!mMrChecker.Contains(mr->GetLKey())) { + NN_LOG_WARN("Try to destroy unowned memory region in shm driver " << mName); + return; + } + mMrChecker.UnRegister(mr->GetLKey()); + mr->UnInitialize(); +} + +void *NetDriverShmWithOOB::MapAndRegVaForUB(unsigned long memid, uint64_t &va) +{ + NN_LOG_ERROR("operation is not supported in shm"); + return nullptr; +} + +NResult NetDriverShmWithOOB::UnmapVaForUB(uint64_t &va) +{ + NN_LOG_ERROR("operation is not supported in shm"); + return NN_ERROR; +} +} +} \ No newline at end of file diff --git a/src/transport/shm/net_shm_driver_oob.h b/src/transport/shm/net_shm_driver_oob.h new file mode 100644 index 0000000000000000000000000000000000000000..0c90a644ac82d5d04c1972f2e8c2c9ccf3ddcbb5 --- /dev/null +++ b/src/transport/shm/net_shm_driver_oob.h @@ -0,0 +1,150 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_SHM_DRIVER_OOB_H +#define OCK_HCOM_NET_SHM_DRIVER_OOB_H + +#include "hcom.h" + +#include "net_common.h" +#include "net_delay_release_timer.h" +#include "net_oob.h" +#include "net_shm_common.h" +#include "securec.h" +#include "shm_channel_keeper.h" +#include "shm_handle.h" +#include "shm_mr_pool.h" + +namespace ock { +namespace hcom { +class NetDriverShmWithOOB : public UBSHcomNetDriver { +public: + NetDriverShmWithOOB(const std::string &name, bool startOob, UBSHcomNetDriverProtocol protocol) + : UBSHcomNetDriver(name, startOob, protocol) + { + OBJ_GC_INCREASE(NetDriverShmWithOOB); + } + + ~NetDriverShmWithOOB() override + { + OBJ_GC_DECREASE(NetDriverShmWithOOB); + } + + NResult Initialize(const UBSHcomNetDriverOptions &option) override; + + void UnInitialize() override; + + NResult Start() override; + void Stop() override; + + NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) override; + NResult CreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) override; + NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, unsigned long memid) override; + + void DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) override; + + NResult Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, + uint8_t clientGrpNo) override; + + NResult Connect(const std::string &oobIp, uint16_t oobPort, const std::string &payload, UBSHcomNetEndpointPtr &ep, + uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) override; + + NResult Connect(const std::string &serverUrl, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, + uint8_t serverGrpNo = 0, uint8_t clientGrpNo = 0, uint64_t ctx = 0) override; + + NResult MultiRailNewConnection(OOBTCPConnection &conn); + + void DestroyEndpoint(UBSHcomNetEndpointPtr &ep) override; + + NResult SendExchangeInfo(OOBTCPConnection &conn, ShmConnExchangeInfo &exInfo); + NResult ReceiveExchangeInfo(OOBTCPConnection &conn, ShmConnExchangeInfo &exInfo); + + void *MapAndRegVaForUB(unsigned long memid, uint64_t &va) override; + + NResult UnmapVaForUB(uint64_t &va) override; + + inline NResult ValidateMemoryRegion(uint64_t lKey, uintptr_t address, uint64_t size) + { + return mMrChecker.Validate(lKey, address, size); + } + + inline UBSHcomNetDriverOptions GetOptions() + { + return mOptions; + } + +protected: + NResult ValidateOptions(); + NResult CreateWorkerResource(); + NResult CreateWorkers(); + void ClearWorkers(); + void UnInitializeInner(); + void StopInner(); + + NResult HandleNewOobConn(OOBTCPConnection &conn); + NResult HandleNewRequest(ShmOpContextInfo &ctx, uint32_t immData); + NResult HandleReqPosted(ShmOpCompInfo &ctx); + NResult OneSideDone(ShmOpContextInfo *ctx); + + void HandleChanelKeeperMsg(const ShmChKeeperMsgHeader &header, const ShmChannelPtr &channelPtr); + void ProcessEpError(const ShmChannelPtr &channelPtr); + + NResult ConnectSyncEp(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint8_t serverGrpNo, uint64_t ctx); + + inline void AddEp(const UBSHcomNetEndpointPtr &newEp) + { + /* added into map */ + if (NN_LIKELY(newEp != nullptr)) { + std::lock_guard guard(mEndPointsMutex); + mEndPoints.emplace(newEp->Id(), newEp); + } + } + + inline bool Remove(uint64_t id) + { + std::lock_guard guard(mEndPointsMutex); + return (mEndPoints.erase(id) > 0); + } + + inline const std::string &ChooseListenIp() + { + if (NN_UNLIKELY(mFilteredIps.empty())) { + return CONST_EMPTY_STRING; + } + + return mFilteredIps[0]; + } + + void ClearShmLeftFile(); + + void HandleKeeperMsgGetMrFd(const ShmChKeeperMsgHeader &header, const ShmChannelPtr &channelPtr); + +protected: + std::vector mWorkers; + std::vector mFilteredIps; + NetMemPoolFixedPtr mOpCompMemPool = nullptr; + NetMemPoolFixedPtr mOpCtxMemPool = nullptr; + NetMemPoolFixedPtr mSglCompMemPool = nullptr; + ShmChannelKeeperPtr mChannelKeeper = nullptr; + DelayReleaseTimerPtr mDelayReleaseTimer = nullptr; + MemoryRegionChecker mMrChecker; + std::thread mClearThread; + std::atomic_bool mClearThreadStarted { false }; + +private: + friend class NetAsyncEndpointShm; + friend class NetSyncEndpointShm; +}; +} +} + +#endif // OCK_HCOM_NET_SHM_DRIVER_OOB_H diff --git a/src/transport/shm/net_shm_sync_endpoint.cpp b/src/transport/shm/net_shm_sync_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0139967e5fd004a81018ef358872f8a9a5314042 --- /dev/null +++ b/src/transport/shm/net_shm_sync_endpoint.cpp @@ -0,0 +1,711 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "shm_validation.h" +#include "hcom_log.h" +#include "net_shm_sync_endpoint.h" + +namespace ock { +namespace hcom { +NetSyncEndpointShm::NetSyncEndpointShm(uint64_t id, ShmChannel *ch, NetDriverShmWithOOB *driver, + const UBSHcomNetWorkerIndex &workerIndex, ShmSyncEndpoint *shmEp, ShmMRHandleMap &handleMap) + : NetEndpointImpl(id, workerIndex), mShmCh(ch), mDriver(driver), mShmEp(shmEp), mrHandleMap(handleMap) +{ + if (mShmCh != nullptr) { + mShmCh->IncreaseRef(); + } + if (mShmEp != nullptr) { + mShmEp->IncreaseRef(); + } + + if (mDriver != nullptr) { + mDriver->IncreaseRef(); + } + + if (mShmCh != nullptr && mDriver != nullptr) { + mSegSize = mDriver->GetOptions().mrSendReceiveSegSize; + mAllowedSize = mSegSize - sizeof(UBSHcomNetTransHeader); + } + + OBJ_GC_INCREASE(NetSyncEndpointShm); +} + +NetSyncEndpointShm::~NetSyncEndpointShm() +{ + if (mShmCh != nullptr) { + mShmCh->DecreaseRef(); + } + if (mShmEp != nullptr) { + mShmEp->DecreaseRef(); + } + + if (mDriver != nullptr) { + mDriver->DecreaseRef(); + } + + OBJ_GC_DECREASE(NetSyncEndpointShm); +} + +NResult NetSyncEndpointShm::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, opCode, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post send as validate fail"); + return result; + } + + if (NN_UNLIKELY((result = PostSendValidationMaxSize(request, mAllowedSize, mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post send as validate size fail"); + return result; + } + + /* get free buffer from channel */ + uintptr_t address = 0; + uint64_t offset = 0; + result = mShmCh->DCGetFreeBuck(address, offset, NN_NO100, mDefaultTimeout); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Shm Failed to get free buck from Shm Channel " << mShmCh->Id() << ", result " << result); + return result; + } + + /* copy header */ + auto *header = reinterpret_cast(address); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = seqNO == 0 ? NextSeq() : seqNO; + header->flags = NTH_TWO_SIDE; + + mLastSendSeqNo = header->seqNo; + + /* copy message */ + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(address + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("Shm Failed to post send message as encryption failed"); + mShmCh->DCMarkBuckFree(address); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + header->dataLength = request.size; + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(address + sizeof(UBSHcomNetTransHeader)), + mShmCh->GetSendDCBuckSize() - sizeof(UBSHcomNetTransHeader), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mShmCh->DCMarkBuckFree(address); + NN_LOG_ERROR("Failed to copy request to address"); + return NN_INVALID_PARAM; + } + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + UBSHcomNetTransRequest innerReq = request; + innerReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + innerReq.lAddress = address; + + uint64_t finishTime = GetFinishTime(); + + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_SYNC_POST_SEND); + do { + result = mShmEp->PostSend(mShmCh, innerReq, offset, 0, mDefaultTimeout); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post send request, result: " << result); + mShmCh->DCMarkBuckFree(address); + TRACE_DELAY_END(SHM_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetSyncEndpointShm::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendValidation(mState, mId, opCode, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post send as validation fail"); + return result; + } + + if (NN_UNLIKELY((result = PostSendValidationMaxSize(request, mAllowedSize, mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post send as validate size failed"); + return result; + } + + /* get free buffer from channel */ + uintptr_t address = 0; + uint64_t offset = 0; + result = mShmCh->DCGetFreeBuck(address, offset, NN_NO100, mDefaultTimeout); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Failed to get free buck from shm channel " << mShmCh->Id() << ", result " << result); + return result; + } + + /* copy header */ + auto *header = reinterpret_cast(address); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header->timeout = opInfo.timeout; + header->errorCode = opInfo.errorCode; + mLastSendSeqNo = header->seqNo; + + /* copy message */ + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(address + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("Failed to post send message as encryption failure"); + mShmCh->DCMarkBuckFree(address); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + header->dataLength = request.size; + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(address + sizeof(UBSHcomNetTransHeader)), + mShmCh->GetSendDCBuckSize() - sizeof(UBSHcomNetTransHeader), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mShmCh->DCMarkBuckFree(address); + NN_LOG_ERROR("Failed to copy request to address"); + return NN_INVALID_PARAM; + } + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + + UBSHcomNetTransRequest innerReq = request; + innerReq.lAddress = address; + innerReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_SYNC_POST_SEND); + do { + result = mShmEp->PostSend(mShmCh, innerReq, offset, 0, mDefaultTimeout); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post send request, result " << result); + mShmCh->DCMarkBuckFree(address); + TRACE_DELAY_END(SHM_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetSyncEndpointShm::PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNO) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendRawValidation(mState, mId, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post send raw as validate fail"); + return result; + } + + if (NN_UNLIKELY((result = PostSendValidationMaxSize(request, mSegSize, mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post send raw as validate size fail"); + return result; + } + + /* get free buffer from channel */ + uintptr_t address = 0; + uint64_t offset = 0; + result = mShmCh->DCGetFreeBuck(address, offset, NN_NO100, mDefaultTimeout); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Failed to get free buck from shm channel " << mShmCh->Id() << ", result " << result); + return result; + } + + UBSHcomNetTransRequest innerReq = request; + innerReq.lAddress = address; + mLastSendSeqNo = seqNO; + + /* copy message */ + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, reinterpret_cast(request.lAddress), request.size, + reinterpret_cast(address), cipherLen)) { + NN_LOG_ERROR("Failed to post send message as encryption failure"); + mShmCh->DCMarkBuckFree(address); + return NN_ENCRYPT_FAILED; + } + innerReq.size = cipherLen; + } else { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(address), mShmCh->GetSendDCBuckSize(), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + NN_LOG_ERROR("Failed to copy request to address"); + mShmCh->DCMarkBuckFree(address); + return NN_INVALID_PARAM; + } + innerReq.size = request.size; + } + + /* if result is timeout, need to retry */ + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_SYNC_POST_SEND_RAW); + do { + result = mShmEp->PostSend(mShmCh, innerReq, offset, seqNO, mDefaultTimeout); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_SYNC_POST_SEND_RAW, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post send request, result " << result); + mShmCh->DCMarkBuckFree(address); + TRACE_DELAY_END(SHM_EP_SYNC_POST_SEND_RAW, result); + return result; +} + +NResult NetSyncEndpointShm::PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostSendSglValidation(mState, mId, mDriver, seqNo, request, mSegSize, + mIsNeedEncrypt, mAes)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post send raw sgl as validate fail"); + return result; + } + + /* get free buffer from channel */ + uintptr_t address = 0; + uint64_t offset = 0; + result = mShmCh->DCGetFreeBuck(address, offset, NN_NO100, mDefaultTimeout); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Failed to get free buck from shm channel " << mShmCh->Id() << ", result " << result); + return result; + } + + uint32_t dataLen = 0; + uint32_t iovOffset = 0; + + UBSHcomNetTransRequest innerReq = {}; + innerReq.lAddress = address; + mLastSendSeqNo = seqNo; + + /* copy message */ + if (mIsNeedEncrypt) { + for (uint16_t i = 0; i < request.iovCount; i++) { + dataLen += request.iov[i].size; + } + + UBSHcomNetMessage tmpMsg {}; + bool messageReady = tmpMsg.AllocateIfNeed(dataLen); + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Failed to allocate net msg buffer failed"); + mShmCh->DCMarkBuckFree(address); + return NN_MALLOC_FAILED; + } + for (uint16_t i = 0; i < request.iovCount; i++) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(reinterpret_cast(tmpMsg.mBuf) + iovOffset), + request.iov[i].size, reinterpret_cast(request.iov[i].lAddress), + request.iov[i].size) != NN_OK)) { + mShmCh->DCMarkBuckFree(address); + NN_LOG_WARN("Invalid operation to memcpy_s in shm encrypt PostSendRawSgl"); + return NN_ERROR; + } + iovOffset += request.iov[i].size; + } + + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, tmpMsg.mBuf, dataLen, reinterpret_cast(address), cipherLen)) { + NN_LOG_ERROR("Failed to post send message as encryption failure"); + mShmCh->DCMarkBuckFree(address); + return NN_ENCRYPT_FAILED; + } + + innerReq.size = cipherLen; + } else { + for (uint16_t i = 0; i < request.iovCount; i++) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(address + iovOffset), request.iov[i].size, + reinterpret_cast(request.iov[i].lAddress), request.iov[i].size) != NN_OK)) { + mShmCh->DCMarkBuckFree(address); + NN_LOG_WARN("Invalid operation to memcpy_s in shm PostSendRawSgl"); + return NN_ERROR; + } + dataLen += request.iov[i].size; + iovOffset += request.iov[i].size; + } + innerReq.size = dataLen; + } + + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SHM_EP_SYNC_POST_SEND_RAW_SGL); + do { + result = mShmEp->PostSendRawSgl(mShmCh, innerReq, request, offset, seqNo, mDefaultTimeout); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_SYNC_POST_SEND_RAW_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post send request, result " << result); + mShmCh->DCMarkBuckFree(address); + TRACE_DELAY_END(SHM_EP_SYNC_POST_SEND_RAW_SGL, result); + return result; +} + +NResult NetSyncEndpointShm::PostRead(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteValidation(mState, mId, mDriver, mShmCh, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post read as validate fail"); + return result; + } + + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SHM_EP_SYNC_POST_READ); + do { + result = mShmEp->PostRead(mShmCh, request, mrHandleMap); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_SYNC_POST_READ, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + flag = false; + } while (flag); + + TRACE_DELAY_END(SHM_EP_SYNC_POST_READ, result); + return result; +} + +NResult NetSyncEndpointShm::PostRead(const UBSHcomNetTransSglRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostReadWriteSglValidation(mState, mId, mDriver, mShmCh, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post read sgl as validate fail"); + return result; + } + + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SHM_EP_SYNC_POST_READ_SGL); + do { + result = mShmEp->PostRead(mShmCh, request, mrHandleMap); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_SYNC_POST_READ_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + flag = false; + } while (flag); + + TRACE_DELAY_END(SHM_EP_SYNC_POST_READ_SGL, result); + return result; +} + +NResult NetSyncEndpointShm::PostWrite(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = ReadWriteValidation(mState, mId, mDriver, mShmCh, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post write as validate fail"); + return result; + } + + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SHM_EP_SYNC_POST_WRITE); + do { + result = mShmEp->PostWrite(mShmCh, request, mrHandleMap); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_SYNC_POST_WRITE, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + flag = false; + } while (flag); + + TRACE_DELAY_END(SHM_EP_SYNC_POST_WRITE, result); + return result; +} + +NResult NetSyncEndpointShm::PostWrite(const UBSHcomNetTransSglRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = PostReadWriteSglValidation(mState, mId, mDriver, mShmCh, request)) != NN_OK)) { + NN_LOG_ERROR("Shm failed to sync post write sgl as validate fail"); + return result; + } + + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SHM_EP_SYNC_POST_WRITE_SGL); + do { + result = mShmEp->PostWrite(mShmCh, request, mrHandleMap); + if (result == SH_OK) { + TRACE_DELAY_END(SHM_EP_SYNC_POST_WRITE_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + flag = false; + } while (flag); + + TRACE_DELAY_END(SHM_EP_SYNC_POST_WRITE_SGL, result); + return result; +} + +NResult NetSyncEndpointShm::Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) +{ + ShmOpContextInfo opCtx {}; + NResult result = NN_OK; + mDemandPollingOpType = ShmOpContextInfo::SH_RECEIVE; + uint32_t immData = 0; + + if (NN_UNLIKELY(mExistDelayEvent)) { + mExistDelayEvent = false; + + auto *ch = reinterpret_cast(mDelayHandleReceiveEvent.peerChannelAddress); + if (NN_UNLIKELY(ch == nullptr)) { + NN_LOG_ERROR("Shm Got invalid event in " << mShmEp->GetName() << ", dropped it"); + return NN_ERROR; + } + + uintptr_t address = 0; + if (NN_UNLIKELY((result = ch->GetPeerDataAddressByOffset(mDelayHandleReceiveEvent.dataOffset, address)) != + SH_OK)) { + NN_LOG_ERROR("Shm Got invalid event " << mShmEp->GetName() << " as get data address failed, dropped it"); + return result; + } + + opCtx = ShmOpContextInfo(ch, address, mDelayHandleReceiveEvent.dataSize, + static_cast(mDelayHandleReceiveEvent.opType), + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR); + } else if (NN_UNLIKELY((result = mShmEp->Receive(timeout, opCtx, immData)) != NN_OK)) { + NN_LOG_ERROR("Shm Failed to receive response from peer, result " << result); + return result; + } + + if (NN_UNLIKELY(opCtx.opType != mDemandPollingOpType)) { + NN_LOG_ERROR("Shm Got un-demand operation type " << opCtx.opType << ", ignored"); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_ERROR; + } + + auto *tmpHeader = reinterpret_cast(opCtx.dataAddress); + result = NetFunc::ValidateHeaderWithSeqNo(*tmpHeader, opCtx.dataSize, mLastSendSeqNo); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Shm Failed to validate received header param, ep " << Id()); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return result; + } + + size_t realDataSize = 0; + if (mIsNeedEncrypt) { + const void *cipherData = reinterpret_cast(opCtx.dataAddress + sizeof(UBSHcomNetTransHeader)); + realDataSize = mAes.GetRawLen(tmpHeader->dataLength); + uint32_t decryptLen = 0; + bool msgReady = mRespMessage.AllocateIfNeed(realDataSize); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Shm Failed to allocate memory for response size " << opCtx.dataSize << + ", probably out of memory"); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_MALLOC_FAILED; + } + if (!mAes.Decrypt(mSecrets, cipherData, tmpHeader->dataLength, mRespMessage.mBuf, decryptLen)) { + NN_LOG_ERROR("Shm Failed to decrypt data"); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_DECRYPT_FAILED; + } + } else { + realDataSize = tmpHeader->dataLength; + auto msgReady = mRespMessage.AllocateIfNeed(realDataSize); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size " << realDataSize << ", probably out of memory"); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_MALLOC_FAILED; + } + + auto tmpDataAddress = reinterpret_cast(opCtx.dataAddress + sizeof(UBSHcomNetTransHeader)); + if (NN_UNLIKELY(memcpy_s(mRespMessage.mBuf, mRespMessage.GetBufLen(), tmpDataAddress, realDataSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy tmpDataAddress to mRespMessage"); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_INVALID_PARAM; + } + } + + if (NN_UNLIKELY(memcpy_s(&(mRespCtx.mHeader), sizeof(UBSHcomNetTransHeader), tmpHeader, + sizeof(UBSHcomNetTransHeader)) != NN_OK)) { + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + NN_LOG_ERROR("Failed to copy tmpHeader to mRespCtx"); + return NN_INVALID_PARAM; + } + mRespMessage.mDataLen = realDataSize; + mRespCtx.mHeader.dataLength = realDataSize; + mRespCtx.mMessage = &mRespMessage; + ctx.mHeader = mRespCtx.mHeader; + ctx.mMessage = mRespCtx.mMessage; + + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return result; +} + +NResult NetSyncEndpointShm::ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) +{ + ShmOpContextInfo opCtx {}; + NResult result = NN_OK; + mDemandPollingOpType = ShmOpContextInfo::SH_RECEIVE; + uint32_t immData = 0; + if (NN_UNLIKELY(mExistDelayEvent)) { + mExistDelayEvent = false; + + auto *ch = reinterpret_cast(mDelayHandleReceiveEvent.peerChannelAddress); + if (NN_UNLIKELY(ch == nullptr)) { + NN_LOG_ERROR("Got invalid event in " << mShmEp->GetName() << ", dropped it"); + return NN_ERROR; + } + uintptr_t address = 0; + if (NN_UNLIKELY((result = ch->GetPeerDataAddressByOffset(mDelayHandleReceiveEvent.dataOffset, address)) != + SH_OK)) { + NN_LOG_ERROR("Got invalid event " << mShmEp->GetName() << " as get data address failed, dropped it"); + return result; + } + opCtx = ShmOpContextInfo(ch, address, mDelayHandleReceiveEvent.dataSize, + static_cast(mDelayHandleReceiveEvent.opType), + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR); + } else if (NN_UNLIKELY((result = mShmEp->Receive(timeout, opCtx, immData)) != NN_OK)) { + NN_LOG_ERROR("Failed to get operation,time out"); + return result; + } + + if (NN_UNLIKELY(opCtx.opType != mDemandPollingOpType)) { + NN_LOG_ERROR("Got un-demand operation type " << opCtx.opType << ", ignored"); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_ERROR; + } + if (NN_UNLIKELY(immData != mLastSendSeqNo)) { + NN_LOG_ERROR("Received un-matched seq no " << immData << ", demand seq no " << mLastSendSeqNo); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_SEQ_NO_NOT_MATCHED; + } + + size_t realDataSize = 0; + if (mIsNeedEncrypt) { + const void *cipherData = reinterpret_cast(opCtx.dataAddress); + realDataSize = mAes.GetRawLen(opCtx.dataSize); + uint32_t decryptLen = 0; + bool msgReady = mRespMessage.AllocateIfNeed(realDataSize); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size " << opCtx.dataSize << + ", probably out of memory"); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_MALLOC_FAILED; + } + if (!mAes.Decrypt(mSecrets, cipherData, opCtx.dataSize, mRespMessage.mBuf, decryptLen)) { + NN_LOG_ERROR("Failed to decrypt data"); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_DECRYPT_FAILED; + } + } else { + realDataSize = opCtx.dataSize; + auto msgReady = mRespMessage.AllocateIfNeed(realDataSize); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size " << realDataSize << ", probably out of memory"); + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return NN_MALLOC_FAILED; + } + + auto tmpDataAddress = reinterpret_cast(opCtx.dataAddress); + if (NN_UNLIKELY(memcpy_s(mRespMessage.mBuf, mRespMessage.GetBufLen(), tmpDataAddress, realDataSize) != NN_OK)) { + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + NN_LOG_ERROR("Failed to copy tmpDataAddress to mRespMessage"); + return NN_INVALID_PARAM; + } + } + + mRespMessage.mDataLen = realDataSize; + mRespCtx.mMessage = &mRespMessage; + + ctx.mHeader = {}; + ctx.mHeader.opCode = -1; + ctx.mHeader.seqNo = immData; + ctx.mHeader.dataLength = realDataSize; + ctx.mMessage = mRespCtx.mMessage; + + opCtx.channel->DCMarkPeerBuckFree(opCtx.dataAddress); + return result; +} + +NResult NetSyncEndpointShm::WaitCompletion(int32_t timeout) +{ + ShmEvent event {}; + NResult result = NN_OK; + +POLL_EVENT: + if (NN_UNLIKELY(result = mShmEp->DequeueEvent(timeout, event)) != NN_OK) { + return result; + } + + // repost if receive opType + if (event.opType == ShmOpContextInfo::SH_RECEIVE) { + if (!mExistDelayEvent) { + mDelayHandleReceiveEvent = event; + mExistDelayEvent = true; + goto POLL_EVENT; + } else { + NN_LOG_ERROR("Receive operation type has double received, prev context is not process"); + return SH_ERROR; + } + } + + if (event.opType == ShmOpContextInfo::ShmOpType::SH_SEND) { + auto compEvent = reinterpret_cast(event.peerChannelAddress); + if (compEvent != nullptr && compEvent->channel != nullptr) { + compEvent->channel->DecreaseRef(); + } + return result; + } + + if (event.opType == ShmOpContextInfo::SH_READ || event.opType == ShmOpContextInfo::SH_WRITE || + event.opType == ShmOpContextInfo::SH_SGL_READ || event.opType == ShmOpContextInfo::SH_SGL_WRITE) { + auto opContextInfo = reinterpret_cast(event.peerChannelAddress); + if (opContextInfo != nullptr && opContextInfo->channel != nullptr) { + opContextInfo->channel->DecreaseRef(); + } + return result; + } + + NN_LOG_ERROR("Got un-demand operation type " << event.opType << ", ignored"); + return SH_ERROR; +} +} +} \ No newline at end of file diff --git a/src/transport/shm/net_shm_sync_endpoint.h b/src/transport/shm/net_shm_sync_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..ec043e741af6d6561833f290ba0a9474955a05b5 --- /dev/null +++ b/src/transport/shm/net_shm_sync_endpoint.h @@ -0,0 +1,195 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_SHM_ENDPOINT_H +#define OCK_HCOM_NET_SHM_ENDPOINT_H + +#include "hcom.h" +#include "transport/net_endpoint_impl.h" +#include "hcom_utils.h" +#include "net_common.h" +#include "net_monotonic.h" +#include "net_security_alg.h" +#include "net_shm_common.h" +#include "net_shm_driver_oob.h" +#include "shm_composed_endpoint.h" +#include "shm_handle_fds.h" + +namespace ock { +namespace hcom { +class NetSyncEndpointShm : public NetEndpointImpl { +public: + NetSyncEndpointShm(uint64_t id, ShmChannel *ch, NetDriverShmWithOOB *driver, + const UBSHcomNetWorkerIndex &workerIndex, ShmSyncEndpoint *shmEp, ShmMRHandleMap &handleMap); + ~NetSyncEndpointShm() override; + + NResult SetEpOption(UBSHcomEpOptions &epOptions) override + { + NN_LOG_WARN("[SHM SyncEp] Empty function for now"); + return NN_OK; + } + + uint32_t GetSendQueueCount() override + { + NN_LOG_WARN("[SHM SyncEp] Empty function for now"); + return 0; + } + + const std::string &PeerIpAndPort() override + { + if (NN_LIKELY(mShmCh != nullptr)) { + return mShmCh->PeerIpPort(); + } + + return CONST_EMPTY_STRING; + } + + const std::string &UdsName() override + { + if (NN_LIKELY(mShmCh != nullptr)) { + return mShmCh->UdsName(); + } + + return CONST_EMPTY_STRING; + } + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + NResult PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + NResult PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) override; + NResult PostRead(const UBSHcomNetTransRequest &request) override; + NResult PostRead(const UBSHcomNetTransSglRequest &request) override; + NResult PostWrite(const UBSHcomNetTransRequest &request) override; + NResult PostWrite(const UBSHcomNetTransSglRequest &request) override; + NResult WaitCompletion(int32_t timeout) override; + NResult Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) override; + NResult ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) override; + + NResult GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &idInfo) override + { + if (!mState.Compare(NEP_ESTABLISHED)) { + NN_LOG_ERROR("[SHM SyncEp] EP is not established"); + return NN_EP_NOT_ESTABLISHED; + } + + if (!mDriver->mStartOobSvr) { + NN_LOG_ERROR("[SHM SyncEp] oob server is not start"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + idInfo = mRemoteUdsIdInfo; + return NN_OK; + } + + bool GetPeerIpPort(std::string &ip, uint16_t &port) override + { + NN_LOG_WARN("Invalid operation for shm, shm does not have ip and port"); + return false; + } + + NResult SendFds(int fds[], uint32_t len) override + { + if (NN_UNLIKELY(len < NN_NO1 || len > NN_NO4)) { + NN_LOG_ERROR("Failed to send fds in shm async ep as length should more than 0 and less than 4."); + return NN_PARAM_INVALID; + } + + if (NN_UNLIKELY(!mState.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Failed to send fds in shm async ep as endpoint " << mId << " is not established, state is " << + UBSHcomNEPStateToString(mState.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + int innerFds[NN_NO4] = {0}; + for (uint32_t i = 0; i < len; i++) { + innerFds[i] = fds[i]; + if (fds[i] <= 0) { + NN_LOG_ERROR("Failed to send fds in shm async ep, as invalid fds index:" << i); + return NN_INVALID_PARAM; + } + } + + std::lock_guard guard(mShmCh->mFdMutex); + ShmChKeeperMsgHeader header {}; + header.msgType = ShmChKeeperMsgType::EXCHANGE_USER_FD; + header.dataSize = len; + if (NN_UNLIKELY(::send(mShmCh->UdsFD(), &header, sizeof(ShmChKeeperMsgHeader), MSG_NOSIGNAL) <= 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to send header info of exchange external fd to peer, error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_ERROR; + } + + return ShmHandleFds::SendMsgFds(mShmCh->UdsFD(), innerFds, NN_NO4); + } + + NResult ReceiveFds(int fds[], uint32_t len, int32_t timeoutSec) override + { + if (NN_UNLIKELY(len < NN_NO1 || len > NN_NO4)) { + NN_LOG_ERROR("Failed to receive fds in shm async ep as length should more than 0 and less than 4."); + return NN_PARAM_INVALID; + } + + if (NN_UNLIKELY(!mState.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Failed to receive fds in shm async ep as endpoint " << mId << + " is not established, state is " << UBSHcomNEPStateToString(mState.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + return mShmCh->RemoveUserFds(fds, len, timeoutSec); + } + + void Close() override + { + if (NN_UNLIKELY(mShmCh != nullptr)) { + mShmCh->Close(); + } + } + +private: + static bool inline NeedRetry(HResult res) + { + if (res == SH_OP_CTX_FULL || res == SH_RETRY_FULL) { + return true; + } + + return false; + } + + uint64_t inline GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + + ShmChannel *mShmCh = nullptr; + NetDriverShmWithOOB *mDriver = nullptr; + ShmSyncEndpoint *mShmEp = nullptr; + uint32_t mAllowedSize = 0; + uint32_t mLastSendSeqNo = 0; + ShmOpContextInfo::ShmOpType mDemandPollingOpType = ShmOpContextInfo::SH_SEND; + UBSHcomNetMessage mRespMessage; + UBSHcomNetResponseContext mRespCtx; + ShmMRHandleMap &mrHandleMap; + + bool mExistDelayEvent = false; + ShmEvent mDelayHandleReceiveEvent; +}; +} +} + +#endif // OCK_HCOM_NET_SHM_ENDPOINT_H diff --git a/src/transport/shm/shm_channel.cpp b/src/transport/shm/shm_channel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ca3d528a507d843e713870e55104df75fe2ca3c --- /dev/null +++ b/src/transport/shm/shm_channel.cpp @@ -0,0 +1,164 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "shm_channel.h" + +namespace ock { +namespace hcom { +/* get channel fd queue size */ +uint32_t ShmChannel::gQueueSizeCap = GetQueueCap(); + +HResult ShmChannel::Initialize() +{ + /* create data channel */ + ShmDataChannelOptions opt(mId, mSendDCBuckSize, mSendDCBuckCount, true); + ShmDataChannelPtr dc = new (std::nothrow) ShmDataChannel(mName, opt, &mState); + if (NN_UNLIKELY(dc == nullptr)) { + NN_LOG_ERROR("Failed to new ShmDataChannel " << mName << ", probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + /* initialize channel */ + auto result = dc->Initialize(); + if (NN_UNLIKELY(result != SH_OK)) { + NN_LOG_ERROR("Failed to init ShmDataChannel " << mName << ", result " << result); + return result; + } + + dc->IncreaseRef(); + mDataChannel = dc.Get(); + NN_LOG_INFO("shm channel " << mName << "," << mId << " initialized "); + return SH_OK; +} + +void ShmChannel::UnInitialize() +{ + mState.CAS(CH_NEW, CH_BROKEN); + + if (mDataChannel != nullptr) { + mDataChannel->DecreaseRef(); + mDataChannel = nullptr; + } + + if (mPeerDataChannel != nullptr) { + mPeerDataChannel->DecreaseRef(); + mPeerDataChannel = nullptr; + } + + if (mPeerEventQueue != nullptr) { + mPeerEventQueue->DecreaseRef(); + mPeerEventQueue = nullptr; + } + NetFunc::NN_SafeCloseFd(mFd); +} + +HResult ShmChannel::ValidateExchangeInfo(const ShmConnExchangeInfo &info) +{ + if (NN_UNLIKELY(info.qCapacity == 0 || info.qCapacity > NN_NO8192 || info.queueFd <= 0)) { + NN_LOG_ERROR("Failed to change ShmChannel" << mName << ":" << mId << + " to ready as invalid queue capacity or fd from peer"); + return SH_PARAM_INVALID; + } + + if (NN_UNLIKELY(info.GetQueueName().empty())) { + NN_LOG_ERROR("Failed to change ShmChannel" << mName << ":" << mId << + " to ready as invalid queue name from peer"); + return SH_PARAM_INVALID; + } + + if (NN_UNLIKELY(info.dcBuckCount == 0 || info.dcBuckSize == 0 || info.dcBuckCount > NN_NO65535 || + info.dcBuckSize > NET_SGE_MAX_SIZE)) { + NN_LOG_ERROR("Failed to change ShmChannel" << mName << ":" << mId << + " to ready as invalid buck size or count from peer"); + return SH_PARAM_INVALID; + } + + if (NN_UNLIKELY(info.GetDCName().empty())) { + NN_LOG_ERROR("Failed to change ShmChannel" << mName << ":" << mId << + " to ready as invalid data channel name from peer"); + return SH_PARAM_INVALID; + } + + if (NN_UNLIKELY(info.channelId == 0 || info.channelAddress == 0 || info.channelFd <= 0)) { + NN_LOG_ERROR("Failed to change ShmChannel" << mName << ":" << mId << + " to ready as invalid data channel id, address or fd from peer"); + return SH_PARAM_INVALID; + } + + return SH_OK; +} + +HResult ShmChannel::ChangeToReady(const ShmConnExchangeInfo &info) +{ + NN_LOG_INFO("Try to change shm channel " << mName << ":" << mId << " to ready with ex info " << info.ToString()); + HResult result = SH_OK; + if (NN_UNLIKELY((result = ValidateExchangeInfo(info)) != SH_OK)) { + return result; + } + + /* new eq handle for send msg to peer event queue of worker */ + ShmHandlePtr peerEqHandle = new (std::nothrow) + ShmHandle(mName, info.GetQueueName(), mId, ShmEventQueue::MemSize(info.qCapacity), info.queueFd, false); + if (NN_UNLIKELY(peerEqHandle.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new shmHandle in ShmChannel " << mName << ", probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + /* initialize event queue handle without ownership */ + if (NN_UNLIKELY((result = peerEqHandle->Initialize()) != SH_OK)) { + NN_LOG_ERROR("Failed to change ShmChannel " << mName << ":" << mId << " to ready as result " << result); + return result; + } + + /* new event queue object */ + ShmEventQueuePtr queue = new (std::nothrow) ShmEventQueue(mName, info.qCapacity, peerEqHandle); + if (NN_UNLIKELY(queue.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new event queue in ShmChannel " << mName << ", probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + /* initialize event queue */ + if (NN_UNLIKELY((result = queue->Initialize()) != SH_OK)) { + NN_LOG_ERROR("Failed to change ShmChannel " << mName << ":" << mId << " to ready as result " << result); + return result; + } + + /* create peer data channel */ + ShmDataChannelOptions opt(info.channelId, info.dcBuckSize, info.dcBuckCount, info.channelFd, false); + (void)opt.SetFileName(info.GetDCName()); + ShmDataChannelPtr peerDC = new (std::nothrow) ShmDataChannel(mName + ":peer", opt, &mState); + if (NN_UNLIKELY(peerDC.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new data channel of peer in ShmChannel " << mName << ", probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + /* initialize dc */ + if (NN_UNLIKELY((result = peerDC->Initialize()) != SH_OK)) { + NN_LOG_ERROR("Failed to change ShmChannel " << mName << ":" << mId << " to ready as result " << result); + return result; + } + + mPeerEventQueue = queue.Get(); + mPeerEventQueue->IncreaseRef(); + + mPeerDataChannel = peerDC.Get(); + mPeerDataChannel->IncreaseRef(); + + mPeerChId = info.channelId; + + mPeerEventPooling = info.mode == ShmPollingMode::SHM_EVENT_POLLING; + + mPeerChAddress = info.channelAddress; + + return SH_OK; +} +} +} \ No newline at end of file diff --git a/src/transport/shm/shm_channel.h b/src/transport/shm/shm_channel.h new file mode 100644 index 0000000000000000000000000000000000000000..0358629edb96479b83a7ee0e0b8f4f153533e384 --- /dev/null +++ b/src/transport/shm/shm_channel.h @@ -0,0 +1,565 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SHM_CHANNEL_H +#define OCK_HCOM_SHM_CHANNEL_H + +#include +#include +#include + +#include "net_common.h" +#include "shm_channel_keeper.h" +#include "shm_common.h" +#include "shm_data_channel.h" +#include "shm_queue.h" + +namespace ock { +namespace hcom { +constexpr uint32_t QUEUE_TIMEOUT_US = NN_NO5 * NN_NO1000000; // 5s + +class ShmChannel { +public: + static uint32_t gQueueSizeCap; + +public: + static HResult CreateAndInit(const std::string &name, uint64_t id, uint32_t dcBuckSize, uint16_t dcBuckCount, + ShmChannelPtr &out) + { + ShmChannelPtr tmp = new (std::nothrow) ShmChannel(name, id, dcBuckSize, dcBuckCount); + if (NN_UNLIKELY(tmp.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new ShmChannel " << name << ", probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + auto result = tmp->Initialize(); + if (NN_UNLIKELY(result != SH_OK)) { + return result; + } + + out.Set(tmp.Get()); + return result; + } + +public: + ShmChannel(const std::string &name, uint64_t id, uint32_t dcBuckSize, uint16_t dcBuckCount) + : mId(id), mName(name), mSendDCBuckSize(dcBuckSize), mSendDCBuckCount(dcBuckCount) + { + OBJ_GC_INCREASE(ShmChannel); + } + + ~ShmChannel() + { + UnInitialize(); + OBJ_GC_DECREASE(ShmChannel); + } + + HResult Initialize(); + void UnInitialize(); + + inline const std::string &PeerIpPort() const + { + return mPeerIpPort; + } + + inline void PeerIpAndPort(const std::string &value) + { + mPeerIpPort = value; + } + + /* + * @brief Set a context by caller + */ + inline uint64_t UpContext() const + { + return mUpCtx; + } + + /* + * @brief Set a context by caller + */ + inline void UpContext(uint64_t value) + { + mUpCtx = value; + } + + /* + * @brief Set a context by caller + */ + inline uint64_t UpContext1() const + { + return mUpCtx1; + } + + /* + * @brief Set a context by caller + */ + inline void UpContext1(uint64_t value) + { + mUpCtx1 = value; + } + + /* + * @brief Get the file description of uds + */ + inline int UdsFD() const + { + return mFd; + } + + inline void UdsFD(int fd) + { + mFd = fd; + } + + inline const std::string &UdsName() const + { + return mUdsName; + } + + inline void UdsName(std::string udsName) + { + mUdsName = udsName; + } + + inline uint64_t Id() const + { + return mId; + } + + inline void Close() + { + NetFunc::NN_SafeCloseFd(mFd); + } + + bool FillExchangeInfo(ShmConnExchangeInfo &info) const; + + HResult ChangeToReady(const ShmConnExchangeInfo &info); + + inline HResult DCGetFreeBuck(uintptr_t &address, uint64_t &offsetToBase, uint16_t waitPeriodUs = NN_NO100, + int32_t timeoutSecond = -1) + { + NN_ASSERT_LOG_RETURN(mDataChannel != nullptr, SH_NOT_INITIALIZED) + return mDataChannel->TryOccupyWithWait(address, offsetToBase, waitPeriodUs, timeoutSecond); + } + + inline void DCMarkBuckFree(uintptr_t address) + { + if (NN_UNLIKELY(mDataChannel == nullptr)) { + NN_LOG_WARN("data channel is null in DCMarkBuckFree"); + return; + } + mDataChannel->MarkFree(address); + } + + inline void DCMarkPeerBuckFree(uintptr_t address) + { + if (NN_UNLIKELY(mPeerDataChannel == nullptr)) { + NN_LOG_WARN("data channel is null in DCMarkBuckFree"); + return; + } + mPeerDataChannel->MarkFree(address); + } + + HResult EQEventEnqueue(ShmEvent &event); + + HResult GetRemoteMrFds(uint32_t remoteKey, int &rfd); + HResult GetRemoteMrHandle(uint32_t remoteKey, uint32_t bufSize, ShmMRHandleMap &mrHandleMap); + void AddOpCtxInfo(ShmOpContextInfo *shmCtxInfo); + void AddOpCompInfo(ShmOpCompInfo *compInfo); + + HResult RemoveOpCtxInfo(ShmOpContextInfo *ctxInfo); + HResult RemoveOpCompInfo(ShmOpCompInfo *compInfo); + + // need to call this when qp broken, to get these contexts to return mrs + void GetCtxPosted(ShmOpContextInfo *&remaining); + void GetCompPosted(ShmOpCompInfo *&remaining); + + inline uint64_t PeerChannelId() const + { + return mPeerChId; + } + + inline uintptr_t PeerChannelAddress() const + { + return mPeerChAddress; + } + + inline UBSHcomNetAtomicState &State() + { + return mState; + } + + HResult GetPeerDataAddressByOffset(uint64_t offset, uintptr_t &address); + + HResult AddMrFd(int fd) + { + std::unique_lock guard(mMrFdQueueMutex); + if (mMrFdQueue.size() >= gQueueSizeCap) { + NN_LOG_ERROR("Failed to add fd in the queue, the queue size is exceeded in channel " << mName << " " << + mId); + return SH_FDS_QUEUE_FULL; + } + + mMrFdQueue.push(fd); + return SH_OK; + } + + HResult RemoveMrFd(int &fd) + { + bool flag = true; + auto start = NetMonotonic::TimeUs(); + do { + { + std::lock_guard guard(mMrFdQueueMutex); + if (!mMrFdQueue.empty()) { + fd = mMrFdQueue.front(); + mMrFdQueue.pop(); + return SH_OK; + } + } + + auto end = NetMonotonic::TimeUs(); + auto pollTime = end - start; + if (QUEUE_TIMEOUT_US < pollTime) { + NN_LOG_ERROR("Within a limited time, failed to get remote mr fds as queue empty in channel " << mName << + " " << mId); + flag = false; + break; + } + + usleep(NN_NO128); + } while (flag); + + return SH_TIME_OUT; + } + + inline static uint32_t GetQueueCap() noexcept + { + /* set fd queue size */ + char *envSize = ::getenv("HCOM_SHM_EXCHANGE_FD_QUEUE_SIZE"); + + if (envSize != nullptr) { + long value = 0; + if (NetFunc::NN_Stol(envSize, value) && value >= NN_NO10 && value <= NN_NO256) { + NN_LOG_INFO("Successfully to set the fd exchange queue capacity to " << value); + return value; + } + NN_LOG_ERROR("Invalid setting 'HCOM_SHM_EXCHANGE_FD_QUEUE_SIZE' which should be 10~256, restored fd " + "exchange queue capacity to default value 10"); + } + + return NN_NO10; + } + + HResult AddUserFds(int fds[], uint32_t len) + { + std::unique_lock guard(mUserFdQueueMutex); + if (mUserFdQueue.size() + len > gQueueSizeCap) { + NN_LOG_ERROR("Failed to add fd in the queue, the queue size is exceeded in channel " << mName << " " << + mId); + return SH_FDS_QUEUE_FULL; + } + + for (uint32_t i = 0; i < len; i++) { + mUserFdQueue.push(fds[i]); + } + + return SH_OK; + } + + HResult RemoveUserFds(int fds[], uint32_t len, int32_t timeoutSec) + { + uint32_t timeoutUs = QUEUE_TIMEOUT_US; + if (timeoutSec > 0) { + timeoutUs = static_cast(timeoutSec) * NN_NO1000000; + } + bool flag = true; + uint32_t index = 0; + auto start = NetMonotonic::TimeUs(); + do { + { + std::lock_guard guard(mUserFdQueueMutex); + while (!mUserFdQueue.empty() && index < len) { + fds[index] = mUserFdQueue.front(); + mUserFdQueue.pop(); + index++; + } + if (index == len) { + return SH_OK; + } + } + + auto end = NetMonotonic::TimeUs(); + auto pollTime = end - start; + if (timeoutUs < pollTime) { + NN_LOG_ERROR("Failed to remove user fds in queue of channel " << mName << " " << mId << + " as timeout " << timeoutUs << " us is exceeded"); + flag = false; + break; + } + + usleep(NN_NO128); + } while (flag); + + return SH_TIME_OUT; + } + + inline uint32_t GetSendDCBuckSize() const + { + return mSendDCBuckSize; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +public: + // 1.ensure the order of send header firstly and send fds secondly in multi thread + // 2.add the same lock in GET_MR_FD\SEND_MR_FD\EXCHANGE_USER_FD process to ensure order of header + + // fds in diff thread, and then keeper thread receive order is true + std::mutex mFdMutex; + +private: + HResult ValidateExchangeInfo(const ShmConnExchangeInfo &info); + +private: + ShmEventQueue *mPeerEventQueue = nullptr; /* event queue of peer worker */ + ShmDataChannel *mDataChannel = nullptr; /* channel for data transfer */ + ShmDataChannel *mPeerDataChannel = nullptr; /* peer data channel for reading data */ + uint64_t mPeerChId = 0; /* peer channel id */ + uint64_t mPeerChAddress = 0; /* peer channel address */ + uint64_t mId = 0; /* id of this channel */ + uint64_t mUpCtx = 0; /* up context */ + uint64_t mUpCtx1 = 0; /* up context 1 */ + bool mPeerEventPooling = true; /* peer is event pooling or not */ + NetSpinLock mLock; /* spin lock of post ctx */ + ShmOpContextInfo mCtxPosted {}; /* one side done ctx double linked list */ + ShmOpCompInfo mCompPosted {}; /* two side complete post ctx double linked list */ + uint32_t mCtxPostedCount { 0 }; /* one side done ctx count */ + uint32_t mCompPostedCount { 0 }; /* two side complete post ctx count */ + + int mFd = -1; /* uds fd to transfer file descriptor of shm files, between client and server */ + std::string mUdsName; + + std::string mName; /* name of channel */ + std::string mPeerIpPort; /* peer ip port */ + uint32_t mSendDCBuckSize = NN_NO256; /* buck size of data channel for send */ + uint16_t mSendDCBuckCount = NN_NO16; /* buck count of data channel for send */ + std::mutex mMrFdQueueMutex; /* lock for add/remove in exchange mr fd queue */ + std::queue mMrFdQueue; /* exchange one side mr fd queue */ + std::mutex mUserFdQueueMutex; /* lock for add/remove in exchange user fd queue */ + std::queue mUserFdQueue; /* exchange user fd queue */ + UBSHcomNetAtomicState mState { CH_NEW }; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; + +inline void ShmChannel::AddOpCompInfo(ShmOpCompInfo *compInfo) +{ + if (NN_LIKELY(compInfo != nullptr)) { + // bi-direction linked list, 4 step to insert to head + compInfo->prev = &mCompPosted; + mLock.Lock(); + // head -><- first -><- second -><- third -> nullptr + // insert into the head place + compInfo->next = mCompPosted.next; + if (mCompPosted.next != nullptr) { + mCompPosted.next->prev = compInfo; + } + mCompPosted.next = compInfo; + ++mCompPostedCount; + mLock.Unlock(); + } +} + +inline HResult ShmChannel::RemoveOpCompInfo(ShmOpCompInfo *compInfo) +{ + mLock.Lock(); + if (mCompPostedCount == 0) { + mLock.Unlock(); + return SH_OP_CTX_REMOVED; + } + + if (NN_LIKELY(compInfo != nullptr)) { + // bi-direction linked list, 4 step to remove one + // repeat remove + if (compInfo->prev == nullptr) { + mLock.Unlock(); + return SH_OP_CTX_REMOVED; + } + + // head-><- first -><- second -><- third -> nullptr + compInfo->prev->next = compInfo->next; + if (compInfo->next != nullptr) { + compInfo->next->prev = compInfo->prev; + } + --mCompPostedCount; + compInfo->prev = nullptr; + compInfo->next = nullptr; + } + mLock.Unlock(); + return SH_OK; +} + +inline void ShmChannel::AddOpCtxInfo(ShmOpContextInfo *shmCtxInfo) +{ + if (NN_LIKELY(shmCtxInfo != nullptr)) { + // bi-direction linked list, 4 step to insert to head + shmCtxInfo->prev = &mCtxPosted; + mLock.Lock(); + // head -><- first -><- second -><- third -> nullptr + // insert into the head place + shmCtxInfo->next = mCtxPosted.next; + if (mCtxPosted.next != nullptr) { + mCtxPosted.next->prev = shmCtxInfo; + } + mCtxPosted.next = shmCtxInfo; + ++mCtxPostedCount; + mLock.Unlock(); + } +} + +inline HResult ShmChannel::RemoveOpCtxInfo(ShmOpContextInfo *ctxInfo) +{ + mLock.Lock(); + if (mCtxPostedCount == 0) { + mLock.Unlock(); + return SH_OP_CTX_REMOVED; + } + + if (NN_LIKELY(ctxInfo != nullptr)) { + // bi-direction linked list, 4 step to remove one + // repeat remove + if (ctxInfo->prev == nullptr) { + mLock.Unlock(); + return SH_OP_CTX_REMOVED; + } + + // head-><- first -><- second -><- third -> nullptr + ctxInfo->prev->next = ctxInfo->next; + if (ctxInfo->next != nullptr) { + ctxInfo->next->prev = ctxInfo->prev; + } + --mCtxPostedCount; + + ctxInfo->prev = nullptr; + ctxInfo->next = nullptr; + } + mLock.Unlock(); + return SH_OK; +} + +inline void ShmChannel::GetCtxPosted(ShmOpContextInfo *&remaining) +{ + mLock.Lock(); + // head -> first -><- second -><- third -> nullptr + remaining = mCtxPosted.next; + mCtxPosted.next = nullptr; + mCtxPostedCount = 0; + mLock.Unlock(); +} + +inline void ShmChannel::GetCompPosted(ShmOpCompInfo *&remaining) +{ + mLock.Lock(); + // head -> first -><- second -><- third -> nullptr + remaining = mCompPosted.next; + mCompPosted.next = nullptr; + mCompPostedCount = 0; + mLock.Unlock(); +} + +inline HResult ShmChannel::EQEventEnqueue(ShmEvent &event) +{ + NN_ASSERT_LOG_RETURN(mPeerEventQueue != nullptr, SH_NOT_INITIALIZED) + + if (mPeerEventPooling) { + return mPeerEventQueue->EnqueueAndNotify(event); + } + + return mPeerEventQueue->Enqueue(event); +} + +inline HResult ShmChannel::GetRemoteMrFds(uint32_t remoteKey, int &rfd) +{ + ShmChKeeperMsgHeader header {}; + header.msgType = ShmChKeeperMsgType::GET_MR_FD; + header.dataSize = sizeof(remoteKey); + + std::lock_guard guard(mFdMutex); + if (NN_UNLIKELY(::send(UdsFD(), &header, sizeof(ShmChKeeperMsgHeader), MSG_NOSIGNAL) <= 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to notify exchange mr fd info, as" + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SH_ERROR; + } + + if (NN_UNLIKELY(::send(UdsFD(), &remoteKey, sizeof(remoteKey), MSG_NOSIGNAL) <= 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get remote mr fds for key:" << remoteKey << " as" + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SH_ERROR; + } + + return RemoveMrFd(rfd); +} + +inline HResult ShmChannel::GetRemoteMrHandle(uint32_t remoteKey, uint32_t bufSize, ShmMRHandleMap &mrHandleMap) +{ + int rfd = 0; + auto result = GetRemoteMrFds(remoteKey, rfd); + if (NN_UNLIKELY(result != SH_OK)) { + NN_LOG_INFO("Get remote mr fd failed, result is:" << result); + return result; + } + + std::string tmpName = "tmp_mr"; + if (mrHandleMap.GetFromRemoteMap(rfd) == nullptr) { + auto remoteHandle = new (std::nothrow) ShmHandle(mName, tmpName, rfd, bufSize, rfd, false); + if (remoteHandle == nullptr) { + NN_LOG_ERROR("Failed to new remote shm handle for shm data channel " << mName << + ", probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + result = remoteHandle->Initialize(); + if (NN_UNLIKELY(result != NN_OK)) { + delete remoteHandle; + return result; + } + mrHandleMap.AddToRemoteMap(remoteKey, remoteHandle); + } + + return NN_OK; +} + +inline bool ShmChannel::FillExchangeInfo(ShmConnExchangeInfo &info) const +{ + if (NN_LIKELY(mDataChannel != nullptr)) { + info.channelId = mId; + info.dcBuckSize = mDataChannel->BuckSize(); + info.dcBuckCount = mDataChannel->BuckCount(); + info.channelAddress = reinterpret_cast(this); + info.channelFd = mDataChannel->GetShmHandle()->Fd(); + return info.SetDCName(mDataChannel->Filepath()); + } + + return false; +} + +inline HResult ShmChannel::GetPeerDataAddressByOffset(uint64_t offset, uintptr_t &address) +{ + NN_ASSERT_LOG_RETURN(mPeerDataChannel != nullptr, SH_NOT_INITIALIZED) + return mPeerDataChannel->GetAddressByOffset(offset, address); +} +} +} + +#endif // OCK_HCOM_SHM_CHANNEL_H diff --git a/src/transport/shm/shm_channel_keeper.cpp b/src/transport/shm/shm_channel_keeper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b67f268947be0738e3c11695bb98aecbb9349e7 --- /dev/null +++ b/src/transport/shm/shm_channel_keeper.cpp @@ -0,0 +1,271 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include + +#include "shm_channel.h" +#include "shm_handle_fds.h" +#include "shm_channel_keeper.h" + +namespace ock { +namespace hcom { +constexpr uint32_t MAX_EPOLL_SIZE = 4096 * 4; // 4096 hosts, 4 card per host +constexpr uint32_t MAX_EPOLL_WAIT_EVENTS = 16; +constexpr uint32_t EPOLL_WAIT_TIMEOUT = 1000; // 1 second + +HResult ShmChannelKeeper::Start() +{ + std::lock_guard guard(mMutex); + if (mStarted) { + return SH_OK; + } + + if (mMsgHandler == nullptr) { + NN_LOG_ERROR("Message handler is not set in ShmChannelKeeper " << mName); + return SH_PARAM_INVALID; + } + + mEpollHandle = epoll_create(MAX_EPOLL_SIZE); + if (mEpollHandle < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create epoll in ShmChannelKeeper " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SH_CREATE_KEEPER_EPOLL_FAILURE; + } + + std::thread tmpThread(&ShmChannelKeeper::RunInThread, this); + mEPollThread = std::move(tmpThread); + std::string threadName = "ShmChKeeper" + std::to_string(mDriverIndex); + if (pthread_setname_np(mEPollThread.native_handle(), threadName.c_str()) != 0) { + NN_LOG_WARN("Failed to set name of ShmChannelKeeper working thread to " << threadName); + } + + while (!mThreadStarted.load()) { + usleep(NN_NO10); + } + + mStarted = true; + return SH_OK; +} + +void ShmChannelKeeper::Stop() +{ + std::lock_guard guard(mMutex); + if (!mStarted) { + NN_LOG_WARN("ShmChannelKeeper " << mName << " has not been started"); + return; + } + + StopInner(); + + mStarted = false; +} + +void ShmChannelKeeper::StopInner() +{ + mNeedStop = true; + if (mEPollThread.native_handle()) { + mEPollThread.join(); + } + + if (mEpollHandle != -1) { + NetFunc::NN_SafeCloseFd(mEpollHandle); + mEpollHandle = -1; + } +} + +HResult ShmChannelKeeper::AddShmChannel(const ShmChannelPtr &ch) +{ + NN_ASSERT_LOG_RETURN(ch.Get() != nullptr, SH_PARAM_INVALID) + NN_ASSERT_LOG_RETURN(ch->UdsFD() != -1, SH_PARAM_INVALID) + + std::lock_guard guard(mChMapMutex); + auto iter = mShmChannels.find(ch->Id()); + if (iter != mShmChannels.end()) { + NN_LOG_ERROR("Failed to add channel " << ch->Id() << " into ShmChannelKeeper " << mName << + " as already existed, remove it firstly."); + return SH_DUP_CH_IN_KEEPER; + } + + struct epoll_event ev {}; + ev.events = EPOLLIN; + ev.data.ptr = ch.Get(); + if (epoll_ctl(mEpollHandle, EPOLL_CTL_ADD, ch->UdsFD(), &ev) != 0) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to add channel " << ch->Id() << " into ShmChannelKeeper " << mName << + " as epoll add failed, errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SH_CH_ADD_FAILURE_IN_KEEPER; + } + + if (!mShmChannels.emplace(ch->Id(), ch).second) { + NN_LOG_ERROR("Failed to add channel " << ch->Id() << " into ShmChannelKeeper " << mName); + return SH_CH_ADD_FAILURE_IN_KEEPER; + } + + return SH_OK; +} + +HResult ShmChannelKeeper::RemoveShmChannel(uint64_t id) +{ + std::lock_guard guard(mChMapMutex); + auto iter = mShmChannels.find(id); + if (iter == mShmChannels.end()) { + NN_LOG_ERROR("No channel with " << id << " found in ShmChannelKeeper " << mName); + return SH_CH_REMOVE_FAILURE_IN_KEEPER; + } + + if (epoll_ctl(mEpollHandle, EPOLL_CTL_DEL, iter->second->UdsFD(), nullptr) != 0) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to delete from epoll handle for channel " << id << " in ShmChannelKeeper " << mName << + ", errno:" << errno << " error:" << NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SH_CH_REMOVE_FAILURE_IN_KEEPER; + } + + iter->second->State().CAS(CH_NEW, CH_BROKEN); + mShmChannels.erase(iter); + + return SH_OK; +} + +void ShmChannelKeeper::RunInThread() +{ + mThreadStarted.store(true); + NN_LOG_INFO("Shm channelKeeper " << mName << " working thread started"); + + struct epoll_event ev[MAX_EPOLL_WAIT_EVENTS]; + while (!mNeedStop) { + try { + // do epoll wait + int count = epoll_wait(mEpollHandle, ev, MAX_EPOLL_WAIT_EVENTS, EPOLL_WAIT_TIMEOUT); + if (count > 0) { + /* there are events, handle it */ + TRACE_DELAY_BEGIN(SHM_THREAD_CHANNEL_KEEPER); + HandleEpollEvent(count, ev); + TRACE_DELAY_END(SHM_THREAD_CHANNEL_KEEPER, 0); + } else if (count == 0) { + continue; + } else if (errno == EINTR) { + NN_LOG_WARN("Got errno EINTR in channelKeeper " << mName); + continue; + } else { + /* error happens */ + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to do epoll_wait in channelKeeper " << mName << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + continue; + } + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Got runtime error in ShmChannelKeeper::RunInThread '" << ex.what() << + "', ignore and continue"); + } catch (...) { + NN_LOG_WARN("Got unknown error in ShmChannelKeeper::RunInThread, ignore and continue"); + } + } + + NN_LOG_INFO("Shm channelKeeper " << mName << " working thread exiting"); +} + +HResult ShmChannelKeeper::ExchangeFdProcess(ShmChKeeperMsgHeader &header, const ShmChannelPtr &ch) +{ + HResult result = SH_OK; + if (header.msgType == SEND_MR_FD) { + int fds[NN_NO4] = {0}; + result = ShmHandleFds::ReceiveMsgFds(ch->UdsFD(), fds, NN_NO4); + if (NN_UNLIKELY(result != SH_OK)) { + NN_LOG_ERROR("Failed to receive the peer fd from the channel" << ch->Id()); + return result; + } + + result = ch->AddMrFd(fds[NN_NO0]); + if (NN_UNLIKELY(result != SH_OK)) { + NN_LOG_ERROR("Successfully received mr to peer fd:" << fds[NN_NO0] << ", but the channel " << ch->Id() << + " cannot add peer fd to the fd queue, result is " << result); + return result; + } + } else if (header.msgType == EXCHANGE_USER_FD) { + int fds[NN_NO4] = {0}; + result = ShmHandleFds::ReceiveMsgFds(ch->UdsFD(), fds, NN_NO4); + if (NN_UNLIKELY(result != SH_OK)) { + NN_LOG_ERROR("Failed to receive the peer fds from the channel" << ch->Id()); + return result; + } + + if (header.dataSize > NN_NO4) { + NN_LOG_ERROR("Fd length " << header.dataSize << " is invalid "); + return SH_PARAM_INVALID; + } + + result = ch->AddUserFds(fds, header.dataSize); + if (NN_UNLIKELY(result != SH_OK)) { + NN_LOG_ERROR("Failed to add fds to channel " << ch->Id() << " fd queue, result is " << result); + return result; + } + } + return result; +} + +void ShmChannelKeeper::HandleEpollEvent(uint32_t eventCount, struct epoll_event *events) +{ + if (NN_UNLIKELY(events == nullptr)) { + return; + } + + ShmChKeeperMsgHeader header {}; + ShmChannelPtr ch; + + for (uint32_t i = 0; i < eventCount; ++i) { + if (!(events[i].events & EPOLLIN)) { + continue; + } + + ch = static_cast(events[i].data.ptr); + /* read header in blocking */ + auto result = ::read(ch->UdsFD(), &header, sizeof(ShmChKeeperMsgHeader)); + if (result <= 0) { + /* reset by peer */ + header.msgType = ShmChKeeperMsgType::RESET_BY_PEER; + header.dataSize = 0; + + (void)RemoveShmChannel(ch->Id()); + } else if (static_cast(result) != sizeof(ShmChKeeperMsgHeader)) { + NN_LOG_WARN("Un-reachable path"); + continue; + } else { + if (header.msgType < GET_MR_FD || header.msgType > EXCHANGE_USER_FD) { + NN_LOG_WARN("Un-reachable path, msgType is incorrect"); + continue; + } + } + + if (header.msgType == SEND_MR_FD || header.msgType == EXCHANGE_USER_FD) { + if (NN_LIKELY(ExchangeFdProcess(header, ch) != SH_PEER_FD_ERROR)) { + continue; + } + // peer fd error should process ep error + header.msgType = ShmChKeeperMsgType::RESET_BY_PEER; + header.dataSize = 0; + (void)RemoveShmChannel(ch->Id()); + } + + try { + mMsgHandler(header, ch); + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Got runtime incorrect signal in mMsgHandler " << ex.what() << " in , ignored"); + } catch (...) { + NN_LOG_WARN("Got unknown signal in mMsgHandler , ignored"); + } + } +} +} +} \ No newline at end of file diff --git a/src/transport/shm/shm_channel_keeper.h b/src/transport/shm/shm_channel_keeper.h new file mode 100644 index 0000000000000000000000000000000000000000..d8de3da299f764a6d654574e9cdea554eccb767c --- /dev/null +++ b/src/transport/shm/shm_channel_keeper.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_SHM_CHANNEL_KEEPER_H +#define HCOM_SHM_CHANNEL_KEEPER_H + +#include +#include + +#include "hcom_def.h" +#include "shm_common.h" + +namespace ock { +namespace hcom { +/* + * Keeper message types + */ +enum ShmChKeeperMsgType : uint16_t { + RESET_BY_PEER = 0, + ACTIVE_CLOSE_CH = 1, + EXCHANGE_CH_FD = 2, + GET_MR_FD = 3, + SEND_MR_FD = 4, + EXCHANGE_USER_FD = 5, +}; + +/* + * Keeper message header + */ +struct ShmChKeeperMsgHeader { + uint16_t msgType = 0; /* message type */ + uint16_t dataSize = 0; /* data size */ +} __attribute__((packed)); + +/* + * Callback function when received message + */ +using NewKeeperMsgHandler = std::function; + +/* + * ShmChannelKeeper is for polling uds event including: + * 1 reset by peer, for example peer process crashed + * 2 close by peer actively + * 3 exchange fd for shm files + */ +class ShmChannelKeeper { +public: + ShmChannelKeeper(const std::string &name, uint16_t driverIndex) + : mDriverIndex(driverIndex), mName(name + std::to_string(driverIndex)) + { + OBJ_GC_INCREASE(ShmChannelKeeper); + } + + ~ShmChannelKeeper() + { + OBJ_GC_DECREASE(ShmChannelKeeper); + } + + HResult Start(); + void Stop(); + + HResult AddShmChannel(const ShmChannelPtr &ch); + HResult RemoveShmChannel(uint64_t id); + + inline void RegisterMsgHandler(const NewKeeperMsgHandler &handler) + { + mMsgHandler = handler; + } + +private: + void StopInner(); + void RunInThread(); + void HandleEpollEvent(uint32_t eventCount, struct epoll_event *events); + HResult ExchangeFdProcess(ShmChKeeperMsgHeader &header, const ShmChannelPtr &ch); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS +private: + int mEpollHandle = -1; + NewKeeperMsgHandler mMsgHandler = nullptr; + std::map mShmChannels; + + std::mutex mMutex; + std::mutex mChMapMutex; + bool mStarted = false; + bool mNeedStop = false; + uint16_t mDriverIndex = 0; + std::thread mEPollThread; + std::atomic_bool mThreadStarted { false }; + std::string mName; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} + +#endif // HCOM_SHM_CHANNEL_KEEPER_H diff --git a/src/transport/shm/shm_common.h b/src/transport/shm/shm_common.h new file mode 100644 index 0000000000000000000000000000000000000000..4668dc20b1bdb51b279e37a17f0bb0560634d09c --- /dev/null +++ b/src/transport/shm/shm_common.h @@ -0,0 +1,274 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SHM_COMMON_H +#define OCK_HCOM_SHM_COMMON_H + +#include +#include +#include +#include + +#include "securec.h" +#include "hcom.h" +#include "net_ctx_info_pool.h" +#include "net_delay_release_timer.h" +#include "net_mem_pool_fixed.h" +#include "shm_lock_guard.h" +#include "shm_mr_handle_map.h" + +namespace ock { +namespace hcom { +class ShmHandle; +class ShmDataChannel; +template class ShmQueue; +class ShmChannel; +class ShmWorker; +class ShmChannelKeeper; +class ShmSyncEndpoint; +class ShmHandleFds; + +using ShmHandlePtr = NetRef; +using ShmDataChannelPtr = NetRef; +using ShmChannelPtr = NetRef; +using ShmChannelKeeperPtr = NetRef; +using DelayReleaseTimerPtr = NetRef; +using ShmSyncEndpointPtr = NetRef; + +enum ShmPollingMode : uint8_t { + SHM_EVENT_POLLING = 0, + SHM_BUSY_POLLING = 1, +}; + +inline std::string &ShmPollingModeToStr(ShmPollingMode v) +{ + static std::string STRINGS[NN_NO2] = {"event", "busy"}; + return STRINGS[v]; +} + +/* + * @brief exchange info for uds + */ +struct ShmConnExchangeInfo { + char qName[NN_NO32] {}; + char dcName[NN_NO64] {}; + uint64_t channelId = 0; + int channelFd = 0; + uintptr_t channelAddress = 0; + uint32_t qCapacity = 0; + int queueFd = 0; + uint32_t dcBuckSize = 0; + uint32_t dcBuckCount = 0; + uint16_t payLoadSize = 0; + ShmPollingMode mode = SHM_EVENT_POLLING; + + inline bool SetQueueName(const std::string &v) + { + NN_SET_CHAR_ARRAY_FROM_STRING(qName, v); + } + + inline std::string GetQueueName() const + { + return NN_CHAR_ARRAY_TO_STRING(qName); + } + + inline bool SetDCName(const std::string &v) + { + NN_SET_CHAR_ARRAY_FROM_STRING(dcName, v); + } + + inline std::string GetDCName() const + { + return NN_CHAR_ARRAY_TO_STRING(dcName); + } + + inline std::string ToString() const + { + std::ostringstream oss; + oss << "qName: " << GetQueueName() << ", dcName " << dcName << ", chId " << channelId << ", qCap: " << + qCapacity << ", dcBuckSize: " << dcBuckSize << ", dcBuckCnt: " << dcBuckCount; + return oss.str(); + } +} __attribute__((packed)); + +using ShmIdleHandler = UBSHcomNetDriverIdleHandler; + +/* + * @brief shm operation context + * make sure it is 64bits which equal to one cache line of CPU + */ +struct ShmOpContextInfo { + enum ShmOpType : uint8_t { + SH_SEND = 0, + SH_RECEIVE = 1, + SH_WRITE = 2, + SH_READ = 3, + SH_SGL_WRITE = 4, + SH_SGL_READ = 5, + SH_SEND_RAW = 6, + SH_RECEIVE_RAW = 7, + SH_SEND_RAW_SGL = 8, + }; + + enum ShmErrorType : uint8_t { + SH_NO_ERROR = 0, + SH_OPERATE_FAILURE = 1, + SH_RESET_BY_PEER = 2, + SH_OUT_OF_MEM = 3, + SH_TIMEOUT = 4, + }; + + ShmOpContextInfo *prev = nullptr; /* previous one for bi-direct link */ + ShmOpContextInfo *next = nullptr; /* next one for bi-direct link */ + ShmChannel *channel = nullptr; /* shm channel */ + uintptr_t dataAddress = 0; /* data address */ + uint32_t dataSize = 0; /* data size */ + uint32_t lKey = 0; /* lKey of read write MR */ + uintptr_t mrMemAddr = 0; /* address of read write MR */ + ShmOpType opType = SH_RECEIVE; /* receive by default */ + ShmErrorType errType = SH_NO_ERROR; /* by default no error */ + uint16_t upCtxSize = 0; /* up context size */ + char upCtx[NN_NO16] = {}; /* 16 bytes for upper context */ + + ShmOpContextInfo() = default; + + ShmOpContextInfo(ShmChannel *ch, uintptr_t da, uint32_t ds, ShmOpType op, ShmErrorType et) + : channel(ch), dataAddress(da), dataSize(ds), opType(op), errType(et) + {} + + static inline NResult GetNResult(ShmErrorType opResult) + { + switch (opResult) { + case ShmErrorType::SH_NO_ERROR: + return NN_OK; + case ShmErrorType::SH_TIMEOUT: + return NN_MSG_TIMEOUT; + case ShmErrorType::SH_OUT_OF_MEM: + return NN_MALLOC_FAILED; + default: + return NN_MSG_ERROR; + } + } +} __attribute__((packed)); + +struct ShmSglOpContextInfo { + UBSHcomNetTransSgeIov iov[NET_SGE_MAX_IOV] = {}; + uint16_t iovCount = 0; // max count:NN_NO4 + uint16_t upCtxSize = 0; + char upCtx[NN_NO16] = {}; // 16 bytes for upper context + NResult result = NN_OK; +} __attribute__((packed)); + +struct ShmOpCompInfo { + UBSHcomNetTransHeader header {}; + ShmChannel *channel = nullptr; /* shm channel */ + UBSHcomNetTransRequest request {}; + uint16_t upCtxSize = 0; /* up context size */ + char upCtx[NN_NO16] = {}; /* 16 bytes for upper context */ + ShmOpCompInfo *prev = nullptr; /* previous one for bi-direct link */ + ShmOpCompInfo *next = nullptr; /* next one for bi-direct link */ + ShmOpContextInfo::ShmOpType opType = ShmOpContextInfo::ShmOpType::SH_SEND; + ShmOpContextInfo::ShmErrorType errType = ShmOpContextInfo::ShmErrorType::SH_NO_ERROR; /* by default no error */ +} __attribute__((packed)); + +struct ShmSglOpCompInfo { + ShmSglOpContextInfo *ctx = nullptr; + + ShmSglOpCompInfo() = default; + explicit ShmSglOpCompInfo(ShmSglOpContextInfo *sglCtx) : ctx(sglCtx) {} +} __attribute__((packed)); + +/* + * @brief Shm event struct + */ +struct ShmEvent { + uint32_t immData = 0; /* imm data */ + uint32_t dataSize = 0; /* size of data */ + uint64_t dataOffset = 0; /* offset of the data based address of sender */ + uint64_t channelId = 0; /* sender channel id */ + uint64_t peerChannelId = 0; /* peer channel id, i.e. receiver channel id */ + uintptr_t peerChannelAddress = 0; /* channel address */ + ShmChannel *shmChannel = nullptr; /* sender ch address */ + ShmOpContextInfo::ShmOpType opType = ShmOpContextInfo::ShmOpType::SH_SEND; /* op type */ + + ShmEvent() = default; + + ShmEvent(uint32_t s, uint32_t ds, uint64_t o, uint64_t myId, uint64_t pId, uintptr_t pa, uint8_t op) + : immData(s), + dataSize(ds), + dataOffset(o), + channelId(myId), + peerChannelId(pId), + peerChannelAddress(pa), + opType(static_cast(op)) + {} + + ShmEvent(uintptr_t pa, uint8_t op) : peerChannelAddress(pa), opType(static_cast(op)) {} + + void SetChannel(ShmChannel *channel) + { + shmChannel = channel; + } + + std::string ToString() const + { + std::ostringstream oss; + oss << "imm-data " << immData << ", ch-id " << channelId << ", peer-ch-id " << peerChannelId << + ", peer-channel-address: " << peerChannelAddress << ", data-offset " << dataOffset << ", data-size: " << + dataSize << ", opType: " << opType; + return oss.str(); + } +}; + +enum ShmChannelState : uint8_t { + CH_NEW = 0, + CH_BROKEN = 1, +}; + +/* + * @brief Event queue for both busy polling and event polling + */ +using ShmEventQueue = ShmQueue; +using ShmEventQueuePtr = NetRef; + +const std::string SHM_F_EVENT_QUEUE_PREFIX = "hcom-eq"; +const std::string SHM_F_DC_PREFIX = "hcom-dc"; + +using HResult = int32_t; +enum ShCode { + SH_OK = 0, + SH_ERROR = 300, + SH_PARAM_INVALID = 301, + SH_MEMORY_ALLOCATE_FAILED = 302, + SH_NEW_OBJECT_FAILED = 303, + SH_FILE_OP_FAILED = 304, + SH_NOT_INITIALIZED = 305, + SH_TIME_OUT = 306, + SH_OP_CTX_FULL = 307, + SH_CH_BROKEN = 308, + SH_CREATE_KEEPER_EPOLL_FAILURE = 309, + SH_DUP_CH_IN_KEEPER = 310, + SH_CH_ADD_FAILURE_IN_KEEPER = 311, + SH_CH_REMOVE_FAILURE_IN_KEEPER = 312, + SH_RETRY_FULL = 313, + SH_SEND_COMPLETION_CALLBACK_FAILURE = 314, + SH_FDS_QUEUE_FULL = 315, + SH_PEER_FD_ERROR = 316, + SH_OP_CTX_REMOVED = 317, +}; + +using ShmOpCompInfoPool = OpContextInfoPool; +using ShmOpContextInfoPool = OpContextInfoPool; +using ShmSglContextInfoPool = OpContextInfoPool; +} +} + +#endif // OCK_HCOM_SHM_COMMON_H diff --git a/src/transport/shm/shm_composed_endpoint.cpp b/src/transport/shm/shm_composed_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b394fc64588b1d779969da47ebc3f8e525c1a90b --- /dev/null +++ b/src/transport/shm/shm_composed_endpoint.cpp @@ -0,0 +1,438 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "shm_composed_endpoint.h" +namespace ock { +namespace hcom { +HResult ShmSyncEndpoint::Create(const std::string &name, uint16_t eventQueueLength, ShmPollingMode mode, + ShmSyncEndpointPtr &ep) +{ + ShmSyncEndpointPtr tmpEp = new (std::nothrow) ShmSyncEndpoint(name, eventQueueLength, mode); + if (NN_UNLIKELY(tmpEp.Get() == nullptr)) { + NN_LOG_ERROR("Failed to create RDMASyncClientEndPoint, probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + auto result = tmpEp->CreateEventQueue(); + if (NN_UNLIKELY(result != SH_OK)) { + return result; + } + + ep.Set(tmpEp.Get()); + + return result; +} + +HResult ShmSyncEndpoint::CreateEventQueue() +{ + /* get data size */ + uint64_t dataSize = ShmEventQueue::MemSize(mEventQueueLength); + + /* create handle for event queue */ + HResult result = SH_OK; + ShmHandlePtr tmpHandle = new (std::nothrow) ShmHandle(mName, SHM_F_EVENT_QUEUE_PREFIX, 1, dataSize, true); + if (NN_UNLIKELY(tmpHandle.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new shm handle for sync ep " << mName << ", probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + /* create and initialize event queue */ + ShmEventQueuePtr tmpQueue = new (std::nothrow) ShmEventQueue(mName, mEventQueueLength, tmpHandle); + if (NN_UNLIKELY(tmpQueue.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new shm event queue for sync ep " << mName); + return SH_NEW_OBJECT_FAILED; + } + + if ((result = tmpQueue->Initialize()) != SH_OK) { + NN_LOG_ERROR("Failed to initialize shm event queue"); + return result; + } + + /* assign member variables */ + mHandleEventQueue.Set(tmpHandle.Get()); + mEventQueue = tmpQueue.Get(); + mEventQueue->IncreaseRef(); + + return result; +} + +HResult ShmSyncEndpoint::PostSend(ShmChannel *ch, const UBSHcomNetTransRequest &req, uint64_t offset, uint32_t immData, + int32_t defaultTimeout) +{ + if (NN_UNLIKELY(req.upCtxSize > sizeof(ShmOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to PostSend with ShmWorker " << mName << " as upCtxSize > " << + sizeof(ShmOpContextInfo::upCtx)); + return SH_PARAM_INVALID; + } + mDefaultTimeout = defaultTimeout; + + /* get op completion ctx */ + static thread_local ShmOpCompInfo ctx {}; + if (immData == 0) { + ctx.header = *(reinterpret_cast(req.lAddress)); + } + ctx.header.immData = immData; + ctx.channel = ch; + ctx.request = req; + ctx.opType = immData == 0 ? ShmOpContextInfo::ShmOpType::SH_SEND : ShmOpContextInfo::ShmOpType::SH_SEND_RAW; + ch->IncreaseRef(); + + ShmEvent event(immData, req.size, offset, ch->Id(), ch->PeerChannelId(), ch->PeerChannelAddress(), + ShmOpContextInfo::ShmOpType::SH_RECEIVE); + auto result = ch->EQEventEnqueue(event); + if (NN_UNLIKELY(result != SH_OK)) { + if (result == ShmEventQueue::SHM_QUEUE_FULL) { + result = SH_RETRY_FULL; + } + ch->DecreaseRef(); + return result; + } + + /* send local event for send Waitcompletion */ + ShmEvent eventSent(reinterpret_cast(&ctx), ShmOpContextInfo::ShmOpType::SH_SEND); + uint64_t finishTime = GetFinishTime(); + bool flag = true; + do { + result = mEventQueue->EnqueueAndNotify(eventSent); + if (result == SH_OK) { + flag = false; + return SH_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + if (result == ShmEventQueue::SHM_QUEUE_FULL) { + result = SH_SEND_COMPLETION_CALLBACK_FAILURE; + } + + flag = false; + } while (flag); + + ch->DecreaseRef(); + return result; +} + +HResult ShmSyncEndpoint::FillSglCtx(ShmSglOpContextInfo *sglCtx, const UBSHcomNetTransSglRequest &sglReq) +{ + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Shm Failed to PostSendRawSgl with ShmWorker as no ctx left"); + return SH_PARAM_INVALID; + } + + sglCtx->result = SH_OK; + if (NN_UNLIKELY(memcpy_s(sglCtx->iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, sglReq.iov, + sizeof(UBSHcomNetTransSgeIov) * sglReq.iovCount) != SH_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return SH_PARAM_INVALID; + } + sglCtx->upCtxSize = sglReq.upCtxSize; + sglCtx->iovCount = sglReq.iovCount; + if (sglReq.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx->upCtx, NN_NO16, sglReq.upCtxData, sglReq.upCtxSize) != SH_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return SH_PARAM_INVALID; + } + } + + return SH_OK; +} + +HResult ShmSyncEndpoint::PostSendRawSgl(ShmChannel *ch, const UBSHcomNetTransRequest &req, + const UBSHcomNetTransSglRequest &sglReq, uint64_t offset, uint32_t immData, int32_t defaultTimeout) +{ + if (NN_UNLIKELY(sglReq.upCtxSize > sizeof(ShmOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to PostSend with sync endpoint " << mName << " as upCtxSize > " << + sizeof(ShmOpContextInfo::upCtx)); + return SH_PARAM_INVALID; + } + mDefaultTimeout = defaultTimeout; + + thread_local ShmSglOpContextInfo sglCtx {}; + auto result = FillSglCtx(&sglCtx, sglReq); + if (NN_UNLIKELY(result != SH_OK)) { + return result; + } + + /* get op completion ctx */ + thread_local ShmOpCompInfo ctx {}; + if (immData == 0) { + ctx.header = *(reinterpret_cast(req.lAddress)); + } + ctx.header.immData = immData; + ctx.channel = ch; + ctx.request = req; + ctx.opType = ShmOpContextInfo::ShmOpType::SH_SEND_RAW_SGL; + ctx.upCtxSize = sizeof(ShmSglOpCompInfo); + auto upCtx = reinterpret_cast(&ctx.upCtx); + upCtx->ctx = &sglCtx; + ch->IncreaseRef(); + + ShmEvent event(immData, req.size, offset, ch->Id(), ch->PeerChannelId(), ch->PeerChannelAddress(), + ShmOpContextInfo::ShmOpType::SH_RECEIVE); + result = ch->EQEventEnqueue(event); + if (NN_UNLIKELY(result != SH_OK)) { + if (result == ShmEventQueue::SHM_QUEUE_FULL) { + result = SH_RETRY_FULL; + } + ch->DecreaseRef(); + return result; + } + + /* send local event for send Waitcompletion */ + ShmEvent eventSent(reinterpret_cast(&ctx), ShmOpContextInfo::ShmOpType::SH_SEND); + uint64_t finishTime = GetFinishTime(); + bool flag = true; + do { + result = mEventQueue->EnqueueAndNotify(eventSent); + if (result == SH_OK) { + flag = false; + return SH_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + if (result == ShmEventQueue::SHM_QUEUE_FULL) { + result = SH_SEND_COMPLETION_CALLBACK_FAILURE; + } + + flag = false; + } while (flag); + + ch->DecreaseRef(); + return result; +} + +/* Single-side transport process */ +static inline HResult SyncReadWriteProcess(UBSHcomNetTransSgeIov &iov, ShmMRHandleMap &mrHandleMap, ShmChannel *ch, + ShmOpContextInfo::ShmOpType type) +{ + auto localMrHandle = mrHandleMap.GetFromLocalMap(static_cast(iov.lKey)); + if (NN_UNLIKELY(localMrHandle == nullptr)) { + NN_LOG_ERROR("Local mr handle is nullptr"); + return SH_ERROR; + } + + auto remoteMrHandle = mrHandleMap.GetFromRemoteMap(static_cast(iov.rKey)); + if (remoteMrHandle == nullptr) { + /* remote address not exist in local map, exchange mr fd and mmap before copy */ + auto result = ch->GetRemoteMrHandle(static_cast(iov.rKey), iov.size, mrHandleMap); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + remoteMrHandle = mrHandleMap.GetFromRemoteMap(static_cast(iov.rKey)); + } + + /* address has mmap already, copy directly */ + if (type == ShmOpContextInfo::ShmOpType::SH_READ || type == ShmOpContextInfo::ShmOpType::SH_SGL_READ) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(localMrHandle->ShmAddress()), localMrHandle->DataSize(), + reinterpret_cast(remoteMrHandle->ShmAddress()), iov.size) != SH_OK)) { + NN_LOG_ERROR("Failed to copy remoteMrHandle to localMrHandle"); + return SH_PARAM_INVALID; + } + } else if (type == ShmOpContextInfo::ShmOpType::SH_WRITE || type == ShmOpContextInfo::ShmOpType::SH_SGL_WRITE) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(remoteMrHandle->ShmAddress()), remoteMrHandle->DataSize(), + reinterpret_cast(localMrHandle->ShmAddress()), iov.size) != SH_OK)) { + NN_LOG_ERROR("Failed to copy localMrHandle to remoteMrHandle"); + return SH_PARAM_INVALID; + } + } else { + NN_LOG_ERROR("Failed to PostReadWrite unreachable path"); + return SH_ERROR; + } + + return SH_OK; +} + +HResult ShmSyncEndpoint::SendLocalEventForOneSideDone(ShmOpContextInfo *ctx, ShmOpContextInfo::ShmOpType type) +{ + /* send local event for send WaitCompletion */ + ShmEvent eventSent(reinterpret_cast(ctx), type); + + uint64_t finishTime = GetFinishTime(); + bool flag = true; + HResult result = SH_OK; + do { + result = mEventQueue->EnqueueAndNotify(eventSent); + if (result == SH_OK) { + flag = false; + return SH_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + if (result == ShmEventQueue::SHM_QUEUE_FULL) { + flag = false; + return SH_SEND_COMPLETION_CALLBACK_FAILURE; + } + + flag = false; + } while (flag); + + return result; +} + +HResult ShmSyncEndpoint::PostReadWrite(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap, + ShmOpContextInfo::ShmOpType type) +{ + /* upper caller need to make sure ch is not null */ + if (NN_UNLIKELY(req.upCtxSize > sizeof(ShmOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to PostSend with ShmWorker " << mName << " as upCtxSize > " << + sizeof(ShmOpContextInfo::upCtx)); + return SH_PARAM_INVALID; + } + + UBSHcomNetTransSgeIov iov {}; + iov.lKey = req.lKey; + iov.rKey = req.rKey; + iov.size = req.size; + + // Prevent integer truncation, safely converts uint64_t to uint32_t + if (NN_UNLIKELY(iov.lKey > UINT32_MAX || iov.rKey > UINT32_MAX)) { + NN_LOG_ERROR("Shm failed to PostReadWrite with RDMAWorker as Key is larger than uint32max, lkey" << + iov.lKey << " rKey " << iov.rKey); + return SH_PARAM_INVALID; + } + + HResult result = SH_OK; + if (NN_UNLIKELY((result = SyncReadWriteProcess(iov, mrHandleMap, ch, type)) != SH_OK)) { + NN_LOG_ERROR("Failed to read/write data to/from server"); + return result; + } + + /* get op ctx */ + thread_local ShmOpContextInfo ctx {}; + ctx.channel = ch; + ctx.mrMemAddr = req.lAddress; + ctx.lKey = static_cast(req.lKey); + ctx.dataSize = req.size; + ctx.opType = type; + ctx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(ctx.upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != SH_OK)) { + NN_LOG_ERROR("Failed to copy req to ctx"); + return SH_PARAM_INVALID; + } + } + ch->IncreaseRef(); + + /* send local event for one side completion */ + result = SendLocalEventForOneSideDone(&ctx, type); + if (NN_UNLIKELY(result != SH_OK && result != SH_CH_BROKEN)) { + ch->DecreaseRef(); + } + + return result; +} + +HResult ShmSyncEndpoint::PostReadWriteSgl(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, + ShmMRHandleMap &mrHandleMap, ShmOpContextInfo::ShmOpType type) +{ + /* upper caller need to make sure ch is not null */ + if (NN_UNLIKELY(req.upCtxSize > sizeof(ShmOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to PostReadWriteSgl type:" << type << " with ShmWorker " << mName << " as upCtxSize > " << + sizeof(ShmOpContextInfo::upCtx)); + return SH_PARAM_INVALID; + } + + HResult result = SH_OK; + for (auto i = 0; i < req.iovCount; i++) { + if (NN_UNLIKELY((result = SyncReadWriteProcess(req.iov[i], mrHandleMap, ch, type)) != SH_OK)) { + NN_LOG_ERROR("Failed to read/write sgl data to/from server"); + return result; + } + } + + /* get op ctx */ + thread_local ShmOpContextInfo ctx {}; + thread_local ShmSglOpContextInfo sglCtx {}; + result = FillSglCtx(&sglCtx, req); + if (NN_UNLIKELY(result != SH_OK)) { + return result; + } + + ctx.channel = ch; + ctx.mrMemAddr = 0; + ctx.lKey = 0; + ctx.dataSize = sizeof(UBSHcomNetTransSgeIov) * req.iovCount; + ctx.opType = type; + ctx.upCtxSize = sizeof(ShmSglOpCompInfo); + auto upCtx = reinterpret_cast(&ctx.upCtx); + upCtx->ctx = &sglCtx; + ch->IncreaseRef(); + + /* send local event for one side completion */ + result = SendLocalEventForOneSideDone(&ctx, type); + if (NN_UNLIKELY(result != SH_OK && result != SH_CH_BROKEN)) { + ch->DecreaseRef(); + } + + return result; +} + +HResult ShmSyncEndpoint::Receive(int32_t timeout, ShmOpContextInfo &opCtx, uint32_t &immData) +{ + HResult result = SH_OK; + ShmEvent event {}; + if (NN_UNLIKELY((result = DequeueEvent(timeout, event)) != SH_OK)) { + NN_LOG_ERROR("Failed to dequeue event"); + return result; + } + + auto *ch = reinterpret_cast(event.peerChannelAddress); + if (NN_UNLIKELY(ch == nullptr)) { + NN_LOG_ERROR("Got invalid event in EP " << mName << ", dropped it"); + return SH_ERROR; + } + + uintptr_t address = 0; + if (NN_UNLIKELY((result = ch->GetPeerDataAddressByOffset(event.dataOffset, address)) != SH_OK)) { + NN_LOG_ERROR("Got invalid event in worker " << mName << " as get data address failed, dropped it"); + return result; + } + + ShmOpContextInfo ctx(ch, address, event.dataSize, static_cast(event.opType), + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR); + opCtx = ctx; + immData = event.immData; + + return result; +} + +HResult ShmSyncEndpoint::DequeueEvent(int32_t timeout, ShmEvent &opEvent) +{ + int32_t timeoutInMs = TimeSecToMs(timeout); + HResult result = SH_OK; + ShmEvent event {}; + + if (mShmMode == SHM_BUSY_POLLING) { + auto start = NetMonotonic::TimeMs(); + do { + result = mEventQueue->Dequeue(event); + auto end = NetMonotonic::TimeMs(); + auto pollTime = end - start; + if (result == ShmEventQueue::SHM_QUEUE_EMPTY && timeoutInMs >= 0 && pollTime > (uint64_t)timeoutInMs) { + return SH_TIME_OUT; + } + } while (result == ShmEventQueue::SHM_QUEUE_EMPTY); + } else if (mShmMode == SHM_EVENT_POLLING) { + // stopping param is for worker polling case, it is not used in self polling case, just a placeholder + bool stopping = false; + result = mEventQueue->DequeueOrWait(event, stopping, timeoutInMs); + if (NN_UNLIKELY(result != SH_OK)) { + return result; + } + } + /* get event */ + opEvent = event; + return result; +} +} +} \ No newline at end of file diff --git a/src/transport/shm/shm_composed_endpoint.h b/src/transport/shm/shm_composed_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..db9d6929311f4f8222f5f7da14eef2f57aacb13e --- /dev/null +++ b/src/transport/shm/shm_composed_endpoint.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_SHM_COMPOSED_ENDPOINT_H +#define HCOM_SHM_COMPOSED_ENDPOINT_H + +#include "hcom.h" +#include "hcom_def.h" +#include "hcom_obj_statistics.h" +#include "net_monotonic.h" +#include "shm_common.h" +#include "shm_handle.h" +#include "shm_queue.h" +#include "shm_channel.h" +namespace ock { +namespace hcom { +class ShmSyncEndpoint { +public: + static HResult Create(const std::string &name, uint16_t eventQueueLength, ShmPollingMode mode, + ShmSyncEndpointPtr &ep); + +public: + ShmSyncEndpoint(const std::string &name, uint16_t eventQueueLength, ShmPollingMode mode) + : mName(name), mEventQueueLength(eventQueueLength), mShmMode(mode) + { + OBJ_GC_INCREASE(ShmSyncEndpoint); + } + + ~ShmSyncEndpoint() + { + if (mEventQueue != nullptr) { + mEventQueue->DecreaseRef(); + mEventQueue = nullptr; + } + OBJ_GC_DECREASE(ShmSyncEndpoint); + } + + HResult PostSend(ShmChannel *ch, const UBSHcomNetTransRequest &req, uint64_t offset, uint32_t immData, + int32_t defaultTimeout); + + HResult PostSendRawSgl(ShmChannel *ch, const UBSHcomNetTransRequest &req, const UBSHcomNetTransSglRequest &sglReq, + uint64_t offset, uint32_t immData, int32_t defaultTimeout); + HResult PostRead(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap); + HResult PostRead(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, ShmMRHandleMap &mrHandleMap); + HResult PostWrite(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap); + HResult PostWrite(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, ShmMRHandleMap &mrHandleMap); + HResult Receive(int32_t timeout, ShmOpContextInfo &opCtx, uint32_t &immData); + HResult DequeueEvent(int32_t timeout, ShmEvent &opEvent); + + inline bool FillQueueExchangeInfo(ShmConnExchangeInfo &info) + { + if (NN_UNLIKELY(mEventQueue != nullptr)) { + info.qCapacity = mEventQueue->Capacity(); + } + + if (NN_LIKELY(mHandleEventQueue.Get() != nullptr)) { + info.queueFd = mHandleEventQueue->Fd(); + return info.SetQueueName(mHandleEventQueue->FullPath()); + } + + info.mode = mShmMode; + return false; + } + + inline std::string GetName() + { + return mName; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + HResult CreateEventQueue(); + HResult FillSglCtx(ShmSglOpContextInfo *sglCtx, const UBSHcomNetTransSglRequest &sglReq); + HResult SendLocalEventForOneSideDone(ShmOpContextInfo *ctx, ShmOpContextInfo::ShmOpType type); + HResult PostReadWriteSgl(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, ShmMRHandleMap &mrHandleMap, + ShmOpContextInfo::ShmOpType type); + + HResult PostReadWrite(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap, + ShmOpContextInfo::ShmOpType type); + + uint64_t inline GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + + static bool inline NeedRetry(HResult result) + { + if (NN_UNLIKELY(result == ShmEventQueue::SHM_QUEUE_FULL)) { + return true; + } + + return false; + } + +private: + std::string mName; + uint16_t mEventQueueLength = NN_NO2048; + ShmEventQueue *mEventQueue = nullptr; + ShmHandlePtr mHandleEventQueue; /* handle of event queue */ + ShmPollingMode mShmMode = SHM_EVENT_POLLING; + int32_t mDefaultTimeout = -1; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; + +inline HResult ShmSyncEndpoint::PostRead(ShmChannel *ch, const UBSHcomNetTransRequest &req, + ShmMRHandleMap &mrHandleMap) +{ + return PostReadWrite(ch, req, mrHandleMap, ShmOpContextInfo::ShmOpType::SH_READ); +} + +inline HResult ShmSyncEndpoint::PostRead(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, + ShmMRHandleMap &mrHandleMap) +{ + return PostReadWriteSgl(ch, req, mrHandleMap, ShmOpContextInfo::ShmOpType::SH_SGL_READ); +} + +inline HResult ShmSyncEndpoint::PostWrite(ShmChannel *ch, const UBSHcomNetTransRequest &req, + ShmMRHandleMap &mrHandleMap) +{ + return PostReadWrite(ch, req, mrHandleMap, ShmOpContextInfo::ShmOpType::SH_WRITE); +} + +inline HResult ShmSyncEndpoint::PostWrite(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, + ShmMRHandleMap &mrHandleMap) +{ + return PostReadWriteSgl(ch, req, mrHandleMap, ShmOpContextInfo::ShmOpType::SH_SGL_WRITE); +} +} +} +#endif // HCOM_SHM_COMPOSED_ENDPOINT_H diff --git a/src/transport/shm/shm_data_channel.h b/src/transport/shm/shm_data_channel.h new file mode 100644 index 0000000000000000000000000000000000000000..35c1912848e414655db4196195f3371e6083b0c4 --- /dev/null +++ b/src/transport/shm/shm_data_channel.h @@ -0,0 +1,344 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SHM_DATA_CHANNEL_H +#define OCK_HCOM_SHM_DATA_CHANNEL_H + +#include "shm_common.h" +#include "shm_handle.h" + +namespace ock { +namespace hcom { +/* + * @brief buck state + */ +enum ShmBuckState : uint8_t { + HCOM_SHM_FREE = 0, + HCOM_SHM_OCCUPIED = 1, +}; + +/* + * @brief Shm buck meta + */ +struct ShmBuckMeta { + ShmBuckState state = HCOM_SHM_FREE; + + /* + * @brief Occupy this buck if its free using CAS + */ + inline bool OccupyIfFree() + { + return __sync_bool_compare_and_swap(&state, HCOM_SHM_FREE, HCOM_SHM_OCCUPIED); + } + + /* + * @brief Mark to free if occupied using CAS + */ + inline bool Free() + { + return __sync_bool_compare_and_swap(&state, HCOM_SHM_OCCUPIED, HCOM_SHM_FREE); + } +} __attribute__((packed)); + +struct ShmDataChannelOptions { + uint32_t buckSize = 0; + int mFd = 0; + uint16_t buckCount = 0; + bool isOwner = true; + uint64_t id = 0; + char fileName[NN_NO64]{}; + + inline bool SetFileName(const std::string &v) + { + NN_SET_CHAR_ARRAY_FROM_STRING(fileName, v); + } + + inline std::string GetFileName() const + { + return NN_CHAR_ARRAY_TO_STRING(fileName); + } + + ShmDataChannelOptions() = default; + ShmDataChannelOptions(uint64_t i, uint32_t buckS, uint16_t bCnt, bool own) + : buckSize(buckS), buckCount(bCnt), isOwner(own), id(i) + {} + ShmDataChannelOptions(uint64_t i, uint32_t buckS, uint16_t bCnt, int fd, bool own) + : buckSize(buckS), mFd(fd), buckCount(bCnt), isOwner(own), id(i) + {} + + std::string ToString() const + { + std::ostringstream oss; + oss << "buck-size: " << buckSize << ", buck-count: " << buckCount << ", is-owner: " << isOwner; + return oss.str(); + } +}; + +/* + * @brief Data channel for shm communication + * + * 1 hold shm memory, owner hold it; with fixed size of buck + * 2 get a free buck, for writer + * 3 mark the buck to free, for read + */ +class ShmDataChannel { +public: + static inline uint64_t MemSize(uint32_t buckMemSize, uint32_t buckCnt) + { + // buckMemSize max value is NET_SGE_MAX_SIZE(900MB), buckCnt max is 65535, will not over 2^64 + return static_cast(buckMemSize) * buckCnt + buckCnt * sizeof(ShmBuckMeta); + } + +public: + ShmDataChannel(const std::string &name, const ShmDataChannelOptions &opt, + UBSHcomNetAtomicState *state) + : mOptions(opt), mName(name), mState(state) + { + OBJ_GC_INCREASE(ShmDataChannel); + } + + ~ShmDataChannel() + { + UnInitialize(); + OBJ_GC_DECREASE(ShmDataChannel); + } + + const std::string &Filepath() const + { + if (mHandle.Get() != nullptr) { + return mHandle->FullPath(); + } + + return CONST_EMPTY_STRING; + } + + HResult ValidateOptions() + { + if (mName.empty()) { + NN_LOG_ERROR("Name of shm data channel is empty"); + return SH_PARAM_INVALID; + } + + if (NN_UNLIKELY(mState == nullptr)) { + NN_LOG_ERROR("State of shm data state is empty"); + return SH_PARAM_INVALID; + } + + if (mOptions.buckSize == 0 || mOptions.buckCount == 0) { + NN_LOG_ERROR("Buck mem size or buck count is 0 for shm data channel " << mName); + return SH_PARAM_INVALID; + } + + /* do later check buck size and buck count */ + + return SH_OK; + } + + /* + * @brief Initialize data channel + */ + HResult Initialize() + { + if (mInited) { + return SH_OK; + } + + HResult result = SH_OK; + if ((result = ValidateOptions()) != SH_OK) { + return result; + } + + /* create shm file handle */ + uint64_t desired = MemSize(mOptions.buckSize, mOptions.buckCount); + std::string fileName = mOptions.isOwner ? SHM_F_DC_PREFIX : mOptions.GetFileName(); + mHandle = new (std::nothrow) ShmHandle(mName, fileName, mOptions.id, desired, mOptions.mFd, mOptions.isOwner); + if (NN_UNLIKELY(mHandle == nullptr)) { + NN_LOG_ERROR("Failed to new shm handle for shm data channel " << mName << ", probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + if ((result = mHandle->Initialize()) != SH_OK) { + return result; + } + + mBuckBaseAddress = mHandle->ShmAddress() + sizeof(ShmBuckMeta) * mOptions.buckCount; + mBuckEndAddress = mBuckBaseAddress + static_cast(mOptions.buckSize) * mOptions.buckCount; + mMeta = reinterpret_cast(mHandle->ShmAddress()); + if (mOptions.isOwner) { + for (uint16_t i = 0; i < mOptions.buckCount; ++i) { + mMeta[i].state = HCOM_SHM_FREE; + } + } + + NN_LOG_INFO("Data channel " << mName << " at " << mHandle->FullPath() << " initialized, size meta " + << sizeof(ShmBuckMeta) << " with options " << mOptions.ToString()); + + mInited = true; + return SH_OK; + } + + void UnInitialize() + { + if (!mInited) { + return; + } + + mHandle->UnInitialize(); + + mInited = false; + } + + /* + * @brief Try to occupy one, if one is available then wait some time and try again + * + * @param address [out] the address occupied + * @param offset [out] offset to base address + * @param waitPeriodMs [in] sleep period in us + * @param timeoutSecond [in] timeout in seconds, -1 means wait infinity, 0 don't wait, >0 wait n second + * + * @return 0 if occupied + * + */ + inline HResult TryOccupyWithWait( + uintptr_t &address, uint64_t &offset, uint16_t waitPeriodUs = NN_NO100, int32_t timeoutSecond = -1) + { + if (NN_UNLIKELY(!mInited)) { + NN_LOG_ERROR("Failed to occupy one buck from shm data channel " << mName << ", as not initialized"); + return SH_NOT_INITIALIZED; + } + + const int64_t timeCountUs = static_cast(timeoutSecond) * NN_NO1000 * NN_NO1000; + int64_t timePassedUs = 0; + + const uint16_t buckCnt = mOptions.buckCount; + + while (true) { + for (uint16_t i = 0; i < buckCnt; ++i) { + /* to find */ + if (mMeta[i].OccupyIfFree()) { + /* found one and return */ + offset = mOptions.buckSize * i; + address = offset + mBuckBaseAddress; + return SH_OK; + } + } + + /* check if needing to re-find */ + if (timeCountUs < 0) { + /* <0 means wait infinity */ + usleep(waitPeriodUs); + continue; + } else if (timeCountUs == 0) { + /* 0 means don't wait */ + return SH_TIME_OUT; + } + + /* > 0 means wait sometime */ + if (timePassedUs >= timeCountUs) { + return SH_TIME_OUT; + } + + if (NN_UNLIKELY(mState->Compare(CH_BROKEN))) { + NN_LOG_ERROR("Failed to occupy one buck from shm data channel " << mName << ", as ch state is broken"); + return SH_CH_BROKEN; + } + + usleep(waitPeriodUs); + timePassedUs += waitPeriodUs; + } + } + + /* + * @brief Try to mark to free + * + * @param address [in] the address to be marked + * + * @return 0 if marked + */ + inline void MarkFree(uintptr_t address) + { + if (NN_UNLIKELY(!mInited)) { + NN_LOG_WARN("Unable to mark one buck free from shm data channel " << mName << " as not initialized"); + } + + if (NN_UNLIKELY(address >= mBuckEndAddress || address < mBuckBaseAddress)) { + NN_LOG_WARN("Unable to mark one buck free from shm data channel " << mName << " as address is invalid"); + } + + uint64_t tmpIndex = (address - mBuckBaseAddress) / mOptions.buckSize; + if ((tmpIndex * mOptions.buckSize + mBuckBaseAddress) != address) { + NN_LOG_WARN("Unable to mark one buck free from shm data channel " << mName << " as address is invalid"); + } + + auto &stateData = mMeta[tmpIndex]; + if (NN_UNLIKELY(!stateData.Free())) { + NN_LOG_WARN("Unable to mark free as is not occupied in shm data channel " << mName); + } + } + + inline uint32_t BuckSize() const + { + return mOptions.buckSize; + } + + inline uint32_t BuckCount() const + { + return mOptions.buckCount; + } + + inline HResult GetAddressByOffset(uint64_t offset, uintptr_t &address) const + { + if (NN_UNLIKELY(!mInited)) { + NN_LOG_ERROR("Failed to translate address by shm data channel " << mName << " as not initialized"); + return SH_NOT_INITIALIZED; + } + + uint64_t tmpIndex = offset / mOptions.buckSize; + if (tmpIndex >= mOptions.buckCount) { + NN_LOG_ERROR("Failed to translate address by shm data channel " << mName << " as address is invalid"); + return SH_PARAM_INVALID; + } + + address = mBuckBaseAddress + offset; + + if ((tmpIndex * mOptions.buckSize + mBuckBaseAddress) != address) { + NN_LOG_ERROR("Failed to translate address by shm data channel " << mName << " as address is invalid"); + return SH_PARAM_INVALID; + } + + return SH_OK; + } + + inline ShmHandlePtr &GetShmHandle() + { + return mHandle; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + /* put the hot variables at head part of this class */ + ShmBuckMeta *mMeta = nullptr; + uintptr_t mBuckBaseAddress = 0; + uintptr_t mBuckEndAddress = 0; + ShmHandlePtr mHandle = nullptr; + bool mInited = false; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + ShmDataChannelOptions mOptions{}; + + std::string mName; + UBSHcomNetAtomicState *mState = nullptr; +}; +} // namespace hcom +} // namespace ock +#endif // OCK_HCOM_SHM_DATA_CHANNEL_H diff --git a/src/transport/shm/shm_handle.h b/src/transport/shm/shm_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..7061541ca374ee38852b3aadee1d910a6c274f3b --- /dev/null +++ b/src/transport/shm/shm_handle.h @@ -0,0 +1,222 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_SHM_HANDLE_H +#define HCOM_SHM_HANDLE_H + +#include +#include +#include +#include +#include "net_common.h" +#include "shm_common.h" + +namespace ock { +namespace hcom { +class ShmHandle { +public: + ShmHandle(const std::string &name, const std::string &filePrefix, uint64_t id, uint64_t dataSize, bool isOwner) + : mName(name), mFilePrefix(filePrefix), mId(id), mDataSize(dataSize), mIsOwner(isOwner) + { + OBJ_GC_INCREASE(ShmHandle); + } + + ShmHandle(const std::string &name, const std::string &filePrefix, uint64_t id, uint64_t dataSize, int fd, + bool isOwner) + : mName(name), mFilePrefix(filePrefix), mId(id), mDataSize(dataSize), mFd(fd), mIsOwner(isOwner) + { + OBJ_GC_INCREASE(ShmHandle); + } + + ~ShmHandle() + { + UnInitialize(); + OBJ_GC_DECREASE(ShmHandle); + } + + HResult Initialize() + { + if (mInited) { + return SH_OK; + } + + if (mFilePrefix.empty()) { + NN_LOG_ERROR("File prefix is empty in shm handle " << mName); + return SH_PARAM_INVALID; + } + + int32_t pid = getpid(); + if (pid < 0) { + NN_LOG_ERROR("Get PID is incorrect in shm handle " << mName); + return SH_ERROR; + } + mPId = mIsOwner ? static_cast(pid) : 0; + + /* get file name */ + mFullPath = mIsOwner ? GetFileName() : mFilePrefix; + + if (mIsOwner) { + /* create file */ +#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 0) + auto tmpFd = shm_open(mFullPath.c_str(), O_CREAT | O_RDWR | O_EXCL | O_CLOEXEC, mPermission); +#else + int tmpFd = syscall(SYS_memfd_create, mFullPath.c_str(), 0); +#endif + if (tmpFd < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create shm file for " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << + ", please check if fd is out of limit"); + return SH_FILE_OP_FAILED; + } + /* truncate */ + if (ftruncate(tmpFd, mDataSize) != 0) { + NetFunc::NN_SafeCloseFd(tmpFd); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to truncate file for " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SH_FILE_OP_FAILED; + } + mFd = tmpFd; + } + +#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 0) + /* lock file Make other processes aware that the file is in use */ + if (NN_UNLIKELY(flock(mFd, LOCK_EX | LOCK_NB) != 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to lock file for " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + NetFunc::NN_SafeCloseFd(mFd); + return SH_FILE_OP_FAILED; + } +#endif + + /* mmap */ + auto mappedAddress = mmap(nullptr, mDataSize, PROT_READ | PROT_WRITE, MAP_SHARED, mFd, 0); + if (mappedAddress == MAP_FAILED) { + NetFunc::NN_SafeCloseFd(mFd); + + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to mmap file for " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SH_FILE_OP_FAILED; + } + + /* owner set 1B per 4K, make sure physical page */ + if (mIsOwner) { + auto pos = reinterpret_cast(mappedAddress); + uint64_t setLength = 0; + // if directly *pos=0 may be call bus error + uint8_t zero = 0; + while (setLength < mDataSize) { + *pos = zero; + setLength += NN_NO4096; + pos += NN_NO4096; + } + + pos = reinterpret_cast(mappedAddress) + (mDataSize - NN_NO1); + *pos = zero; + } + + mAddress = reinterpret_cast(mappedAddress); + mInited = true; + return SH_OK; + } + + void UnInitialize() + { + if (!mInited) { + return; + } + + if (munmap(reinterpret_cast(mAddress), mDataSize) != 0) { + NN_LOG_ERROR("Failed to munmap address in shm handle " << mName); + } + + NetFunc::NN_SafeCloseFd(mFd); + +#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 0) + if (mIsOwner && shm_unlink(mFullPath.c_str()) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to remove file for " << mName << " error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + } +#endif + mInited = false; + } + + std::string ToString() const + { + std::ostringstream oss; + oss << "name: " << mName << ", id: " << mId << ", data-size: " << mDataSize << ", address: " << mAddress << + ", fd: " << mFd << ", is-owner: " << mIsOwner << ", inited: " << mInited << ", full-path: " << mFullPath; + return oss.str(); + } + + inline uint64_t Id() const + { + return mId; + } + + inline uint64_t DataSize() const + { + return mDataSize; + } + + inline uintptr_t ShmAddress() const + { + return mAddress; + } + + inline bool IsOwner() const + { + return mIsOwner; + } + + inline int Fd() const + { + return mFd; + } + + inline const std::string &FullPath() const + { + return mFullPath; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + std::string GetFileName() const + { + std::string filePath = mFilePrefix + "-" + std::to_string(mPId) + "-" + std::to_string(mId); + return filePath; + } + +private: + std::string mFullPath; + std::string mName; + std::string mFilePrefix; + uint64_t mId = 0; + uint64_t mDataSize = 0; + uintptr_t mAddress = 0; + int mFd = -1; + int mPermission = NN_NO400; + uint32_t mPId = 0; + + bool mIsOwner = true; + bool mInited = false; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} + +#endif // HCOM_SHM_HANDLE_H diff --git a/src/transport/shm/shm_handle_fds.h b/src/transport/shm/shm_handle_fds.h new file mode 100644 index 0000000000000000000000000000000000000000..0505b94459424ff6bbaa8963a2378c61c21edb13 --- /dev/null +++ b/src/transport/shm/shm_handle_fds.h @@ -0,0 +1,128 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SHM_HANDLE_FDS_H +#define OCK_HCOM_SHM_HANDLE_FDS_H + +#include "shm_common.h" + +namespace ock { +namespace hcom { + +class ShmHandleFds { +public: + static HResult SendMsgFds(int udsFd, int fds[], uint32_t len) + { + if (NN_UNLIKELY(len != NN_NO4)) { + NN_LOG_ERROR("Failed to send fds as len of fds should be 4"); + return SH_ERROR; + } + + // create iov for msg_iov param + struct iovec iov = { + .iov_base = &len, + .iov_len = sizeof(uint32_t) + }; + + uint32_t fdsSize = sizeof(int) * NN_NO4; + char buf[CMSG_SPACE(fdsSize)]; + bzero(buf, fdsSize); + + struct msghdr msg {}; + msg.msg_iov = &iov; + msg.msg_control = buf; + msg.msg_iovlen = 1; + msg.msg_controllen = sizeof(buf); + + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + if (NN_UNLIKELY(cmsg == nullptr)) { + NN_LOG_ERROR("CMSG_FIRSTHDR get empty msg"); + return SH_ERROR; + } + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(fdsSize); + + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(CMSG_DATA(cmsg)), fdsSize, fds, fdsSize) != SH_OK)) { + NN_LOG_ERROR("Failed to copy fds to cmsg"); + return SH_PARAM_INVALID; + } + + auto result = ::sendmsg(udsFd, &msg, 0); + if (NN_UNLIKELY(result <= 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to send fds msg to peer result:" << result << ", as errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SH_ERROR; + } + + return NN_OK; + } + + static HResult ReceiveMsgFds(int udsFd, int fds[], uint32_t len) + { + if (NN_UNLIKELY(len != NN_NO4)) { + NN_LOG_ERROR("Failed to receive fds as len of fds should be 4"); + return SH_ERROR; + } + + // create iov for msg_iov param + uint32_t recvLen = 0; + struct iovec iov = { + .iov_base = &recvLen, + .iov_len = sizeof(uint32_t) + }; + + uint32_t fdsSize = sizeof(int) * NN_NO4; + + char buf[CMSG_SPACE(fdsSize)]; + bzero(buf, fdsSize); + + struct msghdr msg {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = buf; + msg.msg_controllen = sizeof(buf); + + auto result = ::recvmsg(udsFd, &msg, 0); + if (NN_UNLIKELY((result == 0) && (errno == EXIT_SUCCESS))) { + NN_LOG_ERROR("Failed to receive fds msg from peer, as channel fd has been destroyed "); + return SH_PEER_FD_ERROR; + } + + if (NN_UNLIKELY(result <= 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to receive fds msg from peer, as errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SH_ERROR; + } + + if (NN_UNLIKELY(recvLen != len)) { + NN_LOG_ERROR("Failed to receive fds as receive Len:" << recvLen << " is not equal to len:" << len); + return SH_ERROR; + } + + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + if (NN_UNLIKELY(cmsg == nullptr)) { + NN_LOG_ERROR("CMSG_FIRSTHDR get empty msg"); + return SH_ERROR; + } + if (NN_UNLIKELY(memcpy_s(fds, fdsSize, reinterpret_cast(CMSG_DATA(cmsg)), fdsSize) != SH_OK)) { + NN_LOG_ERROR("Failed to copy cmsg to fds"); + return SH_PARAM_INVALID; + } + + return SH_OK; + } +}; +} +} +#endif \ No newline at end of file diff --git a/src/transport/shm/shm_lock_guard.h b/src/transport/shm/shm_lock_guard.h new file mode 100644 index 0000000000000000000000000000000000000000..bd7ecfdfbb2fc33e882409c5af2fd56a3429496e --- /dev/null +++ b/src/transport/shm/shm_lock_guard.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SHM_LOCK_GUARD_H +#define OCK_HCOM_SHM_LOCK_GUARD_H +#include +#include "hcom.h" + +namespace ock { +namespace hcom { + +class RWLockGuard { +public: + RWLockGuard(RWLockGuard &) = delete; + RWLockGuard &operator = (RWLockGuard &) = delete; + + explicit RWLockGuard(NetReadWriteLock &lock) : mRwLock(lock) {} + + inline void LockRead() + { + mRwLock.LockRead(); + } + + inline void LockWrite() + { + mRwLock.LockWrite(); + } + + ~RWLockGuard() + { + mRwLock.UnLock(); + } + +private: + NetReadWriteLock &mRwLock; +}; +} +} +#endif \ No newline at end of file diff --git a/src/transport/shm/shm_mr_handle_map.h b/src/transport/shm/shm_mr_handle_map.h new file mode 100644 index 0000000000000000000000000000000000000000..4ee8aa441308b919fd451428a567aa73dc6ae7f1 --- /dev/null +++ b/src/transport/shm/shm_mr_handle_map.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef OCK_HCOM_SHM_HANDLE_MAP_H +#define OCK_HCOM_SHM_HANDLE_MAP_H +#include "shm_lock_guard.h" + +namespace ock { +namespace hcom { +class ShmHandle; +using ShmHandlePtr = NetRef; + +class ShmMRHandleMap { +public: + ShmMRHandleMap(const ShmMRHandleMap &) = delete; + ShmMRHandleMap &operator = (const ShmMRHandleMap &) = delete; + + static ShmMRHandleMap &GetInstance() + { + static ShmMRHandleMap shmMrHandleMap; + return shmMrHandleMap; + } + + inline NResult AddToLocalMap(uint32_t key, const ShmHandlePtr &shmHandle) + { + RWLockGuard(mLRwLock).LockWrite(); + mMrLKeyFdMap.emplace(key, shmHandle); + return NN_OK; + } + + inline void ClearLocalMap() + { + RWLockGuard(mLRwLock).LockWrite(); + if (!mMrLKeyFdMap.empty()) { + mMrLKeyFdMap.clear(); + } + } + + inline ShmHandlePtr GetFromLocalMap(uint32_t key) + { + RWLockGuard(mLRwLock).LockRead(); + auto iter = mMrLKeyFdMap.find(key); + if (iter == mMrLKeyFdMap.end()) { + return nullptr; + } + return iter->second; + } + + inline NResult AddToRemoteMap(uint32_t key, const ShmHandlePtr &shmHandle) + { + RWLockGuard(mRRwLock).LockWrite(); + mMrRKeyFdMap.emplace(key, shmHandle); + return NN_OK; + } + + inline void ClearRemoteMap() + { + RWLockGuard(mRRwLock).LockWrite(); + if (!mMrRKeyFdMap.empty()) { + mMrRKeyFdMap.clear(); + } + } + + inline ShmHandlePtr GetFromRemoteMap(uint32_t key) + { + RWLockGuard(mRRwLock).LockRead(); + auto iter = mMrRKeyFdMap.find(key); + if (iter == mMrRKeyFdMap.end()) { + return nullptr; + } + return iter->second; + } + +private: + ShmMRHandleMap() = default; + ~ShmMRHandleMap() + { + ClearLocalMap(); + ClearRemoteMap(); + } + +private: + NetReadWriteLock mLRwLock; + NetReadWriteLock mRRwLock; + std::map mMrLKeyFdMap; + std::map mMrRKeyFdMap; +}; +} +} +#endif \ No newline at end of file diff --git a/src/transport/shm/shm_mr_pool.cpp b/src/transport/shm/shm_mr_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b3f23dd1d668cb2ac0d73d3f9eef131240bf7173 --- /dev/null +++ b/src/transport/shm/shm_mr_pool.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "shm_mr_pool.h" +#include "shm_handle.h" + +namespace ock { +namespace hcom { +std::atomic ShmMemoryRegion::LOCAL_KEY_INDEX(0); + +NResult ShmMemoryRegion::Create(const std::string &name, uint64_t size, ShmMemoryRegion *&buf) +{ + if (NN_UNLIKELY(size == 0)) { + NN_LOG_ERROR("Failed to create shm memory region as size is zero"); + return NN_INVALID_PARAM; + } + + auto tmpBuf = new (std::nothrow) ShmMemoryRegion(name, false, 0, size); + if ((NN_UNLIKELY(tmpBuf == nullptr))) { + return NN_NEW_OBJECT_FAILED; + } + + buf = tmpBuf; + + return NN_OK; +} + +NResult ShmMemoryRegion::Create(const std::string &name, uintptr_t address, uint64_t size, ShmMemoryRegion *&buf) +{ + if (NN_UNLIKELY(address == 0 || size == 0)) { + NN_LOG_ERROR("Failed to create shm memory region as size or address is zero"); + return NN_INVALID_PARAM; + } + + auto tmpBuf = new (std::nothrow) ShmMemoryRegion(name, true, address, size); + if ((NN_UNLIKELY(tmpBuf == nullptr))) { + return NN_NEW_OBJECT_FAILED; + } + + buf = tmpBuf; + + return NN_OK; +} + +NResult ShmMemoryRegion::Initialize() +{ + std::lock_guard guard(mMutex); + if (mInited) { + return NN_OK; + } + + if (mExternalMemory) { + if ((mBuf == 0 || mSize == 0)) { + NN_LOG_ERROR("Invalid external memory address or size for Shm memory region " << mName); + return NN_INVALID_PARAM; + } + + mLKey = GenerateKey(); + mInited = true; + + /* don't do bzero to external memory, because this may clean user's data */ + return NN_OK; + } + + /* allocate memory */ + uint64_t newId = NetUuid::GenerateUuid(); + ShmHandlePtr mrHandle = new (std::nothrow) ShmHandle(mName, "mr_" + mName, newId, mSize, true); + if (NN_UNLIKELY(mrHandle == nullptr)) { + NN_LOG_ERROR("Failed to create shm handle for shm memory region " << mName); + return NN_NEW_OBJECT_FAILED; + } + + if (mrHandle->Initialize() != NN_OK) { + NN_LOG_ERROR("Failed to initialize shm handle for shm memory region " << mName); + return NN_NOT_INITIALIZED; + } + + auto tmpBuf = mrHandle->ShmAddress(); + if (tmpBuf == 0) { + NN_LOG_ERROR("Failed to allocate memory for Shm memory region " << mName << " with size " << mSize); + return NN_MALLOC_FAILED; + } + + mBuf = tmpBuf; + mLKey = GenerateKey(); + mMrHandle = mrHandle; + mInited = true; + return NN_OK; +} + +void ShmMemoryRegion::UnInitialize() +{ + std::lock_guard guard(mMutex); + if (!mInited) { + return; + } + + if (!mExternalMemory && mMrHandle != nullptr) { + mMrHandle->UnInitialize(); + mMrHandle = nullptr; + mBuf = 0; + } + + mInited = false; +} + +inline uint32_t ShmMemoryRegion::GenerateKey() +{ + // 获取完整PID并使用哈希混合 + uint32_t pid = static_cast(getpid()); + std::hash hashCount; + + // 混合PID、索引和时间哈希 + uint32_t mix = + hashCount(pid) ^ hashCount(LOCAL_KEY_INDEX.fetch_add(1)) ^ (static_cast(time(nullptr)) & 0xFFFF); + + // 二次混合确保均匀分布 + return (mix ^ (mix >> NN_NO16)) * 0x45d9f3b; +} +} +} \ No newline at end of file diff --git a/src/transport/shm/shm_mr_pool.h b/src/transport/shm/shm_mr_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..909b0e98168cc81c0efc37a548dbdd379d4a89fd --- /dev/null +++ b/src/transport/shm/shm_mr_pool.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SHM_MEMORY_REGION_H_23234 +#define OCK_HCOM_SHM_MEMORY_REGION_H_23234 + +#include + +#include "hcom.h" +#include "shm_channel.h" + +namespace ock { +namespace hcom { +class ShmMemoryRegion : public UBSHcomNetMemoryRegion { +public: + static NResult Create(const std::string &name, uint64_t size, ShmMemoryRegion *&buf); + static NResult Create(const std::string &name, uintptr_t address, uint64_t size, ShmMemoryRegion *&buf); + + void *GetMemorySeg() override + { + return nullptr; + } + + void GetVa(uint64_t &va, uint64_t &va_len, uint32_t &token_id) override + { + return; + } + +public: + ShmMemoryRegion(const std::string &name, bool extMem, uintptr_t extMemAddress, uint64_t size) + : UBSHcomNetMemoryRegion(name, extMem, extMemAddress, size) + { + OBJ_GC_INCREASE(ShmMemoryRegion); + } + + ~ShmMemoryRegion() override + { + OBJ_GC_DECREASE(ShmMemoryRegion); + } + + NResult Initialize() override; + void UnInitialize() override; + + virtual ShmHandlePtr GetMrHandle() + { + return mMrHandle; + } + +private: + uint32_t GenerateKey(); + +private: + std::mutex mMutex; + bool mInited = false; + ShmHandlePtr mMrHandle = nullptr; + static std::atomic LOCAL_KEY_INDEX; +}; +} +} + +#endif // OCK_HCOM_SHM_MEMORY_REGION_H_23234 diff --git a/src/transport/shm/shm_queue.h b/src/transport/shm/shm_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..9783d35cd934266f6a78c9d97bd526ac3bfa2ef7 --- /dev/null +++ b/src/transport/shm/shm_queue.h @@ -0,0 +1,466 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SHM_WRAPPER_H +#define OCK_HCOM_SHM_WRAPPER_H + +#include "net_monotonic.h" +#include "shm_common.h" +#include "shm_handle.h" + +namespace ock { +namespace hcom { +struct ShmQueueHeadTail { + volatile uint32_t head = 0; + volatile uint32_t tail = 0; +}; + +struct ShmQueueMeta { + uint32_t capacity = 0; /* capacity of the queue */ + uint32_t mask = 0; /* mask of queue */ + + ShmQueueHeadTail prod {}; /* producer info */ + ShmQueueHeadTail cons {}; /* consumer info */ + + sem_t sem {}; /* sem info */ + + std::string ToString() const + { + std::ostringstream oss; + oss << "capacity " << capacity << ", mask " << mask << ", prod: " << prod.head << "-" << prod.tail << + ", cons: " << cons.head << "-" << cons.tail; + return oss.str(); + } +}; + +template class ShmQueue { +public: + static const HResult SHM_QUEUE_FULL = -1; + static const HResult SHM_QUEUE_EMPTY = -2; + static const HResult SHM_QUEUE_NOT_INIT = -3; + +public: + ShmQueue(const std::string &name, uint32_t capacity, const ShmHandlePtr &shmHandle) + : mShmHandle(shmHandle.Get()), mName(name), mCapacity(capacity) + { + OBJ_GC_INCREASE(ShmQueue); + } + + ~ShmQueue() + { + UnInitialize(); + OBJ_GC_DECREASE(ShmQueue); + } + + static inline uint32_t MemSize(uint32_t capacity) + { + auto tmp = capacity; + if (!POWER_OF_2(capacity)) { + tmp = NN_NextPower2(capacity); + } + + return sizeof(ShmQueueMeta) + sizeof(T) * tmp; + } + + /* + * @brief Initialize + */ + HResult Initialize() + { + if (mInited) { + return SH_OK; + } + + if (mShmHandle.Get() == nullptr || mShmHandle->Initialize() != NN_OK || mCapacity == 0) { + NN_LOG_ERROR("Failed to initialize shm queue " << mName); + return SH_PARAM_INVALID; + } + + /* check if capacity is power of 2 */ + if (!POWER_OF_2(mCapacity)) { + mCapacity = NN_NextPower2(mCapacity); + } + + if (mShmHandle->DataSize() != MemSize(mCapacity)) { + NN_LOG_ERROR("Failed to initialize shm queue " << mName << " as size not matched, " << + mShmHandle->DataSize() << "!=" << MemSize(mCapacity)); + return SH_PARAM_INVALID; + } + + /* + * for example capacity is 4 [100], then mask is 3 [011] + * for tail/head fast reverse + */ + mMask = mCapacity - 1; + + mQueueMeta = reinterpret_cast(mShmHandle->ShmAddress()); + + NN_LOG_TRACE_INFO("shm mem base info, sizeof(ShmQueueMeta) " << sizeof(ShmQueueMeta) << ", meta " << + mQueueMeta->ToString()); + + if (mShmHandle->IsOwner()) { + /* set meta */ + bzero(reinterpret_cast(mShmHandle->ShmAddress()), mShmHandle->DataSize()); + mQueueMeta->capacity = mCapacity; + mQueueMeta->mask = mMask; + mQueueMeta->prod.head = 0; + mQueueMeta->prod.tail = 0; + mQueueMeta->cons.head = 0; + mQueueMeta->cons.tail = 0; + + auto result = sem_init(&mQueueMeta->sem, 1, 0); + if (result != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed initialize shm sem for queue " << mName << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return result; + } + } else { + /* get capacity and mask */ + mCapacity = mQueueMeta->capacity; + mMask = mQueueMeta->mask; + } + + mQueueData = reinterpret_cast(mShmHandle->ShmAddress() + sizeof(ShmQueueMeta)); + mMaxFailedTime = static_cast(NetFunc::NN_GetLongEnv("HCOM_SHM_MAX_ENQUEUE_STUCK_TIME", + NN_NO1, NN_NO7200, NN_NO10)); + mMaxEnqueueTimeout = static_cast(NetFunc::NN_GetLongEnv("HCOM_SHM_ENQUEUE_TIMEOUT", + NN_NO1, NN_NO7200, NN_NO20)); + NN_LOG_DEBUG("SHM: mMaxFailedTime " << mMaxFailedTime << ", mMaxEnqueueTimeout " << mMaxEnqueueTimeout); + mInited = true; + return SH_OK; + } + + void UnInitialize() + { + if (!mInited) { + return; + } + + sem_destroy(&mQueueMeta->sem); + mQueueMeta = nullptr; + mQueueData = nullptr; + mCapacity = 0; + mMask = 0; + mInited = false; + mShmHandle.Set(nullptr); + } + + inline HResult Enqueue(T &item) + { + if (NN_UNLIKELY(!mInited)) { + return SHM_QUEUE_NOT_INIT; + } + + uint32_t oldHead = 0; + uint32_t newHead = 0; + + uint32_t remainCapacity = 0; + /* request space, if no space return error */ + if (NN_UNLIKELY(RequestProduceSpace(oldHead, newHead, 1, remainCapacity) == 0)) { + return SHM_QUEUE_FULL; + } + + /* write item to the right place */ + EnqueueItem(oldHead, item); + /* update tail */ + if (NN_UNLIKELY(!UpdateProdTail(true, oldHead, newHead))) { + return SHM_QUEUE_FULL; + } + + return SH_OK; + } + + inline HResult Dequeue(T &item) + { + if (NN_UNLIKELY(!mInited)) { + return SHM_QUEUE_NOT_INIT; + } + + uint32_t oldHead = 0; + uint32_t newHead = 0; + + uint32_t dequeueSize = 0; + /* request 1 item for dequeue */ + if (RequestDequeue(oldHead, newHead, 1, dequeueSize) == 0) { + return SHM_QUEUE_EMPTY; + } + + /* if will consume the failed one, jump to the next. Since low probability to happen, handle it once */ + if (oldHead == mFailedProd) { + NN_LOG_WARN("Skip the failed prod " << mFailedProd); + /* update the tail */ + UpdateConsTail(false, oldHead, newHead); + /* request 1 item for dequeue */ + if (RequestDequeue(oldHead, newHead, 1, dequeueSize) == 0) { + return SHM_QUEUE_EMPTY; + } + } + + /* read the item from right place */ + DequeueItem(oldHead, item); + + /* update the tail */ + UpdateConsTail(false, oldHead, newHead); + + return SH_OK; + } + + inline HResult EnqueueAndNotify(T &item) + { + auto result = Enqueue(item); + if (NN_UNLIKELY(result != SH_OK)) { + return result; + } + + return sem_post(&mQueueMeta->sem); + } + + inline void LocalStopAndNotify() + { + mStop = true; + sem_post(&mQueueMeta->sem); + } + + inline bool CompTime(const struct timespec &a, const struct timespec &b) + { + if (a.tv_sec != b.tv_sec) { + return (a.tv_sec > b.tv_sec); + } + return (a.tv_nsec > b.tv_nsec); + } + + inline void CheckAndMarkProducerState() + { + if (mQueueMeta->prod.head > mQueueMeta->prod.tail) { + if (mTempProdIdx != mQueueMeta->prod.tail) { + mTempProdIdx = mQueueMeta->prod.tail; + struct timespec timeOutTime = MONOTONIC_TIME(); + timeOutTime.tv_sec += static_cast(mMaxFailedTime); + mFailedTime = timeOutTime; + return; + } + } else { + mTempProdIdx = UINT64_MAX; + return; + } + + struct timespec nowTime = MONOTONIC_TIME(); + if (mTempProdIdx != UINT64_MAX && CompTime(nowTime, mFailedTime)) { + mFailedProd = mTempProdIdx; + mTempProdIdx = UINT64_MAX; + mQueueMeta->prod.tail++; + NN_LOG_WARN("Dectected enqueue stuck, skip idx: " << mFailedProd); + } + } + + inline HResult DequeueOrWait(T &item, bool &stopping, int32_t timeoutInMs) + { + auto start = NetMonotonic::TimeMs(); + while (true) { + /* stopping */ + if (NN_UNLIKELY(mStop)) { + stopping = true; + return SH_OK; + } + // check if any producer stuck in enqueue. If stuck, kick it out + CheckAndMarkProducerState(); + + auto pollTime = NetMonotonic::TimeMs() - start; + if (timeoutInMs >= 0 && pollTime >= static_cast(timeoutInMs)) { + return SH_TIME_OUT; + } + + struct timespec semTimeout {}; + if (timeoutInMs < 0) { + // set 0 means never timeout + semTimeout.tv_sec = 0; + semTimeout.tv_nsec = 0; + } else { + clock_gettime(CLOCK_REALTIME, &semTimeout); + semTimeout.tv_nsec += + static_cast(static_cast(static_cast(timeoutInMs)) * NN_NO1000000); + if (semTimeout.tv_nsec >= static_cast(NN_NO1000000000)) { + semTimeout.tv_sec += semTimeout.tv_nsec / NN_NO1000000000; + semTimeout.tv_nsec %= NN_NO1000000000; + } + } + + if (sem_timedwait(&mQueueMeta->sem, &semTimeout) != 0) { + continue; + } + + /* dequeue */ + if (NN_LIKELY(Dequeue(item) == SH_OK)) { + return SH_OK; + } + } + } + + inline uint32_t Capacity() const + { + return mCapacity; + } + + std::string ToString() const + { + std::ostringstream oss; + oss << "name: " << mName << ", capacity: " << mCapacity; + return oss.str(); + } + +private: + inline uint32_t RequestProduceSpace(uint32_t &oHead, uint32_t &nHead, uint32_t reqSize, uint32_t &freeSpace) + { + const uint32_t capacity = mCapacity; + bool successful = false; + uint32_t tmpReqCount = reqSize; + + do { + /* major 3 steps: + * step 1: assign, global variable to local one + * step 2: calculate target variables + * --- a) check if there are enough spaces, return if no space left, + * --- b) calculate target + * step 3: commit using atomic operations, if failed try + */ + reqSize = tmpReqCount; + + oHead = mQueueMeta->prod.head; + + /* read barrier avoid order */ + H_RMB(); + + freeSpace = (capacity + mQueueMeta->cons.tail - oHead); + + /* no free space */ + if (H_UNLIKELY(reqSize > freeSpace)) { + return 0; + } + + nHead = oHead + reqSize; + + /* commit */ + successful = H_CAS(&mQueueMeta->prod.head, oHead, nHead); + } while (H_UNLIKELY(!successful)); + + return reqSize; + } + + inline uint32_t RequestDequeue(uint32_t &oHead, uint32_t &nHead, uint32_t reqSize, uint32_t &dequeueCount) + { + bool successful = false; + do { + /* major 3 steps: + * step 1: assign, global variable to local one + * step 2: calculate target variables + * --- a) check if there are enough items, if no just we have + * --- b) calculate target + * step 3: commit using atomic operations, if failed try + */ + oHead = mQueueMeta->cons.head; + + /* read barrier avoid order */ + H_RMB(); + + dequeueCount = mQueueMeta->prod.tail - oHead; + if (dequeueCount == 0) { + return 0; + } else if (dequeueCount > reqSize) { + dequeueCount = reqSize; + } + + nHead = oHead + dequeueCount; + + successful = H_CAS(&mQueueMeta->cons.head, oHead, nHead); + } while (H_UNLIKELY(!successful)); + + return dequeueCount; + } + + inline void EnqueueItem(uint32_t oHead, T &item) + { + mQueueData[oHead & mMask] = item; + } + + inline void DequeueItem(uint32_t oHead, T &item) + { + item = mQueueData[oHead & mMask]; + } + + /* + * @brief update produce tail + */ + inline bool UpdateProdTail(bool enqueue, uint32_t oldTail, uint32_t newTail) + { + if (enqueue) { + H_WMB(); + } else { + H_RMB(); + } + + uint64_t endTimeSecond = NetMonotonic::TimeSec() + mMaxEnqueueTimeout; + // if others is enqueue/dequeue in progress, wait + uint32_t cmpTail = oldTail; + while (H_UNLIKELY(!H_CAS(&mQueueMeta->prod.tail, cmpTail, newTail))) { + cmpTail = oldTail; + if (NetMonotonic::TimeSec() > endTimeSecond) { + NN_LOG_ERROR("Update Prod tail failed, timeout."); + return false; + } + } + return true; + } + + /* + * @brief update consume tail + */ + inline void UpdateConsTail(bool enqueue, uint32_t oldTail, uint32_t newTail) + { + if (enqueue) { + H_WMB(); + } else { + H_RMB(); + } + + /* if others is enqueue/dequeue in progress, wait */ + while (H_UNLIKELY(mQueueMeta->cons.tail != oldTail)) { + H_Pause(); + } + + mQueueMeta->cons.tail = newTail; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + ShmQueueMeta *mQueueMeta = nullptr; + T *mQueueData = nullptr; + bool mStop = false; + ShmHandlePtr mShmHandle = nullptr; + + std::string mName; + bool mInited = false; + uint32_t mCapacity = 0; + uint32_t mMask = 0; + + uint32_t mMaxFailedTime = 10; + uint32_t mMaxEnqueueTimeout = 20; + uint64_t mTempProdIdx = UINT64_MAX; + uint64_t mFailedProd = UINT64_MAX; + struct timespec mFailedTime {}; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} + +#endif // OCK_HCOM_SHM_WRAPPER_H diff --git a/src/transport/shm/shm_validation.h b/src/transport/shm/shm_validation.h new file mode 100644 index 0000000000000000000000000000000000000000..3fec7b8cf0bdd842a8fe1f89a6092c5dae935bbc --- /dev/null +++ b/src/transport/shm/shm_validation.h @@ -0,0 +1,191 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_SHM_VALIDATION_H +#define OCK_HCOM_NET_SHM_VALIDATION_H +#ifdef SHM_BUILD_ENABLED + +#include "hcom.h" +#include "hcom_utils.h" +#include "net_common.h" +#include "net_monotonic.h" +#include "net_security_alg.h" +#include "net_shm_common.h" +#include "net_shm_driver_oob.h" + + +namespace ock { +namespace hcom { +#define VALIDATE_ENCRYPT_LENGTH(encryptLen, calLen, mShmCh, address) \ + if (NN_UNLIKELY((encryptLen) != (calLen))) { \ + NN_LOG_ERROR("Failed to encrypt data as encrypt length " << (encryptLen) << " is not equal to cal length " << \ + (calLen)); \ + (mShmCh)->DCMarkBuckFree((address)); \ + return NN_ENCRYPT_FAILED; \ + } + +#define VALIDATE_DECRYPT_LENGTH(decryptLen, calLen, opCtx) \ + if (NN_UNLIKELY((decryptLen) != (calLen))) { \ + NN_LOG_ERROR("Failed to decrypt data as decrypt length " << (decryptLen) << " is not equal to cal length " << \ + (calLen)); \ + (opCtx).channel->DCMarkPeerBuckFree((opCtx).dataAddress); \ + return NN_DECRYPT_FAILED; \ + } + +static __always_inline NResult PostSendValidation(UBSHcomNetAtomicState &state, uint64_t id, + uint16_t opCode, const UBSHcomNetTransRequest &request) +{ + if (NN_UNLIKELY(!state.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Endpoint " << id << " is not established, state is " << UBSHcomNEPStateToString(state.Get())); + return NN_EP_NOT_ESTABLISHED; + } + if (NN_UNLIKELY(opCode >= MAX_OPCODE)) { + NN_LOG_ERROR("Failed to post message as opcode is invalid, which should with the range 0~" << (MAX_OPCODE - 1)); + return NN_INVALID_OPCODE; + } + if (NN_UNLIKELY(request.lAddress == 0 || request.size == 0)) { + NN_LOG_ERROR("Failed to post message as source data is null or size is zero"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +static __always_inline NResult PostSendRawValidation(UBSHcomNetAtomicState &state, uint64_t id, + const UBSHcomNetTransRequest &request) +{ + if (NN_UNLIKELY(!state.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Endpoint " << id << " is not established, state is " << UBSHcomNEPStateToString(state.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + if (NN_UNLIKELY(request.lAddress == 0 || request.size == 0)) { + NN_LOG_ERROR("Failed to post message as source data is null or size is zero"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +static __always_inline NResult PostSendValidationMaxSize(const UBSHcomNetTransRequest &request, uint32_t allowedSize, + bool mIsNeedEncrypt, AesGcm128 mAes) +{ + size_t size = request.size; + if (mIsNeedEncrypt) { + size = mAes.EstimatedEncryptLen(request.size); + } + if (NN_UNLIKELY(size > allowedSize)) { + NN_LOG_ERROR("Failed to post message as message size " << size << " is too large, use one side post"); + return NN_TWO_SIDE_MESSAGE_TOO_LARGE; + } + return NN_OK; +} + +static __always_inline NResult PostSendSglValidationInner(uint64_t &size, const UBSHcomNetTransSglRequest &request, + NetDriverShmWithOOB *driver, uint32_t allowedSize, bool mIsNeedEncrypt, AesGcm128 mAes) +{ + for (uint16_t i = 0; i < request.iovCount; ++i) { + auto &&iov = request.iov[i]; + if (NN_OK != driver->ValidateMemoryRegion(iov.lKey, iov.lAddress, iov.size)) { + NN_LOG_ERROR("Invalid MemoryRegion or lkey in iov"); + return NN_INVALID_LKEY; + } + size += iov.size; + } + + if (mIsNeedEncrypt) { + size = mAes.EstimatedEncryptLen(size); + NN_LOG_INFO("size after encrypt is " << size << " allowedSize is " << allowedSize); + } + + if (NN_UNLIKELY(size > allowedSize)) { + NN_LOG_ERROR("Failed to post raw sgl message as size " << size << " is too large, use one side instead"); + return NN_TWO_SIDE_MESSAGE_TOO_LARGE; + } + return NN_OK; +} + +static __always_inline NResult PostSendSglValidation(UBSHcomNetAtomicState &state, uint64_t id, + NetDriverShmWithOOB *driver, uint32_t seqNo, const UBSHcomNetTransSglRequest &request, uint32_t allowedSize, + bool mIsNeedEncrypt, AesGcm128 mAes) +{ + if (NN_UNLIKELY(request.iov == nullptr || request.iovCount > NET_SGE_MAX_IOV || request.iovCount == 0)) { + NN_LOG_ERROR("Invalid iov ptr:" << request.iov << " or iov cnt:" << request.iovCount); + return NN_PARAM_INVALID; + } + + if (NN_UNLIKELY(!state.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Endpoint " << id << " is not established, state is " << UBSHcomNEPStateToString(state.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + if (NN_UNLIKELY(seqNo == 0)) { + NN_LOG_ERROR("Failed to post raw sgl message as seqNo must > 0"); + return NN_INVALID_PARAM; + } + + uint64_t size = 0; + if (NN_UNLIKELY(PostSendSglValidationInner(size, request, driver, allowedSize, mIsNeedEncrypt, mAes) != NN_OK)) { + return NN_INVALID_PARAM; + } + return NN_OK; +} + +static __always_inline NResult ReadWriteValidation(UBSHcomNetAtomicState &state, uint64_t id, + NetDriverShmWithOOB *driver, ShmChannel *shmCh, const UBSHcomNetTransRequest &request) +{ + if (NN_UNLIKELY(!state.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Endpoint " << id << " is not established, state is " << UBSHcomNEPStateToString(state.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + if (NN_UNLIKELY(shmCh == nullptr || driver == nullptr)) { + NN_LOG_ERROR("Invalid endpoint"); + return NN_ERROR; + } + + if (NN_OK != driver->ValidateMemoryRegion(request.lKey, request.lAddress, request.size)) { + NN_LOG_ERROR("Invalid MemoryRegion or lkey"); + return NN_INVALID_LKEY; + } + return NN_OK; +} + +static __always_inline NResult PostReadWriteSglValidation(UBSHcomNetAtomicState &state, + uint32_t id, NetDriverShmWithOOB *driver, ShmChannel *shmCh, const UBSHcomNetTransSglRequest &request) +{ + if (NN_UNLIKELY(request.iov == nullptr || request.iovCount > NET_SGE_MAX_IOV || request.iovCount == 0)) { + NN_LOG_ERROR("Invalid iov ptr: " << request.iov << " or iov cnt: " << request.iovCount); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(!state.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Endpoint " << id << " is not established, state is " << UBSHcomNEPStateToString(state.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + if (NN_UNLIKELY(shmCh == nullptr || driver == nullptr)) { + NN_LOG_ERROR("Invalid endpoint"); + return NN_ERROR; + } + + auto iovCount = request.iovCount; + for (auto i = 0; i < iovCount; i++) { + auto iov = request.iov[i]; + if (NN_OK != driver->ValidateMemoryRegion(iov.lKey, iov.lAddress, iov.size)) { + NN_LOG_ERROR("Invalid MemoryRegion or lkey"); + return NN_INVALID_LKEY; + } + } + return NN_OK; +} +} +} +#endif +#endif // OCK_HCOM_NET_SHM_VALIDATION_H diff --git a/src/transport/shm/shm_worker.cpp b/src/transport/shm/shm_worker.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b695c1e7ee4ba117a4651203c3b6716d767d265a --- /dev/null +++ b/src/transport/shm/shm_worker.cpp @@ -0,0 +1,364 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "shm_worker.h" +#include "shm_handle.h" +#include "shm_queue.h" + +namespace ock { +namespace hcom { +std::atomic ShmWorker::GLOBAL_WORKER_INDEX(0); + +ShmWorker::ShmWorker(const std::string &name, const UBSHcomNetWorkerIndex &index, const ShmWorkerOptions &options, + const NetMemPoolFixedPtr &opMemPool, const NetMemPoolFixedPtr &opCtxMemPool, const NetMemPoolFixedPtr &sglOpMemPool) + : mName(name + index.ToString()), mIndex(index), mOptions(options) +{ + if (mOpCompInfoPool.Initialize(opMemPool) != NN_OK) { + NN_LOG_ERROR("Failed to initialize op complete pool for worker " << mName); + } + + if (mOpCtxInfoPool.Initialize(opCtxMemPool) != NN_OK) { + NN_LOG_ERROR("Failed to initialize op ctx pool for worker " << mName); + } + + if (mSglCtxInfoPool.Initialize(sglOpMemPool) != NN_OK) { + NN_LOG_ERROR("Failed to initialize sgl op ctx pool for worker " << mName); + } + + OBJ_GC_INCREASE(ShmWorker); +} + +HResult ShmWorker::Initialize() +{ + std::lock_guard locker(mMutex); + if (mInited) { + return SH_OK; + } + + /* validate */ + HResult result = SH_OK; + if ((result = Validate()) != SH_OK) { + NN_LOG_ERROR("Failed to validate in shm worker initialize"); + return result; + } + + /* create event queue */ + if ((result = CreateEventQueue()) != SH_OK) { + NN_LOG_ERROR("Failed to create event queue in shm worker initialize"); + return result; + } + + mInited = true; + return SH_OK; +} + +void ShmWorker::UnInitialize() +{ + std::lock_guard locker(mMutex); + if (!mInited) { + return; + } + + if (mEventQueue != nullptr) { + mEventQueue->DecreaseRef(); + mEventQueue = nullptr; + } +} + +HResult ShmWorker::Validate() +{ + // do later + return SH_OK; +} + +HResult ShmWorker::CreateEventQueue() +{ + if (mEventQueue != nullptr) { + NN_LOG_ERROR("Event queue is already created in shm worker " << mName); + return SH_ERROR; + } + + /* get id and data size */ + auto id = GLOBAL_WORKER_INDEX++; + uint64_t dataSize = ShmEventQueue::MemSize(mOptions.eventQueueLength); + + /* create handle for event queue */ + HResult result = SH_OK; + ShmHandlePtr tmpHandle = new (std::nothrow) ShmHandle(mName, SHM_F_EVENT_QUEUE_PREFIX, id, dataSize, true); + if (NN_UNLIKELY(tmpHandle.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new shm handle for worker " << mName << ", probably out of memory"); + return SH_NEW_OBJECT_FAILED; + } + + /* create and initialize event queue */ + ShmEventQueuePtr tmpQueue = new (std::nothrow) ShmEventQueue(mName, mOptions.eventQueueLength, tmpHandle); + if (NN_UNLIKELY(tmpQueue.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new shm event queue for worker " << mName); + return SH_NEW_OBJECT_FAILED; + } + + if ((result = tmpQueue->Initialize()) != SH_OK) { + return result; + } + + /* assign member variables */ + mHandleEventQueue.Set(tmpHandle.Get()); + mEventQueue = tmpQueue.Get(); + mEventQueue->IncreaseRef(); + + return SH_OK; +} + +void SetThreadNameAndAffinity(const std::string& name, int16_t cpuId) +{ + pthread_setname_np(pthread_self(), name.c_str()); + + if ((cpuId) != -1) { + cpu_set_t cpuSet; + CPU_ZERO(&cpuSet); + CPU_SET(cpuId, &cpuSet); + if (pthread_setaffinity_np(pthread_self(), sizeof(cpuSet), &cpuSet) != 0) { + NN_LOG_WARN("Unable to bind shm worker " << name << " << to cpu " << cpuId); + } + } +} + +void ShmWorker::RunInThread(int16_t cpuId) +{ + SetThreadNameAndAffinity("shmWorker" + mIndex.ToString(), cpuId); + + if (mOptions.threadPriority != 0) { + if (NN_UNLIKELY(setpriority(PRIO_PROCESS, 0, mOptions.threadPriority) != 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Unable to set worker thread priority in shm worker " << mName << ", errno:" << errno << + " error:" << NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + } + } + + mProgressThrStarted.store(true); + + if (mOptions.mode == SHM_EVENT_POLLING) { + DoEventPolling(); + } else if (mOptions.mode == SHM_BUSY_POLLING) { + DoBusyPolling(); + } + + NN_LOG_INFO("Shm worker " << mName << " progress thread exiting"); +} + +#define HANDLE_OP(event) \ + do { \ + if ((event).opType == ShmOpContextInfo::ShmOpType::SH_SEND || \ + (event).opType == ShmOpContextInfo::ShmOpType::SH_SEND_RAW_SGL) { \ + if (NN_UNLIKELY((event).shmChannel == nullptr)) { \ + NN_LOG_WARN("Got invalid event " << (event).opType << ", in worker " << mName << " as ch is null"); \ + /* if state is broken ch has already decreased\remove\return in keeper thread */ \ + continue; \ + } \ + if (NN_UNLIKELY((event).shmChannel->State().Compare(CH_BROKEN))) { \ + (event).shmChannel->DecreaseRef(); \ + NN_LOG_WARN("Got invalid event " << (event).opType << " in worker " << mName << " as ch is broken"); \ + /* if state is broken ch has already decreased\remove\return in keeper thread */ \ + continue; \ + } \ + \ + auto compEvent = reinterpret_cast((event).peerChannelAddress); \ + if (NN_UNLIKELY(compEvent == nullptr)) { \ + NN_LOG_WARN("Got invalid event " << (event).opType << " as ctx is null"); \ + (event).shmChannel->DecreaseRef(); \ + continue; \ + } \ + \ + /* if state is broken in this line, ch has already decreased\remove\return in keeper thread, res is 0 */ \ + if (NN_UNLIKELY(compEvent->channel->RemoveOpCompInfo(compEvent) != SH_OK)) { \ + NN_LOG_WARN("Got invalid event " << (event).opType << " as ctx may be removed by keeper thread"); \ + (event).shmChannel->DecreaseRef(); \ + continue; \ + } \ + \ + /* call upper completion function */ \ + /* decrease channel ref and return comp info in upper call */ \ + mSendPostedHandler(*compEvent); \ + (event).shmChannel->DecreaseRef(); \ + } else if ((event).opType == ShmOpContextInfo::ShmOpType::SH_RECEIVE) { \ + auto *ch = reinterpret_cast((event).peerChannelAddress); \ + /* if reset by peer ch ref decrease to 0, ch will be 0 */ \ + if (NN_UNLIKELY(ch == nullptr)) { \ + NN_LOG_WARN("Got invalid event " << (event).opType << " in worker " << mName << ", dropped it"); \ + continue; \ + } \ + \ + if (NN_UNLIKELY(!ch->State().Compare(CH_NEW))) { \ + NN_LOG_WARN("Got invalid event " << (event).opType << " in worker " << mName << " as ch is broken"); \ + continue; \ + } \ + \ + uintptr_t address = 0; \ + if (NN_UNLIKELY((ch->GetPeerDataAddressByOffset((event).dataOffset, address)) != SH_OK)) { \ + NN_LOG_WARN("Got invalid event in worker " << mName << " as get data address failed, dropped it"); \ + continue; \ + } \ + \ + ShmOpContextInfo ctx(ch, address, (event).dataSize, \ + static_cast((event).opType), \ + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR); \ + mNewRequestHandler(ctx, (event).immData); \ + } else if ((event).opType == ShmOpContextInfo::ShmOpType::SH_READ || \ + (event).opType == ShmOpContextInfo::ShmOpType::SH_WRITE || \ + (event).opType == ShmOpContextInfo::ShmOpType::SH_SGL_READ || \ + (event).opType == ShmOpContextInfo::ShmOpType::SH_SGL_WRITE) { \ + if (NN_UNLIKELY((event).shmChannel == nullptr)) { \ + NN_LOG_WARN("Got invalid event " << (event).opType << " in worker " << mName << " as ch is null"); \ + /* if state is broken ch has already decreased\remove\return in keeper thread */ \ + continue; \ + } \ + if (NN_UNLIKELY((event).shmChannel->State().Compare(CH_BROKEN))) { \ + (event).shmChannel->DecreaseRef(); \ + NN_LOG_WARN("Got invalid event " << (event).opType << " in worker " << mName << " as ch is broken"); \ + /* if state is broken ch has already decreased\remove\return in keeper thread */ \ + continue; \ + } \ + \ + auto ctx = reinterpret_cast((event).peerChannelAddress); \ + if (NN_UNLIKELY(ctx == nullptr)) { \ + NN_LOG_WARN("Got invalid event " << (event).opType << " as ctx is null, ch may be broken"); \ + (event).shmChannel->DecreaseRef(); \ + continue; \ + } \ + /* if state is broken in this line, ch has already decreased\remove\return in keeper thread, res is 0 */ \ + if (NN_UNLIKELY((event).shmChannel->RemoveOpCtxInfo(ctx) != SH_OK)) { \ + NN_LOG_WARN("Got invalid event " << (event).opType << " as ctx may be removed by keeper thread"); \ + (event).shmChannel->DecreaseRef(); \ + continue; \ + } \ + /* call upper completion function */ \ + /* decrease channel ref and return comp info in upper call */ \ + mOneSideDoneHandler(ctx); \ + (event).shmChannel->DecreaseRef(); \ + } \ + } while (0) + +void ShmWorker::DoEventPolling() +{ + ShmEvent event {}; + bool stopping = false; + + HResult result; + while (!mNeedToStop) { + if (NN_UNLIKELY((result = mEventQueue->DequeueOrWait(event, stopping, mOptions.pollingTimeoutMs)) != SH_OK)) { + /* timeout need invoke idle */ + if (mIdleHandler != nullptr) { + mIdleHandler(mIndex); + } + continue; + } + + if (NN_UNLIKELY(stopping)) { + NN_LOG_INFO("Get stop sign in shm worker " << mName << ", stopping"); + break; + } + + TRACE_DELAY_BEGIN(SHM_WORKER_EVENT_POLLING); + HANDLE_OP(event); + TRACE_DELAY_END(SHM_WORKER_EVENT_POLLING, 0); + + NN_LOG_TRACE_INFO("got event " << event.ToString() << ", result " << result); + } +} + +void ShmWorker::DoBusyPolling() +{ + ShmEvent event {}; + + while (!mNeedToStop) { + auto result = mEventQueue->Dequeue(event); + if (result == ShmEventQueue::SHM_QUEUE_EMPTY) { + // check if any producer stuck in enqueue. If stuck, kick it out + mEventQueue->CheckAndMarkProducerState(); + /* if there is no coming request, call up idle function */ + if (mIdleHandler != nullptr) { + mIdleHandler(mIndex); + } + continue; + } + + TRACE_DELAY_BEGIN(SHM_WORKER_BUSY_POLLING); + HANDLE_OP(event); + TRACE_DELAY_END(SHM_WORKER_BUSY_POLLING, 0); + + NN_LOG_TRACE_INFO("got event " << event.ToString() << ", result " << result); + } +} + +HResult ShmWorker::Start() +{ + std::lock_guard guard(mMutex); + if (!mInited) { + NN_LOG_ERROR("Failed to start shm worker " << mName << " as it is not initialized"); + return SH_ERROR; + } + + if (mStarted) { + NN_LOG_WARN("Unable to start shm worker " << mName << " as it is already started"); + return SH_OK; + } + + /* validate handler */ + if (mNewRequestHandler == nullptr) { + NN_LOG_ERROR("Failed to start shm worker " << mName << " as new request handler is null"); + return SH_PARAM_INVALID; + } + + if (mSendPostedHandler == nullptr) { + NN_LOG_ERROR("Failed to start shm worker " << mName << " as request posted handler is null"); + return SH_PARAM_INVALID; + } + + if (mOneSideDoneHandler == nullptr) { + NN_LOG_ERROR("Failed to start shm worker " << mName << " as one side done handler is null"); + return SH_PARAM_INVALID; + } + + mNeedToStop = false; + std::thread tmpThread(&ShmWorker::RunInThread, this, mOptions.cpuId); + mProgressThr = std::move(tmpThread); + + while (!mProgressThrStarted.load()) { + usleep(NN_NO10); + } + + mProgressThrStarted = false; + + mStarted = true; + return SH_OK; +} + +void ShmWorker::Stop() +{ + std::lock_guard guard(mMutex); + if (!mStarted) { + return; + } + + mNeedToStop = true; + + if (mOptions.mode == SHM_EVENT_POLLING) { + mEventQueue->LocalStopAndNotify(); + } + + if (mProgressThr.joinable()) { + mProgressThr.join(); + } + + mStarted = false; +} +} +} \ No newline at end of file diff --git a/src/transport/shm/shm_worker.h b/src/transport/shm/shm_worker.h new file mode 100644 index 0000000000000000000000000000000000000000..2958fec7b9532b6a96ea44d450fd715d713d8503 --- /dev/null +++ b/src/transport/shm/shm_worker.h @@ -0,0 +1,312 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SHM_WORKER_H +#define OCK_HCOM_SHM_WORKER_H + +#include + +#include "net_monotonic.h" +#include "shm_common.h" +#include "shm_handle.h" +#include "shm_queue.h" +#include "shm_channel.h" + +namespace ock { +namespace hcom { +using ShmNewReqHandler = std::function; +using ShmPostedHandler = std::function; +using ShmOneSideHandler = std::function; + +struct ShmWorkerOptions { + ShmPollingMode mode = SHM_BUSY_POLLING; + uint16_t eventQueueLength = NN_NO8192; + int16_t cpuId = -1; + uint32_t pollingTimeoutMs = NN_NO500; /* epoll or poll timeout */ + uint16_t pollingBatchSize = NN_NO8; + /* worker thread priority [-20,20], 20 is the lowest, -20 is the highest, 0 (default) means do not set priority */ + int threadPriority = 0; + + std::string ToShortString() const + { + std::ostringstream oss; + oss << "mode: " << ShmPollingModeToStr(mode) << ", poll-timeout: " << pollingTimeoutMs << "us, event-q-cap: " << + eventQueueLength; + return oss.str(); + } +}; + +class ShmWorker { +public: + ShmWorker(const std::string &name, const UBSHcomNetWorkerIndex &index, const ShmWorkerOptions &options, + const NetMemPoolFixedPtr &opMemPool, const NetMemPoolFixedPtr &opCtxMemPool, + const NetMemPoolFixedPtr &sglOpMemPool); + + ~ShmWorker() + { + Stop(); + UnInitialize(); + + OBJ_GC_DECREASE(ShmWorker); + } + + HResult Initialize(); + void UnInitialize(); + + HResult Start(); + void Stop(); + + inline void RegisterNewReqHandler(const ShmNewReqHandler &h) + { + mNewRequestHandler = h; + } + + inline void RegisterReqPostedHandler(const ShmPostedHandler &h) + { + mSendPostedHandler = h; + } + + inline void RegisterOneSideHandler(const ShmOneSideHandler &h) + { + mOneSideDoneHandler = h; + } + + inline void RegisterIdleHandler(const ShmIdleHandler &h) + { + mIdleHandler = h; + } + + inline const std::string &Name() const + { + return mName; + } + + inline bool FillQueueExchangeInfo(ShmConnExchangeInfo &info) + { + if (NN_UNLIKELY(mEventQueue != nullptr)) { + info.qCapacity = mEventQueue->Capacity(); + } + + if (NN_LIKELY(mHandleEventQueue.Get() != nullptr)) { + info.queueFd = mHandleEventQueue->Fd(); + return info.SetQueueName(mHandleEventQueue->FullPath()); + } + + info.mode = mOptions.mode; + return false; + } + + inline const UBSHcomNetWorkerIndex &Index() const + { + return mIndex; + } + + inline void ReturnOpContextInfo(ShmOpContextInfo *ctx) + { + if (NN_LIKELY(ctx != nullptr)) { + if (NN_LIKELY(ctx->channel != nullptr)) { + ctx->channel->DecreaseRef(); + } + mOpCtxInfoPool.Return(ctx); + ctx = nullptr; + } + } + + inline void ReturnOpCompInfo(ShmOpCompInfo *ctx) + { + if (NN_LIKELY(ctx != nullptr)) { + if (NN_LIKELY(ctx->channel != nullptr)) { + ctx->channel->DecreaseRef(); + } + mOpCompInfoPool.Return(ctx); + ctx = nullptr; + } + } + + inline void ReturnSglContextInfo(ShmSglOpContextInfo *&ctx) + { + if (NN_LIKELY(ctx != nullptr)) { + mSglCtxInfoPool.Return(ctx); + ctx = nullptr; + } + } + + HResult PostSend(ShmChannel *ch, const UBSHcomNetTransRequest &req, uint64_t offset, uint32_t immData, + int32_t defaultTimeout); + HResult PostSendRawSgl(ShmChannel *ch, const UBSHcomNetTransRequest &req, const UBSHcomNetTransSglRequest &sglReq, + uint64_t offset, uint32_t immData, int32_t defaultTimeout); + HResult PostRead(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap); + HResult PostReadSgl(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, ShmMRHandleMap &mrHandleMap); + HResult PostWrite(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap); + HResult PostWriteSgl(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, ShmMRHandleMap &mrHandleMap); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + HResult Validate(); + HResult CreateEventQueue(); + void RunInThread(int16_t cpuId); + + void DoEventPolling(); + void DoBusyPolling(); + + HResult PostReadWrite(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap, + ShmOpContextInfo::ShmOpType type); + + HResult PostReadWriteSgl(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, ShmMRHandleMap &mrHandleMap, + ShmOpContextInfo::ShmOpType type); + + uint64_t inline GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + static bool inline NeedRetry(HResult &result, ShmChannel *ch) + { + if (NN_UNLIKELY(ch->State().Compare(CH_BROKEN))) { + result = SH_CH_BROKEN; + return false; + } + + if (result == ShmEventQueue::SHM_QUEUE_FULL) { + return true; + } + + return false; + } + + HResult FillSglCtx(ShmSglOpContextInfo *sglCtx, const UBSHcomNetTransSglRequest &sglReq); + HResult SendLocalEvent(uintptr_t ctx, ShmChannel *ch, ShmOpContextInfo::ShmOpType type); + +private: + static std::atomic GLOBAL_WORKER_INDEX; + +private: + std::string mName; + std::mutex mMutex; + UBSHcomNetWorkerIndex mIndex {}; + bool mInited = false; + int32_t mDefaultTimeout = -1; + + ShmWorkerOptions mOptions {}; + + /* variable for thread */ + std::thread mProgressThr; /* thread object of progress */ + bool mStarted = false; /* thread already started or not */ + std::atomic_bool mProgressThrStarted { false }; /* started flag */ + volatile bool mNeedToStop = false; /* flag to be stopped */ + + ShmOpCompInfoPool mOpCompInfoPool; /* op completion pool */ + ShmOpContextInfoPool mOpCtxInfoPool; /* op completion pool */ + ShmSglContextInfoPool mSglCtxInfoPool; /* sgl op context pool */ + + ShmEventQueue *mEventQueue = nullptr; /* event queue for polling with both event and busy mode */ + ShmNewReqHandler mNewRequestHandler = nullptr; /* request process related */ + ShmPostedHandler mSendPostedHandler = nullptr; /* send request posted process related */ + ShmOneSideHandler mOneSideDoneHandler = nullptr; /* one side done will call this */ + ShmIdleHandler mIdleHandler = nullptr; /* no request will call this */ + + ShmHandlePtr mHandleEventQueue; /* handle of event queue */ + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; + +inline HResult ShmWorker::PostSend(ShmChannel *ch, const UBSHcomNetTransRequest &req, uint64_t offset, uint32_t immData, + int32_t defaultTimeout = -1) +{ + /* upper caller need to make sure ch is not null */ + if (NN_UNLIKELY(req.upCtxSize > sizeof(ShmOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to PostSend with ShmWorker " << mName << " as upCtxSize > " << + sizeof(ShmOpContextInfo::upCtx)); + return SH_PARAM_INVALID; + } + + if (NN_UNLIKELY(ch->State().Compare(CH_BROKEN))) { + NN_LOG_ERROR("Failed to PostSend with ShmWorker " << mName << " as ch status is broken"); + return SH_CH_BROKEN; + } + + mDefaultTimeout = defaultTimeout; + + ShmEvent event(immData, req.size, offset, ch->Id(), ch->PeerChannelId(), ch->PeerChannelAddress(), + ShmOpContextInfo::ShmOpType::SH_RECEIVE); + auto result = ch->EQEventEnqueue(event); + if (NN_UNLIKELY(result != SH_OK)) { + if (result == ShmEventQueue::SHM_QUEUE_FULL) { + NN_LOG_ERROR("Failed to PostSend with ShmWorker " << mName << " as event queue is full"); + result = SH_RETRY_FULL; + } + return result; + } + + /* get op completion ctx */ + auto ctx = mOpCompInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostSend with ShmWorker " << mName << " as no opCtx left"); + return SH_OP_CTX_FULL; + } + + bzero(ctx, sizeof(ShmOpCompInfo)); + if (immData == 0) { + ctx->header = *(reinterpret_cast(req.lAddress)); + } + ctx->header.immData = immData; + ctx->channel = ch; + ctx->request = req; + ctx->opType = immData == 0 ? ShmOpContextInfo::ShmOpType::SH_SEND : ShmOpContextInfo::ShmOpType::SH_SEND_RAW; + ch->IncreaseRef(); + ch->AddOpCompInfo(ctx); + + /* send local event for send completion callback */ + result = SendLocalEvent(reinterpret_cast(ctx), ch, ShmOpContextInfo::ShmOpType::SH_SEND); + if (NN_UNLIKELY(result != SH_OK && result != SH_CH_BROKEN)) { + /* if state is broken ch of ctx has already decreased\remove\return in keeper thread, ensure not deal twice */ + /* if state is ok ch of ctx has already decreased in worker thread */ + if (NN_UNLIKELY(ch->RemoveOpCompInfo(ctx) != SH_OK)) { + return result; + } + ch->DecreaseRef(); + mOpCompInfoPool.Return(ctx); + } + + return result; +} + +inline HResult ShmWorker::PostRead(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap) +{ + return PostReadWrite(ch, req, mrHandleMap, ShmOpContextInfo::ShmOpType::SH_READ); +} + +inline HResult ShmWorker::PostReadSgl(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, + ShmMRHandleMap &mrHandleMap) +{ + return PostReadWriteSgl(ch, req, mrHandleMap, ShmOpContextInfo::ShmOpType::SH_SGL_READ); +} + +inline HResult ShmWorker::PostWrite(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap) +{ + return PostReadWrite(ch, req, mrHandleMap, ShmOpContextInfo::ShmOpType::SH_WRITE); +} + +inline HResult ShmWorker::PostWriteSgl(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, + ShmMRHandleMap &mrHandleMap) +{ + return PostReadWriteSgl(ch, req, mrHandleMap, ShmOpContextInfo::ShmOpType::SH_SGL_WRITE); +} +} +} + +#endif // OCK_HCOM_SHM_WORKER_H diff --git a/src/transport/shm/shm_worker_io.cpp b/src/transport/shm/shm_worker_io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..66d1d6052262cfe4669681221f2897dfa6390b66 --- /dev/null +++ b/src/transport/shm/shm_worker_io.cpp @@ -0,0 +1,344 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "shm_worker.h" +#include "shm_handle.h" +#include "shm_queue.h" + +namespace ock { +namespace hcom { + + +HResult ShmWorker::FillSglCtx(ShmSglOpContextInfo *sglCtx, const UBSHcomNetTransSglRequest &sglReq) +{ + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostSendRawSgl with ShmWorker as no ctx left"); + return SH_PARAM_INVALID; + } + + sglCtx->result = SH_OK; + if (NN_UNLIKELY(memcpy_s(sglCtx->iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, + sglReq.iov, sizeof(UBSHcomNetTransSgeIov) * sglReq.iovCount) != SH_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return SH_PARAM_INVALID; + } + sglCtx->iovCount = sglReq.iovCount; + sglCtx->upCtxSize = sglReq.upCtxSize; + if (sglReq.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx->upCtx, NN_NO16, sglReq.upCtxData, sglReq.upCtxSize) != SH_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return SH_PARAM_INVALID; + } + } + + return SH_OK; +} + +HResult ShmWorker::SendLocalEvent(uintptr_t ctx, ShmChannel *ch, ShmOpContextInfo::ShmOpType type) +{ + ShmEvent eventSent(ctx, type); + eventSent.SetChannel(ch); + /* if failed decrease in this thread, if success or broken decrease worker thread */ + eventSent.shmChannel->IncreaseRef(); + + uint64_t finishTime = GetFinishTime(); + bool flag = true; + HResult result = SH_OK; + do { + if (mOptions.mode == SHM_EVENT_POLLING) { + result = mEventQueue->EnqueueAndNotify(eventSent); + } else if (mOptions.mode == SHM_BUSY_POLLING) { + result = mEventQueue->Enqueue(eventSent); + } + if (result == SH_OK) { + return SH_OK; + } else if (NeedRetry(result, ch) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + if (result == ShmEventQueue::SHM_QUEUE_FULL) { + eventSent.shmChannel->DecreaseRef(); + return SH_SEND_COMPLETION_CALLBACK_FAILURE; + } + + flag = false; + } while (flag); + + eventSent.shmChannel->DecreaseRef(); + return result; +} + +HResult ShmWorker::PostSendRawSgl(ShmChannel *ch, const UBSHcomNetTransRequest &req, + const UBSHcomNetTransSglRequest &sglReq, uint64_t offset, uint32_t immData, int32_t defaultTimeout = -1) +{ + /* upper caller need to make sure ch is not null */ + if (NN_UNLIKELY(sglReq.upCtxSize > sizeof(ShmOpContextInfo::upCtx))) { + NN_LOG_ERROR("Shm Failed to PostSend with ShmWorker " << mName << " as upCtxSize > " << + sizeof(ShmOpContextInfo::upCtx)); + return SH_PARAM_INVALID; + } + + if (NN_UNLIKELY(ch->State().Compare(CH_BROKEN))) { + NN_LOG_ERROR("Shm Failed to PostSend with ShmWorker " << mName << " as ch status is broken"); + return SH_CH_BROKEN; + } + + mDefaultTimeout = defaultTimeout; + + ShmEvent event(immData, req.size, offset, ch->Id(), ch->PeerChannelId(), ch->PeerChannelAddress(), + ShmOpContextInfo::ShmOpType::SH_RECEIVE); + auto result = ch->EQEventEnqueue(event); + if (NN_UNLIKELY(result != SH_OK)) { + if (result == ShmEventQueue::SHM_QUEUE_FULL) { + result = SH_RETRY_FULL; + } + return result; + } + + /* get op sgl ctx */ + auto sglCtx = mSglCtxInfoPool.Get(); + result = FillSglCtx(sglCtx, sglReq); + if (NN_UNLIKELY(result != SH_OK)) { + return result; + } + + /* get op completion ctx */ + auto ctx = mOpCompInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Shm Failed to PostSend with ShmWorker " << mName << " as no opCtx left"); + return SH_OP_CTX_FULL; + } + + bzero(ctx, sizeof(ShmOpCompInfo)); + if (immData == 0) { + ctx->header = *(reinterpret_cast(req.lAddress)); + } + ctx->header.immData = immData; + ctx->channel = ch; + ctx->request = req; + ctx->opType = ShmOpContextInfo::ShmOpType::SH_SEND_RAW_SGL; + ctx->upCtxSize = sizeof(ShmSglOpCompInfo); + auto upCtx = static_cast((void *)&(ctx->upCtx)); + upCtx->ctx = sglCtx; + ch->IncreaseRef(); + ch->AddOpCompInfo(ctx); + + /* send local event for send completion callback */ + result = SendLocalEvent(reinterpret_cast(ctx), ch, ShmOpContextInfo::ShmOpType::SH_SEND_RAW_SGL); + if (NN_UNLIKELY(result != SH_OK && result != SH_CH_BROKEN)) { + /* if state is broken ch of ctx has already decreased\remove\return in keeper thread, ensure not deal twice */ + /* if state is ok ch of ctx has already decreased in worker thread */ + if (NN_UNLIKELY(ch->RemoveOpCompInfo(ctx) != SH_OK)) { + return result; + } + ch->DecreaseRef(); + mOpCompInfoPool.Return(ctx); + mSglCtxInfoPool.Return(sglCtx); + } + + return result; +} + +static inline HResult ReadWriteProcess(UBSHcomNetTransSgeIov iov, ShmMRHandleMap &mrHandleMap, ShmChannel *ch, + ShmOpContextInfo::ShmOpType type) +{ + auto localMemHandle = mrHandleMap.GetFromLocalMap(static_cast(iov.lKey)); + if (NN_UNLIKELY(localMemHandle == nullptr)) { + return SH_ERROR; + } + + auto remoteMemHandle = mrHandleMap.GetFromRemoteMap(static_cast(iov.rKey)); + if (remoteMemHandle == nullptr) { + /* remote address not exist in local map, exchange mr fd and mmap before copy */ + auto result = ch->GetRemoteMrHandle(static_cast(iov.rKey), iov.size, mrHandleMap); + if (NN_UNLIKELY(result != NN_OK)) { + return result; + } + remoteMemHandle = mrHandleMap.GetFromRemoteMap(static_cast(iov.rKey)); + } + + /* address has mmap already, copy directly */ + if (type == ShmOpContextInfo::ShmOpType::SH_READ || type == ShmOpContextInfo::ShmOpType::SH_SGL_READ) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(localMemHandle->ShmAddress()), localMemHandle->DataSize(), + reinterpret_cast(remoteMemHandle->ShmAddress()), iov.size) != SH_OK)) { + NN_LOG_ERROR("Failed to copy remoteMemHandle to localMemHandle"); + return SH_PARAM_INVALID; + } + } else if (type == ShmOpContextInfo::ShmOpType::SH_WRITE || type == ShmOpContextInfo::ShmOpType::SH_SGL_WRITE) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(remoteMemHandle->ShmAddress()), remoteMemHandle->DataSize(), + reinterpret_cast(localMemHandle->ShmAddress()), iov.size) != SH_OK)) { + NN_LOG_ERROR("Failed to copy localMemHandle to remoteMemHandle"); + return SH_PARAM_INVALID; + } + } else { + NN_LOG_INFO("Failed to PostReadWrite unreachable path"); + return SH_ERROR; + } + + return SH_OK; +} + +HResult ShmWorker::PostReadWrite(ShmChannel *ch, const UBSHcomNetTransRequest &req, ShmMRHandleMap &mrHandleMap, + ShmOpContextInfo::ShmOpType type) +{ + /* upper caller need to make sure ch is not null */ + if (NN_UNLIKELY(req.upCtxSize > sizeof(ShmOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to PostReadWrite type:" << type << " with ShmWorker " << mName << " as upCtxSize > " << + sizeof(ShmOpContextInfo::upCtx)); + return SH_PARAM_INVALID; + } + + if (NN_UNLIKELY(ch->State().Compare(CH_BROKEN))) { + NN_LOG_ERROR("Failed to PostSend with ShmWorker " << mName << " as ch status is broken"); + return SH_CH_BROKEN; + } + + UBSHcomNetTransSgeIov iov {}; + iov.lKey = req.lKey; + iov.rKey = req.rKey; + iov.size = req.size; + + HResult result = SH_OK; + if (NN_UNLIKELY((result = ReadWriteProcess(iov, mrHandleMap, ch, type)) != SH_OK)) { + return result; + } + + /* get op ctx */ + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostReadWrite type:" << type << " with ShmWorker " << mName << " as no opCtx left"); + return SH_OP_CTX_FULL; + } + + bzero(ctx, sizeof(ShmOpContextInfo)); + ctx->channel = ch; + ctx->mrMemAddr = req.lAddress; + ctx->lKey = static_cast(req.lKey); + ctx->dataSize = req.size; + ctx->opType = type; + ctx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(ctx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + ch->IncreaseRef(); + ch->AddOpCtxInfo(ctx); + + /* send local event for one side done callback */ + result = SendLocalEvent(reinterpret_cast(ctx), ch, type); + if (NN_UNLIKELY(result != SH_OK && result != SH_CH_BROKEN)) { + /* if state is broken ch of ctx has already decreased\remove\return in keeper thread, ensure not deal twice */ + /* if state is ok ch of ctx has already decreased in worker thread */ + if (NN_UNLIKELY(ch->RemoveOpCtxInfo(ctx) != SH_OK)) { + return result; + } + ch->DecreaseRef(); + mOpCtxInfoPool.Return(ctx); + } + + return result; +} + + +static inline void FillReadWriteSglCtx(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, + ShmOpContextInfo::ShmOpType type, ShmOpContextInfo *ctx, ShmSglOpContextInfo *sglCtx) +{ + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(sglCtx->iov), sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, + reinterpret_cast(req.iov), sizeof(UBSHcomNetTransSgeIov) * req.iovCount) != SH_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return; + } + sglCtx->iovCount = req.iovCount; + sglCtx->upCtxSize = req.upCtxSize; + if (sglCtx->upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx->upCtx, NN_NO16, req.upCtxData, sglCtx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return; + } + } + bzero(ctx, sizeof(ShmOpContextInfo)); + ctx->channel = ch; + ctx->mrMemAddr = 0; + ctx->lKey = 0; + ctx->dataSize = sizeof(UBSHcomNetTransSgeIov) * req.iovCount; + ctx->opType = type; + ctx->upCtxSize = sizeof(ShmSglOpCompInfo); + + auto upCtx = reinterpret_cast(&ctx->upCtx); + upCtx->ctx = sglCtx; +} + +HResult ShmWorker::PostReadWriteSgl(ShmChannel *ch, const UBSHcomNetTransSglRequest &req, ShmMRHandleMap &mrHandleMap, + ShmOpContextInfo::ShmOpType type) +{ + /* upper caller need to make sure ch is not null */ + if (NN_UNLIKELY(req.upCtxSize > sizeof(ShmOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to PostReadWriteSgl type:" << type << " with ShmWorker " << mName << " as upCtxSize > " << + sizeof(ShmOpContextInfo::upCtx)); + return SH_PARAM_INVALID; + } + + if (NN_UNLIKELY(ch->State().Compare(CH_BROKEN))) { + NN_LOG_ERROR("Failed to PostSend with ShmWorker " << mName << " as ch status is broken"); + return SH_CH_BROKEN; + } + + /* get op ctx */ + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostReadWriteSgl type:" << type << " with ShmWorker " << mName << " as no ctx left"); + return SH_OP_CTX_FULL; + } + + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostReadWriteSgl with ShmWorker " << mName << " as no sglCtx left"); + mOpCtxInfoPool.Return(ctx); + return SH_PARAM_INVALID; + } + + FillReadWriteSglCtx(ch, req, type, ctx, sglCtx); + + ch->IncreaseRef(); + ch->AddOpCtxInfo(ctx); + + HResult result = SH_OK; + for (auto i = 0; i < req.iovCount; i++) { + if (NN_UNLIKELY((result = ReadWriteProcess(req.iov[i], mrHandleMap, ch, type)) != SH_OK)) { + if (NN_UNLIKELY(ch->RemoveOpCtxInfo(ctx) != SH_OK)) { + return result; + } + ch->DecreaseRef(); + mOpCtxInfoPool.Return(ctx); + mSglCtxInfoPool.Return(sglCtx); + return result; + } + } + + /* send local event for one side done callback */ + result = SendLocalEvent(reinterpret_cast(ctx), ch, type); + if (NN_UNLIKELY(result != SH_OK && result != SH_CH_BROKEN)) { + /* if state is broken ch of ctx has already decreased\remove\return in keeper thread, ensure not deal twice */ + /* if state is ok ch of ctx has already decreased in worker thread */ + if (NN_UNLIKELY(ch->RemoveOpCtxInfo(ctx) != SH_OK)) { + return result; + } + ch->DecreaseRef(); + mOpCtxInfoPool.Return(ctx); + mSglCtxInfoPool.Return(sglCtx); + } + + return result; +} +} +} \ No newline at end of file diff --git a/src/transport/sock/net_sock_async_endpoint.cpp b/src/transport/sock/net_sock_async_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a7d2e0d558920cac48b778112e5f7c7d2fd9e14 --- /dev/null +++ b/src/transport/sock/net_sock_async_endpoint.cpp @@ -0,0 +1,555 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_sock_driver_oob.h" +#include "sock_validation.h" +#include "net_sock_async_endpoint.h" + +namespace ock { +namespace hcom { +NetAsyncEndpointSock::NetAsyncEndpointSock(uint64_t id, Sock *sock, NetDriverSockWithOOB *driver, + const UBSHcomNetWorkerIndex &workerIndex) + : NetEndpointImpl(id, workerIndex), mSock(sock), mDriver(driver) +{ + if (mSock != nullptr) { + mSock->IncreaseRef(); + mWorker = reinterpret_cast(mSock->UpContext1()); + } + + if (mWorker != nullptr) { + mWorker->IncreaseRef(); + } + + if (mDriver != nullptr) { + mSegSize = mDriver->mOptions.mrSendReceiveSegSize; + mAllowedSize = mSegSize - sizeof(SockTransHeader); + mDriver->IncreaseRef(); + } + + OBJ_GC_INCREASE(NetAsyncEndpointSock); +} + +NetAsyncEndpointSock::~NetAsyncEndpointSock() +{ + if (mWorker != nullptr && mSock != nullptr) { + mWorker->RemoveFromEpoll(mSock); + } + + if (mSock != nullptr) { + mSock->Close(); + mSock->DecreaseRef(); + } + + if (mWorker != nullptr) { + mWorker->DecreaseRef(); + } + + if (mDriver != nullptr) { + mDriver->DecreaseRef(); + } + + OBJ_GC_DECREASE(NetAsyncEndpointSock); + // do later +} + +NResult NetAsyncEndpointSock::SetEpOption(UBSHcomEpOptions &epOptions) +{ + if (!epOptions.tcpBlockingIo) { + NN_LOG_WARN("Tcp is nonblocking in default, there is no need to set it again"); + return NN_OK; + } + + if (mDefaultTimeout > 0 && epOptions.sendTimeout > mDefaultTimeout) { + NN_LOG_WARN("send timeout should not longer than mDefaultTimeout " << mDefaultTimeout); + return NN_ERROR; + } + + if (NN_UNLIKELY(mSock->SetBlockingIo(epOptions) != SS_OK)) { + NN_LOG_WARN("Unable to set sock " << mSock->Name() << " blocking io mode."); + return NN_ERROR; + } + + return NN_OK; +} + +uint32_t NetAsyncEndpointSock::GetSendQueueCount() +{ + return mSock->GetSendQueueCount(); +} + +NResult NetAsyncEndpointSock::PostSendZCopy(int16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) +{ + REQ_SIZE_VALIDATION_ZERO_COPY(); + + UBSHcomNetTransHeader header{}; + if (opCode == -1) { + header.immData = 1; + } else { + header.opCode = opCode; + } + header.seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header.flags = NTH_TWO_SIDE; + header.timeout = opInfo.timeout; + header.errorCode = opInfo.errorCode; + header.dataLength = request.size; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + auto worker = reinterpret_cast(mSock->UpContext1()); + + NResult result = NN_OK; + uint64_t finishTimeSend = GetFinishTime(); + TRACE_DELAY_BEGIN(SOCK_EP_ASYNC_POST_SEND); + do { + result = worker->PostSend(mSock, header, request); + if (result == SS_OK) { + NN_LOG_TRACE_INFO("Sock Post send ep id " << mId << ", flag " << header.flags << ", seqNo " << + header.seqNo << ", size " << request.size); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTimeSend) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + break; + } while (true); + + NN_LOG_ERROR("Failed to async post send request, result " << result); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND, result); + return result; +} + +NResult NetAsyncEndpointSock::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post send as state validation failed."); + return result; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post send as buff validation failed."); + return result; + } + OPCODE_VALIDATION(); + + if (mSendZCopy) { + UBSHcomNetTransOpInfo opInfo(seqNo, 0, 0, 0); + return PostSendZCopy(opCode, request, opInfo); + } + + REQ_SIZE_VALIDATION(); + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mSockDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to async post send message as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = seqNo == 0 ? NextSeq() : seqNo; + header->flags = NTH_TWO_SIDE; + header->dataLength = request.size; + auto dataAddress = mrBufAddress + sizeof(SockTransHeader); // req data start address + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(dataAddress), mDriver->mSockDriverSendMR->GetSingleSegSize() - + sizeof(SockTransHeader), reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mSockDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to dataAddress"); + return NN_INVALID_PARAM; + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + auto worker = reinterpret_cast(mSock->UpContext1()); + + uint64_t finishTimeSend = GetFinishTime(); + TRACE_DELAY_BEGIN(SOCK_EP_ASYNC_POST_SEND); + do { + result = worker->PostSend(mSock, *header, request); + if (result == SS_OK) { + NN_LOG_TRACE_INFO("Post send ep id " << mId << ", flag " << header->flags << ", seqNo " << header->seqNo << + ", size " << request.size); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTimeSend) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + break; + } while (true); + + mDriver->mSockDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to async post send request, result " << result); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND, result); + return result; +} + +NResult NetAsyncEndpointSock::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post send as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post send as buff validation failed"); + return result; + } + OPCODE_VALIDATION(); + + if (mSendZCopy) { + return PostSendZCopy(opCode, request, opInfo); + } + + REQ_SIZE_VALIDATION(); + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mSockDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to async post send message with opInfo as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + auto *sockHeader = reinterpret_cast(mrBufAddress); + bzero(sockHeader, sizeof(UBSHcomNetTransHeader)); + sockHeader->opCode = opCode; + sockHeader->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + sockHeader->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint16_t)NTH_TWO_SIDE; + sockHeader->timeout = opInfo.timeout; + sockHeader->errorCode = opInfo.errorCode; + sockHeader->dataLength = request.size; + auto dataAddress = mrBufAddress + sizeof(SockTransHeader); // req data start address + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(dataAddress), mDriver->mSockDriverSendMR->GetSingleSegSize() - + sizeof(SockTransHeader), reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mSockDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to dataAddress"); + return NN_INVALID_PARAM; + } + + /* finally fill sockHeader crc */ + sockHeader->headerCrc = NetFunc::CalcHeaderCrc32(sockHeader); + auto worker = reinterpret_cast(mSock->UpContext1()); + + uint64_t finishTimeOpSend = GetFinishTime(); + TRACE_DELAY_BEGIN(SOCK_EP_ASYNC_POST_SEND); + do { + result = worker->PostSend(mSock, *sockHeader, request); + if (result == SS_OK) { + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTimeOpSend) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + break; + } while (true); + + mDriver->mSockDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to async post send request with opInfo, result " << result); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND, result); + return result; +} + +NResult NetAsyncEndpointSock::PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post send raw as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post send raw as buff validation failed"); + return result; + } + + if (mSendZCopy) { + UBSHcomNetTransOpInfo opInfo(seqNo, 0, 0, 0); + return PostSendZCopy(-1, request, opInfo); + } + + REQ_SIZE_VALIDATION(); + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mSockDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to post async message as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->immData = 1; + header->seqNo = seqNo == 0 ? NextSeq() : seqNo; + header->flags = NTH_TWO_SIDE; + header->dataLength = request.size; + auto dataAddress = mrBufAddress + sizeof(SockTransHeader); // req data start address + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(dataAddress), mDriver->mSockDriverSendMR->GetSingleSegSize(), + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mSockDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to dataAddress"); + return NN_INVALID_PARAM; + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + auto worker = reinterpret_cast(mSock->UpContext1()); + + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SOCK_EP_ASYNC_POST_SEND_RAW); + do { + result = worker->PostSend(mSock, *header, request); + if (result == SS_OK) { + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND_RAW, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + break; + } while (true); + + mDriver->mSockDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to async post send raw request, result " << result); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND_RAW, result); + return result; +} + +NResult NetAsyncEndpointSock::PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) +{ + size_t totalSize = 0; + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post send raw sgl as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = TwoSideSglValidation(request, mDriver, mSegSize, totalSize)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post send raw sgl as sgl validation failed"); + return result; + } + + UBSHcomNetTransHeader header {}; + header.seqNo = seqNo == 0 ? NextSeq() : seqNo; + header.immData = 1; + header.flags = NTH_TWO_SIDE_SGL; + header.dataLength = totalSize; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + auto worker = reinterpret_cast(mSock->UpContext1()); + + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SOCK_EP_ASYNC_POST_SEND_RAW_SGL); + do { + result = worker->PostSendRawSgl(mSock, header, request); + if (result == SS_OK) { + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND_RAW_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + break; + } while (true); + + NN_LOG_ERROR("Failed to post send request, result " << result); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_SEND_RAW_SGL, result); + return result; +} + +NResult NetAsyncEndpointSock::PostRead(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post read as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post read as buff validation failed"); + return result; + } + + if (NN_UNLIKELY((result = OneSideValidation(request, mDriver)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post read as one side validation failed"); + return result; + } + + UBSHcomNetTransHeader header {}; + header.seqNo = mSock->OneSideNextSeq(); // do later change to NextReq() + header.flags = NTH_READ; + header.dataLength = sizeof(UBSHcomNetTransSgeIov); + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + auto worker = reinterpret_cast(mSock->UpContext1()); + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SOCK_EP_ASYNC_POST_READ); + do { + result = worker->PostRead(mSock, header, request); + if (result == SS_OK) { + NN_LOG_TRACE_INFO("Post read ep id " << mId << ", flag " << header.flags << ", seqNo " << header.seqNo << + ", size " << request.size); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_READ, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post read request, result " << result); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_READ, result); + return result; +} + +NResult NetAsyncEndpointSock::PostRead(const UBSHcomNetTransSglRequest &request) +{ + size_t totalSize = 0; + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post read sgl as state validation failed"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY((result = OneSideSglValidation(request, mDriver, totalSize)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post read sgl as sgl validation failed"); + return NN_INVALID_PARAM; + } + + UBSHcomNetTransHeader header {}; + header.seqNo = mSock->OneSideNextSeq(); + header.flags = NTH_READ_SGL; + header.dataLength = sizeof(request.iovCount) + sizeof(UBSHcomNetTransSgeIov) * request.iovCount; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + auto worker = reinterpret_cast(mSock->UpContext1()); + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SOCK_EP_ASYNC_POST_READ_SGL); + do { + result = worker->PostRead(mSock, header, request); + if (result == SS_OK) { + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_READ_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post read sgl request, result " << result); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_READ_SGL, result); + return result; +} + +NResult NetAsyncEndpointSock::PostWrite(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post write as state validation failed"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post write as buff validation failed"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY((result = OneSideValidation(request, mDriver)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post write as one side validation failed"); + return NN_INVALID_PARAM; + } + + UBSHcomNetTransHeader header {}; + header.seqNo = mSock->OneSideNextSeq(); + header.flags = NTH_WRITE; + header.dataLength = sizeof(UBSHcomNetTransSgeIov) + request.size; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + auto worker = reinterpret_cast(mSock->UpContext1()); + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SOCK_EP_ASYNC_POST_WRITE); + do { + result = worker->PostWrite(mSock, header, request); + if (result == SS_OK) { + NN_LOG_TRACE_INFO("Post write ep id " << mId << ", flag " << header.flags << ", seqNo " << header.seqNo << + ", size " << request.size); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_WRITE, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post write request, result " << result); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_WRITE, result); + return result; +} + +NResult NetAsyncEndpointSock::PostWrite(const UBSHcomNetTransSglRequest &request) +{ + size_t totalSize = 0; + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post write sgl as state validation failed"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY((result = OneSideSglValidation(request, mDriver, totalSize)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to async post write sgl as sgl validation failed"); + return NN_INVALID_PARAM; + } + + UBSHcomNetTransHeader header {}; + header.seqNo = mSock->OneSideNextSeq(); + header.flags = NTH_WRITE_SGL; + header.dataLength = sizeof(request.iovCount) + sizeof(UBSHcomNetTransSgeIov) * request.iovCount + totalSize; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + auto worker = reinterpret_cast(mSock->UpContext1()); + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SOCK_EP_ASYNC_POST_WRITE_SGL); + do { + result = worker->PostWrite(mSock, header, request); + if (result == SS_OK) { + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_WRITE_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post write sgl request, result " << result); + TRACE_DELAY_END(SOCK_EP_ASYNC_POST_WRITE_SGL, result); + return result; +} +} +} \ No newline at end of file diff --git a/src/transport/sock/net_sock_async_endpoint.h b/src/transport/sock/net_sock_async_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..836cefccdef7b916b310a235eb9e5b19714c92de --- /dev/null +++ b/src/transport/sock/net_sock_async_endpoint.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_SOCK_ASYNC_ENDPOINT_H +#define OCK_HCOM_NET_SOCK_ASYNC_ENDPOINT_H + +#include "transport/net_endpoint_impl.h" +#include "net_monotonic.h" +#include "net_security_alg.h" +#include "net_sock_common.h" +#include "sock_common.h" + +namespace ock { +namespace hcom { +class NetAsyncEndpointSock : public NetEndpointImpl { +public: + NetAsyncEndpointSock(uint64_t id, Sock *sock, NetDriverSockWithOOB *driver, + const UBSHcomNetWorkerIndex &workerIndex); + ~NetAsyncEndpointSock() override; + + NResult SetEpOption(UBSHcomEpOptions &epOptions) override; + + uint32_t GetSendQueueCount() override; + + NResult PostSendZCopy(int16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo); + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNo) override; + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + + NResult PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) override; + + NResult PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) override; + + NResult PostRead(const UBSHcomNetTransRequest &request) override; + + NResult PostRead(const UBSHcomNetTransSglRequest &request) override; + + NResult PostWrite(const UBSHcomNetTransRequest &request) override; + + NResult PostWrite(const UBSHcomNetTransSglRequest &request) override; + + const std::string &PeerIpAndPort() override + { + if (NN_LIKELY(mSock != nullptr)) { + return mSock->PeerIpPort(); + } + + return CONST_EMPTY_STRING; + } + + const std::string &UdsName() override + { + NN_LOG_WARN("[Sock AsyncEp] Empty function for now"); + return CONST_EMPTY_STRING; + } + + inline NResult WaitCompletion(int32_t timeout) override + { + NN_LOG_WARN("Invalid operation, wait completion is not supported by NetAsyncEndpointSock"); + return NN_INVALID_OPERATION; + } + + inline NResult Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) override + { + NN_LOG_WARN("Invalid operation, wait completion is not supported by NetAsyncEndpointSock"); + return NN_INVALID_OPERATION; + } + + inline NResult ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) override + { + NN_LOG_WARN("Invalid operation, wait completion is not supported by NetAsyncEndpointSock"); + return NN_INVALID_OPERATION; + } + + void Close() override + { + if (mState.Compare(NEP_ESTABLISHED)) { + mState.Set(NEP_BROKEN); + } else { + return; + } + NN_LOG_INFO("Close tcp ep id " << mId << " by user"); + mWorker->EpCloseByUser(mSock); + } + + bool GetPeerIpPort(std::string &ip, uint16_t &port) override + { + if (NN_UNLIKELY(mSock == nullptr)) { + return false; + } + + auto ipPort = mSock->PeerIpPort(); + if (NN_UNLIKELY(ipPort.empty())) { + NN_LOG_ERROR("[Sock AsyncEp] ip and port of peer is empty"); + return false; + } + + std::vector ipPortVec; + NetFunc::NN_SplitStr(ipPort, ":", ipPortVec); + if (NN_UNLIKELY(ipPortVec.size() != NN_NO2)) { + NN_LOG_ERROR("[Sock AsyncEp] ip and port of peer is invalid"); + return false; + } + + try { + port = std::stoi(ipPortVec[1]); + } catch (...) { + NN_LOG_ERROR("[Sock AsyncEp] port of peer is invalid"); + return false; + } + if (port == 0) { + NN_LOG_ERROR("[Sock AsyncEp] oob type is uds, does not have peer ip and port msg"); + return false; + } + ip = ipPortVec[0]; + + return true; + } + + NResult GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &sockIdInfo) override + { + // 用户可能在建链回调中使用该函数,此时ep状态并未设置成NEP_ESTABLISHED + if (!mState.Compare(NEP_ESTABLISHED)) { + NN_LOG_WARN("[Sock AsyncEp] EP status is " << mState.Get() << + " now, use ep after the connection established."); + } + + if (!mDriver->mStartOobSvr) { + NN_LOG_ERROR("[Sock AsyncEp] oob server is not start"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + if (mDriver->mOptions.oobType != NET_OOB_UDS) { + NN_LOG_ERROR("[Sock AsyncEp] oob type is not uds"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + // 通过mRemoteUdsIdInfo值判断是否可以返回给用户 + if (mRemoteUdsIdInfo.gid == 0 && mRemoteUdsIdInfo.pid == 0 && mRemoteUdsIdInfo.uid == 0) { + NN_LOG_ERROR("[Sock AsyncEp] RemoteUdsIdInfo has not been set."); + return NN_ERROR; + } + sockIdInfo = mRemoteUdsIdInfo; + return NN_OK; + } + + inline void EnableSendZCopy() + { + mSendZCopy = true; + } + +private: + static bool inline NeedRetry(NResult sockResult) + { + if (sockResult == SS_TCP_RETRY || sockResult == SS_SOCK_ADD_QUEUE_FAILED) { + return true; + } + + return false; + } + + uint64_t inline GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + Sock *mSock = nullptr; + SockWorker *mWorker = nullptr; + NetDriverSockWithOOB *mDriver = nullptr; + + bool mSendZCopy = false; + + friend class NetDriverSockWithOOB; +}; +} +} + +#endif // OCK_HCOM_NET_SOCK_ASYNC_ENDPOINT_H diff --git a/src/transport/sock/net_sock_common.h b/src/transport/sock/net_sock_common.h new file mode 100644 index 0000000000000000000000000000000000000000..dacdd9dc8bad4f77d877beaecec49dffc158fa40 --- /dev/null +++ b/src/transport/sock/net_sock_common.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_SOCK_COMMON_H_234234 +#define OCK_HCOM_NET_SOCK_COMMON_H_234234 + +#include +#include + +#include "hcom.h" +#include "net_common.h" +#include "net_memory_region.h" +#include "net_oob.h" +#include "sock_worker.h" + +namespace ock { +namespace hcom { +class NetAsyncEndpointSock; +class NetSyncEndpointSock; +class NetDriverSockWithOOB; + +enum SockExchangeOp : int16_t { + REAL_CONNECT = -1, +}; + +} +} + +#endif // OCK_HCOM_NET_SOCK_COMMON_H_234234 diff --git a/src/transport/sock/net_sock_driver_oob.cpp b/src/transport/sock/net_sock_driver_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..053be3db88807d8d317cd7db8e0bb5385b170afd --- /dev/null +++ b/src/transport/sock/net_sock_driver_oob.cpp @@ -0,0 +1,1608 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_sock_driver_oob.h" +#include "hcom_def.h" +#include "hcom_err.h" +#include "hcom_log.h" +#include "net_oob.h" +#include "net_oob_ssl.h" +#include "net_sock_sync_endpoint.h" +#include "net_sock_async_endpoint.h" +#include "net_sock_common.h" +#include "net_oob_secure.h" + +namespace ock { +namespace hcom { +NResult NetDriverSockWithOOB::Initialize(const UBSHcomNetDriverOptions &option) +{ + std::lock_guard guard(mInitMutex); + if (mInited) { + return NN_OK; + } + + mOptions = option; + + NResult sockRes = NN_OK; + if (NN_UNLIKELY((sockRes = mOptions.ValidateCommonOptions()) != NN_OK)) { + return sockRes; + } + + if (NN_UNLIKELY((sockRes = ValidateOptions()) != NN_OK)) { + return sockRes; + } + + if (NN_UNLIKELY(UBSHcomNetOutLogger::Instance() == nullptr)) { + return NN_NOT_INITIALIZED; + } + + if (option.enableTls) { + if (HcomSsl::Load() != 0) { + NN_LOG_ERROR("[Sock] Failed to load openssl API"); + return NN_NOT_INITIALIZED; + } + } + mEnableTls = option.enableTls; + NN_LOG_INFO("Try to initialize driver '" << mName << "' with " << mOptions.ToStringForSock()); + + if ((sockRes = CreateWorkerResource()) != NN_OK) { + NN_LOG_ERROR("[Sock] failed to create worker resource"); + UnInitializeInner(); + return sockRes; + } + + /* create workers */ + if ((sockRes = CreateWorkers()) != NN_OK) { + NN_LOG_ERROR("[Sock] failed to create workers"); + UnInitializeInner(); + return sockRes; + } + + /* create lb for client */ + if ((sockRes = CreateClientLB()) != NN_OK) { + NN_LOG_ERROR("[Sock] failed to create client lb"); + UnInitializeInner(); + return sockRes; + } + + /* create oob */ + if (mStartOobSvr) { + if ((sockRes = CreateListeners()) != NN_OK) { + NN_LOG_ERROR("[Sock] failed to create listeners"); + UnInitializeInner(); + return sockRes; + } + } + + mMrChecker.Reserve(NN_NO128); + mMrChecker.SetLockWhenOperates(false); + + mInited = true; + return NN_OK; +} + +void NetDriverSockWithOOB::UnInitialize() +{ + std::lock_guard guard(mInitMutex); + if (!mInited) { + return; + } + if (mStarted) { + NN_LOG_WARN("Unable to unInitialize sock driver" << " " << mName << " which is not stopped"); + return; + } + + UnInitializeInner(); + mInited = false; +} + +void NetDriverSockWithOOB::UnInitializeInner() +{ + if (mOpCtxMemPool != nullptr) { + mOpCtxMemPool.Set(nullptr); + } + + if (mSglCtxMemPool != nullptr) { + mSglCtxMemPool.Set(nullptr); + } + + if (mHeaderReqMemPool != nullptr) { + mHeaderReqMemPool.Set(nullptr); + } + + if (mSockDriverSendMR != nullptr) { + mSockDriverSendMR->DecreaseRef(); + mSockDriverSendMR = nullptr; + } + + for (auto oobServer : mOobServers) { + oobServer->DecreaseRef(); + } + mOobServers.clear(); + if (!mEndPoints.empty()) { + mEndPoints.clear(); + } + ClearWorkers(); + DestroyClientLB(); +} + +NResult NetDriverSockWithOOB::ValidateOptions() +{ + /* validate param related to device IpMask for RDMA and Sock */ + if (NN_UNLIKELY(!ValidateArrayOptions(mOptions.netDeviceIpMask, NN_NO256))) { + NN_LOG_ERROR("Option 'netDeviceIpMask' is invalid, " << mOptions.netDeviceIpMask << + " is set in driver,the Array max length is 256."); + return NN_INVALID_PARAM; + } + + /* validate params related to tcp connection send and receive buffer size in kernel for Sock */ + if (NN_UNLIKELY(mOptions.tcpSendBufSize > NN_NO4096)) { + NN_LOG_ERROR("Option 'tcpSendBufSize is invalid, " << mOptions.tcpSendBufSize << + " is set in driver, the valid value range is 0 ~ 4MB"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(mOptions.tcpReceiveBufSize > NN_NO4096)) { + NN_LOG_ERROR("Option 'tcpReceiveBufSize is invalid, " << mOptions.tcpReceiveBufSize << + " is set in driver, the valid value range is 0 ~ 4MB"); + return NN_INVALID_PARAM; + } + + if (mSockType == SOCK_TCP || mSockType == SOCK_UDS_TCP) { + std::vector filters; + NetFunc::NN_SplitStr(mOptions.NetDeviceIpMask(), ",", filters); + if (filters.empty()) { + NN_LOG_ERROR("Invalid ip mask '" << mOptions.netDeviceIpMask << "' is set, example '192.168.100.0/24'"); + return NN_INVALID_IP; + } + + std::vector matchIps; + for (auto &mask : filters) { + FilterIp(mask, matchIps); + } + + if (matchIps.empty()) { + NN_LOG_ERROR("No matched ip found with '" << mOptions.netDeviceIpMask << "', example '192.168.100.0/24'"); + return NN_INVALID_IP; + } + + mFilteredIps.swap(matchIps); + + if (mStartOobSvr && mOobListenOptions.empty()) { + NN_LOG_ERROR("No listening ip and port is set in driver " << mName); + return NN_INVALID_PARAM; + } + } + + /* validate options */ + if (mOptions.mode == NET_BUSY_POLLING) { + mOptions.mode = NET_EVENT_POLLING; + NN_LOG_WARN("Busy polling is not supported in TCP/UDS driver, changed to event mode in driver " << mName); + } + + if (NN_UNLIKELY(ValidateAndParseOobPortRange(mOptions.oobPortRange) != NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidateOptionsOobType() != NN_OK)) { + return NN_INVALID_PARAM; + } + + return NN_OK; +} + +NResult NetDriverSockWithOOB::CreateWorkers() +{ + NResult result = NN_OK; + + std::vector workerGroups; + std::vector> workerGroupCpus; + std::vector flatWorkerCpus; + std::vector workerThreadPriority; + + /* parse */ + if (!(NetFunc::NN_ParseWorkersGroups(mOptions.WorkGroups(), workerGroups)) || + !(NetFunc::NN_ParseWorkerGroupsCpus(mOptions.WorkerGroupCpus(), workerGroupCpus)) || + !(NetFunc::NN_FinalizeWorkerGroupCpus(workerGroups, workerGroupCpus, true, flatWorkerCpus)) || + !(NetFunc::NN_ParseWorkersGroupsThreadPriority(mOptions.WorkerGroupThreadPriority(), + workerThreadPriority, workerGroups.size()))) { + NN_LOG_ERROR("[Sock] Failed to parse worker or cpu groups"); + return NN_INVALID_PARAM; + } + + SockWorkerOptions options; + options.SetValue(mOptions, mStartOobSvr); + if ((mOptions.workerThreadPriority != 0) && (!workerThreadPriority.empty())) { + NN_LOG_WARN("Driver options 'workerThreadPriority' and 'workerGroupsThreadPriority' set all, preferential use " + "'workerGroupsThreadPriority'"); + } + /* create workers */ + mWorkers.reserve(flatWorkerCpus.size()); + uint32_t groupIndex = 0; + uint16_t totalWorkerIndex = 0; + UBSHcomNetWorkerIndex workerIndex {}; + for (auto item : workerGroups) { + /* The left of mWorkerGroups is the index of each group's first worker in the mWorkers */ + mWorkerGroups.emplace_back(totalWorkerIndex, item); + for (uint32_t i = 0; i < item; ++i) { + options.cpuId = flatWorkerCpus.at(totalWorkerIndex++); + if (!workerThreadPriority.empty()) { + options.threadPriority = workerThreadPriority[groupIndex]; + } + auto *worker = new (std::nothrow) + SockWorker(mSockType, mName, workerIndex, mOpCtxMemPool, mSglCtxMemPool, mHeaderReqMemPool, options); + if (NN_UNLIKELY(worker == nullptr)) { + NN_LOG_ERROR("Failed to create sock worker in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + workerIndex.Set(i, groupIndex, mIndex); + worker->SetIndex(workerIndex); + if (NN_UNLIKELY((result = worker->Initialize()) != NN_OK)) { + delete worker; + NN_LOG_ERROR("Failed to initialize sock worker in driver " << mName << ", result " << result); + return NN_NEW_OBJECT_FAILED; + } + + worker->IncreaseRef(); + mWorkers.push_back(worker); + } + ++groupIndex; + } + + return NN_OK; +} + +NResult NetDriverSockWithOOB::CreateWorkerResource() +{ + NResult result; + if (((result = CreateOpCtxMemPool()) != NN_OK)) { + NN_LOG_ERROR("Sock failed to create op ctx memory pool"); + return result; + } + + if (((result = CreateSglCtxMemPool()) != NN_OK)) { + NN_LOG_ERROR("Sock failed to create Sgl ctx memory pool"); + return result; + } + + if (mOptions.tcpSendZCopy) { + if (((result = CreateHeaderReqMemPool()) != NN_OK)) { + NN_LOG_ERROR("Sock failed to create header request memory pool"); + return result; + } + } else { + if (((result = CreateSendMr()) != NN_OK)) { + NN_LOG_ERROR("Sock falied to create send mr"); + return result; + } + } + return NN_OK; +} + +NResult NetDriverSockWithOOB::CreateOpCtxMemPool() +{ + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = sizeof(SockOpContextInfo); + options.tcExpandBlkCnt = NN_NO64; + + mOpCtxMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mOpCtxMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for sock op context pool " << mName << ", probably out of memory"); + return NN_INVALID_PARAM; + } + + auto result = mOpCtxMemPool->Initialize(); + if (result != NN_OK) { + mOpCtxMemPool.Set(nullptr); + NN_LOG_ERROR("Failed to initialize memory pool for sock op context pool " << mName << + ", probably out of memory"); + return result; + } + + return NN_OK; +} + +NResult NetDriverSockWithOOB::CreateSglCtxMemPool() +{ + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = NN_NO512; // the sgl context is 468, not power of 2, set to the closest num 512 + options.tcExpandBlkCnt = NN_NO64; + mSglCtxMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mSglCtxMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for sgl op context in driver " << mName << + ", probably out of memory"); + return NN_INVALID_PARAM; + } + + auto result = mSglCtxMemPool->Initialize(); + if (result != NN_OK) { + mSglCtxMemPool.Set(nullptr); + NN_LOG_ERROR("Failed to initialize memory pool for sgl op context in driver " << mName << + ", probably out of memory"); + return result; + } + + return NN_OK; +} + +NResult NetDriverSockWithOOB::CreateHeaderReqMemPool() +{ + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = NN_NextPower2(sizeof(SockHeaderReqInfo)); + options.tcExpandBlkCnt = NN_NO64; + mHeaderReqMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mHeaderReqMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for header request context in driver " << mName << + ", probably out of memory"); + return NN_INVALID_PARAM; + } + + auto result = mHeaderReqMemPool->Initialize(); + if (result != NN_OK) { + mHeaderReqMemPool.Set(nullptr); + NN_LOG_ERROR("Failed to initialize memory pool for header request context in driver " << mName << + ", probably out of memory"); + return result; + } + + return NN_OK; +} + +NResult NetDriverSockWithOOB::CreateSendMr() +{ + NResult result = NN_OK; + // create mr pool for send/receive and initialize + if (NN_UNLIKELY((result = NormalMemoryRegionFixedBuffer::Create(mName, mOptions.mrSendReceiveSegSize, + mOptions.mrSendReceiveSegCount, mSockDriverSendMR)) != NN_OK)) { + NN_LOG_ERROR("Failed to create mr for send/receive in NetDriverSock " << mName << ", result " << result); + return result; + } + mSockDriverSendMR->IncreaseRef(); + + if (NN_UNLIKELY((result = mSockDriverSendMR->Initialize()) != NN_OK)) { + NN_LOG_ERROR("Failed to initialize mr for send/receive in NetDriverSock " << mName << ", result " << result); + mSockDriverSendMR->DecreaseRef(); + return result; + } + + return NN_OK; +} + +void NetDriverSockWithOOB::ClearWorkers() +{ + mWorkerGroups.clear(); + for (auto worker : mWorkers) { + worker->Stop(); + worker->DecreaseRef(); + } + mWorkers.clear(); +} + +NResult NetDriverSockWithOOB::Start() +{ + std::lock_guard guard(mInitMutex); + if (!mInited) { + NN_LOG_ERROR("Failed to start driver " << mName << " as it is not initialized"); + return NN_ERROR; + } + + if (mOptions.dontStartWorkers) { + mStarted = true; + return NN_OK; + } + + if (mStarted) { + return NN_OK; + } + + NResult result = NN_OK; + if (NN_UNLIKELY(result = ValidateHandlesCheck()) != NN_OK) { + ClearWorkers(); + return result; + } + for (auto &item : mWorkers) { + if (NN_UNLIKELY(item == nullptr)) { + NN_LOG_ERROR("[Sock] Failed to start worker " << mName << " as it is null"); + ClearWorkers(); + return result; + } + + item->RegisterNewReqHandler(std::bind(&NetDriverSockWithOOB::HandleNewRequest, this, std::placeholders::_1)); + item->RegisterReqPostedHandler(std::bind(&NetDriverSockWithOOB::HandleReqPosted, this, std::placeholders::_1)); + item->RegisterOneSideHandler(std::bind(&NetDriverSockWithOOB::OneSideDone, this, std::placeholders::_1)); + item->RegisterEpCloseHandler(std::bind(&NetDriverSockWithOOB::HandleEpClose, this, std::placeholders::_1)); + if (mIdleHandler) { + item->RegisterIdleHandler(mIdleHandler); + } + + if ((result = item->Start()) != NN_OK) { + NN_LOG_ERROR("Failed to start worker " << mName << ", result " << result); + ClearWorkers(); + return result; + } + } + + if (mStartOobSvr) { + if (mNewEndPointHandler == nullptr) { + NN_LOG_ERROR("Sock failed to do start in Driver " << mName << ", as newEndPointerHandler is null"); + return NN_INVALID_PARAM; + } + for (auto &oobServer : mOobServers) { + oobServer->SetNewConnCB(std::bind(&NetDriverSockWithOOB::HandleNewOobConn, this, std::placeholders::_1)); + } + + /* start oob server */ + if ((result = StartListeners()) != NN_OK) { + ClearWorkers(); + return result; + } + } + + mStarted = true; + return NN_OK; +} + +void NetDriverSockWithOOB::Stop() +{ + std::lock_guard guard(mInitMutex); + if (!mStarted) { + return; + } + + for (auto worker : mWorkers) { + worker->Stop(); + } + + StopListeners(); + + mStarted = false; +} + +NResult NetDriverSockWithOOB::CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + if (NN_UNLIKELY(size == 0 || size > NN_NO107374182400)) { + NN_LOG_ERROR("Sock Failed to create mem region as size is 0 or greater than 100 GB"); + return NN_INVALID_PARAM; + } + + if (!mInited) { + NN_LOG_ERROR("Sock Failed to create memory region in driver " << mName << ", as not initialized"); + return NN_NOT_INITIALIZED; + } + + NormalMemoryRegion *tmp = nullptr; + auto result = NormalMemoryRegion::Create(mName, size, tmp); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Sock Failed to create memory region in driver " << mName << ", probably out of memory"); + return result; + } + + if ((result = tmp->Initialize()) != NN_OK) { + delete tmp; + return result; + } + + if ((result = mMrChecker.Register(tmp->GetLKey(), tmp->GetAddress(), size)) != NN_OK) { + NN_LOG_INFO("Sock Failed to add memory region to range checker in driver" << mName << " for duplicate keys"); + delete tmp; + return result; + } + + mr.Set(static_cast(tmp)); + + return NN_OK; +} +NResult NetDriverSockWithOOB::CreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + if (!mInited) { + NN_LOG_ERROR("Failed to create memory region in driver " << mName << ", as not initialized"); + return NN_NOT_INITIALIZED; + } + + if (address == 0) { + NN_LOG_ERROR("Failed to create memory region in driver " << mName << ", as address is 0"); + return NN_INVALID_PARAM; + } + + NormalMemoryRegion *tmp = nullptr; + auto res = NormalMemoryRegion::Create(mName, address, size, tmp); + if (NN_UNLIKELY(res != NN_OK)) { + NN_LOG_ERROR("Failed to create memory region in driver " << mName << ", probably out of memory"); + return res; + } + + if ((res = tmp->Initialize()) != NN_OK) { + delete tmp; + return res; + } + + if ((res = mMrChecker.Register(tmp->GetLKey(), tmp->GetAddress(), size)) != NN_OK) { + NN_LOG_ERROR("Failed to add memory region to range checker in driver" << mName << " for duplicate keys"); + delete tmp; + return res; + } + + mr.Set(static_cast(tmp)); + + return NN_OK; +} + +NResult NetDriverSockWithOOB::CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, unsigned long memid) +{ + NN_LOG_ERROR("operation is not supported in tcp"); + return NN_ERROR; +} + +void NetDriverSockWithOOB::DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) +{ + if (mr.Get() == nullptr) { + NN_LOG_WARN("Try to destroy null memory region in sock driver " << mName); + return; + } + + if (!mMrChecker.Contains(mr->GetLKey())) { + NN_LOG_WARN("Try to destroy unowned memory region in driver " << mName); + return; + } + mMrChecker.UnRegister(mr->GetLKey()); + + auto tmp = mr.ToChild(); + if (NN_UNLIKELY(tmp == nullptr)) { + NN_LOG_WARN("Invalid operation to dynamic cast"); + return; + } + tmp->UnInitialize(); +} + +NResult NetDriverSockWithOOB::Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, + uint8_t serverGrpNo, uint8_t clientGrpNo) +{ + if (mOptions.oobType == NET_OOB_TCP) { + return Connect(mOobIp, mOobPort, payload, ep, flags, serverGrpNo, clientGrpNo, 0); + } else if (mOptions.oobType == NET_OOB_UDS) { + return Connect(mUdsName, 0, payload, ep, flags, serverGrpNo, clientGrpNo, 0); + } + return NN_ERROR; +} + +NResult NetDriverSockWithOOB::Connect(const std::string &serverUrl, const std::string &payload, + UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + if (NN_UNLIKELY(!mInited.load())) { + NN_LOG_ERROR("[Sock] Driver " << mName << " is not initialized"); + return NN_NOT_INITIALIZED; + } + + if (NN_UNLIKELY(!mStarted)) { + NN_LOG_ERROR("[Sock] Failed to connect on driver " << mName << " as it is not started"); + return NN_ERROR; + } + + if (payload.size() > NN_NO1024) { + NN_LOG_ERROR("[Sock] Failed to connect server as payload size " << payload.size() << " over limit"); + return NN_INVALID_PARAM; + } + + NetDriverOobType type; + std::string ip; + uint16_t port = 0; + if (NN_UNLIKELY(NetFunc::NN_ValidateUrl(serverUrl) != NN_OK)) { + NN_LOG_ERROR("Invalid url"); + return NN_PARAM_INVALID; + } + if (NN_UNLIKELY(ParseUrl(serverUrl, type, ip, port) != NN_OK)) { + NN_LOG_ERROR("Invalid url, url:" << serverUrl); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(!mInited.load() || clientGrpNo >= mWorkerGroups.size())) { + NN_LOG_ERROR("Invalid clientGrpNo " << clientGrpNo << ", or driver " << mName << " is not initialized"); + return NN_ERROR; + } + + OOBTCPClientPtr clt; + if (mEnableTls) { + auto oobSSLClient = new OOBSSLClient(type, ip, port, + mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(oobSSLClient != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLClient->SetTlsOptions(mOptions); + oobSSLClient->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + clt = oobSSLClient; + } else { + clt = new OOBTCPClient(type, ip, port); + NN_ASSERT_LOG_RETURN(clt.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + if (flags & NET_EP_SELF_POLLING) { + return ConnectSyncEp(clt, payload, ep, serverGrpNo, ctx); + } + return Connect(clt, payload, ep, serverGrpNo, clientGrpNo, ctx); +} + +NResult NetDriverSockWithOOB::Connect(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + if (NN_UNLIKELY(!mInited.load())) { + NN_LOG_ERROR("Sock Driver " << mName << " is not initialized"); + return NN_NOT_INITIALIZED; + } + + if (NN_UNLIKELY(!mStarted)) { + NN_LOG_ERROR("Sock Failed to connect on driver " << mName << " as it is not started"); + return NN_ERROR; + } + + if (payload.size() > NN_NO1024) { + NN_LOG_ERROR("Sock Failed to connect server via payload size " << payload.size() << " over limit"); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(!mInited || clientGrpNo >= mWorkerGroups.size())) { + NN_LOG_ERROR("Invalid clientGrpNo " << clientGrpNo << ", or driver " << mName << " is not initialized"); + return NN_ERROR; + } + + OOBTCPClientPtr clt; + if (mEnableTls) { + auto oobSSLClient = + new OOBSSLClient(mOptions.oobType, oobIp, oobPort, mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(oobSSLClient != nullptr, NN_NEW_OBJECT_FAILED) + oobSSLClient->SetTlsOptions(mOptions); + oobSSLClient->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + clt = oobSSLClient; + } else { + clt = new OOBTCPClient(mOptions.oobType, oobIp, oobPort); + NN_ASSERT_LOG_RETURN(clt.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + if (flags & NET_EP_SELF_POLLING) { + return ConnectSyncEp(clt, payload, ep, serverGrpNo, ctx); + } + return Connect(clt, payload, ep, serverGrpNo, clientGrpNo, ctx); +} + +NResult NetDriverSockWithOOB::Connect(const OOBTCPClientPtr &client, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + /* try to connect to oob server */ + OOBTCPConnection *conn = nullptr; + NResult result = NN_OK; + if ((result = client->Connect(conn)) != 0) { + NN_LOG_ERROR("Sock Failed to connect server via oob, result " << result); + return result; + } + + NetLocalAutoDecreasePtr autoDecPtr(conn); + if (client->GetOobType() == NET_OOB_TCP) { + conn->SetIpAndPort(client->GetServerIp(), client->GetServerPort()); + } else { + conn->SetIpAndPort(client->GetServerUdsName(), 0); + } + + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBClient(mSecInfoProvider, mSecInfoValidator, conn, mName, ctx, + mOptions.secType))) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + /* send connection header */ + ConnectHeader header {}; + SetConnHeader(header, mOptions.magic, mOptions.version, serverGrpNo, Protocol(), mMajorVersion, + mMinorVersion, mOptions.tlsVersion); + if (NN_UNLIKELY((result = conn->Send(&header, sizeof(ConnectHeader))) != NN_OK)) { + NN_LOG_ERROR("Sock Failed to send conn header to oob server " << client->GetServerIp() << ":" <<\ + client->GetServerPort() << " in driver " << mName); + return NN_ERROR; + } + + /* receive connect response and peer sock id */ + ConnRespWithUId respWithUId {}; + void *tmpBuff = &respWithUId; + if (NN_UNLIKELY((result = conn->Receive(tmpBuff, sizeof(ConnRespWithUId))) != NN_OK)) { + return result; + } + + /* connect response */ + auto resp = respWithUId.connResp; + switch (resp) { + case MAGIC_MISMATCH: + NN_LOG_ERROR("Sock Failed to pass server magic validation " << mName << ", result " << NN_CONNECT_REFUSED); + return NN_CONNECT_REFUSED; + case PROTOCOL_MISMATCH: + NN_LOG_ERROR("Sock Failed to pass server protocol validation " << mName << ", result " << + NN_CONNECT_PROTOCOL_MISMATCH); + return NN_CONNECT_PROTOCOL_MISMATCH; + case SERVER_INTERNAL_ERROR: + NN_LOG_ERROR("Sock Server error happened, connection refused " << mName << ", result " << resp); + return NN_ERROR; + case VERSION_MISMATCH: + NN_LOG_ERROR("Sock Failed to pass server version validation " << mName << ", result " << + NN_CONNECT_REFUSED); + return NN_CONNECT_REFUSED; + case TLS_VERSION_MISMATCH: + NN_LOG_ERROR("Sock Failed to pass server tls version validation " << mName << ", result " << + NN_CONNECT_REFUSED); + return NN_CONNECT_REFUSED; + case OK: + case OK_PROTOCOL_TCP: + case OK_PROTOCOL_UDS: + break; + default: + NN_LOG_ERROR("Sock Server error happened, connection refused " << mName << ", result: " << resp); + return NN_ERROR; + } + + /* peer ep id */ + auto newSockId = respWithUId.epId; + NN_LOG_TRACE_INFO("Sock new ep id will be set as" << " " << newSockId << " in driver " << mName); + + /* choose worker */ + uint16_t workerIndex = 0; + if (NN_UNLIKELY(!mClientLb->ChooseWorker(clientGrpNo, std::to_string(newSockId), workerIndex)) || + workerIndex >= mWorkers.size()) { + NN_LOG_ERROR("Sock Failed to choose worker during connect in driver " << mName); + return NN_ERROR; + } + + NN_LOG_TRACE_INFO("Worker " << workerIndex << " is chosen in driver " << mName); + + SockWorker *worker = mWorkers[workerIndex]; + NN_ASSERT_LOG_RETURN(worker != nullptr, NN_ERROR); + + /* create sock and initialize */ + SockOptions options {}; + options.sendQueueSize = mOptions.qpSendQueueSize; + Sock *sock; + int fdConn = conn->TransferFd(); + if (mEnableTls) { + sock = new (std::nothrow) Sock(mSockType, mName, newSockId, fdConn, options, conn); + } else { + sock = new (std::nothrow) Sock(mSockType, mName, newSockId, fdConn, options); + } + if (NN_UNLIKELY(sock == nullptr)) { + NN_LOG_ERROR("Failed to new async sock in driver " << mName << ", probably out of memory"); + NetFunc::NN_SafeCloseFd(fdConn); + return NN_NEW_OBJECT_FAILED; + } + + sock->PeerIpPort(conn->GetIpAndPort()); + NetLocalAutoDecreasePtr autoDecSock(sock); + + if (NN_UNLIKELY((result = sock->Initialize(worker->Options())))) { + NN_LOG_ERROR("Failed to initialize sock " << sock->Id() << " in driver " << mName << ", result " << result); + return NN_NEW_OBJECT_FAILED; + } + + /* send real head and payload */ + UBSHcomNetTransHeader workerFirstReq {}; + workerFirstReq.flags = NTH_TWO_SIDE; + workerFirstReq.opCode = SockExchangeOp::REAL_CONNECT; + workerFirstReq.dataLength = payload.length(); + workerFirstReq.seqNo = header.wholeHeader[0]; /* use reqNo */ + /* finally fill header crc */ + workerFirstReq.headerCrc = NetFunc::CalcHeaderCrc32(workerFirstReq); + if (NN_UNLIKELY((result = sock->SendRealConnHeader(fdConn, &workerFirstReq, + sizeof(UBSHcomNetTransHeader))) != NN_OK)) { + NN_LOG_ERROR("Failed to send payload header to peer at " << conn->GetIpAndPort() << " in driver " << mName); + return result; + } + + if (!payload.empty()) { + if ((result = sock->Send(payload.c_str(), payload.length())) != NN_OK) { + NN_LOG_ERROR("Failed to send payload to peer at " << conn->GetIpAndPort() << " in driver " << mName << + ", errno " << result); + return result; + } + } + + /* added worker as up context */ + sock->UpContext1(reinterpret_cast(worker)); + + /* create ep */ + UBSHcomNetEndpointPtr newEp = new (std::nothrow) NetAsyncEndpointSock(sock->Id(), sock, this, worker->Index()); + if (NN_UNLIKELY(newEp.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new async sock ep in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + if (mEnableTls) { + auto childEp = newEp.ToChild(); + auto tmp = dynamic_cast(conn); + if (NN_UNLIKELY(childEp == nullptr || tmp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + childEp->EnableEncrypt(mOptions); + childEp->SetSecrets(tmp->Secret()); + } + if (mOptions.tcpSendZCopy) { + auto childEp = newEp.ToChild(); + if (NN_UNLIKELY(childEp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_ERROR; + } + childEp->EnableSendZCopy(); + } + /* set ep as sock up context and add ep into map */ + sock->UpContext(reinterpret_cast(newEp.Get())); + + /* if non-blocking set postedHandler and ctxInfoPool to sock, used in sock func ProcessQueueReq */ + sock->SetSockPostedHandler(worker->GetSockPostedHandler()); + sock->SetSockOneSideHandler(worker->GetSockOneSideHandler()); + sock->SetSockOpContextInfoPool(worker->GetSockOpContextInfoPool()); + sock->SetSockSglContextInfoPool(worker->GetSockSglContextInfoPool()); + sock->SetSockHeaderReqInfoPool(worker->GetSockHeaderReqInfoPool()); + sock->SetSockDriverSendMR(mSockDriverSendMR); + sock->SetMrChecker(&mMrChecker); + + NN_LOG_TRACE_INFO("Sock created " << sock->ToString() << " in driver " << mName); + newEp->StoreConnInfo(NetFunc::GetIpByFd(fdConn), conn->ListenPort(), header.version, payload); + + // receive server ready signal + int8_t ready = -1; + tmpBuff = static_cast(&ready); + result = sock->Receive(tmpBuff, sizeof(int8_t)); + if (result != 0 || ready != 1) { + NN_LOG_ERROR("Sock Failed to connect to server as server not responses or return not ready, result " << result); + return NN_ERROR; + } + + if (sock->SetNonBlockingIo() != SS_OK) { + NN_LOG_ERROR("Failed to set sock " << sock->Name() << " nonblocking io mode."); + return NN_ERROR; + } + + AddEp(newEp); + + /* add to worker epoll */ + if (NN_UNLIKELY(worker->AddToEpoll(sock, EPOLLIN) != SS_OK)) { + NN_LOG_ERROR("Failed to add sock " << sock->Name() << " to the epoll handle."); + return NN_ERROR; + } + + newEp->State().Set(NEP_ESTABLISHED); + outEp.Set(newEp.Get()); + + NN_LOG_INFO("New connection to " << client->GetServerIp() << ":" << client->GetServerPort() << + " established, async ep id " << outEp->Id() << " worker info " << worker->DetailName()); + return NN_OK; +} + +NResult NetDriverSockWithOOB::ConnectSyncEp(const OOBTCPClientPtr &client, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint8_t serverGrpNo, uint64_t ctx) +{ + if (NN_UNLIKELY(!mInited)) { + NN_LOG_ERROR("Driver " << mName << " is not initialized"); + return NN_ERROR; + } + + /* try to connect to oob server */ + OOBTCPConnection *conn = nullptr; + NResult result = NN_OK; + if ((result = client->Connect(conn)) != 0) { + NN_LOG_ERROR("Sock Failed to connect server via oob,result " << result); + return result; + } + + NetLocalAutoDecreasePtr autoDecPtr(conn); + if (client->GetOobType() == NET_OOB_TCP) { + conn->SetIpAndPort(client->GetServerIp(), client->GetServerPort()); + } else { + conn->SetIpAndPort(client->GetServerUdsName(), 0); + } + + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBClient(mSecInfoProvider, mSecInfoValidator, conn, mName, ctx, + mOptions.secType))) { + return NN_OOB_SEC_PROCESS_ERROR; + } + /* send connection header */ + ConnectHeader header {}; + + SetConnHeader(header, mOptions.magic, mOptions.version, serverGrpNo, Protocol(), mMajorVersion, mMinorVersion, + mOptions.tlsVersion); + if (NN_UNLIKELY((result = conn->Send(&header, sizeof(ConnectHeader))) != NN_OK)) { + NN_LOG_ERROR("Sock Failed to send conn header to oob server " << client->GetServerIp() << ":" + << client->GetServerPort() << " in Driver " << mName); + return NN_ERROR; + } + + /* receive connect response and peer ep id */ + ConnRespWithUId respWithUId {}; + void *tmpBuf = &respWithUId; + if (NN_UNLIKELY((result = conn->Receive(tmpBuf, sizeof(ConnRespWithUId))) != NN_OK)) { + return result; + } + + /* connect response */ + auto resp = respWithUId.connResp; + switch (resp) { + case MAGIC_MISMATCH: + NN_LOG_ERROR("Failed to pass server magic validation " << mName << ", result " << NN_CONNECT_REFUSED); + return NN_CONNECT_REFUSED; + case PROTOCOL_MISMATCH: + NN_LOG_ERROR("Failed to pass server magic validation " << mName << ", result " << NN_CONNECT_REFUSED); + return NN_CONNECT_PROTOCOL_MISMATCH; + case SERVER_INTERNAL_ERROR: + NN_LOG_ERROR("Server error happened, connection refused " << mName << ", result " << resp); + return NN_ERROR; + case VERSION_MISMATCH: + NN_LOG_ERROR("Failed to pass server version validation " << mName << ", result " << NN_CONNECT_REFUSED); + return NN_CONNECT_REFUSED; + case TLS_VERSION_MISMATCH: + NN_LOG_ERROR("Failed to pass server tls version validation " << mName << ", result " << NN_CONNECT_REFUSED); + return NN_CONNECT_REFUSED; + case OK: + case OK_PROTOCOL_TCP: + case OK_PROTOCOL_UDS: + break; + default: + NN_LOG_ERROR("Sock Server error happened, connection refused " << mName << ", result " << resp); + return NN_ERROR; + } + + /* peer ep id */ + auto newSockId = respWithUId.epId; + NN_LOG_TRACE_INFO("new ep id will be set as " << newSockId << " in driver " << mName); + + int fdConn = conn->TransferFd(); + + /* create sock and initialize */ + SockOptions option {}; + option.sendQueueSize = mOptions.qpSendQueueSize; + Sock *sock; + if (mEnableTls) { + sock = new (std::nothrow) Sock(mSockType, mName, newSockId, fdConn, option, conn); + } else { + sock = new (std::nothrow) Sock(mSockType, mName, newSockId, fdConn, option); + } + if (NN_UNLIKELY(sock == nullptr)) { + NN_LOG_ERROR("Failed to new async sock in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + sock->PeerIpPort(conn->GetIpAndPort()); + NetLocalAutoDecreasePtr autoDecSock(sock); + + SockWorkerOptions options; + options.SetValue(mOptions, mStartOobSvr); + + if (NN_UNLIKELY((result = sock->Initialize(options)))) { + NN_LOG_ERROR("Failed to initialize sock " << sock->Id() << " in driver " << mName << " result " << result); + return NN_NEW_OBJECT_FAILED; + } + + /* send real head and payload */ + UBSHcomNetTransHeader workerFirstReq {}; + workerFirstReq.opCode = SockExchangeOp::REAL_CONNECT; + workerFirstReq.flags = NTH_TWO_SIDE; + workerFirstReq.dataLength = payload.length(); + workerFirstReq.seqNo = header.wholeHeader[0]; /* use reqNo */ + + /* finally fill header crc */ + workerFirstReq.headerCrc = NetFunc::CalcHeaderCrc32(workerFirstReq); + + if (NN_UNLIKELY((result = sock->SendRealConnHeader(fdConn, &workerFirstReq, + sizeof(UBSHcomNetTransHeader))) != NN_OK)) { + NN_LOG_ERROR("Failed to send payload header to peer at " << conn->GetIpAndPort() << " in driver " << mName); + NetFunc::NN_SafeCloseFd(fdConn); + return result; + } + + if (!payload.empty()) { + if ((result = sock->Send(payload.c_str(), payload.length())) != NN_OK) { + NN_LOG_ERROR("Failed to send payload to peer at " << conn->GetIpAndPort() << " in driver " << mName << + ", errno " << result); + NetFunc::NN_SafeCloseFd(fdConn); + return result; + } + } + + /* create ep */ + const UBSHcomNetWorkerIndex netWorkerIndex {}; + UBSHcomNetEndpointPtr newEp = new (std::nothrow) NetSyncEndpointSock(sock->Id(), sock, this, netWorkerIndex); + if (NN_UNLIKELY(newEp.Get() == nullptr)) { + NN_LOG_ERROR("Failed to new sync sock ep in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + if (mEnableTls) { + auto childEp = newEp.ToChild(); + auto tmp = dynamic_cast(conn); + if (NN_UNLIKELY(childEp == nullptr || tmp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + childEp->EnableEncrypt(mOptions); + childEp->SetSecrets(tmp->Secret()); + } + /* set ep as sock up context and add ep into map */ + sock->UpContext(reinterpret_cast(newEp.Get())); + sock->SetMrChecker(&mMrChecker); + NN_LOG_TRACE_INFO("Sock created " << sock->ToString() << " in driver " << mName); + + newEp->StoreConnInfo(NetFunc::GetIpByFd(fdConn), conn->ListenPort(), header.version, payload); + + // receive server ready signal + int8_t ready = -1; + tmpBuf = static_cast(&ready); + result = sock->Receive(tmpBuf, sizeof(int8_t)); + if (result != 0 || ready != 1) { + NN_LOG_ERROR("Failed to connect to server as server not responses or return not ready, result " << result); + // do later: handle pre post-ed mr + return NN_ERROR; + } + + AddEp(newEp); + + newEp->State().Set(NEP_ESTABLISHED); + outEp.Set(newEp.Get()); + + NN_LOG_INFO("New connect to " << client->GetServerIp() << ":" << client->GetServerPort() << + " established, sync ep id " << outEp->Id()); + return NN_OK; +} + +void NetDriverSockWithOOB::DestroyEndpoint(UBSHcomNetEndpointPtr &ep) +{ + if (NN_UNLIKELY(ep.Get() == nullptr)) { + NN_LOG_WARN("The sock ep is null already."); + return; + } + + NN_LOG_INFO("Destroy endpoint id " << ep->Id()); + if (!Remove(ep->Id())) { + NN_LOG_WARN("Unable to destroy sock endpoint as ep " << ep->Id() << " doesn't exist, maybe cleaned already"); + return; + } + + ep.Set(nullptr); +} + +void NetDriverSockWithOOB::DestroyEndpointById(uint64_t id) +{ + std::lock_guard guard(mEndPointsMutex); + auto it = mEndPoints.find(id); + if (NN_UNLIKELY(it == mEndPoints.end())) { + NN_LOG_WARN("the id is not in the ep map"); + return; + } + + NN_LOG_INFO("Destroy endpoint id " << id); + if (NN_UNLIKELY(mEndPoints.erase(id) <= 0)) { + NN_LOG_WARN("Unable to destroy sock endpoint as ep " << id << " doesn't exist, maybe cleaned already"); + return; + } + + mEndPoints[id].Set(nullptr); +} + +NResult NetDriverSockWithOOB::HandleNewOobConn(OOBTCPConnection &conn) +{ + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBServer(mSecInfoProvider, mSecInfoValidator, conn, mName, + mOptions.secType)) != NN_OK) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + uint32_t ip = NetFunc::GetIpByFd(conn.GetFd()); + if (NN_UNLIKELY(OOBSecureProcess::SecProcessCompareEpNum(ip, conn.ListenPort(), conn.GetIpAndPort(), + mOobServers)) != NN_OK) { + NN_LOG_ERROR("Sock connection num exceeds maximum"); + return NN_OOB_SEC_PROCESS_ERROR; + } + + NResult result = 0; + /* receive header and verify */ + ConnectHeader header {}; + void *headerBuf = &header; + if (NN_UNLIKELY((result = conn.Receive(headerBuf, sizeof(ConnectHeader))) != 0)) { + NN_LOG_ERROR("Failed to read header from " << conn.GetIpAndPort() << " for driver " << mName << ", result " << + result); + return result; + } + + ConnRespWithUId respWithUId{ OK, 0 }; + result = OOBSecureProcess::SecCheckConnectionHeader(header, mOptions, mEnableTls, Protocol(), mMajorVersion, + mMinorVersion, respWithUId); + if (result != NN_OK) { + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + + /* choose worker */ + const NetWorkerLBPtr &lb = conn.LoadBalancer(); + NN_ASSERT_LOG_RETURN(lb.Get() != nullptr, NN_ERROR) + uint16_t workerIndex = 0; + if (NN_UNLIKELY(!lb->ChooseWorker(header.groupIndex, conn.GetIpAndPort(), workerIndex)) || + workerIndex >= mWorkers.size()) { + NN_LOG_ERROR("Failed to choose worker during connect in driver " << mName); + return NN_ERROR; + } + + ConnectResp resp = GetConnResp(mSockType); + uint64_t newSockId = NetUuid::GenerateUuid(); + { + std::lock_guard guard(mEndPointsMutex); + while (mEndPoints.count(newSockId) != 0) { + NN_LOG_WARN("Duplicate generate ep id " << newSockId << " for connection to " + << conn.GetIpAndPort() << " for driver " << mName << ", regenereate"); + newSockId = NetUuid::GenerateUuid(); + } + } + + NN_LOG_TRACE_INFO("new sock id will be set as " << newSockId << " in driver " << mName); + + respWithUId.connResp = resp; + respWithUId.epId = newSockId; + if (NN_UNLIKELY((result = conn.Send(&respWithUId, sizeof(ConnRespWithUId))) != NN_OK)) { + NN_LOG_ERROR("Failed to send connect response to " << conn.GetIpAndPort() << " for driver " << mName); + return NN_ERROR; + } + + NN_LOG_TRACE_INFO("Worker " << workerIndex << " is chosen in driver " << mName); + SockWorker *worker = mWorkers[workerIndex]; + NN_ASSERT_LOG_RETURN(worker != nullptr, NN_ERROR); + /* send worker exchange info to oob client */ + if (mSockType == SOCK_TCP || mSockType == SOCK_UDS) { + int fdConn = conn.TransferFd(); + + /* create sock and initialize */ + SockOptions options {}; + Sock *sock; + if (mEnableTls) { + sock = new (std::nothrow) Sock(mSockType, mName, newSockId, fdConn, options, &conn); + } else { + sock = new (std::nothrow) Sock(mSockType, mName, newSockId, fdConn, options); + } + + if (NN_UNLIKELY(sock == nullptr)) { + NN_LOG_ERROR("Failed to new sock in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + NetLocalAutoDecreasePtr autoDecSock(sock); + + if (mEnableTls) { + auto tmp = dynamic_cast(&conn); + if (NN_UNLIKELY(tmp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + sock->Secret(tmp->Secret()); + } + + if (NN_UNLIKELY((result = sock->Initialize(worker->Options())))) { + NN_LOG_ERROR("Failed to initialize sock " << sock->Id() << " in driver " << mName << " result " << result); + return NN_NEW_OBJECT_FAILED; + } + + sock->SetSockPostedHandler(worker->GetSockPostedHandler()); + sock->SetSockOneSideHandler(worker->GetSockOneSideHandler()); + sock->SetSockOpContextInfoPool(worker->GetSockOpContextInfoPool()); + sock->SetSockSglContextInfoPool(worker->GetSockSglContextInfoPool()); + sock->SetSockHeaderReqInfoPool(worker->GetSockHeaderReqInfoPool()); + sock->SetSockDriverSendMR(mSockDriverSendMR); + + sock->PeerIpPort(conn.GetIpAndPort()); + sock->StoreConnInfo(NetFunc::GetIpByFd(fdConn), conn.ListenPort(), header.version); + + /* added worker as up context */ + sock->UpContext1(reinterpret_cast(worker)); + + /* add to worker epoll */ + if (NN_UNLIKELY(worker->AddToEpoll(sock, EPOLLIN) != NN_OK)) { + NN_LOG_ERROR("Failed to add sock " << sock->Name() << " to the epoll handle."); + return NN_ERROR; + } + } else { + NN_ASSERT_LOG_RETURN(false, NN_ERROR); + } + + OOBSecureProcess::SecProcessAddEpNum(ip, conn.ListenPort(), conn.GetIpAndPort(), mOobServers); + + return NN_OK; +} + +NResult NetDriverSockWithOOB::HandleSockError(Sock *sock) +{ + /* sock is failure and close sock */ + NN_LOG_TRACE_INFO("Sock error " << (sock)->ToString()); + + /* remove fd */ + auto worker = reinterpret_cast(sock->UpContext1()); + NN_ASSERT_LOG_RETURN(worker != nullptr, NN_ERROR); + + /* sock will DecreaseRef at worker RemoveFromEpoll, if in real connect process it will be destroyed, + * so sock IncreaseRef here to make sure destroyed after get UpContext */ + sock->IncreaseRef(); + worker->RemoveFromEpoll(sock); + sock->DealCbWithFailure(); + sock->Close(); + OOBSecureProcess::SecProcessDelEpNum(sock->mLocalIp, sock->mListenPort, sock->PeerIpPort(), + mOobServers); + /* remove ep */ + UBSHcomNetEndpointPtr brokenEp = reinterpret_cast(sock->UpContext()); + sock->DecreaseRef(); + NN_ASSERT_LOG_RETURN(brokenEp.Get() != nullptr, NN_ERROR); + brokenEp->mState.Set(NEP_BROKEN); + /* call upper function */ + mEndPointBrokenHandler(brokenEp); + DestroyEndpoint(brokenEp); + return NN_EP_CLOSE; +} + +NResult NetDriverSockWithOOB::HandleSockRealConnect(SockOpContextInfo &ctx) +{ + { + std::lock_guard guard(mEndPointsMutex); + if (mEndPoints.count(ctx.sock->Id())) { + NN_LOG_WARN("Duplicate real connect for driver " << mName << " sock id " << ctx.sock->Id()); + return NN_ERROR; + } + } + + NetLocalAutoDecreasePtr autoDecSock((ctx).sock); + NN_ASSERT_LOG_RETURN(ctx.sock->UpContext1() != 0, NN_ERROR) + ConnectHeader header {}; + SockWorker *worker = nullptr; + UBSHcomNetEndpointPtr ep = nullptr; + static thread_local std::string payload; + NResult result = NN_EP_CLOSE; + + worker = reinterpret_cast((ctx).sock->UpContext1()); + if (NN_UNLIKELY(worker == nullptr)) { + ctx.sock->Close(); + OOBSecureProcess::SecProcessDelEpNum(ctx.sock->mLocalIp, ctx.sock->mListenPort, ctx.sock->PeerIpPort(), + mOobServers); + NN_LOG_ERROR("Invalid worker for driver " << mName); + return NN_EP_CLOSE; + } + /* handle real connection */ + header.wholeHeader[0] = (ctx).header->seqNo; + do { + if (NN_UNLIKELY(header.magic != mOptions.magic)) { + NN_LOG_ERROR("Invalid client request for driver " << mName << ", wrong mgc"); + break; + } + + /* create ep */ + ep = new (std::nothrow) NetAsyncEndpointSock(ctx.sock->Id(), ctx.sock, this, worker->Index()); + if (NN_UNLIKELY(ep == nullptr)) { + NN_LOG_ERROR("Failed to new async sock ep in driver " << mName << ", probably out of memory"); + break; + } + + if (ctx.sock->mType == SOCK_UDS) { + struct ucred remoteIds {}; + socklen_t len = static_cast(sizeof(struct ucred)); + if (NN_UNLIKELY(getsockopt(ctx.sock->FD(), SOL_SOCKET, SO_PEERCRED, &remoteIds, &len) != 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get uds ids in driver " << mName << ", errno:" << errno << + " error:" << NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + break; + } + ep->RemoteUdsIdInfo(remoteIds.pid, remoteIds.uid, remoteIds.gid); + } + + auto childEp = ep.ToChild(); + if (NN_UNLIKELY(childEp == nullptr)) { + NN_LOG_ERROR("ToChild failed"); + break; + } + if (mEnableTls) { + childEp->EnableEncrypt(mOptions); + childEp->SetSecrets((ctx).sock->mSecret); + } + if (mOptions.tcpSendZCopy) { + childEp->EnableSendZCopy(); + } + /* set payload */ + uint32_t payloadLen = ctx.header->dataLength; + if (payloadLen == 0 || payloadLen > NN_NO1024) { + NN_LOG_ERROR("Invalid payload length " << payloadLen << ", it should be 1 ~ 1024"); + break; + } + + if (payloadLen > 0) { + payload.resize(ctx.header->dataLength + NN_NO1); + payload = { reinterpret_cast(ctx.dataAddress), ctx.header->dataLength }; + payload[ctx.header->dataLength] = '\0'; + } else { + payload.clear(); + } + ep->Payload(payload); + ep->StoreConnInfo(ctx.sock->mLocalIp, ctx.sock->mListenPort, ctx.sock->mVersion, payload); + ctx.sock->SetMrChecker(&mMrChecker); + /* do callback */ + if (NN_UNLIKELY((result = mNewEndPointHandler(ctx.sock->PeerIpPort(), ep, payload)) != NN_OK)) { + NN_LOG_ERROR("Got " << result << " from new ep callback, this new connection from " << + ctx.sock->PeerIpPort() << " will be dropped"); + break; + } + int8_t ready = 1; + if ((result = ctx.sock->Send(&ready, sizeof(int8_t))) != NN_OK) { + NN_LOG_ERROR("Failed to send ready signal to client, result " << result); + break; + } + if (ctx.sock->SetNonBlockingIo() != SS_OK) { + NN_LOG_WARN("Unable to set sock " << ctx.sock->Name() << " nonblocking io mode."); + break; + } + /* set to established */ + ep->State().Set(NEP_ESTABLISHED); + /* set ep as sock up context and add ep into map */ + ctx.sock->UpContext(reinterpret_cast(ep.Get())); + AddEp(ep); + result = NN_OK; + NN_LOG_INFO("New connection from " << ctx.sock->PeerIpPort() << " established, async ep id " << + ep->Id() << " worker info " << worker->DetailName()); + } while (0); + + if (result != NN_OK) { + worker->RemoveFromEpoll(ctx.sock); + ctx.sock->Close(); + OOBSecureProcess::SecProcessDelEpNum(ctx.sock->mLocalIp, ctx.sock->mListenPort, ctx.sock->PeerIpPort(), + mOobServers); + result = NN_EP_CLOSE; + } + + return result; +} + +NResult NetDriverSockWithOOB::HandleNewRequest(SockOpContextInfo &ctx) +{ + NN_ASSERT_LOG_RETURN(ctx.sock != nullptr, NN_ERROR) + NResult result = NN_OK; + + if (NN_UNLIKELY(ctx.errType != SockOpContextInfo::SS_NO_ERROR)) { + NN_LOG_WARN("sock " << ctx.sock->mName << " received an incorrect request and it is causing ep destroy"); + return HandleSockError(ctx.sock); + } + + NN_ASSERT_LOG_RETURN(ctx.header != nullptr, NN_ERROR); + /* user op code */ + if (NN_LIKELY(ctx.header->opCode >= 0)) { + static thread_local UBSHcomNetRequestContext netCtx {}; + static thread_local UBSHcomNetMessage netMsg {}; + + /* set net context */ + NN_ASSERT_LOG_RETURN(ctx.sock->UpContext() != 0, NN_ERROR) + netCtx.mEp.Set(reinterpret_cast(ctx.sock->UpContext())); + netCtx.mHeader = *(ctx.header); + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED; + if (ctx.header->immData != NN_NO0) { + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED_RAW; + } + netCtx.mMessage = &netMsg; + + netMsg.mBuf = ctx.sock->ReceiveData().Data(); + netMsg.mDataLen = ctx.sock->ReceiveData().ActualDataSize(); + + /* call upper handler */ + result = mReceivedRequestHandler(netCtx); + netCtx.mEp.Set(nullptr); + netMsg.mBuf = nullptr; + return result; + } else if (ctx.header->opCode == SockExchangeOp::REAL_CONNECT) { + return HandleSockRealConnect(ctx); + } + + return NN_OK; +} + +NResult NetDriverSockWithOOB::HandleReqPosted(SockOpContextInfo *ctx) +{ + NN_ASSERT_LOG_RETURN(ctx != nullptr, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx->sock != nullptr, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx->sock->UpContext() != 0, NN_ERROR) + NResult result = NN_OK; + + static thread_local UBSHcomNetRequestContext netCtx {}; + if (ctx->opType == SockOpContextInfo::SS_SEND || ctx->opType == SockOpContextInfo::SS_SEND_RAW) { + if (ctx->opType == SockOpContextInfo::SS_SEND) { + if (mOptions.tcpSendZCopy) { + if (NN_UNLIKELY(memcpy_s(&(netCtx.mHeader), sizeof(UBSHcomNetTransHeader), + &ctx->headerRequest->sendHeader, sizeof(UBSHcomNetTransHeader)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } else { + if (NN_UNLIKELY(memcpy_s(&(netCtx.mHeader), sizeof(UBSHcomNetTransHeader), + reinterpret_cast(ctx->sendBuff), + sizeof(UBSHcomNetTransHeader)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + } else { + netCtx.mHeader.Invalid(); + } + + netCtx.mEp.Set(reinterpret_cast(ctx->sock->UpContext())); + netCtx.mResult = SockOpContextInfo::GetNResult(ctx->errType); + netCtx.mMessage = nullptr; + netCtx.mOpType = + ctx->opType == SockOpContextInfo::SS_SEND ? UBSHcomNetRequestContext::NN_SENT : + UBSHcomNetRequestContext::NN_SENT_RAW; + + netCtx.mOriginalReq = {}; + netCtx.mOriginalReq.upCtxSize = ctx->upCtxSize; + + if (netCtx.mOriginalReq.upCtxSize > 0 && + netCtx.mOriginalReq.upCtxSize <= sizeof(UBSHcomNetTransRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalReq.upCtxData, NN_NO16, ctx->upCtx, ctx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy ctx to netCtx"); + return NN_INVALID_PARAM; + } + } + + // call to callback + if (NN_UNLIKELY((result = mRequestPostedHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call requestPostedHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode + << ", dataSize " << netCtx.mHeader.dataLength << "]"); + } + netCtx.mEp.Set(nullptr); + if (!mOptions.tcpSendZCopy && ctx->sendBuff != nullptr) { + ctx->sock->mSockDriverSendMR->ReturnBuffer(reinterpret_cast(ctx->sendBuff)); + } + } else if (ctx->opType == SockOpContextInfo::SS_SEND_RAW_SGL) { + auto sglCtx = ctx->sendCtx; + + // set context + netCtx.mOpType = UBSHcomNetRequestContext::NN_SENT_RAW_SGL; + netCtx.mEp.Set(reinterpret_cast(ctx->sock->UpContext())); + netCtx.mResult = SockOpContextInfo::GetNResult(ctx->errType); + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + if (NN_UNLIKELY(memcpy_s(netCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, sglCtx->iov, + sizeof(UBSHcomNetTransSgeIov) * sglCtx->iovCount) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + + netCtx.mOriginalSglReq.iov = netCtx.iov; + netCtx.mOriginalSglReq.upCtxSize = ctx->upCtxSize; + netCtx.mOriginalSglReq.iovCount = sglCtx->iovCount; + if (netCtx.mOriginalSglReq.upCtxSize > 0 && + netCtx.mOriginalSglReq.upCtxSize <= sizeof(UBSHcomNetTransSglRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalSglReq.upCtxData, NN_NO16, ctx->upCtx, ctx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy ctx to netCtx"); + return NN_INVALID_PARAM; + } + } + + // call to callback + if (NN_UNLIKELY((result = mRequestPostedHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call requestPostedHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + netCtx.mEp.Set(nullptr); + ctx->sock->mSglCtxInfoPool.Return(ctx->sendCtx); + ctx->sendCtx = nullptr; + } else { + NN_LOG_WARN("Unreachable path"); + } + if (mOptions.tcpSendZCopy) { + ctx->sock->mHeaderReqInfoPool.Return(ctx->headerRequest); + ctx->headerRequest = nullptr; + } + ctx = nullptr; + return NN_OK; +} + +NResult NetDriverSockWithOOB::OneSideDone(SockOpContextInfo *ctx) +{ + NN_ASSERT_LOG_RETURN(ctx != nullptr, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx->sock != nullptr, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx->sock->UpContext() != 0, NN_ERROR) + NN_ASSERT_LOG_RETURN(ctx->sock->UpContext1() != 0, NN_ERROR) + NResult result = NN_OK; + + auto worker = reinterpret_cast(ctx->sock->UpContext1()); + static thread_local UBSHcomNetRequestContext netCtx {}; + if (ctx->opType == SockOpContextInfo::SS_WRITE || ctx->opType == SockOpContextInfo::SS_READ) { + // set context + netCtx.mEp.Set(reinterpret_cast(ctx->sock->UpContext())); + netCtx.mResult = SockOpContextInfo::GetNResult(ctx->errType); + netCtx.mOpType = + ctx->opType == SockOpContextInfo::SS_WRITE ? UBSHcomNetRequestContext::NN_WRITTEN : + UBSHcomNetRequestContext::NN_READ; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + netCtx.mOriginalReq.lAddress = ctx->sendCtx->iov[0].lAddress; + netCtx.mOriginalReq.lKey = ctx->sendCtx->iov[0].lKey; + netCtx.mOriginalReq.size = ctx->sendCtx->iov[0].size; + netCtx.mOriginalReq.upCtxSize = ctx->upCtxSize; + + if (netCtx.mOriginalReq.upCtxSize > 0 && + netCtx.mOriginalReq.upCtxSize <= sizeof(UBSHcomNetTransRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalReq.upCtxData, NN_NO16, ctx->upCtx, ctx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("failed to copy ctx to upCtxData"); + return NN_INVALID_PARAM; + } + } + + // return context to worker and ctx is not usable anymore + worker->ReturnSglContextInfo(ctx->sendCtx); + worker->ReturnOpContextInfo(ctx); + + // called to callback + if (NN_UNLIKELY((result = mOneSideDoneHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call oneSideDoneHandler in Driver " << mName << " return non-zero for buff done"); + } + netCtx.mEp.Set(nullptr); + } else if (ctx->opType == SockOpContextInfo::SS_SGL_WRITE || ctx->opType == SockOpContextInfo::SS_SGL_READ) { + auto sglCtx = ctx->sendCtx; + // set context + netCtx.mEp.Set(reinterpret_cast(ctx->sock->UpContext())); + netCtx.mResult = SockOpContextInfo::GetNResult(ctx->errType); + netCtx.mOpType = ctx->opType == SockOpContextInfo::SS_SGL_WRITE ? UBSHcomNetRequestContext::NN_SGL_WRITTEN : + UBSHcomNetRequestContext::NN_SGL_READ; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + if (NN_UNLIKELY(memcpy_s(netCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, + sglCtx->iov, sizeof(UBSHcomNetTransSgeIov) * sglCtx->iovCount) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + netCtx.mOriginalSglReq.iov = netCtx.iov; + netCtx.mOriginalSglReq.iovCount = sglCtx->iovCount; + netCtx.mOriginalSglReq.upCtxSize = ctx->upCtxSize; + if (netCtx.mOriginalSglReq.upCtxSize > 0 && + netCtx.mOriginalSglReq.upCtxSize <= sizeof(UBSHcomNetTransSglRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalSglReq.upCtxData, NN_NO16, ctx->upCtx, ctx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy ctx to netCtx"); + return NN_INVALID_PARAM; + } + } + worker->ReturnSglContextInfo(sglCtx); + worker->ReturnOpContextInfo(ctx); + // called to callback + if (NN_UNLIKELY((result = mOneSideDoneHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call oneSideDoneHandler in Driver " << mName << " return non-zero for sgl type done"); + } + netCtx.mEp.Set(nullptr); + } else { + NN_LOG_WARN("Unreachable path"); + } + + return NN_OK; +} + +NResult NetDriverSockWithOOB::HandleEpClose(Sock *sock) +{ + NN_ASSERT_LOG_RETURN(sock != nullptr, NN_ERROR); + + NN_LOG_WARN("sock " << sock->mName << " received the incorrect event and it is causing ep destroy."); + return HandleSockError(sock); +} + +NResult NetDriverSockWithOOB::MultiRailNewConnection(OOBTCPConnection &conn) +{ + NN_LOG_ERROR("Invalid operation, TCP is not supported by MultiRail"); + return NN_ERROR; +} + +void *NetDriverSockWithOOB::MapAndRegVaForUB(unsigned long memid, uint64_t &va) +{ + NN_LOG_ERROR("operation is not supported in tcp"); + return nullptr; +} + +NResult NetDriverSockWithOOB::UnmapVaForUB(uint64_t &va) +{ + NN_LOG_ERROR("operation is not supported in tcp"); + return NN_ERROR; +} +} +} diff --git a/src/transport/sock/net_sock_driver_oob.h b/src/transport/sock/net_sock_driver_oob.h new file mode 100644 index 0000000000000000000000000000000000000000..f80329ad4f67d6da5d896732360cc2be2118c852 --- /dev/null +++ b/src/transport/sock/net_sock_driver_oob.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_SOCK_DRIVER_OOB_H_234234 +#define OCK_HCOM_NET_SOCK_DRIVER_OOB_H_234234 + +#include "net_sock_common.h" + +namespace ock { +namespace hcom { +class NetDriverSockWithOOB : public UBSHcomNetDriver { +public: + NetDriverSockWithOOB(const std::string &name, bool startOobSvr, UBSHcomNetDriverProtocol protocol, SockType t) + : UBSHcomNetDriver(name, startOobSvr, protocol), mSockType(t) + { + OBJ_GC_INCREASE(NetDriverSockWithOOB); + } + + ~NetDriverSockWithOOB() override + { + OBJ_GC_DECREASE(NetDriverSockWithOOB); + } + + NResult Initialize(const UBSHcomNetDriverOptions &option) override; + + void UnInitialize() override; + + NResult Start() override; + void Stop() override; + + NResult CreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) override; + NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) override; + NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, unsigned long memid) override; + void DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) override; + + inline NResult ValidateMemoryRegion(uint64_t lKey, uintptr_t address, uint64_t size) + { + return mMrChecker.Validate(lKey, address, size); + } + + NResult Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo, + uint8_t clientGrpNo) override; + + NResult Connect(const std::string &oobIp, uint16_t oobPort, const std::string &payload, UBSHcomNetEndpointPtr &ep, + uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) override; + + NResult Connect(const std::string &serverUrl, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, + uint8_t serverGrpNo = 0, uint8_t clientGrpNo = 0, uint64_t ctx = 0) override; + + NResult MultiRailNewConnection(OOBTCPConnection &conn); + + void *MapAndRegVaForUB(unsigned long memid, uint64_t &va) override; + + NResult UnmapVaForUB(uint64_t &va) override; + + void DestroyEndpoint(UBSHcomNetEndpointPtr &ep) override; + void DestroyEndpointById(uint64_t id); + inline NetMemPoolFixedPtr GetOpCtxMemPool() + { + return mOpCtxMemPool; + } + inline NetMemPoolFixedPtr GetSglCtxMemPool() + { + return mSglCtxMemPool; + } + +protected: + NResult ValidateOptions(); + NResult CreateWorkers(); + void ClearWorkers(); + void UnInitializeInner(); + NResult HandleSockError(Sock *sock); + + NResult HandleNewOobConn(OOBTCPConnection &conn); + NResult HandleNewRequest(SockOpContextInfo &ctx); + NResult HandleReqPosted(SockOpContextInfo *ctx); + NResult OneSideDone(SockOpContextInfo *ctx); + NResult HandleEpClose(Sock *sock); + + NResult Connect(const OOBTCPClientPtr &client, const std::string &payload, UBSHcomNetEndpointPtr &outEp, + uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx); + NResult ConnectSyncEp(const OOBTCPClientPtr &client, const std::string &payload, UBSHcomNetEndpointPtr &outEp, + uint8_t serverGrpNo, uint64_t ctx); + + inline bool Remove(uint64_t id) + { + std::lock_guard guard(mEndPointsMutex); + return (mEndPoints.erase(id) > 0); + } + + inline void AddEp(const UBSHcomNetEndpointPtr &newEp) + { + /* added into map */ + if (NN_LIKELY(newEp != nullptr)) { + std::lock_guard guard(mEndPointsMutex); + mEndPoints.emplace(newEp->Id(), newEp); + } + } + + static inline ConnectResp GetConnResp(SockType t) + { + switch (t) { + case SOCK_TCP: + return OK_PROTOCOL_TCP; + case SOCK_UDS: + return OK_PROTOCOL_UDS; + default: + return OK; + } + } + +protected: + SockType mSockType = SockType::SOCK_TCP; + std::vector mWorkers; + std::vector mFilteredIps; + MemoryRegionChecker mMrChecker; + NormalMemoryRegionFixedBuffer *mSockDriverSendMR = nullptr; + + NResult CreateWorkerResource(); + NResult CreateOpCtxMemPool(); + NResult CreateSglCtxMemPool(); + NResult CreateHeaderReqMemPool(); + NResult CreateSendMr(); + + NResult HandleSockRealConnect(SockOpContextInfo &ctx); + + NetMemPoolFixedPtr mOpCtxMemPool = nullptr; + NetMemPoolFixedPtr mSglCtxMemPool = nullptr; + NetMemPoolFixedPtr mHeaderReqMemPool = nullptr; + + friend class NetAsyncEndpointSock; + friend class NetSyncEndpointSock; +}; +} +} + +#endif // OCK_HCOM_NET_SOCK_DRIVER_OOB_H_234234 diff --git a/src/transport/sock/net_sock_sync_endpoint.cpp b/src/transport/sock/net_sock_sync_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..752669826007626cf104822cc4c8c7c57faa2c88 --- /dev/null +++ b/src/transport/sock/net_sock_sync_endpoint.cpp @@ -0,0 +1,748 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_sock_driver_oob.h" +#include "sock_validation.h" +#include "net_sock_sync_endpoint.h" + +namespace ock { +namespace hcom { +NetSyncEndpointSock::NetSyncEndpointSock(uint64_t id, Sock *sock, NetDriverSockWithOOB *driver, + const UBSHcomNetWorkerIndex &workerIndex) + : NetEndpointImpl(id, workerIndex), mSock(sock), mDriver(driver) +{ + if (mSock != nullptr) { + mSock->IncreaseRef(); + } + + if (mDriver != nullptr) { + mSegSize = mDriver->mOptions.mrSendReceiveSegSize; + mAllowedSize = mSegSize - sizeof(SockTransHeader); + mDriver->IncreaseRef(); + + mOpCtxInfoPool.Initialize(mDriver->GetOpCtxMemPool()); + mSglCtxInfoPool.Initialize(mDriver->GetSglCtxMemPool()); + } + + OBJ_GC_INCREASE(NetSyncEndpointSock); +} + +NetSyncEndpointSock::~NetSyncEndpointSock() +{ + if (mSock != nullptr) { + mSock->Close(); + mSock->DecreaseRef(); + } + + if (mDriver != nullptr) { + mDriver->DecreaseRef(); + } + + OBJ_GC_DECREASE(NetSyncEndpointSock); + // do later +} + +NResult NetSyncEndpointSock::SetEpOption(UBSHcomEpOptions &epOptions) +{ + if (mDefaultTimeout > 0 && epOptions.sendTimeout > mDefaultTimeout) { + NN_LOG_WARN("send timeout should not longer than mDefaultTimeout " << mDefaultTimeout); + return NN_ERROR; + } + + if (NN_UNLIKELY(mSock->SetBlockingSendTimeout(epOptions.sendTimeout) != SS_OK)) { + NN_LOG_WARN("Unable to set sock " << mSock->Name() << " timeout options"); + return NN_ERROR; + } + + return NN_OK; +} + +#define TIMEOUT_PROCESS() \ + do { \ + mSock->Close(); \ + mState.Set(NEP_BROKEN); \ + mDriver->DestroyEndpointById(mId); \ + } while (0) + +NResult NetSyncEndpointSock::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post send with seqNo as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post send with seqNo as buff validation failed"); + return result; + } + REQ_SIZE_VALIDATION(); + OPCODE_VALIDATION(); + + UBSHcomNetTransHeader header {}; + header.dataLength = request.size; + header.seqNo = seqNo == 0 ? NextSeq() : seqNo; + header.flags = NTH_TWO_SIDE; + header.opCode = opCode; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + mLastFlag = header.flags; + mLastSendSeqNo = header.seqNo; + + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SOCK_EP_SYNC_POST_SEND); + do { + result = mSock->PostSend(header, request); + if (result == SS_OK) { + TRACE_DELAY_END(SOCK_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NetMonotonic::TimeNs() < finishTime && NeedRetry(result) && mDefaultTimeout != 0) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + // no retry result or timeout = 0 + break; + } while (true); + + NN_LOG_ERROR("Failed to sync post send request, result " << result); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetSyncEndpointSock::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post send with opInfo as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post send with opInfo as buff validation failed"); + return result; + } + REQ_SIZE_VALIDATION(); + OPCODE_VALIDATION(); + + UBSHcomNetTransHeader header {}; + header.opCode = opCode; + header.seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header.flags = ((uint16_t)opInfo.flags << NN_NO8) | ((uint16_t)NTH_TWO_SIDE); + header.errorCode = opInfo.errorCode; + header.timeout = opInfo.timeout; + header.dataLength = request.size; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + mLastSendSeqNo = header.seqNo; + mLastFlag = header.flags; + + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SOCK_EP_SYNC_POST_SEND); + do { + result = mSock->PostSend(header, request); + if (result == SS_OK) { + TRACE_DELAY_END(SOCK_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } else if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + // no retry result or timeout = 0 + break; + } while (true); + + NN_LOG_ERROR("Failed to sync post send request with opInfo, result is " << result); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetSyncEndpointSock::PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post send raw as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post send raw as buff validation failed"); + return result; + } + REQ_SIZE_VALIDATION(); + + UBSHcomNetTransHeader header {}; + header.seqNo = seqNo == 0 ? NextSeq() : seqNo; + header.immData = 1; + header.flags = NTH_TWO_SIDE; + header.dataLength = request.size; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + mLastSendSeqNo = header.seqNo; + mLastFlag = header.flags; + + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SOCK_EP_SYNC_POST_SEND_RAW); + do { + result = mSock->PostSend(header, request); + if (result == SS_OK) { + TRACE_DELAY_END(SOCK_EP_SYNC_POST_SEND_RAW, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // sleep is not suitable for scenes like LWT + continue; + } else if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + // no retry result or timeout = 0 + break; + } while (true); + + NN_LOG_ERROR("Failed to sync post send raw request, result " << result); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_SEND_RAW, result); + return result; +} + +NResult NetSyncEndpointSock::PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) +{ + size_t totalSize = 0; + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post send raw sgl as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = TwoSideSglValidation(request, mDriver, mSegSize, totalSize)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post send raw sgl as sgl validation failed"); + return result; + } + + UBSHcomNetTransHeader header {}; + header.flags = NTH_TWO_SIDE_SGL; + header.immData = 1; + header.dataLength = totalSize; + header.seqNo = seqNo == 0 ? NextSeq() : seqNo; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + mLastSendSeqNo = header.seqNo; + mLastFlag = header.flags; + + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(SOCK_EP_SYNC_POST_SEND_RAW_SGL); + do { + result = mSock->PostSendSgl(header, request); + if (result == SS_OK) { + TRACE_DELAY_END(SOCK_EP_SYNC_POST_SEND_RAW_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // sleep is not suitable for scenes like LWT + continue; + } else if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + // no retry result or timeout = 0 + break; + } while (true); + + NN_LOG_ERROR("Failed to post send request, result " << result); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_SEND_RAW_SGL, result); + return result; +} + +#define RETURN_RESOURCES(opCtx) \ + do { \ + (void)mSock->RemoveOpCtx((opCtx)->sendCtx->sendHeader.seqNo); \ + mSock->DecreaseRef(); \ + mSglCtxInfoPool.Return((opCtx)->sendCtx); \ + mOpCtxInfoPool.Return((opCtx)); \ + } while (0) + +NResult NetSyncEndpointSock::PostRead(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post read as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post read as buff validation failed"); + return result; + } + + if (NN_UNLIKELY((result = OneSideValidation(request, mDriver)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post read as one side validation failed"); + return result; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to post read with sock " << mSock->Name() << " as no ctx left"); + return SS_CTX_FULL; + } + + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with sock " << mSock->Name() << " as no sglCtx left"); + mOpCtxInfoPool.Return(ctx); + return SS_CTX_FULL; + } + + UBSHcomNetTransHeader header {}; + header.seqNo = mSock->OneSideNextSeq(); + header.flags = NTH_READ; + header.dataLength = sizeof(UBSHcomNetTransSgeIov); + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + mLastFlag = header.flags; + + SockOpContextInfo::SockOpType opType = SockOpContextInfo::SockOpType::SS_READ; + if (NN_UNLIKELY(FillReadWriteCtx(ctx, sglCtx, request, opType, header) != NN_OK)) { + NN_LOG_ERROR("Failed to fill read ctx"); + mSglCtxInfoPool.Return(sglCtx); + mOpCtxInfoPool.Return(ctx); + return NN_INVALID_PARAM; + } + + mSock->AddOpCtx(header.seqNo, ctx); + mSock->IncreaseRef(); + + uint64_t finishTime = GetFinishTime(); + bool flag = true; + TRACE_DELAY_BEGIN(SOCK_EP_SYNC_POST_READ); + do { + result = mSock->PostRead(ctx); + if (result == SS_OK) { + mSglCtxInfoPool.Return(sglCtx); + mOpCtxInfoPool.Return(ctx); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_READ, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); + continue; + } else if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post read request, result " << result); + RETURN_RESOURCES(ctx); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_READ, result); + return result; +} + +NResult NetSyncEndpointSock::PostRead(const UBSHcomNetTransSglRequest &request) +{ + size_t totalSize = 0; + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post read sgl as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = OneSideSglValidation(request, mDriver, totalSize)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post read sgl as sgl validation failed"); + return result; + } + + auto opCtx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtx == nullptr)) { + NN_LOG_ERROR("Failed to post read sgl with sock " << mSock->Name() << " as no op ctx left"); + return SS_CTX_FULL; + } + + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to post read sgl with sock " << mSock->Name() << " as no op ctx left"); + mOpCtxInfoPool.Return(opCtx); + return SS_CTX_FULL; + } + + UBSHcomNetTransHeader header {}; + header.seqNo = mSock->OneSideNextSeq(); + header.flags = NTH_READ_SGL; + header.dataLength = sizeof(request.iovCount) + sizeof(UBSHcomNetTransSgeIov) * request.iovCount; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + mLastFlag = header.flags; + + uint64_t finishTime = GetFinishTime(); + + SockOpContextInfo::SockOpType opType = SockOpContextInfo::SockOpType::SS_SGL_READ; + if (NN_UNLIKELY(FillReadWriteSglCtx(opCtx, sglCtx, request, opType, header) != NN_OK)) { + NN_LOG_ERROR("Failed to fill read sgl ctx"); + mSglCtxInfoPool.Return(sglCtx); + mOpCtxInfoPool.Return(opCtx); + return NN_INVALID_PARAM; + } + + mSock->AddOpCtx(header.seqNo, opCtx); + mSock->IncreaseRef(); + bool readSglFlag = true; + TRACE_DELAY_BEGIN(SOCK_EP_SYNC_POST_READ_SGL); + do { + result = mSock->PostReadSgl(opCtx); + if (result == SS_OK) { + mSglCtxInfoPool.Return(sglCtx); + mOpCtxInfoPool.Return(opCtx); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_READ_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // sleep is not suitable for scenes like LWT + continue; + } else if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + // no retry result or timeout = 0 + readSglFlag = false; + } while (readSglFlag); + + NN_LOG_ERROR("Failed to post read sgl request, result " << result); + RETURN_RESOURCES(opCtx); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_READ_SGL, result); + return result; +} + +NResult NetSyncEndpointSock::PostWrite(const UBSHcomNetTransRequest &request) +{ + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post write as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = BuffValidation(request)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post write as buff validation failed"); + return result; + } + + if (NN_UNLIKELY((result = OneSideValidation(request, mDriver)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post write as one side validation failed"); + return result; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with sock " << mSock->Name() << " as no reqInfo left"); + return SS_CTX_FULL; + } + + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with sock " << mSock->Name() << " as no sglCtx left"); + mOpCtxInfoPool.Return(ctx); + return SS_CTX_FULL; + } + + UBSHcomNetTransHeader header {}; + header.seqNo = mSock->OneSideNextSeq(); + header.flags = NTH_WRITE; + header.dataLength = sizeof(UBSHcomNetTransSgeIov) + request.size; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + mLastFlag = header.flags; + + uint64_t finishTime = GetFinishTime(); + + SockOpContextInfo::SockOpType opType = SockOpContextInfo::SockOpType::SS_WRITE; + if (NN_UNLIKELY(FillReadWriteCtx(ctx, sglCtx, request, opType, header) != NN_OK)) { + NN_LOG_ERROR("Failed to fill write ctx"); + mSglCtxInfoPool.Return(sglCtx); + mOpCtxInfoPool.Return(ctx); + return NN_INVALID_PARAM; + } + + mSock->AddOpCtx(header.seqNo, ctx); + mSock->IncreaseRef(); + + bool flag = true; + TRACE_DELAY_BEGIN(SOCK_EP_SYNC_POST_WRITE); + do { + result = mSock->PostWrite(ctx); + if (result == SS_OK) { + mSglCtxInfoPool.Return(sglCtx); + mOpCtxInfoPool.Return(ctx); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_WRITE, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } else if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post write request, result " << result); + RETURN_RESOURCES(ctx); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_WRITE, result); + return result; +} + +NResult NetSyncEndpointSock::PostWrite(const UBSHcomNetTransSglRequest &request) +{ + size_t totalSize = 0; + NResult result = NN_OK; + if (NN_UNLIKELY((result = StateValidation(mState, mId, mDriver, mSock)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post write sgl as state validation failed"); + return result; + } + + if (NN_UNLIKELY((result = OneSideSglValidation(request, mDriver, totalSize)) != NN_OK)) { + NN_LOG_ERROR("Sock failed to sync post write sgl as sgl validation failed"); + return result; + } + + auto opCtx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtx == nullptr)) { + NN_LOG_ERROR("Failed to post write sgl with sock " << mSock->Name() << " as no op ctx left"); + return SS_PARAM_INVALID; + } + + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to post write sgl with sock " << mSock->Name() << " as no op ctx left"); + mOpCtxInfoPool.Return(opCtx); + return SS_PARAM_INVALID; + } + + UBSHcomNetTransHeader header {}; + header.seqNo = mSock->OneSideNextSeq(); + header.flags = NTH_WRITE_SGL; + header.dataLength = sizeof(request.iovCount) + sizeof(UBSHcomNetTransSgeIov) * request.iovCount + totalSize; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + mLastFlag = header.flags; + + uint64_t finishTime = GetFinishTime(); + + SockOpContextInfo::SockOpType opType = SockOpContextInfo::SockOpType::SS_SGL_WRITE; + if (NN_UNLIKELY(FillReadWriteSglCtx(opCtx, sglCtx, request, opType, header) != NN_OK)) { + NN_LOG_ERROR("Failed to fill write sgl ctx"); + mSglCtxInfoPool.Return(sglCtx); + mOpCtxInfoPool.Return(opCtx); + return NN_INVALID_PARAM; + } + + mSock->AddOpCtx(header.seqNo, opCtx); + mSock->IncreaseRef(); + bool writeSglFlag = true; + TRACE_DELAY_BEGIN(SOCK_EP_SYNC_POST_WRITE_SGL); + do { + result = mSock->PostWriteSgl(opCtx); + if (result == SS_OK) { + mSglCtxInfoPool.Return(sglCtx); + mOpCtxInfoPool.Return(opCtx); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_WRITE_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } else if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + // no retry result or timeout = 0 + writeSglFlag = false; + } while (writeSglFlag); + + NN_LOG_ERROR("Failed to post write sgl request, result " << result); + RETURN_RESOURCES(opCtx); + TRACE_DELAY_END(SOCK_EP_SYNC_POST_WRITE_SGL, result); + return result; +} + +static inline NResult WriteData(Sock *sock, SockTransHeader &header, SockOpContextInfo *originalCtx, void *buf) +{ + if (header.flags == NTH_READ_ACK) { + if (originalCtx->sendCtx->iov[0].size != header.dataLength) { + NN_LOG_ERROR("Failed to check sock with sock " << sock->Name() << " as size different."); + return SS_PARAM_INVALID; + } + + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(originalCtx->sendCtx->iov[0].lAddress), + originalCtx->sendCtx->iov[0].size, buf, originalCtx->sendCtx->iov[0].size) != NN_OK)) { + NN_LOG_ERROR("Failed to copy buf to sendCtx"); + return NN_INVALID_PARAM; + } + } else if (header.flags == NTH_READ_SGL_ACK) { + /* write data */ + if (header.dataLength < + (sizeof(UBSHcomNetTransSglRequest::iovCount) + sizeof(UBSHcomNetTransSgeIov) * + originalCtx->sendCtx->iovCount)) { + NN_LOG_ERROR("Failed to ReadSglAck as data size " << header.dataLength << " is less than iov size"); + return SS_PARAM_INVALID; + } + auto iovCount = reinterpret_cast(buf); + if (*iovCount == 0 || *iovCount > NN_NO4 || *iovCount != originalCtx->sendCtx->iovCount) { + NN_LOG_ERROR("Failed to check sock with sock " << sock->Name() << " as iov count is illegal."); + return SS_PARAM_INVALID; + } + auto sgeIov = + reinterpret_cast(reinterpret_cast(buf) + + sizeof(UBSHcomNetTransSglRequest::iovCount)); + auto data = reinterpret_cast(reinterpret_cast(buf) + + sizeof(UBSHcomNetTransSglRequest::iovCount) + sizeof(UBSHcomNetTransSgeIov) * (*iovCount)); + + uint32_t dataSize = 0; + for (uint16_t i = 0; i < *iovCount; i++) { + dataSize += sgeIov->size; + } + + if (originalCtx->sendCtx->sendHeader.dataLength + dataSize != header.dataLength) { + NN_LOG_ERROR("Failed to check sock with sock " << sock->Name() << " as size different."); + return SS_PARAM_INVALID; + } + // do later check + uint32_t copyOffset = 0; + for (uint16_t i = 0; i < *iovCount; i++) { + UBSHcomNetTransSgeIov iov = originalCtx->sendCtx->iov[i]; + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(iov.lAddress), iov.size, data + copyOffset, + iov.size) != NN_OK)) { + NN_LOG_ERROR("Failed to copy data to iov"); + return NN_INVALID_PARAM; + } + copyOffset += iov.size; + } + } + return NN_OK; +} + +NResult NetSyncEndpointSock::WaitCompletion(int32_t timeout) +{ + if (mLastFlag == NTH_TWO_SIDE || mLastFlag == NTH_TWO_SIDE_SGL) { + return NN_OK; + } + + auto result = mSock->PostReceiveHeader(mRespCtx.mHeader, timeout); + if (result != SS_OK) { + if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + NN_LOG_ERROR("Failed to post receive header, result " << result); + return result; + } + if (mRespCtx.mHeader.flags == NTH_READ_ACK || mRespCtx.mHeader.flags == NTH_READ_SGL_ACK) { + auto msgReady = mRespMessage.AllocateIfNeed(mRespCtx.mHeader.dataLength); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size " << mRespCtx.mHeader.dataLength << + ", probably out of memory"); + return NN_MALLOC_FAILED; + } + result = mSock->PostReceiveBody(mRespMessage.mBuf, mRespCtx.mHeader.dataLength, true); + if (result != SS_OK) { + NN_LOG_ERROR("Failed to receive body, result " << result << ", seqNo " << mRespCtx.mHeader.seqNo); + return result; + } + NN_LOG_TRACE_INFO("Receive body successfully: sock " << mSock->Id() << ", head imm data " << + mRespCtx.mHeader.immData << ", flags " << mRespCtx.mHeader.flags << ", seqNo " << mRespCtx.mHeader.seqNo << + ", data len " << mRespCtx.mHeader.dataLength); + + auto originalReadCtx = mSock->RemoveOpCtx(mRespCtx.mHeader.seqNo); + if (originalReadCtx == nullptr) { + NN_LOG_ERROR("Failed to handle ack with sock " << mSock->Name() << " as invalid seqNo " << + mRespCtx.mHeader.seqNo); + return SS_PARAM_INVALID; + } + if (originalReadCtx->sock != mSock) { + NN_LOG_ERROR("Failed to check with sock " << mSock->Name() << " as sock different."); + return SS_PARAM_INVALID; + } + + return WriteData(mSock, mRespCtx.mHeader, originalReadCtx, mRespMessage.mBuf); + } else if (mRespCtx.mHeader.flags == NTH_WRITE_ACK || mRespCtx.mHeader.flags == NTH_WRITE_SGL_ACK) { + NN_LOG_TRACE_INFO("Post receive header successfully: sock " << mSock->Id() << ", head imm data " << + mRespCtx.mHeader.immData << ", flags " << mRespCtx.mHeader.flags << ", seqNo " << mRespCtx.mHeader.seqNo); + + auto originalWriteCtx = mSock->RemoveOpCtx(mRespCtx.mHeader.seqNo); + if (originalWriteCtx == nullptr) { + NN_LOG_ERROR("Failed to handle ack with sock " << mSock->Name() << " as invalid seqNo " << + mRespCtx.mHeader.seqNo); + return SS_PARAM_INVALID; + } + if (originalWriteCtx->sock != mSock) { + NN_LOG_ERROR("Failed to check with sock " << mSock->Name() << " as sock different."); + return SS_PARAM_INVALID; + } + if (originalWriteCtx->sendCtx->sendHeader.dataLength != mRespCtx.mHeader.dataLength) { + NN_LOG_ERROR("Failed to check sock with sock" << mSock->Name() << " as size different."); + return SS_PARAM_INVALID; + } + } + + return NN_OK; +} + +NResult NetSyncEndpointSock::Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) +{ + auto result = mSock->PostReceiveHeader(mRespCtx.mHeader, timeout); + if (result != SS_OK) { + if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + NN_LOG_ERROR("Failed to post receive header, result " << result); + return result; + } + NN_LOG_TRACE_INFO("Post receive header successfully: sock " << mSock->Id() << ", head imm data " << + mRespCtx.mHeader.immData << ", flags " << mRespCtx.mHeader.flags << ", seqNo " << mRespCtx.mHeader.seqNo); + + if (NN_UNLIKELY(mRespCtx.mHeader.seqNo != mLastSendSeqNo)) { + NN_LOG_ERROR("Received un-matched seq no " << mRespCtx.mHeader.seqNo << ", demand seq no " << mLastSendSeqNo); + return NN_SEQ_NO_NOT_MATCHED; + } + + auto msgReady = mRespMessage.AllocateIfNeed(mRespCtx.mHeader.dataLength); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size " << mRespCtx.mHeader.dataLength << + ", probably out of memory"); + return NN_MALLOC_FAILED; + } + + result = mSock->PostReceiveBody(mRespMessage.mBuf, mRespCtx.mHeader.dataLength, false); + if (result != SS_OK) { + if (result == SS_TIMEOUT) { + TIMEOUT_PROCESS(); + } + NN_LOG_ERROR("Failed to receive body, result " << result << ", seqNo " << mRespCtx.mHeader.seqNo); + return result; + } + NN_LOG_TRACE_INFO("Post receive body successfully: sock " << mSock->Id() << ", head imm data " << + mRespCtx.mHeader.immData << ", flags " << mRespCtx.mHeader.flags << ", seqNo " << mRespCtx.mHeader.seqNo << + ", data len " << mRespCtx.mHeader.dataLength); + + mRespMessage.mDataLen = mRespCtx.mHeader.dataLength; + mRespCtx.mMessage = &mRespMessage; + ctx.mHeader = mRespCtx.mHeader; + ctx.mMessage = mRespCtx.mMessage; + return NN_OK; +} + +NResult NetSyncEndpointSock::ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) +{ + return Receive(timeout, ctx); +} +} +} \ No newline at end of file diff --git a/src/transport/sock/net_sock_sync_endpoint.h b/src/transport/sock/net_sock_sync_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..c7dc129a11c9d44826ae5a5ae8ad42df2fdc9b88 --- /dev/null +++ b/src/transport/sock/net_sock_sync_endpoint.h @@ -0,0 +1,225 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_NET_SOCK_SYNC_ENDPOINT_H +#define OCK_HCOM_NET_SOCK_SYNC_ENDPOINT_H + +#include "transport/net_endpoint_impl.h" +#include "net_monotonic.h" +#include "net_security_alg.h" +#include "net_sock_common.h" +#include "sock_common.h" + +namespace ock { +namespace hcom { +using SockOpContextInfoPool = OpContextInfoPool; +using SockSglContextInfoPool = OpContextInfoPool; +class NetSyncEndpointSock : public NetEndpointImpl { +public: + NetSyncEndpointSock(uint64_t id, Sock *sock, NetDriverSockWithOOB *driver, + const UBSHcomNetWorkerIndex &workerIndex); + ~NetSyncEndpointSock() override; + + NResult SetEpOption(UBSHcomEpOptions &epOptions) override; + + uint32_t GetSendQueueCount() override + { + NN_LOG_WARN("[Sock SyncEp] Invalid function, sync strategy does not have queue."); + return 0; + } + + const std::string &PeerIpAndPort() override + { + if (NN_LIKELY(mSock != nullptr)) { + return mSock->PeerIpPort(); + } + + return CONST_EMPTY_STRING; + } + + const std::string &UdsName() override + { + NN_LOG_WARN("[Sock SyncEp] Empty function for now"); + return CONST_EMPTY_STRING; + } + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNo) override; + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + + NResult PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) override; + + NResult PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) override; + + NResult PostRead(const UBSHcomNetTransSglRequest &request) override; + + NResult PostRead(const UBSHcomNetTransRequest &request) override; + + NResult PostWrite(const UBSHcomNetTransSglRequest &request) override; + + NResult PostWrite(const UBSHcomNetTransRequest &request) override; + + NResult WaitCompletion(int32_t timeout) override; + + NResult ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) override; + + NResult Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) override; + + NResult GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &sockIdInfo) override + { + // 用户可能在建链回调中使用该函数,此时ep状态并未设置成NEP_ESTABLISHED + if (!mState.Compare(NEP_ESTABLISHED)) { + NN_LOG_WARN("[Sock SyncEp] EP status is " << mState.Get() << + " now, use ep after the connection established."); + } + + if (!mDriver->mStartOobSvr) { + NN_LOG_ERROR("[Sock SyncEp] oob server is not start"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + if (mDriver->mOptions.oobType != NET_OOB_UDS) { + NN_LOG_ERROR("[Sock SyncEp] oob type is not uds"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + // 通过mRemoteUdsIdInfo值判断是否可以返回给用户 + if (mRemoteUdsIdInfo.gid == 0 && mRemoteUdsIdInfo.pid == 0 && mRemoteUdsIdInfo.uid == 0) { + NN_LOG_ERROR("[Sock SyncEp] RemoteUdsIdInfo has not been set."); + return NN_ERROR; + } + sockIdInfo = mRemoteUdsIdInfo; + return NN_OK; + } + + bool GetPeerIpPort(std::string &ip, uint16_t &port) override + { + if (NN_UNLIKELY(mSock == nullptr)) { + return false; + } + + auto ipPort = mSock->PeerIpPort(); + if (NN_UNLIKELY(ipPort.empty())) { + NN_LOG_ERROR("[Sock SyncEp] ip and port of peer is empty"); + return false; + } + + std::vector ipPortVec; + NetFunc::NN_SplitStr(ipPort, ":", ipPortVec); + if (NN_UNLIKELY(ipPortVec.size() != NN_NO2)) { + NN_LOG_ERROR("[Sock SyncEp] ip and port of peer is invalid"); + return false; + } + + try { + port = std::stoi(ipPortVec[1]); + } catch (...) { + NN_LOG_ERROR("[Sock SyncEp] port of peer is invalid"); + return false; + } + + // port will only be 0 when the connection is on uds + if (port == 0) { + NN_LOG_ERROR("[Sock SyncEp] oob type is uds, does not have peer ip and port msg"); + return false; + } + ip = ipPortVec[0]; + + return true; + } + + void Close() override + { + if (mState.Compare(NEP_ESTABLISHED)) { + mState.Set(NEP_BROKEN); + } else { + return; + } + NN_LOG_INFO("Close tcp sync ep id " << mId << " by user"); + mSock->Close(); + } + +private: + uint64_t inline GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + + static bool inline NeedRetry(NResult sockResult) + { + if (sockResult == SS_TCP_RETRY) { + return true; + } + + return false; + } + + __always_inline NResult FillReadWriteCtx(SockOpContextInfo *ctx, SockSglContextInfo *sglCtx, + const UBSHcomNetTransRequest &request, SockOpContextInfo::SockOpType opType, UBSHcomNetTransHeader header) + { + ctx->sock = mSock; + ctx->opType = opType; + ctx->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + ctx->upCtxSize = request.upCtxSize; + if (ctx->upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(ctx->upCtx, NN_NO16, request.upCtxData, ctx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy request to sglCtx"); + return NN_INVALID_PARAM; + } + } + + UBSHcomNetTransSgeIov iov(request.lAddress, request.rAddress, request.lKey, request.rKey, request.size); + sglCtx->Clone(header, &iov, NN_NO1); + ctx->sendCtx = sglCtx; + return NN_OK; + } + + __always_inline NResult FillReadWriteSglCtx(SockOpContextInfo *ctx, SockSglContextInfo *sglCtx, + const UBSHcomNetTransSglRequest &request, SockOpContextInfo::SockOpType opType, UBSHcomNetTransHeader header) + { + ctx->sock = mSock; + ctx->opType = opType; + ctx->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + ctx->upCtxSize = request.upCtxSize; + if (ctx->upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(ctx->upCtx, NN_NO16, request.upCtxData, ctx->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + + sglCtx->Clone(header, request.iov, request.iovCount); + ctx->sendCtx = sglCtx; + return NN_OK; + } + + Sock *mSock = nullptr; + NetDriverSockWithOOB *mDriver = nullptr; + uint32_t mLastSendSeqNo = 0; + uint16_t mLastFlag = 0; + UBSHcomNetResponseContext mRespCtx; + UBSHcomNetMessage mRespMessage; + + SockOpContextInfoPool mOpCtxInfoPool; + SockSglContextInfoPool mSglCtxInfoPool; + + friend class NetDriverSockWithOOB; +}; +} +} + +#endif // OCK_HCOM_NET_SOCK_SYNC_ENDPOINT_H diff --git a/src/transport/sock/sock_buff.h b/src/transport/sock/sock_buff.h new file mode 100644 index 0000000000000000000000000000000000000000..2103392f1dff46d08d53063ef868e08f7bbe466b --- /dev/null +++ b/src/transport/sock/sock_buff.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef OCK_HCOM_SOCK_BUFF_H +#define OCK_HCOM_SOCK_BUFF_H + +#include +#include + +#include "net_common.h" + +namespace ock { +namespace hcom { + +/* ***************************************************************************************************** */ +/* + * @brief sock buffer for receive + */ +class SockBuff { +public: + SockBuff() + { + OBJ_GC_INCREASE(SockBuff); + } + + ~SockBuff() + { + if (NN_LIKELY(mBuf != nullptr)) { + free(mBuf); + mBuf = nullptr; + } + + OBJ_GC_DECREASE(SockBuff); + } + + inline bool ExpandIfNeed(uint32_t size) + { + if (NN_UNLIKELY(size == NN_NO0)) { + NN_LOG_ERROR("Invalid size 0"); + return false; + } + + if (NN_UNLIKELY(size > mSize)) { + /* + * 1 free the previous allocated memory + * 2 allocate new one + * 3 set mSize to new size + */ + if (mBuf != nullptr) { + free(mBuf); + } + + mBuf = malloc(size); + if (NN_LIKELY(mBuf != nullptr)) { + mSize = size; + return true; + } + + /* return false, if not allocated */ + mBuf = nullptr; + mSize = NN_NO0; + return false; + } + + return true; + } + + inline void *Data() const + { + return mBuf; + } + + inline uintptr_t DataIntPtr() const + { + return reinterpret_cast(mBuf); + } + + inline void ActualDataSize(uint32_t size) + { + mActualDataSize = size; + } + + inline uint32_t ActualDataSize() const + { + return mActualDataSize; + } + + inline uint32_t Size() const + { + return mSize; + } + +private: + void *mBuf = nullptr; + uint32_t mSize = 0; + uint32_t mActualDataSize = 0; +}; +} +} +#endif \ No newline at end of file diff --git a/src/transport/sock/sock_common.h b/src/transport/sock/sock_common.h new file mode 100644 index 0000000000000000000000000000000000000000..0d0ec45e6765e28c708580ee4038e315764c781c --- /dev/null +++ b/src/transport/sock/sock_common.h @@ -0,0 +1,229 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SOCK_COMMON_H_2344 +#define OCK_HCOM_SOCK_COMMON_H_2344 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "hcom.h" +#include "hcom_def.h" +#include "hcom_log.h" +#include "hcom_utils.h" + +namespace ock { +namespace hcom { +/* pre-declare classes */ +class SockBuff; +class SockBuffList; +class Sock; +class SockWorker; + +using SockPtr = NetRef; + +using SockTransHeader = UBSHcomNetTransHeader; + +enum SockType : uint8_t { + SOCK_UDS = 0, /* uds as transfer protocol */ + SOCK_TCP = 1, /* tcp as transfer protocol */ + SOCK_UDS_TCP = 2, /* both tcp and uds, if local host use uds, otherwise use tcp */ +}; + +inline const std::string &SockTypeToString(SockType v) +{ + static const std::string STRINGS[NN_NO3] = {"UDS", "TCP", "UDS&TCP"}; + return STRINGS[v]; +} + +struct SockOptions { + uint16_t receiveBufSizeKB = NN_NO1; /* default receive buffer size, 1KB */ + uint16_t sendBufSizeKB = 0; /* default send buffer size, 0KB */ + uint16_t sendQueueSize = NN_NO256; /* send queue size */ + bool sendZCopy = false; /* whether copy send request */ +}; + +struct SockWorkerOptions { + uint32_t pollingTimeoutMs = NN_NO500; /* epoll or poll timeout */ + uint16_t pollingBatchSize = NN_NO16; /* epoll or poll batch size */ + int16_t cpuId = -1; /* cpu id to bind */ + bool isServer = false; /* is serve or not */ + uint16_t sockReceiveBufKB = 0; /* socket receive buffer in kernel */ + uint16_t sockSendBufKB = 0; /* socket send buffer in kernel */ + uint32_t keepaliveIdleTime = NN_NO64; /* idle 5 seconds to start to probe */ + uint32_t keepaliveProbeTimes = NN_NO7; /* probe times */ + uint32_t keepaliveProbeInterval = NN_NO2; /* probe interval, in second */ + uint32_t sendQueueSize = NN_NO256; + /* worker thread priority [-20,20], 20 is the lowest, -20 is the highest, 0 (default) means do not set priority */ + int threadPriority = 0; + /* timeout during io (s), it should be [-1, 1024], -1 means do not set */ + int tcpUserTimeout = -1; + bool tcpEnableNoDelay = true; + bool tcpSendZCopy = false; + + inline std::string ToString() const + { + std::ostringstream oss; + oss << "options polling-timeout-us: " << pollingTimeoutMs << "us, polling-batch-size: " << pollingBatchSize << + ", is-server: " << isServer << ", recv-buf-size: " << sockReceiveBufKB << "KB, send-buf-size: " << + sockSendBufKB << "KB, keepalive-idle-time: " << keepaliveIdleTime << "s, keepalive-probe-times: " << + keepaliveProbeTimes << ", keepalive-probe-interval: " << keepaliveProbeInterval << "s"; + return oss.str(); + } + + inline std::string ToShortString() const + { + std::ostringstream oss; + oss << "options polling-timeout-us: " << pollingTimeoutMs << ", polling-batch-size: " << pollingBatchSize; + return oss.str(); + } + + void SetValue(const UBSHcomNetDriverOptions& opt, bool isStartOobServer) + { + pollingTimeoutMs = opt.eventPollingTimeout; + pollingBatchSize = opt.pollingBatchSize; + isServer = isStartOobServer; + sockSendBufKB = opt.tcpSendBufSize; + sockReceiveBufKB = opt.tcpReceiveBufSize; + keepaliveIdleTime = opt.heartBeatIdleTime; + keepaliveProbeTimes = opt.heartBeatProbeTimes; + keepaliveProbeInterval = opt.heartBeatProbeInterval; + sendQueueSize = opt.qpSendQueueSize; + threadPriority = opt.workerThreadPriority; + tcpUserTimeout = opt.tcpUserTimeout; + tcpEnableNoDelay = opt.tcpEnableNoDelay; + tcpSendZCopy = opt.tcpSendZCopy; + } +}; + +struct SockSglContextInfo { + SockTransHeader sendHeader {}; // record header for raw/raw sgl/read/write/ + uint16_t iovCount = 0; // max count:NET_SGE_MAX_IOV + UBSHcomNetTransSgeIov iov[NET_SGE_MAX_IOV] = {}; + + inline void Clone(SockTransHeader newHeader, UBSHcomNetTransSgeIov *newIov, uint16_t newIovCnt) + { + sendHeader = newHeader; + iovCount = newIovCnt; + for (uint16_t i = 0; i < iovCount; i++) { + iov[i] = newIov[i]; + } + } +} __attribute__((packed)); + +struct SockHeaderReqInfo { + SockTransHeader sendHeader{}; + void *request = nullptr; +} __attribute__((packed)); + +struct SockOpContextInfo { + enum SockOpType : uint8_t { + SS_SEND = 0, + SS_SEND_RAW = 1, + SS_SEND_RAW_SGL = 2, + SS_RECEIVE = 3, + SS_WRITE = 4, + SS_READ = 5, + SS_SGL_WRITE = 6, + SS_SGL_READ = 7, + SS_WRITE_ACK = 8, + SS_READ_ACK = 9, + SS_SGL_WRITE_ACK = 10, + SS_SGL_READ_ACK = 11, + }; + + enum SockErrorType : uint8_t { + SS_NO_ERROR = 0, + SS_OPERATE_FAILURE = 1, + SS_RESET_BY_PEER = 2, + SS_OUT_OF_MEM = 3, + SS_TIMEOUT = 4, + }; + + SockTransHeader *header = nullptr; /* receive header operation */ + Sock *sock = nullptr; /* sock */ + uintptr_t dataAddress = 0; /* receive data address */ + uint32_t dataSize = 0; /* receive data size */ + SockOpType opType = SS_RECEIVE; /* receive by default */ + SockErrorType errType = SS_NO_ERROR; /* by default no error */ + uint16_t upCtxSize = 0; /* up context size */ + char upCtx[NN_NO16] = {}; /* 16 bytes for upper context */ + union { + void *sendBuff = nullptr; /* send or send raw: header + req */ + SockSglContextInfo *sendCtx; /* send sgl, read or write */ + SockHeaderReqInfo *headerRequest; /* send, without copy */ + }; + bool isSent = false; /* record the sendMsg is sent or not */ + char rsv[NN_NO7] = {}; /* reserve */ + + static inline NResult GetNResult(SockErrorType errorType) + { + switch (errorType) { + case SockErrorType::SS_NO_ERROR: + return NN_OK; + case SockErrorType::SS_RESET_BY_PEER: + return NN_EP_CLOSE; + case SockErrorType::SS_OUT_OF_MEM: + return NN_MALLOC_FAILED; + case SockErrorType::SS_TIMEOUT: + return NN_MSG_TIMEOUT; + default: + return NN_MSG_ERROR; + } + } +}; + +using SResult = int32_t; + +enum SCode { + SS_OK = 0, + SS_ERROR = 400, /* general error */ + SS_PARAM_INVALID = 401, + SS_MEMORY_ALLOCATE_FAILED = 402, + SS_NEW_OBJECT_FAILED = 403, + SS_SOCK_LISTEN_FAILED = 404, + SS_SOCK_CREATE_FAILED = 405, + SS_SOCK_DATA_SIZE_UN_MATCHED = 406, + SS_SOCK_EPOLL_OP_FAILED = 407, + SS_SOCK_SEND_FAILED = 408, + SS_SOCK_CONNECT_FAILED = 409, + SS_TCP_SET_OPTION_FAILED = 410, + SS_TCP_GET_OPTION_FAILED = 411, + SS_WORKER_EPOLL_FAILED = 412, + SS_TCP_RETRY = 413, + SS_SOCK_SEND_EAGAIN = 414, + SS_SOCK_ADD_QUEUE_FAILED = 415, + SS_CTX_FULL = 416, + SS_OOB_SSL_WRITE_ERROR = 417, + SS_OOB_SSL_READ_ERROR = 418, + SS_RESET_BY_PEER = 419, + SS_SSL_READ_FAILED = 420, + SS_TIMEOUT = 421, +}; + +constexpr uint32_t SOCK_CTX_MAP_RESERVATION = 8192; +} +} + +#endif // OCK_HCOM_SOCK_COMMON_H_2344 diff --git a/src/transport/sock/sock_validation.h b/src/transport/sock/sock_validation.h new file mode 100644 index 0000000000000000000000000000000000000000..a6152f8a8d3f500c7db871a5ee55691c9b3b858e --- /dev/null +++ b/src/transport/sock/sock_validation.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SOCK_VALIDATION_H +#define OCK_HCOM_SOCK_VALIDATION_H +#ifdef SOCK_BUILD_ENABLED + +#include "net_monotonic.h" +#include "net_sock_common.h" +#include "sock_common.h" + +namespace ock { +namespace hcom { +#define OPCODE_VALIDATION() \ + do { \ + if (NN_UNLIKELY(opCode >= MAX_OPCODE)) { \ + NN_LOG_ERROR("Failed to post message as opcode is invalid, which should with the range 0~" << \ + (MAX_OPCODE - 1)); \ + return NN_INVALID_OPCODE; \ + } \ + } while (0) + +#define REQ_SIZE_VALIDATION() \ + do { \ + if (NN_UNLIKELY(request.size > mAllowedSize)) { \ + NN_LOG_ERROR("Failed to post message in sock as size " << request.size << " is too large"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +#define REQ_SIZE_VALIDATION_ZERO_COPY() \ + do { \ + if (NN_UNLIKELY(request.size > NET_SGE_MAX_SIZE)) { \ + NN_LOG_ERROR("Failed to post message in sock as size " << request.size << " is too large"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +static __always_inline NResult StateValidation(UBSHcomNetAtomicState &state, uint64_t id, + NetDriverSockWithOOB *driver, Sock *sock) +{ + if (NN_UNLIKELY(!state.Compare(NEP_ESTABLISHED))) { + NN_LOG_ERROR("Endpoint " << id << " is not established, state is " << UBSHcomNEPStateToString(state.Get())); + return NN_EP_NOT_ESTABLISHED; + } + + if (NN_UNLIKELY(sock == nullptr || driver == nullptr)) { + NN_LOG_ERROR("Invalid endpoint"); + return NN_ERROR; + } + return NN_OK; +} + +static __always_inline NResult BuffValidation(const UBSHcomNetTransRequest &request) +{ + if (NN_UNLIKELY(request.upCtxSize > sizeof(SockOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to post message as upCtxSize > " << sizeof(SockOpContextInfo::upCtx)); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(request.lAddress == 0 || request.size == 0)) { + NN_LOG_ERROR("Failed to post message as source data is null or size is zero"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +static __always_inline NResult TwoSideSglValidation(const UBSHcomNetTransSglRequest &request, + NetDriverSockWithOOB *driver, uint32_t segSize, size_t &totalSize) +{ + if (NN_UNLIKELY(request.upCtxSize > sizeof(SockOpContextInfo::upCtx))) { + NN_LOG_ERROR("Sock failed to post message as upCtxSize > " << sizeof(SockOpContextInfo::upCtx)); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(request.iov == nullptr || request.iovCount == 0 || request.iovCount > NET_SGE_MAX_IOV)) { + NN_LOG_ERROR("Sock failed to post message as iov is null or cnt is invalid " << request.iovCount); + return NN_INVALID_PARAM; + } + + for (uint16_t i = 0; i < request.iovCount; i++) { + if (NN_OK != driver->ValidateMemoryRegion(request.iov[i].lKey, request.iov[i].lAddress, request.iov[i].size)) { + NN_LOG_ERROR("Sock invalid MemoryRegion or lKey in iov of sgl request"); + return NN_INVALID_LKEY; + } + totalSize += request.iov[i].size; + } + + if (NN_UNLIKELY(totalSize < NN_NO1 || totalSize > segSize)) { + NN_LOG_ERROR("Sock Failed to post raw sgl message as size " << totalSize << + " is too large, use one side instead"); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +static __always_inline NResult OneSideValidation(const UBSHcomNetTransRequest &request, NetDriverSockWithOOB *driver) +{ + if (NN_UNLIKELY(request.upCtxSize > sizeof(SockOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to post message as upCtxSize > " << sizeof(SockOpContextInfo::upCtx)); + return NN_INVALID_PARAM; + } + + if (NN_OK != driver->ValidateMemoryRegion(request.lKey, request.lAddress, request.size)) { + NN_LOG_ERROR("Invalid MemoryRegion or lKey in request"); + return NN_INVALID_LKEY; + } + return NN_OK; +} + +static __always_inline NResult OneSideSglValidation(const UBSHcomNetTransSglRequest &request, + NetDriverSockWithOOB *driver, size_t &totalSize) +{ + if (NN_UNLIKELY(request.upCtxSize > sizeof(SockOpContextInfo::upCtx))) { + NN_LOG_ERROR("Failed to post message as upCtxSize > " << sizeof(SockOpContextInfo::upCtx)); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(request.iov == nullptr || request.iovCount == 0 || request.iovCount > NET_SGE_MAX_IOV)) { + NN_LOG_ERROR("Failed to post message as iov is null or cnt is invalid " << request.iovCount); + return NN_INVALID_PARAM; + } + + for (uint16_t i = 0; i < request.iovCount; i++) { + if (NN_UNLIKELY(NN_OK != driver->ValidateMemoryRegion(request.iov[i].lKey, request.iov[i].lAddress, + request.iov[i].size))) { + NN_LOG_ERROR("Invalid MemoryRegion or lKey in iov of sgl request"); + return NN_INVALID_LKEY; + } + totalSize += request.iov[i].size; + } + return NN_OK; +} +} +} +#endif +#endif // OCK_HCOM_SOCK_VALIDATION_H diff --git a/src/transport/sock/sock_worker.h b/src/transport/sock/sock_worker.h new file mode 100644 index 0000000000000000000000000000000000000000..90a76b106407420f3b0478934c10259755e4ba69 --- /dev/null +++ b/src/transport/sock/sock_worker.h @@ -0,0 +1,330 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SOCK_WORKER_H_234214 +#define OCK_HCOM_SOCK_WORKER_H_234214 + +#include + +#include "hcom.h" +#include "net_ctx_info_pool.h" +#include "net_mem_pool_fixed.h" +#include "sock_common.h" +#include "sock_wrapper.h" + +namespace ock { +namespace hcom { +using SockNewReqHandler = std::function; +using SockPostedHandler = std::function; +using SockOneSideHandler = std::function; +using SockEpCloseHandler = std::function; + +using SockOpContextInfoPool = OpContextInfoPool; +using SockSglContextInfoPool = OpContextInfoPool; + +// when there is no request from cq, call this +using SockIdleHandler = UBSHcomNetDriverIdleHandler; + +class SockWorker { +public: + SockWorker(SockType t, const std::string &name, const UBSHcomNetWorkerIndex &index, + const NetMemPoolFixedPtr &opCtxMemPool, const NetMemPoolFixedPtr &sglCtxMemPool, + const NetMemPoolFixedPtr &headerReqMemPool, const SockWorkerOptions &options) + : mType(t), + mName(name), + mIndex(index), + mOpCtxMemPool(opCtxMemPool), + mSglCtxMemPool(sglCtxMemPool), + mHeaderReqMemPool(headerReqMemPool), + mOptions(options) + { + OBJ_GC_INCREASE(SockWorker); + } + + ~SockWorker() + { + UnInitialize(); + + OBJ_GC_DECREASE(SockWorker); + } + + SResult Initialize(); + void UnInitialize(); + + SResult Start(); + void Stop(); + + inline void ReturnResources(Sock *sock, SockOpContextInfo *ctx) + { + sock->ReturnQueueSpace(NN_NO1); + mSglCtxInfoPool.Return(ctx->sendCtx); + ctx->sendCtx = nullptr; + mOpCtxInfoPool.Return(ctx); + ctx = nullptr; + } + + void ReturnResources(Sock *sock, SockOpContextInfo *ctx, SockSglContextInfo *sglCtx); + + inline const UBSHcomNetWorkerIndex &Index() const + { + return mIndex; + } + + inline void SetIndex(const UBSHcomNetWorkerIndex &value) + { + mIndex = value; + } + + std::string DetailName() const + { + std::ostringstream oss; + oss << "[name: " << mName << ", index: " << mIndex.ToString() << "]"; + return oss.str(); + } + + inline void RegisterNewReqHandler(const SockNewReqHandler &h) + { + mNewRequestHandler = h; + } + + inline void RegisterReqPostedHandler(const SockPostedHandler &h) + { + mSendPostedHandler = h; + } + + inline void RegisterOneSideHandler(const SockOneSideHandler &h) + { + mOneSideDoneHandler = h; + } + + inline void RegisterEpCloseHandler(const SockEpCloseHandler &h) + { + mEpCloseHandler = h; + } + + inline void RegisterIdleHandler(const SockIdleHandler &h) + { + mIdleHandler = h; + } + + inline const SockWorkerOptions &Options() const + { + return mOptions; + } + + inline void ReturnOpContextInfo(SockOpContextInfo *&ctx) + { + if (NN_LIKELY(ctx != nullptr)) { + if (NN_LIKELY(ctx->sock != nullptr)) { + ctx->sock->DecreaseRef(); + } + mOpCtxInfoPool.Return(ctx); + ctx = nullptr; + } + } + + inline void ReturnSglContextInfo(SockSglContextInfo *&ctx) + { + if (NN_LIKELY(ctx != nullptr)) { + mSglCtxInfoPool.Return(ctx); + ctx = nullptr; + } + } + + inline SockOpContextInfoPool GetSockOpContextInfoPool() const + { + return mOpCtxInfoPool; + } + + inline SockSglContextInfoPool GetSockSglContextInfoPool() const + { + return mSglCtxInfoPool; + } + + inline SockHeaderReqInfoPool GetSockHeaderReqInfoPool() const + { + return mHeaderReqInfoPool; + } + + inline SockPostedHandler GetSockPostedHandler() const + { + return mSendPostedHandler; + } + + inline SockOneSideHandler GetSockOneSideHandler() const + { + return mOneSideDoneHandler; + } + + SResult PostSend(Sock *sock, SockTransHeader &header, const UBSHcomNetTransRequest &req); + SResult PostSendRawSgl(Sock *sock, SockTransHeader &header, const UBSHcomNetTransSglRequest &req); + SResult PostRead(Sock *sock, SockTransHeader &header, const UBSHcomNetTransRequest &request); + SResult PostRead(Sock *sock, SockTransHeader &header, const UBSHcomNetTransSglRequest &request); + SResult PostWrite(Sock *sock, SockTransHeader &header, const UBSHcomNetTransRequest &request); + SResult PostWrite(Sock *sock, SockTransHeader &header, const UBSHcomNetTransSglRequest &request); + +#define SET_EPOLL_EVENT(selfSock, events, evNewFd) \ + do { \ + (evNewFd).data.ptr = selfSock; \ + (evNewFd).events = events; \ + } while (0) + + inline SResult AddToEpoll(Sock *sock, uint32_t events) + { + NN_ASSERT_LOG_RETURN(sock != nullptr, SS_PARAM_INVALID) + + if (sock->FD() == INVALID_FD) { + return SS_PARAM_INVALID; + } + + struct epoll_event evNewFd {}; + SET_EPOLL_EVENT(sock, events, evNewFd); + NN_LOG_TRACE_INFO("Adding sock " << sock->Id() << " address " << sock << " fd " << sock->FD() << + " into sock worker " << mName); + + if (NN_UNLIKELY(epoll_ctl(mEpollHandle, EPOLL_CTL_ADD, sock->FD(), &evNewFd) != 0)) { + NN_LOG_ERROR("Failed to add fd " << sock->FD() << " into epoll for sock worker " << mName << + ", errno " << errno); + return SS_SOCK_EPOLL_OP_FAILED; + } + + sock->IncreaseRef(); + return SS_OK; + } + + inline SResult ModifyInEpoll(Sock *sock, uint32_t events) + { + NN_ASSERT_LOG_RETURN(sock != nullptr, SS_PARAM_INVALID) + + if (sock->FD() == INVALID_FD) { + return SS_PARAM_INVALID; + } + + NN_LOG_TRACE_INFO("Modifying sock " << sock->Id() << " fd " << sock->FD() << " in sock worker " << mName); + + struct epoll_event evNewFd {}; + SET_EPOLL_EVENT(sock, events, evNewFd); + + if (NN_UNLIKELY(epoll_ctl(mEpollHandle, EPOLL_CTL_MOD, sock->FD(), &evNewFd) != 0)) { + if (errno == ENOENT) { + NN_LOG_ERROR("fd in epoll for worker " << mName << " is not found or has been removed from epoll"); + return SS_SOCK_EPOLL_OP_FAILED; + } + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to modify fd in epoll for sock worker " << mName << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_SOCK_EPOLL_OP_FAILED; + } + + return SS_OK; + } + + inline SResult RemoveFromEpoll(Sock *sock) + { + NN_ASSERT_LOG_RETURN(sock != nullptr, SS_PARAM_INVALID) + + if (sock->FD() == INVALID_FD) { + return SS_PARAM_INVALID; + } + + NN_LOG_TRACE_INFO("Deleting sock " << sock->Id() << " fd " << sock->FD() << " from sock worker " << mName); + + if (NN_UNLIKELY(epoll_ctl(mEpollHandle, EPOLL_CTL_DEL, sock->FD(), nullptr) != 0)) { + if (errno == ENOENT) { + NN_LOG_ERROR("Sock " << sock->Id() << " fd " << sock->FD() << + " has been removed from epoll in worker " << mName); + return SS_OK; + } + NN_LOG_ERROR("Failed to remove from epoll for sock worker " << mName << ", errno " << errno); + return SS_SOCK_EPOLL_OP_FAILED; + } + + sock->DecreaseRef(); + return SS_OK; + } + + void EpCloseByUser(Sock *sock); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + SResult Validate(); + void RunInThread(int16_t cpuId); + void UnInitializeInner(); + SResult HandleReceiveCtx(SockOpContextInfo &opCtx); + SResult PostReadAck(SockOpContextInfo &opCtx); + SResult PostReadAckHandle(SockOpContextInfo &opCtx); + SResult PostWriteAck(SockOpContextInfo &opCtx); + SResult PostWriteAckHandle(SockOpContextInfo &opCtx); + SResult PostWriteSglAck(SockOpContextInfo &opCtx); + SResult PostWriteSglAckHandle(SockOpContextInfo &opCtx); + SResult PostReadSglAck(SockOpContextInfo &opCtx); + SResult PostReadSglAckHandle(SockOpContextInfo &opCtx); + SResult GenerateReadSglAckOpCtxInfo(SockOpContextInfo *&opCtxInfo, SockOpContextInfo &opCtx, + UBSHcomNetTransSgeIov *&rawIov, uint16_t iovCount, uint32_t dataSize); + SResult GenerateWriteSglAckOpCtxInfo(SockOpContextInfo *&opCtxInfo, SockOpContextInfo &opCtx); + inline SResult CheckIovLen(SockOpContextInfo &opCtx, uint16_t &iovCount); + SResult GenerateWriteAckOpCtxInfo(SockOpContextInfo *&opCtxInfo, SockOpContextInfo &opCtx); + SResult InitContextInfoPool(); + __always_inline void BindCpuSetPthreadName(int16_t cpuId) + { + // set cpu id + cpu_set_t cpuSet; + if (cpuId != -1) { + CPU_ZERO(&cpuSet); + CPU_SET(cpuId, &cpuSet); + if (pthread_setaffinity_np(pthread_self(), sizeof(cpuSet), &cpuSet) != 0) { + NN_LOG_WARN("Unable to bind sock worker " << mName << mIndex.ToString() << " to cpu " << cpuId); + } + } + // set thread name + std::string workerName = mType == SOCK_TCP ? "SockWkr" : "UDSWkr"; + workerName += mIndex.ToString(); + pthread_setname_np(pthread_self(), workerName.c_str()); + NN_LOG_INFO("SockWorker [name: " << mName << ", index: " << mIndex.ToString() << ", cpuId: " << cpuId << + ", more " << mOptions.ToShortString() << "] working thread started"); + } + +private: + SockType mType = SOCK_TCP; + std::string mName; + std::mutex mMutex; + UBSHcomNetWorkerIndex mIndex {}; + bool mInited = false; + NetMemPoolFixedPtr mOpCtxMemPool = nullptr; + NetMemPoolFixedPtr mSglCtxMemPool = nullptr; + NetMemPoolFixedPtr mHeaderReqMemPool = nullptr; + + SockWorkerOptions mOptions {}; + + /* variable for thread */ + std::thread mProgressThr; /* thread object of progress */ + bool mStarted = false; /* thread already started or not */ + std::atomic_bool mProgressThrStarted { false }; /* started flag */ + volatile bool mNeedToStop = false; /* flag to be stopped */ + + SockNewReqHandler mNewRequestHandler = nullptr; /* request process related */ + SockPostedHandler mSendPostedHandler = nullptr; /* send request posted process related */ + SockOneSideHandler mOneSideDoneHandler = nullptr; /* one side done will call this */ + SockEpCloseHandler mEpCloseHandler = nullptr; /* ep closing will call this */ + SockIdleHandler mIdleHandler = nullptr; /* no request will call this */ + + int mEpollHandle = -1; /* event polling handle */ + SockOpContextInfoPool mOpCtxInfoPool; + SockSglContextInfoPool mSglCtxInfoPool; + SockHeaderReqInfoPool mHeaderReqInfoPool; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} + +#endif // OCK_HCOM_SOCK_WORKER_H_234214 diff --git a/src/transport/sock/sock_worker_core.cpp b/src/transport/sock/sock_worker_core.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b65072fecc2e0570a6b8c328234ae7c1055e31d --- /dev/null +++ b/src/transport/sock/sock_worker_core.cpp @@ -0,0 +1,822 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "sock_worker.h" + +namespace ock { +namespace hcom { +constexpr uint32_t MAX_EPOLL_SIZE = NN_NO8192; + +SResult SockWorker::Validate() +{ + /* do later */ + return SS_OK; +} + +SResult SockWorker::Initialize() +{ + std::lock_guard guard(mMutex); + if (mInited) { + return SS_OK; + } + + SResult result = SS_OK; + if (NN_UNLIKELY((result = Validate()) != SS_OK)) { + NN_LOG_ERROR("Failed to validate in sock worker initialize"); + return result; + } + + NN_LOG_INFO("Try to initialize sock worker '" << mName << "' with " << mOptions.ToString()); + + /* create epoll */ + if (NN_UNLIKELY((mEpollHandle = epoll_create(MAX_EPOLL_SIZE)) < 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create epoll in sock worker " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SS_WORKER_EPOLL_FAILED; + } + + if (NN_UNLIKELY((result = InitContextInfoPool()) != SS_OK)) { + return result; + } + + mInited = true; + return SS_OK; +} + +SResult SockWorker::InitContextInfoPool() +{ + SResult result = SS_OK; + if (mType == SOCK_UDS) { + if ((result = mOpCtxInfoPool.Initialize(mOpCtxMemPool, UBSHcomNetDriverProtocol::UDS)) != SS_OK) { + NN_LOG_ERROR("Failed to initialize operation context info pool in SockWorker " << DetailName()); + return result; + } + + if ((result = mSglCtxInfoPool.Initialize(mSglCtxMemPool, UBSHcomNetDriverProtocol::UDS)) != SS_OK) { + NN_LOG_ERROR("Failed to initialize sgl context info pool in SockWorker " << DetailName()); + return result; + } + + if (mOptions.tcpSendZCopy) { + if ((result = mHeaderReqInfoPool.Initialize(mHeaderReqMemPool, UBSHcomNetDriverProtocol::UDS)) != SS_OK) { + NN_LOG_ERROR("Failed to initialize header request info pool in SockWorker " << DetailName()); + return result; + } + } + } else { + if ((result = mOpCtxInfoPool.Initialize(mOpCtxMemPool)) != SS_OK) { + NN_LOG_ERROR("Failed to initialize operation context info pool in SockWorker " << DetailName()); + return result; + } + + if ((result = mSglCtxInfoPool.Initialize(mSglCtxMemPool)) != SS_OK) { + NN_LOG_ERROR("Failed to initialize sgl context info pool in SockWorker " << DetailName()); + return result; + } + + if (mOptions.tcpSendZCopy) { + if ((result = mHeaderReqInfoPool.Initialize(mHeaderReqMemPool)) != SS_OK) { + NN_LOG_ERROR("Failed to initialize header request info pool in SockWorker " << DetailName()); + return result; + } + } + } + return result; +} + +void SockWorker::UnInitialize() +{ + std::lock_guard guard(mMutex); + if (!mInited) { + return; + } + + UnInitializeInner(); + mOpCtxInfoPool.UnInitialize(); + mSglCtxInfoPool.UnInitialize(); + mHeaderReqInfoPool.UnInitialize(); +} + +void SockWorker::UnInitializeInner() +{ + if (mEpollHandle != -1) { + NetFunc::NN_SafeCloseFd(mEpollHandle); + } + + // do later +} + +SResult SockWorker::HandleReceiveCtx(SockOpContextInfo &opCtx) +{ + switch ((opCtx.header->flags & 0xff)) { + case NTH_TWO_SIDE: + case NTH_TWO_SIDE_SGL: + NN_LOG_TRACE_INFO("Receive new request " << opCtx.sock->Id() << " head imm data " << + opCtx.header->immData << ", flags " << opCtx.header->flags << ", seqNo " << opCtx.header->seqNo << + ", data len " << opCtx.header->dataLength); + return mNewRequestHandler(opCtx); + case NTH_READ: + return PostReadAck(opCtx); + case NTH_READ_ACK: + return PostReadAckHandle(opCtx); + case NTH_READ_SGL: + return PostReadSglAck(opCtx); + case NTH_READ_SGL_ACK: + return PostReadSglAckHandle(opCtx); + case NTH_WRITE: + return PostWriteAck(opCtx); + case NTH_WRITE_ACK: + return PostWriteAckHandle(opCtx); + case NTH_WRITE_SGL: + return PostWriteSglAck(opCtx); + case NTH_WRITE_SGL_ACK: + return PostWriteSglAckHandle(opCtx); + default: + NN_LOG_ERROR("Receive head invalid flags " << opCtx.header->flags); + return SS_PARAM_INVALID; + } +} + +#define HANDLE_SOCK_EVENT(sockOpResult, doUpperCall) \ + Sock *sock = static_cast(oneEv.data.ptr); \ + if (NN_UNLIKELY(sock == nullptr)) { \ + NN_LOG_ERROR("Sock is null in polled event for sock worker " << mName); \ + continue; \ + } \ + if (NN_UNLIKELY(fcntl(sock->FD(), F_GETFD) == -1 && errno == EBADF)) { \ + NN_LOG_ERROR("Receive bad fd " << sock->FD() << " in sock worker " << mName); \ + continue; \ + } \ + \ + static thread_local SockOpContextInfo opCtx {}; \ + if (oneEv.events & EPOLLIN) { \ + if (NN_LIKELY(((sockOpResult) = sock->HandleIn((doUpperCall)))) == SockOpContextInfo::SS_NO_ERROR) { \ + /* if fully receive a request */ \ + \ + if (doUpperCall) { \ + /* set context */ \ + opCtx.header = sock->GetHeaderAddress(); \ + opCtx.sock = sock; \ + opCtx.dataAddress = sock->mReceiveBuff.DataIntPtr(); \ + opCtx.dataSize = sock->mReceiveBuff.ActualDataSize(); \ + opCtx.opType = SockOpContextInfo::SS_RECEIVE; \ + opCtx.errType = SockOpContextInfo::SS_NO_ERROR; \ + \ + /* handle by type */ \ + if (NN_UNLIKELY(HandleReceiveCtx(opCtx) == NN_EP_CLOSE)) { \ + /* fd is already removed from epoll, cannot be modified again */ \ + continue; \ + } \ + } \ + if (NN_UNLIKELY(ModifyInEpoll(sock, EPOLLIN | EPOLLOUT | EPOLLET) != SS_OK)) { \ + NN_LOG_WARN("Unable to modify sock " << sock->Id() << " in epoll in"); \ + } \ + /* not fully received, continue to process next event */ \ + continue; \ + } \ + NN_LOG_TRACE_INFO("Got error " << (sockOpResult) << " on sock " << sock->Id() << " with peer " << \ + sock->PeerIpPort() << " in sock worker " << mName); \ + \ + /* do sock conn broken process */ \ + bzero(&opCtx, sizeof(SockOpContextInfo)); \ + opCtx.sock = sock; \ + opCtx.opType = SockOpContextInfo::SS_RECEIVE; \ + opCtx.errType = sockOpResult; \ + \ + /* do upper call */ \ + mNewRequestHandler(opCtx); \ + /* continue to process next event */ \ + continue; \ + } else if (oneEv.events & EPOLLOUT) { \ + auto result = sock->ProcessQueueReq(); \ + if (result == SS_SOCK_SEND_EAGAIN) { \ + if (NN_UNLIKELY(ModifyInEpoll(sock, EPOLLIN | EPOLLOUT | EPOLLET) != SS_OK)) { \ + NN_LOG_WARN("Unable to modify sock " << sock->Id() << " in epoll out"); \ + } \ + } \ + if (result == SS_RESET_BY_PEER || result == SS_SOCK_SEND_FAILED) { \ + if (NN_UNLIKELY(ModifyInEpoll(sock, EPOLLWRNORM) != SS_OK)) { \ + NN_LOG_WARN("Unable to modify sock " << sock->Id() << " when EPOLLWRNORM in epoll out"); \ + } \ + } \ + continue; \ + } else if (oneEv.events & EPOLLWRNORM) { \ + mEpCloseHandler(sock); \ + continue; \ + } \ + \ + NN_LOG_TRACE_INFO("Receive sock " << sock->Id() << " event " << oneEv.events); \ + /* continue to process next event */ \ + continue + + +#define HANDLE_EVENTS(count, sockOpResult, doUpperCall, ev) \ + for (uint16_t i = 0; i < static_cast(count); ++i) { \ + struct epoll_event &oneEv = (ev)[i]; \ + HANDLE_SOCK_EVENT(sockOpResult, doUpperCall); \ + } + +void SockWorker::RunInThread(int16_t cpuId) +{ + BindCpuSetPthreadName(cpuId); + + if (mOptions.threadPriority != 0) { + if (NN_UNLIKELY(setpriority(PRIO_PROCESS, 0, mOptions.threadPriority) != 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Unable to set worker thread priority in sock worker " << mName << ", errno:" << errno << + " error:" << NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + } + } + + mProgressThrStarted.store(true); + + const uint16_t pollBatchSize = mOptions.pollingBatchSize; + const uint32_t timeout = mOptions.pollingTimeoutMs; + + struct epoll_event ev[pollBatchSize]; + + /* for new accept sock */ + bool doUpperCall = false; + SockOpContextInfo::SockErrorType sockOpResult = SockOpContextInfo::SS_NO_ERROR; + + /* start epoll */ + while (!mNeedToStop) { + /* do epoll wait */ + int count = epoll_wait(mEpollHandle, ev, pollBatchSize, timeout); + if (count > 0) { + /* there are events, handle it */ + NN_LOG_TRACE_INFO("Got " << count << " in sock worker " << mName); + TRACE_DELAY_BEGIN(SOCK_WORKER_EPOLL_WAIT); + HANDLE_EVENTS(count, sockOpResult, doUpperCall, ev) + TRACE_DELAY_END(SOCK_WORKER_EPOLL_WAIT, 0); + } else if (count == 0) { + NN_LOG_TRACE_INFO("Got " << count << " in sock worker " << mName); + /* if io request, call idle */ + if (mIdleHandler != nullptr) { + mIdleHandler(mIndex); + } + } else if (errno == EINTR) { + NN_LOG_TRACE_INFO("Got error no EINTR in sock worker " << mName); + continue; + } else { + /* error happens */ + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to do epoll_wait in sock worker " << mName << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + break; + } + } + + NN_LOG_INFO("Sock worker " << mName << ":" << mIndex.ToString() << " progress thread exiting"); +} + +SResult SockWorker::Start() +{ + std::lock_guard guard(mMutex); + if (!mInited) { + NN_LOG_ERROR("Failed to start sock worker " << mName << " as it is not initialized"); + return SS_ERROR; + } + + if (mStarted) { + NN_LOG_WARN("Unable to start sock worker " << mName << " as it is already started"); + return SS_OK; + } + + /* validate handler */ + if (mNewRequestHandler == nullptr) { + NN_LOG_ERROR("Failed to start sock worker " << mName << " as new request handler is null"); + return SS_PARAM_INVALID; + } + + if (mSendPostedHandler == nullptr) { + NN_LOG_ERROR("Failed to start sock worker " << mName << " as request posted handler is null"); + return SS_PARAM_INVALID; + } + + if (mOneSideDoneHandler == nullptr) { + NN_LOG_ERROR("Failed to start sock worker " << mName << " as one side done handler is null"); + return SS_PARAM_INVALID; + } + mNeedToStop = false; + std::thread tmpThread(&SockWorker::RunInThread, this, mOptions.cpuId); + mProgressThr = std::move(tmpThread); + + while (!mProgressThrStarted.load()) { + usleep(NN_NO10); + } + + mProgressThrStarted = false; + + mStarted = true; + return SS_OK; +} + +void SockWorker::Stop() +{ + std::lock_guard guard(mMutex); + if (!mStarted) { + return; + } + + mNeedToStop = true; + if (mProgressThr.joinable()) { + mProgressThr.join(); + } + + mStarted = false; +} + +SResult SockWorker::PostReadAck(SockOpContextInfo &opCtx) +{ + NN_ASSERT_LOG_RETURN(opCtx.sock->UpContext() != 0, SS_ERROR) + while (NN_UNLIKELY(!opCtx.sock->GetQueueSpace())) { + (void)opCtx.sock->ProcessQueueReq(); + } + if (NN_UNLIKELY(opCtx.dataSize < sizeof(UBSHcomNetTransSgeIov))) { + NN_LOG_ERROR("Failed to PostReadAck as data size " << opCtx.dataSize << " is less than iov size"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return SS_PARAM_INVALID; + } + auto rawIov = reinterpret_cast(opCtx.dataAddress); + if (NN_UNLIKELY(NN_OK != opCtx.sock->mMrChecker->Validate(rawIov->rKey, rawIov->rAddress, rawIov->size))) { + NN_LOG_ERROR("Invalid memory region or local key"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return NN_INVALID_LKEY; + } + + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + opCtx.sock->ReturnQueueSpace(NN_NO1); + NN_LOG_ERROR("Failed to PostReadAck with SockWorker " << DetailName() << " as no ctx left"); + return SS_CTX_FULL; + } + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostReadAck with SockWorker " << DetailName() << " as no sglCtx left"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + mOpCtxInfoPool.Return(opCtxInfo); + return SS_CTX_FULL; + } + + opCtxInfo->sock = opCtx.sock; + opCtxInfo->opType = SockOpContextInfo::SockOpType::SS_READ_ACK; + + SockTransHeader header = {}; + header.flags = NTH_READ_ACK; + header.seqNo = opCtx.header->seqNo; + header.dataLength = rawIov->size; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + + UBSHcomNetTransSgeIov iov(rawIov->rAddress, 0, rawIov->size); + UBSHcomNetTransSglRequest req(&iov, NN_NO1, 0); + + sglCtx->Clone(header, req.iov, req.iovCount); + opCtxInfo->sendCtx = sglCtx; + + auto result = opCtx.sock->PostSendSgl(opCtxInfo); + // blocking post send need call upper handle + if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(opCtx.sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_OK) { + auto res = ModifyInEpoll(opCtx.sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + ReturnResources(opCtx.sock, opCtxInfo); + return result; +} + +SResult SockWorker::PostReadAckHandle(SockOpContextInfo &opCtx) +{ + NN_ASSERT_LOG_RETURN(opCtx.sock->UpContext() != 0, SS_ERROR) + auto originalCtx = opCtx.sock->RemoveOpCtx(opCtx.header->seqNo); + if (originalCtx == nullptr) { + NN_LOG_ERROR("Failed to PostReadAckHandle with sock worker " << DetailName() << " as invalid seqNo " << + opCtx.header->seqNo); + return SS_PARAM_INVALID; + } + if (originalCtx->sock != opCtx.sock) { + NN_LOG_ERROR("Failed to check with sock worker " << DetailName() << " as sock different."); + return SS_PARAM_INVALID; + } + // only the first iov is used, the mr info is recorded in this iov + if (originalCtx->sendCtx->iov[0].size != opCtx.dataSize) { + NN_LOG_ERROR("Failed to check sock with sock worker " << DetailName() << " as size different."); + return SS_PARAM_INVALID; + } + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(originalCtx->sendCtx->iov[0].lAddress), + originalCtx->sendCtx->iov[0].size, reinterpret_cast(opCtx.dataAddress), + originalCtx->sendCtx->iov[0].size) != SS_OK)) { + NN_LOG_ERROR("Failed to copy opCtx to iov"); + return SS_PARAM_INVALID; + } + NN_LOG_TRACE_INFO("PostReadAckHandle " << opCtx.sock->Id() << " head imm data " << opCtx.header->immData << + ", flags " << opCtx.header->flags << ", seqNo " << opCtx.header->seqNo << ", data len " << + opCtx.header->dataLength); + return mOneSideDoneHandler(originalCtx); +} + +SResult SockWorker::GenerateReadSglAckOpCtxInfo(SockOpContextInfo *&opCtxInfo, SockOpContextInfo &opCtx, + UBSHcomNetTransSgeIov *&rawIov, uint16_t iovCount, uint32_t dataSize) +{ + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostReadSglAck with sock worker " << DetailName() << " as no sglCtx left"); + return SS_CTX_FULL; + } + + opCtxInfo->sock = opCtx.sock; + opCtxInfo->opType = SockOpContextInfo::SockOpType::SS_SGL_READ_ACK; + + SockTransHeader header = {}; + header.flags = NTH_READ_SGL_ACK; + header.seqNo = opCtx.header->seqNo; + header.dataLength = opCtx.header->dataLength + dataSize; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + + sglCtx->Clone(header, rawIov, iovCount); + opCtxInfo->sendCtx = sglCtx; + return SS_OK; +} + +inline SResult SockWorker::CheckIovLen(SockOpContextInfo &opCtx, uint16_t &iovCount) +{ + if (NN_UNLIKELY(opCtx.dataSize < sizeof(UBSHcomNetTransSglRequest::iovCount))) { + NN_LOG_ERROR("Failed to PostReadAck as data size " << opCtx.dataSize << " is less than iovCount size"); + return false; + } + auto count = reinterpret_cast(opCtx.dataAddress); + + if (*count == 0 || *count > NN_NO4) { + NN_LOG_ERROR("Failed to check sock with sock worker " << mName << " as iov count is illegal."); + return false; + } + if (NN_UNLIKELY(opCtx.dataSize < (sizeof(UBSHcomNetTransSglRequest::iovCount) + + sizeof(UBSHcomNetTransSgeIov) * (*count)))) { + NN_LOG_ERROR("Failed to PostReadAck as data size " << opCtx.dataSize << " is less than iov size"); + return false; + } + iovCount = *count; + return true; +} + +SResult SockWorker::PostReadSglAck(SockOpContextInfo &opCtx) +{ + NN_ASSERT_LOG_RETURN(opCtx.sock->UpContext() != 0, SS_ERROR) + while (NN_UNLIKELY(!opCtx.sock->GetQueueSpace())) { + (void)opCtx.sock->ProcessQueueReq(); + } + uint16_t iovCount = 0; + if (NN_UNLIKELY(!CheckIovLen(opCtx, iovCount))) { + opCtx.sock->ReturnQueueSpace(NN_NO1); + return SS_PARAM_INVALID; + } + auto rawIov = reinterpret_cast(opCtx.dataAddress + + sizeof(UBSHcomNetTransSglRequest::iovCount)); + uint32_t dataSize = 0; + for (uint16_t i = 0; i < iovCount; i++) { + if (NN_UNLIKELY(NN_OK != + opCtx.sock->mMrChecker->Validate(rawIov[i].rKey, rawIov[i].rAddress, rawIov[i].size))) { + NN_LOG_ERROR("Invalid memory region or local key"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return NN_INVALID_LKEY; + } + dataSize += rawIov[i].size; + } + + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + NN_LOG_ERROR("Failed to PostReadSglAck with sock worker " << DetailName() << " as no ctx left"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return SS_CTX_FULL; + } + + if (NN_UNLIKELY(GenerateReadSglAckOpCtxInfo(opCtxInfo, opCtx, rawIov, iovCount, dataSize) != SS_OK)) { + opCtx.sock->ReturnQueueSpace(NN_NO1); + mOpCtxInfoPool.Return(opCtxInfo); + return SS_CTX_FULL; + } + + auto result = opCtx.sock->PostReadSglAck(opCtxInfo); + // blocking post send need call upper handle + if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(opCtx.sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_OK) { + auto res = ModifyInEpoll(opCtx.sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + ReturnResources(opCtx.sock, opCtxInfo); + + return result; +} + +SResult SockWorker::PostReadSglAckHandle(SockOpContextInfo &opCtx) +{ + NN_ASSERT_LOG_RETURN(opCtx.sock->UpContext() != 0, SS_ERROR) + auto originalCtx = opCtx.sock->RemoveOpCtx(opCtx.header->seqNo); + if (originalCtx == nullptr) { + NN_LOG_ERROR("Failed to handle ack with sock worker " << mName << " as invalid seqNo " << opCtx.header->seqNo); + return SS_PARAM_INVALID; + } + + if (originalCtx->sock != opCtx.sock) { + NN_LOG_ERROR("Failed to check read sgl sock ptr with sock worker " << mName << " as sock different."); + return SS_PARAM_INVALID; + } + + if (NN_UNLIKELY(opCtx.dataSize < sizeof(UBSHcomNetTransSglRequest::iovCount))) { + NN_LOG_ERROR("Failed to PostReadAck as data size " << opCtx.dataSize << " is less than iovCount size"); + return SS_PARAM_INVALID; + } + /* write data */ + auto iovCount = reinterpret_cast(opCtx.dataAddress); + if (*iovCount == 0 || *iovCount > NN_NO4 || *iovCount != originalCtx->sendCtx->iovCount) { + NN_LOG_ERROR("Failed to check sock with sock worker " << mName << " as iov count is illegal."); + return SS_PARAM_INVALID; + } + if (NN_UNLIKELY(opCtx.dataSize < (sizeof(UBSHcomNetTransSglRequest::iovCount) + + sizeof(UBSHcomNetTransSgeIov) * (*iovCount)))) { + NN_LOG_ERROR("Failed to PostReadAck as data size " << opCtx.dataSize << " is less than iov size"); + return SS_PARAM_INVALID; + } + auto sgeIov = reinterpret_cast(opCtx.dataAddress + + sizeof(UBSHcomNetTransSglRequest::iovCount)); + auto data = reinterpret_cast(opCtx.dataAddress + sizeof(UBSHcomNetTransSglRequest::iovCount) + + sizeof(UBSHcomNetTransSgeIov) * (*iovCount)); + + uint32_t dataSize = 0; + for (uint16_t i = 0; i < *iovCount; i++) { + dataSize += sgeIov[i].size; + } + + if (originalCtx->sendCtx->sendHeader.dataLength + dataSize != opCtx.header->dataLength) { + NN_LOG_ERROR("Failed to check sock with sock worker " << mName << " as size different."); + return SS_PARAM_INVALID; + } + + uint32_t copyOffset = 0; + for (uint16_t i = 0; i < *iovCount; i++) { + UBSHcomNetTransSgeIov iov = originalCtx->sendCtx->iov[i]; + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(iov.lAddress), iov.size, data + copyOffset, iov.size) != + NN_OK)) { + NN_LOG_ERROR("Failed to copy data to iov"); + return NN_INVALID_PARAM; + } + copyOffset += iov.size; + } + + NN_LOG_TRACE_INFO("PostReadSglAckHandle " << opCtx.sock->Id() << " head imm data " << opCtx.header->immData << + ", flags " << opCtx.header->flags << ", seqNo " << opCtx.header->seqNo << ", data len " << + opCtx.header->dataLength); + return mOneSideDoneHandler(originalCtx); +} + +SResult SockWorker::GenerateWriteAckOpCtxInfo(SockOpContextInfo *&opCtxInfo, SockOpContextInfo &opCtx) +{ + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostWriteAck with sock worker " << DetailName() << " as no sglCtx left"); + return SS_CTX_FULL; + } + + SockTransHeader header = {}; + header.flags = NTH_WRITE_ACK; + header.seqNo = opCtx.header->seqNo; + header.dataLength = opCtx.header->dataLength; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + + sglCtx->sendHeader = header; + opCtxInfo->sendCtx = sglCtx; + opCtxInfo->opType = SockOpContextInfo::SockOpType::SS_WRITE_ACK; + return SS_OK; +} + +SResult SockWorker::PostWriteAck(SockOpContextInfo &opCtx) +{ + NN_ASSERT_LOG_RETURN(opCtx.sock->UpContext() != 0, SS_ERROR) + /* send ack */ + while (NN_UNLIKELY(!opCtx.sock->GetQueueSpace())) { + (void)opCtx.sock->ProcessQueueReq(); + } + if (NN_UNLIKELY(opCtx.dataSize < sizeof(UBSHcomNetTransSgeIov))) { + NN_LOG_ERROR("Failed to PostWriteAck as data size " << opCtx.dataSize << " is less than iov size"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return SS_PARAM_INVALID; + } + auto rawIov = reinterpret_cast(opCtx.dataAddress); + if (NN_UNLIKELY(NN_OK != opCtx.sock->mMrChecker->Validate(rawIov->rKey, rawIov->rAddress, rawIov->size))) { + NN_LOG_ERROR("Invalid memory region or local key"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return NN_INVALID_LKEY; + } + /* write data */ + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(rawIov->rAddress), rawIov->size, + reinterpret_cast(opCtx.dataAddress + sizeof(UBSHcomNetTransSgeIov)), rawIov->size) != NN_OK)) { + opCtx.sock->ReturnQueueSpace(NN_NO1); + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return SS_PARAM_INVALID; + } + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + NN_LOG_ERROR("Failed to PostWriteAck with sock worker " << DetailName() << " as no ctx left"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return SS_PARAM_INVALID; + } + if (NN_UNLIKELY(GenerateWriteAckOpCtxInfo(opCtxInfo, opCtx))) { + opCtx.sock->ReturnQueueSpace(NN_NO1); + mOpCtxInfoPool.Return(opCtxInfo); + return SS_CTX_FULL; + } + auto result = opCtx.sock->PostSendHead(opCtxInfo); + if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(opCtx.sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_OK) { + auto res = ModifyInEpoll(opCtx.sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + ReturnResources(opCtx.sock, opCtxInfo); + + return result; +} + +SResult SockWorker::PostWriteAckHandle(SockOpContextInfo &opCtx) +{ + NN_ASSERT_LOG_RETURN(opCtx.sock->UpContext() != 0, SS_ERROR) + auto originalCtx = opCtx.sock->RemoveOpCtx(opCtx.header->seqNo); + if (originalCtx == nullptr) { + NN_LOG_ERROR("Failed to handle ack with sock worker " << mName << " as invalid seqNo " << opCtx.header->seqNo); + return SS_PARAM_INVALID; + } + if (originalCtx->sock != opCtx.sock) { + NN_LOG_ERROR("Failed to check write sock ptr with sock worker " << mName << " as sock different."); + return SS_PARAM_INVALID; + } + if (originalCtx->sendCtx->sendHeader.dataLength != opCtx.header->dataLength) { + NN_LOG_ERROR("Failed to check sock with sock worker " << mName << " as size different."); + return SS_PARAM_INVALID; + } + + NN_LOG_TRACE_INFO("PostWriteAckHandle " << opCtx.sock->Id() << " head imm data " << opCtx.header->immData << + ", flags " << opCtx.header->flags << ", seqNo " << opCtx.header->seqNo << ", data len " << + opCtx.header->dataLength); + return mOneSideDoneHandler(originalCtx); +} + +SResult SockWorker::GenerateWriteSglAckOpCtxInfo(SockOpContextInfo *&opCtxInfo, SockOpContextInfo &opCtx) +{ + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostWriteSglAck with sock worker " << DetailName() << " as no sglCtx left"); + return SS_CTX_FULL; + } + + /* send ack */ + SockTransHeader header = {}; + header.flags = NTH_WRITE_SGL_ACK; + header.seqNo = opCtx.header->seqNo; + header.dataLength = opCtx.header->dataLength; + + /* finally fill header crc */ + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + + sglCtx->sendHeader = header; + opCtxInfo->sendCtx = sglCtx; + opCtxInfo->opType = SockOpContextInfo::SockOpType::SS_SGL_WRITE_ACK; + return SS_OK; +} + +SResult SockWorker::PostWriteSglAck(SockOpContextInfo &opCtx) +{ + NN_ASSERT_LOG_RETURN(opCtx.sock->UpContext() != 0, SS_ERROR) + while (NN_UNLIKELY(!opCtx.sock->GetQueueSpace())) { + (void)opCtx.sock->ProcessQueueReq(); + } + + uint16_t iovCount = 0; + if (NN_UNLIKELY(!CheckIovLen(opCtx, iovCount))) { + opCtx.sock->ReturnQueueSpace(NN_NO1); + return SS_PARAM_INVALID; + } + auto iov = reinterpret_cast(opCtx.dataAddress + + sizeof(UBSHcomNetTransSglRequest::iovCount)); + uint32_t dataSize = 0; + for (uint16_t i = 0; i < iovCount; i++) { + dataSize += iov[i].size; + } + if (NN_UNLIKELY(opCtx.dataSize < + (sizeof(UBSHcomNetTransSglRequest::iovCount) + sizeof(UBSHcomNetTransSgeIov) * iovCount + dataSize))) { + NN_LOG_ERROR("Failed to PostReadAck as data size " << opCtx.dataSize << " is less than iov data size"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return SS_PARAM_INVALID; + } + + auto data = reinterpret_cast(opCtx.dataAddress + sizeof(UBSHcomNetTransSglRequest::iovCount) + + sizeof(UBSHcomNetTransSgeIov) * iovCount); + + uint32_t copyOffset = 0; + for (uint16_t i = 0; i < iovCount; i++) { + if (NN_UNLIKELY(NN_OK != opCtx.sock->mMrChecker->Validate(iov[i].rKey, iov[i].rAddress, iov[i].size))) { + NN_LOG_ERROR("Invalid memory region or local key"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return NN_INVALID_LKEY; + } + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(iov[i].rAddress), iov[i].size, data + copyOffset, + iov[i].size) != NN_OK)) { + opCtx.sock->ReturnQueueSpace(NN_NO1); + NN_LOG_ERROR("Failed to copy data to iov"); + return NN_INVALID_PARAM; + } + copyOffset += iov[i].size; + } + + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + NN_LOG_ERROR("Failed to PostWriteSglAck with sock worker " << DetailName() << " as no ctx left"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + return SS_PARAM_INVALID; + } + + if (GenerateWriteSglAckOpCtxInfo(opCtxInfo, opCtx) != SS_OK) { + NN_LOG_ERROR("Failed to PostWriteSglAck with sock worker " << DetailName() << " as no Sglctx left"); + opCtx.sock->ReturnQueueSpace(NN_NO1); + mOpCtxInfoPool.Return(opCtxInfo); + return SS_PARAM_INVALID; + } + + auto result = opCtx.sock->PostSendHead(opCtxInfo); + if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(opCtx.sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_OK) { + auto res = ModifyInEpoll(opCtx.sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + ReturnResources(opCtx.sock, opCtxInfo); + + return result; +} + +SResult SockWorker::PostWriteSglAckHandle(SockOpContextInfo &opCtx) +{ + NN_ASSERT_LOG_RETURN(opCtx.sock->UpContext() != 0, SS_ERROR) + + auto originalCtx = opCtx.sock->RemoveOpCtx(opCtx.header->seqNo); + if (originalCtx == nullptr) { + NN_LOG_ERROR("Failed to handle ack with sock worker " << mName << " as invalid seqNo " << opCtx.header->seqNo); + return SS_PARAM_INVALID; + } + if (originalCtx->sock != opCtx.sock) { + NN_LOG_ERROR("Failed to check write sgl sock ptr with sock worker " << mName << " as sock different."); + return SS_PARAM_INVALID; + } + if (originalCtx->sendCtx->sendHeader.dataLength != opCtx.header->dataLength) { + NN_LOG_ERROR("Failed to check sock with sock worker " << mName << " as data length different."); + return SS_PARAM_INVALID; + } + + NN_LOG_TRACE_INFO("PostWriteSglAckHandle " << opCtx.sock->Id() << " head imm data " << opCtx.header->immData << + ", flags " << opCtx.header->flags << ", seqNo " << opCtx.header->seqNo << ", data len " << + opCtx.header->dataLength); + return mOneSideDoneHandler(originalCtx); +} + +void SockWorker::ReturnResources(Sock *sock, SockOpContextInfo *ctx, SockSglContextInfo *sglCtx) +{ + sock->ReturnQueueSpace(NN_NO1); + if (sglCtx != nullptr) { + mSglCtxInfoPool.Return(sglCtx); + sglCtx = nullptr; + } + if (ctx != nullptr) { + mOpCtxInfoPool.Return(ctx); + ctx = nullptr; + } +} + +void SockWorker::EpCloseByUser(Sock *sock) +{ + if (mEpCloseHandler == nullptr) { + NN_LOG_WARN("Worker ep close handler is null, worker name is" << mName); + } + mEpCloseHandler(sock); +} +} +} \ No newline at end of file diff --git a/src/transport/sock/sock_worker_io.cpp b/src/transport/sock/sock_worker_io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..abf45690c86b1c82b0b5201aa94e6faefa810a52 --- /dev/null +++ b/src/transport/sock/sock_worker_io.cpp @@ -0,0 +1,367 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "sock_worker.h" + +namespace ock { +namespace hcom { + + +/* async endpoint operation */ +SResult SockWorker::PostSend(Sock *sock, SockTransHeader &header, const UBSHcomNetTransRequest &req) +{ + if (NN_UNLIKELY(!sock->GetQueueSpace())) { + return SS_SOCK_ADD_QUEUE_FAILED; + } + + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + NN_LOG_ERROR("Failed to PostSend with sock worker " << DetailName() << " as no ctx left"); + sock->ReturnQueueSpace(NN_NO1); + return SS_CTX_FULL; + } + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + opCtxInfo->sock = sock; + opCtxInfo->opType = + header.immData == 0 ? SockOpContextInfo::SockOpType::SS_SEND : SockOpContextInfo::SockOpType::SS_SEND_RAW; + opCtxInfo->upCtxSize = req.upCtxSize; + if (opCtxInfo->upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(opCtxInfo->upCtx, NN_NO16, req.upCtxData, opCtxInfo->upCtxSize) != NN_OK)) { + ReturnResources(sock, opCtxInfo, nullptr); + NN_LOG_ERROR("Failed to copy req to opCtxInfo"); + return SS_PARAM_INVALID; + } + } + if (mOptions.tcpSendZCopy) { + auto headerReqInfo = mHeaderReqInfoPool.Get(); + if (NN_UNLIKELY(headerReqInfo == nullptr)) { + NN_LOG_ERROR("Failed to PostSend with sock worker " << DetailName() << " as no ctx left"); + ReturnResources(sock, opCtxInfo, nullptr); + return SS_CTX_FULL; + } + headerReqInfo->sendHeader = header; + headerReqInfo->request = reinterpret_cast(req.lAddress); + opCtxInfo->headerRequest = headerReqInfo; + } else { + opCtxInfo->sendBuff = &header; + } + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + + auto result = sock->PostSend(opCtxInfo); + // blocking post send need call upper handle + if (result == SS_OK) { + sock->ReturnQueueSpace(NN_NO1); + mSendPostedHandler(opCtxInfo); + mOpCtxInfoPool.Return(opCtxInfo); + NN_LOG_TRACE_INFO("PostSend cb sock " << sock->Id() << " head imm data " << header.immData << ", flags " << + header.flags << ", seqNo " << header.seqNo << ", data len " << header.dataLength); + return result; + } else if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_TCP_RETRY) { + auto res = ModifyInEpoll(sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + sock->ReturnQueueSpace(NN_NO1); + if (mOptions.tcpSendZCopy) { + mHeaderReqInfoPool.Return(opCtxInfo->headerRequest); + opCtxInfo->headerRequest = nullptr; + } + mOpCtxInfoPool.Return(opCtxInfo); + opCtxInfo = nullptr; + + return result; +} + +SResult SockWorker::PostSendRawSgl(Sock *sock, SockTransHeader &header, const UBSHcomNetTransSglRequest &req) +{ + if (NN_UNLIKELY(!sock->GetQueueSpace())) { + return SS_SOCK_ADD_QUEUE_FAILED; + } + + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + NN_LOG_ERROR("Failed to PostSendRawSgl with sock worker " << DetailName() << " as no ctx left"); + sock->ReturnQueueSpace(NN_NO1); + return SS_CTX_FULL; + } + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostSendRawSgl with sock worker " << DetailName() << " as no sglCtx left"); + sock->ReturnQueueSpace(NN_NO1); + mOpCtxInfoPool.Return(opCtxInfo); + return SS_CTX_FULL; + } + + opCtxInfo->sock = sock; + opCtxInfo->opType = SockOpContextInfo::SockOpType::SS_SEND_RAW_SGL; + opCtxInfo->upCtxSize = req.upCtxSize; + if (opCtxInfo->upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(opCtxInfo->upCtx, NN_NO16, req.upCtxData, opCtxInfo->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to opCtxInfo"); + ReturnResources(sock, opCtxInfo, sglCtx); + return SS_PARAM_INVALID; + } + } + + sglCtx->Clone(header, req.iov, req.iovCount); + opCtxInfo->sendCtx = sglCtx; + + auto result = sock->PostSendSgl(opCtxInfo); + // blocking post send need call upper handle + if (result == SS_OK) { + sock->ReturnQueueSpace(NN_NO1); + mSendPostedHandler(opCtxInfo); + mOpCtxInfoPool.Return(opCtxInfo); + NN_LOG_TRACE_INFO("PostSendRawSgl cb sock " << sock->Id() << " head imm data " << header.immData << + ", flags " << header.flags << ", seqNo " << header.seqNo << ", data len " << header.dataLength); + return result; + } else if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_TCP_RETRY) { + auto res = ModifyInEpoll(sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + ReturnResources(sock, opCtxInfo); + + return result; +} + +SResult SockWorker::PostRead(Sock *sock, SockTransHeader &header, const UBSHcomNetTransRequest &req) +{ + if (NN_UNLIKELY(!sock->GetQueueSpace())) { + return SS_SOCK_ADD_QUEUE_FAILED; + } + + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with sock worker " << DetailName() << " as no ctx left"); + sock->ReturnQueueSpace(NN_NO1); + return SS_PARAM_INVALID; + } + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with sock worker " << DetailName() << " as no sgl ctx left"); + sock->ReturnQueueSpace(NN_NO1); + mOpCtxInfoPool.Return(opCtxInfo); + return SS_CTX_FULL; + } + + opCtxInfo->sock = sock; + opCtxInfo->opType = SockOpContextInfo::SockOpType::SS_READ; + opCtxInfo->upCtxSize = req.upCtxSize; + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + if (opCtxInfo->upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(opCtxInfo->upCtx, NN_NO16, req.upCtxData, opCtxInfo->upCtxSize) != SS_OK)) { + ReturnResources(sock, opCtxInfo, sglCtx); + NN_LOG_ERROR("Failed to copy request to opCtxInfo"); + return SS_PARAM_INVALID; + } + } + + UBSHcomNetTransSgeIov iov(req.lAddress, req.rAddress, req.lKey, req.rKey, req.size); + sglCtx->Clone(header, &iov, NN_NO1); + opCtxInfo->sendCtx = sglCtx; + + sock->AddOpCtx(header.seqNo, opCtxInfo); + sock->IncreaseRef(); + + auto result = sock->PostRead(opCtxInfo); + if (result == SS_OK) { + sock->ReturnQueueSpace(NN_NO1); + return result; + } else if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_TCP_RETRY) { + auto res = ModifyInEpoll(sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + (void)sock->RemoveOpCtx(header.seqNo); + sock->DecreaseRef(); + ReturnResources(sock, opCtxInfo); + + return result; +} + +SResult SockWorker::PostRead(Sock *sock, SockTransHeader &header, const UBSHcomNetTransSglRequest &req) +{ + if (NN_UNLIKELY(!sock->GetQueueSpace())) { + return SS_SOCK_ADD_QUEUE_FAILED; + } + + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with sock worker " << DetailName() << " as no ctx left"); + sock->ReturnQueueSpace(NN_NO1); + return SS_CTX_FULL; + } + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with sock worker " << DetailName() << " as no sglCtx left"); + sock->ReturnQueueSpace(NN_NO1); + mOpCtxInfoPool.Return(opCtxInfo); + return SS_CTX_FULL; + } + + opCtxInfo->sock = sock; + opCtxInfo->opType = SockOpContextInfo::SockOpType::SS_SGL_READ; + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + opCtxInfo->upCtxSize = req.upCtxSize; + if (opCtxInfo->upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(opCtxInfo->upCtx, NN_NO16, req.upCtxData, opCtxInfo->upCtxSize) != SS_OK)) { + ReturnResources(sock, opCtxInfo, sglCtx); + NN_LOG_ERROR("Failed to copy request to opCtxInfo"); + return SS_PARAM_INVALID; + } + } + + sglCtx->Clone(header, req.iov, req.iovCount); + opCtxInfo->sendCtx = sglCtx; + + sock->AddOpCtx(header.seqNo, opCtxInfo); + sock->IncreaseRef(); + + auto result = sock->PostReadSgl(opCtxInfo); + // blocking post send need call upper handle + if (result == SS_OK) { + sock->ReturnQueueSpace(NN_NO1); + return result; + } else if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_TCP_RETRY) { + auto res = ModifyInEpoll(sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + (void)sock->RemoveOpCtx(header.seqNo); + sock->DecreaseRef(); + ReturnResources(sock, opCtxInfo); + + return result; +} + +SResult SockWorker::PostWrite(Sock *sock, SockTransHeader &header, const UBSHcomNetTransRequest &req) +{ + if (NN_UNLIKELY(!sock->GetQueueSpace())) { + return SS_SOCK_ADD_QUEUE_FAILED; + } + + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with sock worker " << DetailName() << " as no ctx left"); + sock->ReturnQueueSpace(NN_NO1); + return SS_CTX_FULL; + } + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with sock worker " << DetailName() << " as no sglCtx left"); + sock->ReturnQueueSpace(NN_NO1); + mOpCtxInfoPool.Return(opCtxInfo); + return SS_CTX_FULL; + } + + opCtxInfo->sock = sock; + opCtxInfo->opType = SockOpContextInfo::SockOpType::SS_WRITE; + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + opCtxInfo->upCtxSize = req.upCtxSize; + if (opCtxInfo->upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(opCtxInfo->upCtx, NN_NO16, req.upCtxData, opCtxInfo->upCtxSize) != SS_OK)) { + ReturnResources(sock, opCtxInfo, sglCtx); + NN_LOG_ERROR("Failed to copy req to opCtxInfo"); + return SS_PARAM_INVALID; + } + } + + UBSHcomNetTransSgeIov iov(req.lAddress, req.rAddress, req.lKey, req.rKey, req.size); + sglCtx->Clone(header, &iov, NN_NO1); + opCtxInfo->sendCtx = sglCtx; + + sock->AddOpCtx(header.seqNo, opCtxInfo); + sock->IncreaseRef(); + + auto result = sock->PostWrite(opCtxInfo); + // blocking post send need call upper handle + if (result == SS_OK) { + sock->ReturnQueueSpace(NN_NO1); + return result; + } else if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_TCP_RETRY) { + auto res = ModifyInEpoll(sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + (void)sock->RemoveOpCtx(header.seqNo); + sock->DecreaseRef(); + ReturnResources(sock, opCtxInfo); + + return result; +} + +SResult SockWorker::PostWrite(Sock *sock, SockTransHeader &header, const UBSHcomNetTransSglRequest &req) +{ + if (NN_UNLIKELY(!sock->GetQueueSpace())) { + return SS_SOCK_ADD_QUEUE_FAILED; + } + + auto opCtxInfo = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(opCtxInfo == nullptr)) { + NN_LOG_ERROR("Failed to post write sgl with sock worker " << DetailName() << " as no ctx left"); + sock->ReturnQueueSpace(NN_NO1); + return SS_CTX_FULL; + } + opCtxInfo->errType = SockOpContextInfo::SockErrorType::SS_NO_ERROR; + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to post write sgl with sock worker " << DetailName() << " as no sglCtx left"); + sock->ReturnQueueSpace(NN_NO1); + mOpCtxInfoPool.Return(opCtxInfo); + return SS_CTX_FULL; + } + + opCtxInfo->sock = sock; + opCtxInfo->opType = SockOpContextInfo::SockOpType::SS_SGL_WRITE; + opCtxInfo->upCtxSize = req.upCtxSize; + if (opCtxInfo->upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(opCtxInfo->upCtx, NN_NO16, req.upCtxData, opCtxInfo->upCtxSize) != SS_OK)) { + ReturnResources(sock, opCtxInfo, sglCtx); + NN_LOG_ERROR("Failed to copy req to opCtxInfo"); + return SS_PARAM_INVALID; + } + } + + sglCtx->Clone(header, req.iov, req.iovCount); + opCtxInfo->sendCtx = sglCtx; + + sock->AddOpCtx(header.seqNo, opCtxInfo); + sock->IncreaseRef(); + + auto result = sock->PostWriteSgl(opCtxInfo); + // blocking post send need call upper handle + if (result == SS_OK) { + sock->ReturnQueueSpace(NN_NO1); + return result; + } else if (result == SS_SOCK_SEND_EAGAIN) { + return ModifyInEpoll(sock, EPOLLIN | EPOLLOUT | EPOLLET); + } else if (result != SS_TCP_RETRY) { + auto res = ModifyInEpoll(sock, EPOLLWRNORM); + result = res == SS_OK ? result : res; + } + (void)sock->RemoveOpCtx(header.seqNo); + sock->DecreaseRef(); + ReturnResources(sock, opCtxInfo); + + return result; +} +} +} \ No newline at end of file diff --git a/src/transport/sock/sock_wrapper.cpp b/src/transport/sock/sock_wrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..516d5316abb45211572890da61917da90fe62648 --- /dev/null +++ b/src/transport/sock/sock_wrapper.cpp @@ -0,0 +1,266 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "sock_wrapper.h" + +namespace ock { +namespace hcom { +SResult Sock::Initialize(const SockWorkerOptions &workerOptions) +{ + std::lock_guard guard(mInitMutex); + if (mInited) { + return SS_OK; + } + + if (NN_UNLIKELY(mType != SOCK_UDS && mType != SOCK_TCP)) { + NN_LOG_ERROR("Failed to initialize sock as type is invalid"); + return SS_PARAM_INVALID; + } + + SResult result = SS_OK; + /* validate options */ + if (NN_UNLIKELY((result = ValidateOptions()) != SS_OK)) { + NN_LOG_ERROR("Failed to validate options in sock initialize"); + return result; + } + + /* validate fd and set sock option */ + if (NN_UNLIKELY((result = SetSockOption(workerOptions)) != SS_OK)) { + NN_LOG_ERROR("Failed to set sock options in sock initialize"); + return result; + } + + /* allocate receive buf and send list */ + if (NN_UNLIKELY(!mReceiveBuff.ExpandIfNeed(mOptions.receiveBufSizeKB * NN_NO1024))) { + NN_LOG_ERROR("Failed to allocate receive buffer for sock " << mId << ", probably of memory"); + return SS_MEMORY_ALLOCATE_FAILED; + } + + mReceiveState.ResetHeader(); + mCtxMap.reserve(SOCK_CTX_MAP_RESERVATION); + mSendQueue.Initialize(); + mInited = true; + return SS_OK; +} + +void Sock::UnInitialize() +{ + std::lock_guard guard(mInitMutex); + if (mSsl != nullptr) { + HcomSsl::SslShutdown(mSsl); + HcomSsl::SslFree(mSsl); + mSsl = nullptr; + } + NetFunc::NN_SafeCloseFd(mFd); + mCtxMap.clear(); +} + +void Sock::Close() +{ + std::lock_guard guard(mInitMutex); + if (mSsl != nullptr) { + HcomSsl::SslShutdown(mSsl); + HcomSsl::SslFree(mSsl); + mSsl = nullptr; + } + NetFunc::NN_SafeCloseFd(mFd); +} + +SResult Sock::ValidateOptions() +{ + if (NN_UNLIKELY(mOptions.receiveBufSizeKB == 0)) { + mOptions.receiveBufSizeKB = 1; /* min 1kB */ + } + return SS_OK; +} + +SResult Sock::SetSockOption(const SockWorkerOptions &workerOptions) +{ + /* fd is invalid */ + if (NN_UNLIKELY(mFd == -1)) { + NN_LOG_ERROR("Failed to initialize sock " << mId << " as mFd is invalid"); + return SS_PARAM_INVALID; + } + + mOptions.sendQueueSize = workerOptions.sendQueueSize; + mOptions.sendZCopy = workerOptions.tcpSendZCopy; + + if (workerOptions.sockReceiveBufKB > 0) { + /* set the size of receive buffer, which would compromise the performance of tcp */ + mOptions.receiveBufSizeKB = workerOptions.sockReceiveBufKB; + auto value = workerOptions.sockReceiveBufKB * NN_NO1024; + if (NN_UNLIKELY(setsockopt(mFd, SOL_SOCKET, SO_RCVBUF, &value, sizeof(value)) < 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set receive buffer for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + } + + if (workerOptions.sockSendBufKB > 0) { + /* set the size of send buffer, which would compromise the performance of tcp */ + mOptions.sendBufSizeKB = workerOptions.sockSendBufKB; + auto value = workerOptions.sockSendBufKB * NN_NO1024; + if (NN_UNLIKELY(setsockopt(mFd, SOL_SOCKET, SO_SNDBUF, &value, sizeof(value)) < 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set send buffer for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + } else { + int sendBufSize = 0; + socklen_t buffTypeSize = sizeof(sendBufSize); + if (NN_UNLIKELY(getsockopt(mFd, SOL_SOCKET, SO_SNDBUF, &sendBufSize, &buffTypeSize) < 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get send buffer for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_GET_OPTION_FAILED; + } + if (NN_UNLIKELY(sendBufSize <= 0)) { + NN_LOG_ERROR("send buffer size should be greater than 0 for sock" << mId); + return SS_TCP_GET_OPTION_FAILED; + } + mOptions.sendBufSizeKB = sendBufSize / NN_NO1024; + } + + /* stop here if uds */ + if (mType == SockType::SOCK_UDS) { + return SS_OK; + } + + /* following only for tcp */ + /* set keep alive */ + int value = 1; + auto optSize = sizeof(workerOptions.keepaliveIdleTime); + if (NN_UNLIKELY(setsockopt(mFd, SOL_SOCKET, SO_KEEPALIVE, &value, sizeof(value)) < 0 || + setsockopt(mFd, IPPROTO_TCP, TCP_KEEPIDLE, &workerOptions.keepaliveIdleTime, optSize) < 0 || + setsockopt(mFd, IPPROTO_TCP, TCP_KEEPINTVL, &workerOptions.keepaliveProbeInterval, optSize) < 0 || + setsockopt(mFd, IPPROTO_TCP, TCP_KEEPCNT, &workerOptions.keepaliveProbeTimes, optSize) < 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set keepalive options for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + + /* set no delay */ + if (workerOptions.tcpEnableNoDelay && + NN_UNLIKELY(setsockopt(mFd, SOL_TCP, TCP_NODELAY, reinterpret_cast(&value), sizeof(value)) != 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set no-delay for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + + if (workerOptions.tcpUserTimeout < -1 || workerOptions.tcpUserTimeout > static_cast(NN_NO1024)) { + NN_LOG_ERROR( + "tcpUserTimeout is invalid, it should be [-1, 1024], -1 means do not set, 0 means never timeout during io"); + return SS_PARAM_INVALID; + } + + /* set timeout during io (ms) */ + if (workerOptions.tcpUserTimeout >= 0) { + auto timeout = workerOptions.tcpUserTimeout * NN_NO1000; + if (NN_UNLIKELY(setsockopt(mFd, SOL_TCP, TCP_USER_TIMEOUT, &timeout, sizeof(timeout)) != 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set timeout for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + } + return SS_OK; +} + +/* in order to have same function as other protocols which take -1 as blocking forever and 0 as return immediately, + * while tcp acts contrarily, we switch it manually when setting. */ +SResult Sock::SetBlockingSendTimeout(int32_t sendTimeout) +{ + if (sendTimeout == mSendTimeoutSecond) { + return SS_OK; + } + mSendTimeoutSecond = sendTimeout; + + sendTimeout = sendTimeout > 0 ? sendTimeout : sendTimeout == 0 ? -1 : 0; + struct timeval timeval { + sendTimeout, 0 + }; + if (NN_UNLIKELY(setsockopt(mFd, SOL_SOCKET, SO_SNDTIMEO, &timeval, sizeof(timeval)) < 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set send timeout for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + + return SS_OK; +} + +SResult Sock::SetBlockingIo() +{ + int32_t value = NN_NO1; + /* set blocking, fcntl result is 0 or -1 */ + if (NN_UNLIKELY((value = fcntl(mFd, F_GETFL, 0)) == -1)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get control value for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + + if (NN_UNLIKELY((value = fcntl(mFd, F_SETFL, static_cast(value) & ~O_NONBLOCK)) == -1)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set control value for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + mTcpBlockingMode = true; + + return SS_OK; +} + +SResult Sock::SetNonBlockingIo() +{ + int value = 1; + /* set no-blocking */ + if (NN_UNLIKELY((value = fcntl(mFd, F_GETFL, 0)) == -1)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get control value for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + + if (NN_UNLIKELY((value = fcntl(mFd, F_SETFL, static_cast(value) | O_NONBLOCK)) == -1)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set control value for sock " << mId << ", errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + return SS_TCP_SET_OPTION_FAILED; + } + + mTcpBlockingMode = false; + + return SS_OK; +} + +SResult Sock::SetBlockingIo(UBSHcomEpOptions &epOptions) +{ + if (NN_UNLIKELY(SetBlockingIo() != SS_OK)) { + return SS_TCP_SET_OPTION_FAILED; + } + mCbByWorkerInBlocking = epOptions.cbByWorkerInBlocking; + if (NN_UNLIKELY(SetBlockingSendTimeout(epOptions.sendTimeout) != SS_OK)) { + return SS_TCP_SET_OPTION_FAILED; + } + return SS_OK; +} + +uint32_t Sock::GetSendQueueCount() +{ + return mSendQueue.Size(); +} +} +} diff --git a/src/transport/sock/sock_wrapper.h b/src/transport/sock/sock_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..7c62dd3fde98eeab72c052b769da21f1864ac027 --- /dev/null +++ b/src/transport/sock/sock_wrapper.h @@ -0,0 +1,1538 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef OCK_HCOM_SOCK_WRAPPER_H_234234 +#define OCK_HCOM_SOCK_WRAPPER_H_234234 + +#include +#include + +#include "net_common.h" +#include "net_ctx_info_pool.h" +#include "net_memory_region.h" +#include "net_oob.h" +#include "net_oob_ssl.h" +#include "net_security_rand.h" +#include "openssl_api_wrapper.h" +#include "securec.h" +#include "sock_common.h" +#include "net_oob_openssl.h" +#include "sock_buff.h" + +namespace ock { +namespace hcom { +struct SockReceiveState { + uint16_t headerLen = sizeof(SockTransHeader); + uint16_t headerToBeReceived = 0; + ssize_t bodyToBeReceived = -1; + + inline bool ShouldReceiveHeader() const + { + return bodyToBeReceived == -1; + } + + inline uint16_t ReceivedHeaderLen() const + { + return headerLen - headerToBeReceived; + } + + inline void ResetHeader() + { + headerToBeReceived = sizeof(SockTransHeader); + bodyToBeReceived = -1; + } + + inline bool BodySatisfied(ssize_t newReceivedSize) + { + bodyToBeReceived -= newReceivedSize; + return bodyToBeReceived == 0; + } + + inline bool HeaderSatisfied(uint16_t newReceivedHeader) + { + headerToBeReceived -= newReceivedHeader; + return headerToBeReceived == 0; + } +} __attribute__((packed)); + +struct SendingQueueRequest { + uint64_t remainSize = 0; + struct iovec iov[NN_NO7] {}; + uint16_t iovCount = 0; + bool isTwoSideMode = false; +}; + +/* ***************************************************************************************************** */ +/* + * @brief sock buffer list for send pending + */ +class SockBuffList { +public: +private: +}; + +/* ***************************************************************************************************** */ +/* + * @brief sock wrapper + */ +using SockOpContextInfoPool = OpContextInfoPool; +using SockSglContextInfoPool = OpContextInfoPool; +using SockHeaderReqInfoPool = OpContextInfoPool; +using SockPostedHandler = std::function; +using SockOneSideHandler = std::function; +class Sock { +public: + inline void BufferStatus() const + { + int buffer = 0; + socklen_t buffLen = sizeof buffer; + getsockopt(mFd, SOL_SOCKET, SO_SNDBUF, &buffer, &buffLen); + NN_LOG_INFO("send buffer size in total:" << buffer); + ioctl(mFd, SIOCOUTQ, &buffer); + NN_LOG_INFO("send buffer size in using:" << buffer); + getsockopt(mFd, SOL_SOCKET, SO_RCVBUF, &buffer, &buffLen); + NN_LOG_INFO("receive buffer size in total:" << buffer); + ioctl(mFd, SIOCINQ, &buffer); + NN_LOG_INFO("receive buffer size in using:" << buffer); + } + + static inline SResult SendRealConnHeader(int fd, void *buf, uint32_t size) + { + if (NN_UNLIKELY(fd == -1 || buf == nullptr)) { + return SS_PARAM_INVALID; + } + + if (NN_UNLIKELY(::send(fd, buf, size, 0) <= 0)) { + NN_LOG_ERROR("Failed to send real connection header, with errno is " << errno); + return SS_SOCK_SEND_FAILED; + } + + return SS_OK; + } + + inline SResult Send(const void *buf, uint32_t size) + { + if (NN_UNLIKELY(mFd == -1 || buf == nullptr)) { + return SS_PARAM_INVALID; + } + + ssize_t result = 0; + + if (mEnableTls) { + uint32_t writeLen = 0; + return SSLSend(buf, size, writeLen); + } else { + result = ::send(mFd, buf, size, 0); + if (result <= 0) { + NN_LOG_ERROR("Failed to send data, ret: " << result << ", errno: " << errno); + return errno; + } + } + + if (NN_UNLIKELY(result != size)) { + NN_LOG_ERROR("Failed to send data, expected size: " << size << ", actual size: " << result); + return SS_SOCK_DATA_SIZE_UN_MATCHED; + } + + return SS_OK; + } + + inline SResult Receive(void *&buf, uint32_t size) + { + if (NN_UNLIKELY(mFd == -1 || buf == nullptr)) { + return SS_PARAM_INVALID; + } + + ssize_t result = 0; + + if (mEnableTls) { + return SSLRead(buf, size, reinterpret_cast(result)); + } else { + result = ::recv(mFd, buf, size, 0); + if (result <= 0) { + NN_LOG_ERROR("Failed to recv data, ret: " << result << ", errno: " << errno); + return errno; + } + } + + if (NN_UNLIKELY(result != size)) { + NN_LOG_ERROR("Failed to recv data, expected size: " << size << ", actual size: " << result); + return SS_SOCK_DATA_SIZE_UN_MATCHED; + } + + return SS_OK; + } + +public: + Sock(SockType type, const std::string &name, uint64_t id, int fd, SockOptions &options) + : mFd(fd), + mQueueVacantSize(options.sendQueueSize), + mQueueSize(options.sendQueueSize), + mName(name), + mId(id), + mType(type), + mSendQueue(options.sendQueueSize) + { + mEnableTls = false; + OBJ_GC_INCREASE(Sock); + } + + Sock(SockType type, const std::string &name, uint64_t id, int fd, SockOptions &options, OOBTCPConnection *conn) + : mFd(fd), + mQueueVacantSize(options.sendQueueSize), + mQueueSize(options.sendQueueSize), + mName(name), + mId(id), + mType(type), + mSendQueue(options.sendQueueSize) + { + mEnableTls = true; + mSsl = reinterpret_cast(conn)->TransferSsl(); + + OBJ_GC_INCREASE(Sock); + } + + ~Sock() + { + UnInitialize(); + + OBJ_GC_DECREASE(Sock); + } + + SResult Initialize(const SockWorkerOptions &workerOptions); + void UnInitialize(); + void Close(); + SResult SetBlockingSendTimeout(int32_t sendTimeout); + SResult SetBlockingIo(UBSHcomEpOptions &epOptions); + uint32_t GetSendQueueCount(); + + /* + * @brief Get name + */ + inline const std::string &Name() const + { + return mName; + } + + /* + * @brief Get ip and port of peer + */ + inline const std::string &PeerIpPort() const + { + return mPeerIpPort; + } + + /* + * @brief Get sock id + */ + inline uint64_t Id() const + { + return mId; + } + + /* + * @brief Set a context by caller + */ + inline void UpContext(uint64_t value) + { + mUpCtx = value; + } + + /* + * @brief Get a context by caller + */ + inline uint64_t UpContext() const + { + return mUpCtx; + } + + /* + * @brief Set a context by caller + */ + inline void UpContext1(uint64_t value) + { + mUpCtx1 = value; + } + + /* + * @brief Get a context by caller + */ + inline uint64_t UpContext1() const + { + return mUpCtx1; + } + + /* + * @brief Get a secret by caller + */ + inline NetSecrets &Secret() + { + return mSecret; + } + + /* + * @brief Set a secret by caller + */ + inline void Secret(NetSecrets &secret) + { + mSecret = secret; + } + + /* + * @brief Get receive data + */ + inline SockBuff &ReceiveData() + { + return mReceiveBuff; + } + + /* + * @brief Get the file description of socket + */ + inline int FD() const + { + return mFd; + } + + inline bool CbByWorkerInBlocking() const + { + return mCbByWorkerInBlocking; + } + + inline void SetCbByWorkerInBlocking(bool cbByWorkerInBlocking) + { + mCbByWorkerInBlocking = cbByWorkerInBlocking; + } + + inline uint32_t OneSideNextSeq() + { + return __sync_fetch_and_add(&mSeqIndex, 1); + } + + inline void AddOpCtx(uint32_t id, SockOpContextInfo *opCtx) + { + std::lock_guard guard(mCtxMutex); + + mCtxMap.emplace(id, opCtx); + } + + inline SockOpContextInfo *RemoveOpCtx(uint32_t id) + { + std::lock_guard guard(mCtxMutex); + + auto iter = mCtxMap.find(id); + if (NN_UNLIKELY(iter == mCtxMap.end())) { + return nullptr; + } + SockOpContextInfo *ctxInfo = iter->second; + mCtxMap.erase(iter); + return ctxInfo; + } + /* + * @brief Get header address + */ + inline SockTransHeader *GetHeaderAddress() + { + return &mHeader; + } + +#define COMPOSE_REQUEST(mSendingQueueRequest, reqInQueue) \ + if ((mSendingQueueRequest).remainSize == 0) { \ + if ((reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SEND || \ + (reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SEND_RAW) { \ + uint32_t reqSize = 0; \ + if (mOptions.sendZCopy) { \ + (mSendingQueueRequest).iov[NN_NO0].iov_base = &(reqInQueue)->headerRequest->sendHeader; \ + (mSendingQueueRequest).iov[NN_NO1].iov_base = (reqInQueue)->headerRequest->request; \ + reqSize = (reqInQueue)->headerRequest->sendHeader.dataLength; \ + } else { \ + (mSendingQueueRequest).iov[NN_NO0].iov_base = (reqInQueue)->sendBuff; \ + (mSendingQueueRequest).iov[NN_NO1].iov_base = reinterpret_cast( \ + reinterpret_cast((reqInQueue)->sendBuff) + sizeof(SockTransHeader)); \ + reqSize = reinterpret_cast((reqInQueue)->sendBuff)->dataLength; \ + } \ + (mSendingQueueRequest).iov[NN_NO0].iov_len = sizeof(SockTransHeader); \ + (mSendingQueueRequest).iov[NN_NO1].iov_len = reqSize; \ + (mSendingQueueRequest).iovCount = NN_NO2; \ + (mSendingQueueRequest).remainSize = reqSize + sizeof(SockTransHeader); \ + } else if ((reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SEND_RAW_SGL || \ + (reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_READ_ACK) { \ + auto sendCtx = (reqInQueue)->sendCtx; \ + (mSendingQueueRequest).iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); \ + (mSendingQueueRequest).iov[NN_NO0].iov_len = sizeof(SockTransHeader); \ + for (uint16_t i = 0; i < sendCtx->iovCount; i++) { \ + (mSendingQueueRequest).iov[i + NN_NO1].iov_base = reinterpret_cast(sendCtx->iov[i].lAddress); \ + (mSendingQueueRequest).iov[i + NN_NO1].iov_len = sendCtx->iov[i].size; \ + } \ + (mSendingQueueRequest).iovCount = sendCtx->iovCount + NN_NO1; \ + (mSendingQueueRequest).remainSize = sendCtx->sendHeader.dataLength + sizeof(SockTransHeader); \ + } else if ((reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_WRITE) { \ + auto sendCtx = (reqInQueue)->sendCtx; \ + (mSendingQueueRequest).iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); \ + (mSendingQueueRequest).iov[NN_NO0].iov_len = sizeof(SockTransHeader); \ + (mSendingQueueRequest).iov[NN_NO1].iov_base = reinterpret_cast(&sendCtx->iov[0]); \ + (mSendingQueueRequest).iov[NN_NO1].iov_len = sizeof(UBSHcomNetTransSgeIov); \ + (mSendingQueueRequest).iov[NN_NO2].iov_base = reinterpret_cast(sendCtx->iov[0].lAddress); \ + (mSendingQueueRequest).iov[NN_NO2].iov_len = sendCtx->iov[0].size; \ + (mSendingQueueRequest).iovCount = NN_NO3; \ + (mSendingQueueRequest).remainSize = \ + sizeof(SockTransHeader) + sizeof(UBSHcomNetTransSgeIov) + sendCtx->iov[0].size; \ + } else if ((reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_READ) { \ + auto sendCtx = (reqInQueue)->sendCtx; \ + (mSendingQueueRequest).iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); \ + (mSendingQueueRequest).iov[NN_NO0].iov_len = sizeof(SockTransHeader); \ + (mSendingQueueRequest).iov[NN_NO1].iov_base = reinterpret_cast(&sendCtx->iov[0]); \ + (mSendingQueueRequest).iov[NN_NO1].iov_len = sizeof(UBSHcomNetTransSgeIov); \ + (mSendingQueueRequest).iovCount = NN_NO2; \ + (mSendingQueueRequest).remainSize = sizeof(SockTransHeader) + sizeof(UBSHcomNetTransSgeIov); \ + } else if ((reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_WRITE_ACK || \ + (reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SGL_WRITE_ACK) { \ + (mSendingQueueRequest).iov[NN_NO0].iov_base = \ + reinterpret_cast(&(reqInQueue)->sendCtx->sendHeader); \ + (mSendingQueueRequest).iov[NN_NO0].iov_len = sizeof(SockTransHeader); \ + (mSendingQueueRequest).iovCount = NN_NO1; \ + (mSendingQueueRequest).remainSize = sizeof(SockTransHeader); \ + } else if ((reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SGL_READ) { \ + auto sendCtx = (reqInQueue)->sendCtx; \ + (mSendingQueueRequest).iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); \ + (mSendingQueueRequest).iov[NN_NO0].iov_len = sizeof(SockTransHeader); \ + (mSendingQueueRequest).iov[NN_NO1].iov_base = reinterpret_cast(&sendCtx->iovCount); \ + (mSendingQueueRequest).iov[NN_NO1].iov_len = sizeof(UBSHcomNetTransSglRequest::iovCount); \ + (mSendingQueueRequest).iov[NN_NO2].iov_base = reinterpret_cast(sendCtx->iov); \ + (mSendingQueueRequest).iov[NN_NO2].iov_len = sizeof(UBSHcomNetTransSgeIov) * sendCtx->iovCount; \ + (mSendingQueueRequest).iovCount = NN_NO3; \ + (mSendingQueueRequest).remainSize = sendCtx->sendHeader.dataLength + sizeof(SockTransHeader); \ + } else if ((reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SGL_WRITE || \ + (reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SGL_READ_ACK) { \ + auto sendCtx = (reqInQueue)->sendCtx; \ + (mSendingQueueRequest).iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); \ + (mSendingQueueRequest).iov[NN_NO0].iov_len = sizeof(SockTransHeader); \ + (mSendingQueueRequest).iov[NN_NO1].iov_base = reinterpret_cast(&sendCtx->iovCount); \ + (mSendingQueueRequest).iov[NN_NO1].iov_len = sizeof(UBSHcomNetTransSglRequest::iovCount); \ + (mSendingQueueRequest).iov[NN_NO2].iov_base = reinterpret_cast(sendCtx->iov); \ + (mSendingQueueRequest).iov[NN_NO2].iov_len = sizeof(UBSHcomNetTransSgeIov) * sendCtx->iovCount; \ + for (uint16_t i = 0; i < sendCtx->iovCount; i++) { \ + if ((reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SGL_WRITE) { \ + (mSendingQueueRequest).iov[i + NN_NO3].iov_base = \ + reinterpret_cast(sendCtx->iov[i].lAddress); \ + } else { \ + (mSendingQueueRequest).iov[i + NN_NO3].iov_base = \ + reinterpret_cast(sendCtx->iov[i].rAddress); \ + } \ + (mSendingQueueRequest).iov[i + NN_NO3].iov_len = sendCtx->iov[i].size; \ + } \ + (mSendingQueueRequest).iovCount = NN_NO3 + sendCtx->iovCount; \ + (mSendingQueueRequest).remainSize = sendCtx->sendHeader.dataLength + sizeof(SockTransHeader); \ + } \ + \ + if ((reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SEND || \ + (reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SEND_RAW || \ + (reqInQueue)->opType == SockOpContextInfo::SockOpType::SS_SEND_RAW_SGL) { \ + (mSendingQueueRequest).isTwoSideMode = true; \ + } else { \ + (mSendingQueueRequest).isTwoSideMode = false; \ + } \ + } + +#define PROCESS_REQUEST(mSendingQueueRequest, reqInQueue) \ + do { \ + std::lock_guard guard(mIoMutex); \ + ssize_t result = 0; \ + if ((mSendingQueueRequest).isTwoSideMode && mEnableTls) { \ + auto iov = (mSendingQueueRequest).iov; \ + for (uint16_t i = 0; i < (mSendingQueueRequest).iovCount; i++) { \ + if (iov[i].iov_len == 0) { \ + continue; \ + } \ + int ret = 0; \ + SResult writeRet = SS_OK; \ + if (i == NN_NO0) { \ + ret = writev(mFd, &iov[i], NN_NO1); \ + } else { \ + writeRet = SSLSend(iov[i].iov_base, iov[i].iov_len, reinterpret_cast(ret)); \ + } \ + if (ret <= 0 || writeRet != SS_OK) { \ + if (errno != ECONNRESET && errno != EAGAIN) { \ + NN_LOG_ERROR("Failed to send msg with tls to peer in sock " << mId << " name " << mName << \ + " result:" << result << " iov_len:" << iov[i].iov_len); \ + return SS_SOCK_SEND_FAILED; \ + } \ + break; \ + } \ + result += ret; \ + if (static_cast(ret) != iov[i].iov_len) { \ + break; \ + } \ + } \ + } else { \ + result = writev(mFd, reinterpret_cast(&(mSendingQueueRequest).iov), \ + (mSendingQueueRequest).iovCount); \ + } \ + if (result <= 0) { \ + if (errno == ECONNRESET) { \ + NN_LOG_ERROR("Failed to send msg to peer in sock " << mId << " name " << mName << ", reset by peer, " \ + << " result:" << result); \ + return SS_RESET_BY_PEER; \ + } \ + if (errno == EAGAIN) { \ + /* send buff is full not send */ \ + return SS_SOCK_SEND_EAGAIN; \ + } \ + NN_LOG_ERROR("Failed to send msg to peer in sock " << mId << " name " << mName << ", errno " << errno << \ + " error code:" << errno << " result:" << result); \ + return SS_SOCK_SEND_FAILED; \ + } \ + \ + NN_LOG_TRACE_INFO("Receive sock " << Id() << " event EPOLLOUT," \ + << "queue size:" << mSendQueue.Size() << " deque and write result: " << \ + result << " req size:" << (mSendingQueueRequest).remainSize << " max send size:" << maxSendSize); \ + if (static_cast(result) < (mSendingQueueRequest).remainSize) { \ + for (uint32_t i = 0; i < (mSendingQueueRequest).iovCount; i++) { \ + auto iovLen = static_cast((mSendingQueueRequest).iov[i].iov_len); \ + if (result < iovLen) { \ + (mSendingQueueRequest).iov[i].iov_base = \ + reinterpret_cast((mSendingQueueRequest).iov[i].iov_base) + result; \ + (mSendingQueueRequest).iov[i].iov_len -= static_cast(result); \ + (mSendingQueueRequest).remainSize -= static_cast(result); \ + break; \ + } else if (result > iovLen) { \ + result -= iovLen; \ + (mSendingQueueRequest).remainSize -= static_cast(iovLen); \ + (mSendingQueueRequest).iov[i].iov_len = 0; \ + } else { \ + (mSendingQueueRequest).remainSize -= static_cast(iovLen); \ + (mSendingQueueRequest).iov[i].iov_len = 0; \ + break; \ + } \ + } \ + return SS_SOCK_SEND_EAGAIN; \ + } else { \ + (mSendingQueueRequest).remainSize = 0; \ + } \ + } while (0) + +#define POST_PROCESS(popReq) \ + do { \ + ReturnQueueSpace(NN_NO1); \ + if ((popReq)->opType == SockOpContextInfo::SockOpType::SS_SEND || \ + (popReq)->opType == SockOpContextInfo::SockOpType::SS_SEND_RAW || \ + (popReq)->opType == SockOpContextInfo::SockOpType::SS_SEND_RAW_SGL) { \ + mSendPostedHandler((popReq)); \ + } \ + \ + if ((popReq)->opType == SockOpContextInfo::SockOpType::SS_WRITE_ACK || \ + (popReq)->opType == SockOpContextInfo::SockOpType::SS_READ_ACK || \ + (popReq)->opType == SockOpContextInfo::SockOpType::SS_SGL_WRITE_ACK || \ + (popReq)->opType == SockOpContextInfo::SockOpType::SS_SGL_READ_ACK) { \ + mSglCtxInfoPool.Return((popReq)->sendCtx); \ + (popReq)->sendCtx = nullptr; \ + mOpCtxInfoPool.Return((popReq)); \ + (popReq) = nullptr; \ + } \ + } while (0) + + inline void DealCbWithFailure() + { + SockOpContextInfo *popReq = {}; + while (mSendQueue.GetFront(popReq)) { + if (!mSendQueue.PopFront(popReq)) { + break; + } + popReq->errType = SockOpContextInfo::SS_OPERATE_FAILURE; + POST_PROCESS(popReq); + } + + for (auto &it : mCtxMap) { + it.second->errType = SockOpContextInfo::SS_OPERATE_FAILURE; + mOneSideDoneHandler(it.second); + } + mCtxMap.clear(); + } + + inline SResult ProcessQueueReq() + { + int64_t maxSendSize = static_cast(mOptions.sendBufSizeKB) * NN_NO1024; + bool isGetSuccess = true; + + while (maxSendSize > 0) { + SockOpContextInfo *reqInQueue = nullptr; + isGetSuccess = mSendQueue.GetFront(reqInQueue); + if (!isGetSuccess) { + return SS_OK; + } + + if (!reqInQueue->isSent) { + COMPOSE_REQUEST(mSendingQueueRequest, reqInQueue); + + auto sentSize = mSendingQueueRequest.remainSize; + + PROCESS_REQUEST(mSendingQueueRequest, reqInQueue); + + maxSendSize -= static_cast(sentSize); + } + SockOpContextInfo *popReq = {}; + isGetSuccess = mSendQueue.PopFront(popReq); + if (!isGetSuccess) { + return SS_OK; + } + POST_PROCESS(popReq); + } + + return SS_SOCK_SEND_EAGAIN; + } + +#define POST_SEND(iov, requestSize, seqNo) \ + do { \ + ssize_t ret = 0; \ + if (!mEnableTls) { \ + if ((ret = writev(mFd, reinterpret_cast(&(iov)), NN_NO2)) < \ + static_cast((requestSize) + sizeof(SockTransHeader))) { \ + if (ret == 0) { \ + return SS_TCP_RETRY; \ + } \ + if (errno == 0) { \ + NN_LOG_ERROR("Failed to PostSend to peer in sock " << mId << " name " << mName << " with " << \ + mSendTimeoutSecond << " second timeout, " << ret << " is sent"); \ + return SS_TIMEOUT; \ + } \ + NN_LOG_ERROR("Failed to PostSend to peer in sock " << mId << " name " << mName << ", errno " << \ + errno << ", seqNo " << (seqNo)); \ + return SS_SOCK_SEND_FAILED; \ + } \ + } else { \ + if ((ret = writev(mFd, &(iov)[NN_NO0], NN_NO1)) < static_cast(sizeof(SockTransHeader))) { \ + if (ret == 0) { \ + return SS_TCP_RETRY; \ + } \ + if (errno == 0) { \ + NN_LOG_ERROR("(TLS)Failed to PostSend header to peer in sock " << mId << " name " << mName << \ + " with " << mSendTimeoutSecond << " second timeout, " << ret << " is sent"); \ + return SS_TIMEOUT; \ + } \ + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; \ + NN_LOG_ERROR("(TLS)Failed to PostSend header to peer in sock " << mId << " name " << mName << \ + ", error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << ", seqNo " << \ + (seqNo)); \ + return SS_SOCK_SEND_FAILED; \ + } \ + uint32_t writeLen = 0; \ + ret = SSLSend((iov)[NN_NO1].iov_base, (iov)[NN_NO1].iov_len, writeLen); \ + if (ret != SS_OK || writeLen != (iov)[NN_NO1].iov_len) { \ + if (ret == SS_TIMEOUT) { \ + NN_LOG_ERROR("(TLS)Failed to PostSendSgl body to peer in sock " << mId << " name " << mName << \ + ", error is timeout with " << mSendTimeoutSecond << " second, seqNo " << (seqNo) << \ + ", the failed iov data len " << (iov)[NN_NO1].iov_len); \ + return SS_TIMEOUT; \ + } \ + NN_LOG_ERROR("(TLS)Failed to PostSend body to peer in sock " << mId << " name " << mName << \ + ", seqNo " << (seqNo) << ", the failed iov data len " << (iov)[NN_NO1].iov_len); \ + return SS_SOCK_SEND_FAILED; \ + } \ + } \ + } while (0) + + inline SResult PostSend(SockOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr)) { + return SS_PARAM_INVALID; + } + + if (mTcpBlockingMode) { + struct iovec iov[NN_NO2]; + uint32_t requestSize = 0; + if (mOptions.sendZCopy) { + iov[NN_NO0].iov_base = &ctx->headerRequest->sendHeader; + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + iov[NN_NO1].iov_base = ctx->headerRequest->request; + requestSize = reinterpret_cast(&ctx->headerRequest->sendHeader)->dataLength; + iov[NN_NO1].iov_len = requestSize; + } else { + iov[NN_NO0].iov_base = ctx->sendBuff; + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + iov[NN_NO1].iov_base = + reinterpret_cast(reinterpret_cast(ctx->sendBuff) + sizeof(SockTransHeader)); + requestSize = reinterpret_cast(ctx->sendBuff)->dataLength; + iov[NN_NO1].iov_len = requestSize; + } + std::lock_guard guard(mIoMutex); + POST_SEND(iov, requestSize, reinterpret_cast(ctx->sendBuff)->seqNo); + + if (mCbByWorkerInBlocking) { + ctx->isSent = true; + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + NN_LOG_TRACE_INFO("Post send request successfully : sock " << mId << ", head imm data " << + reinterpret_cast(ctx->sendBuff)->immData << ", flags " << + reinterpret_cast(ctx->sendBuff)->flags << ", seqNo " << + reinterpret_cast(ctx->sendBuff)->seqNo << ", data len " << + reinterpret_cast(ctx->sendBuff)->dataLength); + return SS_OK; + } else { + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + } + + inline SResult PostSend(SockTransHeader &header, const UBSHcomNetTransRequest &req) + { + struct iovec iov[NN_NO2]; + iov[NN_NO0].iov_base = reinterpret_cast(&header); + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + iov[NN_NO1].iov_base = reinterpret_cast(req.lAddress); + iov[NN_NO1].iov_len = req.size; + + std::lock_guard guard(mIoMutex); + POST_SEND(iov, req.size, header.seqNo); + + NN_LOG_TRACE_INFO("PostSend request successfully : sock " << mId << ", head imm data " << header.immData << + ", flags " << header.flags << ", seqNo " << header.seqNo << ", data len " << header.dataLength); + return SS_OK; + } + + inline SResult PostSendSglSsl(SockOpContextInfo *ctx, struct iovec *iov, uint32_t iovLen = NN_NO5) + { + auto sendCtx = ctx->sendCtx; + ssize_t ret = 0; + if ((ret = writev(mFd, &iov[NN_NO0], NN_NO1)) < static_cast(sizeof(SockTransHeader))) { + if (ret == 0) { + return SS_TCP_RETRY; + } + if (errno == 0) { + NN_LOG_ERROR("(TLS)Failed to PostSendSgl header to peer in sock " << mId << " name " << mName << + " with " << mSendTimeoutSecond << " second timeout, " << ret << " is sent"); + return SS_TIMEOUT; + } + + NN_LOG_ERROR("(TLS)Failed to PostSendSgl header to peer in sock " << mId << " name " << mName << + ", errno " << errno << ", seqNo " << reinterpret_cast(ctx->sendBuff)->seqNo); + return SS_SOCK_SEND_FAILED; + } + + for (uint32_t i = 1; i < NN_NO1 + sendCtx->iovCount; i++) { + uint32_t writeLen = 0; + ret = SSLSend(iov[i].iov_base, iov[i].iov_len, writeLen); + if (ret == SS_TIMEOUT) { + NN_LOG_ERROR("(TLS)Failed to PostSendSgl body to peer in sock " << mId << " name " << + mName << ", error is timeout with " << mSendTimeoutSecond << " second, seqNo " << + reinterpret_cast(ctx->sendBuff)->seqNo << + ", the failed iov data len " << iov[NN_NO1].iov_len); + return SS_TIMEOUT; + } + if (ret != SS_OK || writeLen != static_cast(iov[i].iov_len)) { + NN_LOG_ERROR("(TLS)Failed to PostSendSgl body to peer in sock " << mId << " name " << mName << + ", seqNo " << reinterpret_cast(ctx->sendBuff)->seqNo << + ", the failed iov data len " << iov[i].iov_len); + return SS_SOCK_SEND_FAILED; + } + } + return SS_OK; + } + + inline SResult PostSendSgl(SockOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr)) { + return SS_PARAM_INVALID; + } + + if (mTcpBlockingMode) { + auto sendCtx = ctx->sendCtx; + struct iovec iov[NN_NO5]; + iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + + NN_LOG_TRACE_INFO("PostSendSgl in sock iov count " << sendCtx->iovCount << ", head size " << + iov[NN_NO0].iov_len); + size_t requestSize = 0; + for (uint16_t i = 0; i < sendCtx->iovCount; i++) { + iov[i + NN_NO1].iov_base = reinterpret_cast(sendCtx->iov[i].lAddress); + iov[i + NN_NO1].iov_len = sendCtx->iov[i].size; + NN_LOG_TRACE_INFO("iov index " << i + NN_NO1 << ", length " << iov[i + NN_NO1].iov_len); + requestSize += iov[i + NN_NO1].iov_len; + } + + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + if (mEnableTls && ctx->opType == SockOpContextInfo::SockOpType::SS_SEND_RAW_SGL) { + if ((ret = PostSendSglSsl(ctx, iov, NN_NO5) != SS_OK)) { + NN_LOG_ERROR("PostSendSglSsl failed, ret: " << ret); + return ret; + } + } else { + if ((ret = writev(mFd, reinterpret_cast(&iov), NN_NO1 + sendCtx->iovCount)) < + static_cast(requestSize + sizeof(SockTransHeader))) { + if (ret == 0) { + return SS_TCP_RETRY; + } + if (errno == 0) { + NN_LOG_ERROR("Failed to PostSendSgl to peer in sock: " << mId << " name: " << mName << " with " + << mSendTimeoutSecond << " second timeout, " << ret << " is sent"); + return SS_TIMEOUT; + } + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to PostSendSgl to peer in sock " << mId << " name " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SS_SOCK_SEND_FAILED; + } + } + + if (mCbByWorkerInBlocking && ctx->opType == SockOpContextInfo::SockOpType::SS_SEND_RAW_SGL) { + ctx->isSent = true; + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + NN_LOG_TRACE_INFO("Post send request successfully : sock " << mId << ", head imm data " << + sendCtx->sendHeader.immData << ", flags " << sendCtx->sendHeader.flags << ", seqNo " << + sendCtx->sendHeader.seqNo << ", data len " << sendCtx->sendHeader.dataLength); + + return SS_OK; + } else { + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + } + + inline SResult PostSendSgl(SockTransHeader &header, const UBSHcomNetTransSglRequest &req) + { + struct iovec iov[NN_NO5]; + iov[NN_NO0].iov_base = reinterpret_cast(&header); + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + + NN_LOG_TRACE_INFO("Send raw sgl in sock iov count " << req.iovCount << ", head size " << iov[NN_NO0].iov_len); + size_t requestSize = 0; + for (uint16_t i = 0; i < req.iovCount; i++) { + iov[i + NN_NO1].iov_base = reinterpret_cast(req.iov[i].lAddress); + iov[i + NN_NO1].iov_len = req.iov[i].size; + NN_LOG_TRACE_INFO("iov index " << i + NN_NO1 << ", length " << req.iov[i].size); + requestSize += iov[i + NN_NO1].iov_len; + } + + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + if (!mEnableTls) { + if ((ret = writev(mFd, reinterpret_cast(&iov), NN_NO1 + req.iovCount)) < + static_cast(requestSize + sizeof(SockTransHeader))) { + if (ret == 0) { + return SS_TCP_RETRY; + } + if (errno == 0) { + NN_LOG_ERROR("Failed to PostSendSgl to peer in sock " << mId << " name " << mName << " with " << + mSendTimeoutSecond << " second timeout, " << ret << " is sent"); + return SS_TIMEOUT; + } + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to PostSendSgl to peer in sock " << mId << " name " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SS_SOCK_SEND_FAILED; + } + } else { + if ((ret = writev(mFd, &iov[NN_NO0], NN_NO1)) < static_cast(sizeof(SockTransHeader))) { + if (ret == 0) { + return SS_TCP_RETRY; + } + if (errno == 0) { + NN_LOG_ERROR("(TLS)Failed to PostSendSgl header to peer in sock " << mId << " name " << mName << + " with " << mSendTimeoutSecond << " second timeout, " << ret << " is sent"); + return SS_TIMEOUT; + } + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("(TLS)Failed to PostSendSgl header to peer in sock " << mId << " name " << mName << + ", error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << ", seqNo " << + header.seqNo); + return SS_SOCK_SEND_FAILED; + } + for (uint16_t i = 1; i < NN_NO1 + req.iovCount; i++) { + uint32_t writeLen = 0; + ret = SSLSend(iov[i].iov_base, iov[i].iov_len, writeLen); + if (ret == SS_TIMEOUT) { + NN_LOG_ERROR("(TLS)Failed to PostSendSgl body to peer in sock " << mId << " name " << mName << + ", error is timeout with " << mSendTimeoutSecond << " second, seqNo " << header.seqNo << + ", the failed iov data len " << iov[i].iov_len); + return SS_TIMEOUT; + } + if (ret != SS_OK || writeLen != iov[i].iov_len) { + NN_LOG_ERROR("(TLS)Failed to PostSendSgl body to peer in sock " << mId << " name " << mName << + ", seqNo " << header.seqNo << ", the failed iov data len " << iov[i].iov_len); + return SS_SOCK_SEND_FAILED; + } + } + } + NN_LOG_TRACE_INFO("PostSendSgl request successfully: sock " << mId << ", head imm data " << header.immData << + ", flags " << header.flags << ", seqNo " << header.seqNo << ", data len " << header.dataLength); + return SS_OK; + } + + inline SResult PostSendHead(SockOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr)) { + return SS_PARAM_INVALID; + } + + if (mTcpBlockingMode) { + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + if ((ret = ::send(mFd, reinterpret_cast(&ctx->sendCtx->sendHeader), sizeof(SockTransHeader), 0)) < + static_cast(sizeof(SockTransHeader))) { + if (ret == 0) { + return SS_TCP_RETRY; + } + if (errno == 0) { + NN_LOG_ERROR("Failed to PostSendHead to peer in sock " << mId << " name " << mName << " with " << + mSendTimeoutSecond << " second, " << ret << " is sent"); + return SS_TIMEOUT; + } + NN_LOG_ERROR("Failed to PostSendHead to peer in sock " << mId << " name " << mName << ", errno " << + errno); + return SS_SOCK_SEND_FAILED; + } + NN_LOG_TRACE_INFO("Post send head successfully: sock " << mId << ", head imm data " << + ctx->sendCtx->sendHeader.immData << ", flags " << ctx->sendCtx->sendHeader.flags << ", seqNo " << + ctx->sendCtx->sendHeader.seqNo); + return SS_OK; + } else { + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + } + + inline SResult PostRead(SockOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr)) { + return SS_PARAM_INVALID; + } + + if (mTcpBlockingMode) { + auto sendCtx = ctx->sendCtx; + struct iovec iov[NN_NO2]; + iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + iov[NN_NO1].iov_base = reinterpret_cast(&sendCtx->iov[0]); + iov[NN_NO1].iov_len = sizeof(UBSHcomNetTransSgeIov); + + auto length = iov[NN_NO0].iov_len + iov[NN_NO1].iov_len; + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + if ((ret = writev(mFd, reinterpret_cast(&iov), NN_NO2)) < + static_cast(length)) { + if (ret == 0) { + return SS_TCP_RETRY; + } + + if (errno == 0) { + NN_LOG_ERROR("Failed to PostRead to peer in sock " << mId << " name " << mName << " with " << + mSendTimeoutSecond << " second timeout, " << ret << " is sent"); + return SS_TIMEOUT; + } + + NN_LOG_ERROR("Failed to PostRead to peer in sock " << mId << " name " << mName << ", errno " << errno << + ", seqNo " << sendCtx->sendHeader.seqNo); + return SS_SOCK_SEND_FAILED; + } + NN_LOG_TRACE_INFO("PostRead successfully: sock " << mId << ", head imm data " << + sendCtx->sendHeader.immData << ", flags " << sendCtx->sendHeader.flags << ", seqNo " << + sendCtx->sendHeader.seqNo << ", data len " << sendCtx->sendHeader.dataLength); + return SS_OK; + } else { + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + } + + inline SResult PostWrite(SockOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr)) { + return SS_PARAM_INVALID; + } + + if (mTcpBlockingMode) { + auto sendCtx = ctx->sendCtx; + struct iovec iov[NN_NO3]; + iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + iov[NN_NO1].iov_base = reinterpret_cast(&sendCtx->iov[0]); + iov[NN_NO1].iov_len = sizeof(UBSHcomNetTransSgeIov); + iov[NN_NO2].iov_base = reinterpret_cast(sendCtx->iov[0].lAddress); + iov[NN_NO2].iov_len = sendCtx->iov[0].size; + + auto length = iov[NN_NO0].iov_len + iov[NN_NO1].iov_len + iov[NN_NO2].iov_len; + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + if ((ret = writev(mFd, reinterpret_cast(&iov), NN_NO3)) < + static_cast(length)) { + if (ret == 0) { + return SS_TCP_RETRY; + } + if (errno == 0) { + NN_LOG_ERROR("Failed to PostSendSgl to peer in sock " << mId << " name " << mName << " with " << + mSendTimeoutSecond << " second timeout, " << ret << " is sent"); + return SS_TIMEOUT; + } + NN_LOG_ERROR("Failed to PostSendSgl to peer in sock " << mId << " name " << mName << ", errno " << + errno); + return SS_SOCK_SEND_FAILED; + } + + NN_LOG_TRACE_INFO("PostWrite request successfully : sock " << mId << ", head imm data " << + sendCtx->sendHeader.immData << ", flags " << sendCtx->sendHeader.flags << ", seqNo " << + sendCtx->sendHeader.seqNo << ", data len " << sendCtx->sendHeader.dataLength); + + return SS_OK; + } else { + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + } + + inline SResult PostReceiveHeader(SockTransHeader &header, int32_t timeoutSecond = 0) + { + if (NN_UNLIKELY(mRevTimeoutSecond != timeoutSecond)) { + mRevTimeoutSecond = timeoutSecond; + timeoutSecond = timeoutSecond > 0 ? timeoutSecond : timeoutSecond == 0 ? -1 : 0; + struct timeval timeout = { timeoutSecond, 0 }; + if (NN_UNLIKELY( + setsockopt(mFd, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout), sizeof(timeval)) < 0)) { + return SS_TCP_SET_OPTION_FAILED; + } + } + { + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + uint32_t result = 0; + auto buff = reinterpret_cast(&header); + size_t remainingSize = sizeof(SockTransHeader); + while (result < sizeof(SockTransHeader)) { + ret = ::recv(mFd, buff, remainingSize, 0); + if (errno == EAGAIN) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to PostReceiveHeader from peer in sock " << mId << " name " << mName << + ", error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << + " due to timeout with " << mRevTimeoutSecond << " second, " << ret << " is received"); + return SS_TIMEOUT; + } + if (ret <= 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to PostReceiveHeader from peer in sock " << mId << " name " << mName << + ", error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SS_SOCK_SEND_FAILED; + } + + buff = reinterpret_cast(reinterpret_cast(buff) + ret); + result += static_cast(ret); + remainingSize -= static_cast(ret); + } + } + + auto result = NetFunc::ValidateHeader(header); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("Failed to validate received header, ep " << Id()); + return result; + } + NN_LOG_TRACE_INFO("PostReceiveHeader from peer successfully: sock " << mId << ", head imm data " << + header.immData << ", flags " << header.flags << ", seqNo " << header.seqNo); + return SS_OK; + } + + inline SResult PostReceiveBody(void *buff, uint32_t dataLength, bool isOneSide) + { + if (NN_UNLIKELY(buff == nullptr || dataLength == 0)) { + return SS_PARAM_INVALID; + } + + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + uint32_t result = 0; + size_t remainingSize = static_cast(dataLength); + while (result < dataLength) { + if (!mEnableTls || isOneSide) { + ret = ::recv(mFd, buff, remainingSize, 0); + if (errno == EAGAIN) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to PostReceiveBody from peer in sock " << mId << " name " << mName << + ", error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << + " due to timeout with " << mRevTimeoutSecond << " second, " << ret << " is received"); + return SS_TIMEOUT; + } + if (ret <= 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to PostReceiveBody from peer in sock " << mId << " name " << mName << + ", error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return SS_SOCK_SEND_FAILED; + } + } else { + auto readResult = SSLRead(buff, remainingSize, reinterpret_cast(ret)); + if (readResult == SS_TIMEOUT) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("(TLS)Failed to PostReceiveBody from peer in sock " << mId << " name " << mName << + ", error " << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << + " due to timeout with " << mRevTimeoutSecond << " second, " << ret << " is received"); + return SS_TIMEOUT; + } + + if (readResult != SS_OK) { + NN_LOG_ERROR("(TLS)Failed to PostReceiveBody from peer in sock " << mId << " name " << mName); + return SS_SSL_READ_FAILED; + } + } + + buff = reinterpret_cast(reinterpret_cast(buff) + ret); + result += static_cast(ret); + remainingSize -= static_cast(ret); + } + + return SS_OK; + } + + inline SResult PostReadSgl(SockOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr)) { + return SS_PARAM_INVALID; + } + + if (mTcpBlockingMode) { + auto sendCtx = ctx->sendCtx; + struct iovec iov[NN_NO3]; + iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + iov[NN_NO1].iov_base = reinterpret_cast(&sendCtx->iovCount); + iov[NN_NO1].iov_len = sizeof(UBSHcomNetTransSglRequest::iovCount); + iov[NN_NO2].iov_base = reinterpret_cast(sendCtx->iov); + iov[NN_NO2].iov_len = sizeof(UBSHcomNetTransSgeIov) * sendCtx->iovCount; + + auto length = iov[NN_NO0].iov_len + iov[NN_NO1].iov_len + iov[NN_NO2].iov_len; + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + if ((ret = writev(mFd, reinterpret_cast(&iov), NN_NO3)) < + static_cast(length)) { + if (ret == 0) { + return SS_TCP_RETRY; + } + + if (errno == 0) { + NN_LOG_ERROR("Failed to PostReadSgl to peer in sock " << mId << " name " << mName << " with " << + mSendTimeoutSecond << " second timeout, " << ret << " is sent"); + return SS_TIMEOUT; + } + + NN_LOG_ERROR("Failed to PostReadSgl to peer in sock " << mId << " name " << mName << ", errno " << + errno << ", seqNo " << sendCtx->sendHeader.seqNo); + return SS_SOCK_SEND_FAILED; + } + NN_LOG_TRACE_INFO("PostReadSgl successfully: sock " << mId << ", head imm data " << + sendCtx->sendHeader.immData << ", flags " << sendCtx->sendHeader.flags << ", seqNo " << + sendCtx->sendHeader.seqNo << ", data len " << sendCtx->sendHeader.dataLength); + return SS_OK; + } else { + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + } + + inline SResult PostReadSglAck(SockOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr)) { + return SS_PARAM_INVALID; + } + + if (mTcpBlockingMode) { + auto sendCtx = ctx->sendCtx; + struct iovec iov[NN_NO7]; + iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + iov[NN_NO1].iov_base = reinterpret_cast(&sendCtx->iovCount); + iov[NN_NO1].iov_len = sizeof(UBSHcomNetTransSglRequest::iovCount); + iov[NN_NO2].iov_base = reinterpret_cast(sendCtx->iov); + iov[NN_NO2].iov_len = sizeof(UBSHcomNetTransSgeIov) * sendCtx->iovCount; + + auto length = iov[NN_NO0].iov_len + iov[NN_NO1].iov_len + iov[NN_NO2].iov_len; + for (uint16_t i = 0; i < sendCtx->iovCount; i++) { + iov[i + NN_NO3].iov_base = reinterpret_cast(sendCtx->iov[i].rAddress); + iov[i + NN_NO3].iov_len = sendCtx->iov[i].size; + length += iov[i + NN_NO3].iov_len; + } + + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + if ((ret = writev(mFd, reinterpret_cast(&iov), NN_NO3 + sendCtx->iovCount)) < + static_cast(length)) { + if (ret == 0) { + return SS_TCP_RETRY; + } + + if (errno == 0) { + NN_LOG_ERROR("Failed to PostReadSglAck to peer in sock " << mId << " name " << mName << " with " << + mSendTimeoutSecond << " second timeout, " << ret << " is sent"); + return SS_TIMEOUT; + } + + NN_LOG_ERROR("Failed to PostReadSglAck to peer in sock " << mId << " name " << mName << ", errno " << + errno << ", seqNo " << sendCtx->sendHeader.seqNo); + return SS_SOCK_SEND_FAILED; + } + + NN_LOG_TRACE_INFO("PostReadSglAck successfully: sock " << mId << ", head imm data " << + sendCtx->sendHeader.immData << ", flags " << sendCtx->sendHeader.flags << ", seqNo " << + sendCtx->sendHeader.seqNo << ", data len " << sendCtx->sendHeader.dataLength); + return SS_OK; + } else { + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + } + + inline SResult PostWriteSgl(SockOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr)) { + return SS_PARAM_INVALID; + } + + if (mTcpBlockingMode) { + auto sendCtx = ctx->sendCtx; + struct iovec iov[NN_NO7]; + iov[NN_NO0].iov_base = reinterpret_cast(&sendCtx->sendHeader); + iov[NN_NO0].iov_len = sizeof(SockTransHeader); + iov[NN_NO1].iov_base = reinterpret_cast(&sendCtx->iovCount); + iov[NN_NO1].iov_len = sizeof(UBSHcomNetTransSglRequest::iovCount); + iov[NN_NO2].iov_base = reinterpret_cast(sendCtx->iov); + iov[NN_NO2].iov_len = sizeof(UBSHcomNetTransSgeIov) * sendCtx->iovCount; + + for (uint16_t i = 0; i < sendCtx->iovCount; i++) { + iov[i + NN_NO3].iov_base = reinterpret_cast(sendCtx->iov[i].lAddress); + iov[i + NN_NO3].iov_len = sendCtx->iov[i].size; + } + + std::lock_guard guard(mIoMutex); + ssize_t ret = 0; + if ((ret = writev(mFd, reinterpret_cast(&iov), NN_NO3 + sendCtx->iovCount)) < + static_cast(sendCtx->sendHeader.dataLength + sizeof(SockTransHeader))) { + if (ret == 0) { + return SS_TCP_RETRY; + } + + if (errno == 0) { + NN_LOG_ERROR("Failed to PostWriteSgl to peer in sock " << mId << " name " << mName << " with " << + mSendTimeoutSecond << " second timeout, " << ret << " is sent"); + return SS_TIMEOUT; + } + + NN_LOG_ERROR("Failed to PostWriteSgl to peer in sock " << mId << " name " << mName << ", errno " << + errno << ", seqNo " << sendCtx->sendHeader.seqNo); + return SS_SOCK_SEND_FAILED; + } + + NN_LOG_TRACE_INFO("PostWriteSgl successfully: sock " << mId << ", head imm data " << + sendCtx->sendHeader.immData << ", flags " << sendCtx->sendHeader.flags << ", seqNo " << + sendCtx->sendHeader.seqNo << ", data len " << sendCtx->sendHeader.dataLength); + return SS_OK; + } else { + mSendQueue.PushBack(ctx); + return SS_SOCK_SEND_EAGAIN; + } + } + +#define NO_BODY_FLAG(flag) ((flag) == NTH_WRITE_ACK || (flag) == NTH_WRITE_SGL_ACK) + +#define RECEIVE_HEADER(result, fullReceived) \ + do { \ + result = ::recv(mFd, reinterpret_cast(headDataPtr + mReceiveState.ReceivedHeaderLen()), \ + mReceiveState.headerToBeReceived, 0); \ + if (NN_LIKELY((result) > 0)) { \ + /* header is full */ \ + if (mReceiveState.HeaderSatisfied(result)) { \ + if (NN_UNLIKELY(NetFunc::ValidateHeader(mHeader) != NN_OK)) { \ + NN_LOG_ERROR("Failed to validate received header param, sock " << mId); \ + return SockOpContextInfo::SS_OPERATE_FAILURE; \ + } \ + /* set body len to be received to the value in header */ \ + mReceiveState.bodyToBeReceived = static_cast(mHeader.dataLength); \ + /* expand memory size */ \ + if (NN_UNLIKELY(!mReceiveBuff.ExpandIfNeed(mHeader.dataLength))) { \ + NN_LOG_ERROR("Failed to expand receive buffer to " << mHeader.dataLength << \ + ", probably out of memory"); \ + return SockOpContextInfo::SS_OUT_OF_MEM; \ + } \ + \ + /* set actually body data to 0 */ \ + mReceiveBuff.ActualDataSize(0); \ + /* if head only message do upper callback directly */ \ + fullReceived = (mHeader.dataLength == 0); \ + NN_LOG_TRACE_INFO("Receive sock " << mId << " head imm data " << mHeader.immData << ", flags " << \ + mHeader.flags << ", seqNo " << mHeader.seqNo << ", data len " << mHeader.dataLength); \ + if ((fullReceived) == true || NO_BODY_FLAG(mHeader.flags)) { \ + mReceiveState.ResetHeader(); \ + fullReceived = true; \ + return SockOpContextInfo::SS_NO_ERROR; \ + } \ + } else { \ + return SockOpContextInfo::SS_NO_ERROR; /* header is not fully received, continue to receive */ \ + } \ + } else { \ + /* ECONNRESET is broken during io, SUCCESS is broken during idle time. */ \ + if (errno == ECONNRESET || errno == 0) { \ + NN_LOG_WARN("Sock " << mId << " does not receive data header, connection " \ + << " reset by peer, errno " << errno); \ + return SockOpContextInfo::SS_RESET_BY_PEER; /* socket is closed by peer, socket is error */ \ + } \ + /* if errno is eagain is normal, need to continue to receive */ \ + /* else meaning failed to read from socket, socket is error */ \ + if (errno != EAGAIN) { \ + NN_LOG_ERROR("sock " << mId << " receive header failed, errno " << errno); \ + } \ + return (errno == EAGAIN ? SockOpContextInfo::SS_NO_ERROR : SockOpContextInfo::SS_OPERATE_FAILURE); \ + } \ + } while (0) + +#define RECEIVE_BODY(result, fullReceived) \ + do { \ + /* receive body */ \ + auto dataPtr = \ + mReceiveBuff.DataIntPtr() + (mHeader.dataLength - static_cast(mReceiveState.bodyToBeReceived)); \ + if (mEnableTls && ((mHeader.flags & 0xff) == NTH_TWO_SIDE || (mHeader.flags & 0xff) == NTH_TWO_SIDE_SGL)) { \ + auto readRet = SSLRead(reinterpret_cast(dataPtr), mReceiveState.bodyToBeReceived, \ + reinterpret_cast(result)); \ + if (readRet != SS_OK) { \ + result = -1; \ + } \ + } else { \ + result = ::recv(mFd, reinterpret_cast(dataPtr), mReceiveState.bodyToBeReceived, 0); \ + } \ + if (NN_LIKELY((result) > 0)) { \ + /* body is full */ \ + if (mReceiveState.BodySatisfied(result)) { \ + mReceiveState.ResetHeader(); \ + mReceiveBuff.ActualDataSize(mHeader.dataLength); \ + fullReceived = true; \ + NN_LOG_TRACE_INFO("Receive sock " << mId << " full body size " << mHeader.dataLength); \ + return SockOpContextInfo::SS_NO_ERROR; \ + } \ + \ + NN_LOG_TRACE_INFO("Receive sock " << mId << " not full body size " << mReceiveState.bodyToBeReceived); \ + /* body is not fully received, continue to receive */ \ + return SockOpContextInfo::SS_NO_ERROR; \ + } else { \ + /* ECONNRESET is broken during io, SUCCESS is broken during idle time. */ \ + if (errno == ECONNRESET || errno == 0) { \ + NN_LOG_WARN("Sock " << mId << " does not receive data body, connection " \ + << " reset by peer, errno " << errno); \ + return SockOpContextInfo::SS_RESET_BY_PEER; /* socket is closed by peer, socket is error */ \ + } \ + /* if errno is eagain is normal, need to continue to receive */ \ + /* else meaning failed to read from socket, socket is error */ \ + if (errno != EAGAIN) { \ + NN_LOG_ERROR("sock " << mId << " receive body failed, errno " << errno); \ + } \ + return (errno == EAGAIN ? SockOpContextInfo::SS_NO_ERROR : SockOpContextInfo::SS_OPERATE_FAILURE); \ + } \ + } while (0) + + /* + * @brief Receive data when data is received + * + * @param fullReceived [out] if header and body are both received, can do upper call + * + * @param return true if socket is ok, otherwise the socket is broken then need to do connection broken process + * + */ + inline SockOpContextInfo::SockErrorType HandleIn(bool &fullReceived) + { + const auto headDataPtr = reinterpret_cast(&mHeader); + + fullReceived = false; + + /* receive header */ + ssize_t result = 0; + if (mReceiveState.ShouldReceiveHeader()) { + RECEIVE_HEADER(result, fullReceived); + } + + RECEIVE_BODY(result, fullReceived); + } + + inline bool HandleOut() + { + return true; + } + + inline void SetSockOpContextInfoPool(SockOpContextInfoPool opCtxInfoPool) + { + mOpCtxInfoPool = opCtxInfoPool; + } + + inline void SetSockSglContextInfoPool(SockSglContextInfoPool sglCtxInfoPool) + { + mSglCtxInfoPool = sglCtxInfoPool; + } + + inline void SetSockHeaderReqInfoPool(SockHeaderReqInfoPool headerReqInfoPool) + { + mHeaderReqInfoPool = headerReqInfoPool; + } + + inline void SetSockDriverSendMR(NormalMemoryRegionFixedBuffer *sockDriverSendMR) + { + mSockDriverSendMR = sockDriverSendMR; + } + + inline void SetSockPostedHandler(SockPostedHandler sockPostedHandler) + { + mSendPostedHandler = sockPostedHandler; + } + + inline void SetSockOneSideHandler(SockOneSideHandler oneSideDoneHandler) + { + mOneSideDoneHandler = oneSideDoneHandler; + } + + inline void SetMrChecker(MemoryRegionChecker *checker) + { + mMrChecker = checker; + } + + inline bool GetQueueSpace(uint32_t times = NN_NO8, uint32_t sleepUs = NN_NO64) + { + while (times-- > 0) { + if (NN_LIKELY(__sync_sub_and_fetch(&mQueueVacantSize, NN_NO1) >= 0)) { + return true; + } + __sync_add_and_fetch(&mQueueVacantSize, NN_NO1); + usleep(sleepUs); + } + return false; + } + + inline void ReturnQueueSpace(uint16_t size) + { + int32_t ref = __sync_add_and_fetch(&mQueueVacantSize, size); + if (ref > mQueueSize) { + NN_LOG_WARN("Queue size " << ref << " over capacity " << mQueueSize); + } + } + + std::string ToString() + { + std::ostringstream oss; + oss << "info [type " << SockTypeToString(mType) << ", name " << mName << ", id " << mId << ", peer-ip-port " << + mPeerIpPort << ", up-ctx: " << mUpCtx << ", up-ctx1: " << mUpCtx1 << ", rev-buff-size: " << + mReceiveBuff.Size(); + return oss.str(); + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +protected: + SResult SetSockOption(const SockWorkerOptions &workerOptions); + SResult ValidateOptions(); + SResult SetBlockingIo(); + SResult SetNonBlockingIo(); + + /* + * @brief Set ip and port of peer + */ + inline void PeerIpPort(const std::string &value) + { + mPeerIpPort = value; + } + + /* + * @brief store connect info + */ + inline void StoreConnInfo(uint32_t localIp, uint16_t listenPort, uint8_t version) + { + mLocalIp = localIp; + mListenPort = listenPort; + mVersion = version; + } + + int SSLSend(const void *buf, uint32_t size, uint32_t &writeLen) + { + int ret = HcomSsl::SslWrite(mSsl, buf, size); + if (ret <= 0) { + int sslErrCode = HcomSsl::SslGetError(mSsl, ret); + if (sslErrCode == HcomSsl::SSL_ERROR_WANT_WRITE) { + return SS_TIMEOUT; + } + NN_LOG_ERROR("Failed to write data to TLS channel, ret: " << ret << ", errno: " << sslErrCode << + " write Len: " << size); + return SS_OOB_SSL_WRITE_ERROR; + } + writeLen = static_cast(ret); + return SS_OK; + } + + SResult SSLRead(void *buff, size_t dataLength, uint32_t &readLen) + { + auto ret = HcomSsl::SslRead(mSsl, buff, dataLength); + if (ret <= 0) { + int sslErrCode = HcomSsl::SslGetError(mSsl, ret); + if (sslErrCode == HcomSsl::SSL_ERROR_WANT_READ) { + return SS_TIMEOUT; + } + NN_LOG_ERROR("SSL read failed sock id " << mId << " name " << mName << ", error " << sslErrCode); + return SS_SSL_READ_FAILED; + } + readLen = static_cast(ret); + return SS_OK; + } + +protected: + int mFd = -1; /* socket fd */ + uint64_t mUpCtx = 0; /* up context */ + uint64_t mUpCtx1 = 0; /* up context 1 */ + SockBuff mReceiveBuff; /* one extendable receive buffer */ + SockTransHeader mHeader; /* sock command header */ + SockReceiveState mReceiveState {}; /* receive data status */ + bool mCbByWorkerInBlocking = false; /* worker call send post cb for blocking io */ + bool mTcpBlockingMode = true; /* tcp mode: nonblocking in default */ + int64_t mQueueVacantSize = 0; + int64_t mQueueSize = 0; + std::mutex mInitMutex; + SockOptions mOptions; /* sock options */ + SSL *mSsl = nullptr; + NetSecrets mSecret; + bool mEnableTls = true; + std::string mName; /* name of sock */ + std::string mPeerIpPort; /* peer ip and port */ + uint32_t mLocalIp = INVALID_IP; + uint16_t mListenPort = 0; + uint8_t mVersion = 0; + uint64_t mId = 0; /* uid */ + SockType mType = SOCK_TCP; /* sock type */ + bool mInited = false; /* inited or not */ + + std::mutex mIoMutex; + + uint32_t mSeqIndex = 1; + std::mutex mCtxMutex; /* op context mutex */ + std::unordered_map mCtxMap; /* op context map */ + + NetRingBuffer mSendQueue; + SendingQueueRequest mSendingQueueRequest; + + SockOpContextInfoPool mOpCtxInfoPool; + SockSglContextInfoPool mSglCtxInfoPool; + SockHeaderReqInfoPool mHeaderReqInfoPool; + SockPostedHandler mSendPostedHandler = nullptr; + SockOneSideHandler mOneSideDoneHandler = nullptr; + NormalMemoryRegionFixedBuffer *mSockDriverSendMR = nullptr; + MemoryRegionChecker *mMrChecker = nullptr; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class SockWorker; + friend class NetDriverSockWithOOB; + +private: + int32_t mSendTimeoutSecond = -1; + int32_t mRevTimeoutSecond = -1; +}; +} +} + +#endif // OCK_HCOM_SOCK_WRAPPER_H_234234 diff --git a/src/transport/ub/net_ub_driver.cpp b/src/transport/ub/net_ub_driver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bce734100a84a38eac0215a65db0f3e3de67a7e3 --- /dev/null +++ b/src/transport/ub/net_ub_driver.cpp @@ -0,0 +1,735 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED + +#include "hcom_def.h" +#include "net_ub_driver.h" +#include "net_ub_endpoint.h" +#include "openssl_api_wrapper.h" +#include "under_api/obmm/obmm_api_wrapper.h" +#include "ub_common.h" +#include "ub_mr_fixed_buf.h" +#include "ub_worker.h" + +namespace ock { +namespace hcom { +NResult NetDriverUB::Initialize(const UBSHcomNetDriverOptions &option) +{ + std::lock_guard lock(mInitMutex); + if (mInited) { + return NN_OK; + } + + mOptions = option; + + if (NN_UNLIKELY(UBSHcomNetOutLogger::Instance() == nullptr)) { + return NN_NOT_INITIALIZED; + } + + NResult result = NN_OK; + if (NN_UNLIKELY((result = mOptions.ValidateCommonOptions()) != NN_OK)) { + return result; + } + + if (NN_UNLIKELY((result = ValidateOptions()) != NN_OK)) { + return result; + } + + NN_LOG_INFO("Try to initialize with " << mOptions.ToString()); + + if (option.enableTls) { + if (HcomSsl::Load() != 0) { + NN_LOG_ERROR("Failed to load openssl API"); + return NN_NOT_INITIALIZED; + } + } + mEnableTls = option.enableTls; + mHeartBeatIdleTime = mOptions.heartBeatIdleTime; + mHeartBeatProbeInterval = mOptions.heartBeatProbeInterval; + + // create context and initialize + if (((result = CreateContext()) != NN_OK)) { + NN_LOG_ERROR("UB failed to create ctx"); + UnInitializeInner(); + return result; + } + + if (((result = mContext->Initialize()) != 0)) { + NN_LOG_ERROR("UB failed to initialize ctx"); + UnInitializeInner(); + return result; + } + + if ((result = ValidaQpQueueSizeOptions()) != NN_OK) { + NN_LOG_ERROR("UB failed to validate qp queue size options"); + UnInitializeInner(); + return result; + } + + if ((result = CreateWorkerResource()) != NN_OK) { + NN_LOG_ERROR("UB failed to create worker resource"); + UnInitializeInner(); + return result; + } + + if ((result = CreateWorkers()) != NN_OK) { + NN_LOG_ERROR("UB failed to create workers"); + UnInitializeInner(); + return result; + } + + /* create lb for client */ + if ((result = CreateClientLB()) != NN_OK) { + NN_LOG_ERROR("UB failed to create client lb"); + UnInitializeInner(); + return result; + } + + if ((result = DoInitialize()) != NN_OK) { + NN_LOG_ERROR("UB failed to do initialize"); + UnInitializeInner(); + return result; + } + + mMrChecker.Reserve(NN_NO128); + mMrChecker.SetLockWhenOperates(false); + + mInited = true; + return NN_OK; +} + +NResult NetDriverUB::ValidateOptions() +{ + /* validate param related to device IpMask for UB and Sock */ + if (NN_UNLIKELY(!ValidateArrayOptions(mOptions.netDeviceIpMask, NN_NO256))) { + NN_LOG_ERROR("Option 'netDeviceIpMask' is invalid, " << mOptions.netDeviceIpMask << + " is set in driver,the Array max length is 256."); + return NN_INVALID_PARAM; + } + + uint64_t sendRecvMrSize = static_cast(mOptions.mrSendReceiveSegCount) * mOptions.mrSendReceiveSegSize; + + if (mOptions.prePostReceiveSizePerQP == 0) { + NN_LOG_ERROR("Invalid option prePostReceiveSizePerQP " << mOptions.prePostReceiveSizePerQP << + ", should not be zero"); + return NN_INVALID_PARAM; + } + + // 32K 为硬件 max_jfr_depth 上限 + if (mOptions.prePostReceiveSizePerQP > NN_NO32768) { + NN_LOG_WARN("Invalid option prePostReceiveSizePerQP " << mOptions.prePostReceiveSizePerQP << + ", should be <= " << NN_NO32768 << ", set to " << NN_NO32768); + mOptions.prePostReceiveSizePerQP = NN_NO32768; + } + + if (mOptions.maxPostSendCountPerQP == 0) { + NN_LOG_ERROR("Invalid option maxPostSendCountPerQP " << mOptions.maxPostSendCountPerQP << + ", should not be zero"); + return NN_INVALID_PARAM; + } + + if (mOptions.maxPostSendCountPerQP > NN_NO32768) { + NN_LOG_WARN("Invalid option maxPostSendCountPerQP " << mOptions.maxPostSendCountPerQP << ", should be <= " << + NN_NO32768 << ", set to " << NN_NO32768); + mOptions.maxPostSendCountPerQP = NN_NO32768; + } + + if (mOptions.maxPostSendCountPerQP > mOptions.prePostReceiveSizePerQP) { + NN_LOG_WARN("Invalid option maxPostSendCountPerQP " << mOptions.maxPostSendCountPerQP << + ", over than prePostReceiveSizePerQP " << mOptions.prePostReceiveSizePerQP << " , change to equal"); + mOptions.maxPostSendCountPerQP = mOptions.prePostReceiveSizePerQP; + } + + if (NN_UNLIKELY(ValidateAndParseOobPortRange(mOptions.oobPortRange) != NN_OK)) { + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(ValidateOptionsOobType() != NN_OK)) { + return NN_INVALID_PARAM; + } + + return NN_OK; +} + +NResult NetDriverUB::ValidaQpQueueSizeOptions() +{ + if (mContext == nullptr) { + NN_LOG_WARN("Unable to get system max jfs and max jfr, cannot compare with Option 'qpSendQueueSize' and " + "'qpReceiveQueueSize'."); + return NN_OK; + } + uint32_t maxJfs = mContext->GetMaxJfs(); + uint32_t maxJfr = mContext->GetMaxJfr(); + if (NN_UNLIKELY(maxJfs < NN_NO8 || maxJfr < NN_NO8)) { + NN_LOG_ERROR("Urma max Jfs and max jfr less than or equal to 8. "); + return NN_ERROR; + } + uint32_t needJfs = maxJfs - NN_NO8; + uint32_t needJfr = maxJfr - NN_NO8; + if (mOptions.qpSendQueueSize > needJfs) { + NN_LOG_WARN("Urma max Jfs is " << maxJfs << " , urma option 'qpSendQueueSize' range is 16~" << needJfs << + " ,change 'qpSendQueueSize' to " << needJfs); + mOptions.qpSendQueueSize = needJfs; + } + if (mOptions.qpReceiveQueueSize > needJfr) { + NN_LOG_WARN("Urma max Jfr is " << maxJfr << " , urma option 'qpReceiveQueueSize' range is 16~" << needJfr << + " ,change 'qpReceiveQueueSize' to " << needJfr); + mOptions.qpReceiveQueueSize = needJfr; + } + return NN_OK; +} + +NResult NetDriverUB::GetDeviceByIp(UBEId &tmpEid) +{ + int result = 0; + if (mOptions.enableMultiRail) { + uint16_t enableCount = 0; + std::vector enableIps; + result = UBDeviceHelper::GetEnableDeviceCount(mOptions.NetDeviceIpMask(), enableCount, enableIps, + mOptions.NetDeviceIpGroup()); + if (result != NN_OK) { + return result; + } + mMatchIp = enableIps[mDevIndex]; + } else { + // filter ip by mask + std::vector filters; + NetFunc::NN_SplitStr(mOptions.NetDeviceIpMask(), ",", filters); + if (filters.empty()) { + NN_LOG_ERROR("Invalid ip mask '" << mOptions.netDeviceIpMask << "' by set, example '192.168.0.0/24'"); + return NN_INVALID_IP; + } + + std::vector matchIps; + for (auto &mask : filters) { + FilterIp(mask, matchIps); + } + + if (matchIps.empty()) { + NN_LOG_ERROR("No matched ip found with '" << mOptions.netDeviceIpMask << "', example '192.168.0.0/24'"); + return NN_INVALID_IP; + } + // init urma devices + if ((result = UBDeviceHelper::Initialize()) != 0) { + NN_LOG_ERROR("Failed to init devices"); + return result; + } + + NN_LOG_INFO(UBDeviceHelper::DeviceInfo()); + + // choose the first matched ip + mMatchIp = matchIps[0]; + } + + if ((result = UBDeviceHelper::GetDeviceByIp(mMatchIp, tmpEid)) != 0) { + UBDeviceHelper::UnInitialize(); + NN_LOG_ERROR("Failed to get device by ip"); + return result; + } + return NN_OK; +} + +NResult NetDriverUB::GetDeviceByEid(UBEId &tmpEid) +{ + int result = 0; + if (Protocol() != UBSHcomNetDriverProtocol::UBC) { + NN_LOG_ERROR("UBSHcomUbcMode should only be enabled on UBC protocol"); + return NN_ERROR; + } + + // init urma devices + if ((result = UBDeviceHelper::Initialize()) != 0) { + NN_LOG_ERROR("Failed to init devices"); + return result; + } + + NN_LOG_INFO(UBDeviceHelper::DeviceInfo()); + + if ((result = UBDeviceHelper::GetDeviceByEid(mOptions.netDeviceEid, tmpEid)) != 0) { + UBDeviceHelper::UnInitialize(); + NN_LOG_ERROR("Failed to get device by eid"); + return result; + } + return NN_OK; +} + +NResult NetDriverUB::GetDeviceByName(UBEId &tmpEid) +{ + int result = 0; + // init urma devices + if ((result = UBDeviceHelper::Initialize()) != 0) { + NN_LOG_ERROR("Failed to init devices"); + return result; + } + + NN_LOG_INFO(UBDeviceHelper::DeviceInfo()); + // hard code + char name[] = "bonding"; + uint8_t len = strlen(name); + if ((result = UBDeviceHelper::GetDeviceByName(name, len, tmpEid)) != 0) { + UBDeviceHelper::UnInitialize(); + NN_LOG_ERROR("Failed to get device by name"); + return result; + } + + return NN_OK; +} + +NResult NetDriverUB::CreateContext() +{ + if (mContext != nullptr) { + return NN_OK; + } + + int result = 0; + UBEId tmpEid{}; + if (Protocol() == UBSHcomNetDriverProtocol::UBC) { + if (GetDeviceByName(tmpEid) != 0) { + NN_LOG_ERROR("Failed to get device by name"); + return NN_ERROR; + } + } else if (mOptions.oobType == NET_OOB_UB) { + if (GetDeviceByEid(tmpEid) != 0) { + NN_LOG_ERROR("Failed to get device by eid"); + return NN_ERROR; + } + } else { + if (GetDeviceByIp(tmpEid) != 0) { + NN_LOG_ERROR("Failed to get device by ip"); + return NN_ERROR; + } + } + + mBandWidth = tmpEid.bandWidth; + NN_LOG_INFO("eid found devIndex " << tmpEid.devIndex << ", eidIndex " << tmpEid.eidIndex); + + // create context + if ((result = UBContext::Create(mName, tmpEid, mContext)) != 0) { + UBDeviceHelper::UnInitialize(); + NN_LOG_ERROR("Failed to new ctx, result " << result); + return result; + } + + NN_ASSERT_LOG_RETURN(mContext != nullptr, NN_ERROR); + + mContext->IncreaseRef(); + mContext->protocol = Protocol(); + return NN_OK; +} + +NResult NetDriverUB::CreateSendMr(uint8_t slave) +{ + int result = 0; + // create mr pool for send/receive and initialize + if ((result = UBMemoryRegionFixedBuffer::Create(mName, mContext, mOptions.mrSendReceiveSegSize, + mOptions.mrSendReceiveSegCount, slave, mDriverSendMR)) != 0) { + NN_LOG_ERROR("Failed to create mr for send/receive in NetDriverUB " << mName << ", result " << result); + return result; + } + mDriverSendMR->IncreaseRef(); + if ((result = mDriverSendMR->Initialize()) != 0) { + NN_LOG_ERROR("Failed to initialize mr for send/receive in NetDriverUB " << mName << ", result " << result); + return result; + } + + return NN_OK; +} + +NResult NetDriverUB::CreateOpCtxMemPool() +{ + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = NN_NextPower2(sizeof(UBOpContextInfo)); + options.tcExpandBlkCnt = NN_NO64; + mOpCtxMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mOpCtxMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for UB op context pool " << mName << ", probably out of memory"); + return NN_INVALID_PARAM; + } + + auto result = mOpCtxMemPool->Initialize(); + if (result != NN_OK) { + mOpCtxMemPool.Set(nullptr); + NN_LOG_ERROR("Failed to initialize memory pool for UB op context pool " << mName << ", probably out of memory"); + return result; + } + + return NN_OK; +} + +NResult NetDriverUB::CreateSglCtxMemPool() +{ + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO1; + options.minBlkSize = NN_NextPower2(sizeof(UBSglContextInfo)); + options.tcExpandBlkCnt = NN_NO64; + mSglCtxMemPool = new (std::nothrow) NetMemPoolFixed(mName, options); + if (mSglCtxMemPool.Get() == nullptr) { + NN_LOG_ERROR("Failed to create memory pool for UB sgl op context in driver " << mName << + ", probably out of memory"); + return NN_INVALID_PARAM; + } + + auto result = mSglCtxMemPool->Initialize(); + if (result != NN_OK) { + mSglCtxMemPool.Set(nullptr); + NN_LOG_ERROR("Failed to initialize memory pool for UB sgl op context in driver " << mName << + ", probably out of memory"); + return result; + } + + return NN_OK; +} + +NResult NetDriverUB::CreateWorkerResource() +{ + auto result = CreateSendMr(mOptions.slave); + if (result != NN_OK) { + NN_LOG_ERROR("UB falied to create send mr"); + return result; + } + + result = CreateOpCtxMemPool(); + if (result != NN_OK) { + NN_LOG_ERROR("UB Failed to create op ctx memory pool"); + return result; + } + + result = CreateSglCtxMemPool(); + if (NN_UNLIKELY(result != NN_OK)) { + NN_LOG_ERROR("UB failed to create Sgl ctx memory pool"); + } + + return NN_OK; +} + +void NetDriverUB::ClearWorkers() +{ + mWorkerGroups.clear(); + for (auto worker : mWorkers) { + worker->DecreaseRef(); + } + mWorkers.clear(); +} + +void NetDriverUB::DestroyEndpoint(UBSHcomNetEndpointPtr &ep) +{ + if (ep.Get() == nullptr) { + NN_LOG_WARN("The ub ep is null already."); + return; + } + + NN_LOG_INFO("Destroy endpoint id " << ep->Id()); + mEndPointsMutex.lock(); + auto result = mEndPoints.erase(ep->Id()); + mEndPointsMutex.unlock(); + + if (result == 0) { + NN_LOG_WARN("Unable to destroy ub endpoint as ep " << ep->Id() << " doesn't exist, maybe cleaned already"); + return; + } + + ep.Set(nullptr); +} + +NResult NetDriverUB::CreateWorkers() +{ + NResult result = NN_OK; + + std::vector workerGroups; + std::vector> workerGroupCpus; + std::vector flatWorkerCpus; + std::vector workerThreadPriority; + + /* parse */ + if (!(NetFunc::NN_ParseWorkersGroups(mOptions.WorkGroups(), workerGroups)) || + !(NetFunc::NN_ParseWorkerGroupsCpus(mOptions.WorkerGroupCpus(), workerGroupCpus)) || + !(NetFunc::NN_FinalizeWorkerGroupCpus(workerGroups, workerGroupCpus, mOptions.mode != NET_BUSY_POLLING, + flatWorkerCpus)) || + !(NetFunc::NN_ParseWorkersGroupsThreadPriority(mOptions.WorkerGroupThreadPriority(), + workerThreadPriority, workerGroups.size()))) { + NN_LOG_ERROR("Failed to parse worker or cpu groups"); + return NN_INVALID_PARAM; + } + + UBWorkerOptions options; + options.SetValue(mOptions); + if ((mOptions.workerThreadPriority != 0) && (!workerThreadPriority.empty())) { + NN_LOG_WARN("Driver options 'workerThreadPriority' and 'workerGroupsThreadPriority' set all, preferential use " + "'workerGroupsThreadPriority'."); + } + + /* create workers */ + mWorkers.reserve(flatWorkerCpus.size()); + uint32_t groupIndex = 0; + UBSHcomNetWorkerIndex workerIndex{}; + uint16_t totalWorkerIndex = 0; + for (auto item : workerGroups) { + NN_LOG_TRACE_INFO("add worker " << groupIndex << ", item " << item); + /* The left of mWorkerGroups is the index of each group's first worker in the mWorkers */ + mWorkerGroups.emplace_back(totalWorkerIndex, item); + for (uint16_t i = 0; i < item; ++i) { + options.cpuId = flatWorkerCpus.at(totalWorkerIndex++); + if (!workerThreadPriority.empty()) { + options.threadPriority = workerThreadPriority[groupIndex]; + } + UBWorker *worker = nullptr; + if (NN_UNLIKELY( + (result = UBWorker::Create(mName, mContext, options, mOpCtxMemPool, mSglCtxMemPool, worker)) != 0)) { + return result; + } + + workerIndex.Set(i, groupIndex, mIndex); + worker->SetIndex(workerIndex); + + if (NN_UNLIKELY((result = worker->Initialize()) != NN_OK)) { + delete worker; + NN_LOG_ERROR("Failed to initialize UB worker in driver " << mName << ", result " << result); + return NN_NEW_OBJECT_FAILED; + } + + worker->IncreaseRef(); + mWorkers.push_back(worker); + } + ++groupIndex; + } + + std::ostringstream groupInfo; + groupInfo << "Worker group info :"; + for (auto item : mWorkerGroups) { + groupInfo << " [" << item.first << " : " << item.second << "] "; + } + NN_LOG_TRACE_INFO(groupInfo.str()); + return NN_OK; +} + +void NetDriverUB::UnInitialize() +{ + std::lock_guard locker(mInitMutex); + if (!mInited) { + return; + } + if (mStarted) { + NN_LOG_WARN("Unable to unInitialize ub driver " << mName << " which is not stopped"); + return; + } + + DoUnInitialize(); + + UnInitializeInner(); + mInited = false; +} + +void NetDriverUB::UnInitializeInner() +{ + if (mContext != nullptr) { + mContext->DecreaseRef(); + mContext = nullptr; + } + + if (mDriverSendMR != nullptr) { + mDriverSendMR->DecreaseRef(); + mDriverSendMR = nullptr; + } + + if (mOpCtxMemPool != nullptr) { + mOpCtxMemPool.Set(nullptr); + } + + DestroyClientLB(); + ClearWorkers(); + if (!mMapVaSgeForUB.empty()) { + for (const auto &pair : mMapVaSgeForUB) { + uint64_t va = pair.first; + UnmapVaForUB(va); + } + mMapVaSgeForUB.clear(); + } +} + +#define DRIVER_CHECK_HANDLES() \ + do { \ + if (mReceivedRequestHandler == nullptr) { \ + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as receivedRequestHandler is null"); \ + return NN_INVALID_PARAM; \ + } \ + \ + if (mRequestPostedHandler == nullptr) { \ + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as requestPostedHandler is null"); \ + return NN_INVALID_PARAM; \ + } \ + \ + if (mOneSideDoneHandler == nullptr) { \ + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as oneSideDoneHandler is null"); \ + return NN_INVALID_PARAM; \ + } \ + \ + if (mEndPointBrokenHandler == nullptr) { \ + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as endPointBrokenHandler is null"); \ + return NN_INVALID_PARAM; \ + } \ + } while (0) + +NResult NetDriverUB::Start() +{ + std::lock_guard locker(mInitMutex); + if (mStarted) { + return NN_OK; + } + + if (!mInited) { + NN_LOG_ERROR("Failed to start NetDriverUB " << mName << ", as isn't initialized"); + return NN_ERROR; + } + + NResult result = NN_OK; + if (!mOptions.dontStartWorkers) { + DRIVER_CHECK_HANDLES(); + for (uint64_t i = 0; i < mWorkers.size(); i++) { + if (NN_LIKELY((result = mWorkers[i]->Start()) == NN_OK)) { + continue; + } + NN_LOG_ERROR("Failed to start driver " << mName << " as failed to start worker"); + for (uint64_t j = 0; j < i; j++) { + mWorkers[j]->Stop(); + } + return result; + } + } else { + NN_LOG_INFO("Workers in driver " << mName << " will not be started as dontStartWorkers is true"); + } + + if (NN_UNLIKELY(result = DoStart()) != NN_OK) { + NN_LOG_ERROR("Failed to do start NetDriverUB " << mName << ", result " << result); + for (auto worker : mWorkers) { + worker->Stop(); + } + return result; + } + mStarted = true; + return NN_OK; +} + +void NetDriverUB::Stop() +{ + std::lock_guard locker(mInitMutex); + if (!mStarted) { + return; + } + + DoStop(); + + for (auto worker : mWorkers) { + worker->Stop(); + } + + mStarted = false; +} + +NResult NetDriverUB::CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + if (NN_UNLIKELY(size == 0 || size > NN_NO107374182400)) { + NN_LOG_ERROR("Failed to create mem region as size is 0 or greater than 100 GB"); + return NN_INVALID_PARAM; + } + + if (!mInited) { + NN_LOG_ERROR("Failed to create Memory region in NetDriverUB " << mName << ", as not initialized"); + return NN_EP_NOT_INITIALIZED; + } + + UBMemoryRegion *tmp = nullptr; + auto res = UBMemoryRegion::Create(mName, mContext, size, tmp); + if (NN_UNLIKELY(res != UB_OK)) { + NN_LOG_ERROR("Failed to create Memory region in NetDriverUB " << mName << ", probably out of memory"); + return res; + } + + if ((res = tmp->InitializeForOneSide()) != UB_OK) { + delete tmp; + return res; + } + + mr.Set(static_cast(tmp)); + std::lock_guard locker(mLockTseg); + mMapTseg.emplace(mr->GetLKey(), static_cast(tmp->GetMemorySeg())); + + return NN_OK; +} + +NResult NetDriverUB::CreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + if (!mInited) { + NN_LOG_ERROR("Failed to create Memory region with ptr in NetDriverUB " << mName << ", as not initialized"); + return NN_EP_NOT_INITIALIZED; + } + + if (address == 0) { + NN_LOG_ERROR("Failed to create Memory region with ptr in NetDriverUB " << mName << ", as address is 0"); + return NN_INVALID_PARAM; + } + + UBMemoryRegion *tmp = nullptr; + auto result = UBMemoryRegion::Create(mName, mContext, address, size, tmp); + if (NN_UNLIKELY(result != UB_OK)) { + NN_LOG_ERROR("Failed to create Memory region with ptr in NetDriverUB " << mName << ", probably out of memory"); + return result; + } + + if ((result = tmp->InitializeForOneSide()) != UB_OK) { + delete tmp; + return result; + } + + mr.Set(static_cast(tmp)); + std::lock_guard locker(mLockTseg); + mMapTseg.emplace(mr->GetLKey(), static_cast(tmp->GetMemorySeg())); + + return NN_OK; +} + +NResult NetDriverUB::CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, unsigned long memid) +{ + NN_LOG_ERROR("operation not supported in non-hccs NetDriverUB "); + return NN_ERROR; +} + +void NetDriverUB::DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) +{ + if (mr.Get() == nullptr) { + NN_LOG_WARN("Try to destroy null memory region in UB driver " << mName); + return; + } + + std::lock_guard locker(mLockTseg); + auto result = mMapTseg.erase(mr->GetLKey()); + if (result == 0) { + NN_LOG_WARN("Unable to erase mr from driver as " << mr->GetLKey() << " doesn't exist, maybe cleaned already"); + } + + mr->UnInitialize(); +} + +void *NetDriverUB::MapAndRegVaForUB(unsigned long memid, uint64_t &va) +{ + NN_LOG_ERROR("operation not supported in non-hccs NetDriverUB "); + return nullptr; +} + +NResult NetDriverUB::UnmapVaForUB(uint64_t &va) +{ + NN_LOG_ERROR("operation not supported in non-hccs NetDriverUB "); + return NN_ERROR; +} +} +} +#endif diff --git a/src/transport/ub/net_ub_driver.h b/src/transport/ub/net_ub_driver.h new file mode 100644 index 0000000000000000000000000000000000000000..dd9c726264cab05e14bee7e740588d582816b139 --- /dev/null +++ b/src/transport/ub/net_ub_driver.h @@ -0,0 +1,124 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_NET_UB_DRIVER_H +#define HCOM_NET_UB_DRIVER_H +#ifdef UB_BUILD_ENABLED + +#include +#include + +#include "hcom.h" +#include "net_common.h" +#include "ub_device_helper.h" +#include "net_mem_pool_fixed.h" + +namespace ock { +namespace hcom { +class UBWorker; +class NetDriverUB : public UBSHcomNetDriver { +public: + NetDriverUB(const std::string &name, bool isServer, UBSHcomNetDriverProtocol protocol) + : UBSHcomNetDriver(name, isServer, protocol) + { + OBJ_GC_INCREASE(NetDriverUB); + } + + ~NetDriverUB() override + { + OBJ_GC_DECREASE(NetDriverUB); + } + + NResult Initialize(const UBSHcomNetDriverOptions &option) override; + void UnInitialize() override; + + NResult Start() override; + void Stop() override; + + NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) override; + NResult CreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) override; + NResult CreateMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr, unsigned long memid) override; + void DestroyMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) override; + + void *MapAndRegVaForUB(unsigned long memid, uint64_t &va) override; + + NResult UnmapVaForUB(uint64_t &va) override; + + inline NResult ValidateMemoryRegion(uint64_t lKey, uintptr_t address, uint64_t size) + { + return NN_OK; + } + + inline NResult GetTseg(uint64_t lKey, urma_target_seg_t *&tseg) + { + std::lock_guard locker(mLockTseg); + auto it = mMapTseg.find(lKey); + if (it == mMapTseg.end()) { + NN_LOG_ERROR("Failed to get tseg by lkey: " << lKey); + return UB_PARAM_INVALID; + } + + tseg = it->second; + return NN_OK; + } + + void DestroyEndpoint(UBSHcomNetEndpointPtr &ep) override; + +protected: + NResult ValidateOptions(); + NResult ValidaQpQueueSizeOptions(); + NResult CreateContext(); + NResult CreateWorkers(); + NResult GetDeviceByIp(UBEId &tmpEid); + NResult GetDeviceByEid(UBEId &tmpEid); + NResult GetDeviceByName(UBEId &tmpEid); + void ClearWorkers(); + void UnInitializeInner(); + virtual NResult DoInitialize() + { + return NN_OK; + } + + virtual void DoUnInitialize() {} + + virtual NResult DoStart() + { + return NN_OK; + } + + virtual void DoStop() {} + +protected: + std::string mMatchIp; + UBContext *mContext = nullptr; + std::vector mWorkers; + UBMemoryRegionFixedBuffer *mDriverSendMR = nullptr; + MemoryRegionChecker mMrChecker; + uint32_t mHeartBeatIdleTime = NN_NO8; + uint32_t mHeartBeatProbeInterval = NN_NO1; + std::map mMapVaSgeForUB; +private: + NResult CreateSendMr(uint8_t slave); + NResult ImportRemotePA(unsigned long memid); + NResult CreateOpCtxMemPool(); + NResult CreateSglCtxMemPool(); + NResult CreateWorkerResource(); + NetMemPoolFixedPtr mOpCtxMemPool = nullptr; + NetMemPoolFixedPtr mSglCtxMemPool = nullptr; + std::map mMapTseg; + std::mutex mLockTseg; +}; +} +} + +#endif +#endif // HCOM_NET_UB_DRIVER_H diff --git a/src/transport/ub/net_ub_driver_oob.cpp b/src/transport/ub/net_ub_driver_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..18f229c4505959a515d16b9f0e4b5dd71d5d489e --- /dev/null +++ b/src/transport/ub/net_ub_driver_oob.cpp @@ -0,0 +1,1973 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED +#include +#include + +#include "net_monotonic.h" +#include "net_oob_ssl.h" +#include "net_ub_endpoint.h" +#include "net_ub_driver_oob.h" +#include "net_oob_secure.h" +#include "ub_worker.h" + +namespace ock { +namespace hcom { +constexpr uint64_t MAX_OP_TIME_US = NN_NO500000; // 500 ms +uint64_t g_connection_count = 0; + +NResult NetDriverUBWithOob::DoInitialize() +{ + if (mWorkers.empty()) { + NN_LOG_ERROR("Failed to do initialize in Driver " << mName << ", as mWorkers is empty"); + } + + for (auto worker : mWorkers) { + worker->RegisterPostedHandler(std::bind(&NetDriverUBWithOob::SendFinished, this, std::placeholders::_1)); + worker->RegisterNewRequestHandler(std::bind(&NetDriverUBWithOob::NewRequest, this, std::placeholders::_1)); + worker->RegisterOneSideDoneHandler(std::bind(&NetDriverUBWithOob::OneSideDone, this, std::placeholders::_1)); + if (mIdleHandler != nullptr) { + worker->RegisterIdleHandler(mIdleHandler); + } + } + + NResult result = NN_OK; + // create oob + if (mStartOobSvr) { + if (mOptions.oobType != NET_OOB_UB) { + result = CreateListeners(mOptions.enableMultiRail); + } else { + result = CreateUrmaListeners(mPublicJetty); + } + if (result != NN_OK) { + NN_LOG_ERROR("Failed to create listeners"); + return NN_ERROR; + } + } + mEndPoints.reserve(NN_NO1024); + + return NN_OK; +} + +void NetDriverUBWithOob::DoUnInitialize() +{ + if (mStarted) { + NN_LOG_WARN("Unable to uninitialize ub driver " << mName << " which is not stopped"); + return; + } + + if (!mOobServers.empty()) { + mOobServers.clear(); + } +} + +NResult NetDriverUBWithOob::DoStart() +{ + NResult result = NN_OK; + if (mStartOobSvr) { + if (mOptions.oobType != NET_OOB_UB) { + if (mNewEndPointHandler == nullptr) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as newEndPointerHandler is null"); + return NN_INVALID_PARAM; + } + + /* set cb for listeners */ + for (auto &oobServer : mOobServers) { + oobServer->SetNewConnCB(std::bind(&NetDriverUBWithOob::NewConnectionCB, this, std::placeholders::_1)); + oobServer->SetNewConnCbThreadNum(mOptions.oobConnHandleThreadCount); + oobServer->SetNewConnCbQueueCap(mOptions.oobConnHandleQueueCap); + } + + result = StartListeners(); + if (result != NN_OK) { + NN_LOG_ERROR("Failed to start listeners for driver " << mName << ", result " << result); + return result; + } + } else { + mPublicJetty->SetNewConnCB( + std::bind(&NetDriverUBWithOob::PublicJettyNewConnectionCB, this, std::placeholders::_1)); + result = mPublicJetty->StartPublicJetty(); + if (result != NN_OK) { + NN_LOG_ERROR("Failed to start public jetty for driver " << mName << ", result " << result); + return result; + } + } + } + + mHeartBeat = new (std::nothrow) NetHeartbeat(this, mOptions.heartBeatIdleTime, mOptions.heartBeatProbeInterval); + if (mHeartBeat == nullptr) { + NN_LOG_ERROR("Failed to do start in Driver " << mName << ", as new heartbeat failed"); + return NN_ERROR; + } + + result = mHeartBeat->Start(); + if (result != NN_OK) { + StopListeners(); + return result; + } + + mNeedStopEvent = false; + std::thread tmpEventThread(&NetDriverUBWithOob::RunInUbEventThread, this); + mUBEventThread = std::move(tmpEventThread); + + while (!mEventStarted.load()) { + usleep(NN_NO10); + } + + return NN_OK; +} + +void NetDriverUBWithOob::DoStop() +{ + if (mHeartBeat != nullptr) { + mHeartBeat->Stop(); + delete mHeartBeat; + mHeartBeat = nullptr; + } + + mNeedStopEvent = true; + if (mUBEventThread.native_handle()) { + mUBEventThread.join(); + } + if (mPublicJetty != nullptr) { + mPublicJetty->Stop(); + } + StopListeners(); +} + +NResult NetDriverUBWithOob::MultiRailNewConnection(OOBTCPConnection &conn) +{ + return NewConnectionCB(conn); +} + +void NetDriverUBWithOob::DestroyEpByPortNum(int portNum) +{ + static thread_local std::vector endPointsCopy; + endPointsCopy.reserve(NN_NO8192); + endPointsCopy.clear(); + { + std::lock_guard locker(mEndPointsMutex); + for (auto iter = mEndPoints.begin(); iter != mEndPoints.end();) { + auto asyncEp = iter->second.ToChild(); + if (asyncEp != nullptr && asyncEp->GetQp()->GetPortNum() == portNum) { + endPointsCopy.emplace_back(iter->second); + iter = mEndPoints.erase(iter); + } else { + ++iter; + } + } + } + + for (auto &endPoint : endPointsCopy) { + NN_LOG_WARN("Detect port down event, handle Ep id " << endPoint->Id() << " of driver " << mName); + ProcessEpError(reinterpret_cast(endPoint.Get())); + } + + NN_LOG_INFO("Destroyed all endpoints count " << endPointsCopy.size() << " by port down of driver " << mName); + endPointsCopy.clear(); +} + +void NetDriverUBWithOob::HandlePortDown(int portNum) +{ + for (auto &worker : mWorkers) { + if (worker->PortNum() == portNum) { + worker->Stop(); + } + } + + DestroyEpByPortNum(portNum); +} + +void NetDriverUBWithOob::HandlePortActive(int portNum) +{ + for (auto &worker : mWorkers) { + if (worker->PortNum() == portNum) { + worker->Start(); + } + } +} + +void NetDriverUBWithOob::DestroyEpInWorker(UBWorker *worker) +{ + static thread_local std::vector endPointsCopy; + endPointsCopy.reserve(NN_NO8192); + endPointsCopy.clear(); + { + std::lock_guard locker(mEndPointsMutex); + for (auto iter = mEndPoints.begin(); iter != mEndPoints.end();) { + auto asyncEp = iter->second.ToChild(); + if (asyncEp != nullptr && asyncEp->mWorker == worker) { + endPointsCopy.emplace_back(iter->second); + iter = mEndPoints.erase(iter); + } else { + ++iter; + } + } + } + + for (auto &endPoint : endPointsCopy) { + NN_LOG_WARN("Detect CQ incorrect event, handle Ep id " << endPoint->Id() << " of driver " << mName); + ProcessEpError(reinterpret_cast(endPoint.Get())); + } + + NN_LOG_INFO("Destroyed all endpoints count " << endPointsCopy.size() << " in UB worker " << worker->DetailName() << + " of driver " << mName); + endPointsCopy.clear(); +} + +void NetDriverUBWithOob::HandleCqEvent(urma_async_event_t *event) +{ + /* when sync mode connecting, there is no worker */ + if (NN_UNLIKELY(event->element.jfc == nullptr || event->element.jfc->jfc_cfg.user_ctx == 0)) { + NN_LOG_ERROR("CQ error for CQ of driver " << mName); + return; + } + + auto worker = reinterpret_cast(event->element.jfc->jfc_cfg.user_ctx); + NN_LOG_ERROR("CQ error for CQ in UB worker " << worker->DetailName() << " of driver " << mName); + if (worker->Stop() != UB_OK) { + NN_LOG_ERROR("Handle Cq event stop error in UB worker " << worker->DetailName() << " of driver " << mName); + return; + } + + DestroyEpInWorker(worker); + if (worker->ReInitializeCQ() != UB_OK) { + NN_LOG_ERROR("Handle Cq event ReInitializeCQ error in UB worker " << worker->DetailName() << " of driver " << + mName); + return; + } + if (worker->Start() != UB_OK) { + NN_LOG_ERROR("Handle Cq event start error in UB worker " << worker->DetailName() << " of driver " << mName); + return; + } +} + +static inline std::string QpDetailInfo(void *qpContext) +{ + auto qp = reinterpret_cast(qpContext); + std::ostringstream oss; + oss << "[Qp name:" << qp->GetName() << ", id:" << qp->GetId() << "]"; + return oss.str(); +} + +void NetDriverUBWithOob::HandleAsyncEvent(urma_async_event_t *event) +{ + switch (event->event_type) { + case URMA_EVENT_JFC_ERR: + HandleCqEvent(event); + NN_LOG_ERROR("jfc error of driver " << mName); + return; + case URMA_EVENT_JFS_ERR: + NN_LOG_ERROR("jfs error of driver " << mName); + return; + case URMA_EVENT_JFR_ERR: + NN_LOG_ERROR("jfr error of driver " << mName); + return; + case URMA_EVENT_JFR_LIMIT: + NN_LOG_ERROR("jfr limit of driver " << mName); + return; + case URMA_EVENT_JETTY_ERR: + NN_LOG_ERROR("jetty error of driver " << mName); + return; + case URMA_EVENT_JETTY_LIMIT: + NN_LOG_ERROR("jetty limit of driver " << mName); + return; + case URMA_EVENT_JETTY_GRP_ERR: + NN_LOG_ERROR("jetty grp error of driver " << mName); + return; + case URMA_EVENT_PORT_ACTIVE: + NN_LOG_ERROR("port active of driver " << mName); + HandlePortActive(event->element.port_id); + return; + case URMA_EVENT_PORT_DOWN: + NN_LOG_ERROR("port down of driver " << mName); + HandlePortDown(event->element.port_id); + return; + case URMA_EVENT_DEV_FATAL: + NN_LOG_ERROR("dev fatal of driver " << mName); + return; + case URMA_EVENT_EID_CHANGE: + NN_LOG_ERROR("eid change of driver " << mName); + mContext->UpdateGid(mMatchIp); + return; + case URMA_EVENT_ELR_ERR: + NN_LOG_ERROR("elr error of driver " << mName); + return; + case URMA_EVENT_ELR_DONE: + NN_LOG_ERROR("elr done of driver " << mName); + return; + default: + NN_LOG_ERROR("Unknown event " << event->event_type << " of driver " << mName); + } +} + +void NetDriverUBWithOob::RunInUbEventThread() +{ + mEventStarted.store(true); + NN_LOG_INFO("Ub event monitor thread for driver " << mName << " started"); + + /* set thread name */ + pthread_setname_np(pthread_self(), ("UBEvent" + std::to_string(mIndex)).c_str()); + + /* set nonblock */ + urma_context_t *urmaContext = mContext->GetContext(); + if (urmaContext == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get urma context for driver " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + mEventStarted.store(false); + return; + } + int flags = fcntl(urmaContext->async_fd, F_GETFL); + int ret = fcntl(urmaContext->async_fd, F_SETFL, (static_cast(flags)) | O_NONBLOCK); + if (ret < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to change event fd of ub context for driver " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + mEventStarted.store(false); + return; + } + + urma_async_event_t event{}; + while (!mNeedStopEvent) { + struct pollfd fd {}; + int timeoutMs = NN_NO100; + fd.fd = urmaContext->async_fd; + fd.events = POLLIN; + fd.revents = 0; + do { + ret = poll(&fd, 1, timeoutMs); + if (ret > 0) { + break; + } else if (ret < 0 && errno != EINTR) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to poll event fd of ub context for driver " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + break; + } + // rc == 0 + } while (!mNeedStopEvent); + + if (mNeedStopEvent) { + break; + } + ret = HcomUrma::GetAsyncEvent(urmaContext, &event); + if (ret != 0) { + /* nothing happen when nonblock mode */ + continue; + } + + /* when fatal event happened, need stop worker first, then call ep broken to prevent race condition + with poll cq thread */ + HandleAsyncEvent(&event); + + /* ack the event, otherwise destroy cq will block */ + HcomUrma::AckAsyncEvent(&event); + } + NN_LOG_INFO("UB event monitor thread for driver " << mName << " exiting"); + mEventStarted.store(false); +} + +int NetDriverUBWithOob::NewConnectionCB(OOBTCPConnection &conn) +{ + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBServer(mSecInfoProvider, mSecInfoValidator, conn, mName, + mOptions.secType)) != NN_OK) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + int ret = 0; + + // receive server worker grpno + auto startRecvWG = NetMonotonic::TimeUs(); + ConnectHeader header{}; + void *grpnobuf = &header; + if ((ret = conn.Receive(grpnobuf, sizeof(ConnectHeader))) != 0) { + NN_LOG_ERROR("Failed to receive specified server worker grpno from client " << mName << ", ret " << ret); + return NN_ERROR; + } + + ConnRespWithUId respWithUId{ OK, 0 }; + ret = OOBSecureProcess::SecCheckConnectionHeader(header, mOptions, mEnableTls, Protocol(), mMajorVersion, + mMinorVersion, respWithUId); + if (ret != NN_OK) { + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + + auto endRecvWG = NetMonotonic::TimeUs(); + auto recvWGtime = endRecvWG - startRecvWG; + if (NN_UNLIKELY(recvWGtime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Receive group num time is too long :" << recvWGtime << " us."); + } + + /* choose worker */ + NetWorkerLBPtr lb = nullptr; + if (mOptions.enableMultiRail) { + lb = mServerLb; + } else { + lb = conn.LoadBalancer(); + } + NN_ASSERT_LOG_RETURN(lb.Get() != nullptr, NN_ERROR) + uint16_t wkrIdx = 0; + if (NN_UNLIKELY(!lb->ChooseWorker(header.groupIndex, conn.GetIpAndPort(), wkrIdx)) || + wkrIdx >= mWorkers.size()) { + NN_LOG_ERROR("Failed to find worker fit grpno " << header.groupIndex << " in " << mName << " , ret " << + ret); + ConnRespWithUId respWithUId{ WORKER_GRPNO_MISMATCH, 0 }; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + + NN_LOG_TRACE_INFO("Worker " << wkrIdx << " is chosen in driver " << mName); + auto worker = mWorkers[wkrIdx]; + NN_ASSERT_LOG_RETURN(worker != nullptr, NN_ERROR); + + if (!worker->IsWorkStarted()) { + NN_LOG_ERROR("Failed to connect worker group no " << header.groupIndex << " in " << mName); + ConnRespWithUId respWithUId{ WORKER_NOT_STARTED, 0 }; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + + // create qp + auto startCreateQp = NetMonotonic::TimeUs(); + UBJetty *qp = nullptr; + if ((ret = worker->CreateQP(qp)) != 0) { + NN_LOG_ERROR("Failed to create qp for new connection in Driver " << mName << " , ret " << ret); + ConnRespWithUId respWithUId{ SERVER_INTERNAL_ERROR, 0 }; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + qp->SetName(mName); + NetLocalAutoDecreasePtr qpAutoDecPtr(qp); + uint32_t token = GenerateSecureRandomUint32(); + UBJettyExchangeInfo info{}; + info.token = token; + if ((ret = qp->Initialize(mOptions.mrSendReceiveSegCount, 0, token)) != 0) { + NN_LOG_ERROR("Failed to initialize qp for new connection in Driver " << mName << " , ret " << ret); + ConnRespWithUId respWithUId{ SERVER_INTERNAL_ERROR, 0 }; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + return NN_ERROR; + } + std::string ipPort = conn.GetIpAndPort(); + qp->SetPeerIpAndPort(ipPort); + + g_connection_count++; + auto id = NetUuid::GenerateUuid(); + NN_LOG_TRACE_INFO("new ep id will be set as " << id << " in driver " << mName); + respWithUId.connResp = OK; + respWithUId.epId = id; + conn.Send(&respWithUId, sizeof(ConnRespWithUId)); + auto endCreateQp = NetMonotonic::TimeUs(); + auto createQpTime = endCreateQp - startCreateQp; + if (NN_UNLIKELY(createQpTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Create qp time is too long :" << createQpTime << " us."); + } + + // exchange info + NN_LOG_TRACE_INFO("Get and send exchange info of ep"); + auto startExchInfo = NetMonotonic::TimeUs(); + auto prePostCount = mOptions.prePostReceiveSizePerQP; + if (mHeartBeat != nullptr) { + if ((ret = qp->CreateHBMemoryRegion(NN_NO128, qp->mHBLocalMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for local HB, ret: " << ret); + return ret; + } + if ((ret = qp->CreateHBMemoryRegion(NN_NO128, qp->mHBRemoteMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for remote HB, ret: " << ret); + qp->DestroyHBMemoryRegion(qp->mHBLocalMr); + return ret; + } + qp->GetRemoteHbInfo(info); + } + info.receiveSegSize = mOptions.mrSendReceiveSegSize; + info.receiveSegCount = mOptions.prePostReceiveSizePerQP; + info.maxSendWr = mOptions.qpSendQueueSize; + info.maxReceiveWr = mOptions.qpReceiveQueueSize; + if (((ret = qp->FillExchangeInfo(info)) != 0)) { + NN_LOG_ERROR("Failed to get ep exchange info in Driver " << mName << ", ret " << ret); + return NN_ERROR; + } + if (((ret = conn.Send(&info, sizeof(UBJettyExchangeInfo))) != 0)) { + NN_LOG_ERROR("Failed to send ep exchange info in Driver " << mName << ", ret " << ret); + return NN_ERROR; + } + NN_LOG_TRACE_INFO("Send exchange info success in Server " << mName); + NN_LOG_TRACE_INFO("local ep ex info lid " << info.lid << ", qpn " << info.qpn << ", gid interface " << + info.gid.global.interface_id); + + std::unique_ptr peerExInfo(new (std::nothrow) UBJettyExchangeInfo); + if (!peerExInfo) { + NN_LOG_ERROR("Failed to alloc UBJettyExchangeInfo in Driver " << mName); + return NN_MALLOC_FAILED; + } + + if ((ret = conn.Receive(peerExInfo.get(), sizeof(UBJettyExchangeInfo))) != 0) { + NN_LOG_ERROR("Failed to receive ep exchange info in Driver " << mName << ", ret " << ret); + return NN_ERROR; + } + NN_LOG_TRACE_INFO("Recv exchange info success in Server " << mName); + qp->StoreExchangeInfo(peerExInfo.release()); + + // receive payload length + uint32_t payloadLen = 0; + auto tmpPayloadLen = reinterpret_cast(&payloadLen); + if ((ret = conn.Receive(tmpPayloadLen, sizeof(uint32_t))) != 0) { + NN_LOG_ERROR("Failed to receive connection payload length in Driver " << mName << ", ret " << ret); + return NN_ERROR; + } + + if (payloadLen == 0 || payloadLen > NN_NO1024) { + NN_LOG_ERROR("Invalid payload length " << payloadLen << ", it should be 1 ~ 1024"); + return NN_ERROR; + } + + // receive payload + std::string payload; + if (payloadLen > 0) { + auto payloadChars = new (std::nothrow) char[payloadLen + NN_NO1]; + if (payloadChars == nullptr) { + NN_LOG_ERROR("Failed to new payload char array in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + NetLocalAutoFreePtr autoFreePayChars(payloadChars, true); + + void *tmpChars = static_cast(payloadChars); + if ((ret = conn.Receive(tmpChars, payloadLen)) != 0) { + NN_LOG_ERROR("Failed to receive connection payload in Driver " << mName << ", ret " << ret); + return NN_ERROR; + } + + payloadChars[payloadLen] = '\0'; + payload = std::string(payloadChars, payloadLen); + } + + NN_LOG_TRACE_INFO("Remote qp ex info lid " << info.lid << ", qpn " << info.qpn << ", gid interface " << + info.gid.global.interface_id << ", pre-post-receive-count " << info.receiveSegCount); + if ((ret = qp->ChangeToReady(qp->GetExchangeInfo())) != 0) { + NN_LOG_ERROR("Failed to change qp to ready in Driver " << mName << ", ret " << ret); + return ret; + } + + auto *mrSegs = new (std::nothrow) uintptr_t[prePostCount]; + if (mrSegs == nullptr) { + NN_LOG_ERROR("Failed to create mr address array in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + NetLocalAutoFreePtr segAutoDelete(mrSegs, true); + + if (!qp->GetFreeBufferN(mrSegs, prePostCount)) { + NN_LOG_ERROR("Failed to get free mr from pool, mr is not enough"); + return NN_MALLOC_FAILED; + } + + uint16_t i = 0; + for (; i < prePostCount; i++) { + if ((ret = worker->PostReceive(qp, mrSegs[i], mOptions.mrSendReceiveSegSize, + reinterpret_cast(qp->GetMemorySeg()))) != 0) { + ClearJettyResource(qp); + return ret; + } + } + + for (; i < prePostCount; i++) { + qp->ReturnBuffer(mrSegs[i]); + } + auto endExchInfo = NetMonotonic::TimeUs(); + auto exchInfoTime = endExchInfo - startExchInfo; + if (NN_UNLIKELY(exchInfoTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Exchange info time too long :" << exchInfoTime << " us."); + } + + // create endpoint + auto startCreateEp = NetMonotonic::TimeUs(); + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBAsyncEndpoint(id, qp, this, worker); + if (ep.Get() == nullptr) { + NN_LOG_ERROR("Failed to create UBSHcomNetEndpoint in Driver " << mName << ", probably out of memory"); + ClearJettyResource(qp); + return NN_NEW_OBJECT_FAILED; + } + + if (mOptions.oobType == NET_OOB_UDS) { + struct ucred remoteIds {}; + socklen_t len = sizeof(struct ucred); + if (NN_UNLIKELY(getsockopt(conn.GetFd(), SOL_SOCKET, SO_PEERCRED, &remoteIds, &len) != 0)) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to get uds ids in driver " << mName << " errno:" << errno << " error:" << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return NN_GET_UDS_ID_INFO_FAILED; + } + ep->RemoteUdsIdInfo(remoteIds.pid, remoteIds.uid, remoteIds.gid); + } + + ep->StoreConnInfo(NetFunc::GetIpByFd(conn.GetFd()), conn.ListenPort(), header.version, payload); + ep.ToChild()->SetRemoteHbInfo(qp->GetExchangeInfo().hbAddress, qp->GetExchangeInfo().hbKey, + qp->GetExchangeInfo().hbMrSize); + if (mEnableTls) { + auto childEp = ep.ToChild(); + if (NN_UNLIKELY(childEp == nullptr)) { + NN_LOG_ERROR("To child Failed"); + return NN_ERROR; + } + auto tmp = dynamic_cast(&conn); + if (NN_UNLIKELY(tmp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + childEp->EnableEncrypt(mOptions); + childEp->SetSecrets(tmp->Secret()); + } + ep->mDevIndex = mDevIndex; + ep->mPeerDevIndex = mPeerDevIndex; + ep->mBandWidth = mBandWidth; + + ret = mNewEndPointHandler(conn.GetIpAndPort(), ep, payload); + if (NN_UNLIKELY(ret != UB_OK)) { + NN_LOG_ERROR("Called new end point handler failed, ret " << ret); + ClearJettyResource(qp); + return NN_ERROR; + } + + // 在EP创建失败时,保证 jetty 不会索引到无效EP,必须位于set ep to ESTABLISHED 上方。因历史原因,用户可能会在 + // NewEndpointHandler 中发送信息,而如果状态为ESTABLISHED但是Jetty无法索引到EP,在UBWorker poll到对于jetty + // 上发生事件时可能无法找到源 EP. + // \see NetDriverUBWithOob::ProcessPollingResult + // \see NetServiceDefaultImp::ServiceRequestPosted + // \see HcomServiceImp::ServiceReuqestPosted + qp->SetUpContext(reinterpret_cast(ep.Get())); + qp->SetUpId(ep->Id()); + ep.ToChild()->State().Set(NEP_ESTABLISHED); + + // ready 同步信令发送后客户端可能会立即发包,会在 UBWorker 中被动触发事件。如果 jetty 无法通过 UpContext() 索引到 + // ep,此 ep 上产生的事件无法被 UBWorker 转发至回调。因此发送 ready 同步信令必须位于 qp->SetUpContext(...)之后。 + int8_t ready = (ret == UB_OK) ? 1 : 0; + if ((ret = conn.Send(&ready, sizeof(int8_t))) != UB_OK) { + NN_LOG_ERROR("Failed to send ready signal to client, ret " << ret); + ClearJettyResource(qp); + return NN_ERROR; + } + + // EP被安全创建完毕 + { + std::lock_guard locker(mEndPointsMutex); + mEndPoints.emplace(id, ep); + } + + auto endCreateEp = NetMonotonic::TimeUs(); + auto createEpTime = endCreateEp - startCreateEp; + if (NN_UNLIKELY(createEpTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Create endpoint time too long :" << createEpTime << " us."); + } + + NN_LOG_INFO("New connection from " << conn.GetIpAndPort() << " established, async ep id " << ep->Id() + << ", jetty id: " << qp->QpNum() << ", worker info " << worker->DetailName()); + return NN_OK; +} + +NResult NetDriverUBWithOob::Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, + uint8_t serverGrpNo, uint8_t clientGrpNo) +{ + if (mOptions.oobType == NET_OOB_TCP) { + return Connect(mOobIp, mOobPort, payload, ep, flags, serverGrpNo, clientGrpNo); + } else if (mOptions.oobType == NET_OOB_UDS) { + return Connect(mUdsName, 0, payload, ep, flags, serverGrpNo, clientGrpNo); + } else if (mOptions.oobType == NET_OOB_UB) { + return ConnectByPublicJetty(mOobIp, mOobPort, payload, ep, flags, serverGrpNo, clientGrpNo); + } + return NN_ERROR; +} + +NResult NetDriverUBWithOob::Connect(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + if (ClientCheckState(payload) != 0) { + NN_LOG_ERROR("Failed to connect as driver not start or payload oversize"); + return NN_ERROR; + } + /* all kind of drivers can connect to peer to get an ep */ + if (mOptions.oobType == NET_OOB_UB) { + return ConnectByPublicJetty(oobIp, oobPort, payload, outEp, flags, serverGrpNo, clientGrpNo, ctx); + } + + if ((flags & NET_EP_SELF_POLLING) || (flags & NET_EP_EVENT_POLLING)) { + return ConnectSyncEp(oobIp, oobPort, payload, outEp, flags, serverGrpNo, ctx); + } + + OOBTCPClientPtr tcpClient; + if (mEnableTls) { + tcpClient = new (std::nothrow) + OOBSSLClient(mOptions.oobType, oobIp, oobPort, mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(tcpClient.Get() != nullptr, NN_NEW_OBJECT_FAILED) + tcpClient.ToChild()->SetTlsOptions(mOptions); + tcpClient.ToChild()->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + } else { + tcpClient = new (std::nothrow) OOBTCPClient(mOptions.oobType, oobIp, oobPort); + NN_ASSERT_LOG_RETURN(tcpClient.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + /* try to connect to oob server */ + OOBTCPConnection *conn = nullptr; + NResult result = NN_OK; + if ((result = tcpClient->Connect(conn)) != 0) { + NN_LOG_ERROR("Failed to connect server via oob, result " << result); + return result; + } + + NetLocalAutoDecreasePtr autoDecPtr(conn); + conn->SetIpAndPort(oobIp, oobPort); + + if (mOptions.enableMultiRail) { + ConnectHeader driverHeader; + SetDriverConnHeader(driverHeader, mBandWidth, mDevIndex); + if ((result = conn->Send(&driverHeader, sizeof(ConnectHeader))) != 0) { + NN_LOG_ERROR("Failed to send driver info " << mName << ", Result " << result); + return result; + } + + ConnectHeader header{}; + void *grpnobuf = static_cast(&header); + auto result = conn->Receive(grpnobuf, sizeof(ConnectHeader)); + if (result != 0) { + NN_LOG_ERROR("Failed to receive specified device info for server, Result " << result); + return result; + } + + if (header.devIndex >= NN_NO4) { + NN_LOG_ERROR("Invalid devIndex " << header.devIndex << " in header, which should be in 0 ~ 3"); + return NN_ERROR; + } + mPeerDevIndex = header.devIndex; + } + + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBClient(mSecInfoProvider, mSecInfoValidator, conn, mName, ctx, + mOptions.secType))) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + /* send connection header & grpNo */ + auto startSendGrpNo = NetMonotonic::TimeUs(); + ConnectHeader header; + SetConnHeader(header, mOptions.magic, mOptions.version, serverGrpNo, Protocol(), mMajorVersion, mMinorVersion, + mOptions.tlsVersion); + header.reserve = ctx; + if ((result = conn->Send(&header, sizeof(ConnectHeader))) != 0) { + NN_LOG_ERROR("Failed to send server worker grpno in Driver " << mName << ", result " << result); + return result; + } + + /* receive connect response and peer ep id */ + ConnRespWithUId respWithUId{}; + void *ackBuf = static_cast(&respWithUId); + if ((result = conn->Receive(ackBuf, sizeof(ConnRespWithUId))) != 0) { + NN_LOG_ERROR("Failed receive ServerAck in Driver " << mName << ", result " << result); + return result; + } + + /* connect response */ + auto serverRsp = respWithUId.connResp; + if (serverRsp == MAGIC_MISMATCH) { + NN_LOG_ERROR("Failed to pass server magic validation " << mName << ", result " << serverRsp); + return NN_CONNECT_REFUSED; + } + + if (serverRsp == WORKER_GRPNO_MISMATCH || serverRsp == WORKER_NOT_STARTED) { + NN_LOG_ERROR("Failed to choose worker or not started " << mName << ", result " << serverRsp); + return NN_CONNECT_REFUSED; + } + + if (serverRsp == PROTOCOL_MISMATCH) { + NN_LOG_ERROR("Failed to pass server protocol validation " << mName << ", result " << serverRsp); + return NN_CONNECT_PROTOCOL_MISMATCH; + } + + if (serverRsp == SERVER_INTERNAL_ERROR) { + NN_LOG_ERROR("Server error happened, connection refused " << mName << ", result " << serverRsp); + return NN_ERROR; + } + + if (serverRsp != OK) { + NN_LOG_ERROR("Server error happened, connection refused " << mName << ", result " << serverRsp); + return NN_ERROR; + } + auto endSendGrpNo = NetMonotonic::TimeUs(); + auto sendGrpNoTime = endSendGrpNo - startSendGrpNo; + if (NN_UNLIKELY(sendGrpNoTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Send groupNo time too long: " << sendGrpNoTime << " us."); + } + + /* peer ep id */ + auto id = respWithUId.epId; + NN_LOG_TRACE_INFO("new ep id will be set as " << id << " in driver " << mName); + + /* Choose worker */ + uint16_t workerIndex = 0; + if (NN_UNLIKELY(!mClientLb->ChooseWorker(clientGrpNo, std::to_string(id), workerIndex)) || + workerIndex >= mWorkers.size()) { + NN_LOG_ERROR("Failed to choose worker during connect in driver " << mName); + return NN_ERROR; + } + + NN_ASSERT_LOG_RETURN(workerIndex < mWorkers.size(), NN_ERROR) + auto *worker = mWorkers[workerIndex]; + + if (!worker->IsWorkStarted()) { + NN_LOG_ERROR("Failed to connect worker group no " << clientGrpNo << " in " << mName); + return NN_ERROR; + } + + /* Create Qp */ + UBJetty *jetty = nullptr; + if (worker->CreateQP(jetty) != UB_OK) { + NN_LOG_ERROR("Fail to create jetty"); + return NN_ERROR; + } + jetty->SetName(mName); + NetLocalAutoDecreasePtr qpAutoDecPtr(jetty); + + uint32_t token = GenerateSecureRandomUint32(); + UBJettyExchangeInfo info{}; + info.token = token; + if (jetty->Initialize(mOptions.mrSendReceiveSegCount, 0, token) != 0) { + NN_LOG_ERROR("Failed to initialize jetty for new connection in Driver " << mName); + return NN_ERROR; + } + /* fill and send exchange info */ + auto startExchInfo = NetMonotonic::TimeUs(); + NN_LOG_TRACE_INFO("get and send exchange info of ep"); + if (mHeartBeat != nullptr) { + if ((result = jetty->CreateHBMemoryRegion(NN_NO128, jetty->mHBLocalMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for local HB, result " << result); + return result; + } + if ((result = jetty->CreateHBMemoryRegion(NN_NO128, jetty->mHBRemoteMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for remote HB, result " << result); + jetty->DestroyHBMemoryRegion(jetty->mHBLocalMr); + return result; + } + jetty->GetRemoteHbInfo(info); + } + info.maxSendWr = mOptions.qpSendQueueSize; + info.maxReceiveWr = mOptions.qpReceiveQueueSize; + info.receiveSegSize = mOptions.mrSendReceiveSegSize; + info.receiveSegCount = mOptions.prePostReceiveSizePerQP; + + if (((result = jetty->FillExchangeInfo(info)) != 0)) { + NN_LOG_ERROR("Failed to get ep exchange info in Driver " << mName << ", result " << result); + return result; + } + if (((result = conn->Send(&info, sizeof(UBJettyExchangeInfo))) != 0)) { + NN_LOG_ERROR("Failed to send ep exchange info in Driver " << mName << ", result " << result); + return result; + } + + auto prePostCount = mOptions.prePostReceiveSizePerQP; + + // send payload len + uint32_t payloadLength = payload.length(); + if ((result = conn->Send(&payloadLength, sizeof(uint32_t))) != 0) { + NN_LOG_ERROR("Failed to send connection payload length in Driver " << mName << ", result " << result); + return result; + } + + // send payload + if (payloadLength > 0) { + auto payloadPtr = reinterpret_cast(const_cast(payload.c_str())); + if ((result = conn->Send(payloadPtr, payloadLength)) != 0) { + NN_LOG_ERROR("Failed to send connection payload in Driver " << mName << ", result " << result); + return result; + } + } + + // receive exchange info + std::unique_ptr peerExInfo(new (std::nothrow) UBJettyExchangeInfo); + if (!peerExInfo) { + NN_LOG_ERROR("Failed to alloc UBJettyExchangeInfo in Driver " << mName); + return NN_MALLOC_FAILED; + } + + if ((result = conn->Receive(peerExInfo.get(), sizeof(UBJettyExchangeInfo))) != 0) { + NN_LOG_ERROR("Failed to receive ep exchange info in Driver " << mName << ", result " << result); + return NN_ERROR; + } + jetty->StoreExchangeInfo(peerExInfo.release()); + + /* change to ready */ + NN_LOG_TRACE_INFO("remote jetty ex info lid " << info.lid << ", qpn " << info.qpn << ", gid interface " << + info.gid.global.interface_id << ", pre-post-receive-count " << info.receiveSegCount); + if ((result = jetty->ChangeToReady(jetty->GetExchangeInfo())) != 0) { + NN_LOG_ERROR("Failed to change jetty to ready in Driver " << mName << ", result " << result); + return result; + } + + jetty->SetPeerIpAndPort(conn->GetIpAndPort()); + + auto *mrSegs = new (std::nothrow) uintptr_t[prePostCount]; + if (mrSegs == nullptr) { + NN_LOG_ERROR("Failed to create array of mr address in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + NetLocalAutoFreePtr segAutoDelete(mrSegs, true); + + if (!jetty->GetFreeBufferN(mrSegs, prePostCount)) { + NN_LOG_ERROR("Failed to get free mr from pool, result " << result); + return NN_ERROR; + } + + uint16_t i = 0; + for (; i < prePostCount; i++) { + if ((result = worker->PostReceive(jetty, mrSegs[i], mOptions.mrSendReceiveSegSize, + reinterpret_cast(jetty->GetMemorySeg()))) != 0) { + ClearJettyResource(jetty); + return result; + } + } + + for (; i < prePostCount; i++) { + jetty->ReturnBuffer(mrSegs[i]); + } + + auto endExchInfo = NetMonotonic::TimeUs(); + auto exchInfoTime = endExchInfo - startExchInfo; + if (NN_UNLIKELY(exchInfoTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Exchange Info time too long: " << exchInfoTime << " us."); + } + + /* Create endpoint */ + auto startCreateEp = NetMonotonic::TimeUs(); + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBAsyncEndpoint(id, jetty, this, worker); + if (ep.Get() == nullptr) { + NN_LOG_ERROR("Failed to create UBSHcomNetEndpoint in Driver " << mName << ", probably out of memory"); + ClearJettyResource(jetty); + return NN_NEW_OBJECT_FAILED; + } + + if (mEnableTls) { + auto childEp = ep.ToChild(); + if (NN_UNLIKELY(childEp == nullptr)) { + NN_LOG_ERROR("To child Failed"); + return NN_ERROR; + } + auto tmp = dynamic_cast(conn); + if (NN_UNLIKELY(tmp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + childEp->EnableEncrypt(mOptions); + childEp->SetSecrets(tmp->Secret()); + } + + ep->StoreConnInfo(NetFunc::GetIpByFd(conn->GetFd()), conn->ListenPort(), header.version, payload); + ep.ToChild()->SetRemoteHbInfo(jetty->GetExchangeInfo().hbAddress, + jetty->GetExchangeInfo().hbKey, jetty->GetExchangeInfo().hbMrSize); + + // receive server ready signal + int8_t ready = -1; + result = conn->Receive(&ready, sizeof(int8_t)); + if (result != 0 || ready != 1) { + NN_LOG_ERROR("Failed to connect to server as server not responses or return not ready, result " << result); + ClearJettyResource(jetty); + return NN_ERROR; + } + + // \see NetDriverUBWithOob::NewConnectionCB + jetty->SetUpContext(reinterpret_cast(ep.Get())); + ep->State().Set(NEP_ESTABLISHED); + { + std::lock_guard locker(mEndPointsMutex); + mEndPoints.emplace(ep->Id(), ep); + } + + NN_LOG_INFO("New connect to " << oobIp << ":" << oobPort << " established, async ep id: " << ep->Id() + << ", jetty id: " << jetty->QpNum() << ", worker info " << worker->DetailName()); + outEp = ep; + reinterpret_cast(ep.Get())->GetQp()->SetUpId(ep->Id()); + auto endCreateEp = NetMonotonic::TimeUs(); + auto createEpTime = endCreateEp - startCreateEp; + if (NN_UNLIKELY(createEpTime > MAX_OP_TIME_US)) { + NN_LOG_WARN("Create endpoint time too long: " << createEpTime << " us."); + } + return NN_OK; +} + +NResult NetDriverUBWithOob::ConnectSyncEp(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo, uint64_t ctx) +{ + OOBTCPClientPtr client; + if (mEnableTls) { + client = new (std::nothrow) + OOBSSLClient(mOptions.oobType, oobIp, oobPort, mTlsPrivateKeyCB, mTlsCertCB, mTlsCaCallback); + NN_ASSERT_LOG_RETURN(client.Get() != nullptr, NN_NEW_OBJECT_FAILED) + client.ToChild()->SetTlsOptions(mOptions); + client.ToChild()->SetPSKCallback(mPskFindSessionCb, mPskUseSessionCb); + } else { + client = new (std::nothrow) OOBTCPClient(mOptions.oobType, oobIp, oobPort); + NN_ASSERT_LOG_RETURN(client.Get() != nullptr, NN_NEW_OBJECT_FAILED) + } + + /* try to connect to oob server */ + OOBTCPConnection *conn = nullptr; + NResult result = NN_OK; + if ((result = client->Connect(conn)) != 0) { + NN_LOG_ERROR("Failed to connect server via oob, result " << result); + return result; + } + + NetLocalAutoDecreasePtr autoDecPtr(conn); + conn->SetIpAndPort(oobIp, oobPort); + + if (NN_UNLIKELY(OOBSecureProcess::SecProcessInOOBClient(mSecInfoProvider, mSecInfoValidator, conn, mName, ctx, + mOptions.secType))) { + return NN_OOB_SEC_PROCESS_ERROR; + } + + UBPollingMode pollMode = ((flags & NET_EP_EVENT_POLLING)) ? UB_EVENT_POLLING : UB_BUSY_POLLING; + + auto prePostCount = mOptions.prePostReceiveSizePerQP; + + // create qp and cq + UBJetty *qp = nullptr; + UBJfc *cq = nullptr; + JettyOptions qpOptions(mOptions.qpSendQueueSize, mOptions.qpReceiveQueueSize, mOptions.mrSendReceiveSegSize, + mOptions.prePostReceiveSizePerQP, mOptions.slave, mOptions.ubcMode); + if ((result = NetUBSyncEndpoint::CreateResources(mName, mContext, pollMode, qpOptions, qp, cq)) != 0) { + NN_LOG_ERROR("Failed to create qp and cq, result " << result); + return result; + } + qp->SetName(mName); + NetLocalAutoDecreasePtr qpAutoDecPtr(qp); + NetLocalAutoDecreasePtr cqAutoDecPtr(cq); + + if (cq->Initialize() != 0) { + NN_LOG_ERROR("Failed to initialize cq for new connection in Driver " << mName); + return NN_ERROR; + } + uint32_t token = GenerateSecureRandomUint32(); + UBJettyExchangeInfo info{}; + info.token = token; + if (qp->Initialize(mOptions.mrSendReceiveSegCount, 0, token) != 0) { + NN_LOG_ERROR("Failed to initialize qp for new connection in Driver " << mName); + return NN_ERROR; + } + + /* send connection header */ + ConnectHeader header; + SetConnHeader(header, mOptions.magic, mOptions.version, serverGrpNo, Protocol(), mMajorVersion, mMinorVersion, + mOptions.tlsVersion); + + if ((result = conn->Send(&header, sizeof(ConnectHeader))) != 0) { + NN_LOG_ERROR("Failed to send server worker grpno in Driver " << mName << ", result " << result); + return result; + } + + /* receive connect response and peer ep id */ + ConnRespWithUId respWithUId{}; + void *ackBuf = static_cast(&respWithUId); + if ((result = conn->Receive(ackBuf, sizeof(ConnRespWithUId))) != 0) { + NN_LOG_ERROR("Failed receive ServerAck in Driver " << mName << ", result " << result); + return result; + } + + /* connect response */ + auto serverAck = respWithUId.connResp; + if (serverAck == MAGIC_MISMATCH) { + NN_LOG_ERROR("Failed to pass server magic validation " << mName << ",magic " << header.magic << ", result " << + serverAck); + return NN_CONNECT_REFUSED; + } + + if (serverAck == WORKER_GRPNO_MISMATCH || serverAck == WORKER_NOT_STARTED) { + NN_LOG_ERROR("Failed to choose worker or not started " << mName << ", result " << serverAck); + return NN_CONNECT_REFUSED; + } + + if (serverAck == PROTOCOL_MISMATCH) { + NN_LOG_ERROR("Failed to pass server protocol validation " << mName << ", result " << serverAck); + return NN_CONNECT_PROTOCOL_MISMATCH; + } + + if (serverAck == SERVER_INTERNAL_ERROR) { + NN_LOG_ERROR("Server error happened, connection refused " << mName << ", result " << serverAck); + return NN_ERROR; + } + + if (serverAck != OK) { + NN_LOG_ERROR("Server error happened, connection refused " << mName << ", result " << serverAck); + return NN_ERROR; + } + + /* peer ep id */ + auto id = respWithUId.epId; + NN_LOG_TRACE_INFO("new ep id will be set as " << id << " in driver " << mName); + // exchange info + NN_LOG_TRACE_INFO("get and send exchange info of ep"); + if (mHeartBeat != nullptr) { + mHeartBeat->GetRemoteHbInfo(info); + } + info.maxSendWr = mOptions.qpSendQueueSize; + info.maxReceiveWr = mOptions.qpReceiveQueueSize; + info.receiveSegSize = mOptions.mrSendReceiveSegSize; + info.receiveSegCount = mOptions.prePostReceiveSizePerQP; + if (((result = qp->FillExchangeInfo(info)) != 0)) { + NN_LOG_ERROR("Failed to get ep exchange info in Driver " << mName << ", result " << result); + return result; + } + if (((result = conn->Send(&info, sizeof(UBJettyExchangeInfo))) != 0)) { + NN_LOG_ERROR("Failed to send ep exchange info in Driver " << mName << ", result " << result); + return result; + } + + // send payload len + uint32_t payloadLen = payload.length(); + if ((result = conn->Send(&payloadLen, sizeof(uint32_t))) != 0) { + NN_LOG_ERROR("Failed to send connection payload length in Driver " << mName << ", result " << result); + return result; + } + + // send payload + if (payloadLen > 0) { + auto payloadPtr = reinterpret_cast(const_cast(payload.c_str())); + if ((result = conn->Send(payloadPtr, payloadLen)) != 0) { + NN_LOG_ERROR("Failed to send connection payload in Driver " << mName << ", result " << result); + return result; + } + } + + std::unique_ptr peerExInfo(new (std::nothrow) UBJettyExchangeInfo); + if (!peerExInfo) { + NN_LOG_ERROR("Failed to alloc UBJettyExchangeInfo in Driver " << mName); + return NN_MALLOC_FAILED; + } + + if ((result = conn->Receive(peerExInfo.get(), sizeof(UBJettyExchangeInfo))) != 0) { + NN_LOG_ERROR("Failed to receive ep exchange info in Driver " << mName << ", result " << result); + return NN_ERROR; + } + qp->StoreExchangeInfo(peerExInfo.release()); + + NN_LOG_TRACE_INFO("remote qp ex info lid " << info.lid << ", qpn " << info.qpn << ", gid interface " << + info.gid.global.interface_id << ", pre-post-receive-count " << info.receiveSegCount); + if ((result = qp->ChangeToReady(qp->GetExchangeInfo())) != 0) { + NN_LOG_ERROR("Failed to change ep to ready in Driver " << mName << ", result " << result); + return result; + } + + qp->SetPeerIpAndPort(conn->GetIpAndPort()); + + auto *mrSegs = new (std::nothrow) uintptr_t[prePostCount]; + if (mrSegs == nullptr) { + NN_LOG_ERROR("Failed to create mr address array in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + NetLocalAutoFreePtr segAutoDelete(mrSegs, true); + + if (!qp->GetFreeBufferN(mrSegs, prePostCount)) { + NN_LOG_ERROR("Failed to get free mr from pool, result " << result); + return NN_ERROR; + } + + // create endpoint + static UBSHcomNetWorkerIndex workerIndex; + workerIndex.driverIdx = mIndex; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBSyncEndpoint(id, qp, cq, prePostCount + NN_NO4, this, + workerIndex); + if (ep.Get() == nullptr) { + NN_LOG_ERROR("Failed to create UBSHcomNetEndpoint in Driver " << mName << ", probably out of memory"); + // do later: handle pre post-ed mr + return NN_NEW_OBJECT_FAILED; + } + NN_LOG_INFO("Create sync ep success, ep id: " << ep->Id() << ", with jetty id: " << qp->QpNum()); + + if (reinterpret_cast(ep.Get())->mCtxPool.Initialize() != UB_OK) { + NN_LOG_ERROR("Fail to initialize mCtxPool"); + } + + for (int i = 0; i < prePostCount; i++) { + result = reinterpret_cast(ep.Get())->PostReceive(mrSegs[i], mOptions.mrSendReceiveSegSize, + reinterpret_cast(qp->GetMemorySeg())); + if (result != 0) { + // do later if failure, qp should break at this time + return result; + } + } + + if (mEnableTls) { + auto childEp = ep.ToChild(); + if (NN_UNLIKELY(childEp == nullptr)) { + NN_LOG_ERROR("To child Failed"); + return NN_ERROR; + } + auto tmp = dynamic_cast(conn); + if (NN_UNLIKELY(tmp == nullptr)) { + NN_LOG_ERROR("dynamic cast error"); + return NN_OOB_SEC_PROCESS_ERROR; + } + childEp->EnableEncrypt(mOptions); + childEp->SetSecrets(tmp->Secret()); + } + ep->StoreConnInfo(NetFunc::GetIpByFd(conn->GetFd()), conn->ListenPort(), header.version, payload); + + // receive server ready signal + int8_t ready = -1; + result = conn->Receive(&ready, sizeof(int8_t)); + if (result != 0 || ready != 1) { + NN_LOG_ERROR("Failed to connect to server as server not respond or return not ready, result " << result); + // do later: handle pre post-ed mr + return NN_ERROR; + } + + // SyncEP 不会在 UBWorker中处理事件,为保持一致性与AsyncEP采用一样顺序。 + // \see NetDriverUBWithOob::NewConnectionCB + qp->SetUpContext(reinterpret_cast(ep.Get())); + ep->State().Set(NEP_ESTABLISHED); + { + std::lock_guard locker(mEndPointsMutex); + mEndPoints.emplace(id, ep); + } + + NN_LOG_INFO("New connect to " << oobIp << ":" << oobPort << " established, sync ep id " << ep->Id()); + outEp = ep; + reinterpret_cast(ep.Get())->mJetty->SetUpId(ep->Id()); + return NN_OK; +} + +void NetDriverUBWithOob::ProcessErrorNewRequest(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->ubJetty == nullptr || ctx->ubJetty->GetUpContext1() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker is null of NewRequest in Driver " << mName << ""); + return; + } + + if (ctx->opType == UBOpContextInfo::RECEIVE) { + ctx->ubJetty->ReturnBuffer(ctx->mrMemAddr); + auto worker = reinterpret_cast(ctx->ubJetty->GetUpContext1()); + worker->ReturnOpContextInfo(ctx); + // not receive remote data, do not call user callback + } else { + NN_LOG_WARN("Unreachable path"); + } +} + +int NetDriverUBWithOob::SendRawSglFinishedCB(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx) +{ + int result = 0; + + auto worker = reinterpret_cast(ctx->ubJetty->GetUpContext1()); + auto sgeCtx = reinterpret_cast(ctx->upCtx); + auto sglCtx = sgeCtx->ctx; + result = UBOpContextInfo::GetNResult(ctx->opResultType); + // set context + netCtx.mEp.Set(reinterpret_cast(ctx->ubJetty->GetUpContext())); + netCtx.mResult = sglCtx->result < result ? result : sglCtx->result; + netCtx.mOpType = UBSHcomNetRequestContext::NN_SENT_RAW_SGL; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + if (NN_UNLIKELY(memcpy_s(netCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, sglCtx->iov, + sizeof(UBSHcomNetTransSgeIov) * sglCtx->iovCount) != UB_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return UB_PARAM_INVALID; + } + netCtx.mOriginalSglReq.iov = netCtx.iov; + netCtx.mOriginalSglReq.iovCount = sglCtx->iovCount; + netCtx.mOriginalSglReq.upCtxSize = sglCtx->upCtxSize; + if (netCtx.mOriginalSglReq.upCtxSize > 0 && + netCtx.mOriginalSglReq.upCtxSize <= sizeof(UBSHcomNetTransSglRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalSglReq.upCtxData, NN_NO16, sglCtx->upCtx, sglCtx->upCtxSize) != + UB_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return UB_PARAM_INVALID; + } + } + worker->ReturnSglContextInfo(sglCtx); + // called to callback + if (NN_UNLIKELY((result = mRequestPostedHandler(netCtx)) != UB_OK)) { + NN_LOG_ERROR("Call requestPostedHandler in Driver " << mName << " return non-zero for sgl type " << + ctx->opType << " done"); + } + netCtx.mEp.Set(nullptr); + + // buffer should return when encrypt + if (mEnableTls) { + (void)mDriverSendMR->ReturnBuffer(ctx->mrMemAddr); + } + + worker->ReturnOpContextInfo(ctx); + + return NN_OK; +} + +void PrintSendFinishDebug(UBSHcomNetTransHeader &header, UBOpContextInfo *ctx) +{ + UBSHcomNetEndpointPtr debugEp = reinterpret_cast(ctx->ubJetty->GetUpContext()); + uint64_t epId = debugEp->Id(); + if (ctx->opType == UBOpContextInfo::SEND) { + NN_LOG_DEBUG("[Request Send] ------ ep id = " << epId << ", headerCrc = " << header.headerCrc + << ", opCode = " << header.opCode << ", flags = " << header.flags << ", seqNo = " << header.seqNo + << ",timeout = " << header.timeout << ", errCode = " << header.errorCode << ", dataLength = " + << header.dataLength << ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::POLLED)); + } else { + NN_LOG_DEBUG("[Request Send] ------ raw request, ep id = " << epId << "dataLength = " << ctx->dataSize << + ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::POLLED)); + } +} + +int NetDriverUBWithOob::SendSglInlineFinishedCB(UBOpContextInfo *ctx, UBSHcomNetRequestContext &requestCtx, + UBWorker *worker) +{ + int result = 0; + requestCtx.mHeader.Invalid(); + requestCtx.mResult = UBOpContextInfo::GetNResult(ctx->opResultType); + requestCtx.mEp.Set(reinterpret_cast(ctx->ubJetty->GetUpContext())); + requestCtx.mMessage = nullptr; + requestCtx.mOpType = UBSHcomNetRequestContext::NN_SENT_SGL_INLINE; + requestCtx.mOriginalReq = {}; + requestCtx.mOriginalReq.lAddress = 0; + requestCtx.mOriginalReq.size = ctx->dataSize; + requestCtx.mOriginalReq.upCtxSize = ctx->upCtxSize; + + if (requestCtx.mOriginalReq.upCtxSize > 0 && + requestCtx.mOriginalReq.upCtxSize <= sizeof(UBSendReadWriteRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(requestCtx.mOriginalReq.upCtxData, ctx->upCtxSize, ctx->upCtx, ctx->upCtxSize) != + UB_OK)) { + NN_LOG_ERROR("Failed to copy req to ctx"); + return UB_PARAM_INVALID; + } + } + // return context to worker, and ctx is set null, not usable anymore + worker->ReturnOpContextInfo(ctx); + // call to callback + if (NN_UNLIKELY((result = mRequestPostedHandler(requestCtx)) != UB_OK)) { + NN_LOG_ERROR("Call requestPostedHandler in Driver " << mName << + " return non-zero for receive message [dataSize " << requestCtx.mHeader.dataLength << "]"); + } + requestCtx.mEp.Set(nullptr); + return NN_OK; +} + +int NetDriverUBWithOob::SendFinishedCB(UBOpContextInfo *ctx) +{ + using NRC = UBSHcomNetRequestContext; + int result = 0; + static thread_local UBSHcomNetRequestContext requestCtx{}; + ctx->ubJetty->ReturnPostSendWr(); + auto worker = reinterpret_cast(ctx->ubJetty->GetUpContext1()); + if (ctx->opType == UBOpContextInfo::SEND || ctx->opType == UBOpContextInfo::SEND_RAW) { + if (ctx->opType == UBOpContextInfo::SEND) { + (void)memcpy_s(&(requestCtx.mHeader), sizeof(UBSHcomNetTransHeader), + reinterpret_cast(ctx->mrMemAddr), sizeof(UBSHcomNetTransHeader)); + } else { + requestCtx.mHeader.Invalid(); + } + PrintSendFinishDebug(requestCtx.mHeader, ctx); + requestCtx.mResult = UBOpContextInfo::GetNResult(ctx->opResultType); + requestCtx.mEp.Set(reinterpret_cast(ctx->ubJetty->GetUpContext())); + requestCtx.mMessage = nullptr; + requestCtx.mOpType = ctx->opType == UBOpContextInfo::SEND ? NRC::NN_SENT : NRC::NN_SENT_RAW; + requestCtx.mOriginalReq = {}; + // if PostSend implement with one side memory, the lAddress should be valued with ctx->mrMemAddr. + requestCtx.mOriginalReq.lAddress = 0; + requestCtx.mOriginalReq.size = ctx->dataSize; + requestCtx.mOriginalReq.upCtxSize = ctx->upCtxSize; + + if (requestCtx.mOriginalReq.upCtxSize > 0 && + requestCtx.mOriginalReq.upCtxSize <= sizeof(UBSendReadWriteRequest::upCtxData)) { + (void)memcpy_s(requestCtx.mOriginalReq.upCtxData, ctx->upCtxSize, ctx->upCtx, ctx->upCtxSize); + } + + if (NN_UNLIKELY(!mDriverSendMR->ReturnBuffer(ctx->mrMemAddr))) { + NN_LOG_ERROR("Failed to return mr segment back in Driver " << mName); + } + + // return context to worker, and ctx is set null, not usable anymore + worker->ReturnOpContextInfo(ctx); + if (requestCtx.mHeader.opCode == HB_SEND_OP || requestCtx.mHeader.opCode == HB_RECV_OP) { + return NN_OK; + } + // call to callback + if (NN_UNLIKELY((result = mRequestPostedHandler(requestCtx)) != UB_OK)) { + NN_LOG_ERROR("Call requestPostedHandler in Driver " << mName + << " return non-zero for receive message [opCode: " << requestCtx.mHeader.opCode << ", dataSize " + << requestCtx.mHeader.dataLength << "]"); + } + requestCtx.mEp.Set(nullptr); + } else if (ctx->opType == UBOpContextInfo::SEND_RAW_SGL) { + return SendRawSglFinishedCB(ctx, requestCtx); + } else if (ctx->opType == UBOpContextInfo::SEND_SGL_INLINE) { + return SendSglInlineFinishedCB(ctx, requestCtx); + } else { + NN_LOG_WARN("Unreachable path"); + } + + return NN_OK; +} + +void NetDriverUBWithOob::ProcessErrorSendFinished(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->ubJetty == nullptr || ctx->ubJetty->GetUpContext1() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker is null of SendFinished in Driver " << mName << ""); + return; + } + + SendFinishedCB(ctx); +} + +int NetDriverUBWithOob::RWSglOneSideDoneCB(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx) +{ + int result = 0; + auto worker = reinterpret_cast(ctx->ubJetty->GetUpContext1()); + auto sgeCtx = reinterpret_cast(ctx->upCtx); + auto sglContext = sgeCtx->ctx; + result = UBOpContextInfo::GetNResult(ctx->opResultType); + sglContext->result = sglContext->result < result ? result : sglContext->result; + auto refCount = __sync_add_and_fetch(&(sglContext->refCount), 1); + if (refCount != sglContext->iovCount) { + worker->ReturnOpContextInfo(ctx); + return NN_OK; + } + // set context + netCtx.mEp.Set(reinterpret_cast(ctx->ubJetty->GetUpContext())); + + NN_LOG_DEBUG("[Request RWSglOneSideDoneCB] ------ ep id = " << netCtx.mEp->Id() << ", opType = " << + static_cast(ctx->opType) << ", lKey = " << ctx->lKey << ", size = " << ctx->dataSize << + ", header opcode = " << netCtx.mHeader.opCode << ", seqNo = " << netCtx.mHeader.seqNo << ", status = " << + UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::POLLED)); + + netCtx.mResult = sglContext->result; + netCtx.mOpType = + ctx->opType == UBOpContextInfo::SGL_WRITE ? UBSHcomNetRequestContext::NN_SGL_WRITTEN : + UBSHcomNetRequestContext::NN_SGL_READ; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + if (NN_UNLIKELY(memcpy_s(netCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, sglContext->iov, + sizeof(UBSHcomNetTransSgeIov) * sglContext->iovCount) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + netCtx.mOriginalSglReq.iov = netCtx.iov; + netCtx.mOriginalSglReq.iovCount = sglContext->iovCount; + netCtx.mOriginalSglReq.upCtxSize = sglContext->upCtxSize; + if (netCtx.mOriginalSglReq.upCtxSize > 0 && + netCtx.mOriginalSglReq.upCtxSize <= sizeof(UBSHcomNetTransSglRequest::upCtxData)) { + if (NN_UNLIKELY(memcpy_s(netCtx.mOriginalSglReq.upCtxData, NN_NO16, + sglContext->upCtx, sglContext->upCtxSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + } + worker->ReturnSglContextInfo(sglContext); + // called to callback + if (NN_UNLIKELY((result = mOneSideDoneHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call oneSideDoneHandler in Driver " << mName << " return non-zero for sgl type " << ctx->opType << + " done"); + } + netCtx.mEp.Set(nullptr); + worker->ReturnOpContextInfo(ctx); + + return NN_OK; +} + +int NetDriverUBWithOob::OneSideDoneCB(UBOpContextInfo *ctx) +{ + int result = 0; + static thread_local UBSHcomNetRequestContext netCtx{}; + auto worker = reinterpret_cast(ctx->ubJetty->GetUpContext1()); + ctx->ubJetty->ReturnOneSideWr(); + if (ctx->opType == UBOpContextInfo::WRITE || ctx->opType == UBOpContextInfo::READ) { + // set context + netCtx.mResult = UBOpContextInfo::GetNResult(ctx->opResultType); + netCtx.mEp.Set(reinterpret_cast(ctx->ubJetty->GetUpContext())); + NN_LOG_DEBUG("[Request oneSideDown] ------ ep id = " << netCtx.mEp->Id() << ", opType = " << + static_cast(ctx->opType) << ", lKey = " << ctx->lKey << ", size = " << ctx->dataSize << + ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::POLLED)); + netCtx.mOpType = + ctx->opType == UBOpContextInfo::WRITE ? UBSHcomNetRequestContext::NN_WRITTEN : + UBSHcomNetRequestContext::NN_READ; + netCtx.mHeader.Invalid(); + netCtx.mMessage = nullptr; + netCtx.mOriginalReq.lAddress = ctx->mrMemAddr; + netCtx.mOriginalReq.lKey = ctx->lKey; + netCtx.mOriginalReq.size = ctx->dataSize; + netCtx.mOriginalReq.upCtxSize = ctx->upCtxSize; + + if (netCtx.mOriginalReq.upCtxSize > 0 && + netCtx.mOriginalReq.upCtxSize <= sizeof(UBSendReadWriteRequest::upCtxData)) { + (void)memcpy_s(netCtx.mOriginalReq.upCtxData, ctx->upCtxSize, ctx->upCtx, ctx->upCtxSize); + } + + // return context to worker and ctx is not usable anymore + worker->ReturnOpContextInfo(ctx); + + // called to callback + if (NN_UNLIKELY((result = mOneSideDoneHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call oneSideDoneHandler in Driver " << mName << " failed."); + } + netCtx.mEp.Set(nullptr); + } else if (ctx->opType == UBOpContextInfo::SGL_WRITE || ctx->opType == UBOpContextInfo::SGL_READ) { + return RWSglOneSideDoneCB(ctx, netCtx); + } else if (ctx->opType == UBOpContextInfo::HB_WRITE) { + auto ep = reinterpret_cast(ctx->ubJetty->GetUpContext()); + if (ctx->opResultType == UBOpContextInfo::SUCCESS) { + ep->HbRecordCount(); + } + + worker->ReturnOpContextInfo(ctx); + } else { + NN_LOG_WARN("Unreachable path"); + } + + return NN_OK; +} + +void NetDriverUBWithOob::ProcessErrorOneSideDone(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->ubJetty == nullptr || ctx->ubJetty->GetUpContext1() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker is null of OneSidedone in Driver " << mName << ""); + return; + } + + OneSideDoneCB(ctx); +} + +void NetDriverUBWithOob::ProcessEpError(uintptr_t ep) +{ + auto endpointPtr = reinterpret_cast(ep); + + // UBWorker 线程与心跳线程只会有一个成功 + bool process = false; + if (NN_UNLIKELY(!endpointPtr->EPBrokenProcessed().compare_exchange_strong(process, true))) { + NN_LOG_WARN("Ep id " << endpointPtr->Id() << " broken handled by other thread"); + return; + } + + if (endpointPtr->State().Compare(NEP_ESTABLISHED)) { + endpointPtr->State().Set(NEP_BROKEN); + } + + // 这里存在大段注释,解释了一些极限情况下的异常可能性 + auto qp = endpointPtr->GetQp(); + qp->Stop(); + + NN_LOG_WARN("Handle Ep state " << NEPStateToString(endpointPtr->State().Get()) << ", Ep id " << endpointPtr->Id() << + " , try call Ep broken handle"); + mEndPointBrokenHandler(endpointPtr); +} + +void NetDriverUBWithOob::ProcessQPError(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->ubJetty == nullptr || ctx->ubJetty->GetUpContext1() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker is null of ProcessQPError in Driver " << mName << ""); + return; + } + + // get ep + auto epPtr = reinterpret_cast(ctx->ubJetty->GetUpContext()); + ProcessEpError(reinterpret_cast(epPtr)); +} + +void NetDriverUBWithOob::ProcessTwoSideHeartbeat(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx) +{ + auto tmpEp = reinterpret_cast(ctx->ubJetty->GetUpContext()); + if (netCtx.mHeader.opCode == HB_SEND_OP) { + char data; + UBSHcomNetTransRequest req((void *)(&data), sizeof(data), 0); + tmpEp->PostSend(HB_RECV_OP, req, 0); + netCtx.mEp.Set(nullptr); + return; + } + if (netCtx.mHeader.opCode == HB_RECV_OP) { + tmpEp->HbRecordCount(); + netCtx.mEp.Set(nullptr); + return; + } +} + +NResult NetDriverUBWithOob::NewRequestOnEncryption(UBOpContextInfo *ctx, UBSHcomNetMessage &msg, bool &messageReady, + UBSHcomNetRequestContext &netCtx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->ubJetty == nullptr || ctx->ubJetty->GetUpContext1() == 0 || + ctx->ubJetty->GetUpContext() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker or ep is null of RequestReceived in Driver " << mName); + return NN_INVALID_PARAM; + } + auto ubWorker = reinterpret_cast(ctx->ubJetty->GetUpContext1()); + auto qpUpContext = ctx->ubJetty->GetUpContext(); + auto *tmpHeader = reinterpret_cast(ctx->mrMemAddr); + UBSHcomNetEndpointPtr epPtr = reinterpret_cast(qpUpContext); + auto asyncEp = epPtr.ToChild(); + if (asyncEp == nullptr) { + NN_LOG_ERROR("Failed to get async ep"); + ubWorker->RePostReceive(ctx); + return NN_ERROR; + } + if (!asyncEp->mIsNeedEncrypt) { + NN_LOG_ERROR("Failed to validate encrypt by driver support but ep not."); + ubWorker->RePostReceive(ctx); + return NN_INVALID_PARAM; + } + size_t decryptRawLen = asyncEp->mAes.GetRawLen(tmpHeader->dataLength); + messageReady = msg.AllocateIfNeed(decryptRawLen); + if (NN_LIKELY(messageReady)) { + uint32_t decryptLen = 0; + if (!asyncEp->mAes.Decrypt(asyncEp->mSecrets, reinterpret_cast(ctx->mrMemAddr + + sizeof(UBSHcomNetTransHeader)), tmpHeader->dataLength, msg.mBuf, decryptLen)) { + NN_LOG_ERROR("Failed to decrypt data"); + (void)ubWorker->RePostReceive(ctx); + return NN_DECRYPT_FAILED; + } + if (memcpy_s(&(netCtx.mHeader), sizeof(UBSHcomNetTransHeader), tmpHeader, + sizeof(UBSHcomNetTransHeader)) != NN_OK) { + NN_LOG_ERROR("Failed to memcpy to netCtx header"); + return NN_ERROR; + } + msg.mDataLen = decryptRawLen; + } + return NN_OK; +} + +int NetDriverUBWithOob::NewRequest(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(!ValidateRequestContext(ctx))) { + return NN_ERROR; + } + + if (NN_UNLIKELY(ctx->opResultType != UBOpContextInfo::SUCCESS)) { + ProcessQPError(ctx); + return NN_OK; + } + + static thread_local UBSHcomNetRequestContext netCtx{}; + auto worker = reinterpret_cast(ctx->ubJetty->GetUpContext1()); + uint32_t immData = *reinterpret_cast(ctx->upCtx); + bool messageReady = true; + auto qpUpContext = ctx->ubJetty->GetUpContext(); + if (ctx->opType == UBOpContextInfo::RECEIVE && immData == 0) { + static thread_local UBSHcomNetMessage msg; + auto *tmpHeader = reinterpret_cast(ctx->mrMemAddr); + + UBSHcomNetEndpointPtr debugEp = reinterpret_cast(qpUpContext); + uint64_t epId = debugEp->Id(); + auto tmpAsyncEp = debugEp.ToChild(); + UBSHcomNetTransHeader *header = (UBSHcomNetTransHeader *)tmpHeader; + NN_LOG_DEBUG("[Request Recv] ------ common request, ep id = " << epId << ", headerCrc = " + << header->headerCrc << ", opCode = " << header->opCode << ", flags=" << header->flags << ", seqNo=" + << header->seqNo << ",timeout=" << header->timeout << ", errCode=" << header->errorCode << ", dataLength=" + << header->dataLength << " dataSize = " << ctx->dataSize << ", status = " << + UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::POLLED)); + + if (NN_UNLIKELY(NetFunc::ValidateHeaderWithDataSize(*tmpHeader, ctx->dataSize) != NN_OK)) { + NN_LOG_ERROR("Failed to validate received header " << tmpHeader->headerCrc); + worker->RePostReceive(ctx); + return NN_VALIDATE_HEADER_CRC_INVALID; + } + if (NN_LIKELY(!mOptions.enableTls)) { + messageReady = msg.AllocateIfNeed(tmpHeader->dataLength); + if (NN_LIKELY(messageReady)) { + (void)memcpy_s(&(netCtx.mHeader), sizeof(UBSHcomNetTransHeader), tmpHeader, + sizeof(UBSHcomNetTransHeader)); + (void)memcpy_s(msg.mBuf, tmpHeader->dataLength, + reinterpret_cast(ctx->mrMemAddr + sizeof(UBSHcomNetTransHeader)), tmpHeader->dataLength); + msg.mDataLen = tmpHeader->dataLength; + } + } else { + if (NewRequestOnEncryption(ctx, msg, messageReady, netCtx) != NN_OK) { + NN_LOG_ERROR("Failed to decrypt new request"); + return NN_ERROR; + } + } + + int result = 0; + if (NN_UNLIKELY((result = worker->RePostReceive(ctx)) != 0)) { + NN_LOG_ERROR("Failed to repost receive in Driver " << mName << ", result " << result); + } + + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Failed to build UBSHcomNetRequestContext or message in Driver " << mName << + ", receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << msg.mDataLen << + "] will be dropped"); + return NN_OK; + } + + netCtx.mEp.Set(reinterpret_cast(qpUpContext)); + netCtx.mMessage = &msg; + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED; + netCtx.mOriginalReq = {}; + netCtx.mHeader.dataLength = msg.mDataLen; + netCtx.extHeaderType = tmpHeader->extHeaderType; // 指导服务层处理 + + // call to callback + if (NN_UNLIKELY((result = mReceivedRequestHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call receivedRequestHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + netCtx.mEp.Set(nullptr); + } else if (ctx->opType == UBOpContextInfo::RECEIVE && immData != 0) { + static thread_local UBSHcomNetMessage msg; + + UBSHcomNetEndpointPtr debugEp = reinterpret_cast(qpUpContext); + uint64_t epId = debugEp->Id(); + auto tmpAsyncEp = debugEp.ToChild(); + NN_LOG_DEBUG("[Request Recv] ------ raw request, ep id = " << epId << ", seqNo = " << immData + << ", dataSize = " << msg.DataLen() << ", status = " << + UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::POLLED)); + + return NewReceivedRawRequest(ctx, netCtx, msg, worker, immData); + } else { + NN_LOG_WARN("Unreachable path"); + } + + return NN_OK; +} + +NResult NetDriverUBWithOob::NewReceivedRawRequest(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + UBSHcomNetMessage &msg, UBWorker *worker, uint32_t immData) const +{ /* for raw message */ + bool messageReady = true; + auto qpUpContext = ctx->ubJetty->GetUpContext(); + if (NN_LIKELY(!mOptions.enableTls)) { + messageReady = msg.AllocateIfNeed(ctx->dataSize); + if (NN_LIKELY(messageReady)) { + (void)memcpy_s(msg.mBuf, ctx->dataSize, reinterpret_cast(ctx->mrMemAddr), ctx->dataSize); + msg.mDataLen = ctx->dataSize; + } + } else { + UBSHcomNetEndpointPtr endpointPtr = reinterpret_cast(qpUpContext); + auto childEp = endpointPtr.ToChild(); + if (childEp == nullptr) { + NN_LOG_ERROR("Failed to get async ep"); + worker->RePostReceive(ctx); + return NN_ERROR; + } + if (!childEp->mIsNeedEncrypt) { + NN_LOG_ERROR("Failed to validate encrypt by driver support but ep not."); + worker->RePostReceive(ctx); + return NN_INVALID_PARAM; + } + size_t decryptRawLen = childEp->mAes.GetRawLen(ctx->dataSize); + messageReady = msg.AllocateIfNeed(decryptRawLen); + if (NN_LIKELY(messageReady)) { + uint32_t decryptLen = 0; + if (!childEp->mAes.Decrypt(childEp->mSecrets, reinterpret_cast(ctx->mrMemAddr), ctx->dataSize, + msg.mBuf, decryptLen)) { + NN_LOG_ERROR("Failed to decrypt data"); + (void)worker->RePostReceive(ctx); + return NN_DECRYPT_FAILED; + } + msg.mDataLen = decryptRawLen; + } + } + + int ret = 0; + + // after repost the ctx cannot be used anymore + if (NN_UNLIKELY((ret = worker->RePostReceive(ctx)) != 0)) { + NN_LOG_ERROR("Failed to repost receive in Driver " << mName << ", ret " << ret); + } + + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Failed to build UBSHcomNetRequestContext or message in Driver " << mName << + ", receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << msg.mDataLen << + "] will be dropped"); + return NN_OK; + } + + netCtx.mMessage = &msg; + netCtx.mEp.Set(reinterpret_cast(qpUpContext)); + netCtx.mHeader.Invalid(); + netCtx.mHeader.dataLength = msg.mDataLen; + netCtx.mHeader.seqNo = immData; + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED_RAW; + netCtx.mOriginalReq = {}; + + // call to callback + if (NN_UNLIKELY((ret = mReceivedRequestHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call receivedRequestHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + + netCtx.mEp.Set(nullptr); + + return NN_OK; +} + +NResult NetDriverUBWithOob::NewReceivedRequest(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + UBSHcomNetMessage &msg, UBWorker *worker) const +{ + bool messageReady = true; + auto *tmpHeader = reinterpret_cast(ctx->mrMemAddr); + auto qpUpContext = ctx->ubJetty->GetUpContext(); + + auto rst = NetFunc::ValidateHeaderWithDataSize(*tmpHeader, ctx->dataSize); + if (NN_UNLIKELY(rst != NN_OK)) { + worker->RePostReceive(ctx); + return rst; + } + + // 非加密场景可以免拷贝 + if (NN_LIKELY(!mOptions.enableTls)) { + auto tmpDataAddress = reinterpret_cast(ctx->mrMemAddr + sizeof(UBSHcomNetTransHeader)); + return NewReceivedRequestWithoutCopy(ctx, netCtx, msg, worker, tmpDataAddress, tmpHeader); + } + + UBSHcomNetEndpointPtr ep = reinterpret_cast(qpUpContext); + auto asyncEp = ep.ToChild(); + if (asyncEp == nullptr) { + NN_LOG_ERROR("Failed to get async ep"); + worker->RePostReceive(ctx); + return NN_ERROR; + } + if (!asyncEp->mIsNeedEncrypt) { + NN_LOG_ERROR("Failed to validate encrypt by driver support but ep not."); + worker->RePostReceive(ctx); + return NN_INVALID_PARAM; + } + uint32_t decryptRawLen = asyncEp->mAes.GetRawLen(tmpHeader->dataLength); + messageReady = msg.AllocateIfNeed(decryptRawLen); + if (NN_LIKELY(messageReady)) { + uint32_t decryptLen = 0; + if (!asyncEp->mAes.Decrypt(asyncEp->mSecrets, reinterpret_cast(ctx->mrMemAddr + + sizeof(UBSHcomNetTransHeader)), tmpHeader->dataLength, msg.mBuf, decryptLen)) { + NN_LOG_ERROR("Failed to decrypt data"); + (void)worker->RePostReceive(ctx); + return NN_DECRYPT_FAILED; + } + if (NN_UNLIKELY(memcpy_s(&(netCtx.mHeader), sizeof(UBSHcomNetTransHeader), tmpHeader, + sizeof(UBSHcomNetTransHeader)) != UB_OK)) { + NN_LOG_ERROR("Failed to copy header to netCtx"); + worker->RePostReceive(ctx); + return NN_INVALID_PARAM; + } + msg.mDataLen = decryptRawLen; + } + + int result = 0; + if (NN_UNLIKELY((result = worker->RePostReceive(ctx)) != 0)) { + NN_LOG_ERROR("Failed to repost receive in Driver " << mName << ", result " << result); + } + + if (NN_UNLIKELY(!messageReady)) { + NN_LOG_ERROR("Failed to build UBSHcomNetRequestContext or message in Driver " << mName << + ", receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << msg.mDataLen << + "] will be dropped"); + return NN_OK; + } + + netCtx.mEp.Set(reinterpret_cast(qpUpContext)); + netCtx.mMessage = &msg; + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED; + netCtx.mOriginalReq = {}; + netCtx.mHeader.dataLength = msg.mDataLen; + + // call to callback + if (NN_UNLIKELY((result = mReceivedRequestHandler(netCtx)) != NN_OK)) { + NN_LOG_ERROR("Call receivedRequestHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + + netCtx.mEp.Set(nullptr); + + return NN_OK; +} + +NResult NetDriverUBWithOob::NewReceivedRequestWithoutCopy(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + UBSHcomNetMessage &msg, UBWorker *worker, void *dataAddress, UBSHcomNetTransHeader *header) const +{ + if (NN_UNLIKELY(memcpy_s(&(netCtx.mHeader), sizeof(UBSHcomNetTransHeader), header, + sizeof(UBSHcomNetTransHeader)) != NN_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + return NN_INVALID_PARAM; + } + msg.SetBuf(dataAddress, header->dataLength); + msg.mDataLen = header->dataLength; + + netCtx.mEp.Set(reinterpret_cast(ctx->ubJetty->GetUpContext())); + netCtx.mOpType = UBSHcomNetRequestContext::NN_RECEIVED; + netCtx.mMessage = &msg; + netCtx.mOriginalReq = {}; + netCtx.mHeader.dataLength = msg.mDataLen; + netCtx.extHeaderType = header->extHeaderType; // 指导服务层处理 + int result = 0; + // call to callback + if (NN_UNLIKELY((result = mReceivedRequestHandler(netCtx)) != NN_OK)) { + NN_LOG_WARN("Verbs Call receivedRequestHandler in Driver " << mName << + " return non-zero for receive message [opCode: " << netCtx.mHeader.opCode << ", dataSize " << + netCtx.mHeader.dataLength << "]"); + } + msg.SetBuf(nullptr, 0); + netCtx.mMessage = nullptr; + netCtx.mEp.Set(nullptr); + + if (NN_UNLIKELY((result = worker->RePostReceive(ctx)) != 0)) { + NN_LOG_WARN("Verbs Failed to repost receive in Driver " << mName << ", result " << result); + } + + return NN_OK; +} + +int NetDriverUBWithOob::SendFinished(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(!ValidateRequestContext(ctx))) { + return NN_ERROR; + } + + if (NN_UNLIKELY(ctx->HasInternalError())) { + ProcessQPError(ctx); + return NN_OK; + } + + return SendFinishedCB(ctx); +} + +int NetDriverUBWithOob::OneSideDone(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(!ValidateRequestContext(ctx))) { + return NN_ERROR; + } + + if (NN_UNLIKELY(ctx->HasInternalError())) { + ProcessQPError(ctx); + return NN_OK; + } + + return OneSideDoneCB(ctx); +} + +NResult NetDriverUBWithOob::Connect(const std::string &serverUrl, const std::string &payload, UBSHcomNetEndpointPtr &ep, + uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + NetDriverOobType type; + std::string ip; + uint16_t port = 0; + if (NN_UNLIKELY(NetFunc::NN_ValidateUrl(serverUrl) != NN_OK)) { + NN_LOG_ERROR("Invalid url"); + return NN_PARAM_INVALID; + } + if (NN_UNLIKELY(ParseUrl(serverUrl, type, ip, port) != NN_OK)) { + NN_LOG_WARN("Invalid url, url:" << serverUrl); + return NN_INVALID_PARAM; + } + + if (type == NetDriverOobType::NET_OOB_UB) { + OobEidAndJettyId(ip, port); + mOptions.oobType = NetDriverOobType::NET_OOB_UB; + } + + return Connect(ip, port, payload, ep, flags, serverGrpNo, clientGrpNo, ctx); +} +} +} +#endif diff --git a/src/transport/ub/net_ub_driver_oob.h b/src/transport/ub/net_ub_driver_oob.h new file mode 100644 index 0000000000000000000000000000000000000000..91615db715e51740d264628df8d8eee87487f34f --- /dev/null +++ b/src/transport/ub/net_ub_driver_oob.h @@ -0,0 +1,180 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_NET_UB_DRIVER_OOB_H +#define HCOM_NET_UB_DRIVER_OOB_H +#ifdef UB_BUILD_ENABLED + +#include + +#include "net_oob.h" +#include "net_ub_driver.h" +#include "net_util.h" +#include "net_heartbeat.h" +#include "ub_common.h" +#include "ub_mr_fixed_buf.h" +#include "ub_urma_wrapper_public_jetty.h" + +namespace ock { +namespace hcom { +#define PUBLIC_JETTY_NUM_MIN (4) +#define PUBLIC_JETTY_NUM_MAX (1023) + +class NetDriverUBWithOob : public NetDriverUB { +public: + NetDriverUBWithOob(const std::string &name, bool startOobSvr, UBSHcomNetDriverProtocol protocol) + : NetDriverUB(name, startOobSvr, protocol) + { + OBJ_GC_INCREASE(NetDriverUBWithOob); + } + + ~NetDriverUBWithOob() override + { + OBJ_GC_DECREASE(NetDriverUBWithOob); + if (mPublicJetty != nullptr) { + mPublicJetty->DecreaseRef(); + } + } + + NResult Connect(const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, uint8_t serverGrpNo = 0, + uint8_t clientGrpNo = 0) override; + NResult Connect(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo = 0, uint8_t clientGrpNo = 0, + uint64_t ctx = 0) override; + NResult Connect(const std::string &serverUrl, const std::string &payload, UBSHcomNetEndpointPtr &ep, uint32_t flags, + uint8_t serverGrpNo = 0, uint8_t clientGrpNo = 0, uint64_t ctx = 0) override; + NResult MultiRailNewConnection(OOBTCPConnection &mConn); + uint16_t GetHbIdleTime() + { + if (mHeartBeat == nullptr) { + NN_LOG_ERROR("mHeartBeat is nullpttr"); + return 0; + } + return mHeartBeat->GetHbIdleTime(); + } + +protected: + int NewConnectionCB(OOBTCPConnection &mConn); + int NewRequest(UBOpContextInfo *ctx); + NResult NewRequestOnEncryption(UBOpContextInfo *ctx, UBSHcomNetMessage &msg, bool &messageReady, + UBSHcomNetRequestContext &netCtx); + int SendFinished(UBOpContextInfo *ctx); + int OneSideDone(UBOpContextInfo *ctx); + + NResult DoInitialize() override; + void DoUnInitialize() override; + + NResult DoStart() override; + void DoStop() override; + + void RunInUbEventThread(); + int SendFinishedCB(UBOpContextInfo *ctx); + int OneSideDoneCB(UBOpContextInfo *ctx); + int SendRawSglFinishedCB(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx); + int RWSglOneSideDoneCB(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx); + int SendSglInlineFinishedCB(UBOpContextInfo *ctx, UBSHcomNetRequestContext &requestCtx, UBWorker *worker); + + void ProcessEpError(uintptr_t ep); + void ProcessQPError(UBOpContextInfo *ctx); + void ProcessErrorNewRequest(UBOpContextInfo *ctx); + void ProcessErrorSendFinished(UBOpContextInfo *ctx); + void ProcessErrorOneSideDone(UBOpContextInfo *ctx); + void ProcessTwoSideHeartbeat(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx); + + NResult ConnectSyncEp(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo, uint64_t ctx); + + NResult CreatePublicJetty(UBPublicJetty *&publicJetty, uint32_t id, bool isServer = false); + NResult CreateUrmaListeners(UBPublicJetty *&publicJetty); + NResult PublicJettyNewConnectionCB(UBOpContextInfo *ctx); + NResult ConnectByPublicJetty(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo = 0, uint8_t clientGrpNo = 0, + uint64_t ctx = 0); + NResult ConnectSyncEpByPublicJetty(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo = 0, uint8_t clientGrpNo = 0, + uint64_t ctx = 0); + NResult ConnectASyncEpByPublicJetty(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo = 0, uint8_t clientGrpNo = 0, + uint64_t ctx = 0); + NResult CheckServerACK(JettyConnResp &exchangeMsg); + NResult PrePostReceiveOnConnection(UBJetty *qp, UBWorker *worker); + NResult FillExchMsg(JettyConnHeader *exchangeInfo, UBJetty *qp, const std::string &payload, + uint8_t serverGrpNo, UBPublicJetty *clientPublicJetty); + NResult ServerCreateEp(UBJettyExchangeInfo &info, UBJetty *qp, UBWorker *worker, JettyConnHeader *exchangeInfo, + UBPublicJetty *serverControlJetty); + NResult ServerReplyMsg(UBJetty *qp, JettyConnResp &exchangeMsg, UBPublicJetty *serverControlJetty, + uint32_t token = 0); + NResult ServerCreateJetty(UBJetty *&qp, UBWorker *worker, JettyConnResp &exchangeMsg, JettyConnHeader *exchangeInfo, + UBPublicJetty *serverControlJetty, uint32_t token = 0); + NResult ServerSelectWorker(UBWorker *&worker, JettyConnResp &exchangeMsg, uint8_t groupIndex, + UBPublicJetty *serverControlJetty); + NResult CheckMagicAndProtocol(JettyConnResp &exchangeMsg, JettyConnHeader *exchangeInfo, + UBPublicJetty *serverControlJetty); + NResult ClientCheckState(const std::string &payload); + NResult PublicJettyConnect(const std::string &oobIp, uint16_t oobPort, UBPublicJetty *&clientPublicJetty); + NResult ClientSelectWorker(UBWorker *&worker, uint8_t clientGrpNo, uint64_t id); + NResult ClientCreateJetty(UBJetty *&qp, UBWorker *worker, uint32_t token = 0); + NResult ClientSendConnReq(const std::string payload, uint64_t id, uint8_t serverGrpNo, + UBPublicJetty *clientPublicJetty, UBJetty *qp, UBPublicJetty *clientControlJetty, uint32_t token = 0); + NResult ClientEstablishConnOnReply(UBPublicJetty *clientControlJetty, UBJetty *qp, UBJettyExchangeInfo &info); + NResult ClientCreateEp(UBSHcomNetEndpointPtr &outEp, uint64_t id, UBJetty *qp, UBWorker *worker, + UBJettyExchangeInfo &info, UBPublicJetty *clientControlJetty); + NResult ClientSyncEpCreateJetty(UBJetty *&qp, UBJfc *&cq, UBPollingMode pollMode, uint32_t token = 0); + NResult PrePostReceiveOnSyncEp(UBSHcomNetEndpointPtr ep, uint16_t prePostCount, UBJetty *qp); + void ClientSyncEpSetInfo(UBSHcomNetEndpointPtr ep, UBJetty *qp, UBSHcomNetEndpointPtr &outEp); + NResult ServerEstablishCtrlConn(JettyConnHeader *exchangeInfo, UBPublicJetty *serverControlJetty); + NResult CreateSyncEp(UBJetty *qp, UBJfc *cq, uint64_t id, UBSHcomNetEndpointPtr &outEp, + UBPublicJetty *clientControlJetty); + +private: + friend class NetUBAsyncEndpoint; + friend class NetUBSyncEndpoint; + friend class UBJetty; + + NResult NewReceivedRequest(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, UBSHcomNetMessage &msg, + UBWorker *worker) const; + + NResult NewReceivedRawRequest(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, UBSHcomNetMessage &msg, + UBWorker *worker, uint32_t immData) const; + + NResult NewReceivedRequestWithoutCopy(UBOpContextInfo *ctx, UBSHcomNetRequestContext &netCtx, + UBSHcomNetMessage &msg, UBWorker *worker, void *dataAddress, UBSHcomNetTransHeader *header) const; + + void DestroyEpInWorker(UBWorker *worker); + void DestroyEpByPortNum(int portNum); + void HandleCqEvent(urma_async_event_t *event); + void HandlePortDown(int portNum); + void HandlePortActive(int portNum); + void HandleAsyncEvent(urma_async_event_t *event); + void RunInUBEventThread(); + void ClearJettyResource(UBJetty *qp); + inline bool ValidateRequestContext(UBOpContextInfo *ctx) + { + if (NN_UNLIKELY(ctx == nullptr || ctx->ubJetty == nullptr || ctx->ubJetty->GetUpContext() == 0 || + ctx->ubJetty->GetUpContext1() == 0)) { + NN_LOG_ERROR("Ctx or QP or Worker or ep is null of RequestReceived in Driver " << mName << ""); + return false; + } + return true; + } + bool mNeedStopEvent = false; + std::thread mUBEventThread; + std::atomic mEventStarted{ false }; + NetHeartbeat *mHeartBeat = nullptr; + UBPublicJetty *mPublicJetty = nullptr; + friend class NetHeartbeat; +}; +} +} + +#endif +#endif // HCOM_NET_UB_DRIVER_OOB_H diff --git a/src/transport/ub/net_ub_driver_oob_public_jetty.cpp b/src/transport/ub/net_ub_driver_oob_public_jetty.cpp new file mode 100644 index 0000000000000000000000000000000000000000..356196f8ebda97c07cef0a09462dcfcb82603de5 --- /dev/null +++ b/src/transport/ub/net_ub_driver_oob_public_jetty.cpp @@ -0,0 +1,1002 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED +#include +#include + +#include "net_monotonic.h" +#include "net_oob_ssl.h" +#include "net_ub_endpoint.h" +#include "net_ub_driver_oob.h" +#include "net_oob_secure.h" +#include "ub_mr_fixed_buf.h" +#include "ub_urma_wrapper_public_jetty.h" +#include "ub_worker.h" + +namespace ock { +namespace hcom { + +NResult NetDriverUBWithOob::CreateUrmaListeners(UBPublicJetty *&publicJetty) +{ + if (mOobListenOptions.empty()) { + NN_LOG_ERROR("No listen info is set for oob type " << UBSHcomNetDriverOobTypeToString(mOptions.oobType) << + " in driver " << mName); + return NN_INVALID_PARAM; + } + for (auto &lOpt : mOobListenOptions) { + uint32_t jettyId = lOpt.port; + if (jettyId < PUBLIC_JETTY_NUM_MIN || jettyId > PUBLIC_JETTY_NUM_MAX) { + NN_LOG_ERROR("Invalid public jetty id " << jettyId << " should in 4~1023"); + return NN_ERROR; + } + if ((CreatePublicJetty(mPublicJetty, jettyId, true)) != NN_OK) { + NN_LOG_ERROR("Failed to create oob public jetty"); + return NN_ERROR; + } + mPublicJetty->IncreaseRef(); + auto twc = lOpt.targetWorkerCount == 0 ? UINT16_MAX : lOpt.targetWorkerCount; + NetWorkerLBPtr lb = new (std::nothrow) NetWorkerLB(mName, mOptions.lbPolicy, twc); + if (NN_UNLIKELY(lb == nullptr)) { + NN_LOG_ERROR("Failed to new oob load balancer in driver " << mName); + return NN_NEW_OBJECT_FAILED; + } + mPublicJetty->SetWorkerLb(lb.Get()); + + /* add worker groups to lb */ + if (NN_UNLIKELY(lb->AddWorkerGroups(mWorkerGroups) != NN_OK)) { + NN_LOG_ERROR("Failed to added worker groups into load balancer in driver " << mName); + return NN_ERROR; + } + } + + return NN_OK; +} + +NResult NetDriverUBWithOob::ConnectByPublicJetty(const std::string &oobIp, uint16_t oobPort, const std::string &payload, + UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, uint64_t ctx) +{ + if (ClientCheckState(payload) != 0) { + NN_LOG_ERROR("Failed to connect as driver not start or payload oversize"); + return NN_ERROR; + } + if (oobPort < PUBLIC_JETTY_NUM_MIN || oobPort > PUBLIC_JETTY_NUM_MAX) { + NN_LOG_ERROR("Invalid public jetty id " << oobPort << " should in 4~1023"); + return NN_ERROR; + } + if ((flags & NET_EP_SELF_POLLING) || (flags & NET_EP_EVENT_POLLING)) { + return ConnectSyncEpByPublicJetty(oobIp, oobPort, payload, outEp, flags, serverGrpNo, clientGrpNo, ctx); + } + + return ConnectASyncEpByPublicJetty(oobIp, oobPort, payload, outEp, flags, serverGrpNo, clientGrpNo, ctx); +} + +NResult NetDriverUBWithOob::ConnectASyncEpByPublicJetty(const std::string &oobIp, uint16_t oobPort, + const std::string &payload, UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, + uint64_t ctx) +{ + UBPublicJetty *clientPublicJetty = nullptr; + if (PublicJettyConnect(oobIp, oobPort, clientPublicJetty) != 0) { + NN_LOG_ERROR("Failed to connect to server public jetty"); + if (clientPublicJetty != nullptr) { + delete clientPublicJetty; + clientPublicJetty = nullptr; + } + return NN_ERROR; + } + NetLocalAutoDecreasePtr publicJettyAutoDecPtr(clientPublicJetty); + UBPublicJetty *clientControlJetty = nullptr; + if (CreatePublicJetty(clientControlJetty, 0, false) != NN_OK) { + NN_LOG_ERROR("Failed to create control jetty in client"); + return NN_ERROR; + } + NetLocalAutoDecreasePtr clientControlJettyAutoDecPtr(clientControlJetty); + if (clientControlJetty->StartPublicJetty() != NN_OK) { + NN_LOG_ERROR("Failed to start control jetty in client"); + return NN_ERROR; + } + // choose worker + auto id = NetUuid::GenerateUuid(); + UBWorker *worker = nullptr; + if (ClientSelectWorker(worker, clientGrpNo, id) != 0) { + NN_LOG_ERROR("Failed to select worker in connection"); + return NN_ERROR; + } + // create rc jetty + UBJetty *qp = nullptr; + uint32_t token = GenerateSecureRandomUint32(); + if (ClientCreateJetty(qp, worker, token) != 0) { + NN_LOG_ERROR("Failed to create jetty in client"); + if (qp != nullptr) { + delete qp; + qp = nullptr; + } + return NN_ERROR; + } + NetLocalAutoDecreasePtr qpAutoDecPtr(qp); + // fill exchange info + if (ClientSendConnReq(payload, id, serverGrpNo, clientPublicJetty, qp, clientControlJetty, token) != 0) { + NN_LOG_ERROR("Failed to send connect request to server"); + return NN_ERROR; + } + // recv exchange info from server + UBJettyExchangeInfo info{}; + if (ClientEstablishConnOnReply(clientControlJetty, qp, info) != 0) { + NN_LOG_ERROR("Failed to establish connection on ack in client"); + return NN_ERROR; + } + if (PrePostReceiveOnConnection(qp, worker) != 0) { + NN_LOG_ERROR("Failed to pre postrecv in public client connections"); + ClearJettyResource(qp); + return NN_ERROR; + } + /* Create endpoint */ + if (ClientCreateEp(outEp, id, qp, worker, info, clientControlJetty) != 0) { + NN_LOG_ERROR("Failed to create ep in public client connection"); + ClearJettyResource(qp); + return NN_ERROR; + } + + return NN_OK; +} + +NResult NetDriverUBWithOob::ClientCheckState(const std::string &payload) +{ + if (NN_UNLIKELY(!mInited.load())) { + NN_LOG_ERROR("Driver " << mName << " is not initialized"); + return NN_NOT_INITIALIZED; + } + + if (NN_UNLIKELY(!mStarted)) { + NN_LOG_ERROR("Failed to connect on driver " << mName << " as it is not started"); + return NN_ERROR; + } + + if (payload.size() >= NN_NO1024) { + NN_LOG_ERROR("Failed to connect to server via payload size " << payload.size() << + " over limit size " << NN_NO1024); + return NN_INVALID_PARAM; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::CreatePublicJetty(UBPublicJetty *&publicJetty, uint32_t id, bool isServer) +{ + NResult result = NN_OK; + auto tmpJfc = new (std::nothrow) UBJfc(mName, mContext, false); + if (tmpJfc == nullptr) { + NN_LOG_ERROR("Failed to create jfc in public jetty"); + return NN_ERROR; + } + result = tmpJfc->Initialize(); + if (result != UB_OK) { + NN_LOG_ERROR("Jfc initialize failed in create public jetty " << result); + delete(tmpJfc); + return result; + } + publicJetty = new (std::nothrow) UBPublicJetty(mName, id, mContext, tmpJfc, isServer); + if (publicJetty == nullptr) { + NN_LOG_ERROR("Failed to create public jetty"); + delete(tmpJfc); + return NN_ERROR; + } + if ((publicJetty->InitializePublicJetty(id)) != NN_OK) { + NN_LOG_ERROR("Failed to initialize public jetty"); + delete(publicJetty); + publicJetty = nullptr; + return NN_ERROR; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::PublicJettyConnect(const std::string &oobIp, uint16_t oobPort, + UBPublicJetty *&clientPublicJetty) +{ + urma_eid_t remoteEid{}; + if (CreatePublicJetty(clientPublicJetty, 0, false) != NN_OK) { + NN_LOG_ERROR("Failed to create public jetty in client"); + goto ERROR_FREE; + } + + if (clientPublicJetty->StartPublicJetty() != NN_OK) { + NN_LOG_ERROR("Failed to start public jetty in client"); + goto ERROR_FREE; + } + + if (HcomUrma::StrToEid(oobIp.c_str(), &remoteEid) != 0) { + NN_LOG_ERROR("Failed to convert to eid as eid illegal"); + goto ERROR_FREE; + } + + if (clientPublicJetty->ImportPublicJetty(remoteEid, oobPort) != 0) { + NN_LOG_ERROR("Failed to import remote public jetty in client"); + goto ERROR_FREE; + } + return NN_OK; + +ERROR_FREE: + if (clientPublicJetty != nullptr) { + delete clientPublicJetty; + clientPublicJetty = nullptr; + } + return NN_ERROR; +} + +NResult NetDriverUBWithOob::ClientSelectWorker(UBWorker *&worker, uint8_t clientGrpNo, uint64_t id) +{ + uint16_t workerIndex = 0; + if (NN_UNLIKELY(!mClientLb->ChooseWorker(clientGrpNo, std::to_string(id), workerIndex)) || + workerIndex >= mWorkers.size()) { + NN_LOG_ERROR("Failed to choose worker during connect in driver " << mName); + return NN_ERROR; + } + + NN_ASSERT_LOG_RETURN(workerIndex < mWorkers.size(), NN_ERROR) + worker = mWorkers[workerIndex]; + NN_ASSERT_LOG_RETURN(worker != nullptr, NN_ERROR) + + return NN_OK; +} + +NResult NetDriverUBWithOob::ClientSendConnReq(const std::string payload, uint64_t id, uint8_t serverGrpNo, + UBPublicJetty *clientPublicJetty, UBJetty *qp, UBPublicJetty *clientControlJetty, uint32_t token) +{ + if (NN_UNLIKELY(clientPublicJetty == nullptr)) { + NN_LOG_ERROR("Failed to send connection request as clientPublicJetty is nullptr"); + return UB_PARAM_INVALID; + } + uint32_t msgSize = sizeof(JettyConnHeader) - 1024 + payload.size() + 1; + JettyConnHeader exchangeInfo; + exchangeInfo.epId = id; + exchangeInfo.info.token = token; + if (FillExchMsg(&exchangeInfo, qp, payload, serverGrpNo, clientControlJetty) != 0) { + NN_LOG_ERROR("Failed to fill exchange message in client public jetty"); + return NN_ERROR; + } + NN_LOG_INFO("Client send exchangeInfo clientControlJettyId = " << exchangeInfo.controlJettyId << " jettyId = " + << exchangeInfo.info.jettyId.id); + // send to server + if (clientPublicJetty->SendByPublicJetty(&exchangeInfo, msgSize) != 0) { + NN_LOG_ERROR("Failed to send data to server public jetty"); + return NN_ERROR; + } + if (clientPublicJetty->PollingCompletion() != 0) { + NN_LOG_ERROR("Failed to poll completion in client public jetty"); + return NN_ERROR; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::CheckServerACK(JettyConnResp &exchangeMsg) +{ + auto serverAcks = exchangeMsg.connResp; + switch (serverAcks) { + case MAGIC_MISMATCH: + NN_LOG_ERROR("Failed to pass server magic validation " << mName << ", result " << serverAcks); + return NN_CONNECT_REFUSED; + case WORKER_GRPNO_MISMATCH: + case WORKER_NOT_STARTED: + NN_LOG_ERROR("Failed to choose worker or not started " << mName << ", result " << serverAcks); + return NN_CONNECT_REFUSED; + case PROTOCOL_MISMATCH: + NN_LOG_ERROR("Failed to pass server protocol validation " << mName << ", result " << serverAcks); + return NN_CONNECT_PROTOCOL_MISMATCH; + case SERVER_INTERNAL_ERROR: + NN_LOG_ERROR("Server error happened, connection refused " << mName << ", result " << serverAcks); + return NN_ERROR; + case VERSION_MISMATCH: + NN_LOG_ERROR("Failed to pass server version validation " << mName << ", result " << serverAcks); + return NN_CONNECT_REFUSED; + case TLS_VERSION_MISMATCH: + NN_LOG_ERROR("Failed to pass server tls version validation " << mName << ", result " << serverAcks); + return NN_CONNECT_REFUSED; + case OK: + break; + default: + NN_LOG_ERROR("Server error happened, connection refused " << mName << ", result " << serverAcks); + return NN_ERROR; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::ClientEstablishConnOnReply(UBPublicJetty *clientControlJetty, UBJetty *qp, + UBJettyExchangeInfo &info) +{ + if (NN_UNLIKELY(qp == nullptr || clientControlJetty == nullptr)) { + NN_LOG_ERROR("Failed to establish connection on reply as qp or clientControlJetty is nullptr"); + return UB_PARAM_INVALID; + } + JettyConnResp exchangeMsg{}; + if (clientControlJetty->Receive(&exchangeMsg, sizeof(JettyConnResp)) != 0) { + NN_LOG_ERROR("Failed to receive exchange message"); + return NN_ERROR; + } + NN_LOG_INFO("Client recv exchangeMsg serverControlJetty id = " << exchangeMsg.serverCtrlJettyId << " jettyId = " + << exchangeMsg.info.jettyId.id); + if (CheckServerACK(exchangeMsg) != 0) { + NN_LOG_ERROR("Failed to check server ack in client public jetty"); + return NN_ERROR; + } + if (clientControlJetty->ImportPublicJetty(exchangeMsg.serverCtrlEid, exchangeMsg.serverCtrlJettyId) != 0) { + NN_LOG_ERROR("Failed to import client jetty in public server"); + return NN_ERROR; + } + info = exchangeMsg.info; + std::unique_ptr peerExInfo(new (std::nothrow) UBJettyExchangeInfo(info)); + if (!peerExInfo) { + NN_LOG_ERROR("Failed to alloc UBJettyExchangeInfo in Driver " << mName); + return NN_MALLOC_FAILED; + } + qp->StoreExchangeInfo(peerExInfo.release()); + + // import and bind rc jetty + if (qp->ChangeToReady(info) != 0) { + NN_LOG_ERROR("Failed to change qp to ready in Driver " << mName); + int8_t clientAck = -1; + clientControlJetty->SendByPublicJetty(&clientAck, sizeof(int8_t)); + clientControlJetty->PollingCompletion(); + return NN_ERROR; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::ClientCreateJetty(UBJetty *&qp, UBWorker *worker, uint32_t token) +{ + if (NN_UNLIKELY(worker == nullptr)) { + NN_LOG_ERROR("Failed to create jetty in client as worker is nullptr"); + return NN_PARAM_INVALID; + } + int result = 0; + if ((result = worker->CreateQP(qp)) != 0) { + NN_LOG_ERROR("Failed to create qp for new connection in Driver " << mName << " , result " << result); + goto ERROR_FREE; + } + qp->SetName(mName); + if ((result = qp->Initialize(mOptions.mrSendReceiveSegCount, 0, token)) != 0) { + NN_LOG_ERROR("Failed to initialize qp for new connection in Driver " << mName << " , result " << result); + goto ERROR_FREE; + } + return NN_OK; + +ERROR_FREE: + if (qp != nullptr) { + delete qp; + qp = nullptr; + } + return NN_ERROR; +} + +NResult NetDriverUBWithOob::ClientCreateEp(UBSHcomNetEndpointPtr &outEp, uint64_t id, UBJetty *qp, UBWorker *worker, + UBJettyExchangeInfo &info, UBPublicJetty *clientControlJetty) +{ + if (NN_UNLIKELY(qp == nullptr || worker == nullptr)) { + NN_LOG_ERROR("Failed to create ep in client as qp or worker is nullptr"); + return NN_PARAM_INVALID; + } + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBAsyncEndpoint(id, qp, this, worker); + if (ep.Get() == nullptr) { + NN_LOG_ERROR("Failed to create UBSHcomNetEndpoint in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + ep.ToChild()->SetRemoteHbInfo(info.hbAddress, info.hbKey, info.hbMrSize); + int8_t clientAck = 1; + NN_LOG_INFO("clientControlJetty send clientAck jetty id: " << clientControlJetty->GetJettyId()); + if (clientControlJetty->SendByPublicJetty(&clientAck, sizeof(int8_t)) != 0) { + NN_LOG_ERROR("Failed to send ready signal in public client jetty id: " << clientControlJetty->GetJettyId()); + return NN_ERROR; + } + if (clientControlJetty->PollingCompletion() != 0) { + NN_LOG_ERROR("Failed to poll completion in clientControlJetty jetty id: " << clientControlJetty->GetJettyId()); + return NN_ERROR; + } + int8_t serverAck = -1; + if (clientControlJetty->Receive(&serverAck, sizeof(int8_t)) != 0) { + NN_LOG_ERROR("Failed to receive serverAck signal from server jetty id: " << clientControlJetty->GetJettyId()); + return NN_ERROR; + } + if (serverAck != 1) { + NN_LOG_ERROR("Failed to check serverAck signal from server jetty id: " << clientControlJetty->GetJettyId()); + return NN_ERROR; + } + + // \see NetDriverUBWithOob::NewConnectionCB + qp->SetUpContext(reinterpret_cast(ep.Get())); + ep->State().Set(NEP_ESTABLISHED); + { + std::lock_guard locker(mEndPointsMutex); + mEndPoints.emplace(ep->Id(), ep); + } + NN_LOG_INFO("New connection established via public jetty, async ep id " << ep->Id() << ", jetty id: " + << qp->QpNum() << ", worker info " << worker->DetailName()); + outEp = ep; + reinterpret_cast(ep.Get())->GetQp()->SetUpId(ep->Id()); + return NN_OK; +} + + +NResult NetDriverUBWithOob::ServerEstablishCtrlConn(JettyConnHeader *exchangeInfo, UBPublicJetty *serverControlJetty) +{ + if (NN_UNLIKELY(exchangeInfo == nullptr || serverControlJetty == nullptr)) { + NN_LOG_ERROR("Failed to establish control connection as exchangeInfo or serverControlJetty is nullptr"); + return NN_PARAM_INVALID; + } + NN_LOG_INFO("Server recv exchangeInfo clientControlJettyId = " << exchangeInfo->controlJettyId << " jettyId = " + << exchangeInfo->info.jettyId.id); + urma_eid_t remoteEid = exchangeInfo->info.eid; + if (serverControlJetty->StartPublicJetty() != NN_OK) { + NN_LOG_ERROR("Failed to start public jetty in client"); + return NN_ERROR; + } + if (serverControlJetty->ImportPublicJetty(remoteEid, exchangeInfo->controlJettyId) != 0) { + NN_LOG_ERROR("Failed to import client jetty in public server"); + return NN_ERROR; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::PublicJettyNewConnectionCB(UBOpContextInfo *ctx) +{ + auto exchangeInfo = reinterpret_cast(ctx->mrMemAddr); + NN_ASSERT_LOG_RETURN(exchangeInfo != nullptr, NN_ERROR) + // connect to client public jetty + UBPublicJetty *serverControlJetty = nullptr; + if (CreatePublicJetty(serverControlJetty, 0, false) != NN_OK) { + NN_LOG_ERROR("Failed to create public jetty in client"); + return NN_ERROR; + } + NetLocalAutoDecreasePtr serverControlJettyAutoDecPtr(serverControlJetty); + // check connect info + JettyConnResp exchangeMsg{}; + exchangeMsg.msgType = UrmaConnectMsgType::EXCHANGE_MSG; + exchangeMsg.connResp = ConnectResp::OK; + if (ServerEstablishCtrlConn(exchangeInfo, serverControlJetty) != NN_OK) { + NN_LOG_ERROR("Failed to establish control connection in server"); + return NN_ERROR; + } + if (CheckMagicAndProtocol(exchangeMsg, exchangeInfo, serverControlJetty) != 0) { + NN_LOG_ERROR("Failed to check magic number or protocol"); + return NN_ERROR; + } + // choose worker + UBWorker *worker = nullptr; + if (ServerSelectWorker(worker, exchangeMsg, exchangeInfo->ConnectHeader.groupIndex, serverControlJetty) != 0) { + NN_LOG_ERROR("Failed to select in public server"); + return NN_ERROR; + } + // Create RC Jetty + UBJetty *qp = nullptr; + uint32_t token = GenerateSecureRandomUint32(); + if (ServerCreateJetty(qp, worker, exchangeMsg, exchangeInfo, serverControlJetty, token) != 0) { + NN_LOG_ERROR("Failed to create jetty in new connection callback"); + if (qp != nullptr) { + delete qp; + qp = nullptr; + } + return NN_ERROR; + } + NetLocalAutoDecreasePtr qpAutoDecPtr(qp); + // send exchange info back to client + if (ServerReplyMsg(qp, exchangeMsg, serverControlJetty, token) != 0) { + NN_LOG_ERROR("Failed to reply message to client"); + return NN_ERROR; + } + if (PrePostReceiveOnConnection(qp, worker) != 0) { + NN_LOG_ERROR("Failed to pre postrecv in public server connection cb"); + ClearJettyResource(qp); + return NN_ERROR; + } + // Create endpoint + if (ServerCreateEp(exchangeInfo->info, qp, worker, exchangeInfo, serverControlJetty) != 0) { + NN_LOG_ERROR("Failed to create ep in public server connection cb"); + ClearJettyResource(qp); + return NN_ERROR; + } + + return NN_OK; +} + +void NetDriverUBWithOob::ClearJettyResource(UBJetty *qp) +{ + if (qp == nullptr) { + NN_LOG_WARN("Failed to clear jetty resource as jetty is nullptr"); + return; + } + + // 建链失败时 EP 会先于 jetty 析构,此种情况下需要保证在触发 FLUSH_ERR_DONE 时 jetty 无法索引到 已析构的 EP,清理工 + // 作全权由本函数 ClearJettyResource 负责。 + qp->SetUpContext(0); + qp->Stop(); + + UBOpContextInfo *it = nullptr; + UBOpContextInfo *next = nullptr; + qp->GetCtxPosted(it); + while (it != nullptr) { + next = it->next; + if (it->opType != UBOpContextInfo::OpType::RECEIVE) { + NN_LOG_ERROR("Failed to clear jetty resource as invalid type"); + } + ProcessErrorNewRequest(it); + + // 至此,it指向的内存可能会归还给 mempool,再修改it指向的内存可能会引起并发冲突 + it = next; + } + return; +} + +NResult NetDriverUBWithOob::ServerCreateEp(UBJettyExchangeInfo &info, UBJetty *qp, UBWorker *worker, + JettyConnHeader *exchangeInfo, UBPublicJetty *serverControlJetty) +{ + if (NN_UNLIKELY(qp == nullptr || worker == nullptr)) { + NN_LOG_ERROR("Failed to create ep in server as qp or worker is nullptr"); + return NN_PARAM_INVALID; + } + + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBAsyncEndpoint(exchangeInfo->epId, qp, this, worker); + if (ep.Get() == nullptr) { + NN_LOG_ERROR("Failed to create UBSHcomNetEndpoint in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + ep.ToChild()->SetRemoteHbInfo(info.hbAddress, info.hbKey, info.hbMrSize); + ep->mDevIndex = mDevIndex; + ep->mPeerDevIndex = mPeerDevIndex; + ep->mBandWidth = mBandWidth; + + std::string payload; + auto payloadLen = exchangeInfo->payloadLen; + if (payloadLen > 0) { + exchangeInfo->payload[payloadLen] = '\0'; + payload = std::string(exchangeInfo->payload, payloadLen); + } + struct in_addr ipAddr{}; + char ipStr[INET_ADDRSTRLEN]{}; + ipAddr.s_addr = exchangeInfo->info.eid.in4.addr; + if (inet_ntop(AF_INET, &ipAddr, ipStr, INET_ADDRSTRLEN) == NULL) { + NN_LOG_ERROR("Failed to convert ip num to string"); + return NN_ERROR; + } + auto listenJettyId = exchangeInfo->controlJettyId; + std::string eidAndPort = std::string(ipStr) + ":" + std::to_string(listenJettyId); + ep->StoreConnInfo(exchangeInfo->info.eid.in4.addr, listenJettyId, exchangeInfo->ConnectHeader.version, payload); + + // client server + // -----------client ack----------> + // NewEndpointHandler() + // <----------server ack----------- + // 客户端 EP创建完毕 + int8_t clientAck = -1; + if (serverControlJetty->Receive(&clientAck, sizeof(int8_t)) != 0) { + NN_LOG_ERROR("Failed to receive clientAck signal from client jetty id: " << serverControlJetty->GetJettyId()); + return NN_ERROR; + } + if (clientAck != 1) { + NN_LOG_ERROR("Failed to check clientAck signal from client jetty id: " << serverControlJetty->GetJettyId()); + return NN_ERROR; + } + + NResult result = NN_OK; + NN_LOG_INFO("ServerControlJetty send serverAck jetty id: " << serverControlJetty->GetJettyId()); + if (NN_UNLIKELY(mNewEndPointHandler(eidAndPort, ep, payload) != UB_OK)) { + NN_LOG_ERROR("Called new end point handler failed jetty id: " << serverControlJetty->GetJettyId()); + result = NN_ERROR; + } + + // \see NetDriverUBWithOob::NewConnectionCB + qp->SetUpContext(reinterpret_cast(ep.Get())); + qp->SetUpId(ep->Id()); + ep->State().Set(NEP_ESTABLISHED); + + // serverAck 同步信令发送后客户端可能会立即发包,会在UBWorker中被动触发事件。如果 jetty 无法通过 UpContext() 索引 + // 到ep, 此 ep上产生的事件无法被 UBWorker 转发至回调。因此发送 serverAck 同步信令必须位于 qp->SetUpContext(...)之 + // 后。 + // \see NetDriverUBWithOob::NewConnectionCB + int8_t serverAck = (result == NN_OK) ? 1 : 0; + if (serverControlJetty->SendByPublicJetty(&serverAck, sizeof(int8_t)) != 0) { + NN_LOG_ERROR("Failed to send serverAck signal in public server jetty id: " << serverControlJetty->GetJettyId()); + return NN_ERROR; + } + if (serverControlJetty->PollingCompletion() != 0) { + NN_LOG_ERROR("Failed to poll completion in client serverControlJetty jetty id: " << + serverControlJetty->GetJettyId()); + return NN_ERROR; + } + NN_LOG_INFO("serverControlJetty end ServerHandshake jetty id: " << serverControlJetty->GetJettyId()); + + // EP 被安全创建完毕 + { + std::lock_guard locker(mEndPointsMutex); + mEndPoints.emplace(ep->Id(), ep); + } + + NN_LOG_INFO("New connection build via public jetty, ep id " << ep->Id() << ", jetty id: " << qp->QpNum() + << ", worker info " << worker->DetailName()); + return NN_OK; +} + +NResult NetDriverUBWithOob::CheckMagicAndProtocol(JettyConnResp &exchangeMsg, JettyConnHeader *exchangeInfo, + UBPublicJetty *serverControlJetty) +{ + auto header = exchangeInfo->ConnectHeader; + if (header.magic != mOptions.magic) { + NN_LOG_ERROR("Failed to match magic number from client, connection refused header.magic = " << header.magic << + ", mOptions.magic = " << mOptions.magic); + exchangeMsg.connResp = MAGIC_MISMATCH; + serverControlJetty->SendByPublicJetty(&exchangeMsg, sizeof(JettyConnResp)); + return NN_ERROR; + } + if (header.protocol != Protocol()) { + NN_LOG_ERROR("Failed to match protocol " << Protocol() << " vs " << header.protocol << " connection refused"); + exchangeMsg.connResp = PROTOCOL_MISMATCH; + serverControlJetty->SendByPublicJetty(&exchangeMsg, sizeof(JettyConnResp)); + return NN_ERROR; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::FillExchMsg(JettyConnHeader *exchangeInfo, UBJetty *qp, + const std::string &payload, uint8_t serverGrpNo, UBPublicJetty *clientControlJetty) +{ + int result = 0; + exchangeInfo->msgType = CONNECT_REQ; + exchangeInfo->controlJettyId = clientControlJetty->GetJettyId(); + exchangeInfo->SetConnHeader(mOptions.magic, mOptions.version, serverGrpNo, Protocol(), mMajorVersion, mMinorVersion, + mOptions.tlsVersion); + + exchangeInfo->info.maxSendWr = mOptions.qpSendQueueSize; + exchangeInfo->info.maxReceiveWr = mOptions.qpReceiveQueueSize; + exchangeInfo->info.receiveSegSize = mOptions.mrSendReceiveSegSize; + exchangeInfo->info.receiveSegCount = mOptions.prePostReceiveSizePerQP; + if (mHeartBeat != nullptr) { + if ((result = qp->CreateHBMemoryRegion(NN_NO128, qp->mHBLocalMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for local HB in client, result " << result); + return result; + } + if ((result = qp->CreateHBMemoryRegion(NN_NO128, qp->mHBRemoteMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for remote HB, result " << result); + qp->DestroyHBMemoryRegion(qp->mHBLocalMr); + return result; + } + qp->GetRemoteHbInfo(exchangeInfo->info); + exchangeInfo->info.isNeedSendHb = true; + } + if ((result = qp->FillExchangeInfo(exchangeInfo->info)) != 0) { + NN_LOG_ERROR("Failed to get or send ep exchange info in Driver " << mName << ", result " << result); + return result; + } + if (NN_UNLIKELY(memcpy_s(exchangeInfo->payload, NN_NO1024, payload.c_str(), payload.size()) + != NN_OK)) { + NN_LOG_ERROR("Failed to copy data"); + return NN_ERROR; + } + exchangeInfo->payloadLen = payload.size(); + exchangeInfo->payload[exchangeInfo->payloadLen] = '\0'; + return NN_OK; +} + +NResult NetDriverUBWithOob::PrePostReceiveOnConnection(UBJetty *qp, UBWorker *worker) +{ + int result = 0; + if (NN_UNLIKELY(qp == nullptr || worker == nullptr)) { + NN_LOG_ERROR("Failed to pre postrecv as qp is nullptr"); + return UB_PARAM_INVALID; + } + auto prePostCount = mOptions.prePostReceiveSizePerQP; + auto *mrSegs = new (std::nothrow) uintptr_t[prePostCount]; + if (mrSegs == nullptr) { + NN_LOG_ERROR("Failed to create mr address array in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + + NetLocalAutoFreePtr segAutoDelete(mrSegs, true); + + if (!qp->GetFreeBufferN(mrSegs, prePostCount)) { + NN_LOG_ERROR("Failed to get free mr from pool"); + return NN_ERROR; + } + + uint16_t i = 0; + for (; i < prePostCount; i++) { + if ((result = worker->PostReceive(qp, mrSegs[i], mOptions.mrSendReceiveSegSize, + reinterpret_cast(qp->GetMemorySeg()))) != 0) { + break; + } + } + + for (; i < prePostCount; i++) { + qp->ReturnBuffer(mrSegs[i]); + } + + return result; +} + +NResult NetDriverUBWithOob::ServerSelectWorker(UBWorker *&worker, JettyConnResp &exchangeMsg, + uint8_t groupIndex, UBPublicJetty *serverControlJetty) +{ + uint16_t workerIndex = 0; + NetWorkerLBPtr lb = mPublicJetty->LoadBalancer(); + NN_ASSERT_LOG_RETURN(lb.Get() != nullptr, NN_ERROR) + if (NN_UNLIKELY(!lb->ChooseWorker(groupIndex, mOobIp, workerIndex)) || + workerIndex >= mWorkers.size()) { + exchangeMsg.connResp = WORKER_GRPNO_MISMATCH; + serverControlJetty->SendByPublicJetty(&exchangeMsg, sizeof(JettyConnResp)); + return NN_ERROR; + } + worker = mWorkers[workerIndex]; + NN_ASSERT_LOG_RETURN(worker != nullptr, NN_ERROR) + if (!worker->IsWorkStarted()) { + NN_LOG_ERROR("Failed to connect worker group no " << groupIndex << " in " << mName); + exchangeMsg.connResp = WORKER_NOT_STARTED; + serverControlJetty->SendByPublicJetty(&exchangeMsg, sizeof(JettyConnResp)); + return NN_ERROR; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::ServerCreateJetty(UBJetty *&qp, UBWorker *worker, JettyConnResp &exchangeMsg, + JettyConnHeader *exchangeInfo, UBPublicJetty *serverControlJetty, uint32_t token) +{ + int result = 0; + uint64_t epId = exchangeInfo->epId; + if ((result = worker->CreateQP(qp)) != 0) { + NN_LOG_ERROR("Failed to create qp for new connection in Driver " << mName << " , result " << result); + exchangeMsg.connResp = SERVER_INTERNAL_ERROR; + serverControlJetty->SendByPublicJetty(&exchangeMsg, sizeof(JettyConnResp)); + return NN_ERROR; + } + qp->SetName(mName); + if ((result = qp->Initialize(mOptions.mrSendReceiveSegCount, 0, token)) != 0) { + NN_LOG_ERROR("Failed to initialize qp for new connection in Driver " << mName << " , result " << result); + exchangeMsg.connResp = SERVER_INTERNAL_ERROR; + serverControlJetty->SendByPublicJetty(&exchangeMsg, sizeof(JettyConnResp)); + delete qp; + qp = nullptr; + return NN_ERROR; + } + UBJettyExchangeInfo info = exchangeInfo->info; + std::unique_ptr peerExInfo(new (std::nothrow) UBJettyExchangeInfo(info)); + if (!peerExInfo) { + NN_LOG_ERROR("Failed to alloc UBJettyExchangeInfo in Driver " << mName); + delete qp; + qp = nullptr; + return NN_MALLOC_FAILED; + } + qp->StoreExchangeInfo(peerExInfo.release()); + + if ((result = qp->ChangeToReady(info)) != 0) { + NN_LOG_ERROR("Failed to change qp to ready in Driver " << mName << ", result " << result); + exchangeMsg.connResp = SERVER_INTERNAL_ERROR; + serverControlJetty->SendByPublicJetty(&exchangeMsg, sizeof(JettyConnResp)); + delete qp; + qp = nullptr; + return result; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::ServerReplyMsg(UBJetty *qp, JettyConnResp &exchangeMsg, UBPublicJetty *serverControlJetty, + uint32_t token) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to reply message as qp is nullptr"); + return UB_PARAM_INVALID; + } + int result = 0; + exchangeMsg.info.maxSendWr = mOptions.qpSendQueueSize; + exchangeMsg.info.maxReceiveWr = mOptions.qpReceiveQueueSize; + exchangeMsg.info.receiveSegSize = mOptions.mrSendReceiveSegSize; + exchangeMsg.info.receiveSegCount = mOptions.prePostReceiveSizePerQP; + exchangeMsg.serverCtrlJettyId = serverControlJetty->GetJettyId(); + exchangeMsg.serverCtrlEid = serverControlJetty->GetEid(); + exchangeMsg.info.token = token; + if (mHeartBeat != nullptr) { + if ((result = qp->CreateHBMemoryRegion(NN_NO128, qp->mHBLocalMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for local HB in server, result " << result); + return result; + } + + if ((result = qp->CreateHBMemoryRegion(NN_NO128, qp->mHBRemoteMr)) != NN_OK) { + NN_LOG_ERROR("Failed to create mr for remote HB, result " << result); + qp->DestroyHBMemoryRegion(qp->mHBLocalMr); + return result; + } + + qp->GetRemoteHbInfo(exchangeMsg.info); + exchangeMsg.info.isNeedSendHb = true; + } + + if ((result = qp->FillExchangeInfo(exchangeMsg.info)) != 0) { + NN_LOG_ERROR("Failed to get or send ep exchange info in Driver " << mName << ", result " << result); + return result; + } + NN_LOG_INFO("Server send exchangeMsg serverControlJetty = " << exchangeMsg.serverCtrlJettyId << " jettyId = " + << exchangeMsg.info.jettyId.id); + if (serverControlJetty->SendByPublicJetty(&exchangeMsg, sizeof(JettyConnResp)) != 0) { + NN_LOG_ERROR("Failed to send data in public server"); + return NN_ERROR; + } + if (serverControlJetty->PollingCompletion() != 0) { + NN_LOG_ERROR("Failed to poll completion in server serverControlJetty"); + return NN_ERROR; + } + return NN_OK; +} + +NResult NetDriverUBWithOob::CreateSyncEp(UBJetty *qp, UBJfc *cq, uint64_t id, UBSHcomNetEndpointPtr &outEp, + UBPublicJetty *clientControlJetty) +{ + auto prePostCount = mOptions.prePostReceiveSizePerQP; + static UBSHcomNetWorkerIndex workerIndex; + workerIndex.driverIdx = mIndex; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBSyncEndpoint(id, qp, cq, prePostCount + NN_NO4, this, + workerIndex); + if (ep.Get() == nullptr) { + NN_LOG_ERROR("Failed to create UB sync endpoint in Driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + NN_LOG_INFO("Create sync ep success, ep id: " << ep->Id() << ", with jetty id: " << qp->QpNum()); + + if (reinterpret_cast(ep.Get())->mCtxPool.Initialize() != UB_OK) { + NN_LOG_ERROR("Fail to initialize mCtxPool"); + } + if (PrePostReceiveOnSyncEp(ep, prePostCount, qp) != 0) { + NN_LOG_ERROR("Failed to pre post recv in client sync ep"); + return NN_ERROR; + } + + int8_t clientAck = 1; + NN_LOG_INFO("clientControlJetty send clientAck jetty id: " << clientControlJetty->GetJettyId()); + if (clientControlJetty->SendByPublicJetty(&clientAck, sizeof(clientAck)) != 0) { + NN_LOG_ERROR("Failed to send ready signal in public client jetty id: " << clientControlJetty->GetJettyId()); + return NN_ERROR; + } + if (clientControlJetty->PollingCompletion() != 0) { + NN_LOG_ERROR("Failed to poll completion in clientControlJetty jetty id: " << clientControlJetty->GetJettyId()); + return NN_ERROR; + } + + int8_t serverAck = 1; + if (clientControlJetty->Receive(&serverAck, sizeof(serverAck)) != 0) { + NN_LOG_ERROR("Failed to receive serverAck signal from server jetty id: " << clientControlJetty->GetJettyId()); + return NN_ERROR; + } + if (serverAck != 1) { + NN_LOG_ERROR("Failed to check serverAck signal from server jetty id: " << clientControlJetty->GetJettyId()); + return NN_ERROR; + } + + ClientSyncEpSetInfo(ep, qp, outEp); + return NN_OK; +} + +NResult NetDriverUBWithOob::ConnectSyncEpByPublicJetty(const std::string &oobIp, uint16_t oobPort, + const std::string &payload, UBSHcomNetEndpointPtr &outEp, uint32_t flags, uint8_t serverGrpNo, uint8_t clientGrpNo, + uint64_t ctx) +{ + // create public jetty and connect + UBPublicJetty *clientPublicJetty = nullptr; + if (PublicJettyConnect(oobIp, oobPort, clientPublicJetty) != 0) { + NN_LOG_ERROR("Failed to connect to server public jetty"); + return NN_ERROR; + } + NetLocalAutoDecreasePtr publicJettyAutoDecPtr(clientPublicJetty); + UBPublicJetty *clientControlJetty = nullptr; + if (CreatePublicJetty(clientControlJetty, 0, false) != NN_OK) { + NN_LOG_ERROR("Failed to create public jetty in client"); + return NN_ERROR; + } + NetLocalAutoDecreasePtr clientControlJettyAutoDecPtr(clientControlJetty); + if (clientControlJetty->StartPublicJetty() != NN_OK) { + NN_LOG_ERROR("Failed to start public jetty in client"); + return NN_ERROR; + } + UBPollingMode pollMode = ((flags & NET_EP_EVENT_POLLING)) ? UB_EVENT_POLLING : UB_BUSY_POLLING; + UBJetty *qp = nullptr; + UBJfc *cq = nullptr; + uint32_t token = GenerateSecureRandomUint32(); + if (ClientSyncEpCreateJetty(qp, cq, pollMode, token) != 0) { + NN_LOG_ERROR("Failed to create jetty in client sycn ep"); + return NN_ERROR; + } + NetLocalAutoDecreasePtr qpAutoDecPtr(qp); + NetLocalAutoDecreasePtr cqAutoDecPtr(cq); + // fill exchange info + auto id = NetUuid::GenerateUuid(); + if (ClientSendConnReq(payload, id, serverGrpNo, clientPublicJetty, qp, clientControlJetty, token) != 0) { + NN_LOG_ERROR("Failed to send connect request to server"); + return NN_ERROR; + } + // recv exchange info from server + UBJettyExchangeInfo info{}; + if (ClientEstablishConnOnReply(clientControlJetty, qp, info) != 0) { + NN_LOG_ERROR("Failed to establish connection on ack in client"); + return NN_ERROR; + } + if (CreateSyncEp(qp, cq, id, outEp, clientControlJetty) != 0) { + NN_LOG_ERROR("Failed to create sync ep in client"); + return NN_ERROR; + } + + return NN_OK; +} + +void NetDriverUBWithOob::ClientSyncEpSetInfo(UBSHcomNetEndpointPtr ep, UBJetty *qp, UBSHcomNetEndpointPtr &outEp) +{ + auto id = ep->Id(); + + // SyncEp 不会在 UBWorker 中处理事件,为保持一致性与 AsyncEp 采用一样顺序。 + // \see NetDriverUBWithOob::NewConnectionCB + qp->SetUpContext(reinterpret_cast(ep.Get())); + ep->State().Set(NEP_ESTABLISHED); + { + std::lock_guard locker(mEndPointsMutex); + mEndPoints.emplace(id, ep); + } + outEp = ep; + reinterpret_cast(ep.Get())->mJetty->SetUpId(id); + NN_LOG_INFO("New connection established via public jetty, sync ep id " << id); +} + +NResult NetDriverUBWithOob::PrePostReceiveOnSyncEp(UBSHcomNetEndpointPtr ep, uint16_t prePostCount, UBJetty *qp) +{ + int result = 0; + auto *mrSegs = new (std::nothrow) uintptr_t[prePostCount]; + if (mrSegs == nullptr) { + NN_LOG_ERROR("Failed to create mr address array in driver " << mName << ", probably out of memory"); + return NN_NEW_OBJECT_FAILED; + } + NetLocalAutoFreePtr segAutoDelete(mrSegs, true); + if (!qp->GetFreeBufferN(mrSegs, prePostCount)) { + NN_LOG_ERROR("Failed to get free mr from pool, result " << result); + return NN_ERROR; + } + int i = 0; + for (; i < prePostCount; i++) { + if (result = reinterpret_cast(ep.Get())->PostReceive(mrSegs[i], + mOptions.mrSendReceiveSegSize, reinterpret_cast(qp->GetMemorySeg())) != 0) { + break; + } + } + for (; i < prePostCount; i++) { + qp->ReturnBuffer(mrSegs[i]); + } + return result; +} + +NResult NetDriverUBWithOob::ClientSyncEpCreateJetty(UBJetty *&qp, UBJfc *&cq, UBPollingMode pollMode, uint32_t token) +{ + int result = 0; + JettyOptions qpOptions(mOptions.qpSendQueueSize, mOptions.qpReceiveQueueSize, mOptions.mrSendReceiveSegSize, + mOptions.prePostReceiveSizePerQP, mOptions.slave, mOptions.ubcMode); + if ((result = NetUBSyncEndpoint::CreateResources(mName, mContext, pollMode, qpOptions, qp, cq)) != 0) { + NN_LOG_ERROR("Failed to create qp and cq, result " << result); + return result; + } + qp->SetName(mName); + if (cq->Initialize() != 0) { + NN_LOG_ERROR("Failed to initialize cq for new connection in Driver " << mName); + delete cq; + delete qp; + return NN_ERROR; + } + + if (qp->Initialize(mOptions.mrSendReceiveSegCount, 0, token) != 0) { + NN_LOG_ERROR("Failed to initialize qp for new connection in Driver " << mName); + cq->UnInitialize(); + delete cq; + delete qp; + return NN_ERROR; + } + return NN_OK; +} +} +} +#endif diff --git a/src/transport/ub/net_ub_endpoint.cpp b/src/transport/ub/net_ub_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..19b6c5101469152427a444061ffef28958272d0c --- /dev/null +++ b/src/transport/ub/net_ub_endpoint.cpp @@ -0,0 +1,1889 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED + +#include "net_ub_endpoint.h" +#include "ub_worker.h" + +namespace ock { +namespace hcom { +#define STATE_VALIDATION(state, id, driver) \ + do { \ + if (NN_UNLIKELY(!(state).Compare(NEP_ESTABLISHED))) { \ + NN_LOG_ERROR("Endpoint " << (id) << " is not established, state is " << \ + UBSHcomNEPStateToString((state).Get())); \ + return NN_EP_NOT_ESTABLISHED; \ + } \ + \ + if (NN_UNLIKELY(!(driver)->IsStarted())) { \ + NN_LOG_ERROR("Failed to validate state as driver " << (driver) << " is not started"); \ + return NN_ERROR; \ + } \ + } while (0) + +#define LOCAL_REQUEST_VALIDATION(request) \ + do { \ + if (NN_UNLIKELY((request).lAddress == 0 || (request).size == 0)) { \ + NN_LOG_ERROR("Failed to validate request as source data is null or size is zero"); \ + return UB_PARAM_INVALID; \ + } \ + if (NN_UNLIKELY((request).upCtxSize > sizeof(UBOpContextInfo::upCtx))) { \ + NN_LOG_ERROR("Failed to validate request as up ctx size invalid " << (request).upCtxSize); \ + return UB_PARAM_INVALID; \ + } \ + } while (0) + +#define SIZE_VALIDATION(request, allowedSize) \ + do { \ + size_t compareSize = (request).size; \ + if (mIsNeedEncrypt) { \ + compareSize = mAes.EstimatedEncryptLen((request).size); \ + } \ + \ + if (NN_UNLIKELY(compareSize > (allowedSize))) { \ + NN_LOG_ERROR("Failed to post message as message size " << ((request).size) << \ + " is too large, use one side post"); \ + return NN_TWO_SIDE_MESSAGE_TOO_LARGE; \ + } \ + } while (0) + +#define POST_SEND_VALIDATION(state, id, driver, opCode, request, allowedSize) \ + do { \ + STATE_VALIDATION(state, id, driver); \ + LOCAL_REQUEST_VALIDATION(request); \ + SIZE_VALIDATION(request, allowedSize); \ + if (NN_UNLIKELY((opCode) >= MAX_OPCODE)) { \ + NN_LOG_ERROR("Failed to post message as opcode is invalid, which should with the range 0~" << \ + (MAX_OPCODE - 1)); \ + return NN_INVALID_OPCODE; \ + } \ + } while (0) + +#define POST_SEND_RAW_VALIDATION(state, id, driver, seqNo, request, allowedSize) \ + do { \ + STATE_VALIDATION(state, id, driver); \ + LOCAL_REQUEST_VALIDATION(request); \ + SIZE_VALIDATION(request, allowedSize); \ + if (NN_UNLIKELY((seqNo) == 0)) { \ + NN_LOG_ERROR("Failed to post raw message as seqNo must > 0"); \ + return UB_PARAM_INVALID; \ + } \ + } while (0) + +#define READ_WRITE_VALIDATION(state, id, driver, request) \ + do { \ + STATE_VALIDATION(state, id, driver); \ + LOCAL_REQUEST_VALIDATION(request); \ + if (NN_UNLIKELY((request).rAddress == 0)) { \ + NN_LOG_ERROR("Failed to validate request as remote data is null"); \ + return UB_PARAM_INVALID; \ + } \ + if (NN_OK != (driver)->ValidateMemoryRegion((request).lKey, (request).lAddress, (request).size)) { \ + NN_LOG_ERROR("Invalid MemoryRegion or local key"); \ + return NN_INVALID_LKEY; \ + } \ + } while (0) + +#define SGL_VALIDATION(request, totalSize) \ + do { \ + if (NN_UNLIKELY((request).iov == nullptr || (request).iovCount > NET_SGE_MAX_IOV || \ + (request).iovCount == 0)) { \ + NN_LOG_ERROR("Invalid iov ptr:" << (request).iov << " or iov cnt:" << (request).iovCount); \ + return UB_PARAM_INVALID; \ + } \ + if (NN_UNLIKELY((request).upCtxSize > sizeof(UBOpContextInfo::upCtx))) { \ + NN_LOG_ERROR("Failed to validate request as up ctx size invalid " << (request).upCtxSize); \ + return UB_PARAM_INVALID; \ + } \ + for (int i = 0; i < (request).iovCount; ++i) { \ + if (NN_OK != mDriver->ValidateMemoryRegion((request).iov[i].lKey, (request).iov[i].lAddress, \ + (request).iov[i].size)) { \ + NN_LOG_ERROR("Invalid MemoryRegion or lKey in iov in async PostWrite"); \ + return NN_INVALID_LKEY; \ + } \ + (totalSize) += (request).iov[i].size; \ + } \ + } while (0) + +#define READ_WRITE_SGL_VALIDATION(state, id, driver, request) \ + do { \ + STATE_VALIDATION(state, id, driver); \ + size_t tmpTotalSize = 0; \ + SGL_VALIDATION(request, tmpTotalSize); \ + for (int i = 0; i < (request).iovCount; ++i) { \ + if (NN_UNLIKELY((request).iov[i].rAddress == NN_NO0)) { \ + NN_LOG_ERROR("Failed to validate request as remote data is null, index " << i); \ + return UB_PARAM_INVALID; \ + } \ + } \ + } while (0) + +#define POST_SEND_SGL_VALIDATION(state, id, driver, seqNo, request, allowedSize, totalSize) \ + do { \ + STATE_VALIDATION(state, id, driver); \ + if (NN_UNLIKELY((seqNo) == 0)) { \ + NN_LOG_ERROR("Failed to post raw message as seqNo must > 0"); \ + return UB_PARAM_INVALID; \ + } \ + \ + SGL_VALIDATION(request, (totalSize)); \ + size_t compareSize = (totalSize); \ + if (mIsNeedEncrypt) { \ + compareSize = mAes.EstimatedEncryptLen((totalSize)); \ + } \ + \ + if (NN_UNLIKELY(compareSize > (allowedSize))) { \ + NN_LOG_ERROR("Failed to post send raw sgl as message size " << compareSize << \ + " is too large, use one side post"); \ + return NN_TWO_SIDE_MESSAGE_TOO_LARGE; \ + } \ + } while (0) + +#define ENCRYPT_RAW_SGL(tlsReq, mrBufAddress, size, mAes, mDriver) \ + do { \ + uintptr_t tmpBuff = 0; \ + if (NN_UNLIKELY(!(mDriver)->mDriverSendMR->GetFreeBuffer(tmpBuff))) { \ + NN_LOG_ERROR("Failed to post message as failed to get tmp mr buffer from pool from driver " << \ + (mDriver)->Name()); \ + return NN_GET_BUFF_FAILED; \ + } \ + \ + uint32_t iovOffset = 0; \ + for (int i = 0; i < request.iovCount; i++) { \ + (void)memcpy_s(reinterpret_cast(tmpBuff + iovOffset), request.iov[i].size, \ + reinterpret_cast(request.iov[i].lAddress), request.iov[i].size); \ + iovOffset += request.iov[i].size; \ + } \ + \ + if (NN_UNLIKELY(!(mDriver)->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { \ + NN_LOG_ERROR("Failed to post message as failed to get mr buffer from pool from driver " << \ + (mDriver)->Name()); \ + (void)(mDriver)->mDriverSendMR->ReturnBuffer(tmpBuff); \ + return NN_GET_BUFF_FAILED; \ + } \ + \ + uint32_t cipherLen = 0; \ + if (!(mAes).Encrypt(mSecrets, reinterpret_cast(tmpBuff), size, \ + reinterpret_cast(mrBufAddress), cipherLen)) { \ + NN_LOG_ERROR("Failed to post send message as encryption failure"); \ + (void)(mDriver)->mDriverSendMR->ReturnBuffer(tmpBuff); \ + (void)(mDriver)->mDriverSendMR->ReturnBuffer(mrBufAddress); \ + return NN_ENCRYPT_FAILED; \ + } \ + \ + (tlsReq).lAddress = mrBufAddress; \ + (tlsReq).lKey = (mDriver)->mDriverSendMR->GetLKey(); \ + (tlsReq).srcSeg = (mDriver)->mDriverSendMR->GetMemorySeg(); \ + (tlsReq).size = cipherLen; \ + (size) = cipherLen; \ + \ + (void)(mDriver)->mDriverSendMR->ReturnBuffer(tmpBuff); \ + } while (0) + +static inline GetSglTseg(NetDriverUBWithOob *driver, UBSHcomNetTransSglRequest &sglReq) +{ + for (uint16_t i = 0; i < sglReq.iovCount; i++) { + urma_target_seg_t *tseg = nullptr; + if (driver->GetTseg(sglReq.iov[i].lKey, tseg) != NN_OK) { + NN_LOG_ERROR("Failed to post read request, as get tseg failed"); + return UB_PARAM_INVALID; + } + sglReq.iov[i].srcSeg = static_cast(tseg); + } + return NN_OK; +} + +NetUBAsyncEndpoint::NetUBAsyncEndpoint(uint64_t id, UBJetty *qp, NetDriverUBWithOob *driver, UBWorker *worker) + : NetEndpointImpl(id, worker != nullptr ? worker->Index() : UBSHcomNetWorkerIndex{}), + mJetty(qp), mWorker(worker), mDriver(driver) +{ + if (mDriver != nullptr) { + mDriver->IncreaseRef(); + } + + if (mWorker != nullptr) { + mWorker->IncreaseRef(); + } + + if (mJetty != nullptr) { + mJetty->IncreaseRef(); + mIsNeedSendHb = mJetty->GetExchangeInfo().isNeedSendHb; + } + + if (mJetty != nullptr && mDriver != nullptr) { + mSegSize = mDriver->mOptions.mrSendReceiveSegSize < mJetty->GetPostSendMaxSize() ? + mDriver->mOptions.mrSendReceiveSegSize : + mJetty->GetPostSendMaxSize(); + mAllowedSize = mSegSize - sizeof(UBSHcomNetTransHeader); + mDmSize = mDriver->mOptions.dmSegSize; + } + + if (mIsNeedSendHb && mDriver != nullptr) { + mHeartBeatIdleTime = mDriver->GetHbIdleTime(); + UpdateTargetHbTime(); + } + + OBJ_GC_INCREASE(NetUBAsyncEndpoint); +} + +NetUBAsyncEndpoint::~NetUBAsyncEndpoint() +{ + // jetty 析构时要求 worker、driver都存活 + if (mJetty != nullptr) { + // 当 EP 析构时,说明它不再被用户使用、已经从全局 EP 表中被删除、上层的 channel 也被 DelayEraseChannel 真正删除。 + // 如果存在 UBJetty 的 PostedCount > 0,说明存在过在 FLUSH_ERR_DONE 之后用户绕过了 EP 和 jetty 的状态检查进行 + // post 的情况。 + // + // \see NetDriverUBWithOob::ProcessEpError + if (mJetty->GetPostedCount() > 0) { + NN_LOG_WARN("There are OPs posted though jetty is in error state, flushing..."); + mJetty->Flush(); + } + + mJetty->DecreaseRef(); + mJetty = nullptr; + } + + // worker 会使用 driver 层注册的函数 + if (mWorker != nullptr) { + mWorker->DecreaseRef(); + mWorker = nullptr; + } + + if (mDriver != nullptr) { + mDriver->DecreaseRef(); + mDriver = nullptr; + } + OBJ_GC_DECREASE(NetUBAsyncEndpoint); +} + +NResult NetUBAsyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) +{ + POST_SEND_VALIDATION(mState, mId, mDriver, opCode, request, mAllowedSize); + // get mr from pool + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to async post send with seq no as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = seqNO == 0 ? NextSeq() : seqNO; + header->flags = NTH_TWO_SIDE; + + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, + (void *)request.lAddress, request.size, reinterpret_cast(mrBufAddress + + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to async post send with seq no as encryption failure"); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + header->dataLength = request.size; + (void)memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), request.size, + reinterpret_cast(request.lAddress), request.size); + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + + // change lAddress to mrAddress and set lKey + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + + UBSHcomNetTransRequest ubReq = request; + ubReq.lAddress = mrBufAddress; + ubReq.lKey = mDriver->mDriverSendMR->GetLKey(); + ubReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + + auto sendFlag = true; + uint64_t finishTime = GetFinishTime(); + NResult result = NN_OK; + TRACE_DELAY_BEGIN(UB_EP_ASYNC_POST_SEND); + do { + result = worker->PostSend(mJetty, ubReq, + reinterpret_cast(mDriver->mDriverSendMR->GetMemorySeg())); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + sendFlag = false; + } while (sendFlag); + + NN_LOG_ERROR("Failed to async post send with seq no, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND, result); + return result; +} + +NResult NetUBAsyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) +{ + POST_SEND_VALIDATION(mState, mId, mDriver, opCode, request, mAllowedSize); + // get mr from pool + NResult res = NN_OK; + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to async post send with op info as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header->opCode = opCode; + header->timeout = opInfo.timeout; + header->errorCode = opInfo.errorCode; + + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, + (void *)request.lAddress, request.size, reinterpret_cast(mrBufAddress + + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("Failed to async post send with op info as encryption failure"); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + header->dataLength = request.size; + (void)memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), request.size, + reinterpret_cast(request.lAddress), request.size); + } + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + + // change lAddress to mrAddress and set lKey + UBSHcomNetTransRequest ubReq = request; + ubReq.lAddress = mrBufAddress; + ubReq.lKey = mDriver->mDriverSendMR->GetLKey(); + ubReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + + auto sendOpFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_ASYNC_POST_SEND); + do { + res = worker->PostSend(mJetty, ubReq, + reinterpret_cast(mDriver->mDriverSendMR->GetMemorySeg())); + if (res == UB_OK) { + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND, res); + return NN_OK; + } else if (NeedRetry(res) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry res or timeout = 0 + sendOpFlag = false; + } while (sendOpFlag); + + NN_LOG_ERROR("Failed to async post send with op info, result " << res); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND, res); + return res; +} + +NResult NetUBAsyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo, const UBSHcomExtHeaderType extHeaderType, const void *extHeader, + uint32_t extHeaderSize) +{ + if (NN_UNLIKELY(extHeaderType == UBSHcomExtHeaderType::RAW)) { + NN_LOG_ERROR("Should not use RAW type when extHeader is given."); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(!extHeader)) { + NN_LOG_ERROR("The extHeader is invalid."); + return NN_INVALID_PARAM; + } + + // 保证 extHeaderSize + request.size <= mAllowedSize. + POST_SEND_VALIDATION(mState, mId, mDriver, opCode, request, mAllowedSize - extHeaderSize); + + // get mr from pool + NResult result = NN_OK; + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to async post send with opInfo as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->timeout = opInfo.timeout; + header->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header->errorCode = opInfo.errorCode; + header->dataLength = request.size + extHeaderSize; + header->extHeaderType = extHeaderType; + + if (mIsNeedEncrypt) { + NN_LOG_WARN("postsent encrypt is not supported now!"); + } + + // 拷贝上层指定的 header,此时将要发送的结构为 + // | UBSHcomNetTransHeader | extHeader | request body | + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader), extHeader, + extHeaderSize) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + + // 拷贝消息主体 + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader) + extHeaderSize), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader) - extHeaderSize, + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress in async ep"); + return NN_INVALID_PARAM; + } + + // 头部全部写入完毕后才生成 crc32 + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + + // lAddress -> mrAddress + UBSHcomNetTransRequest ubReq = request; + ubReq.lAddress = mrBufAddress; + ubReq.lKey = mDriver->mDriverSendMR->GetLKey(); + ubReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + + auto sendOpFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_ASYNC_POST_SEND); + do { + result = worker->PostSend(mJetty, ubReq, + reinterpret_cast(mDriver->mDriverSendMR->GetMemorySeg())); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + sendOpFlag = false; + } while (sendOpFlag); + + NN_LOG_ERROR("Failed to async post send with op info, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND, result); + return result; +} + +NResult NetUBAsyncEndpoint::PostSendSglInline( + uint16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo) +{ + // 仅支持UBC,同时需要加密必定会涉及到内存拷贝,仍然走非inline方式 + if (mIsNeedEncrypt || mJetty->GetProtocol() != UBSHcomNetDriverProtocol::UBC) { + return PostSend(opCode, request, opInfo); + } + + POST_SEND_VALIDATION(mState, mId, mDriver, opCode, request, mAllowedSize); + + NResult result = NN_OK; + UBSHcomNetTransHeader header; + header.opCode = opCode; + header.seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header.flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header.timeout = opInfo.timeout; + header.errorCode = opInfo.errorCode; + header.dataLength = request.size; + header.headerCrc = NetFunc::CalcHeaderCrc32(header); + + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + bool sendOpFlag = true; + uint64_t finishTime = GetFinishTime(); + do { + result = worker->PostSendSglInline(mJetty, header, request); + if (result == UB_OK) { + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + sendOpFlag = false; + } while (sendOpFlag); + return result; +} + +NResult NetUBAsyncEndpoint::PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) +{ + POST_SEND_RAW_VALIDATION(mState, mId, mDriver, seqNo, request, mSegSize); + + /* get mr from pool */ + NResult result = UB_OK; + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to post message as failed to get mr buffer from pool from driver " << mDriver->Name()); + return NN_GET_BUFF_FAILED; + } + + size_t msgSize = 0; + if (!mIsNeedEncrypt) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress), request.size, + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to send mr"); + return UB_PARAM_INVALID; + } + msgSize = request.size; + } else { + uint32_t cipherLen = 0; + result = mAes.Encrypt(mSecrets, + (void *)request.lAddress, request.size, reinterpret_cast(mrBufAddress), cipherLen); + if (!result) { + NN_LOG_ERROR("Failed to send raw message as encryption failure"); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + msgSize = cipherLen; + } + + UBSHcomNetTransRequest ubReq = request; + ubReq.lAddress = mrBufAddress; + ubReq.lKey = mDriver->mDriverSendMR->GetLKey(); + ubReq.size = msgSize; + + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + auto sendRawAsyncFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_ASYNC_POST_SEND_RAW); + do { + result = worker->PostSend(mJetty, ubReq, + reinterpret_cast(mDriver->mDriverSendMR->GetMemorySeg()), seqNo); + if (NN_LIKELY(result == UB_OK)) { + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND_RAW, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(NN_NO128); + continue; + } + // no retry result or timeout = 0 + sendRawAsyncFlag = false; + } while (sendRawAsyncFlag); + + NN_LOG_ERROR("Failed to post send raw request, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND_RAW, result); + return result; +} + +NResult NetUBAsyncEndpoint::PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) +{ + size_t size = 0; + POST_SEND_SGL_VALIDATION(mState, mId, mDriver, seqNo, request, mSegSize, size); + UBSHcomNetTransSglRequest sglReq = request; + if (GetSglTseg(mDriver, sglReq) != NN_OK) { + NN_LOG_ERROR("GetSglTseg failed"); + return UB_PARAM_INVALID; + } + + UBSHcomNetTransRequest tlsReq {}; // used in encryption, to do... + uintptr_t mrBufAddress = 0; + if (mIsNeedEncrypt) { + ENCRYPT_RAW_SGL(tlsReq, mrBufAddress, size, mAes, mDriver); + } + + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + NResult result = NN_OK; + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_ASYNC_POST_SEND_RAW_SGL); + do { + result = worker->PostSendSgl(mJetty, request, tlsReq, seqNo, mIsNeedEncrypt); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND_RAW_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep眠 + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + if (mIsNeedEncrypt) { + (void)mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + } + + NN_LOG_ERROR("Failed to post send raw sgl request, result: " << result); + TRACE_DELAY_END(UB_EP_ASYNC_POST_SEND_RAW_SGL, result); + return result; +} + +NResult NetUBAsyncEndpoint::PostRead(const UBSHcomNetTransRequest &request) +{ + READ_WRITE_VALIDATION(mState, mId, mDriver, request); + UBSHcomNetTransRequest reqInner = request; + urma_target_seg_t *tseg = nullptr; + if (mDriver->GetTseg(request.lKey, tseg) != NN_OK) { + NN_LOG_ERROR("Failed to post read request, as get tseg failed."); + return UB_PARAM_INVALID; + } + reqInner.srcSeg = static_cast(tseg); + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + NResult result = NN_OK; + auto asyncReadFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_ASYNC_POST_READ); + do { + result = worker->PostRead(mJetty, reqInner); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_ASYNC_POST_READ, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + asyncReadFlag = false; + } while (asyncReadFlag); + + NN_LOG_ERROR("Failed to post read request, result " << result); + TRACE_DELAY_END(UB_EP_ASYNC_POST_READ, result); + return result; +} + +NResult NetUBAsyncEndpoint::PostRead(const UBSHcomNetTransSglRequest &request) +{ + READ_WRITE_SGL_VALIDATION(mState, mId, mDriver, request); + + UBSHcomNetTransSglRequest sglReq = request; + if (GetSglTseg(mDriver, sglReq) != NN_OK) { + NN_LOG_ERROR("Failed to get sgl tseg"); + return UB_PARAM_INVALID; + } + + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + NResult result = UB_OK; + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_ASYNC_POST_READ_SGL); + do { + result = worker->PostOneSideSgl(mJetty, sglReq); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_ASYNC_POST_READ_SGL, result); + return UB_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post read sgl request, result " << result); + TRACE_DELAY_END(UB_EP_ASYNC_POST_READ_SGL, result); + return result; +} + +NResult NetUBAsyncEndpoint::PostWrite(const UBSHcomNetTransRequest &request) +{ + READ_WRITE_VALIDATION(mState, mId, mDriver, request); + UBSHcomNetTransRequest reqInner = request; + urma_target_seg_t *tseg = nullptr; + if (mDriver->GetTseg(request.lKey, tseg) != NN_OK) { + NN_LOG_ERROR("Failed to post read request, as get tseg failed"); + return UB_PARAM_INVALID; + } + reqInner.srcSeg = static_cast(tseg); + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + + NResult result = NN_OK; + auto asyncWriteFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_ASYNC_POST_WRITE); + do { + result = worker->PostWrite(mJetty, reqInner); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_ASYNC_POST_WRITE, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + asyncWriteFlag = false; + } while (asyncWriteFlag); + + NN_LOG_ERROR("Failed to post write request, result " << result); + TRACE_DELAY_END(UB_EP_ASYNC_POST_WRITE, result); + return result; +} + +NResult NetUBAsyncEndpoint::PostWrite(const UBSHcomNetTransSglRequest &request) +{ + READ_WRITE_SGL_VALIDATION(mState, mId, mDriver, request); + + UBSHcomNetTransSglRequest sglReq = request; + if (GetSglTseg(mDriver, sglReq) != NN_OK) { + NN_LOG_ERROR("GetSglTseg failed"); + return UB_PARAM_INVALID; + } + + auto worker = reinterpret_cast(mJetty->GetUpContext1()); + NResult result = UB_OK; + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_ASYNC_POST_WRITE_SGL); + do { + result = worker->PostOneSideSgl(mJetty, sglReq, false); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_ASYNC_POST_WRITE_SGL, result); + return UB_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post write sgl request, result " << result); + TRACE_DELAY_END(UB_EP_ASYNC_POST_WRITE_SGL, result); + return result; +} + +void NetUBAsyncEndpoint::UpdateTargetHbTime() +{ + mTargetHbTime = NetMonotonic::TimeSec() + mHeartBeatIdleTime; +} + +NetUBSyncEndpoint::NetUBSyncEndpoint(uint64_t id, UBJetty *qp, UBJfc *cq, uint32_t ubOpCtxPoolSize, + NetDriverUBWithOob *driver, const UBSHcomNetWorkerIndex &workerIndex) + : NetEndpointImpl(id, workerIndex), mJetty(qp), mJfc(cq), mCtxPool("ctxPool", ubOpCtxPoolSize), mDriver(driver) +{ + if (mJetty != nullptr) { + mJetty->IncreaseRef(); + } + + if (mJfc != nullptr) { + mJfc->IncreaseRef(); + } + + if (mDriver != nullptr) { + mDriver->IncreaseRef(); + } + + if (mJetty != nullptr && mDriver != nullptr) { + mSegSize = mDriver->mOptions.mrSendReceiveSegSize < mJetty->GetPostSendMaxSize() ? + mDriver->mOptions.mrSendReceiveSegSize : + mJetty->GetPostSendMaxSize(); + mAllowedSize = mSegSize - sizeof(UBSHcomNetTransHeader); + mDmSize = mDriver->mOptions.dmSegSize; + } + + /* set worker index and group index to 0xFFFF */ + mWorkerIndex.idxInGrp = INVALID_WORKER_INDEX; + mWorkerIndex.grpIdx = INVALID_WORKER_GROUP_INDEX; + + OBJ_GC_INCREASE(NetUBSyncEndpoint); +} + +NetUBSyncEndpoint::~NetUBSyncEndpoint() +{ + if (mJetty != nullptr) { + mJetty->DecreaseRef(); + mJetty = nullptr; + } + + if (mDriver != nullptr) { + mDriver->DecreaseRef(); + mDriver = nullptr; + } + + OBJ_GC_DECREASE(NetUBSyncEndpoint); +} + +NResult NetUBSyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) +{ + POST_SEND_VALIDATION(mState, mId, mDriver, opCode, request, mAllowedSize); + + // get mr from pool + NResult result = NN_OK; + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to sync post send with seq no as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + // copy message + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->seqNo = seqNO == 0 ? NextSeq() : seqNO; + header->opCode = opCode; + header->flags = NTH_TWO_SIDE; + header->dataLength = request.size; + + mLastSendSeqNo = header->seqNo; + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, + (void *)request.lAddress, request.size, reinterpret_cast(mrBufAddress + + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("Failed to sync post send with seq no as encryption failure"); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + // copy message + header->dataLength = request.size; + + (void)memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), request.size, + reinterpret_cast(request.lAddress), request.size); + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + mDemandPollingOpType = UBOpContextInfo::SEND; + + // post request + // change lAddress to mrAddress and set lKey + UBSHcomNetTransRequest ubReq = request; + ubReq.lAddress = mrBufAddress; + ubReq.lKey = mDriver->mDriverSendMR->GetLKey(); + ubReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + + auto syncSendFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_SYNC_POST_SEND); + do { + result = InnerPostSend(ubReq, reinterpret_cast(mDriver->mDriverSendMR->GetMemorySeg())); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + syncSendFlag = false; + } while (syncSendFlag); + + NN_LOG_ERROR("Failed to sync post send with seq no, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetUBSyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) +{ + POST_SEND_VALIDATION(mState, mId, mDriver, opCode, request, mAllowedSize); + + // get mr from pool + NResult result = NN_OK; + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to sync post send with opInfo as failed to get mr buffer from pool"); + return NN_GET_BUFF_FAILED; + } + + // copy message + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint16_t)NTH_TWO_SIDE; + header->timeout = opInfo.timeout; + header->dataLength = request.size; + header->errorCode = opInfo.errorCode; + + mLastSendSeqNo = header->seqNo; + if (mIsNeedEncrypt) { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, + (void *)request.lAddress, request.size, reinterpret_cast(mrBufAddress + + sizeof(UBSHcomNetTransHeader)), cipherLen)) { + NN_LOG_ERROR("Failed to sync post send with op info as encryption failure"); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + header->dataLength = cipherLen; + } else { + // copy message + header->dataLength = request.size; + + (void)memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), request.size, + reinterpret_cast(request.lAddress), request.size); + } + + /* finally fill header crc */ + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + mDemandPollingOpType = UBOpContextInfo::SEND; + + // post request + // change lAddress to mrAddress and set lKey + UBSHcomNetTransRequest ubReq = request; + ubReq.lAddress = mrBufAddress; + ubReq.lKey = mDriver->mDriverSendMR->GetLKey(); + ubReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + + auto syncSendOpFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_SYNC_POST_SEND); + do { + result = InnerPostSend(ubReq, reinterpret_cast(mDriver->mDriverSendMR->GetMemorySeg())); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + syncSendOpFlag = false; + } while (syncSendOpFlag); + + NN_LOG_ERROR("Failed to sync post send with op info, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetUBSyncEndpoint::PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo, const UBSHcomExtHeaderType extHeaderType, const void *extHeader, + uint32_t extHeaderSize) +{ + if (NN_UNLIKELY(extHeaderType == UBSHcomExtHeaderType::RAW)) { + NN_LOG_ERROR("RAW type should not be used when extHeader is given."); + return NN_INVALID_PARAM; + } + + if (NN_UNLIKELY(!extHeader)) { + NN_LOG_ERROR("The ExtHeader is invalid."); + return NN_INVALID_PARAM; + } + + // 保证 extHeaderSize + request.size <= mAllowedSize. + POST_SEND_VALIDATION(mState, mId, mDriver, opCode, request, mAllowedSize - extHeaderSize); + + // get mr from pool + NResult result = NN_OK; + uintptr_t mrBufAddress = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to async post send with op info as get mr buffer from send mr pool failed"); + return NN_GET_BUFF_FAILED; + } + + auto *header = reinterpret_cast(mrBufAddress); + bzero(header, sizeof(UBSHcomNetTransHeader)); + header->opCode = opCode; + header->seqNo = opInfo.seqNo == 0 ? NextSeq() : opInfo.seqNo; + header->flags = ((uint16_t)opInfo.flags << NN_NO8) | (uint64_t)NTH_TWO_SIDE; + header->timeout = opInfo.timeout; + header->errorCode = opInfo.errorCode; + header->extHeaderType = extHeaderType; + header->dataLength = request.size + extHeaderSize; + + mLastSendSeqNo = header->seqNo; + if (mIsNeedEncrypt) { + NN_LOG_WARN("postsent encrypt is not supported now."); + } + + // 拷贝上层指定的 header,此时将要发送的结构为 + // | UBSHcomNetTransHeader | extHeader | request body | + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader)), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader), extHeader, + extHeaderSize) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + + // 拷贝消息主体 + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress + sizeof(UBSHcomNetTransHeader) + extHeaderSize), + mDriver->mDriverSendMR->GetSingleSegSize() - sizeof(UBSHcomNetTransHeader) - extHeaderSize, + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return NN_INVALID_PARAM; + } + + // 头部全部写入完毕后才生成 crc32 + header->headerCrc = NetFunc::CalcHeaderCrc32(header); + mDemandPollingOpType = UBOpContextInfo::SEND; + + // lAddress -> mrAddress + UBSHcomNetTransRequest ubReq = request; + ubReq.lAddress = mrBufAddress; + ubReq.lKey = mDriver->mDriverSendMR->GetLKey(); + ubReq.size = sizeof(UBSHcomNetTransHeader) + header->dataLength; + + auto syncSendOpFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_SYNC_POST_SEND); + do { + result = InnerPostSend(ubReq, reinterpret_cast(mDriver->mDriverSendMR->GetMemorySeg())); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + syncSendOpFlag = false; + } while (syncSendOpFlag); + + NN_LOG_ERROR("Failed to async post send with op info, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND, result); + return result; +} + +NResult NetUBSyncEndpoint::PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) +{ + POST_SEND_RAW_VALIDATION(mState, mId, mDriver, seqNo, request, mSegSize); + + /* get mr from pool */ + NResult result = UB_OK; + uintptr_t mrBufAddress = 0; + size_t msgSize = 0; + if (NN_UNLIKELY(!mDriver->mDriverSendMR->GetFreeBuffer(mrBufAddress))) { + NN_LOG_ERROR("Failed to post raw message as failed to get mr buffer from pool from driver " << mDriver->Name()); + return UB_MEMORY_ALLOCATE_FAILED; + } + + if (!mIsNeedEncrypt) { + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress), request.size, + reinterpret_cast(request.lAddress), request.size) != NN_OK)) { + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + NN_LOG_ERROR("Failed to copy request to mrBufAddress"); + return UB_PARAM_INVALID; + } + msgSize = request.size; + } else { + uint32_t cipherLen = 0; + if (!mAes.Encrypt(mSecrets, + (void *)request.lAddress, request.size, reinterpret_cast(mrBufAddress), cipherLen)) { + NN_LOG_ERROR("Failed send message as encryption failure"); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + return NN_ENCRYPT_FAILED; + } + msgSize = cipherLen; + } + + UBSHcomNetTransRequest ubReq = request; + ubReq.lAddress = mrBufAddress; + ubReq.lKey = mDriver->mDriverSendMR->GetLKey(); + ubReq.size = msgSize; + + // 在 SEND_RAW 下,seqNo 不能为 0, 即表明 InnerPostSend 将会使用 `UBOpContextInfo::SEND_RAW` + mDemandPollingOpType = UBOpContextInfo::SEND_RAW; + + mLastSendSeqNo = seqNo; + + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_SYNC_POST_SEND_RAW); + do { + result = InnerPostSend(ubReq, reinterpret_cast(mDriver->mDriverSendMR->GetMemorySeg()), + seqNo); + if (NN_LIKELY(result == UB_OK)) { + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND_RAW, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + NN_LOG_ERROR("Failed to post raw send request, result " << result); + mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND_RAW, result); + return result; +} + +NResult NetUBSyncEndpoint::InnerPostSendSgl(const UBSendSglRWRequest &req, const UBSendReadWriteRequest &tlsReq, + uint32_t immData) +{ + if (NN_UNLIKELY(mJetty == nullptr)) { + NN_LOG_ERROR("Failed to InnerPostSendSgl with NetUBSyncEndpoint as jetty is null"); + return UB_PARAM_INVALID; + } + + static thread_local UBSglContextInfo sglCtx; + sglCtx.result = UB_OK; + sglCtx.qp = mJetty; + if (NN_UNLIKELY(memcpy_s(sglCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, + req.iov, sizeof(UBSHcomNetTransSgeIov) * req.iovCount) != UB_OK)) { + NN_LOG_ERROR("InnerPostSendSgl failed to copy the UBSHcomNetTransSgeIov to sglCtx"); + return UB_PARAM_INVALID; + } + sglCtx.iovCount = req.iovCount; + sglCtx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx.upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != UB_OK)) { + NN_LOG_ERROR("InnerPostSendSgl Failed to copy req to sglCtx"); + return UB_PARAM_INVALID; + } + } + static thread_local UBOpContextInfo ctx; + // if not encrypt reqTls lAddress\size\lKey is 0 + ctx.dataSize = tlsReq.size; + ctx.mrMemAddr = tlsReq.lAddress; + ctx.ubJetty = mJetty; + ctx.qpNum = mJetty->QpNum(); + ctx.opType = UBOpContextInfo::SEND_RAW_SGL; + ctx.opResultType = UBOpContextInfo::SUCCESS; + ctx.upCtxSize = static_cast(sizeof(UBSgeCtxInfo)); + auto upCtx = reinterpret_cast(&ctx.upCtx); + upCtx->ctx = &sglCtx; + UBSHcomNetTransSglRequest sglReq = req; + if (GetSglTseg(mDriver, sglReq) != NN_OK) { + NN_LOG_ERROR("GetSglTseg failed"); + return UB_PARAM_INVALID; + } + mJetty->IncreaseRef(); + + auto result = mJetty->PostSendSgl(sglReq.iov, sglReq.iovCount, reinterpret_cast(&ctx), immData); + if (NN_UNLIKELY(result != UB_OK)) { + mJetty->DecreaseRef(); + } + + return result; +} + +NResult NetUBSyncEndpoint::PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) +{ + size_t size = 0; + POST_SEND_SGL_VALIDATION(mState, mId, mDriver, seqNo, request, mSegSize, size); + + UBSHcomNetTransRequest tlsReq{}; + uintptr_t mrBufAddress = 0; + if (mIsNeedEncrypt) { + ENCRYPT_RAW_SGL(tlsReq, mrBufAddress, size, mAes, mDriver); + } + + mDemandPollingOpType = UBOpContextInfo::SEND_RAW_SGL; + NResult result = UB_OK; + mLastSendSeqNo = seqNo; + auto flag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_SYNC_POST_SEND_RAW_SGL); + do { + result = InnerPostSendSgl(request, tlsReq, seqNo); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND_RAW_SGL, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + if (mIsNeedEncrypt) { + (void)mDriver->mDriverSendMR->ReturnBuffer(mrBufAddress); + } + + NN_LOG_ERROR("NetUBSyncEndpoint Failed to post send raw sgl request, result " << result); + TRACE_DELAY_END(UB_EP_SYNC_POST_SEND_RAW_SGL, result); + return result; +} + +NResult NetUBSyncEndpoint::PostRead(const UBSHcomNetTransRequest &request) +{ + READ_WRITE_VALIDATION(mState, mId, mDriver, request); + + mDemandPollingOpType = UBOpContextInfo::READ; + NResult result = NN_OK; + auto readFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_SYNC_POST_READ); + do { + result = InnerPostRead(request); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_SYNC_POST_READ, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + readFlag = false; + } while (readFlag); + + NN_LOG_ERROR("Failed to post read request, result " << result); + TRACE_DELAY_END(UB_EP_SYNC_POST_READ, result); + return result; +} + +NResult NetUBSyncEndpoint::PostRead(const UBSHcomNetTransSglRequest &request) +{ + READ_WRITE_SGL_VALIDATION(mState, mId, mDriver, request); + + mDemandPollingOpType = UBOpContextInfo::SGL_READ; + NResult result = UB_OK; + auto readSglFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_SYNC_POST_READ_SGL); + do { + result = PostOneSideSgl(request); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_SYNC_POST_READ_SGL, result); + return UB_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + readSglFlag = false; + } while (readSglFlag); + + NN_LOG_ERROR("Failed to post read sgl request, result " << result); + TRACE_DELAY_END(UB_EP_SYNC_POST_READ_SGL, result); + return result; +} + +NResult NetUBSyncEndpoint::PostWrite(const UBSHcomNetTransRequest &request) +{ + READ_WRITE_VALIDATION(mState, mId, mDriver, request); + + mDemandPollingOpType = UBOpContextInfo::WRITE; + NResult result = NN_OK; + auto writeFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_SYNC_POST_WRITE); + do { + result = InnerPostWrite(request); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_SYNC_POST_WRITE, result); + return NN_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + writeFlag = false; + } while (writeFlag); + + NN_LOG_ERROR("Failed to post write request, result " << result); + TRACE_DELAY_END(UB_EP_SYNC_POST_WRITE, result); + return result; +} + +NResult NetUBSyncEndpoint::PostWrite(const UBSHcomNetTransSglRequest &request) +{ + READ_WRITE_SGL_VALIDATION(mState, mId, mDriver, request); + + mDemandPollingOpType = UBOpContextInfo::SGL_WRITE; + NResult result = UB_OK; + auto writeSglFlag = true; + uint64_t finishTime = GetFinishTime(); + TRACE_DELAY_BEGIN(UB_EP_SYNC_POST_WRITE_SGL); + do { + result = PostOneSideSgl(request, false); + if (result == UB_OK) { + TRACE_DELAY_END(UB_EP_SYNC_POST_WRITE_SGL, result); + return UB_OK; + } else if (NeedRetry(result) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + writeSglFlag = false; + } while (writeSglFlag); + + NN_LOG_ERROR("Failed to post write sgl request, result " << result); + TRACE_DELAY_END(UB_EP_SYNC_POST_WRITE_SGL, result); + return result; +} + +NResult NetUBSyncEndpoint::WaitCompletion(int32_t timeout) +{ + NN_LOG_TRACE_INFO("wait completion mDemandPollingOpType " << mDemandPollingOpType); + UBOpContextInfo *opCtx = nullptr; + NResult result = NN_OK; + uint32_t immData = 0; + +POLL_CQ: + if (NN_UNLIKELY(result = PollingCompletion(opCtx, timeout, immData)) != NN_OK) { + // do later + return result; + } + + /* If opCtx->opType doesn't match with mDemandingPollingOpType, that means wrong cqe was polled. + * Store opCtx and immData, and handle them later. Do polling cq again. */ + if (NN_UNLIKELY(opCtx->opType != mDemandPollingOpType)) { + // repost if receive opType + if (opCtx->opType == UBOpContextInfo::RECEIVE) { + if (mDelayHandleReceiveCtx == nullptr) { + mDelayHandleReceiveCtx = opCtx; + mDelayHandleReceiveImmData = immData; + goto POLL_CQ; + } else { + NN_LOG_ERROR("Receive operation type has double received, prev context is not process"); + } + } + NN_LOG_WARN("Got un-demand operation type: " << opCtx->opType << ", ignored by ep id: " << mId); + } + + opCtx->ubJetty->DecreaseRef(); + if (opCtx->opType == UBOpContextInfo::SEND && !(mDriver->mDriverSendMR->ReturnBuffer(opCtx->mrMemAddr))) { + NN_LOG_ERROR("Failed to return mr segment back in Driver " << mDriver->mName); + } + + if (opCtx->opType == UBOpContextInfo::SEND_RAW_SGL && mIsNeedEncrypt) { + // buffer should return when encrypt + (void)mDriver->mDriverSendMR->ReturnBuffer(opCtx->mrMemAddr); + } + + if (opCtx->opType == UBOpContextInfo::SGL_WRITE || opCtx->opType == UBOpContextInfo::SGL_READ) { + auto sgeCtx = reinterpret_cast(opCtx->upCtx); + auto sglCtx = sgeCtx->ctx; + result = UBOpContextInfo::GetNResult(opCtx->opResultType); + sglCtx->result = sglCtx->result < result ? sglCtx->result : result; + auto refCount = __sync_add_and_fetch(&(sglCtx->refCount), 1); + if (refCount == sglCtx->iovCount) { + return sglCtx->result; + } + goto POLL_CQ; + } + + return NN_OK; +} + +NResult NetUBSyncEndpoint::Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) +{ + NResult result = NN_OK; + UBOpContextInfo *opCtx = nullptr; + uint32_t immData = 0; + + mDemandPollingOpType = UBOpContextInfo::RECEIVE; + NN_LOG_TRACE_INFO("receive mDemandPollingOpType " << mDemandPollingOpType); + + /* Handle ctx from incorrect polling */ + if (NN_UNLIKELY(mDelayHandleReceiveCtx != nullptr)) { + opCtx = mDelayHandleReceiveCtx; + mDelayHandleReceiveCtx = nullptr; + } else if (NN_UNLIKELY(result = PollingCompletion(opCtx, timeout, immData)) != NN_OK) { + // do later + return result; + } + + do { + if (NN_UNLIKELY(opCtx->opType != mDemandPollingOpType)) { + NN_LOG_ERROR("Got a cqe with un-demand operation type " << opCtx->opType << ", ignored"); + result = NN_ERROR; + break; + } + + auto *tmpHeader = reinterpret_cast(opCtx->mrMemAddr); + result = NetFunc::ValidateHeaderWithDataSize(*tmpHeader, opCtx->dataSize); + if (NN_UNLIKELY(result != NN_OK)) { + break; + } + + auto tmpDataAddress = reinterpret_cast(opCtx->mrMemAddr + sizeof(UBSHcomNetTransHeader)); + size_t realDataSize = 0; + if (mIsNeedEncrypt) { + uint32_t decryptLen = 0; + realDataSize = mAes.GetRawLen(tmpHeader->dataLength); + auto msgReady = mRespMessage.AllocateIfNeed(realDataSize); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size: " << realDataSize << + ", probably out of memory"); + result = NN_MALLOC_FAILED; + break; + } + + if (!mAes.Decrypt(mSecrets, tmpDataAddress, tmpHeader->dataLength, mRespMessage.mBuf, + decryptLen)) { + NN_LOG_ERROR("Failed to decrypt message"); + result = NN_DECRYPT_FAILED; + break; + } + mRespMessage.mDataLen = decryptLen; + } else { + realDataSize = tmpHeader->dataLength; + auto msgReady = mRespMessage.AllocateIfNeed(realDataSize); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size: " << realDataSize << + ", probably out of memory"); + result = NN_MALLOC_FAILED; + break; + } + if (NN_UNLIKELY(memcpy_s(mRespMessage.mBuf, + mRespMessage.GetBufLen(), tmpDataAddress, realDataSize) != UB_OK)) { + NN_LOG_ERROR("Failed to copy tmpDataAddress to mRespMessage"); + result = NN_INVALID_PARAM; + break; + } + mRespMessage.mDataLen = realDataSize; + } + + if (NN_UNLIKELY(memcpy_s(&(mRespCtx.mHeader), + sizeof(UBSHcomNetTransHeader), tmpHeader, sizeof(UBSHcomNetTransHeader)) != UB_OK)) { + NN_LOG_ERROR("Failed to copy tmpHeader to mRespCtx"); + result = NN_INVALID_PARAM; + break; + } + } while (false); + + auto receiveFlag = true; + uint64_t finishTime = GetFinishTime(); + NResult rePostResult = UB_OK; + do { + rePostResult = RePostReceive(opCtx); + if (rePostResult == UB_OK) { + break; + } else if (NeedRetry(rePostResult) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry rePostResult or timeout = 0 + receiveFlag = false; + } while (receiveFlag); + + if (NN_UNLIKELY(rePostResult != UB_OK)) { + NN_LOG_ERROR("Failed to repost receive, result " << rePostResult); + mJetty->ReturnBuffer(opCtx->mrMemAddr); + return rePostResult; + } + + if (NN_LIKELY(result == NN_OK)) { + mRespCtx.mMessage = &mRespMessage; + ctx.mHeader = mRespCtx.mHeader; + ctx.mMessage = mRespCtx.mMessage; + } + + return result; +} + +void NetUBSyncEndpoint::ReceiveRawHandle(UBOpContextInfo *opCtx, uint32_t immData, NResult &result) +{ + if (NN_UNLIKELY(opCtx->opType != mDemandPollingOpType)) { + NN_LOG_ERROR("Got un-demand operation type " << opCtx->opType << " in ReceiveRaw, ignored"); + result = NN_ERROR; + return; + } + + if (NN_UNLIKELY(immData != mLastSendSeqNo)) { + NN_LOG_ERROR("Received un-matched seq no " << immData << ", demand seq no " << mLastSendSeqNo); + result = NN_SEQ_NO_NOT_MATCHED; + return; + } + + auto dataSize = opCtx->dataSize; + auto msgReady = mRespMessage.AllocateIfNeed(dataSize); + if (NN_UNLIKELY(!msgReady)) { + NN_LOG_ERROR("Failed to allocate memory for response size " << opCtx->dataSize << + ", probably out of memory"); + result = NN_MALLOC_FAILED; + return; + } + + auto tmpDataAddress = reinterpret_cast(opCtx->mrMemAddr); + if (mIsNeedEncrypt) { + uint32_t decryptLen = 0; + if (!mAes.Decrypt(mSecrets, tmpDataAddress, dataSize, mRespMessage.mBuf, decryptLen)) { + NN_LOG_ERROR("Failed to decrypt data"); + result = NN_DECRYPT_FAILED; + return; + } + mRespMessage.mDataLen = decryptLen; + } else { + if (NN_UNLIKELY(memcpy_s(mRespMessage.mBuf, mRespMessage.GetBufLen(), tmpDataAddress, dataSize) != NN_OK)) { + NN_LOG_ERROR("Failed to copy tmpDataAddress to mRespMessage"); + result = NN_INVALID_PARAM; + return; + } + mRespMessage.mDataLen = dataSize; + } +} + +UResult NetUBSyncEndpoint::ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) +{ + UBOpContextInfo *opCtx = nullptr; + NResult result = NN_OK; + uint32_t immData = 0; + + mDemandPollingOpType = UBOpContextInfo::RECEIVE; + + NN_LOG_TRACE_INFO("receive mDemandPollingOpType " << mDemandPollingOpType); + + /* Handle ctx and immData from incorrect polling */ + if (NN_UNLIKELY(mDelayHandleReceiveCtx != nullptr && mDelayHandleReceiveImmData != 0)) { + opCtx = mDelayHandleReceiveCtx; + immData = mDelayHandleReceiveImmData; + mDelayHandleReceiveCtx = nullptr; + mDelayHandleReceiveImmData = 0; + } else if (NN_UNLIKELY(result = PollingCompletion(opCtx, timeout, immData)) != NN_OK) { + // do later + return result; + } + ReceiveRawHandle(opCtx, immData, result); + auto flag = true; + uint64_t finishTime = GetFinishTime(); + UResult rePostResult = NN_OK; + uintptr_t mrMemAddr = opCtx->mrMemAddr; + do { + rePostResult = RePostReceive(opCtx); + if (rePostResult == NN_OK) { + break; + } else if (NeedRetry(rePostResult) && mDefaultTimeout != 0 && NetMonotonic::TimeNs() < finishTime) { + usleep(100UL); // LWT situation is not suitable for calling system sleep + continue; + } + // no retry result or timeout = 0 + flag = false; + } while (flag); + + if (NN_UNLIKELY(rePostResult != NN_OK)) { + NN_LOG_ERROR("NetUBSyncEndpoint Failed to repost receive raw, result " << rePostResult); + mJetty->ReturnBuffer(mrMemAddr); + return rePostResult; + } + + if (NN_LIKELY(result == NN_OK)) { + mRespCtx.mMessage = &mRespMessage; + ctx.mHeader = {}; + ctx.mHeader.opCode = -1; + ctx.mHeader.seqNo = immData; + ctx.mMessage = mRespCtx.mMessage; + } + + return result; +} + +NResult NetUBSyncEndpoint::PollingCompletion(UBOpContextInfo *&ctx, int32_t timeout, uint32_t &immData) +{ + if (NN_UNLIKELY(mJfc == nullptr)) { + NN_LOG_ERROR("Failed to polling completion with UBSyncEndpoint as cq is null"); + return UB_EP_NOT_INITIALIZED; + } + + int32_t timeoutInMs = TimeSecToMs(timeout); + urma_cr_t wc{}; + uint32_t pollCount = 1; + NResult result = UB_OK; + if (mPollingMode == UB_BUSY_POLLING) { + auto start = NetMonotonic::TimeMs(); + int64_t pollTime = 0; + do { + pollCount = 1; + result = mJfc->ProgressV(&wc, pollCount); + + pollTime = (int64_t)(NetMonotonic::TimeMs() - start); + if (pollCount == 0 && timeoutInMs >= 0 && pollTime > timeoutInMs) { + return UB_CQ_EVENT_GET_TIMOUT; + } + } while (result == UB_OK && pollCount == 0); + } else if (mPollingMode == UB_EVENT_POLLING) { + result = mJfc->EventProgressV(&wc, pollCount, timeoutInMs); + } + + if (NN_UNLIKELY(result != UB_OK)) { + return result; + } + + auto *contextInfo = reinterpret_cast(wc.user_ctx); + contextInfo->dataSize = wc.completion_len; + contextInfo->opResultType = UBOpContextInfo::OpResult(wc); + ctx = contextInfo; + if (NN_UNLIKELY(wc.status != URMA_CR_SUCCESS)) { + NN_LOG_ERROR("Poll cq failed in UBSyncEndpoint, wcStatus " << wc.status << ", opType " << contextInfo->opType); + return UB_CQ_WC_WRONG; + } + immData = wc.imm_data; + + return UB_OK; +} + +NResult NetUBSyncEndpoint::PostReceive(uintptr_t bufAddress, uint32_t bufSize, urma_target_seg_t *localSeg) +{ + if (NN_UNLIKELY(mJetty == nullptr)) { + NN_LOG_ERROR("Failed to PostReceive with NetUBSyncEndpoint as qp is null"); + return UB_PARAM_INVALID; + } + + UBOpContextInfo *ctx = nullptr; + if (NN_UNLIKELY(!mCtxPool.Dequeue(ctx))) { + NN_LOG_ERROR("Failed to PostReceive with NetUBSyncEndpoint as no ctx left"); + return UB_PARAM_INVALID; + } + + ctx->ubJetty = mJetty; + ctx->mrMemAddr = bufAddress; + ctx->localSeg = localSeg; + ctx->dataSize = bufSize; + ctx->qpNum = mJetty->QpNum(); + ctx->opType = UBOpContextInfo::RECEIVE; + ctx->opResultType = UBOpContextInfo::SUCCESS; + mJetty->IncreaseRef(); + + // attach context to qp firstly, because post cloud be finished very fast + // if posted failed, need to remove + mJetty->AddOpCtxInfo(ctx); + + auto result = mJetty->PostReceive(bufAddress, bufSize, localSeg, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + mJetty->DecreaseRef(); + mJetty->RemoveOpCtxInfo(ctx); + mCtxPool.Enqueue(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +NResult NetUBSyncEndpoint::RePostReceive(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->ubJetty == nullptr)) { + NN_LOG_ERROR("Failed to RePostReceive with UBSyncEndpoint as ctx or its qp is null"); + return UB_PARAM_INVALID; + } + + auto result = ctx->ubJetty->PostReceive(ctx->mrMemAddr, mJetty->PostRegMrSize(), ctx->localSeg, + reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + ctx->ubJetty->DecreaseRef(); + mJetty->RemoveOpCtxInfo(ctx); + mCtxPool.Enqueue(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +NResult NetUBSyncEndpoint::CreateResources(const std::string &name, UBContext *ctx, UBPollingMode pollMode, + const JettyOptions &options, UBJetty *&qp, UBJfc *&cq) +{ + if (ctx == nullptr || name.empty()) { + return UB_PARAM_INVALID; + } + + auto tmpCQ = new (std::nothrow) UBJfc(name, ctx, pollMode == UB_EVENT_POLLING); + if (tmpCQ == nullptr) { + NN_LOG_ERROR("Failed to create UBJfc, probably out of memory"); + return UB_NEW_OBJECT_FAILED; + } + + auto tmpQP = new (std::nothrow) UBJetty(name, UBJetty::NewId(), ctx, tmpCQ, options); + if (tmpQP == nullptr) { + NN_LOG_ERROR("Failed to create UBJetty, probably out of memory"); + delete tmpCQ; + return UB_NEW_OBJECT_FAILED; + } + + qp = tmpQP; + cq = tmpCQ; + + return UB_OK; +} + +NResult NetUBSyncEndpoint::InnerPostSend(const UBSendReadWriteRequest &req, urma_target_seg_t *localSeg, + uint32_t immData) +{ + if (NN_UNLIKELY(mJetty == nullptr)) { + NN_LOG_ERROR("Failed to PostSend with UBSyncEndpoint as qp is null"); + return UB_PARAM_INVALID; + } + + static thread_local UBOpContextInfo ctx{}; + ctx.ubJetty = mJetty; + ctx.mrMemAddr = req.lAddress; + ctx.dataSize = req.size; + ctx.qpNum = mJetty->QpNum(); + ctx.lKey = req.lKey; + ctx.opType = immData == 0 ? UBOpContextInfo::SEND : UBOpContextInfo::SEND_RAW; + ctx.opResultType = UBOpContextInfo::SUCCESS; + ctx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + (void)memcpy_s(ctx.upCtx, req.upCtxSize, req.upCtxData, req.upCtxSize); + } + mJetty->IncreaseRef(); + + auto result = mJetty->PostSend(req.lAddress, req.size, localSeg, &ctx, immData); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + mJetty->DecreaseRef(); + } + + // ctx could not be used if post successfully + return result; +} + +NResult NetUBSyncEndpoint::InnerPostRead(const UBSendReadWriteRequest &req) +{ + if (NN_UNLIKELY(mJetty == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with UBSyncEndpoint as qp is null"); + return UB_PARAM_INVALID; + } + + urma_target_seg_t *tseg = nullptr; + if (mDriver->GetTseg(req.lKey, tseg) != NN_OK) { + NN_LOG_ERROR("Failed to post read request as failed to get tseg"); + return UB_PARAM_INVALID; + } + + static thread_local UBOpContextInfo ctx{}; + ctx.ubJetty = mJetty; + ctx.mrMemAddr = req.lAddress; + ctx.dataSize = req.size; + ctx.qpNum = mJetty->QpNum(); + ctx.lKey = req.lKey; + ctx.opType = UBOpContextInfo::READ; + ctx.opResultType = UBOpContextInfo::SUCCESS; + ctx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + (void)memcpy_s(ctx.upCtx, req.upCtxSize, req.upCtxData, req.upCtxSize); + } + mJetty->IncreaseRef(); + + UResult result = UB_OK; + result = mJetty->PostRead(req.lAddress, tseg, req.rAddress, req.rKey, req.size, reinterpret_cast(&ctx)); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + mJetty->DecreaseRef(); + } + + // ctx could not be used if post successfully + return result; +} + +NResult NetUBSyncEndpoint::InnerPostWrite(const UBSendReadWriteRequest &req) +{ + if (NN_UNLIKELY(mJetty == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with UBSyncEndpoint as qp is null"); + return UB_PARAM_INVALID; + } + + urma_target_seg_t *tseg = nullptr; + if (mDriver->GetTseg(req.lKey, tseg) != NN_OK) { + NN_LOG_ERROR("Failed to post read request, as get tseg failed"); + return UB_PARAM_INVALID; + } + + static thread_local UBOpContextInfo ctx{}; + ctx.ubJetty = mJetty; + ctx.mrMemAddr = req.lAddress; + ctx.dataSize = req.size; + ctx.qpNum = mJetty->QpNum(); + ctx.lKey = req.lKey; + ctx.opType = UBOpContextInfo::WRITE; + ctx.opResultType = UBOpContextInfo::SUCCESS; + ctx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + (void)memcpy_s(ctx.upCtx, req.upCtxSize, req.upCtxData, req.upCtxSize); + } + mJetty->IncreaseRef(); + + UResult result = UB_OK; + result = mJetty->PostWrite(req.lAddress, tseg, req.rAddress, req.rKey, req.size, reinterpret_cast(&ctx)); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + mJetty->DecreaseRef(); + } + + // ctx could not be used if post successfully + return result; +} + +UResult NetUBSyncEndpoint::CreateOneSideCtx(const UBSgeCtxInfo &sgeInfo, const UBSHcomNetTransSgeIov *iov, + uint32_t iovCount, uint64_t (&ctxArr)[NET_SGE_MAX_IOV], bool isRead) +{ + if (iov == nullptr || iovCount == NN_NO0 || iovCount > NN_NO4 || ctxArr == nullptr) { + NN_LOG_ERROR("Urma failed to create oneSide operation ctx because param invalid"); + return UB_PARAM_INVALID; + } + static thread_local UBOpContextInfo ctx[NN_NO4] = {}; + for (uint32_t i = 0; i < iovCount; ++i) { + ctx[i].ubJetty = mJetty; + ctx[i].mrMemAddr = iov[i].lAddress; + ctx[i].dataSize = iov[i].size; + ctx[i].qpNum = mJetty->QpNum(); + ctx[i].lKey = iov[i].lKey; + ctx[i].opType = isRead ? UBOpContextInfo::SGL_READ : UBOpContextInfo::SGL_WRITE; + ctx[i].opResultType = UBOpContextInfo::SUCCESS; + ctx[i].upCtxSize = static_cast(sizeof(UBSgeCtxInfo)); + auto upCtx = static_cast((void *)&(ctx[i].upCtx)); + upCtx->ctx = sgeInfo.ctx; + upCtx->idx = i; + + mJetty->IncreaseRef(); + ctxArr[i] = reinterpret_cast(&ctx[i]); + } + return UB_OK; +} + +UResult NetUBSyncEndpoint::PostOneSideSgl(const UBSendSglRWRequest &req, bool isRead) +{ + if (NN_UNLIKELY(mJetty == nullptr)) { + NN_LOG_ERROR("Urma failed to Post oneSide with UBWorker as qp is null."); + return UB_PARAM_INVALID; + } + + static thread_local UBSglContextInfo sglCtx; + sglCtx.qp = mJetty; + sglCtx.result = UB_OK; + if (NN_UNLIKELY(memcpy_s(sglCtx.iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, req.iov, + sizeof(UBSHcomNetTransSgeIov) * req.iovCount) != UB_OK)) { + NN_LOG_ERROR("Urma post oneSide failed to copy UBSHcomNetTransSgeIov to sglCtx"); + return UB_PARAM_INVALID; + } + sglCtx.iovCount = req.iovCount; + sglCtx.upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx.upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != UB_OK)) { + NN_LOG_ERROR("Urma failed to copy upCtx to sglCtx"); + return UB_PARAM_INVALID; + } + } + + UBSgeCtxInfo sgeInfo(&sglCtx); + sglCtx.refCount = 0; + uint64_t ctxArr[NET_SGE_MAX_IOV]; + UResult result = CreateOneSideCtx(sgeInfo, req.iov, req.iovCount, ctxArr, isRead); + if (result != UB_OK) { + NN_LOG_ERROR("Urma failed to create one side ctx."); + return result; + } + UBSHcomNetTransSglRequest sglReq = req; + if (GetSglTseg(mDriver, sglReq) != NN_OK) { + NN_LOG_ERROR("GetSglTseg failed"); + return UB_PARAM_INVALID; + } + result = mJetty->PostOneSideSgl(sglReq.iov, sglReq.iovCount, ctxArr, isRead, NET_SGE_MAX_IOV); + if (NN_UNLIKELY(result != UB_OK)) { + for (int i = 0; i < req.iovCount; ++i) { + mJetty->DecreaseRef(); + } + } + return result; +} +} +} +#endif diff --git a/src/transport/ub/net_ub_endpoint.h b/src/transport/ub/net_ub_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..10ee2002f9d8881d82e7e15dc86a40cc8be7afbc --- /dev/null +++ b/src/transport/ub/net_ub_endpoint.h @@ -0,0 +1,428 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_NET_UB_ENDPOINT_H +#define HCOM_NET_UB_ENDPOINT_H +#ifdef UB_BUILD_ENABLED + +#include "transport/net_endpoint_impl.h" +#include "net_monotonic.h" +#include "net_security_alg.h" +#include "hcom_utils.h" +#include "ub_urma_wrapper_jetty.h" +#include "net_ub_driver_oob.h" + +namespace ock { +namespace hcom { +class NetUBAsyncEndpoint : public NetEndpointImpl { +public: + NetUBAsyncEndpoint(uint64_t id, UBJetty *qp, NetDriverUBWithOob *driver, UBWorker *worker); + ~NetUBAsyncEndpoint() override; + + NResult SetEpOption(UBSHcomEpOptions &epOptions) override + { + NN_LOG_WARN("[UB AsyncEp] Empty function for now"); + return NN_OK; + } + + const std::string &PeerIpAndPort() override + { + if (mJetty != nullptr) { + return mJetty->GetPeerIpAndPort(); + } + + return CONST_EMPTY_STRING; + } + + uint32_t GetSendQueueCount() override + { + return mJetty->GetSendQueueSize(); + } + + const std::string &UdsName() override + { + NN_LOG_WARN("[UB AsyncEp] Empty function for now"); + return CONST_EMPTY_STRING; + } + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + + NResult PostSendSglInline(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + + NResult PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + NResult PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo) override; + + NResult PostRead(const UBSHcomNetTransRequest &request) override; + NResult PostRead(const UBSHcomNetTransSglRequest &request) override; + NResult PostWrite(const UBSHcomNetTransRequest &request) override; + NResult PostWrite(const UBSHcomNetTransSglRequest &request) override; + void UpdateTargetHbTime(); + + bool checkTargetHbTime(uint64_t currTime) + { + if (mTargetHbTime < currTime) { + mTargetHbTime = currTime + mHeartBeatIdleTime; + return true; + } + return false; + } + + NResult WaitCompletion(int32_t timeout) override + { + NN_LOG_WARN("Invalid operation, wait completion is not supported by NetUBAsyncEndpoint"); + return NN_INVALID_OPERATION; + } + + NResult Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) override + { + NN_LOG_WARN("Invalid operation, wait completion is not supported by NetUBAsyncEndpoint"); + return NN_INVALID_OPERATION; + } + + NResult ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) override + { + NN_LOG_WARN("Invalid operation, wait completion is not supported by NetUBAsyncEndpoint"); + return NN_INVALID_OPERATION; + } + + inline bool HbCheckStateNormal() + { + if (mHbCount > mHbLastCount) { + mHbLastCount = mHbCount; + return true; + } + + return false; + } + + inline void HbRecordCount() + { + __sync_add_and_fetch(&mHbCount, 1); + } + + inline void SetRemoteHbInfo(uintptr_t address, uint64_t key, uint64_t size) + { + mRemoteHbAddress = address; + mRemoteHbKey = key; + mHbMrSize = size; + } + + inline UBJetty *GetQp() const + { + return mJetty; + } + + NResult GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &idInfo) override + { + if (!mState.Compare(NEP_ESTABLISHED)) { + NN_LOG_ERROR("[UB AsyncEp] EP is not established"); + return NN_EP_NOT_ESTABLISHED; + } + + if (!mDriver->mStartOobSvr) { + NN_LOG_ERROR("[UB AsyncEp] oob server is not start"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + if (mDriver->mOptions.oobType != NET_OOB_UDS) { + NN_LOG_ERROR("[UB AsyncEp] oob type is not uds"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + idInfo = mRemoteUdsIdInfo; + return NN_OK; + } + + bool GetPeerIpPort(std::string &ip, uint16_t &port) override + { + if (NN_UNLIKELY(mJetty == nullptr)) { + return false; + } + + auto ipPort = mJetty->GetPeerIpAndPort(); + if (NN_UNLIKELY(ipPort.empty())) { + NN_LOG_ERROR("[UB AsyncEp] ip and port of peer is empty"); + return false; + } + + std::vector ipPortVec; + NetFunc::NN_SplitStr(ipPort, ":", ipPortVec); + if (NN_UNLIKELY(ipPortVec.size() != NN_NO2)) { + NN_LOG_ERROR("[UB AsyncEp] ip and port of peer is invalid"); + return false; + } + + try { + port = std::stoi(ipPortVec[1]); + } catch (...) { + NN_LOG_ERROR("[UB AsyncEp] port of peer is invalid"); + return false; + } + if (port == 0) { + NN_LOG_ERROR("[UB AsyncEp] oob type is uds, does not have peer ip and port msg"); + return false; + } + ip = ipPortVec[0]; + + return true; + } + + void Close() override + { + NN_LOG_INFO("[UB AsyncEp] Close ep id " << mId << " by user"); + mJetty->Stop(); + } + +protected: + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo, + const UBSHcomExtHeaderType extHeaderType, const void *extHeader, uint32_t extHeaderSize) override; + +private: + uint64_t inline GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + + bool inline NeedRetry(NResult &result) + { + if (!State().Compare(NEP_ESTABLISHED)) { + result = NN_EP_NOT_ESTABLISHED; + return false; + } + + if (result == UB_QP_POST_SEND_WR_FULL || result == UB_QP_ONE_SIDE_WR_FULL || result == UB_QP_CTX_FULL) { + return true; + } + + return false; + } + + inline UBWorker *GetWorker() const + { + return mWorker; + } + + NetDriverUBWithOob *GetDriver() const + { + return mDriver; + } + + UBJetty *mJetty = nullptr; + UBWorker *mWorker = nullptr; + NetDriverUBWithOob *mDriver = nullptr; + + uint64_t mHbCount = 1; + uint64_t mHbLastCount = 0; + uintptr_t mRemoteHbAddress = 0; + uint64_t mRemoteHbKey = 0; + uint64_t mHbMrSize = 0; + uint32_t mDmSize = 0; + uint64_t mTargetHbTime = 0; + uint16_t mHeartBeatIdleTime = NN_NO60; + + friend class NetDriverUBWithOob; + friend class NetHeartbeat; + friend class UBJetty; + friend class UBWorker; // 依赖GetDriver +}; + +/* *********************************************************************************** */ +class NetUBSyncEndpoint : public NetEndpointImpl { +public: + NetUBSyncEndpoint(uint64_t id, UBJetty *qp, UBJfc *cq, uint32_t ubOpCtxPoolSize, NetDriverUBWithOob *driver, + const UBSHcomNetWorkerIndex &workerIndex); + ~NetUBSyncEndpoint() override; + + NResult SetEpOption(UBSHcomEpOptions &epOptions) override + { + NN_LOG_WARN("[UB SyncEp] Empty function for now"); + return NN_OK; + } + + uint32_t GetSendQueueCount() override + { + NN_LOG_WARN("[UB SyncEp] Empty function for now"); + return NN_OK; + } + + inline void PollingMode(UBPollingMode m) + { + mPollingMode = m; + } + + const std::string &PeerIpAndPort() override + { + if (mJetty != nullptr) { + return mJetty->GetPeerIpAndPort(); + } + + return CONST_EMPTY_STRING; + } + + const std::string &UdsName() override + { + NN_LOG_WARN("[UB SyncEp] Empty function for now"); + return CONST_EMPTY_STRING; + } + + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, uint32_t seqNO) override; + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, + const UBSHcomNetTransOpInfo &opInfo) override; + + NResult PostSendRaw(const UBSHcomNetTransRequest &request, uint32_t seqNo) override; + NResult PostSendRawSgl(const UBSHcomNetTransSglRequest &request, uint32_t seqNo = 0) override; + + NResult PostRead(const UBSHcomNetTransRequest &request) override; + NResult PostRead(const UBSHcomNetTransSglRequest &request) override; + NResult PostWrite(const UBSHcomNetTransRequest &request) override; + NResult PostWrite(const UBSHcomNetTransSglRequest &request) override; + NResult WaitCompletion(int32_t timeout) override; + + NResult Receive(int32_t timeout, UBSHcomNetResponseContext &ctx) override; + NResult ReceiveRaw(int32_t timeout, UBSHcomNetResponseContext &ctx) override; + void ReceiveRawHandle(UBOpContextInfo *opCtx, uint32_t immData, NResult &result); + + NResult InnerPostSend(const UBSendReadWriteRequest &req, urma_target_seg_t *localSeg, uint32_t immData = 0); + NResult InnerPostSendSgl(const UBSendSglRWRequest &req, const UBSendReadWriteRequest &tlsReq, uint32_t immData); + NResult InnerPostRead(const UBSendReadWriteRequest &req); + NResult InnerPostWrite(const UBSendReadWriteRequest &req); + UResult PostOneSideSgl(const UBSendSglRWRequest &req, bool isRead = true); + UResult CreateOneSideCtx(const UBSgeCtxInfo &sgeInfo, const UBSHcomNetTransSgeIov *iov, uint32_t iovCount, + uint64_t (&ctxArr)[NET_SGE_MAX_IOV], bool isRead); + + NResult PollingCompletion(UBOpContextInfo *&ctx, int32_t timeout, uint32_t &immData); + NResult PostReceive(uintptr_t bufAddress, uint32_t bufSize, urma_target_seg_t *localSeg); + NResult RePostReceive(UBOpContextInfo *ctx); + static NResult CreateResources(const std::string &name, UBContext *ctx, UBPollingMode pollMode, + const JettyOptions &options, UBJetty *&qp, UBJfc *&cq); + + inline UBJetty *GetQp() const + { + return mJetty; + } + + NResult GetRemoteUdsIdInfo(UBSHcomNetUdsIdInfo &idInfo) override + { + if (!mState.Compare(NEP_ESTABLISHED)) { + NN_LOG_ERROR("[UB SyncEp] EP is not established"); + return NN_EP_NOT_ESTABLISHED; + } + + if (!mDriver->mStartOobSvr) { + NN_LOG_ERROR("[UB SyncEp] oob server is not start"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + if (mDriver->mOptions.oobType != NET_OOB_UDS) { + NN_LOG_ERROR("[UB SyncEp] oob type is not uds"); + return NN_UDS_ID_INFO_NOT_SUPPORT; + } + + idInfo = mRemoteUdsIdInfo; + return NN_OK; + } + + bool GetPeerIpPort(std::string &ip, uint16_t &port) override + { + if (NN_UNLIKELY(mJetty == nullptr)) { + return false; + } + + auto ipPort = mJetty->GetPeerIpAndPort(); + if (NN_UNLIKELY(ipPort.empty())) { + NN_LOG_ERROR("ip and port of peer is empty"); + return false; + } + + std::vector ipPortVec; + NetFunc::NN_SplitStr(ipPort, ":", ipPortVec); + if (NN_UNLIKELY(ipPortVec.size() != NN_NO2)) { + NN_LOG_ERROR("ip and port of peer is invalid"); + return false; + } + + try { + port = std::stoi(ipPortVec[1]); + } catch (...) { + NN_LOG_ERROR("port of peer is invalid"); + return false; + } + if (port == 0) { + NN_LOG_ERROR("oob type is uds, does not have peer ip and port msg"); + return false; + } + ip = ipPortVec[0]; + + return true; + } + + void Close() override + { + mJetty->Stop(); + } + +protected: + NResult PostSend(uint16_t opCode, const UBSHcomNetTransRequest &request, const UBSHcomNetTransOpInfo &opInfo, + const UBSHcomExtHeaderType extHeaderType, const void *extHeader, uint32_t extHeaderSize) override; + +private: + inline uint64_t GetFinishTime() + { + if (mDefaultTimeout > 0) { + return NetMonotonic::TimeNs() + static_cast(mDefaultTimeout) * 1000000000UL; + } else if (mDefaultTimeout < 0) { + return UINT64_MAX; + } + + return 0; + } + + static inline bool NeedRetry(NResult result) + { + if (result == UB_QP_POST_SEND_WR_FULL || result == UB_QP_ONE_SIDE_WR_FULL || result == UB_QP_CTX_FULL) { + return true; + } + + return false; + } + + UBJetty *mJetty = nullptr; + UBJfc *mJfc = nullptr; + NetObjPool mCtxPool; + + NetDriverUBWithOob *mDriver = nullptr; + UBPollingMode mPollingMode = UBPollingMode::UB_BUSY_POLLING; + uint32_t mLastSendSeqNo = 0; + UBOpContextInfo::OpType mDemandPollingOpType = UBOpContextInfo::SEND; + UBSHcomNetResponseContext mRespCtx; + UBSHcomNetMessage mRespMessage; + UBOpContextInfo *mDelayHandleReceiveCtx = nullptr; + uint32_t mDelayHandleReceiveImmData = 0; + uint32_t mDmSize = 0; + + friend class NetDriverUBWithOob; +}; +} +} + + +#endif +#endif // HCOM_NET_UB_ENDPOINT_H diff --git a/src/transport/ub/ub_common.h b/src/transport/ub/ub_common.h new file mode 100644 index 0000000000000000000000000000000000000000..0e92318c77a6b6f7e24cf4b6b709447bb6148028 --- /dev/null +++ b/src/transport/ub/ub_common.h @@ -0,0 +1,307 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_COMMON_H +#define HCOM_UB_COMMON_H +#ifdef UB_BUILD_ENABLED + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hcom.h" +#include "hcom_def.h" +#include "hcom_num_def.h" +#include "hcom_log.h" +#include "net_common.h" +#include "net_obj_pool.h" +#include "under_api/urma/urma_api_wrapper.h" + +namespace ock { +namespace hcom { +/* + * return type + */ +using UResult = int; + +enum UBCode { + UB_OK = 0, + UB_PARAM_INVALID = 200, + UB_MEMORY_ALLOCATE_FAILED = 201, + UB_NEW_OBJECT_FAILED = 202, + UB_OPEN_FILE_FAILED = 203, + UB_READ_FILE_FAILED = 204, + UB_DEVICE_FAILED_OPEN = 205, + UB_DEVICE_INDEX_OVERFLOW = 206, + UB_DEVICE_OPEN_FAILED = 207, + UB_DEVICE_FAILED_GET_IP_ADDRESS = 208, + UB_DEVICE_NO_IP_MATCHED = 209, + UB_DEVICE_NO_IP_TO_GID_MATCHED = 210, + UB_DEVICE_INVALID_IP_MASK = 211, + UB_MR_REG_FAILED = 212, + UB_CQ_NOT_INITIALIZED = 213, + UB_CQ_POLLING_FAILED = 214, + UB_CQ_POLLING_TIMEOUT = 215, + UB_CQ_POLLING_ERROR_RESULT = 216, + UB_CQ_POLLING_UNMATCHED_OPCODE = 217, + UB_CQ_EVENT_GET_FAILED = 218, + UB_CQ_EVENT_NOTIFY_FAILED = 219, + UB_CQ_WC_WRONG = 220, + UB_CQ_EVENT_GET_TIMOUT = 221, + UB_QP_CREATE_FAILED = 222, + UB_QP_NOT_INITIALIZED = 223, + UB_QP_CHANGE_STATE_FAILED = 224, + UB_QP_POST_RECEIVE_FAILED = 225, + UB_QP_POST_SEND_FAILED = 226, + UB_QP_POST_READ_FAILED = 227, + UB_QP_POST_WRITE_FAILED = 228, + UB_QP_RECEIVE_CONFIG_ERR = 229, + UB_QP_POST_SEND_WR_FULL = 230, + UB_QP_ONE_SIDE_WR_FULL = 231, + UB_QP_CTX_FULL = 232, + UB_QP_CHANGE_ERR = 233, + UB_QP_IMPORT_FAILED = 234, + UB_QP_BIND_FAILED = 235, + UB_EP_NOT_INITIALIZED = 236, + UB_WORKER_NOT_INITIALIZED = 237, + UB_WORKER_BIND_CPU_FAILED = 238, + UB_WORKER_REQUEST_HANDLER_NOT_SET = 239, + UB_WORKER_SEND_POSTED_HANDLER_NOT_SET = 240, + UB_WORKER_ONE_SIDE_DONE_HANDLER_NOT_SET = 241, + UB_WORKER_FAILED_ADD_QP = 242, + UB_ERROR = 243, +}; + +/* constant variable */ +constexpr uint32_t TARGET_JETTY_ID_OFFSET = NN_NO10000; +constexpr uint32_t JETTY_MAX_SEND_WR = NN_NO256; +constexpr uint32_t JETTY_MAX_RECV_WR = NN_NO256; +constexpr uint32_t JETTY_MIN_RNR_TIMER = NN_NO12; +constexpr uint32_t JETTY_TIMEOUT = NN_NO14; +constexpr uint32_t JETTY_RETRY_COUNT = NN_NO7; +constexpr uint32_t JETTY_RNR_RETRY = NN_NO7; +constexpr uint32_t JFC_COUNT = NN_NO1024; + +/* + * class forward declaration + */ +class UBMemoryRegionFixedBuffer; +class NetDriverUB; +class UBFixedMemPool; + +// verbs wrappers +class UBDeviceHelper; +class UBContext; +class UBJetty; +class UBJfc; +class UBMemoryRegion; + +// logic part +class UBWorker; + +// oob for qp setup +class OOBTCPConnection; +class OOBTCPServer; +class OOBTCPClient; + +using UBSendSglInlineHeader = UBSHcomNetTransHeader; +using UBSendReadWriteRequest = UBSHcomNetTransRequest; +using UBSendSglRWRequest = UBSHcomNetTransSglRequest; + +// the size of UBOpContextInfo is 64 bytes which fit to single CPU cache line +struct UBOpContextInfo { + enum OpType : uint8_t { + SEND = 0, + SEND_RAW = 1, + SEND_RAW_SGL = 2, + RECEIVE = 3, + RECEIVE_RAW = 4, + WRITE = 5, + READ = 6, + SGL_WRITE = 7, + SGL_READ = 8, + HB_WRITE = 9, + SEND_SGL_INLINE = 10, + }; + + enum OpResultType : uint8_t { + SUCCESS = 0, + ERR_TIMEOUT = 1, + ERR_CANCELED = 2, + ERR_IO_ERROR = 3, + ERR_EP_BROKEN = 4, + ERR_EP_CLOSE = 5, + ERR_ACCESS_ABRT = 6, + ERR_ACK_TIMEOUT = 7, + + INVALID_MAGIC = 0xFF, + }; + + struct UBOpContextInfo *prev = nullptr; /* link to prev context */ + struct UBOpContextInfo *next = nullptr; /* link to next context */ + UBJetty *ubJetty = nullptr; /* pointer to qp */ + uintptr_t mrMemAddr = 0; /* address of the buffer */ + urma_target_seg_t *localSeg; /* local target segment */ + uint64_t lKey = 0; /* local key */ + uint32_t dataSize = 0; /* actual data size */ + uint32_t qpNum = 0; /* qp ID */ + OpType opType = RECEIVE; /* op type */ + OpResultType opResultType = OpResultType::SUCCESS; /* op result */ + uint16_t upCtxSize = 0; /* up context size stored in upCtx[] */ + char upCtx[NN_NO16] = {}; /* 16 bytes for upper context */ + + bool HasInternalError() const + { + switch (opResultType) { + // 成功不是一个内部错误 + case OpResultType::SUCCESS: + return false; + + // 超时对用户不可见,例如 TPACK 超时 + case OpResultType::ERR_TIMEOUT: + return true; + + // 内部错误,hcom 自治 + case OpResultType::ERR_CANCELED: + case OpResultType::ERR_IO_ERROR: + return true; + + // 这两个错误码仅当在判定 EP 出错、处理时才会设置,正常通过 CQE 上报的不会是此状态 + case OpResultType::ERR_EP_BROKEN: + case OpResultType::ERR_EP_CLOSE: + return true; + + // 远端内存访问失败,可能是远端 RQE 没有准备好 + case OpResultType::ERR_ACCESS_ABRT: + return false; + + // AckTimeout,可能是远端 RQE 用尽了 + case OpResultType::ERR_ACK_TIMEOUT: + return false; + + default: + return true; + } + } + + static inline NResult GetNResult(OpResultType opResult) + { + switch (opResult) { + case OpResultType::SUCCESS: + return NN_OK; + case OpResultType::ERR_TIMEOUT: + return NN_MSG_TIMEOUT; + case OpResultType::ERR_CANCELED: + return NN_MSG_CANCELED; + case OpResultType::ERR_EP_BROKEN: + return NN_EP_BROKEN; + case OpResultType::ERR_EP_CLOSE: + return NN_EP_CLOSE; + case OpResultType::ERR_ACCESS_ABRT: + return NN_URMA_ACCESS_ABRT; + case OpResultType::ERR_ACK_TIMEOUT: + return NN_URMA_ACK_TIMEOUT; + default: + return NN_MSG_ERROR; + } + } + + static inline OpResultType OpResult(urma_cr_t &result) + { + switch (result.status) { + case URMA_CR_SUCCESS: + return OpResultType::SUCCESS; + case URMA_CR_ACK_TIMEOUT_ERR: + return OpResultType::ERR_ACK_TIMEOUT; + case URMA_CR_RNR_RETRY_CNT_EXC_ERR: + return OpResultType::ERR_TIMEOUT; + case URMA_CR_WR_FLUSH_ERR: + case URMA_CR_WR_FLUSH_ERR_DONE: + case URMA_CR_WR_SUSPEND_DONE: + return OpResultType::ERR_CANCELED; + case URMA_CR_REM_ACCESS_ABORT_ERR: + return OpResultType::ERR_ACCESS_ABRT; + default: + NN_LOG_ERROR("Operation result: " << static_cast(result.status)); + return OpResultType::ERR_IO_ERROR; + } + } +} __attribute__((packed)); + +struct UBSglContextInfo { + UBJetty *qp = nullptr; // the qp pointer which posted from + UBSHcomNetTransSgeIov iov[NET_SGE_MAX_IOV] = {}; + NResult result = NN_OK; + uint32_t reserve1 = 0; + uint16_t refCount = 0; // equal to iovCount + uint16_t iovCount = 0; // max count:NN_NO16 + uint16_t upCtxSize = 0; + uint16_t reserve2 = 0; + char upCtx[NN_NO16] = {}; // 16 bytes for upper context +} __attribute__((packed)); + +struct UBSgeCtxInfo { + UBSglContextInfo *ctx = nullptr; + uint16_t idx = 0; + + UBSgeCtxInfo() = default; + explicit UBSgeCtxInfo(UBSglContextInfo *sglCtx) : ctx(sglCtx) {} +} __attribute__((packed)); + +enum UBPollingMode : uint8_t { + UB_BUSY_POLLING = 0, + UB_EVENT_POLLING = 1, +}; + +struct JettyOptions { + uint32_t maxSendWr = JETTY_MAX_SEND_WR; + uint32_t maxReceiveWr = JETTY_MAX_RECV_WR; + uint32_t mrSegSize = NN_NO1024; + uint32_t mrSegCount = NN_NO64; + uint8_t slave = 1; + UBSHcomUbcMode ubcMode = UBSHcomUbcMode::LowLatency; + + JettyOptions() = default; + + JettyOptions(uint32_t maxSendWrNum, uint32_t maxReceiveWrNum, uint32_t segSize, uint32_t segCount, uint8_t slave, + UBSHcomUbcMode mode = UBSHcomUbcMode::LowLatency) + : maxSendWr(maxSendWrNum), + maxReceiveWr(maxReceiveWrNum), + mrSegSize(segSize), + mrSegCount(segCount), + slave(slave), + ubcMode(mode) + { + } +} __attribute__((packed)); + +struct UBVaSge { + uint64_t va; + int fd; + urma_target_seg_t *targetSeg = nullptr; + urma_target_seg_t *dstSeg = nullptr; +}; +} +} +#endif +#endif // HCOM_UB_COMMON_H diff --git a/src/transport/ub/ub_device_helper.cpp b/src/transport/ub/ub_device_helper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dd35d63bda1de10ef417df600f524e45df9133ca --- /dev/null +++ b/src/transport/ub/ub_device_helper.cpp @@ -0,0 +1,414 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED + +#include "ub_device_helper.h" + +namespace ock { +namespace hcom { + + +uint32_t UBDeviceHelper::G_InitRef = 0; +std::unordered_map UBDeviceHelper::G_UBDevMap; +std::unordered_map> UBDeviceHelper::G_UBDevEidTable; +std::unordered_map UBDeviceHelper::G_UBDevBWTable; +std::mutex UBDeviceHelper::G_Mutex; +uint32_t UBDeviceHelper::PORT_NUMBER = 1; + +UResult UBDeviceHelper::Initialize() +{ + UResult ret = UB_OK; + std::lock_guard guard(G_Mutex); + if (G_InitRef != 0) { + // 第二次进来直接加引用计数,防止mUBContext析构的时候调用UnInitialize时把资源直接释放 + G_InitRef++; + return ret; + } + ret = DoInitialize(); + return ret; +} + +void UBDeviceHelper::UnInitialize() +{ + std::lock_guard guard(G_Mutex); + G_InitRef--; + if (G_InitRef != 0) { + return; + } + // HcomUrma::Uninit() 每个进程只能调用一次,防止一个进程多个service多次调用 + HcomUrma::Uninit(); + G_UBDevMap.clear(); + G_UBDevEidTable.clear(); + G_UBDevBWTable.clear(); +} + +UResult UBDeviceHelper::DoInitialize() +{ + // 后续HCOM重构时重新定义此处数值换算,目前为了不修改头文件中uint8_t bandWidth(范围0~2555)的定义,只做大致比例换算。 + G_UBDevBWTable = { { URMA_SP_10M, 1 }, { URMA_SP_100M, 1 }, { URMA_SP_1G, 1 }, { URMA_SP_2_5G, 3 }, + { URMA_SP_5G, 5 }, { URMA_SP_10G, 10 }, { URMA_SP_14G, 14 }, { URMA_SP_25G, 25 }, + { URMA_SP_40G, 40 }, { URMA_SP_50G, 50 }, { URMA_SP_100G, 100 }, { URMA_SP_200G, 200 }, + { URMA_SP_400G, 255 }, { URMA_SP_800G, 255 } }; + auto ret = DoUpdate(); + if (NN_UNLIKELY(ret != UB_OK)) { + G_UBDevBWTable.clear(); + return ret; + } + // 第一次成功DoInitialize增加引用计数 + G_InitRef++; + return UB_OK; +} + +UResult UBDeviceHelper::DoUpdate() +{ + UResult ret = UB_OK; + bool isFindDevice = false; + urma_init_attr_t initAttr{}; + ret = HcomUrma::Init(&initAttr); + if (ret != URMA_SUCCESS && ret != URMA_EEXIST) { + NN_LOG_ERROR("Failed to initialize urma environment"); + return ret; + } + G_UBDevMap.clear(); + G_UBDevEidTable.clear(); + + urma_device_t **devList = nullptr; + int devCount = 0; + devList = HcomUrma::GetDeviceList(&devCount); + NN_LOG_TRACE_INFO("UB Device count:" << devCount); + if (devList == nullptr) { + NN_LOG_ERROR("Failed to call get urma device list, errno " << errno); + return UB_DEVICE_FAILED_OPEN; + } + auto guard = MakeScopeExit([&devList]() { HcomUrma::FreeDeviceList(devList); }); + G_UBDevMap.reserve(devCount); + G_UBDevEidTable.reserve(devCount); + + urma_device_attr_t *devAttr = reinterpret_cast(malloc(sizeof(urma_device_attr_t))); + if (devAttr == nullptr) { + NN_LOG_ERROR("Failed to malloc get urma device attr."); + return UB_NEW_OBJECT_FAILED; + } + auto guard1 = MakeScopeExit([&devAttr]() { free(devAttr); }); + for (int i = 0; i < devCount; i++) { + if (devList[i] == nullptr) { // should not happen + NN_LOG_TRACE_INFO("UB Device " << i << " is null"); + continue; + } + + UBDeviceSimpleInfo info; + info.devIndex = i; + if (NN_UNLIKELY(strcpy_s(info.devName, URMA_MAX_NAME, reinterpret_cast(devList[i]->name)) != + UB_OK)) { + NN_LOG_ERROR("Failed to copy device name when initializing device"); + return UB_PARAM_INVALID; + } + NN_LOG_TRACE_INFO("UB Device " << i << " name " << devList[i]->name); + + urma_context_t *ctx = HcomUrma::CreateContext(devList[i], 0); + if (ctx == nullptr) { + NN_LOG_WARN("Unable to create urma context"); + continue; + } + uint32_t eidCnt = 0; + urma_eid_info_t *eidInfoList = HcomUrma::GetEidList(devList[i], &eidCnt); + if (eidInfoList == nullptr) { + NN_LOG_ERROR("Failed to get eid list"); + HcomUrma::DeleteContext(ctx); + return UB_PARAM_INVALID; + } + + // Query and process device info + if ((ret = HcomUrma::QueryDevice(devList[i], devAttr)) != 0) { + NN_LOG_ERROR("Failed to query urma device"); + HcomUrma::FreeEidList(eidInfoList); + HcomUrma::DeleteContext(ctx); + return ret; + } + + info.active = devAttr->port_attr[0].state == URMA_PORT_ACTIVE; + auto it = G_UBDevBWTable.find(devAttr->port_attr[0].active_speed); + if (it == G_UBDevBWTable.end()) { + NN_LOG_ERROR("UB failed to query urma device bandwidth."); + HcomUrma::FreeEidList(eidInfoList); + HcomUrma::DeleteContext(ctx); + return UB_PARAM_INVALID; + } + + uint32_t bandWidth = it->second; + std::vector eidVec; + eidVec.reserve(eidCnt); + GetEidVec(info.devName, i, eidCnt, eidInfoList, eidVec, bandWidth); + + info.deviceInfo.maxSge = std::min(devAttr->dev_cap.max_jfs_sge, devAttr->dev_cap.max_jfr_sge); + + G_UBDevMap.emplace(i, info); + G_UBDevEidTable.emplace(info.devName, eidVec); + isFindDevice = true; + HcomUrma::FreeEidList(eidInfoList); + HcomUrma::DeleteContext(ctx); + } + if (!isFindDevice) { + NN_LOG_ERROR("Failed to get urma device."); + return UB_PARAM_INVALID; + } + return UB_OK; +} + +UResult UBDeviceHelper::Update() +{ + std::lock_guard guard(G_Mutex); + return DoUpdate(); +} + +void UBDeviceHelper::GetEidVec(const std::string &devName, uint16_t devIndex, uint32_t eidCnt, + urma_eid_info_t *eidInfoList, std::vector &outGidVec, uint8_t bandWidth) +{ + UBEId eid{}; + for (uint32_t i = 0; i < eidCnt; i++) { + if (eidInfoList[i].eid.in6.interface_id == 0) { + continue; + } + eid.devIndex = devIndex; + eid.eidIndex = eidInfoList[i].eid_index; + eid.urmaEid = eidInfoList[i].eid; + eid.bandWidth = bandWidth; + outGidVec.push_back(eid); + } +} + +UResult UBDeviceHelper::GetDeviceCount(uint16_t &deviceCount, std::vector &enabledDevices) +{ + UResult ret = UB_OK; + if ((ret = Initialize()) != UB_OK) { + return ret; + } + + { + std::lock_guard guard(G_Mutex); + deviceCount = G_UBDevMap.size(); + for (auto &item : G_UBDevMap) { + if (item.second.active) { + enabledDevices.push_back(item.second); + } + } + } + + return UB_OK; +} + +UResult UBDeviceHelper::GetEnableDeviceCount(std::string ipMask, uint16_t &enableDevCount, + std::vector &enableIps, std::string ipGroup) +{ + UResult result = UB_OK; + std::vector matchIps; + // filter ip by mask + NetFunc::NN_SplitStr(ipGroup, ";", matchIps); + if (matchIps.empty()) { + std::vector filters; + NetFunc::NN_SplitStr(ipMask, ",", filters); + if (filters.empty()) { + NN_LOG_ERROR("[UB] Invalid ip mask '" << ipMask << "' by set, example '192.168.0.0/24'"); + return NN_INVALID_IP; + } + for (auto &mask : filters) { + result = FilterIp(mask, matchIps); + } + if (matchIps.empty()) { + NN_LOG_ERROR("[UB] No matched ip found with ipGroup or ipMask."); + return UB_DEVICE_NO_IP_MATCHED; + } + } + // init urma devices + if ((result = Initialize()) != 0) { + NN_LOG_ERROR("[UB] Failed to init devices"); + return result; + } + + NN_LOG_INFO(DeviceInfo()); + + uint16_t enableCount = 0; + std::vector findIps; + // choose the matched ip and port active + for (uint16_t i = 0; i < static_cast(matchIps.size()); ++i) { + UBEId tmpEid{}; + if ((GetDeviceByIp(matchIps[i], tmpEid)) != 0) { + NN_LOG_WARN("[UB] Failed to get device by ip " << matchIps[i]); + continue; + } + // active or not + if (G_UBDevMap[tmpEid.devIndex].active) { + enableCount++; + findIps.emplace_back(matchIps[i]); + } + NN_LOG_DEBUG("gid found devIndex " << tmpEid.devIndex << ", gidIndex " << tmpEid.eidIndex); + } + enableDevCount = enableCount; + enableIps = findIps; + return result; +} + +UResult UBDeviceHelper::GetDeviceByIp(const std::string &ip, UBEId &gid) +{ + UResult ret = UB_OK; + struct sockaddr_in address {}; + if ((ret = GetIfAddressByIp(ip, address)) != UB_OK) { + return ret; + } + + return GetDeviceByAddress(ip, address, gid); +} + +UResult UBDeviceHelper::GetDeviceByEid(const uint8_t eid[], UBEId &gid) +{ + std::lock_guard guard(G_Mutex); + for (auto &item : G_UBDevEidTable) { + for (auto &gItem : item.second) { + if (std::memcmp(eid, gItem.urmaEid.raw, URMA_EID_SIZE) == 0) { + gid = gItem; + return UB_OK; + } + } + } + + NN_LOG_ERROR("Failed to get proper gid by eid " << eid); + return UB_DEVICE_NO_IP_TO_GID_MATCHED; +} + +UResult UBDeviceHelper::GetDeviceByName(const char name[], uint8_t len, UBEId &gid) +{ + std::lock_guard guard(G_Mutex); + for (auto &item : G_UBDevEidTable) { + if (strncmp(name, item.first.c_str(), len) == 0) { + gid = item.second[0]; + return UB_OK; + } + } + + NN_LOG_ERROR("Failed to get proper gid by name " << name); + return UB_DEVICE_NO_IP_TO_GID_MATCHED; +} + +UResult UBDeviceHelper::GetIfAddressByIp(const std::string &ip, struct sockaddr_in &address) +{ + struct ifaddrs *addresses = nullptr; + if (getifaddrs(&addresses) != 0) { + NN_LOG_ERROR("Failed to get interface addresses"); + return UB_DEVICE_FAILED_GET_IP_ADDRESS; + } + + char ipStr[INET_ADDRSTRLEN] = {0}; + bool found = false; + + struct ifaddrs *iter = addresses; + while (iter != nullptr) { + if (iter->ifa_addr != nullptr && iter->ifa_addr->sa_family == AF_INET) { + inet_ntop(AF_INET, &((reinterpret_cast(iter->ifa_addr))->sin_addr), ipStr, + INET_ADDRSTRLEN); + if (ip == std::string(ipStr)) { + address = *(reinterpret_cast(iter->ifa_addr)); + found = true; + break; + } + } + iter = iter->ifa_next; + } + freeifaddrs(addresses); + + if (!found) { + NN_LOG_ERROR("Failed to get interface address for ip " << ip); + return UB_DEVICE_NO_IP_MATCHED; + } + + return UB_OK; +} + +UResult UBDeviceHelper::GetDeviceByAddress(const std::string &ip, struct sockaddr_in &address, UBEId &eid) +{ + UResult result = UB_OK; + if ((result = Initialize()) != UB_OK) { + return result; + } + + UBEId tmpEid{}; + bool found = false; + + std::lock_guard lock(G_Mutex); + for (auto &item : G_UBDevEidTable) { + for (auto &gItem : item.second) { + auto devI6Address = reinterpret_cast(gItem.urmaEid.raw); + auto targetAddress = address.sin_addr.s_addr; + + auto judge1 = ((devI6Address->s6_addr32[NN_NO0] | devI6Address->s6_addr32[NN_NO1]) | + (devI6Address->s6_addr32[NN_NO2] ^ htonl(0x0000ffff))) == 0UL; + /* IPv4 encoded multicast addresses */ + auto judge2 = devI6Address->s6_addr32[NN_NO0] == htonl(0xff0e0000) && + ((devI6Address->s6_addr32[NN_NO1] | (devI6Address->s6_addr32[NN_NO2] ^ htonl(0x0000ffff))) == 0UL); + if (!((judge1 || judge2) && devI6Address->s6_addr32[NN_NO3] == targetAddress)) { + // doesn't match + continue; + } + + // match + if (!found) { // first found + tmpEid = gItem; + found = true; + } else { + // found new one then compare the version, higher version is better + tmpEid = gItem; + } + } + } + + if (!found) { + NN_LOG_ERROR("Failed to get proper gid by address for ip " << ip); + return UB_DEVICE_NO_IP_TO_GID_MATCHED; + } + + eid = tmpEid; + return UB_OK; +} + +std::string UBDeviceHelper::DeviceInfo() +{ + std::ostringstream oss; + std::lock_guard guard(G_Mutex); + if (!G_InitRef) { + oss << "UBDeviceHelper has not been initialized"; + return oss.str(); + } + + // dump device info + oss << "UBDeviceHelper device info, devices: count " << G_UBDevMap.size() << ", "; + for (auto &item : G_UBDevMap) { + oss << "[" << item.second.devIndex << "," << item.second.devName << "," << item.second.active << "] "; + } + + oss << ", gidTable: count " << G_UBDevEidTable.size() << ", "; + for (auto &item : G_UBDevEidTable) { + oss << "[deviceName " << item.first << ", "; + for (auto &eid : item.second) { + oss << "[" << eid.devIndex << "," << eid.eidIndex << "] "; + } + oss << "] "; + } + + return oss.str(); +} + +uint32_t UBDeviceHelper::GetPortNumber() +{ + return PORT_NUMBER; +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/ub/ub_device_helper.h b/src/transport/ub/ub_device_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..4804bf1f3495b5f1410bd7e6b3b49b4031898a43 --- /dev/null +++ b/src/transport/ub/ub_device_helper.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_WRAPPER_DEVICE_HELPER_H +#define HCOM_UB_WRAPPER_DEVICE_HELPER_H +#ifdef UB_BUILD_ENABLED + +#include "ub_common.h" + +namespace ock { +namespace hcom { + +struct UBDeviceSimpleInfo { + uint16_t devIndex = 0; + char devName[URMA_MAX_NAME]{}; + bool active = false; + UBSHcomNetDriverDeviceInfo deviceInfo; +}; + +struct UBEId { + uint16_t devIndex = 0; + uint16_t eidIndex = 0; + urma_eid_t urmaEid; + uint8_t bandWidth = 0; +} __attribute__((packed)); + +class UBDeviceHelper { +public: + static UResult Initialize(); + static void UnInitialize(); + static UResult Update(); + + static UResult GetDeviceCount(uint16_t &deviceCount, std::vector &enabledDevices); + + static UResult GetDeviceByIp(const std::string &ip, UBEId &gid); + static UResult GetDeviceByEid(const uint8_t eid[], UBEId &gid); + static UResult GetDeviceByName(const char name[], uint8_t len, UBEId &gid); + + static uint32_t GetPortNumber(); + + static std::string DeviceInfo(); + + static UResult GetEnableDeviceCount(std::string ipMask, uint16_t &enableDevCount, + std::vector &enableIps, std::string ipGroup); + +private: + static UResult DoInitialize(); + static UResult DoUpdate(); + static void GetEidVec(const std::string &devName, uint16_t devIndex, uint32_t eidCnt, urma_eid_info_t *eidInfoList, + std::vector &outGidVec, uint8_t bandWidth); + + static UResult GetIfAddressByIp(const std::string &ip, struct sockaddr_in &address); + static UResult GetDeviceByAddress(const std::string &ip, struct sockaddr_in &address, UBEId &gid); + +private: + static std::unordered_map G_UBDevMap; + static std::unordered_map> G_UBDevEidTable; + static std::mutex G_Mutex; + static uint32_t G_InitRef; + + static uint32_t PORT_NUMBER; + static std::unordered_map G_UBDevBWTable; +}; +} +} +#endif +#endif // HCOM_UB_WRAPPER_DEVICE_HELPER_H \ No newline at end of file diff --git a/src/transport/ub/ub_fixed_mem_pool.cpp b/src/transport/ub/ub_fixed_mem_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c24e6e905f356a85f22d226fd3d1c08da7eca51 --- /dev/null +++ b/src/transport/ub/ub_fixed_mem_pool.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include "ub_fixed_mem_pool.h" +namespace ock { +namespace hcom { + +UResult UBFixedMemPool::Initialize() +{ + auto tmpBuf = memalign(NN_NO4096, mTotalSize); + if (tmpBuf == nullptr) { + NN_LOG_ERROR("Failed to allocate memory for UBFixedMemPool with size " << mTotalSize); + return UB_MEMORY_ALLOCATE_FAILED; + } + mBuf = reinterpret_cast(tmpBuf); + if (MakeFreeList() != UB_OK) { + NN_LOG_ERROR("Failed to make free list"); + UnInitialize(); + return UB_ERROR; + } + NN_LOG_INFO("UB mempool initialized total size = " << mTotalSize << " blk size = " << mBlkSize << " blk cnt = " + << mBlkCnt); + return UB_OK; +} + +UResult UBFixedMemPool::MakeFreeList() +{ + if (mBuf == 0 || mBlkSize * mBlkCnt != mTotalSize) { + NN_LOG_ERROR("Failed to make free list as invalid parameter"); + return UB_PARAM_INVALID; + } + auto address = mBuf; + auto iter = reinterpret_cast(address); + mHead.next = iter; + for (uint32_t i = 1; i < mBlkCnt; ++i) { + address += mBlkSize; + iter->next = reinterpret_cast(address); + iter = reinterpret_cast(address); + } + iter->next = nullptr; + return UB_OK; +} + +bool UBFixedMemPool::GetFreeBuffer(uintptr_t &buf) +{ + std::lock_guard guard(mMutex); + if (NN_UNLIKELY(mHead.next == nullptr)) { + NN_LOG_ERROR("Failed to get buffer as no free buffer"); + return false; + } + auto tmp = mHead.next; + mHead.next = tmp->next; + buf = reinterpret_cast(tmp); + return true; +} + +bool UBFixedMemPool::ReturnBuffer(uintptr_t buf) +{ + std::lock_guard guard(mMutex); + if (NN_UNLIKELY(!((buf >= mBuf) && (buf - mBuf < mTotalSize) && ((buf - mBuf) % mBlkSize == 0)))) { + NN_LOG_ERROR("Failed to return buffer as invalid address"); + return false; + } + auto tmp = reinterpret_cast(buf); + tmp->next = mHead.next; + mHead.next = tmp; + return true; +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/ub/ub_fixed_mem_pool.h b/src/transport/ub/ub_fixed_mem_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..7a58fea4e383b3ad1ced427fcf0467cc5ca8f485 --- /dev/null +++ b/src/transport/ub/ub_fixed_mem_pool.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_UB_FIXED_MEM_POOL_H +#define HCOM_UB_FIXED_MEM_POOL_H +#ifdef UB_BUILD_ENABLED + +#include "ub_common.h" +#include "net_bucket_linked_list.h" +#include "net_util.h" + +namespace ock { +namespace hcom { +/* + * Mini block allocated to end user + */ +struct UBMemPoolMinBlock { + UBMemPoolMinBlock *next = nullptr; /* link to next min block */ +}; + +/* + * Mem pool for fixed block size + * UBFixedMemPool is now used only by public jetty + */ +class UBFixedMemPool { +public: + UBFixedMemPool(uint16_t blkSize, uint16_t blkCnt = NN_NO64): mBlkSize(blkSize), mBlkCnt(blkCnt) + { + mTotalSize = mBlkSize * mBlkCnt; + } + ~UBFixedMemPool() + { + UnInitialize(); + } + UResult Initialize(); + UResult MakeFreeList(); + bool GetFreeBuffer(uintptr_t &buf); + bool ReturnBuffer(uintptr_t buf); + void UnInitialize() + { + if (mBuf != 0) { + free(reinterpret_cast(mBuf)); + mBuf = 0; + } + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS +private: + uint16_t mBlkSize = NN_NO128; + uint16_t mBlkCnt = NN_NO64; + uint64_t mTotalSize = NN_NO8192; + uintptr_t mBuf = 0; + UBMemPoolMinBlock mHead {}; + std::mutex mMutex; + + DEFINE_RDMA_REF_COUNT_VARIABLE; +}; +} +} +#endif +#endif // HCOM_UB_FIXED_MEM_POOL_H \ No newline at end of file diff --git a/src/transport/ub/ub_jetty_ptr_map.h b/src/transport/ub/ub_jetty_ptr_map.h new file mode 100644 index 0000000000000000000000000000000000000000..75e35f435407904114e3740d68033570ae405360 --- /dev/null +++ b/src/transport/ub/ub_jetty_ptr_map.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_JETTY_PTR_MAP_H +#define HCOM_UB_JETTY_PTR_MAP_H + +#include "ub_common.h" +#include "ub_urma_wrapper_jetty.h" + +#include + +namespace ock { +namespace hcom { +/// JettyPtrMap 支持通过 urma jetty id 查找 `UBJetty*` +class JettyPtrMap { +public: + JettyPtrMap() = default; + + JettyPtrMap(const JettyPtrMap &) = delete; + JettyPtrMap(JettyPtrMap &&rhs) noexcept = delete; + JettyPtrMap &operator=(const JettyPtrMap &) = delete; + JettyPtrMap &operator=(JettyPtrMap &&rhs) noexcept = delete; + + ~JettyPtrMap() + { + if (mId2Jetty) { + munmap(mId2Jetty, mId2JettySize * sizeof(UBJetty *)); + mId2Jetty = nullptr; + mId2JettySize = 0; + } + } + + UResult Initialize() + { + const size_t jettyIdMax = NN_NO65536; + void *id2Jetty = mmap(nullptr, jettyIdMax * sizeof(UBJetty *), PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, 0, 0); + if (id2Jetty == MAP_FAILED) { + NN_LOG_ERROR("Unable to mmap with size: " << (jettyIdMax * sizeof(UBJetty *))); + return UB_MEMORY_ALLOCATE_FAILED; + } + + mId2Jetty = reinterpret_cast(id2Jetty); + mId2JettySize = static_cast(jettyIdMax); + return UB_OK; + } + + UBJetty *Lookup(uint32_t jettyId) const + { + if (NN_UNLIKELY(jettyId >= mId2JettySize)) { + NN_LOG_ERROR("The given jetty id " << jettyId << " exceeds the size " << mId2JettySize + << " when looking-up"); + return nullptr; + } + + return mId2Jetty[jettyId]; + } + + UResult Emplace(uint32_t jettyId, UBJetty *jetty) + { + if (NN_UNLIKELY(jettyId >= mId2JettySize)) { + NN_LOG_ERROR("The given jetty id " << jettyId << " exceeds the size " << mId2JettySize + << " when inserting"); + return UB_ERROR; + } + + mId2Jetty[jettyId] = jetty; + return UB_OK; + } + + UResult Modify(uint32_t jettyId, UBJetty *jetty) + { + if (NN_UNLIKELY(jettyId >= mId2JettySize)) { + NN_LOG_ERROR("The given jetty id " << jettyId << " exceeds the size " << mId2JettySize + << " when modifying"); + return UB_ERROR; + } + + mId2Jetty[jettyId] = jetty; + return UB_OK; + } + + UResult Clear(uint32_t jettyId) + { + return Modify(jettyId, nullptr); + } + +private: + UBJetty **mId2Jetty = nullptr; ///< UBJetty::mUrmaJettyId -> UBJetty* 映射表 + uint32_t mId2JettySize = 0; ///< mId2Jetty 映射表大小 +}; + +} // namespace hcom +} // namespace ock + +#endif // HCOM_UB_JETTY_PTR_MAP_H \ No newline at end of file diff --git a/src/transport/ub/ub_mr_fixed_buf.cpp b/src/transport/ub/ub_mr_fixed_buf.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84ceadd7c8baf85f1e3730241149c0ca96989618 --- /dev/null +++ b/src/transport/ub/ub_mr_fixed_buf.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED + +#include "ub_mr_fixed_buf.h" + +namespace ock { +namespace hcom { + +UResult UBMemoryRegionFixedBuffer::Create(const std::string &name, UBContext *ctx, uint32_t singleSegSize, + uint32_t segCount, unsigned long memid, UBMemoryRegionFixedBuffer *&buf) +{ + auto tmp = new (std::nothrow) UBMemoryRegionFixedBuffer(name, ctx, memid, singleSegSize, segCount); + if (tmp == nullptr) { + return UB_NEW_OBJECT_FAILED; + } + buf = tmp; + return UB_OK; +} + +UResult UBMemoryRegionFixedBuffer::Initialize() +{ + UResult result = UB_OK; + if ((result = UBMemoryRegion::Initialize()) != UB_OK) { + return result; + } + + // init un-allocated + uintptr_t address = mBuf; + for (uint32_t i = 0; i < mSegCount; i++) { + mLinkList.PushFront(address); + address += mSingleSegSize; + } + + return UB_OK; +} + +void UBMemoryRegionFixedBuffer::UnInitialize() +{ + UBMemoryRegion::UnInitialize(); +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/ub/ub_mr_fixed_buf.h b/src/transport/ub/ub_mr_fixed_buf.h new file mode 100644 index 0000000000000000000000000000000000000000..97daed22675b47420ad456f172d3ef885a20a639 --- /dev/null +++ b/src/transport/ub/ub_mr_fixed_buf.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_MR_FIXED_BUF_H +#define HCOM_UB_MR_FIXED_BUF_H +#ifdef UB_BUILD_ENABLED + +#include "ub_mr_pool.h" + +namespace ock { +namespace hcom { + +class UBMemoryRegionFixedBuffer : public UBMemoryRegion { +public: + static UResult Create(const std::string &name, UBContext *ctx, uint32_t singleSegSize, uint32_t segCount, + unsigned long memid, UBMemoryRegionFixedBuffer *&buf); + +public: + UBMemoryRegionFixedBuffer(const std::string &name, UBContext *ctx, unsigned long memid, uint32_t singleSegSize, + uint32_t segCount) + : UBMemoryRegion(name, ctx, static_cast(singleSegSize) * static_cast(segCount), memid, 0), + mSingleSegSize(singleSegSize), + mSegCount(segCount) + { + OBJ_GC_INCREASE(UBMemoryRegionFixedBuffer); + } + + ~UBMemoryRegionFixedBuffer() override + { + UnInitialize(); + OBJ_GC_DECREASE(UBMemoryRegionFixedBuffer); + } + + UResult Initialize() override; + + inline bool GetFreeBuffer(uintptr_t &item) + { + return mLinkList.Pop(item); + } + + inline bool ReturnBuffer(uintptr_t value) + { + mLinkList.PushFront(value); + return true; + } + + inline bool GetFreeBufferN(uintptr_t *&items, uint32_t n) + { + return mLinkList.PopN(items, n); + } + + inline uint32_t GetSingleSegSize() const + { + return mSingleSegSize; + } + + std::string ToString() + { + std::ostringstream oss; + oss << "buf-address " << mBuf << ", mSingleSegSize " << mSingleSegSize << ", mSegCount " << mSegCount << + ", total buf size " << mSize; + return oss.str(); + } + +protected: + void UnInitialize() override; + +private: + uint32_t mSingleSegSize = MR_FIXED_POOL_DEFAULT_SEG_SIZE; + uint32_t mSegCount = MR_FIXED_POOL_DEFAULT_SEG_COUNT; + + // uintptr_p store the start address of each mr segment + NetBucketLinkedList mLinkList; +}; +} +} +#endif +#endif // HCOM_UB_MR_FIXED_BUF_H \ No newline at end of file diff --git a/src/transport/ub/ub_mr_pool.cpp b/src/transport/ub/ub_mr_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c221ace1c44048e586678ce59ffd833c7fe0bc92 --- /dev/null +++ b/src/transport/ub/ub_mr_pool.cpp @@ -0,0 +1,233 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include + +#include "ub_common.h" +#include "ub_mr_pool.h" +#include "under_api/urma/urma_api_wrapper.h" +#include "under_api/obmm/obmm_api_wrapper.h" + +namespace ock { +namespace hcom { +uint64_t UBMemoryRegion::gPageSize = sysconf(_SC_PAGESIZE); + +UResult UBMemoryRegion::Create(const std::string &name, UBContext *ctx, uint64_t size, UBMemoryRegion *&buf) +{ + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to create ub memory region as ctx is nullptr"); + return UB_PARAM_INVALID; + } + + auto tmpBuf = new (std::nothrow) UBMemoryRegion(name, ctx, size, 0, 0); + if ((NN_UNLIKELY(tmpBuf == nullptr))) { + NN_LOG_ERROR("Failed to create ub memory region"); + return UB_NEW_OBJECT_FAILED; + } + + buf = tmpBuf; + + return UB_OK; +} + +UResult UBMemoryRegion::Create(const std::string &name, UBContext *ctx, uintptr_t address, uint64_t size, + UBMemoryRegion *&buf) +{ + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to create ub memory region as ctx is nullptr"); + return UB_PARAM_INVALID; + } + + auto tmpBuf = new (std::nothrow) UBMemoryRegion(name, ctx, address, size); + if ((NN_UNLIKELY(tmpBuf == nullptr))) { + NN_LOG_ERROR("Failed to create ub memory region"); + return UB_NEW_OBJECT_FAILED; + } + + buf = tmpBuf; + + return UB_OK; +} + +UResult UBMemoryRegion::Initialize() +{ + if (mMemSeg != nullptr) { + return UB_OK; + } + + if (mUBContext == nullptr) { + NN_LOG_ERROR("Failed to initialize UBMemoryRegion as ctx is nullptr"); + return UB_PARAM_INVALID; + } + + urma_target_seg_t *tmpMR = nullptr; + + urma_reg_seg_flag_t flag{}; + flag.bs.access = URMA_ACCESS_READ | URMA_ACCESS_WRITE; + + urma_seg_cfg_t seg_cfg{}; + seg_cfg.len = mSize; + seg_cfg.flag = flag; + + if (mExternalMemory) { + // the memory is allocated externally + // register mr directly + NN_LOG_WARN("externally allocated memory"); + auto tmpBuf = reinterpret_cast(mBuf); + seg_cfg.va = tmpBuf; + tmpMR = HcomUrma::RegisterSeg(mUBContext->mUrmaContext, &seg_cfg); + if (tmpMR == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to register ex mem for UBMemoryRegion " << mName << " error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return UB_MR_REG_FAILED; + } + } else { + // allocate memory + if (gPageSize <= 0) { + NN_LOG_ERROR("Failed to get page size from system, page size: " << gPageSize); + return UB_PARAM_INVALID; + } + auto tmpBuf = memalign(gPageSize, mSize); + if (tmpBuf == nullptr) { + NN_LOG_ERROR("Failed to allocate memory for UBMemoryRegion " << mName << " with size " << mSize); + return UB_MEMORY_ALLOCATE_FAILED; + } + + seg_cfg.va = reinterpret_cast(tmpBuf); + // register memory region to card + tmpMR = HcomUrma::RegisterSeg(mUBContext->mUrmaContext, &seg_cfg); + if (tmpMR == nullptr) { + free(tmpBuf); + tmpBuf = nullptr; + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to register memory for UBMemoryRegion " << mName << " error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return UB_MR_REG_FAILED; + } + mBuf = reinterpret_cast(tmpBuf); + } + + mMemSeg = tmpMR; + + return UB_OK; +} + +UResult UBMemoryRegion::InitializeForOneSide() +{ + if (mMemSeg != nullptr) { + return UB_OK; + } + + if (mUBContext == nullptr) { + NN_LOG_ERROR("Failed to initialize UBMemoryRegion as UBContex is null"); + return UB_PARAM_INVALID; + } + + urma_target_seg_t *tmpMR = nullptr; + + urma_reg_seg_flag_t flag{}; + flag.bs.access = URMA_ACCESS_READ | URMA_ACCESS_WRITE; + flag.bs.token_policy = URMA_TOKEN_PLAIN_TEXT; + + uint32_t tokenValue = GenerateSecureRandomUint32(); + urma_seg_cfg_t seg_cfg{}; + seg_cfg.len = mSize; + seg_cfg.token_value = {tokenValue}; + seg_cfg.flag = flag; + + if (mExternalMemory) { + // the memory is allocated externally + // register mr directly + auto tmpBuf = reinterpret_cast(mBuf); + seg_cfg.va = tmpBuf; + tmpMR = HcomUrma::RegisterSeg(mUBContext->mUrmaContext, &seg_cfg); + if (tmpMR == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to register external memory for UBMemoryRegion " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << + ", buffer " << tmpBuf << " with size " << mSize); + return UB_MR_REG_FAILED; + } + } else { + // allocate memory + if (gPageSize <= 0) { + NN_LOG_ERROR("Failed to get system page size, page size: " << gPageSize); + return UB_PARAM_INVALID; + } + auto tmpBuf = memalign(gPageSize, mSize); + if (tmpBuf == nullptr) { + NN_LOG_ERROR("Failed to allocate memory for UBMemoryRegion " << mName << " with size " << mSize); + return UB_MEMORY_ALLOCATE_FAILED; + } + seg_cfg.va = reinterpret_cast(tmpBuf); + // register memory region to card + tmpMR = HcomUrma::RegisterSeg(mUBContext->mUrmaContext, &seg_cfg); + if (tmpMR == nullptr) { + free(tmpBuf); + tmpBuf = nullptr; + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to register memory for UBMemoryRegion " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << " with size " << mSize); + return UB_MR_REG_FAILED; + } + + mBuf = reinterpret_cast(tmpBuf); + } + + mMemSeg = tmpMR; + mLKey = (static_cast(tokenValue) << NN_NO32) | (tmpMR->seg.token_id); + + return UB_OK; +} + +UResult UBMemoryRegion::InitializeWithPA(unsigned long memid) +{ + int fd_e = HcomObmm::ObmmOpen(memid); + if (fd_e < 0) { + NN_LOG_ERROR("Failed to get fd with memid " << memid << " errno " << errno); + return UB_MEMORY_ALLOCATE_FAILED; + } + mMemFd = fd_e; + auto tmpBuf = mmap(NULL, mSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd_e, 0); + if (tmpBuf == MAP_FAILED) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("mmap error: " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + close(fd_e); + return UB_MEMORY_ALLOCATE_FAILED; + } + + mBuf = reinterpret_cast(tmpBuf); + mGetBufWithMapping = true; + return UB_OK; +} + +void UBMemoryRegion::UnInitialize() +{ + if (mMemSeg != nullptr) { + HcomUrma::UnregisterSeg(mMemSeg); + } + + if (!mExternalMemory && mBuf != 0) { + free(reinterpret_cast(mBuf)); + } + mUBContext->DecreaseRef(); + + mMemSeg = nullptr; + mBuf = 0; + mUBContext = nullptr; +} +} +} + +#endif diff --git a/src/transport/ub/ub_mr_pool.h b/src/transport/ub/ub_mr_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..413f21a44e460376a7cdaf1b7e78c832ccf1ac5a --- /dev/null +++ b/src/transport/ub/ub_mr_pool.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_MR_POOL_H +#define HCOM_UB_MR_POOL_H +#ifdef UB_BUILD_ENABLED +#include +#include "libobmm.h" + +#include "hcom.h" +#include "ub_common.h" +#include "ub_urma_wrapper_ctx.h" +#include "net_bucket_linked_list.h" +#include "under_api/urma/urma_api_wrapper.h" +#include "net_util.h" + +namespace ock { +namespace hcom { +class UBMemoryRegion : public UBSHcomNetMemoryRegion { +public: + static UResult Create(const std::string &name, UBContext *ctx, uint64_t size, UBMemoryRegion *&buf); + static UResult Create(const std::string &name, UBContext *ctx, uintptr_t address, uint64_t size, + UBMemoryRegion *&buf); + void *GetMemorySeg() override + { + return mMemSeg; + } + + void GetVa(uint64_t &va, uint64_t &va_len, uint32_t &token_id) override + { + va = mMemSeg->seg.ubva.va; + va_len = mMemSeg->seg.len; + token_id = mMemSeg->seg.token_id; + } + +public: + UBMemoryRegion() = delete; + UBMemoryRegion(const UBMemoryRegion &other) = delete; + UBMemoryRegion(UBMemoryRegion &&other) = delete; + UBMemoryRegion &operator = (const UBMemoryRegion &) = delete; + UBMemoryRegion &operator = (UBMemoryRegion &&) = delete; + + ~UBMemoryRegion() override + { + OBJ_GC_DECREASE(UBMemoryRegion); + } + + UResult Initialize() override; + UResult InitializeForOneSide(); + UResult InitializeWithPA(unsigned long memid); + void UnInitialize() override; + inline UBSHcomNetDriverProtocol GetProtocol() + { + return mUBContext->protocol; + } + +public: + UBContext *mUBContext = nullptr; + +protected: + UBMemoryRegion(const std::string &name, UBContext *ctx, uint64_t size, unsigned long memid, int flag) + : UBSHcomNetMemoryRegion(name, false, 0, size), mUBContext(ctx), mMemid(memid) + { + // increase the reference count of context + if (ctx != nullptr) { + ctx->IncreaseRef(); + } + + OBJ_GC_INCREASE(UBMemoryRegion); + } + + UBMemoryRegion(const std::string &name, UBContext *ctx, uintptr_t address, uint64_t size) + : UBSHcomNetMemoryRegion(name, true, address, size), mUBContext(ctx) + { + // increase the reference count of context + if (ctx != nullptr) { + ctx->IncreaseRef(); + } + + OBJ_GC_INCREASE(UBMemoryRegion); + } + +protected: + urma_target_seg_t *mMemSeg = nullptr; + unsigned long mMemid = 0; + int mMemFd = 0; // InitializeWithPA obmm open fd + static uint64_t gPageSize; +}; +} +} +#endif +#endif // HCOM_UB_MR_POOL_H diff --git a/src/transport/ub/ub_thread_pool.cpp b/src/transport/ub/ub_thread_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..403182d10b2d5521287971716c87a55bb99f72fe --- /dev/null +++ b/src/transport/ub/ub_thread_pool.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED + +#include "ub_thread_pool.h" + +namespace ock { +namespace hcom { + +void UBThreadPool::Initialize() +{ + if (mIsRunning) { + return; + } + + mIsRunning = true; + mThreads.reserve(mThreadCount); + + for (int i = 0; i < mThreadCount; ++i) { + mThreads.emplace_back(&UBThreadPool::RunInThread, this); + } + + NN_LOG_INFO("UB threadpool initialized with " << mThreadCount << " threads"); +} + +void UBThreadPool::Stop() +{ + NN_LOG_INFO("UB threadpool begin to stop"); + if (!mIsRunning) { + NN_LOG_INFO("UB threadpool is not running"); + return; + } + + mIsRunning = false; + mCondition.notify_all(); + for (auto& thread : mThreads) { + if (thread.joinable()) { + thread.join(); + } + } + mThreads.clear(); + std::lock_guard lock(mMutex); + while (!mTasks.empty()) { + mTasks.pop(); + } + NN_LOG_INFO("UB threadpool has been stopped"); +} + +void UBThreadPool::Submit(std::function task) +{ + if (!mIsRunning) { + NN_LOG_ERROR("UB threadpool is not running"); + return; + } + + std::lock_guard lock(mMutex); + mTasks.emplace(std::move(task)); + mCondition.notify_one(); +} + +void UBThreadPool::RunInThread() +{ + while (mIsRunning) { + std::function task; + { + std::unique_lock lock(mMutex); + mCondition.wait(lock, [this]() { + return !mIsRunning || !mTasks.empty(); + }); + if (!mIsRunning) { + return; + } + task = std::move(mTasks.front()); + mTasks.pop(); + } + try { + task(); + } catch (const std::exception& e) { + NN_LOG_ERROR("Caught error " << e.what() << " when execute a task, continue"); + } catch (...) { + NN_LOG_ERROR("Caught unknown error when execute a task, continue"); + } + } +} +} +} +#endif \ No newline at end of file diff --git a/src/transport/ub/ub_thread_pool.h b/src/transport/ub/ub_thread_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..8335b4f2296390c40f7ca1035b84444aad31d717 --- /dev/null +++ b/src/transport/ub/ub_thread_pool.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_THREAD_POOL_H +#define HCOM_UB_THREAD_POOL_H +#ifdef UB_BUILD_ENABLED +#include +#include +#include +#include +#include +#include +#include + +#include "hcom.h" +#include "ub_common.h" + +namespace ock { +namespace hcom { + +class UBThreadPool { +public: + explicit UBThreadPool(uint16_t threadCount = NN_NO16) : mThreadCount(threadCount), mIsRunning(false) {} + ~UBThreadPool() + { + Stop(); + } + UBThreadPool(const UBThreadPool&) = delete; + UBThreadPool& operator=(const UBThreadPool&) = delete; + void Initialize(); + void Stop(); + void Submit(std::function task); + +private: + void RunInThread(); + + uint16_t mThreadCount; + std::vector mThreads; + std::queue> mTasks; + std::mutex mMutex; + std::condition_variable mCondition; + std::atomic mIsRunning; +}; + +} +} +#endif +#endif // HCOM_UB_THREAD_POOL_H \ No newline at end of file diff --git a/src/transport/ub/ub_urma_wrapper_ctx.cpp b/src/transport/ub/ub_urma_wrapper_ctx.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddce481d13dc04504eb0112a0840413fac7e398f --- /dev/null +++ b/src/transport/ub/ub_urma_wrapper_ctx.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED + +#include "ub_urma_wrapper_ctx.h" + +namespace ock { +namespace hcom { + +UResult UBContext::Create(const std::string &name, const UBEId &eid, UBContext *&ctx) +{ + auto tmpCtx = new (std::nothrow) UBContext(name, eid); + if (tmpCtx == nullptr) { + return UB_NEW_OBJECT_FAILED; + } + + ctx = tmpCtx; + return UB_OK; +} + +UResult UBContext::Initialize() +{ + if (mUrmaContext != nullptr) { + NN_LOG_INFO("UBContext " << mName << " already initialized"); + return UB_OK; + } + + UResult ret = UB_OK; + + urma_device_t **devList = nullptr; + int devCount = 0; + devList = HcomUrma::GetDeviceList(&devCount); + if (devList == nullptr) { + NN_LOG_ERROR("Failed to call get urma device list for UBContext " << mName << ", errno " << errno); + return UB_DEVICE_FAILED_OPEN; + } + auto guard = MakeScopeExit([&devList]() { HcomUrma::FreeDeviceList(devList); }); + if (mDevIndex >= devCount) { + NN_LOG_ERROR("Invalid device index is set for UBContext " << mName); + return UB_DEVICE_INDEX_OVERFLOW; + } + + urma_context_t *tmpCtx = nullptr; + if ((tmpCtx = HcomUrma::CreateContext(devList[mDevIndex], mEidIndex)) == nullptr) { + NN_LOG_ERROR("Invalid device index is set for UBContext " << mName << ", errno " << errno); + return UB_DEVICE_OPEN_FAILED; + } + + mDevAttr = reinterpret_cast(malloc(sizeof(urma_device_attr_t))); + if (mDevAttr == nullptr) { + HcomUrma::DeleteContext(tmpCtx); + NN_LOG_ERROR("Failed to malloc for urma device attr"); + return UB_MEMORY_ALLOCATE_FAILED; + } + if ((ret = HcomUrma::QueryDevice(devList[mDevIndex], mDevAttr)) != 0) { + NN_LOG_ERROR("Failed to query urma device"); + free(mDevAttr); + mDevAttr = nullptr; + HcomUrma::DeleteContext(tmpCtx); + return ret; + } + int tmpMaxSge = std::min(mDevAttr->dev_cap.max_jfs_sge, mDevAttr->dev_cap.max_jfr_sge); + mMaxSge = tmpMaxSge < mMaxSge ? tmpMaxSge : mMaxSge; + + NN_LOG_INFO("Device info: max_qp " << mDevAttr->dev_cap.max_jetty << " ,max_qp_wr " << + mDevAttr->dev_cap.max_jfs_depth << " ,max_sge " << tmpMaxSge << " ,adapter max_cqe " << mMaxSge << + " ,max_cq " << mDevAttr->dev_cap.max_jfc << " ,max_cqe " << mDevAttr->dev_cap.max_jfc_depth); + + mMaxJfr = mDevAttr->dev_cap.max_jfr_depth; + mMaxJfs = mDevAttr->dev_cap.max_jfs_depth; + + mUrmaContext = tmpCtx; + return UB_OK; +} + +UResult UBContext::UnInitialize() +{ + if (mUrmaContext != nullptr) { + int res = 0; + if ((res = HcomUrma::DeleteContext(mUrmaContext)) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Unable to delete UB Context " << res << ", as errno " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + } + mUrmaContext = nullptr; + } + + if (mDevAttr != nullptr) { + free(mDevAttr); + mDevAttr = nullptr; + } + UBDeviceHelper::UnInitialize(); + return UB_OK; +} + +void UBContext::UpdateGid(const std::string &matchIp) +{ + auto ret = UBDeviceHelper::Update(); + if (NN_UNLIKELY(ret != UB_OK)) { + NN_LOG_ERROR("Failed to do update"); + return; + } + + UBEId tmpEid{}; + if ((UBDeviceHelper::GetDeviceByIp(matchIp, tmpEid)) != 0) { + NN_LOG_ERROR("Failed to get device by ip " << matchIp); + return; + } + + NN_LOG_INFO("gid found devIndex " << tmpEid.devIndex << ", gidIndex " << tmpEid.eidIndex); + mBestEid = tmpEid; +} +} // namespace hcom +} +#endif \ No newline at end of file diff --git a/src/transport/ub/ub_urma_wrapper_ctx.h b/src/transport/ub/ub_urma_wrapper_ctx.h new file mode 100644 index 0000000000000000000000000000000000000000..b3c4025b8604cdedb8aae336634cbca653543292 --- /dev/null +++ b/src/transport/ub/ub_urma_wrapper_ctx.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_URMA_WRAPPER_CTX_H +#define HCOM_UB_URMA_WRAPPER_CTX_H +#ifdef UB_BUILD_ENABLED + +#include "ub_common.h" +#include "ub_device_helper.h" + +namespace ock { +namespace hcom { + +extern std::atomic g_jetty_id; +extern uint64_t g_connection_count; + +class UBContext { +public: + static UResult Create(const std::string &name, const UBEId &eid, UBContext *&ctx); + +public: + UBContext(const std::string &name, const UBEId &eid) : mName(name), mDevIndex(eid.devIndex), + mEidIndex(eid.eidIndex), mBestEid(eid) + { + OBJ_GC_INCREASE(UBContext); + } + + ~UBContext() + { + UnInitialize(); + OBJ_GC_DECREASE(UBContext); + } + + UResult Initialize(); + UResult UnInitialize(); + + void UpdateGid(const std::string &matchIp); + + UBContext() = delete; + UBContext(const UBContext &) = delete; + UBContext &operator = (const UBContext &) = delete; + UBContext(UBContext &&) = delete; + UBContext &operator = (UBContext &&) = delete; + + std::string ToString(); + + inline urma_context_t *GetContext() + { + return mUrmaContext; + } + + inline uint32_t GetMaxJfs() + { + return mMaxJfs; + } + + inline uint32_t GetMaxJfr() + { + return mMaxJfr; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + + UBSHcomNetDriverProtocol protocol = UBSHcomNetDriverProtocol::UBC; + +private: + std::string mName; + urma_context_t *mUrmaContext = nullptr; + urma_device_attr_t *mDevAttr = nullptr; + uint8_t mPortNumber = 1; + uint16_t mDevIndex = 0; + uint16_t mEidIndex = 0; + uint32_t mMaxJfs = 0; + uint32_t mMaxJfr = 0; + int mMaxSge = NN_NO16; + UBEId mBestEid{}; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend UBJetty; + friend UBJfc; + friend UBMemoryRegion; + friend UBWorker; + friend NetDriverUB; + friend class UBPublicJetty; +}; +} +} +#endif +#endif // HCOM_UB_URMA_WRAPPER_CTX_H \ No newline at end of file diff --git a/src/transport/ub/ub_urma_wrapper_jetty.cpp b/src/transport/ub/ub_urma_wrapper_jetty.cpp new file mode 100644 index 0000000000000000000000000000000000000000..88c934430391e3e351caf99524413f36fd702ff9 --- /dev/null +++ b/src/transport/ub/ub_urma_wrapper_jetty.cpp @@ -0,0 +1,516 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED + +#include "hcom_env.h" +#include "ub_urma_wrapper_jetty.h" +#include "ub_worker.h" +#include "under_api/obmm/obmm_api_wrapper.h" + +namespace ock { +namespace hcom { + +std::atomic g_jetty_id(1); + +/* ******************************************************************************************** */ +/* ******************************************************************************************** */ +uint32_t UBJetty::G_INDEX = 1; + +UResult UBJetty::CreateUrmaJetty(uintptr_t seg_pa, uint32_t seg_len, uint32_t seg_count, uint32_t token) +{ + if (mUBContext == nullptr || mUBContext->mUrmaContext == nullptr || mSendJfc == nullptr || + mSendJfc->mUrmaJfc == nullptr) { + NN_LOG_ERROR("Invalid parameter for jetty creating"); + return UB_PARAM_INVALID; + } + + mCtxPosted.next = nullptr; + mCtxPosted.prev = nullptr; + + mJettyOptions.maxSendWr = + (mJettyOptions.maxSendWr < JETTY_MAX_SEND_WR) ? JETTY_MAX_SEND_WR : mJettyOptions.maxSendWr; + mJettyOptions.maxReceiveWr = + (mJettyOptions.maxReceiveWr < JETTY_MAX_RECV_WR) ? JETTY_MAX_RECV_WR : mJettyOptions.maxReceiveWr; + + urma_jfs_cfg_t jfs_cfg{}; + FillJfsCfg(&jfs_cfg); + urma_jfr_cfg_t jfr_cfg{}; + FillJfrCfg(&jfr_cfg, token); + + urma_jetty_flag_t jetty_flag{}; + jetty_flag.bs.share_jfr = 1; + + urma_jetty_cfg_t jetty_cfg{}; + jetty_cfg.id = 0; + jetty_cfg.flag = jetty_flag; + jetty_cfg.jfs_cfg = jfs_cfg; + jetty_cfg.jfr_cfg = &jfr_cfg; + + urma_jetty_t *tmpJetty = nullptr; + + mJfr = HcomUrma::CreateJfr(mUBContext->mUrmaContext, &jfr_cfg); + if (mJfr == nullptr) { + NN_LOG_ERROR("urma create jfr failed"); + return UB_PARAM_INVALID; + } + jetty_cfg.shared.jfc = mRecvJfc->mUrmaJfc; + jetty_cfg.shared.jfr = mJfr; + tmpJetty = HcomUrma::CreateJetty(mUBContext->mUrmaContext, &jetty_cfg); + if (tmpJetty == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create urma jetty for UBJetty " << mName << ", errno " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + if (mJfr != nullptr) { + HcomUrma::DeleteJfr(mJfr); + mJfr = nullptr; + } + return UB_QP_CREATE_FAILED; + } + mUrmaJetty = tmpJetty; + mUrmaJettyId = mUrmaJetty->jetty_id.id; + + NN_LOG_INFO("Create jetty success, jetty id: " << mUrmaJettyId << ", jfr id: " << (mJfr ? mJfr->jfr_id.id : -1) << + ", jfc id: " << mRecvJfc->mUrmaJfc->jfc_id.id); + return UB_OK; +} + +UResult UBJetty::Stop() +{ + std::lock_guard lock(mStopMutex); + if (mUrmaJetty == nullptr) { + return UB_OK; + } + + // 标记为 ERROR,准备进入待清理流程 + // + // 用户可能人工调用 EP->Close() 进入,与ProcessEpError (心跳线程或者是 UBWorker)产生竞争,避免因多次调用 modify + // jetty error 产生多次 FLUSH_ERROR_DONE 的情况。 + UBJettyState expState = UBJettyState::READY; + if (!mState.compare_exchange_strong(expState, UBJettyState::ERROR)) { + return UB_OK; + } + + // 仅 AsyncEp. 在 modify jetty error 后可能会立即触发 FLUSH_ERROR_DONE,需要提前准备好。 + auto *worker = reinterpret_cast(GetUpContext1()); + if (worker != nullptr) { + worker->mJettyPtrMap.Emplace(mUrmaJettyId, this); + } + + //jfr 为 error 后不会再收到对端发来的数据。 + if (mJfr != nullptr) { + urma_jfr_attr_t jfr_attr = {}; + jfr_attr.mask = JFR_STATE; + jfr_attr.state = URMA_JFR_STATE_ERROR; + result = HcomUrma::ModifyJfr(mJfr, &jfr_attr); + if (result != UB_OK) { + NN_LOG_ERROR("Fail to modify jfr to URMA_JFR_STATE_ERROR, urma result is " << result); + return result; + } + } + + // 大段注释,主要说明了需要先将jfr置为errror再ModifyJetty,否则会导致问题。 + struct urma_jetty_attr attr = {}; + attr.mask = JETTY_STATE; + attr.state = URMA_JETTY_STATE_ERROR; + int result = HcomUrma::ModifyJetty(mUrmaJetty, &attr); + if (result != UB_OK) { + NN_LOG_ERROR("Failed to modify jetty to URMA_JETTY_STATE_ERROR, urma result is " << result); + return result; + } + + if (mHBLocalMr != nullptr) { + DestroyHBMemoryRegion(mHBLocalMr); + mHBLocalMr.Set(nullptr); + } + + if (mHBRemoteMr != nullptr) { + DestroyHBMemoryRegion(mHBRemoteMr); + mHBRemoteMr.Set(nullptr); + } + + NN_LOG_INFO("Stop Jetty " << mName << ", jetty id " << mUrmaJetty->jetty_id.id << ", Ep Id " << mUpId); + return result; +} + +void UBJetty::FillJfsCfg(urma_jfs_cfg_t *jfs_cfg) +{ + jfs_cfg->user_ctx = reinterpret_cast(this); + jfs_cfg->jfc = mSendJfc->mUrmaJfc; + jfs_cfg->trans_mode = URMA_TM_RC; + jfs_cfg->depth = mJettyOptions.maxSendWr + NN_NO8; + jfs_cfg->max_sge = static_cast(mUBContext->mMaxSge); + jfs_cfg->flag.value = 0; + jfs_cfg->max_inline_data = HcomEnv::InlineThreshold(); + jfs_cfg->err_timeout = NN_NO8; + jfs_cfg->rnr_retry = NN_NO7; + // HighBandwidth RM mode, LowLatency RC mode + jfs_cfg->trans_mode = (mJettyOptions.ubcMode == UBSHcomUbcMode::HighBandwidth) ? URMA_TM_RM : URMA_TM_RC; + jfs_cfg->flag.bs.multi_path = (mJettyOptions.ubcMode == UBSHcomUbcMode::HighBandwidth) ? 1 : 0; +} + +void UBJetty::FillJfrCfg(urma_jfr_cfg_t *jfr_cfg, uint32_t token) +{ + jfr_cfg->user_ctx = reinterpret_cast(this); + jfr_cfg->jfc = mRecvJfc->mUrmaJfc; + jfr_cfg->trans_mode = URMA_TM_RC; + jfr_cfg->depth = mJettyOptions.maxReceiveWr + NN_NO8; + jfr_cfg->max_sge = static_cast(mUBContext->mMaxSge); + jfr_cfg->token_value = {token}; + jfr_cfg->flag.bs.token_policy = URMA_TOKEN_PLAIN_TEXT; + jfr_cfg->id = 0; + jfr_cfg->flag.bs.tag_matching = URMA_NO_TAG_MATCHING; + // HighBandwidth RM mode, LowLatency RC mode + jfr_cfg->trans_mode = (mJettyOptions.ubcMode == UBSHcomUbcMode::HighBandwidth) ? URMA_TM_RM : URMA_TM_RC; +} + +UResult UBJetty::CreateJettyMr() +{ + NResult result = NN_OK; + // create mr pool for send/receive and initialize + if ((result = UBMemoryRegionFixedBuffer::Create(mName, mUBContext, mJettyOptions.mrSegSize, + mJettyOptions.mrSegCount, mJettyOptions.slave, mJettyMr)) != 0) { + NN_LOG_ERROR("Failed to create mr for send/receive in jetty " << mName << ", result " << result); + return result; + } + mJettyMr->IncreaseRef(); + if ((result = mJettyMr->Initialize()) != 0) { + NN_LOG_ERROR("Failed to initialize mr for send/receive in jetty " << mName << ", result " << result); + return result; + } + + return UB_OK; +} + +UBMemoryRegionFixedBuffer *UBJetty::GetJettyMr() +{ + return mJettyMr; +} + +bool UBJetty::GetFreeBuff(uintptr_t &item) +{ + return mJettyMr->GetFreeBuffer(item); +} + +bool UBJetty::GetFreeBufferN(uintptr_t *&items, uint32_t n) +{ + return mJettyMr->GetFreeBufferN(items, n); +} + +bool UBJetty::ReturnBuffer(uintptr_t value) +{ + return mJettyMr->ReturnBuffer(value); +} + +uint64_t UBJetty::GetLKey() +{ + return mJettyMr->GetLKey(); +} + +urma_target_seg_t *UBJetty::GetMemorySeg() +{ + return reinterpret_cast(mJettyMr->GetMemorySeg()); +} + +UResult UBJetty::Initialize(uint32_t seg_count, unsigned long memid, uint32_t token) +{ + auto result = CreateJettyMr(); + if (result != UB_OK) { + return result; + } + result = CreateUrmaJetty(0, 0, 0, token); + if (result != UB_OK) { + NN_LOG_ERROR("Failed to create urma jetty"); + return result; + } + return UB_OK; +} + +UResult UBJetty::UnInitialize() +{ + int result = 0; + if (mUrmaJetty != nullptr) { + if (mJettyOptions.ubcMode == UBSHcomUbcMode::LowLatency) { + result = HcomUrma::UnbindJetty(mUrmaJetty); + } + if (mTargetJetty != nullptr) { + HcomUrma::UnimportJetty(mTargetJetty); + mTargetJetty = nullptr; + } + if ((result = HcomUrma::DeleteJetty(mUrmaJetty)) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Unable to delete jetty id " << mUrmaJettyId << ", result " << result << ", errno " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + } else { + NN_LOG_INFO("Delete jetty success, jetty id: " << mUrmaJettyId); + } + mUrmaJetty = nullptr; + } + + if (mJfr != nullptr) { + if ((result = HcomUrma::DeleteJfr(mJfr)) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Unable to delete jfr " << result << ", as errno " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + } + mJfr = nullptr; + } + + if (mSendJfc != nullptr) { + mSendJfc->DecreaseRef(); + } + + if (mRecvJfc != nullptr && mRecvJfc != mSendJfc) { + mRecvJfc->DecreaseRef(); + } + mSendJfc = nullptr; + mRecvJfc = nullptr; + + if (mUBContext != nullptr) { + mUBContext->DecreaseRef(); + mUBContext = nullptr; + } + + // 销毁jfr才能停止接收;先销毁jfr,再销毁接收缓冲区 + if (mJettyMr != nullptr) { + mJettyMr->DecreaseRef(); + mJettyMr = nullptr; + } + + if (mHBLocalMr != nullptr) { + DestroyHBMemoryRegion(mHBLocalMr); + mHBLocalMr.Set(nullptr); + } + + if (mHBRemoteMr != nullptr) { + DestroyHBMemoryRegion(mHBRemoteMr); + mHBRemoteMr.Set(nullptr); + } + NN_LOG_INFO("Uninitialize jetty success, jetty id: " << mUrmaJettyId); + return UB_OK; +} + +void UBJetty::Cleanup() +{ + // 仅 AsyncEp 需要清理 op ctx list. SyncEp 使用的是 thread_local 级别的 op ctx, 无需清理 + auto *worker = reinterpret_cast(GetUpContext1()); + if (worker == nullptr) { + return; + } + + // 如果在创建 EP 过程中失败,则 UBJetty 无对应EP,依赖 ClearJettyResource 清理 PostReceive 的资源。 + // \see ClearJettyResource + auto *ep = reinterpret_cast(GetUpContext()); + if (ep == nullptr) { + return; + } + + // EP 析构时先析构 jetty,再析构 worker、driver. + auto *driver = ep->GetDriver(); + + UBOpContextInfo *it = nullptr; + GetCtxPosted(it); + while (it != nullptr) { + UBOpContextInfo *next = it->next; + + // 剩余的 op ctx 都未被硬件处理,无法获得CQE,需要 hcom 人工清理 + it->opResultType = UBOpContextInfo::ERR_EP_BROKEN; + switch (it->opType) { + case UBOpContextInfo::SEND: + case UBOpContextInfo::SEND_RAW: + case UBOpContextInfo::SEND_RAW_SGL: + driver->ProcessErrorSendFinished(it); + break; + case UBOpContextInfo::RECEIVE: + case UBOpContextInfo::RECEIVE_RAW: + driver->ProcessErrorNewRequest(it); + break; + case UBOpContextInfo::WRITE: + case UBOpContextInfo::READ: + case UBOpContextInfo::SGL_WRITE: + case UBOpContextInfo::SGL_READ: + case UBOpContextInfo::HB_WRITE: + driver->ProcessErrorOneSideDone(it); + break; + } + + // 至此,it 指向的内存可能会归还给 mempool,再修改 it 指向的内存可能会引起并发冲突 + it = next; + } +} + +UResult UBJetty::ChangeToInit(urma_jetty_attr_t &attr) +{ + return UB_OK; +} + +UResult UBJetty::ChangeToReceive(ock::hcom::UBJettyExchangeInfo &exInfo, urma_jetty_attr_t &attr) +{ + return UB_OK; +} + +UResult UBJetty::ChangeToSend(urma_jetty_attr_t &attr) +{ + return UB_OK; +} + +UResult UBJetty::ChangeToReady(ock::hcom::UBJettyExchangeInfo &exInfo) +{ + if (NN_UNLIKELY(mUrmaJetty == nullptr)) { + NN_LOG_ERROR("Failed to change jetty " << mName << " state to READY as urma jetty is not created."); + return UB_QP_CHANGE_STATE_FAILED; + } + + UResult ret = 0; + ret = SetMaxSendWrConfig(exInfo); + if (ret != UB_OK) { + return ret; + } + + ret = ImportAndBindJetty(exInfo.token); + if (ret != UB_OK) { + return ret; + } + + NN_LOG_INFO("UB jetty " << mId << " attr send queue size " << mJettyOptions.maxSendWr << ", receive queue size " << + mJettyOptions.maxReceiveWr << ", eid-n-n " << (exInfo.eid.in6.interface_id != 0)); + + mState = UBJettyState::READY; + return UB_OK; +} + +UResult UBJetty::SetMaxSendWrConfig(UBJettyExchangeInfo &exInfo) +{ + NN_LOG_TRACE_INFO("Remote qpId " << mId << " info: send wr " << exInfo.maxSendWr << ", receive wr " << + exInfo.maxReceiveWr << ", receive seg size " << exInfo.receiveSegSize << ", receive seg count " << + exInfo.receiveSegCount); + NN_LOG_TRACE_INFO("Local qpId " << mId << " info: send wr " << mJettyOptions.maxSendWr << ", receive wr " << + mJettyOptions.maxReceiveWr << ", receive seg size " << mJettyOptions.mrSegSize << ", receive seg count " << + mJettyOptions.mrSegCount); + + int32_t maxWr = std::min(mJettyOptions.maxSendWr, exInfo.maxReceiveWr); + int32_t maxPostSendWr = std::min(mJettyOptions.maxSendWr, exInfo.receiveSegCount); + if (maxWr < maxPostSendWr) { + NN_LOG_ERROR("Qp " << mId << " max wr " << maxWr << " is less than max post send wr" << maxPostSendWr); + return UB_QP_RECEIVE_CONFIG_ERR; + } + // one side operation do not consume remote receive queue element + mOneSideMaxWr = maxWr - maxPostSendWr; + mOneSideRef = mOneSideMaxWr; + mPostSendMaxWr = maxPostSendWr; + mPostSendRef = mPostSendMaxWr; + mPostSendMaxSize = exInfo.receiveSegSize; + NN_LOG_TRACE_INFO("Qp id " << mId << " one side max wr " << mOneSideMaxWr << ", post send max wr " << + mPostSendMaxWr << ", post send max size " << mPostSendMaxSize); + return UB_OK; +} + +UResult UBJetty::FillExchangeInfo(UBJettyExchangeInfo &exInfo) +{ + if (mUrmaJetty == nullptr || mUBContext == nullptr) { + return UB_QP_NOT_INITIALIZED; + } + + exInfo.jettyId = mUrmaJetty->jetty_id; + exInfo.eid = mUBContext->mBestEid.urmaEid; + + return UB_OK; +} + +void UBJetty::StoreExchangeInfo(UBJettyExchangeInfo *exInfo) +{ + mRemoteJettyInfo.reset(exInfo); +} + +UBJettyExchangeInfo &UBJetty::GetExchangeInfo() +{ + return *mRemoteJettyInfo; +} + +void UBJetty::SetPeerIpAndPort(const std::string &value) +{ + mPeerIpPort = value; +} + +uint32_t UBJetty::GetPostSendMaxSize() const +{ + return mPostSendMaxSize; +} + +UResult UBJetty::ImportAndBindJetty(uint32_t token) +{ + // import/bind remote jetty + urma_rjetty_t remoteJetty{}; // remote jetty on the other side + remoteJetty.jetty_id = mRemoteJettyInfo->jettyId; + remoteJetty.trans_mode = URMA_TM_RC; + remoteJetty.type = URMA_JETTY; + remoteJetty.trans_mode = (mJettyOptions.ubcMode == UBSHcomUbcMode::HighBandwidth) ? URMA_TM_RM : URMA_TM_RC; + remoteJetty.flag.bs.token_policy = URMA_TOKEN_PLAIN_TEXT; + urma_token_t tokenValue{token}; + mTargetJetty = HcomUrma::ImportJetty(mUBContext->mUrmaContext, &remoteJetty, &tokenValue); + if (mTargetJetty == nullptr) { + NN_LOG_ERROR("Failed to import jetty"); + return UB_QP_IMPORT_FAILED; + } + + NN_LOG_INFO("Local jetty id: " << mUrmaJetty->jetty_id.id); + NN_LOG_INFO("Remote jetty id: " << mRemoteJettyInfo->jettyId.id); + + if (mJettyOptions.ubcMode == UBSHcomUbcMode::LowLatency) { + int ret = HcomUrma::BindJetty(mUrmaJetty, mTargetJetty); + if (ret != URMA_SUCCESS && ret != URMA_EEXIST) { + NN_LOG_ERROR("Failed to bind local jetty, result: " << ret); + HcomUrma::UnimportJetty(mTargetJetty); + mTargetJetty = nullptr; + return UB_QP_BIND_FAILED; + } + } + + return UB_OK; +} + +NResult UBJetty::CreateHBMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + if (NN_UNLIKELY(size == 0 || size > NN_NO65536)) { + NN_LOG_ERROR("Failed to create heartbeat mem region as size is 0 or greater than 64 KB"); + return NN_INVALID_PARAM; + } + + UBMemoryRegion *tmp = nullptr; + auto result = UBMemoryRegion::Create(mName, mUBContext, size, tmp); + if (NN_UNLIKELY(result != UB_OK)) { + NN_LOG_ERROR("Failed to create heartbeat mem region, result " << result); + return result; + } + + if ((result = tmp->InitializeForOneSide()) != UB_OK) { + delete tmp; + return result; + } + + mr.Set(static_cast(tmp)); + + return UB_OK; +} + +void UBJetty::DestroyHBMemoryRegion(UBSHcomNetMemoryRegionPtr &mr) +{ + if (mr.Get() == nullptr) { + NN_LOG_WARN("Try to destroy null memory region"); + return; + } + + mr->UnInitialize(); +} + +} // namespace hcom +} // namespace ock +#endif diff --git a/src/transport/ub/ub_urma_wrapper_jetty.h b/src/transport/ub/ub_urma_wrapper_jetty.h new file mode 100644 index 0000000000000000000000000000000000000000..57eaf3dd4adc7f62a4fbf0bdfb80387c9ce16b71 --- /dev/null +++ b/src/transport/ub/ub_urma_wrapper_jetty.h @@ -0,0 +1,1022 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_URMA_WRAPPER_JETTY_H +#define HCOM_UB_URMA_WRAPPER_JETTY_H +#ifdef UB_BUILD_ENABLED + +#include + +#include +#include "ub/umdk/urma/urma_ubagg.h" + +#include "hcom_utils.h" +#include "net_util.h" +#include "net_oob.h" +#include "ub_common.h" +#include "net_load_balance.h" +#include "net_ctx_info_pool.h" +#include "net_mem_pool_fixed.h" +#include "hcom_obj_statistics.h" +#include "under_api/urma/urma_api_wrapper.h" +#include "ub_mr_fixed_buf.h" +#include "ub_urma_wrapper_jfc.h" + +namespace ock { +namespace hcom { + +struct UBJettyExchangeInfo { + urma_jetty_id_t jettyId{}; + urma_eid_t eid{}; + uint32_t token = 0; + uintptr_t hbAddress = 0; + uint64_t hbKey = 0; + uint64_t hbMrSize = 0; + uint32_t maxSendWr = JETTY_MAX_SEND_WR; + uint32_t maxReceiveWr = JETTY_MAX_RECV_WR; + uint32_t receiveSegSize = NN_NO1024; + uint32_t receiveSegCount = NN_NO64; + bool isNeedSendHb = true; +} __attribute__((packed)); + +enum class UBJettyState : uint8_t { + RESET, ///< 初始状态 + READY, ///< 可收发数据 + ERROR, ///< 调用 modify jetty error之后 +}; + +class UBJetty { +public: + UBJetty(const std::string &name, uint32_t id, UBContext *ctx, UBJfc *jfc, JettyOptions jettyOptions = {}) + : mName(name), mId(id), mUBContext(ctx), mSendJfc(jfc), mRecvJfc(jfc), mJettyOptions(jettyOptions) + { + if (mUBContext != nullptr) { + mUBContext->IncreaseRef(); + } + + if (mSendJfc != nullptr) { + mSendJfc->IncreaseRef(); + } + + OBJ_GC_INCREASE(UBJetty); + } + + ~UBJetty() + { + UnInitialize(); + OBJ_GC_DECREASE(UBJetty); + } + + /* call urma_create_jetty to create real jetty */ + UResult CreateUrmaJetty(uintptr_t seg_pa, uint32_t seg_len, uint32_t seg_count, uint32_t token = 0); + UResult CreateJettyMr(); + UBMemoryRegionFixedBuffer *GetJettyMr(); + + UResult Initialize(uint32_t seg_count, unsigned long memid, uint32_t token = 0); + UResult UnInitialize(); + + /// 清理 op ctx 等资源 + void Cleanup(); + + /// 清理在 FLUSH_ERR_DONE 之后被 post 的 op ctx等资源 + void Flush() + { + Cleanup(); + } + /* + exchange information needs to be transformed by other channel (e.g. tcp connection) + 1 firstly do the initialization + 2 got qp exchange info from peer + 3 call this function to change qp state to ready state (INIT & RTS & RTR) + */ + UResult ChangeToReady(UBJettyExchangeInfo &exInfo); + + /* after qp initialized, retrieve the qp qp_num for exchange */ + UResult FillExchangeInfo(UBJettyExchangeInfo &exInfo); + void StoreExchangeInfo(UBJettyExchangeInfo *exInfo); + UBJettyExchangeInfo &GetExchangeInfo(); + UResult ImportAndBindJetty(uint32_t token = 0); + + inline urma_target_seg_t *ImportSeg(uintptr_t addr, uint32_t bufSize, uint64_t token) + { + uint32_t tokenId = static_cast(token); + urma_token_t tokenValue = {static_cast(token >> NN_NO32)}; + urma_seg_t remoteSeg{}; + remoteSeg.len = bufSize; + remoteSeg.ubva.va = addr; + remoteSeg.token_id = tokenId; + remoteSeg.ubva.eid = mRemoteJettyInfo->eid; + remoteSeg.attr.bs.token_policy = URMA_TOKEN_PLAIN_TEXT; + + urma_import_seg_flag_t flag{}; + flag.bs.cacheable = URMA_NON_CACHEABLE; + flag.bs.access = URMA_ACCESS_READ | URMA_ACCESS_WRITE; + flag.bs.mapping = URMA_SEG_NOMAP; + + return HcomUrma::ImportSeg(mUBContext->mUrmaContext, &remoteSeg, &tokenValue, 0, flag); + } + + inline UBSHcomNetDriverProtocol GetProtocol() + { + return mUBContext->protocol; + } + + inline UResult PostReceive(uintptr_t bufAddr, uint32_t bufSize, urma_target_seg_t *localSeg, uint64_t context) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + + urma_jfr_wr_t *bad_wr; + + urma_sge_t local_sge{}; + local_sge.addr = bufAddr; + local_sge.len = bufSize; + local_sge.tseg = localSeg; + + urma_jfr_wr_t wr{}; + wr.user_ctx = context; + wr.src.sge = &local_sge; + wr.src.num_sge = 1; + wr.next = nullptr; + + NN_LOG_DEBUG("[Post Buffer] ------ urma_post_jetty_recv_wr1, jetty id: " << mUrmaJetty->jetty_id.id + << ", jfc id: " << mRecvJfc->mUrmaJfc->jfc_id.id); + auto ret = HcomUrma::PostJettyRecvWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_ERROR("Failed to post receive request to jetty " << mName << ", result " << ret); + return UB_QP_POST_RECEIVE_FAILED; + } + + return UB_OK; + } + + inline UResult PostSend(uintptr_t bufAddr, uint32_t bufSize, urma_target_seg_t *localSeg, UBOpContextInfo *context, + uint32_t immData = 0) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + + auto qpUpContext = context->ubJetty->GetUpContext(); + UBSHcomNetEndpointPtr ep = reinterpret_cast(qpUpContext); + UBSHcomNetTransHeader *header = (UBSHcomNetTransHeader *)bufAddr; + uint64_t epId = ep->Id(); + + // 如果是普通send,存在header,header中的seqNo是序号 + // 如果是send_raw,不存在header,immData是序号 + if (context->opType == UBOpContextInfo::SEND) { + NN_LOG_DEBUG("[Request Send] ------ ep id = " << epId << ", headerCrc = " + << header->headerCrc << ", opCode = " << header->opCode << ", flags = " << header->flags << ", seqNo = " + << header->seqNo << ",timeout = " << header->timeout << ", errCode = " << header->errorCode + << ", dataLength = " << header->dataLength << ", status = " << + UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_URMA)); + } else { + NN_LOG_DEBUG("[Request Send] ------ ep id = " << epId << ", seqNo = " << immData + << ", bufSize = " << bufSize << ", bufhead = " << *(reinterpret_cast(bufAddr)) + << ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_URMA)); + } + + urma_jfs_wr_t *bad_wr; + urma_sge_t local_sge{}; + local_sge.addr = bufAddr; + local_sge.len = bufSize; + local_sge.tseg = localSeg; + + urma_jfs_wr_t wr{}; + wr.user_ctx = reinterpret_cast(context); + wr.send.src.sge = &local_sge; + wr.send.src.num_sge = 1; + wr.send.imm_data = immData; + wr.next = nullptr; + wr.opcode = URMA_OPC_SEND_IMM; + wr.flag.bs.complete_enable = 1; + wr.tjetty = mTargetJetty; + + auto ret = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_ERROR("Failed to post send request to jetty " << mName << ", result " << ret); + return UB_QP_POST_SEND_FAILED; + } + + return UB_OK; + } + + inline UResult PostSendSglInline(UBSHcomNetTransDataIov *iov, uint32_t iovCount, uint64_t context, + uint32_t immData = 0) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + + urma_jfs_wr_t *badWR; + urma_sge_t list[NN_NO4] = {}; + urma_target_seg_t srcSeg[NET_SGE_MAX_IOV] = {}; + for (uint32_t i = 0; i < iovCount; i++) { + list[i].addr = iov[i].address; + list[i].len = iov[i].size; + } + + urma_jfs_wr_t wr {}; + wr.user_ctx = reinterpret_cast(context); + wr.send.src.sge = list; + wr.send.src.num_sge = iovCount; + wr.send.imm_data = immData; + wr.next = nullptr; + wr.opcode = URMA_OPC_SEND_IMM; + wr.flag.bs.complete_enable = 1; + wr.flag.bs.inline_flag = 1; + wr.tjetty = mTargetJetty; + + auto result = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to post send request to jetty " << mName << ", result " << result); + return UB_QP_POST_SEND_FAILED; + } + + return UB_OK; + } + + inline UResult PostSendSglInlineUbc(UBSHcomNetTransDataIov *iov, uint32_t iovCount, uint64_t context, + urma_target_seg_t **tseg, uint32_t immData = 0) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + + urma_jfs_wr_t *badWR; + urma_sge_t list[NN_NO4] = {}; + for (uint32_t i = 0; i < iovCount; i++) { + list[i].addr = iov[i].address; + list[i].len = iov[i].size; + list[i].tseg = tseg[i]; + } + + urma_jfs_wr_t wr {}; + wr.user_ctx = reinterpret_cast(context); + wr.send.src.sge = list; + wr.send.src.num_sge = iovCount; + wr.send.imm_data = immData; + wr.next = nullptr; + wr.opcode = URMA_OPC_SEND_IMM; + wr.flag.bs.inline_flag = 1; + wr.flag.bs.complete_enable = 1; + + auto result = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Failed to post send sgl request to jetty " << mName << ", result " << result); + return UB_QP_POST_SEND_FAILED; + } + + return UB_OK; + } + + /// @brief 发送 send 请求,使用 sgl 方式 + /// @param iov [in] 将要发送的向量。仅需填充 lAddress, lkey 和 size. + /// @param iovCount [in] 向量长度 + /// @param context [in] Service 层上下文 + /// @param immData [in] 附带的立即数 + UResult PostSendSgl(UBSHcomNetTransSgeIov *iov, uint32_t iovCount, uint64_t context, uint32_t immData = 0) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + + urma_jfs_wr_t *bad_wr; + urma_sge_t list[NET_SGE_MAX_IOV]; + urma_target_seg_t srcSeg[NET_SGE_MAX_IOV] = {}; + for (uint32_t i = 0; i < iovCount; i++) { + list[i].addr = iov[i].lAddress; + list[i].len = iov[i].size; + list[i].tseg = reinterpret_cast(iov[i].srcSeg); + } + + urma_jfs_wr_t wr{}; + wr.user_ctx = reinterpret_cast(context); + wr.send.src.sge = list; + wr.send.src.num_sge = iovCount; + wr.send.imm_data = immData; + wr.next = nullptr; + wr.opcode = URMA_OPC_SEND_IMM; + wr.flag.bs.complete_enable = 1; + wr.tjetty = mTargetJetty; + + auto ret = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &bad_wr); // urma_post_jetty_send_wr + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_ERROR("Failed to post send sgl request to jetty " << mName << ", result " << ret); + return UB_QP_POST_SEND_FAILED; + } + + return UB_OK; + } + + inline UResult PostRead(uintptr_t bufAddr, urma_target_seg_t *srcSeg, uintptr_t dstBufAddr, + urma_target_seg_t *dstSeg, uint32_t bufSize, uint64_t context) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + + int ret = 0; + + urma_jfs_wr_t *bad_wr; + urma_sge_t src_sge{}; + src_sge.addr = bufAddr; + src_sge.len = bufSize; + src_sge.tseg = srcSeg; + + urma_sge_t dst_sge{}; + dst_sge.addr = dstBufAddr; + dst_sge.len = bufSize; + dst_sge.tseg = dstSeg; + + urma_jfs_wr_t wr{}; + wr.user_ctx = context; + wr.rw.src.sge = &dst_sge; + wr.rw.src.num_sge = 1; + wr.rw.dst.sge = &src_sge; + wr.rw.dst.num_sge = 1; + wr.next = nullptr; + wr.flag.bs.complete_enable = 1; + wr.tjetty = mTargetJetty; + wr.opcode = URMA_OPC_READ; + + ret = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_ERROR("Failed to post read request to jetty " << mName << ", result " << ret); + return UB_QP_POST_WRITE_FAILED; + } + + return UB_OK; + } + + inline void FillUrmaTargetSeg(urma_target_seg_t &tseg, uintptr_t addr, uint32_t bufSize, uint32_t token) + { + tseg.seg.ubva.va = addr; + tseg.seg.len = bufSize; + tseg.seg.token_id = token; + tseg.urma_ctx = mUBContext->mUrmaContext; + } + + /// @brief 发送 Read 请求,从对端读取数据至本端 + /// @param bufAddr [in] 本地 MR 目标地址 + /// @param ltoken [in] 本地 MR 访问 key + /// @param dstBufAddr [in] 对端 MR 地址 + /// @param rtoken [in] 对端 MR 远程访问 key + /// @param bufSize [in] 读取大小 + /// @param context [in] `UBOpContextInfo`, 事件完成时通过此 context 找回对应 Jetty 以及上层 Service 层的上下文 + UResult PostRead(uintptr_t bufAddr, uint64_t ltoken, uintptr_t dstBufAddr, uint64_t rtoken, uint32_t bufSize, + uint64_t context) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + + NN_LOG_DEBUG("[Request Read] ------ ep id = " << mUpId << ", lKey = " << ltoken << ", rKey = " << rtoken << + ",size = " << bufSize << ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_URMA)); + + urma_target_seg_t *dstSeg = ImportSeg(dstBufAddr, bufSize, rtoken); + if (dstSeg == nullptr) { + NN_LOG_ERROR("Failed to import dstSeg"); + return UB_QP_POST_READ_FAILED; + } + + urma_target_seg_t srcSeg{}; + FillUrmaTargetSeg(srcSeg, bufAddr, bufSize, static_cast(ltoken)); + + urma_jfs_wr_t *bad_wr; + urma_sge_t src_sge{}; + src_sge.addr = bufAddr; + src_sge.tseg = &srcSeg; + src_sge.len = bufSize; + + urma_sge_t dst_sge{}; + dst_sge.addr = dstBufAddr; + dst_sge.tseg = dstSeg; + dst_sge.len = bufSize; + + urma_jfs_wr_t wr{}; + wr.rw.src.sge = &dst_sge; + wr.rw.src.num_sge = 1; + wr.rw.dst.sge = &src_sge; + wr.rw.dst.num_sge = 1; + wr.next = nullptr; + wr.user_ctx = context; + wr.opcode = URMA_OPC_READ; + wr.flag.bs.complete_enable = 1; + wr.tjetty = mTargetJetty; + + auto ret = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + ret = HcomUrma::UnimportSeg(dstSeg); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_WARN("Unable to unImport Seg " << mName << ", result: " << ret); + } + NN_LOG_ERROR("Failed to post read request to jetty " << mName << ", result: " << ret); + return UB_QP_POST_READ_FAILED; + } + ret = HcomUrma::UnimportSeg(dstSeg); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_WARN("Unable to unImport Seg " << mName << ", result: " << ret); + } + return UB_OK; + } + + /// @brief 发送 Read 请求,从对端读取数据至本端 + /// @param bufAddr [in] 本地 MR 目标地址 + /// @param ltoken [in] 本地 MR 访问 key + /// @param dstBufAddr [in] 对端 MR 地址 + /// @param rtoken [in] 对端 MR 远程访问 key + /// @param bufSize [in] 读取大小 + /// @param context [in] `UBOpContextInfo`, 事件完成时通过此 context 找回对应 Jetty 以及上层 Service 层的上下文 + UResult PostRead(uintptr_t bufAddr, urma_target_seg_t *ltseg, uintptr_t dstBufAddr, uint64_t rtoken, + uint32_t bufSize, uint64_t context) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + urma_target_seg_t *dstSeg = ImportSeg(dstBufAddr, bufSize, rtoken); + if (dstSeg == nullptr) { + NN_LOG_ERROR("Failed to import dstSeg"); + return UB_QP_POST_READ_FAILED; + } + + urma_jfs_wr_t *bad_wr; + urma_sge_t src_sge{}; + src_sge.addr = bufAddr; + src_sge.tseg = ltseg; + src_sge.len = bufSize; + + urma_sge_t dst_sge{}; + dst_sge.addr = dstBufAddr; + dst_sge.tseg = dstSeg; + dst_sge.len = bufSize; + + urma_jfs_wr_t wr{}; + wr.user_ctx = context; + wr.rw.src.sge = &dst_sge; + wr.rw.src.num_sge = 1; + wr.rw.dst.sge = &src_sge; + wr.rw.dst.num_sge = 1; + wr.next = nullptr; + wr.opcode = URMA_OPC_READ; + wr.flag.bs.complete_enable = 1; + wr.tjetty = mTargetJetty; + + auto ret = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + ret = HcomUrma::UnimportSeg(dstSeg); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_WARN("Unable to unImport Seg " << mName << ", result " << ret); + } + NN_LOG_ERROR("Failed to post read request to jetty " << mName << ", result " << ret); + return UB_QP_POST_READ_FAILED; + } + ret = HcomUrma::UnimportSeg(dstSeg); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_WARN("Unable to unImport Seg " << mName << ", result " << ret); + } + return UB_OK; + } + + inline UResult PostWrite(uintptr_t bufAddr, urma_target_seg_t *srcSeg, uintptr_t dstBufAddr, + urma_target_seg_t *dstSeg, uint32_t bufSize, uint64_t context) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + + int ret = 0; + + urma_jfs_wr_t *bad_wr; + urma_sge_t src_sge{}; + src_sge.addr = bufAddr; + src_sge.len = bufSize; + src_sge.tseg = srcSeg; + + urma_sge_t dst_sge{}; + dst_sge.addr = dstBufAddr; + dst_sge.len = bufSize; + dst_sge.tseg = dstSeg; + + urma_jfs_wr_t wr{}; + wr.user_ctx = context; + wr.rw.src.sge = &src_sge; + wr.rw.src.num_sge = 1; + wr.rw.dst.sge = &dst_sge; + wr.rw.dst.num_sge = 1; + wr.next = nullptr; + wr.opcode = URMA_OPC_WRITE; + wr.flag.bs.complete_enable = 1; + wr.tjetty = mTargetJetty; + + ret = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_ERROR("Failed to post write request to jetty " << mName << ", result " << ret); + return UB_QP_POST_WRITE_FAILED; + } + + return UB_OK; + } + + /// @brief 发送 Write 请求,将数据从本端写入至对端 + /// @param bufAddr [in] 本地 MR 源地址 + /// @param ltoken [in] 本地 MR 访问 key + /// @param dstBufAddr [in] 对端 MR 地址 + /// @param rtoken [in] 对端 MR 远程访问 key + /// @param bufSize [in] 写入大小 + /// @param context [in] `UBOpContextInfo`, 事件完成时通过此 context 找回对应 Jetty 以及上层 Service 层的上下文 + UResult PostWrite(uintptr_t bufAddr, uint64_t ltoken, uintptr_t dstBufAddr, uint64_t rtoken, uint32_t bufSize, + uint64_t context) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + urma_target_seg_t *dstSeg = ImportSeg(dstBufAddr, bufSize, rtoken); + if (dstSeg == nullptr) { + NN_LOG_ERROR("Failed to import dstSeg"); + return UB_QP_POST_WRITE_FAILED; + } + + urma_target_seg_t srcSeg{}; + FillUrmaTargetSeg(srcSeg, bufAddr, bufSize, static_cast(ltoken)); + + NN_LOG_DEBUG("[Request Write] ------ ep id = " << mUpId << ", lKey = " << ltoken << ", rKey = " << rtoken << + ",size = " << bufSize << ", status = " << UBSHcomRequestStatusToString(UBSHcomNetRequestStatus::IN_URMA)); + + urma_jfs_wr_t *bad_wr; + urma_sge_t src_sge{}; + src_sge.addr = bufAddr; + src_sge.tseg = &srcSeg; + src_sge.len = bufSize; + + urma_sge_t dst_sge{}; + dst_sge.addr = dstBufAddr; + dst_sge.tseg = dstSeg; + dst_sge.len = bufSize; + + urma_jfs_wr_t wr{}; + wr.rw.dst.sge = &dst_sge; + wr.rw.dst.num_sge = 1; + wr.rw.src.sge = &src_sge; + wr.rw.src.num_sge = 1; + wr.user_ctx = context; + wr.next = nullptr; + wr.opcode = URMA_OPC_WRITE; + wr.flag.bs.complete_enable = 1; + wr.tjetty = mTargetJetty; + + auto ret = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + ret = HcomUrma::UnimportSeg(dstSeg); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_WARN("Unable to unImport Seg " << mName << ", result: " << ret); + } + NN_LOG_ERROR("Failed to post write request to jetty " << mName << ", result: " << ret); + return UB_QP_POST_WRITE_FAILED; + } + + ret = HcomUrma::UnimportSeg(dstSeg); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_WARN("Unable to unImport Seg " << mName << ", result: " << ret); + } + return UB_OK; + } + + /// @brief 发送 Write 请求,将数据从本端写入至对端 + /// @param bufAddr [in] 本地 MR 源地址 + /// @param ltoken [in] 本地 MR 访问 key + /// @param dstBufAddr [in] 对端 MR 地址 + /// @param rtoken [in] 对端 MR 远程访问 key + /// @param bufSize [in] 写入大小 + /// @param context [in] `UBOpContextInfo`, 事件完成时通过此 context 找回对应 Jetty 以及上层 Service 层的上下文 + UResult PostWrite(uintptr_t bufAddr, urma_target_seg_t *ltseg, uintptr_t dstBufAddr, uint64_t rtoken, + uint32_t bufSize, uint64_t context) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + urma_target_seg_t *dstSeg = ImportSeg(dstBufAddr, bufSize, rtoken); + if (dstSeg == nullptr) { + NN_LOG_ERROR("Failed to import dstSeg"); + return UB_QP_POST_WRITE_FAILED; + } + + urma_jfs_wr_t *bad_wr; + urma_sge_t src_sge{}; + src_sge.addr = bufAddr; + src_sge.tseg = ltseg; + src_sge.len = bufSize; + + urma_sge_t dst_sge{}; + dst_sge.addr = dstBufAddr; + dst_sge.tseg = dstSeg; + dst_sge.len = bufSize; + + urma_jfs_wr_t wr{}; + wr.user_ctx = context; + wr.rw.dst.sge = &dst_sge; + wr.rw.dst.num_sge = 1; + wr.rw.src.sge = &src_sge; + wr.rw.src.num_sge = 1; + wr.next = nullptr; + wr.opcode = URMA_OPC_WRITE; + wr.flag.bs.complete_enable = 1; + wr.tjetty = mTargetJetty; + + auto ret = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + ret = HcomUrma::UnimportSeg(dstSeg); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_WARN("Unable to unImport Seg " << mName << ", result " << ret); + } + NN_LOG_ERROR("Failed to post write request to jetty " << mName << ", result " << ret); + return UB_QP_POST_WRITE_FAILED; + } + + ret = HcomUrma::UnimportSeg(dstSeg); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_WARN("Unable to unImport Seg " << mName << ", result " << ret); + } + return UB_OK; + } + + /// @brief 发送单边 read/write 请求,采用 sgl 方式 + /// @param iov [in] 将要发送的向量。仅需填充 lAddress, lkey, rAddress, rkey 和 size. + /// @param iovCount [in] 向量长度 + /// @param context [in] Service 层上下文 + /// @param isRead [in] 是否选择发送 read 请求,当为 false 时发送 write 请求。 + /// @param ctxLen [in] context 数组长度 + UResult PostOneSideSgl(UBSHcomNetTransSgeIov *iov, uint32_t iovCount, uint64_t *context, + bool isRead, uint8_t ctxLen) + { + if (NN_UNLIKELY(mUrmaJetty == nullptr || mState != UBJettyState::READY)) { + return UB_QP_NOT_INITIALIZED; + } + + urma_jfs_wr_t *badWR; + urma_jfs_wr_t wrList[NET_SGE_MAX_IOV] = {}; + urma_target_seg_t srcSeg[NET_SGE_MAX_IOV] = {}; + urma_sge_t src_sge[NET_SGE_MAX_IOV] = {}; + urma_target_seg_t *dstSeg[NET_SGE_MAX_IOV] = {}; + urma_sge_t dst_sge[NET_SGE_MAX_IOV] = {}; + UResult ret = UB_OK; + uint32_t i = 0; + for (; i < iovCount; i++) { + FillUrmaTargetSeg(srcSeg[i], iov[i].lAddress, iov[i].size, iov[i].lKey); + src_sge[i].addr = iov[i].lAddress; + src_sge[i].len = iov[i].size; + src_sge[i].tseg = static_cast(iov[i].srcSeg); + + dstSeg[i] = ImportSeg(iov[i].rAddress, iov[i].size, iov[i].rKey); + if (dstSeg[i] == nullptr) { + NN_LOG_ERROR("Failed to import dstSeg"); + ret = isRead ? UB_QP_POST_READ_FAILED : UB_QP_POST_WRITE_FAILED; + break; + } + + dst_sge[i].addr = iov[i].rAddress; + dst_sge[i].len = iov[i].size; + dst_sge[i].tseg = dstSeg[i]; + + auto &wr = wrList[i]; + wr.user_ctx = context[i]; + wr.rw.src.num_sge = 1; + wr.rw.dst.num_sge = 1; + if (isRead) { + wr.opcode = URMA_OPC_READ; + wr.rw.src.sge = &dst_sge[i]; + wr.rw.dst.sge = &src_sge[i]; + } else { + wr.opcode = URMA_OPC_WRITE; + wr.rw.src.sge = &src_sge[i]; + wr.rw.dst.sge = &dst_sge[i]; + } + wr.next = (i + 1 == iovCount) ? nullptr : &wrList[i + 1]; + wr.tjetty = mTargetJetty; + wr.flag.bs.complete_enable = 1; + } + + if (ret == UB_OK) { + auto result = HcomUrma::PostJettySendWr(mUrmaJetty, wrList, NET_SGE_MAX_IOV, &badWR); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_ERROR("Urma failed to post oneSide request to jetty " << mName << ", result " << result); + ret = isRead ? UB_QP_POST_READ_FAILED : UB_QP_POST_WRITE_FAILED; + } + } + + for (uint32_t index = 0; index < i; ++index) { + auto result = HcomUrma::UnimportSeg(dstSeg[index]); + if (NN_UNLIKELY(result != 0)) { + NN_LOG_WARN("Unable to unImport Seg " << mName << ", result " << result); + } + } + return ret; + } + + inline uint32_t GetId() const + { + return mId; + } + + inline void SetUpId(uint64_t id) + { + mUpId = id; + } + + inline uint64_t GetUpId() const + { + return mUpId; + } + + inline const std::string &GetName() const + { + return mName; + } + + inline void SetName(const std::string &value) + { + mName = value; + } + + inline const std::string &GetPeerIpAndPort() const + { + return mPeerIpPort; + } + + void SetPeerIpAndPort(const std::string &value); + + uint32_t GetPostSendMaxSize() const; + + inline uint8_t GetPortNum() const + { + return mUBContext->mPortNumber; + } + + inline void SetUpContext(uintptr_t ctx) + { + mUpContext = ctx; + } + + inline uintptr_t GetUpContext() const + { + return mUpContext; + } + + inline void SetUpContext1(uintptr_t ctx) + { + mUpContext1 = ctx; + } + + inline uintptr_t GetUpContext1() const + { + return mUpContext1; + } + + inline uint32_t GetJettyId() + { + if (mUrmaJetty != nullptr) { + return mUrmaJetty->jetty_id.id; + } + return 0; + } + + bool GetFreeBuff(uintptr_t &item); + bool ReturnBuffer(uintptr_t value); + bool GetFreeBufferN(uintptr_t *&items, uint32_t n); + uint64_t GetLKey(); + urma_target_seg_t *GetMemorySeg(); + + inline void AddOpCtxInfo(UBOpContextInfo *ctxInfo) + { + if (NN_LIKELY(ctxInfo != nullptr)) { + // bi-direction linked list, 4 step to insert to head + ctxInfo->prev = &mCtxPosted; + mLock.Lock(); + // head -><- first -><- second -><- third -> nullptr + // insert into the head place + ctxInfo->next = mCtxPosted.next; + if (mCtxPosted.next != nullptr) { + mCtxPosted.next->prev = ctxInfo; + } + mCtxPosted.next = ctxInfo; + ++mCtxPostedCount; + mLock.Unlock(); + } + } + + inline void RemoveOpCtxInfo(UBOpContextInfo *ctxInfo) + { + if (NN_LIKELY(ctxInfo != nullptr)) { + // bi-direction linked list, 4 step to remove one + mLock.Lock(); + + // repeat remove + if (ctxInfo->prev == nullptr) { + mLock.Unlock(); + return; + } + + // head-><- first -><- second -><- third -> nullptr + ctxInfo->prev->next = ctxInfo->next; + if (ctxInfo->next != nullptr) { + ctxInfo->next->prev = ctxInfo->prev; + } + --mCtxPostedCount; + + ctxInfo->prev = nullptr; + ctxInfo->next = nullptr; + mLock.Unlock(); + } + } + + // need to call this when qp broken, to get these contexts to return mrs + inline void GetCtxPosted(UBOpContextInfo *&remaining) + { + mLock.Lock(); + // head -> first -><- second -><- third -> nullptr + remaining = mCtxPosted.next; + mCtxPosted.next = nullptr; + mCtxPostedCount = 0; + mLock.Unlock(); + } + + /// 获取 Jetty 发送队列的长度。 + inline uint32_t GetSendQueueSize() + { + int32_t ref = __sync_fetch_and_add(&mPostSendRef, 0); + ref = std::max(0, std::min(ref, mPostSendMaxWr)); + return static_cast(mPostSendMaxWr - ref); + } + + /// 获取所有提交至 Jetty 队列中的任务个数,总数为 PostReceive + PostSend 族 + /// 函数的和。因为 RDMA 有 prePostReceive 机制,所以它的值一般会大于等于 + /// prePostReceiveSizePerQP 的值。 + /// \see prePostReceiveSizePerQP + inline uint32_t GetPostedCount() + { + mLock.Lock(); + auto tmp = mCtxPostedCount; + mLock.Unlock(); + return tmp; + } + + inline bool GetPostSendWr(uint32_t times = NN_NO8, uint32_t sleepUs = NN_NO64) + { + while (times-- > 0) { + if (NN_LIKELY(__sync_sub_and_fetch(&mPostSendRef, 1) >= 0)) { + return true; + } + __sync_add_and_fetch(&mPostSendRef, 1); + usleep(sleepUs); + } + return false; + } + + inline void ReturnPostSendWr() + { + int32_t ref = __sync_add_and_fetch(&mPostSendRef, 1); + if (ref > mPostSendMaxWr) { + NN_LOG_WARN("[UB] Posted send requests " << ref << " over capacity " << mPostSendMaxWr); + } + } + + inline bool GetOneSideWr(uint32_t times = NN_NO8, uint32_t sleepUs = NN_NO64) + { + while (times-- > 0) { + if (NN_LIKELY(__sync_sub_and_fetch(&mOneSideRef, 1) >= 0)) { + return true; + } + __sync_add_and_fetch(&mOneSideRef, 1); + usleep(sleepUs); + } + return false; + } + + inline void ReturnOneSideWr() + { + int32_t ref = __sync_add_and_fetch(&mOneSideRef, 1); + if (ref > mOneSideMaxWr) { + NN_LOG_WARN("[UB] Posted one side requests " << ref << " over capacity " << mOneSideMaxWr); + } + } + + // UBC Heartbeat + NResult CreateHBMemoryRegion(uint64_t size, UBSHcomNetMemoryRegionPtr &mr); + void DestroyHBMemoryRegion(UBSHcomNetMemoryRegionPtr &mr); + + inline uintptr_t GetNextLocalHBAddress() + { + uint64_t nextOffset = __sync_fetch_and_add(&mLocalNextOffset, NN_NO4) % mHBLocalMr->Size(); + return mHBLocalMr->GetAddress() + nextOffset; + } + + inline uint64_t GetLocalHBKey() const + { + return mHBLocalMr->GetLKey(); + } + + void GetRemoteHbInfo(UBJettyExchangeInfo &info) + { + uint64_t nextOffset = __sync_fetch_and_add(&mRemoteNextOffset, NN_NO4) % mHBRemoteMr->Size(); + info.hbAddress = mHBRemoteMr->GetAddress() + nextOffset; + info.hbKey = mHBRemoteMr->GetLKey(); + info.hbMrSize = NN_NO4; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +public: + static uint32_t NewId() + { + return __sync_fetch_and_add(&G_INDEX, 1); + } + + inline uint32_t QpNum() const + { + if (NN_UNLIKELY(mUrmaJetty == nullptr)) { + return 0xffffffff; + } + + return mUrmaJettyId; + } + + inline uint32_t PostRegMrSize() const + { + return mJettyOptions.mrSegSize; + } + + UBJettyState State() const + { + return mState; + } + + // stop jetty + UResult Stop(); + +private: + void FillJettyCfg(urma_jetty_cfg_t &jetty_cfg, uintptr_t seg_pa, uintptr_t seg_va, uint32_t seg_len, + uint32_t seg_count); + void FillJfsCfg(urma_jfs_cfg_t *jfs_cfg); + void FillJfrCfg(urma_jfr_cfg_t *jfr_cfg, uint32_t token = 0); + UResult ChangeToInit(urma_jetty_attr_t &attr); + UResult ChangeToReceive(UBJettyExchangeInfo &exInfo, urma_jetty_attr_t &attr); + UResult ChangeToSend(urma_jetty_attr_t &attr); + UResult SetMaxSendWrConfig(UBJettyExchangeInfo &exInfo); + +private: + std::string mName; + std::string mPeerIpPort; + uint32_t mId = 0; + uint64_t mUpId = 0; + std::atomic mState{UBJettyState::RESET}; + std::mutex mStopMutex; + + UBContext *mUBContext = nullptr; + UBJfc *mSendJfc = nullptr; + UBJfc *mRecvJfc = nullptr; + urma_jfr_t *mJfr = nullptr; + JettyOptions mJettyOptions{}; + uint32_t mUrmaJettyId = 0; // mUrmaJetty->jetty_id.id + urma_jetty_t *mUrmaJetty = nullptr; + urma_target_jetty_t *mTargetJetty = nullptr; + std::unique_ptr mRemoteJettyInfo; // 对端建链时交换信息 + uintptr_t mUpContext = 0; + uintptr_t mUpContext1 = 0; + NetSpinLock mLock; + UBOpContextInfo mCtxPosted{}; + uint32_t mCtxPostedCount{ 0 }; + UBMemoryRegionFixedBuffer *mJettyMr = nullptr; + + int32_t mOneSideMaxWr = JETTY_MAX_SEND_WR - NN_NO64; + int32_t mOneSideRef = JETTY_MAX_SEND_WR - NN_NO64; + int32_t mPostSendMaxWr = NN_NO64; + uint32_t mPostSendMaxSize = NN_NO1024; + int32_t mPostSendRef = NN_NO64; + + UBSHcomNetMemoryRegionPtr mHBLocalMr = nullptr; + UBSHcomNetMemoryRegionPtr mHBRemoteMr = nullptr; + uint64_t mLocalNextOffset = 0; + uint64_t mRemoteNextOffset = 0; + + friend class NetDriverUBWithOob; + friend class NetHeartbeat; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + static uint32_t G_INDEX; +}; +} // namespace hcom +} // namespace ock + +#endif +#endif // HCOM_UB_URMA_WRAPPER_JETTY_H diff --git a/src/transport/ub/ub_urma_wrapper_jfc.cpp b/src/transport/ub/ub_urma_wrapper_jfc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c4d5a6353f339812e1aa8cfc8a839f6ed16127c --- /dev/null +++ b/src/transport/ub/ub_urma_wrapper_jfc.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED + +#include "ub_urma_wrapper_jfc.h" + +namespace ock { +namespace hcom { + +UResult UBJfc::CreatePollingCq() +{ + urma_jfc_cfg_t jfc_cfg{}; + jfc_cfg.depth = mJfcCount; + jfc_cfg.flag.value = 0; + jfc_cfg.jfce = nullptr; + jfc_cfg.user_ctx = mWork; + + urma_jfc_t *tmpJfc = HcomUrma::CreateJfc(mUBContext->mUrmaContext, &jfc_cfg); + if (tmpJfc == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create completion queue for UBJfc " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return UB_NEW_OBJECT_FAILED; + } + + mUrmaJfc = tmpJfc; + return UB_OK; +} + +UResult UBJfc::CreateEventCq() +{ + // create jfce + urma_jfce_t *tmpJfce = HcomUrma::CreateJfce(mUBContext->mUrmaContext); + if (tmpJfce == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create JFCE for UBJfc " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return UB_NEW_OBJECT_FAILED; + } + + // create jfc + urma_jfc_cfg_t jfc_cfg{}; + jfc_cfg.depth = mJfcCount; + jfc_cfg.flag.value = 0; + jfc_cfg.jfce = tmpJfce; + jfc_cfg.user_ctx = mWork; + + urma_jfc_t *tmpJfc = HcomUrma::CreateJfc(mUBContext->mUrmaContext, &jfc_cfg); + if (tmpJfc == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create completion queue for UBJfc " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + HcomUrma::DeleteJfce(tmpJfce); + return UB_NEW_OBJECT_FAILED; + } + + if (HcomUrma::RearmJfc(tmpJfc, 0) != 0) { + HcomUrma::DeleteJfc(tmpJfc); + HcomUrma::DeleteJfce(tmpJfce); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create completion queue for UBJfc " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return UB_NEW_OBJECT_FAILED; + } + + int flags = fcntl(tmpJfce->fd, F_GETFL); + if (fcntl(tmpJfce->fd, F_SETFL, static_cast(flags) | O_NONBLOCK) < 0) { + HcomUrma::DeleteJfc(tmpJfc); + HcomUrma::DeleteJfce(tmpJfce); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set no blocking for UBJfc " << mName << ", error " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return UB_NEW_OBJECT_FAILED; + } + + mUrmaJfcEvent = tmpJfce; + mUrmaJfc = tmpJfc; + return UB_OK; +} + +UResult UBJfc::Initialize() +{ + NN_LOG_TRACE_INFO("UBJfc::Initialize"); + if (mUrmaJfc != nullptr) { + return UB_OK; + } + + if (mUBContext == nullptr || mUBContext->mUrmaContext == nullptr) { + NN_LOG_ERROR("Failed to initialize UBJfc as ub context is null"); + return UB_PARAM_INVALID; + } + if (mCreateCompletionChannel) { + return CreateEventCq(); + } else { + return CreatePollingCq(); + } + + NN_LOG_TRACE_INFO("UBJfc::Initialized"); + return UB_OK; +} + +UResult UBJfc::UnInitialize() +{ + int res = 0; + if (mUrmaJfc != nullptr) { + if ((res = HcomUrma::DeleteJfc(mUrmaJfc)) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Unable to delete jfc " << res << ", as errno " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + } + mUrmaJfc = nullptr; + } + + if (mUrmaJfcEvent != nullptr) { + HcomUrma::DeleteJfce(mUrmaJfcEvent); + mUrmaJfcEvent = nullptr; + } + + if (mUBContext != nullptr) { + mUBContext->DecreaseRef(); + mUBContext = nullptr; + } + return UB_OK; +} + +UResult UBJfc::ProgressV(urma_cr_t *cr, uint32_t &countInOut) +{ + if (NN_UNLIKELY(mUrmaJfc == nullptr || cr == nullptr)) { + return UB_CQ_NOT_INITIALIZED; + } + + uint16_t times = 0; + + while (true) { + auto n = HcomUrma::PollJfc(mUrmaJfc, countInOut, cr); + if (NN_UNLIKELY(n < 0)) { + NN_LOG_ERROR("Poll jfc failed in UBJfc " << mName << ", errno " << errno << " n = " << n); + return UB_CQ_POLLING_FAILED; + } + if (n == 0) { + times++; + if (times < NN_NO10) { + continue; + } + } + + countInOut = static_cast(n); + break; + } + + return UB_OK; +} + +UResult UBJfc::EventProgressV(urma_cr_t *cr, uint32_t &countInOut, int32_t timeoutInMs) +{ + if (NN_UNLIKELY(mUrmaJfc == nullptr || mUrmaJfcEvent == nullptr || cr == nullptr)) { + return UB_CQ_NOT_INITIALIZED; + } + + // wait request if n == 0 + urma_jfc_t *jfc = nullptr; + + // Wait for the completion event + int result = HcomUrma::WaitJfc(mUrmaJfcEvent, 1, timeoutInMs, &jfc); + if (result < 0) { + NN_LOG_ERROR("urma_wait_jfc failed, jfc id: " << mUrmaJfc->jfc_id.id << ", errno " << errno); + return UB_CQ_EVENT_GET_FAILED; + } + + int cqeCnt = HcomUrma::PollJfc(mUrmaJfc, countInOut, cr); + if (cqeCnt < 0) { + NN_LOG_ERROR("Poll jfc failed in UBJfc " << mName << ", errno " << errno); + return UB_CQ_POLLING_FAILED; + } + countInOut = static_cast(cqeCnt); + + if (jfc != nullptr) { + // Ack the event + uint32_t ackCnt = 1; + HcomUrma::AckJfc(&jfc, &ackCnt, 1); + } + + // Request notification upon the next completion event + if (HcomUrma::RearmJfc(mUrmaJfc, false) != 0) { + NN_LOG_ERROR("Notify cq event failed in UBJfc " << mName << ", errno " << errno); + return UB_CQ_EVENT_NOTIFY_FAILED; + } + + return UB_OK; +} +} // namespace hcom +} +#endif \ No newline at end of file diff --git a/src/transport/ub/ub_urma_wrapper_jfc.h b/src/transport/ub/ub_urma_wrapper_jfc.h new file mode 100644 index 0000000000000000000000000000000000000000..6595763fc84c876515a9d0c37b01e9353b63332c --- /dev/null +++ b/src/transport/ub/ub_urma_wrapper_jfc.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_URMA_WRAPPER_JFC_H +#define HCOM_UB_URMA_WRAPPER_JFC_H +#ifdef UB_BUILD_ENABLED + +#include "ub_common.h" +#include "ub_urma_wrapper_ctx.h" + +namespace ock { +namespace hcom { + +class UBJfc { +public: + UBJfc(const std::string &name, UBContext *ctx, bool createCompletionChannel = false, uintptr_t work = 0) + : mName(name), mCreateCompletionChannel(createCompletionChannel), mWork(work), mUBContext(ctx) + { + if (mUBContext != nullptr) { + mUBContext->IncreaseRef(); + } + + OBJ_GC_INCREASE(UBJfc); + } + + ~UBJfc() + { + UnInitialize(); + OBJ_GC_DECREASE(UBJfc); + } + + inline void SetJfcCount(uint32_t value) + { + mJfcCount = (value < NN_NO1024) ? NN_NO1024 : value; + } + + inline uint32_t GetCQCount() + { + return mJfcCount; + } + + UResult Initialize(); + UResult UnInitialize(); + + UResult ProgressV(urma_cr_t *cr, uint32_t &countInOut); + UResult EventProgressV(urma_cr_t *cr, uint32_t &countInOut, int32_t timeoutInMs = NN_NO500); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +private: + UResult CreatePollingCq(); + UResult CreateEventCq(); + std::string mName; + uint32_t mJfcCount = JFC_COUNT; + bool mCreateCompletionChannel = false; + uintptr_t mWork = 0; + UBContext *mUBContext = nullptr; + urma_jfc_t *mUrmaJfc = nullptr; + urma_jfce_t *mUrmaJfcEvent = nullptr; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class UBJetty; + friend class UBPublicJetty; +}; +} +} +#endif +#endif // HCOM_UB_URMA_WRAPPER_JFC_H \ No newline at end of file diff --git a/src/transport/ub/ub_urma_wrapper_public_jetty.cpp b/src/transport/ub/ub_urma_wrapper_public_jetty.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32a14e932663b821aef65ec2045ddfe6e29adac2 --- /dev/null +++ b/src/transport/ub/ub_urma_wrapper_public_jetty.cpp @@ -0,0 +1,698 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include "ub_urma_wrapper_public_jetty.h" + +namespace ock { +namespace hcom { + +std::string EidToStr(urma_eid_t urmaEid) +{ + std::string str = ""; + for (int i = 0; i < URMA_EID_SIZE; i++) { + str += std::to_string(urmaEid.raw[i]); + } + return str; +} + + +uint32_t UBPublicJetty::G_INDEX = 1; + +// public jetty import remote jetty +UResult UBPublicJetty::ImportPublicJetty(const urma_eid_t &remoteEid, uint32_t jettyId) +{ + // import remote jetty + urma_rjetty_t remoteJetty{}; + remoteJetty.jetty_id.id = jettyId; + remoteJetty.jetty_id.eid = remoteEid; + remoteJetty.trans_mode = URMA_TM_RM; + remoteJetty.type = URMA_JETTY; + urma_token_t token{0}; + + mTargetJetty = HcomUrma::ImportJetty(mUBContext->mUrmaContext, &remoteJetty, &token); + if (mTargetJetty == nullptr) { + NN_LOG_ERROR("Failed to import public jetty"); + return UB_QP_IMPORT_FAILED; + } + NN_LOG_INFO("Local public jetty id: " << mUrmaJetty->jetty_id.id << ", local eid: " << + EidToStr(mUrmaJetty->jetty_id.eid) << "; Remote public jetty id: " << mTargetJetty->id.id << ", remote eid: " << + EidToStr(mTargetJetty->id.eid)); + + return UB_OK; +} + +void UBPublicJetty::FillJfsCfg(urma_jfs_cfg_t *jfs_cfg) +{ + jfs_cfg->user_ctx = reinterpret_cast(this); + jfs_cfg->jfc = mSendJfc->mUrmaJfc; + jfs_cfg->trans_mode = URMA_TM_RM; + jfs_cfg->depth = JETTY_MAX_SEND_WR; + jfs_cfg->max_sge = static_cast(mUBContext->mMaxSge); + jfs_cfg->flag.value = 0; + jfs_cfg->flag.bs.multi_path = 1; +} + +void UBPublicJetty::FillJfrCfg(urma_jfr_cfg_t *jfr_cfg) +{ + jfr_cfg->user_ctx = reinterpret_cast(this); + jfr_cfg->jfc = mRecvJfc->mUrmaJfc; + jfr_cfg->trans_mode = URMA_TM_RM; + jfr_cfg->depth = JETTY_MAX_RECV_WR; + jfr_cfg->max_sge = static_cast(mUBContext->mMaxSge); + jfr_cfg->id = 0; + jfr_cfg->flag.bs.tag_matching = URMA_NO_TAG_MATCHING; +} + +// create a public jetty +UResult UBPublicJetty::CreateUrmaPublicJetty(uint32_t id) +{ + if (mUBContext == nullptr || mUBContext->mUrmaContext == nullptr || mSendJfc == nullptr || + mSendJfc->mUrmaJfc == nullptr) { + NN_LOG_ERROR("Invalid parameter for jetty creating"); + return UB_PARAM_INVALID; + } + + // jfs cfg + urma_jfs_cfg_t jfs_cfg{}; + FillJfsCfg(&jfs_cfg); + // jfr cfg + urma_jfr_cfg_t jfr_cfg{}; + FillJfrCfg(&jfr_cfg); + // jetty flag + urma_jetty_flag_t jetty_flag{}; + jetty_flag.bs.share_jfr = 1; + // jetty cfg + urma_jetty_cfg_t jetty_cfg{}; + jetty_cfg.id = id; // 非0,公知jetty + jetty_cfg.flag = jetty_flag; + jetty_cfg.jfs_cfg = jfs_cfg; + jetty_cfg.jfr_cfg = &jfr_cfg; + // create jetty + urma_jetty_t *tmpJetty = nullptr; + mJfr = HcomUrma::CreateJfr(mUBContext->mUrmaContext, &jfr_cfg); + if (mJfr == nullptr) { + NN_LOG_ERROR("urma create jfr failed"); + return UB_PARAM_INVALID; + } + + jetty_cfg.shared.jfc = mRecvJfc->mUrmaJfc; + jetty_cfg.shared.jfr = mJfr; + tmpJetty = HcomUrma::CreateJetty(mUBContext->mUrmaContext, &jetty_cfg); + if (tmpJetty == nullptr) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create urma jetty for UBJetty " << mName << ", errno " << + NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + HcomUrma::DeleteJfr(mJfr); + mJfr = nullptr; + return UB_QP_CREATE_FAILED; + } + mUrmaJetty = tmpJetty; + mUrmaJettyId = mUrmaJetty->jetty_id.id; + NN_LOG_INFO("Create public jetty success, jetty id: " << mUrmaJettyId << ", local eid: " << + EidToStr(mUrmaJetty->jetty_id.eid) << ", jfr id: " << mJfr->jfr_id.id << ", recv jfc id: " << + mRecvJfc->mUrmaJfc->jfc_id.id << ", send jfc id: " << mSendJfc->mUrmaJfc->jfc_id.id << ", multi_path: " << + mUrmaJetty->jetty_cfg.jfs_cfg.flag.bs.multi_path); + return UB_OK; +} + +// create jetty mr for public jetty +UResult UBPublicJetty::CreateJettyMr() +{ + NResult result = NN_OK; + uint32_t segCount = isServer ? NN_NO32 : NN_NO8; + // create mr pool for send/receive and initialize + if ((result = UBMemoryRegionFixedBuffer::Create(mName, mUBContext, PUBLIC_JETTY_SEG_SIZE, segCount, 0, mJettyMr)) + != 0) { + NN_LOG_ERROR("Failed to create mr for send/receive in public jetty " << mName << ", result " << result); + return result; + } + mJettyMr->IncreaseRef(); + if ((result = mJettyMr->Initialize()) != 0) { + NN_LOG_ERROR("Failed to initialize mr for send/receive in public jetty " << mName << ", result " << result); + return result; + } + + return UB_OK; +} + +UResult UBPublicJetty::CreateCtxInfoPool() +{ + uint16_t blkSize = NN_NextPower2(sizeof(UBOpContextInfo)); + uint16_t blkCnt = isServer ? NN_NO32 : NN_NO8; + + mCtxInfoPool = new (std::nothrow) UBFixedMemPool(blkSize, blkCnt); + if (mCtxInfoPool == nullptr) { + NN_LOG_ERROR("Failed to create context info pool for public jetty probably out of memory"); + return UB_MEMORY_ALLOCATE_FAILED; + } + mCtxInfoPool->IncreaseRef(); + auto result = mCtxInfoPool->Initialize(); + if (result != UB_OK) { + NN_LOG_ERROR("Failed to initialize context info pool for public jetty"); + mCtxInfoPool->UnInitialize(); + } + + return result; +} + +// inirialzie public jetty +UResult UBPublicJetty::InitializePublicJetty(uint32_t id) +{ + auto result = CreateJettyMr(); + if (result != UB_OK) { + NN_LOG_ERROR("Failed to create jetty mr in public jetty"); + return result; + } + if ((result = CreateCtxInfoPool()) != UB_OK) { + NN_LOG_ERROR("Failed to create context info pool in public jetty"); + return result; + } + result = CreateUrmaPublicJetty(id); + if (result != UB_OK) { + NN_LOG_ERROR("Failed to create um jetty in public jetty"); + return result; + } + if (isServer) { + mThreadPool = new (std::nothrow) UBThreadPool(NN_NO16); + if (NN_UNLIKELY(mThreadPool == nullptr)) { + NN_LOG_ERROR("Create ub thread pool failed"); + return UB_ERROR; + } + mThreadPool->Initialize(); + } + return UB_OK; +} + +// start public jetty +UResult UBPublicJetty::StartPublicJetty() +{ + if (mIsStarted) { + return UB_OK; + } + mIsStarted = true; + uintptr_t mrBufAddress = 0; + uint32_t prePostCount = isServer ? NN_NO32 : NN_NO4; + auto *mrSegs = new (std::nothrow) uintptr_t[prePostCount]; + if (mrSegs == nullptr) { + NN_LOG_ERROR("Failed to create mr address array in Driver " << mName << ", probably out of memory"); + return UB_NEW_OBJECT_FAILED; + } + NetLocalAutoFreePtr segAutoDelete(mrSegs, true); + if (!mJettyMr->GetFreeBufferN(mrSegs, prePostCount)) { + NN_LOG_ERROR("failed to get free mr from pool, mr is not enough"); + return UB_MEMORY_ALLOCATE_FAILED; + } + urma_target_seg_t *localSeg = reinterpret_cast(mJettyMr->GetMemorySeg()); + uint32_t i = 0; + for (; i < prePostCount; i++) { + uintptr_t buf = 0; + if (NN_UNLIKELY(!mCtxInfoPool->GetFreeBuffer(buf))) { + NN_LOG_ERROR("Failed to get a free context info buffer from pool"); + mJettyMr->ReturnBuffer(mrBufAddress); + return UB_MEMORY_ALLOCATE_FAILED; + } + auto *ctx = reinterpret_cast(buf); + bzero(ctx, sizeof(UBOpContextInfo)); + ctx->mrMemAddr = mrSegs[i]; + ctx->dataSize = PUBLIC_JETTY_SEG_SIZE; + ctx->localSeg = localSeg; + ctx->opType = UBOpContextInfo::RECEIVE; + ctx->opResultType = UBOpContextInfo::SUCCESS; + + if (PostReceive(mrSegs[i], PUBLIC_JETTY_SEG_SIZE, localSeg, reinterpret_cast(ctx)) != 0) { + NN_LOG_ERROR("Failed to postrecv in start public jetty"); + mJettyMr->ReturnBuffer(ctx->mrMemAddr); + mCtxInfoPool->ReturnBuffer(reinterpret_cast(ctx)); + return UB_ERROR; + } + } + if (isServer) { + mNeedStop = false; + std::thread tmpThread(&UBPublicJetty::RunInThread, this); + mPublicJettyPollingThread = std::move(tmpThread); + } + + return UB_OK; +} + +int UBPublicJetty::NewRequest(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->mrMemAddr == 0)) { + NN_LOG_ERROR("Ctx or mrMemAddr is null in public jetty"); + return NN_ERROR; + } + auto exchangeInfo = reinterpret_cast(ctx->mrMemAddr); + auto msgType = exchangeInfo->msgType; + switch (msgType) { + case (UrmaConnectMsgType::CONNECT_REQ): + mNewConnectionHandler(ctx); + break; + case (UrmaConnectMsgType::EXCHANGE_MSG): + break; + default: + NN_LOG_ERROR("exchangeInfo invalid msgType " << exchangeInfo->msgType); + } + + ctx->opType = UBOpContextInfo::RECEIVE; + ctx->opResultType = UBOpContextInfo::SUCCESS; + + if (PostReceive(ctx->mrMemAddr, PUBLIC_JETTY_SEG_SIZE, GetMemorySeg(), reinterpret_cast(ctx)) != 0) { + NN_LOG_ERROR("Failed to post receive in new request handler"); + mJettyMr->ReturnBuffer(ctx->mrMemAddr); + mCtxInfoPool->ReturnBuffer(reinterpret_cast(ctx)); + return UB_QP_POST_RECEIVE_FAILED; + } + return UB_OK; +} + +int UBPublicJetty::SendFinished(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(!ReturnBuffer(ctx->mrMemAddr))) { + NN_LOG_ERROR("Failed to return buffer mr to jetty mr pool"); + } + + if (NN_UNLIKELY(!mCtxInfoPool->ReturnBuffer(reinterpret_cast(ctx)))) { + NN_LOG_ERROR("Failed to return context info in public jetty"); + } + + return UB_OK; +} + +void UBPublicJetty::ProcessWorkerCompletion(UBOpContextInfo *ctx) +{ + NN_LOG_INFO("Start process worker completion thread id: " << pthread_self()); + switch (ctx->opType) { + case (UBOpContextInfo::OpType::RECEIVE): + NewRequest(ctx); + break; + default: + NN_LOG_ERROR("Poll cq invalid OpType " << ctx->opType); + } + NN_LOG_INFO("End process worker completion thread id: " << pthread_self()); +} + +void UBPublicJetty::ProcessPollingResult(urma_cr_t &wc) +{ + UBOpContextInfo *ctx = nullptr; + ctx = reinterpret_cast(wc.user_ctx); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Ctx is null in public jetty polling"); + return; + } + ctx->dataSize = wc.completion_len; + if (mThreadPool == nullptr) { + NN_LOG_ERROR("Failed to submit conn task as ub thrad pool not initialize"); + return; + } + // optimize to thread poll in next version + mThreadPool->Submit([this, ctx]() { + this->ProcessWorkerCompletion(ctx); + }); +} + +// public jetty polling thread +void UBPublicJetty::RunInThread() +{ + NN_LOG_INFO("OOB server public jetty accept thread started success, load balancer " << + (mWorkerLb == nullptr ? "null" : mWorkerLb->ToString())); + urma_cr_t wc{}; + uint32_t pollCount = 0; + while (!mNeedStop) { + try { + pollCount = 1; + // avoid urma event poll zero cqe bug + mRecvJfc->ProgressV(&wc, pollCount); + if (pollCount != 0) { + ProcessPollingResult(wc); + } + usleep(NN_NO100000); // 100ms + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Got runtime incorrect signal in UBWorker::RunInThread '" << ex.what() << + "', ignore and continue"); + } catch (...) { + NN_LOG_WARN("Got unknown signal in UBWorker::RunInThread, ignore and continue"); + } + } +} + +inline void FillSendWr(urma_jfs_wr_t &wr, uint64_t ctx, urma_sge_t *localSge, urma_target_jetty_t *targetJetty) +{ + wr.user_ctx = reinterpret_cast(ctx); + wr.send.src.sge = localSge; + wr.send.src.num_sge = 1; + wr.send.imm_data = 0; + wr.next = nullptr; + wr.opcode = URMA_OPC_SEND; + wr.flag.bs.complete_enable = 1; + wr.tjetty = targetJetty; +} +// send a message to target jetty +UResult UBPublicJetty::SendByPublicJetty(const void *buf, uint32_t size) +{ + if (NN_UNLIKELY(mUrmaJetty == nullptr || mTargetJetty == nullptr)) { + NN_LOG_ERROR("Failed to send by public jetty as local jetty or target jetty is nullptr"); + return UB_QP_NOT_INITIALIZED; + } + + uintptr_t mrBufAddress = 0; + NResult res = NN_OK; + if (!mJettyMr->GetFreeBuffer(mrBufAddress)) { + NN_LOG_ERROR("failed to get free mr from pool, mr is not enough"); + return UB_MEMORY_ALLOCATE_FAILED; + } + urma_target_seg_t *localSeg = GetMemorySeg(); + if (NN_UNLIKELY(memcpy_s(reinterpret_cast(mrBufAddress), PUBLIC_JETTY_SEG_SIZE, buf, size) != 0)) { + NN_LOG_ERROR("Failed to copy oob port range"); + return UB_PARAM_INVALID; + } + + uintptr_t ctxBuffer = 0; + if (NN_UNLIKELY(!mCtxInfoPool->GetFreeBuffer(ctxBuffer))) { + NN_LOG_ERROR("Failed to get a free context info buffer from pool"); + mJettyMr->ReturnBuffer(mrBufAddress); + return UB_MEMORY_ALLOCATE_FAILED; + } + auto *ctx = reinterpret_cast(ctxBuffer); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to get ctx in public jetty"); + return UB_QP_CTX_FULL; + } + ctx->mrMemAddr = mrBufAddress; + ctx->dataSize = size; + ctx->localSeg = localSeg; + ctx->opType = UBOpContextInfo::SEND; + ctx->opResultType = UBOpContextInfo::SUCCESS; + + urma_jfs_wr_t *bad_wr; + urma_sge_t local_sge{}; + local_sge.addr = mrBufAddress; + local_sge.len = size; + local_sge.tseg = localSeg; + + urma_jfs_wr_t wr{}; + FillSendWr(wr, reinterpret_cast(ctx), &local_sge, mTargetJetty); + + auto ret = HcomUrma::PostJettySendWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_ERROR("Failed to post send request to public jetty " << mName << ", result " << ret); + mJettyMr->ReturnBuffer(mrBufAddress); + mCtxInfoPool->ReturnBuffer(ctxBuffer); + return UB_QP_POST_SEND_FAILED; + } + + return UB_OK; +} + +UResult UBPublicJetty::PostReceive(uintptr_t bufAddr, uint32_t bufSize, urma_target_seg_t *localSeg, uint64_t context) +{ + if (NN_UNLIKELY(mUrmaJetty == nullptr || bufAddr == 0 || bufSize == 0 || localSeg == nullptr)) { + NN_LOG_ERROR("Failed to postrecv as mUrmaJetty or bufAddr or bufSize or localSeg is null"); + return NN_INVALID_PARAM; + } + + urma_jfr_wr_t *bad_wr; + + urma_sge_t local_sge{}; + local_sge.addr = bufAddr; + local_sge.len = bufSize; + local_sge.tseg = localSeg; + + urma_jfr_wr_t wr{}; + wr.src.sge = &local_sge; + wr.src.num_sge = 1; + wr.next = nullptr; + wr.user_ctx = context; + + NN_LOG_DEBUG("[Post Buffer] ------ urma_post_jetty_recv_wr2, jetty id: " << mUrmaJetty->jetty_id.id << + ", jfc id: " << mRecvJfc->mUrmaJfc->jfc_id.id); + auto ret = HcomUrma::PostJettyRecvWr(mUrmaJetty, &wr, &bad_wr); + if (NN_UNLIKELY(ret != 0)) { + NN_LOG_ERROR("Failed to post receive request to jetty " << mName << ", result " << ret); + return UB_QP_POST_RECEIVE_FAILED; + } + + return UB_OK; +} + +UResult UBPublicJetty::CheckRecvResult(urma_cr_t wc, uint32_t size, UResult result, uint32_t pollCount, + int32_t timeoutInMs) +{ + // 若pollCount == 0,大概率是超时还未收到事件,返回失败 + if (pollCount == 0) { + NN_LOG_ERROR("polled 0 cqe, jetty id: " << mUrmaJetty->jetty_id.id << ", jfc id: " << + mRecvJfc->mUrmaJfc->jfc_id.id << "timeout: " << timeoutInMs << " ms"); + return UB_CQ_POLLING_FAILED; + } + + // 若pollCount非0,判断result + if (NN_UNLIKELY(result != UB_OK)) { + NN_LOG_ERROR("Failed to event polling in public jetty Receive res = " << result << ", polling timeout " << + timeoutInMs << " ms"); + return result; + } + + if (NN_UNLIKELY(wc.status != URMA_CR_SUCCESS)) { + NN_LOG_ERROR("Poll cq failed in public jetty Receive wcStatus " << wc.status); + return UB_CQ_WC_WRONG; + } + + if (NN_UNLIKELY(wc.completion_len != size)) { + NN_LOG_ERROR("Failed to Receive in public jetty Receive as expect size:" << size << " actual size: " << + wc.completion_len); + return UB_CQ_WC_WRONG; + } + return UB_OK; +} + +UResult UBPublicJetty::Receive(void *buf, uint32_t size) +{ + if (NN_UNLIKELY(buf == nullptr || size == 0)) { + NN_LOG_ERROR("Failed to Receive as invalid param"); + return UB_PARAM_INVALID; + } + + UResult result = UB_OK; + urma_cr_t wc{}; + int32_t timeoutInMs = TimeSecToMs(mPollTimeout); + uint32_t pollCount = 1; + + // avoid urma event poll zero cqe bug + auto start = NetMonotonic::TimeMs(); + int64_t pollTime = 0; + do { + pollCount = 1; + result = mRecvJfc->ProgressV(&wc, pollCount); + pollTime = (int64_t)(NetMonotonic::TimeMs() - start); + if (pollCount == 0 && timeoutInMs >= 0 && pollTime > timeoutInMs) { + NN_LOG_ERROR("Busy poll failed pollCount = " << pollCount << " pollTime = " << pollTime << " in recv"); + return UB_CQ_EVENT_GET_TIMOUT; + } + usleep(NN_NO100000); // 100ms + } while (result == UB_OK && pollCount == 0); + if (CheckRecvResult(wc, size, result, pollCount, timeoutInMs) != UB_OK) { + return UB_ERROR; + } + + UBOpContextInfo *ctx = reinterpret_cast(wc.user_ctx); + if (ctx == nullptr || ctx->mrMemAddr == 0) { + NN_LOG_ERROR("Failed to Receive as ctx is nullptr"); + return UB_ERROR; + } + + if (NN_UNLIKELY(memcpy_s(buf, size, (void *)(ctx->mrMemAddr), size) != SER_OK)) { + NN_LOG_ERROR("Failed to copy data"); + return UB_ERROR; + } + // postrecv and return ctx + ctx->opType = UBOpContextInfo::RECEIVE; + ctx->opResultType = UBOpContextInfo::SUCCESS; + + if (PostReceive(ctx->mrMemAddr, PUBLIC_JETTY_SEG_SIZE, GetMemorySeg(), reinterpret_cast(ctx)) != 0) { + NN_LOG_ERROR("Failed to post receive in jetty receive"); + mJettyMr->ReturnBuffer(ctx->mrMemAddr); + mCtxInfoPool->ReturnBuffer(reinterpret_cast(ctx)); + return UB_QP_POST_RECEIVE_FAILED; + } + return UB_OK; +} + +UResult UBPublicJetty::PollingCompletion() +{ + if (NN_UNLIKELY(mRecvJfc == nullptr || mRecvJfc->mUrmaJfc == nullptr)) { + NN_LOG_ERROR("Failed to polling completion with public jetty as jfc is null"); + return UB_EP_NOT_INITIALIZED; + } + int32_t timeoutInMs = TimeSecToMs(mPollTimeout); + urma_cr_t wc{}; + uint32_t pollCount = 1; + NResult result = UB_OK; + + // avoid urma event poll zero cqe bug + auto start = NetMonotonic::TimeMs(); + int64_t pollTime = 0; + do { + pollCount = 1; + result = mRecvJfc->ProgressV(&wc, pollCount); + pollTime = (int64_t)(NetMonotonic::TimeMs() - start); + if (pollCount == 0 && timeoutInMs >= 0 && pollTime > timeoutInMs) { + NN_LOG_ERROR("Busy poll completion failed pollCount = " << pollCount << " pollTime = " << pollTime); + return UB_CQ_EVENT_GET_TIMOUT; + } + usleep(NN_NO100000); // 100ms + } while (result == UB_OK && pollCount == 0); + // 若pollCount == 0,大概率是超时还未收到事件,返回失败 + if (pollCount == 0) { + NN_LOG_ERROR("polled 0 cqe, jetty id: " << mUrmaJetty->jetty_id.id << ", jfc id: " << + mRecvJfc->mUrmaJfc->jfc_id.id << "timeout: " << timeoutInMs << " ms"); + return UB_CQ_POLLING_FAILED; + } + + // 若pollCount非0,判断result + if (NN_UNLIKELY(result != UB_OK)) { + NN_LOG_ERROR("Failed to epoll in jetty recv res = " << result << ", polling timeout " << timeoutInMs << " ms"); + return result; + } + + if (NN_UNLIKELY(wc.status != URMA_CR_SUCCESS)) { + NN_LOG_WARN("Poll cq failed in public jetty wcStatus " << wc.status); + if (NN_UNLIKELY(wc.status == URMA_CR_WR_SUSPEND_DONE || wc.status == URMA_CR_WR_FLUSH_ERR_DONE)) { + NN_LOG_WARN("Polled a fake cqe in public jetty"); + return UB_ERROR; + } + } + + auto ctx = reinterpret_cast(wc.user_ctx); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Ctx is null in public jetty polling completion"); + return UB_ERROR; + } + + if (NN_UNLIKELY(!ReturnBuffer(ctx->mrMemAddr))) { + NN_LOG_ERROR("Failed to return buffer mr to jetty mr pool"); + } + + if (NN_UNLIKELY(!mCtxInfoPool->ReturnBuffer(reinterpret_cast(ctx)))) { + NN_LOG_ERROR("Failed to return context info in public jetty"); + } + + return UB_OK; +} + +void UBPublicJetty::Stop() +{ + std::lock_guard lock(mStopMutex); + if (!mIsStarted) { + return; + } + mIsStarted = false; + mNeedStop = true; + if (mPublicJettyPollingThread.joinable()) { + mPublicJettyPollingThread.join(); + } + if (mThreadPool != nullptr) { + mThreadPool->Stop(); + delete mThreadPool; + mThreadPool = nullptr; + } + int result = 0; + if (mUrmaJetty != nullptr) { + struct urma_jetty_attr attr = {}; + attr.mask = JETTY_STATE; + attr.state = URMA_JETTY_STATE_ERROR; + result = HcomUrma::ModifyJetty(mUrmaJetty, &attr); + if (result != 0) { + NN_LOG_ERROR("Failed to modify jetty to URMA_JETTY_STATE_ERROR, urma result = " << result); + } + } + if (mTargetJetty != nullptr) { + result = HcomUrma::UnimportJetty(mTargetJetty); + mTargetJetty = nullptr; + if (result != 0) { + NN_LOG_ERROR("Failed to unimport target jetty, urma result = " << result); + } + } + if (mUrmaJetty != nullptr) { + result = HcomUrma::DeleteJetty(mUrmaJetty); + mUrmaJetty = nullptr; + if (result != 0) { + NN_LOG_ERROR("Failed to delete jetty, urma result = " << result); + } else { + NN_LOG_INFO("Delete public jetty success, jetty id " << mUrmaJettyId); + } + } + if (mJfr != nullptr) { + result = HcomUrma::DeleteJfr(mJfr); + mJfr = nullptr; + if (result != 0) { + NN_LOG_ERROR("Failed to delete jfr, urma result = " << result); + } + } +} + +// public jetty clear resource +UResult UBPublicJetty::UnInitialize() +{ + Stop(); + if (mJettyMr != nullptr) { + mJettyMr->DecreaseRef(); + mJettyMr = nullptr; + } + if (mCtxInfoPool != nullptr) { + mCtxInfoPool->DecreaseRef(); + mCtxInfoPool = nullptr; + } + if (mSendJfc != nullptr) { + mSendJfc->DecreaseRef(); + } + + if (mRecvJfc != nullptr && mRecvJfc != mSendJfc) { + mRecvJfc->DecreaseRef(); + } + mSendJfc = nullptr; + mRecvJfc = nullptr; + + if (mUBContext != nullptr) { + mUBContext->DecreaseRef(); + } + + if (mWorkerLb.Get() != nullptr) { + mWorkerLb.Set(nullptr); + } + NN_LOG_INFO("Uninitialize public jetty success, jetty id: " << mUrmaJettyId); + return UB_OK; +} + +// get a free buffer from mr +bool UBPublicJetty::GetFreeBuff(uintptr_t &item) +{ + return mJettyMr->GetFreeBuffer(item); +} + +// get N free buffers from mr +bool UBPublicJetty::GetFreeBufferN(uintptr_t *&items, uint32_t n) +{ + return mJettyMr->GetFreeBufferN(items, n); +} + +// return a free buffer to mr +bool UBPublicJetty::ReturnBuffer(uintptr_t value) +{ + return mJettyMr->ReturnBuffer(value); +} + +urma_target_seg_t *UBPublicJetty::GetMemorySeg() +{ + return reinterpret_cast(mJettyMr->GetMemorySeg()); +} +} // namespace hcom +} // namespace ock +#endif \ No newline at end of file diff --git a/src/transport/ub/ub_urma_wrapper_public_jetty.h b/src/transport/ub/ub_urma_wrapper_public_jetty.h new file mode 100644 index 0000000000000000000000000000000000000000..41034debdede53289273cd2fb040a1b53784bf9c --- /dev/null +++ b/src/transport/ub/ub_urma_wrapper_public_jetty.h @@ -0,0 +1,230 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_URMA_WRAPPER_PUBLIC_JETTY_H +#define HCOM_UB_URMA_WRAPPER_PUBLIC_JETTY_H +#ifdef UB_BUILD_ENABLED + +#include "net_oob.h" +#include "net_load_balance.h" +#include "ub_common.h" +#include "ub_fixed_mem_pool.h" +#include "ub_urma_wrapper_jetty.h" +#include "ub_thread_pool.h" + +namespace ock { +namespace hcom { + +enum UrmaConnectMsgType : uint8_t { + CONNECT_REQ = 1, + EXCHANGE_MSG = 2, +}; + +struct JettyConnHeader { + UrmaConnectMsgType msgType; + uint64_t epId = 0; + UBJettyExchangeInfo info{}; + uint32_t controlJettyId = 0; + struct { + uint64_t magic : 16; + uint64_t version : 8; + uint64_t groupIndex : 8; + uint64_t protocol : 8; + uint64_t bandWidth : 8; + uint64_t devIndex : 8; + uint64_t majorVersion : 8; + uint64_t minorVersion : 8; + uint64_t tlsVersion : 16; + uint64_t reserve : 40; + } ConnectHeader; + uint32_t payloadLen = 0; + char payload[1024]; + + inline void SetConnHeader(uint32_t magic, uint32_t version, uint32_t groupIndex, uint32_t protocol, + uint32_t majorVersion, uint32_t minorVersion, uint32_t tlsVersion) + { + ConnectHeader.magic = magic; + ConnectHeader.version = version; + ConnectHeader.groupIndex = groupIndex; + ConnectHeader.protocol = protocol; + ConnectHeader.majorVersion = majorVersion; + ConnectHeader.minorVersion = minorVersion; + ConnectHeader.tlsVersion = tlsVersion; + } +} __attribute__((packed)); + +struct JettyConnResp { + UrmaConnectMsgType msgType; + ConnectResp connResp = OK; + uint64_t epId = 0; + UBJettyExchangeInfo info{}; + uint32_t serverCtrlJettyId = 0; + urma_eid_t serverCtrlEid{}; +} __attribute__((packed)); + +#define PUBLIC_JETTY_SEG_SIZE 2560 + + +class UBPublicJetty { +public: + using NewConnectionHandler = std::function; + + UBPublicJetty(const std::string &name, uint32_t id, UBContext *ctx, UBJfc *jfc, bool isServer = false, + JettyOptions jettyOptions = {}) : mName(name), mId(id), mUBContext(ctx), mSendJfc(jfc), mRecvJfc(jfc), + isServer(isServer), mJettyOptions(jettyOptions) + { + mIsStarted = false; + if (mUBContext != nullptr) { + mUBContext->IncreaseRef(); + } + + if (mSendJfc != nullptr) { + mSendJfc->IncreaseRef(); + } + mPollTimeout = GetPollTimeout(); + OBJ_GC_INCREASE(UBPublicJetty); + } + + ~UBPublicJetty() + { + UnInitialize(); + OBJ_GC_DECREASE(UBPublicJetty); + } + + /* create public(URMA_TM_RM) jetty */ + UResult CreateUrmaPublicJetty(uint32_t id); + UResult InitializePublicJetty(uint32_t id); + UResult CreateCtxInfoPool(); + void ProcessWorkerCompletion(UBOpContextInfo *ctx); + + UResult StartPublicJetty(); + void RunInThread(); + void ProcessPollingResult(urma_cr_t &wc); + int NewRequest(UBOpContextInfo *ctx); + int SendFinished(UBOpContextInfo *ctx); + UResult ImportPublicJetty(const urma_eid_t &remoteEid, uint32_t jettyId); + UResult SendByPublicJetty(const void *buf, uint32_t size); + UResult PollingCompletion(); + UResult Receive(void *buf, uint32_t size); + UResult PostReceive(uintptr_t bufAddr, uint32_t bufSize, urma_target_seg_t *localSeg, uint64_t context); + UResult CheckRecvResult(urma_cr_t wc, uint32_t size, UResult result, uint32_t pollCount, int32_t timeoutInMs); + + UResult CreateJettyMr(); + UResult UnInitialize(); + void Stop(); + inline void SetNewConnCB(const NewConnectionHandler &handler) + { + mNewConnectionHandler = handler; + } + inline UBSHcomNetDriverProtocol GetProtocol() + { + return mUBContext->protocol; + } + + inline uint32_t GetJettyId() + { + if (mUrmaJetty != nullptr) { + return mUrmaJetty->jetty_id.id; + } + return 0; + } + + inline urma_eid_t GetEid() + { + return mUBContext->mBestEid.urmaEid; + } + + inline void SetWorkerLb(NetWorkerLB *lb) + { + if (lb != nullptr) { + mWorkerLb = lb; + } + } + + inline const NetWorkerLBPtr &LoadBalancer() const + { + return mWorkerLb; + } + + bool GetFreeBuff(uintptr_t &item); + bool ReturnBuffer(uintptr_t value); + bool GetFreeBufferN(uintptr_t *&items, uint32_t n); + urma_target_seg_t *GetMemorySeg(); + + DEFINE_RDMA_REF_COUNT_FUNCTIONS + +public: + inline uint32_t QpNum() const + { + if (NN_UNLIKELY(mUrmaJetty == nullptr)) { + return 0xffffffff; + } + + return mUrmaJetty->jetty_id.id; + } + + static uint32_t NewId() + { + return __sync_fetch_and_add(&G_INDEX, 1); + } + +private: + void FillJfsCfg(urma_jfs_cfg_t *jfs_cfg); + void FillJfrCfg(urma_jfr_cfg_t *jfr_cfg); + static long GetPollTimeout() + { + static long timeout = []() { + long res = NetFunc::NN_GetLongEnv("HCOM_UB_CONNECTION_POLL_TIMEOUT", NN_NO1, NN_NO180, NN_NO60); + NN_LOG_INFO("Public jetty polling timeout is " << res << " s"); + return res; + }(); + return timeout; + } + +private: + std::string mName; + bool isServer = false; + std::atomic mIsStarted; + uint32_t mId = 0; + std::mutex mStopMutex; + UBContext *mUBContext = nullptr; + UBJfc *mSendJfc = nullptr; + UBJfc *mRecvJfc = nullptr; + urma_jfr_t *mJfr = nullptr; + JettyOptions mJettyOptions{}; + uint32_t mUrmaJettyId = 0; // mUrmaJetty->jetty_id.id + urma_jetty_t *mUrmaJetty = nullptr; + urma_target_jetty_t *mTargetJetty = nullptr; + UBMemoryRegionFixedBuffer *mJettyMr = nullptr; + UBFixedMemPool *mCtxInfoPool = nullptr; + // public polling thread + std::thread mPublicJettyPollingThread; + bool mNeedStop = true; + NewConnectionHandler mNewConnectionHandler = nullptr; + NetWorkerLBPtr mWorkerLb = nullptr; + UBThreadPool *mThreadPool = nullptr; + + int32_t mOneSideMaxWr = JETTY_MAX_SEND_WR - NN_NO64; + int32_t mOneSideRef = JETTY_MAX_SEND_WR - NN_NO64; + int32_t mPostSendMaxWr = NN_NO64; + int32_t mPostSendMaxSize = NN_NO1024; + int32_t mPostSendRef = NN_NO64; + long mPollTimeout = NN_NO60; + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + static uint32_t G_INDEX; +}; +} +} +#endif +#endif // HCOM_UB_URMA_WRAPPER_PUBLIC_JETTY_H \ No newline at end of file diff --git a/src/transport/ub/ub_worker.h b/src/transport/ub/ub_worker.h new file mode 100644 index 0000000000000000000000000000000000000000..1fb0b117ef31293026ae8e47556bd8fa4b774ef1 --- /dev/null +++ b/src/transport/ub/ub_worker.h @@ -0,0 +1,422 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_UB_WORKER_H +#define HCOM_UB_WORKER_H +#ifdef UB_BUILD_ENABLED + +#include +#include +#include +#include +#include +#include +#include + +#include "net_ctx_info_pool.h" +#include "net_ub_endpoint.h" +#include "ub_urma_wrapper_jetty.h" +#include "ub_jetty_ptr_map.h" + +namespace ock { +namespace hcom { +using UBNewReqHandler = std::function; +using UBPostedHandler = std::function; +using UBOneSideDoneHandler = std::function; + +// when there is no request from cq, call this +using UBIdleHandler = UBSHcomNetDriverIdleHandler; + +using UBOpContextInfoPool = OpContextInfoPool; +using UBSglContextInfoPool = OpContextInfoPool; + +enum UBWorkerType : uint8_t { + UB_SENDER = 0, + UB_RECEIVER = 1, + UB_SENDER_RECEIVER = 2, +}; + +std::string &WorkerTypeToString(UBWorkerType tp); +std::string &PollingModeToString(UBPollingMode m); + +using UBWorkerOptions = struct UBWorkerOptionsStruct { + UBWorkerType workerType = UBWorkerType::UB_RECEIVER; + UBPollingMode workerMode = UBPollingMode::UB_BUSY_POLLING; + int16_t cpuId = -1; + uint16_t completionQueueDepth = NN_NO2048; + uint16_t maxPostSendCountPerQP = NN_NO64; + uint16_t prePostReceiveSizePerQP = NN_NO64; + uint16_t pollingBatchSize = NN_NO4; + uint32_t qpSendQueueSize = NN_NO256; + uint32_t qpReceiveQueueSize = NN_NO256; + uint32_t qpMrSegSize = NN_NO1024; + uint32_t qpMrSegCount = NN_NO64; + uint32_t eventPollingTimeout = NN_NO500; + bool dontStartWorkers = false; + /* worker thread priority [-20,20], 20 is the lowest, -20 is the highest, 0 (default) means do not set priority */ + int threadPriority = 0; + uint8_t slave = 1; + UBSHcomUbcMode ubcMode = UBSHcomUbcMode::LowLatency; + + std::string ToString() const + { + std::ostringstream oss; + oss << "options type: " << WorkerTypeToString(workerType) << ", mode: " << PollingModeToString(workerMode) << + ", jfc size: " << completionQueueDepth << ", max post send: " << maxPostSendCountPerQP << + ", pre-post receive size: " << prePostReceiveSizePerQP << ", poll batch size " << pollingBatchSize << + ", cpu id: " << cpuId << ", jetty send queue: " << qpSendQueueSize << ", jetty receive queue: " << + qpReceiveQueueSize << ", dontStartWorkers: " << dontStartWorkers; + return oss.str(); + } + + void SetValue(const UBSHcomNetDriverOptions &opt) + { + workerType = UBWorkerType::UB_SENDER_RECEIVER; + completionQueueDepth = opt.completionQueueDepth; + maxPostSendCountPerQP = opt.maxPostSendCountPerQP; + prePostReceiveSizePerQP = opt.prePostReceiveSizePerQP; + pollingBatchSize = opt.pollingBatchSize; + if (opt.mode == NET_EVENT_POLLING) { + workerMode = UBPollingMode::UB_EVENT_POLLING; + } else if (opt.mode == NET_BUSY_POLLING) { + workerMode = UBPollingMode::UB_BUSY_POLLING; + } + qpSendQueueSize = opt.qpSendQueueSize; + qpReceiveQueueSize = opt.qpReceiveQueueSize; + qpMrSegSize = opt.mrSendReceiveSegSize; + qpMrSegCount = opt.prePostReceiveSizePerQP; + eventPollingTimeout = opt.eventPollingTimeout; + dontStartWorkers = opt.dontStartWorkers; + threadPriority = opt.workerThreadPriority; + slave = opt.slave; + ubcMode = opt.ubcMode; + } +}; + +class UBWorker { +public: + UBWorker(const std::string &name, UBContext *ctx, const UBWorkerOptions &options, const NetMemPoolFixedPtr &memPool, + const NetMemPoolFixedPtr &sglMemPool); + + virtual ~UBWorker() + { + UnInitialize(); + OBJ_GC_DECREASE(UBWorker); + } + + UResult Initialize(); + UResult UnInitialize(); + UResult ReInitializeCQ(); + + UResult Start(); + UResult Stop(); + + inline void SetIndex(const UBSHcomNetWorkerIndex &value) + { + mIndex = value; + } + + inline const UBSHcomNetWorkerIndex &Index() const + { + return mIndex; + } + + inline bool IsWorkStarted(uint32_t timeOutSecond = NN_NO8) + { + uint64_t count = static_cast(timeOutSecond) * NN_NO1000000 / NN_NO100; + while (--count > 0 && !mProgressThreadStarted.load()) { + usleep(NN_NO100); + } + + if (count > 0) { + return true; + } else { + return false; + } + } + + UResult CreateQP(UBJetty *&qp); + + UResult PostReceive(UBJetty *qp, uintptr_t bufAddress, uint32_t bufSize, urma_target_seg_t *localSeg); + UResult PostSend(UBJetty *qp, const UBSendReadWriteRequest &req, urma_target_seg_t *localSeg, uint32_t immData = 0); + UResult PostSendSglInline(UBJetty *qp, const UBSendSglInlineHeader &header, const UBSendReadWriteRequest &req, + uint32_t immData = 0); + UResult PostSendSgl(UBJetty *qp, const UBSHcomNetTransSglRequest &req, const UBSHcomNetTransRequest &tlsReq, + uint32_t immData, bool isEncrypted); + UResult PostRead(UBJetty *qp, const UBSendReadWriteRequest &req); + UResult PostWrite(UBJetty *qp, const UBSendReadWriteRequest &req, + UBOpContextInfo::OpType type = UBOpContextInfo::WRITE); + UResult RePostReceive(UBOpContextInfo *ctx); + UResult CreateOneSideCtx(const UBSgeCtxInfo &sgeInfo, const UBSHcomNetTransSgeIov *iov, uint32_t iovCount, + uint64_t (&ctxArr)[NET_SGE_MAX_IOV], bool isRead); + UResult PostOneSideSgl(UBJetty *qp, const UBSendSglRWRequest &req, bool isRead = true); + + inline UBOpContextInfo *GetOpContextInfo() + { + return mOpCtxInfoPool.Get(); + } + + inline void ReturnOpContextInfo(UBOpContextInfo *&ctx) + { + if (NN_LIKELY(ctx != nullptr)) { + if (NN_LIKELY(ctx->ubJetty != nullptr)) { + ctx->ubJetty->DecreaseRef(); + } + mOpCtxInfoPool.Return(ctx); + ctx = nullptr; + } + } + + inline void ReturnSglContextInfo(UBSglContextInfo *&ctx) + { + if (NN_LIKELY(ctx != nullptr)) { + mSglCtxInfoPool.Return(ctx); + ctx = nullptr; + } + } + + inline void RegisterNewRequestHandler(const UBNewReqHandler &handler) + { + mNewRequestHandler = handler; + } + + inline void RegisterPostedHandler(const UBPostedHandler &handler) + { + mSendPostedHandler = handler; + } + + inline void RegisterOneSideDoneHandler(const UBOneSideDoneHandler &handler) + { + mOneSideDoneHandler = handler; + } + + inline void RegisterIdleHandler(const UBIdleHandler &handler) + { + mIdleHandler = handler; + } + + inline const std::string &Name() const + { + return mName; + } + + std::string DetailName() const + { + std::ostringstream oss; + oss << "[name: " << mName << ", index: " << mIndex.ToString() << "]"; + return oss.str(); + } + + inline uint8_t PortNum() const + { + return mUBContext->mPortNumber; + } + + DEFINE_RDMA_REF_COUNT_FUNCTIONS +public: + static UResult Create(const std::string &name, UBContext *ctx, const UBWorkerOptions &options, + NetMemPoolFixedPtr memPool, NetMemPoolFixedPtr sglMemPool, UBWorker *&outWorker); + +protected: + void RunInThread(); + void DoWithBusyPolling(); + void DoWithCQEventPolling(); + +protected: + std::string mName; + UBSHcomNetWorkerIndex mIndex{}; + UBContext *mUBContext = nullptr; + UBJfc *mUBJfc = nullptr; + NetMemPoolFixedPtr mOpCtxMemPool = nullptr; + NetMemPoolFixedPtr mSglCtxMemPool = nullptr; + bool mInited = false; + + UBWorkerOptions mOptions{}; + + // variable for thread + std::thread mProgressThread; + std::atomic_bool mProgressThreadStarted; + int16_t mProgressCpuId = -1; + bool mNeedStop = false; + + UBOpContextInfoPool mOpCtxInfoPool; + UBSglContextInfoPool mSglCtxInfoPool; + + // request process related + UBNewReqHandler mNewRequestHandler = nullptr; + + // send request posted process related + UBPostedHandler mSendPostedHandler = nullptr; + + // one side done related + UBOneSideDoneHandler mOneSideDoneHandler = nullptr; + + // no request will this + UBIdleHandler mIdleHandler = nullptr; + + uint32_t mProgressBatchSize = NN_NO4; + + JettyPtrMap mJettyPtrMap; ///< ID -> UBJetty* 映射表,仅出错后开始记录 + + DEFINE_RDMA_REF_COUNT_VARIABLE; + + friend class UBJetty; + +private: + inline __attribute__((always_inline)) bool BusyPolling(urma_cr_t *wc, uint32_t &pollCount) + { + if (NN_UNLIKELY(mUBJfc->ProgressV(wc, pollCount) != UB_OK)) { + return true; + } + return false; + } + + inline __attribute__((always_inline)) bool CqEventPolling(urma_cr_t *wc, uint32_t &pollCount, uint32_t pollTimeOut) + { + if (NN_UNLIKELY(mUBJfc->EventProgressV(wc, pollCount, pollTimeOut) != UB_OK)) { + if (mIdleHandler != nullptr) { + mIdleHandler(mIndex); + } + return true; + } + return false; + } + + inline __attribute__((always_inline)) void ProcessPollingResult(urma_cr_t *wc, uint32_t pollCount, + UBJetty *&lastBrokenQp, urma_cr_status_t &lastErrorWcStatus) + { + for (uint32_t i = 0; i < pollCount; i++) { + const uint32_t jettyId = wc[i].local_id; + + // SQE 被硬件处理时同时 modify jetty error 了 + if (wc[i].status == URMA_CR_WR_FLUSH_ERR) { + NN_LOG_DEBUG("SQE flushed, jetty id: " << wc[i].local_id); + continue; + } + + // 按照先 modify jfr error, 再 modify jetty error 的顺序可以保证 FLUSH_ERR_DONE 必定为最后第一个错误,后续 + // 不会出现正常的CQE,之前所有的正常 Post 的资源统一在 FLUSH_ERR_DONE 时回收。 + // \see UBJetty::Stop() + if (wc[i].status == URMA_CR_WR_FLUSH_ERR_DONE || wc[i].status == URMA_CR_WR_SUSPEND_DONE) { + UBJetty *jetty = mJettyPtrMap.Lookup(jettyId); + if (jetty == nullptr) { + NN_LOG_WARN("The jetty id " << jettyId << " has no associated UBJetty"); + continue; + } + jetty->Cleanup(); + mJettyPtrMap.Clear(jettyId); + + // 如果在创建 EP 过程中失败,则 UBJetty 无对应 EP, 依赖 ClearJettyResource做清理。 + // \see ClearJettyResource + auto ep = reinterpret_cast(jetty->GetUpContext()); + if (ep != nullptr) { + // EP 存在时,driver必定存在。 + auto *driver = ep->GetDriver(); + + // 从全局 EP 表中删除 EP. + UBSHcomNetEndpointPtr nep(ep); + driver->DestroyEndpoint(nep); + } + continue; + } + + UBOpContextInfo *contextInfo = reinterpret_cast(wc[i].user_ctx); + contextInfo->opResultType = UBOpContextInfo::OpResult(wc[i]); + switch (contextInfo->ubJetty->State()) { + case UBJettyState::READY: + break; + + // 已经处于 error 状态,需要等到 FLUSH_ERR_DONE 进行资源回收 + case UBJettyState::ERROR: + continue; + + case UBJettyState::RESET: + NN_LOG_ERROR("Unreachable: A jetty with reset state is unable to recv/send. Something went wrong."); + break; + } + + CheckPollingResult(*contextInfo, wc[i], lastBrokenQp, lastErrorWcStatus); + if (!contextInfo->HasInternalError()) { + // detach the context + contextInfo->ubJetty->RemoveOpCtxInfo(contextInfo); + } + + auto ep = reinterpret_cast(contextInfo->ubJetty->GetUpContext()); + if (ep == nullptr) { + NN_LOG_ERROR("Unreachable: A jetty received message with no EP bound"); + continue; + } + if (wc[i].status == URMA_CR_SUCCESS) { + ep->UpdateTargetHbTime(); + } + DispatchByContexetInfoType(*contextInfo, wc[i]); + } + /* if there is no coming request, call up idle function */ + if (mIdleHandler != nullptr && (pollCount) == 0) { + mIdleHandler(mIndex); + } + } + + inline __attribute__((always_inline)) void CheckPollingResult(UBOpContextInfo &contextInfo, urma_cr_t &wc, + UBJetty *&lastBrokenQp, urma_cr_status_t &lastErrorWcStatus) + { + if (NN_UNLIKELY(wc.status == URMA_CR_SUCCESS)) { + return; + } + if (contextInfo.opType == UBOpContextInfo::HB_WRITE) { + lastBrokenQp = contextInfo.ubJetty; + NN_LOG_INFO("HB poll cq receive wcStatus " << wc.status << ", maybe remote ep " << + contextInfo.ubJetty->GetUpId() << " closed"); + } else if (lastBrokenQp != contextInfo.ubJetty) { + lastBrokenQp = contextInfo.ubJetty; + NN_LOG_ERROR("Poll cq failed in UBWorker " << DetailName() << ", wcStatus " << wc.status << ", opType " << + (uint32_t)(contextInfo.opType) << ", ep id = " << contextInfo.ubJetty->GetUpId() << ", context = " << + (uint64_t)(&contextInfo) << ", mrMemAddr = " << contextInfo.mrMemAddr); + } else if (lastErrorWcStatus != wc.status) { + lastErrorWcStatus = wc.status; + NN_LOG_ERROR("Poll cq failed in UBWorker " << DetailName() << ", wc Status " << wc.status << ", opType " << + (uint32_t)contextInfo.opType << ", ep id = " << contextInfo.ubJetty->GetUpId() << ", context = " << + (uint64_t)(&contextInfo) << ", mrMemAddr = " << contextInfo.mrMemAddr); + } + } + + inline __attribute__((always_inline)) void DispatchByContexetInfoType(UBOpContextInfo &contextInfo, urma_cr_t &wc) + { + switch (contextInfo.opType) { + case (UBOpContextInfo::OpType::SEND): + case (UBOpContextInfo::OpType::SEND_RAW): + case (UBOpContextInfo::OpType::SEND_RAW_SGL): + case (UBOpContextInfo::OpType::SEND_SGL_INLINE): + mSendPostedHandler(&contextInfo); + break; + case (UBOpContextInfo::OpType::RECEIVE): /* NOTE, up context is store imm data */ + (contextInfo).dataSize = wc.completion_len; + *((int32_t *)(void *)&((contextInfo).upCtx)) = wc.imm_data; + mNewRequestHandler(&contextInfo); + break; + case (UBOpContextInfo::OpType::WRITE): + case (UBOpContextInfo::OpType::SGL_WRITE): + case (UBOpContextInfo::OpType::HB_WRITE): + case (UBOpContextInfo::OpType::READ): + case (UBOpContextInfo::OpType::SGL_READ): + mOneSideDoneHandler(&contextInfo); + break; + default: + NN_LOG_ERROR("Poll cq invalid OpType " << contextInfo.opType); + } + } +}; +} // namespace hcom +} // namespace ock + +#endif +#endif // HCOM_UB_WORKER_H diff --git a/src/transport/ub/ub_worker_core.cpp b/src/transport/ub/ub_worker_core.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4486346df41aa9790a84cd75d642acb60e2bfebd --- /dev/null +++ b/src/transport/ub/ub_worker_core.cpp @@ -0,0 +1,373 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "ub_worker.h" +#include "net_ub_endpoint.h" + +namespace ock { +namespace hcom { +std::string &WorkerTypeToString(UBWorkerType tp) +{ + static std::string workerTypeString[3] = {"sender", "receiver", "sender&receiver"}; + static std::string unknownWorkerType = "unknown worker type"; + if (tp != UB_SENDER && tp != UB_RECEIVER && tp != UB_SENDER_RECEIVER) { + return unknownWorkerType; + } + return workerTypeString[tp]; +} + +std::string &PollingModeToString(UBPollingMode m) +{ + static std::string workerModeString[2] = {"busy_polling", "cq_event_polling"}; + static std::string unknownWorkerMode = "unknown worker mode"; + if (m != UB_BUSY_POLLING && m != UB_EVENT_POLLING) { + return unknownWorkerMode; + } + return workerModeString[m]; +} + +UBWorker::UBWorker(const std::string &name, UBContext *ctx, const UBWorkerOptions &options, + const NetMemPoolFixedPtr &memPool, const NetMemPoolFixedPtr &sglMemPool) + : mName(name), + mUBContext(ctx), + mOpCtxMemPool(memPool), + mSglCtxMemPool(sglMemPool), + mOptions(options), + mProgressThreadStarted(false) +{ + if (mUBContext != nullptr) { + mUBContext->IncreaseRef(); + } + + mProgressCpuId = options.cpuId; + mProgressBatchSize = options.pollingBatchSize; + OBJ_GC_INCREASE(UBWorker); +} + +UResult UBWorker::Initialize() +{ + if (mInited) { + return UB_OK; + } + + if (mUBContext == nullptr || mUBContext->mUrmaContext == nullptr) { + NN_LOG_ERROR("UB Context is null, probably not initialized"); + return UB_PARAM_INVALID; + } + + // create and init CQ + auto tmpCQ = new (std::nothrow) + UBJfc(DetailName(), mUBContext, mOptions.workerMode == UB_EVENT_POLLING, reinterpret_cast(this)); + if (tmpCQ == nullptr) { + NN_LOG_ERROR("Failed to new UBJfc in UBWorker " << DetailName() << ", probably out of memory"); + return UB_NEW_OBJECT_FAILED; + } + + tmpCQ->SetJfcCount(mOptions.completionQueueDepth); + + UResult result = UB_OK; + if ((result = tmpCQ->Initialize()) != UB_OK) { + NN_LOG_ERROR("Failed to initialize UBJfc in UBWorker " << DetailName() << ", result " << result); + delete tmpCQ; + tmpCQ = nullptr; + return result; + } + + if ((result = mOpCtxInfoPool.Initialize(mOpCtxMemPool)) != UB_OK) { + NN_LOG_ERROR("Failed to initialize operation context info pool in UBWorker " << DetailName()); + delete tmpCQ; + tmpCQ = nullptr; + return result; + } + + if ((result = mSglCtxInfoPool.Initialize(mSglCtxMemPool)) != UB_OK) { + NN_LOG_ERROR("Failed to initialize sgl context info pool in UBWorker " << DetailName()); + delete tmpCQ; + tmpCQ = nullptr; + return result; + } + + if ((result = mJettyPtrMap.Initialize()) != UB_OK) { + NN_LOG_ERROR("Failed to initialize jetty ptr map in UBWorker " << DetailName()); + delete tmpCQ; + tmpCQ = nullptr; + return result; + } + + mUBJfc = tmpCQ; + mUBJfc->IncreaseRef(); + mInited = true; + return UB_OK; +} + +UResult UBWorker::UnInitialize() +{ + if (!mInited) { + return UB_OK; + } + + if (mUBJfc != nullptr) { + mUBJfc->DecreaseRef(); + mUBJfc = nullptr; + } + + if (mUBContext != nullptr) { + mUBContext->DecreaseRef(); + mUBContext = nullptr; + } + + if (mOpCtxMemPool != nullptr) { + mOpCtxMemPool.Set(nullptr); + } + + mOpCtxInfoPool.UnInitialize(); + + mInited = false; + return UB_OK; +} + +UResult UBWorker::ReInitializeCQ() +{ + if (!mInited) { + return UB_OK; + } + + if (mUBJfc != nullptr) { + mUBJfc->DecreaseRef(); + mUBJfc = nullptr; + } + + // create and init CQ + auto tmpCQ = new (std::nothrow) + UBJfc(DetailName(), mUBContext, mOptions.workerMode == UB_EVENT_POLLING, reinterpret_cast(this)); + if (tmpCQ == nullptr) { + NN_LOG_ERROR("Failed to new UBJfc in UBWorker " << DetailName() << + " in reinitialization, probably out of memory"); + return UB_NEW_OBJECT_FAILED; + } + + tmpCQ->SetJfcCount(mOptions.completionQueueDepth); + + UResult result = UB_OK; + if ((result = tmpCQ->Initialize()) != UB_OK) { + NN_LOG_ERROR("Failed to initialize UBJfc in UBWorker " << DetailName() << ", result " << result); + delete tmpCQ; + tmpCQ = nullptr; + return result; + } + + mUBJfc = tmpCQ; + mUBJfc->IncreaseRef(); + + return UB_OK; +} + +UResult UBWorker::Start() +{ + if (!mInited) { + NN_LOG_ERROR("Failed to start UBWorker " << DetailName() << " as not initialized"); + return UB_WORKER_NOT_INITIALIZED; + } + + if (mOptions.dontStartWorkers) { + NN_LOG_INFO("Do not start workers " << DetailName()); + return UB_OK; + } + + if ((mOptions.workerType == UB_RECEIVER || mOptions.workerType == UB_SENDER_RECEIVER) && + mNewRequestHandler == nullptr) { + NN_LOG_ERROR("New request handler is not registered yet in UBWorker " << DetailName()); + return UB_WORKER_REQUEST_HANDLER_NOT_SET; + } + + if ((mOptions.workerType == UB_SENDER || mOptions.workerType == UB_SENDER_RECEIVER) && + mSendPostedHandler == nullptr) { + NN_LOG_ERROR("Send request posted handler is not registered yet in UBWorker " << DetailName()); + return UB_WORKER_SEND_POSTED_HANDLER_NOT_SET; + } + + if (mOneSideDoneHandler == nullptr) { + NN_LOG_WARN("One side done handler is not registered yet in UBWorker " << DetailName()); + } + + mNeedStop = false; + std::thread tmpThread(&UBWorker::RunInThread, this); + mProgressThread = std::move(tmpThread); + std::string threadName = "UBWkr" + mIndex.ToString(); + if (pthread_setname_np(mProgressThread.native_handle(), threadName.c_str()) != 0) { + NN_LOG_WARN("Unable to set name of UBWorker progress thread"); + } + + if (mProgressCpuId != -1) { + cpu_set_t cpuSet; + CPU_ZERO(&cpuSet); + CPU_SET(mProgressCpuId, &cpuSet); + if (pthread_setaffinity_np(mProgressThread.native_handle(), sizeof(cpuSet), &cpuSet) != 0) { + NN_LOG_WARN("Unable to bind UBWorker" << mIndex.ToString() << " << to cpu " << mProgressCpuId); + } + } + + while (!mProgressThreadStarted.load()) { + usleep(NN_NO10); + } + + return UB_OK; +} + +UResult UBWorker::Stop() +{ + mNeedStop = true; + if (mProgressThread.native_handle()) { + mProgressThread.join(); + } + return UB_OK; +} + +void UBWorker::DoWithBusyPolling() +{ + // allocate wc vector + auto *wc = static_cast(calloc(mProgressBatchSize, sizeof(urma_cr_t))); + if (wc == nullptr) { + NN_LOG_ERROR("Failed to allocate wc in UBWorker " << DetailName() << ", thread exiting"); + return; + } + + uint32_t pollCount = 0; + UBJetty *lastBrokenQp = nullptr; + urma_cr_status_t lastErrorWcStatus = URMA_CR_SUCCESS; + + while (!mNeedStop) { + try { + pollCount = mProgressBatchSize; + if (BusyPolling(wc, pollCount)) { + continue; + } + TRACE_DELAY_BEGIN(UB_WORKER_BUSY_POLLING); + ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + TRACE_DELAY_END(UB_WORKER_BUSY_POLLING, 0); + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Got runtime incorrect signal in UBWorker::RunInThread '" << ex.what() << + "', ignore and continue"); + } catch (...) { + NN_LOG_WARN("Got unknown signal in UBWorker::RunInThread, ignore and continue"); + } + } + + free(wc); + wc = nullptr; +} + +void UBWorker::DoWithCQEventPolling() +{ + // allocate wc vector + auto *wc = static_cast(calloc(mProgressBatchSize, sizeof(urma_cr_t))); + if (wc == nullptr) { + NN_LOG_ERROR("Failed to allocate wc in UBWorker " << DetailName() << ", thread exiting"); + return; + } + + uint32_t pollCount = 0; + uint32_t pollTimeOut = 0; + UBJetty *lastBrokenQp = nullptr; + urma_cr_status_t lastErrorWcStatus = URMA_CR_SUCCESS; + + while (!mNeedStop) { + try { + pollCount = mProgressBatchSize; + pollTimeOut = mOptions.eventPollingTimeout; + if (CqEventPolling(wc, pollCount, pollTimeOut)) { + continue; + } + TRACE_DELAY_BEGIN(UB_WORKER_EVENT_POLLING); + ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + TRACE_DELAY_END(UB_WORKER_EVENT_POLLING, 0); + } catch (std::runtime_error &ex) { + NN_LOG_WARN("Got runtime incorrect signal in UB worker thread '" << ex.what() << "', ignore and continue"); + } catch (...) { + NN_LOG_WARN("Got unknown signal in UB worker thread, ignore and continue"); + } + } + + free(wc); + wc = nullptr; +} + +void UBWorker::RunInThread() +{ + if (mOptions.threadPriority != 0) { + if (NN_UNLIKELY(setpriority(PRIO_PROCESS, 0, mOptions.threadPriority) != 0)) { + char errBuf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Unable to set worker thread priority in ub worker " << mName << ", errno:" << + NetFunc::NN_GetStrError(errno, errBuf, NET_STR_ERROR_BUF_SIZE)); + } + } + + mProgressThreadStarted.store(true); + NN_LOG_INFO("UBWorker " << DetailName() << ", cpuId: " << mProgressCpuId << ", cq count: " << + ((mUBJfc != nullptr) ? mUBJfc->GetCQCount() : 0) << ", polling batch size: " << mProgressBatchSize << + ", more " << mOptions.ToString() << "] working thread started"); + + if (mOptions.workerMode == UB_BUSY_POLLING) { + DoWithBusyPolling(); + } else if (mOptions.workerMode == UB_EVENT_POLLING) { + DoWithCQEventPolling(); + } else { + NN_LOG_ERROR("Un-reachable"); + } + + NN_LOG_INFO("UBWorker " << DetailName() << " working thread exiting"); +} + +UResult UBWorker::Create(const std::string &name, UBContext *ctx, const UBWorkerOptions &options, + NetMemPoolFixedPtr memPool, NetMemPoolFixedPtr sglMemPool, UBWorker *&outWorker) +{ + if (ctx == nullptr || name.empty()) { + NN_LOG_ERROR("Failed to create ub worker as ctx is nullptr or name empty"); + return UB_PARAM_INVALID; + } + + auto tmp = new (std::nothrow) UBWorker(name, ctx, options, std::move(memPool), std::move(sglMemPool)); + if (tmp == nullptr) { + NN_LOG_ERROR("Failed to create UBWorker, probably out of memory"); + return UB_NEW_OBJECT_FAILED; + } + + outWorker = tmp; + return UB_OK; +} + +UResult UBWorker::CreateQP(UBJetty *&qp) +{ + if (NN_UNLIKELY(!mInited)) { + NN_LOG_ERROR("Failed to create qp with UBWorker " << DetailName() << " as not initialized"); + return UB_WORKER_NOT_INITIALIZED; + } + + JettyOptions jettyOptions(mOptions.qpSendQueueSize, mOptions.qpReceiveQueueSize, mOptions.qpMrSegSize, + mOptions.qpMrSegCount, mOptions.slave, mOptions.ubcMode); + qp = new (std::nothrow) UBJetty(DetailName(), UBJetty::NewId(), mUBContext, mUBJfc, jettyOptions); + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to create qp with UBWorker " << DetailName() << ", probably out of memory"); + return UB_NEW_OBJECT_FAILED; + } + + qp->SetUpContext1(reinterpret_cast(this)); + return UB_OK; +} +} +} +#endif diff --git a/src/transport/ub/ub_worker_io.cpp b/src/transport/ub/ub_worker_io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba2fa3199a058e6208643e6efe6849b34407ae99 --- /dev/null +++ b/src/transport/ub/ub_worker_io.cpp @@ -0,0 +1,474 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "ub_worker.h" + +namespace ock { +namespace hcom { + +UResult UBWorker::PostReceive(UBJetty *qp, uintptr_t bufAddress, uint32_t bufSize, urma_target_seg_t *localSeg) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to PostReceive with UBWorker " << DetailName() << " as qp is null"); + return UB_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostReceive with UBWorker " << DetailName() << " as no ctx left"); + return UB_QP_CTX_FULL; + } + + /* set to all 0 */ + bzero(ctx, sizeof(UBOpContextInfo)); + ctx->ubJetty = qp; + ctx->mrMemAddr = bufAddress; + ctx->dataSize = bufSize; + ctx->qpNum = qp->QpNum(); + ctx->localSeg = localSeg; + ctx->opType = UBOpContextInfo::RECEIVE; + ctx->opResultType = UBOpContextInfo::SUCCESS; + qp->IncreaseRef(); + + // attach context to qp firstly, because post cloud be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + auto result = qp->PostReceive(bufAddress, bufSize, localSeg, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +UResult UBWorker::RePostReceive(UBOpContextInfo *ctx) +{ + if (NN_UNLIKELY(ctx == nullptr || ctx->ubJetty == nullptr)) { + NN_LOG_ERROR("Failed to RePostReceive with UBWorker " << DetailName() << " as ctx or its qp is null"); + return UB_PARAM_INVALID; + } + + // attach context to qp firstly, because post cloud be finished very fast + // if posted failed, need to remove + ctx->ubJetty->AddOpCtxInfo(ctx); + + auto result = + ctx->ubJetty->PostReceive(ctx->mrMemAddr, mOptions.qpMrSegSize, ctx->localSeg, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + ctx->ubJetty->DecreaseRef(); + ctx->ubJetty->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +UResult UBWorker::PostSend(UBJetty *qp, const UBSendReadWriteRequest &req, urma_target_seg_t *localSeg, + uint32_t immData) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to PostSend with UBWorker " << DetailName() << " as qp is null"); + return UB_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostSend with UBWorker " << DetailName() << " as no reqInfo left"); + return UB_QP_CTX_FULL; + } + + if (NN_UNLIKELY(!qp->GetPostSendWr())) { + NN_LOG_ERROR("Failed to PostSend with UBWorker " << DetailName() << " as no post send wr left"); + mOpCtxInfoPool.Return(ctx); + return UB_QP_POST_SEND_WR_FULL; + } + ctx->ubJetty = qp; + ctx->mrMemAddr = req.lAddress; + ctx->dataSize = req.size; + ctx->qpNum = qp->QpNum(); + ctx->lKey = req.lKey; + ctx->opType = immData == 0 ? UBOpContextInfo::SEND : UBOpContextInfo::SEND_RAW; + ctx->opResultType = UBOpContextInfo::SUCCESS; + ctx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + (void)memcpy_s(ctx->upCtx, req.upCtxSize, req.upCtxData, req.upCtxSize); + } + qp->IncreaseRef(); + + // attach context to qp firstly, because post cloud be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + auto result = qp->PostSend(req.lAddress, req.size, localSeg, ctx, immData); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnPostSendWr(); + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +UResult UBWorker::PostSendSglInline(UBJetty *qp, const UBSendSglInlineHeader &header, const UBSendReadWriteRequest &req, + uint32_t immData) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Verbs Failed to PostSend with RDMAWorker " << DetailName() << " as qp is null"); + return UB_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Verbs Failed to PostSend with RDMAWorker " << DetailName() << " as no reqInfo left"); + return UB_QP_CTX_FULL; + } + + if (NN_UNLIKELY(!qp->GetPostSendWr())) { + NN_LOG_ERROR("Verbs Failed to PostSend with RDMAWorker " << DetailName() << " as no post send wr left"); + mOpCtxInfoPool.Return(ctx); + return UB_QP_POST_SEND_WR_FULL; + } + ctx->ubJetty = qp; + ctx->mrMemAddr = req.lAddress; + ctx->dataSize = req.size; + ctx->lKey = req.lKey; + ctx->qpNum = qp->QpNum(); + ctx->opType = UBOpContextInfo::SEND_SGL_INLINE; + ctx->opResultType = UBOpContextInfo::SUCCESS; + ctx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0 && NN_UNLIKELY(memcpy_s(ctx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != UB_OK)) { + NN_LOG_ERROR("Failed to copy req to ctx"); + return UB_PARAM_INVALID; + } + qp->IncreaseRef(); + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + UBSHcomNetTransDataIov netTransDataIov[NN_NO2]; + netTransDataIov[NN_NO0].address = reinterpret_cast(&header); + netTransDataIov[NN_NO0].size = sizeof(UBSendSglInlineHeader); + netTransDataIov[NN_NO1].address = req.lAddress; + netTransDataIov[NN_NO1].size = req.size; + + auto result = qp->PostSendSglInline(netTransDataIov, NN_NO2, reinterpret_cast(ctx), immData); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnPostSendWr(); + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +UResult UBWorker::PostSendSgl(UBJetty *qp, const UBSHcomNetTransSglRequest &req, const UBSHcomNetTransRequest &tlsReq, + uint32_t immData, bool isEncrypted) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to PostSendSgl with UBWorker " << DetailName() << " as qp is null"); + return UB_PARAM_INVALID; + } + + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Failed to PostSendSgl with UBWorker " << DetailName() << " as no ctx left"); + return UB_PARAM_INVALID; + } + sglCtx->qp = qp; + sglCtx->result = UB_OK; + if (NN_UNLIKELY(memcpy_s(sglCtx->iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, req.iov, + sizeof(UBSHcomNetTransSgeIov) * req.iovCount) != UB_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + mSglCtxInfoPool.Return(sglCtx); + return UB_PARAM_INVALID; + } + sglCtx->iovCount = req.iovCount; + sglCtx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != UB_OK)) { + NN_LOG_ERROR("Failed to copy req to sglCtx"); + mSglCtxInfoPool.Return(sglCtx); + return UB_PARAM_INVALID; + } + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostSendSgl with UBWorker " << DetailName() << " as no reqInfo left"); + mSglCtxInfoPool.Return(sglCtx); + return UB_QP_CTX_FULL; + } + if (NN_UNLIKELY(!qp->GetPostSendWr())) { + NN_LOG_ERROR("Failed to PostSendSgl with UBWorker " << DetailName() << " as no post send wr left"); + mOpCtxInfoPool.Return(ctx); + mSglCtxInfoPool.Return(sglCtx); + return UB_QP_POST_SEND_WR_FULL; + } + ctx->ubJetty = qp; + // if not encrypt reqTls lAddress\size\lKey is 0 + ctx->mrMemAddr = tlsReq.lAddress; + ctx->dataSize = tlsReq.size; + ctx->lKey = tlsReq.lKey; + ctx->qpNum = qp->QpNum(); + ctx->opType = UBOpContextInfo::SEND_RAW_SGL; + ctx->opResultType = UBOpContextInfo::SUCCESS; + ctx->upCtxSize = static_cast(sizeof(UBSgeCtxInfo)); + auto upCtx = static_cast((void *)&(ctx->upCtx)); + upCtx->ctx = sglCtx; + qp->IncreaseRef(); + + // attach context to qp firstly, because post could be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + UResult result = UB_OK; + if (isEncrypted != 0) { + result = qp->PostSend(tlsReq.lAddress, tlsReq.size, reinterpret_cast(tlsReq.srcSeg), + reinterpret_cast(ctx), immData); + } else { + result = qp->PostSendSgl(req.iov, req.iovCount, reinterpret_cast(ctx), immData); + } + + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnPostSendWr(); + qp->RemoveOpCtxInfo(ctx); + qp->DecreaseRef(); + mOpCtxInfoPool.Return(ctx); + mSglCtxInfoPool.Return(sglCtx); + } + return result; +} + +UResult UBWorker::PostRead(UBJetty *qp, const UBSendReadWriteRequest &req) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with UBWorker " << DetailName() << " as qp is null"); + return UB_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostRead with UBWorker " << DetailName() << " as no reqInfo left"); + return UB_QP_CTX_FULL; + } + + if (NN_UNLIKELY(!qp->GetOneSideWr())) { + NN_LOG_ERROR("Failed to PostRead with UBWorker " << DetailName() << " as no one side wr left"); + mOpCtxInfoPool.Return(ctx); + return UB_QP_ONE_SIDE_WR_FULL; + } + ctx->ubJetty = qp; + ctx->mrMemAddr = req.lAddress; + ctx->dataSize = req.size; + ctx->qpNum = qp->QpNum(); + ctx->lKey = req.lKey; + ctx->opType = UBOpContextInfo::READ; + ctx->opResultType = UBOpContextInfo::SUCCESS; + ctx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + (void)memcpy_s(ctx->upCtx, req.upCtxSize, req.upCtxData, req.upCtxSize); + } + qp->IncreaseRef(); + + // attach context to qp firstly, because post cloud be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + UResult result = UB_OK; + result = qp->PostRead(req.lAddress, reinterpret_cast(req.srcSeg), req.rAddress, + req.rKey, req.size, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnOneSideWr(); + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} + +UResult UBWorker::CreateOneSideCtx(const UBSgeCtxInfo &sgeInfo, const UBSHcomNetTransSgeIov *iov, uint32_t iovCount, + uint64_t (&ctxArr)[NET_SGE_MAX_IOV], bool isRead) +{ + if (iov == nullptr || iovCount == NN_NO0 || iovCount > NN_NO4 || ctxArr == nullptr) { + NN_LOG_ERROR("Urma failed to create oneSide operation ctx because param invalid"); + return UB_PARAM_INVALID; + } + for (uint32_t i = 0; i < iovCount; ++i) { + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Urma failed to oneSide operation with UBWorker " << DetailName() << " as no ctx left"); + for (uint32_t j = 0; j < i; ++j) { + sgeInfo.ctx->qp->ReturnOneSideWr(); + sgeInfo.ctx->qp->RemoveOpCtxInfo(reinterpret_cast(ctxArr[j])); + sgeInfo.ctx->qp->DecreaseRef(); + mOpCtxInfoPool.Return(reinterpret_cast(ctxArr[j])); + } + return UB_QP_CTX_FULL; + } + + if (NN_UNLIKELY(!sgeInfo.ctx->qp->GetOneSideWr())) { + NN_LOG_ERROR("Urma failed to oneSide operation with UBWorker " << DetailName() << + " as no one side wr left"); + mOpCtxInfoPool.Return(ctx); + for (uint32_t j = 0; j < i; ++j) { + sgeInfo.ctx->qp->ReturnOneSideWr(); + sgeInfo.ctx->qp->RemoveOpCtxInfo(reinterpret_cast(ctxArr[j])); + sgeInfo.ctx->qp->DecreaseRef(); + mOpCtxInfoPool.Return(reinterpret_cast(ctxArr[j])); + } + return UB_QP_ONE_SIDE_WR_FULL; + } + ctx->ubJetty = sgeInfo.ctx->qp; + ctx->mrMemAddr = iov[i].lAddress; + ctx->dataSize = iov[i].size; + ctx->qpNum = sgeInfo.ctx->qp->QpNum(); + ctx->lKey = iov[i].lKey; + ctx->opType = isRead ? UBOpContextInfo::SGL_READ : UBOpContextInfo::SGL_WRITE; + ctx->opResultType = UBOpContextInfo::SUCCESS; + ctx->upCtxSize = static_cast(sizeof(UBSgeCtxInfo)); + auto upCtx = static_cast((void *)&(ctx->upCtx)); + upCtx->ctx = sgeInfo.ctx; + upCtx->idx = i; + + sgeInfo.ctx->qp->IncreaseRef(); + sgeInfo.ctx->qp->AddOpCtxInfo(ctx); + ctxArr[i] = reinterpret_cast(ctx); + } + return UB_OK; +} + +UResult UBWorker::PostOneSideSgl(UBJetty *qp, const UBSendSglRWRequest &req, bool isRead) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Urma failed to PostRead with UBWorker " << DetailName() << " as qp is null"); + return UB_PARAM_INVALID; + } + + auto sglCtx = mSglCtxInfoPool.Get(); + if (NN_UNLIKELY(sglCtx == nullptr)) { + NN_LOG_ERROR("Urma failed to get from mSglCtxInfoPool "); + return UB_PARAM_INVALID; + } + + sglCtx->qp = qp; + sglCtx->result = UB_OK; + if (NN_UNLIKELY(memcpy_s(sglCtx->iov, sizeof(UBSHcomNetTransSgeIov) * NET_SGE_MAX_IOV, req.iov, + sizeof(UBSHcomNetTransSgeIov) * req.iovCount) != UB_OK)) { + NN_LOG_ERROR("Urma failed to copy iov to sglCtx"); + mSglCtxInfoPool.Return(sglCtx); + return UB_PARAM_INVALID; + } + sglCtx->iovCount = req.iovCount; + sglCtx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + if (NN_UNLIKELY(memcpy_s(sglCtx->upCtx, NN_NO16, req.upCtxData, req.upCtxSize) != UB_OK)) { + NN_LOG_ERROR("Urma failed to copy upCtx to sglCtx"); + mSglCtxInfoPool.Return(sglCtx); + return UB_PARAM_INVALID; + } + } + + UBSgeCtxInfo sgeInfo(sglCtx); + sglCtx->refCount = 0; + uint64_t ctxArr[NET_SGE_MAX_IOV]; + UResult result = CreateOneSideCtx(sgeInfo, req.iov, req.iovCount, ctxArr, isRead); + if (result != UB_OK) { + NN_LOG_ERROR("Urma failed to create one side ctx."); + mSglCtxInfoPool.Return(sglCtx); + return result; + } + result = qp->PostOneSideSgl(req.iov, req.iovCount, ctxArr, isRead, NET_SGE_MAX_IOV); + if (NN_UNLIKELY(result != UB_OK)) { + for (int i = 0; i < req.iovCount; ++i) { + qp->ReturnOneSideWr(); + qp->RemoveOpCtxInfo(reinterpret_cast(ctxArr[i])); + qp->DecreaseRef(); + mOpCtxInfoPool.Return(reinterpret_cast(ctxArr[i])); + } + mSglCtxInfoPool.Return(sglCtx); + } + return result; +} + +UResult UBWorker::PostWrite(UBJetty *qp, const UBSendReadWriteRequest &req, UBOpContextInfo::OpType type) +{ + if (NN_UNLIKELY(qp == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with UBWorker " << DetailName() << " as qp is null"); + return UB_PARAM_INVALID; + } + + auto ctx = mOpCtxInfoPool.Get(); + if (NN_UNLIKELY(ctx == nullptr)) { + NN_LOG_ERROR("Failed to PostWrite with UBWorker " << DetailName() << " as no ctx left"); + return UB_QP_CTX_FULL; + } + if (NN_UNLIKELY(!qp->GetOneSideWr())) { + NN_LOG_ERROR("Failed to PostWrite with UBWorker " << DetailName() << " as no one side wr left"); + mOpCtxInfoPool.Return(ctx); + return UB_QP_ONE_SIDE_WR_FULL; + } + ctx->ubJetty = qp; + ctx->mrMemAddr = req.lAddress; + ctx->dataSize = req.size; + ctx->qpNum = qp->QpNum(); + ctx->lKey = req.lKey; + ctx->opType = type; + ctx->opResultType = UBOpContextInfo::SUCCESS; + ctx->upCtxSize = req.upCtxSize; + if (req.upCtxSize > 0) { + (void)memcpy_s(ctx->upCtx, req.upCtxSize, req.upCtxData, req.upCtxSize); + } + qp->IncreaseRef(); + + // attach context to qp firstly, because post cloud be finished very fast + // if posted failed, need to remove + qp->AddOpCtxInfo(ctx); + + UResult result = UB_OK; + result = qp->PostWrite(req.lAddress, reinterpret_cast(req.srcSeg), req.rAddress, + req.rKey, req.size, reinterpret_cast(ctx)); + if (NN_UNLIKELY(result != UB_OK)) { + // remove ctx from qp firstly, then return to pool because, ctx maybe deleted + qp->ReturnOneSideWr(); + qp->DecreaseRef(); + qp->RemoveOpCtxInfo(ctx); + mOpCtxInfoPool.Return(ctx); + } + + // ctx could not be used if post successfully + return result; +} +} +} +#endif \ No newline at end of file diff --git a/src/under_api/obmm/obmm_api_dl.cpp b/src/under_api/obmm/obmm_api_dl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7443e1fbfb0769681ae16d36838040595ba7f5a4 --- /dev/null +++ b/src/under_api/obmm/obmm_api_dl.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include +#include "hcom_log.h" +#include "hcom_def.h" +#include "common/net_common.h" +#include "obmm_api_dl.h" + +using namespace ock::hcom; + +OBMM_EXPORT ObmmAPI::hcomObmmExport = nullptr; +OBMM_EXPORT_USERADDR ObmmAPI::hcomObmmExportUseraddr = nullptr; +OBMM_UNEXPORT ObmmAPI::hcomObmmUnexport = nullptr; +OBMM_IMPORT ObmmAPI::hcomObmmImport = nullptr; +OBMM_UNIMPORT ObmmAPI::hcomObmmUnimport = nullptr; +OBMM_OPEN ObmmAPI::hcomObmmOpen = nullptr; + +bool ObmmAPI::gLoaded = false; + +#define DLSYM(type, ptr, sym) \ + do { \ + auto ptr1 = dlsym(handle, sym); \ + if (ptr1 == nullptr) { \ + NN_LOG_ERROR("Failed to load function " << sym << ", error " << dlerror()); \ + dlclose(handle); \ + return NN_NOF1; \ + } \ + ptr = (type)ptr1; \ + } while (0) + +int ObmmAPI::LoadObmmAPI() +{ + if (gLoaded) { + return 0; + } + + auto handle = dlopen(OBMM_SO_PATH, RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + NN_LOG_ERROR("Failed to load obmm so " << OBMM_SO_PATH << ", error " << dlerror()); + return NN_NOF1; + } + DLSYM(OBMM_EXPORT, ObmmAPI::hcomObmmExport, "obmm_export"); + DLSYM(OBMM_UNEXPORT, ObmmAPI::hcomObmmUnexport, "obmm_unexport"); + DLSYM(OBMM_IMPORT, ObmmAPI::hcomObmmImport, "obmm_import"); + DLSYM(OBMM_UNIMPORT, ObmmAPI::hcomObmmUnimport, "obmm_unimport"); + DLSYM(OBMM_OPEN, ObmmAPI::hcomObmmOpen, "obmm_open"); + + NN_LOG_INFO("Success to load obmm"); + gLoaded = true; + + return 0; +} +#endif diff --git a/src/under_api/obmm/obmm_api_dl.h b/src/under_api/obmm/obmm_api_dl.h new file mode 100644 index 0000000000000000000000000000000000000000..37614ba101b62e961284b096d40b8d468ea6859f --- /dev/null +++ b/src/under_api/obmm/obmm_api_dl.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_UNDER_API_OBMM_API_DL_H +#define HCOM_UNDER_API_OBMM_API_DL_H +#ifdef UB_BUILD_ENABLED + +#include +#include "hcom.h" + +namespace ock { +namespace hcom { + +#define OBMM_SO_PATH "libobmm.so" + +using OBMM_EXPORT = uint64_t (*)(size_t size, bitmask *nodes, unsigned long flags, struct obmm_mem_desc *desc); +using OBMM_EXPORT_USERADDR = uint64_t (*)(void *addr, size_t size, unsigned long flags, struct obmm_mem_desc *desc); +using OBMM_UNEXPORT = int (*)(uint64_t id, unsigned long flags); +using OBMM_IMPORT = uint64_t (*)(struct obmm_mem_desc *desc, unsigned long flags, int *numa); +using OBMM_UNIMPORT = int (*)(uint64_t id, unsigned long flags); +using OBMM_OPEN = int (*)(uint64_t id); + +class ObmmAPI { +public: + static OBMM_EXPORT hcomObmmExport; + static OBMM_EXPORT_USERADDR hcomObmmExportUseraddr; + static OBMM_UNEXPORT hcomObmmUnexport; + static OBMM_IMPORT hcomObmmImport; + static OBMM_UNIMPORT hcomObmmUnimport; + static OBMM_OPEN hcomObmmOpen; + + static int LoadObmmAPI(); + +private: + static bool gLoaded; +}; +} +} +#endif +#endif /* UAPI_OBMM_H */ \ No newline at end of file diff --git a/src/under_api/obmm/obmm_api_wrapper.h b/src/under_api/obmm/obmm_api_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..e48b21b99e9b64ae905a1b1818a220325e552099 --- /dev/null +++ b/src/under_api/obmm/obmm_api_wrapper.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_OBMM_API_WRAPPER_H +#define HCOM_OBMM_API_WRAPPER_H +#ifdef UB_BUILD_ENABLED + +#include "obmm_api_dl.h" + +namespace ock { +namespace hcom { +class HcomObmm { +public: + static inline uint64_t ObmmExport(size_t size, bitmask *nodes, unsigned long flags, struct obmm_mem_desc *desc) + { + return ObmmAPI::hcomObmmExport(size, nodes, flags, desc); + } + + static inline uint64_t ObmmExportUseraddr(void *addr, size_t size, unsigned long flags, + struct obmm_mem_desc *desc) + { + return ObmmAPI::hcomObmmExportUseraddr(addr, size, flags, desc); + } + + static inline int ObmmUnexport(uint64_t id, unsigned long flags) + { + return ObmmAPI::hcomObmmUnexport(id, flags); + } + + static inline int ObmmUnimport(uint64_t id, unsigned long flags) + { + return ObmmAPI::hcomObmmUnimport(id, flags); + } + + static inline uint64_t ObmmImport(struct obmm_mem_desc *desc, unsigned long flags, int *numa) + { + return ObmmAPI::hcomObmmImport(desc, flags, numa); + } + + static inline int ObmmOpen(uint64_t id) + { + return ObmmAPI::hcomObmmOpen(id); + } + + static inline int Load() + { + return ObmmAPI::LoadObmmAPI(); + } +}; +} +} +#endif +#endif // HCOM_OBMM_API_WRAPPER_H \ No newline at end of file diff --git a/src/under_api/openssl/openssl_api_dl.cpp b/src/under_api/openssl/openssl_api_dl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c9b867b15db6db0735b09943ed5ca2e18fcee2be --- /dev/null +++ b/src/under_api/openssl/openssl_api_dl.cpp @@ -0,0 +1,281 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "hcom_utils.h" +#include "openssl_api_dl.h" + +#define DLSYM(handle, type, ptr, sym) \ + do { \ + auto ptr1 = dlsym((handle), (sym)); \ + if (ptr1 == nullptr) { \ + NN_LOG_ERROR("Failed to load " << (sym)); \ + dlclose(handle); \ + return -1; \ + } \ + (ptr) = (type)ptr1; \ + } while (0) + +/** @brief Adapt to different OpenSSL versions. */ +#define DLSYM_UPDATE(handle, type, ptr, sym1, sym2) \ + do { \ + auto ptr1 = dlsym((handle), (sym1)); \ + if (ptr1 == nullptr) { \ + ptr1 = dlsym((handle), (sym2)); \ + if (ptr1 == nullptr) { \ + NN_LOG_ERROR("Failed to load " << (sym1)); \ + dlclose(handle); \ + return -1; \ + } \ + } \ + (ptr) = (type)ptr1; \ + } while (0) + +namespace ock { +namespace hcom { +FuncInit SSLAPI::initSsl = nullptr; +FuncInit SSLAPI::initCypto = nullptr; +FuncOpensslCleanup SSLAPI::opensslCleanup = nullptr; + +FuncGetMethod SSLAPI::tlsServerMethod = nullptr; +FuncGetMethod SSLAPI::tlsClientMethod = nullptr; +FuncSslOperation SSLAPI::sslShutdown = nullptr; +FuncSslFd SSLAPI::sslSetFd = nullptr; +FuncSslNew SSLAPI::sslNew = nullptr; +FuncSslFree SSLAPI::sslFree = nullptr; +FuncSslCtxNew SSLAPI::sslCtxNew = nullptr; +FuncSslCtxFree SSLAPI::sslCtxFree = nullptr; +FuncSslWrite SSLAPI::sslWrite = nullptr; +FuncSslRead SSLAPI::sslRead = nullptr; +FuncSslOperation SSLAPI::sslConnect = nullptr; +FuncSslOperation SSLAPI::sslAccept = nullptr; +FuncSslGetError SSLAPI::sslGetError = nullptr; + +FuncSslCtxCtrl SSLAPI::sslCtxCtrl = nullptr; +FuncSslGetCurrentCipher SSLAPI::sslGetCurrentCipher = nullptr; +FuncSslGetVersion SSLAPI::sslGetVersion = nullptr; +FuncSetCipherSuites SSLAPI::setCipherSuites = nullptr; +FuncUsePrivKeyFile SSLAPI::usePrivKeyFile = nullptr; +FuncUseCertChainFile SSLAPI::useCertChainFile = nullptr; +FuncSslCtxSetVerify SSLAPI::sslCtxSetVerify = nullptr; +FuncSetDefaultPasswdCbUserdata SSLAPI::setDefaultPasswdCbUserdata = nullptr; +FuncSetCertVerifyCallback SSLAPI::setCertVerifyCallback = nullptr; +FuncLoadVerifyLocations SSLAPI::loadVerifyLocations = nullptr; +FuncCheckPrivateKey SSLAPI::checkPrivateKey = nullptr; +FuncSslGetVerifyResult SSLAPI::sslGetVerifyResult = nullptr; +FuncSslGetPeerCertificate SSLAPI::sslGetPeerCertificate = nullptr; +FuncSslCtxSetOptions SSLAPI::SslCtxSetOptions = nullptr; +FuncSslCtxSetPskFindSessionCallback SSLAPI::SslCtxSetPskFindSessionCallback = nullptr; +FuncSslCtxSetPskUseSessionCallback SSLAPI::SslCtxSetPskUseSessionCallback = nullptr; + +FuncSslSessionNew SSLAPI::SslSessionNew = nullptr; +FuncSslSessionSet1MasterKey SSLAPI::SslSessionSet1MasterKey = nullptr; +FuncSslSessionSetProtocolVersion SSLAPI::SslSessionSetProtocolVersion = nullptr; +FuncSslSessionSetCipher SSLAPI::SslSessionSetCipher = nullptr; +FuncSslCipherFind SSLAPI::SslCipherFind = nullptr; + +FuncEvpAesCipher SSLAPI::evpAes128Gcm = nullptr; +FuncEvpAesCipher SSLAPI::evpAes256Gcm = nullptr; +FuncEvpAesCipher SSLAPI::evpAes128Ccm = nullptr; +FuncEvpAesCipher SSLAPI::evpChacha20Poly1305 = nullptr; + +FuncEvpCipherCtxNew SSLAPI::evpCipherCtxNew = nullptr; +FuncEvpCipherCtxFree SSLAPI::evpCipherCtxFree = nullptr; +FuncEvpCipherCtxCtrl SSLAPI::evpCipherCtxCtrl = nullptr; + +FuncEvpEncryptInitEx SSLAPI::evpEncryptInitEx = nullptr; +FuncEvpEncryptUpdate SSLAPI::evpEncryptUpdate = nullptr; +FuncEvpEncryptFinalEx SSLAPI::evpEncryptFinalEx = nullptr; +FuncEvpDecryptInitEx SSLAPI::evpDecryptInitEx = nullptr; +FuncEvpDecryptUpdate SSLAPI::evpDecryptUpdate = nullptr; +FuncEvpDecryptFinalEx SSLAPI::evpDecryptFinalEx = nullptr; + +FuncRandPoll SSLAPI::randPoll = nullptr; +FuncRandStatus SSLAPI::randStatus = nullptr; +FuncRandBytes SSLAPI::randBytes = nullptr; +FuncRandBytes SSLAPI::randPrivBytes = nullptr; +FuncRandSeed SSLAPI::randSeed = nullptr; + +FuncX509VerifyCert SSLAPI::x509VerifyCert = nullptr; +FuncX509VerifyCertErrorString SSLAPI::x509VerifyCertErrorString = nullptr; +FuncX509StoreCtxGetError SSLAPI::x509StoreCtxGetError = nullptr; +FuncPemReadBioX509Crl SSLAPI::pemReadBioX509Crl = nullptr; +FuncBioSFile SSLAPI::bioSFile = nullptr; +FuncBioNew SSLAPI::bioNew = nullptr; +FuncBioFree SSLAPI::bioFree = nullptr; +FuncBioCtrl SSLAPI::bioCtrl = nullptr; +FuncX509StoreCtxGet0Store SSLAPI::x509StoreCtxGet0Store = nullptr; +FuncX509StoreCtxSetFlags SSLAPI::x509StoreCtxSetFlags = nullptr; +FuncX509StoreAddCrl SSLAPI::x509StoreAddCrl = nullptr; +FuncX509CrlFree SSLAPI::x509CrlFree = nullptr; + +bool SSLAPI::gLoaded = false; +const char *SSLAPI::gOpensslEnvPath = "HCOM_OPENSSL_PATH"; +const char *SSLAPI::gOpensslLibSslName = "libssl.so"; +const char *SSLAPI::gOpensslLibCryptoName = "libcrypto.so"; +const char *SSLAPI::gSep = "/"; + +int SSLAPI::GetLibPath(std::string &libSslPath, std::string &libCryptoPath) +{ + char *envPath = ::getenv(gOpensslEnvPath); + if (envPath == nullptr) { + libSslPath = gOpensslLibSslName; + libCryptoPath = gOpensslLibCryptoName; + return 0; + } + + std::string opensslPath = envPath; + if (!CanonicalPath(opensslPath)) { + NN_LOG_ERROR("env set for openssl is invalid " << gOpensslEnvPath); + return -1; + } + + libCryptoPath = opensslPath + gSep + gOpensslLibCryptoName; + if (::access(libCryptoPath.c_str(), F_OK) != 0) { + NN_LOG_ERROR("libcrypto.so path set in env is invalid"); + return -1; + } + + libSslPath = opensslPath + gSep + gOpensslLibSslName; + if (::access(libSslPath.c_str(), F_OK) != 0) { + NN_LOG_ERROR("libssl.so path set in env is invalid"); + return -1; + } + return 0; +} + +int SSLAPI::LoadSSLSymbols(void *sslHandle) +{ + DLSYM(sslHandle, FuncInit, initSsl, "OPENSSL_init_ssl"); + DLSYM(sslHandle, FuncInit, initCypto, "OPENSSL_init_crypto"); + DLSYM(sslHandle, FuncOpensslCleanup, opensslCleanup, "OPENSSL_cleanup"); + DLSYM(sslHandle, FuncGetMethod, tlsServerMethod, "TLS_server_method"); + DLSYM(sslHandle, FuncGetMethod, tlsClientMethod, "TLS_client_method"); + DLSYM(sslHandle, FuncSslOperation, sslShutdown, "SSL_shutdown"); + DLSYM(sslHandle, FuncSslFd, sslSetFd, "SSL_set_fd"); + DLSYM(sslHandle, FuncSslNew, sslNew, "SSL_new"); + DLSYM(sslHandle, FuncSslFree, sslFree, "SSL_free"); + DLSYM(sslHandle, FuncSslCtxNew, sslCtxNew, "SSL_CTX_new"); + DLSYM(sslHandle, FuncSslCtxFree, sslCtxFree, "SSL_CTX_free"); + DLSYM(sslHandle, FuncSslWrite, sslWrite, "SSL_write"); + DLSYM(sslHandle, FuncSslRead, sslRead, "SSL_read"); + DLSYM(sslHandle, FuncSslOperation, sslConnect, "SSL_connect"); + DLSYM(sslHandle, FuncSslOperation, sslAccept, "SSL_accept"); + DLSYM(sslHandle, FuncSslGetError, sslGetError, "SSL_get_error"); + DLSYM(sslHandle, FuncSetCipherSuites, setCipherSuites, "SSL_CTX_set_ciphersuites"); + DLSYM(sslHandle, FuncSslCtxCtrl, sslCtxCtrl, "SSL_CTX_ctrl"); + DLSYM(sslHandle, FuncSslGetCurrentCipher, sslGetCurrentCipher, "SSL_get_current_cipher"); + DLSYM(sslHandle, FuncSslGetVersion, sslGetVersion, "SSL_get_version"); + DLSYM(sslHandle, FuncUsePrivKeyFile, usePrivKeyFile, "SSL_CTX_use_PrivateKey_file"); + DLSYM(sslHandle, FuncUseCertChainFile, useCertChainFile, "SSL_CTX_use_certificate_chain_file"); + DLSYM(sslHandle, FuncSslCtxSetVerify, sslCtxSetVerify, "SSL_CTX_set_verify"); + DLSYM(sslHandle, FuncSetDefaultPasswdCbUserdata, setDefaultPasswdCbUserdata, + "SSL_CTX_set_default_passwd_cb_userdata"); + DLSYM(sslHandle, FuncSetCertVerifyCallback, setCertVerifyCallback, "SSL_CTX_set_cert_verify_callback"); + DLSYM(sslHandle, FuncLoadVerifyLocations, loadVerifyLocations, "SSL_CTX_load_verify_locations"); + DLSYM(sslHandle, FuncCheckPrivateKey, checkPrivateKey, "SSL_CTX_check_private_key"); + DLSYM(sslHandle, FuncSslGetVerifyResult, sslGetVerifyResult, "SSL_get_verify_result"); + DLSYM_UPDATE(sslHandle, FuncSslGetPeerCertificate, sslGetPeerCertificate, "SSL_get_peer_certificate", + "SSL_get1_peer_certificate"); + DLSYM(sslHandle, FuncSslCtxSetOptions, SslCtxSetOptions, "SSL_CTX_set_options"); + DLSYM(sslHandle, FuncSslCtxSetPskFindSessionCallback, SslCtxSetPskFindSessionCallback, + "SSL_CTX_set_psk_find_session_callback"); + DLSYM(sslHandle, FuncSslCtxSetPskUseSessionCallback, SslCtxSetPskUseSessionCallback, + "SSL_CTX_set_psk_use_session_callback"); + DLSYM(sslHandle, FuncSslSessionNew, SslSessionNew, "SSL_SESSION_new"); + DLSYM(sslHandle, FuncSslSessionSet1MasterKey, SslSessionSet1MasterKey, "SSL_SESSION_set1_master_key"); + DLSYM(sslHandle, FuncSslSessionSetProtocolVersion, SslSessionSetProtocolVersion, + "SSL_SESSION_set_protocol_version"); + DLSYM(sslHandle, FuncSslSessionSetCipher, SslSessionSetCipher, "SSL_SESSION_set_cipher"); + DLSYM(sslHandle, FuncSslCipherFind, SslCipherFind, "SSL_CIPHER_find"); + return 0; +} + +int SSLAPI::LoadCryptoSymbols(void *cryptoHandle) +{ + DLSYM(cryptoHandle, FuncEvpCipherCtxNew, evpCipherCtxNew, "EVP_CIPHER_CTX_new"); + DLSYM(cryptoHandle, FuncEvpCipherCtxFree, evpCipherCtxFree, "EVP_CIPHER_CTX_free"); + DLSYM(cryptoHandle, FuncEvpCipherCtxCtrl, evpCipherCtxCtrl, "EVP_CIPHER_CTX_ctrl"); + DLSYM(cryptoHandle, FuncEvpEncryptInitEx, evpEncryptInitEx, "EVP_EncryptInit_ex"); + DLSYM(cryptoHandle, FuncEvpEncryptUpdate, evpEncryptUpdate, "EVP_EncryptUpdate"); + DLSYM(cryptoHandle, FuncEvpEncryptFinalEx, evpEncryptFinalEx, "EVP_EncryptFinal_ex"); + DLSYM(cryptoHandle, FuncEvpDecryptInitEx, evpDecryptInitEx, "EVP_DecryptInit_ex"); + DLSYM(cryptoHandle, FuncEvpDecryptUpdate, evpDecryptUpdate, "EVP_DecryptUpdate"); + DLSYM(cryptoHandle, FuncEvpDecryptFinalEx, evpDecryptFinalEx, "EVP_DecryptFinal_ex"); + DLSYM(cryptoHandle, FuncEvpAesCipher, evpAes128Gcm, "EVP_aes_128_gcm"); + DLSYM(cryptoHandle, FuncEvpAesCipher, evpAes256Gcm, "EVP_aes_256_gcm"); + DLSYM(cryptoHandle, FuncEvpAesCipher, evpAes128Ccm, "EVP_aes_128_ccm"); + DLSYM(cryptoHandle, FuncEvpAesCipher, evpChacha20Poly1305, "EVP_chacha20_poly1305"); + + DLSYM(cryptoHandle, FuncRandPoll, randPoll, "RAND_poll"); + DLSYM(cryptoHandle, FuncRandStatus, randStatus, "RAND_status"); + DLSYM(cryptoHandle, FuncRandBytes, randBytes, "RAND_bytes"); + DLSYM(cryptoHandle, FuncRandBytes, randPrivBytes, "RAND_priv_bytes"); + DLSYM(cryptoHandle, FuncRandSeed, randSeed, "RAND_seed"); + + DLSYM(cryptoHandle, FuncX509VerifyCert, x509VerifyCert, "X509_verify_cert"); + DLSYM(cryptoHandle, FuncX509VerifyCertErrorString, x509VerifyCertErrorString, "X509_verify_cert_error_string"); + DLSYM(cryptoHandle, FuncX509StoreCtxGetError, x509StoreCtxGetError, "X509_STORE_CTX_get_error"); + DLSYM(cryptoHandle, FuncPemReadBioX509Crl, pemReadBioX509Crl, "PEM_read_bio_X509_CRL"); + DLSYM(cryptoHandle, FuncBioSFile, bioSFile, "BIO_s_file"); + DLSYM(cryptoHandle, FuncBioNew, bioNew, "BIO_new"); + DLSYM(cryptoHandle, FuncBioFree, bioFree, "BIO_free"); + DLSYM(cryptoHandle, FuncBioCtrl, bioCtrl, "BIO_ctrl"); + DLSYM(cryptoHandle, FuncX509StoreCtxGet0Store, x509StoreCtxGet0Store, "X509_STORE_CTX_get0_store"); + DLSYM(cryptoHandle, FuncX509StoreCtxSetFlags, x509StoreCtxSetFlags, "X509_STORE_CTX_set_flags"); + DLSYM(cryptoHandle, FuncX509StoreAddCrl, x509StoreAddCrl, "X509_STORE_add_crl"); + DLSYM(cryptoHandle, FuncX509CrlFree, x509CrlFree, "X509_CRL_free"); + return 0; +} + +int SSLAPI::LoadOpensslAPI() +{ + NN_LOG_INFO("Starting to load openssl api"); + if (gLoaded) { + return 0; + } + + std::string libSslPath; + std::string libCryptoPath; + if (GetLibPath(libSslPath, libCryptoPath) != 0) { + return -1; + } + + auto sslHandle = dlopen(libSslPath.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (sslHandle == nullptr) { + NN_LOG_ERROR("Failed to dlopen libssl.so err: " << dlerror()); + return -1; + } + + if (LoadSSLSymbols(sslHandle) == -1) { + dlclose(sslHandle); + return -1; + } + + auto cryptoHandle = dlopen(libCryptoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (cryptoHandle == nullptr) { + NN_LOG_ERROR("Failed to dlopen libcrypto.so err: " << dlerror()); + dlclose(sslHandle); + return -1; + } + + if (LoadCryptoSymbols(cryptoHandle) == -1) { + dlclose(sslHandle); + dlclose(cryptoHandle); + return -1; + } + gLoaded = true; + return 0; +} +} +} \ No newline at end of file diff --git a/src/under_api/openssl/openssl_api_dl.h b/src/under_api/openssl/openssl_api_dl.h new file mode 100644 index 0000000000000000000000000000000000000000..7ee5bfb81636eaa5124ad1176b636f375f28ce4f --- /dev/null +++ b/src/under_api/openssl/openssl_api_dl.h @@ -0,0 +1,204 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_UNDER_API_OPENSSL_API_DL_H_2134 +#define HCOM_UNDER_API_OPENSSL_API_DL_H_2134 + +#include "hcom.h" + +namespace ock { +namespace hcom { +// Openssl datatype +using OPENSSL_INIT_SETTINGS = struct ossl_init_settings_st; +using SSL_METHOD = struct ssl_method_st; +using SSL = struct ssl_st; +using SSL_CTX = struct ssl_ctx_st; +using X509_STORE_CTX = struct x509_store_ctx_st; +using X509_CRL = struct x509_crl; +using ENGINE = struct engine_st; +using EVP_CIPHER = struct evp_cipher_st; +using EVP_CIPHER_CTX = struct evp_cipher_ctx_st; +using SSL_CIPHER = struct ssl_cipher_st; +using X509 = struct x509_st; +using BIO = struct bio; +using PEM_PASSWORD_CB = struct pem_password_cb; +using BIO_METHOD = struct bio_method; +using X509_STORE = struct x509_store; +using EVP_MD = struct evp_md_st; +using SSL_SESSION = struct ssl_session_st; + +using SSL_psk_find_session_cb_func = int (*)(SSL *ssl, const unsigned char *identity, size_t identity_len, + SSL_SESSION **sess); +using SSL_psk_use_session_cb_func = int (*)(SSL *ssl, const EVP_MD *md, const unsigned char **id, size_t *idlen, + SSL_SESSION **sess); + +using FuncInit = int (*)(uint64_t, const OPENSSL_INIT_SETTINGS *); +using FuncOpensslCleanup = void (*)(); +using FuncGetMethod = const SSL_METHOD *(*)(void); +using FuncSslOperation = int (*)(SSL *); +using FuncSslFd = int (*)(SSL *, int); +using FuncSslNew = SSL *(*)(SSL_CTX *); +using FuncSslFree = void (*)(SSL *); +using FuncSslCtxNew = SSL_CTX *(*)(const SSL_METHOD *); +using FuncSslCtxFree = void (*)(SSL_CTX *); +using FuncSslWrite = int (*)(SSL *, const void *, int); +using FuncSslRead = int (*)(SSL *, void *, int); +using FuncSslGetError = int (*)(const SSL *, int); + +using FuncSetCipherSuites = int (*)(SSL_CTX *, const char *); +// SSL_CTX_set_min_proto_version +using FuncSslCtxCtrl = long (*)(SSL_CTX *, int, long, void *); +using FuncSslGetCurrentCipher = const SSL_CIPHER *(*)(const SSL *); +using FuncSslGetVersion = const char *(*)(const SSL *); + +using FuncUsePrivKeyFile = int (*)(SSL_CTX *ctx, const char *, int); +using FuncUseCertChainFile = int (*)(SSL_CTX *, const char *); +using FuncSslCtxSetVerify = void (*)(SSL_CTX *, int mode, int (*)(int, X509_STORE_CTX *)); +using FuncSetDefaultPasswdCbUserdata = void (*)(SSL_CTX *, void *); +using FuncSetCertVerifyCallback = void (*)(SSL_CTX *, int (*cb)(X509_STORE_CTX *, void *), void *); +using FuncLoadVerifyLocations = int (*)(SSL_CTX *, const char *, const char *); +using FuncCheckPrivateKey = int (*)(const SSL_CTX *); +using FuncSslGetVerifyResult = long (*)(const SSL *); +using FuncSslGetPeerCertificate = X509 *(*)(const SSL *); +using FuncSslCtxSetOptions = int (*)(const SSL_CTX *, int); +using FuncSslCtxSetPskFindSessionCallback = int (*)(SSL_CTX *, SSL_psk_find_session_cb_func); +using FuncSslCtxSetPskUseSessionCallback = int (*)(SSL_CTX *, SSL_psk_use_session_cb_func); + +using FuncSslSessionNew = SSL_SESSION *(*)(); +using FuncSslSessionSet1MasterKey = int (*)(SSL_SESSION *, const unsigned char *, size_t); +using FuncSslSessionSetProtocolVersion = int (*)(SSL_SESSION *, int); +using FuncSslSessionSetCipher = int (*)(SSL_SESSION *, const SSL_CIPHER *); +using FuncSslCipherFind = const SSL_CIPHER *(*)(SSL *, const unsigned char *); + +using FuncEvpAesCipher = const EVP_CIPHER *(*)(); +using FuncEvpCipherCtxNew = EVP_CIPHER_CTX *(*)(); +using FuncEvpCipherCtxFree = void (*)(EVP_CIPHER_CTX *); +using FuncEvpCipherCtxCtrl = int (*)(EVP_CIPHER_CTX *, int, int, void *); +using FuncEvpEncryptInitEx = int (*)(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *, const unsigned char *, + const unsigned char *); +using FuncEvpEncryptUpdate = int (*)(EVP_CIPHER_CTX *, unsigned char *, int *, const unsigned char *, int); +using FuncEvpEncryptFinalEx = int (*)(EVP_CIPHER_CTX *, unsigned char *, int *); +using FuncEvpDecryptInitEx = FuncEvpEncryptInitEx; +using FuncEvpDecryptUpdate = FuncEvpEncryptUpdate; +using FuncEvpDecryptFinalEx = FuncEvpEncryptFinalEx; + +using FuncRandPoll = int (*)(void); +using FuncRandStatus = FuncRandPoll; +using FuncRandBytes = int (*)(unsigned char *buf, int num); +using FuncRandSeed = void (*)(const void *, int); + +using FuncX509VerifyCert = int (*)(X509_STORE_CTX *ctx); +using FuncX509VerifyCertErrorString = const char *(*)(long n); +using FuncX509StoreCtxGetError = int (*)(const X509_STORE_CTX *ctx); +using FuncPemReadBioX509Crl = X509_CRL *(*)(BIO *bp, X509_CRL **x, PEM_PASSWORD_CB *cb, void *u); +using FuncBioSFile = const BIO_METHOD *(*)(void); +using FuncBioNew = BIO *(*)(const BIO_METHOD *); +using FuncBioFree = void (*)(BIO *b); +using FuncBioCtrl = long (*)(BIO *bp, int cmd, long larg, void *parg); +using FuncX509StoreCtxGet0Store = X509_STORE *(*)(const X509_STORE_CTX *ctx); +using FuncX509StoreCtxSetFlags = void (*)(X509_STORE_CTX *ctx, unsigned long flags); +using FuncX509StoreAddCrl = int (*)(X509_STORE *xs, X509_CRL *x); +using FuncX509CrlFree = void (*)(X509_CRL *x); + +class SSLAPI { +public: + static FuncInit initSsl; + static FuncInit initCypto; + static FuncOpensslCleanup opensslCleanup; + static FuncGetMethod tlsServerMethod; + static FuncGetMethod tlsClientMethod; + static FuncSslOperation sslShutdown; + static FuncSslFd sslSetFd; + static FuncSslNew sslNew; + static FuncSslFree sslFree; + static FuncSslCtxNew sslCtxNew; + static FuncSslCtxFree sslCtxFree; + static FuncSslWrite sslWrite; + static FuncSslRead sslRead; + static FuncSslOperation sslConnect; + static FuncSslOperation sslAccept; + static FuncSslGetError sslGetError; + + static FuncSslCtxCtrl sslCtxCtrl; + static FuncSslGetCurrentCipher sslGetCurrentCipher; + static FuncSslGetVersion sslGetVersion; + static FuncSetCipherSuites setCipherSuites; + static FuncUsePrivKeyFile usePrivKeyFile; + static FuncUseCertChainFile useCertChainFile; + static FuncSslCtxSetVerify sslCtxSetVerify; + static FuncSetDefaultPasswdCbUserdata setDefaultPasswdCbUserdata; + static FuncSetCertVerifyCallback setCertVerifyCallback; + static FuncLoadVerifyLocations loadVerifyLocations; + static FuncCheckPrivateKey checkPrivateKey; + static FuncSslGetVerifyResult sslGetVerifyResult; + static FuncSslGetPeerCertificate sslGetPeerCertificate; + static FuncSslCtxSetOptions SslCtxSetOptions; + static FuncSslCtxSetPskFindSessionCallback SslCtxSetPskFindSessionCallback; + static FuncSslCtxSetPskUseSessionCallback SslCtxSetPskUseSessionCallback; + + static FuncSslSessionNew SslSessionNew; + static FuncSslSessionSet1MasterKey SslSessionSet1MasterKey; + static FuncSslSessionSetProtocolVersion SslSessionSetProtocolVersion; + static FuncSslSessionSetCipher SslSessionSetCipher; + static FuncSslCipherFind SslCipherFind; + + static FuncEvpAesCipher evpAes128Gcm; + static FuncEvpAesCipher evpAes256Gcm; + static FuncEvpAesCipher evpAes128Ccm; + static FuncEvpAesCipher evpChacha20Poly1305; + + static FuncEvpCipherCtxNew evpCipherCtxNew; + static FuncEvpCipherCtxFree evpCipherCtxFree; + static FuncEvpCipherCtxCtrl evpCipherCtxCtrl; + + static FuncEvpEncryptInitEx evpEncryptInitEx; + static FuncEvpEncryptUpdate evpEncryptUpdate; + static FuncEvpEncryptFinalEx evpEncryptFinalEx; + static FuncEvpDecryptInitEx evpDecryptInitEx; + static FuncEvpDecryptUpdate evpDecryptUpdate; + static FuncEvpDecryptFinalEx evpDecryptFinalEx; + + static FuncRandPoll randPoll; + static FuncRandStatus randStatus; + static FuncRandBytes randBytes; + static FuncRandBytes randPrivBytes; + static FuncRandSeed randSeed; + + static FuncX509VerifyCert x509VerifyCert; + static FuncX509VerifyCertErrorString x509VerifyCertErrorString; + static FuncX509StoreCtxGetError x509StoreCtxGetError; + static FuncPemReadBioX509Crl pemReadBioX509Crl; + static FuncBioSFile bioSFile; + static FuncBioNew bioNew; + static FuncBioFree bioFree; + static FuncBioCtrl bioCtrl; + static FuncX509StoreCtxGet0Store x509StoreCtxGet0Store; + static FuncX509StoreCtxSetFlags x509StoreCtxSetFlags; + static FuncX509StoreAddCrl x509StoreAddCrl; + static FuncX509CrlFree x509CrlFree; + + static int LoadOpensslAPI(); + +private: + static const char *gOpensslEnvPath; + static const char *gOpensslLibSslName; + static const char *gOpensslLibCryptoName; + static const char *gSep; + static bool gLoaded; + + static int GetLibPath(std::string &libSslPath, std::string &libCryptoPath); + static int LoadSSLSymbols(void *sslHandle); + static int LoadCryptoSymbols(void *cryptoHandle); +}; +} +} + +#endif // HCOM_UNDER_API_OPENSSL_API_DL_H_2134 diff --git a/src/under_api/openssl/openssl_api_wrapper.h b/src/under_api/openssl/openssl_api_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..3eca8bd965ebaf5908f8c5001969245cdb571458 --- /dev/null +++ b/src/under_api/openssl/openssl_api_wrapper.h @@ -0,0 +1,367 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_OPENSSL_API_WRAPPER_H +#define HCOM_OPENSSL_API_WRAPPER_H + +#include "openssl_api_dl.h" + +namespace ock { +namespace hcom { +class HcomSsl { +public: + static const uint32_t SSL_VERIFY_NONE = NN_NO0; + static const uint32_t SSL_VERIFY_PEER = NN_NO1; + static const uint32_t SSL_VERIFY_FAIL_IF_NO_PEER_CERT = NN_NO2; + static const uint32_t SSL_FILETYPE_PEM = NN_NO1; + static const uint32_t EVP_CTRL_AEAD_SET_IVLEN = NN_NO9; + static const uint32_t EVP_CTRL_AEAD_GET_TAG = NN_NO16; + static const uint32_t EVP_CTRL_AEAD_SET_TAG = NN_NO17; + static const uint32_t OPENSSL_INIT_LOAD_SSL_STRINGS = NN_NO2097152; + static const uint32_t OPENSSL_INIT_LOAD_CRYPTO_STRINGS = NN_NO2; + static const uint32_t SSL_CTRL_SET_MIN_PROTO_VERSION = NN_NO123; + static const uint32_t SSL_CTRL_SET_MAX_PROTO_VERSION = NN_NO124; + static const uint32_t SSL_ERROR_WANT_READ = NN_NO2; + static const uint32_t SSL_ERROR_WANT_WRITE = NN_NO3; + static const uint32_t SSL_NO_TLS1_2_RENEGOTIATION = NN_NO262144; + + static const uint32_t BIO_C_SET_FILENAME = NN_NO108; + static const uint32_t BIO_CLOSE = NN_NO1; + static const uint32_t BIO_FP_READ = NN_NO2; + static const uint32_t X509_V_FLAG_CRL_CHECK = NN_NO4; + + static int OpensslInitSsl(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) + { + return SSLAPI::initSsl(opts, settings); + } + + static inline int OpensslInitCrypto(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) + { + return SSLAPI::initCypto(opts, settings); + } + + static inline const SSL_METHOD *TlsClientMethod() + { + return SSLAPI::tlsClientMethod(); + } + + static inline const SSL_METHOD *TlsServerMethod() + { + return SSLAPI::tlsServerMethod(); + } + + static inline int SslShutdown(SSL *s) + { + return SSLAPI::sslShutdown(s); + } + + static inline int SslSetFd(SSL *s, int fd) + { + return SSLAPI::sslSetFd(s, fd); + } + + static inline SSL *SslNew(SSL_CTX *ctx) + { + return SSLAPI::sslNew(ctx); + } + + static inline void SslFree(SSL *s) + { + SSLAPI::sslFree(s); + } + + static SSL_CTX *SslCtxNew(const SSL_METHOD *method) + { + return SSLAPI::sslCtxNew(method); + } + + static inline void SslCtxFree(SSL_CTX *ctx) + { + SSLAPI::sslCtxFree(ctx); + } + + static inline int SslWrite(SSL *s, const void *buf, int num) + { + return SSLAPI::sslWrite(s, buf, num); + } + + static inline int SslRead(SSL *s, void *buf, int num) + { + return SSLAPI::sslRead(s, buf, num); + } + + static inline int SslConnect(SSL *s) + { + return SSLAPI::sslConnect(s); + } + + static inline int SslAccept(SSL *s) + { + return SSLAPI::sslAccept(s); + } + + static inline int SslGetError(const SSL *s, int retCode) + { + return SSLAPI::sslGetError(s, retCode); + } + + static inline int SslCtxSetCipherSuites(SSL_CTX *ctx, const char *str) + { + return SSLAPI::setCipherSuites(ctx, str); + } + + static inline long SslCtxCtrl(SSL_CTX *ctx, int cmd, long larg, void *parg) + { + return SSLAPI::sslCtxCtrl(ctx, cmd, larg, parg); + } + + static inline const char *SslGetVersion(const SSL *ssl) + { + return SSLAPI::sslGetVersion(ssl); + } + + static inline void SslCtxSetVerify(SSL_CTX *ctx, int mode, int (*cb)(int, X509_STORE_CTX *)) + { + SSLAPI::sslCtxSetVerify(ctx, mode, cb); + } + + static inline int SslCtxUsePrivateKeyFile(SSL_CTX *ctx, const char *file, int type) + { + return SSLAPI::usePrivKeyFile(ctx, file, type); + } + + static inline int SslCtxUseCertificateChainFile(SSL_CTX *ctx, const char *file) + { + return SSLAPI::useCertChainFile(ctx, file); + } + + static inline void SslCtxSetDefaultPasswdCbUserdata(SSL_CTX *ctx, void *u) + { + SSLAPI::setDefaultPasswdCbUserdata(ctx, u); + } + + static inline void SslCtxSetCertVerifyCallback(SSL_CTX *ctx, int (*cb)(X509_STORE_CTX *, void *), void *arg) + { + SSLAPI::setCertVerifyCallback(ctx, cb, arg); + } + + static inline int SslCtxLoadVerifyLocations(SSL_CTX *ctx, const char *cafile, const char *capath) + { + return SSLAPI::loadVerifyLocations(ctx, cafile, capath); + } + + static inline int SslCtxCheckPrivateKey(const SSL_CTX *ctx) + { + return SSLAPI::checkPrivateKey(ctx); + } + + static inline void SslCtxSetPskFindSessionCallback(SSL_CTX *ctx, SSL_psk_find_session_cb_func cb) + { + SSLAPI::SslCtxSetPskFindSessionCallback(ctx, cb); + } + + static inline void SslCtxSetPskUseSessionCallback(SSL_CTX *ctx, SSL_psk_use_session_cb_func cb) + { + SSLAPI::SslCtxSetPskUseSessionCallback(ctx, cb); + } + + static inline SSL_SESSION *SslSessionNew() + { + return SSLAPI::SslSessionNew(); + } + + static inline int SslSessionSet1MasterKey(SSL_SESSION *sess, const unsigned char *in, size_t len) + { + return SSLAPI::SslSessionSet1MasterKey(sess, in, len); + } + + static inline int SslSessionSetProtocolVersion(SSL_SESSION *sess, int version) + { + return SSLAPI::SslSessionSetProtocolVersion(sess, version); + } + + static inline int SslSessionSetCipher(SSL_SESSION *sess, const SSL_CIPHER *cipher) + { + return SSLAPI::SslSessionSetCipher(sess, cipher); + } + + static inline const SSL_CIPHER *SslCipherFind(SSL *ssl, const unsigned char *ptr) + { + return SSLAPI::SslCipherFind(ssl, ptr); + } + + static inline X509 *SslGetPeerCertificate(const SSL *ssl) + { + return SSLAPI::sslGetPeerCertificate(ssl); + } + + static inline long SslGetVerifyResult(const SSL *ssl) + { + return SSLAPI::sslGetVerifyResult(ssl); + } + + static inline int SslCtxSetOption(const SSL_CTX *ctx, int options) + { + return SSLAPI::SslCtxSetOptions(ctx, options); + } + + static inline const EVP_CIPHER *EvpAes128Gcm() + { + return SSLAPI::evpAes128Gcm(); + } + + static inline const EVP_CIPHER *EvpAes256Gcm() + { + return SSLAPI::evpAes256Gcm(); + } + + static inline const EVP_CIPHER *EvpAes128Ccm() + { + return SSLAPI::evpAes128Ccm(); + } + + static inline const EVP_CIPHER *EvpChacha20Poly1305() + { + return SSLAPI::evpChacha20Poly1305(); + } + + static inline EVP_CIPHER_CTX *EvpCipherCtxNew() + { + return SSLAPI::evpCipherCtxNew(); + } + + static inline void EvpCipherCtxFree(EVP_CIPHER_CTX *ctx) + { + SSLAPI::evpCipherCtxFree(ctx); + } + + static inline int EvpCipherCtxCtrl(EVP_CIPHER_CTX *ctx, int type, int arg, void *ptr) + { + return SSLAPI::evpCipherCtxCtrl(ctx, type, arg, ptr); + } + + static inline int EvpEncryptInitEx(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl, + const unsigned char *key, const unsigned char *iv) + { + return SSLAPI::evpEncryptInitEx(ctx, cipher, impl, key, iv); + } + + static inline int EvpEncryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl, const unsigned char *in, + int inl) + { + return SSLAPI::evpEncryptUpdate(ctx, out, outl, in, inl); + } + + static inline int EvpEncryptFinalEx(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl) + { + return SSLAPI::evpEncryptFinalEx(ctx, out, outl); + } + + static inline int EvpDecryptInitEx(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl, + const unsigned char *key, const unsigned char *iv) + { + return SSLAPI::evpDecryptInitEx(ctx, cipher, impl, key, iv); + } + + static inline int EvpDecryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl, const unsigned char *in, + int inl) + { + return SSLAPI::evpDecryptUpdate(ctx, out, outl, in, inl); + } + + static inline int EvpDecryptFinalEx(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl) + { + return SSLAPI::evpDecryptFinalEx(ctx, out, outl); + } + + static inline int RandPoll() + { + return SSLAPI::randPoll(); + } + + static inline int RandStatus() + { + return SSLAPI::randStatus(); + } + + static inline int RandPrivBytes(unsigned char *buf, int num) + { + return SSLAPI::randPrivBytes(buf, num); + } + + static inline int X509VerifyCert(X509_STORE_CTX *ctx) + { + return SSLAPI::x509VerifyCert(ctx); + } + + static inline const char *X509VerifyCertErrorString(long n) + { + return SSLAPI::x509VerifyCertErrorString(n); + } + + static inline int X509StoreCtxGetError(X509_STORE_CTX *ctx) + { + return SSLAPI::x509StoreCtxGetError(ctx); + } + + static inline X509_CRL *PemReadBioX509Crl(BIO *bp, X509_CRL **x, PEM_PASSWORD_CB *cb, void *u) + { + return SSLAPI::pemReadBioX509Crl(bp, x, cb, u); + } + + static inline const BIO_METHOD *BioSFile(void) + { + return SSLAPI::bioSFile(); + } + + static inline BIO *BioNew(const BIO_METHOD *bioMethod) + { + return SSLAPI::bioNew(bioMethod); + } + + static inline int BioCtrl(BIO *bp, int cmd, long larg, void *parg) + { + return SSLAPI::bioCtrl(bp, cmd, larg, parg); + } + + static inline void BioFree(BIO *b) + { + return SSLAPI::bioFree(b); + } + + static inline X509_STORE *X509StoreCtxGet0Store(const X509_STORE_CTX *ctx) + { + return SSLAPI::x509StoreCtxGet0Store(ctx); + } + static inline void X509StoreCtxSetFlags(X509_STORE_CTX *ctx, unsigned long flags) + { + return SSLAPI::x509StoreCtxSetFlags(ctx, flags); + } + static inline int X509StoreAddCrl(X509_STORE *xs, X509_CRL *x) + { + return SSLAPI::x509StoreAddCrl(xs, x); + } + static inline void X509CrlFree(X509_CRL *x) + { + return SSLAPI::x509CrlFree(x); + } + + static inline int Load() + { + return SSLAPI::LoadOpensslAPI(); + } + + static inline void UnLoad() {} +}; +} +} +#endif // HCOM_OPENSSL_API_WRAPPER_H diff --git a/src/under_api/urma/urma_api_dl.cpp b/src/under_api/urma/urma_api_dl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0962e4de0e1e4f6562a58175d4fdb6ed1b857029 --- /dev/null +++ b/src/under_api/urma/urma_api_dl.cpp @@ -0,0 +1,284 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include + +#if defined(TEST_LLT) && defined(MOCK_URMA) +#include "fake_urma.h" +#endif +#include "hcom_log.h" +#include "urma_api_dl.h" + +using namespace ock::hcom; + +URMA_INIT UrmaAPI::hcomInnerUrmaInit = nullptr; +URMA_UNINIT UrmaAPI::hcomInnerUrmaUninit = nullptr; +URMA_GET_DEVICE_LIST UrmaAPI::hcomInnerUrmaGetDeviceList = nullptr; +URMA_FREE_DEVICE_LIST UrmaAPI::hcomInnerUrmaFreeDeviceList = nullptr; +URMA_GET_EID_LIST UrmaAPI::hcomInnerUrmaGetEidList = nullptr; +URMA_FREE_EID_LIST UrmaAPI::hcomInnerUrmaFreeEidList = nullptr; +URMA_GET_DEVICE_BY_NAME UrmaAPI::hcomInnerUrmaGetDeviceByName = nullptr; +URMA_GET_DEVICE_BY_EID UrmaAPI::hcomInnerUrmaGetDeviceByEid = nullptr; +URMA_QUERY_DEVICE UrmaAPI::hcomInnerUrmaQueryDevice = nullptr; +URMA_CREATE_CONTEXT UrmaAPI::hcomInnerUrmaCreateContext = nullptr; +URMA_DELETE_CONTEXT UrmaAPI::hcomInnerUrmaDeleteContext = nullptr; +URMA_CREATE_JFC UrmaAPI::hcomInnerUrmaCreateJfc = nullptr; +URMA_MODIFY_JFC UrmaAPI::hcomInnerUrmaModifyJfc = nullptr; +URMA_DELETE_JFC UrmaAPI::hcomInnerUrmaDeleteJfc = nullptr; +URMA_CREATE_JFS UrmaAPI::hcomInnerUrmaCreateJfs = nullptr; +URMA_MODIFY_JFS UrmaAPI::hcomInnerUrmaModifyJfs = nullptr; +URMA_QUERY_JFS UrmaAPI::hcomInnerUrmaQueryJfs = nullptr; +URMA_DELETE_JFS UrmaAPI::hcomInnerUrmaDeleteJfs = nullptr; +URMA_FLUSH_JFS UrmaAPI::hcomInnerUrmaFlushJfs = nullptr; +URMA_CREATE_JFR UrmaAPI::hcomInnerUrmaCreateJfr = nullptr; +URMA_MODIFY_JFR UrmaAPI::hcomInnerUrmaModifyJfr = nullptr; +URMA_QUERY_JFR UrmaAPI::hcomInnerUrmaQueryJfr = nullptr; +URMA_DELETE_JFR UrmaAPI::hcomInnerUrmaDeleteJfr = nullptr; +URMA_IMPORT_JFR UrmaAPI::hcomInnerUrmaImportJfr = nullptr; +URMA_UNIMPORT_JFR UrmaAPI::hcomInnerUrmaUnimportJfr = nullptr; +URMA_ADVISE_JFR UrmaAPI::hcomInnerUrmaAdviseJfr = nullptr; +URMA_UNADVISE_JFR UrmaAPI::hcomInnerUrmaUnadviseJfr = nullptr; +URMA_CREATE_JETTY UrmaAPI::hcomInnerUrmaCreateJetty = nullptr; +URMA_MODIFY_JETTY UrmaAPI::hcomInnerUrmaModifyJetty = nullptr; +URMA_QUERY_JETTY UrmaAPI::hcomInnerUrmaQueryJetty = nullptr; +URMA_DELETE_JETTY UrmaAPI::hcomInnerUrmaDeleteJetty = nullptr; +URMA_IMPORT_JETTY UrmaAPI::hcomInnerUrmaImportJetty = nullptr; +URMA_UNIMPORT_JETTY UrmaAPI::hcomInnerUrmaUnimportJetty = nullptr; +URMA_ADVISE_JETTY UrmaAPI::hcomInnerUrmaAdviseJetty = nullptr; +URMA_UNADVISE_JETTY UrmaAPI::hcomInnerUrmaUnadviseJetty = nullptr; +URMA_BIND_JETTY UrmaAPI::hcomInnerUrmaBindJetty = nullptr; +URMA_UNBIND_JETTY UrmaAPI::hcomInnerUrmaUnbindJetty = nullptr; +URMA_FLUSH_JETTY UrmaAPI::hcomInnerUrmaFlushJetty = nullptr; +URMA_CREATE_JETTY_GRP UrmaAPI::hcomInnerUrmaCreateJettyGrp = nullptr; +URMA_DELETE_JETTY_GRP UrmaAPI::hcomInnerUrmaDeleteJettyGrp = nullptr; +URMA_CREATE_JFCE UrmaAPI::hcomInnerUrmaCreateJfce = nullptr; +URMA_DELETE_JFCE UrmaAPI::hcomInnerUrmaDeleteJfce = nullptr; +URMA_GET_ASYNC_EVENT UrmaAPI::hcomInnerUrmaGetAsyncEvent = nullptr; +URMA_ACK_ASYNC_EVENT UrmaAPI::hcomInnerUrmaAckAsyncEvent = nullptr; +URMA_ALLOC_TOKEN_ID UrmaAPI::hcomInnerUrmaAllocTokenId = nullptr; +URMA_FREE_TOKEN_ID UrmaAPI::hcomInnerUrmaFreeTokenId = nullptr; +URMA_REGISTER_SEG UrmaAPI::hcomInnerUrmaRegisterSeg = nullptr; +URMA_UNREGISTER_SEG UrmaAPI::hcomInnerUrmaUnregisterSeg = nullptr; +URMA_IMPORT_SEG UrmaAPI::hcomInnerUrmaImportSeg = nullptr; +URMA_UNIMPORT_SEG UrmaAPI::hcomInnerUrmaUnimportSeg = nullptr; +URMA_POST_JFS_WR UrmaAPI::hcomInnerUrmaPostJfsWr = nullptr; +URMA_POST_JFR_WR UrmaAPI::hcomInnerUrmaPostJfrWr = nullptr; +URMA_POST_JETTY_SEND_WR UrmaAPI::hcomInnerUrmaPostJettySendWr = nullptr; +URMA_POST_JETTY_RECV_WR UrmaAPI::hcomInnerUrmaPostJettyRecvWr = nullptr; +URMA_WRITE UrmaAPI::hcomInnerUrmaWrite = nullptr; +URMA_READ UrmaAPI::hcomInnerUrmaRead = nullptr; +URMA_SEND UrmaAPI::hcomInnerUrmaSend = nullptr; +URMA_RECV UrmaAPI::hcomInnerUrmaRecv = nullptr; +URMA_POLL_JFC UrmaAPI::hcomInnerUrmaPollJfc = nullptr; +URMA_REARM_JFC UrmaAPI::hcomInnerUrmaRearmJfc = nullptr; +URMA_WAIT_JFC UrmaAPI::hcomInnerUrmaWaitJfc = nullptr; +URMA_ACK_JFC UrmaAPI::hcomInnerUrmaAckJfc = nullptr; +URMA_USER_CTL UrmaAPI::hcomInnerUrmaUserCtl = nullptr; +URMA_REGISTER_LOG_FUNC UrmaAPI::hcomInnerUrmaRegisterLogFunc = nullptr; +URMA_UNREGISTER_LOG_FUNC UrmaAPI::hcomInnerUrmaUnregisterLogFunc = nullptr; +URMA_LOG_GET_LEVEL UrmaAPI::hcomInnerUrmaLogGetLevel = nullptr; +URMA_LOG_SET_LEVEL UrmaAPI::hcomInnerUrmaLogSetLevel = nullptr; +URMA_STR_TO_EID UrmaAPI::hcomInnerUrmaStrToEid = nullptr; +URMA_LOG_SET_THREAD_TAG UrmaAPI::hcomInnerUrmaLogSetThreadTag = nullptr; + +bool UrmaAPI::gLoaded = false; + +bool UrmaAPI::IsLoaded() +{ + return gLoaded; +} + +#if !defined(TEST_LLT) || !defined(MOCK_URMA) +#define DLSYM(type, ptr, sym) \ + do { \ + auto ptr1 = dlsym(handle, sym); \ + if (ptr1 == nullptr) { \ + NN_LOG_ERROR("Failed to load function " << sym << ", error " << dlerror()); \ + dlclose(handle); \ + return -1; \ + } \ + ptr = (type)ptr1; \ + } while (0) + +int UrmaAPI::LoadUrmaAPI() +{ + if (gLoaded) { + return 0; + } + + // UBC 多路径使用虚拟聚合设备,依赖 liburma_bond.so, 而它又会在 liburma.so 中隐式地打开,且依赖 liburma.so 中的符号。 + void *handle = dlopen(URMA_SO_PATH, RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + NN_LOG_ERROR("Failed to load verbs so " << URMA_SO_PATH << ", error " << dlerror()); + return -1; + } + + DLSYM(URMA_INIT, UrmaAPI::hcomInnerUrmaInit, "urma_init"); + DLSYM(URMA_UNINIT, UrmaAPI::hcomInnerUrmaUninit, "urma_uninit"); + DLSYM(URMA_GET_DEVICE_LIST, UrmaAPI::hcomInnerUrmaGetDeviceList, "urma_get_device_list"); + DLSYM(URMA_FREE_DEVICE_LIST, UrmaAPI::hcomInnerUrmaFreeDeviceList, "urma_free_device_list"); + DLSYM(URMA_GET_EID_LIST, UrmaAPI::hcomInnerUrmaGetEidList, "urma_get_eid_list"); + DLSYM(URMA_FREE_EID_LIST, UrmaAPI::hcomInnerUrmaFreeEidList, "urma_free_eid_list"); + DLSYM(URMA_GET_DEVICE_BY_NAME, UrmaAPI::hcomInnerUrmaGetDeviceByName, "urma_get_device_by_name"); + DLSYM(URMA_GET_DEVICE_BY_EID, UrmaAPI::hcomInnerUrmaGetDeviceByEid, "urma_get_device_by_eid"); + DLSYM(URMA_QUERY_DEVICE, UrmaAPI::hcomInnerUrmaQueryDevice, "urma_query_device"); + DLSYM(URMA_CREATE_CONTEXT, UrmaAPI::hcomInnerUrmaCreateContext, "urma_create_context"); + DLSYM(URMA_DELETE_CONTEXT, UrmaAPI::hcomInnerUrmaDeleteContext, "urma_delete_context"); + DLSYM(URMA_CREATE_JFC, UrmaAPI::hcomInnerUrmaCreateJfc, "urma_create_jfc"); + DLSYM(URMA_MODIFY_JFC, UrmaAPI::hcomInnerUrmaModifyJfc, "urma_modify_jfc"); + DLSYM(URMA_DELETE_JFC, UrmaAPI::hcomInnerUrmaDeleteJfc, "urma_delete_jfc"); + DLSYM(URMA_CREATE_JFS, UrmaAPI::hcomInnerUrmaCreateJfs, "urma_create_jfs"); + DLSYM(URMA_MODIFY_JFS, UrmaAPI::hcomInnerUrmaModifyJfs, "urma_modify_jfs"); + DLSYM(URMA_QUERY_JFS, UrmaAPI::hcomInnerUrmaQueryJfs, "urma_query_jfs"); + DLSYM(URMA_DELETE_JFS, UrmaAPI::hcomInnerUrmaDeleteJfs, "urma_delete_jfs"); + DLSYM(URMA_FLUSH_JFS, UrmaAPI::hcomInnerUrmaFlushJfs, "urma_flush_jfs"); + DLSYM(URMA_CREATE_JFR, UrmaAPI::hcomInnerUrmaCreateJfr, "urma_create_jfr"); + DLSYM(URMA_MODIFY_JFR, UrmaAPI::hcomInnerUrmaModifyJfr, "urma_modify_jfr"); + DLSYM(URMA_QUERY_JFR, UrmaAPI::hcomInnerUrmaQueryJfr, "urma_query_jfr"); + DLSYM(URMA_DELETE_JFR, UrmaAPI::hcomInnerUrmaDeleteJfr, "urma_delete_jfr"); + DLSYM(URMA_IMPORT_JFR, UrmaAPI::hcomInnerUrmaImportJfr, "urma_import_jfr"); + DLSYM(URMA_UNIMPORT_JFR, UrmaAPI::hcomInnerUrmaUnimportJfr, "urma_unimport_jfr"); + DLSYM(URMA_ADVISE_JFR, UrmaAPI::hcomInnerUrmaAdviseJfr, "urma_advise_jfr"); + DLSYM(URMA_UNADVISE_JFR, UrmaAPI::hcomInnerUrmaUnadviseJfr, "urma_unadvise_jfr"); + DLSYM(URMA_CREATE_JETTY, UrmaAPI::hcomInnerUrmaCreateJetty, "urma_create_jetty"); + DLSYM(URMA_MODIFY_JETTY, UrmaAPI::hcomInnerUrmaModifyJetty, "urma_modify_jetty"); + DLSYM(URMA_QUERY_JETTY, UrmaAPI::hcomInnerUrmaQueryJetty, "urma_query_jetty"); + DLSYM(URMA_DELETE_JETTY, UrmaAPI::hcomInnerUrmaDeleteJetty, "urma_delete_jetty"); + DLSYM(URMA_IMPORT_JETTY, UrmaAPI::hcomInnerUrmaImportJetty, "urma_import_jetty"); + DLSYM(URMA_UNIMPORT_JETTY, UrmaAPI::hcomInnerUrmaUnimportJetty, "urma_unimport_jetty"); + DLSYM(URMA_ADVISE_JETTY, UrmaAPI::hcomInnerUrmaAdviseJetty, "urma_advise_jetty"); + DLSYM(URMA_UNADVISE_JETTY, UrmaAPI::hcomInnerUrmaUnadviseJetty, "urma_unadvise_jetty"); + DLSYM(URMA_BIND_JETTY, UrmaAPI::hcomInnerUrmaBindJetty, "urma_bind_jetty"); + DLSYM(URMA_UNBIND_JETTY, UrmaAPI::hcomInnerUrmaUnbindJetty, "urma_unbind_jetty"); + DLSYM(URMA_FLUSH_JETTY, UrmaAPI::hcomInnerUrmaFlushJetty, "urma_flush_jetty"); + DLSYM(URMA_CREATE_JETTY_GRP, UrmaAPI::hcomInnerUrmaCreateJettyGrp, "urma_create_jetty_grp"); + DLSYM(URMA_DELETE_JETTY_GRP, UrmaAPI::hcomInnerUrmaDeleteJettyGrp, "urma_delete_jetty_grp"); + DLSYM(URMA_CREATE_JFCE, UrmaAPI::hcomInnerUrmaCreateJfce, "urma_create_jfce"); + DLSYM(URMA_DELETE_JFCE, UrmaAPI::hcomInnerUrmaDeleteJfce, "urma_delete_jfce"); + DLSYM(URMA_GET_ASYNC_EVENT, UrmaAPI::hcomInnerUrmaGetAsyncEvent, "urma_get_async_event"); + DLSYM(URMA_ACK_ASYNC_EVENT, UrmaAPI::hcomInnerUrmaAckAsyncEvent, "urma_ack_async_event"); + DLSYM(URMA_ALLOC_TOKEN_ID, UrmaAPI::hcomInnerUrmaAllocTokenId, "urma_alloc_token_id"); + DLSYM(URMA_FREE_TOKEN_ID, UrmaAPI::hcomInnerUrmaFreeTokenId, "urma_free_token_id"); + DLSYM(URMA_REGISTER_SEG, UrmaAPI::hcomInnerUrmaRegisterSeg, "urma_register_seg"); + DLSYM(URMA_UNREGISTER_SEG, UrmaAPI::hcomInnerUrmaUnregisterSeg, "urma_unregister_seg"); + DLSYM(URMA_IMPORT_SEG, UrmaAPI::hcomInnerUrmaImportSeg, "urma_import_seg"); + DLSYM(URMA_UNIMPORT_SEG, UrmaAPI::hcomInnerUrmaUnimportSeg, "urma_unimport_seg"); + DLSYM(URMA_POST_JFS_WR, UrmaAPI::hcomInnerUrmaPostJfsWr, "urma_post_jfs_wr"); + DLSYM(URMA_POST_JFR_WR, UrmaAPI::hcomInnerUrmaPostJfrWr, "urma_post_jfr_wr"); + DLSYM(URMA_POST_JETTY_SEND_WR, UrmaAPI::hcomInnerUrmaPostJettySendWr, "urma_post_jetty_send_wr"); + DLSYM(URMA_POST_JETTY_RECV_WR, UrmaAPI::hcomInnerUrmaPostJettyRecvWr, "urma_post_jetty_recv_wr"); + DLSYM(URMA_WRITE, UrmaAPI::hcomInnerUrmaWrite, "urma_write"); + DLSYM(URMA_READ, UrmaAPI::hcomInnerUrmaRead, "urma_read"); + DLSYM(URMA_SEND, UrmaAPI::hcomInnerUrmaSend, "urma_send"); + DLSYM(URMA_RECV, UrmaAPI::hcomInnerUrmaRecv, "urma_recv"); + DLSYM(URMA_POLL_JFC, UrmaAPI::hcomInnerUrmaPollJfc, "urma_poll_jfc"); + DLSYM(URMA_REARM_JFC, UrmaAPI::hcomInnerUrmaRearmJfc, "urma_rearm_jfc"); + DLSYM(URMA_WAIT_JFC, UrmaAPI::hcomInnerUrmaWaitJfc, "urma_wait_jfc"); + DLSYM(URMA_ACK_JFC, UrmaAPI::hcomInnerUrmaAckJfc, "urma_ack_jfc"); + DLSYM(URMA_USER_CTL, UrmaAPI::hcomInnerUrmaUserCtl, "urma_user_ctl"); + DLSYM(URMA_REGISTER_LOG_FUNC, UrmaAPI::hcomInnerUrmaRegisterLogFunc, "urma_register_log_func"); + DLSYM(URMA_UNREGISTER_LOG_FUNC, UrmaAPI::hcomInnerUrmaUnregisterLogFunc, "urma_unregister_log_func"); + DLSYM(URMA_LOG_GET_LEVEL, UrmaAPI::hcomInnerUrmaLogGetLevel, "urma_log_get_level"); + DLSYM(URMA_LOG_SET_LEVEL, UrmaAPI::hcomInnerUrmaLogSetLevel, "urma_log_set_level"); + DLSYM(URMA_STR_TO_EID, UrmaAPI::hcomInnerUrmaStrToEid, "urma_str_to_eid"); + DLSYM(URMA_LOG_SET_THREAD_TAG, UrmaAPI::hcomInnerUrmaLogSetThreadTag, "urma_log_set_thread_tag"); + + NN_LOG_INFO("Success to load urma api"); + gLoaded = true; + + return 0; +} +#else +int UrmaAPI::LoadUrmaAPI() +{ + if (gLoaded) { + return 0; + } + + UrmaAPI::hcomInnerUrmaInit = urma_init; + UrmaAPI::hcomInnerUrmaUninit = urma_uninit; + UrmaAPI::hcomInnerUrmaGetDeviceList = urma_get_device_list; + UrmaAPI::hcomInnerUrmaFreeDeviceList = urma_free_device_list; + UrmaAPI::hcomInnerUrmaGetEidList = urma_get_eid_list; + UrmaAPI::hcomInnerUrmaFreeEidList = urma_free_eid_list; + UrmaAPI::hcomInnerUrmaGetDeviceByName = urma_get_device_by_name; + UrmaAPI::hcomInnerUrmaGetDeviceByEid = urma_get_device_by_eid; + UrmaAPI::hcomInnerUrmaQueryDevice = urma_query_device; + UrmaAPI::hcomInnerUrmaCreateContext = urma_create_context; + UrmaAPI::hcomInnerUrmaDeleteContext = urma_delete_context; + UrmaAPI::hcomInnerUrmaCreateJfc = urma_create_jfc; + UrmaAPI::hcomInnerUrmaModifyJfc = urma_modify_jfc; + UrmaAPI::hcomInnerUrmaDeleteJfc = urma_delete_jfc; + UrmaAPI::hcomInnerUrmaCreateJfs = urma_create_jfs; + UrmaAPI::hcomInnerUrmaModifyJfs = urma_modify_jfs; + UrmaAPI::hcomInnerUrmaQueryJfs = urma_query_jfs; + UrmaAPI::hcomInnerUrmaDeleteJfs = urma_delete_jfs; + UrmaAPI::hcomInnerUrmaFlushJfs = urma_flush_jfs; + UrmaAPI::hcomInnerUrmaCreateJfr = urma_create_jfr; + UrmaAPI::hcomInnerUrmaModifyJfr = urma_modify_jfr; + UrmaAPI::hcomInnerUrmaQueryJfr = urma_query_jfr; + UrmaAPI::hcomInnerUrmaDeleteJfr = urma_delete_jfr; + UrmaAPI::hcomInnerUrmaImportJfr = urma_import_jfr; + UrmaAPI::hcomInnerUrmaUnimportJfr = urma_unimport_jfr; + UrmaAPI::hcomInnerUrmaAdviseJfr = urma_advise_jfr; + UrmaAPI::hcomInnerUrmaUnadviseJfr = urma_unadvise_jfr; + UrmaAPI::hcomInnerUrmaCreateJetty = urma_create_jetty; + UrmaAPI::hcomInnerUrmaModifyJetty = urma_modify_jetty; + UrmaAPI::hcomInnerUrmaQueryJetty = urma_query_jetty; + UrmaAPI::hcomInnerUrmaDeleteJetty = urma_delete_jetty; + UrmaAPI::hcomInnerUrmaImportJetty = urma_import_jetty; + UrmaAPI::hcomInnerUrmaUnimportJetty = urma_unimport_jetty; + UrmaAPI::hcomInnerUrmaAdviseJetty = urma_advise_jetty; + UrmaAPI::hcomInnerUrmaUnadviseJetty = urma_unadvise_jetty; + UrmaAPI::hcomInnerUrmaBindJetty = urma_bind_jetty; + UrmaAPI::hcomInnerUrmaUnbindJetty = urma_unbind_jetty; + UrmaAPI::hcomInnerUrmaFlushJetty = urma_flush_jetty; + UrmaAPI::hcomInnerUrmaCreateJettyGrp = urma_create_jetty_grp; + UrmaAPI::hcomInnerUrmaDeleteJettyGrp = urma_delete_jetty_grp; + UrmaAPI::hcomInnerUrmaCreateJfce = urma_create_jfce; + UrmaAPI::hcomInnerUrmaDeleteJfce = urma_delete_jfce; + UrmaAPI::hcomInnerUrmaGetAsyncEvent = urma_get_async_event; + UrmaAPI::hcomInnerUrmaAckAsyncEvent = urma_ack_async_event; + UrmaAPI::hcomInnerUrmaAllocTokenId = urma_alloc_token_id; + UrmaAPI::hcomInnerUrmaFreeTokenId = urma_free_token_id; + UrmaAPI::hcomInnerUrmaRegisterSeg = urma_register_seg; + UrmaAPI::hcomInnerUrmaUnregisterSeg = urma_unregister_seg; + UrmaAPI::hcomInnerUrmaImportSeg = urma_import_seg; + UrmaAPI::hcomInnerUrmaUnimportSeg = urma_unimport_seg; + UrmaAPI::hcomInnerUrmaPostJfsWr = urma_post_jfs_wr; + UrmaAPI::hcomInnerUrmaPostJfrWr = urma_post_jfr_wr; + UrmaAPI::hcomInnerUrmaPostJettySendWr = urma_post_jetty_send_wr; + UrmaAPI::hcomInnerUrmaPostJettyRecvWr = urma_post_jetty_recv_wr; + UrmaAPI::hcomInnerUrmaWrite = urma_write; + UrmaAPI::hcomInnerUrmaRead = urma_read; + UrmaAPI::hcomInnerUrmaSend = urma_send; + UrmaAPI::hcomInnerUrmaRecv = urma_recv; + UrmaAPI::hcomInnerUrmaPollJfc = urma_poll_jfc; + UrmaAPI::hcomInnerUrmaRearmJfc = urma_rearm_jfc; + UrmaAPI::hcomInnerUrmaWaitJfc = urma_wait_jfc; + UrmaAPI::hcomInnerUrmaAckJfc = urma_ack_jfc; + UrmaAPI::hcomInnerUrmaUserCtl = urma_user_ctl; + UrmaAPI::hcomInnerUrmaRegisterLogFunc = urma_register_log_func; + UrmaAPI::hcomInnerUrmaUnregisterLogFunc = urma_unregister_log_func; + UrmaAPI::hcomInnerUrmaLogGetLevel = urma_log_get_level; + UrmaAPI::hcomInnerUrmaLogSetLevel = urma_log_set_level; + UrmaAPI::hcomInnerUrmaStrToEid = urma_str_to_eid; + UrmaAPI::hcomInnerUrmaLogSetThreadTag = urma_log_set_thread_tag; + + NN_LOG_INFO("Success to load fake iburma"); + gLoaded = true; + + return 0; +} +#endif + +#endif diff --git a/src/under_api/urma/urma_api_dl.h b/src/under_api/urma/urma_api_dl.h new file mode 100644 index 0000000000000000000000000000000000000000..0bd9fd96e46dc7684d5d9e0b1b5b608dc0422858 --- /dev/null +++ b/src/under_api/urma/urma_api_dl.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_DYLOADER_IURMA_H +#define HCOM_DYLOADER_IURMA_H +#ifdef UB_BUILD_ENABLED + +#include +#include +#include +#include +#include + +#include "ub/umdk/urma/urma_types.h" + +#define URMA_SO_PATH "liburma.so.0" + +using URMA_INIT = urma_status_t (*)(urma_init_attr_t *conf); +using URMA_UNINIT = urma_status_t (*)(void); +using URMA_GET_DEVICE_LIST = urma_device_t **(*)(int *num_devices); +using URMA_FREE_DEVICE_LIST = void (*)(urma_device_t **device_list); +using URMA_GET_EID_LIST = urma_eid_info_t *(*)(urma_device_t *dev, uint32_t *cnt); +using URMA_FREE_EID_LIST = void (*)(urma_eid_info_t *eid_list); +using URMA_GET_DEVICE_BY_NAME = urma_device_t *(*)(char *dev_name); +using URMA_GET_DEVICE_BY_EID = urma_device_t *(*)(urma_eid_t eid, urma_transport_type_t type); +using URMA_QUERY_DEVICE = urma_status_t (*)(urma_device_t *dev, urma_device_attr_t *dev_attr); +using URMA_CREATE_CONTEXT = urma_context_t *(*)(urma_device_t *dev, uint32_t eid_index); +using URMA_DELETE_CONTEXT = urma_status_t (*)(urma_context_t *ctx); +using URMA_CREATE_JFC = urma_jfc_t *(*)(urma_context_t *ctx, urma_jfc_cfg_t *jfc_cfg); +using URMA_MODIFY_JFC = urma_status_t (*)(urma_jfc_t *jfc, urma_jfc_attr_t *attr); +using URMA_DELETE_JFC = urma_status_t (*)(urma_jfc_t *jfc); +using URMA_CREATE_JFS = urma_jfs_t *(*)(urma_context_t *ctx, urma_jfs_cfg_t *jfs_cfg); +using URMA_MODIFY_JFS = urma_status_t (*)(urma_jfs_t *jfs, urma_jfs_attr_t *attr); +using URMA_QUERY_JFS = urma_status_t (*)(urma_jfs_t *jfs, urma_jfs_cfg_t *cfg, urma_jfs_attr_t *attr); +using URMA_DELETE_JFS = urma_status_t (*)(urma_jfs_t *jfs); +using URMA_FLUSH_JFS = int (*)(urma_jfs_t *jfs, int cr_cnt, urma_cr_t *cr); +using URMA_CREATE_JFR = urma_jfr_t *(*)(urma_context_t *ctx, urma_jfr_cfg_t *jfr_cfg); +using URMA_MODIFY_JFR = urma_status_t (*)(urma_jfr_t *jfr, urma_jfr_attr_t *attr); +using URMA_QUERY_JFR = urma_status_t (*)(urma_jfr_t *jfr, urma_jfr_cfg_t *cfg, urma_jfr_attr_t *attr); +using URMA_DELETE_JFR = urma_status_t (*)(urma_jfr_t *jfr); +using URMA_IMPORT_JFR = urma_target_jetty_t *(*)(urma_context_t *ctx, urma_rjfr_t *rjfr, urma_token_t *token_value); +using URMA_UNIMPORT_JFR = urma_status_t (*)(urma_target_jetty_t *target_jfr); +using URMA_ADVISE_JFR = urma_status_t (*)(urma_jfs_t *jfs, urma_target_jetty_t *tjfr); +using URMA_UNADVISE_JFR = urma_status_t (*)(urma_jfs_t *jfs, urma_target_jetty_t *tjfr); +using URMA_CREATE_JETTY = urma_jetty_t *(*)(urma_context_t *ctx, urma_jetty_cfg_t *jetty_cfg); +using URMA_MODIFY_JETTY = urma_status_t (*)(urma_jetty_t *jetty, urma_jetty_attr_t *attr); +using URMA_QUERY_JETTY = urma_status_t (*)(urma_jetty_t *jetty, urma_jetty_cfg_t *cfg, urma_jetty_attr_t *attr); +using URMA_DELETE_JETTY = urma_status_t (*)(urma_jetty_t *jetty); +using URMA_IMPORT_JETTY = urma_target_jetty_t *(*)(urma_context_t *ctx, urma_rjetty_t *rjetty, + urma_token_t *token_value); +using URMA_UNIMPORT_JETTY = urma_status_t (*)(urma_target_jetty_t *tjetty); +using URMA_ADVISE_JETTY = urma_status_t (*)(urma_jetty_t *jetty, urma_target_jetty_t *tjetty); +using URMA_UNADVISE_JETTY = urma_status_t (*)(urma_jetty_t *jetty, urma_target_jetty_t *tjetty); +using URMA_BIND_JETTY = urma_status_t (*)(urma_jetty_t *jetty, urma_target_jetty_t *tjetty); +using URMA_UNBIND_JETTY = urma_status_t (*)(urma_jetty_t *jetty); +using URMA_FLUSH_JETTY = int (*)(urma_jetty_t *jetty, int cr_cnt, urma_cr_t *cr); +using URMA_CREATE_JETTY_GRP = urma_jetty_grp_t *(*)(urma_context_t *ctx, urma_jetty_grp_cfg_t *cfg); +using URMA_DELETE_JETTY_GRP = urma_status_t (*)(urma_jetty_grp_t *jetty_grp); +using URMA_CREATE_JFCE = urma_jfce_t *(*)(urma_context_t *ctx); +using URMA_DELETE_JFCE = urma_status_t (*)(urma_jfce_t *jfce); +using URMA_GET_ASYNC_EVENT = urma_status_t (*)(urma_context_t *ctx, urma_async_event_t *event); +using URMA_ACK_ASYNC_EVENT = void (*)(urma_async_event_t *event); +using URMA_ALLOC_TOKEN_ID = urma_token_id_t *(*)(urma_context_t *ctx); +using URMA_FREE_TOKEN_ID = urma_status_t (*)(urma_token_id_t *token_id); +using URMA_REGISTER_SEG = urma_target_seg_t *(*)(urma_context_t *ctx, urma_seg_cfg_t *seg_cfg); +using URMA_UNREGISTER_SEG = urma_status_t (*)(urma_target_seg_t *target_seg); +using URMA_IMPORT_SEG = urma_target_seg_t *(*)(urma_context_t *ctx, urma_seg_t *seg, urma_token_t *token_value, + uint64_t addr, urma_import_seg_flag_t flag); +using URMA_UNIMPORT_SEG = urma_status_t (*)(urma_target_seg_t *tseg); +using URMA_POST_JFS_WR = urma_status_t (*)(urma_jfs_t *jfs, urma_jfs_wr_t *wr, urma_jfs_wr_t **bad_wr); +using URMA_POST_JFR_WR = urma_status_t (*)(urma_jfr_t *jfr, urma_jfr_wr_t *wr, urma_jfr_wr_t **bad_wr); +using URMA_POST_JETTY_SEND_WR = urma_status_t (*)(urma_jetty_t *jetty, urma_jfs_wr_t *wr, urma_jfs_wr_t **bad_wr); +using URMA_POST_JETTY_RECV_WR = urma_status_t (*)(urma_jetty_t *jetty, urma_jfr_wr_t *wr, urma_jfr_wr_t **bad_wr); +using URMA_WRITE = urma_status_t (*)(urma_jfs_t *jfs, urma_target_jetty_t *target_jfr, urma_target_seg_t *dst_tseg, + urma_target_seg_t *src_tseg, uint64_t dst, uint64_t src, uint32_t len, urma_jfs_wr_flag_t flag, uint64_t user_ctx); +using URMA_READ = urma_status_t (*)(urma_jfs_t *jfs, urma_target_jetty_t *target_jfr, urma_target_seg_t *dst_tseg, + urma_target_seg_t *src_tseg, uint64_t dst, uint64_t src, uint32_t len, urma_jfs_wr_flag_t flag, uint64_t user_ctx); +using URMA_SEND = urma_status_t (*)(urma_jfs_t *jfs, urma_target_jetty_t *target_jfr, urma_target_seg_t *src_tseg, + uint64_t src, uint32_t len, urma_jfs_wr_flag_t flag, uint64_t user_ctx); +using URMA_RECV = urma_status_t (*)(urma_jfr_t *jfr, urma_target_seg_t *recv_tseg, uint64_t buf, uint32_t len, + uint64_t user_ctx); +using URMA_POLL_JFC = int (*)(urma_jfc_t *jfc, int cr_cnt, urma_cr_t *cr); +using URMA_REARM_JFC = urma_status_t (*)(urma_jfc_t *jfc, bool solicited_only); +using URMA_WAIT_JFC = int (*)(urma_jfce_t *jfce, uint32_t jfc_cnt, int time_out, urma_jfc_t *jfc[]); +using URMA_ACK_JFC = void (*)(urma_jfc_t *jfc[], uint32_t nevents[], uint32_t jfc_cnt); +using URMA_USER_CTL = urma_status_t (*)(urma_context_t *ctx, urma_user_ctl_in_t *in, urma_user_ctl_out_t *out); +using URMA_REGISTER_LOG_FUNC = urma_status_t (*)(urma_log_cb_t func); +using URMA_UNREGISTER_LOG_FUNC = urma_status_t (*)(void); +using URMA_LOG_GET_LEVEL = urma_vlog_level_t (*)(void); +using URMA_LOG_SET_LEVEL = void (*)(urma_vlog_level_t level); +using URMA_STR_TO_EID = int (*)(const char *buf, urma_eid_t *eid); +using URMA_LOG_SET_THREAD_TAG = void (*)(const char *tag); + +class UrmaAPI { +public: + static URMA_INIT hcomInnerUrmaInit; + static URMA_UNINIT hcomInnerUrmaUninit; + static URMA_GET_DEVICE_LIST hcomInnerUrmaGetDeviceList; + static URMA_FREE_DEVICE_LIST hcomInnerUrmaFreeDeviceList; + static URMA_GET_EID_LIST hcomInnerUrmaGetEidList; + static URMA_FREE_EID_LIST hcomInnerUrmaFreeEidList; + static URMA_GET_DEVICE_BY_NAME hcomInnerUrmaGetDeviceByName; + static URMA_GET_DEVICE_BY_EID hcomInnerUrmaGetDeviceByEid; + static URMA_QUERY_DEVICE hcomInnerUrmaQueryDevice; + static URMA_CREATE_CONTEXT hcomInnerUrmaCreateContext; + static URMA_DELETE_CONTEXT hcomInnerUrmaDeleteContext; + static URMA_CREATE_JFC hcomInnerUrmaCreateJfc; + static URMA_MODIFY_JFC hcomInnerUrmaModifyJfc; + static URMA_DELETE_JFC hcomInnerUrmaDeleteJfc; + static URMA_CREATE_JFS hcomInnerUrmaCreateJfs; + static URMA_MODIFY_JFS hcomInnerUrmaModifyJfs; + static URMA_QUERY_JFS hcomInnerUrmaQueryJfs; + static URMA_DELETE_JFS hcomInnerUrmaDeleteJfs; + static URMA_FLUSH_JFS hcomInnerUrmaFlushJfs; + static URMA_CREATE_JFR hcomInnerUrmaCreateJfr; + static URMA_MODIFY_JFR hcomInnerUrmaModifyJfr; + static URMA_QUERY_JFR hcomInnerUrmaQueryJfr; + static URMA_DELETE_JFR hcomInnerUrmaDeleteJfr; + static URMA_IMPORT_JFR hcomInnerUrmaImportJfr; + static URMA_UNIMPORT_JFR hcomInnerUrmaUnimportJfr; + static URMA_ADVISE_JFR hcomInnerUrmaAdviseJfr; + static URMA_UNADVISE_JFR hcomInnerUrmaUnadviseJfr; + static URMA_CREATE_JETTY hcomInnerUrmaCreateJetty; + static URMA_MODIFY_JETTY hcomInnerUrmaModifyJetty; + static URMA_QUERY_JETTY hcomInnerUrmaQueryJetty; + static URMA_DELETE_JETTY hcomInnerUrmaDeleteJetty; + static URMA_IMPORT_JETTY hcomInnerUrmaImportJetty; + static URMA_UNIMPORT_JETTY hcomInnerUrmaUnimportJetty; + static URMA_ADVISE_JETTY hcomInnerUrmaAdviseJetty; + static URMA_UNADVISE_JETTY hcomInnerUrmaUnadviseJetty; + static URMA_BIND_JETTY hcomInnerUrmaBindJetty; + static URMA_UNBIND_JETTY hcomInnerUrmaUnbindJetty; + static URMA_FLUSH_JETTY hcomInnerUrmaFlushJetty; + static URMA_CREATE_JETTY_GRP hcomInnerUrmaCreateJettyGrp; + static URMA_DELETE_JETTY_GRP hcomInnerUrmaDeleteJettyGrp; + static URMA_CREATE_JFCE hcomInnerUrmaCreateJfce; + static URMA_DELETE_JFCE hcomInnerUrmaDeleteJfce; + static URMA_GET_ASYNC_EVENT hcomInnerUrmaGetAsyncEvent; + static URMA_ACK_ASYNC_EVENT hcomInnerUrmaAckAsyncEvent; + static URMA_ALLOC_TOKEN_ID hcomInnerUrmaAllocTokenId; + static URMA_FREE_TOKEN_ID hcomInnerUrmaFreeTokenId; + static URMA_REGISTER_SEG hcomInnerUrmaRegisterSeg; + static URMA_UNREGISTER_SEG hcomInnerUrmaUnregisterSeg; + static URMA_IMPORT_SEG hcomInnerUrmaImportSeg; + static URMA_UNIMPORT_SEG hcomInnerUrmaUnimportSeg; + static URMA_POST_JFS_WR hcomInnerUrmaPostJfsWr; + static URMA_POST_JFR_WR hcomInnerUrmaPostJfrWr; + static URMA_POST_JETTY_SEND_WR hcomInnerUrmaPostJettySendWr; + static URMA_POST_JETTY_RECV_WR hcomInnerUrmaPostJettyRecvWr; + static URMA_WRITE hcomInnerUrmaWrite; + static URMA_READ hcomInnerUrmaRead; + static URMA_SEND hcomInnerUrmaSend; + static URMA_RECV hcomInnerUrmaRecv; + static URMA_POLL_JFC hcomInnerUrmaPollJfc; + static URMA_REARM_JFC hcomInnerUrmaRearmJfc; + static URMA_WAIT_JFC hcomInnerUrmaWaitJfc; + static URMA_ACK_JFC hcomInnerUrmaAckJfc; + static URMA_USER_CTL hcomInnerUrmaUserCtl; + static URMA_REGISTER_LOG_FUNC hcomInnerUrmaRegisterLogFunc; + static URMA_UNREGISTER_LOG_FUNC hcomInnerUrmaUnregisterLogFunc; + static URMA_LOG_GET_LEVEL hcomInnerUrmaLogGetLevel; + static URMA_LOG_SET_LEVEL hcomInnerUrmaLogSetLevel; + static URMA_STR_TO_EID hcomInnerUrmaStrToEid; + static URMA_LOG_SET_THREAD_TAG hcomInnerUrmaLogSetThreadTag; + + static bool IsLoaded(); + +#if defined(TEST_LLT) && defined(MOCK_VERBS) + static int LoadUrmaAPI(); +#else + static int LoadUrmaAPI(); +#endif + +private: + static bool gLoaded; +}; + +#endif +#endif // HCOM_DYLOADER_IURMA_H diff --git a/src/under_api/urma/urma_api_wrapper.h b/src/under_api/urma/urma_api_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..4d4fb3cf555c0802943103f720f5f58158abaa3a --- /dev/null +++ b/src/under_api/urma/urma_api_wrapper.h @@ -0,0 +1,399 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_URMA_API_WRAPPER_H +#define HCOM_URMA_API_WRAPPER_H +#ifdef UB_BUILD_ENABLED + +#include "urma_api_dl.h" + +namespace ock { +namespace hcom { +class HcomUrma { +public: + static inline urma_status_t Init(urma_init_attr_t *conf) + { + return UrmaAPI::hcomInnerUrmaInit(conf); + } + + static inline urma_status_t Uninit(void) + { + return UrmaAPI::hcomInnerUrmaUninit(); + } + + static inline urma_device_t **GetDeviceList(int *num_devices) + { + return UrmaAPI::hcomInnerUrmaGetDeviceList(num_devices); + } + + static inline void FreeDeviceList(urma_device_t **device_list) + { + return UrmaAPI::hcomInnerUrmaFreeDeviceList(device_list); + } + + static inline urma_eid_info_t *GetEidList(urma_device_t *dev, uint32_t *cnt) + { + return UrmaAPI::hcomInnerUrmaGetEidList(dev, cnt); + } + + static inline void FreeEidList(urma_eid_info_t *eid_list) + { + return UrmaAPI::hcomInnerUrmaFreeEidList(eid_list); + } + + static inline urma_device_t *GetDeviceByName(char *dev_name) + { + return UrmaAPI::hcomInnerUrmaGetDeviceByName(dev_name); + } + + static inline urma_device_t *GetDeviceByEid(urma_eid_t eid, urma_transport_type_t type) + { + return UrmaAPI::hcomInnerUrmaGetDeviceByEid(eid, type); + } + + static inline urma_status_t QueryDevice(urma_device_t *dev, urma_device_attr_t *dev_attr) + { + return UrmaAPI::hcomInnerUrmaQueryDevice(dev, dev_attr); + } + + static inline urma_context_t *CreateContext(urma_device_t *dev, uint32_t eid_index) + { + return UrmaAPI::hcomInnerUrmaCreateContext(dev, eid_index); + } + + static inline urma_status_t DeleteContext(urma_context_t *ctx) + { + return UrmaAPI::hcomInnerUrmaDeleteContext(ctx); + } + + static inline urma_jfc_t *CreateJfc(urma_context_t *ctx, urma_jfc_cfg_t *jfc_cfg) + { + return UrmaAPI::hcomInnerUrmaCreateJfc(ctx, jfc_cfg); + } + + static inline urma_status_t ModifyJfc(urma_jfc_t *jfc, urma_jfc_attr_t *attr) + { + return UrmaAPI::hcomInnerUrmaModifyJfc(jfc, attr); + } + + static inline urma_status_t DeleteJfc(urma_jfc_t *jfc) + { + return UrmaAPI::hcomInnerUrmaDeleteJfc(jfc); + } + + static inline urma_jfs_t *CreateJfs(urma_context_t *ctx, urma_jfs_cfg_t *jfs_cfg) + { + return UrmaAPI::hcomInnerUrmaCreateJfs(ctx, jfs_cfg); + } + + static inline urma_status_t ModifyJfs(urma_jfs_t *jfs, urma_jfs_attr_t *attr) + { + return UrmaAPI::hcomInnerUrmaModifyJfs(jfs, attr); + } + + static inline urma_status_t QueryJfs(urma_jfs_t *jfs, urma_jfs_cfg_t *cfg, urma_jfs_attr_t *attr) + { + return UrmaAPI::hcomInnerUrmaQueryJfs(jfs, cfg, attr); + } + + static inline urma_status_t DeleteJfs(urma_jfs_t *jfs) + { + return UrmaAPI::hcomInnerUrmaDeleteJfs(jfs); + } + + static inline int FlushJfs(urma_jfs_t *jfs, int cr_cnt, urma_cr_t *cr) + { + return UrmaAPI::hcomInnerUrmaFlushJfs(jfs, cr_cnt, cr); + } + + static inline urma_jfr_t *CreateJfr(urma_context_t *ctx, urma_jfr_cfg_t *jfr_cfg) + { + return UrmaAPI::hcomInnerUrmaCreateJfr(ctx, jfr_cfg); + } + + static inline urma_status_t ModifyJfr(urma_jfr_t *jfr, urma_jfr_attr_t *attr) + { + return UrmaAPI::hcomInnerUrmaModifyJfr(jfr, attr); + } + + static inline urma_status_t QueryJfr(urma_jfr_t *jfr, urma_jfr_cfg_t *cfg, urma_jfr_attr_t *attr) + { + return UrmaAPI::hcomInnerUrmaQueryJfr(jfr, cfg, attr); + } + + static inline urma_status_t DeleteJfr(urma_jfr_t *jfr) + { + return UrmaAPI::hcomInnerUrmaDeleteJfr(jfr); + } + + static inline urma_target_jetty_t *ImportJfr(urma_context_t *ctx, urma_rjfr_t *rjfr, urma_token_t *token_value) + { + return UrmaAPI::hcomInnerUrmaImportJfr(ctx, rjfr, token_value); + } + + static inline urma_status_t UnimportJfr(urma_target_jetty_t *target_jfr) + { + return UrmaAPI::hcomInnerUrmaUnimportJfr(target_jfr); + } + + static inline urma_status_t AdviseJfr(urma_jfs_t *jfs, urma_target_jetty_t *tjfr) + { + return UrmaAPI::hcomInnerUrmaAdviseJfr(jfs, tjfr); + } + + static inline urma_status_t UnadviseJfr(urma_jfs_t *jfs, urma_target_jetty_t *tjfr) + { + return UrmaAPI::hcomInnerUrmaUnadviseJfr(jfs, tjfr); + } + + static inline urma_jetty_t *CreateJetty(urma_context_t *ctx, urma_jetty_cfg_t *jetty_cfg) + { + return UrmaAPI::hcomInnerUrmaCreateJetty(ctx, jetty_cfg); + } + + static inline urma_status_t ModifyJetty(urma_jetty_t *jetty, urma_jetty_attr_t *attr) + { + return UrmaAPI::hcomInnerUrmaModifyJetty(jetty, attr); + } + + static inline urma_status_t QueryJetty(urma_jetty_t *jetty, urma_jetty_cfg_t *cfg, urma_jetty_attr_t *attr) + { + return UrmaAPI::hcomInnerUrmaQueryJetty(jetty, cfg, attr); + } + + static inline urma_status_t DeleteJetty(urma_jetty_t *jetty) + { + return UrmaAPI::hcomInnerUrmaDeleteJetty(jetty); + } + + static inline urma_target_jetty_t *ImportJetty(urma_context_t *ctx, urma_rjetty_t *rjetty, + urma_token_t *token_value) + { + return UrmaAPI::hcomInnerUrmaImportJetty(ctx, rjetty, token_value); + } + + static inline urma_status_t UnimportJetty(urma_target_jetty_t *tjetty) + { + return UrmaAPI::hcomInnerUrmaUnimportJetty(tjetty); + } + + static inline urma_status_t AdviseJetty(urma_jetty_t *jetty, urma_target_jetty_t *tjetty) + { + return UrmaAPI::hcomInnerUrmaAdviseJetty(jetty, tjetty); + } + + static inline urma_status_t UnadviseJetty(urma_jetty_t *jetty, urma_target_jetty_t *tjetty) + { + return UrmaAPI::hcomInnerUrmaUnadviseJetty(jetty, tjetty); + } + + static inline urma_status_t BindJetty(urma_jetty_t *jetty, urma_target_jetty_t *tjetty) + { + return UrmaAPI::hcomInnerUrmaBindJetty(jetty, tjetty); + } + + static inline urma_status_t UnbindJetty(urma_jetty_t *jetty) + { + return UrmaAPI::hcomInnerUrmaUnbindJetty(jetty); + } + + static inline int FlushJetty(urma_jetty_t *jetty, int cr_cnt, urma_cr_t *cr) + { + return UrmaAPI::hcomInnerUrmaFlushJetty(jetty, cr_cnt, cr); + } + + static inline urma_jetty_grp_t *CreateJettyGrp(urma_context_t *ctx, urma_jetty_grp_cfg_t *cfg) + { + return UrmaAPI::hcomInnerUrmaCreateJettyGrp(ctx, cfg); + } + + static inline urma_status_t DeleteJettyGrp(urma_jetty_grp_t *jetty_grp) + { + return UrmaAPI::hcomInnerUrmaDeleteJettyGrp(jetty_grp); + } + + static inline urma_jfce_t *CreateJfce(urma_context_t *ctx) + { + return UrmaAPI::hcomInnerUrmaCreateJfce(ctx); + } + + static inline urma_status_t DeleteJfce(urma_jfce_t *jfce) + { + return UrmaAPI::hcomInnerUrmaDeleteJfce(jfce); + } + + static inline urma_status_t GetAsyncEvent(urma_context_t *ctx, urma_async_event_t *event) + { + return UrmaAPI::hcomInnerUrmaGetAsyncEvent(ctx, event); + } + + static inline void AckAsyncEvent(urma_async_event_t *event) + { + return UrmaAPI::hcomInnerUrmaAckAsyncEvent(event); + } + + static inline urma_token_id_t *AllocTokenId(urma_context_t *ctx) + { + return UrmaAPI::hcomInnerUrmaAllocTokenId(ctx); + } + + static inline urma_status_t FreeTokenId(urma_token_id_t *token_id) + { + return UrmaAPI::hcomInnerUrmaFreeTokenId(token_id); + } + + static inline urma_target_seg_t *RegisterSeg(urma_context_t *ctx, urma_seg_cfg_t *seg_cfg) + { + return UrmaAPI::hcomInnerUrmaRegisterSeg(ctx, seg_cfg); + } + + static inline urma_status_t UnregisterSeg(urma_target_seg_t *target_seg) + { + return UrmaAPI::hcomInnerUrmaUnregisterSeg(target_seg); + } + + static inline urma_target_seg_t *ImportSeg(urma_context_t *ctx, urma_seg_t *seg, urma_token_t *token_value, + uint64_t addr, urma_import_seg_flag_t flag) + { + return UrmaAPI::hcomInnerUrmaImportSeg(ctx, seg, token_value, addr, flag); + } + + static inline urma_status_t UnimportSeg(urma_target_seg_t *tseg) + { + return UrmaAPI::hcomInnerUrmaUnimportSeg(tseg); + } + + static inline urma_status_t PostJfsWr(urma_jfs_t *jfs, urma_jfs_wr_t *wr, urma_jfs_wr_t **bad_wr) + { + return UrmaAPI::hcomInnerUrmaPostJfsWr(jfs, wr, bad_wr); + } + + static inline urma_status_t PostJfrWr(urma_jfr_t *jfr, urma_jfr_wr_t *wr, urma_jfr_wr_t **bad_wr) + { + return UrmaAPI::hcomInnerUrmaPostJfrWr(jfr, wr, bad_wr); + } + + static inline urma_status_t PostJettySendWr(urma_jetty_t *jetty, urma_jfs_wr_t *wr, urma_jfs_wr_t **bad_wr) + { + return UrmaAPI::hcomInnerUrmaPostJettySendWr(jetty, wr, bad_wr); + } + + static inline urma_status_t PostJettySendWr( + urma_jetty_t *jetty, urma_jfs_wr_t *wr, uint32_t wrCnt, urma_jfs_wr_t **bad_wr) + { + return UrmaAPI::hcomInnerUrmaPostJettySendWr(jetty, wr, bad_wr); + } + + static inline urma_status_t PostJettyRecvWr(urma_jetty_t *jetty, urma_jfr_wr_t *wr, urma_jfr_wr_t **bad_wr) + { + return UrmaAPI::hcomInnerUrmaPostJettyRecvWr(jetty, wr, bad_wr); + } + + static inline urma_status_t Write(urma_jfs_t *jfs, urma_target_jetty_t *target_jfr, urma_target_seg_t *dst_tseg, + urma_target_seg_t *src_tseg, uint64_t dst, uint64_t src, uint32_t len, urma_jfs_wr_flag_t flag, + uint64_t user_ctx) + { + return UrmaAPI::hcomInnerUrmaWrite(jfs, target_jfr, dst_tseg, src_tseg, dst, src, len, flag, user_ctx); + } + + static inline urma_status_t Read(urma_jfs_t *jfs, urma_target_jetty_t *target_jfr, urma_target_seg_t *dst_tseg, + urma_target_seg_t *src_tseg, uint64_t dst, uint64_t src, uint32_t len, urma_jfs_wr_flag_t flag, + uint64_t user_ctx) + { + return UrmaAPI::hcomInnerUrmaRead(jfs, target_jfr, dst_tseg, src_tseg, dst, src, len, flag, user_ctx); + } + + static inline urma_status_t Send(urma_jfs_t *jfs, urma_target_jetty_t *target_jfr, urma_target_seg_t *src_tseg, + uint64_t src, uint32_t len, urma_jfs_wr_flag_t flag, uint64_t user_ctx) + { + return UrmaAPI::hcomInnerUrmaSend(jfs, target_jfr, src_tseg, src, len, flag, user_ctx); + } + + static inline urma_status_t Recv(urma_jfr_t *jfr, urma_target_seg_t *recv_tseg, uint64_t buf, uint32_t len, + uint64_t user_ctx) + { + return UrmaAPI::hcomInnerUrmaRecv(jfr, recv_tseg, buf, len, user_ctx); + } + + static inline int PollJfc(urma_jfc_t *jfc, int cr_cnt, urma_cr_t *cr) + { + return UrmaAPI::hcomInnerUrmaPollJfc(jfc, cr_cnt, cr); + } + + static inline urma_status_t RearmJfc(urma_jfc_t *jfc, bool solicited_only) + { + return UrmaAPI::hcomInnerUrmaRearmJfc(jfc, solicited_only); + } + + static inline int WaitJfc(urma_jfce_t *jfce, uint32_t jfc_cnt, int time_out, urma_jfc_t *jfc[]) + { + return UrmaAPI::hcomInnerUrmaWaitJfc(jfce, jfc_cnt, time_out, jfc); + } + + static inline void AckJfc(urma_jfc_t *jfc[], uint32_t nevents[], uint32_t jfc_cnt) + { + return UrmaAPI::hcomInnerUrmaAckJfc(jfc, nevents, jfc_cnt); + } + + static inline urma_status_t UserCtl(urma_context_t *ctx, urma_user_ctl_in_t *in, urma_user_ctl_out_t *out) + { + return UrmaAPI::hcomInnerUrmaUserCtl(ctx, in, out); + } + + static inline urma_status_t RegisterLogFunc(urma_log_cb_t func) + { + return UrmaAPI::hcomInnerUrmaRegisterLogFunc(func); + } + + static inline urma_status_t UnregisterLogFunc(void) + { + return UrmaAPI::hcomInnerUrmaUnregisterLogFunc(); + } + + static inline urma_vlog_level_t LogGetLevel(void) + { + return UrmaAPI::hcomInnerUrmaLogGetLevel(); + } + + static inline void LogSetLevel(urma_vlog_level_t level) + { + return UrmaAPI::hcomInnerUrmaLogSetLevel(level); + } + + static inline int StrToEid(const char *buf, urma_eid_t *eid) + { + return UrmaAPI::hcomInnerUrmaStrToEid(buf, eid); + } + + static inline void LogSetThreadTag(const char *tag) + { + return UrmaAPI::hcomInnerUrmaLogSetThreadTag(tag); + } + + static inline bool IsLoaded() + { + return UrmaAPI::IsLoaded(); + } + + static inline int Load() + { +#if !defined(TEST_LLT) || !defined(MOCK_VERBS) + return UrmaAPI::LoadUrmaAPI(); +#else + return UrmaAPI::LoadUrmaAPI(); +#endif + } +}; +} +} + +#endif +#endif // HCOM_URMA_API_WRAPPER_H \ No newline at end of file diff --git a/src/under_api/verbs/verbs_api_dl.cpp b/src/under_api/verbs/verbs_api_dl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12a0addb53057e397855845b96d4c97773b36876 --- /dev/null +++ b/src/under_api/verbs/verbs_api_dl.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include + +#if defined(TEST_LLT) && defined(MOCK_VERBS) +#include "fake_ibv.h" +#endif +#include "hcom_log.h" +#include "verbs_api_dl.h" +#include "../../common/net_common.h" +using namespace ock::hcom; + +// ibv APIs +IBV_GET_DEVICE_LIST VerbsAPI::hcomInnerIbvGetDevList = nullptr; +IBV_FORK_INIT VerbsAPI::hcomInnerIbvForkInit = nullptr; +IBV_QUERY_PORT VerbsAPI::hcomInnerIbvQueryPort = nullptr; +IBV_OPEN_DEVICE VerbsAPI::hcomInnerIbvOpenDevice = nullptr; +IBV_ALLOC_PD VerbsAPI::hcomInnerIbvAllocPD = nullptr; +IBV_FREE_DEVICE_LIST VerbsAPI::hcomInnerIbvFreeDevList = nullptr; +IBV_CREATE_COMP_CHANNEL VerbsAPI::hcomInnerIbvCreateCompChannel = nullptr; +IBV_GET_CQ_EVENT VerbsAPI::hcomInnerIbvGetCQEvent = nullptr; +IBV_GET_ASYNC_EVENT VerbsAPI::hcomInnerIbvGetAsyncEvent = nullptr; +IBV_ACK_ASYNC_EVENT VerbsAPI::hcomInnerIbvAckAsyncEvent = nullptr; +IBV_CREATE_QP VerbsAPI::hcomInnerIbvCreateQP = nullptr; +IBV_CLOSE_DEVICE VerbsAPI::hcomInnerIbvCloseDev = nullptr; +IBV_DEALLOC_PD VerbsAPI::hcomInnerIbvDeallocPD = nullptr; +IBV_CREATE_CQ VerbsAPI::hcomInnerCreateCQ = nullptr; +IBV_DESTROY_COMP_CHANNEL VerbsAPI::hcomInnerDestroyCompChannel = nullptr; +IBV_DESTROY_CQ VerbsAPI::hcomInnerDestroyCQ = nullptr; +IBV_ACK_CQ_EVENTS VerbsAPI::hcomInnerAckCQ = nullptr; +IBV_DESTROY_QP VerbsAPI::hcomInnerDestroyQP = nullptr; +IBV_MODIFY_QP VerbsAPI::hcomInnerModityQP = nullptr; +IBV_DEREG_MR VerbsAPI::hcomInnerDeregMr = nullptr; +IBV_QUERY_GID VerbsAPI::hcomInnerQueryGid = nullptr; +IBV_QUERY_DEVICE VerbsAPI::hcomInnerQueryDevice = nullptr; +IBV_PORT_STATE_STR VerbsAPI::hcomInnerPortStateStr = nullptr; +IBV_REG_MR_IOVA2 VerbsAPI::hcomInnerRegMrIOVA2 = nullptr; +IBV_REG_MR VerbsAPI::hcomInnerRegMr = nullptr; + +bool VerbsAPI::gLoaded = false; + +#if !defined(TEST_LLT) || !defined(MOCK_VERBS) +#define DLSYM(type, ptr, sym) \ + do { \ + auto ptr1 = dlsym(handle, sym); \ + if (ptr1 == nullptr) { \ + NN_LOG_ERROR("Failed to load function " << sym << ", error " << dlerror()); \ + dlclose(handle); \ + return -1; \ + } \ + ptr = (type)ptr1; \ + } while (0) + +int VerbsAPI::LoadVerbsAPI() +{ + if (gLoaded) { + return 0; + } + + auto handle = dlopen(IVERBS_SO_PATH, RTLD_NOW); + if (handle == nullptr) { + NN_LOG_ERROR("Failed to load verbs so " << IVERBS_SO_PATH << ", error " << dlerror()); + return -1; + } + + DLSYM(IBV_GET_DEVICE_LIST, VerbsAPI::hcomInnerIbvGetDevList, "ibv_get_device_list"); + DLSYM(IBV_FORK_INIT, VerbsAPI::hcomInnerIbvForkInit, "ibv_fork_init"); + DLSYM(IBV_QUERY_PORT, VerbsAPI::hcomInnerIbvQueryPort, "ibv_query_port"); + DLSYM(IBV_OPEN_DEVICE, VerbsAPI::hcomInnerIbvOpenDevice, "ibv_open_device"); + DLSYM(IBV_ALLOC_PD, VerbsAPI::hcomInnerIbvAllocPD, "ibv_alloc_pd"); + DLSYM(IBV_FREE_DEVICE_LIST, VerbsAPI::hcomInnerIbvFreeDevList, "ibv_free_device_list"); + DLSYM(IBV_CREATE_COMP_CHANNEL, VerbsAPI::hcomInnerIbvCreateCompChannel, "ibv_create_comp_channel"); + DLSYM(IBV_GET_CQ_EVENT, VerbsAPI::hcomInnerIbvGetCQEvent, "ibv_get_cq_event"); + DLSYM(IBV_GET_ASYNC_EVENT, VerbsAPI::hcomInnerIbvGetAsyncEvent, "ibv_get_async_event"); + DLSYM(IBV_ACK_ASYNC_EVENT, VerbsAPI::hcomInnerIbvAckAsyncEvent, "ibv_ack_async_event"); + DLSYM(IBV_CREATE_QP, VerbsAPI::hcomInnerIbvCreateQP, "ibv_create_qp"); + DLSYM(IBV_CLOSE_DEVICE, VerbsAPI::hcomInnerIbvCloseDev, "ibv_close_device"); + DLSYM(IBV_DEALLOC_PD, VerbsAPI::hcomInnerIbvDeallocPD, "ibv_dealloc_pd"); + DLSYM(IBV_CREATE_CQ, VerbsAPI::hcomInnerCreateCQ, "ibv_create_cq"); + DLSYM(IBV_DESTROY_COMP_CHANNEL, VerbsAPI::hcomInnerDestroyCompChannel, "ibv_destroy_comp_channel"); + DLSYM(IBV_DESTROY_CQ, VerbsAPI::hcomInnerDestroyCQ, "ibv_destroy_cq"); + DLSYM(IBV_ACK_CQ_EVENTS, VerbsAPI::hcomInnerAckCQ, "ibv_ack_cq_events"); + DLSYM(IBV_DESTROY_QP, VerbsAPI::hcomInnerDestroyQP, "ibv_destroy_qp"); + DLSYM(IBV_MODIFY_QP, VerbsAPI::hcomInnerModityQP, "ibv_modify_qp"); + DLSYM(IBV_DEREG_MR, VerbsAPI::hcomInnerDeregMr, "ibv_dereg_mr"); + DLSYM(IBV_QUERY_GID, VerbsAPI::hcomInnerQueryGid, "ibv_query_gid"); + DLSYM(IBV_QUERY_DEVICE, VerbsAPI::hcomInnerQueryDevice, "ibv_query_device"); + DLSYM(IBV_REG_MR_IOVA2, VerbsAPI::hcomInnerRegMrIOVA2, "ibv_reg_mr_iova2"); + DLSYM(IBV_REG_MR, VerbsAPI::hcomInnerRegMr, "ibv_reg_mr"); + DLSYM(IBV_PORT_STATE_STR, VerbsAPI::hcomInnerPortStateStr, "ibv_port_state_str"); + + NN_LOG_INFO("Success to load ibverbs"); + gLoaded = true; + + return 0; +} +#else +int VerbsAPI::LoadFakeVerbsAPI() +{ + if (gLoaded) { + return 0; + } + + VerbsAPI::hcomInnerIbvGetDevList = ibv_get_device_list; + VerbsAPI::hcomInnerIbvForkInit = ibv_fork_init; + VerbsAPI::hcomInnerIbvQueryPort = fake_ibv_query_port; + VerbsAPI::hcomInnerIbvOpenDevice = ibv_open_device; + VerbsAPI::hcomInnerQueryDevice = ibv_query_device; + VerbsAPI::hcomInnerIbvAllocPD = ibv_alloc_pd; + VerbsAPI::hcomInnerIbvFreeDevList = ibv_free_device_list; + VerbsAPI::hcomInnerIbvCreateCompChannel = ibv_create_comp_channel; + VerbsAPI::hcomInnerIbvGetCQEvent = ibv_get_cq_event; + VerbsAPI::hcomInnerIbvGetAsyncEvent = ibv_get_async_event; + VerbsAPI::hcomInnerIbvAckAsyncEvent = ibv_ack_async_event; + VerbsAPI::hcomInnerIbvCreateQP = ibv_create_qp; + VerbsAPI::hcomInnerIbvCloseDev = ibv_close_device; + VerbsAPI::hcomInnerIbvDeallocPD = ibv_dealloc_pd; + VerbsAPI::hcomInnerCreateCQ = ibv_create_cq; + VerbsAPI::hcomInnerDestroyCompChannel = ibv_destroy_comp_channel; + VerbsAPI::hcomInnerDestroyCQ = ibv_destroy_cq; + VerbsAPI::hcomInnerAckCQ = ibv_ack_cq_events; + VerbsAPI::hcomInnerDestroyQP = ibv_destroy_qp; + VerbsAPI::hcomInnerModityQP = ibv_modify_qp; + VerbsAPI::hcomInnerDeregMr = ibv_dereg_mr; + VerbsAPI::hcomInnerQueryGid = ibv_query_gid; + VerbsAPI::hcomInnerRegMrIOVA2 = ibv_reg_mr_iova2; + VerbsAPI::hcomInnerRegMr = ibv_reg_mr; + VerbsAPI::hcomInnerPortStateStr = ibv_port_state_str; + + NN_LOG_INFO("Success to load fake ibverbs"); + gLoaded = true; + + return 0; +} +#endif + +#endif \ No newline at end of file diff --git a/src/under_api/verbs/verbs_api_dl.h b/src/under_api/verbs/verbs_api_dl.h new file mode 100644 index 0000000000000000000000000000000000000000..1de37fc9581eba9f66fd124bd33763b9f415ff51 --- /dev/null +++ b/src/under_api/verbs/verbs_api_dl.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_DYLOADER_IVERBS_H +#define HCOM_DYLOADER_IVERBS_H +#ifdef RDMA_BUILD_ENABLED + +#include +#include +#include +#include +#include +#include +#include + +#include "hcom_def.h" + +#define IVERBS_SO_PATH "libibverbs.so" + +using IBV_GET_DEVICE_LIST = struct ibv_device **(*)(int *num_devices); +using IBV_FORK_INIT = int (*)(); +using IBV_OPEN_DEVICE = struct ibv_context *(*)(struct ibv_device *device); +using IBV_ALLOC_PD = struct ibv_pd *(*)(struct ibv_context *context); +using IBV_QUERY_PORT = int (*)(struct ibv_context *context, uint8_t port_num, struct _compat_ibv_port_attr *port_attr); +using IBV_FREE_DEVICE_LIST = void (*)(struct ibv_device **list); +using IBV_REG_MR = struct ibv_mr *(*)(struct ibv_pd *pd, void *addr, size_t length, int access); +using IBV_CREATE_COMP_CHANNEL = struct ibv_comp_channel *(*)(struct ibv_context *context); +using IBV_GET_CQ_EVENT = int (*)(struct ibv_comp_channel *channel, struct ibv_cq **cq, void **cq_context); +using IBV_GET_ASYNC_EVENT = int (*)(struct ibv_context *context, struct ibv_async_event *event); +using IBV_ACK_ASYNC_EVENT = void (*)(struct ibv_async_event *event); +using IBV_CREATE_QP = struct ibv_qp *(*)(struct ibv_pd *pd, struct ibv_qp_init_attr *qp_init_attr); +using IBV_CLOSE_DEVICE = int (*)(struct ibv_context *context); +using IBV_DEALLOC_PD = int (*)(struct ibv_pd *pd); +using IBV_CREATE_CQ = struct ibv_cq *(*)(struct ibv_context *context, int cqe, void *cq_context, + struct ibv_comp_channel *channel, int comp_vector); +using IBV_DESTROY_COMP_CHANNEL = int (*)(struct ibv_comp_channel *channel); +using IBV_DESTROY_CQ = int (*)(struct ibv_cq *cq); +using IBV_ACK_CQ_EVENTS = void (*)(struct ibv_cq *cq, unsigned int nevents); +using IBV_DESTROY_QP = int (*)(struct ibv_qp *qp); +using IBV_MODIFY_QP = int (*)(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask); +using IBV_DEREG_MR = int (*)(struct ibv_mr *mr); +using IBV_QUERY_GID = int (*)(struct ibv_context *context, uint8_t port_num, int index, union ibv_gid *gid); +using IBV_QUERY_DEVICE = int (*)(struct ibv_context *context, struct ibv_device_attr *device_attr); +using IBV_REG_MR_IOVA2 = struct ibv_mr *(*)(struct ibv_pd *pd, void *addr, size_t length, uint64_t iova, + unsigned int access); +using IBV_PORT_STATE_STR = const char *(*)(enum ibv_port_state port_state); + +class VerbsAPI { +public: + static IBV_GET_DEVICE_LIST hcomInnerIbvGetDevList; + static IBV_FORK_INIT hcomInnerIbvForkInit; + static IBV_QUERY_PORT hcomInnerIbvQueryPort; + static IBV_OPEN_DEVICE hcomInnerIbvOpenDevice; + static IBV_ALLOC_PD hcomInnerIbvAllocPD; + static IBV_FREE_DEVICE_LIST hcomInnerIbvFreeDevList; + static IBV_CREATE_COMP_CHANNEL hcomInnerIbvCreateCompChannel; + static IBV_GET_CQ_EVENT hcomInnerIbvGetCQEvent; + static IBV_GET_ASYNC_EVENT hcomInnerIbvGetAsyncEvent; + static IBV_ACK_ASYNC_EVENT hcomInnerIbvAckAsyncEvent; + static IBV_CREATE_QP hcomInnerIbvCreateQP; + static IBV_CLOSE_DEVICE hcomInnerIbvCloseDev; + static IBV_DEALLOC_PD hcomInnerIbvDeallocPD; + static IBV_CREATE_CQ hcomInnerCreateCQ; + static IBV_DESTROY_COMP_CHANNEL hcomInnerDestroyCompChannel; + static IBV_DESTROY_CQ hcomInnerDestroyCQ; + static IBV_ACK_CQ_EVENTS hcomInnerAckCQ; + static IBV_DESTROY_QP hcomInnerDestroyQP; + static IBV_MODIFY_QP hcomInnerModityQP; + static IBV_DEREG_MR hcomInnerDeregMr; + static IBV_QUERY_GID hcomInnerQueryGid; + static IBV_QUERY_DEVICE hcomInnerQueryDevice; + static IBV_REG_MR_IOVA2 hcomInnerRegMrIOVA2; + static IBV_REG_MR hcomInnerRegMr; + static IBV_PORT_STATE_STR hcomInnerPortStateStr; + +#if defined(TEST_LLT) && defined(MOCK_VERBS) + static int LoadFakeVerbsAPI(); +#else + static int LoadVerbsAPI(); +#endif + +private: + static bool gLoaded; +}; + +#ifndef IBV_ACCESS_OPTIONAL_RANGE +#define IBV_ACCESS_OPTIONAL_RANGE 0 +#endif + +#define HCOM_IBV_REG_MR(pd, addr, length, access, is_access_const) \ + ({ \ + struct ibv_mr *ret; \ + auto noIova2 = VerbsAPI::hcomInnerRegMrIOVA2 == nullptr; \ + if (((is_access_const) && ((access)&IBV_ACCESS_OPTIONAL_RANGE) == 0) || noIova2) { \ + ret = VerbsAPI::hcomInnerRegMr((pd), (addr), (length), (access)); \ + } else { \ + ret = VerbsAPI::hcomInnerRegMrIOVA2((pd), (addr), (length), (uintptr_t)(addr), (access)); \ + } \ + ret; \ + }) + +#define HCOM_IBV_INNER_REG_MR(pd, addr, length, access) \ + HCOM_IBV_REG_MR(pd, addr, length, access, __builtin_constant_p(((access)&IBV_ACCESS_OPTIONAL_RANGE) == 0)) + +#ifndef verbs_get_ctx_op +#define HCOM_IBV_INNER_QUERY_PORT(context, port_num, port_attr) \ + ({ \ + int rc; \ + bzero((port_attr), sizeof(*(port_attr))); \ + rc = VerbsAPI::hcomInnerIbvQueryPort(context, port_num, \ + reinterpret_cast(port_attr)); \ + rc; \ + }) +#else +#define HCOM_IBV_INNER_QUERY_PORT(context, port_num, port_attr) \ + ({ \ + struct verbs_context *vctx = verbs_get_ctx_op(context, query_port); \ + int rc; \ + if (!vctx) { \ + bzero((port_attr), sizeof(*(port_attr))); \ + rc = VerbsAPI::hcomInnerIbvQueryPort(context, port_num, \ + reinterpret_cast(port_attr)); \ + } else { \ + rc = vctx->query_port(context, port_num, port_attr, sizeof(*(port_attr))); \ + } \ + rc; \ + }) +#endif +#endif +#endif // HCOM_DYLOADER_IVERBS_H \ No newline at end of file diff --git a/src/under_api/verbs/verbs_api_wrapper.h b/src/under_api/verbs/verbs_api_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..2de09e5b9d9b3e2aa73e1877324177ee8ca7b5b6 --- /dev/null +++ b/src/under_api/verbs/verbs_api_wrapper.h @@ -0,0 +1,154 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_VERBS_API_WRAPPER_H +#define HCOM_VERBS_API_WRAPPER_H +#ifdef RDMA_BUILD_ENABLED + +#include "verbs_api_dl.h" + +namespace ock { +namespace hcom { +class HcomIbv { +public: + static inline struct ibv_device **GetDevList(int *num_devices) + { + return VerbsAPI::hcomInnerIbvGetDevList(num_devices); + } + + static inline int ForkInit() + { + return VerbsAPI::hcomInnerIbvForkInit(); + } + + static inline struct ibv_context *OpenDevice(struct ibv_device *device) + { + return VerbsAPI::hcomInnerIbvOpenDevice(device); + } + + static inline struct ibv_pd *AllocPd(struct ibv_context *context) + { + return VerbsAPI::hcomInnerIbvAllocPD(context); + } + + static inline int QueryPort(struct ibv_context *context, uint8_t port_num, struct ibv_port_attr *port_attr) + { + return HCOM_IBV_INNER_QUERY_PORT(context, port_num, port_attr); + } + + static inline void FreeDevList(struct ibv_device **list) + { + VerbsAPI::hcomInnerIbvFreeDevList(list); + } + + static inline struct ibv_mr *RegMr(struct ibv_pd *pd, void *addr, size_t length, unsigned int access) + { + return HCOM_IBV_INNER_REG_MR(pd, addr, length, access); + } + + static inline struct ibv_comp_channel *CreateCompChannel(struct ibv_context *context) + { + return VerbsAPI::hcomInnerIbvCreateCompChannel(context); + } + + static inline int GetCqEvent(struct ibv_comp_channel *channel, struct ibv_cq **cq, void **cq_context) + { + return VerbsAPI::hcomInnerIbvGetCQEvent(channel, cq, cq_context); + } + + static inline int GetAsyncEvent(struct ibv_context *context, struct ibv_async_event *event) + { + return VerbsAPI::hcomInnerIbvGetAsyncEvent(context, event); + } + + static inline void AckAsyncEvent(struct ibv_async_event *event) + { + VerbsAPI::hcomInnerIbvAckAsyncEvent(event); + } + + static inline struct ibv_qp *CreateQp(struct ibv_pd *pd, struct ibv_qp_init_attr *qp_init_attr) + { + return VerbsAPI::hcomInnerIbvCreateQP(pd, qp_init_attr); + } + + static inline int CloseDev(struct ibv_context *context) + { + return VerbsAPI::hcomInnerIbvCloseDev(context); + } + + static inline int DeallocPd(struct ibv_pd *pd) + { + return VerbsAPI::hcomInnerIbvDeallocPD(pd); + } + + static inline struct ibv_cq *CreateCq(struct ibv_context *context, int cqe, void *cq_context, + struct ibv_comp_channel *channel, int comp_vector) + { + return VerbsAPI::hcomInnerCreateCQ(context, cqe, cq_context, channel, comp_vector); + } + + static inline int DestroyCompChannel(struct ibv_comp_channel *channel) + { + return VerbsAPI::hcomInnerDestroyCompChannel(channel); + } + static inline int DestroyCq(struct ibv_cq *cq) + { + return VerbsAPI::hcomInnerDestroyCQ(cq); + } + + static inline void AckCqEvents(struct ibv_cq *cq, unsigned int nevents) + { + VerbsAPI::hcomInnerAckCQ(cq, nevents); + } + static inline int DestroyQp(struct ibv_qp *qp) + { + return VerbsAPI::hcomInnerDestroyQP(qp); + } + + static inline int ModifyQp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask) + { + return VerbsAPI::hcomInnerModityQP(qp, attr, attr_mask); + } + + static inline int DeregMr(struct ibv_mr *mr) + { + return VerbsAPI::hcomInnerDeregMr(mr); + } + + static inline int QueryGid(struct ibv_context *context, uint8_t port_num, int index, union ibv_gid *gid) + { + return VerbsAPI::hcomInnerQueryGid(context, port_num, index, gid); + } + + static inline int QueryDevice(struct ibv_context *context, struct ibv_device_attr *device_attr) + { + return VerbsAPI::hcomInnerQueryDevice(context, device_attr); + } + + static inline const char *PortStateStr(enum ibv_port_state port_state) + { + return VerbsAPI::hcomInnerPortStateStr(port_state); + } + + static inline int Load() + { +#if !defined(TEST_LLT) || !defined(MOCK_VERBS) + return VerbsAPI::LoadVerbsAPI(); +#else + return VerbsAPI::LoadFakeVerbsAPI(); +#endif + } +}; +} +} + +#endif +#endif // HCOM_VERBS_API_WRAPPER_H \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7171c68cd668061b2a617a7779750eaeeaf43727 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,18 @@ +project(hcom_tests_integration) + +add_compile_options(-fno-access-control) +add_compile_options(-fno-inline) + +set(HCOM_3RDPARTY_INSTALL_DIR "${HCOM_OUTPUT_PATH}/hcom_3rdparty") +set(LIB_HTRACER_DIR "${HCOM_3RDPARTY_INSTALL_DIR}/hcom_tracer/lib") +set(LIB_HTRACER_STATIC "${HCOM_3RDPARTY_INSTALL_DIR}/hcom_tracer/lib/libhtracer_static.a") + +# prepare libhtracer_static for ut compile +EXEC_PROGRAM(ar ${LIB_HTRACER_DIR} ARGS x ${LIB_HTRACER_STATIC}) +EXEC_PROGRAM(ar ${LIB_HTRACER_DIR} ARGS x ${LIB_SECURE_C_3RDPARTH_STATIC}) +EXEC_PROGRAM(ar ${LIB_HTRACER_DIR} ARGS cur ${LIB_HTRACER_STATIC} ${LIB_HTRACER_DIR}/*.o) +EXEC_PROGRAM(rm ${LIB_HTRACER_DIR} ARGS ${LIB_HTRACER_DIR}/*.o) + +add_subdirectory(stub) +add_subdirectory(llt) +add_subdirectory(unit_test) diff --git a/test/deploy/deploy_physical.sh b/test/deploy/deploy_physical.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0d94e0d80bddaefb0c209f834b3846a1d136bc9 --- /dev/null +++ b/test/deploy/deploy_physical.sh @@ -0,0 +1,26 @@ +#!/bin/bash +CUR_PATH=$(cd $(dirname $0);pwd) +# 判断当前用户是否已经申请仿真环境,如果没有申请,执行部署脚本直接报错 +if [ ! -e "$HOST_CFG" ]; then + echo "Error:File $HOST_CFG is not exist, please apply for a test environment first." + exit 1 +fi + +# 获取申请的环境节点的IP和端口,如果环境为多节点,业务需要自行适配 +NODES=$(cat ${HOST_CFG} | awk '{print $2, $3}') + +for NODE in ${NODES}; do + NODE_IP=$(echo ${NODE} | awk '{print $1}') + NODE_PORT=$(echo ${NODE} | awk '{print $2}') + + # 删除目标节点的旧安装包,并上传新安装包 + ssh root@${NODE_IP} -p ${NODE_PORT} rm -rf /opt/install/package/OCK-CommunicationSuite* + scp -P ${NODE_PORT} ${WORKSPACE}/BeiMing_output/OCK-CommunicationSuite_2.0.0_aarch64.tar.gz root@${NODE_IP}:/opt/install/package/ + + # 在目标节点上解压和安装 + ssh root@${NODE_IP} -p ${NODE_PORT} << EOF + cd /opt/install/package/; + tar -zxvf OCK-CommunicationSuite_2.0.0_aarch64.tar.gz; + rpm -ivh --nodeps OCK-CommunicationSuite_HCOM*.rpm --force; +EOF +done \ No newline at end of file diff --git a/test/external_libs/mockcpp_support_arm64.patch b/test/external_libs/mockcpp_support_arm64.patch new file mode 100644 index 0000000000000000000000000000000000000000..5f42f4a4fa26aeebe0a59a034439c76605fb0344 --- /dev/null +++ b/test/external_libs/mockcpp_support_arm64.patch @@ -0,0 +1,877 @@ +diff --git a/include/mockcpp/ApiHookFunctor.h b/include/mockcpp/ApiHookFunctor.h +index a860a95..6a77488 100644 +--- a/include/mockcpp/ApiHookFunctor.h ++++ b/include/mockcpp/ApiHookFunctor.h +@@ -20,6 +20,7 @@ + + #include + #include ++#include + + MOCKCPP_NS_START + +@@ -30,20 +31,21 @@ struct ApiHookFunctor + }; + + const std::string empty_caller(""); +-#define __MOCKCPP_API_HOOK_FUNCTOR_DEF(n, CallingConvention) \ +-template \ +-struct ApiHookFunctor \ ++ ++#define __MOCKCPP_C_API_HOOK_FUNCTOR_DEF(n, CallingConvention) \ ++template \ ++struct ApiHookFunctor \ + { \ + private: \ +- typedef R CallingConvention F (DECL_ARGS(n)); \ ++ typedef R (CallingConvention *F) (DECL_ARGS(n)); \ + \ + static R CallingConvention hook(DECL_PARAMS_LIST(n)) \ + { \ +- return GlobalMockObject::instance.invoke(apiAddress) \ +- (empty_caller DECL_REST_PARAMS(n)); \ ++ return GlobalMockObject::instance.invoke(apiAddress) \ ++ (empty_caller, RefAny() DECL_REST_PARAMS(n)); \ + } \ + \ +- static bool appliedBy(F* api) \ ++ static bool appliedBy(F api) \ + { return apiAddress == reinterpret_cast(api); } \ + \ + static void* getHook() \ +@@ -53,51 +55,112 @@ private: \ + { if(--refCount == 0) apiAddress = 0; } \ + public: \ + \ +- static void* getApiHook(F* api) \ ++ static void* getApiHook(F api) \ ++ { \ ++ if(!appliedBy(api)) return 0; \ ++ ++refCount; \ ++ return getHook(); \ ++ } \ ++ \ ++ static void* applyApiHook(F api) \ ++ { \ ++ if(apiAddress != 0) return 0; \ ++ apiAddress = reinterpret_cast(api); \ ++ refCount = 1; \ ++ return getHook(); \ ++ } \ ++ \ ++ static bool freeApiHook(void* hook) \ ++ { \ ++ if(getHook() != hook) return false; \ ++ freeHook(); \ ++ return true; \ ++ } \ ++private: \ ++ static void* apiAddress; \ ++ static unsigned int refCount; \ ++}; \ ++template \ ++void* ApiHookFunctor::apiAddress = 0; \ ++template \ ++unsigned int ApiHookFunctor::refCount = 0 ++ ++/* For C++ method */ ++#define __MOCKCPP_CXX_API_HOOK_FUNCTOR_DEF(n, CallingConvention, ConstConvention) \ ++template \ ++struct ApiHookFunctor \ ++{ \ ++private: \ ++ typedef ApiHookFunctor ThisType; \ ++ typedef R (CallingConvention C::*F) (DECL_ARGS(n)) ConstConvention; \ ++ \ ++ R CallingConvention hook(DECL_PARAMS_LIST(n)) \ ++ { \ ++ C *This = reinterpret_cast(this); \ ++ return GlobalMockObject::instance.invoke(apiAddress) \ ++ (empty_caller, This DECL_REST_PARAMS(n)); \ ++ } \ ++ \ ++ static bool appliedBy(F api) \ ++ { return apiAddress == Details::methodToAddr(api); } \ ++ \ ++ static void* getHook() \ ++ { return Details::methodToAddr(&ThisType::hook); } \ ++ \ ++ static void freeHook() \ ++ { if(--refCount == 0) apiAddress = 0; } \ ++public: \ ++ \ ++ static void* getApiHook(F api) \ + { \ +- if(! appliedBy(api)) return 0; \ +- ++refCount; \ +- return getHook(); \ ++ if(!appliedBy(api)) return 0; \ ++ ++refCount; \ ++ return getHook(); \ + } \ + \ +- static void* applyApiHook(F* api) \ ++ static void* applyApiHook(F api) \ + { \ +- if(apiAddress != 0) return 0; \ +- apiAddress = reinterpret_cast(api); \ +- refCount = 1; \ +- return getHook(); \ ++ if(apiAddress != 0) return 0; \ ++ apiAddress = Details::methodToAddr(api); \ ++ refCount = 1; \ ++ return getHook(); \ + } \ + \ + static bool freeApiHook(void* hook) \ + { \ +- if(getHook() != hook) return false; \ +- freeHook(); \ +- return true; \ ++ if(getHook() != hook) return false; \ ++ freeHook(); \ ++ return true; \ + } \ + private: \ + static void* apiAddress; \ + static unsigned int refCount; \ + }; \ +-template \ +-void* ApiHookFunctor::apiAddress = 0; \ +-template \ +-unsigned int ApiHookFunctor::refCount = 0 ++template \ ++void* ApiHookFunctor::apiAddress = 0; \ ++template \ ++unsigned int ApiHookFunctor::refCount = 0 + +-#if defined(_MSC_VER) +-// TODO: ApiHook related tests failed on VS2019. +-// [ ERROR ] TestApiHook.h:66: hardware exception STATUS_ILLEGAL_INSTRUCTION raised in setup or running test +-// [ ERROR ] TestApiHook.h:66 : hardware exception STATUS_ACCESS_VIOLATION raised in teardown +-#if _MSC_VER >= 1920 // VS 2019 ++#ifdef WIN32 ++#if defined(_MSC_VER) && defined(BUILD_FOR_X86) + #define MOCKCPP_API_HOOK_FUNCTOR_DEF(n) \ +-__MOCKCPP_API_HOOK_FUNCTOR_DEF(n, __stdcall) ++__MOCKCPP_C_API_HOOK_FUNCTOR_DEF(n, ); \ ++__MOCKCPP_C_API_HOOK_FUNCTOR_DEF(n, __stdcall); \ ++__MOCKCPP_CXX_API_HOOK_FUNCTOR_DEF(n, , ); \ ++__MOCKCPP_CXX_API_HOOK_FUNCTOR_DEF(n, , const); \ ++__MOCKCPP_CXX_API_HOOK_FUNCTOR_DEF(n, __stdcall, ); \ ++__MOCKCPP_CXX_API_HOOK_FUNCTOR_DEF(n, __stdcall, const) + #else + #define MOCKCPP_API_HOOK_FUNCTOR_DEF(n) \ +-__MOCKCPP_API_HOOK_FUNCTOR_DEF(n, ); \ +-__MOCKCPP_API_HOOK_FUNCTOR_DEF(n, __stdcall) ++__MOCKCPP_C_API_HOOK_FUNCTOR_DEF(n, ); \ ++__MOCKCPP_CXX_API_HOOK_FUNCTOR_DEF(n, , ); \ ++__MOCKCPP_CXX_API_HOOK_FUNCTOR_DEF(n, , const) + #endif + #else + #define MOCKCPP_API_HOOK_FUNCTOR_DEF(n) \ +-__MOCKCPP_API_HOOK_FUNCTOR_DEF(n, ) ++__MOCKCPP_C_API_HOOK_FUNCTOR_DEF(n, ); \ ++__MOCKCPP_CXX_API_HOOK_FUNCTOR_DEF(n, , ); \ ++__MOCKCPP_CXX_API_HOOK_FUNCTOR_DEF(n, , const) + #endif + + MOCKCPP_API_HOOK_FUNCTOR_DEF(0); +diff --git a/include/mockcpp/ApiHookGenerator.h b/include/mockcpp/ApiHookGenerator.h +index 569d20f..beb04a9 100644 +--- a/include/mockcpp/ApiHookGenerator.h ++++ b/include/mockcpp/ApiHookGenerator.h +@@ -26,7 +26,7 @@ MOCKCPP_NS_START + template + struct ApiHookGenerator + { +- static void* findApiHook(F* api) ++ static void* findApiHook(F api) + { + void* hook; + +@@ -36,7 +36,7 @@ struct ApiHookGenerator + return hook; + } + +- static void* appyApiHook(F* api) ++ static void* appyApiHook(F api) + { + void* hook; + +@@ -61,10 +61,10 @@ private: + template + struct ApiHookGenerator + { +- static void* findApiHook(F* api) ++ static void* findApiHook(F api) + { return 0; } + +- static void* appyApiHook(F* api) ++ static void* appyApiHook(F api) + { + oss_t oss; + +diff --git a/include/mockcpp/ApiHookHolderFactory.h b/include/mockcpp/ApiHookHolderFactory.h +index 6105caf..a40e3e1 100644 +--- a/include/mockcpp/ApiHookHolderFactory.h ++++ b/include/mockcpp/ApiHookHolderFactory.h +@@ -27,7 +27,7 @@ struct ApiHookHolder; + struct ApiHookHolderFactory + { + template +- static ApiHookHolder* create(F* api) ++ static ApiHookHolder* create(F api) + { + return new ParameterizedApiHookHolder(api); + } +diff --git a/include/mockcpp/ApiHookMocker.h b/include/mockcpp/ApiHookMocker.h +index 3d7c963..5c11422 100644 +--- a/include/mockcpp/ApiHookMocker.h ++++ b/include/mockcpp/ApiHookMocker.h +@@ -21,18 +21,119 @@ + #include + #include + #include ++#include ++#include ++#include + + MOCKCPP_NS_START + +-template +-InvocationMockBuilderGetter mockAPI(const std::string& name, API* api) ++struct mockAPIauto {}; ++template struct mockAPI; ++ ++template struct mockAPI + { ++ static InvocationMockBuilderGetter get( ++ const std::string& name, const std::string& type, API api) ++ { ++ return MOCKCPP_NS::GlobalMockObject::instance.method ++ ( type.empty() ? name : name + " #" + type + "#" ++ , Details::methodToAddr(api) ++ , ApiHookHolderFactory::create(api)); ++ } ++ ++ template ++ static InvocationMockBuilderGetter get_virtual( ++ const std::string& name, const std::string& type, const C *c, API api) ++ { ++ void ***vtbl = (void ***)c; ++ std::pair indices = ++ getIndicesOfMethod(api); ++ union { void *_addr; API _api; }; ++ _addr = (*vtbl)[indices.second]; + return MOCKCPP_NS::GlobalMockObject::instance.method +- ( name +- , reinterpret_cast(api) +- , ApiHookHolderFactory::create(api)); +-} ++ ( type.empty() ? name : name + " #" + type + "#" ++ , _addr ++ , ApiHookHolderFactory::create(_api)); ++ } ++ ++ template ++ static InvocationMockBuilderGetter get_virtual( ++ const std::string& name, const std::string& type, const C &c, API api) ++ { ++ return get_virtual(name, type, &c, api); ++ } ++}; // struct mockAPI ++ ++template<> struct mockAPI ++{ ++#define __MOCKCPP_C_API_GET_FUNCTION_DEF(n, CallingConvention) \ ++ template \ ++ static InvocationMockBuilderGetter get( \ ++ const std::string& name, const std::string& type, R (CallingConvention *api)(DECL_ARGS(n))) \ ++ { \ ++ typedef R (CallingConvention *API)(DECL_ARGS(n)); \ ++ return mockAPI::get(name, type, api); \ ++ } ++ ++#define __MOCKCPP_CXX_API_GET_FUNCTION_DEF(n, CallingConvention, ConstMethod) \ ++ template \ ++ static InvocationMockBuilderGetter get( \ ++ const std::string& name, const std::string& type, R (CallingConvention C::*api)(DECL_ARGS(n)) ConstMethod) \ ++ { \ ++ typedef R (CallingConvention C::*API)(DECL_ARGS(n)) ConstMethod; \ ++ return mockAPI::get(name, type, api); \ ++ } \ ++ template \ ++ static InvocationMockBuilderGetter get_virtual( \ ++ const std::string& name, const std::string& type, const C *c, R (CallingConvention C::*api)(DECL_ARGS(n)) ConstMethod) \ ++ { \ ++ typedef R (CallingConvention C::*API)(DECL_ARGS(n)) ConstMethod; \ ++ return mockAPI::get_virtual(name, type, c, api); \ ++ } \ ++ template \ ++ static InvocationMockBuilderGetter get_virtual( \ ++ const std::string& name, const std::string& type, const C &c, R (CallingConvention C::*api)(DECL_ARGS(n)) ConstMethod) \ ++ { \ ++ typedef R (CallingConvention C::*API)(DECL_ARGS(n)) ConstMethod; \ ++ return mockAPI::get_virtual(name, type, c, api); \ ++ } ++ ++#ifdef WIN32 ++#if defined(_MSC_VER) && defined(BUILD_FOR_X86) ++#define MOCKCPP_API_GET_FUNCTION_DEF(n) \ ++ __MOCKCPP_C_API_GET_FUNCTION_DEF(n, ); \ ++ __MOCKCPP_C_API_GET_FUNCTION_DEF(n, __stdcall); \ ++ __MOCKCPP_CXX_API_GET_FUNCTION_DEF(n, , ); \ ++ __MOCKCPP_CXX_API_GET_FUNCTION_DEF(n, , const); \ ++ __MOCKCPP_CXX_API_GET_FUNCTION_DEF(n, __stdcall, ); \ ++ __MOCKCPP_CXX_API_GET_FUNCTION_DEF(n, __stdcall, const) ++#else ++#define MOCKCPP_API_GET_FUNCTION_DEF(n) \ ++ __MOCKCPP_C_API_GET_FUNCTION_DEF(n, ); \ ++ __MOCKCPP_CXX_API_GET_FUNCTION_DEF(n, , ); \ ++ __MOCKCPP_CXX_API_GET_FUNCTION_DEF(n, , const) ++#endif ++#else ++#define MOCKCPP_API_GET_FUNCTION_DEF(n) \ ++ __MOCKCPP_C_API_GET_FUNCTION_DEF(n, ); \ ++ __MOCKCPP_CXX_API_GET_FUNCTION_DEF(n, , ); \ ++ __MOCKCPP_CXX_API_GET_FUNCTION_DEF(n, , const) ++#endif + ++ MOCKCPP_API_GET_FUNCTION_DEF(0); ++ MOCKCPP_API_GET_FUNCTION_DEF(1); ++ MOCKCPP_API_GET_FUNCTION_DEF(2); ++ MOCKCPP_API_GET_FUNCTION_DEF(3); ++ MOCKCPP_API_GET_FUNCTION_DEF(4); ++ MOCKCPP_API_GET_FUNCTION_DEF(5); ++ MOCKCPP_API_GET_FUNCTION_DEF(6); ++ MOCKCPP_API_GET_FUNCTION_DEF(7); ++ MOCKCPP_API_GET_FUNCTION_DEF(8); ++ MOCKCPP_API_GET_FUNCTION_DEF(9); ++ MOCKCPP_API_GET_FUNCTION_DEF(10); ++ MOCKCPP_API_GET_FUNCTION_DEF(11); ++ MOCKCPP_API_GET_FUNCTION_DEF(12); ++}; // struct mockAPI + + MOCKCPP_NS_END + +diff --git a/include/mockcpp/ChainableMockMethod.h b/include/mockcpp/ChainableMockMethod.h +index ba37e20..a23cf54 100644 +--- a/include/mockcpp/ChainableMockMethod.h ++++ b/include/mockcpp/ChainableMockMethod.h +@@ -39,6 +39,7 @@ struct ChainableMockMethodBase + {} + + RT operator()( const std::string& nameOfCaller ++ , const RefAny& pThisPointer = RefAny() + , const RefAny& p01 = RefAny() + , const RefAny& p02 = RefAny() + , const RefAny& p03 = RefAny() +@@ -57,7 +58,7 @@ struct ChainableMockMethodBase + + try { + const Any& result = \ +- invokable->invoke( nameOfCaller ++ invokable->invoke( nameOfCaller, pThisPointer + , p01, p02, p03, p04, p05, p06 + , p07, p08, p09, p10, p11, p12 + , resultProvider); +diff --git a/include/mockcpp/ChainableMockMethodCore.h b/include/mockcpp/ChainableMockMethodCore.h +index 98df8d7..217ef4c 100644 +--- a/include/mockcpp/ChainableMockMethodCore.h ++++ b/include/mockcpp/ChainableMockMethodCore.h +@@ -46,6 +46,7 @@ public: + // Method + const Any& + invoke( const std::string& nameOfCaller ++ , const RefAny& pThisPointer + , const RefAny& p1 + , const RefAny& p2 + , const RefAny& p3 +diff --git a/include/mockcpp/GlobalMockObject.h b/include/mockcpp/GlobalMockObject.h +index 7d2046a..1746966 100644 +--- a/include/mockcpp/GlobalMockObject.h ++++ b/include/mockcpp/GlobalMockObject.h +@@ -42,6 +42,21 @@ struct GlobalMockObject + static MockObjectType instance; + }; + ++namespace Details { ++ template ++ void *methodToAddr(Method m) ++ { ++ union ++ { ++ void *addr_; ++ Method m_; ++ }; ++ ++ m_ = m; ++ return addr_; ++ } ++} // namespace Details ++ + MOCKCPP_NS_END + + #endif +diff --git a/include/mockcpp/GnuMethodInfoReader.h b/include/mockcpp/GnuMethodInfoReader.h +index 0945ad7..7fb3358 100644 +--- a/include/mockcpp/GnuMethodInfoReader.h ++++ b/include/mockcpp/GnuMethodInfoReader.h +@@ -73,6 +73,10 @@ GnuMethodDescription getGnuDescOfVirtualMethod(Method input) + MethodDescriptionUnion m; + m.method = input; + ++#if defined(__arm__) || defined(__aarch64__) ++ m.desc.u.index++; ++#endif ++ + oss_t oss; + oss << "Virtual method address should be odd, please make sure the method " + << TypeString::value() << " is a virtual method"; +diff --git a/include/mockcpp/Invocation.h b/include/mockcpp/Invocation.h +index 03393bd..bab3477 100644 +--- a/include/mockcpp/Invocation.h ++++ b/include/mockcpp/Invocation.h +@@ -32,6 +32,7 @@ struct InvocationImpl; + struct Invocation + { + Invocation(const std::string nameOfCaller ++ , const RefAny& pThisPointer = RefAny() + , const RefAny& p01 = RefAny() + , const RefAny& p02 = RefAny() + , const RefAny& p03 = RefAny() +@@ -48,8 +49,12 @@ struct Invocation + + virtual ~Invocation(); + ++ RefAny& getThisPointer(void) const; ++ + RefAny& getParameter(const unsigned int i) const; + ++ RefAny& getParameterWithThis(const unsigned int i) const; ++ + std::string getNameOfCaller() const; + + std::string toString(void) const; +diff --git a/include/mockcpp/Invokable.h b/include/mockcpp/Invokable.h +index a1b78d7..145ba7b 100644 +--- a/include/mockcpp/Invokable.h ++++ b/include/mockcpp/Invokable.h +@@ -35,6 +35,7 @@ struct Invokable + { + virtual const Any& + invoke( const std::string& nameOfCaller ++ , const RefAny& pThisPointer + , const RefAny& p1 + , const RefAny& p2 + , const RefAny& p3 +diff --git a/include/mockcpp/JmpCode.h b/include/mockcpp/JmpCode.h +index 26f77b0..ed0fac6 100644 +--- a/include/mockcpp/JmpCode.h ++++ b/include/mockcpp/JmpCode.h +@@ -33,6 +33,7 @@ struct JmpCode + + void* getCodeData() const; + size_t getCodeSize() const; ++ void flushCache() const; + private: + JmpCodeImpl* This; + }; +diff --git a/include/mockcpp/ParameterizedApiHookHolder.h b/include/mockcpp/ParameterizedApiHookHolder.h +index 39b7670..0142ac0 100644 +--- a/include/mockcpp/ParameterizedApiHookHolder.h ++++ b/include/mockcpp/ParameterizedApiHookHolder.h +@@ -30,7 +30,7 @@ struct ParameterizedApiHookHolder + { + const static unsigned int maxSeq = 10; + +- ParameterizedApiHookHolder(F* api) ++ ParameterizedApiHookHolder(F api) + { + (m_hook = ApiHookGenerator::findApiHook(api)) || + (m_hook = ApiHookGenerator::appyApiHook(api)); +diff --git a/include/mockcpp/mockcpp.h b/include/mockcpp/mockcpp.h +index 306bc7a..8cc385d 100644 +--- a/include/mockcpp/mockcpp.h ++++ b/include/mockcpp/mockcpp.h +@@ -39,7 +39,11 @@ + #endif + + +-#if ( defined (__LP64__) \ ++#if defined (__aarch64__) ++#define BUILD_FOR_AARCH64 ++#elif defined (__arm__) ++#define BUILD_FOR_ARM32 ++#elif ( defined (__LP64__) \ + || defined (__64BIT__) \ + || defined (_LP64) \ + || ((defined(__WORDSIZE)) && (__WORDSIZE == 64)) \ +diff --git a/include/mockcpp/mokc.h b/include/mockcpp/mokc.h +index 800a766..95a0f0a 100644 +--- a/include/mockcpp/mokc.h ++++ b/include/mockcpp/mokc.h +@@ -27,7 +27,9 @@ + # define MOCKER(api) MOCKCPP_NS::GlobalMockObject::instance.method(#api) + # else + # include +-# define MOCKER(api) MOCKCPP_NS::mockAPI(#api, api) ++# define MOCKER(api, ...) MOCKCPP_NS::mockAPI<__VA_ARGS__>::get(#api, ""#__VA_ARGS__, api) ++# define MOCKER_CPP(api, ...) MOCKCPP_NS::mockAPI<__VA_ARGS__>::get(#api, ""#__VA_ARGS__, api) ++# define MOCKER_CPP_VIRTUAL(obj, api, ...) MOCKCPP_NS::mockAPI<__VA_ARGS__>::get_virtual(#api, ""#__VA_ARGS__, obj, api) + # endif + + USING_MOCKCPP_NS +diff --git a/src/ChainableMockMethodCore.cpp b/src/ChainableMockMethodCore.cpp +index c254c0f..4aa4190 100644 +--- a/src/ChainableMockMethodCore.cpp ++++ b/src/ChainableMockMethodCore.cpp +@@ -149,6 +149,7 @@ ChainableMockMethodCore::getName() const + const Any& + ChainableMockMethodCore::invoke + ( const std::string& nameOfCaller ++ , const RefAny& pThisPointer + , const RefAny& p1 + , const RefAny& p2 + , const RefAny& p3 +@@ -163,7 +164,7 @@ ChainableMockMethodCore::invoke + , const RefAny& p12 + , SelfDescribe* &resultProvider) + { +- Invocation inv(nameOfCaller,p1,p2,p3,p4,p5,p6,p7,p8,p9,p10,p11,p12); ++ Invocation inv(nameOfCaller,pThisPointer,p1,p2,p3,p4,p5,p6,p7,p8,p9,p10,p11,p12); + return This->invoke(inv, resultProvider); + } + +diff --git a/src/Invocation.cpp b/src/Invocation.cpp +index 17019f5..c27d1cc 100644 +--- a/src/Invocation.cpp ++++ b/src/Invocation.cpp +@@ -35,6 +35,7 @@ struct InvocationImpl + { + std::vector parameters; + std::string nameOfCaller; ++ RefAny thisPointer; + + std::string toString(void) const; + +@@ -70,6 +71,7 @@ std::string InvocationImpl::toString() const + + Invocation::Invocation( + const std::string name ++ , const RefAny& pThisPointer + , const RefAny& p1 + , const RefAny& p2 + , const RefAny& p3 +@@ -85,6 +87,7 @@ Invocation::Invocation( + ) + : This(new InvocationImpl(name)) + { ++ This->thisPointer = pThisPointer; + INIT_PARAMETER(1); + INIT_PARAMETER(2); + INIT_PARAMETER(3); +@@ -106,6 +109,11 @@ Invocation::~Invocation() + } + + //////////////////////////////////////////////////////////////// ++RefAny& Invocation::getThisPointer(void) const ++{ ++ return This->thisPointer; ++} ++ + RefAny& Invocation::getParameter(const unsigned int i) const + { + if (i < 1 || i > maxNumOfParameters ) +@@ -116,6 +124,21 @@ RefAny& Invocation::getParameter(const unsigned int i) const + return This->parameters[i-1]; + } + ++RefAny& Invocation::getParameterWithThis(const unsigned int i) const ++{ ++ if (This->thisPointer.empty()) ++ { ++ return getParameter(i); ++ } ++ ++ if (1 == i) ++ { ++ return This->thisPointer; ++ } ++ ++ return getParameter(i - 1); ++} ++ + //////////////////////////////////////////////////////////////// + std::string Invocation::toString(void) const + { +diff --git a/src/JmpCode.cpp b/src/JmpCode.cpp +index 35794fb..f64abb5 100644 +--- a/src/JmpCode.cpp ++++ b/src/JmpCode.cpp +@@ -30,6 +30,7 @@ struct JmpCodeImpl + //////////////////////////////////////////////// + JmpCodeImpl(const void* from, const void* to) + { ++ m_from = from; + ::memcpy(m_code, jmpCodeTemplate, JMP_CODE_SIZE); + SET_JMP_CODE(m_code, from, to); + } +@@ -47,7 +48,14 @@ struct JmpCodeImpl + } + + //////////////////////////////////////////////// ++ void flushCache() const ++ { ++ FLUSH_CACHE((const char *)m_from, JMP_CODE_SIZE); ++ } + ++ //////////////////////////////////////////////// ++ ++ const void *m_from; + unsigned char m_code[JMP_CODE_SIZE]; + }; + +@@ -77,5 +85,11 @@ JmpCode::getCodeSize() const + return This->getCodeSize(); + } + +-MOCKCPP_NS_END ++/////////////////////////////////////////////////// ++void ++JmpCode::flushCache() const ++{ ++ return This->flushCache(); ++} + ++MOCKCPP_NS_END +diff --git a/src/JmpCodeAARCH64.h b/src/JmpCodeAARCH64.h +new file mode 100644 +index 0000000..2d741c0 +--- /dev/null ++++ b/src/JmpCodeAARCH64.h +@@ -0,0 +1,69 @@ ++/*** ++ mockcpp is a C/C++ mock framework. ++ Copyright [2008] [Darwin Yuan ] ++ [Chen Guodong ] ++ Licensed under the Apache License, Version 2.0 (the "License"); ++ you may not use this file except in compliance with the License. ++ You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++ Unless required by applicable law or agreed to in writing, software ++ distributed under the License is distributed on an "AS IS" BASIS, ++ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ See the License for the specific language governing permissions and ++ limitations under the License. ++***/ ++#ifndef __MOCKCPP_JMP_CODE_AARCH64_H__ ++#define __MOCKCPP_JMP_CODE_AARCH64_H__ ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++MOCKCPP_NS_START ++ ++struct l2cache_addr_range { ++ uintptr_t start; ++ uintptr_t end; ++}; ++ ++MOCKCPP_NS_END ++ ++#define ADDR_ALIGN_UP(addr) ((((addr) + ((4096) - 1)) & (~((4096) - 1))) & 0xffffffffffffffff) ++#define ADDR_ALIGN_DOWN(addr) (((addr) & (~((4096) - 1))) & 0xffffffffffffffff) ++#define OUTER_CACHE_INV_RANGE _IOWR('S', 0x00, struct l2cache_addr_range) ++#define OUTER_CACHE_CLEAN_RANGE _IOWR('S', 0x01, struct l2cache_addr_range) ++#define OUTER_CACHE_FLUSH_RANGE _IOWR('S', 0x02, struct l2cache_addr_range) ++#define L1_INV_I_CACHE _IOWR('S', 0x03, struct l2cache_addr_range) ++#define D_TO_I_CACHE_FLUSH_RANGE _IOWR('S', 0x04, struct l2cache_addr_range) ++ ++const unsigned char jmpCodeTemplate[] = ++ { 0x57, 0x00, 0x00, 0x58, 0xe0, 0x02, 0x1f, 0xd6, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; ++ ++#define SET_SJMP_CODE(base, from, to) do { \ ++ using instruct_t = signed int; \ ++ instruct_t offset = (intptr_t)to - (intptr_t)from; \ ++ offset = ((offset >> 2) & 0x03FFFFFF) | 0x14000000; \ ++ *(instruct_t *)(base) = offset; \ ++ } while(0) ++ ++#define SET_JMP_CODE(base, from, to) do { \ ++ *(void **)(base + 8) = (void *)to; \ ++ } while(0) ++ ++#define FLUSH_CACHE(from, length) do { \ ++ struct l2cache_addr_range usr_data; \ ++ usr_data.start = ADDR_ALIGN_DOWN((unsigned long long)from); \ ++ usr_data.end = ADDR_ALIGN_UP((unsigned long long)from) + length; \ ++ __builtin___clear_cache((char *)usr_data.start, (char *)usr_data.end); \ ++} while (0) ++ ++#endif +diff --git a/src/JmpCodeARM32.h b/src/JmpCodeARM32.h +new file mode 100644 +index 0000000..1eec42b +--- /dev/null ++++ b/src/JmpCodeARM32.h +@@ -0,0 +1,36 @@ ++/*** ++ mockcpp is a C/C++ mock framework. ++ Copyright [2008] [Darwin Yuan ] ++ [Chen Guodong ] ++ Licensed under the Apache License, Version 2.0 (the "License"); ++ you may not use this file except in compliance with the License. ++ You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++ Unless required by applicable law or agreed to in writing, software ++ distributed under the License is distributed on an "AS IS" BASIS, ++ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ See the License for the specific language governing permissions and ++ limitations under the License. ++***/ ++#ifndef __MOCKCPP_JMP_CODE_ARM32_H__ ++#define __MOCKCPP_JMP_CODE_ARM32_H__ ++ ++#include ++ ++const unsigned char jmpCodeTemplate[] = ++ { 0xEA, 0x00, 0x00, 0x00 }; ++ ++#define SET_JMP_CODE(base, from, to) do { \ ++ int offset = (int)to - (int)from - 8; \ ++ offset = (offset >> 2) & 0x00FFFFFF; \ ++ int code = *(int *)(base) | offset; \ ++ *(int *)(base) = changeByteOrder(code); \ ++ } while(0) ++ ++#define FLUSH_CACHE(from, length) do { \ ++ ::system("echo 3 > /proc/sys/vm/drop_caches"); \ ++} while (0) ++ ++#endif +diff --git a/src/JmpCodeArch.h b/src/JmpCodeArch.h +index 26abd73..53353eb 100644 +--- a/src/JmpCodeArch.h ++++ b/src/JmpCodeArch.h +@@ -19,11 +19,29 @@ + + #include + ++template ++inline T changeByteOrder(const T v) { ++ enum { S = sizeof(T) }; ++ T rst = v; ++ char *p = (char *)&rst; ++ char tmp = 0; ++ for (unsigned int i = 0; i < S / 2; ++i) { ++ tmp = p[i]; ++ p[i] = p[S - i - 1]; ++ p [S - i - 1] = tmp; ++ } ++ ++ return rst; ++} ++ + #if BUILD_FOR_X64 + # include "JmpCodeX64.h" + #elif BUILD_FOR_X86 + # include "JmpCodeX86.h" ++#elif defined(BUILD_FOR_ARM32) ++# include "JmpCodeARM32.h" ++#elif defined(BUILD_FOR_AARCH64) ++# include "JmpCodeAARCH64.h" + #endif + + #endif +- +diff --git a/src/JmpCodeX64.h b/src/JmpCodeX64.h +index 198507a..e5b4f31 100644 +--- a/src/JmpCodeX64.h ++++ b/src/JmpCodeX64.h +@@ -27,5 +27,6 @@ const unsigned char jmpCodeTemplate[] = + *(uintptr_t *)(base + 6) = (uintptr_t)to; \ + } while(0) + +-#endif ++#define FLUSH_CACHE(from, length) ((void)0) + ++#endif +diff --git a/src/JmpCodeX86.h b/src/JmpCodeX86.h +index ebdc526..a06a02e 100644 +--- a/src/JmpCodeX86.h ++++ b/src/JmpCodeX86.h +@@ -23,5 +23,6 @@ const unsigned char jmpCodeTemplate[] = { 0xE9, 0x00, 0x00, 0x00, 0x00 }; + (unsigned long long)to - (unsigned long long)from - sizeof(jmpCodeTemplate); \ + } while(0) + +-#endif ++#define FLUSH_CACHE(from, length) ((void)0) + ++#endif +diff --git a/src/JmpOnlyApiHook.cpp b/src/JmpOnlyApiHook.cpp +index d4cfa68..964828f 100644 +--- a/src/JmpOnlyApiHook.cpp ++++ b/src/JmpOnlyApiHook.cpp +@@ -68,6 +68,7 @@ struct JmpOnlyApiHookImpl + void changeCode(const void* data) + { + CodeModifier::modify(const_cast(m_api), data, m_jmpCode.getCodeSize()); ++ m_jmpCode.flushCache(); + } + + ///////////////////////////////////////////////////// +diff --git a/src/UnixCodeModifier.cpp b/src/UnixCodeModifier.cpp +index ab4014e..e1fbe75 100644 +--- a/src/UnixCodeModifier.cpp ++++ b/src/UnixCodeModifier.cpp +@@ -18,27 +18,32 @@ + #include + #include + #include ++#include + + #include ++#include "JmpCodeArch.h" + + #define PAGE_ALIGN_BITS 12 + + ////////////////////////////////////////////////////////////////// + #define PAGE_SIZE (1 << PAGE_ALIGN_BITS) +-#define ALIGN_TO_PAGE_BOUNDARY(addr) (void*) (((uintptr_t)addr) & (~((1<<(PAGE_ALIGN_BITS))-1))) ++#define ALIGN_TO_PAGE_BOUNDARY(addr, page_size) (void*) (((uintptr_t)addr) & (~((page_size)-1))) + ////////////////////////////////////////////////////////////////// + + MOCKCPP_NS_START + + bool CodeModifier::modify(void *dest, const void *src, size_t size) + { +- if(::mprotect(ALIGN_TO_PAGE_BOUNDARY(dest), PAGE_SIZE * 2, PROT_EXEC | PROT_WRITE | PROT_READ ) != 0) +- { +- return false; ++ unsigned int page_size = getpagesize(); ++ ++ if (::mprotect(ALIGN_TO_PAGE_BOUNDARY(dest, page_size), page_size * 2, PROT_EXEC | PROT_WRITE | PROT_READ) != 0) { ++ return false; + } + + ::memcpy(dest, src, size); + ++ FLUSH_CACHE(dest, size); ++ + + #if 0 + #if BUILD_FOR_X86 diff --git a/test/hlt/mt_accept/mt_accept_test.cpp b/test/hlt/mt_accept/mt_accept_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f5e47b92e4e0e8df61baaf689d411d274e139bde --- /dev/null +++ b/test/hlt/mt_accept/mt_accept_test.cpp @@ -0,0 +1,363 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "hcom_service.h" +namespace ock { +namespace hcom { + +constexpr uint32_t NN_NO13 = 13; + +NetService *service = nullptr; +NetService *client = nullptr; +NetChannelPtr channel = nullptr; + +UBSHcomNetDriverProtocol driverType = TCP; +std::string oobIp = ""; +uint16_t g_oobPort = 9981; +std::string ipSeg = "192.168.100.0/24"; +std::string udsName = "SHM_UDS"; +int32_t g_dataSize = 1024; +int32_t g_pingCount = 1000000; +int16_t g_asyncWorkerCpuId = -1; +uint32_t g_workerNum = 1; + +uint64_t g_startTimeServer = 0; +uint64_t g_endTimeServer = 0; +int g_numprocs = 0; +int g_rank; +#define MPI_CHECK(stmt) \ + do { \ + int mpiErrno = (stmt); \ + if (MPI_SUCCESS != mpiErrno) { \ + fprintf(stderr, "[%s:%d] MPI call failed with %d \n", __FILE__, __LINE__, mpiErrno); \ + exit(EXIT_FAILURE); \ + } \ + assert(MPI_SUCCESS == mpiErrno); \ + } while (0) + +int ValidateArguments(int argc, char *argv[]) +{ + const char *usage = "usage\n" + " -d, --driver, driver type, 0 means rdma, 1 means tcp\n" + " -i, --ip, server ip mask, e.g. 10.175.118.1\n" + " -p, --port, server port, by default 9981\n" + " -s, --io size , max data size\n" + " -w, --worker num , worker num\n" + " -c, --cpuId, async worker\n"; + + if (argc != NN_NO13) { + printf("invalid param, %s, for example %s -d 0 -i rdma_nic_ip -p 9981 -s 1024 -w 1 -c 5\n", usage, argv[0]); + return -1; + } + + return 0; +} + +int ProcessOptions(int argc, char *argv[]) +{ + struct option options[] = { + {"driver", required_argument, nullptr, 'd'}, + {"ip", required_argument, nullptr, 'i'}, + {"port", required_argument, nullptr, 'p'}, + {"size", required_argument, nullptr, 's'}, + {"worker num", required_argument, nullptr, 'w'}, + {"cpuId", required_argument, nullptr, 'c'}, + {nullptr, 0, nullptr, 0}, + }; + + if (ValidateArguments(argc, argv) != 0) { + return -1; + } + + int ret = 0; + int index = 0; + std::string str = "d:i:p:s:w:c:"; + while ((ret = getopt_long(argc, argv, str.c_str(), options, &index)) != -1) { + switch (ret) { + case 'd': + driverType = static_cast((uint16_t)strtoul(optarg, NULL, 0)); + if (driverType > SHM) { + printf("invalid driver type %d", driverType); + return -1; + } + break; + case 'i': + oobIp = optarg; + ipSeg = oobIp + "/24"; + break; + case 'p': + g_oobPort = static_cast(strtoul(optarg, nullptr, 0)); + break; + case 's': + g_dataSize = static_cast(strtoul(optarg, nullptr, 0)); + break; + case 'w': + g_workerNum = static_cast(strtoul(optarg, nullptr, 0)); + break; + case 'c': + g_asyncWorkerCpuId = strtoul(optarg, nullptr, 0); + break; + default: + printf("invalid param, for example -d 0 -i rdma_nic_ip -p 9981 -s 1024 -w 1 -c 5"); + return -1; + } + } + return 0; +} +int g_count = 0; +int NewChannel(const std::string &ipPort, const NetChannelPtr &ch, const std::string &payload) +{ + g_count++; + if (g_count == 1) { + g_startTimeServer = MONOTONIC_TIME_NS(); + std::cout << "all connect startTimeServer: " << g_startTimeServer << " ns, count " << g_count << std::endl; + } + + if (g_count == g_numprocs - 1) { + g_endTimeServer = MONOTONIC_TIME_NS(); + double s = static_cast(g_endTimeServer - g_startTimeServer) / 1000000000; + std::cout << "all connect success: " << s << " s, count " << g_count << + ", numprocs " << g_numprocs << std::endl; + } + NN_LOG_INFO("new channel " << ch->Id() << " call from " << ipPort << " payload: " << payload); + return 0; +} + +void BrokenChannel(const NetChannelPtr &ch) +{ + NN_LOG_INFO("ep broken"); +} + +int ReceivedRequest(NetServiceContext &context) +{ + return 0; +} + +int PostSendRequest(NetServiceContext context) +{ + return 0; +} +int OneSideDownRequest(NetServiceContext context) +{ + return 0; +} + +bool HcomServerInitStart() +{ + if (service != nullptr) { + NN_LOG_ERROR("service already created"); + return false; + } + + service = NetService::Instance(driverType, "server1", true); + if (service == nullptr) { + NN_LOG_ERROR("failed to create service already created"); + return false; + } + NetServiceOptions options{}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.mrSendReceiveSegSize = NN_NO1024 + g_dataSize; + if (driverType == SHM) { + options.oobType = NET_OOB_UDS; + UBSHcomNetOobUDSListenerOptions listenOpt; + listenOpt.Name(udsName); + listenOpt.perm = 0; + service->AddOobUdsOptions(listenOpt); + } + if (g_asyncWorkerCpuId != -1) { + std::string str = std::to_string(g_asyncWorkerCpuId) + "-" + std::to_string(g_asyncWorkerCpuId); + options.SetWorkerGroupsCpuSet(str); + NN_LOG_INFO(" set cpuId " << str); + } + options.SetNetDeviceIpMask(ipSeg); + options.SetWorkerGroups(std::to_string(g_workerNum)); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + service->SetOobIpAndPort(oobIp, g_oobPort); + service->RegisterNewChannelHandler(NewChannel); + service->RegisterChannelBrokenHandler(BrokenChannel, ock::hcom::BROKEN_ALL); + service->RegisterOpReceiveHandler(0, ReceivedRequest); + service->RegisterOpSentHandler(0, PostSendRequest); + service->RegisterOpOneSideHandler(0, OneSideDownRequest); + + int result = 0; + if ((result = service->Start(options)) != 0) { + NN_LOG_ERROR("failed to initialize service " << result); + return false; + } + NN_LOG_INFO("service initialized and start"); + + return true; +} + +bool HcomClientInitStart(int rank) +{ + if (client != nullptr) { + NN_LOG_ERROR("client already created"); + return false; + } + + client = NetService::Instance(driverType, "client" + std::to_string(rank), false); + if (client == nullptr) { + NN_LOG_ERROR("failed to create client already created"); + return false; + } + NetServiceOptions options{}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.mrSendReceiveSegSize = NN_NO1024 + g_dataSize; + if (driverType == SHM) { + options.oobType = NET_OOB_UDS; + } + if (g_asyncWorkerCpuId != -1) { + std::string str = std::to_string(g_asyncWorkerCpuId) + "-" + std::to_string(g_asyncWorkerCpuId); + options.SetWorkerGroupsCpuSet(str); + NN_LOG_INFO("client set cpuId " << str); + } + options.SetNetDeviceIpMask(ipSeg); + options.SetWorkerGroups(std::to_string(g_workerNum)); + NN_LOG_INFO("client set ip mask " << options.netDeviceIpMask); + + client->SetOobIpAndPort(oobIp, g_oobPort); + + client->RegisterChannelBrokenHandler(BrokenChannel, ock::hcom::BROKEN_ALL); + client->RegisterOpReceiveHandler(0, ReceivedRequest); + client->RegisterOpSentHandler(0, PostSendRequest); + client->RegisterOpOneSideHandler(0, OneSideDownRequest); + + int result = 0; + if ((result = client->Start(options)) != 0) { + NN_LOG_ERROR("failed to start client " << result); + return false; + } + NN_LOG_INFO("client" << rank << " initialized and start"); + + return true; +} + +bool HcomInit(int rank) +{ + bool ret; + if (rank == 0) { + ret = HcomServerInitStart(); + } else { + ret = HcomClientInitStart(rank); + } + return ret; +} + +void HcomClientUninit(int rank) +{ + client->Stop(); + NetService::DestroyInstance("client" + std::to_string(rank)); +} + +void HcomServerUninit() +{ + service->Stop(); + + NetService::DestroyInstance("server1"); +} + +bool HcomClientConnect(int rank) +{ + if (client == nullptr) { + NN_LOG_ERROR("client is null" << rank); + return false; + } + int result = 0; + NetServiceConnectOptions options{}; + + if (driverType == SHM) { + result = client->Connect(udsName, 0, "hello service", channel, options); + } else { + result = client->Connect(oobIp, g_oobPort, "hello service", channel, options); + } + + if (result != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + NN_LOG_INFO("client" << rank << " success to connect to server, channel id " << channel->Id()); + + return true; +} + +int main(int argc, char *argv[]) +{ + int ret1; + bool ret; + + ret1 = ProcessOptions(argc, argv); + if (ret1 != 0) { + MPI_CHECK(MPI_Finalize()); + exit(EXIT_FAILURE); + } + + MPI_CHECK(MPI_Init(&argc, &argv)); + MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &g_rank)); + MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &g_numprocs)); + + if (g_numprocs < NN_NO2) { + if (g_rank == 0) { + fprintf(stderr, "This test requires at least two processes\n"); + } + + MPI_CHECK(MPI_Finalize()); + exit(EXIT_FAILURE); + } + + ret = HcomInit(g_rank); + if (!ret) { + MPI_CHECK(MPI_Finalize()); + exit(EXIT_FAILURE); + } + + MPI_CHECK(MPI_Barrier(MPI_COMM_WORLD)); + + if (g_rank != 0) { + // client process [do connect] + ret = HcomClientConnect(g_rank); + if (!ret) { + MPI_CHECK(MPI_Finalize()); + exit(EXIT_FAILURE); + } + } + if (g_rank != 0) { + HcomClientUninit(g_rank); + } + + MPI_CHECK(MPI_Barrier(MPI_COMM_WORLD)); + + if (g_rank == 0) { + HcomServerUninit(); + } + + MPI_CHECK(MPI_Finalize()); + + return EXIT_SUCCESS; +} +} +} \ No newline at end of file diff --git a/test/hlt/mt_accept/run_mt.sh b/test/hlt/mt_accept/run_mt.sh new file mode 100644 index 0000000000000000000000000000000000000000..96010f99826723948642574844807daabaa09008 --- /dev/null +++ b/test/hlt/mt_accept/run_mt.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# *********************************************************************** +# Copyright: (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# script for run mt +# version: 1.0.0 +# change log: +# *********************************************************************** +#设置client和server总数量 +num=$1 + +# 获取当前目录 +CURRENT_PATH=$(cd $(dirname ${0}) && pwd) + +# 获取当前目录的父目录 +parent_dir=$(dirname "$(pwd)") + +# 获取父目录的父目录 +grandparent_dir=$(dirname "$parent_dir") + +# 获取父目录的父目录的父目录 +greatgrandparent_dir=$(dirname "$grandparent_dir") + +echo "mpi building ... " +export LD_LIBRARY_PATH=${greatgrandparent_dir}/build/src:$LD_LIBRARY_PATH +mpicc -o3 -Wall -I${greatgrandparent_dir}/src -L${greatgrandparent_dir}/build/src -lhcom -lstdc++ -o test ./mt_accept_test.cpp + +#mpirun -n 1001 -x LD_LIBRARY_PATH -x HCOM_SET_LOG_LEVEL=3 -x HCOM_CONNECTION_RETRY_TIME=1 -host 96.10.130.125:128,96.10.130.126:128,96.10.130.127:128,96.10.130.128:128,96.10.130.129:128,96.10.130.130:128,96.10.130.131:128,96.10.130.132:105 ./test -d 1 -i 10.10.3.126 -p 9982 -s 1024 -w 1 -c -1 +echo "test running ... " +mpirun -n ${num} -x LD_LIBRARY_PATH -x HCOM_SET_LOG_LEVEL=3 -x HCOM_CONNECTION_RETRY_TIME=1 -hostfile ${CURRENT_PATH}/crossfile ./test -d 1 -i 10.10.3.126 -p 9982 -s 1024 -w 1 -c -1 + diff --git a/test/llt/CMakeLists.txt b/test/llt/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b31235446717a9bd2757694fb17bfd670704799 --- /dev/null +++ b/test/llt/CMakeLists.txt @@ -0,0 +1,61 @@ +# +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +cmake_minimum_required(VERSION 3.14) +project(HCOM_LLT C CXX) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 11) + +# define macro that will be used in source code +add_compile_options(-DMOCK_VERBS -DTEST_LLT) + +# collect hcom source files +include_directories(${CMAKE_SOURCE_DIR}/src/) +aux_source_directory(service_v2/api SOURCE_FILES) +aux_source_directory(service_v2 SOURCE_FILES) +include_directories(${CMAKE_SOURCE_DIR}/src/api/java_sdk/jni) +include_directories(${CMAKE_SOURCE_DIR}/src/api/java_sdk/jni/include) +include_directories(${CMAKE_SOURCE_DIR}/src/api/java_sdk/jni/service) +include_directories(${CMAKE_SOURCE_DIR}/src/api/java_sdk/jni/common) +include_directories(${CMAKE_SOURCE_DIR}/src/under_api/verbs) +file(GLOB_RECURSE SOURCE_FILES + ${CMAKE_SOURCE_DIR}/src/*.cpp + ${CMAKE_SOURCE_DIR}/src/*.h) + +# collect hcom unittest files +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/testcase) +aux_source_directory(main/ SOURCE_FILES) +aux_source_directory(testcase/ SOURCE_FILES) +aux_source_directory(testcase/common SOURCE_FILES) +aux_source_directory(testcase/transport SOURCE_FILES) +aux_source_directory(testcase/transport/rdma SOURCE_FILES) +#aux_source_directory(testcase/transport/shm SOURCE_FILES) +aux_source_directory(testcase/transport/sock SOURCE_FILES) + +# include&link gtest +set(GTEST_INSTALL_DIR "${TEST_TOOL_INSTALL_PATH}/googletest") +if (NOT EXISTS ${GTEST_INSTALL_DIR}) + message(ERROR "GTEST_INSTALL_DIR(${GTEST_INSTALL_DIR}) is invalid") +endif() +include_directories(${GTEST_INSTALL_DIR}/include) +link_directories(${GTEST_INSTALL_DIR}/lib64) + +# include&link mockcpp +set(MOCKCPP_INSTALL_DIR "${TEST_TOOL_INSTALL_PATH}/mockcpp") +if (NOT EXISTS ${MOCKCPP_INSTALL_DIR}) + message(ERROR "MOCKCPP_INSTALL_DIR(${MOCKCPP_INSTALL_DIR}) is invalid") +endif() +include_directories(${MOCKCPP_INSTALL_DIR}/include) +link_directories(${MOCKCPP_INSTALL_DIR}/lib) + +# enable gcov +add_compile_options(-ftest-coverage -fprofile-arcs) + +# build hcom_ut +set(DEPEND_LIBS rt pthread gtest gcov mockcpp dl fake_ibv_static boundscheck) +add_executable(Hcomtest ${SOURCE_FILES}) +target_compile_options(Hcomtest PUBLIC -D_GNU_SOURCE) +set_target_properties(Hcomtest PROPERTIES OUTPUT_NAME "hcom_test") +set_target_properties(Hcomtest PROPERTIES CLEAN_DIRECT_OUTPUT 1) +set_target_properties(Hcomtest PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") +target_link_libraries(Hcomtest ${DEPEND_LIBS} rt) diff --git a/test/llt/main/unit_main.cpp b/test/llt/main/unit_main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f30dcebf66e1bb71813c6585c693f815754079ec --- /dev/null +++ b/test/llt/main/unit_main.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +int main(int argc, char *argv[]) +{ + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + printf("hcom_test result %d\n", ret); + return ret; +} \ No newline at end of file diff --git a/test/llt/testcase/api/java/pom.xml b/test/llt/testcase/api/java/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..648e9e8b0e630a2c72e62c3ce8301e4239f37863 --- /dev/null +++ b/test/llt/testcase/api/java/pom.xml @@ -0,0 +1,122 @@ + + + 4.0.0 + + com.huawei.ock.hcom + hcom_unit_test + UT test for Huawei Open Computing Kit, java SDK + jar + beiming.24.4 + + + UTF-8 + 2.0.5 + + + + + com.huawei.ock + hcom-sdk + beiming.24.4 + + + com.huawei.dt + dt4j-starter-boot + ${dt4j.version} + test + + + com.huawei.dt + dt4j-starter-mockito + ${dt4j.version} + test + + + org.powermock + powermock-module-junit4 + 2.0.9 + compile + + + org.powermock + powermock-api-mockito2 + 2.0.9 + compile + + + + + ../../../../../src/api/java_sdk/src/com/huawei/ock/hcom/ + src/ + + + org.apache.maven.plugins + maven-surefire-plugin + 3.2.3 + + false + + --add-opens java.base/java.lang=ALL-UNNAMED + --add-opens java.base/java.util=ALL-UNNAMED + --add-opens java.base/java.lang.reflect=ALL-UNNAMED + + + + + maven-assembly-plugin + + + jar-with-dependencies + + ${project.artifactId}-${project.version} + false + + + + make-assembly + package + + single + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.11.0 + + + + org.projectlombok + lombok + 1.18.32 + + + + + + + com.huawei.dt + dt4j-coverage-maven-plugin + ${dt4j.version} + + TestReport + 1.0.0 + TestUser + + + + + instrument + + instrument + + + + + + + \ No newline at end of file diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/common/ExternLoggerListenerTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/common/ExternLoggerListenerTest.java new file mode 100644 index 0000000000000000000000000000000000000000..edf2e6ea528eddb5fe9fde61736e96f65f017826 --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/common/ExternLoggerListenerTest.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.common; + +import com.huawei.ock.hcom.common.ExternLogger; +import com.huawei.ock.hcom.common.ExternLoggerListener; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; +import org.powermock.modules.junit4.PowerMockRunner; + +/** + * class ExternLoggerListener test + * + * @since 2024-08-22 + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({ExternLoggerListener.class, Long.class}) +@SuppressStaticInitializationFor("com.huawei.ock.hcom.common.ExternLoggerListener") +public class ExternLoggerListenerTest { + class ExternLoggerTest implements ExternLogger { + @Override + public void log(int level, String message) { + + } + } + + @Test + public void exitRunTest() { + ExternLoggerTest loggerTest = new ExternLoggerTest(); + PowerMockito.suppress(PowerMockito.method(ExternLoggerListener.class, "nativeAddExternLogListener")); + ExternLoggerListener listener = new ExternLoggerListener(loggerTest, 1, 100); + PowerMockito.suppress(PowerMockito.method(ExternLoggerListener.class, "nativeInterrupt")); + listener.exit(); + PowerMockito.stub(PowerMockito.method(ExternLoggerListener.class, "nativeStop")).toReturn(2); + PowerMockito.stub(PowerMockito.method(ExternLoggerListener.class, "nativePollLogMsg")).toReturn("test log"); + listener.run(); + } + + @Test + public void logTest() { + ExternLoggerTest loggerTest = new ExternLoggerTest(); + PowerMockito.suppress(PowerMockito.method(ExternLoggerListener.class, "nativeAddExternLogListener")); + ExternLoggerListener listener = new ExternLoggerListener(loggerTest, 1, 100); + PowerMockito.suppress(PowerMockito.method(ExternLoggerListener.class, "nativeInterrupt")); + listener.exit(); + PowerMockito.stub(PowerMockito.method(ExternLoggerListener.class, "nativeStop")).toReturn(2); + PowerMockito.stub(PowerMockito.method(ExternLoggerListener.class, "nativePollLogMsg")).toReturn("test log"); + PowerMockito.stub(PowerMockito.method(Long.class, "intValue")).toReturn(1); + listener.run(); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetChannelTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetChannelTest.java new file mode 100644 index 0000000000000000000000000000000000000000..1a6e13ad7ae12db74ed3d54f4d3102822cf0d98d --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetChannelTest.java @@ -0,0 +1,298 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service; + +import com.huawei.ock.hcom.service.NetChannel; +import com.huawei.ock.hcom.service.NetServiceMessage; +import com.huawei.ock.hcom.service.NetServiceOpInfo; +import com.huawei.ock.hcom.service.NetChannelCallback; +import com.huawei.ock.hcom.service.NetServiceContext; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.Arrays; + +/** + * class NetChannel test + * + * @since 2024-02-17 + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest(NetChannel.class) +@SuppressStaticInitializationFor("com.huawei.ock.hcom.service.NetChannel") +public class NetChannelTest { + @Test + public void sendTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data); + byte[] data1 = new byte[10]; + Arrays.fill(data, (byte) 'c'); + message.data = data1; + message.transferOwner = false; + NetServiceOpInfo opInfo = new NetServiceOpInfo((short) 0); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSend")).toReturn(0); + NetChannel channel = new NetChannel(123, 456, false); + channel.send(opInfo, message, 1); + } + + @Test(expected = Exception.class) + public void sendTestException() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data); + byte[] data1 = new byte[10]; + Arrays.fill(data, (byte) 'c'); + message.data = data1; + message.transferOwner = false; + NetServiceOpInfo opInfo = new NetServiceOpInfo((short) 0); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSend")).toReturn(1); + NetChannel channel = new NetChannel(123, 456, false); + channel.send(opInfo, message, 1); + } + + class NetChannelCallbackTest implements NetChannelCallback { + @Override + public void run(NetServiceContext ctx) { + + } + } + + @Test + public void sendTestCb() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data); + byte[] data1 = new byte[10]; + Arrays.fill(data, (byte) 'c'); + message.data = data1; + message.transferOwner = false; + NetServiceOpInfo opInfo = new NetServiceOpInfo((short) 0); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSendWithCb")).toReturn(0); + NetChannel channel = new NetChannel(123, 456, false); + NetChannelCallbackTest cbTest = new NetChannelCallbackTest(); + channel.send(opInfo, message, cbTest, 1); + } + + @Test(expected = Exception.class) + public void sendTestCbException() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data); + byte[] data1 = new byte[10]; + Arrays.fill(data, (byte) 'c'); + message.data = data1; + message.transferOwner = false; + NetServiceOpInfo opInfo = new NetServiceOpInfo((short) 0); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSendWithCb")).toReturn(1); + NetChannel channel = new NetChannel(123, 456, false); + NetChannelCallbackTest cbTest = new NetChannelCallbackTest(); + channel.send(opInfo, message, cbTest, 1); + } + + @Test(expected = Exception.class) + public void sendExceptionTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data); + byte[] data1 = new byte[10]; + Arrays.fill(data, (byte) 'c'); + message.data = data1; + message.transferOwner = false; + NetServiceOpInfo opInfo = new NetServiceOpInfo((short) 0); + NetChannel channel = new NetChannel(123, 456, false); + NetChannelCallbackTest cbTest = new NetChannelCallbackTest(); + channel.send(opInfo, message, cbTest); + } + + @Test + public void sendRawTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data); + byte[] data1 = new byte[10]; + Arrays.fill(data, (byte) 'c'); + message.data = data1; + message.transferOwner = false; + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSendRaw")).toReturn(0); + NetChannel channel = new NetChannel(123, 456, false); + channel.sendRaw(message, 1); + } + + @Test(expected = Exception.class) + public void sendRawExceptionTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data); + byte[] data1 = new byte[10]; + Arrays.fill(data, (byte) 'c'); + message.data = data1; + message.transferOwner = false; + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSendRaw")).toReturn(1); + NetChannel channel = new NetChannel(123, 456, false); + channel.sendRaw(message, 1); + } + + @Test + public void sendRawCbTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data); + byte[] data1 = new byte[10]; + Arrays.fill(data, (byte) 'c'); + message.data = data1; + message.transferOwner = false; + NetChannelCallbackTest cbTest = new NetChannelCallbackTest(); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSendRawWithCb")).toReturn(0); + NetChannel channel = new NetChannel(123, 456, false); + channel.sendRaw(message, cbTest, 1); + } + + @Test(expected = Exception.class) + public void sendRawCbExceptionTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data); + byte[] data1 = new byte[10]; + Arrays.fill(data, (byte) 'c'); + message.data = data1; + message.transferOwner = false; + NetChannelCallbackTest cbTest = new NetChannelCallbackTest(); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSendRawWithCb")).toReturn(1); + NetChannel channel = new NetChannel(123, 456, false); + channel.sendRaw(message, cbTest, 1); + } + + @Test + public void syncCallTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data, false); + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'h'); + NetServiceMessage response = new NetServiceMessage(data1, false); + NetServiceOpInfo opInfo = new NetServiceOpInfo((short) 0); + NetServiceOpInfo respInfo = new NetServiceOpInfo((short) 0); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSyncCall")).toReturn(0); + NetChannel channel = new NetChannel(123, 456, false); + channel.syncCall(opInfo, message, respInfo, response); + } + + @Test(expected = Exception.class) + public void syncCallExceptionTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data, false); + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'h'); + NetServiceMessage response = new NetServiceMessage(data1, false); + NetServiceOpInfo opInfo = new NetServiceOpInfo((short) 0); + NetServiceOpInfo respInfo = new NetServiceOpInfo((short) 0); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSyncCall")).toReturn(1); + NetChannel channel = new NetChannel(123, 456, false); + channel.syncCall(opInfo, message, respInfo, response); + } + + @Test + public void syncCallRawTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data, false); + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'h'); + NetServiceMessage response = new NetServiceMessage(data1, false); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSyncCallRaw")).toReturn(0); + NetChannel channel = new NetChannel(123, 456, false); + channel.syncCallRaw(message, response); + } + + @Test(expected = Exception.class) + public void syncCallRawExceptionTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data, false); + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'h'); + NetServiceMessage response = new NetServiceMessage(data1, false); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeSyncCallRaw")).toReturn(1); + NetChannel channel = new NetChannel(123, 456, false); + channel.syncCallRaw(message, response); + } + + @Test + public void asyncCallTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data, false); + NetServiceOpInfo opInfo = new NetServiceOpInfo((short) 0); + NetChannelCallbackTest cbTest = new NetChannelCallbackTest(); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeAsyncCall")).toReturn(0); + NetChannel channel = new NetChannel(123, 456, false); + channel.asyncCall(opInfo, message, cbTest); + } + + @Test(expected = Exception.class) + public void asyncCallExceptionTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data, false); + NetServiceOpInfo opInfo = new NetServiceOpInfo((short) 0); + NetChannelCallbackTest cbTest = new NetChannelCallbackTest(); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeAsyncCall")).toReturn(1); + NetChannel channel = new NetChannel(123, 456, false); + channel.asyncCall(opInfo, message, cbTest); + } + + @Test + public void asyncCallRawTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data, false); + NetChannelCallbackTest cbTest = new NetChannelCallbackTest(); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeAsyncCallRaw")).toReturn(0); + NetChannel channel = new NetChannel(123, 456, false); + channel.asyncCallRaw(message, cbTest); + } + + @Test(expected = Exception.class) + public void asyncCallRawExceptionTest() throws Exception { + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(data, false); + NetChannelCallbackTest cbTest = new NetChannelCallbackTest(); + PowerMockito.stub(PowerMockito.method(NetChannel.class, "nativeAsyncCallRaw")).toReturn(1); + NetChannel channel = new NetChannel(123, 456, false); + channel.asyncCallRaw(message, cbTest); + } + + @Test + public void upCtxTest() { + NetChannel channel = new NetChannel(123, 456, false); + channel.setUpCtx("12345"); + Assert.assertEquals("12345", channel.getUpCtx()); + } + + @Test + public void closeTest() { + NetChannel channel = new NetChannel(123, 456, false); + PowerMockito.suppress(PowerMockito.method(NetChannel.class, "decreaseObject")); + channel.Close(); + Assert.assertNull(channel.getUpCtx()); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetOobListenOptionsTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetOobListenOptionsTest.java new file mode 100644 index 0000000000000000000000000000000000000000..569abcb25da5da6880ff9af9960d6fae66890ace --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetOobListenOptionsTest.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service; + +import com.huawei.ock.hcom.service.NetOobListenOptions; + +import org.junit.Assert; +import org.junit.Test; + +/** + * class NetOobListenOptions test + * + * @since 2024-02-17 + */ +public class NetOobListenOptionsTest { + @Test + public void getIpAndPortTest() throws Exception { + String ip = "999999.0000"; + int port = 198; + NetOobListenOptions opt = new NetOobListenOptions(ip, port); + opt.validate(); + + Assert.assertEquals(ip, opt.getIp()); + Assert.assertEquals(port, opt.getPort()); + } + + @Test + public void getTargetWorkerCountTest() { + String ip1 = "999999.0000"; + int port1 = 198; + int targetWorkerCount = 3; + NetOobListenOptions opt1 = new NetOobListenOptions(ip1, port1, targetWorkerCount); + Assert.assertEquals(targetWorkerCount, opt1.getTargetWorkerCount()); + } + + @Test(expected = Exception.class) + public void validateTest_Exception() throws Exception { + String ip = "999999.0000"; + int port = -1; + NetOobListenOptions opt = new NetOobListenOptions(ip, port); + opt.validate(); + } + + @Test(expected = Exception.class) + public void validateTest_Exception2() throws Exception { + int port = 59; + NetOobListenOptions opt = new NetOobListenOptions("", port); + opt.validate(); + } + + @Test + public void setIpAndPortTest() { + String ip = "999999.0000"; + int port = 198; + NetOobListenOptions opt = new NetOobListenOptions(ip, port); + opt.setPort(100); + Assert.assertEquals(100, opt.getPort()); + + opt.setIp("123.123"); + Assert.assertEquals("123.123", opt.getIp()); + } + + @Test + public void setTargetWorkerCountTest() { + String ip = "999999.0000"; + int port = 198; + int targetWorkerCount = 3; + NetOobListenOptions opt = new NetOobListenOptions(ip, port, targetWorkerCount); + opt.setTargetWorkerCount(55); + Assert.assertEquals(55, opt.getTargetWorkerCount()); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetOobUDSListenOptionsTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetOobUDSListenOptionsTest.java new file mode 100644 index 0000000000000000000000000000000000000000..af5148aaa42c29867d61bf8e6bd899219136abb8 --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetOobUDSListenOptionsTest.java @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service; + +import com.huawei.ock.hcom.service.NetOobUDSListenOptions; + +import org.junit.Test; + +/** + * class NetOobUDSListenOptions test + * + * @since 2024-02-17 + */ +public class NetOobUDSListenOptionsTest { + @Test(expected = Exception.class) + public void validateTest_Execption() throws Exception { + NetOobUDSListenOptions udsOpt1 = new NetOobUDSListenOptions(); + udsOpt1.name = "yyyyy"; + udsOpt1.validate(); + + NetOobUDSListenOptions udsOpt = new NetOobUDSListenOptions(); + udsOpt.validate(); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetProvideSecInfoTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetProvideSecInfoTest.java new file mode 100644 index 0000000000000000000000000000000000000000..d2c55f781611376173ddd1f585902bd9635635ae --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetProvideSecInfoTest.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service; + +import com.huawei.ock.hcom.service.NetProvideSecInfo; +import com.huawei.ock.hcom.service.NetServiceOptions; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.modules.junit4.PowerMockRunner; + +/** + * class NetProvideSecInfo test + * + * @since 2024-08-22 + */ +@RunWith(PowerMockRunner.class) +public class NetProvideSecInfoTest { + @Test + public void sendTest() { + NetProvideSecInfo info = new NetProvideSecInfo(); + info.type = NetServiceOptions.SecInfoValidateType.SEC_VALIDATE_DISABLED; + info.validate(); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceConnectOptionsTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceConnectOptionsTest.java new file mode 100644 index 0000000000000000000000000000000000000000..219c82681b6683a4ce9e112bfb05e48b2b23d983 --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceConnectOptionsTest.java @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service; + +import com.huawei.ock.hcom.service.NetServiceConnectOptions; + +import org.junit.Test; + +/** + * class NetServiceConnectOptions test + * + * @since 2024-02-17 + */ +public class NetServiceConnectOptionsTest { + @Test + public void validateTest() throws Exception { + NetServiceConnectOptions options = new NetServiceConnectOptions(); + options.validate(); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceContextTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceContextTest.java new file mode 100644 index 0000000000000000000000000000000000000000..35a28b1c370c65d20ff4d0fe8e0544059d030b5e --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceContextTest.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service; + +import com.huawei.ock.hcom.service.NetServiceContext; +import com.huawei.ock.hcom.service.NetChannelCallback; +import com.huawei.ock.hcom.service.NetChannel; +import com.huawei.ock.hcom.service.NetServiceOpInfo; +import com.huawei.ock.hcom.service.NetServiceMessage; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.Arrays; + +/** + * class NetServiceContext test + * + * @since 2024-02-17 + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({NetServiceContext.class, NetChannel.class}) +@SuppressStaticInitializationFor("com.huawei.ock.hcom.service.NetServiceContext") +public class NetServiceContextTest { + @Test + public void GetOpInfoTest() { + byte[] data = new byte[1024]; + NetServiceContext ctx = new NetServiceContext(0L, 1L, true, (short) 100, (short) 1, (short) 0, + (short) 10, 3, data, 0L); + NetChannel channel = ctx.getChannel(); + Assert.assertNotNull(channel); + NetServiceOpInfo opt = ctx.getOpInfo(); + Assert.assertEquals(10, opt.flags); + Assert.assertEquals(100, opt.opCode); + } + + @Test + public void replySendTest_Except() throws Exception { + byte[] data = new byte[1024]; + NetServiceContext ctx = new NetServiceContext(0L, 1L, true, (short) 100, (short) 1, (short) 0, + (short) 10, 3, data, 0L); + byte[] localData = new byte[100]; + Arrays.fill(localData, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(localData); + NetServiceOpInfo opInfo1 = new NetServiceOpInfo((short) 0); + PowerMockito.suppress(PowerMockito.method(NetChannel.class, "send", NetServiceOpInfo.class, + NetServiceMessage.class, long.class)); + ctx.replySend(opInfo1, message); + } + + class NetChannelCallbackTest implements NetChannelCallback { + @Override + public void run(NetServiceContext netServiceContext) {} + } + + @Test + public void replySendTest_Except2() throws Exception { + byte[] data = new byte[1024]; + NetServiceContext ctx = new NetServiceContext(0L, 1L, true, (short) 100, (short) 1, (short) 0, + (short) 10, 3, data, 0L); + byte[] localData = new byte[100]; + Arrays.fill(localData, (byte) 'h'); + NetServiceMessage message = new NetServiceMessage(localData, true); + NetServiceOpInfo opInfo1 = new NetServiceOpInfo((short) 0); + NetChannelCallbackTest cb = new NetChannelCallbackTest(); + PowerMockito.suppress(PowerMockito.method(NetChannel.class, "send", NetServiceOpInfo.class, + NetServiceMessage.class, NetChannelCallback.class, long.class)); + ctx.replySend(opInfo1, message, cb); + + ctx.invalidate(); + ctx.replySendRaw(message, cb); + } + + @Test + public void getNum() { + byte[] data = new byte[1024]; + NetServiceContext ctx1 = new NetServiceContext(0L, 1L, true, (short) 100, (short) 1, (short) 0, + (short) 10, 3, data, 0L); + // getOpCode + Assert.assertEquals(100, ctx1.getOpCode()); + + // getResult + Assert.assertEquals(3, ctx1.getResult()); + + // getData + Assert.assertEquals(data, ctx1.getData()); + + // getDataLength + Assert.assertEquals(1024, ctx1.getDataLength()); + + // isTimeout + Assert.assertEquals(false, ctx1.isTimeout()); + + // getRspCtx + Assert.assertEquals(0L, ctx1.getRspCtx()); + } + + @Test + public void getOpCode() throws Exception { + byte[] data = new byte[1024]; + NetServiceContext ctx1 = new NetServiceContext(0L, 1L, true, (short) 100, (short) 1, (short) 0, + (short) 10, 3, data, 0L); + PowerMockito.suppress(PowerMockito.field(NetServiceContext.class, "opInfo")); + Assert.assertEquals(1024, ctx1.getOpCode()); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceOpInfoTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceOpInfoTest.java new file mode 100644 index 0000000000000000000000000000000000000000..684f4b62cdd59b8b64f53ad08382895e735ee395 --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceOpInfoTest.java @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service; + +import com.huawei.ock.hcom.service.NetServiceOpInfo; + +import org.junit.Assert; +import org.junit.Test; + +/** + * class NetServiceOpInfo test + * + * @since 2024-08-22 + */ +public class NetServiceOpInfoTest { + @Test + public void createNetServiceOpInfoTest() { + short opCode = 100; + short errorCode = 20; + NetServiceOpInfo opInfo = new NetServiceOpInfo(opCode, errorCode); + Assert.assertEquals(opInfo.errorCode, errorCode); + Assert.assertEquals(opInfo.opCode, opCode); + + short opCode1 = 111; + short errorCode1 = 5; + short timeOut = 10; + NetServiceOpInfo opInfo1 = new NetServiceOpInfo(opCode1, errorCode1, timeOut); + Assert.assertEquals(opInfo1.errorCode, errorCode1); + Assert.assertEquals(opInfo1.opCode, opCode1); + Assert.assertEquals(opInfo1.timeout, timeOut); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceOptionsTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceOptionsTest.java new file mode 100644 index 0000000000000000000000000000000000000000..b0b05fe8be4f1eb1a212775b2caadc1191e189b6 --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceOptionsTest.java @@ -0,0 +1,74 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service; + +import com.huawei.ock.hcom.service.NetServiceOptions; +import com.huawei.ock.hcom.service.NetOobUDSListenOptions; +import com.huawei.ock.hcom.service.NetOobListenOptions; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; + +/** + * class ServiceOptions test + * + * @since 2024-02-17 + */ +public class NetServiceOptionsTest { + String ip = "12121.12312 "; + int port = 9982; + + @Test(expected = Exception.class) + public void validateTest_Fail1() throws Exception { + NetServiceOptions options = new NetServiceOptions(); + NetOobUDSListenOptions optUds = new NetOobUDSListenOptions(); + optUds.name = "null"; + options.addUDSListenOptions(optUds); + NetOobUDSListenOptions optUds1 = new NetOobUDSListenOptions(); + options.addUDSListenOptions(optUds1); + options.validate(); + } + + @Test(expected = Exception.class) + public void validateTest_Fail2() throws Exception { + NetServiceOptions options = new NetServiceOptions(); + options.validate(); + NetOobListenOptions opt = new NetOobListenOptions(ip, 1); + options.addListenOptions(opt); + NetOobListenOptions opt1 = new NetOobListenOptions(ip, -1); + options.addListenOptions(opt1); + options.validate(); + } + + @Test + public void AddListenOptionsTest() { + NetServiceOptions options = new NetServiceOptions(); + NetOobListenOptions opt = new NetOobListenOptions(ip, port); + options.addListenOptions(opt); + ArrayList list = options.getOobListenOptions(); + Assert.assertEquals(ip, list.get(0).getIp()); + } + + @Test + public void AddUDSListenOptions() { + NetServiceOptions options = new NetServiceOptions(); + NetOobUDSListenOptions opt = new NetOobUDSListenOptions(); + opt.perm = 10; + options.addUDSListenOptions(opt); + + ArrayList list = options.getOobUDSListenOptions(); + Assert.assertEquals(10, list.get(0).perm); + } +} \ No newline at end of file diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..8f581f1898da7398aff0063835a0eae45ed614a6 --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/NetServiceTest.java @@ -0,0 +1,289 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service; + +import com.huawei.ock.hcom.common.ExternLogger; +import com.huawei.ock.hcom.common.ExternLoggerListener; +import com.huawei.ock.hcom.service.NetService; +import com.huawei.ock.hcom.service.NetChannel; +import com.huawei.ock.hcom.service.NetDriverOptions; +import com.huawei.ock.hcom.service.NetProvideSecInfo; +import com.huawei.ock.hcom.service.NetServiceOptions; +import com.huawei.ock.hcom.service.NetServiceContext; +import com.huawei.ock.hcom.service.NetServiceConnectOptions; +import com.huawei.ock.hcom.service.NetServiceListener; +import com.huawei.ock.hcom.service.NetSecValidateListener; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; +import org.powermock.modules.junit4.PowerMockRunner; + +/** + * class NetService test + * + * @since 2024-08-22 + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({NetService.class, System.class, ExternLoggerListener.class, Thread.class, NetServiceTest.class}) +@SuppressStaticInitializationFor("com.huawei.ock.hcom.service.NetService") +public class NetServiceTest { + @Before + public void setUp() { + PowerMockito.suppress(PowerMockito.method(ExternLoggerListener.class, "run")); + } + + @Test + public void InstanceTest() throws Exception { + PowerMockito.suppress(PowerMockito.method(System.class, "loadLibrary")); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeInstance")).toReturn(1L); + NetService serviceTest = NetService.Instance(NetDriverOptions.Protocol.TCP, "name", true); + Assert.assertNotNull(serviceTest); + } + + @Test(expected = Exception.class) + public void InstanceExceptionTest() throws Exception { + PowerMockito.suppress(PowerMockito.method(System.class, "loadLibrary")); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeInstance")).toReturn(-1L); + NetService serviceTest = NetService.Instance(NetDriverOptions.Protocol.TCP, "name", true); + Assert.assertNull(serviceTest); + } + + /** + * 创建service实例 + * + * @return NetService 返回创建的Service实例 + * @throws Exception if an error occurs when createService + */ + public NetService createService() throws Exception { + PowerMockito.suppress(PowerMockito.method(System.class, "loadLibrary")); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeInstance")).toReturn(1L); + return NetService.Instance(NetDriverOptions.Protocol.TCP, "testService", true); + } + + class ExternLoggerTest implements ExternLogger { + @Override + public void log(int level, String message) { + } + } + + @Test + public void addExternLoggerTest() throws Exception { + ExternLoggerTest exLoggerTest = new ExternLoggerTest(); + PowerMockito.suppress(PowerMockito.method(ExternLoggerListener.class, "nativeAddExternLogListener")); + PowerMockito.suppress(PowerMockito.method(Thread.class, "start")); + NetService.addExternLogger(exLoggerTest, 1, 100); + } + + class NetServiceListenerTest implements NetServiceListener { + @Override + public int onNewChannel(String ipPort, NetChannel channel, String payload) { + return 0; + } + + @Override + public int onChannelBroken(NetChannel ch) { + return 0; + } + + @Override + public int onNewRequest(NetServiceContext ctx) { + return 0; + } + + @Override + public int onRequestSent(NetServiceContext ctx) { + return 0; + } + + @Override + public int onOneSideDone(NetServiceContext ctx) { + return 0; + } + + @Override + public int onIdle() { + return 0; + } + } + + @Test(expected = Exception.class) + public void addListenerCppAddressTest() throws Exception { + NetService serviceTest = createService(); + NetServiceListenerTest listenerTest = new NetServiceListenerTest(); + PowerMockito.suppress(PowerMockito.field(NetService.class, "cppAddress")); + serviceTest.addListener(listenerTest); + } + + @Test(expected = Exception.class) + public void addListenerNullTest() throws Exception { + NetService serviceTest = createService(); + serviceTest.addListener(null); + } + + @Test + public void addListenerTest() throws Exception { + NetService serviceTest = createService(); + NetServiceListenerTest listenerTest = new NetServiceListenerTest(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeAddEventListener")).toReturn(0); + serviceTest.addListener(listenerTest); + } + + @Test(expected = Exception.class) + public void addListenerExceptionTest() throws Exception { + NetService serviceTest = createService(); + NetServiceListenerTest listenerTest = new NetServiceListenerTest(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeAddEventListener")).toReturn(1); + serviceTest.addListener(listenerTest); + } + + class NetSecValidateListenerTest implements NetSecValidateListener { + @Override + public NetProvideSecInfo onProvideSecInfo(long ctx) { + return null; + } + + @Override + public int onValidateSecInfo(long ctx, long flag, String input) { + return 0; + } + } + + @Test(expected = Exception.class) + public void addSecValidateListenerCppAddressTest() throws Exception { + NetService serviceTest = createService(); + NetSecValidateListenerTest listenerTest = new NetSecValidateListenerTest(); + PowerMockito.suppress(PowerMockito.field(NetService.class, "cppAddress")); + serviceTest.addSecValidateListener(listenerTest); + } + + @Test(expected = Exception.class) + public void addSecValidateListenerNullTest() throws Exception { + NetService serviceTest = createService(); + serviceTest.addSecValidateListener(null); + } + + @Test + public void addSecValidateListenerTest() throws Exception { + NetService serviceTest = createService(); + NetSecValidateListenerTest listenerTest = new NetSecValidateListenerTest(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeAddSecValidateListener")).toReturn(0); + serviceTest.addSecValidateListener(listenerTest); + } + + @Test(expected = Exception.class) + public void addSecValidateListenerExceptionTest() throws Exception { + NetService serviceTest = createService(); + NetSecValidateListenerTest listenerTest = new NetSecValidateListenerTest(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeAddSecValidateListener")).toReturn(1); + serviceTest.addSecValidateListener(listenerTest); + } + + @Test(expected = Exception.class) + public void startCppAddressTest() throws Exception { + NetService serviceTest = createService(); + NetServiceOptions options = new NetServiceOptions(); + PowerMockito.suppress(PowerMockito.field(NetService.class, "cppAddress")); + serviceTest.start(options); + } + + @Test(expected = Exception.class) + public void startNullTest() throws Exception { + NetService serviceTest = createService(); + serviceTest.start(null); + } + + @Test + public void startTest() throws Exception { + NetService serviceTest = createService(); + NetServiceOptions options = new NetServiceOptions(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeStart")).toReturn(0); + serviceTest.start(options); + } + + @Test(expected = Exception.class) + public void startExceptionTest() throws Exception { + NetService serviceTest = createService(); + NetServiceOptions options = new NetServiceOptions(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeStart")).toReturn(1); + serviceTest.start(options); + } + + @Test + public void stopNotStartTest() throws Exception { + NetService serviceTest = createService(); + serviceTest.stop(); + } + + @Test + public void stopTest() throws Exception { + NetService serviceTest = createService(); + NetServiceOptions options = new NetServiceOptions(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeStart")).toReturn(0); + serviceTest.start(options); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeStop")).toReturn(0); + serviceTest.stop(); + } + + @Test(expected = Exception.class) + public void connectNameNullTest() throws Exception { + NetService serviceTest = createService(); + String oobIpOrName = null; + int oobPort = 100; + String payload = "testService"; + NetServiceConnectOptions options = new NetServiceConnectOptions(); + serviceTest.connect(oobIpOrName, oobPort, payload, options); + } + + @Test + public void connectTest() throws Exception { + NetService serviceTest = createService(); + String oobIpOrName = "null"; + int oobPort = 100; + String payload = "testService"; + NetServiceConnectOptions options = new NetServiceConnectOptions(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeConnect")).toReturn(0L); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeGetChannelProperty")).toReturn("1234456##1"); + NetChannel channel = serviceTest.connect(oobIpOrName, oobPort, payload, options); + Assert.assertNotNull(channel); + } + + @Test(expected = Exception.class) + public void connectNativeConnectTest() throws Exception { + NetService serviceTest = createService(); + String oobIpOrName = "oobIpOrName"; + int oobPort = 100; + String payload = "testService"; + NetServiceConnectOptions options = new NetServiceConnectOptions(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeConnect")).toReturn(-1L); + serviceTest.connect(oobIpOrName, oobPort, payload, options); + } + + @Test(expected = Exception.class) + public void connectNativeGetChannelPropertyTest() throws Exception { + NetDriverOptions opt = new NetDriverOptions(); + NetService serviceTest = createService(); + String oobIpOrName = "oobIpOrName"; + int oobPort = 100; + String payload = "testService"; + NetServiceConnectOptions options = new NetServiceConnectOptions(); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeConnect")).toReturn(0L); + PowerMockito.stub(PowerMockito.method(NetService.class, "nativeGetChannelProperty")).toReturn(null); + PowerMockito.suppress(PowerMockito.method(NetService.class, "nativeDestroyChannel")); + serviceTest.connect(oobIpOrName, oobPort, payload, options); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/datatype/CAInfoTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/datatype/CAInfoTest.java new file mode 100644 index 0000000000000000000000000000000000000000..04ca168e456cec52487fb5cda22a8cdd5679108f --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/datatype/CAInfoTest.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service.datatype; + +import com.huawei.ock.hcom.service.NetDriverOptions; +import com.huawei.ock.hcom.service.datatype.CAInfo; + +import org.junit.Assert; +import org.junit.Test; + +/** + * class CAInfo test + * + * @since 2024-02-17 + */ +public class CAInfoTest { + private String caPath = "yyyyyy"; + private String crlPath = "yyyyyy/ccccc"; + private NetDriverOptions.PeerCertVerifyType verifyType = NetDriverOptions.PeerCertVerifyType.VERIFY_BY_DEFAULT; + + @Test + public void getUT() { + CAInfo caInfo = new CAInfo(caPath, crlPath, verifyType); + Assert.assertEquals(caPath, caInfo.getCaPath()); + Assert.assertEquals(crlPath, caInfo.getCrlPath()); + Assert.assertEquals(1, caInfo.getVerifyType()); + } + + @Test + public void getCaPathUT() { + CAInfo caInfo = new CAInfo(caPath); + Assert.assertEquals(caPath, caInfo.getCaPath()); + } +} diff --git a/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/datatype/PKeyTest.java b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/datatype/PKeyTest.java new file mode 100644 index 0000000000000000000000000000000000000000..badd73d1e670685c941978a988390cfa7d9924de --- /dev/null +++ b/test/llt/testcase/api/java/src/com/huawei/ock/hcom/test/service/datatype/PKeyTest.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +package com.huawei.ock.hcom.test.service.datatype; + +import com.huawei.ock.hcom.service.datatype.PKey; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; + +/** + * class PKey test + * + * @since 2024-02-17 + */ +public class PKeyTest { + private String path = "yyyyyy"; + byte[] data = new byte[10]; + + @Test + public void getUT() { + PKey pKey = new PKey(path, data); + Assert.assertEquals(path, pKey.getPath()); + Assert.assertEquals(data, pKey.getKeypass()); + Arrays.fill(data, (byte) 0); + pKey.Clean(); + Assert.assertEquals(data, pKey.getKeypass()); + } +} diff --git a/test/llt/testcase/capi/test_capi_service.cpp b/test/llt/testcase/capi/test_capi_service.cpp new file mode 100644 index 0000000000000000000000000000000000000000..979eb2b6efd4823a27a46441932475d5dd9c05f9 --- /dev/null +++ b/test/llt/testcase/capi/test_capi_service.cpp @@ -0,0 +1,876 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#define protected public +#include "hcom_service.h" +#include "hcom_service_c.h" +#include "net_service_default_imp.h" +#include "test_capi_service.h" + +using namespace ock::hcom; +TestCapiService::TestCapiService() {} + +char oobIp[24] = "127.0.0.1"; +uint16_t oobPort = 9981; +char capiIpSeg[24] = "127.0.0.0/16"; + +uint32_t dataSize = 1024; + +ubs_hcom_service_type serviceType = C_SERVICE_SHM; +Net_Service service = 0; +Net_Service client = 0; +Net_Channel channel = 0; +Service_ConnectOpt capiOptions; +const char *udsName = "SHM_UDS"; +typedef struct { + uintptr_t lAddress; + uint32_t lKey; + uint32_t size; +} TestRegMrInfo; + +TestRegMrInfo capiRemoteMrInfo[4]; + +char *data = NULL; +char *data1 = NULL; + +ubs_hcom_channel_callback cb; +void CommonCb(void *arg, ubs_hcom_service_context context) +{ + return; +} + +static int NewEndPoint(Net_Channel newCh, uint64_t usrCtx, const char *payLoad) +{ + NN_LOG_INFO("Channel new, payLoad: " << payLoad); + channel = newCh; + return 0; +} + +static int EndPointBroken(Net_Channel ch, uint64_t usrCtx, const char *payLoad) +{ + NN_LOG_ERROR("Channel broken, payLoad:" << payLoad); + return 0; +} + +static int RequestReceived(ubs_hcom_service_context ctx, uint64_t usrCtx) +{ + NN_LOG_INFO("Get context type start"); + Service_Message message; + message.data = ubs_hcom_context_get_data(ctx); + message.size = ubs_hcom_context_get_datalen(ctx); + if (message.data == NULL) { + NN_LOG_ERROR("failed to get message"); + return -1; + } + + ubs_hcom_service_context_type type; + if (ubs_hcom_context_get_type(ctx, &type) != 0) { + NN_LOG_ERROR("Get context type failed"); + return -1; + } + + Net_Channel tmpChannel; + if (ubs_hcom_context_get_channel(ctx, &tmpChannel) != 0) { + NN_LOG_ERROR("Get channel failed"); + return -1; + } + + cb.cb = CommonCb; + cb.arg = NULL; + if (type == SERVICE_RECEIVED_RAW) { + char *receive = (char *)message.data; + if (receive[0] == 0) { + // receive send message + return 0; + } + + // receive call message, need send response + Service_Message req = message; + + Service_OpInfo sendOpInfo = { 0, 0, 0, 0 }; + Service_RspCtx rsp = 0; + if (ubs_hcom_context_get_rspctx(ctx, &rsp) != 0) { + NN_LOG_ERROR("Get response ctx failed"); + return -1; + } + + // in hcom receive thread, need async send message; in user thread, there is no limit + if (Channel_PostResponse(tmpChannel, rsp, &sendOpInfo, &req, &cb) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return -1; + } + return 0; + } + + // SERVICE_RECEIVED type + Service_OpInfo opInfo; + if (Service_GetOpInfo(ctx, &opInfo) != 0) { + NN_LOG_ERROR("Get op info failed"); + return -1; + } + if (opInfo.opCode == 0) { + printf("receive msg, op code 0"); + } else if (opInfo.opCode == 2) { + Service_Message req = { capiRemoteMrInfo, sizeof(capiRemoteMrInfo) }; + Service_OpInfo sendOpInfo = opInfo; + + // post send callback + Service_RspCtx rsp = 0; + if (ubs_hcom_context_get_rspctx(ctx, &rsp) != 0) { + NN_LOG_ERROR("Get response ctx failed"); + return -1; + } + // in hcom receive thread, need async send message; in user thread, there is no limit + if (Channel_PostResponse(tmpChannel, rsp, &sendOpInfo, &req, &cb) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return -1; + } + } else { + Service_Message req = message; + // send the same message back to verify + + Service_OpInfo sendOpInfo = opInfo; + + // post send callback + Service_RspCtx rsp = 0; + if (ubs_hcom_context_get_rspctx(ctx, &rsp) != 0) { + NN_LOG_ERROR("Get response ctx failed"); + return -1; + } + + // in hcom receive thread, need async send message; in user thread, there is no limit + if (Channel_PostResponse(tmpChannel, rsp, &sendOpInfo, &req, &cb) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return -1; + } + } + return 0; +} + +static int RequestPosted(ubs_hcom_service_context ctx, uint64_t usrCtx) +{ + NN_LOG_INFO("posted"); + return 0; +} + +static int OneSideDone(ubs_hcom_service_context ctx, uint64_t usrCtx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + +int RegCapiSglMem(Net_Service driver, TestRegMrInfo capiMrInfo[], std::vector &mrs) +{ + for (uint16_t i = 0; i < 4; i++) { + ubs_hcom_memory_region mrArray; + int result = ubs_hcom_service_register_memory_region(driver, dataSize, &mrArray); + if (result != 0) { + NN_LOG_ERROR("reg mr failed"); + return -1; + } + + ubs_hcom_mr_info mrInfo; + result = ubs_hcom_service_get_memory_region_info(mrArray, &mrInfo); + if (result != 0) { + NN_LOG_ERROR("parse mr failed"); + return -1; + } + capiMrInfo[i].lAddress = mrInfo.lAddress; + capiMrInfo[i].lKey = mrInfo.lKey; + capiMrInfo[i].size = mrInfo.size; + mrs.push_back(mrArray); + memset(reinterpret_cast(capiMrInfo[i].lAddress), ' ', capiMrInfo[i].size); + } + + return 0; +} + +void DestoryCapiSglMem(Net_Service driver, std::vector &mrs) +{ + while (!mrs.empty()) { + ubs_hcom_service_destroy_memory_region(driver, mrs.back()); + mrs.pop_back(); + } +} + +typedef struct { + Service_Message rsp; + sem_t sem; + int ret; +} CallAsyncStruct; +void CallAsyncCb(void *arg, ubs_hcom_service_context context) +{ + CallAsyncStruct *param = (CallAsyncStruct *)arg; + + if (ubs_hcom_context_get_result(context, ¶m->ret) != 0) { + NN_LOG_ERROR("Call async failed " << param->ret); + return; + } + + Service_Message message; + message.data = ubs_hcom_context_get_data(context); + message.size = ubs_hcom_context_get_datalen(context); + if (message.data == NULL) { + sem_post(¶m->sem); + param->ret = -1; + NN_LOG_ERROR("failed to get message"); + return; + } + + if (message.size != param->rsp.size) { + sem_post(¶m->sem); + param->ret = -1; + NN_LOG_ERROR("Receive unwanted message"); + return; + } + + memcpy(param->rsp.data, message.data, message.size); + sem_post(¶m->sem); +} + +int CreateCapiService(ubs_hcom_service_request_handler receiveCb) +{ + int result = 0; + + if (service != 0) { + NN_LOG_ERROR("service already created"); + return -1; + } + + result = ubs_hcom_service_create(serviceType, "server_capi", 1, &service); + if (result != 0) { + NN_LOG_ERROR("failed to create service already created"); + return -1; + } + + ubs_hcom_service_options capiOptions; + bzero(&capiOptions, sizeof(ubs_hcom_service_options)); + capiOptions.mode = C_SERVICE_BUSY_POLLING; + capiOptions.mrSendReceiveSegSize = 1024 + dataSize; + capiOptions.mrSendReceiveSegCount = 8192; + capiOptions.enableTls = false; + strcpy(capiOptions.netDeviceIpMask, capiIpSeg); + sprintf(capiOptions.workerGroups, "%u", 1); + if (serviceType == C_SERVICE_SHM) { + capiOptions.oobType = C_SERVICE_OOB_UDS; + capiOptions.mode = C_SERVICE_EVENT_POLLING; + Service_OobUDSListenerOptions listenOpt; + strcpy(listenOpt.name, udsName); + listenOpt.perm = 0; + Service_AddOobUdsOptions(service, listenOpt); + } + + Service_RegisterChannelHandler(service, C_CHANNEL_NEW, &NewEndPoint, C_CHANNEL_RECONNECT, 1); + Service_RegisterChannelHandler(service, C_CHANNEL_BROKEN, &EndPointBroken, C_CHANNEL_RECONNECT, 1); + Service_RegisterOpHandler(service, 0, C_SERVICE_REQUEST_RECEIVED, receiveCb, 1); + Service_RegisterOpHandler(service, 0, C_SERVICE_REQUEST_POSTED, &RequestPosted, 1); + Service_RegisterOpHandler(service, 0, C_SERVICE_READWRITE_DONE, &OneSideDone, 1); + + Service_SetOobIpAndPort(service, oobIp, oobPort); + + if ((result = ubs_hcom_service_start(service, capiOptions)) != 0) { + NN_LOG_ERROR("failed to start service " << result); + return -1; + } + NN_LOG_INFO("service started"); + + return 0; +} + +int CreateCapiClient() +{ + int result = 0; + + if (client != 0) { + NN_LOG_ERROR("service already created"); + return -1; + } + + result = ubs_hcom_service_create(serviceType, "client_capi", 0, &client); + if (result != 0) { + NN_LOG_ERROR("failed to create service already created"); + return -1; + } + + ubs_hcom_service_options capiOptions; + bzero(&capiOptions, sizeof(ubs_hcom_service_options)); + capiOptions.mode = C_SERVICE_BUSY_POLLING; + capiOptions.mrSendReceiveSegSize = 1024 + dataSize; + capiOptions.mrSendReceiveSegCount = 8192; + capiOptions.enableTls = false; + strcpy(capiOptions.netDeviceIpMask, capiIpSeg); + sprintf(capiOptions.workerGroups, "%u", 1); + if (serviceType == C_SERVICE_SHM) { + capiOptions.oobType = C_SERVICE_OOB_UDS; + capiOptions.mode = C_SERVICE_EVENT_POLLING; + } + + Service_RegisterChannelHandler(client, C_CHANNEL_BROKEN, &EndPointBroken, C_CHANNEL_RECONNECT, 1); + Service_RegisterOpHandler(client, 0, C_SERVICE_REQUEST_RECEIVED, &RequestReceived, 1); + Service_RegisterOpHandler(client, 0, C_SERVICE_REQUEST_POSTED, &RequestPosted, 1); + Service_RegisterOpHandler(client, 0, C_SERVICE_READWRITE_DONE, &OneSideDone, 1); + + Service_SetOobIpAndPort(client, oobIp, oobPort); + + if ((result = ubs_hcom_service_start(client, capiOptions)) != 0) { + NN_LOG_ERROR("failed to start service " << result); + return -1; + } + NN_LOG_INFO("service started"); + + return 0; +} + + +void TestCapiService::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); + setenv("HCOM_CONNECTION_RECV_TIMEOUT_SEC", "1", 1); + setenv("HCOM_CONNECTION_SEND_TIMEOUT_SEC", "1", 1); + capiOptions = { 0, 0, C_CHANNEL_FUNC_CB, 0, 0, 0 }; + capiOptions.epSize = 1; +} + +void TestCapiService::TearDown() +{ + service = 0; + client = 0; + capiOptions = { 0, 0, C_CHANNEL_FUNC_CB, 0, 0, 0 }; + GlobalMockObject::verify(); +} + +TEST_F(TestCapiService, ServiceGetMemoryRegionInfo) +{ + CreateCapiService(&RequestReceived); + ubs_hcom_memory_region mr; + int result = ubs_hcom_service_register_memory_region(0, dataSize, &mr); + EXPECT_EQ(501, result); + + result = ubs_hcom_service_register_memory_region(service, dataSize, nullptr); + EXPECT_EQ(501, result); + + result = ubs_hcom_service_register_memory_region(service, 0, &mr); + EXPECT_EQ(103, result); + + result = ubs_hcom_service_register_memory_region(service, dataSize, &mr); + EXPECT_EQ(0, result); + + ubs_hcom_mr_info mrInfo; + result = ubs_hcom_service_get_memory_region_info(0, &mrInfo); + EXPECT_EQ(501, result); + + result = ubs_hcom_service_get_memory_region_info(mr, &mrInfo); + EXPECT_EQ(0, result); + + ubs_hcom_service_destroy_memory_region(service, mr); + Service_Stop(service); + ubs_hcom_service_destroy(service); +} + +TEST_F(TestCapiService, ChannelPostSend) +{ + CreateCapiService(&RequestReceived); + CreateCapiClient(); + int result = 0; + Net_Channel clientChannel = 0; + + result = ubs_hcom_service_connect(0, udsName, oobPort, "hello service c", &clientChannel, &capiOptions); + EXPECT_EQ(501, result); + + result = ubs_hcom_service_connect(client, udsName, oobPort, nullptr, &clientChannel, &capiOptions); + EXPECT_EQ(501, result); + + result = ubs_hcom_service_connect(client, udsName, oobPort, "hello service c", nullptr, &capiOptions); + EXPECT_EQ(501, result); + + capiOptions.epSize = 19; + result = ubs_hcom_service_connect(client, udsName, oobPort, "hello service c", &clientChannel, &capiOptions); + EXPECT_EQ(501, result); + + capiOptions.epSize = 1; + result = ubs_hcom_service_connect(client, udsName, oobPort, "hello service c", &clientChannel, &capiOptions); + EXPECT_EQ(0, result); + + data = (char *)malloc(dataSize); + Service_Message message = { data, dataSize }; + + Service_OpInfo opInfo = { 0, 0, 0, 0 }; + opInfo.opCode = 0; + opInfo.timeout = 1; + + /* postSend */ + result = Channel_PostSend(clientChannel, &opInfo, &message, NULL); + EXPECT_EQ(0, result); + + /* creat callback */ + char data2[dataSize]; + CallAsyncStruct asyncParam; + asyncParam.rsp.data = data2; + asyncParam.rsp.size = dataSize; + sem_init(&asyncParam.sem, 0, 0); + asyncParam.ret = 0; + + ubs_hcom_channel_callback cb; + cb.arg = &asyncParam; + cb.cb = CallAsyncCb; + result = Channel_PostSend(clientChannel, &opInfo, &message, &cb); + EXPECT_EQ(0, result); + sem_wait(&asyncParam.sem); + sem_destroy(&asyncParam.sem); + + result = Channel_PostSend(0, &opInfo, &message, NULL); + EXPECT_EQ(501, result); + + result = Channel_PostSend(clientChannel, nullptr, &message, NULL); + EXPECT_EQ(501, result); + + result = Channel_PostSend(clientChannel, &opInfo, nullptr, NULL); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::SendInner).stubs().will(returnValue(501)); + result = Channel_PostSend(clientChannel, &opInfo, &message, &cb); + EXPECT_EQ(501, result); + + /* postsendRaw */ + result = Channel_PostSendRaw(clientChannel, &message, NULL); + EXPECT_EQ(0, result); + + result = Channel_PostSendRaw(0, &message, NULL); + EXPECT_EQ(501, result); + + result = Channel_PostSendRaw(clientChannel, nullptr, NULL); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::SendRawInner).defaults().will(returnValue(501)); + result = Channel_PostSendRaw(clientChannel, &message, NULL); + EXPECT_EQ(501, result); + + GlobalMockObject::verify(); + + /* postSendRawSgl */ + TestRegMrInfo capiLocalMrInfo[4]; + std::vector mrClient; + result = RegCapiSglMem(client, capiLocalMrInfo, mrClient); + EXPECT_EQ(0, result); + std::vector mrService; + result = RegCapiSglMem(service, capiRemoteMrInfo, mrService); + EXPECT_EQ(0, result); + + Service_Message req; + req.data = data; + req.size = dataSize; + + Service_Message rsp; + TestRegMrInfo getRemoteMrInfo[4]; + rsp.data = getRemoteMrInfo; + rsp.size = sizeof(getRemoteMrInfo); + + Service_OpInfo reqInfo = { 0, 0, 0, 0 }; + reqInfo.opCode = 2; + + Service_OpInfo rspInfo = { 0, 0, 0, 0 }; + + result = Channel_SyncCall(clientChannel, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(0, result); + + Service_Request reqRawSgl; + reqRawSgl.lAddress = capiLocalMrInfo[0].lAddress; + reqRawSgl.rAddress = getRemoteMrInfo[0].lAddress; + reqRawSgl.lKey = capiLocalMrInfo[0].lKey; + reqRawSgl.rKey = getRemoteMrInfo[0].lKey; + reqRawSgl.size = capiLocalMrInfo[0].size; + + Service_SglRequest request; + request.iov = &reqRawSgl; + request.iovCount = 1; + memset((void *)(reqRawSgl.lAddress), 0, 1); + + result = Channel_PostSendRawSgl(clientChannel, &request, NULL); + EXPECT_EQ(0, result); + + result = Channel_PostSendRawSgl(0, &request, NULL); + EXPECT_EQ(501, result); + + result = Channel_PostSendRawSgl(clientChannel, nullptr, NULL); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::SendRawSglInner).defaults().will(returnValue(501)); + result = Channel_PostSendRawSgl(clientChannel, &request, NULL); + EXPECT_EQ(501, result); + DestoryCapiSglMem(client, mrClient); + DestoryCapiSglMem(service, mrService); + + GlobalMockObject::verify(); + + /* SyncCall */ + req = { data, dataSize }; + rsp = { data, dataSize }; + reqInfo = { 0, 0, 0, 0 }; + rspInfo = { 0, 0, 0, 0 }; + reqInfo.opCode = 1; + reqInfo.timeout = 0; + reqInfo.errorCode = 0xff; + + result = Channel_SyncCall(clientChannel, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(0, result); + + result = Channel_SyncCall(0, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncCall(clientChannel, nullptr, &req, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncCall(clientChannel, &reqInfo, nullptr, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncCall(clientChannel, &reqInfo, &req, &rspInfo, nullptr); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::SyncCallInner).defaults().will(returnValue(501)); + result = Channel_SyncCall(clientChannel, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + GlobalMockObject::verify(); + + /* AsyncCall */ + req = { data, dataSize }; + opInfo = { 0, 0, 0, 0 }; + opInfo.opCode = 1; + + // CallAsyncStruct asyncParam + asyncParam.rsp.data = data2; + asyncParam.rsp.size = dataSize; + sem_init(&asyncParam.sem, 0, 0); + asyncParam.ret = 0; + + // ubs_hcom_channel_callback cb + cb.arg = &asyncParam; + cb.cb = CallAsyncCb; + + result = Channel_AsyncCall(clientChannel, &opInfo, &req, &cb); + EXPECT_EQ(0, result); + + result = Channel_AsyncCall(0, &opInfo, &req, &cb); + EXPECT_EQ(501, result); + + result = Channel_AsyncCall(clientChannel, nullptr, &req, &cb); + EXPECT_EQ(501, result); + + result = Channel_AsyncCall(clientChannel, &opInfo, nullptr, &cb); + EXPECT_EQ(501, result); + + result = Channel_AsyncCall(clientChannel, &opInfo, &req, nullptr); + EXPECT_EQ(501, result); + + sem_wait(&asyncParam.sem); + sem_destroy(&asyncParam.sem); + + MOCKER_CPP(&NetChannel::AsyncCallInner).defaults().will(returnValue(501)); + result = Channel_AsyncCall(channel, &opInfo, &req, &cb); + EXPECT_EQ(501, result); + + GlobalMockObject::verify(); + + /* Channel_SyncCallRaw */ + + req = { data, dataSize }; + rsp = { data, dataSize }; + data[0] = 1; + + result = Channel_SyncCallRaw(clientChannel, &req, &rsp); + EXPECT_EQ(0, result); + + result = Channel_SyncCallRaw(0, &req, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncCallRaw(clientChannel, nullptr, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncCallRaw(clientChannel, &req, nullptr); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::SyncCallRawInner).defaults().will(returnValue(501)); + result = Channel_SyncCallRaw(clientChannel, &req, &rsp); + EXPECT_EQ(501, result); + + GlobalMockObject::verify(); + + /* Channel_SyncCallRawSgl */ + result = RegCapiSglMem(client, capiLocalMrInfo, mrClient); + EXPECT_EQ(0, result); + result = RegCapiSglMem(service, capiRemoteMrInfo, mrService); + EXPECT_EQ(0, result); + + data = (char *)malloc(dataSize); + // Service_Message req + req.data = data; + req.size = dataSize; + + + rsp.data = getRemoteMrInfo; + rsp.size = sizeof(getRemoteMrInfo); + + reqInfo = { 0, 0, 0, 0 }; + reqInfo.opCode = 2; + rspInfo = { 0, 0, 0, 0 }; + + result = Channel_SyncCall(clientChannel, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(0, result); + + Service_Request reqIov; + reqIov.lAddress = capiLocalMrInfo[0].lAddress; + reqIov.rAddress = getRemoteMrInfo[0].lAddress; + reqIov.lKey = capiLocalMrInfo[0].lKey; + reqIov.rKey = getRemoteMrInfo[0].lKey; + reqIov.size = capiLocalMrInfo[0].size; + + Service_SglRequest reqRawSgl2; + char *buff = (char *)(reqIov.lAddress); + buff[0] = 1; // mark call + reqRawSgl2.iov = &reqIov; + reqRawSgl2.iovCount = 1; + + Service_Message rspRawSgl = { data, dataSize }; + data[0] = 1; // mark call + result = Channel_SyncCallRawSgl(clientChannel, &reqRawSgl2, &rspRawSgl); + EXPECT_EQ(0, result); + + result = Channel_SyncCallRawSgl(0, &reqRawSgl2, &rspRawSgl); + EXPECT_EQ(501, result); + + result = Channel_SyncCallRawSgl(clientChannel, nullptr, &rspRawSgl); + EXPECT_EQ(501, result); + + result = Channel_SyncCallRawSgl(clientChannel, &reqRawSgl2, nullptr); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::SyncCallRawSglInner).defaults().will(returnValue(501)); + result = Channel_SyncCallRawSgl(clientChannel, &reqRawSgl2, &rspRawSgl); + EXPECT_EQ(501, result); + + DestoryCapiSglMem(client, mrClient); + DestoryCapiSglMem(service, mrService); + + GlobalMockObject::verify(); + + /* Channel_WriteRead */ + + result = RegCapiSglMem(client, capiLocalMrInfo, mrClient); + EXPECT_EQ(0, result); + result = RegCapiSglMem(service, capiRemoteMrInfo, mrService); + EXPECT_EQ(0, result); + + req.data = data; + req.size = dataSize; + + rsp.data = getRemoteMrInfo; + rsp.size = sizeof(getRemoteMrInfo); + + reqInfo = { 0, 0, 0, 0 }; + reqInfo.opCode = 2; + rspInfo = { 0, 0, 0, 0 }; + + result = Channel_SyncCall(clientChannel, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(0, result); + + result = Channel_SyncCall(clientChannel, &reqInfo, &req, &rspInfo, nullptr); + EXPECT_EQ(501, result); + + result = Channel_SyncCall(clientChannel, nullptr, &req, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncCall(0, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncCall(clientChannel, &reqInfo, nullptr, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + Service_Request reqWrite; + reqWrite.lAddress = capiLocalMrInfo[0].lAddress; + reqWrite.rAddress = getRemoteMrInfo[0].lAddress; + reqWrite.lKey = capiLocalMrInfo[0].lKey; + reqWrite.rKey = getRemoteMrInfo[0].lKey; + reqWrite.size = capiLocalMrInfo[0].size; + + /* write */ + result = Channel_Write(clientChannel, &reqWrite, NULL); + EXPECT_EQ(0, result); + + result = Channel_Write(0, &reqWrite, NULL); + EXPECT_EQ(501, result); + + result = Channel_Write(clientChannel, nullptr, NULL); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::WriteInner, int (NetChannel::*)(const NetServiceRequest &, const NetCallback *)) + .defaults() + .will(returnValue(501)); + result = Channel_Write(clientChannel, &reqWrite, NULL); + EXPECT_EQ(501, result); + + GlobalMockObject::verify(); + + Service_SglRequest reqSgl; + reqSgl.iov = &reqWrite; + reqSgl.iovCount = 1; + + /* write sgl */ + result = Channel_WriteSgl(clientChannel, &reqSgl, NULL); + EXPECT_EQ(0, result); + + result = Channel_WriteSgl(0, &reqSgl, NULL); + EXPECT_EQ(501, result); + + result = Channel_WriteSgl(clientChannel, nullptr, NULL); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::WriteSglInner, int (NetChannel::*)(const NetServiceSglRequest &, const NetCallback *)) + .defaults() + .will(returnValue(501)); + result = Channel_WriteSgl(clientChannel, &reqSgl, NULL); + EXPECT_EQ(501, result); + + GlobalMockObject::verify(); + + /* read */ + result = Channel_Read(clientChannel, &reqWrite, NULL); + EXPECT_EQ(0, result); + + result = Channel_Read(0, &reqWrite, NULL); + EXPECT_EQ(501, result); + + result = Channel_Read(clientChannel, nullptr, NULL); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::ReadInner, int (NetChannel::*)(const NetServiceRequest &, const NetCallback *)) + .defaults() + .will(returnValue(501)); + result = Channel_Read(clientChannel, &reqWrite, NULL); + EXPECT_EQ(501, result); + + GlobalMockObject::verify(); + + /* read sgl */ + result = Channel_ReadSgl(clientChannel, &reqSgl, NULL); + EXPECT_EQ(0, result); + + result = Channel_ReadSgl(0, &reqSgl, NULL); + EXPECT_EQ(501, result); + + result = Channel_ReadSgl(clientChannel, nullptr, NULL); + EXPECT_EQ(501, result); + + MOCKER_CPP(&NetChannel::ReadSglInner, int (NetChannel::*)(const NetServiceSglRequest &, const NetCallback *)) + .defaults() + .will(returnValue(501)); + result = Channel_ReadSgl(clientChannel, &reqSgl, NULL); + EXPECT_EQ(501, result); + + DestoryCapiSglMem(client, mrClient); + DestoryCapiSglMem(service, mrService); + GlobalMockObject::verify(); + + free(data); + data = nullptr; + Channel_Destroy(clientChannel); + Channel_Destroy(channel); + Service_Stop(service); + ubs_hcom_service_destroy(service); + Service_Stop(client); + ubs_hcom_service_destroy(client); +} + + +static int RequestReceivedContext(ubs_hcom_service_context ctx, uint64_t usrCtx) +{ + NN_LOG_INFO("Get context type start"); + Service_Message message; + message.data = ubs_hcom_context_get_data(ctx); + message.size = ubs_hcom_context_get_datalen(ctx); + if (message.data == NULL) { + NN_LOG_ERROR("failed to get message"); + return -1; + } + + ubs_hcom_service_context_type type; + if (ubs_hcom_context_get_type(ctx, &type) != 0) { + NN_LOG_ERROR("Get context type failed"); + return -1; + } + + cb.cb = CommonCb; + cb.arg = NULL; + if (type == SERVICE_RECEIVED_RAW) { + // receive call message, need send response + Service_Message req = message; + Service_OpInfo sendOpInfo = { 0, 0, 0, 0 }; + + ubs_hcom_service_context cloneCtx = Service_ContextClone(ctx); + EXPECT_NE(0ul, cloneCtx); + + // in hcom receive thread, need async send message; in user thread, there is no limit + auto ret = Service_ContextReplyRaw(cloneCtx, &req, &cb); + Service_ContextDeClone(cloneCtx); + EXPECT_EQ(ret, 0); + return 0; + } + + // SERVICE_RECEIVED type + Service_OpInfo opInfo; + Service_Message req = message; + // send the same message back to verify + Service_OpInfo sendOpInfo = opInfo; + ubs_hcom_service_context cloneCtx = Service_ContextClone(ctx); + EXPECT_NE(0ul, cloneCtx); + + // in hcom receive thread, need async send message; in user thread, there is no limit + auto ret = Service_ContextReply(cloneCtx, &sendOpInfo, &req, &cb); + Service_ContextDeClone(cloneCtx); + EXPECT_EQ(ret, 0); + + return 0; +} + +TEST_F(TestCapiService, ServiceContextTest) +{ + CreateCapiService(&RequestReceivedContext); + CreateCapiClient(); + + Net_Channel clientChannel; + capiOptions.epSize = 1; + auto result = ubs_hcom_service_connect(client, udsName, oobPort, "hello service c context", &clientChannel, &capiOptions); + EXPECT_EQ(0, result); + + char testData[512]; + Service_Message req = { testData, sizeof(testData) }; + Service_Message rsp = { testData, sizeof(testData) }; + Service_OpInfo reqInfo = { 0, 0, 0, 0 }; + Service_OpInfo rspInfo = { 0, 0, 0, 0 }; + reqInfo.opCode = 1; + reqInfo.timeout = 0; + reqInfo.errorCode = 0xff; + + result = Channel_SyncCall(clientChannel, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(0, result); + + result = Channel_SyncCallRaw(clientChannel, &req, &rsp); + EXPECT_EQ(0, result); + + Channel_Destroy(clientChannel); + Channel_Destroy(channel); + Service_Stop(service); + ubs_hcom_service_destroy(service); + Service_Stop(client); + ubs_hcom_service_destroy(client); +} \ No newline at end of file diff --git a/test/llt/testcase/capi/test_capi_service.h b/test/llt/testcase/capi/test_capi_service.h new file mode 100644 index 0000000000000000000000000000000000000000..26ad51e2cb9c870b30f0af215496b7e6b6710f82 --- /dev/null +++ b/test/llt/testcase/capi/test_capi_service.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_TEST_CAPI_SERVICE_H +#define HCOM_TEST_CAPI_SERVICE_H +#include +#include + +class TestCapiService : public testing::Test { +public: + TestCapiService(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_CAPI_SERVICE_H diff --git a/test/llt/testcase/capi/test_capi_service_rndv.cpp b/test/llt/testcase/capi/test_capi_service_rndv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5bc8bc2dfb65a95443dbe29e81e5dd496a10b12 --- /dev/null +++ b/test/llt/testcase/capi/test_capi_service_rndv.cpp @@ -0,0 +1,556 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "test_capi_service_rndv.h" + +#include "hcom.h" +#include "hcom_service_c.h" +#include "rdma_common.h" +#include "hcom_env.h" + +using namespace ock::hcom; +TestCapiServiceRndv::TestCapiServiceRndv() {} +#define BASE_IP "127.0.0.1" +#define IP_SEG "127.0.0.0/16" +Net_Channel serviceChannel = 0; +Net_Service serviceRndv = 0; +Net_Service clientRndv = 0; +ubs_hcom_service_type serviceRndvType = C_SERVICE_RDMA; +int32_t rndvSize = 1024; +int32_t rndvPort = 0; +int rdnvBasePort = 4900; +Service_ConnectOpt rndvOptions; +ubs_hcom_channel_callback cbRndv; + + +static int NewEndPoint(Net_Channel newCh, uint64_t usrCtx, const char *payLoad) +{ + NN_LOG_INFO("Channel new, payLoad: " << payLoad); + serviceChannel = newCh; + return 0; +} + +static int EndPointBroken(Net_Channel ch, uint64_t usrCtx, const char *payLoad) +{ + NN_LOG_ERROR("Channel broken, payLoad:" << payLoad); + return 0; +} + +static int RequestReceived(ubs_hcom_service_context ctx, uint64_t usrCtx) +{ + NN_LOG_INFO("Request received"); + return 0; +} + +static int RequestPosted(ubs_hcom_service_context ctx, uint64_t usrCtx) +{ + return 0; +} + +static int OneSideDone(ubs_hcom_service_context ctx, uint64_t usrCtx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + +Net_MemoryAllocator memAllocator = 0; +void *addressRndv = NULL; +Net_MemoryAllocator memClientAllocator = 0; +void *addressClientRndv = NULL; +uint64_t memSizeRndv = 1024 * 1024 * 128; + +int PrepareMemAllocator() +{ + addressRndv = memalign(4096, memSizeRndv); + if (addressRndv == NULL) { + NN_LOG_ERROR("Failed to alloc memory, maybe lack of spare memory in system."); + return -1; + } + + Net_MemoryAllocatorOptions options1; + options1.address = (uintptr_t)addressRndv; + options1.size = memSizeRndv; + options1.minBlockSize = 4096; + options1.alignedAddress = 1; + options1.cacheTierCount = 10; + options1.cacheBlockCountPerTier = 8; + options1.bucketCount = 8; + options1.cacheTierPolicy = C_TIER_POWER; + + int result = Net_MemoryAllocatorCreate(C_DYNAMIC_SIZE_WITH_CACHE, &options1, &memAllocator); + if (result != 0) { + NN_LOG_ERROR("Failed to create memory allocator"); + return -1; + } + + return 0; +} + +int PrepareClientMemAllocator() +{ + addressClientRndv = memalign(4096, memSizeRndv); + if (addressClientRndv == NULL) { + NN_LOG_ERROR("Failed to alloc memory, maybe lack of spare memory in system."); + return -1; + } + + Net_MemoryAllocatorOptions options1; + options1.address = (uintptr_t)addressClientRndv; + options1.size = memSizeRndv; + options1.minBlockSize = 4096; + options1.alignedAddress = 1; + options1.cacheTierCount = 10; + options1.cacheBlockCountPerTier = 8; + options1.bucketCount = 8; + options1.cacheTierPolicy = C_TIER_POWER; + + int result = Net_MemoryAllocatorCreate(C_DYNAMIC_SIZE_WITH_CACHE, &options1, &memClientAllocator); + if (result != 0) { + NN_LOG_ERROR("Failed to create memory allocator"); + return -1; + } + + return 0; +} + + +int MemAllocate(uint64_t rndvSize, uintptr_t *add, uint32_t *key) +{ + return Net_MemoryAllocatorAllocate(memAllocator, rndvSize, add, key); +} + +int MemFree(uintptr_t add) +{ + return Net_MemoryAllocatorFree(memAllocator, add); +} + +int MemClientAllocate(uint64_t rndvSize, uintptr_t *add, uint32_t *key) +{ + return Net_MemoryAllocatorAllocate(memClientAllocator, rndvSize, add, key); +} + +int MemClientFree(uintptr_t add) +{ + return Net_MemoryAllocatorFree(memClientAllocator, add); +} + +void CommonCbRndv(void *arg, ubs_hcom_service_context context) +{ + return; +} + +int RndvHandlerC(Service_RndvContext context) +{ + // step1 direct handle message or change to other thread + + // step2 rsp message + int ret = 0; + Service_Message req = { &ret, sizeof(ret) }; + Service_OpInfo reqInfo = { 0 }; + cbRndv.cb = CommonCbRndv; + ret = Service_RndvReply(context, &reqInfo, &req, &cbRndv); + if (ret != 0) { + NN_LOG_ERROR("Reply message failed " << ret); + } + + // step3 free context + Service_RndvFreeContext(context); + return ret; +} + +int CreateRndvService() +{ + int result = 0; + + if (serviceRndv != 0) { + NN_LOG_ERROR("service already created."); + return -1; + } + + result = ubs_hcom_service_create(serviceRndvType, "server_rndv", 1, &serviceRndv); + if (result != 0) { + NN_LOG_ERROR("failed to create service already created."); + return -1; + } + + result = PrepareMemAllocator(); + if (result != 0) { + NN_LOG_ERROR("failed to prepare mem."); + return -1; + } + + ubs_hcom_service_options options; + bzero(&options, sizeof(ubs_hcom_service_options)); + options.enableRndv = 1; + options.mode = C_SERVICE_BUSY_POLLING; + options.mrSendReceiveSegSize = 1024 + rndvSize; + options.mrSendReceiveSegCount = 8192; + options.enableTls = false; + strcpy(options.netDeviceIpMask, IP_SEG); + + Service_RegisterChannelHandler(serviceRndv, C_CHANNEL_NEW, &NewEndPoint, C_CHANNEL_RECONNECT, 1); + Service_RegisterChannelHandler(serviceRndv, C_CHANNEL_BROKEN, &EndPointBroken, C_CHANNEL_RECONNECT, 1); + Service_RegisterOpHandler(serviceRndv, 0, C_SERVICE_REQUEST_RECEIVED, &RequestReceived, 1); + Service_RegisterOpHandler(serviceRndv, 0, C_SERVICE_REQUEST_POSTED, &RequestPosted, 1); + Service_RegisterOpHandler(serviceRndv, 0, C_SERVICE_READWRITE_DONE, &OneSideDone, 1); + Service_RegisterAllocateHandler(serviceRndv, &MemAllocate); + Service_RegisterFreeHandler(serviceRndv, &MemFree); + Service_RegisterRndvOpHandler(serviceRndv, &RndvHandlerC); + rndvPort = ++rdnvBasePort; + Service_SetOobIpAndPort(serviceRndv, BASE_IP, rndvPort); + + if ((result = ubs_hcom_service_start(serviceRndv, options)) != 0) { + NN_LOG_ERROR("failed to start service " << result); + return -1; + } + NN_LOG_INFO("service started"); + + ubs_hcom_memory_region mr; + result = ubs_hcom_service_register_assign_memory_region(serviceRndv, (uintptr_t)addressRndv, memSizeRndv, &mr); + if (result != 0) { + NN_LOG_ERROR("Register mr failed"); + return -1; + } + + ubs_hcom_mr_info info; + ubs_hcom_service_get_memory_region_info(mr, &info); + Net_MemoryAllocatorSetMrKey(memAllocator, info.lKey); + + return 0; +} + +int CreateRndvClient() +{ + int result = 0; + + if (clientRndv != 0) { + NN_LOG_ERROR("service already created."); + return -1; + } + + result = ubs_hcom_service_create(serviceRndvType, "client_rndv", 0, &clientRndv); + if (result != 0) { + NN_LOG_ERROR("failed to create service already created."); + return -1; + } + + result = PrepareClientMemAllocator(); + if (result != 0) { + NN_LOG_ERROR("failed to prepare mem."); + return -1; + } + + ubs_hcom_service_options options; + bzero(&options, sizeof(ubs_hcom_service_options)); + options.enableRndv = 1; + options.mode = C_SERVICE_BUSY_POLLING; + options.mrSendReceiveSegSize = 1024 + rndvSize; + options.mrSendReceiveSegCount = 8192; + options.enableTls = false; + strcpy(options.netDeviceIpMask, IP_SEG); + + Service_RegisterChannelHandler(clientRndv, C_CHANNEL_BROKEN, &EndPointBroken, C_CHANNEL_RECONNECT, 1); + Service_RegisterOpHandler(clientRndv, 0, C_SERVICE_REQUEST_RECEIVED, &RequestReceived, 1); + Service_RegisterOpHandler(clientRndv, 0, C_SERVICE_REQUEST_POSTED, &RequestPosted, 1); + Service_RegisterOpHandler(clientRndv, 0, C_SERVICE_READWRITE_DONE, &OneSideDone, 1); + Service_RegisterAllocateHandler(clientRndv, &MemClientAllocate); + Service_RegisterFreeHandler(clientRndv, &MemClientFree); + Service_RegisterRndvOpHandler(clientRndv, &RndvHandlerC); + + Service_SetOobIpAndPort(clientRndv, BASE_IP, rndvPort); + + if ((result = ubs_hcom_service_start(clientRndv, options)) != 0) { + NN_LOG_ERROR("failed to start service " << result); + return -1; + } + NN_LOG_INFO("client started."); + + ubs_hcom_memory_region mr; + result = ubs_hcom_service_register_assign_memory_region(clientRndv, (uintptr_t)addressClientRndv, memSizeRndv, &mr); + if (result != 0) { + NN_LOG_ERROR("Register mr failed."); + return -1; + } + + ubs_hcom_mr_info info; + ubs_hcom_service_get_memory_region_info(mr, &info); + Net_MemoryAllocatorSetMrKey(memClientAllocator, info.lKey); + + return 0; +} + + +void TestCapiServiceRndv::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); + MOCKER(ReadRoCEVersionFromFile).stubs().will(returnValue(0)); + rndvOptions = { 0, 0, C_CHANNEL_FUNC_CB, 0, 0, 0 }; + rndvOptions.epSize = 2; + CreateRndvService(); + CreateRndvClient(); +} + +void TestCapiServiceRndv::TearDown() +{ + Channel_Destroy(serviceChannel); + Service_Stop(serviceRndv); + ubs_hcom_service_destroy(serviceRndv); + Service_Stop(clientRndv); + ubs_hcom_service_destroy(clientRndv); + serviceRndv = 0; + clientRndv = 0; + rndvOptions = { 0, 0, C_CHANNEL_FUNC_CB, 0, 0, 0 }; + free(addressRndv); + addressRndv = NULL; + free(addressClientRndv); + addressClientRndv = NULL; + GlobalMockObject::verify(); +} + +typedef struct { + Service_Message rsp; + sem_t sem; + int ret; +} CallAsyncStruct; +void CallAsyncCbRdnv(void *arg, ubs_hcom_service_context context) +{ + CallAsyncStruct *param = (CallAsyncStruct *)arg; + + if (ubs_hcom_context_get_result(context, ¶m->ret) != 0) { + NN_LOG_ERROR("Call async failed " << param->ret); + return; + } + + sem_post(¶m->sem); +} + +TEST_F(TestCapiServiceRndv, CallRequest) +{ + int result = 0; + Net_Channel clientChannel = 0; + result = ubs_hcom_service_connect(clientRndv, BASE_IP, rndvPort, "hello service c", &clientChannel, &rndvOptions); + EXPECT_EQ(0, result); + + /* CallRequest */ + char *data = (char *)malloc(rndvSize); + Service_Request req; + uintptr_t addressPtr; + uint32_t key; + result = MemClientAllocate(rndvSize, &addressPtr, &key); + EXPECT_EQ(0, result); + + req.lAddress = addressPtr; + req.lKey = key; + req.size = rndvSize; + + Service_Message rsp = { data, rndvSize }; + Service_OpInfo reqInfo = { 0, 0, 0, 0 }; + reqInfo.timeout = 1; + Service_OpInfo rspInfo = { 0, 0, 0, 0 }; + + result = Channel_SyncRndvCall(clientChannel, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(0, result); + + result = Channel_SyncRndvCall(0, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncRndvCall(clientChannel, nullptr, &req, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncRndvCall(clientChannel, &reqInfo, nullptr, &rspInfo, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncRndvCall(clientChannel, &reqInfo, &req, nullptr, &rsp); + EXPECT_EQ(501, result); + + result = Channel_SyncRndvCall(clientChannel, &reqInfo, &req, &rspInfo, nullptr); + EXPECT_EQ(501, result); + + free(data); + MemClientFree(addressPtr); + GlobalMockObject::verify(); + + /* CallRawSglRequest */ + char *data1 = (char *)malloc(rndvSize); + Service_Request req1[2]; + for (uint32_t i = 0; i < 2; i++) { + uintptr_t addressPtr; + uint32_t key; + + result = MemClientAllocate(rndvSize, &addressPtr, &key); + EXPECT_EQ(0, result); + + req1[i].lAddress = addressPtr; + req1[i].lKey = key; + req1[i].size = rndvSize; + } + + Service_SglRequest sgl1 = { req1, 1 }; + Service_Message rsp1 = { data1, rndvSize }; + Service_OpInfo reqInfo1 = { 0, 0, 0, 0 }; + Service_OpInfo rspInfo1 = { 0, 0, 0, 0 }; + + result = Channel_SyncRndvSglCall(clientChannel, &reqInfo1, &sgl1, &rspInfo1, &rsp1); + EXPECT_EQ(0, result); + + result = Channel_SyncRndvSglCall(0, &reqInfo1, &sgl1, &rspInfo1, &rsp1); + EXPECT_EQ(501, result); + + result = Channel_SyncRndvSglCall(clientChannel, nullptr, &sgl1, &rspInfo1, &rsp1); + EXPECT_EQ(501, result); + + result = Channel_SyncRndvSglCall(clientChannel, &reqInfo1, nullptr, &rspInfo1, &rsp1); + EXPECT_EQ(501, result); + + result = Channel_SyncRndvSglCall(clientChannel, &reqInfo1, &sgl1, nullptr, &rsp1); + EXPECT_EQ(501, result); + + result = Channel_SyncRndvSglCall(clientChannel, &reqInfo1, &sgl1, &rspInfo1, nullptr); + EXPECT_EQ(501, result); + + free(data1); + MemClientFree(req1[0].lAddress); + MemClientFree(req1[1].lAddress); + + GlobalMockObject::verify(); + + /* CallAsyncRequest */ + Service_Request req2; + Service_OpInfo reqInfo2 = { 0, 0, 0, 0 }; + + uintptr_t addressPtr2; + uint32_t key2; + + result = MemClientAllocate(rndvSize, &addressPtr2, &key2); + EXPECT_EQ(0, result); + + req2.lAddress = addressPtr2; + req2.lKey = key2; + req2.size = rndvSize; + + char data2[rndvSize]; + CallAsyncStruct asyncParam2; + asyncParam2.rsp.data = data2; + asyncParam2.rsp.size = rndvSize; + sem_init(&asyncParam2.sem, 0, 0); + asyncParam2.ret = 0; + + ubs_hcom_channel_callback cb2; + cb2.arg = &asyncParam2; + cb2.cb = CallAsyncCbRdnv; + + result = Channel_AsyncRndvCall(clientChannel, &reqInfo2, &req2, &cb2); + EXPECT_EQ(0, result); + sem_wait(&asyncParam2.sem); + sem_destroy(&asyncParam2.sem); + + + result = Channel_AsyncRndvCall(0, &reqInfo2, &req2, &cb2); + EXPECT_EQ(501, result); + + result = Channel_AsyncRndvCall(clientChannel, nullptr, &req2, &cb2); + EXPECT_EQ(501, result); + + result = Channel_AsyncRndvCall(clientChannel, &reqInfo2, nullptr, &cb2); + EXPECT_EQ(501, result); + + result = Channel_AsyncRndvCall(clientChannel, &reqInfo2, &req2, nullptr); + EXPECT_EQ(501, result); + + MemClientFree(addressPtr2); + EXPECT_EQ(0, asyncParam2.ret); + + GlobalMockObject::verify(); + + /* CallAsyncRawSglRequest */ + Service_Request req3[2]; + for (uint32_t i = 0; i < 2; i++) { + uintptr_t addressPtr; + uint32_t key; + + result = MemClientAllocate(rndvSize, &addressPtr, &key); + EXPECT_EQ(0, result); + + req3[i].lAddress = addressPtr; + req3[i].lKey = key; + req3[i].size = rndvSize; + } + + Service_SglRequest sgl3 = { req3, 1 }; + Service_OpInfo reqInfo3 = { 0, 0, 0, 0 }; + + CallAsyncStruct asyncParam3; + sem_init(&asyncParam3.sem, 0, 0); + asyncParam3.ret = 0; + + ubs_hcom_channel_callback cb3; + cb3.arg = &asyncParam3; + cb3.cb = CallAsyncCbRdnv; + + result = Channel_AsyncRndvSglCall(clientChannel, &reqInfo3, &sgl3, &cb3); + EXPECT_EQ(0, result); + + result = Channel_AsyncRndvSglCall(0, &reqInfo3, &sgl3, &cb3); + EXPECT_EQ(501, result); + + result = Channel_AsyncRndvSglCall(clientChannel, nullptr, &sgl3, &cb3); + EXPECT_EQ(501, result); + + result = Channel_AsyncRndvSglCall(clientChannel, &reqInfo3, nullptr, &cb3); + EXPECT_EQ(501, result); + + result = Channel_AsyncRndvSglCall(clientChannel, &reqInfo3, &sgl3, nullptr); + EXPECT_EQ(501, result); + + sem_wait(&asyncParam3.sem); + sem_destroy(&asyncParam3.sem); + MemClientFree(req3[0].lAddress); + MemClientFree(req3[1].lAddress); + EXPECT_EQ(0, asyncParam3.ret); + GlobalMockObject::verify(); + + Channel_Destroy(clientChannel); +} + +TEST_F(TestCapiServiceRndv, CallRequestInline) +{ + int result = 0; + Net_Channel clientChannel = 0; + result = ubs_hcom_service_connect(clientRndv, BASE_IP, rndvPort, "hello service c", &clientChannel, &rndvOptions); + EXPECT_EQ(0, result); + + /* CallRequest */ + char *data = (char *)malloc(rndvSize); + Service_Request req; + uintptr_t addressPtr; + uint32_t key; + result = MemClientAllocate(rndvSize, &addressPtr, &key); + EXPECT_EQ(0, result); + + req.lAddress = addressPtr; + req.lKey = key; + req.size = rndvSize; + + Service_Message rsp = { data, rndvSize }; + Service_OpInfo reqInfo = { 0, 0, 0, 0 }; + reqInfo.timeout = 1; + Service_OpInfo rspInfo = { 0, 0, 0, 0 }; + + MOCKER_CPP(&HcomEnv::InlineThreshold).stubs().will(returnValue(4096)); + + result = Channel_SyncRndvCall(clientChannel, &reqInfo, &req, &rspInfo, &rsp); + EXPECT_EQ(0, result); + + free(data); + MemClientFree(addressPtr); + GlobalMockObject::verify(); +} \ No newline at end of file diff --git a/test/llt/testcase/capi/test_capi_service_rndv.h b/test/llt/testcase/capi/test_capi_service_rndv.h new file mode 100644 index 0000000000000000000000000000000000000000..eda23d067d27e96f7a219e5f6e8ce313d1df78cc --- /dev/null +++ b/test/llt/testcase/capi/test_capi_service_rndv.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_TEST_CAPI_SERVICE_RNDV_H +#define HCOM_TEST_CAPI_SERVICE_RNDV_H +#include +#include + +class TestCapiServiceRndv : public testing::Test { +public: + TestCapiServiceRndv(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_CAPI_SERVICE_RNDV_H diff --git a/test/llt/testcase/capi/test_rdma_c.cpp b/test/llt/testcase/capi/test_rdma_c.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4af7f2e3c77b75399074868e402f4cddb453179e --- /dev/null +++ b/test/llt/testcase/capi/test_rdma_c.cpp @@ -0,0 +1,781 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef RDMA_BUILD_ENABLED +#include +#include "mockcpp/mockcpp.hpp" +#include "test_rdma_c.hpp" +#include "capi/hcom_c.h" +#include "string.h" +#include "hcom.h" +#include "common/net_util.h" +#include "transport/rdma/verbs/net_rdma_sync_endpoint.h" +#include "transport/rdma/verbs/net_rdma_async_endpoint.h" +#include "transport/rdma/rdma_mr_dm_buf.h" +#include "transport/rdma/rdma_mr_fixed_buf.h" +#include "transport/rdma/verbs/rdma_worker.h" +#include "fake_ibv.h" +#include "transport/rdma/verbs/net_rdma_driver.h" +#include "ut_helper.h" + +TestCaseRdmaC::TestCaseRdmaC() {} + +void TestCaseRdmaC::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestCaseRdmaC::TearDown() +{ + GlobalMockObject::verify(); +} + +using namespace ock::hcom; +#ifdef MOCK_VERBS +#ifdef __cplusplus +extern "C" { +#endif +int fake_ibv_post_send(fake_qp_t *my_qp, struct ibv_send_wr *wr); +int fake_post_read(fake_qp_t *my_qp, struct ibv_send_wr *wr); +int fake_post_write(fake_qp_t *my_qp, struct ibv_send_wr *wr); +#ifdef __cplusplus +} +#endif +#endif +// cpp case +using CTestOpCode = enum { + C_GET_MR = 1, + C_SET_MR, + C_CHECK_SYNC_RESPONSE, + C_SEND_RAW, +}; + +#define CHECK_RESULT_TRUE(result) \ + EXPECT_EQ(true, result); \ + if (!result) { \ + return; \ + } + +#define CLEAN_UP_ALL_STUBS() GlobalMockObject::verify() + +constexpr uint64_t C_SYNC_SEND_VALUE = 0xAAAA0000; +constexpr uint64_t C_SYNC_RECEIVE_VALUE = 0x0000AAAA; +constexpr uint64_t C_SET_MSG_SUCCESS = 0xCCCCCCCC; +constexpr uint64_t ASYNC_RW_COUNT = 4; +constexpr uint64_t C_RDMA_LISTEN_PORT = 7778; +// server +Net_Driver cServerDriver = 0; +Net_EndPoint epServer = 0; +Net_MemoryRegion mrRegion[NN_NO4]; + +char *ipSeg = IP_SEG; +char *certPath; +bool enableTls = true; +Net_DriverCipherSuite cipherSuite = C_AES_GCM_256; + +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t lKey = 0; + uint32_t size = 0; +} __attribute__((packed)); +TestRegMrInfo cServerLocalMrInfo[NN_NO4]; + +char *join(char *a, char *b) +{ + char *c = (char *)malloc(strlen(a) + strlen(b) + 1); + if (c == NULL) + exit(1); + char *tempc = c; + while (*a != '\0') { + *c++ = *a++; + } + while ((*c++ = *b++) != '\0') { + } + return tempc; +} + +static int SNewEndPoint(Net_EndPoint newEp, uint64_t usrCtx, const char *payLoad) +{ + NN_LOG_INFO("ep new"); + epServer = newEp; + return 0; +} + +static int SEndPointBroken(Net_EndPoint bEp, uint64_t usrCtx, const char *payLoad) +{ + NN_LOG_INFO("ep broken"); + return 0; +} + +static int SRequestReceived(Net_RequestContext *ctx, uint64_t usrCtx) +{ + int result = 0; + Net_SendRequest rsp = { 0 }; + if (ctx->type == C_RECEIVED) { + if (ctx->opCode == C_GET_MR) { + rsp.data = (uintptr_t)cServerLocalMrInfo; + rsp.size = sizeof(cServerLocalMrInfo); + if ((result = Net_EPPostSend(ctx->ep, ctx->opCode, &rsp)) != 0) { + NN_LOG_INFO("failed to post message to data to server, result " << result); + return result; + } + + NN_LOG_TRACE_INFO("request rsp Mr info"); + for (uint16_t i = 0; i < NN_NO4; i++) { + NN_LOG_TRACE_INFO("idx:" << i << " key:" << cServerLocalMrInfo[i].lKey << " address:" << + cServerLocalMrInfo[i].lAddress << " size" << cServerLocalMrInfo[i].size); + } + } else if (ctx->opCode == C_SET_MR) { + memset(cServerLocalMrInfo, 0, sizeof(cServerLocalMrInfo)); + uint64_t ret = C_SET_MSG_SUCCESS; + rsp.data = (uintptr_t)&ret; + rsp.size = sizeof(uint64_t); + if ((result = Net_EPPostSend(ctx->ep, ctx->opCode, &rsp)) != 0) { + NN_LOG_INFO("failed to post message to data to server, result " << result); + return result; + } + } else if (ctx->opCode == C_CHECK_SYNC_RESPONSE) { + uint64_t *readValue = reinterpret_cast((void *)(ctx->msgData)); + EXPECT_EQ(C_SYNC_SEND_VALUE, *readValue); + uint64_t returnValue = C_SYNC_RECEIVE_VALUE; + rsp.data = (uintptr_t)&returnValue; + rsp.size = sizeof(uint64_t); + if ((result = Net_EPPostSend(ctx->ep, ctx->opCode, &rsp)) != 0) { + NN_LOG_INFO("failed to post message to data to server, result " << result); + return result; + } + } + } else if (ctx->type == C_RECEIVED_RAW) { + if (ctx->seqNo == C_SEND_RAW) { + uint64_t returnValue = 0; + rsp.data = (uintptr_t)&returnValue; + rsp.size = sizeof(uint64_t); + if ((result = Net_EPPostSendRaw(ctx->ep, &rsp, ctx->seqNo)) != 0) { + NN_LOG_INFO("failed to post message to data to server, result " << result); + return result; + } + } + } + + return 0; +} + +static int SRequestPosted(Net_RequestContext *ctx, uint64_t usrCtx) +{ + NN_LOG_TRACE_INFO("posted"); + return 0; +} + +static int SOneSideDone(Net_RequestContext *ctx, uint64_t usrCtx) +{ + NN_LOG_TRACE_INFO("one side done"); + return 0; +} + +static void SIdle(uint8_t wkrGrpIdx, uint16_t workerIndex, uint64_t usrCtx) {} + +static void SErase(char *pass, int len) {} + +static int SVerify(void *x509, const char *path) +{ + NN_LOG_INFO("verify"); + return 0; +} + +static int SCertCallback(const char *name, char **value) +{ + char cert[] = "/server/cert.pem"; + *value = join(certPath, cert); + printf("cert callback v: %s \n", *value); + return 0; +} + +static int SPrivateKeyCallback(const char *name, char **value, char **keyPass, Net_TlsKeyPassErase *erase) +{ + printf("private key cb"); + static char content[] = "huawei"; + *keyPass = content; + char keypem[] = "/server/key.pem"; + *value = join(certPath, keypem); + *erase = &SErase; + return 0; +} + +static int SCACallback(const char *name, char **caPath, char **crlPath, Net_PeerCertVerifyType *peerCertVerifyType, + Net_TlsCertVerify *verify) +{ + char caCert[] = "/CA/cacert.pem"; + *caPath = join(certPath, caCert); + *peerCertVerifyType = C_VERIFY_BY_DEFAULT; + return 0; +} + +bool CServerCreateDriver() +{ + int result = 0; + Net_DriverOptions options; + + if (cServerDriver != 0) { + NN_LOG_ERROR("cServerDriver already created"); + return false; + } + + result = Net_DriverCreate(C_DRIVER_RDMA, "c_server", 1, &cServerDriver); + if (result != 0) { + NN_LOG_ERROR("failed to create cServerDriver already created"); + return false; + } + + bzero(&options, sizeof(Net_DriverOptions)); + options.mode = C_EVENT_POLLING; + options.mrSendReceiveSegSize = 1024; + options.mrSendReceiveSegCount = 8192; + options.enableTls = enableTls; + options.cipherSuite = cipherSuite; + strcpy(options.netDeviceIpMask, ipSeg); + sprintf(options.workerGroupsCpuSet, "%u-%u", 20, 20); + + Net_DriverRegisterEpHandler(cServerDriver, C_EP_NEW, &SNewEndPoint, 1); + Net_DriverRegisterEpHandler(cServerDriver, C_EP_BROKEN, &SEndPointBroken, 1); + Net_DriverRegisterOpHandler(cServerDriver, C_OP_REQUEST_RECEIVED, &SRequestReceived, 1); + Net_DriverRegisterOpHandler(cServerDriver, C_OP_REQUEST_POSTED, &SRequestPosted, 1); + Net_DriverRegisterOpHandler(cServerDriver, C_OP_READWRITE_DONE, &SOneSideDone, 1); + Net_DriverRegisterIdleHandler(cServerDriver, &SIdle, 1); + if (enableTls) { + Net_DriverRegisterTLSCb(cServerDriver, &SCertCallback, &SPrivateKeyCallback, &SCACallback); + } + + Net_DriverSetOobIpAndPort(cServerDriver, "127.0.0.1", C_RDMA_LISTEN_PORT); + + if ((result = Net_DriverInitialize(cServerDriver, options)) != 0) { + NN_LOG_ERROR("failed to initialize cServerDriver " << result); + return false; + } + NN_LOG_INFO("cServerDriver initialized"); + + if ((result = Net_DriverStart(cServerDriver)) != 0) { + NN_LOG_INFO("failed to start cServerDriver " << result); + return false; + } + NN_LOG_INFO("cServerDriver started"); + + return true; +} + +bool CServerRegSglMem() +{ + for (uint16_t i = 0; i < NN_NO4; i++) { + int result = Net_DriverCreateMemoryRegion(cServerDriver, NN_NO8, &mrRegion[i]); + if (result != 0) { + NN_LOG_INFO("reg mr failed"); + return false; + } + + Net_MemoryRegionInfo mrInfo; + result = Net_DriverGetMemoryRegionInfo(mrRegion[i], &mrInfo); + if (result != 0) { + NN_LOG_INFO("parse mr failed"); + return false; + } + cServerLocalMrInfo[i].lAddress = mrInfo.lAddress; + cServerLocalMrInfo[i].lKey = mrInfo.lKey; + cServerLocalMrInfo[i].size = mrInfo.size; + memset((void *)(cServerLocalMrInfo[i].lAddress), 0, cServerLocalMrInfo[i].size); + } + + return true; +} + +// client +TestRegMrInfo cRemoteMrInfo[NN_NO4]; +TestRegMrInfo cSelfLocalMrInfo[NN_NO4]; +Net_MemoryRegion clientMrInfo[NN_NO4]; +sem_t cSem; +uint32_t cExecCount = 0; +Net_Driver cDriver = 0; +Net_EndPoint cAsyncEp = 0; +Net_EndPoint cSyncEp = 0; + +int16_t asyncWorkerCpuId = 10; +void CSendAsyncReadWriteRequest(Net_ReadWriteSge *iov, uint64_t index) +{ + int result = 0; + + Net_ReadWriteSglRequest req; + req.iov = iov; + req.iovCount = NN_NO4; + req.upCtxSize = 0; + result = Net_EPPostSglRead(cAsyncEp, &req); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + sem_wait(&cSem); + + NN_LOG_TRACE_INFO("sgl read value idx:" << cExecCount++); + for (uint16_t i = 0; i < NN_NO4; i++) { + uint64_t *readValue = (uint64_t *)((void *)(iov[i].lAddress)); + uint64_t value = *readValue; + EXPECT_EQ(value, index); + NN_LOG_TRACE_INFO("value[" << i << "]=" << *readValue); + *readValue = ++value; + } + + result = Net_EPPostSglWrite(cAsyncEp, &req); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + sem_wait(&cSem); + + Net_ReadWriteRequest buffReq = { 0 }; + buffReq.lMRA = iov[0].lAddress; + buffReq.rMRA = iov[0].rAddress; + buffReq.lKey = iov[0].lKey; + buffReq.rKey = iov[0].rKey; + buffReq.size = iov[0].size; + result = Net_EPPostRead(cAsyncEp, &buffReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + sem_wait(&cSem); + uint64_t *readBuff = reinterpret_cast((void *)(iov[0].lAddress)); + uint64_t readValue = *readBuff; + EXPECT_EQ(readValue, index + 1); + + result = Net_EPPostWrite(cAsyncEp, &buffReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + sem_wait(&cSem); +} + +void CAsyncReadWriteRequest() +{ + Net_ReadWriteSge iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = cSelfLocalMrInfo[i].lAddress; + iov[i].rAddress = cRemoteMrInfo[i].lAddress; + iov[i].lKey = cSelfLocalMrInfo[i].lKey; + iov[i].rKey = cRemoteMrInfo[i].lKey; + iov[i].size = NN_NO8; + } + sem_init(&cSem, 0, 0); + for (int i = 0; i < NN_NO4; i++) { + CSendAsyncReadWriteRequest(iov, i); + } +} + +void CAsyncPostSend() +{ + uint64_t data = C_SYNC_SEND_VALUE; + Net_SendRequest req = { 0 }; + int result = 0; + + req.data = (uintptr_t)&data; + req.size = sizeof(data); + + result = Net_EPPostSend(cAsyncEp, C_CHECK_SYNC_RESPONSE, &req); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to post message to data to server" << result); + return; + } + sem_wait(&cSem); + + result = Net_EPPostSend(cAsyncEp, C_SET_MR, &req); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to post message to data to server" << result); + return; + } + sem_wait(&cSem); +} + +void CAsyncRequest() +{ + CAsyncPostSend(); + CAsyncReadWriteRequest(); +} + +void CSyncPostSend() +{ + uint64_t data = C_SYNC_SEND_VALUE; + Net_SendRequest req = { 0 }; + int result = 0; + + req.data = (uintptr_t)&data; + req.size = sizeof(data); + + result = Net_EPPostSend(cSyncEp, C_CHECK_SYNC_RESPONSE, &req); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to post message to data to server" << result); + return; + } + + result = Net_EPWaitCompletion(cSyncEp, -1); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to wait data to server" << result); + return; + } + + Net_ResponseContext *ctx; + result = Net_EPReceive(cSyncEp, -1, &ctx); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to receive raw data to server" << result); + return; + } + + result = Net_EPPostSendRaw(cSyncEp, &req, C_SEND_RAW); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to post message to data to server" << result); + return; + } + + result = Net_EPWaitCompletion(cSyncEp, -1); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to wait data to server" << result); + return; + } + + result = Net_EPReceiveRaw(cSyncEp, -1, &ctx); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to receive raw data to server" << result); + return; + } +} + +void CSyncRequest() +{ + CSyncPostSend(); +} + + +int CRequestReceived(Net_RequestContext *ctx, uint64_t usrCtx) +{ + if (ctx->type == C_RECEIVED) { + if (ctx->opCode == C_CHECK_SYNC_RESPONSE) { + uint64_t *readValue = reinterpret_cast((void *)(ctx->msgData)); + EXPECT_EQ(C_SYNC_RECEIVE_VALUE, *readValue); + sem_post(&cSem); + } else if (ctx->opCode == C_SET_MR) { + uint64_t *readValue = reinterpret_cast((void *)(ctx->msgData)); + EXPECT_EQ(C_SET_MSG_SUCCESS, *readValue); + sem_post(&cSem); + } else if (ctx->opCode == C_GET_MR) { + memcpy(cRemoteMrInfo, ctx->msgData, ctx->msgSize); + sem_post(&cSem); + } + } else if (ctx->type == C_RECEIVED_RAW) { + } + + return 0; +} + +static int CEndPointBroken(Net_EndPoint bep, uint64_t usrCtx, const char *payLoad) +{ + NN_LOG_INFO("end point " << Net_EPGetContext(bep) << " broken"); + return 0; +} + + +int CRequestPosted(Net_RequestContext *ctx, uint64_t usrCtx) +{ + return 0; +} + +int COneSideDone(Net_RequestContext *ctx, uint64_t usrCtx) +{ + sem_post(&cSem); + return 0; +} + +void CIdle(uint8_t wkrGrpIdx, uint16_t workerIndex, uint64_t usrCtx) {} + +static void CErase(char *pass, int len) {} + +static int CVerify(void *x509, const char *path) +{ + NN_LOG_INFO("verify"); + return 0; +} + +static int CCertCallback(const char *name, char **value) +{ + char cert[] = "/client/cert.pem"; + *value = join(certPath, cert); + return 1; +} + +static int CPrivateKeyCallback(const char *name, char **value, char **keyPass, Net_TlsKeyPassErase *erase) +{ + static char content[] = "huawei"; + *keyPass = content; + char keyPerm[] = "/client/key.pem"; + *value = join(certPath, keyPerm); + *erase = &CErase; + + return 1; +} + +static int CCACallback(const char *name, char **caPath, char **crlPath, Net_PeerCertVerifyType *verifyType, + Net_TlsCertVerify *cb) +{ + char caCert[] = "/CA/cacert.pem"; + *caPath = join(certPath, caCert); + *verifyType = C_VERIFY_BY_NONE; + return 1; +} + +static bool CCreateDriver() +{ + int result = 0; + Net_DriverOptions options; + + if (cDriver != 0) { + NN_LOG_ERROR("cDriver already created"); + return false; + } + + Net_DeviceInfo deviceInfo; + EXPECT_EQ(Net_LocalSupport(C_DRIVER_RDMA, &deviceInfo), 1); + + result = Net_DriverCreate(C_DRIVER_RDMA, "c_client", 0, &cDriver); + if (result != 0) { + NN_LOG_ERROR("failed to create cDriver already created"); + return false; + } + + bzero(&options, sizeof(Net_DriverOptions)); + options.mode = C_EVENT_POLLING; + options.mrSendReceiveSegSize = 2048; + options.mrSendReceiveSegCount = 8192; + strcpy(options.netDeviceIpMask, ipSeg); + options.qpSendQueueSize = 512; + options.qpReceiveQueueSize = 512; + options.version = 1; + options.enableTls = enableTls; + options.cipherSuite = cipherSuite; + + Net_DriverRegisterEpHandler(cDriver, C_EP_BROKEN, &CEndPointBroken, 2); + Net_DriverRegisterOpHandler(cDriver, C_OP_REQUEST_RECEIVED, &CRequestReceived, 2); + Net_DriverRegisterOpHandler(cDriver, C_OP_REQUEST_POSTED, &CRequestPosted, 2); + Net_DriverRegisterOpHandler(cDriver, C_OP_READWRITE_DONE, &COneSideDone, 2); + + if (enableTls) { + Net_DriverRegisterTLSCb(cDriver, &CCertCallback, &CPrivateKeyCallback, &CCACallback); + } + + auto handle = Net_DriverRegisterIdleHandler(cDriver, &CIdle, 2); + EXPECT_NE(handle, 0); + + Net_DriverSetOobIpAndPort(cDriver, "0.0.0.0", C_RDMA_LISTEN_PORT); + + if ((result = Net_DriverInitialize(cDriver, options)) != 0) { + NN_LOG_ERROR("failed to initialize cDriver " << result); + return false; + } + printf("cDriver initialized"); + + if ((result = Net_DriverStart(cDriver)) != 0) { + NN_LOG_ERROR("failed to start cDriver %d" << result); + return false; + } + NN_LOG_INFO("cDriver started"); + + return true; +} + +bool CAsyncConnect() +{ + int result = 0; + + if (cDriver == 0) { + NN_LOG_ERROR("cDriver is null"); + return false; + } + + if ((result = Net_DriverConnect(cDriver, "hello world", &cAsyncEp, 0)) != 0) { + NN_LOG_ERROR("failed to connect to server, result %d" << result); + return false; + } + + Net_EPSetContext(cAsyncEp, 0); + EXPECT_EQ(Net_EPGetContext(cAsyncEp), 0); + + // default config worker group = 1, num = 1, index = 0 + EXPECT_EQ(Net_EPGetWorkerIndex(cAsyncEp), 0); + EXPECT_EQ(Net_EPGetWorkerGroupIndex(cAsyncEp), 0); + + // config when create driver + EXPECT_EQ(Net_EPGetListenPort(cAsyncEp), C_RDMA_LISTEN_PORT); + EXPECT_EQ(Net_EPGetVersion(cAsyncEp), 1); + + sem_init(&cSem, 0, 0); + Net_SendRequest req = { 0 }; + + char value[] = "hello world"; + req.data = (uintptr_t)value; + req.size = sizeof(value); + + if ((result = Net_EPPostSend(cAsyncEp, C_GET_MR, &req)) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return false; + } + + sem_wait(&cSem); + return true; +} + +bool CSyncConnect() +{ + int result = 0; + + if (cDriver == 0) { + NN_LOG_ERROR("cDriver is null"); + return false; + } + + if ((result = Net_DriverConnectToIpPort(cDriver, "0.0.0.0", C_RDMA_LISTEN_PORT, "hello world", &cSyncEp, + NET_C_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result %d" << result); + return false; + } + + return true; +} + +bool CClientRegSglMem() +{ + for (uint16_t i = 0; i < NN_NO4; i++) { + auto result = Net_DriverCreateMemoryRegion(cDriver, NN_NO8, &clientMrInfo[i]); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + + Net_MemoryRegionInfo regMr; + result = Net_DriverGetMemoryRegionInfo(clientMrInfo[i], ®Mr); + if (result != NN_OK) { + printf("parse mr failed\n"); + return false; + } + cSelfLocalMrInfo[i].lAddress = regMr.lAddress; + cSelfLocalMrInfo[i].lKey = regMr.lKey; + cSelfLocalMrInfo[i].size = regMr.size; + memset(reinterpret_cast(cSelfLocalMrInfo[i].lAddress), 0, cSelfLocalMrInfo[i].size); + } + + char *buff = (char *)malloc(4096); + Net_MemoryRegion tmp; + auto result = Net_DriverCreateAssignMemoryRegion(cDriver, (uintptr_t)buff, 4096, &tmp); + EXPECT_EQ(result, 0); + if (result == 0) { + Net_DriverDestroyMemoryRegion(cDriver, tmp); + } + + return true; +} + +bool CClientUnRegSglMem() +{ + for (uint16_t i = 0; i < NN_NO4; i++) { + Net_DriverDestroyMemoryRegion(cDriver, clientMrInfo[i]); + } + + return true; +} + +static int ValidateTlsCert() +{ + char *buffer; + + if ((buffer = getcwd(NULL, 0)) == NULL) { + NN_LOG_ERROR("Cet path for TLS cert failed"); + return -1; + } + + char *currentPath = buffer; + char base[] = "/../test/opensslcrt/normalCert1"; + certPath = join(currentPath, base); + + char cacert[] = "/CA/cacert.pem"; + if (::access(join(certPath, cacert), F_OK) != 0) { + NN_LOG_ERROR("cacert.pem cannot be found under " << certPath); + return -1; + } + + char cert[] = "/server/cert.pem"; + if (::access(join(certPath, cert), F_OK) != 0) { + NN_LOG_ERROR("cert.pem cannot be found under " << certPath); + return -1; + } + + char key[] = "/server/key.pem"; + if (::access(join(certPath, key), F_OK) != 0) { + NN_LOG_ERROR("key.pem cannot be found under " << certPath); + return -1; + } + + return 0; +} + +TEST_F(TestCaseRdmaC, RDMA_C_BASIC_OPERATE) +{ + MOCK_VERSION + + if (enableTls) { + ValidateTlsCert(); + } + + int result = 0; + + result = CServerCreateDriver(); + CHECK_RESULT_TRUE(result); + result = CServerRegSglMem(); + CHECK_RESULT_TRUE(result); + + result = CCreateDriver(); + CHECK_RESULT_TRUE(result); + result = CAsyncConnect(); + CHECK_RESULT_TRUE(result); + result = CClientRegSglMem(); + CHECK_RESULT_TRUE(result); + CAsyncRequest(); + + result = CSyncConnect(); + CHECK_RESULT_TRUE(result); + CSyncRequest(); + + CClientUnRegSglMem(); + + Net_DriverStop(cServerDriver); + Net_DriverUnInitialize(cServerDriver); + Net_DriverDestroy(cServerDriver); + + Net_DriverStop(cDriver); + Net_DriverUnInitialize(cDriver); + Net_DriverDestroy(cDriver); +} +#endif \ No newline at end of file diff --git a/test/llt/testcase/capi/test_rdma_c.hpp b/test/llt/testcase/capi/test_rdma_c.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ad674eb466b5197a30f7936cbb76a6e789a2afd9 --- /dev/null +++ b/test/llt/testcase/capi/test_rdma_c.hpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef _TEST_RDMA_C_HPP_ +#define _TEST_RDMA_C_HPP_ +#include + +class TestCaseRdmaC : public testing::Test { +public: + TestCaseRdmaC(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; + +#endif diff --git a/test/llt/testcase/common/test_address_size_hash_map.cpp b/test/llt/testcase/common/test_address_size_hash_map.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c0ea2cecdef9081cb3d995681382383318abd186 --- /dev/null +++ b/test/llt/testcase/common/test_address_size_hash_map.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "net_addr_size_map.h" +#include "test_address_size_hash_map.h" + +using namespace ock::hcom; + +void TestAddress2SizeHashmap::SetUp() {} + +void TestAddress2SizeHashmap::TearDown() {} + +TEST_F(TestAddress2SizeHashmap, PutRemove) +{ + NetAddress2SizeHashMap hMap {}; + auto result = hMap.Initialize(1024); + ASSERT_EQ(result, 0); + result = hMap.Put(1, 1); + ASSERT_EQ(result, 0); + uint32_t size = 0; + result = hMap.Remove(1, size); + ASSERT_EQ(result, 0); + ASSERT_EQ(size, 1); + hMap.UnInitialize(); +} + +TEST_F(TestAddress2SizeHashmap, DoubleInitialize) +{ + NetAddress2SizeHashMap hMap {}; + hMap.Initialize(1024); + auto result = hMap.Initialize(1024); + ASSERT_EQ(result, 0); + hMap.UnInitialize(); +} + +TEST_F(TestAddress2SizeHashmap, HashBucketPutAndRemove) +{ + NetHashBucket netHashBucket; + auto result = netHashBucket.Put(1, 1); + ASSERT_EQ(result, 1); + result = netHashBucket.Put(2, 2); + ASSERT_EQ(result, 1); + result = netHashBucket.Put(3, 3); + ASSERT_EQ(result, 1); + result = netHashBucket.Put(4, 4); + ASSERT_EQ(result, 1); + result = netHashBucket.Put(5, 5); + ASSERT_EQ(result, 1); + result = netHashBucket.Put(6, 6); + ASSERT_EQ(result, 1); + result = netHashBucket.Put(7, 7); + ASSERT_EQ(result, 0); + uint32_t size = 0; + result = netHashBucket.Remove(1, size); + ASSERT_EQ(result, 1); + ASSERT_EQ(size, 1); + result = netHashBucket.Remove(2, size); + ASSERT_EQ(result, 1); + ASSERT_EQ(size, 2); + result = netHashBucket.Remove(3, size); + ASSERT_EQ(result, 1); + ASSERT_EQ(size, 3); + result = netHashBucket.Remove(4, size); + ASSERT_EQ(result, 1); + ASSERT_EQ(size, 4); + result = netHashBucket.Remove(5, size); + ASSERT_EQ(result, 1); + ASSERT_EQ(size, 5); + result = netHashBucket.Remove(6, size); + ASSERT_EQ(result, 1); + ASSERT_EQ(size, 6); + result = netHashBucket.Remove(7, size); + ASSERT_EQ(result, 0); +} + +TEST_F(TestAddress2SizeHashmap, RemoveAbsentAddress) +{ + NetAddress2SizeHashMap hMap {}; + hMap.Initialize(1024); + uint32_t size = 0; + auto result = hMap.Remove(1, size); + ASSERT_EQ(result, 100); + hMap.UnInitialize(); +} \ No newline at end of file diff --git a/test/llt/testcase/common/test_address_size_hash_map.h b/test/llt/testcase/common/test_address_size_hash_map.h new file mode 100644 index 0000000000000000000000000000000000000000..ae5115aad78c1a6de8ea61f2a53d7bab90cfeb40 --- /dev/null +++ b/test/llt/testcase/common/test_address_size_hash_map.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_ADDRESS_SIZE_HASH_MAP_H +#define HCOM_TEST_ADDRESS_SIZE_HASH_MAP_H + +#include + +class TestAddress2SizeHashmap : public testing::Test { +public: + TestAddress2SizeHashmap() = default; + virtual void SetUp(void); + virtual void TearDown(void); +}; + + +#endif // HCOM_TEST_ADDRESS_SIZE_HASH_MAP_H diff --git a/test/llt/testcase/common/test_memory_allocator.cpp b/test/llt/testcase/common/test_memory_allocator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d21353248ede98e5c2ccb6f851ec1f6d459f6eb0 --- /dev/null +++ b/test/llt/testcase/common/test_memory_allocator.cpp @@ -0,0 +1,307 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "hcom.h" +#include "hcom_def.h" +#include "hcom_utils.h" +#include "net_mem_allocator.h" +#include "test_memory_allocator.h" + +#define SIZE (256UL << 12) +#define MRKEY 1 + +using namespace ock::hcom; + +TestMemoryAllocator::TestMemoryAllocator() {} + +void TestMemoryAllocator::SetUp() {} + +void TestMemoryAllocator::TearDown() {} + +static void ConcurrentRoutine(UBSHcomNetMemoryAllocatorPtr ptr, const int count, bool random, + std::atomic_uint64_t &allocCost, std::atomic_uint64_t &freeCost) +{ + uint64_t allocTotalTime = 0; + + std::vector addrs; + auto blocks = SIZE / NN_NO256; + + for (int i = 0; i < count; ++i) { + uint64_t addr = 0; + auto allocTime = MONOTONIC_TIME_NS(); + auto size = SIZE; + if (random) { + size = (i % blocks + 1) * (NN_NO256); + } + auto res = ptr->Allocate(size, addr); + allocTotalTime += (MONOTONIC_TIME_NS() - allocTime); + ASSERT_EQ(res, NN_OK); + if (addr > 0) { + addrs.emplace_back(addr); + } + } + + uint64_t freeTotalTime = MONOTONIC_TIME_NS(); + + for (uint32_t i = 0; i < addrs.size(); ++i) { + auto ret = ptr->Free(addrs[i]); + ASSERT_EQ(ret, NN_OK); + } + + freeTotalTime = MONOTONIC_TIME_NS() - freeTotalTime; + allocCost.fetch_add(allocTotalTime); + freeCost.fetch_add(freeTotalTime); +} + +TEST_F(TestMemoryAllocator, Serial) +{ + for (int k = 0; k < 4; ++k) { + bool res = false; + auto address = memalign(NN_NO4096, SIZE); + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(address); + options.size = SIZE; + options.minBlockSize = 4096; + UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE, options, ptr); + uint64_t addr = 0; + for (int i = 0; i < 4; ++i) { + auto expectSize = SIZE / ((i % 2) * 2 + 2) - 16; + res = ptr->Allocate(expectSize, addr); + ASSERT_EQ(res, NN_OK); + res = ptr->Free(addr); + ASSERT_EQ(res, NN_OK); + } + res = ptr->Allocate(SIZE, addr); + ASSERT_EQ(res, NN_OK); + } +} + +TEST_F(TestMemoryAllocator, GetSizeNoAlign) +{ + bool res = false; + auto address = memalign(NN_NO4096, SIZE * 16); + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(address); + options.size = SIZE * 16; + options.minBlockSize = 4096; + UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE, options, ptr); + uint64_t addr = 0; + uint64_t addrs[16]; + uint64_t sizes[16]; + for (int i = 0; i < 4; ++i) { + auto expectSize = random() % SIZE + 1; + res = ptr->Allocate(expectSize, addr); + ASSERT_EQ(res, NN_OK); + addrs[i] = addr; + sizes[i] = expectSize; + } + + for (int i = 0; i < 4; ++i) { + addr = addrs[i]; + auto expectSize = sizes[i]; + uint64_t retSize; + res = ptr.ToChild()->GetSizeByAddressNoAlign(addr, retSize); + ASSERT_EQ(res, NN_OK); + ASSERT_EQ(expectSize, retSize); + } +} + +TEST_F(TestMemoryAllocator, SerialAlign4k) +{ + auto address = memalign(NN_NO4096, SIZE); + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(address); + options.size = SIZE; + options.minBlockSize = 4096; + options.alignedAddress = true; + UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE, options, ptr); + uint64_t addr = 0; + for (int i = 0; i < 4; ++i) { + auto expectSize = NN_NO4096; + ptr->Allocate(expectSize, addr); + EXPECT_EQ(addr % NN_NO4096, 0); + } +} + +TEST_F(TestMemoryAllocator, SimpleConcurrent) +{ + uint64_t size = 8192 * 16; + auto address = memalign(NN_NO4096, size); + + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(address); + options.size = size; + options.minBlockSize = 4096; + options.alignedAddress = true; + UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE, options, ptr); + std::vector ths; + for (int i = 0; i < 4; ++i) { + std::thread th([&]() { + for (int j = 0; j < 4; ++j) { + uint64_t addr; + auto res = ptr->Allocate(8192, addr); + ASSERT_EQ(res, NN_OK); + } + }); + ths.push_back(std::move(th)); + } + for (int i = 0; i < 4; ++i) { + ths[i].join(); + } + uint64_t addr; + auto res = ptr->Allocate(8192, addr); + ASSERT_NE(res, NN_OK); +} + +TEST_F(TestMemoryAllocator, Concurrent) +{ + std::atomic_uint64_t allocCost { 0 }; + std::atomic_uint64_t freeCost { 0 }; + + for (int k = 0; k < 4; ++k) { + const auto threadCount = 10; + const auto blockCount = 20; + auto totalSize = SIZE * blockCount * threadCount; + auto address = memalign(NN_NO4096, totalSize); + + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(address); + options.size = totalSize; + options.minBlockSize = 4096; + options.alignedAddress = true; + UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE, options, ptr); + + std::vector threads; + + for (int i = 0; i < threadCount; ++i) { + threads.emplace_back(ConcurrentRoutine, ptr, blockCount, false, std::ref(allocCost), std::ref(freeCost)); + } + for (int i = 0; i < threadCount; ++i) { + threads[i].join(); + } + + ptr->Destroy(); + free(address); + } +} + +TEST_F(TestMemoryAllocator, PerfSerialWithRandomSize) +{ + auto bigSize = SIZE << 4; + + for (int k = 0; k < 4; ++k) { + bool res = false; + auto address = memalign(NN_NO4096, bigSize); + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(address); + options.size = bigSize; + options.minBlockSize = 4096; + UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE, options, ptr); + uint64_t addr = 0; + for (int i = 0; i < 4; ++i) { + auto expectSize = random() % bigSize - 16; + res = ptr->Allocate(expectSize, addr); + ASSERT_EQ(res, NN_OK); + res = ptr->Free(addr); + ASSERT_EQ(res, NN_OK); + } + res = ptr->Allocate(SIZE, addr); + ASSERT_EQ(res, NN_OK); + } +} + +TEST_F(TestMemoryAllocator, CompareToNudeMalloc) +{ + auto size = SIZE << 4; + auto address = memalign(NN_NO4096, SIZE * 16); + uintptr_t addr = 0; + uint64_t block = size / NN_NO256; + uint64_t cost[4] = {0, 0, 0, 0}; + int loopCount = 100; + + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(address); + options.size = SIZE * 16; + options.minBlockSize = 4096; + UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE, options, ptr); + for (int i = 0; i < loopCount; ++i) { + auto esize = (random() % block) * NN_NO256; + auto maAllocCost = MONOTONIC_TIME_NS(); + ptr->Allocate(esize, addr); + maAllocCost = MONOTONIC_TIME_NS() - maAllocCost; + cost[0] += maAllocCost; + + auto maFreeCost = MONOTONIC_TIME_NS(); + ptr->Free(addr); + maFreeCost = MONOTONIC_TIME_NS() - maFreeCost; + cost[1] += maFreeCost; + + auto nuAllocCost = MONOTONIC_TIME_NS(); + auto addr1 = malloc(esize); + nuAllocCost = MONOTONIC_TIME_NS() - nuAllocCost; + cost[2] += nuAllocCost; + + auto nuFreeCost = MONOTONIC_TIME_NS(); + free(addr1); + nuFreeCost = MONOTONIC_TIME_NS() - nuFreeCost; + cost[3] += nuFreeCost; + } + + NN_LOG_INFO("ma alloc cost:" << cost[0] / loopCount << "ns, " + << "ma free cost:" << cost[1] / loopCount << "ns, " + << "na free cost:" << cost[2] / loopCount << "ns, " + << "na free cost:" << cost[3] / loopCount << "ns"); +} + +TEST_F(TestMemoryAllocator, PerfConcurrentWithRandomSize) +{ + std::atomic_uint64_t allocCost { 0 }; + std::atomic_uint64_t freeCost { 0 }; + const auto threadCount = 10; + const auto blockCount = 20; + + auto totalSize = SIZE * blockCount * threadCount; + + auto address = memalign(NN_NO4096, totalSize); + + bzero(address, totalSize); + + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(address); + options.size = totalSize; + options.minBlockSize = 4096; + UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE, options, ptr); + + std::vector threads; + + for (int i = 0; i < threadCount; ++i) { + threads.emplace_back(ConcurrentRoutine, ptr, blockCount, true, std::ref(allocCost), std::ref(freeCost)); + } + for (int i = 0; i < threadCount; ++i) { + threads[i].join(); + } + + NN_LOG_INFO("alloc avg cost " << allocCost / threadCount / blockCount << "ns" + << " free avg cost " << freeCost / threadCount / blockCount << "ns"); + ptr->Destroy(); + free(address); +} \ No newline at end of file diff --git a/test/llt/testcase/common/test_memory_allocator.h b/test/llt/testcase/common/test_memory_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..9226495de4428cb33118fb2b4dcc025352413b6a --- /dev/null +++ b/test/llt/testcase/common/test_memory_allocator.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_MEMORY_ALLOCATOR_H +#define HCOM_TEST_MEMORY_ALLOCATOR_H + +#include + +class TestMemoryAllocator : public testing::Test { +public: + TestMemoryAllocator(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_MEMORY_ALLOCATOR_H diff --git a/test/llt/testcase/common/test_memory_allocator_cache.cpp b/test/llt/testcase/common/test_memory_allocator_cache.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ac4dec718dc96d9b8e560d7afd5a3990d7a3d28 --- /dev/null +++ b/test/llt/testcase/common/test_memory_allocator_cache.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include + +#include "hcom.h" +#include "hcom_def.h" +#include "hcom_utils.h" +#include "test_memory_allocator_cache.h" + +#define SIZE (256UL << 12) + +using namespace ock::hcom; + +TestMemoryAllocatorCache::TestMemoryAllocatorCache() {} + +void TestMemoryAllocatorCache::SetUp() {} + +void TestMemoryAllocatorCache::TearDown() {} + +TEST_F(TestMemoryAllocatorCache, AllocateAndFree) +{ + NResult res = NN_OK; + auto startAddress = memalign(NN_NO4096, SIZE); + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(startAddress); + options.size = SIZE; + options.minBlockSize = NN_NO4096; + options.alignedAddress = true; + res = UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE_WITH_CACHE, options, ptr); + ASSERT_EQ(res, NN_OK); + uint64_t address = 0; + auto expectSize = NN_NO4096; + res = ptr->Allocate(expectSize, address); + ASSERT_EQ(res, NN_OK); + res = ptr->Free(address); + ASSERT_EQ(res, NN_OK); +} + +TEST_F(TestMemoryAllocatorCache, AllocateOverSizeAndFree) +{ + NResult res = NN_OK; + auto startAddress = memalign(NN_NO4096, SIZE); + UBSHcomNetMemoryAllocatorPtr ptr; + UBSHcomNetMemoryAllocatorOptions options; + options.address = reinterpret_cast(startAddress); + options.size = SIZE; + options.minBlockSize = NN_NO4096; + options.alignedAddress = true; + uint64_t maxBlockSize = options.minBlockSize * options.cacheTierCount; + res = UBSHcomNetMemoryAllocator::Create(ock::hcom::DYNAMIC_SIZE_WITH_CACHE, options, ptr); + ASSERT_EQ(res, NN_OK); + uint64_t address = 0; + auto expectSize = maxBlockSize + 1; + res = ptr->Allocate(expectSize, address); + ASSERT_EQ(res, NN_OK); + res = ptr->Free(address); + ASSERT_EQ(res, NN_OK); +} \ No newline at end of file diff --git a/test/llt/testcase/common/test_memory_allocator_cache.h b/test/llt/testcase/common/test_memory_allocator_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..53e46183d37602503934eb0ae1e950afe8dea706 --- /dev/null +++ b/test/llt/testcase/common/test_memory_allocator_cache.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_MEMORY_ALLOCATOR_CACHE_H +#define HCOM_TEST_MEMORY_ALLOCATOR_CACHE_H + +#include + +class TestMemoryAllocatorCache : public testing::Test { +public: + TestMemoryAllocatorCache(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_MEMORY_ALLOCATOR_CACHE_H diff --git a/test/llt/testcase/common/test_net_crc32.cpp b/test/llt/testcase/common/test_net_crc32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76a34e3418ba86fe388d91c74fe43decea155606 --- /dev/null +++ b/test/llt/testcase/common/test_net_crc32.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "test_net_crc32.h" +#include "hcom_def.h" +#include "net_crc32.h" + +using namespace ock::hcom; +TestCaseNetCrc32::TestCaseNetCrc32() {} + +void TestCaseNetCrc32::SetUp() {} + +void TestCaseNetCrc32::TearDown() {} + +TEST_F(TestCaseNetCrc32, TestSameString) +{ + std::string buff = "abcdefghijklnmopqrstuvwxyz"; + auto crc1 = NetCrc32::CalcCrc32(buff.data(), buff.size()); + auto crc2 = NetCrc32::CalcCrc32(buff.data(), buff.size()); + + EXPECT_EQ(crc1, crc2); +} + +TEST_F(TestCaseNetCrc32, TestDifferentString) +{ + std::string buff1 = "abcdefghijklnmopqrstuvwxyz0123456789"; + std::string buff2 = "abcdefghijklnmopqrstuvwxyz"; + auto crc1 = NetCrc32::CalcCrc32(buff1.data(), buff1.size()); + auto crc2 = NetCrc32::CalcCrc32(buff2.data(), buff2.size()); + + EXPECT_NE(crc1, crc2); +} diff --git a/test/llt/testcase/common/test_net_crc32.h b/test/llt/testcase/common/test_net_crc32.h new file mode 100644 index 0000000000000000000000000000000000000000..911127f3890a82a1278fa7d35e56e7d80e3a98fb --- /dev/null +++ b/test/llt/testcase/common/test_net_crc32.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_NET_CRC32_H +#define HCOM_TEST_NET_CRC32_H +#include + +class TestCaseNetCrc32 : public testing::Test { +public: + TestCaseNetCrc32(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_NET_CRC32_H diff --git a/test/llt/testcase/common/test_net_execution_service.cpp b/test/llt/testcase/common/test_net_execution_service.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0005550e9b41631061b2b103f9573f8c7e3d6188 --- /dev/null +++ b/test/llt/testcase/common/test_net_execution_service.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "test_net_execution_service.h" + +#include "net_execution_service.h" + +using namespace ock::hcom; + +TestNetExecutionService::TestNetExecutionService() = default; + +void TestNetExecutionService::SetUp() {} + +void TestNetExecutionService::TearDown() {} + +class Task : public NetRunnable { +public: + void Run() override + { + std::cout << "task is executed" << std::endl; + } +}; + +TEST_F(TestNetExecutionService, ExecutionService) +{ + NetExecutorServicePtr es = NetExecutorService::Create(1, 128); + EXPECT_EQ(es.Get() != nullptr, true); + + es->SetThreadName("tt"); + + EXPECT_EQ(es->Start(), true); + + auto t = new (std::nothrow) Task(); + + EXPECT_EQ(es->Execute(t), true); + + sleep(1); + + es->Stop(); +} \ No newline at end of file diff --git a/test/llt/testcase/common/test_net_execution_service.h b/test/llt/testcase/common/test_net_execution_service.h new file mode 100644 index 0000000000000000000000000000000000000000..5d4e7fd0a23ba6cd09f6429ff0e9101c90e0afac --- /dev/null +++ b/test/llt/testcase/common/test_net_execution_service.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HDAGGER_UT_DG_EXECUTION_SERVICE_H +#define HDAGGER_UT_DG_EXECUTION_SERVICE_H + +#include + +class TestNetExecutionService : public testing::Test { +public: + TestNetExecutionService(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; +#endif // HDAGGER_UT_DG_EXECUTION_SERVICE_H diff --git a/test/llt/testcase/common/test_net_rbtree.cpp b/test/llt/testcase/common/test_net_rbtree.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6504f56e3689488fda478edcf47c36c9e7f82158 --- /dev/null +++ b/test/llt/testcase/common/test_net_rbtree.cpp @@ -0,0 +1,271 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include "test_net_rbtree.h" +#include "hcom_def.h" +#include "common/net_rb_tree.h" +#include "hcom_utils.h" + +using namespace ock::hcom; + +TestNetRbTree::TestNetRbTree() {} + +void TestNetRbTree::SetUp() {} + +void TestNetRbTree::TearDown() {} + +static void InsertToRbTree(NetRbTree *tree, NetRbNode *node) +{ + auto cur = &tree->ref; + NetRbNode *parent = nullptr; + while (*cur != nullptr) { + parent = *cur; + if (node->data < (*cur)->data) { + cur = &(parent->left); + } else if (node->data > (*cur)->data) { + cur = &(parent->right); + } else { + return; + } + } + node->Link(parent, cur); + tree->Insert(node); +} + +static void InsertToRbTree(NetRbTree *tree, int val) +{ + InsertToRbTree(tree, new NetRbNode(val)); +} + +static void CppErase(NetRbTree *tree, int val) +{ + auto cur = tree->ref; + NetRbNode *target = nullptr; + while (cur) { + if (cur->data == val) { + target = cur; + break; + } else if (cur->data > val) { + cur = cur->left; + } else { + cur = cur->right; + } + } + tree->Erase(target); +} + + +static void LevelOrderTraverseDump(NetRbTree *tree) +{ + std::cout << "===========================dump rbtree===========================" << std::endl; + std::vector *> seq { tree->ref }; + std::vector ret; + while (!seq.empty()) { + std::vector *> newSeq; + for (uint i = 0; i < seq.size(); ++i) { + if (!seq[i]) { + std::cout << "null "; + continue; + } + ret.push_back(seq[i]->data); + std::cout << seq[i]->data << " "; + newSeq.push_back(seq[i]->left); + newSeq.push_back(seq[i]->right); + } + seq = newSeq; + std::cout << std::endl; + } +} + +static bool IsEveryRedNodeHasTwoBlackChildren(NetRbNode *node) +{ + if (!node) { + return true; + } + if (node->IsRed()) { + if (node->left == nullptr && node->right == nullptr) { + return true; + } else if (node->left == nullptr || node->right == nullptr) { + return false; + } else if (node->left->IsRed() || node->right->IsRed()) { + return false; + } + } + + if (node->left && !IsEveryRedNodeHasTwoBlackChildren(node->left)) { + return false; + } + if (node->right && !IsEveryRedNodeHasTwoBlackChildren(node->right)) { + return false; + } + return true; +} + +static std::pair ChildrenHasSameBlackHeight(NetRbNode *node) +{ + if (!node) { + return { 0, true }; + } + int leftBH = node->IsBlack() ? 1 : 0; + int rightBH = node->IsBlack() ? 1 : 0; + if (node->left) { + auto ret = ChildrenHasSameBlackHeight(node->left); + if (!ret.second) { + return { 0, false }; + } + leftBH += ret.first; + } + if (node->right) { + auto ret = ChildrenHasSameBlackHeight(node->left); + if (!ret.second) { + return { 0, false }; + } + rightBH += ret.first; + } + return { leftBH, leftBH == rightBH }; +} + +static bool IsInorderAscend(NetRbNode *node) +{ + if (!node) { + return true; + } + + if (node->left) { + if (node->data < node->left->data) { + return false; + } + if (!IsInorderAscend(node->left)) { + return false; + } + } + + if (node->right) { + if (node->data > node->right->data) { + return false; + } + if (!IsInorderAscend(node->right)) { + return false; + } + } + + return true; +} + +static bool IsValidRBTree(NetRbTree *tree) +{ + if (!tree || !tree->ref) { + return true; + } + + /* law1:node must be black or red, no need to check for only 1 bit present for color + * law3:every leaf is black, for we just use nullptr for leaf,do not count it for bh,no need to check + * law3:black root */ + if (tree->ref->IsRed()) { + std::cout << "Invalid RBTree, Red Root" << std::endl; + return false; + } + + /* law4:every red node must has two black children or no child */ + if (!IsEveryRedNodeHasTwoBlackChildren(tree->ref)) { + std::cout << "Invalid RBTree, Red Parent Has Red Child or Single Black Child" << std::endl; + return false; + } + + /* law5:every path from one node has same black height */ + if (!ChildrenHasSameBlackHeight(tree->ref).second) { + std::cout << "Invalid RBTree, BH is not balance" << std::endl; + return false; + } + + /* data check:inorder ascend */ + if (!IsInorderAscend(tree->ref)) { + std::cout << "Invalid RBTree, wrong data order" << std::endl; + return false; + } + + return true; +} + +TEST_F(TestNetRbTree, Serial) +{ + NetRbTree rbTree; + + std::set seeds; + std::set delSeeds; + uint64_t totalTime = 0; + for (int k = 0; k < 10; ++k) { + seeds.clear(); + delSeeds.clear(); + + for (int i = 0; i < 10000; ++i) { + auto v = random() % 10000; + seeds.insert(v); + if (v % 3 == 1) { + delSeeds.insert(v); + } + } + auto cost = MONOTONIC_TIME_NS(); + for (const auto &item : seeds) { + InsertToRbTree(&rbTree, item); + } + cost = MONOTONIC_TIME_NS() - cost; + totalTime += cost; + auto ret = IsValidRBTree(&rbTree); + + ASSERT_EQ(ret, true); + + for (const auto &item : delSeeds) { + CppErase(&rbTree, item); + } + + ret = IsValidRBTree(&rbTree); + + ASSERT_EQ(ret, true); + } + + std::cout << "RbTree insert avg cost:" << totalTime / 10 / 10000 << "ns" << std::endl; +} + +TEST_F(TestNetRbTree, EraseColorLeft) +{ + NetRbTree rbTree; + NetRbNode *root = new NetRbNode(0); + InsertToRbTree(&rbTree, root); + + NetRbNode *node = new NetRbNode(2); + NetRbNode *parent = new NetRbNode(0); + parent->right = nullptr; + + ASSERT_EQ(rbTree.EraseColorLeft(node, parent), false); + + delete node; + delete parent; + delete root; +} + +TEST_F(TestNetRbTree, EraseColorRight) +{ + NetRbTree rbTree; + NetRbNode *root = new NetRbNode(0); + InsertToRbTree(&rbTree, root); + + NetRbNode *node = new NetRbNode(2); + NetRbNode *parent = new NetRbNode(0); + parent->left = nullptr; + + ASSERT_EQ(rbTree.EraseColorRight(node, parent), false); + + delete node; + delete parent; + delete root; +} \ No newline at end of file diff --git a/test/llt/testcase/common/test_net_rbtree.h b/test/llt/testcase/common/test_net_rbtree.h new file mode 100644 index 0000000000000000000000000000000000000000..04f8f446b33d2931e3470c46706d19d7ebbefe04 --- /dev/null +++ b/test/llt/testcase/common/test_net_rbtree.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_NET_RBTREE_H +#define HCOM_TEST_NET_RBTREE_H + +#include + +class TestNetRbTree : public testing::Test { +public: + TestNetRbTree(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_MEMORY_ALLOCATOR_H diff --git a/test/llt/testcase/common/test_net_security_alg.cpp b/test/llt/testcase/common/test_net_security_alg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9e5a81b384b53b520424d44e462ce6d7b286aa34 --- /dev/null +++ b/test/llt/testcase/common/test_net_security_alg.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "hcom.h" +#include "common/net_security_alg.h" +#include "common/net_util.h" +#include "test_net_security_alg.hpp" + +using namespace ock::hcom; + +AesGcm128 mAes; +NetSecrets secrets; + +TestNetSecurityAlg::TestNetSecurityAlg() {} + +void TestNetSecurityAlg::SetUp() +{ + EXPECT_EQ(HcomSsl::Load(), 0); + secrets.Init(AES_CCM_128); +} + +void TestNetSecurityAlg::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestNetSecurityAlg, EncryptSuccess) +{ + void *dest = malloc(1024); + std::string value = "hello"; + + uint32_t destLen; + bool result = + mAes.Encrypt(secrets, value.c_str(), value.length(), dest, destLen); + + EXPECT_EQ(true, result); + EXPECT_EQ(mAes.EstimatedEncryptLen(value.length()), destLen); +} + +TEST_F(TestNetSecurityAlg, EncryptKeySecretNullFailed) +{ + void *dest = malloc(1024); + std::string value = "hello"; + const void *keySecrets = nullptr; + MOCKER_CPP(&NetSecrets::GetKeySecret).stubs().will(returnValue(keySecrets)); + uint32_t destLen; + bool result = mAes.Encrypt(secrets, value.c_str(), value.length(), dest, destLen); + + EXPECT_EQ(false, result); +} + +TEST_F(TestNetSecurityAlg, EncryptAadSecretNullFailed) +{ + void *dest = malloc(1024); + std::string value = "hello"; + const void *aadSecrets = nullptr; + MOCKER_CPP(&NetSecrets::GetAADSecret).stubs().will(returnValue(aadSecrets)); + uint32_t destLen; + bool result = mAes.Encrypt(secrets, value.c_str(), value.length(), dest, destLen); + + EXPECT_EQ(false, result); +} + +TEST_F(TestNetSecurityAlg, DecryptCipherLenTooShortFailed) +{ + void *decrypt = malloc(1024); + std::string cipher = "hello"; + + uint32_t decryptRawLen = mAes.GetRawLen(cipher.length()); + bool result = mAes.Decrypt(secrets, cipher.c_str(), cipher.length(), decrypt, decryptRawLen); + + EXPECT_EQ(false, result); +} + +TEST_F(TestNetSecurityAlg, DecryptWrongCipherSuccess) +{ + void *decrypt = malloc(1024); + std::string cipher = "aad iv aes fake cipher of hello world cipher cipher cipher"; + + uint32_t decryptRawLen = mAes.GetRawLen(cipher.length()); + bool result = mAes.Decrypt(secrets, cipher.c_str(), cipher.length(), decrypt, decryptRawLen); + + EXPECT_EQ(true, result); +} + +TEST_F(TestNetSecurityAlg, DecryptSuccess) +{ + void *dest = malloc(1024); + std::string value = "hello"; + + uint32_t destLen = mAes.EstimatedEncryptLen(value.length()); + bool result = + mAes.Encrypt(secrets, value.c_str(), value.length(), dest, destLen); + + EXPECT_EQ(true, result); + EXPECT_EQ(mAes.EstimatedEncryptLen(value.length()), destLen); + + uint32_t decryptRawLen = mAes.GetRawLen(destLen); + void *decrypt = malloc(decryptRawLen); + result = mAes.Decrypt(secrets, dest, destLen, decrypt, decryptRawLen); + + EXPECT_EQ(true, result); + EXPECT_EQ(value.length(), decryptRawLen); + EXPECT_EQ(0, strncmp(value.c_str(), (char *)decrypt, value.length())); +} \ No newline at end of file diff --git a/test/llt/testcase/common/test_net_security_alg.hpp b/test/llt/testcase/common/test_net_security_alg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a293909178b7d672c4233cd47a060210f54244a4 --- /dev/null +++ b/test/llt/testcase/common/test_net_security_alg.hpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef _TEST_NET_SECURITY_ALG_HPP_ +#define _TEST_NET_SECURITY_ALG_HPP_ +#include +#include + +class TestNetSecurityAlg : public testing::Test { +public: + TestNetSecurityAlg(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; +#endif \ No newline at end of file diff --git a/test/llt/testcase/common/test_net_trace.cpp b/test/llt/testcase/common/test_net_trace.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dfb5d5a32116ddf04e74c8b77f441519968b1e85 --- /dev/null +++ b/test/llt/testcase/common/test_net_trace.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "test_net_trace.h" +#include "hcom_def.h" +#include "hcom_log.h" +#include "hcom_service.h" +#include "net_trace.h" + +using namespace ock::hcom; +TestCaseNetTrace::TestCaseNetTrace() {} + +void TestCaseNetTrace::SetUp() {} + +void TestCaseNetTrace::TearDown() {} + +TEST_F(TestCaseNetTrace, TestTraceLevel2) +{ + // NetService::Instance(RDMA, "trace", false); + for (uint32_t i = 0; i < 10; i++) { + TRACE_DELAY_BEGIN(SOCK_WORKER_HANDLE_EPOLL_WRNORM_EVENT); + TRACE_DELAY_END(SOCK_WORKER_HANDLE_EPOLL_WRNORM_EVENT, 0); + } + + TRACE_DELAY_BEGIN(SERVICE_RECONNECT_COMFIRM); + TRACE_DELAY_END(SERVICE_RECONNECT_COMFIRM, 0); + + TRACE_DELAY_BEGIN(SOCK_DRIVER_CREATE_WORKER_RESOURCE); + TRACE_DELAY_END(SOCK_DRIVER_CREATE_WORKER_RESOURCE, 0); + + std::string dumpStr; + dumpStr = NetTrace::TraceDump(); + + NN_LOG_INFO(dumpStr); +} \ No newline at end of file diff --git a/test/llt/testcase/common/test_net_trace.h b/test/llt/testcase/common/test_net_trace.h new file mode 100644 index 0000000000000000000000000000000000000000..5569f1340349e51fd10c39a9a07e1b9159511815 --- /dev/null +++ b/test/llt/testcase/common/test_net_trace.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_NET_TRACE_H +#define HCOM_TEST_NET_TRACE_H +#include + +class TestCaseNetTrace : public testing::Test { +public: + TestCaseNetTrace(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_NET_TRACE_H \ No newline at end of file diff --git a/test/llt/testcase/common/test_obj_pool.cpp b/test/llt/testcase/common/test_obj_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..06aed4d3592daa18ea71722176ade643eb059b48 --- /dev/null +++ b/test/llt/testcase/common/test_obj_pool.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "test_obj_pool.h" +#include "hcom_def.h" +#include "net_common.h" +#include "net_obj_pool.h" +#include "ut_helper.h" + +using namespace ock::hcom; +TestCaseObjPool::TestCaseObjPool() {} + +void TestCaseObjPool::SetUp() {} + +void TestCaseObjPool::TearDown() {} + +TEST_F(TestCaseObjPool, NetObjPool_Serial) +{ + NResult result; + NetObjPool pool("test_pool", 2); + result = pool.Initialize(); + EXPECT_EQ(result, NN_OK); + + DummyObj *rObj1, *rObj2, *rObj3; + auto ret = pool.Dequeue(rObj1); + EXPECT_EQ(ret, true); + rObj1->tag = 1; + ret = pool.Dequeue(rObj2); + EXPECT_EQ(ret, true); + rObj2->tag = 2; + ret = pool.Dequeue(rObj3); + EXPECT_EQ(ret, true); + pool.Enqueue(rObj1); + pool.Enqueue(rObj2); + pool.Enqueue(rObj3); + DummyObj *rObj4, *rObj5; + ret = pool.Dequeue(rObj4); + EXPECT_EQ(ret, true); + EXPECT_EQ(rObj4->tag, 2); + ret = pool.Dequeue(rObj5); + EXPECT_EQ(ret, true); + EXPECT_EQ(rObj5->tag, 1); +} + +TEST_F(TestCaseObjPool, NetObjPool_Concurrency) +{ + std::vector objs(11); + + NetObjPool pool("test_pool", 10); + std::vector ths; + int vsum = 0; + for (int i = 0; i < 10; ++i) { + vsum += i; + std::thread th([&]() { pool.Initialize(); }); + ths.push_back(std::move(th)); + } + for (uint64_t i = 0; i < ths.size(); ++i) { + ths[i].join(); + } + ths.clear(); + + for (int i = 0; i < 10; ++i) { + pool.Dequeue(objs[i]); + objs[i]->tag = i; + } + + for (int i = 0; i < 10; ++i) { + std::thread th([&, i]() { pool.Enqueue(objs[i]); }); + ths.push_back(std::move(th)); + } + for (uint64_t i = 0; i < ths.size(); ++i) { + ths[i].join(); + } + int sum = 0; + for (int i = 0; i < 10; ++i) { + DummyObj *obj; + pool.Dequeue(obj); + sum += obj->tag; + } + EXPECT_EQ(sum, vsum); +} \ No newline at end of file diff --git a/test/llt/testcase/common/test_obj_pool.h b/test/llt/testcase/common/test_obj_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..859c15cded113325e63a40e5dcf989262c1b35d2 --- /dev/null +++ b/test/llt/testcase/common/test_obj_pool.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TESTCASE_COMMONS_H +#define HCOM_TESTCASE_COMMONS_H +#include + +class TestCaseObjPool : public testing::Test { +public: + TestCaseObjPool(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TESTCASE_COMMONS_H diff --git a/test/llt/testcase/service/test_service_common.h b/test/llt/testcase/service/test_service_common.h new file mode 100644 index 0000000000000000000000000000000000000000..ba919a5861ad0a555512f975d53389be60821f9b --- /dev/null +++ b/test/llt/testcase/service/test_service_common.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TESTCASE_SERVICE_COMMON_H +#define HCOM_TESTCASE_SERVICE_COMMON_H +#include "hcom_service.h" +using namespace ock::hcom; + +#define TEST_SERVICE_IP "127.0.0.2" +#define TEST_SERVICE_MASK "127.0.0.2/24" +#define TEST_SERVICE_IP1 "127.0.0.3" +#define TEST_SERVICE_MASK1 "127.0.0.3/24" +#define TEST_SERVICE_PORT 11111 +#define TEST_SERVICE_SEG_SIZE 1024 + +#endif // HCOM_TESTCASE_SERVICE_COMMON_H diff --git a/test/llt/testcase/service/test_service_ctx_store.cpp b/test/llt/testcase/service/test_service_ctx_store.cpp new file mode 100644 index 0000000000000000000000000000000000000000..155c5166d46c0b3b6052aaa8dfb1474f3dd40b8e --- /dev/null +++ b/test/llt/testcase/service/test_service_ctx_store.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "net_mem_pool_fixed.h" +#include "service_ctx_store.h" +#include "hcom_service.h" +#include "net_service_default_imp.h" +#include "net_monotonic.h" +#include "test_service_ctx_store.h" + +using namespace ock::hcom; +TestCaseCtxStore::TestCaseCtxStore() {} + +void TestCaseCtxStore::SetUp() {} + +void TestCaseCtxStore::TearDown() {} + +TEST_F(TestCaseCtxStore, BASIC) +{ + NetMemPoolFixedPtr ctxMemPool; + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO4; + options.minBlkSize = NN_NO64; + options.tcExpandBlkCnt = NN_NO64; + ctxMemPool = new (std::nothrow) NetMemPoolFixed("test", options); + ASSERT_NE(ctxMemPool.Get(), nullptr); + + auto ret = ctxMemPool->Initialize(); + ASSERT_EQ(ret, 0); + + uint32_t flatSize = NN_NO128; + NetServiceCtxStorePtr ctxStore = new (std::nothrow) NetServiceCtxStore(flatSize, ctxMemPool); + ASSERT_NE(ctxStore.Get(), nullptr); + + ret = ctxStore->Initialize(); + ASSERT_EQ(ret, 0); + + NetSeqNo dumpNetSeq(0); + dumpNetSeq.SetValue(1, 7, 16777214); + NN_LOG_INFO(dumpNetSeq.ToString()); + + uint32_t seqNoFlat[NN_NO128]; + /* set flat full */ + for (uint32_t i = 0; i < flatSize - 1; i++) { + uint32_t seqNoId = 0; + auto result = ctxStore->PutAndGetSeqNo(&seqNoFlat[i], seqNoId); + ASSERT_EQ(result, 0); + + NetSeqNo netSeq(seqNoId); + NN_LOG_TRACE_INFO("flag set i = " << i << ", realSeq = " << netSeq.realSeq << ", value " << &seqNoFlat[i]); + ASSERT_EQ(netSeq.realSeq, (i + 1) % flatSize); + + ASSERT_EQ(netSeq.version, 0u); + ASSERT_EQ(netSeq.fromFlat, 1u); + seqNoFlat[i] = seqNoId; + } + + uint32_t seqNoMap[NN_NO128]; + /* set map, every time seq += 3 */ + for (uint32_t i = 0; i < flatSize / 3; i++) { + uint32_t seqNoId = 0; + auto result = ctxStore->PutAndGetSeqNo(&seqNoMap[i], seqNoId); + ASSERT_EQ(result, 0); + + NetSeqNo netSeq(seqNoId); + ASSERT_EQ(netSeq.realSeq / 3u, i + 1); + + ASSERT_EQ(netSeq.version, 1u); + ASSERT_EQ(netSeq.fromFlat, 0u); + seqNoMap[i] = seqNoId; + } + + for (uint32_t i = 0; i < flatSize - 1; i++) { + uint32_t *seqFlatAdd = nullptr; + + NetSeqNo logNetSeq(seqNoFlat[i]); + NN_LOG_TRACE_INFO("flag get i = " << i << ", realSeq = " << logNetSeq.realSeq << ", value " << &seqNoFlat[i]); + auto result = ctxStore->GetSeqNoAndRemove(seqNoFlat[i], seqFlatAdd); + ASSERT_EQ(result, 0); + ASSERT_EQ(seqFlatAdd, &seqNoFlat[i]); + } + + for (uint32_t i = 0; i < flatSize / 3; i++) { + uint32_t *seqMapAdd = nullptr; + auto result = ctxStore->GetSeqNoAndRemove(seqNoMap[i], seqMapAdd); + ASSERT_EQ(result, 0); + ASSERT_EQ(seqMapAdd, &seqNoMap[i]); + } +} + + +TEST_F(TestCaseCtxStore, PERF) +{ + NetMemPoolFixedPtr ctxMemPool; + NetMemPoolFixedOptions options = {}; + options.superBlkSizeMB = NN_NO4; + options.minBlkSize = NN_NO64; + options.tcExpandBlkCnt = NN_NO256; + ctxMemPool = new (std::nothrow) NetMemPoolFixed("test", options); + ASSERT_NE(ctxMemPool.Get(), nullptr); + + auto ret = ctxMemPool->Initialize(); + ASSERT_EQ(ret, 0); + + uint32_t flatSize = NN_NO1048576; + NetServiceCtxStorePtr ctxStore = new (std::nothrow) NetServiceCtxStore(flatSize, ctxMemPool); + ASSERT_NE(ctxStore.Get(), nullptr); + + ret = ctxStore->Initialize(); + ASSERT_EQ(ret, 0); + + uint32_t seqNoFlat; + uint64_t start = NetMonotonic::TimeUs(); + /* set flat full */ + for (uint32_t i = 0; i < flatSize - 1; i++) { + uint32_t seqNoId = 0; + auto result = ctxStore->PutAndGetSeqNo(&seqNoFlat, seqNoId); + ASSERT_EQ(result, 0); + } + NN_LOG_INFO("Put flat seq no " << flatSize << " cost " << (NetMonotonic::TimeUs() - start) << "us"); +} \ No newline at end of file diff --git a/test/llt/testcase/service/test_service_ctx_store.h b/test/llt/testcase/service/test_service_ctx_store.h new file mode 100644 index 0000000000000000000000000000000000000000..02d1e5be20a9c453e6bcaacf926ecafec9b60e58 --- /dev/null +++ b/test/llt/testcase/service/test_service_ctx_store.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TESTCASE_SERVICE_CTX_H +#define HCOM_TESTCASE_SERVICE_CTX_H +#include + +class TestCaseCtxStore : public testing::Test { +public: + TestCaseCtxStore(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TESTCASE_SERVICE_CTX_H diff --git a/test/llt/testcase/service/test_service_io.cpp b/test/llt/testcase/service/test_service_io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b8f04871cf1fb276dd071c0c5e1c7bf216acef9 --- /dev/null +++ b/test/llt/testcase/service/test_service_io.cpp @@ -0,0 +1,481 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "net_mem_pool_fixed.h" +#include "service_ctx_store.h" +#include "hcom_service.h" +#include "net_service_default_imp.h" +#include "test_service_io.h" +#include "test_service_common.h" +#include "rdma_common.h" + +using namespace ock::hcom; +TestCaseServiceIO::TestCaseServiceIO() {} +void TestCaseServiceIO::SetUp() +{ + MOCKER(ReadRoCEVersionFromFile).stubs().will(returnValue(0)); + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} +void TestCaseServiceIO::TearDown() +{ + GlobalMockObject::verify(); +} + +int NewChannel(const std::string &ipPort, const NetChannelPtr &ch, const std::string &payload) +{ + NN_LOG_INFO("new channel call from " << ipPort << " payload: " << payload); + return 0; +} + +void BrokenChannel(const NetChannelPtr &ch) +{ + NN_LOG_INFO("ep broken"); +} + +int ReceivedRequest(NetServiceContext &context) +{ + NetServiceMessage message(context.MessageData(), context.MessageDataLen()); + + if (context.OpType() == NetServiceContext::SER_RECEIVED_RAW) { + char *receive = reinterpret_cast(message.data); + if (receive[0] == 0) { + // receive send message + return 0; + } + + // receive call message + NetServiceMessage req = message; + NetCallback *newCallback = NewCallback([](NetServiceContext &context) {}, std::placeholders::_1); + + // post send callback + if ((context.Channel()->SendRaw(req, newCallback, context.RspCtx()) != 0)) { + NN_LOG_ERROR("failed to post message to data to server"); + return -1; + } + return 0; + } + + if (context.OpCode() == 0) { + NN_LOG_TRACE_INFO("receive msg, channel id " << context.Channel()->Id() << ", info " << + reinterpret_cast(context.MessageData())); + } else { + NetServiceMessage req = message; + // send the same message back to verify + NetCallback *newCallback = NewCallback([](NetServiceContext &context) {}, std::placeholders::_1); + + // post send callback + if ((context.Channel()->Send(context.OpInfo(), req, newCallback, context.RspCtx())) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return -1; + } + } + return 0; +} + +int PostSendRequest(NetServiceContext context) +{ + return 0; +} +int OneSideDownRequest(NetServiceContext context) +{ + return 0; +} + +UBSHcomNetMemoryAllocatorPtr memPtr = nullptr; +uint64_t memSize = 1024 * 1024 * 128; +void *address = nullptr; + +int RndvAllocate(uint64_t size, uintptr_t &outAddress, uint32_t &outKey) +{ + outKey = memPtr->MrKey(); + return memPtr->Allocate(size, outAddress); +} + +int RndvFree(uintptr_t addressFree) +{ + return memPtr->Free(addressFree); +} + +int RndvHandler(NetServiceRndvContext &ctx) +{ + // step1 direct handle message + + // step2 rsp message + int ret = 0; + NetServiceMessage req(&ret, sizeof(ret)); + NetServiceOpInfo opInfo {}; + NetCallback *newCallback = NewCallback([](NetServiceContext &context) {}, std::placeholders::_1); + if (ctx.ReplyRndv(opInfo, req, newCallback) != 0) { + NN_LOG_ERROR("Reply rndv message failed"); + } + + // step3 free context + ctx.FreeMessage(); + return SER_OK; +} + +bool CreateService(NetService *&service, NetServiceProtocol protocol, const std::string &mask, const std::string &ip, + bool startOob) +{ + if (service != nullptr) { + NN_LOG_ERROR("service already created"); + return false; + } + + if (protocol == UBSHcomNetDriverProtocol::SHM) { + NN_LOG_ERROR("service not support shm protocol"); + return false; + } + + std::string name; + static int nameNeed = 0; + if (startOob) { + name = "test_service_server_"; + name += std::to_string(nameNeed++); + } else { + name = "test_service_client_"; + name += std::to_string(nameNeed++); + } + + service = NetService::Instance(protocol, name, startOob); + if (service == nullptr) { + NN_LOG_ERROR("failed to create service already created"); + return false; + } + + NetServiceOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.mrSendReceiveSegSize = TEST_SERVICE_SEG_SIZE; + options.mrSendReceiveSegCount = 32; + options.enableTls = false; + + options.SetNetDeviceIpMask(mask); + options.SetWorkerGroups("1"); + + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + service->SetOobIpAndPort(ip, TEST_SERVICE_PORT); + service->RegisterNewChannelHandler(NewChannel); + service->RegisterChannelBrokenHandler(BrokenChannel, ock::hcom::BROKEN_ALL); + service->RegisterOpReceiveHandler(0, ReceivedRequest); + service->RegisterOpSentHandler(0, PostSendRequest); + service->RegisterOpOneSideHandler(0, OneSideDownRequest); + int result = 0; + if ((result = service->Start(options)) != 0) { + NN_LOG_ERROR("failed to initialize service " << result); + return false; + } + NN_LOG_INFO("service initialized"); + return true; +} + +bool Connect(NetService *service, NetChannelPtr &ch, const std::string &ip, bool selfPoll) +{ + if (service == nullptr) { + NN_LOG_ERROR("service is null"); + return false; + } + + NetServiceConnectOptions options {}; + if (selfPoll) { + options.flags = NET_EP_SELF_POLLING; + } + int result = service->Connect(ip, TEST_SERVICE_PORT, "hello service", ch, options); + if (result != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +bool RegSglMem(NetService *client, NetService *service, uint32_t dataSize, NetServiceRequest iov[], uint16_t iovSize) +{ + for (uint16_t i = 0; i < iovSize; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = client->RegisterMemoryRegion(dataSize, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + iov[i].lAddress = mr->GetAddress(); + iov[i].lKey = mr->GetLKey(); + iov[i].size = dataSize; + + result = service->RegisterMemoryRegion(dataSize, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + + iov[i].rAddress = mr->GetAddress(); + iov[i].rKey = mr->GetLKey(); + } + + return true; +} + +static void TestServiceSends(NetChannelPtr &ch, NetServiceRequest *iov, uint16_t iovCnt) +{ + char data[128]; + NetServiceMessage message(data, sizeof(data)); + NetServiceOpInfo opInfo {}; + opInfo.opCode = 0; + + auto result = ch->Send(opInfo, message, nullptr); + ASSERT_EQ(result, SER_OK); + + data[0] = 0; // mark send + result = ch->SendRaw(message, nullptr); + ASSERT_EQ(result, SER_OK); + + NetServiceSglRequest request; + request.iov = iov; + request.iovCount = iovCnt; + + memset(reinterpret_cast(iov[0].lAddress), 0, 1); // mark send + result = ch->SendRawSgl(request, nullptr); + ASSERT_EQ(result, SER_OK); +} + +static void TestServiceCall(NetChannelPtr &ch, NetServiceRequest *iov, uint16_t iovCnt) +{ + char data1[128]; + char data2[128]; + NetServiceMessage req(data1, sizeof(data1)); + NetServiceMessage rsp(data2, sizeof(data2)); + NetServiceOpInfo reqInfo {}; + NetServiceOpInfo rspInfo {}; + reqInfo.opCode = 1; + reqInfo.timeout = 1; + + auto result = ch->SyncCall(reqInfo, req, rspInfo, rsp); + ASSERT_EQ(result, SER_OK); + + result = memcmp(&reqInfo, &rspInfo, sizeof(NetServiceOpInfo)); + ASSERT_EQ(result, SER_OK); + + data1[0] = 1; // mark call + result = ch->SyncCallRaw(req, rsp); + ASSERT_EQ(result, SER_OK); + + int ret = 0; + sem_t sem; + sem_init(&sem, 0, 0); + NetCallback *newCallback = NewCallback( + [&sem, &ret, &rsp](NetServiceContext &context) { + if (context.Result() != 0 || context.MessageDataLen() != rsp.size) { + NN_LOG_ERROR("Async call result failed or get unwanted message"); + ret = -1; + sem_post(&sem); + return; + } + + memcpy(rsp.data, context.MessageData(), context.MessageDataLen()); + sem_post(&sem); + }, + std::placeholders::_1); + ASSERT_NE(newCallback, nullptr); + + result = ch->AsyncCall(reqInfo, req, newCallback); + + sem_wait(&sem); + sem_destroy(&sem); + ASSERT_EQ(ret, SER_OK); + + // validate data + result = memcmp(data1, data2, sizeof(data1)); + ASSERT_EQ(result, SER_OK); + + NetServiceSglRequest reqSgl; + char *buff = reinterpret_cast(iov[0].lAddress); + buff[0] = 1; // mark call + reqSgl.iov = iov; + reqSgl.iovCount = iovCnt; + + result = ch->SyncCallRawSgl(reqSgl, rsp); + ASSERT_EQ(result, SER_OK); +} + +TEST_F(TestCaseServiceIO, ALL_IO) +{ + setenv("HCOM_TRACE_LEVEL", "2", 1); + NetService *client = nullptr; + NetService *server = nullptr; + + auto result = CreateService(client, UBSHcomNetDriverProtocol::TCP, TEST_SERVICE_MASK, TEST_SERVICE_IP, false); + ASSERT_EQ(result, true); + + result = CreateService(server, UBSHcomNetDriverProtocol::TCP, TEST_SERVICE_MASK, TEST_SERVICE_IP, true); + ASSERT_EQ(result, true); + + NetChannelPtr ch; + result = Connect(client, ch, TEST_SERVICE_IP, false); + ASSERT_EQ(result, true); + + NetServiceRequest iov[NET_SGE_MAX_IOV]; + result = RegSglMem(client, server, NN_NO8, iov, NET_SGE_MAX_IOV); + ASSERT_EQ(result, true); + + TestServiceSends(ch, iov, NET_SGE_MAX_IOV); + TestServiceCall(ch, iov, NET_SGE_MAX_IOV); + + client->Stop(); + server->Stop(); + NetService::DestroyInstance(client->Name()); + NetService::DestroyInstance(server->Name()); + std::string dumpStr; + dumpStr = NetService::TraceLog(); + NN_LOG_INFO(dumpStr); +} + +TEST_F(TestCaseServiceIO, START_FAILED) +{ + NetService *service = nullptr; + + std::string name = "SERVICE"; + + service = NetService::Instance(UBSHcomNetDriverProtocol::TCP, name, true); + + NetServiceOptions options {}; + int result = service->Start(options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + options.enableRndv = true; + service->RegisterNewChannelHandler(NewChannel); + result = service->Start(options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + service->RegisterRndvAllocateHandler(RndvAllocate); + service->RegisterRndvFreeHandler(RndvFree); + result = service->Start(options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + service->RegisterRndvHandler(RndvHandler); + result = service->Start(options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + service->RegisterChannelBrokenHandler(BrokenChannel, ock::hcom::BROKEN_ALL); + options.maxTypeIndexSize = 17; + result = service->Start(options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + options.maxTypeIndexSize = 1; + result = service->Start(options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + service->RegisterOpReceiveHandler(0, ReceivedRequest); + result = service->Start(options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + service->RegisterOpSentHandler(0, PostSendRequest); + result = service->Start(options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + service->RegisterOpOneSideHandler(0, OneSideDownRequest); + options.periodicThreadNum = 0; + result = service->Start(options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + options.periodicThreadNum = 1; + MOCKER(epoll_create).stubs().will(returnValue(-1)); + result = service->Start(options); + ASSERT_EQ(result, NN_INVALID_IP); + GlobalMockObject::verify(); + + MOCKER(::setsockopt).defaults().will(returnValue(-1)); + options.SetNetDeviceIpMask(TEST_SERVICE_MASK); + options.SetWorkerGroups("1"); + service->SetOobIpAndPort(TEST_SERVICE_IP, TEST_SERVICE_PORT); + result = service->Start(options); + ASSERT_EQ(result, NN_OOB_LISTEN_SOCKET_ERROR); + NetService::DestroyInstance(service->Name()); + GlobalMockObject::verify(); +} + +TEST_F(TestCaseServiceIO, CONNECT_FAILED) +{ + NetService *client = nullptr; + NetService *server = nullptr; + + CreateService(client, UBSHcomNetDriverProtocol::TCP, TEST_SERVICE_MASK, TEST_SERVICE_IP, false); + CreateService(server, UBSHcomNetDriverProtocol::TCP, TEST_SERVICE_MASK, TEST_SERVICE_IP, true); + + NetChannelPtr ch; + + NetServiceConnectOptions options {}; + options.epSize = 0; + int result = client->Connect(TEST_SERVICE_IP, TEST_SERVICE_PORT, "hello service", ch, options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + options.epSize = 1; + options.index = 1; + result = client->Connect(TEST_SERVICE_IP, TEST_SERVICE_PORT, "hello service", ch, options); + ASSERT_EQ(result, SER_INVALID_PARAM); + + options.index = 0; + MOCKER_CPP(&ServiceSerializeConnInfo, int (*)(ServiceConnInfo &, const std::string &, std::string &)) + .defaults() + .will(returnObjectList(500)); + result = client->Connect(TEST_SERVICE_IP, TEST_SERVICE_PORT, "hello service", ch, options); + ASSERT_EQ(result, SER_INVALID_PARAM); + GlobalMockObject::verify(); + + options.clientGrpNo = 1; + result = client->Connect(TEST_SERVICE_IP, TEST_SERVICE_PORT, "hello service", ch, options); + ASSERT_EQ(result, NN_ERROR); + + client->Stop(); + NetService::DestroyInstance(client->Name()); + server->Stop(); + NetService::DestroyInstance(server->Name()); +} + +TEST_F(TestCaseServiceIO, CONNECT_MULTI_PROTOCOL) +{ + NetService *client = nullptr; + NetService *server = nullptr; + + auto result = CreateService(client, UBSHcomNetDriverProtocol::TCP, TEST_SERVICE_MASK, TEST_SERVICE_IP, false); + ASSERT_EQ(result, true); + + result = CreateService(server, UBSHcomNetDriverProtocol::TCP, TEST_SERVICE_MASK, TEST_SERVICE_IP, true); + ASSERT_EQ(result, true); + + NetService *client1 = nullptr; + NetService *server1 = nullptr; + + result = CreateService(client1, UBSHcomNetDriverProtocol::RDMA, TEST_SERVICE_MASK1, TEST_SERVICE_IP1, false); + ASSERT_EQ(result, true); + + result = CreateService(server1, UBSHcomNetDriverProtocol::RDMA, TEST_SERVICE_MASK1, TEST_SERVICE_IP1, true); + ASSERT_EQ(result, true); + + NetChannelPtr ch; + result = Connect(client, ch, TEST_SERVICE_IP, false); + ASSERT_EQ(result, true); + + NetChannelPtr ch1; + result = Connect(client1, ch1, TEST_SERVICE_IP1, false); + ASSERT_EQ(result, true); + + client->Stop(); + NetService::DestroyInstance(client->Name()); + server->Stop(); + NetService::DestroyInstance(server->Name()); + client1->Stop(); + NetService::DestroyInstance(client1->Name()); + server1->Stop(); + NetService::DestroyInstance(server1->Name()); +} \ No newline at end of file diff --git a/test/llt/testcase/service/test_service_io.h b/test/llt/testcase/service/test_service_io.h new file mode 100644 index 0000000000000000000000000000000000000000..bdd3b161a7cd5a889e47be9910e6bc1c9d9cc402 --- /dev/null +++ b/test/llt/testcase/service/test_service_io.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TESTCASE_SERVICE_IO_H +#define HCOM_TESTCASE_SERVICE_IO_H +#include + +class TestCaseServiceIO : public testing::Test { +public: + TestCaseServiceIO(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TESTCASE_SERVICE_IO_H diff --git a/test/llt/testcase/test_blocking_queue.cpp b/test/llt/testcase/test_blocking_queue.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d8fbadcbd5fde610583fc126cd9dc96c964d0c3 --- /dev/null +++ b/test/llt/testcase/test_blocking_queue.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include + +#include "test_blocking_queue.h" +#include "hcom_def.h" +#include "net_obj_pool.h" +#include "ut_helper.h" + +using namespace ock::hcom; +TestCaseBlockingQueue::TestCaseBlockingQueue() {} + +void TestCaseBlockingQueue::SetUp() {} + +void TestCaseBlockingQueue::TearDown() {} + +TEST_F(TestCaseBlockingQueue, Serial) +{ + setenv("HCOM_TRACE_LEVEL", "2", 1); + NResult result; + bool ret; + NetBlockingQueue queue(2); + result = queue.Initialize(); + EXPECT_EQ(result, NN_OK); + DummyObj obj0(0), obj1(1), obj2(2); + + ret = queue.Enqueue(obj1); + EXPECT_EQ(ret, true); + ret = queue.EnqueueFirst(obj0); + EXPECT_EQ(ret, true); + ret = queue.Enqueue(obj2); + EXPECT_EQ(ret, false); + ret = queue.EnqueueFirst(obj2); + EXPECT_EQ(ret, false); + + DummyObj obj3, obj4; + ret = queue.Dequeue(obj3); + EXPECT_EQ(ret, true); + EXPECT_EQ(obj3.tag, 0); + ret = queue.Dequeue(obj4); + EXPECT_EQ(ret, true); + EXPECT_EQ(obj4.tag, 1); +} + +TEST_F(TestCaseBlockingQueue, Concurrency) +{ + NResult result; + bool ret; + NetBlockingQueue queue(3); + result = queue.Initialize(); + EXPECT_EQ(result, NN_OK); + DummyObj obj0(0), obj1(1), obj2(2), obj3(3); + + std::thread th([&]() { + DummyObj obj0, obj1, obj2; + bool ret; + ret = queue.Dequeue(obj0); + EXPECT_EQ(ret, true); + EXPECT_EQ(obj0.tag, 0); + ret = queue.Dequeue(obj1); + EXPECT_EQ(ret, true); + EXPECT_EQ(obj1.tag, 1); + ret = queue.Dequeue(obj2); + EXPECT_EQ(ret, true); + EXPECT_EQ(obj2.tag, 2); + }); + + ret = queue.Enqueue(obj0); + EXPECT_EQ(ret, true); + ret = queue.Enqueue(obj1); + EXPECT_EQ(ret, true); + ret = queue.Enqueue(obj2); + EXPECT_EQ(ret, true); + th.join(); + ret = queue.Enqueue(obj0); + EXPECT_EQ(ret, true); + ret = queue.Enqueue(obj1); + EXPECT_EQ(ret, true); + ret = queue.Enqueue(obj2); + EXPECT_EQ(ret, true); + ret = queue.Enqueue(obj3); + EXPECT_EQ(ret, false); +} \ No newline at end of file diff --git a/test/llt/testcase/test_blocking_queue.h b/test/llt/testcase/test_blocking_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..821eeabdd705af2763c27f06c8105645d5b75d74 --- /dev/null +++ b/test/llt/testcase/test_blocking_queue.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TESTCASE_BLOCKING_QUEUE_H +#define HCOM_TESTCASE_BLOCKING_QUEUE_H + +#include + +class TestCaseBlockingQueue : public testing::Test { +public: + TestCaseBlockingQueue(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TESTCASE_BLOCKING_QUEUE_H diff --git a/test/llt/testcase/test_hcom.cpp b/test/llt/testcase/test_hcom.cpp new file mode 100644 index 0000000000000000000000000000000000000000..10c700337193eeca775e905c33d2452751408ac2 --- /dev/null +++ b/test/llt/testcase/test_hcom.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "test_hcom.hpp" +#include "transport/rdma/verbs/net_rdma_driver.h" +#include "transport/rdma/verbs/net_rdma_driver_oob.h" +#include "common/net_util.h" +#include "ut_helper.h" +#include "net_trace.h" + +using namespace ock::hcom; + +static int NewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + return 0; +} + +static void EndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + NN_LOG_INFO("end point " << ep->Id()); +} + +static int RequestReceived(const UBSHcomNetRequestContext &ctx) +{ + // std::string req((char *)ctx.Message()->Data(), ctx.Header().dataLength); + return 0; +} + +static int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} +static int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + +TestHcom::TestHcom() {} +static UBSHcomNetDriverOptions options {}; +void TestHcom::SetUp() +{ +#ifdef RDMA_BUILD_ENABLED + MOCK_VERSION + if (HcomIbv::Load() != 0) { + NN_LOG_ERROR("Failed to load verbs API"); + } +#endif + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; // 只支持EVENT模式 + options.mrSendReceiveSegSize = 1024; + options.mrSendReceiveSegCount = 1024; + options.enableTls = false; + options.SetNetDeviceIpMask(IP_SEG); +} + +void TestHcom::TearDown() +{ + GlobalMockObject::verify(); +} + +static void Log(int level, const char *msg) +{ + struct timeval tv {}; + char strTime[24]; + + gettimeofday(&tv, nullptr); + strftime(strTime, sizeof strTime, "%Y-%m-%d %H:%M:%S.", localtime(&tv.tv_sec)); + + static std::string levelInfo[4] = {"debug", "info", "warn", "error"}; + + std::cout << strTime << tv.tv_usec << " " << levelInfo[level & 3] << " " << msg << " ExteralLogFunc" << std::endl; +} + +TEST_F(TestHcom, InstanceOfTcpProtocolSuccess) +{ + auto tcpDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "tcpServer", true); + UBSHcomNetDriverDeviceInfo deviceInfo; + bool ret = tcpDriver->LocalSupport(ock::hcom::TCP, deviceInfo); + EXPECT_EQ(true, ret); + + tcpDriver->OobIpAndPort(BASE_IP, 9989); + tcpDriver->Initialize(options); + tcpDriver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + tcpDriver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + tcpDriver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + tcpDriver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + NResult result = tcpDriver->Initialize(options); + EXPECT_EQ(NNCode::NN_OK, result); + tcpDriver->Stop(); +} + + +TEST_F(TestHcom, InstanceOfUdsProtocolSuccess) +{ + auto udsDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::UDS, "udsSerDriver", true); + UBSHcomNetDriverDeviceInfo deviceInfo; + bool ret = udsDriver->LocalSupport(ock::hcom::UDS, deviceInfo); + EXPECT_EQ(true, ret); + + udsDriver->OobIpAndPort(BASE_IP, 9989); + udsDriver->Initialize(options); + udsDriver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + udsDriver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + udsDriver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + udsDriver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + NResult result = udsDriver->Initialize(options); + EXPECT_EQ(NNCode::NN_OK, result); + udsDriver->Stop(); +} + +TEST_F(TestHcom, ExteralLogFunc) +{ + NetLogger::Instance()->SetExternalLogFunction(Log); + NN_LOG_DEBUG("debug_log"); + NN_LOG_INFO("info_log"); + NN_LOG_WARN("warn_log"); + NN_LOG_ERROR("error_log"); + NetLogger::Instance()->SetExternalLogFunction(nullptr); +} + +#ifdef RDMA_BUILD_ENABLED +TEST_F(TestHcom, InstanceOfRDMAProtocolSuccess) +{ + auto rdmaDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rServer", true); + UBSHcomNetDriverDeviceInfo deviceInfo; + bool ret = rdmaDriver->LocalSupport(ock::hcom::RDMA, deviceInfo); + EXPECT_EQ(true, ret); + + rdmaDriver->OobIpAndPort(BASE_IP, 9989); + rdmaDriver->Initialize(options); + rdmaDriver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + rdmaDriver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + rdmaDriver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + rdmaDriver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + NResult result = rdmaDriver->Initialize(options); + EXPECT_EQ(NNCode::NN_OK, result); + rdmaDriver->Stop(); +} +#endif + +TEST_F(TestHcom, InstanceOfOtherProtocolFailed) +{ + UBSHcomNetDriverProtocol driverProtocol; + driverProtocol = (UBSHcomNetDriverProtocol)100; + + auto otherDriver = UBSHcomNetDriver::Instance(driverProtocol, "otherServer", true); + EXPECT_EQ(nullptr, otherDriver); +} + +#ifdef RDMA_BUILD_ENABLED +TEST_F(TestHcom, LocalSupportOtherFailed) +{ + auto otherDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rServer", true); + UBSHcomNetDriverProtocol driverProtocol; + driverProtocol = (UBSHcomNetDriverProtocol)100; + + UBSHcomNetDriverDeviceInfo deviceInfo; + bool ret = otherDriver->LocalSupport(driverProtocol, deviceInfo); + EXPECT_EQ(false, ret); +} +#endif + +TEST_F(TestHcom, InstanceWithoutHtracer) +{ + auto tcpDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "driver_without_htracer", true); + ASSERT_EQ(NetTrace::gTraceInst, nullptr); + tcpDriver->Stop(); +} + +TEST_F(TestHcom, InstanceWithHtracer) +{ + setenv("HCOM_ENABLE_TRACE", "1", 1); + auto tcpDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "driver_with_htracer", true); + ASSERT_NE(NetTrace::gTraceInst, nullptr); + tcpDriver->Stop(); + unsetenv("HCOM_ENABLE_TRACE"); +} \ No newline at end of file diff --git a/test/llt/testcase/test_hcom.hpp b/test/llt/testcase/test_hcom.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd3f7bc3fc9de172b7c6d82f2dbfd4fc5c9b6162 --- /dev/null +++ b/test/llt/testcase/test_hcom.hpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef _TEST_HCOM_HPP_ +#define _TEST_HCOM_HPP_ +#include +#include + +class TestHcom : public testing::Test { +public: + TestHcom(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; +#endif \ No newline at end of file diff --git a/test/llt/testcase/test_openssl.cpp b/test/llt/testcase/test_openssl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..52e5ae9e63f0dc8cc4a6928ed9ad7275e6763730 --- /dev/null +++ b/test/llt/testcase/test_openssl.cpp @@ -0,0 +1,654 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include "hcom.h" +#include "openssl_api_wrapper.h" +#include "test_openssl.h" + +using namespace ock::hcom; +TestOpenSsl::TestOpenSsl() {} + +void TestOpenSsl::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestOpenSsl::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestOpenSsl, OpenSslLoadError) +{ + int result = 0; + void *ptr = nullptr; + MOCKER(dlsym).stubs().will(returnValue(ptr)); + result = HcomSsl::Load(); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError1) +{ + int result = 0; + void *ptr = nullptr; + MOCKER(dlopen).stubs().will(returnValue(ptr)); + result = HcomSsl::Load(); + EXPECT_EQ(-1, result); +} + +int openSize = 64; +int times = 1; +TEST_F(TestOpenSsl, OpenSslLoadError2) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(returnValue(ptr1)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError3) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError4) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError5) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError6) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError7) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError8) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError9) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError10) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError11) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError12) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError13) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError14) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError15) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError16) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError17) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError18) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError19) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError20) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError21) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError22) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError23) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError24) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError25) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError26) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError27) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError28) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError29) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError30) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError31) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError32) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError33) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError34) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError35) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError36) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError37) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError38) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError39) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError40) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError41) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError42) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError43) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError44) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError45) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError46) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError47) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError48) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError49) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError50) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError51) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError52) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError53) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError54) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError55) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} + +TEST_F(TestOpenSsl, OpenSslLoadError56) +{ + int result = 0; + void *ptr = nullptr; + void *ptr1 = malloc(openSize); + MOCKER(dlsym).stubs().will(repeat(ptr1, ++times)).then(returnValue(ptr)); + result = HcomSsl::Load(); + free(ptr1); + EXPECT_EQ(-1, result); +} \ No newline at end of file diff --git a/test/llt/testcase/test_openssl.h b/test/llt/testcase/test_openssl.h new file mode 100644 index 0000000000000000000000000000000000000000..931a4dcece20dd7bb1b8fe501e37c0d113b9dd87 --- /dev/null +++ b/test/llt/testcase/test_openssl.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_TEST_OPENSSL_H +#define HCOM_TEST_OPENSSL_H +#include +#include + +class TestOpenSsl : public testing::Test { +public: + TestOpenSsl(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_OPENSSL_H diff --git a/test/llt/testcase/test_pure_functions.cpp b/test/llt/testcase/test_pure_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ace6a905bdc07cfc4039a7481721fc55a64a5992 --- /dev/null +++ b/test/llt/testcase/test_pure_functions.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "test_pure_functions.h" + + +TestPureFunctions::TestPureFunctions() {} + +void TestPureFunctions::SetUp() {} + +void TestPureFunctions::TearDown() {} + +#if 0 +#include "hcom_securec.h" +using namespace ock::hcom; +TEST_F(TestPureFunctions, Memcpy_s) +{ + auto src = (int *)malloc(sizeof(int) * 10); + for (int i = 0; i < 10; ++i) { + src[i] = i; + } + auto dst = (int *)malloc(sizeof(int) * 8); + bzero(dst, sizeof(int) * 8); + for (int i = 0; i < 8; ++i) { + dst[i] = 11; + } + + auto ret = memcpy_s(nullptr, 0, nullptr, 0); + EXPECT_EQ(ret, SEC_ERANGE); + + ret = memcpy_s(nullptr, 0x7fffffffUL + 1, nullptr, 0); + EXPECT_EQ(ret, SEC_ERANGE); + + ret = memcpy_s(dst, 0x7fffffffUL + 1, src, 0); + EXPECT_EQ(ret, SEC_ERANGE); + EXPECT_EQ(dst[0], 11); + + ret = memcpy_s(nullptr, 1, src, 0); + EXPECT_EQ(ret, SEC_EINVAL); + + ret = memcpy_s(dst, sizeof(int), nullptr, 0); + EXPECT_EQ(ret, SEC_EINVAL_AND_RESET); + EXPECT_EQ(dst[0], 0); + EXPECT_EQ(dst[1], 11); + + dst[0] = 11; + ret = memcpy_s(dst, sizeof(int), src, sizeof(int) * 2); + EXPECT_EQ(ret, SEC_ERANGE_AND_RESET); + EXPECT_EQ(dst[0], 0); + EXPECT_EQ(dst[1], 11); + + dst[0] = 11; + ret = memcpy_s(dst, sizeof(int) * 2, dst + 1, sizeof(int) * 2); + EXPECT_EQ(ret, SEC_EOVERLAP_AND_RESET); + EXPECT_EQ(dst[0], 0); + EXPECT_EQ(dst[1], 0); + + dst[0] = 11; + dst[1] = 11; + ret = memcpy_s(dst, sizeof(int) * 8, src, sizeof(int) * 4); + EXPECT_EQ(ret, EOK); + EXPECT_EQ(dst[0], 0); + EXPECT_EQ(dst[3], 3); + EXPECT_EQ(dst[4], 11); +} + +TEST_F(TestPureFunctions, Strcpy_s) +{ + char src[8] = "abcdefg"; + char dst[10] = "zzzzzzzzz"; + auto result = strcpy_s(nullptr, 0, nullptr); + EXPECT_EQ(result, SEC_ERANGE); + result = strcpy_s(dst, 0, src); + EXPECT_EQ(result, SEC_ERANGE); + result = strcpy_s(nullptr, 1, nullptr); + EXPECT_EQ(result, SEC_EINVAL); + result = strcpy_s(nullptr, 1, src); + EXPECT_EQ(result, SEC_EINVAL); + result = strcpy_s(dst, 8, nullptr); + EXPECT_EQ(result, SEC_EINVAL_AND_RESET); + EXPECT_EQ(dst[0], '\0'); + EXPECT_EQ(dst[1], 'z'); + dst[0] = 'z'; + result = strcpy_s(dst, 7, src); + EXPECT_EQ(result, SEC_ERANGE_AND_RESET); + EXPECT_EQ(dst[0], '\0'); + EXPECT_EQ(dst[1], 'z'); + dst[0] = 'z'; + result = strcpy_s(dst, 6, &(dst[4])); + EXPECT_EQ(result, SEC_EOVERLAP_AND_RESET); + EXPECT_EQ(dst[0], '\0'); + EXPECT_EQ(dst[1], 'z'); + dst[0] = 'z'; + result = strcpy_s(dst, 8, src); + EXPECT_EQ(result, EOK); + EXPECT_EQ(dst[0], 'a'); + EXPECT_EQ(dst[6], 'g'); + EXPECT_EQ(dst[7], '\0'); +} +#endif \ No newline at end of file diff --git a/test/llt/testcase/test_pure_functions.h b/test/llt/testcase/test_pure_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..ffcdb96b8d604890b54aa7bae14960b31fe24c85 --- /dev/null +++ b/test/llt/testcase/test_pure_functions.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_PURE_FUNCTIONS_H +#define HCOM_TEST_PURE_FUNCTIONS_H + +#include +class TestPureFunctions : public testing::Test { +public: + TestPureFunctions(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + + +#endif // HCOM_TEST_PURE_FUNCTIONS_H diff --git a/test/llt/testcase/test_ringbuffer.cpp b/test/llt/testcase/test_ringbuffer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5197488f7120b13bbff64ff1f51382996488460 --- /dev/null +++ b/test/llt/testcase/test_ringbuffer.cpp @@ -0,0 +1,172 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include + +#include "hcom.h" +#include "hcom_def.h" +#include "test_ringbuffer.h" + +using namespace ock::hcom; + +TestCaseRingBuffer::TestCaseRingBuffer() {} + +void TestCaseRingBuffer::SetUp() {} + +void TestCaseRingBuffer::TearDown() {} + +TEST_F(TestCaseRingBuffer, NetRingBuffer_Ser_OK) +{ + NResult result; + bool ret; + NetRingBuffer ringBuffer(4); + EXPECT_EQ(ringBuffer.Capacity(), 4); + result = ringBuffer.Initialize(); + EXPECT_EQ(result, NN_OK); + ret = ringBuffer.PushBack(1); + EXPECT_EQ(ret, true); + ringBuffer.PushBack(2); + EXPECT_EQ(ringBuffer.Size(), 2); + int a = 0; + ringBuffer.PopFront(a); + EXPECT_EQ(a, 1); + ringBuffer.PopFront(a); + EXPECT_EQ(a, 2); + EXPECT_EQ(ringBuffer.Size(), 0); + ringBuffer.PushFront(2); + ringBuffer.PushFront(1); + ringBuffer.PushFront(0); + int *b = new int[2]; + ret = ringBuffer.PopFrontN(b, 2); + EXPECT_EQ(ret, true); + EXPECT_EQ(b[0], 0); + EXPECT_EQ(b[1], 1); + EXPECT_EQ(ringBuffer.Size(), 1); + ringBuffer.UnInitialize(); + delete[] b; +} + +TEST_F(TestCaseRingBuffer, NetRingBuffer_Ser_Fail) +{ + NResult result; + bool ret; + + NetRingBuffer zringBuffer(0); + result = zringBuffer.Initialize(); + EXPECT_NE(result, NN_OK); + + NetRingBuffer ringBuffer(2); + result = ringBuffer.Initialize(); + EXPECT_EQ(result, NN_OK); + + ret = ringBuffer.PushBack(0); + EXPECT_EQ(ret, true); + ret = ringBuffer.PushBack(1); + EXPECT_EQ(ret, true); + ret = ringBuffer.PushBack(2); + EXPECT_EQ(ret, false); + EXPECT_EQ(ringBuffer.Size(), 2); + + int a = -1; + ret = ringBuffer.PopFront(a); + EXPECT_EQ(ret, true); + EXPECT_EQ(a, 0); + ret = ringBuffer.PopFront(a); + EXPECT_EQ(ret, true); + EXPECT_EQ(a, 1); + ret = ringBuffer.PopFront(a); + EXPECT_EQ(ret, false); + EXPECT_EQ(a, 1); + EXPECT_EQ(ringBuffer.Size(), 0); + + ringBuffer.PushBack(0); + ringBuffer.PushBack(1); + int *b = new int[2]; + b[0] = -1; + ret = ringBuffer.PopFrontN(b, 3); + EXPECT_EQ(ret, false); + EXPECT_EQ(b[0], -1); + + delete[] b; +} + + +TEST_F(TestCaseRingBuffer, NetRingBuffer_Con_OK) +{ + NResult result; + bool ret = true; + + int count = 100; + int vsum = 0; + + NetRingBuffer ringBuffer(count); + result = ringBuffer.Initialize(); + EXPECT_EQ(result, NN_OK); + + std::vector ths; + for (int i = 0; i < count; ++i) { + vsum += i; + auto v = i; + std::thread th([&, v]() { ret = ringBuffer.PushBack(v) && ret; }); + ths.push_back(std::move(th)); + } + for (int i = 0; i < count; ++i) { + ths[i].join(); + } + EXPECT_EQ(ret, true); + int sum = 0; + int dup = 0; + std::unordered_set reads; + for (int i = 0; i < count; ++i) { + int r = -1; + ret = ringBuffer.PopFront(r) && ret; + if (reads.count(r) > 0) { + dup++; + } + reads.insert(r); + sum += r; + } + EXPECT_EQ(ret, true); + EXPECT_EQ(sum, vsum); + EXPECT_EQ(dup, 0); + + for (int i = 0; i < count; ++i) { + ret = ringBuffer.PushBack(i) && ret; + } + EXPECT_EQ(ret, true); + + ths.clear(); + EXPECT_EQ(ringBuffer.Size(), count); + sum = 0; + dup = 0; + std::mutex locker; + reads.clear(); + ret = true; + for (int i = 0; i < count; ++i) { + std::thread th([&]() { + int r = -1; + ret = ringBuffer.PopFront(r) && ret; + locker.lock(); + if (reads.count(r)) { + dup++; + } + sum += r; + locker.unlock(); + }); + ths.push_back(std::move(th)); + } + for (int i = 0; i < count; ++i) { + ths[i].join(); + } + EXPECT_EQ(dup, 0); + EXPECT_EQ(ret, true); + EXPECT_EQ(sum, vsum); +} \ No newline at end of file diff --git a/test/llt/testcase/test_ringbuffer.h b/test/llt/testcase/test_ringbuffer.h new file mode 100644 index 0000000000000000000000000000000000000000..efe74429062259158ae580863cfa2033610f3bb3 --- /dev/null +++ b/test/llt/testcase/test_ringbuffer.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TESTCASE_RINGBUFFER_H +#define HCOM_TESTCASE_RINGBUFFER_H +#include +class TestCaseRingBuffer : public testing::Test { +public: + TestCaseRingBuffer(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TESTCASE_RINGBUFFER_H diff --git a/test/llt/testcase/test_spinlock.cpp b/test/llt/testcase/test_spinlock.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ce01de8074a9984b6f5bde0ef43fb1be11840eb --- /dev/null +++ b/test/llt/testcase/test_spinlock.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include + +#include "net_obj_pool.h" +#include "test_spinlock.h" + +using namespace ock::hcom; +TestCaseSpinLock::TestCaseSpinLock() {} + +void TestCaseSpinLock::SetUp() {} + +void TestCaseSpinLock::TearDown() {} + +TEST_F(TestCaseSpinLock, Lock) +{ + std::vector ths; + int count = 1000; + int counter = 0; + int scounter = 0; + NetSpinLock lock; + for (int i = 0; i < count; ++i) { + scounter++; + std::thread th([&]() { + lock.Lock(); + counter++; + lock.Unlock(); + }); + ths.push_back(std::move(th)); + } + for (int i = 0; i < count; ++i) { + ths[i].join(); + } + ths.clear(); + EXPECT_EQ(scounter, counter); + for (int i = 0; i < count; ++i) { + scounter++; + std::thread th([&]() { + lock.Lock(); + counter--; + lock.Unlock(); + }); + ths.push_back(std::move(th)); + } + for (int i = 0; i < count; ++i) { + ths[i].join(); + } + EXPECT_EQ(0, counter); +} diff --git a/test/llt/testcase/test_spinlock.h b/test/llt/testcase/test_spinlock.h new file mode 100644 index 0000000000000000000000000000000000000000..516ec8ab76d055274faf57b32c52123877cfb7d4 --- /dev/null +++ b/test/llt/testcase/test_spinlock.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TESTCASE_SPINLOCK_H +#define HCOM_TESTCASE_SPINLOCK_H +#include + +class TestCaseSpinLock : public testing::Test { +public: + TestCaseSpinLock(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TESTCASE_SPINLOCK_H diff --git a/test/llt/testcase/test_thin_classes.cpp b/test/llt/testcase/test_thin_classes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e22c8fb5d98ffeaf1cab23d6fb62a142dfd5f6de --- /dev/null +++ b/test/llt/testcase/test_thin_classes.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "net_obj_pool.h" +#include "test_thin_classes.h" +#include "ut_helper.h" + +using namespace ock::hcom; +TestCaseThinClasses::TestCaseThinClasses() {} + +void TestCaseThinClasses::SetUp() {} + +void TestCaseThinClasses::TearDown() {} + +TEST_F(TestCaseThinClasses, NetUId) +{ + int count = 1000; + std::unordered_set set; + for (int i = 0; i < count; ++i) { + auto id = NetUuid::GenerateUuid(); + EXPECT_EQ(set.count(id), 0); + set.insert(id); + } + set.clear(); + + const std::string ip = "10.10.1.14"; + std::mutex locker; + std::vector ths; + for (int i = 0; i < count; ++i) { + std::thread th([&]() { + auto id = NetUuid::GenerateUuid(ip); + locker.lock(); + EXPECT_EQ(set.count(id), 0); + set.insert(id); + locker.unlock(); + }); + ths.push_back(std::move(th)); + } + for (int i = 0; i < count; ++i) { + ths[i].join(); + } +} + +TEST_F(TestCaseThinClasses, NetRef) +{ + OBJ_LIFE_CYCLE olc(NONE), olc1(NONE), olc2(NONE); + auto obj = new NoisyObj(olc); + auto obj1 = new NoisyObj(olc1); + EXPECT_EQ(olc, INIT); + EXPECT_EQ(obj->GetRef(), 0); + NetRef ref(obj); + EXPECT_EQ(obj->GetRef(), 1); + ref.Set(obj1); + EXPECT_EQ(obj1->GetRef(), 1); + EXPECT_EQ(olc, DEINIT); + { + auto obj2 = new NoisyObj(olc2); + EXPECT_EQ(obj2->GetRef(), 0); + NetRef ref2(obj2); + EXPECT_EQ(obj2->GetRef(), 1); + } + EXPECT_EQ(olc2, DEINIT); +} + +TEST_F(TestCaseThinClasses, UBSHcomNetAtomicState) +{ + UBSHcomNetAtomicState atomicState(2); + EXPECT_EQ(atomicState.Get(), 2); + auto ret = atomicState.Compare(2); + EXPECT_EQ(ret, true); + ret = atomicState.Compare(1); + EXPECT_EQ(ret, false); + atomicState.Set(1); + EXPECT_EQ(atomicState.Get(), 1); + ret = atomicState.CAS(0, 1); + EXPECT_EQ(ret, false); + ret = atomicState.CAS(1, 2); + EXPECT_EQ(ret, true); + EXPECT_EQ(atomicState.Get(), 2); +} \ No newline at end of file diff --git a/test/llt/testcase/test_thin_classes.h b/test/llt/testcase/test_thin_classes.h new file mode 100644 index 0000000000000000000000000000000000000000..970f73ef71364ed4e685c867a973a3aab59d7af9 --- /dev/null +++ b/test/llt/testcase/test_thin_classes.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TESTCASE_THIN_CLASSES_H +#define HCOM_TESTCASE_THIN_CLASSES_H + +#include + +class TestCaseThinClasses : public testing::Test { +public: + TestCaseThinClasses(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TESTCASE_THIN_CLASSES_H diff --git a/test/llt/testcase/transport/rdma/test_negative_rdma_driver.cpp b/test/llt/testcase/transport/rdma/test_negative_rdma_driver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f19a8cb0ef1263d78b5d9a086413424f473d9ed --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_negative_rdma_driver.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef RDMA_BUILD_ENABLED +#include "test_negative_rdma_driver.h" +#include "mockcpp/mockcpp.hpp" +#include "ut_helper.h" + +using namespace ock::hcom; + +TestNegativeRdmaDriver::TestNegativeRdmaDriver() {} + +void TestNegativeRdmaDriver::SetUp() +{ + MOCK_VERSION + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestNegativeRdmaDriver::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestNegativeRdmaDriver, FakeBusyPolling) +{ + NResult result; + UBSHcomNetDriver *driver, *server; + result = UTHelper::GetDriver(server, DRIVER_STATE_START, true); + UT_CHECK_RESULT_OK(result) +} + +TEST_F(TestNegativeRdmaDriver, UseBeforeInit) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver, *server; + result = UTHelper::GetDriver(server, DRIVER_STATE_START, true); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_NOT_NULL(server) + result = UTHelper::GetDriver(driver, DRIVER_STATE_NONE, false); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_NOT_NULL(driver) + + UBSHcomNetMemoryRegionPtr mr; + result = driver->CreateMemoryRegion(NN_NO1024, mr); + UT_CHECK_RESULT_NOK(result) + result = driver->Start(); + UT_CHECK_RESULT_NOK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_NOK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_INIT); + UT_CHECK_RESULT_OK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_NOK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_START); + UT_CHECK_RESULT_OK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(server, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + UBSHcomNetDriver::DestroyInstance(server->Name()); + UBSHcomNetDriver::DestroyInstance(driver->Name()); +} + +TEST_F(TestNegativeRdmaDriver, DestroyUnownedMr) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver, *driver1; + + result = UTHelper::GetDriver(driver, DRIVER_STATE_INIT, false); + UT_CHECK_RESULT_OK(result) + + result = UTHelper::GetDriver(driver1, DRIVER_STATE_INIT, false); + UT_CHECK_RESULT_OK(result) + + UBSHcomNetMemoryRegionPtr mr1; + result = driver1->CreateMemoryRegion(NN_NO1024, mr1); + EXPECT_EQ(result, NN_OK); + + driver->DestroyMemoryRegion(mr1); + EXPECT_NE(mr1.Get()->GetAddress(), 0); + + driver->UnInitialize(); + driver1->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(driver->Name()); + UBSHcomNetDriver::DestroyInstance(driver1->Name()); +} + +TEST_F(TestNegativeRdmaDriver, UseAfterStop) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver, *server; + result = UTHelper::GetDriver(server, DRIVER_STATE_STOP, true); + UT_CHECK_RESULT_OK(result) + result = UTHelper::GetDriver(driver, DRIVER_STATE_STOP, false); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_FALSE(driver->IsStarted()) + result = driver->Start(); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_TRUE(driver->IsStarted()) + UBSHcomNetDriver::DestroyInstance(server->Name()); + UBSHcomNetDriver::DestroyInstance(driver->Name()); +} + +TEST_F(TestNegativeRdmaDriver, DiscontinuousState) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver; + result = UTHelper::GetDriverStateMask(driver, DRIVER_STATE_INIT | DRIVER_STATE_START | DRIVER_STATE_UNINIT, false); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_TRUE(driver->IsStarted()) + UT_CHECK_RESULT_TRUE(driver->IsInited()) + driver->Stop(); + UT_CHECK_RESULT_FALSE(driver->IsStarted()) + UT_CHECK_RESULT_TRUE(driver->IsInited()) + UBSHcomNetDriver::DestroyInstance(driver->Name()); +} + +TEST_F(TestNegativeRdmaDriver, UseAfterUninit) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver, *server; + result = UTHelper::GetDriver(server, DRIVER_STATE_START, true); + UT_CHECK_RESULT_OK(result) + result = UTHelper::GetDriver(driver, DRIVER_STATE_UNINIT, false); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_FALSE(driver->IsStarted()) + UT_CHECK_RESULT_FALSE(driver->IsInited()) + + UBSHcomNetMemoryRegionPtr mr; + result = driver->CreateMemoryRegion(NN_NO1024, mr); + UT_CHECK_RESULT_NOK(result) + result = driver->Start(); + UT_CHECK_RESULT_NOK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_NOK(result) + + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_INIT); + UT_CHECK_RESULT_OK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_NOK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_START); + UT_CHECK_RESULT_OK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(server, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + UBSHcomNetDriver::DestroyInstance(server->Name()); + UBSHcomNetDriver::DestroyInstance(driver->Name()); +} +#endif \ No newline at end of file diff --git a/test/llt/testcase/transport/rdma/test_negative_rdma_driver.h b/test/llt/testcase/transport/rdma/test_negative_rdma_driver.h new file mode 100644 index 0000000000000000000000000000000000000000..e64abfa8934fded9b163e90b74d6faf08a1e1c7c --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_negative_rdma_driver.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_LLT_TEST_NEGATIVE_CASES_H +#define HCOM_LLT_TEST_NEGATIVE_CASES_H +#include + +class TestNegativeRdmaDriver : public testing::Test { +public: + TestNegativeRdmaDriver(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_LLT_TEST_NEGATIVE_CASES_H diff --git a/test/llt/testcase/transport/rdma/test_negative_rdma_endpoint.cpp b/test/llt/testcase/transport/rdma/test_negative_rdma_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f651f2a30dcf1088378842803767d58256a5efb --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_negative_rdma_endpoint.cpp @@ -0,0 +1,297 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include + +#include "test_negative_rdma_endpoint.h" +#include "mockcpp/mockcpp.hpp" +#include "ut_helper.h" + +#define OVERLOAD_CONNECT 64 +using namespace ock::hcom; + +TestNegativeRdmaEndpoint::TestNegativeRdmaEndpoint() {} + +void TestNegativeRdmaEndpoint::SetUp() +{ + MOCK_VERSION + UTHelper::GetDriver(server, DRIVER_STATE_START, true); + UTHelper::GetDriver(client, DRIVER_STATE_START, false); + auto result = client->CreateMemoryRegion(NN_NO1024, clientMr); + UT_CHECK_RESULT_OK(result) + result = server->CreateMemoryRegion(NN_NO1024, serverMr); + UT_CHECK_RESULT_OK(result) + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestNegativeRdmaEndpoint::TearDown() +{ + GlobalMockObject::verify(); + UTHelper::ForwardDriverStateMask(server, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UTHelper::ForwardDriverStateMask(client, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UBSHcomNetDriver::DestroyInstance(server->Name()); + UBSHcomNetDriver::DestroyInstance(client->Name()); +} + +static void InvalidRequestForReadWrite(UBSHcomNetEndpointPtr &ep, UBSHcomNetTransRequest &req) +{ + NResult result = ep->PostWrite(req); + UT_CHECK_RESULT_NOK(result) + + result = ep->PostRead(req); + UT_CHECK_RESULT_NOK(result) +} + +static void InvalidRequestForSend(UBSHcomNetEndpointPtr &ep, UBSHcomNetTransRequest &req, uint16_t opCode, + uint32_t seqNo, UBSHcomNetTransOpInfo &opInfo) +{ + NResult result = ep->PostSend(opCode, req); + UT_CHECK_RESULT_NOK(result) + + result = ep->PostSend(opCode, req, seqNo); + UT_CHECK_RESULT_NOK(result) + + result = ep->PostSend(opCode, req, opInfo); + UT_CHECK_RESULT_NOK(result) +} + +static void InvalidRequestForReadWriteSgl(UBSHcomNetEndpointPtr &ep, UBSHcomNetTransSglRequest &req) +{ + NResult result = ep->PostWrite(req); + UT_CHECK_RESULT_NOK(result) + + result = ep->PostRead(req); + UT_CHECK_RESULT_NOK(result) +} + +static void InvalidRequestForSendSgl(UBSHcomNetEndpointPtr &ep, UBSHcomNetTransSglRequest &req, uint32_t seqNo) +{ + NResult result = ep->PostSendRawSgl(req, seqNo); + UT_CHECK_RESULT_NOK(result) +} + +TEST_F(TestNegativeRdmaEndpoint, AsyncEpBadReq) +{ + NResult result; + UBSHcomNetEndpointPtr ep; + client->Connect("haha", ep); + + // invalid sgl test + UBSHcomNetTransSglRequest reqSgl; + reqSgl.upCtxSize = sizeof(UBSHcomNetTransSglRequest::upCtxData) + 1; + UBSHcomNetTransSgeIov iov; + InvalidRequestForReadWriteSgl(ep, reqSgl); + InvalidRequestForSendSgl(ep, reqSgl, 1); + + reqSgl.iov = &iov; + InvalidRequestForReadWriteSgl(ep, reqSgl); + InvalidRequestForSendSgl(ep, reqSgl, 1); + + iov.size = 100; + InvalidRequestForReadWriteSgl(ep, reqSgl); + InvalidRequestForSendSgl(ep, reqSgl, 1); + + reqSgl.iovCount = 1; + InvalidRequestForReadWriteSgl(ep, reqSgl); + InvalidRequestForSendSgl(ep, reqSgl, 1); + + reqSgl.upCtxSize = 0; + InvalidRequestForReadWriteSgl(ep, reqSgl); + InvalidRequestForSendSgl(ep, reqSgl, 1); + + iov.lAddress = clientMr->GetAddress(); + InvalidRequestForReadWriteSgl(ep, reqSgl); + InvalidRequestForSendSgl(ep, reqSgl, 1); + + iov.lKey = clientMr->GetLKey(); + InvalidRequestForReadWriteSgl(ep, reqSgl); + + iov.rAddress = serverMr->GetAddress(); + InvalidRequestForReadWriteSgl(ep, reqSgl); + + // invalid request test + UBSHcomNetTransRequest req; + UBSHcomNetTransOpInfo opInfo; + req.upCtxSize = sizeof(UBSHcomNetTransRequest::upCtxData) + 1; + + InvalidRequestForReadWrite(ep, req); + InvalidRequestForSend(ep, req, 0, 0, opInfo); + + req.lAddress = clientMr->GetAddress(); + InvalidRequestForReadWrite(ep, req); + InvalidRequestForSend(ep, req, 0, 0, opInfo); + + req.size = NN_NO100; + InvalidRequestForReadWrite(ep, req); + InvalidRequestForSend(ep, req, 0, 0, opInfo); + + req.upCtxSize = 0; + InvalidRequestForReadWrite(ep, req); + + sem_t sem; + sem_init(&sem, 0, 0); + NResult asyncRes = NN_OK; + // valid lkey rkey success + client->RegisterOneSideDoneHandler([&](const UBSHcomNetRequestContext &ctx) { + asyncRes = ctx.Result(); + sem_post(&sem); + return 0; + }); + + req.rAddress = serverMr->GetAddress(); + req.lKey = clientMr->GetLKey(); + req.rKey = serverMr->GetLKey(); + result = ep->PostWrite(req); + UT_CHECK_RESULT_OK(result) + sem_wait(&sem); + UT_CHECK_RESULT_OK(asyncRes) + + std::string value = "hello world"; + UBSHcomNetTransRequest reqValid((void *)(const_cast(value.c_str())), value.length(), 0); + result = ep->PostSend(0, reqValid); + UT_CHECK_RESULT_OK(result) +} + +TEST_F(TestNegativeRdmaEndpoint, SyncEpBadReq) +{ + NResult result; + UBSHcomNetEndpointPtr ep; + client->Connect("haha", ep, NET_EP_SELF_POLLING); + + UBSHcomNetTransRequest onesideReq; + onesideReq.lAddress = clientMr->GetAddress(); + onesideReq.rAddress = serverMr->GetAddress(); + onesideReq.lKey = 0; + onesideReq.rKey = 0; + onesideReq.size = NN_NO100; + onesideReq.upCtxSize = 0; + + // invalid lkey rkey failed + result = ep->PostWrite(onesideReq); + UT_CHECK_RESULT_NOK(result) + result = ep->WaitCompletion(1); + UT_CHECK_RESULT_NOK(result) + + result = ep->PostRead(onesideReq); + UT_CHECK_RESULT_NOK(result) + result = ep->WaitCompletion(1); + UT_CHECK_RESULT_NOK(result) + + // valid lkey rkey success + onesideReq.rKey = serverMr->GetLKey(); + onesideReq.lKey = clientMr->GetLKey(); + result = ep->PostWrite(onesideReq); + UT_CHECK_RESULT_OK(result) + result = ep->WaitCompletion(1); + UT_CHECK_RESULT_OK(result) + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + result = ep->PostSend(0, req); + UT_CHECK_RESULT_OK(result) + + result = ep->WaitCompletion(1); + UT_CHECK_RESULT_OK(result) + + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(1, respCtx); + UT_CHECK_RESULT_NOK(result) +} + +TEST_F(TestNegativeRdmaEndpoint, AsyncEpUseAfterStopFailed) +{ + NResult result; + UBSHcomNetEndpointPtr ep; + client->Connect("haha", ep); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + + result = ep->PostSend(0, req); + UT_CHECK_RESULT_OK(result) + client->Stop(); + + result = ep->PostSend(0, req); + UT_CHECK_RESULT_NOK(result) +} + +TEST_F(TestNegativeRdmaEndpoint, SyncEpUseAfterStopFailed) +{ + NResult result; + UBSHcomNetEndpointPtr ep; + client->Connect("haha", ep, NET_EP_SELF_POLLING); + client->Stop(); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + result = ep->PostSend(0, req); + UT_CHECK_RESULT_NOK(result) + result = ep->WaitCompletion(2); + UT_CHECK_RESULT_NOK(result) +} + +TEST_F(TestNegativeRdmaEndpoint, AsyncOverload) +{ + NResult result; + for (int i = 0; i < OVERLOAD_CONNECT; ++i) { + UBSHcomNetEndpointPtr ep; + result = client->Connect("haha", ep); + if (result == NN_OK) { + ASSERT_NE(ep.Get(), nullptr); + client->Stop(); + } + } +} + +TEST_F(TestNegativeRdmaEndpoint, SyncOverload) +{ + NResult result; + for (int i = 0; i < OVERLOAD_CONNECT; ++i) { + UBSHcomNetEndpointPtr ep; + result = client->Connect("haha", ep, NET_EP_EVENT_POLLING); + if (result == NN_OK) { + ASSERT_NE(ep.Get(), nullptr); + } + } +} + +TEST_F(TestNegativeRdmaEndpoint, AsyncOverloadPost) +{ + NResult result; + UBSHcomNetEndpointPtr ep; + client->Connect("haha", ep, NET_EP_EVENT_POLLING); + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + ASSERT_NE(ep.Get(), nullptr); + for (int i = 0; i < OVERLOAD_CONNECT; ++i) { + result = ep->PostSend(0, req); + UT_CHECK_RESULT_OK(result) + } +} + +TEST_F(TestNegativeRdmaEndpoint, SyncOverloadPost) +{ + NResult result; + UBSHcomNetEndpointPtr ep; + client->Connect("haha", ep, NET_EP_EVENT_POLLING); + ASSERT_NE(ep.Get(), nullptr); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + + for (int i = 0; i < OVERLOAD_CONNECT; ++i) { + result = ep->PostSend(0, req); + UT_CHECK_RESULT_OK(result) + } +} + +#endif \ No newline at end of file diff --git a/test/llt/testcase/transport/rdma/test_negative_rdma_endpoint.h b/test/llt/testcase/transport/rdma/test_negative_rdma_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..99fe2500c7f49dbf8bf973dbcd9045d427f3c36c --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_negative_rdma_endpoint.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_LLT_TEST_NEGATIVE_RDMA_ENDPOINT_H +#define HCOM_LLT_TEST_NEGATIVE_RDMA_ENDPOINT_H +#include +#include "hcom.h" +using namespace ock::hcom; +class TestNegativeRdmaEndpoint : public testing::Test { +protected: + UBSHcomNetDriver *client = nullptr; + UBSHcomNetDriver *server = nullptr; + UBSHcomNetMemoryRegionPtr clientMr; + UBSHcomNetMemoryRegionPtr serverMr; + +public: + TestNegativeRdmaEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); +}; +#endif // HCOM_LLT_TEST_NEGATIVE_RDMA_ENDPOINT_H diff --git a/test/llt/testcase/transport/rdma/test_net_rdma_driver.cpp b/test/llt/testcase/transport/rdma/test_net_rdma_driver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d9fabeaddea6855da2a0a329da96446efddc84b --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_net_rdma_driver.cpp @@ -0,0 +1,758 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef RDMA_BUILD_ENABLED +#include "transport/rdma/verbs/net_rdma_driver.h" +#include "transport/rdma/verbs/net_rdma_driver_oob.h" +#include "transport/rdma/verbs/net_rdma_async_endpoint.h" +#include "common/net_util.h" +#include "test_net_rdma_driver.hpp" +#include "ut_helper.h" + +using namespace ock::hcom; + +int NewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + return 0; +} + +void EndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + NN_LOG_INFO("end point " << ep->Id()); +} + +int RequestReceived(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} +int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + +TestNetDriverRDMA::TestNetDriverRDMA() {} + +UBSHcomNetDriver *driver = nullptr; +UBSHcomNetDriverOptions options{}; + +UBSHcomNetDriver *CreateDriver(std::string name, bool isOobServer, UBSHcomNetDriverSecType secType = NET_SEC_DISABLED) +{ + UBSHcomNetDriver *innerDriver = nullptr; + std::string ipSeg = IP_SEG; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; // 只支持EVENT模式 + options.mrSendReceiveSegSize = 1024; + options.mrSendReceiveSegCount = 1024; + options.secType = secType; + options.enableTls = false; + options.SetNetDeviceIpMask(ipSeg); + + innerDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, name, isOobServer); + innerDriver->OobIpAndPort(BASE_IP, 9989); + + innerDriver->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + innerDriver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + innerDriver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + innerDriver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + innerDriver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + return innerDriver; +} + +void TestNetDriverRDMA::SetUp() +{ + MOCK_VERSION + if (HcomIbv::Load() != 0) { + NN_LOG_ERROR("Failed to load verbs API"); + } + driver = CreateDriver("rdmaServer1", true); + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestNetDriverRDMA::TearDown() +{ + GlobalMockObject::verify(); + driver->Stop(); + driver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(driver->Name()); +} + +TEST_F(TestNetDriverRDMA, InitSuccess) +{ + NResult result = driver->Initialize(options); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetDriverRDMA, InitTwiceSuccess) +{ + NResult result = driver->Initialize(options); + result = driver->Initialize(options); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetDriverRDMA, InitSizeOverSuccess) +{ + options.prePostReceiveSizePerQP = NN_NO1024; + options.maxPostSendCountPerQP = NN_NO1024; + NResult result = driver->Initialize(options); + EXPECT_EQ(NNCode::NN_OK, result); + + options.prePostReceiveSizePerQP = NN_NO64; + options.maxPostSendCountPerQP = NN_NO64; +} + +TEST_F(TestNetDriverRDMA, InitSizeZeroFail) +{ + options.prePostReceiveSizePerQP = 0; + NResult result = driver->Initialize(options); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); + + options.prePostReceiveSizePerQP = NN_NO64; +} + +TEST_F(TestNetDriverRDMA, InitPostSendCountZeroFail) +{ + options.maxPostSendCountPerQP = 0; + NResult result = driver->Initialize(options); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); + + options.maxPostSendCountPerQP = NN_NO64; +} + +TEST_F(TestNetDriverRDMA, InitWithoutIpMaskFailed) +{ + options.SetNetDeviceIpMask(""); + NResult result = driver->Initialize(options); + EXPECT_EQ(NNCode::NN_INVALID_IP, result); + options.SetNetDeviceIpMask(IP_SEG); +} + +TEST_F(TestNetDriverRDMA, InitWithInvalidTypeFailed) +{ + NetDriverOobType x; + x = (NetDriverOobType)9; + options.oobType = x; + NResult result = driver->Initialize(options); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); + options.oobType = ock::hcom::NET_OOB_TCP; +} + +TEST_F(TestNetDriverRDMA, StartSuccess) +{ + options.eventPollingTimeout = 1000; + NResult result = driver->Initialize(options); + result = driver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetDriverRDMA, StartTwiceSuccess) +{ + NResult result = driver->Initialize(options); + result = driver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + result = driver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetDriverRDMA, StartWithoutStartWorkersSuccess) +{ + options.dontStartWorkers = true; + NResult result = driver->Initialize(options); + + result = driver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + options.dontStartWorkers = false; + + driver->DumpObjectStatistics(); +} + +TEST_F(TestNetDriverRDMA, StartWithoutInitializeFailed) +{ + NResult result = driver->Start(); + EXPECT_EQ(NNCode::NN_ERROR, result); +} + +TEST_F(TestNetDriverRDMA, StartWithoutNewEPHdFailed) +{ + NResult result = driver->Initialize(options); + driver->RegisterNewEPHandler(nullptr); + + result = driver->Start(); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetDriverRDMA, StartWithoutEPBrokenHdFailed) +{ + NResult result = driver->Initialize(options); + driver->RegisterEPBrokenHandler(nullptr); + + result = driver->Start(); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetDriverRDMA, StartWithoutNewReqHdFailed) +{ + NResult result = driver->Initialize(options); + driver->RegisterNewReqHandler(nullptr); + + result = driver->Start(); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetDriverRDMA, StartWithoutReqPostedHdFailed) +{ + NResult result = driver->Initialize(options); + driver->RegisterReqPostedHandler(nullptr); + result = driver->Start(); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetDriverRDMA, StartWithoutOneSideHdFailed) +{ + NResult result = driver->Initialize(options); + driver->RegisterOneSideDoneHandler(nullptr); + + result = driver->Start(); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + + +TEST_F(TestNetDriverRDMA, DriverOobConnectWithoutClientCbFailed) +{ + driver->Initialize(options); + driver->Start(); + + UBSHcomNetDriver *oobC = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "c", true); + + oobC->OobIpAndPort(BASE_IP, 9989); + oobC->Initialize(options); + NResult result = oobC->Start(); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = oobC->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_ERROR, result); +} + +TEST_F(TestNetDriverRDMA, DriverOobConnectSuccess) +{ + driver->Initialize(options); + driver->Start(); + + UBSHcomNetDriver *oobC = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "cSuccess", false); + + oobC->OobIpAndPort(BASE_IP, 9989); + oobC->Initialize(options); + oobC->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + oobC->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + oobC->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + oobC->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + oobC->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + NResult result = oobC->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = oobC->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OK, result); + + oobC->Stop(); + oobC->UnInitialize(); +} + +int CreateAuthInfo(uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, uint32_t &outLen, + bool &needAutoFree) +{ + const char *kToken = "token"; + flag = 1; + output = const_cast(kToken); + outLen = strlen(kToken); + type = NET_SEC_VALID_TWO_WAY; + NN_LOG_INFO("auth info " << output << " len:" << outLen << " flag:" << flag << " sec type:" << + UBSHcomNetDriverSecTypeToString(type)); + return 0; +} + +int CreateAuthInfoFailed(uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, uint32_t &outLen, + bool &needAutoFree) +{ + return -1; +} + +int AuthValidate(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen) +{ + if (input != nullptr) { + NN_LOG_INFO("Auth validate flag:" << flag << " ctx:" << ctx); + } else { + NN_LOG_INFO("Auth validate flag:" << flag << " ctx:" << ctx << " input:" << input << " input Len:" << inputLen); + } + + return 0; +} + +int AuthValidateFailed(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen) +{ + if (input != nullptr) { + NN_LOG_INFO("Auth validate flag:" << flag << " ctx:" << ctx); + } else { + NN_LOG_INFO("Auth validate flag:" << flag << " ctx:" << ctx << " input:" << input << " input Len:" << inputLen); + } + + return -1; +} + +TEST_F(TestNetDriverRDMA, DriverOobSecTwoWaySecSuccess) +{ + auto sDriver = CreateDriver("twoWaySecServer", true); + sDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + sDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + + sDriver->Initialize(options); + NResult result = sDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + + auto cDriver = CreateDriver("twoWaySecClient", false); + cDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + cDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + + cDriver->Initialize(options); + result = cDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = cDriver->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OK, result); + + cDriver->Stop(); + cDriver->UnInitialize(); + sDriver->Stop(); + sDriver->UnInitialize(); +} + +TEST_F(TestNetDriverRDMA, DriverOobSecTwoWaySecFailedWithClientSendError) +{ + auto sDriver = CreateDriver("twoWaySecServer1", true, NET_SEC_VALID_TWO_WAY); + sDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + sDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + sDriver->Initialize(options); + NResult result = sDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + auto cDriver = CreateDriver("CreateAuthInfoFailedCli", false, NET_SEC_VALID_TWO_WAY); + cDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfoFailed, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, + std::placeholders::_6)); + cDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + + cDriver->Initialize(options); + result = cDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = cDriver->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_SEC_PROCESS_ERROR, result); + + cDriver->Stop(); + cDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(cDriver->Name()); + sDriver->Stop(); + sDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(sDriver->Name()); +} + +TEST_F(TestNetDriverRDMA, DriverOobSecTwoWaySecFailedWithServerValidError) +{ + auto sDriver = CreateDriver("AuthValidateFailedServer", true, NET_SEC_VALID_TWO_WAY); + sDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + sDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidateFailed, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); + sDriver->Initialize(options); + NResult result = sDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + auto cDriver = CreateDriver("twoWaySecClient1", false, NET_SEC_VALID_TWO_WAY); + cDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + cDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + + cDriver->Initialize(options); + result = cDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = cDriver->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_SEC_PROCESS_ERROR, result); + + cDriver->Stop(); + cDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(cDriver->Name()); + sDriver->Stop(); + sDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(sDriver->Name()); +} + +TEST_F(TestNetDriverRDMA, DriverOobSecTwoWaySecFailedWithServerSendError) +{ + auto sDriver = CreateDriver("CreateAuthInfoFailedSrv", true, NET_SEC_VALID_TWO_WAY); + sDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfoFailed, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, + std::placeholders::_6)); + sDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + sDriver->Initialize(options); + NResult result = sDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + auto cDriver = CreateDriver("twoWaySecClient2", false, NET_SEC_VALID_TWO_WAY); + cDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + cDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + + cDriver->Initialize(options); + result = cDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = cDriver->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_SEC_PROCESS_ERROR, result); + + cDriver->Stop(); + cDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(cDriver->Name()); + sDriver->Stop(); + sDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(sDriver->Name()); +} + +TEST_F(TestNetDriverRDMA, DriverOobSecTwoWaySecFailedWithClientValidError) +{ + auto sDriver = CreateDriver("twoWaySecServer", true, NET_SEC_VALID_TWO_WAY); + sDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + sDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + sDriver->Initialize(options); + NResult result = sDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + auto cDriver = CreateDriver("AuthValidateFailedClient", false, NET_SEC_VALID_TWO_WAY); + cDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + cDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidateFailed, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); + + cDriver->Initialize(options); + result = cDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = cDriver->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_SEC_PROCESS_ERROR, result); + + cDriver->Stop(); + cDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(cDriver->Name()); + sDriver->Stop(); + sDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(sDriver->Name()); +} + +TEST_F(TestNetDriverRDMA, DriverOobSecTwoWaySecFailedWithClientNotSetProvider) +{ + auto sDriver = CreateDriver("twoWaySecServer3", true, NET_SEC_VALID_TWO_WAY); + sDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + sDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + sDriver->Initialize(options); + NResult result = sDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + auto cDriver = CreateDriver("NotSetProviderClient", false, NET_SEC_VALID_TWO_WAY); + cDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + + cDriver->Initialize(options); + result = cDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = cDriver->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_SEC_PROCESS_ERROR, result); + + cDriver->Stop(); + cDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(cDriver->Name()); + sDriver->Stop(); + sDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(sDriver->Name()); +} + +TEST_F(TestNetDriverRDMA, DriverOobSecTwoWaySecFailedWithClientNotSetValidator) +{ + auto sDriver = CreateDriver("twoWaySecServer4", true, NET_SEC_VALID_TWO_WAY); + sDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + sDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + sDriver->Initialize(options); + NResult result = sDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + auto cDriver = CreateDriver("NotSetValidatorClient", false, NET_SEC_VALID_TWO_WAY); + cDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + + cDriver->Initialize(options); + result = cDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = cDriver->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_SEC_PROCESS_ERROR, result); + + cDriver->Stop(); + cDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(cDriver->Name()); + sDriver->Stop(); + sDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(sDriver->Name()); +} + +TEST_F(TestNetDriverRDMA, DriverOobSecTwoWaySecFailedWithServerNotSetValidator) +{ + auto sDriver = CreateDriver("NotSetValidatorServer", true, NET_SEC_VALID_TWO_WAY); + sDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + + sDriver->Initialize(options); + NResult result = sDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + auto cDriver = CreateDriver("twoWaySecClient3", false, NET_SEC_VALID_TWO_WAY); + cDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + cDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + + cDriver->Initialize(options); + result = cDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = cDriver->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_SEC_PROCESS_ERROR, result); + + cDriver->Stop(); + cDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(cDriver->Name()); + sDriver->Stop(); + sDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(sDriver->Name()); +} + +TEST_F(TestNetDriverRDMA, DriverOobSecTwoWaySecFailedWithServerNotSetProvider) +{ + auto sDriver = CreateDriver("NotSetProviderServer", true, NET_SEC_VALID_TWO_WAY); + sDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + sDriver->Initialize(options); + NResult result = sDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + auto cDriver = CreateDriver("twoWaySecClient4", false, NET_SEC_VALID_TWO_WAY); + cDriver->RegisterEndpointSecInfoProvider(std::bind(&CreateAuthInfo, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + cDriver->RegisterEndpointSecInfoValidator(std::bind(&AuthValidate, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + + cDriver->Initialize(options); + result = cDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = cDriver->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_SEC_PROCESS_ERROR, result); + + cDriver->Stop(); + cDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(cDriver->Name()); + sDriver->Stop(); + sDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(sDriver->Name()); +} + +TEST_F(TestNetDriverRDMA, DriverOobConnectSendReceiveFailed) +{ + options.secType = NET_SEC_VALID_TWO_WAY; + driver->Initialize(options); + driver->Start(); + + UBSHcomNetDriver *oobC = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "cSendReceiveFailed", false); + + oobC->OobIpAndPort(BASE_IP, 9989); + oobC->Initialize(options); + oobC->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + oobC->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + oobC->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + oobC->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + oobC->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + NResult result = oobC->Start(); + EXPECT_EQ(NNCode::NN_OK, result); + + UBSHcomNetEndpointPtr ep = nullptr; + result = oobC->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_SEC_PROCESS_ERROR, result); + + MOCKER(::recv).stubs().will(returnValue(static_cast(-1))); + MOCKER(::send).stubs().will(returnValue(static_cast(-1))); + result = oobC->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OOB_CLIENT_SOCKET_ERROR, result); + + oobC->Stop(); + oobC->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(oobC->Name()); + GlobalMockObject::verify(); +} + +TEST_F(TestNetDriverRDMA, DriverOobConnectEmplaceFailed) +{ + options.enableTls = false; + driver->Initialize(options); + driver->Start(); + uint64_t epId = 1; + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(epId, nullptr, nullptr, workerIndex); + EXPECT_NE(ep.Get(), nullptr); + driver->mEndPoints.emplace(epId, ep); + MOCKER_CPP(&UBSHcomNetEndpoint::Id).stubs().will(returnValue(epId)); + + UBSHcomNetEndpointPtr ep1 = nullptr; + NResult result = driver->Connect(BASE_IP, 9989, "a", ep1, 0, 0); + EXPECT_EQ(result, NN_ERROR); + GlobalMockObject::verify(); +} + +TEST_F(TestNetDriverRDMA, DriverOobConnectUdsSuccess) +{ + UBSHcomNetOobUDSListenerOptions opt{}; + + opt.Name("udsServer"); + opt.perm = 0; + UBSHcomNetDriver *oobUdsS = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "udsServer", true); + + oobUdsS->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + oobUdsS->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + oobUdsS->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + oobUdsS->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + oobUdsS->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + options.oobType = ock::hcom::NET_OOB_UDS; + + oobUdsS->AddOobUdsOptions(opt); + oobUdsS->OobUdsName("udsServer"); + + oobUdsS->Initialize(options); + NResult result = oobUdsS->Start(); + UBSHcomNetDriver *oobUdsC = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "udsClient", false); + oobUdsC->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + oobUdsC->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + oobUdsC->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + oobUdsC->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + oobUdsC->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + options.oobType = ock::hcom::NET_OOB_UDS; + oobUdsC->OobUdsName("udsServer"); + oobUdsC->Initialize(options); + oobUdsC->AddOobUdsOptions(opt); + result = oobUdsC->Start(); + + EXPECT_EQ(NNCode::NN_OK, result); + UBSHcomNetEndpointPtr ep = nullptr; + + result = oobUdsC->Connect("a", ep, 0, 0, 0); + EXPECT_EQ(NNCode::NN_OK, result); + + options.oobType = ock::hcom::NET_OOB_TCP; + oobUdsS->Stop(); + oobUdsS->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(oobUdsS->Name()); + oobUdsC->Stop(); + oobUdsC->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(oobUdsC->Name()); +} + +TEST_F(TestNetDriverRDMA, CreateMemoryRegionSuccess) +{ + driver->Initialize(options); + + UBSHcomNetMemoryRegionPtr mr; + NResult result = driver->CreateMemoryRegion(NN_NO1024, mr); + + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetDriverRDMA, CreateMemoryWithAddRegionSuccess) +{ + driver->Initialize(options); + + UBSHcomNetMemoryRegionPtr mr; + uintptr_t address = 1; + NResult result = driver->CreateMemoryRegion(address, NN_NO1024, mr); + + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetDriverRDMA, CreateMemoryWithAddRegionFailed) +{ + driver->Initialize(options); + + UBSHcomNetMemoryRegionPtr mr; + NResult result = driver->CreateMemoryRegion(0, NN_NO1024, mr); + + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetDriverRDMA, DestroyMemoryRegionSuccess) +{ + driver->Initialize(options); + + UBSHcomNetMemoryRegionPtr mr; + NResult result = driver->CreateMemoryRegion(NN_NO1024, mr); + EXPECT_EQ(NNCode::NN_OK, result); + driver->DestroyMemoryRegion(mr); + EXPECT_EQ(NNCode::NN_OK, result); +} +#endif diff --git a/test/llt/testcase/transport/rdma/test_net_rdma_driver.hpp b/test/llt/testcase/transport/rdma/test_net_rdma_driver.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f3e5b047e07289dadfb183d405f4ad3c00b199b6 --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_net_rdma_driver.hpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef _TEST_NET_RDMA_DRIVER_HPP_ +#define _TEST_NET_RDMA_DRIVER_HPP_ +#include +#include + +class TestNetDriverRDMA : public testing::Test { +public: + TestNetDriverRDMA(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; +#endif \ No newline at end of file diff --git a/test/llt/testcase/transport/rdma/test_rdma.cpp b/test/llt/testcase/transport/rdma/test_rdma.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b7ef8c51c868ff0872f65938acf62f5b4753a7c --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_rdma.cpp @@ -0,0 +1,886 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef RDMA_BUILD_ENABLED +#include +#include "mockcpp/mockcpp.hpp" +#include "test_rdma.hpp" +#include "string.h" +#include "hcom.h" +#include "common/net_util.h" +#include "transport/rdma/verbs/net_rdma_sync_endpoint.h" +#include "transport/rdma/verbs/net_rdma_async_endpoint.h" +#include "transport/rdma/rdma_mr_dm_buf.h" +#include "transport/rdma/rdma_mr_fixed_buf.h" +#include "transport/rdma/verbs/rdma_worker.h" +#include "fake_ibv.h" +#include "transport/rdma/verbs/net_rdma_driver.h" +#include "ut_helper.h" + +TestCaseRdma::TestCaseRdma() {} + +void TestCaseRdma::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestCaseRdma::TearDown() +{ + GlobalMockObject::verify(); +} + +using namespace ock::hcom; +#ifdef MOCK_VERBS +#ifdef __cplusplus +extern "C" { +#endif +int fake_ibv_post_send(fake_qp_t *my_qp, struct ibv_send_wr *wr); +int fake_post_read(fake_qp_t *my_qp, struct ibv_send_wr *wr); +int fake_post_write(fake_qp_t *my_qp, struct ibv_send_wr *wr); +#ifdef __cplusplus +} +#endif +#endif +// cpp case +using TestOpCode = enum { + GET_MR = 1, + CHECK_SYNC_RESPONSE, + SEND_RAW, + RECEIVE_RAW, + POST_SEND_FAIL, + SET_MR, +}; + +#define CHECK_RESULT_TRUE(result) \ + EXPECT_EQ(true, result); \ + if (!result) { \ + return; \ + } + +#define CLEAN_UP_ALL_STUBS() GlobalMockObject::verify() + +constexpr uint64_t SYNC_SEND_VALUE = 0xffff0000; +constexpr uint64_t SYNC_RECEIVE_VALUE = 0x0000ffff; +constexpr uint64_t ASYNC_RW_COUNT = 4; +constexpr uint64_t RDMA_LISTEN_PORT = 22222; +// server +UBSHcomNetDriver *serverDriver = nullptr; + +std::string ipSeg = IP_SEG; + +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t lKey = 0; + uint32_t size = 0; +} __attribute__((packed)); +TestRegMrInfo serverLocalMrInfo[4]; + +int ServerNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + return 0; +} + +void ServerEndPointBroken(const UBSHcomNetEndpointPtr &serverEp) +{ + NN_LOG_TRACE_INFO("end point " << serverEp->Id()); +} + +int ServerRequestReceived(const UBSHcomNetRequestContext &ctx) +{ + std::string req((char *)ctx.Message()->Data(), ctx.Header().dataLength); + NN_LOG_INFO("request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + + int result = 0; + if (ctx.OpType() == UBSHcomNetRequestContext::NN_RECEIVED) { + if (ctx.Header().opCode == GET_MR) { + UBSHcomNetTransRequest rsp((void *)(serverLocalMrInfo), sizeof(serverLocalMrInfo), 0); + if ((result = ctx.EndPoint()->PostSend(ctx.Header().opCode, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + + NN_LOG_INFO("request rsp Mr info"); + for (uint16_t i = 0; i < 4; i++) { + NN_LOG_TRACE_INFO("idx:" << i << " key:" << serverLocalMrInfo[i].lKey << " address:" << + serverLocalMrInfo[i].lAddress << " size" << serverLocalMrInfo[i].size); + } + } else if (ctx.Header().opCode == CHECK_SYNC_RESPONSE) { + uint64_t *readValue = reinterpret_cast((void *)(ctx.Message()->Data())); + EXPECT_EQ(SYNC_SEND_VALUE, *readValue); + + uint64_t rspData = SYNC_RECEIVE_VALUE; + UBSHcomNetTransRequest rsp((void *)(&rspData), sizeof(rspData), 0); + if ((result = ctx.EndPoint()->PostSend(ctx.Header().opCode, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + } else if (ctx.Header().opCode == SET_MR) { + for (uint16_t i = 0; i < 4; i++) { + memset(reinterpret_cast(serverLocalMrInfo[i].lAddress), 0, NN_NO16); + } + uint64_t rspData = 0; + UBSHcomNetTransRequest rsp((void *)(&rspData), sizeof(rspData), 0); + if ((result = ctx.EndPoint()->PostSend(ctx.Header().opCode, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + } + } else if (ctx.OpType() == UBSHcomNetRequestContext::NN_RECEIVED_RAW) { + int32_t *readValue = reinterpret_cast((void *)(ctx.Message()->Data())); + EXPECT_EQ(SEND_RAW, *readValue); + int32_t *localAddress = reinterpret_cast(serverLocalMrInfo[0].lAddress); + *localAddress = RECEIVE_RAW; + UBSHcomNetTransRequest req((void *)(serverLocalMrInfo[0].lAddress), NN_NO4, 0); + if ((result = ctx.EndPoint()->PostSendRaw(req, 1)) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return result; + } + } + + return 0; +} + +int ServerRequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_TRACE_INFO("request posted"); + return 0; +} + +int ServerOneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_TRACE_INFO("one side done"); + return 0; +} + + +bool ServerCreateDriver() +{ + if (serverDriver != nullptr) { + NN_LOG_ERROR("serverDriver already created"); + return false; + } + + serverDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rdmaServer", true); + if (serverDriver == nullptr) { + NN_LOG_ERROR("failed to create serverDriver already created"); + return false; + } + + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; // 只支持EVENT模式 + options.mrSendReceiveSegSize = 1024; + options.mrSendReceiveSegCount = 1024; + options.enableTls = false; + options.SetNetDeviceIpMask(ipSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + options.prePostReceiveSizePerQP = 32; + + serverDriver->RegisterNewEPHandler( + std::bind(&ServerNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + serverDriver->RegisterEPBrokenHandler(std::bind(&ServerEndPointBroken, std::placeholders::_1)); + serverDriver->RegisterNewReqHandler(std::bind(&ServerRequestReceived, std::placeholders::_1)); + serverDriver->RegisterReqPostedHandler(std::bind(&ServerRequestPosted, std::placeholders::_1)); + serverDriver->RegisterOneSideDoneHandler(std::bind(&ServerOneSideDone, std::placeholders::_1)); + + serverDriver->OobIpAndPort(BASE_IP, RDMA_LISTEN_PORT); + + int result = 0; + if ((result = serverDriver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize serverDriver " << result); + return false; + } + NN_LOG_INFO("serverDriver initialized"); + + if ((result = serverDriver->Start()) != 0) { + NN_LOG_ERROR("failed to start serverDriver " << result); + return false; + } + NN_LOG_INFO("serverDriver started"); + + return true; +} + +bool ServerRegSglMem() +{ + for (uint16_t i = 0; i < 4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = serverDriver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + serverLocalMrInfo[i].lAddress = mr->GetAddress(); + serverLocalMrInfo[i].lKey = mr->GetLKey(); + serverLocalMrInfo[i].size = NN_NO16; + memset(reinterpret_cast(serverLocalMrInfo[i].lAddress), 0, NN_NO16); + } + + return true; +} + +// client +UBSHcomNetDriver *clientDriver = nullptr; +UBSHcomNetEndpointPtr clientAsyncEp = nullptr; +UBSHcomNetEndpointPtr clientSyncEp = nullptr; +TestRegMrInfo localMrInfo[NN_NO4]; +TestRegMrInfo remoteMrInfo[NN_NO4]; +sem_t sem; + +uint32_t execCount = 0; + +void ClientEndPointBroken(const UBSHcomNetEndpointPtr &clientEp) +{ + if (clientSyncEp != nullptr && clientEp->Id() == clientSyncEp->Id()) { + NN_LOG_INFO("client sync end point " << clientEp->Id() << " broken"); + clientSyncEp.Set(nullptr); + } else if (clientAsyncEp != nullptr && clientEp->Id() == clientAsyncEp->Id()) { + NN_LOG_INFO("client async end point " << clientEp->Id() << " broken"); + clientAsyncEp.Set(nullptr); + } +} + +int ClientRequestReceived(const UBSHcomNetRequestContext &ctx) +{ + if (ctx.OpType() == UBSHcomNetRequestContext::NN_RECEIVED) { + if (ctx.Header().opCode == GET_MR) { + memcpy(remoteMrInfo, ctx.Message()->Data(), ctx.Message()->DataLen()); + NN_LOG_INFO("get remote Mr info"); + for (uint16_t i = 0; i < NN_NO4; i++) { + NN_LOG_TRACE_INFO("idx:" << i << " key:" << remoteMrInfo[i].lKey << " address:" << + remoteMrInfo[i].lAddress << " size" << remoteMrInfo[i].size); + } + sem_post(&sem); + } else if (ctx.Header().opCode == CHECK_SYNC_RESPONSE) { + uint64_t *readValue = reinterpret_cast((void *)(ctx.Message()->Data())); + EXPECT_EQ(SYNC_RECEIVE_VALUE, *readValue); + } + } else if (ctx.OpType() == UBSHcomNetRequestContext::NN_RECEIVED_RAW) { + int32_t *readValue = reinterpret_cast((void *)(ctx.Message()->Data())); + EXPECT_EQ(RECEIVE_RAW, *readValue); + sem_post(&sem); + } + + return 0; +} + +int ClientRequestPosted(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +int ClientOneSideDone(const UBSHcomNetRequestContext &ctx) +{ + sem_post(&sem); + return 0; +} + +bool ClientCreateDriver() +{ + if (clientDriver != nullptr) { + NN_LOG_ERROR("clientDriver already created"); + return false; + } + + clientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rdmaClient", false); + if (clientDriver == nullptr) { + NN_LOG_ERROR("failed to create clientDriver already created"); + return false; + } + + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; // 只支持EVENT模式 + options.mrSendReceiveSegSize = 1024; + options.mrSendReceiveSegCount = 1024; + options.heartBeatIdleTime = 1; + options.heartBeatProbeInterval = 1; + options.enableTls = false; + options.SetNetDeviceIpMask(ipSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + clientDriver->RegisterEPBrokenHandler(std::bind(&ClientEndPointBroken, std::placeholders::_1)); + clientDriver->RegisterNewReqHandler(std::bind(&ClientRequestReceived, std::placeholders::_1)); + clientDriver->RegisterReqPostedHandler(std::bind(&ClientRequestPosted, std::placeholders::_1)); + clientDriver->RegisterOneSideDoneHandler(std::bind(&ClientOneSideDone, std::placeholders::_1)); + + clientDriver->OobIpAndPort(BASE_IP, RDMA_LISTEN_PORT); + + int result = 0; + if ((result = clientDriver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize clientDriver " << result); + return false; + } + NN_LOG_INFO("clientDriver initialized"); + + if ((result = clientDriver->Start()) != 0) { + NN_LOG_ERROR("failed to start clientDriver " << result); + return false; + } + NN_LOG_INFO("clientDriver started"); + + return true; +} + +bool AsyncClientConnect() +{ + if (clientDriver == nullptr) { + NN_LOG_ERROR("clientDriver is null"); + return false; + } + + int result = 0; + if ((result = clientDriver->Connect("hello world", clientAsyncEp, 0)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + clientAsyncEp->PeerIpAndPort(); + sem_init(&sem, 0, 0); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + if ((result = clientAsyncEp->PostSend(GET_MR, req)) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return false; + } + + sem_wait(&sem); + return true; +} + +bool SyncClientConnect() +{ + if (clientDriver == nullptr) { + NN_LOG_ERROR("clientDriver is null"); + return false; + } + + int result = 0; + if ((result = clientDriver->Connect("hello world", clientSyncEp, NET_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + clientSyncEp->PeerIpAndPort(); + return true; +} + +void SendAsyncOneSideRequest(UBSHcomNetTransSgeIov *iov, uint64_t index) +{ + int result = 0; + MOCKER(fake_post_read).stubs().will(returnValue(0)); + MOCKER(fake_post_write).stubs().will(returnValue(0)); + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + result = clientAsyncEp->PostRead(sglReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to read sgl data from server"); + return; + } + + UBSHcomNetTransSglRequest reqWrite(iov, NN_NO4, 0); + result = clientAsyncEp->PostWrite(sglReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to write sgl data from server"); + return; + } + + UBSHcomNetTransRequest buffReq(iov[0].lAddress, iov[0].rAddress, iov[0].lKey, iov[0].rKey, iov[0].size, 0); + result = clientAsyncEp->PostRead(buffReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + + result = clientAsyncEp->PostWrite(buffReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to write data from server"); + return; + } +} + +void AsyncPostSendRawRequest() +{ + int result = 0; + int32_t *localAddress = reinterpret_cast(localMrInfo[0].lAddress); + *localAddress = SEND_RAW; + UBSHcomNetTransRequest req((void *)(localMrInfo[0].lAddress), NN_NO4, 0); + if ((result = clientAsyncEp->PostSendRaw(req, 1)) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return; + } + sem_wait(&sem); + EXPECT_EQ(result, NN_OK); +} + +void AsyncPostSendFailRequest() +{ + int result = 0; + uint64_t data = 0; + UBSHcomNetTransRequest req((void *)(localMrInfo[0].lAddress), NN_NO4, 0); + clientAsyncEp->DefaultTimeout(0); + + MOCKER(RDMAMemoryRegionFixedBuffer::GetFreeBuffer).stubs().will(returnValue(false)); + result = clientAsyncEp->PostSend(POST_SEND_FAIL, req); + EXPECT_EQ(result, NN_GET_BUFF_FAILED); + + result = clientAsyncEp->PostSendRaw(req, 1); + EXPECT_EQ(result, NN_GET_BUFF_FAILED); + CLEAN_UP_ALL_STUBS(); + + MOCKER(RDMAQp::GetPostSendWr).stubs().will(returnValue(false)); + result = clientAsyncEp->PostSend(POST_SEND_FAIL, req); + EXPECT_EQ(result, RR_QP_POST_SEND_WR_FULL); + + result = clientAsyncEp->PostSendRaw(req, 1); + EXPECT_EQ(result, RR_QP_POST_SEND_WR_FULL); + CLEAN_UP_ALL_STUBS(); + + req.upCtxSize = NN_NO100; + result = clientAsyncEp->PostSend(POST_SEND_FAIL, req); + EXPECT_EQ(result, NN_PARAM_INVALID); + + result = clientAsyncEp->PostSendRaw(req, 1); + EXPECT_EQ(result, NN_PARAM_INVALID); + req.upCtxSize = 0; +#ifdef MOCK_VERBS + MOCKER(fake_ibv_post_send).stubs().will(returnValue(-1)); + result = clientAsyncEp->PostSend(POST_SEND_FAIL, req); + EXPECT_EQ(result, RR_QP_POST_SEND_FAILED); + + result = clientAsyncEp->PostSendRaw(req, 1); + EXPECT_EQ(result, RR_QP_POST_SEND_FAILED); +#endif + CLEAN_UP_ALL_STUBS(); +} + +void AsyncReadRequestCheckResult(UBSHcomNetTransSgeIov *iov, int checkResult, uint16_t upCtxSize) +{ + int result = 0; + + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, upCtxSize); + result = clientAsyncEp->PostRead(sglReq); + EXPECT_EQ(result, checkResult); + + UBSHcomNetTransRequest buffReq(iov[0].lAddress, iov[0].rAddress, iov[0].lKey, iov[0].rKey, iov[0].size, upCtxSize); + result = clientAsyncEp->PostRead(buffReq); + EXPECT_EQ(result, checkResult); +} + +void AsyncWriteRequestCheckResult(UBSHcomNetTransSgeIov *iov, int checkResult, uint16_t upCtxSize) +{ + int result = 0; + + UBSHcomNetTransSglRequest reqWrite(iov, NN_NO4, upCtxSize); + result = clientAsyncEp->PostWrite(reqWrite); + EXPECT_EQ(result, checkResult); + + UBSHcomNetTransRequest buffReq(iov[0].lAddress, iov[0].rAddress, iov[0].lKey, iov[0].rKey, iov[0].size, upCtxSize); + result = clientAsyncEp->PostWrite(buffReq); + EXPECT_EQ(result, checkResult); +} + +void AsyncOneSideFailRequest(UBSHcomNetTransSgeIov *iov) +{ + uint16_t upCtxSize = 0; + clientAsyncEp->DefaultTimeout(0); + MOCKER(NetDriverRDMA::ValidateMemoryRegion).stubs().will(returnValue(-1)); + AsyncReadRequestCheckResult(iov, NN_INVALID_LKEY, upCtxSize); + AsyncWriteRequestCheckResult(iov, NN_INVALID_LKEY, upCtxSize); + CLEAN_UP_ALL_STUBS(); + + MOCKER(RDMAQp::GetOneSideWr).stubs().will(returnValue(false)); + AsyncReadRequestCheckResult(iov, RR_QP_ONE_SIDE_WR_FULL, upCtxSize); + AsyncWriteRequestCheckResult(iov, RR_QP_ONE_SIDE_WR_FULL, upCtxSize); + CLEAN_UP_ALL_STUBS(); + + upCtxSize = NN_NO100; + AsyncReadRequestCheckResult(iov, NN_PARAM_INVALID, upCtxSize); + AsyncWriteRequestCheckResult(iov, NN_PARAM_INVALID, upCtxSize); + upCtxSize = 0; +#ifdef MOCK_VERBS + MOCKER(fake_post_read).stubs().will(returnValue(-1)); + AsyncReadRequestCheckResult(iov, RR_QP_POST_READ_FAILED, upCtxSize); + MOCKER(fake_post_write).stubs().will(returnValue(-1)); + AsyncWriteRequestCheckResult(iov, RR_QP_POST_WRITE_FAILED, upCtxSize); +#endif + CLEAN_UP_ALL_STUBS(); +} + +void AsyncNotSupportOperation() +{ + EXPECT_EQ(clientAsyncEp->WaitCompletion(), NN_INVALID_OPERATION); + + UBSHcomNetResponseContext ctx; + EXPECT_EQ(clientAsyncEp->Receive(ctx), NN_INVALID_OPERATION); + EXPECT_EQ(clientAsyncEp->ReceiveRaw(ctx), NN_INVALID_OPERATION); +} + +void AsyncQpErrorHandle() +{ + MOCKER(RDMAOpContextInfo::OpResult).stubs().will(returnValue(RDMAOpContextInfo::ERR_IO_ERROR)); + int result = 0; + int32_t *localAddress = reinterpret_cast(localMrInfo[0].lAddress); + *localAddress = SEND_RAW; + UBSHcomNetTransRequest req((void *)(localMrInfo[0].lAddress), NN_NO4, 0); + if ((result = clientAsyncEp->PostSendRaw(req, 1)) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return; + } + sleep(1); + CLEAN_UP_ALL_STUBS(); +} + +void AsyncRequest() +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = localMrInfo[i].lAddress; + iov[i].rAddress = remoteMrInfo[i].lAddress; + iov[i].lKey = localMrInfo[i].lKey; + iov[i].rKey = remoteMrInfo[i].lKey; + iov[i].size = NN_NO8; + } + sem_init(&sem, 0, 0); + SendAsyncOneSideRequest(iov, 0); + + AsyncPostSendRawRequest(); + AsyncPostSendFailRequest(); + AsyncOneSideFailRequest(iov); + AsyncNotSupportOperation(); + AsyncQpErrorHandle(); + // clientAsyncEp destroy when broken handle, do not use anymore +} + +void SyncPostSendFailRequest() +{ + int result = 0; + uint64_t data = 0; + UBSHcomNetTransRequest req((void *)(localMrInfo[0].lAddress), NN_NO4, 0); + clientSyncEp->DefaultTimeout(0); + + MOCKER(RDMAMemoryRegionFixedBuffer::GetFreeBuffer).stubs().will(returnValue(false)); + result = clientSyncEp->PostSend(POST_SEND_FAIL, req); + EXPECT_EQ(result, NN_GET_BUFF_FAILED); + + result = clientSyncEp->PostSendRaw(req, 1); + EXPECT_EQ(result, NN_GET_BUFF_FAILED); + CLEAN_UP_ALL_STUBS(); + + req.upCtxSize = NN_NO100; + result = clientSyncEp->PostSend(POST_SEND_FAIL, req); + EXPECT_EQ(result, NN_PARAM_INVALID); + + result = clientSyncEp->PostSendRaw(req, 1); + EXPECT_EQ(result, NN_PARAM_INVALID); + req.upCtxSize = 0; +#ifdef MOCK_VERBS + MOCKER(fake_ibv_post_send).stubs().will(returnValue(-1)); + result = clientSyncEp->PostSend(POST_SEND_FAIL, req); + EXPECT_EQ(result, RR_QP_POST_SEND_FAILED); + + result = clientSyncEp->PostSendRaw(req, 1); + EXPECT_EQ(result, RR_QP_POST_SEND_FAILED); +#endif + CLEAN_UP_ALL_STUBS(); +} + +void SyncReadRequestCheckResult(UBSHcomNetTransSgeIov *iov, int checkResult, uint16_t upCtxSize) +{ + int result = 0; + + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, upCtxSize); + result = clientSyncEp->PostRead(sglReq); + EXPECT_EQ(result, checkResult); + + UBSHcomNetTransRequest buffReq(iov[0].lAddress, iov[0].rAddress, iov[0].lKey, iov[0].rKey, iov[0].size, upCtxSize); + result = clientSyncEp->PostRead(buffReq); + EXPECT_EQ(result, checkResult); +} + +void SyncWriteRequestCheckResult(UBSHcomNetTransSgeIov *iov, int checkResult, uint16_t upCtxSize) +{ + int result = 0; + + UBSHcomNetTransSglRequest reqWrite(iov, NN_NO4, upCtxSize); + result = clientSyncEp->PostWrite(reqWrite); + EXPECT_EQ(result, checkResult); + + UBSHcomNetTransRequest buffReq(iov[0].lAddress, iov[0].rAddress, iov[0].lKey, iov[0].rKey, iov[0].size, upCtxSize); + result = clientSyncEp->PostWrite(buffReq); + EXPECT_EQ(result, checkResult); +} + +void SyncOneSideFailRequest(UBSHcomNetTransSgeIov *iov) +{ + uint16_t upCtxSize = 0; + clientSyncEp->DefaultTimeout(0); + MOCKER(NetDriverRDMA::ValidateMemoryRegion).stubs().will(returnValue(-1)); + SyncReadRequestCheckResult(iov, NN_INVALID_LKEY, upCtxSize); + SyncWriteRequestCheckResult(iov, NN_INVALID_LKEY, upCtxSize); + CLEAN_UP_ALL_STUBS(); + + upCtxSize = NN_NO100; + SyncReadRequestCheckResult(iov, NN_PARAM_INVALID, upCtxSize); + SyncWriteRequestCheckResult(iov, NN_PARAM_INVALID, upCtxSize); + upCtxSize = 0; +#ifdef MOCK_VERBS + MOCKER(fake_post_read).stubs().will(returnValue(-1)); + SyncReadRequestCheckResult(iov, RR_QP_POST_READ_FAILED, upCtxSize); + MOCKER(fake_post_write).stubs().will(returnValue(-1)); + SyncWriteRequestCheckResult(iov, RR_QP_POST_WRITE_FAILED, upCtxSize); +#endif + CLEAN_UP_ALL_STUBS(); +} + +void SyncRequestsSuccess() +{ + // get one mr seg from pool + uint64_t data = 0; + UBSHcomNetResponseContext respCtx {}; + + int result = 0; + data = SYNC_SEND_VALUE; + UBSHcomNetTransRequest req((void *)(&data), sizeof(data), 0); + result = clientSyncEp->PostSend(CHECK_SYNC_RESPONSE, req); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return; + } + + result = clientSyncEp->WaitCompletion(-1); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to wait completion, result " << result); + return; + } + + result = clientSyncEp->Receive(respCtx); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to get response, result " << result); + return; + } + + UBSHcomNetTransRequest buffReq(localMrInfo[0].lAddress, remoteMrInfo[0].lAddress, localMrInfo[0].lKey, + remoteMrInfo[0].lKey, localMrInfo[0].size, 0); + result = clientSyncEp->PostRead(buffReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to read data from server"); + return; + } + + result = clientSyncEp->WaitCompletion(-1); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to wait completion, result " << result); + return; + } + uint64_t *readBuff = reinterpret_cast((void *)(localMrInfo[0].lAddress)); + uint64_t readValue = *readBuff; + EXPECT_EQ(readValue, ASYNC_RW_COUNT); + + result = clientSyncEp->PostWrite(buffReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to write data from server"); + return; + } + + result = clientSyncEp->WaitCompletion(-1); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to wait completion, result " << result); + return; + } + + int32_t *localAddress = reinterpret_cast(localMrInfo[0].lAddress); + *localAddress = SEND_RAW; + UBSHcomNetTransRequest reqRaw((void *)(localMrInfo[0].lAddress), NN_NO4, 0); + result = clientSyncEp->PostSendRaw(reqRaw, 1); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return; + } + result = clientSyncEp->WaitCompletion(-1); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to wait completion, result " << result); + return; + } + + UBSHcomNetResponseContext rawCtx; + result = clientSyncEp->ReceiveRaw(rawCtx); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to get response, result " << result); + return; + } +} + +void SendSyncOneSideRequest(UBSHcomNetTransSgeIov *iov, uint64_t index) +{ + int result = 0; + + UBSHcomNetTransSglRequest sglReq(iov, NN_NO1, 0); + result = clientSyncEp->PostRead(sglReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to read sgl data from server"); + return; + } + + result = clientSyncEp->WaitCompletion(); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to wait read sgl from server"); + return; + } + NN_LOG_TRACE_INFO("sgl read value idx:" << execCount++); + for (uint16_t i = 0; i < NN_NO1; i++) { + uint64_t *readValue = reinterpret_cast((void *)(localMrInfo[i].lAddress)); + uint64_t value = *readValue; + NN_LOG_TRACE_INFO("value[" << i << "]=" << *readValue); + EXPECT_EQ(value, index); + *readValue = ++value; + } + + UBSHcomNetTransSglRequest reqWrite(iov, NN_NO1, 0); + result = clientSyncEp->PostWrite(sglReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to write sgl data from server"); + return; + } + + result = clientSyncEp->WaitCompletion(); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to wait write sgl from server"); + return; + } +} + +void SyncSetRemoteMrZero() +{ + uint64_t data = 0; + UBSHcomNetResponseContext respCtx {}; + + int result = 0; + data = SYNC_SEND_VALUE; + UBSHcomNetTransRequest req((void *)(&data), sizeof(data), 0); + result = clientSyncEp->PostSend(SET_MR, req); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return; + } + + result = clientSyncEp->WaitCompletion(-1); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to wait completion, result " << result); + return; + } + + result = clientSyncEp->Receive(respCtx); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_ERROR("failed to get response, result " << result); + return; + } +} + +void SyncRequests() +{ + SyncRequestsSuccess(); + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = localMrInfo[i].lAddress; + iov[i].rAddress = remoteMrInfo[i].lAddress; + iov[i].lKey = localMrInfo[i].lKey; + iov[i].rKey = remoteMrInfo[i].lKey; + iov[i].size = NN_NO8; + } + + SyncSetRemoteMrZero(); + SendSyncOneSideRequest(iov, 0); + + SyncPostSendFailRequest(); + SyncOneSideFailRequest(iov); +} + +bool ClientRegSglMem() +{ + for (uint16_t i = 0; i < NN_NO4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = clientDriver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + localMrInfo[i].lAddress = mr->GetAddress(); + localMrInfo[i].lKey = mr->GetLKey(); + localMrInfo[i].size = NN_NO16; + memset(reinterpret_cast(localMrInfo[i].lAddress), 0, NN_NO16); + } + + return true; +} + +TEST_F(TestCaseRdma, RDMA_BASIC_OPERATE) +{ +#ifdef MOCK_VERBS + MOCK_VERSION +#endif + bool result = ServerCreateDriver(); + CHECK_RESULT_TRUE(result); + + result = ServerRegSglMem(); + CHECK_RESULT_TRUE(result); +#ifdef MOCK_VERBS + MOCK_VERSION +#endif + result = ClientCreateDriver(); + CHECK_RESULT_TRUE(result); + result = AsyncClientConnect(); + CHECK_RESULT_TRUE(result); + result = ClientRegSglMem(); + CHECK_RESULT_TRUE(result); + AsyncRequest(); + // clientAsyncEp destroy when broken handle, do not use anymore + + result = SyncClientConnect(); + CHECK_RESULT_TRUE(result); + SyncRequests(); + + if (clientDriver->IsStarted()) { + clientDriver->Stop(); + } + if (clientDriver->IsInited()) { + clientDriver->UnInitialize(); + } + if (serverDriver->IsStarted()) { + serverDriver->Stop(); + } + if (serverDriver->IsInited()) { + serverDriver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(clientDriver->Name()); + UBSHcomNetDriver::DestroyInstance(serverDriver->Name()); +} +#endif \ No newline at end of file diff --git a/test/llt/testcase/transport/rdma/test_rdma.hpp b/test/llt/testcase/transport/rdma/test_rdma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9903cd5813547ef1648a157ff924e38308bb2c14 --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_rdma.hpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef _TEST_RDMA_HPP_ +#define _TEST_RDMA_HPP_ +#include + +class TestCaseRdma : public testing::Test { +public: + TestCaseRdma(); + virtual void SetUp(void); + virtual void TearDown(void); +protected: +}; + +#endif + + diff --git a/test/llt/testcase/transport/rdma/test_rdma_heartbeat.cpp b/test/llt/testcase/transport/rdma/test_rdma_heartbeat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da64068de0c7c468f8ebc4570ecd7bbdb9e31fe1 --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_rdma_heartbeat.cpp @@ -0,0 +1,262 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include +#include +#include "test_rdma_heartbeat.hpp" +#include "transport/rdma/rdma_heartbeat.h" +#include +#include +#include +#include + +using namespace ock::hcom; +TestCaseRDMAHeartBeat::TestCaseRDMAHeartBeat() {} + +void TestCaseRDMAHeartBeat::SetUp() {} + +void TestCaseRDMAHeartBeat::TearDown() +{ + GlobalMockObject::verify(); +} + +bool TestConnBrokenCheckCB(int fd) {} + +void TestConnBrokenPostCB(int fd) {} + +TEST_F(TestCaseRDMAHeartBeat, InitFailed) +{ + RIPDeviceHeartbeatManager hbMgr("test init"); + + auto result = hbMgr.Initialize(); + EXPECT_EQ(result, NN_PARAM_INVALID); + + hbMgr.SetConnBrokenCheckHandler(std::bind(&TestConnBrokenCheckCB, std::placeholders::_1)); + + result = hbMgr.Initialize(); + EXPECT_EQ(result, NN_PARAM_INVALID); + + hbMgr.SetConnBrokenPostHandler(std::bind(&TestConnBrokenPostCB, std::placeholders::_1)); + + MOCKER(epoll_create).stubs().will(returnValue(-1)); + result = hbMgr.Initialize(); + EXPECT_EQ(result, NN_HEARTBEAT_CREATE_EPOLL_FAILED); +} + +TEST_F(TestCaseRDMAHeartBeat, Start) +{ + RIPDeviceHeartbeatManager hbMgr("test start"); + + hbMgr.SetConnBrokenCheckHandler(std::bind(&TestConnBrokenCheckCB, std::placeholders::_1)); + hbMgr.SetConnBrokenPostHandler(std::bind(&TestConnBrokenPostCB, std::placeholders::_1)); + + auto result = hbMgr.Initialize(); + EXPECT_EQ(result, NN_OK); + + result = hbMgr.Start(); + EXPECT_EQ(result, NN_OK); + + result = hbMgr.Start(); + EXPECT_EQ(result, NN_OK); + + hbMgr.Stop(); +} + +std::string ip = "0.0.0.0"; +uint32_t port = 6323; + +bool gNeedStop = false; +std::atomic mStarted; + +bool Accept(RIPDeviceHeartbeatManager &hbMgr, int fd, const std::string &ip, int16_t port) +{ + int result = 0; + if ((result = hbMgr.AddNewIP(ip, fd)) != 0) { + NN_LOG_ERROR("Failed to add fd to heartbeat manager " << ip << "-" << fd << " result " << result); + } + return true; +} + +void RunServer() +{ + RIPDeviceHeartbeatManager hbMgr("test server"); + + hbMgr.SetConnBrokenCheckHandler( + std::bind(&RIPDeviceHeartbeatManager::DefaultConnBrokenCheckCB, std::placeholders::_1)); + hbMgr.SetConnBrokenPostHandler( + std::bind(&RIPDeviceHeartbeatManager::DefaultConnBrokenPostCB, std::placeholders::_1)); + hbMgr.SetKeepaliveConfig(1, 1, 1); + + int result = 0; + if ((result = hbMgr.Initialize()) != 0 || (result = hbMgr.Start()) != 0) { + NN_LOG_ERROR("Failed to initialize start RIPDeviceHeartbeatManager, result " << result); + return; + } + + auto listenFD = ::socket(AF_INET, SOCK_STREAM, 0); + if (listenFD < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create listen socket as " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return; + } + + // assign address + struct sockaddr_in addr {}; + bzero(&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = inet_addr(ip.c_str()); + addr.sin_port = htons(port); + + // set option, bind and listen + int flags = 1; + if (::setsockopt(listenFD, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&flags), sizeof(flags)) < 0 || + ::bind(listenFD, reinterpret_cast(&addr), sizeof(addr)) < 0 || + ::listen(listenFD, OOB_DEFAULT_LISTEN_BACKLOG) < 0) { + ::close(listenFD); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to set option or bind or listen on listen socket as " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return; + } + + struct sockaddr_in addressIn {}; + socklen_t len = sizeof(addressIn); + mStarted.store(true); + + bzero(&addressIn, sizeof(struct sockaddr_in)); + auto fd = ::accept(listenFD, reinterpret_cast(&addressIn), &len); + if (fd < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_WARN("Failed to accept on new socket with " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE) << ", ignore and continue"); + return; + } + + Accept(hbMgr, fd, inet_ntoa(addressIn.sin_addr), ntohs(addressIn.sin_port)); + + + close(listenFD); + hbMgr.Stop(); + hbMgr.UnInitialize(); +} + +bool ConnBrokenCheckCB(int fd) +{ + char data[1]; + int result = recv(fd, data, 1, MSG_DONTWAIT); + if (result < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // connection is still ok + return true; + } + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + // connection is wrong + NN_LOG_INFO("Connection is wrong, fd " << fd << ", error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return false; + } else if (result == 0) { + NN_LOG_INFO("Connection is broken, fd " << fd); + return false; // connection really broken + } else { + return true; + } +} + +void ConnBrokenPostCB(int fd) +{ + NN_LOG_INFO("ConnBrokenPostCB called fd " << fd); + close(fd); +} + +bool Connect(const std::string &ip, uint16_t port, int &fd) +{ + auto tmpFD = ::socket(AF_INET, SOCK_STREAM, 0); + if (tmpFD < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create socket as " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return false; + } + + struct sockaddr_in addr {}; + bzero(&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = inet_addr(ip.c_str()); + addr.sin_port = htons(port); + + if (connect(tmpFD, reinterpret_cast(&addr), sizeof(addr)) != 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to connect to " << ip << ":" << port << " as " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE)); + return false; + } + fd = tmpFD; + return true; +} + +void RunClient() +{ + RIPDeviceHeartbeatManager hbMgr("test client"); + + hbMgr.SetConnBrokenCheckHandler(std::bind(&ConnBrokenCheckCB, std::placeholders::_1)); + hbMgr.SetConnBrokenPostHandler(std::bind(&ConnBrokenPostCB, std::placeholders::_1)); + hbMgr.SetKeepaliveConfig(1, 1, 1); + + int result = 0; + if ((result = hbMgr.Initialize()) != 0 || (result = hbMgr.Start()) != 0) { + NN_LOG_ERROR("Failed to initialize start RIPDeviceHeartbeatManager, result " << result); + return; + } + + NN_LOG_INFO("Heartbeat manager started"); + int fd = -1; + if (Connect(ip, port, fd)) { + NN_LOG_INFO("Connected to " << ip << ":" << port << ", fd " << fd); + + if ((result = hbMgr.AddNewIP(ip, fd)) != 0) { + NN_LOG_ERROR("Failed to add fd to heartbeat manager " << fd << " result " << result); + return; + } + + int getFd; + result = hbMgr.GetFdByIP(ip, getFd); + EXPECT_EQ(result, NN_OK); + EXPECT_EQ(fd, getFd); + + result = hbMgr.RemoveIP(ip); + EXPECT_EQ(result, NN_OK); + + result = hbMgr.RemoveByFD(fd); + EXPECT_EQ(result, NN_HEARTBEAT_IP_NO_FOUND); + + close(fd); + } else { + NN_LOG_ERROR("Failed to connect to " << ip << ":" << port); + } + hbMgr.Stop(); +} + +TEST_F(TestCaseRDMAHeartBeat, ClientServerHb) +{ + mStarted.store(false); + std::thread tmpThread(RunServer); + while (!mStarted.load()) { + usleep(10); + } + + RunClient(); + gNeedStop = true; + + tmpThread.join(); +} +#endif \ No newline at end of file diff --git a/test/llt/testcase/transport/rdma/test_rdma_heartbeat.hpp b/test/llt/testcase/transport/rdma/test_rdma_heartbeat.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a0929b9e109e98a3ea7f5c8f7aebc0312f354c6a --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_rdma_heartbeat.hpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef _TEST_RDMA_HEARTBEAT__HPP_ +#define _TEST_RDMA_HEARTBEAT__HPP_ +#include + +class TestCaseRDMAHeartBeat : public testing::Test { +public: + TestCaseRDMAHeartBeat(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; + +#endif diff --git a/test/llt/testcase/transport/rdma/test_rdma_tls.cpp b/test/llt/testcase/transport/rdma/test_rdma_tls.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c27c43d28b17567c2a42fd83e9c9bb49d29783f7 --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_rdma_tls.cpp @@ -0,0 +1,1611 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef RDMA_BUILD_ENABLED +#include + +#include "hcom.h" +#include "hcom_utils.h" +#include "common/net_util.h" +#include "string.h" +#include "test_rdma_tls.hpp" +#include "ut_helper.h" + +using namespace ock::hcom; + +// server +UBSHcomNetDriver *tlsDriver = nullptr; +UBSHcomNetDriver *invalidTlsDriver = nullptr; +UBSHcomNetDriver *notSameCaTlsDriver = nullptr; +UBSHcomNetDriver *certExpiredTlsDriver = nullptr; +UBSHcomNetDriver *certRevokedTlsDriver = nullptr; +UBSHcomNetDriver *cVerifyByNoneTlsDriver = nullptr; +UBSHcomNetDriver *multiLevelCertTlsDriver = nullptr; +UBSHcomNetDriver *abnormalCertChainDriver = nullptr; +UBSHcomNetDriver *normalCertChainDriver = nullptr; +UBSHcomNetDriver *customVerifyTlsDriver = nullptr; + +static UBSHcomNetDriverOptions options {}; + +UBSHcomNetEndpointPtr tlsServerEp = nullptr; +std::string tlsIpSeg = IP_SEG; +std::string certPath; +std::string expiredCertPath; +std::string otherCertPath; +std::string revokedCertPath; +std::string cliVerifyByNoneCertPath; +std::string multiCertPath; +std::string abnormalCertChainPath; +std::string normalCertChainPath; + +std::string syncSendValue = "sync send value"; +std::string syncReplyValue = "sync response by server value"; +std::string asyncSendValue = "async send value"; +std::string asyncSendRawValue = "async send raw value"; +std::string syncSendRawValue = "sync send raw value"; + +using TestOpCode = enum { + CHECK_ASYNC_RESPONSE = 1, + CHECK_SYNC_RESPONSE, + SEND_RAW, + RECEIVE_RAW, +}; + + +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t lKey = 0; + uint32_t size = 0; +} __attribute__((packed)); +TestRegMrInfo tlsSerlocalMrInfo[NN_NO4]; + +bool driverInitAndStart(UBSHcomNetDriver *driver) +{ + int result = 0; + if ((result = driver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("driver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start driver " << result); + return false; + } + NN_LOG_INFO("driver started"); + return true; +} + + +int TlsServerNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + tlsServerEp = newEP; + return 0; +} + +void TlsServerEndPointBroken(const UBSHcomNetEndpointPtr &tlsServerEp) +{ + NN_LOG_INFO("end point " << tlsServerEp->Id()); +} + +int TlsServerRequestReceived(const UBSHcomNetRequestContext &ctx) +{ + std::string req((char *)ctx.Message()->Data(), ctx.Header().dataLength); + NN_LOG_INFO("request received - " << ctx.Header().opCode << ", dataLen " << req.length()); + + int result = 0; + if (ctx.OpType() == UBSHcomNetRequestContext::NN_RECEIVED) { + if (ctx.Header().opCode == CHECK_SYNC_RESPONSE) { + EXPECT_EQ(syncSendValue.length(), req.length()); + EXPECT_EQ(0, memcmp(syncSendValue.c_str(), req.c_str(), syncSendValue.size())); + + UBSHcomNetTransRequest rsp((void *)(const_cast(syncReplyValue.c_str())), + syncReplyValue.length(), 0); + + if ((result = ctx.EndPoint()->PostSend(ctx.Header().opCode, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + } else { + EXPECT_EQ(asyncSendValue.length(), req.length()); + EXPECT_EQ(0, memcmp(asyncSendValue.c_str(), req.c_str(), asyncSendValue.size())); + } + } else if (ctx.OpType() == UBSHcomNetRequestContext::NN_RECEIVED_RAW) { + if (ctx.Header().seqNo == CHECK_SYNC_RESPONSE) { + EXPECT_EQ(syncSendRawValue.length(), req.length()); + EXPECT_EQ(0, memcmp(syncSendRawValue.c_str(), req.c_str(), syncSendRawValue.size())); + + UBSHcomNetTransRequest rsp((void *)(const_cast(syncReplyValue.c_str())), + syncReplyValue.length(), 0); + + if ((result = ctx.EndPoint()->PostSendRaw(rsp, ctx.Header().opCode)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + } else { + EXPECT_EQ(asyncSendRawValue.length(), req.length()); + EXPECT_EQ(0, memcmp(asyncSendRawValue.c_str(), req.c_str(), asyncSendRawValue.size())); + } + } + + return 0; +} + +int TlsServerRequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} + +int TlsServerOneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + +static void Erase(void *pass, int len) {} +static int Verify(void *x509, const char *path) +{ + NN_LOG_INFO("verify by custom func"); + return 0; +} + +static bool CertCallback(const std::string &name, std::string &value) +{ + value = certPath + "/server/cert.pem"; + return true; +} + +static bool PrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = certPath + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool CACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = certPath + "/CA/cacert.pem"; + cb = nullptr; + return true; +} + +int ValidateTlsCert() +{ + char *buffer; + + if ((buffer = getcwd(NULL, 0)) == NULL) { + NN_LOG_ERROR("Cet path for TLS cert failed"); + return -1; + } + + std::string currentPath = buffer; + + certPath = currentPath + "/../test/opensslcrt/normalCert1"; + otherCertPath = currentPath + "/../test/opensslcrt/normalCert2"; + expiredCertPath = currentPath + "/../test/opensslcrt/expiredCert"; + revokedCertPath = currentPath + "/../test/opensslcrt/crlRevokedCert"; + cliVerifyByNoneCertPath = currentPath + "/../test/opensslcrt/serExpCertCliNoCheck"; + multiCertPath = currentPath + "/../test/opensslcrt/multiLevelCert"; + abnormalCertChainPath = currentPath + "/../test/opensslcrt/abnormalCertChain"; + normalCertChainPath = currentPath + "/../test/opensslcrt/normalCertChain"; + + if (!CanonicalPath(certPath)) { + NN_LOG_ERROR("TLS cert path check failed " << certPath); + return -1; + } + + if (!CanonicalPath(otherCertPath)) { + NN_LOG_ERROR("TLS cert path check failed " << certPath); + return -1; + } + + if (!CanonicalPath(expiredCertPath)) { + NN_LOG_ERROR("TLS cert path check failed " << certPath); + return -1; + } + + return 0; +} + +void setServerDriverCallback(UBSHcomNetDriver *driver, UBSHcomTLSCertificationCallback CertCallback, + UBSHcomTLSCaCallback CACallback, UBSHcomTLSPrivateKeyCallback PrivateKeyCallback) +{ + driver->RegisterNewEPHandler( + std::bind(&TlsServerNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&TlsServerEndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&TlsServerRequestReceived, std::placeholders::_1)); + driver->RegisterReqPostedHandler(std::bind(&TlsServerRequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&TlsServerOneSideDone, std::placeholders::_1)); + + driver->RegisterTLSCertificationCallback(std::bind(CertCallback, std::placeholders::_1, std::placeholders::_2)); + driver->RegisterTLSCaCallback(std::bind(CACallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + driver->RegisterTLSPrivateKeyCallback(std::bind(PrivateKeyCallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); +} + +// server with invalid cert path +static bool InvalidCertCallback(const std::string &name, std::string &value) +{ + value = certPath + "/server/cacert.pem"; + return true; +} + +static bool InvalidPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = certPath + "/server/cert.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool InvalidCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = certPath + "/CA/key.pem"; + return true; +} + +bool ServerCreateDriverWithInvalidTls() +{ + if (invalidTlsDriver != nullptr) { + NN_LOG_ERROR("invalidTlsDriver already created"); + } + + invalidTlsDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rdmaTlsCertErrServer", true); + if (invalidTlsDriver == nullptr) { + NN_LOG_ERROR("failed to create invalidTlsDriver already created"); + return false; + } + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(invalidTlsDriver, InvalidCertCallback, InvalidCACallback, InvalidPrivateKeyCallback); + + invalidTlsDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(invalidTlsDriver); +} + + +// server with not same Ca cert +static bool VerifyFailedCertCallback(const std::string &name, std::string &value) +{ + value = otherCertPath + "/server/cert.pem"; + return true; +} + +static bool VerifyFailedPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = otherCertPath + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool VerifyFailedCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = otherCertPath + "/CA/cacert.pem"; + return true; +} + +bool ServerCreateDriverTlsNotSameCACert() +{ + if (notSameCaTlsDriver != nullptr) { + NN_LOG_ERROR("notSameCaTlsDriver already created"); + } + + notSameCaTlsDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "verifyFailedServer", true); + if (notSameCaTlsDriver == nullptr) { + NN_LOG_ERROR("failed to create notSameCaTlsDriver already created"); + return false; + } + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(notSameCaTlsDriver, VerifyFailedCertCallback, VerifyFailedCACallback, + VerifyFailedPrivateKeyCallback); + + notSameCaTlsDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(notSameCaTlsDriver); +} + +// server with expired cert +static bool ExpiredCertCertCallback(const std::string &name, std::string &value) +{ + value = expiredCertPath + "/server/cert.pem"; + return true; +} + +static bool ExpiredCertPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = expiredCertPath + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool ExpiredCertCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = expiredCertPath + "/CA/cacert.pem"; + return true; +} + +bool ServerCreateDriverTlsWithExpiredCert() +{ + if (certExpiredTlsDriver != nullptr) { + NN_LOG_ERROR("certExpiredTlsDriver already created"); + } + + certExpiredTlsDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "certExpiredServer", true); + if (certExpiredTlsDriver == nullptr) { + NN_LOG_ERROR("failed to create certExpiredTlsDriver already created"); + } + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(certExpiredTlsDriver, ExpiredCertCertCallback, ExpiredCertCACallback, + ExpiredCertPrivateKeyCallback); + + certExpiredTlsDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(certExpiredTlsDriver); +} + +// server with revoked cert +static bool RevokedCertCertCallback(const std::string &name, std::string &value) +{ + value = revokedCertPath + "/server/cert.pem"; + return true; +} + +static bool RevokedCertPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = revokedCertPath + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool RevokedCertCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + crlPath = revokedCertPath + "/CA/ca.crl"; + caPath = revokedCertPath + "/CA/cacert.pem"; + return true; +} + +bool ServerCreateDriverTlsWithRevokedCert() +{ + if (certRevokedTlsDriver != nullptr) { + NN_LOG_ERROR("certRevokedTlsDriver already created"); + } + + certRevokedTlsDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "certRevokedServer", true); + if (certRevokedTlsDriver == nullptr) { + NN_LOG_ERROR("failed to create certRevokedTlsDriver already created"); + return false; + } + + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(certRevokedTlsDriver, RevokedCertCertCallback, RevokedCertCACallback, + RevokedCertPrivateKeyCallback); + + certRevokedTlsDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(certRevokedTlsDriver); +} + +// server with client verify by none +static bool CliVerifyByNoneCertCallback(const std::string &name, std::string &value) +{ + value = cliVerifyByNoneCertPath + "/server/cert.pem"; + return true; +} + +static bool CliVerifyByNonePrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = cliVerifyByNoneCertPath + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool CliVerifyByNoneACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = cliVerifyByNoneCertPath + "/CA/cacert.pem"; + return true; +} + +bool ServerCreateDriverTlsWithCVerifyByNone() +{ + if (cVerifyByNoneTlsDriver != nullptr) { + NN_LOG_ERROR("cVerifyByNoneTlsDriver already created"); + } + + cVerifyByNoneTlsDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "cliVerifyByNoneServer", true); + if (certRevokedTlsDriver == nullptr) { + NN_LOG_ERROR("failed to create cVerifyByNoneTlsDriver already created"); + return false; + } + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(cVerifyByNoneTlsDriver, CliVerifyByNoneCertCallback, CliVerifyByNoneACallback, + CliVerifyByNonePrivateKeyCallback); + + cVerifyByNoneTlsDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(cVerifyByNoneTlsDriver); +} + +// server with multi level CA cert +static bool MultiLevelCertCallback(const std::string &name, std::string &value) +{ + value = multiCertPath + "/server/server.crt"; + return true; +} + +static bool MultiLevelPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = multiCertPath + "/server/server.key"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool MultiLevelCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + std::string rootCa = multiCertPath + "/CA/rootca.crt"; + std::string secondCa = multiCertPath + "/CA/secondca.crt"; + caPath = rootCa + ":" + secondCa; + return true; +} + +bool ServerCreateDriverTlsWithMultiLevelCert() +{ + if (multiLevelCertTlsDriver != nullptr) { + NN_LOG_ERROR("multiLevelCertTlsDriver already created"); + } + + multiLevelCertTlsDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "multiLeverCertServer", true); + if (multiLevelCertTlsDriver == nullptr) { + NN_LOG_ERROR("failed to create multiLevelCertTlsDriver already created"); + return false; + } + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(multiLevelCertTlsDriver, MultiLevelCertCallback, MultiLevelCACallback, + MultiLevelPrivateKeyCallback); + + multiLevelCertTlsDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(multiLevelCertTlsDriver); +} + +// server with abnormal cert chain +static bool AbnormalCertChainCertCallback(const std::string &name, std::string &value) +{ + value = abnormalCertChainPath + "/server/cert.pem"; + return true; +} + +static bool AbnormalCertChainPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = abnormalCertChainPath + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + return true; +} + +static bool AbnormalCertChainCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + std::string rootCa = abnormalCertChainPath + "/CA/cacert.pem"; + std::string secondCa = abnormalCertChainPath + "/CA/secondca.crt"; + caPath = rootCa; + return true; +} + +bool ServerCreateDriverTlsWithAbnormalCertChain() +{ + if (abnormalCertChainDriver != nullptr) { + NN_LOG_ERROR("multiLevelCertTlsDriver already created"); + } + + abnormalCertChainDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, + "AbnormalCertChainServer", true); + if (abnormalCertChainDriver == nullptr) { + NN_LOG_ERROR("failed to create multiLevelCertTlsDriver already created"); + return false; + } + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(abnormalCertChainDriver, AbnormalCertChainCertCallback, AbnormalCertChainCACallback, + AbnormalCertChainPrivateKeyCallback); + + abnormalCertChainDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(abnormalCertChainDriver); +} + +// server with normal cert chain +static bool NormalCertChainCertCallback(const std::string &name, std::string &value) +{ + value = normalCertChainPath + "/server/cert.pem"; + return true; +} + +static bool NormalCertChainPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = normalCertChainPath + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + return true; +} + +static bool NormalCertChainCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + std::string rootCa = normalCertChainPath + "/CA/cacert.pem"; + caPath = rootCa; + return true; +} + +bool ServerCreateDriverTlsWithNormalCertChain() +{ + if (abnormalCertChainDriver != nullptr) { + NN_LOG_ERROR("abnormalCertChainDriver already created"); + } + + normalCertChainDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "NormalCertChainServer", true); + if (normalCertChainDriver == nullptr) { + NN_LOG_ERROR("failed to create normalCertChainDriver already created"); + return false; + } + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(normalCertChainDriver, NormalCertChainCertCallback, NormalCertChainCACallback, + NormalCertChainPrivateKeyCallback); + + normalCertChainDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(normalCertChainDriver); +} + +// server with custom verify func +static bool CustomVerifyCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = certPath + "/CA/cacert.pem"; + cb = std::bind(&Verify, std::placeholders::_1, std::placeholders::_2); + peerCertVerifyType = ock::hcom::VERIFY_BY_CUSTOM_FUNC; + return true; +} + +bool ServerCreateDriverTlsCustomVerify() +{ + if (customVerifyTlsDriver != nullptr) { + NN_LOG_ERROR("multiLevelCertTlsDriver already created"); + } + + customVerifyTlsDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "customVerifyServer", true); + if (multiLevelCertTlsDriver == nullptr) { + NN_LOG_ERROR("failed to create multiLevelCertTlsDriver already created"); + return false; + } + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(customVerifyTlsDriver, CertCallback, CustomVerifyCACallback, PrivateKeyCallback); + + customVerifyTlsDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(customVerifyTlsDriver); +} + +// server in norma case +bool ServerCreateDriverWithTls() +{ + if (tlsDriver != nullptr) { + NN_LOG_ERROR("tlsDriver already created"); + return false; + } + + tlsDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rdmaTlsServer", true); + if (tlsDriver == nullptr) { + NN_LOG_ERROR("failed to create tlsDriver already created"); + return false; + } + + options.SetNetDeviceIpMask(tlsIpSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + setServerDriverCallback(tlsDriver, CertCallback, CACallback, PrivateKeyCallback); + + tlsDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(tlsDriver); +} + +bool ServerRegSglMemWithTls() +{ + for (uint16_t i = 0; i < NN_NO4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = tlsDriver->CreateMemoryRegion(NN_NO8, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + tlsSerlocalMrInfo[i].lAddress = mr->GetAddress(); + tlsSerlocalMrInfo[i].lKey = mr->GetLKey(); + tlsSerlocalMrInfo[i].size = NN_NO8; + memset(reinterpret_cast(tlsSerlocalMrInfo[i].lAddress), 0, NN_NO8); + } + + return true; +} + +// client +UBSHcomNetDriver *tlsClientDriver = nullptr; +UBSHcomNetDriver *tlsClientCertExpiredDriver = nullptr; +UBSHcomNetDriver *tlsClientCertRevokedDriver = nullptr; +UBSHcomNetDriver *tlsClientVerifyByNoneDriver = nullptr; +UBSHcomNetDriver *tlsClientMultiLevelCertDriver = nullptr; +UBSHcomNetDriver *tlsClientAbnormalCertChainDriver = nullptr; +UBSHcomNetDriver *tlsClientNormalCertChainDriver = nullptr; +UBSHcomNetDriver *tlsClientCustomVerifyTlsDriver = nullptr; + +UBSHcomNetEndpointPtr tlsClientSyncEp = nullptr; +UBSHcomNetEndpointPtr tlsClientAsyncEp = nullptr; + +void TlsClientEndPointBroken(const UBSHcomNetEndpointPtr &clientEp) +{ + NN_LOG_INFO("end point " << clientEp->Id() << " broken"); +} + +int TlsClientRequestReceived(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +int TlsClientRequestPosted(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +int TlsClientOneSideDone(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +void setClientDriverCallback(UBSHcomNetDriver *driver, UBSHcomTLSCertificationCallback CertCallback, + UBSHcomTLSCaCallback CACallback, UBSHcomTLSPrivateKeyCallback PrivateKeyCallback) +{ + driver->RegisterEPBrokenHandler(std::bind(&TlsClientEndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&TlsClientRequestReceived, std::placeholders::_1)); + driver->RegisterReqPostedHandler(std::bind(&TlsClientRequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&TlsClientOneSideDone, std::placeholders::_1)); + + driver->RegisterTLSCertificationCallback(std::bind(CertCallback, std::placeholders::_1, std::placeholders::_2)); + driver->RegisterTLSCaCallback(std::bind(CACallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + driver->RegisterTLSPrivateKeyCallback(std::bind(PrivateKeyCallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); +} + +// client in normal case +static bool ClientCertCallback(const std::string &name, std::string &value) +{ + value = certPath + "/client/cert.pem"; + return true; +} + +static bool ClientPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = certPath + "/client/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool ClientCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = certPath + "/CA/cacert.pem"; + return true; +} + + +bool ClientCreateDriverWithTls() +{ + if (tlsClientDriver != nullptr) { + NN_LOG_ERROR("tlsClientDriver already created"); + } + + tlsClientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rdmaTlsClient1", false); + if (tlsClientDriver == nullptr) { + NN_LOG_ERROR("failed to create tlsClientDriver already created"); + } + setClientDriverCallback(tlsClientDriver, &ClientCertCallback, &ClientCACallback, &ClientPrivateKeyCallback); + + tlsClientDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(tlsClientDriver); +} + +bool SyncClientConnectWithInvalidTls() +{ + if (tlsClientDriver == nullptr) { + NN_LOG_ERROR("tlsClientDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientDriver->Connect("hello world", tlsClientSyncEp, NET_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + EXPECT_EQ(NNCode::NN_OOB_CLIENT_SOCKET_ERROR, result); + return false; + } + + tlsClientSyncEp->PeerIpAndPort(); + return true; +} + +bool SyncClientConnectWithTls() +{ + if (tlsClientDriver == nullptr) { + NN_LOG_ERROR("tlsClientDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientDriver->Connect("hello world", tlsClientSyncEp, NET_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +bool AsyncClientConnectWithTls() +{ + if (tlsClientDriver == nullptr) { + NN_LOG_ERROR("clientDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientDriver->Connect("hello world", tlsClientAsyncEp, 0)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +void TlsAsyncRequest() +{ + int result; + UBSHcomNetTransRequest req((void *)(const_cast(asyncSendValue.c_str())), asyncSendValue.length(), 0); + + if ((result = tlsClientAsyncEp->PostSend(0, req)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + return; + } + + EXPECT_EQ(NN_OK, result); +} + +void TlsAsyncSendRawRequest() +{ + int result = 0; + UBSHcomNetTransRequest req((void *)(const_cast(asyncSendRawValue.c_str())), asyncSendRawValue.length(), 0); + + if ((result = tlsClientAsyncEp->PostSendRaw(req, 1)) != 0) { + NN_LOG_ERROR("failed to post message to data to server"); + return; + } + + EXPECT_EQ(result, NN_OK); +} + + +void TlsSyncRequests() +{ + int result; + UBSHcomNetTransRequest req((void *)(const_cast(syncSendValue.c_str())), syncSendValue.length(), 0); + + if ((result = tlsClientSyncEp->PostSend(CHECK_SYNC_RESPONSE, req)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + return; + } + + + if ((result = tlsClientSyncEp->WaitCompletion(0)) != 0) { + NN_LOG_INFO("failed to wait completion, result " << result); + return; + } + + UBSHcomNetResponseContext respCtx {}; + if ((result = tlsClientSyncEp->Receive(respCtx)) != 0) { + NN_LOG_INFO("failed to get response, result " << result); + return; + } + + EXPECT_EQ(NN_OK, result); + EXPECT_EQ(syncReplyValue.length(), respCtx.Message()->DataLen()); + EXPECT_EQ(0, strncmp(syncReplyValue.c_str(), (char *)respCtx.Message()->Data(), syncReplyValue.length())); +} + +void TlsSyncSendRawRequests() +{ + int result; + UBSHcomNetTransRequest req((void *)(const_cast(syncSendRawValue.c_str())), syncSendRawValue.length(), 0); + + if ((result = tlsClientSyncEp->PostSendRaw(req, CHECK_SYNC_RESPONSE)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + return; + } + + + if ((result = tlsClientSyncEp->WaitCompletion(0)) != 0) { + NN_LOG_INFO("failed to wait completion, result " << result); + return; + } + + UBSHcomNetResponseContext respCtx {}; + if ((result = tlsClientSyncEp->ReceiveRaw(respCtx)) != 0) { + NN_LOG_INFO("failed to get response, result " << result); + return; + } + + EXPECT_EQ(NN_OK, result); + EXPECT_EQ(syncReplyValue.length(), respCtx.Message()->DataLen()); + EXPECT_EQ(0, strncmp(syncReplyValue.c_str(), (char *)respCtx.Message()->Data(), syncReplyValue.length())); +} + +// client with expired cert +static bool CertExpiredClientCertCallback(const std::string &name, std::string &value) +{ + value = expiredCertPath + "/client/cert.pem"; + return true; +} + +static bool CertExpiredClientPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = expiredCertPath + "/client/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool CertExpiredClientCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = expiredCertPath + "/CA/cacert.pem"; + return true; +} + +bool ClientCreateDriverWithTlsExpiredCert() +{ + if (tlsClientCertExpiredDriver != nullptr) { + NN_LOG_ERROR("tlsClientCertExpiredDriver already created"); + } + + tlsClientCertExpiredDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "certExpiredClient", false); + if (tlsClientCertExpiredDriver == nullptr) { + NN_LOG_ERROR("failed to create tlsClientCertExpiredDriver already created"); + return false; + } + setClientDriverCallback(tlsClientCertExpiredDriver, &CertExpiredClientCertCallback, &CertExpiredClientCACallback, + &CertExpiredClientPrivateKeyCallback); + + tlsClientCertExpiredDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(tlsClientCertExpiredDriver); +} + +bool SyncClientConnectWithTlsCertExpired() +{ + if (tlsClientCertExpiredDriver == nullptr) { + NN_LOG_ERROR("tlsClientCertExpiredDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientCertExpiredDriver->Connect("hello world", tlsClientSyncEp, NET_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +// client with revoked cert +static bool CertRevokedClientCertCallback(const std::string &name, std::string &value) +{ + value = revokedCertPath + "/client/cert.pem"; + return true; +} + +static bool CertRevokedClientPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = revokedCertPath + "/client/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool CertRevokedClientCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + crlPath = revokedCertPath + "/CA/ca.crl"; + caPath = revokedCertPath + "/CA/cacert.pem"; + return true; +} + +bool ClientCreateDriverWithTlsRevokedCert() +{ + if (tlsClientCertRevokedDriver != nullptr) { + NN_LOG_ERROR("tlsClientCertRevokedDriver already created"); + } + + tlsClientCertRevokedDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "certRevokedClient", false); + if (tlsClientCertRevokedDriver == nullptr) { + NN_LOG_ERROR("failed to create tlsClientCertRevokedDriver already created"); + return false; + } + + setClientDriverCallback(tlsClientCertRevokedDriver, &CertRevokedClientCertCallback, &CertRevokedClientCACallback, + &CertRevokedClientPrivateKeyCallback); + + + tlsClientCertRevokedDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(tlsClientCertRevokedDriver); +} + + +bool SyncClientConnectWithTlsCertRevoked() +{ + if (tlsClientCertRevokedDriver == nullptr) { + NN_LOG_ERROR("tlsClientCertRevokedDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientCertRevokedDriver->Connect("hello world", tlsClientSyncEp, NET_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +// client with verify by none +static bool VerifyByNoneClientCertCallback(const std::string &name, std::string &value) +{ + value = cliVerifyByNoneCertPath + "/client/cert.pem"; + return true; +} + +static bool VerifyByNoneClientPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = cliVerifyByNoneCertPath + "/client/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + return true; +} + +static bool VerifyByNoneClientCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = cliVerifyByNoneCertPath + "/CA/cacert.pem"; + peerCertVerifyType = ock::hcom::VERIFY_BY_NONE; + return true; +} + +bool ClientCreateDriverWithTlsVerifyByNone() +{ + if (tlsClientVerifyByNoneDriver != nullptr) { + NN_LOG_ERROR("tlsClientVerifyByNoneDriver already created"); + } + + tlsClientVerifyByNoneDriver = + UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rdmaTlsClientVerifyByNoneDriver", false); + if (tlsClientVerifyByNoneDriver == nullptr) { + NN_LOG_ERROR("failed to create tlsClientVerifyByNoneDriver already created"); + return false; + } + setClientDriverCallback(tlsClientVerifyByNoneDriver, &VerifyByNoneClientCertCallback, &VerifyByNoneClientCACallback, + &VerifyByNoneClientPrivateKeyCallback); + + tlsClientVerifyByNoneDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(tlsClientVerifyByNoneDriver); +} + + +bool SyncClientConnectWithTlsVerifyByNone() +{ + if (tlsClientVerifyByNoneDriver == nullptr) { + NN_LOG_ERROR("tlsClientCertRevokedDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientVerifyByNoneDriver->Connect("hello world", tlsClientSyncEp, NET_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +// client with multi level ca +static bool MultiLevelClientCertCallback(const std::string &name, std::string &value) +{ + value = multiCertPath + "/client/client.crt"; + return true; +} + +static bool MultiLevelClientPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = multiCertPath + "/client/client.key"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + return true; +} + +static bool MultiLevelClientCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + std::string rootCa = multiCertPath + "/CA/rootca.crt"; + std::string secondCa = multiCertPath + "/CA/secondca.crt"; + caPath = rootCa + ":" + secondCa; + return true; +} + +bool ClientCreateDriverWithTlsMultiLevelCert() +{ + if (tlsClientMultiLevelCertDriver != nullptr) { + NN_LOG_ERROR("tlsClientMultiLevelCertDriver already created"); + return false; + } + + tlsClientMultiLevelCertDriver = + UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rdmaTlsClientMultiLevelDriver", false); + if (tlsClientMultiLevelCertDriver == nullptr) { + NN_LOG_ERROR("failed to create tlsClientMultiLevelCertDriver already created"); + return false; + } + + setClientDriverCallback(tlsClientMultiLevelCertDriver, &MultiLevelClientCertCallback, &MultiLevelClientCACallback, + &MultiLevelClientPrivateKeyCallback); + + tlsClientMultiLevelCertDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(tlsClientMultiLevelCertDriver); +} + +bool SyncClientConnectWithMultiLevelCert() +{ + if (tlsClientMultiLevelCertDriver == nullptr) { + NN_LOG_ERROR("tlsClientCertRevokedDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientMultiLevelCertDriver->Connect("hello world", tlsClientSyncEp, NET_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +// client with abnormal cert chain +static bool AbnormalClientCertChainCallback(const std::string &name, std::string &value) +{ + value = abnormalCertChainPath + "/client/cert.pem"; + return true; +} + +static bool AbnormalClientCertChainPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, + int &len, UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = abnormalCertChainPath + "/client/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + return true; +} + +static bool AbnormalClientCertChainCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + std::string rootCa = abnormalCertChainPath + "/CA/cacert.pem"; + caPath = rootCa; + return true; +} + +bool ClientCreateDriverWithTlsAbnormalCertChain() +{ + if (tlsClientAbnormalCertChainDriver != nullptr) { + NN_LOG_ERROR("tlsClientAbnormalCertChainDriver already created"); + return false; + } + + tlsClientAbnormalCertChainDriver = + UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "tlsClientAbnormalCertChainDriver", false); + if (tlsClientAbnormalCertChainDriver == nullptr) { + NN_LOG_ERROR("failed to create tlsClientAbnormalCertChainDriver already created"); + return false; + } + + setClientDriverCallback(tlsClientAbnormalCertChainDriver, &AbnormalClientCertChainCallback, + &AbnormalClientCertChainCACallback, &AbnormalClientCertChainPrivateKeyCallback); + + tlsClientAbnormalCertChainDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(tlsClientAbnormalCertChainDriver); +} + +bool SyncClientConnectWithAbnormalCertChain() +{ + if (tlsClientAbnormalCertChainDriver == nullptr) { + NN_LOG_ERROR("tlsClientAbnormalCertChainDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientAbnormalCertChainDriver->Connect("hello world", tlsClientSyncEp, NET_EP_EVENT_POLLING)) != + 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +// client with normal cert chain +static bool NormalClientCertChainCallback(const std::string &name, std::string &value) +{ + value = normalCertChainPath + "/client/cert.pem"; + return true; +} + +static bool NormalClientCertChainPrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, + int &len, UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = normalCertChainPath + "/client/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + return true; +} + +static bool NormalClientCertChainCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + std::string rootCa = normalCertChainPath + "/CA/cacert.pem"; + caPath = rootCa; + return true; +} + +bool ClientCreateDriverWithTlsNormalCertChain() +{ + if (tlsClientNormalCertChainDriver != nullptr) { + NN_LOG_ERROR("tlsClientNormalCertChainDriver already created"); + return false; + } + + tlsClientNormalCertChainDriver = + UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "tlsClientNormalCertChainDriver", false); + if (tlsClientNormalCertChainDriver == nullptr) { + NN_LOG_ERROR("failed to create tlsClientNormalCertChainDriver already created"); + return false; + } + + setClientDriverCallback(tlsClientNormalCertChainDriver, &NormalClientCertChainCallback, + &NormalClientCertChainCACallback, &NormalClientCertChainPrivateKeyCallback); + + tlsClientNormalCertChainDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(tlsClientNormalCertChainDriver); +} + +bool SyncClientConnectWithNormalCertChain() +{ + if (tlsClientNormalCertChainDriver == nullptr) { + NN_LOG_ERROR("tlsClientNormalCertChainDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientNormalCertChainDriver->Connect("hello world", tlsClientSyncEp, NET_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +// client with custom verify func +static bool CustomVerifyClientCACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = certPath + "/CA/cacert.pem"; + cb = std::bind(&Verify, std::placeholders::_1, std::placeholders::_2); + peerCertVerifyType = ock::hcom::VERIFY_BY_CUSTOM_FUNC; + return true; +} + +bool ClientCreateDriverWithTlsCustomVerify() +{ + if (tlsClientCustomVerifyTlsDriver != nullptr) { + NN_LOG_ERROR("tlsClientCustomVerifyTlsDriver already created"); + return false; + } + + tlsClientCustomVerifyTlsDriver = + UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, "rdmaTlsClientCustomVerifyDriver", false); + if (tlsClientCustomVerifyTlsDriver == nullptr) { + NN_LOG_ERROR("failed to create tlsClientCustomVerifyTlsDriver already created"); + return false; + } + + setClientDriverCallback(tlsClientCustomVerifyTlsDriver, &ClientCertCallback, &CustomVerifyClientCACallback, + &ClientPrivateKeyCallback); + + tlsClientCustomVerifyTlsDriver->OobIpAndPort(BASE_IP, 9998); + + return driverInitAndStart(tlsClientCustomVerifyTlsDriver); +} + +bool SyncClientConnectWithCustomVerify() +{ + if (tlsClientCustomVerifyTlsDriver == nullptr) { + NN_LOG_ERROR("tlsClientCustomVerifyTlsDriver is null"); + return false; + } + + int result = 0; + if ((result = tlsClientCustomVerifyTlsDriver->Connect("hello world", tlsClientSyncEp, NET_EP_EVENT_POLLING)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + return true; +} + +TestCaseRdmaTLS::TestCaseRdmaTLS() {} + +void TestCaseRdmaTLS::SetUp() +{ + ASSERT_EQ(0, ValidateTlsCert()); + MOCKER(ReadRoCEVersionFromFile).stubs().will(returnValue(0)); + + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; // 只支持EVENT模式 + options.mrSendReceiveSegSize = 1024; + options.mrSendReceiveSegCount = 1024; + options.prePostReceiveSizePerQP = 32; + options.enableTls = true; + options.cipherSuite = ock::hcom::AES_GCM_256; + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestCaseRdmaTLS::TearDown() +{ + GlobalMockObject::verify(); +} + + +TEST_F(TestCaseRdmaTLS, RDMATLSSuccess) +{ + bool result = ServerCreateDriverWithTls(); + EXPECT_EQ(true, result); + + result = ServerRegSglMemWithTls(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTls(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithTls(); + EXPECT_EQ(true, result); + TlsSyncRequests(); + TlsSyncSendRawRequests(); + + result = AsyncClientConnectWithTls(); + EXPECT_EQ(true, result); + TlsAsyncRequest(); + TlsAsyncSendRawRequest(); + + tlsClientDriver->Stop(); + tlsClientDriver->UnInitialize(); + + tlsDriver->Stop(); + tlsDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(tlsClientDriver->Name()); + UBSHcomNetDriver::DestroyInstance(tlsDriver->Name()); +} + +TEST_F(TestCaseRdmaTLS, RDMATLSCertPathInvalid) +{ + ASSERT_EQ(0, ValidateTlsCert()); + MOCKER(ReadRoCEVersionFromFile).stubs().will(returnValue(0)); + bool result = ServerCreateDriverWithInvalidTls(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTls(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithInvalidTls(); + EXPECT_EQ(false, result); + + tlsClientDriver->Stop(); + tlsClientDriver->UnInitialize(); + + invalidTlsDriver->Stop(); + invalidTlsDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(tlsClientDriver->Name()); + UBSHcomNetDriver::DestroyInstance(invalidTlsDriver->Name()); +} + +TEST_F(TestCaseRdmaTLS, RDMATLSCertNotSameCAFailed) +{ + bool result = ServerCreateDriverTlsNotSameCACert(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTls(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithTls(); + EXPECT_EQ(false, result); + + tlsClientDriver->Stop(); + tlsClientDriver->UnInitialize(); + + notSameCaTlsDriver->Stop(); + notSameCaTlsDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(tlsClientDriver->Name()); + UBSHcomNetDriver::DestroyInstance(notSameCaTlsDriver->Name()); +} + +TEST_F(TestCaseRdmaTLS, RDMATLSCertExpiredFailed) +{ + bool result = ServerCreateDriverTlsWithExpiredCert(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTlsExpiredCert(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithTlsCertExpired(); + EXPECT_EQ(false, result); + + tlsClientCertExpiredDriver->Stop(); + tlsClientCertExpiredDriver->UnInitialize(); + + certExpiredTlsDriver->Stop(); + certExpiredTlsDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(tlsClientCertExpiredDriver->Name()); + UBSHcomNetDriver::DestroyInstance(certExpiredTlsDriver->Name()); +} + +TEST_F(TestCaseRdmaTLS, RDMATLSCertRevokedFailed) +{ + bool result = ServerCreateDriverTlsWithRevokedCert(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTlsRevokedCert(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithTlsCertRevoked(); + EXPECT_EQ(false, result); + + tlsClientCertRevokedDriver->Stop(); + tlsClientCertRevokedDriver->UnInitialize(); + + certRevokedTlsDriver->Stop(); + certRevokedTlsDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(tlsClientCertRevokedDriver->Name()); + UBSHcomNetDriver::DestroyInstance(certRevokedTlsDriver->Name()); +} + +TEST_F(TestCaseRdmaTLS, RDMATLSVerifyByNoneInClientSuccess) +{ + bool result = ServerCreateDriverTlsWithCVerifyByNone(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTlsVerifyByNone(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithTlsVerifyByNone(); + EXPECT_EQ(true, result); + + tlsClientVerifyByNoneDriver->Stop(); + tlsClientVerifyByNoneDriver->UnInitialize(); + + cVerifyByNoneTlsDriver->Stop(); + cVerifyByNoneTlsDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(tlsClientVerifyByNoneDriver->Name()); + UBSHcomNetDriver::DestroyInstance(cVerifyByNoneTlsDriver->Name()); +} + +TEST_F(TestCaseRdmaTLS, RDMATLSMultiLevelCertSuccess) +{ + bool result = ServerCreateDriverTlsWithMultiLevelCert(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTlsMultiLevelCert(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithMultiLevelCert(); + EXPECT_EQ(true, result); + + multiLevelCertTlsDriver->Stop(); + multiLevelCertTlsDriver->UnInitialize(); + + tlsClientMultiLevelCertDriver->Stop(); + tlsClientMultiLevelCertDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(multiLevelCertTlsDriver->Name()); + UBSHcomNetDriver::DestroyInstance(tlsClientMultiLevelCertDriver->Name()); +} + +TEST_F(TestCaseRdmaTLS, RDMATLSAbnormalCertChainFailed) +{ + bool result = ServerCreateDriverTlsWithAbnormalCertChain(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTlsAbnormalCertChain(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithAbnormalCertChain(); + EXPECT_EQ(false, result); + + abnormalCertChainDriver->Stop(); + abnormalCertChainDriver->UnInitialize(); + + tlsClientAbnormalCertChainDriver->Stop(); + tlsClientAbnormalCertChainDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(abnormalCertChainDriver->Name()); + UBSHcomNetDriver::DestroyInstance(tlsClientAbnormalCertChainDriver->Name()); +} + +TEST_F(TestCaseRdmaTLS, RDMATLSNormalCertChainSuccess) +{ + bool result = ServerCreateDriverTlsWithNormalCertChain(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTlsNormalCertChain(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithNormalCertChain(); + EXPECT_EQ(true, result); + + normalCertChainDriver->Stop(); + normalCertChainDriver->UnInitialize(); + + tlsClientNormalCertChainDriver->Stop(); + tlsClientNormalCertChainDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(normalCertChainDriver->Name()); + UBSHcomNetDriver::DestroyInstance(tlsClientNormalCertChainDriver->Name()); +} + +TEST_F(TestCaseRdmaTLS, RDMATLSVerifyByCustomFuncSuccess) +{ + bool result = ServerCreateDriverTlsCustomVerify(); + EXPECT_EQ(true, result); + + result = ClientCreateDriverWithTlsCustomVerify(); + EXPECT_EQ(true, result); + + result = SyncClientConnectWithCustomVerify(); + EXPECT_EQ(true, result); + + tlsClientCustomVerifyTlsDriver->Stop(); + tlsClientCustomVerifyTlsDriver->UnInitialize(); + + customVerifyTlsDriver->Stop(); + customVerifyTlsDriver->UnInitialize(); + + UBSHcomNetDriver::DestroyInstance(tlsClientCustomVerifyTlsDriver->Name()); + UBSHcomNetDriver::DestroyInstance(customVerifyTlsDriver->Name()); +} +#endif \ No newline at end of file diff --git a/test/llt/testcase/transport/rdma/test_rdma_tls.hpp b/test/llt/testcase/transport/rdma/test_rdma_tls.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ae0ae7f4ff5e3de8d9f01020a77ca45a60f8635b --- /dev/null +++ b/test/llt/testcase/transport/rdma/test_rdma_tls.hpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef _TEST_RDMA_TLS_HPP_ +#define _TEST_RDMA_TLS_HPP_ +#include +#include + +class TestCaseRdmaTLS : public testing::Test { +public: + TestCaseRdmaTLS(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; + +#endif diff --git a/test/llt/testcase/transport/shm/test_shm_common.h b/test/llt/testcase/transport/shm/test_shm_common.h new file mode 100644 index 0000000000000000000000000000000000000000..cc155ab8c12725e71f4d9391ec45c7e98f2500df --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_common.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_TEST_SHM_COMMON_H +#define HCOM_TEST_SHM_COMMON_H +#define BASE_IP "192.168.100.204" +#define IP_SEG "192.168.100.0/24" +#define UDSNAME "shm_server_ut" + +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t lKey = 0; + uint32_t size = 0; +} __attribute__((packed)); +#endif // HCOM_TEST_SHM_COMMON_H diff --git a/test/llt/testcase/transport/shm/test_shm_driver_oob.cpp b/test/llt/testcase/transport/shm/test_shm_driver_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..258ec2c535da275468f877c4bb88e45189141c85 --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_driver_oob.cpp @@ -0,0 +1,522 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include +#include + +#include "hcom.h" +#include "net_mem_pool_fixed.h" +#include "openssl_api_wrapper.h" +#include "shm_common.h" +#include "shm_handle.h" +#include "shm_composed_endpoint.h" +#include "shm_mr_pool.h" +#include "test_shm_common.h" +#include "test_shm_driver_oob.h" + +using namespace ock::hcom; +TestShmDriverOob::TestShmDriverOob() {} + +UBSHcomNetEndpointPtr shmEp = nullptr; +UBSHcomNetDriverOptions shmOptions {}; +static int port = 8091; +UBSHcomNetDriver *shmServerDriver; +UBSHcomNetDriver *shmClientDriver; +UBSHcomNetTransSgeIov iovPtrShmServer[4]; +UBSHcomNetTransSgeIov iovPtrShmClient[4]; +static int g_nameSeed = 0; + +int shmOobNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort); + shmEp = newEP; + return 0; +} + +void shmOobEndPointBroken(const UBSHcomNetEndpointPtr &brokenEp) +{ + NN_LOG_INFO("end point " << brokenEp->Id()); +} + +int shmOobRequestReceived(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + return 0; +} + +int shmOobRequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} + + +int shmOobOneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + + +void SetCallBack(UBSHcomNetDriver *driver) +{ + driver->RegisterNewEPHandler( + std::bind(&shmOobNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&shmOobEndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&shmOobRequestReceived, std::placeholders::_1)); + driver->RegisterReqPostedHandler(std::bind(&shmOobRequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&shmOobOneSideDone, std::placeholders::_1)); +} + +bool RegisterShmMemory(UBSHcomNetDriver *driver, UBSHcomNetTransSgeIov iovs[], + std::vector &mrs) +{ + for (int i = 0; i < 4; i++) { + auto &iov = iovs[i]; + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO8, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + iov.lAddress = mr->GetAddress(); + iov.lKey = mr->GetLKey(); + iov.size = NN_NO8; + mrs.push_back(mr); + memset(reinterpret_cast(iov.lAddress), 0, iov.size); + } + return true; +} + +static void DestoryShmMemory(UBSHcomNetDriver *driver, std::vector &mrs) +{ + while (!mrs.empty()) { + driver->DestroyMemoryRegion(mrs.back()); + mrs.pop_back(); + } +} + +void TestShmDriverOob::SetUp() +{ + shmOptions.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + shmOptions.SetNetDeviceIpMask(IP_SEG); + shmOptions.pollingBatchSize = 16; + shmOptions.SetWorkerGroups("1"); + shmOptions.SetWorkerGroupsCpuSet("1-1"); + shmOptions.dontStartWorkers = false; + shmOptions.oobType = ock::hcom::NET_OOB_UDS; + shmOptions.enableTls = false; + + shmServerDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, + "shm_oob" + std::to_string(g_nameSeed++), true); + UBSHcomNetOobUDSListenerOptions listenOpt; + listenOpt.Name(UDSNAME); + listenOpt.perm = 0; + shmServerDriver->AddOobUdsOptions(listenOpt); + + shmClientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, + "shm_oob" + std::to_string(g_nameSeed++), false); + shmServerDriver->OobIpAndPort(BASE_IP, port); + shmClientDriver->OobIpAndPort(BASE_IP, port++); + SetCallBack(shmServerDriver); + SetCallBack(shmClientDriver); + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestShmDriverOob::TearDown() +{ + if (shmEp != nullptr) { + shmEp.Set(nullptr); + } + std::string serverName = shmServerDriver->Name(); + std::string clientName = shmClientDriver->Name(); + if (shmServerDriver->IsStarted()) { + shmServerDriver->Stop(); + } + if (shmServerDriver->IsInited()) { + shmServerDriver->UnInitialize(); + } + if (shmClientDriver->IsStarted()) { + shmClientDriver->Stop(); + } + if (shmClientDriver->IsInited()) { + shmClientDriver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(serverName); + UBSHcomNetDriver::DestroyInstance(clientName); + + GlobalMockObject::verify(); +} + +TEST_F(TestShmDriverOob, InitSuccess) +{ + NResult result = shmServerDriver->Initialize(shmOptions); + EXPECT_EQ(NNCode::NN_OK, result); + shmServerDriver->UnInitialize(); +} + +TEST_F(TestShmDriverOob, InitSuccessTwice) +{ + shmServerDriver->Initialize(shmOptions); + NResult result = shmServerDriver->Initialize(shmOptions); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestShmDriverOob, InitSuccessWithBusyPolling) +{ + shmOptions.mode = UBSHcomNetDriverWorkingMode::NET_BUSY_POLLING; + NResult result = shmServerDriver->Initialize(shmOptions); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestShmDriverOob, InitFailWithInvaildParam) +{ + shmOptions.qpSendQueueSize = 4; + NResult result = shmServerDriver->Initialize(shmOptions); + shmOptions.qpSendQueueSize = NN_NO256; + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestShmDriverOob, InitSuccessWithoutSetWorkGroup) +{ + shmOptions.SetWorkerGroups(""); + NResult result = shmServerDriver->Initialize(shmOptions); + shmOptions.SetWorkerGroups("1"); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestShmDriverOob, InitFailWithFailToInitWorker) +{ + MOCKER((int(*)(int))syscall).defaults().will(returnValue(-1)); + NResult result = shmServerDriver->Initialize(shmOptions); + EXPECT_EQ(ShCode::SH_PARAM_INVALID, result); +} + +TEST_F(TestShmDriverOob, ConnectFailWithCreateChannelFail) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + MOCKER((int(*)(int))syscall).defaults().will(returnValue(-1)); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "halo", shmEp); + EXPECT_EQ(ShCode::SH_FILE_OP_FAILED, result); +} + +TEST_F(TestShmDriverOob, InitFailWithWorkGroupHasZeroWorker) +{ + shmOptions.SetWorkerGroups("0"); + NResult result = shmServerDriver->Initialize(shmOptions); + shmOptions.SetWorkerGroups("1"); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestShmDriverOob, StartFailWithStartOobServerFail) +{ + shmServerDriver->Initialize(shmOptions); + MOCKER(::socket).defaults().will(returnValue(-1)); + NResult result = shmServerDriver->Start(); + EXPECT_EQ(NNCode::NN_OOB_LISTEN_SOCKET_ERROR, result); +} + +/* CreateMemoryRegion */ +TEST_F(TestShmDriverOob, CreateMemoryRegionSuccess) +{ + shmServerDriver->Initialize(shmOptions); + UBSHcomNetMemoryRegionPtr mr; + NResult result = shmServerDriver->CreateMemoryRegion(16, mr); + EXPECT_EQ(NNCode::NN_OK, result); + shmServerDriver->DestroyMemoryRegion(mr); +} + +TEST_F(TestShmDriverOob, DestoryMemoryRegionSuccess) +{ + shmServerDriver->Initialize(shmOptions); + UBSHcomNetMemoryRegionPtr mr; + NResult result = shmServerDriver->CreateMemoryRegion(16, mr); + EXPECT_EQ(NNCode::NN_OK, result); + shmServerDriver->DestroyMemoryRegion(mr); +} + +TEST_F(TestShmDriverOob, CreateMemoryRegionFail) +{ + shmServerDriver->Initialize(shmOptions); + UBSHcomNetMemoryRegionPtr mr; + NResult result = shmServerDriver->CreateMemoryRegion(0, mr); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestShmDriverOob, CreateMemoryRegionInitializeFail) +{ + shmServerDriver->Initialize(shmOptions); + UBSHcomNetMemoryRegionPtr mr; + MOCKER_CPP(&ShmHandle::Initialize).defaults().will(returnValue(301)); + NResult result = shmServerDriver->CreateMemoryRegion(16, mr); + EXPECT_EQ(NNCode::NN_NOT_INITIALIZED, result); +} + +TEST_F(TestShmDriverOob, CreateMemoryRegionMrHandleZero) +{ + shmServerDriver->Initialize(shmOptions); + UBSHcomNetMemoryRegionPtr mr; + uintptr_t mAddress = 0; + MOCKER_CPP(&ShmHandle::ShmAddress).defaults().will(returnValue(mAddress)); + NResult result = shmServerDriver->CreateMemoryRegion(16, mr); + EXPECT_EQ(NNCode::NN_MALLOC_FAILED, result); +} + +TEST_F(TestShmDriverOob, DestoryMemoryRegionFail) +{ + shmServerDriver->Initialize(shmOptions); + UBSHcomNetMemoryRegionPtr mr; + NResult result = shmServerDriver->CreateMemoryRegion(0, mr); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); + shmServerDriver->DestroyMemoryRegion(mr); +} + +TEST_F(TestShmDriverOob, CreateMemoryRegionWithAddressFail) +{ + shmServerDriver->Initialize(shmOptions); + auto tmpBuf = memalign(NN_NO4096, 10); + UBSHcomNetMemoryRegionPtr mr; + NResult result = shmServerDriver->CreateMemoryRegion(reinterpret_cast(tmpBuf), 16, mr); + EXPECT_EQ(NNCode::NN_INVALID_OPERATION, result); + free(tmpBuf); +} + +TEST_F(TestShmDriverOob, CreateMemoryRegionWithAddressFailWithAddressIsZero) +{ + shmServerDriver->Initialize(shmOptions); + UBSHcomNetMemoryRegionPtr mr; + NResult result = shmServerDriver->CreateMemoryRegion(0, 16, mr); + EXPECT_EQ(NNCode::NN_INVALID_OPERATION, result); +} + +/* connect */ +TEST_F(TestShmDriverOob, ConnectSuccess) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "halo", shmEp); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestShmDriverOob, ConnectTcpFail) +{ + shmOptions.oobType = ock::hcom::NET_OOB_TCP; + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + NResult result = shmClientDriver->Connect("halo", shmEp); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestShmDriverOob, ConnectFailWithoutInit) +{ + NResult result = shmClientDriver->Connect(UDSNAME, 0, "halo", shmEp); + EXPECT_EQ(NNCode::NN_NOT_INITIALIZED, result); +} + +TEST_F(TestShmDriverOob, ConnectFailWithPayloadOversize) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + char payload[1030]; + for (char &i : payload) { + i = '1'; + } + payload[1029] = '\0'; + NResult result = shmClientDriver->Connect(UDSNAME, 0, payload, shmEp); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestShmDriverOob, ConnectFail) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + MOCKER(::connect).defaults().will(returnValue(-1)); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "halo", shmEp); + EXPECT_EQ(NNCode::NN_OOB_CLIENT_SOCKET_ERROR, result); +} + +TEST_F(TestShmDriverOob, ConnectFailWithMagicMismatch) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmOptions.magic = 104; + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "halo", shmEp); + shmOptions.magic = NN_NO256; + EXPECT_EQ(NNCode::NN_CONNECT_REFUSED, result); +} + +TEST_F(TestShmDriverOob, ConnectCreateShmHandleFail) +{ + ShmHandle *tmpHandle = nullptr; + MOCKER_CPP(&NetRef::Get).defaults().will(returnValue(tmpHandle)); + NResult result = shmServerDriver->Initialize(shmOptions); + EXPECT_EQ(SH_NEW_OBJECT_FAILED, result); +} + +TEST_F(TestShmDriverOob, ConnectCreateEventQueueFail) +{ + ShmEventQueue *tmpQueue = nullptr; + MOCKER_CPP(&NetRef::Get).defaults().will(returnValue(tmpQueue)); + NResult result = shmServerDriver->Initialize(shmOptions); + EXPECT_EQ(SH_NEW_OBJECT_FAILED, result); +} + +/* sync connect */ +TEST_F(TestShmDriverOob, ConnectSyncSuccess) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "hello world", shmEp, NET_EP_SELF_POLLING); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestShmDriverOob, ConnectSyncFail1) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + MOCKER(::connect).defaults().will(returnValue(-1)); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "hello world", shmEp, NET_EP_SELF_POLLING); + EXPECT_EQ(NNCode::NN_OOB_CLIENT_SOCKET_ERROR, result); +} + +TEST_F(TestShmDriverOob, ConnectSyncFailWithMagicMismatch) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmOptions.magic = 104; + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "halo", shmEp, NET_EP_SELF_POLLING); + EXPECT_EQ(NNCode::NN_CONNECT_REFUSED, result); +} + +TEST_F(TestShmDriverOob, ConnectSyncWithCreateFail) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + ShmSyncEndpoint *tmpEp = nullptr; + MOCKER_CPP(&NetRef::Get).defaults().will(returnValue(tmpEp)); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "hello world", shmEp, NET_EP_SELF_POLLING); + EXPECT_EQ(SH_NEW_OBJECT_FAILED, result); +} + +TEST_F(TestShmDriverOob, ConnectSyncCreateShmHandleFail) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + ShmHandle *tmpHandle = nullptr; + MOCKER_CPP(&NetRef::Get).defaults().will(returnValue(tmpHandle)); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "hello world", shmEp, NET_EP_SELF_POLLING); + EXPECT_EQ(SH_NEW_OBJECT_FAILED, result); +} + +TEST_F(TestShmDriverOob, ConnectSyncCreateEventQueueFail) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + ShmEventQueue *tmpQueue = nullptr; + MOCKER_CPP(&NetRef::Get).defaults().will(returnValue(tmpQueue)); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "hello world", shmEp, NET_EP_SELF_POLLING); + EXPECT_EQ(SH_NEW_OBJECT_FAILED, result); +} + +TEST_F(TestShmDriverOob, SendSuccess) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + shmClientDriver->Connect(UDSNAME, 0, "halo", shmEp); + static char data[900] = {}; + UBSHcomNetTransRequest req((void *)(data), sizeof(data), 0); + req.upCtxSize = NN_NO16; + for (auto i = 0; i < 16; i++) { + req.upCtxData[i] = 'a'; + } + NResult result = shmEp->PostSend(1, req); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestShmDriverOob, SendRawSglSuccess) +{ + shmServerDriver->RegisterReqPostedHandler(std::bind(&shmOobRequestPosted, std::placeholders::_1)); + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + shmClientDriver->Connect(UDSNAME, 0, "halo", shmEp); + + std::vector mrServer; + RegisterShmMemory(shmServerDriver, iovPtrShmServer, mrServer); + std::vector mrClient; + RegisterShmMemory(shmClientDriver, iovPtrShmClient, mrClient); + UBSHcomNetTransSgeIov iov[4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = iovPtrShmClient[i].lAddress; + iov[i].rAddress = iovPtrShmServer[i].lAddress; + iov[i].lKey = iovPtrShmClient[i].lKey; + iov[i].rKey = iovPtrShmServer[i].lKey; + iov[i].size = NN_NO4; + } + UBSHcomNetTransSglRequest req(iov, NN_NO4, 0); + req.upCtxSize = NN_NO16; + for (auto i = 0; i < 16; i++) { + req.upCtxData[i] = 'a'; + } + NResult result = shmEp->PostSendRawSgl(req, 1); + EXPECT_EQ(NNCode::NN_OK, result); + + DestoryShmMemory(shmServerDriver, mrServer); + DestoryShmMemory(shmClientDriver, mrClient); +} + +TEST_F(TestShmDriverOob, DestoryEp) +{ + shmServerDriver->Initialize(shmOptions); + shmServerDriver->Start(); + shmClientDriver->Initialize(shmOptions); + shmClientDriver->Start(); + NResult result = shmClientDriver->Connect(UDSNAME, 0, "hello world", shmEp, NET_EP_SELF_POLLING); + EXPECT_EQ(NNCode::NN_OK, result); + shmClientDriver->DestroyEndpoint(shmEp); +} + +TEST_F(TestShmDriverOob, DestoryEpFail) +{ + UBSHcomNetEndpointPtr shmEp2 = nullptr; + shmClientDriver->DestroyEndpoint(shmEp2); +} diff --git a/test/llt/testcase/transport/shm/test_shm_driver_oob.h b/test/llt/testcase/transport/shm/test_shm_driver_oob.h new file mode 100644 index 0000000000000000000000000000000000000000..4519d129d2f71f1b2f0433c284b0d19ad848ad76 --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_driver_oob.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_TEST_SHM_DRIVER_OOB_H +#define HCOM_TEST_SHM_DRIVER_OOB_H +#include +#include +using namespace ock::hcom; +class TestShmDriverOob : public testing::Test { +public: + TestShmDriverOob(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; + +#endif // HCOM_TEST_SHM_DRIVER_OOB_H \ No newline at end of file diff --git a/test/llt/testcase/transport/shm/test_shm_endpoint.cpp b/test/llt/testcase/transport/shm/test_shm_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7f734eda4764287636d4e770cd7e5dff1a5cc68 --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_endpoint.cpp @@ -0,0 +1,701 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "hcom.h" +#include "net_shm_sync_endpoint.h" +#include "net_shm_async_endpoint.h" +#include "shm_worker.h" +#include "shm_composed_endpoint.h" +#include "test_shm_common.h" +#include "test_shm_endpoint.h" + +using namespace ock::hcom; +TestShmEndpoint::TestShmEndpoint() {} + +static uint32_t iovCnt = NN_NO4; +static UBSHcomNetEndpointPtr asyncEp = nullptr; +static sem_t sem; +static int g_nameSeed = 0; + +static int ServerNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + asyncEp = newEP; + return 0; +} + +static void EndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + if (asyncEp != nullptr) { + asyncEp.Set(nullptr); + } + NN_LOG_INFO("end point " << ep->Id()); +} + +static int RequestReceived(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("client request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + return 0; +} + +static int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} + +static int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + sem_post(&sem); + return 0; +} + + +static bool RegSglMem(UBSHcomNetDriver *driver, UBSHcomNetTransSgeIov mrInfo[], + std::vector &mrs) +{ + for (int i = 0; i < 4; ++i) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + mrInfo[i].lAddress = mr->GetAddress(); + mrInfo[i].lKey = mr->GetLKey(); + mrInfo[i].size = NN_NO8; + mrs.push_back(mr); + memset(reinterpret_cast(mrInfo[i].lAddress), 1, mrInfo[i].size); + } + return true; +} + +static void DestorySglMem(UBSHcomNetDriver *driver, std::vector &mrs) +{ + while (!mrs.empty()) { + driver->DestroyMemoryRegion(mrs.back()); + mrs.pop_back(); + } +} + +static bool RegReadWriteMem(UBSHcomNetDriver *driver, TestRegMrInfo mrInfo[], + std::vector &mrReadWrite) +{ + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + mrInfo[0].lAddress = mr->GetAddress(); + mrInfo[0].lKey = mr->GetLKey(); + mrInfo[0].size = NN_NO8; + mrReadWrite.push_back(mr); + memset(reinterpret_cast(mrInfo[0].lAddress), 0, mrInfo[0].size); + + return true; +} + +static bool RegReadWriteSglMem(UBSHcomNetDriver *driver, TestRegMrInfo mrInfo[], + std::vector &mrReadWrite) +{ + for (int i = 0; i < 4; ++i) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + mrInfo[i].lAddress = mr->GetAddress(); + mrInfo[i].lKey = mr->GetLKey(); + mrInfo[i].size = NN_NO8; + mrReadWrite.push_back(mr); + memset(reinterpret_cast(mrInfo[i].lAddress), '1', mrInfo[i].size); + } + return true; +} + +/* server new request sgl callback */ +TestRegMrInfo asyncServerMrInfo[NN_NO4]; +static int RequestReceivedServer(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("server request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + + int result = 0; + UBSHcomNetTransRequest rsp((void *)(asyncServerMrInfo), sizeof(asyncServerMrInfo), 0); + if ((result = asyncEp->PostSend(1, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + + NN_LOG_INFO("request rsp Mr info"); + std::string readValue((char *)asyncServerMrInfo[0].lAddress, asyncServerMrInfo[0].size); + NN_LOG_INFO("idx:" << 0 << " key:" << asyncServerMrInfo[0].lKey << " address:" << asyncServerMrInfo[0].lAddress << + " size: " << asyncServerMrInfo[0].size << "string: " << readValue); + return 0; +} +static int RequestReceivedSglServer(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("server request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + + int result = 0; + UBSHcomNetTransRequest rsp((void *)(asyncServerMrInfo), sizeof(asyncServerMrInfo), 0); + if ((result = asyncEp->PostSend(1, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + + NN_LOG_INFO("request rsp Mr info"); + for (uint16_t i = 0; i < NN_NO4; i++) { + NN_LOG_INFO("idx:" << i << " key:" << asyncServerMrInfo[i].lKey << " address:" << + asyncServerMrInfo[i].lAddress << " size: " << asyncServerMrInfo[i].size); + } + return 0; +} + +/* client new request sgl callback */ + +TestRegMrInfo getRemoteMrInfo[NN_NO4]; +static int RequestReceivedClient(const UBSHcomNetRequestContext &ctx) +{ + memcpy(getRemoteMrInfo, ctx.Message()->Data(), ctx.Message()->DataLen()); + NN_LOG_INFO("get remote Mr info"); + std::string readValue((char *)getRemoteMrInfo[0].lAddress, getRemoteMrInfo[0].size); + NN_LOG_INFO("idx:" << 0 << " key:" << getRemoteMrInfo[0].lKey << " address:" << getRemoteMrInfo[0].lAddress << + " size:" << getRemoteMrInfo[0].size << "string: " << readValue); + + sem_post(&sem); + return 0; +} + +static int RequestReceivedSglClient(const UBSHcomNetRequestContext &ctx) +{ + memcpy(getRemoteMrInfo, ctx.Message()->Data(), ctx.Message()->DataLen()); + NN_LOG_INFO("get remote Mr info"); + sem_post(&sem); + return 0; +} + + +static bool CreateServerDriver(UBSHcomNetDriver *&driver, int (*reqHandler)(const UBSHcomNetRequestContext &), + UBSHcomNetDriverOptions &asyncShmOptions) +{ + auto name = "server_ep_" + std::to_string(g_nameSeed++); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, name, true); + + if (driver == nullptr) { + NN_LOG_ERROR("failed to create asyncServerDriver already created"); + return false; + } + asyncShmOptions.oobType = ock::hcom::NET_OOB_UDS; + asyncShmOptions.mode = ock::hcom::NET_EVENT_POLLING; + asyncShmOptions.enableTls = false; + + UBSHcomNetOobUDSListenerOptions listenOpt; + listenOpt.Name(UDSNAME); + listenOpt.perm = 0; + driver->AddOobUdsOptions(listenOpt); + + driver->RegisterNewEPHandler( + std::bind(&ServerNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(reqHandler); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + int result = 0; + if ((result = driver->Initialize(asyncShmOptions)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("asyncServerDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start asyncServerDriver " << result); + return false; + } + NN_LOG_INFO("asyncServerDriver started"); + return true; +} + +static bool CreateClientDriver(UBSHcomNetDriver *&driver, int (*reqHandler)(const UBSHcomNetRequestContext &), + UBSHcomNetDriverOptions &asyncShmOptions) +{ + auto name = "client_ep_" + std::to_string(g_nameSeed++); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, name, false); + if (driver == nullptr) { + NN_LOG_ERROR("failed to create asyncClientDriver already created"); + return false; + } + asyncShmOptions.oobType = ock::hcom::NET_OOB_UDS; + asyncShmOptions.mode = ock::hcom::NET_EVENT_POLLING; + asyncShmOptions.enableTls = false; + + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(reqHandler); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + int result = 0; + if ((result = driver->Initialize(asyncShmOptions)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("asyncClientDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start asyncClientDriver " << result); + return false; + } + NN_LOG_INFO("asyncClientDriver started"); + return true; +} + + +void CloseShmDriver(UBSHcomNetDriver *&asyncClientDriver, UBSHcomNetDriver *&asyncServerDriver) +{ + asyncEp->Close(); + if (asyncEp != nullptr) { + asyncEp.Set(nullptr); + } + std::string serverName = asyncServerDriver->Name(); + std::string clientName = asyncClientDriver->Name(); + if (asyncServerDriver->IsStarted()) { + asyncServerDriver->Stop(); + } + if (asyncServerDriver->IsInited()) { + asyncServerDriver->UnInitialize(); + } + + if (asyncClientDriver->IsStarted()) { + asyncClientDriver->Stop(); + } + if (asyncClientDriver->IsInited()) { + asyncClientDriver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(serverName); + UBSHcomNetDriver::DestroyInstance(clientName); +} + +void TestShmEndpoint::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestShmEndpoint::TearDown() +{ + GlobalMockObject::verify(); +} + + +TEST_F(TestShmEndpoint, PostSendRetry) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriverOptions asyncShmOptions {}; + UBSHcomNetDriver *asyncClientDriver = nullptr; + UBSHcomNetDriver *asyncServerDriver = nullptr; + CreateServerDriver(asyncServerDriver, RequestReceived, asyncShmOptions); + CreateClientDriver(asyncClientDriver, RequestReceived, asyncShmOptions); + asyncClientDriver->Connect(UDSNAME, 0, "hello server", ep); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + + result = ep->PostSend(1, req); + EXPECT_EQ(SH_OK, result); + + UBSHcomNetTransOpInfo innerOpInfo(2, 0, 0, NTH_TWO_SIDE); + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SH_OK, result); + + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SH_OK, result); + + ep->DefaultTimeout(1); + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .stubs() + .will(returnValue(0)) + .then(returnValue(-1)); + result = ep->PostSend(1, req); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .stubs() + .will(returnValue(0)) + .then(returnValue(-1)); + + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .stubs() + .will(returnValue(0)) + .then(returnValue(-1)); + + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmDataChannel::TryOccupyWithWait).defaults().will(returnValue(305)); + result = ep->PostSend(1, req); + EXPECT_EQ(SH_NOT_INITIALIZED, result); + + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SH_NOT_INITIALIZED, result); + + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SH_NOT_INITIALIZED, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .defaults() + .will(returnValue(-1)); + result = ep->PostSend(1, req); + EXPECT_EQ(SH_RETRY_FULL, result); + + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SH_RETRY_FULL, result); + + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SH_RETRY_FULL, result); + + GlobalMockObject::verify(); + + ep->Close(); + CloseShmDriver(asyncClientDriver, asyncServerDriver); +} + +TEST_F(TestShmEndpoint, PostSendRawSglRetry) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriverOptions asyncShmOptions {}; + UBSHcomNetDriver *asyncClientDriver = nullptr; + UBSHcomNetDriver *asyncServerDriver = nullptr; + CreateServerDriver(asyncServerDriver, RequestReceived, asyncShmOptions); + CreateClientDriver(asyncClientDriver, RequestReceived, asyncShmOptions); + asyncClientDriver->Connect(UDSNAME, 0, "hello server", ep); + + std::vector mrs; + UBSHcomNetTransSgeIov clientMrInfo[NN_NO4]; + bool res = RegSglMem(asyncClientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + + result = ep->PostSendRawSgl(req, 1); + EXPECT_EQ(SH_OK, result); + + UBSHcomNetTransSglRequest req2(clientMrInfo, iovCnt, NN_NO29); + result = ep->PostSendRawSgl(req2, 1); + EXPECT_EQ(SH_PARAM_INVALID, result); + + UBSHcomNetTransSglRequest reqSgl(clientMrInfo, iovCnt, 0); + result = ep->PostSendRawSgl(reqSgl, 0); + EXPECT_EQ(NN_INVALID_PARAM, result); + + ep->DefaultTimeout(1); + MOCKER_CPP(&UBSHcomNetAtomicState::Compare).defaults().will(returnValue(true)); + result = ep->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(SH_CH_BROKEN, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .defaults() + .will(returnValue(-1)); + result = ep->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(SH_RETRY_FULL, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .stubs() + .will(returnValue(0)) + .then(returnValue(-1)); + + MOCKER_CPP(&ShmChannel::RemoveOpCompInfo, HResult(ShmChannel::*)(ShmOpCompInfo *)) + .defaults() + .will(returnValue(317)); + result = ep->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + GlobalMockObject::verify(); + + ShmSglOpContextInfo *infoSgl = nullptr; + MOCKER_CPP(&OpContextInfoPool::Get).defaults().will(returnValue(infoSgl)); + result = ep->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(SH_PARAM_INVALID, result); + + GlobalMockObject::verify(); + + ShmOpCompInfo *info = nullptr; + MOCKER_CPP(&OpContextInfoPool::Get).defaults().will(returnValue(info)); + result = ep->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(SH_OP_CTX_FULL, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&UBSHcomNetAtomicState::Compare).defaults().will(returnValue(false)); + result = ep->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(NN_EP_NOT_ESTABLISHED, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&MemoryRegionChecker::Validate).defaults().will(returnValue(100)); + result = ep->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(NN_INVALID_LKEY, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmDataChannel::TryOccupyWithWait).defaults().will(returnValue(305)); + result = ep->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(SH_NOT_INITIALIZED, result); + + GlobalMockObject::verify(); + + DestorySglMem(asyncClientDriver, mrs); + ep->Close(); + + CloseShmDriver(asyncClientDriver, asyncServerDriver); +} + +TEST_F(TestShmEndpoint, PostReadWrite) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriverOptions asyncShmOptions {}; + UBSHcomNetDriver *asyncClientDriver = nullptr; + UBSHcomNetDriver *asyncServerDriver = nullptr; + CreateServerDriver(asyncServerDriver, RequestReceivedServer, asyncShmOptions); + CreateClientDriver(asyncClientDriver, RequestReceivedClient, asyncShmOptions); + asyncClientDriver->Connect(UDSNAME, 0, "hello server", ep); + ep->DefaultTimeout(1); + sem_init(&sem, 0, 0); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(asyncServerDriver, asyncServerMrInfo, mrServer); + EXPECT_TRUE(res); + TestRegMrInfo asyncClientMrInfo[NN_NO4]; + std::vector mrClient; + res = RegReadWriteMem(asyncClientDriver, asyncClientMrInfo, mrClient); + EXPECT_TRUE(res); + + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + + sem_wait(&sem); + + UBSHcomNetTransRequest req; + req.lAddress = asyncClientMrInfo[0].lAddress; + req.rAddress = getRemoteMrInfo[0].lAddress; + req.lKey = asyncClientMrInfo[0].lKey; + req.rKey = getRemoteMrInfo[0].lKey; + req.size = getRemoteMrInfo[0].size; + + result = ep->PostRead(req); + EXPECT_EQ(SH_OK, result); + sem_wait(&sem); + + std::string readValue((char *)asyncClientMrInfo->lAddress, asyncClientMrInfo->size); + NN_LOG_INFO("value[" << 0 << "]= " << readValue); + + result = ep->PostWrite(req); + EXPECT_EQ(SH_OK, result); + sem_wait(&sem); + + ShmHandlePtr localMrHandle = nullptr; + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap).defaults().will(returnValue(localMrHandle)); + result = ep->PostWrite(req); + EXPECT_EQ(SH_ERROR, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .defaults() + .will(returnValue(-1)); + result = ep->PostWrite(req); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&UBSHcomNetAtomicState::Compare).defaults().will(returnValue(true)); + result = ep->PostRead(req); + EXPECT_EQ(SH_CH_BROKEN, result); + + GlobalMockObject::verify(); + + ShmOpContextInfo *info = nullptr; + MOCKER_CPP(&OpContextInfoPool::Get).defaults().will(returnValue(info)); + result = ep->PostRead(req); + EXPECT_EQ(SH_OP_CTX_FULL, result); + + result = ep->PostWrite(req); + EXPECT_EQ(SH_OP_CTX_FULL, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .stubs() + .will(returnValue(-1)); + MOCKER_CPP(&ShmChannel::RemoveOpCtxInfo).defaults().will(returnValue(317)); + result = ep->PostRead(req); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + GlobalMockObject::verify(); + + DestorySglMem(asyncServerDriver, mrServer); + DestorySglMem(asyncClientDriver, mrClient); + ep->Close(); + CloseShmDriver(asyncClientDriver, asyncServerDriver); +} + +TEST_F(TestShmEndpoint, PostReadWriteSgl) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriverOptions asyncShmOptions {}; + UBSHcomNetDriver *asyncClientDriver = nullptr; + UBSHcomNetDriver *asyncServerDriver = nullptr; + CreateServerDriver(asyncServerDriver, RequestReceivedSglServer, asyncShmOptions); + CreateClientDriver(asyncClientDriver, RequestReceivedSglClient, asyncShmOptions); + asyncClientDriver->Connect(UDSNAME, 0, "hello server", ep); + ep->DefaultTimeout(1); + sem_init(&sem, 0, 0); + + bool res; + std::vector mrServer; + res = RegReadWriteSglMem(asyncServerDriver, asyncServerMrInfo, mrServer); + EXPECT_TRUE(res); + TestRegMrInfo asyncClientMrInfo[NN_NO4]; + std::vector mrClient; + res = RegReadWriteSglMem(asyncClientDriver, asyncClientMrInfo, mrClient); + EXPECT_TRUE(res); + + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + + sem_wait(&sem); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = asyncClientMrInfo[i].lAddress; + iov[i].rAddress = getRemoteMrInfo[i].lAddress; + iov[i].lKey = asyncClientMrInfo[i].lKey; + iov[i].rKey = getRemoteMrInfo[i].lKey; + iov[i].size = getRemoteMrInfo[i].size; + } + + UBSHcomNetTransSglRequest req(iov, NN_NO4, 0); + result = ep->PostRead(req); + EXPECT_EQ(SH_OK, result); + sem_wait(&sem); + + for (uint16_t i = 0; i < NN_NO4; i++) { + std::string readValue((char *)asyncClientMrInfo[i].lAddress, asyncClientMrInfo[i].size); + NN_LOG_INFO("value[" << i << "]= " << readValue); + } + result = ep->PostWrite(req); + EXPECT_EQ(SH_OK, result); + sem_wait(&sem); + + ShmOpContextInfo *info = nullptr; + MOCKER_CPP(&OpContextInfoPool::Get).defaults().will(returnValue(info)); + result = ep->PostRead(req); + EXPECT_EQ(SH_OP_CTX_FULL, result); + + result = ep->PostWrite(req); + EXPECT_EQ(SH_OP_CTX_FULL, result); + + GlobalMockObject::verify(); + + ShmSglOpContextInfo *infoSgl = nullptr; + MOCKER_CPP(&OpContextInfoPool::Get).defaults().will(returnValue(infoSgl)); + result = ep->PostRead(req); + EXPECT_EQ(SH_PARAM_INVALID, result); + + GlobalMockObject::verify(); + + ShmHandlePtr localMrHandle = nullptr; + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap).defaults().will(returnValue(localMrHandle)); + MOCKER_CPP(&ShmChannel::RemoveOpCtxInfo).defaults().will(returnValue(317)); + result = ep->PostRead(req); + EXPECT_EQ(SH_ERROR, result); + + GlobalMockObject::verify(); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .defaults() + .will(returnValue(-1)); + MOCKER_CPP(&ShmChannel::RemoveOpCtxInfo).defaults().will(returnValue(317)); + result = ep->PostRead(req); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + GlobalMockObject::verify(); + + DestorySglMem(asyncServerDriver, mrServer); + DestorySglMem(asyncClientDriver, mrClient); + + CloseShmDriver(asyncClientDriver, asyncServerDriver); +} + +TEST_F(TestShmEndpoint, GetRemoteUdsIdInfo) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriverOptions asyncShmOptions {}; + UBSHcomNetDriver *asyncClientDriver = nullptr; + UBSHcomNetDriver *asyncServerDriver = nullptr; + CreateServerDriver(asyncServerDriver, RequestReceived, asyncShmOptions); + CreateClientDriver(asyncClientDriver, RequestReceived, asyncShmOptions); + asyncClientDriver->Connect(UDSNAME, 0, "hello server", ep); + + UBSHcomEpOptions epOptions {}; + result = ep->SetEpOption(epOptions); + EXPECT_EQ(NN_OK, result); + + UBSHcomNetUdsIdInfo idInfo {}; + if (asyncEp != nullptr) { + result = asyncEp->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(NN_OK, result); + NN_LOG_INFO("=======new endpoint remote uds ids, pid: " << idInfo.pid << " uid: " << idInfo.uid << " gid: " << + idInfo.gid << " result:" << result); + } + + GlobalMockObject::verify(); + + MOCKER_CPP(&UBSHcomNetAtomicState::Compare).defaults().will(returnValue(false)); + result = asyncEp->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(NN_EP_NOT_ESTABLISHED, result); + + GlobalMockObject::verify(); + + result = ep->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(NN_UDS_ID_INFO_NOT_SUPPORT, result); + + ep->Close(); + CloseShmDriver(asyncClientDriver, asyncServerDriver); +} \ No newline at end of file diff --git a/test/llt/testcase/transport/shm/test_shm_endpoint.h b/test/llt/testcase/transport/shm/test_shm_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..bc4affda6a5fc5b90444b5906d9e921d19a8b8cd --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_endpoint.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_SHM_ENDPOINT_H +#define HCOM_SHM_ENDPOINT_H +#include +#include + +class TestShmEndpoint : public testing::Test { +public: + TestShmEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); +}; +#endif // HCOM_SHM_ENDPOINT_H diff --git a/test/llt/testcase/transport/shm/test_shm_send_recv_msg.cpp b/test/llt/testcase/transport/shm/test_shm_send_recv_msg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8525e5951264e050d36bc0ebbf8989dae309b40 --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_send_recv_msg.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include +#include + +#include "hcom.h" +#include "net_mem_pool_fixed.h" +#include "openssl_api_wrapper.h" +#include "shm_common.h" +#include "shm_channel.h" +#include "shm_handle.h" +#include "test_shm_common.h" +#include "test_shm_send_recv_msg.h" + +using namespace ock::hcom; +TestShmSendRecvMsg::TestShmSendRecvMsg() {} + +UBSHcomNetEndpointPtr clientEp = nullptr; +UBSHcomNetEndpointPtr serverEp = nullptr; +UBSHcomNetDriverOptions shmMsgOptions {}; +static int port = 8091; +UBSHcomNetDriver *shmMsgServerDriver; +UBSHcomNetDriver *shmMsgClientDriver; +UBSHcomNetTransSgeIov iovPtrMsgServer[4]; +UBSHcomNetTransSgeIov iovPtrMsgClient[4]; +static int g_nameSeed = 0; +uint32_t fdsLen = 3; + + +int ShmMsgNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + serverEp = newEP; + return 0; +} + +void ShmMsgEndPointBroken(const UBSHcomNetEndpointPtr &brokenEp) +{ + NN_LOG_INFO("end point " << brokenEp->Id()); +} + +int ShmMsgRequestReceived(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + return 0; +} + +int ShmMsgRequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} + + +int ShmMsgOneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + + +void SetShmCallBack(UBSHcomNetDriver *driver) +{ + driver->RegisterNewEPHandler( + std::bind(&ShmMsgNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&ShmMsgEndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&ShmMsgRequestReceived, std::placeholders::_1)); + driver->RegisterReqPostedHandler(std::bind(&ShmMsgRequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&ShmMsgOneSideDone, std::placeholders::_1)); +} + +bool MsgRegisterShmMemory(UBSHcomNetDriver *driver, UBSHcomNetTransSgeIov iovs[], + std::vector &mrs) +{ + for (int i = 0; i < 4; i++) { + auto &iov = iovs[i]; + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO8, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + iov.lAddress = mr->GetAddress(); + iov.lKey = mr->GetLKey(); + iov.size = NN_NO8; + mrs.push_back(mr); + memset(reinterpret_cast(iov.lAddress), 0, iov.size); + } + return true; +} + +static void DestorySglMem(UBSHcomNetDriver *driver, std::vector &mrs) +{ + while (!mrs.empty()) { + driver->DestroyMemoryRegion(mrs.back()); + mrs.pop_back(); + } +} + +void CreateFds(int shmFds[]) +{ + for (uint32_t i = 0; i < fdsLen; i++) { + std::string name = "example_shm_fd_" + std::to_string(i); + auto tmpFd = shm_open(name.c_str(), O_CREAT | O_RDWR, 0755); + if (tmpFd < 0) { + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to create shm file error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE));); + } + shmFds[i] = tmpFd; + } +} + +void DestoryFds(int outFds[]) +{ + NN_LOG_INFO("Destory fds len:" << fdsLen << " fds[0]:" << outFds[0] << " fds[1]:" << outFds[1] << " fds[2]:" << + outFds[2]); + + for (uint32_t i = 0; i < fdsLen; i++) { + auto mappedAddress = mmap(nullptr, 10, PROT_READ | PROT_WRITE, MAP_SHARED, outFds[i], 0); + if (mappedAddress == MAP_FAILED) { + close(outFds[i]); + char buf[NET_STR_ERROR_BUF_SIZE] = {0}; + NN_LOG_ERROR("Failed to mmap file error " + << NetFunc::NN_GetStrError(errno, buf, NET_STR_ERROR_BUF_SIZE));); + } + NN_LOG_INFO("shm map fds:" << outFds[i]); + + close(outFds[i]); + outFds[i] = -1; + } +} + +void TestShmSendRecvMsg::SetUp() +{ + shmMsgOptions.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + shmMsgOptions.SetNetDeviceIpMask(IP_SEG); + shmMsgOptions.pollingBatchSize = 16; + shmMsgOptions.SetWorkerGroups("1"); + shmMsgOptions.SetWorkerGroupsCpuSet("1-1"); + shmMsgOptions.dontStartWorkers = false; + shmMsgOptions.oobType = ock::hcom::NET_OOB_UDS; + shmMsgOptions.enableTls = false; + + shmMsgServerDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, + "shm_msg" + std::to_string(g_nameSeed++), true); + UBSHcomNetOobUDSListenerOptions listenOpt; + listenOpt.Name(UDSNAME); + listenOpt.perm = 0; + shmMsgServerDriver->AddOobUdsOptions(listenOpt); + + shmMsgClientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, + "shm_msg" + std::to_string(g_nameSeed++), false); + shmMsgServerDriver->OobIpAndPort(BASE_IP, port); + shmMsgClientDriver->OobIpAndPort(BASE_IP, port++); + SetShmCallBack(shmMsgServerDriver); + SetShmCallBack(shmMsgClientDriver); + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestShmSendRecvMsg::TearDown() +{ + clientEp->Close(); + serverEp->Close(); + if (serverEp != nullptr) { + serverEp.Set(nullptr); + } + if (clientEp != nullptr) { + clientEp.Set(nullptr); + } + std::string serverName = shmMsgServerDriver->Name(); + std::string clientName = shmMsgClientDriver->Name(); + if (shmMsgClientDriver->IsStarted()) { + shmMsgClientDriver->Stop(); + } + if (shmMsgClientDriver->IsInited()) { + shmMsgClientDriver->UnInitialize(); + } + + if (shmMsgServerDriver->IsStarted()) { + shmMsgServerDriver->Stop(); + } + if (shmMsgServerDriver->IsInited()) { + shmMsgServerDriver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(serverName); + UBSHcomNetDriver::DestroyInstance(clientName); + GlobalMockObject::verify(); +} + +TEST_F(TestShmSendRecvMsg, SendMsgRecvMsgSuccess) +{ + shmMsgServerDriver->Initialize(shmMsgOptions); + shmMsgServerDriver->Start(); + shmMsgClientDriver->Initialize(shmMsgOptions); + shmMsgClientDriver->Start(); + NResult result = shmMsgClientDriver->Connect(UDSNAME, 0, "halo", clientEp); + int shmFds[3]; + int outFds[3]; + CreateFds(shmFds); + result = clientEp->SendFds(shmFds, fdsLen); + EXPECT_EQ(NNCode::NN_OK, result); + result = serverEp->ReceiveFds(outFds, fdsLen, 1); + EXPECT_EQ(NNCode::NN_OK, result); + DestoryFds(outFds); +} + +TEST_F(TestShmSendRecvMsg, SendMsgFail1) +{ + shmMsgServerDriver->Initialize(shmMsgOptions); + shmMsgServerDriver->Start(); + shmMsgClientDriver->Initialize(shmMsgOptions); + shmMsgClientDriver->Start(); + NResult result = shmMsgClientDriver->Connect(UDSNAME, 0, "halo", clientEp); + int shmFds[3]; + int outFds[3]; + CreateFds(shmFds); + result = clientEp->SendFds(shmFds, NN_NO5); + EXPECT_EQ(NN_PARAM_INVALID, result); + + DestoryFds(shmFds); +} + +TEST_F(TestShmSendRecvMsg, SendMsgFail2) +{ + shmMsgServerDriver->Initialize(shmMsgOptions); + shmMsgServerDriver->Start(); + shmMsgClientDriver->Initialize(shmMsgOptions); + shmMsgClientDriver->Start(); + NResult result = shmMsgClientDriver->Connect(UDSNAME, 0, "halo", clientEp); + int shmFds[3]; + + result = clientEp->SendFds(shmFds, fdsLen); + EXPECT_EQ(NN_INVALID_PARAM, result); +} + +TEST_F(TestShmSendRecvMsg, SendMsgFail3) +{ + shmMsgServerDriver->Initialize(shmMsgOptions); + shmMsgServerDriver->Start(); + shmMsgClientDriver->Initialize(shmMsgOptions); + shmMsgClientDriver->Start(); + NResult result = shmMsgClientDriver->Connect(UDSNAME, 0, "halo", clientEp); + int shmFds[3]; + CreateFds(shmFds); + MOCKER_CPP(&UBSHcomNetAtomicState::Compare).defaults().will(returnValue(false)); + result = clientEp->SendFds(shmFds, fdsLen); + EXPECT_EQ(NN_EP_NOT_ESTABLISHED, result); + + DestoryFds(shmFds); +} + +TEST_F(TestShmSendRecvMsg, RecvMsgFail1) +{ + shmMsgServerDriver->Initialize(shmMsgOptions); + shmMsgServerDriver->Start(); + shmMsgClientDriver->Initialize(shmMsgOptions); + shmMsgClientDriver->Start(); + NResult result = shmMsgClientDriver->Connect(UDSNAME, 0, "halo", clientEp); + int shmFds[3]; + int outFds[3]; + CreateFds(shmFds); + result = clientEp->SendFds(shmFds, fdsLen); + EXPECT_EQ(NNCode::NN_OK, result); + result = serverEp->ReceiveFds(outFds, NN_NO5, 1); + EXPECT_EQ(NN_PARAM_INVALID, result); + + DestoryFds(shmFds); +} + +TEST_F(TestShmSendRecvMsg, RecvMsgFail2) +{ + shmMsgServerDriver->Initialize(shmMsgOptions); + shmMsgServerDriver->Start(); + shmMsgClientDriver->Initialize(shmMsgOptions); + shmMsgClientDriver->Start(); + NResult result = shmMsgClientDriver->Connect(UDSNAME, 0, "halo", clientEp); + int shmFds[3]; + int outFds[3]; + CreateFds(shmFds); + result = clientEp->SendFds(shmFds, fdsLen); + EXPECT_EQ(NNCode::NN_OK, result); + MOCKER_CPP(&UBSHcomNetAtomicState::Compare).defaults().will(returnValue(false)); + result = serverEp->ReceiveFds(outFds, fdsLen, 1); + EXPECT_EQ(NN_EP_NOT_ESTABLISHED, result); + DestoryFds(shmFds); +} + +TEST_F(TestShmSendRecvMsg, ReadWriteSglSuccess) +{ + shmMsgClientDriver->RegisterOneSideDoneHandler(std::bind(&ShmMsgOneSideDone, std::placeholders::_1)); + shmMsgServerDriver->RegisterOneSideDoneHandler(std::bind(&ShmMsgOneSideDone, std::placeholders::_1)); + shmMsgServerDriver->Initialize(shmMsgOptions); + shmMsgServerDriver->Start(); + shmMsgClientDriver->Initialize(shmMsgOptions); + shmMsgClientDriver->Start(); + shmMsgClientDriver->Connect(UDSNAME, 0, "halo", clientEp); + std::vector mrServer; + MsgRegisterShmMemory(shmMsgServerDriver, iovPtrMsgServer, mrServer); + std::vector mrClient; + MsgRegisterShmMemory(shmMsgClientDriver, iovPtrMsgClient, mrClient); + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < 4; i++) { + iov[i].lAddress = iovPtrMsgClient[i].lAddress; + iov[i].rAddress = iovPtrMsgServer[i].lAddress; + iov[i].lKey = iovPtrMsgClient[i].lKey; + iov[i].rKey = iovPtrMsgServer[i].lKey; + iov[i].size = NN_NO4; + } + UBSHcomNetTransSglRequest req(iov, NN_NO4, 0); + req.upCtxSize = NN_NO16; + for (auto i = 0; i < 16; i++) { + req.upCtxData[i] = 'a'; + } + MOCKER_CPP(&ShmChannel::AddMrFd, HResult(ShmChannel::*)(int)).defaults().will(returnValue(315)); + HResult result = clientEp->PostRead(req); + EXPECT_EQ(SH_TIME_OUT, result); + + DestorySglMem(shmMsgServerDriver, mrServer); + DestorySglMem(shmMsgClientDriver, mrClient); +} diff --git a/test/llt/testcase/transport/shm/test_shm_send_recv_msg.h b/test/llt/testcase/transport/shm/test_shm_send_recv_msg.h new file mode 100644 index 0000000000000000000000000000000000000000..d4befb77874488ff62f4bb1665aab12efc344de4 --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_send_recv_msg.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_SHM_SENDMSG_RECVMSG_H +#define HCOM_SHM_SENDMSG_RECVMSG_H +#include +#include + +class TestShmSendRecvMsg : public testing::Test { +public: + TestShmSendRecvMsg(); + virtual void SetUp(void); + virtual void TearDown(void); +}; +#endif // HCOM_SHM_SENDMSG_RECVMSG_H diff --git a/test/llt/testcase/transport/shm/test_shm_sync_endpoint.cpp b/test/llt/testcase/transport/shm/test_shm_sync_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2f749c6dfbd9f60e3af29da6b136285e1b7d9348 --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_sync_endpoint.cpp @@ -0,0 +1,1549 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "hcom.h" +#include "net_shm_sync_endpoint.h" +#include "net_shm_async_endpoint.h" +#include "shm_queue.h" +#include "shm_common.h" +#include "shm_worker.h" +#include "test_shm_common.h" +#include "test_shm_sync_endpoint.h" + +using namespace ock::hcom; +TestShmSyncEndpoint::TestShmSyncEndpoint() {} + +static UBSHcomNetTransSgeIov clientMrInfo[NN_NO4]; +static uint32_t iovCnt = NN_NO4; +static UBSHcomNetEndpointPtr syncEp = nullptr; +static TestRegMrInfo syncClientMrInfo[NN_NO4]; +static int g_nameSeed = 0; + + +static int ServerNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + syncEp = newEP; + return 0; +} + +static void EndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + if (syncEp != nullptr) { + syncEp.Set(nullptr); + } + NN_LOG_INFO("end point " << ep->Id()); +} + +static int RequestReceivedSend(const UBSHcomNetRequestContext &ctx) +{ + std::string respMsg = "Hello client, this is a reply message"; + + int result = 0; + UBSHcomNetTransRequest req((void *)(const_cast(respMsg.c_str())), respMsg.length(), 0); + if ((result = syncEp->PostSend(0, req)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + } + return 0; +} + +static int RequestReceivedSendRaw(const UBSHcomNetRequestContext &ctx) +{ + std::string respMsg = "Hello client, this is a reply message"; + + int result = 0; + UBSHcomNetTransRequest req((void *)(const_cast(respMsg.c_str())), respMsg.length(), 0); + if ((result = syncEp->PostSendRaw(req, 1)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + } + return 0; +} + +static int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} + +static int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + +static bool RegSglMem(UBSHcomNetDriver *driver, UBSHcomNetTransSgeIov mrInfo[], + std::vector &mrs) +{ + for (int i = 0; i < 4; ++i) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + mrInfo[i].lAddress = mr->GetAddress(); + mrInfo[i].lKey = mr->GetLKey(); + mrInfo[i].size = NN_NO8; + mrs.push_back(mr); + memset(reinterpret_cast(mrInfo[i].lAddress), 1, mrInfo[i].size); + } + return true; +} + +static void DestoryMem(UBSHcomNetDriver *driver, std::vector &mrs) +{ + while (!mrs.empty()) { + driver->DestroyMemoryRegion(mrs.back()); + mrs.pop_back(); + } +} + +static bool RegReadWriteMem(UBSHcomNetDriver *driver, TestRegMrInfo mrInfo[], + std::vector &mrs) +{ + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + mrInfo[0].lAddress = mr->GetAddress(); + mrInfo[0].lKey = mr->GetLKey(); + mrInfo[0].size = NN_NO8; + mrs.push_back(mr); + memset(reinterpret_cast(mrInfo[0].lAddress), '1', mrInfo[0].size); + return true; +} + +static bool RegReadWriteSglMem(UBSHcomNetDriver *driver, TestRegMrInfo mrInfo[], + std::vector &mrs) +{ + for (int i = 0; i < 4; ++i) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + mrInfo[i].lAddress = mr->GetAddress(); + mrInfo[i].lKey = mr->GetLKey(); + mrInfo[i].size = NN_NO8; + mrs.push_back(mr); + memset(reinterpret_cast(mrInfo[i].lAddress), '1', mrInfo[i].size); + } + return true; +} + +/* server new request sgl callback */ +static TestRegMrInfo syncServerMrInfo[NN_NO4]; +static int RequestReceivedServer(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("server request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + + int result = 0; + UBSHcomNetTransRequest rsp((void *)(syncServerMrInfo), sizeof(syncServerMrInfo), 0); + if ((result = syncEp->PostSend(1, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + + NN_LOG_INFO("request rsp Mr info"); + std::string readValue((char *)syncServerMrInfo[0].lAddress, syncServerMrInfo[0].size); + NN_LOG_INFO("idx:" << 0 << " key:" << syncServerMrInfo[0].lKey << " address:" << syncServerMrInfo[0].lAddress << + " size: " << syncServerMrInfo[0].size << "string: " << readValue); + return 0; +} + +static int RequestReceivedSglServer(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("server request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + + int result = 0; + UBSHcomNetTransRequest rsp((void *)(syncServerMrInfo), sizeof(syncServerMrInfo), 0); + if ((result = syncEp->PostSend(1, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + + NN_LOG_INFO("request rsp Mr info"); + for (uint16_t i = 0; i < NN_NO4; i++) { + NN_LOG_INFO("idx:" << i << " key:" << syncServerMrInfo[i].lKey << " address:" << syncServerMrInfo[i].lAddress << + " size: " << syncServerMrInfo[i].size); + } + return 0; +} + +/* client receive server mr info */ +static TestRegMrInfo getRemoteMrInfo[NN_NO4]; + +static bool CreateServerDriver(UBSHcomNetDriver *&driver, int (*reqHandler)(const UBSHcomNetRequestContext &), + UBSHcomNetDriverOptions &options) +{ + auto name = "serverSync_ep_" + std::to_string(g_nameSeed++); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, name, true); + + if (driver == nullptr) { + NN_LOG_ERROR("failed to create serverDriver already created"); + return false; + } + + options.oobType = ock::hcom::NET_OOB_UDS; + options.enableTls = false; + UBSHcomNetOobUDSListenerOptions listenOpt; + listenOpt.Name(UDSNAME); + listenOpt.perm = 0; + driver->AddOobUdsOptions(listenOpt); + + driver->RegisterNewEPHandler( + std::bind(&ServerNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(reqHandler); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + int result = 0; + if ((result = driver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("serverDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start serverDriver " << result); + return false; + } + NN_LOG_INFO("serverDriver started"); + return true; +} + +static bool CreateSyncClientDriver(UBSHcomNetDriver *&driver, UBSHcomNetDriverOptions &options) +{ + auto name = "clientSync_ep_" + std::to_string(g_nameSeed); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, name, false); + if (driver == nullptr) { + NN_LOG_ERROR("failed to create clientDriver already created"); + return false; + } + + options.oobType = ock::hcom::NET_OOB_UDS; + options.dontStartWorkers = true; + options.enableTls = false; + + int result = 0; + if ((result = driver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("clientDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start clientDriver " << result); + return false; + } + NN_LOG_INFO("clientDriver started"); + return true; +} + +void closeShmDriver(UBSHcomNetDriver *&clientDriver, UBSHcomNetDriver *&serverDriver) +{ + syncEp->Close(); + if (syncEp != nullptr) { + syncEp.Set(nullptr); + } + std::string serverName = serverDriver->Name(); + std::string clientName = clientDriver->Name(); + if (serverDriver->IsStarted()) { + serverDriver->Stop(); + } + if (serverDriver->IsInited()) { + serverDriver->UnInitialize(); + } + if (clientDriver->IsStarted()) { + clientDriver->Stop(); + } + if (clientDriver->IsInited()) { + clientDriver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(serverName); + UBSHcomNetDriver::DestroyInstance(clientName); +} + +void TestShmSyncEndpoint::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestShmSyncEndpoint::TearDown() +{ + GlobalMockObject::verify(); +} + + +TEST_F(TestShmSyncEndpoint, SyncPostSendRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(1); + std::string msg = "Hello Hello"; + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), 0); + + result = ep->PostSend(1, req); + EXPECT_EQ(SH_OK, result); + + result = ep->WaitCompletion(2); + EXPECT_EQ(SH_OK, result); + + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(2, respCtx); + std::string resp((char *)respCtx.Message()->Data(), respCtx.Header().dataLength); + NN_LOG_INFO("server response received - " << respCtx.Header().opCode << ", dataLen " << + respCtx.Header().dataLength); + EXPECT_EQ(SH_OK, result); + + MOCKER_CPP(&ShmChannel::EQEventEnqueue).defaults().will(returnValue(-1)); + result = ep->PostSend(1, req); + EXPECT_EQ(SH_RETRY_FULL, result); + + UBSHcomNetTransOpInfo innerOpInfo(1, 0, 0, NTH_TWO_SIDE); + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SH_RETRY_FULL, result); + + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SH_RETRY_FULL, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +/* set time out */ +TEST_F(TestShmSyncEndpoint, SyncPostSendFail2) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(1); + std::string msg = "Hello Hello"; + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), 0); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .stubs() + .will(returnValue(0)) + .then(returnValue(-1)); + + result = ep->PostSend(1, req); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +/* no set timeout */ +TEST_F(TestShmSyncEndpoint, SyncPostSendFail3) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(0); + std::string msg = "Hello Hello"; + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), 0); + + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .defaults() + .will(returnObjectList(0, -1)); + + result = ep->PostSend(1, req); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncPostSendFail4) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(0); + std::string msg = "Hello Hello"; + char upctx[NN_NO29]; + for (uint32_t i = 0; i < NN_NO29; ++i) { + upctx[i] = '2'; + } + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), *upctx); + + result = ep->PostSend(1, req); + EXPECT_EQ(SH_PARAM_INVALID, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncPostSendFail5) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(0); + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + + MOCKER_CPP(&ShmDataChannel::TryOccupyWithWait).defaults().will(returnValue(305)); + result = ep->PostSend(1, req); + EXPECT_EQ(SH_NOT_INITIALIZED, result); + + + UBSHcomNetTransOpInfo innerOpInfo(2, 0, 0, NTH_TWO_SIDE); + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SH_NOT_INITIALIZED, result); + + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SH_NOT_INITIALIZED, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncReceiveRetry1) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::string msg = "Hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), 0); + UBSHcomNetResponseContext respCtx {}; + + result = ep->PostSend(1, req); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(2); + EXPECT_EQ(SH_OK, result); + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset).defaults().will(returnValue(301)); + result = ep->Receive(2, respCtx); + EXPECT_EQ(SH_PARAM_INVALID, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncReceiveRetry2) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::string msg = "Hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), 0); + UBSHcomNetResponseContext respCtx {}; + + result = ep->PostSend(1, req); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(2); + EXPECT_EQ(SH_OK, result); + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset).defaults().will(returnValue(305)); + result = ep->Receive(2, respCtx); + EXPECT_EQ(SH_NOT_INITIALIZED, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncReceiveRetry3) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::string msg = "Hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), 0); + UBSHcomNetResponseContext respCtx {}; + + result = ep->PostSend(1, req); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(2); + EXPECT_EQ(SH_OK, result); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed).defaults().will(returnValue(false)); + result = ep->Receive(2, respCtx); + EXPECT_EQ(NN_MALLOC_FAILED, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncPostSendRawRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSendRaw, options); + EXPECT_TRUE(res); + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::string msg = "Hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SH_OK, result); + + result = ep->WaitCompletion(2); + EXPECT_EQ(SH_OK, result); + + UBSHcomNetResponseContext respCtx {}; + result = ep->ReceiveRaw(-1, respCtx); + std::string resp((char *)respCtx.Message()->Data(), respCtx.Header().dataLength); + NN_LOG_INFO("server response received - " << respCtx.Header().opCode << ", dataLen " << + respCtx.Header().dataLength); + EXPECT_EQ(SH_OK, result); + + MOCKER_CPP(&ShmSyncEndpoint::PostSend).defaults().will(returnObjectList(301, 314)); + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SH_PARAM_INVALID, result); + + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncPostSendRawSglRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + + bool createRes; + createRes = CreateServerDriver(serverDriver, RequestReceivedSendRaw, options); + EXPECT_TRUE(createRes); + createRes = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(createRes); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrs; + bool res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + UBSHcomNetResponseContext respCtx {}; + + result = ep->PostSendRawSgl(req, 1); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(-1); + EXPECT_EQ(SH_OK, result); + result = ep->ReceiveRawSgl(respCtx); + EXPECT_EQ(SH_OK, result); + + MOCKER_CPP(&ShmSyncEndpoint::PostSendRawSgl).defaults().will(returnObjectList(301, 314)); + result = ep->PostSendRawSgl(req, 2); + EXPECT_EQ(SH_PARAM_INVALID, result); + + result = ep->PostSendRawSgl(req, 2); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncPostSendRawSglRetry1) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool createRes; + createRes = CreateServerDriver(serverDriver, RequestReceivedSendRaw, options); + EXPECT_TRUE(createRes); + createRes = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(createRes); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(1); + + std::vector mrs; + bool res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + MOCKER_CPP(&ShmChannel::EQEventEnqueue).defaults().will(returnValue(-1)); + result = ep->PostSendRawSgl(req, 2); + EXPECT_EQ(SH_RETRY_FULL, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} +TEST_F(TestShmSyncEndpoint, SyncPostSendRawSglRetry2) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool createRes; + createRes = CreateServerDriver(serverDriver, RequestReceivedSendRaw, options); + EXPECT_TRUE(createRes); + createRes = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(createRes); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrs; + bool res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + MOCKER_CPP(&ShmSyncEndpoint::FillSglCtx).defaults().will(returnValue(301)); + result = ep->PostSendRawSgl(req, 2); + EXPECT_EQ(SH_PARAM_INVALID, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +/* set time out */ +TEST_F(TestShmSyncEndpoint, SyncPostSendRawSglFail1) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(1); + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .stubs() + .will(returnValue(0)) + .then(returnValue(-1)); + + result = ep->PostSendRawSgl(req, 2); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +/* no set timeout */ +TEST_F(TestShmSyncEndpoint, SyncPostSendRawSglFail2) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(0); + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + MOCKER_CPP(&ShmQueue::EnqueueAndNotify, int32_t(ShmQueue::*)(ShmEvent &)) + .defaults() + .will(returnObjectList(0, -1)); + + result = ep->PostSendRawSgl(req, 2); + EXPECT_EQ(SH_SEND_COMPLETION_CALLBACK_FAILURE, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncPostSendRawSglFail3) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(0); + std::string msg = "Hello Hello"; + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, NN_NO29); + result = ep->PostSendRawSgl(req, 2); + EXPECT_EQ(SH_PARAM_INVALID, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncPostSendRawSglFail4) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(0); + std::string msg = "Hello Hello"; + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, NN_NO16); + result = ep->PostSendRawSgl(req, 0); + EXPECT_EQ(NN_INVALID_PARAM, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncPostSendRawSglFail5) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(0); + std::string msg = "Hello Hello"; + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, NN_NO16); + MOCKER_CPP(&UBSHcomNetAtomicState::Compare).defaults().will(returnValue(false)); + result = ep->PostSendRawSgl(req, 2); + EXPECT_EQ(NN_EP_NOT_ESTABLISHED, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncPostSendRawSglFail6) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(0); + std::string msg = "Hello Hello"; + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, NN_NO16); + MOCKER_CPP(&MemoryRegionChecker::Validate).defaults().will(returnValue(100)); + result = ep->PostSendRawSgl(req, 2); + EXPECT_EQ(NN_INVALID_LKEY, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncReceiveRawRetry1) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSendRaw, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + UBSHcomNetResponseContext respCtx {}; + result = ep->PostSendRawSgl(req, 1); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(-1); + EXPECT_EQ(SH_OK, result); + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset).defaults().will(returnValue(301)); + result = ep->ReceiveRaw(2, respCtx); + EXPECT_EQ(SH_PARAM_INVALID, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncReceiveRawRetry2) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + UBSHcomNetResponseContext respCtx {}; + result = ep->PostSendRawSgl(req, 1); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(2); + EXPECT_EQ(SH_OK, result); + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset).defaults().will(returnValue(305)); + result = ep->ReceiveRaw(2, respCtx); + EXPECT_EQ(SH_NOT_INITIALIZED, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncReceiveRawRetry3) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + UBSHcomNetResponseContext respCtx {}; + result = ep->PostSendRawSgl(req, 1); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(2); + EXPECT_EQ(SH_OK, result); + MOCKER_CPP(&ShmSyncEndpoint::Receive).defaults().will(returnValue(306)); + result = ep->Receive(2, respCtx); + EXPECT_EQ(SH_TIME_OUT, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, SyncReceiveRawRetry4) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrs; + res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSglRequest req(clientMrInfo, iovCnt, 0); + UBSHcomNetResponseContext respCtx {}; + result = ep->PostSendRawSgl(req, 1); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(2); + EXPECT_EQ(SH_OK, result); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed, bool (UBSHcomNetMessage::*)(uint32_t)) + .defaults().will(returnValue(false)); + result = ep->Receive(2, respCtx); + EXPECT_EQ(NN_MALLOC_FAILED, result); + + DestoryMem(clientDriver, mrs); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, PostReadWriteSgl) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSglServer, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrServer; + res = RegReadWriteSglMem(serverDriver, syncServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteSglMem(clientDriver, syncClientMrInfo, mrClient); + EXPECT_TRUE(res); + + /* exchange mr info */ + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(NN_NO2); + EXPECT_EQ(SH_OK, result); + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(NN_NO2, respCtx); + EXPECT_EQ(SH_OK, result); + memcpy(getRemoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = syncClientMrInfo[i].lAddress; + iov[i].rAddress = getRemoteMrInfo[i].lAddress; + iov[i].lKey = syncClientMrInfo[i].lKey; + iov[i].rKey = getRemoteMrInfo[i].lKey; + iov[i].size = getRemoteMrInfo[i].size; + } + + UBSHcomNetTransSglRequest req(iov, NN_NO4, 0); + + result = ep->PostRead(req); + EXPECT_EQ(SH_OK, result); + + for (uint16_t i = 0; i < NN_NO4; i++) { + std::string readValue((char *)syncClientMrInfo[i].lAddress, syncClientMrInfo[i].size); + NN_LOG_INFO("value[" << i << "]= " << readValue); + } + result = ep->PostWrite(req); + EXPECT_EQ(SH_OK, result); + + MOCKER_CPP(&ShmSyncEndpoint::PostRead, + HResult(ShmSyncEndpoint::*)(ShmChannel *, const UBSHcomNetTransSglRequest &, ShmMRHandleMap &)) + .defaults() + .will(returnObjectList(301, 300)); + result = ep->PostRead(req); + EXPECT_EQ(SH_PARAM_INVALID, result); + result = ep->PostRead(req); + EXPECT_EQ(SH_ERROR, result); + + MOCKER_CPP(&ShmSyncEndpoint::PostWrite, + HResult(ShmSyncEndpoint::*)(ShmChannel *, const UBSHcomNetTransSglRequest &, ShmMRHandleMap &)) + .defaults() + .will(returnObjectList(301, 300)); + result = ep->PostWrite(req); + EXPECT_EQ(SH_PARAM_INVALID, result); + result = ep->PostWrite(req); + EXPECT_EQ(SH_ERROR, result); + + DestoryMem(serverDriver, mrServer); + DestoryMem(clientDriver, mrClient); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, PostReadWriteSglFail) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSglServer, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrServer; + res = RegReadWriteSglMem(serverDriver, syncServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteSglMem(clientDriver, syncClientMrInfo, mrClient); + EXPECT_TRUE(res); + + /* exchange mr info */ + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(NN_NO2); + EXPECT_EQ(SH_OK, result); + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(NN_NO2, respCtx); + EXPECT_EQ(SH_OK, result); + memcpy(getRemoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = syncClientMrInfo[i].lAddress; + iov[i].rAddress = getRemoteMrInfo[i].lAddress; + iov[i].lKey = syncClientMrInfo[i].lKey; + iov[i].rKey = getRemoteMrInfo[i].lKey; + iov[i].size = getRemoteMrInfo[i].size; + } + + UBSHcomNetTransSglRequest req(iov, NN_NO4, 17); + + result = ep->PostRead(req); + EXPECT_EQ(SH_PARAM_INVALID, result); + + DestoryMem(serverDriver, mrServer); + DestoryMem(clientDriver, mrClient); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, GetFromLocalMapFail) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSglServer, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrServer; + res = RegReadWriteSglMem(serverDriver, syncServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteSglMem(clientDriver, syncClientMrInfo, mrClient); + EXPECT_TRUE(res); + + /* exchange mr info */ + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(NN_NO2); + EXPECT_EQ(SH_OK, result); + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(NN_NO2, respCtx); + EXPECT_EQ(SH_OK, result); + memcpy(getRemoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = syncClientMrInfo[i].lAddress; + iov[i].rAddress = getRemoteMrInfo[i].lAddress; + iov[i].lKey = syncClientMrInfo[i].lKey; + iov[i].rKey = getRemoteMrInfo[i].lKey; + iov[i].size = getRemoteMrInfo[i].size; + } + + UBSHcomNetTransSglRequest req(iov, NN_NO4, 0); + ShmHandlePtr localMrHandle = nullptr; + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap).defaults().will(returnValue(localMrHandle)); + result = ep->PostRead(req); + EXPECT_EQ(SH_ERROR, result); + + DestoryMem(serverDriver, mrServer); + DestoryMem(clientDriver, mrClient); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, PostReadWrite) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedServer, options); + EXPECT_TRUE(res); + + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrServer; + res = RegReadWriteMem(serverDriver, syncServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(clientDriver, syncClientMrInfo, mrClient); + EXPECT_TRUE(res); + + /* exchange mr info */ + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(NN_NO2); + EXPECT_EQ(SH_OK, result); + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(NN_NO2, respCtx); + EXPECT_EQ(SH_OK, result); + memcpy(getRemoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + UBSHcomNetTransRequest req; + req.lAddress = syncClientMrInfo[0].lAddress; + req.rAddress = getRemoteMrInfo[0].lAddress; + req.lKey = syncClientMrInfo[0].lKey; + req.rKey = getRemoteMrInfo[0].lKey; + req.size = getRemoteMrInfo[0].size; + + NN_LOG_INFO("req " + << "req.lAddress: " << req.lAddress << " req.rAddress: " << req.rAddress << " req.lKey: " << req.lKey << + " req.rKey: " << req.rKey << " req.size:" << req.size); + result = ep->PostRead(req); + EXPECT_EQ(SH_OK, result); + + std::string readValue((char *)syncClientMrInfo[0].lAddress, syncClientMrInfo[0].size); + NN_LOG_INFO("value[" << 0 << "]= " << readValue); + result = ep->PostWrite(req); + EXPECT_EQ(SH_OK, result); + + + MOCKER_CPP(&ShmSyncEndpoint::PostRead, + HResult(ShmSyncEndpoint::*)(ShmChannel *, const UBSHcomNetTransRequest &, ShmMRHandleMap &)) + .defaults() + .will(returnObjectList(301, 300)); + result = ep->PostRead(req); + EXPECT_EQ(SH_PARAM_INVALID, result); + result = ep->PostRead(req); + EXPECT_EQ(SH_ERROR, result); + + MOCKER_CPP(&ShmSyncEndpoint::PostWrite, + HResult(ShmSyncEndpoint::*)(ShmChannel *, const UBSHcomNetTransRequest &, ShmMRHandleMap &)) + .defaults() + .will(returnObjectList(301, 300)); + result = ep->PostWrite(req); + EXPECT_EQ(SH_PARAM_INVALID, result); + result = ep->PostWrite(req); + EXPECT_EQ(SH_ERROR, result); + + DestoryMem(serverDriver, mrServer); + DestoryMem(clientDriver, mrClient); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, PostReadWriteFail) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedServer, options); + EXPECT_TRUE(res); + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrServer; + res = RegReadWriteMem(serverDriver, syncServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(clientDriver, syncClientMrInfo, mrClient); + EXPECT_TRUE(res); + + /* exchange mr info */ + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(NN_NO2); + EXPECT_EQ(SH_OK, result); + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(NN_NO2, respCtx); + EXPECT_EQ(SH_OK, result); + memcpy(getRemoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + UBSHcomNetTransRequest req; + req.lAddress = syncClientMrInfo[0].lAddress; + req.rAddress = getRemoteMrInfo[0].lAddress; + req.lKey = syncClientMrInfo[0].lKey; + req.rKey = getRemoteMrInfo[0].lKey; + req.size = getRemoteMrInfo[0].size; + req.upCtxSize = 18; + NN_LOG_INFO("req " + << "req.lAddress: " << req.lAddress << " req.rAddress: " << req.rAddress << " req.lKey: " << req.lKey << + " req.rKey: " << req.rKey << " req.size:" << req.size); + + result = ep->PostWrite(req); + EXPECT_EQ(SH_PARAM_INVALID, result); + + DestoryMem(serverDriver, mrServer); + DestoryMem(clientDriver, mrClient); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, GetRemoteMrFdsFail) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedServer, options); + EXPECT_TRUE(res); + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + std::vector mrServer; + res = RegReadWriteMem(serverDriver, syncServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(clientDriver, syncClientMrInfo, mrClient); + EXPECT_TRUE(res); + + /* exchange mr info */ + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + result = ep->WaitCompletion(NN_NO2); + EXPECT_EQ(SH_OK, result); + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(NN_NO2, respCtx); + EXPECT_EQ(SH_OK, result); + memcpy(getRemoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + UBSHcomNetTransRequest req; + req.lAddress = syncClientMrInfo[0].lAddress; + req.rAddress = getRemoteMrInfo[0].lAddress; + req.lKey = syncClientMrInfo[0].lKey; + req.rKey = getRemoteMrInfo[0].lKey; + req.size = getRemoteMrInfo[0].size; + + NN_LOG_INFO("req " + << "req.lAddress: " << req.lAddress << " req.rAddress: " << req.rAddress << " req.lKey: " << req.lKey << + " req.rKey: " << req.rKey << " req.size:" << req.size); + + MOCKER_CPP(&ShmChannel::GetRemoteMrFds).defaults().will(returnValue(306)); + result = ep->PostWrite(req); + EXPECT_EQ(SH_TIME_OUT, result); + + DestoryMem(serverDriver, mrServer); + DestoryMem(clientDriver, mrClient); + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, NetSyncEndpointShmFuncation) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + UBSHcomEpOptions epOptions {}; + result = ep->SetEpOption(epOptions); + EXPECT_EQ(NN_OK, result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, GetRemoteUdsIdInfo) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + UBSHcomNetUdsIdInfo idInfo {}; + if (syncEp != nullptr) { + result = syncEp->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(NN_OK, result); + NN_LOG_INFO("========new endpoint remote uds ids, pid: " << idInfo.pid << " uid: " << idInfo.uid << " gid: " << + idInfo.gid << " result:" << result); + } + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, GetRemoteUdsIdInfoFail1) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + UBSHcomNetUdsIdInfo idInfo {}; + result = ep->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(NN_UDS_ID_INFO_NOT_SUPPORT, result); + NN_LOG_INFO("=======new endpoint remote uds ids, pid: " << idInfo.pid << " uid: " << idInfo.uid << " gid: " << + idInfo.gid << " result:" << result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} + +TEST_F(TestShmSyncEndpoint, GetRemoteUdsIdInfoFail2) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriverOptions options {}; + options.mode = NET_EVENT_POLLING; + + bool res; + res = CreateServerDriver(serverDriver, RequestReceivedSend, options); + EXPECT_TRUE(res); + res = CreateSyncClientDriver(clientDriver, options); + EXPECT_TRUE(res); + clientDriver->Connect(UDSNAME, 0, "hello server", ep, NET_EP_SELF_POLLING); + + UBSHcomNetUdsIdInfo idInfo {}; + MOCKER_CPP(&UBSHcomNetAtomicState::Compare).defaults().will(returnValue(false)); + result = syncEp->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(NN_EP_NOT_ESTABLISHED, result); + NN_LOG_INFO("=======new endpoint remote uds ids, pid: " << idInfo.pid << " uid: " << idInfo.uid << " gid: " << + idInfo.gid << " result:" << result); + + ep->Close(); + closeShmDriver(clientDriver, serverDriver); +} diff --git a/test/llt/testcase/transport/shm/test_shm_sync_endpoint.h b/test/llt/testcase/transport/shm/test_shm_sync_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..666b233cacb2a7826020835b739cbc22501707eb --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_sync_endpoint.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_SHM_SYNC_ENDPOINT_H +#define HCOM_TEST_SHM_SYNC_ENDPOINT_H +#include +#include + +class TestShmSyncEndpoint : public testing::Test { +public: + TestShmSyncEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_SHM_SYNC_ENDPOINT_H diff --git a/test/llt/testcase/transport/shm/test_shm_tls.cpp b/test/llt/testcase/transport/shm/test_shm_tls.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4570b515262515b36d593e034c83dc1cdaeceb7 --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_tls.cpp @@ -0,0 +1,1034 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "hcom.h" +#include "net_shm_sync_endpoint.h" +#include "net_shm_async_endpoint.h" +#include "net_security_alg.h" +#include "shm_worker.h" +#include "test_shm_common.h" +#include "test_shm_tls.h" + +using namespace ock::hcom; +TestShmTls::TestShmTls() {} + +UBSHcomNetEndpointPtr tlsShmServerEp = nullptr; +UBSHcomNetEndpointPtr tlsShmClientEp = nullptr; +static TestRegMrInfo tlsClientMrInfo; +static TestRegMrInfo tlsServerMrInfo; +static UBSHcomNetTransSgeIov tlsShmClientMrInfo[NN_NO4]; +UBSHcomNetDriverOptions tlsShmOptions {}; + +UBSHcomNetDriver *tlsShmCDriver = nullptr; +UBSHcomNetDriver *tlsShmSDriver = nullptr; +std::string shmCertPath; +static uint32_t iovCnt = NN_NO4; +static sem_t sem; +static int g_nameSeed = 0; + +int ShmValidateTlsCert() +{ + char *buffer; + + if ((buffer = getcwd(NULL, 0)) == NULL) { + NN_LOG_ERROR("Cet path for TLS cert failed"); + return -1; + } + + std::string currentPath = buffer; + shmCertPath = currentPath + "/../test/opensslcrt/normalCert1"; + + if (!CanonicalPath(shmCertPath)) { + NN_LOG_ERROR("TLS cert path check failed " << shmCertPath); + return -1; + } + + return 0; +} + +static void SetEncryptValue() +{ + // this step should exec after client connect and ep created + std::string value = "value from server"; + size_t encryptLen = tlsShmServerEp->EstimatedEncryptLen(value.length()); + void *cipher = malloc(encryptLen); + tlsShmServerEp->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + memcpy(reinterpret_cast(tlsServerMrInfo.lAddress), cipher, encryptLen); +} + +static size_t SetClientEncryptValue() +{ + std::string value = "value from client"; + size_t encryptLen = tlsShmClientEp->EstimatedEncryptLen(value.length()); + void *cipher = malloc(encryptLen); + tlsShmClientEp->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + + memcpy(reinterpret_cast(tlsClientMrInfo.lAddress), cipher, encryptLen); + return encryptLen; +} + +static int ServerNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + tlsShmServerEp = newEP; + SetEncryptValue(); + return 0; +} + +static int ServerNewEndPointSend(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, + const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + tlsShmServerEp = newEP; + return 0; +} + +static void EndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + NN_LOG_INFO("end point " << ep->Id()); +} + +static int RequestReceivedServer(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("server request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + + int result = 0; + UBSHcomNetTransRequest rsp(&tlsServerMrInfo, sizeof(tlsServerMrInfo), 0); + if ((result = tlsShmServerEp->PostSend(1, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + NN_LOG_INFO("request rsp Mr info"); + return 0; +} + +static TestRegMrInfo getRemoteMrInfo; +static int RequestReceivedClient(const UBSHcomNetRequestContext &ctx) +{ + memcpy(&getRemoteMrInfo, ctx.Message()->Data(), ctx.Message()->DataLen()); + NN_LOG_INFO("get remote Mr info"); + sem_post(&sem); + return 0; +} + +static int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} + +static int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + sem_post(&sem); + return 0; +} + +static void Erase(void *pass, int len) {} +static int Verify(void *x509, const char *path) +{ + return 0; +} + +static bool CertCallback(const std::string &name, std::string &value) +{ + value = shmCertPath + "/server/cert.pem"; + return true; +} + +static bool PrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = shmCertPath + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool CACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = shmCertPath + "/CA/cacert.pem"; + cb = std::bind(&Verify, std::placeholders::_1, std::placeholders::_2); + return true; +} + +static bool RegSglMem(UBSHcomNetDriver *driver, UBSHcomNetTransSgeIov mrInfo[], + std::vector &mrs) +{ + for (int i = 0; i < 4; ++i) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + mrInfo[i].lAddress = mr->GetAddress(); + mrInfo[i].lKey = mr->GetLKey(); + mrInfo[i].size = NN_NO8; + mrs.push_back(mr); + memset(reinterpret_cast(mrInfo[i].lAddress), 0, mrInfo[i].size); + } + return true; +} + +static void DestoryTlsMem(UBSHcomNetDriver *driver, std::vector &mrs) +{ + while (!mrs.empty()) { + driver->DestroyMemoryRegion(mrs.back()); + mrs.pop_back(); + } +} + +static bool RegReadWriteMem(UBSHcomNetDriver *driver, TestRegMrInfo mrInfo[], + std::vector &mrs) +{ + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO1024, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + mrInfo[0].lAddress = mr->GetAddress(); + mrInfo[0].lKey = mr->GetLKey(); + mrInfo[0].size = NN_NO1024; + mrs.push_back(mr); + memset(reinterpret_cast(mrInfo[0].lAddress), 0, mrInfo[0].size); + + return true; +} + +static bool CreateServerDriver(UBSHcomNetDriver *&driver, int (*reqHandler)(const UBSHcomNetRequestContext &), + UBSHcomNetDriverOptions &tlsShmOptions) +{ + auto name = "server_tls_" + std::to_string(g_nameSeed++); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, name, true); + + if (driver == nullptr) { + NN_LOG_ERROR("failed to create tlsShmSDriver already created"); + return false; + } + tlsShmOptions.oobType = ock::hcom::NET_OOB_UDS; + tlsShmOptions.mode = ock::hcom::NET_EVENT_POLLING; + + UBSHcomNetOobUDSListenerOptions listenOpt; + listenOpt.Name(UDSNAME); + listenOpt.perm = 0; + driver->AddOobUdsOptions(listenOpt); + + driver->RegisterNewEPHandler( + std::bind(&ServerNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(reqHandler); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + driver->RegisterTLSCertificationCallback(std::bind(&CertCallback, std::placeholders::_1, std::placeholders::_2)); + driver->RegisterTLSCaCallback(std::bind(&CACallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + driver->RegisterTLSPrivateKeyCallback(std::bind(&PrivateKeyCallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + + int result = 0; + if ((result = driver->Initialize(tlsShmOptions)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("tlsShmSDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start asyncServerDriver " << result); + return false; + } + NN_LOG_INFO("tlsShmSDriver started"); + return true; +} + +static bool CreateServerDriverSend(UBSHcomNetDriver *&driver, int (*reqHandler)(const UBSHcomNetRequestContext &), + UBSHcomNetDriverOptions &tlsShmOptions) +{ + auto name = "server_tls_" + std::to_string(g_nameSeed++); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, name, true); + + if (driver == nullptr) { + NN_LOG_ERROR("failed to create tlsShmSDriver already created"); + return false; + } + tlsShmOptions.oobType = ock::hcom::NET_OOB_UDS; + tlsShmOptions.mode = ock::hcom::NET_EVENT_POLLING; + + UBSHcomNetOobUDSListenerOptions listenOpt; + listenOpt.Name(UDSNAME); + listenOpt.perm = 0; + driver->AddOobUdsOptions(listenOpt); + + driver->RegisterNewEPHandler( + std::bind(&ServerNewEndPointSend, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(reqHandler); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + driver->RegisterTLSCertificationCallback(std::bind(&CertCallback, std::placeholders::_1, std::placeholders::_2)); + driver->RegisterTLSCaCallback(std::bind(&CACallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + driver->RegisterTLSPrivateKeyCallback(std::bind(&PrivateKeyCallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + + int result = 0; + if ((result = driver->Initialize(tlsShmOptions)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("tlsShmSDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start asyncServerDriver " << result); + return false; + } + NN_LOG_INFO("tlsShmSDriver started"); + return true; +} + +static bool CreateClientDriver(UBSHcomNetDriver *&driver, int (*reqHandler)(const UBSHcomNetRequestContext &), + UBSHcomNetDriverOptions &tlsShmOptions) +{ + auto name = "client_tls_" + std::to_string(g_nameSeed++); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, name, false); + if (driver == nullptr) { + NN_LOG_ERROR("failed to create tlsShmCDriver already created"); + return false; + } + tlsShmOptions.oobType = ock::hcom::NET_OOB_UDS; + tlsShmOptions.mode = ock::hcom::NET_EVENT_POLLING; + + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(reqHandler); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + driver->RegisterTLSCertificationCallback(std::bind(&CertCallback, std::placeholders::_1, std::placeholders::_2)); + driver->RegisterTLSCaCallback(std::bind(&CACallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + driver->RegisterTLSPrivateKeyCallback(std::bind(&PrivateKeyCallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + + int result = 0; + if ((result = driver->Initialize(tlsShmOptions)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("tlsShmCDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start tlsShmCDriver " << result); + return false; + } + NN_LOG_INFO("tlsShmCDriver started"); + return true; +} + +static bool CreateSyncClientDriver(UBSHcomNetDriver *&driver, UBSHcomNetDriverOptions &options) +{ + auto name = "clientSync_ep_" + std::to_string(g_nameSeed++); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, name, false); + if (driver == nullptr) { + NN_LOG_ERROR("failed to create clientDriver already created"); + return false; + } + + options.mode = ock::hcom::NET_EVENT_POLLING; + options.oobType = ock::hcom::NET_OOB_UDS; + options.dontStartWorkers = true; + + driver->RegisterTLSCertificationCallback(std::bind(&CertCallback, std::placeholders::_1, std::placeholders::_2)); + driver->RegisterTLSCaCallback(std::bind(&CACallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + driver->RegisterTLSPrivateKeyCallback(std::bind(&PrivateKeyCallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + + int result = 0; + if ((result = driver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("clientDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start clientDriver " << result); + return false; + } + NN_LOG_INFO("clientDriver started"); + return true; +} + +void TlsCloseShmDriver(UBSHcomNetDriver *&tlsShmCDriver, UBSHcomNetDriver *&tlsShmSDriver) +{ + tlsShmClientEp->Close(); + tlsShmServerEp->Close(); + if (tlsShmServerEp != nullptr) { + tlsShmServerEp.Set(nullptr); + } + if (tlsShmClientEp != nullptr) { + tlsShmClientEp.Set(nullptr); + } + std::string serverName = tlsShmSDriver->Name(); + std::string clientName = tlsShmCDriver->Name(); + if (tlsShmCDriver->IsStarted()) { + tlsShmCDriver->Stop(); + } + if (tlsShmCDriver->IsInited()) { + tlsShmCDriver->UnInitialize(); + } + if (tlsShmSDriver->IsStarted()) { + tlsShmSDriver->Stop(); + } + if (tlsShmSDriver->IsInited()) { + tlsShmSDriver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(serverName); + UBSHcomNetDriver::DestroyInstance(clientName); +} + +void TestShmTls::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); + ASSERT_EQ(0, ShmValidateTlsCert()); +} + +void TestShmTls::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestShmTls, PostSendTls) +{ + NResult result; + + tlsShmOptions.enableTls = true; + CreateServerDriverSend(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + + result = tlsShmClientEp->PostSend(1, req); + EXPECT_EQ(SH_OK, result); + + MOCKER_CPP(&AesGcm128::Encrypt, bool (AesGcm128::*)(const unsigned char *, const unsigned char *, + const unsigned char *, size_t, unsigned char *, size_t &)) + .defaults() + .will(returnValue(false)); + + result = tlsShmClientEp->PostSend(1, req); + EXPECT_EQ(NN_ENCRYPT_FAILED, result); + + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, PostSendTlsCipherSuite256) +{ + NResult result; + + tlsShmOptions.enableTls = true; + tlsShmOptions.cipherSuite = ock::hcom::AES_GCM_256; + CreateServerDriverSend(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + result = tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + EXPECT_EQ(SH_OK, result); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + + result = tlsShmClientEp->PostSend(1, req); + EXPECT_EQ(SH_OK, result); + + MOCKER_CPP(&AesGcm128::Encrypt, bool (AesGcm128::*)(const unsigned char *, const unsigned char *, + const unsigned char *, size_t, unsigned char *, size_t &)) + .defaults() + .will(returnValue(false)); + + result = tlsShmClientEp->PostSend(1, req); + EXPECT_EQ(NN_ENCRYPT_FAILED, result); + + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, PostSendTlsCipherSuiteUnknown) +{ + tlsShmOptions.enableTls = true; + tlsShmOptions.cipherSuite = ock::hcom::UBSHcomNetCipherSuite(4); + auto result = CreateServerDriverSend(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + EXPECT_EQ(false, result); + + result = CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + EXPECT_EQ(false, result); +} + +TEST_F(TestShmTls, PostSendOpInfoTls) +{ + NResult result; + + tlsShmOptions.enableTls = true; + tlsShmOptions.cipherSuite = ock::hcom::AES_GCM_128; + CreateServerDriverSend(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + UBSHcomNetTransOpInfo innerOpInfo(2, 0, 0, NTH_TWO_SIDE); + result = tlsShmClientEp->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SH_OK, result); + + MOCKER_CPP(&AesGcm128::Encrypt, bool (AesGcm128::*)(const unsigned char *, const unsigned char *, + const unsigned char *, size_t, unsigned char *, size_t &)) + .defaults() + .will(returnValue(false)); + + result = tlsShmClientEp->PostSend(1, req, innerOpInfo); + EXPECT_EQ(NN_ENCRYPT_FAILED, result); + + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, PostSendRawTls) +{ + NResult result; + + tlsShmOptions.enableTls = true; + CreateServerDriverSend(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + + result = tlsShmClientEp->PostSendRaw(req, 1); + EXPECT_EQ(SH_OK, result); + + MOCKER_CPP(&AesGcm128::Encrypt, bool (AesGcm128::*)(const unsigned char *, const unsigned char *, + const unsigned char *, size_t, unsigned char *, size_t &)) + .defaults() + .will(returnValue(false)); + + result = tlsShmClientEp->PostSendRaw(req, 1); + EXPECT_EQ(NN_ENCRYPT_FAILED, result); + + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, PostSendRawSglTls) +{ + NResult result; + + tlsShmOptions.enableTls = true; + CreateServerDriverSend(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + + std::vector mrs; + bool res = RegSglMem(tlsShmCDriver, tlsShmClientMrInfo, mrs); + EXPECT_TRUE(res); + UBSHcomNetTransSglRequest reqSgl(tlsShmClientMrInfo, iovCnt, 0); + result = tlsShmClientEp->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(SH_OK, result); + + MOCKER_CPP(&AesGcm128::Encrypt, bool (AesGcm128::*)(const unsigned char *, const unsigned char *, + const unsigned char *, size_t, unsigned char *, size_t &)) + .defaults() + .will(returnValue(false)); + + result = tlsShmClientEp->PostSendRawSgl(reqSgl, 1); + EXPECT_EQ(NN_ENCRYPT_FAILED, result); + + DestoryTlsMem(tlsShmCDriver, mrs); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, PostTlsReadWrite) +{ + NResult result; + tlsShmOptions.enableTls = true; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + sem_init(&sem, 0, 0); + + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = tlsShmClientEp->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + + sem_wait(&sem); + + UBSHcomNetTransRequest req; + size_t encryptLen = SetClientEncryptValue(); + req.lAddress = tlsClientMrInfo.lAddress; + req.rAddress = getRemoteMrInfo.lAddress; + req.lKey = tlsClientMrInfo.lKey; + req.rKey = getRemoteMrInfo.lKey; + req.size = encryptLen; + + result = tlsShmClientEp->PostRead(req); + EXPECT_EQ(SH_OK, result); + sem_wait(&sem); + + void *readValue = reinterpret_cast(req.lAddress); + size_t rawLen = tlsShmClientEp->EstimatedDecryptLen(req.size); + void *rawValue = malloc(rawLen); + tlsShmClientEp->Decrypt(readValue, req.size, rawValue, rawLen); + NN_LOG_INFO("post read value is : " << rawValue); + NN_LOG_INFO("value[" << 0 << "]= " << readValue); + + NN_LOG_INFO("=========Read end ,Write start==========="); + SetClientEncryptValue(); + result = tlsShmClientEp->PostWrite(req); + EXPECT_EQ(SH_OK, result); + void *readServerValue = reinterpret_cast(req.rAddress); + size_t rawServerLen = tlsShmClientEp->EstimatedDecryptLen(req.size); + void *rawServerValue = malloc(rawLen); + tlsShmClientEp->Decrypt(readServerValue, req.size, rawServerValue, rawServerLen); + NN_LOG_INFO("post Write value is : " << rawServerValue); + NN_LOG_INFO("value[" << 0 << "]= " << rawServerValue); + + free(rawValue); + free(rawServerValue); + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, PostTlsEncryptFail) +{ + NResult result; + tlsShmOptions.enableTls = true; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + sem_init(&sem, 0, 0); + + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = tlsShmClientEp->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + + sem_wait(&sem); + + /* Set Client Encrypt Value ,when value.length() is 0 */ + std::string value; + size_t encryptLen = tlsShmClientEp->EstimatedEncryptLen(value.length()); + EXPECT_EQ(0, encryptLen); + value = "value from client"; + encryptLen = tlsShmClientEp->EstimatedEncryptLen(value.length()); + void *cipher = malloc(encryptLen); + + /* Set Client Encrypt Value ,AesGcm128::Encrypt is fail */ + MOCKER_CPP(&AesGcm128::Encrypt, bool (AesGcm128::*)(const unsigned char *, const unsigned char *, + const unsigned char *, size_t, unsigned char *, size_t &)) + .defaults() + .will(returnValue(false)); + result = tlsShmClientEp->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + EXPECT_EQ(NN_ERROR, result); + + free(cipher); + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, PostTlsEncryptFail1) +{ + NResult result; + tlsShmOptions.enableTls = false; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + sem_init(&sem, 0, 0); + + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = tlsShmClientEp->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + + sem_wait(&sem); + + /* Encrypt ,when Options.enableTls = false */ + std::string value = "value from client"; + size_t encryptLen = tlsShmClientEp->EstimatedEncryptLen(value.length()); + EXPECT_EQ(0, encryptLen); + void *cipher = malloc(encryptLen); + result = tlsShmClientEp->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + EXPECT_EQ(NN_ERROR, result); + + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, PostTlsDecryptFail) +{ + NResult result; + tlsShmOptions.enableTls = true; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + sem_init(&sem, 0, 0); + + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = tlsShmClientEp->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + + sem_wait(&sem); + + UBSHcomNetTransRequest req; + size_t encryptLen = SetClientEncryptValue(); + req.lAddress = tlsClientMrInfo.lAddress; + req.rAddress = getRemoteMrInfo.lAddress; + req.lKey = tlsClientMrInfo.lKey; + req.rKey = getRemoteMrInfo.lKey; + req.size = encryptLen; + + void *readValue = reinterpret_cast(req.lAddress); + size_t rawLen = tlsShmClientEp->EstimatedDecryptLen(req.size); + void *rawValue = malloc(rawLen); + + /* Set Decrypt ,AesGcm128::Decrypt is fail */ + MOCKER_CPP(&AesGcm128::Decrypt, + bool (AesGcm128::*)(const unsigned char *, const unsigned char *, size_t, unsigned char *, size_t &)) + .defaults() + .will(returnValue(false)); + result = tlsShmClientEp->Decrypt(readValue, req.size, rawValue, rawLen); + EXPECT_EQ(NN_ERROR, result); + + free(rawValue); + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, PostTlsDecryptFail1) +{ + NResult result; + tlsShmOptions.enableTls = false; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsShmOptions); + CreateClientDriver(tlsShmCDriver, RequestReceivedClient, tlsShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp); + sem_init(&sem, 0, 0); + + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = tlsShmClientEp->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + + sem_wait(&sem); + + /* Decrypt ,when Options.enableTls = false */ + void *readValue = reinterpret_cast(tlsClientMrInfo.lAddress); + size_t rawLen = tlsShmClientEp->EstimatedDecryptLen(tlsClientMrInfo.size); + EXPECT_EQ(0, rawLen); + void *rawValue = malloc(rawLen); + result = tlsShmClientEp->Decrypt(readValue, tlsClientMrInfo.lAddress, rawValue, rawLen); + EXPECT_EQ(NN_ERROR, result); + + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, SyncPostTlsReadWrite) +{ + NResult result; + UBSHcomNetDriverOptions tlsSyncShmOptions {}; + tlsSyncShmOptions.enableTls = true; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsSyncShmOptions); + CreateSyncClientDriver(tlsShmCDriver, tlsSyncShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp, NET_EP_SELF_POLLING); + + /* exchange mr info */ + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = tlsShmClientEp->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + result = tlsShmClientEp->WaitCompletion(NN_NO2); + EXPECT_EQ(SH_OK, result); + UBSHcomNetResponseContext respCtx {}; + result = tlsShmClientEp->Receive(NN_NO2, respCtx); + EXPECT_EQ(SH_OK, result); + memcpy(&getRemoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + + UBSHcomNetTransRequest req; + size_t encryptLen = SetClientEncryptValue(); + req.lAddress = tlsClientMrInfo.lAddress; + req.rAddress = getRemoteMrInfo.lAddress; + req.lKey = tlsClientMrInfo.lKey; + req.rKey = getRemoteMrInfo.lKey; + req.size = encryptLen; + + result = tlsShmClientEp->PostRead(req); + EXPECT_EQ(SH_OK, result); + + + void *readValue = reinterpret_cast(req.lAddress); + size_t rawLen = tlsShmClientEp->EstimatedDecryptLen(req.size); + void *rawValue = malloc(rawLen); + tlsShmClientEp->Decrypt(readValue, req.size, rawValue, rawLen); + NN_LOG_INFO("post read value is : " << rawValue); + NN_LOG_INFO("value[" << 0 << "]= " << readValue); + + NN_LOG_INFO("=========Read end ,Write start==========="); + SetClientEncryptValue(); + result = tlsShmClientEp->PostWrite(req); + EXPECT_EQ(SH_OK, result); + void *readServerValue = reinterpret_cast(req.rAddress); + size_t rawServerLen = tlsShmClientEp->EstimatedDecryptLen(req.size); + void *rawServerValue = malloc(rawLen); + tlsShmClientEp->Decrypt(readServerValue, req.size, rawServerValue, rawServerLen); + NN_LOG_INFO("post Write value is : " << rawServerValue); + NN_LOG_INFO("value[" << 0 << "]= " << rawServerValue); + + free(rawValue); + free(rawServerValue); + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, SyncPostTlsEncryptFail) +{ + NResult result; + UBSHcomNetDriverOptions tlsSyncShmOptions {}; + tlsSyncShmOptions.enableTls = true; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsSyncShmOptions); + CreateSyncClientDriver(tlsShmCDriver, tlsSyncShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp, NET_EP_SELF_POLLING); + + /* Set Client Encrypt Value ,when value.length() is 0 */ + std::string value; + size_t encryptLen = tlsShmClientEp->EstimatedEncryptLen(value.length()); + EXPECT_EQ(0, encryptLen); + + /* Set Client Encrypt Value ,AesGcm128::Encrypt is fail */ + value = "value from client"; + encryptLen = tlsShmClientEp->EstimatedEncryptLen(value.length()); + void *cipher = malloc(encryptLen); + MOCKER_CPP(&AesGcm128::Encrypt, bool (AesGcm128::*)(const unsigned char *, const unsigned char *, + const unsigned char *, size_t, unsigned char *, size_t &)) + .defaults() + .will(returnValue(false)); + result = tlsShmClientEp->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + EXPECT_EQ(NN_ERROR, result); + + free(cipher); + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, SyncPostTlsEncryptFail1) +{ + NResult result; + UBSHcomNetDriverOptions tlsSyncShmOptions {}; + tlsSyncShmOptions.enableTls = false; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsSyncShmOptions); + CreateSyncClientDriver(tlsShmCDriver, tlsSyncShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp, NET_EP_SELF_POLLING); + + /* Encrypt ,when Options.enableTls = false */ + std::string value = "value from client"; + size_t encryptLen = tlsShmClientEp->EstimatedEncryptLen(value.length()); + EXPECT_EQ(0, encryptLen); + void *cipher = malloc(encryptLen); + result = tlsShmClientEp->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + EXPECT_EQ(NN_ERROR, result); + + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, SyncPostTlsDecryptFail) +{ + NResult result; + UBSHcomNetDriverOptions tlsSyncShmOptions {}; + tlsSyncShmOptions.enableTls = true; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsSyncShmOptions); + CreateSyncClientDriver(tlsShmCDriver, tlsSyncShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp, NET_EP_SELF_POLLING); + /* exchange mr info */ + std::string msg = "Transfer MrInfo of the client to the server."; + UBSHcomNetTransRequest msgReq((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = tlsShmClientEp->PostSend(1, msgReq); + EXPECT_EQ(SH_OK, result); + result = tlsShmClientEp->WaitCompletion(NN_NO2); + EXPECT_EQ(SH_OK, result); + UBSHcomNetResponseContext respCtx {}; + result = tlsShmClientEp->Receive(NN_NO2, respCtx); + EXPECT_EQ(SH_OK, result); + memcpy(&getRemoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + + UBSHcomNetTransRequest req; + size_t encryptLen = SetClientEncryptValue(); + req.lAddress = tlsClientMrInfo.lAddress; + req.rAddress = getRemoteMrInfo.lAddress; + req.lKey = tlsClientMrInfo.lKey; + req.rKey = getRemoteMrInfo.lKey; + req.size = encryptLen; + + result = tlsShmClientEp->PostRead(req); + EXPECT_EQ(SH_OK, result); + + void *readValue = reinterpret_cast(req.lAddress); + size_t rawLen = tlsShmClientEp->EstimatedDecryptLen(req.size); + void *rawValue = malloc(rawLen); + + /* Set Decrypt ,AesGcm128::Decrypt is fail */ + MOCKER_CPP(&AesGcm128::Decrypt, + bool (AesGcm128::*)(const unsigned char *, const unsigned char *, size_t, unsigned char *, size_t &)) + .defaults() + .will(returnValue(false)); + result = tlsShmClientEp->Decrypt(readValue, req.size, rawValue, rawLen); + EXPECT_EQ(NN_ERROR, result); + + free(rawValue); + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} + +TEST_F(TestShmTls, SyncPostTlsDecryptFail1) +{ + NResult result; + UBSHcomNetDriverOptions tlsSyncShmOptions {}; + tlsSyncShmOptions.enableTls = false; + CreateServerDriver(tlsShmSDriver, RequestReceivedServer, tlsSyncShmOptions); + CreateSyncClientDriver(tlsShmCDriver, tlsSyncShmOptions); + + bool res; + std::vector mrServer; + res = RegReadWriteMem(tlsShmSDriver, &tlsServerMrInfo, mrServer); + EXPECT_TRUE(res); + std::vector mrClient; + res = RegReadWriteMem(tlsShmCDriver, &tlsClientMrInfo, mrClient); + EXPECT_TRUE(res); + + tlsShmCDriver->Connect(UDSNAME, 0, "hello server", tlsShmClientEp, NET_EP_SELF_POLLING); + + /* Decrypt ,when Options.enableTls = false */ + void *readValue = reinterpret_cast(tlsClientMrInfo.lAddress); + size_t rawLen = tlsShmClientEp->EstimatedDecryptLen(tlsClientMrInfo.size); + EXPECT_EQ(0, rawLen); + void *rawValue = malloc(rawLen); + result = tlsShmClientEp->Decrypt(readValue, tlsClientMrInfo.lAddress, rawValue, rawLen); + EXPECT_EQ(NN_ERROR, result); + + DestoryTlsMem(tlsShmSDriver, mrServer); + DestoryTlsMem(tlsShmCDriver, mrClient); + TlsCloseShmDriver(tlsShmCDriver, tlsShmSDriver); +} \ No newline at end of file diff --git a/test/llt/testcase/transport/shm/test_shm_tls.h b/test/llt/testcase/transport/shm/test_shm_tls.h new file mode 100644 index 0000000000000000000000000000000000000000..dca73f279409e6a70cd205ab52be5ebd677fb4db --- /dev/null +++ b/test/llt/testcase/transport/shm/test_shm_tls.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_TEST_SHM_TLS_H +#define HCOM_TEST_SHM_TLS_H +#include +#include +class TestShmTls : public testing::Test { +public: + TestShmTls(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_SHM_TLS_H diff --git a/test/llt/testcase/transport/sock/test_negative_tcp_driver.cpp b/test/llt/testcase/transport/sock/test_negative_tcp_driver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..58ef0b6bad0f53a40e828b20bfa713f7415a6cc3 --- /dev/null +++ b/test/llt/testcase/transport/sock/test_negative_tcp_driver.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "test_negative_tcp_driver.h" +#include "mockcpp/mockcpp.hpp" +#include "ut_helper.h" + +using namespace ock::hcom; + +TestNegativeTcpDriver::TestNegativeTcpDriver() {} + +void TestNegativeTcpDriver::SetUp() +{ + MOCK_VERSION + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestNegativeTcpDriver::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestNegativeTcpDriver, Overload) +{ + NResult result; + UBSHcomNetDriver *server; + UBSHcomNetDriverDeviceInfo deviceInfo; + bool ret = server->LocalSupport(ock::hcom::TCP, deviceInfo); + ASSERT_EQ(true, ret); + for (int i = 0; i < 10; ++i) { + result = UTHelper::GetDriver(server, DRIVER_STATE_START, true, UBSHcomNetDriverProtocol::TCP); + std::cout << "create driver " << i << std::endl; + UT_CHECK_RESULT_OK(result) + server->Stop(); + server->UnInitialize(); + } +} + +TEST_F(TestNegativeTcpDriver, UseBeforeInit) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver, *server; + result = UTHelper::GetDriver(server, DRIVER_STATE_START, true, UBSHcomNetDriverProtocol::TCP); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_NOT_NULL(server) + result = UTHelper::GetDriver(driver, DRIVER_STATE_NONE, false, UBSHcomNetDriverProtocol::TCP); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_NOT_NULL(driver) + + UBSHcomNetMemoryRegionPtr mr; + result = driver->CreateMemoryRegion(NN_NO1024, mr); + UT_CHECK_RESULT_NOK(result) + result = driver->Start(); + UT_CHECK_RESULT_NOK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_NOK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_INIT); + UT_CHECK_RESULT_OK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_NOK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_START); + UT_CHECK_RESULT_OK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(server, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + std::string name1 = server->Name(); + std::string name2 = driver->Name(); + UBSHcomNetDriver::DestroyInstance(name1); + UBSHcomNetDriver::DestroyInstance(name2); +} + +TEST_F(TestNegativeTcpDriver, DestroyUnownedMr) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver, *driver1; + + result = UTHelper::GetDriver(driver, DRIVER_STATE_INIT, false, UBSHcomNetDriverProtocol::TCP); + UT_CHECK_RESULT_OK(result) + + result = UTHelper::GetDriver(driver1, DRIVER_STATE_INIT, false, UBSHcomNetDriverProtocol::TCP); + UT_CHECK_RESULT_OK(result) + + UBSHcomNetMemoryRegionPtr mr1; + result = driver1->CreateMemoryRegion(NN_NO1024, mr1); + UT_CHECK_RESULT_OK(result) + + driver->DestroyMemoryRegion(mr1); + ASSERT_NE(mr1.Get()->GetAddress(), 0); + + driver->UnInitialize(); + driver1->UnInitialize(); + std::string name1 = driver->Name(); + std::string name2 = driver1->Name(); + UBSHcomNetDriver::DestroyInstance(name1); + UBSHcomNetDriver::DestroyInstance(name2); +} + +TEST_F(TestNegativeTcpDriver, UseAfterStop) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver, *server; + result = UTHelper::GetDriver(server, DRIVER_STATE_STOP, true, UBSHcomNetDriverProtocol::TCP); + UT_CHECK_RESULT_OK(result) + result = UTHelper::GetDriver(driver, DRIVER_STATE_STOP, false, UBSHcomNetDriverProtocol::TCP); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_FALSE(driver->IsStarted()) + result = driver->Start(); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_TRUE(driver->IsStarted()) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(server, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + std::string name1 = server->Name(); + std::string name2 = driver->Name(); + UBSHcomNetDriver::DestroyInstance(name1); + UBSHcomNetDriver::DestroyInstance(name2); +} + +TEST_F(TestNegativeTcpDriver, DiscontinuousState) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver; + result = + UTHelper::GetDriverStateMask(driver, DRIVER_STATE_INIT | DRIVER_STATE_START, false, + UBSHcomNetDriverProtocol::TCP); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_TRUE(driver->IsStarted()) + UT_CHECK_RESULT_TRUE(driver->IsInited()) + driver->Stop(); + UT_CHECK_RESULT_FALSE(driver->IsStarted()) + UT_CHECK_RESULT_TRUE(driver->IsInited()) + std::string name1 = driver->Name(); + UBSHcomNetDriver::DestroyInstance(name1); +} + +TEST_F(TestNegativeTcpDriver, UseAfterUninit) +{ + NResult result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *driver, *server; + result = UTHelper::GetDriver(server, DRIVER_STATE_START, true, UBSHcomNetDriverProtocol::TCP); + UT_CHECK_RESULT_OK(result) + result = UTHelper::GetDriver(driver, DRIVER_STATE_UNINIT, false, UBSHcomNetDriverProtocol::TCP); + UT_CHECK_RESULT_OK(result) + UT_CHECK_RESULT_FALSE(driver->IsStarted()) + UT_CHECK_RESULT_FALSE(driver->IsInited()) + + UBSHcomNetMemoryRegionPtr mr; + result = driver->CreateMemoryRegion(NN_NO1024, mr); + UT_CHECK_RESULT_NOK(result) + result = driver->Start(); + UT_CHECK_RESULT_NOK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_NOK(result) + + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_INIT); + UT_CHECK_RESULT_OK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_NOK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_START); + UT_CHECK_RESULT_OK(result) + result = driver->Connect("halo", ep); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(server, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(driver, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + result = UTHelper::ForwardDriverStateMask(server, DRIVER_STATE_STOP | DRIVER_STATE_UNINIT); + UT_CHECK_RESULT_OK(result) + std::string name1 = server->Name(); + std::string name2 = driver->Name(); + UBSHcomNetDriver::DestroyInstance(name1); + UBSHcomNetDriver::DestroyInstance(name2); +} diff --git a/test/llt/testcase/transport/sock/test_negative_tcp_driver.h b/test/llt/testcase/transport/sock/test_negative_tcp_driver.h new file mode 100644 index 0000000000000000000000000000000000000000..fc0e6e692e8c64c720109a6357fce4a934767e6b --- /dev/null +++ b/test/llt/testcase/transport/sock/test_negative_tcp_driver.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_LLT_TEST_NEGATIVE_CASES_H +#define HCOM_LLT_TEST_NEGATIVE_CASES_H +#include + +class TestNegativeTcpDriver : public testing::Test { +public: + TestNegativeTcpDriver(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_LLT_TEST_NEGATIVE_CASES_H diff --git a/test/llt/testcase/transport/sock/test_net_sock_driver_oob.cpp b/test/llt/testcase/transport/sock/test_net_sock_driver_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..455a673f97b4331891ca396d247fbe8111dab489 --- /dev/null +++ b/test/llt/testcase/transport/sock/test_net_sock_driver_oob.cpp @@ -0,0 +1,694 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include + +#include "hcom.h" +#include "mockcpp/mockcpp.hpp" +#include "sock_common.h" +#include "test_net_sock_driver_oob.h" +#include "ut_helper.h" + +using namespace ock::hcom; + +TestNetSockDriverOob::TestNetSockDriverOob() {} + +UBSHcomNetEndpointPtr ep = nullptr; +UBSHcomNetDriverOptions sockOptions {}; +static int g_nameSeed = 159753; +static int port = 9031; +UBSHcomNetDriver *sockServerDriver; +UBSHcomNetDriver *sockClientDriver; +UBSHcomNetTransSgeIov iovPtrServer[4]; +UBSHcomNetTransSgeIov iovPtrClient[4]; +std::string certificatePath; + +int sockOobNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + ep = newEP; + return 0; +} + +void sockOobEndPointBroken(const UBSHcomNetEndpointPtr &brokenEp) +{ + NN_LOG_INFO("end point " << brokenEp->Id()); + if (ep != nullptr) { + ep.Set(nullptr); + } +} + +int sockOobRequestReceived(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + return 0; +} + +int sockOobRequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} + +int sockOobRequestPostedFail(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted fail"); + return -1; +} + +int sockOobOneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + +int sockOobOneSideDoneFail(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done fail"); + return -1; +} + +void Idle(const UBSHcomNetWorkerIndex &workerIndex) {} + +static void Erase(void *pass, int len) {} + +static int Verify(void *x509, const char *path) +{ + return 0; +} + +int SockOobValidateTlsCert() +{ + char *buffer; + if ((buffer = getcwd(NULL, 0)) == NULL) { + NN_LOG_ERROR("Cet path for TLS cert failed"); + return -1; + } + + std::string currentPath = buffer; + certificatePath = currentPath + "/../test/opensslcrt/normalCert1"; + + if (!CanonicalPath(certificatePath)) { + NN_LOG_ERROR("TLS cert path check failed " << certificatePath); + return -1; + } + + return 0; +} + +static bool CertCallback(const std::string &name, std::string &value) +{ + value = certificatePath + "/server/cert.pem"; + return true; +} + +static bool PrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = certificatePath + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool CACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = certificatePath + "/CA/cacert.pem"; + cb = std::bind(&Verify, std::placeholders::_1, std::placeholders::_2); + return true; +} + +void SetCB(UBSHcomNetDriver *driver) +{ + driver->RegisterNewEPHandler( + std::bind(&sockOobNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&sockOobEndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&sockOobRequestReceived, std::placeholders::_1)); + driver->RegisterReqPostedHandler(std::bind(&sockOobRequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&sockOobOneSideDone, std::placeholders::_1)); + + driver->RegisterTLSCertificationCallback(std::bind(&CertCallback, std::placeholders::_1, std::placeholders::_2)); + driver->RegisterTLSCaCallback(std::bind(&CACallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + driver->RegisterTLSPrivateKeyCallback(std::bind(&PrivateKeyCallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); +} + +bool RegisterMemory(UBSHcomNetDriver *driver, UBSHcomNetTransSgeIov iovs[]) +{ + for (int i = 0; i < 4; i++) { + auto &iov = iovs[i]; + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO8, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + iov.lAddress = mr->GetAddress(); + iov.lKey = mr->GetLKey(); + iov.size = NN_NO8; + memset(reinterpret_cast(iov.lAddress), 0, iov.size); + } + return true; +} + +bool RegisterMemoryWithAddress(UBSHcomNetDriver *driver, UBSHcomNetTransSgeIov iovs[]) +{ + for (int i = 0; i < 16; i++) { + auto &iov = iovs[i]; + UBSHcomNetMemoryRegionPtr mr; + auto tmpBuf = memalign(1024, NN_NO8); + auto result = driver->CreateMemoryRegion(reinterpret_cast(tmpBuf), NN_NO8, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + iov.lAddress = mr->GetAddress(); + iov.lKey = mr->GetLKey(); + iov.size = NN_NO8; + memset(reinterpret_cast(iov.lAddress), 0, iov.size); + } + return true; +} + +void TestNetSockDriverOob::SetUp() +{ + MOCK_VERSION + SockOobValidateTlsCert(); + sockOptions.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + sockOptions.SetNetDeviceIpMask(IP_SEG); + sockOptions.pollingBatchSize = 16; + sockOptions.SetWorkerGroups("1"); + sockOptions.SetWorkerGroupsCpuSet("1-1"); + sockOptions.enableTls = false; + sockOptions.dontStartWorkers = false; + sockOptions.magic = NN_NO256; + sockOptions.oobType = ock::hcom::NET_OOB_TCP; + sockServerDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, std::to_string(g_nameSeed++), true); + sockClientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, std::to_string(g_nameSeed++), false); + sockServerDriver->OobIpAndPort(BASE_IP, port); + sockClientDriver->OobIpAndPort(BASE_IP, port++); + SetCB(sockServerDriver); + SetCB(sockClientDriver); + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestNetSockDriverOob::TearDown() +{ + std::string clientName = sockClientDriver->Name(); + std::string serverName = sockServerDriver->Name(); + if (sockServerDriver->IsStarted()) { + sockServerDriver->Stop(); + } + if (sockServerDriver->IsInited()) { + sockServerDriver->UnInitialize(); + } + if (sockClientDriver->IsStarted()) { + sockClientDriver->Stop(); + } + if (sockClientDriver->IsInited()) { + sockClientDriver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(clientName); + UBSHcomNetDriver::DestroyInstance(serverName); + GlobalMockObject::verify(); +} + +TEST_F(TestNetSockDriverOob, InitSuccess) +{ + sockOptions.workerThreadPriority = -21; + NResult result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); + sockOptions.workerThreadPriority = 21; + result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); + sockOptions.workerThreadPriority = 1; + result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_OK, result); + sockServerDriver->UnInitialize(); +} + +TEST_F(TestNetSockDriverOob, InitSuccessTwice) +{ + sockServerDriver->Initialize(sockOptions); + NResult result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, InitFailWithEmptyIp) +{ + sockOptions.SetNetDeviceIpMask(""); + NResult result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_INVALID_IP, result); +} + +TEST_F(TestNetSockDriverOob, InitSuccessWithBusyPolling) +{ + sockOptions.mode = UBSHcomNetDriverWorkingMode::NET_BUSY_POLLING; + NResult result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, InitSuccessWithTLS) +{ + sockOptions.enableTls = true; + NResult result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, InitSuccessWithoutSetWorkGroup) +{ + sockOptions.SetWorkerGroups(""); + NResult result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, InitFailWithWorkGroupHasZeroWorker) +{ + sockOptions.SetWorkerGroups("0"); + NResult result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetSockDriverOob, InitFailWithoutSetListeningIpAndPort) +{ + sockOptions.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + sockOptions.SetNetDeviceIpMask(IP_SEG); + sockOptions.pollingBatchSize = 16; + sockOptions.SetWorkerGroups("1"); + sockOptions.SetWorkerGroupsCpuSet("1-1"); + sockOptions.enableTls = false; + sockOptions.dontStartWorkers = false; + sockOptions.magic = NN_NO256; + UBSHcomNetDriver *netDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, + std::to_string(g_nameSeed++), true); + SetCB(netDriver); + NResult result = netDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetSockDriverOob, InitFailWithFailToInitWorker) +{ + MOCKER(epoll_create).defaults().will(returnValue(-1)); + NResult result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_NEW_OBJECT_FAILED, result); +} + +TEST_F(TestNetSockDriverOob, InitFailWithTLSCipherSuiteUnknown) +{ + sockOptions.enableTls = true; + sockOptions.cipherSuite = ock::hcom::UBSHcomNetCipherSuite(4); + auto result = sockServerDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); + result = sockClientDriver->Initialize(sockOptions); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); + sockOptions.cipherSuite = ock::hcom::AES_GCM_128; +} + +TEST_F(TestNetSockDriverOob, StartSuccessWithIdleHandler) +{ + sockServerDriver->RegisterIdleHandler(std::bind(&Idle, std::placeholders::_1)); + sockServerDriver->Initialize(sockOptions); + NResult result = sockServerDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, StartFailWithStartOobServerFail) +{ + sockServerDriver->Initialize(sockOptions); + MOCKER(::socket).defaults().will(returnValue(-1)); + NResult result = sockServerDriver->Start(); + EXPECT_EQ(NNCode::NN_OOB_LISTEN_SOCKET_ERROR, result); +} + +TEST_F(TestNetSockDriverOob, StartSuccessWithDontStartWorker) +{ + sockOptions.dontStartWorkers = true; + sockServerDriver->Initialize(sockOptions); + NResult result = sockServerDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, StartFailWithoutInit) +{ + NResult result = sockServerDriver->Start(); + EXPECT_EQ(NNCode::NN_ERROR, result); +} + +TEST_F(TestNetSockDriverOob, StartSuccessTwice) +{ + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + NResult result = sockServerDriver->Start(); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionSuccess) +{ + sockServerDriver->Initialize(sockOptions); + UBSHcomNetMemoryRegionPtr mr; + NResult result = sockServerDriver->CreateMemoryRegion(16, mr); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionFailWithoutInit) +{ + UBSHcomNetMemoryRegionPtr mr; + NResult result = sockServerDriver->CreateMemoryRegion(16, mr); + EXPECT_EQ(NNCode::NN_NOT_INITIALIZED, result); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionFail) +{ + UBSHcomNetMemoryRegionPtr mr; + NResult result = sockServerDriver->CreateMemoryRegion(0, mr); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionWithAddressSuccess) +{ + sockServerDriver->Initialize(sockOptions); + auto tmpBuf = memalign(NN_NO4096, 10); + UBSHcomNetMemoryRegionPtr mr; + NResult result = sockServerDriver->CreateMemoryRegion(reinterpret_cast(tmpBuf), 16, mr); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionWithAddressFailWithoutInit) +{ + auto tmpBuf = memalign(NN_NO4096, 10); + UBSHcomNetMemoryRegionPtr mr; + NResult result = sockServerDriver->CreateMemoryRegion(reinterpret_cast(tmpBuf), 16, mr); + EXPECT_EQ(NNCode::NN_NOT_INITIALIZED, result); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionWithAddressFailWithAddressIsZero) +{ + sockServerDriver->Initialize(sockOptions); + UBSHcomNetMemoryRegionPtr mr; + NResult result = sockServerDriver->CreateMemoryRegion(0, 16, mr); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetSockDriverOob, ConnectSuccess) +{ + sockOptions.tcpUserTimeout = 0; + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + NResult result = sockClientDriver->Connect("hello world", ep, 0); + EXPECT_EQ(NNCode::NN_OK, result); +} + + +TEST_F(TestNetSockDriverOob, ConnectUdsSuccess) +{ + const char *testFile = "hcom-server1"; + std::ofstream file(testFile); + file.close(); + sockOptions.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + sockOptions.SetNetDeviceIpMask(IP_SEG); + sockOptions.pollingBatchSize = 16; + sockOptions.SetWorkerGroups("1"); + sockOptions.SetWorkerGroupsCpuSet("1-1"); + sockOptions.enableTls = false; + sockOptions.dontStartWorkers = false; + sockOptions.magic = NN_NO256; + sockOptions.oobType = ock::hcom::NET_OOB_UDS; + sockServerDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::UDS, std::to_string(g_nameSeed++), true); + sockClientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::UDS, std::to_string(g_nameSeed++), false); + UBSHcomNetOobUDSListenerOptions opt {}; + opt.Name("hcom-server1"); + opt.perm = 0; + sockServerDriver->AddOobUdsOptions(opt); + sockClientDriver->OobUdsName("hcom-server1"); + SetCB(sockServerDriver); + SetCB(sockClientDriver); + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + NResult result = sockClientDriver->Connect("hello world", ep, 0); + EXPECT_EQ(NNCode::NN_OK, result); + remove(testFile); +} + +TEST_F(TestNetSockDriverOob, ConnectSuccessWithTLS) +{ + sockOptions.enableTls = true; + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + NResult result = sockClientDriver->Connect("hello world", ep, 0); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, ConnectSuccessWithTLSCipherSuite256) +{ + sockOptions.enableTls = true; + sockOptions.cipherSuite = ock::hcom::AES_GCM_256; + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + NResult result = sockClientDriver->Connect("hello world", ep, 0); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, ConnectFailWithoutInit) +{ + NResult result = sockClientDriver->Connect("hello world", ep, 0); + EXPECT_EQ(NNCode::NN_NOT_INITIALIZED, result); +} + +TEST_F(TestNetSockDriverOob, ConnectFailWithPayloadOversize) +{ + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + char payload[1030]; + for (char &i : payload) { + i = '1'; + } + payload[1029] = '\0'; + NResult result = sockClientDriver->Connect(payload, ep, 0); + EXPECT_EQ(NNCode::NN_INVALID_PARAM, result); +} + +TEST_F(TestNetSockDriverOob, ConnectFail) +{ + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + MOCKER(::connect).defaults().will(returnValue(-1)); + NResult result = sockClientDriver->Connect("hello world", ep, 0); + EXPECT_EQ(NNCode::NN_OOB_CLIENT_SOCKET_ERROR, result); +} + +TEST_F(TestNetSockDriverOob, ConnectFailWithMagicMismatch) +{ + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockOptions.magic = 104; + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + NResult result = sockClientDriver->Connect("hello world", ep, 0); + EXPECT_EQ(NNCode::NN_CONNECT_REFUSED, result); +} + +TEST_F(TestNetSockDriverOob, ConnectFailWithFailToInitSock) +{ + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + MOCKER(setsockopt).defaults().will(returnValue(-1)); + NResult result = sockClientDriver->Connect("hello world", ep, 0); + EXPECT_EQ(NNCode::NN_NEW_OBJECT_FAILED, result); +} + +TEST_F(TestNetSockDriverOob, ConnectSyncFailWithUninit) +{ + sockOptions.enableTls = true; + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + NResult result = sockClientDriver->Connect("hello world", ep, NET_EP_SELF_POLLING); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, ConnectSyncFailWithFailToInitSock) +{ + sockOptions.enableTls = true; + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + MOCKER(setsockopt).defaults().will(returnValue(-1)); + NResult result = sockClientDriver->Connect("hello world", ep, NET_EP_SELF_POLLING); + EXPECT_EQ(NNCode::NN_NEW_OBJECT_FAILED, result); +} + +TEST_F(TestNetSockDriverOob, ConnectSyncSuccess) +{ + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + NResult result = sockClientDriver->Connect("hello world", ep, NET_EP_SELF_POLLING); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, ConnectSyncFail) +{ + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + MOCKER(::connect).defaults().will(returnValue(-1)); + NResult result = sockClientDriver->Connect("hello world", ep, NET_EP_SELF_POLLING); + EXPECT_EQ(NNCode::NN_OOB_CLIENT_SOCKET_ERROR, result); +} + +TEST_F(TestNetSockDriverOob, ConnectSyncFailWithMagicMismatch) +{ + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockOptions.magic = 104; + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + NResult result = sockClientDriver->Connect("hello world", ep, NET_EP_SELF_POLLING); + EXPECT_EQ(NNCode::NN_CONNECT_REFUSED, result); +} + +TEST_F(TestNetSockDriverOob, SendSuccess) +{ + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + sockClientDriver->Connect("hello world", ep, 0); + static char data[100] = {}; + UBSHcomNetTransRequest req((void *)(data), sizeof(data), 0); + req.upCtxSize = NN_NO16; + for (auto i = 0; i < 16; i++) { + req.upCtxData[i] = 'a'; + } + NResult result = ep->PostSend(1, req); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, ReadWriteSuccess) +{ + sockServerDriver->RegisterOneSideDoneHandler(std::bind(&sockOobOneSideDoneFail, std::placeholders::_1)); + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + sockClientDriver->Connect("hello world", ep, 0); + RegisterMemory(sockServerDriver, iovPtrServer); + RegisterMemory(sockClientDriver, iovPtrClient); + UBSHcomNetTransRequest req; + req.lAddress = iovPtrClient[0].lAddress; + req.rAddress = iovPtrServer[0].lAddress; + req.lKey = iovPtrClient[0].lKey; + req.rKey = iovPtrServer[0].lKey; + req.size = NN_NO4; + req.upCtxSize = NN_NO16; + for (uint32_t i = 0; i < NN_NO16; i++) { + req.upCtxData[i] = 'a'; + } + NResult result = ep->PostRead(req); + EXPECT_EQ(NNCode::NN_OK, result); + result = ep->PostWrite(req); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, ReadWriteSglSuccess) +{ + sockServerDriver->RegisterOneSideDoneHandler(std::bind(&sockOobOneSideDoneFail, std::placeholders::_1)); + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + sockClientDriver->Connect("hello world", ep, 0); + RegisterMemory(sockServerDriver, iovPtrServer); + RegisterMemory(sockClientDriver, iovPtrClient); + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = iovPtrClient[i].lAddress; + iov[i].rAddress = iovPtrServer[i].lAddress; + iov[i].lKey = iovPtrClient[i].lKey; + iov[i].rKey = iovPtrServer[i].lKey; + iov[i].size = NN_NO4; + } + UBSHcomNetTransSglRequest req(iov, NN_NO4, 0); + req.upCtxSize = NN_NO16; + for (auto i = 0; i < 16; i++) { + req.upCtxData[i] = 'a'; + } + NResult result = ep->PostRead(req); + EXPECT_EQ(NNCode::NN_OK, result); + result = ep->PostWrite(req); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, SendRawSglSuccess) +{ + sockServerDriver->RegisterReqPostedHandler(std::bind(&sockOobRequestPostedFail, std::placeholders::_1)); + sockServerDriver->Initialize(sockOptions); + sockServerDriver->Start(); + sockClientDriver->Initialize(sockOptions); + sockClientDriver->Start(); + sockClientDriver->Connect("hello world", ep, 0); + RegisterMemory(sockServerDriver, iovPtrServer); + RegisterMemory(sockClientDriver, iovPtrClient); + UBSHcomNetTransSgeIov iov[NN_NO16]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = iovPtrClient[i].lAddress; + iov[i].rAddress = iovPtrServer[i].lAddress; + iov[i].lKey = iovPtrClient[i].lKey; + iov[i].rKey = iovPtrServer[i].lKey; + iov[i].size = NN_NO4; + } + UBSHcomNetTransSglRequest req(iov, NN_NO4, 0); + req.upCtxSize = NN_NO16; + for (auto i = 0; i < 16; i++) { + req.upCtxData[i] = 'a'; + } + NResult result = ep->PostSendRawSgl(req, 1); + EXPECT_EQ(NNCode::NN_OK, result); +} + +TEST_F(TestNetSockDriverOob, DestroyMemoryRegion) +{ + sockServerDriver->Initialize(sockOptions); + sockClientDriver->Initialize(sockOptions); + UBSHcomNetMemoryRegionPtr mr1; + UBSHcomNetMemoryRegionPtr mr2; + UBSHcomNetMemoryRegionPtr mr3; + sockServerDriver->CreateMemoryRegion(NN_NO8, mr1); + sockClientDriver->CreateMemoryRegion(NN_NO8, mr3); + sockServerDriver->DestroyMemoryRegion(mr1); + sockServerDriver->DestroyMemoryRegion(mr2); + sockServerDriver->DestroyMemoryRegion(mr3); +} \ No newline at end of file diff --git a/test/llt/testcase/transport/sock/test_net_sock_driver_oob.h b/test/llt/testcase/transport/sock/test_net_sock_driver_oob.h new file mode 100644 index 0000000000000000000000000000000000000000..413d2ac12ca2f6d1e8d7f7bca0bd632239b4f46c --- /dev/null +++ b/test/llt/testcase/transport/sock/test_net_sock_driver_oob.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_TEST_NET_SOCK_DRIVER_OOB_H +#define HCOM_TEST_NET_SOCK_DRIVER_OOB_H +#include +#include + +class TestNetSockDriverOob : public testing::Test { +public: + TestNetSockDriverOob(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; +#endif // HCOM_TEST_NET_SOCK_DRIVER_OOB_H \ No newline at end of file diff --git a/test/llt/testcase/transport/sock/test_net_sock_endpoint.cpp b/test/llt/testcase/transport/sock/test_net_sock_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f15be2be9e75e61eef13b6a481a0ab0c1a91f272 --- /dev/null +++ b/test/llt/testcase/transport/sock/test_net_sock_endpoint.cpp @@ -0,0 +1,1086 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "gtest/gtest.h" +#include "hcom.h" +#include "hcom_def.h" +#include "net_sock_common.h" +#include "test_net_sock_endpoint.h" + +using namespace ock::hcom; +TestNetSockEndpoint::TestNetSockEndpoint() {} + +#define BASE_IP "127.0.0.1" +#define IP_SEG "127.0.0.0/16" + +struct TestRegMrInfo { + uintptr_t lAddress = 0; + uint32_t lKey = 0; + uint32_t size = 0; +} __attribute__((packed)); + +static UBSHcomNetEndpointPtr serverEp = nullptr; +static sem_t sem; +std::string certPath1; + +void TestNetSockEndpoint::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestNetSockEndpoint::TearDown() +{ + GlobalMockObject::verify(); +} + +/* callback functions */ +static int NewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + serverEp = newEP; + return 0; +} + +static void EndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + NN_LOG_INFO("end point " << ep->Id()); +} + +static int RequestReceived(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("client request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + sem_post(&sem); + return 0; +} + +static int RequestReceivedServer(const UBSHcomNetRequestContext &ctx) +{ + std::string respMsg = "Hello client, this is a reply message"; + + int result = 0; + UBSHcomNetTransRequest req((void *)(const_cast(respMsg.c_str())), respMsg.length(), 0); + if ((result = serverEp->PostSend(1, req)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + } + return 0; +} + +/* server new request sgl callback */ +static TestRegMrInfo localMrInfo[NN_NO4]; +static int RequestReceivedSglServer(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("server request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + + int result = 0; + UBSHcomNetTransRequest rsp((void *)(localMrInfo), sizeof(localMrInfo), 0); + if ((result = serverEp->PostSend(1, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + + NN_LOG_INFO("request rsp Mr info"); + for (uint16_t i = 0; i < NN_NO4; i++) { + NN_LOG_INFO("idx:" << i << " key:" << localMrInfo[i].lKey << " address:" << localMrInfo[i].lAddress << + " size: " << localMrInfo[i].size); + } + return 0; +} + +/* client new request sgl callback */ +static TestRegMrInfo remoteMrInfo[NN_NO4]; +static int RequestReceivedSglClient(const UBSHcomNetRequestContext &ctx) +{ + memcpy(remoteMrInfo, ctx.Message()->Data(), ctx.Message()->DataLen()); + NN_LOG_INFO("get remote Mr info"); + for (uint16_t i = 0; i < NN_NO4; i++) { + NN_LOG_INFO("idx:" << i << " key:" << remoteMrInfo[i].lKey << " address:" << remoteMrInfo[i].lAddress << + " size" << remoteMrInfo[i].size); + } + + sem_post(&sem); + return 0; +} + +static int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} + +static int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + sem_post(&sem); + return 0; +} + +static bool RegSglMem(UBSHcomNetDriver *driver, TestRegMrInfo mrInfo[], std::vector &mrs) +{ + for (uint32_t i = 0; i < NN_NO4; ++i) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO100, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + mrInfo[i].lAddress = mr->GetAddress(); + mrInfo[i].lKey = mr->GetLKey(); + mrInfo[i].size = NN_NO100; + mrs.push_back(mr); + memset(reinterpret_cast(mrInfo[i].lAddress), 0, mrInfo[i].size); + NN_LOG_INFO(driver->Name() << ": lAddress = " << mrInfo[i].lAddress << ", lKey = " << mrInfo[i].lKey); + } + return true; +} + +static void SockEpDestroyMem(UBSHcomNetDriver *driver, std::vector &mrs) +{ + while (!mrs.empty()) { + driver->DestroyMemoryRegion(mrs.back()); + mrs.pop_back(); + } +} + +static void Erase(void *pass, int len) {} + +static int Verify(void *x509, const char *path) +{ + return 0; +} + +static bool CertCallback(const std::string &name, std::string &value) +{ + value = certPath1 + "/server/cert.pem"; + return true; +} + +static bool PrivateKeyCallback(const std::string &name, std::string &value, void *&keyPass, int &len, + UBSHcomTLSEraseKeypass &erase) +{ + static char content[] = "huawei"; + keyPass = reinterpret_cast(content); + len = sizeof(content); + value = certPath1 + "/server/key.pem"; + erase = std::bind(&Erase, std::placeholders::_1, std::placeholders::_2); + + return true; +} + +static bool CACallback(const std::string &name, std::string &caPath, std::string &crlPath, + UBSHcomPeerCertVerifyType &peerCertVerifyType, UBSHcomTLSCertVerifyCallback &cb) +{ + caPath = certPath1 + "/CA/cacert.pem"; + cb = std::bind(&Verify, std::placeholders::_1, std::placeholders::_2); + return true; +} + +static bool CreateServerDriver(UBSHcomNetDriver *&driver, uint16_t port, + int (*reqHandler)(const UBSHcomNetRequestContext &), bool enableTls, uint32_t segSize = 1024, uint16_t buffSize = 0) +{ + auto name = "server-" + std::to_string(port); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, name, true); + + if (driver == nullptr) { + NN_LOG_ERROR("failed to create serverDriver already created"); + return false; + } + + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.SetNetDeviceIpMask(IP_SEG); + options.enableTls = enableTls; + options.mrSendReceiveSegCount = 10; + options.mrSendReceiveSegSize = segSize; + options.tcpSendBufSize = buffSize; + options.tcpReceiveBufSize = buffSize; + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + driver->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(reqHandler); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + if (enableTls) { + driver->RegisterTLSCertificationCallback( + std::bind(&CertCallback, std::placeholders::_1, std::placeholders::_2)); + driver->RegisterTLSCaCallback(std::bind(&CACallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + driver->RegisterTLSPrivateKeyCallback(std::bind(&PrivateKeyCallback, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + } + + driver->OobIpAndPort(BASE_IP, port); + + int result = 0; + if ((result = driver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("serverDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start serverDriver " << result); + return false; + } + NN_LOG_INFO("serverDriver started"); + return true; +} + +static bool CreateClientDriver(UBSHcomNetDriver *&driver, uint16_t port, + int (*reqHandler)(const UBSHcomNetRequestContext &), bool enableTls) +{ + auto name = "client-" + std::to_string(port); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, name, false); + if (driver == nullptr) { + NN_LOG_ERROR("failed to create clientDriver already created"); + return false; + } + + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.SetNetDeviceIpMask(IP_SEG); + options.enableTls = enableTls; + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(reqHandler); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + + if (enableTls) { + driver->RegisterTLSCertificationCallback( + std::bind(&CertCallback, std::placeholders::_1, std::placeholders::_2)); + driver->RegisterTLSCaCallback(std::bind(&CACallback, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + driver->RegisterTLSPrivateKeyCallback(std::bind(&PrivateKeyCallback, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)); + } + + driver->OobIpAndPort(BASE_IP, port); + + int result = 0; + if ((result = driver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("clientDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start clientDriver " << result); + return false; + } + NN_LOG_INFO("clientDriver started"); + return true; +} + +static bool CreateClientDriverSync(UBSHcomNetDriver *&driver, uint16_t port, uint32_t segSize = 1024, + uint16_t buffSize = 0) +{ + auto name = "client-" + std::to_string(port); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, name, false); + if (driver == nullptr) { + NN_LOG_ERROR("failed to create clientDriver already created"); + return false; + } + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.mrSendReceiveSegCount = 10; + options.mrSendReceiveSegSize = segSize; + options.dontStartWorkers = true; + options.tcpSendBufSize = buffSize; + options.tcpReceiveBufSize = buffSize; + options.enableTls = false; + options.SetNetDeviceIpMask(IP_SEG); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + driver->OobIpAndPort(BASE_IP, port); + + int result = 0; + if ((result = driver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize driver " << result); + return false; + } + NN_LOG_INFO("clientDriver initialized"); + + if ((result = driver->Start()) != 0) { + NN_LOG_ERROR("failed to start clientDriver " << result); + return false; + } + NN_LOG_INFO("clientDriver started"); + return true; +} + +void CloseDriver(UBSHcomNetDriver *&clientDriver, UBSHcomNetDriver *&serverDriver) +{ + std::string clientName = clientDriver->Name(); + std::string serverName = serverDriver->Name(); + if (clientDriver->IsStarted()) { + clientDriver->Stop(); + clientDriver->UnInitialize(); + } + if (serverDriver->IsStarted()) { + serverDriver->Stop(); + serverDriver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(clientName); + UBSHcomNetDriver::DestroyInstance(serverName); +} + +TEST_F(TestNetSockEndpoint, PostSendRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + CreateServerDriver(serverDriver, 9911, RequestReceived, false); + CreateClientDriver(clientDriver, 9911, RequestReceived, false); + + clientDriver->Connect("hello server", ep, 0); + + ep->DefaultTimeout(1); + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + + auto peerIpAndPort = ep->PeerIpAndPort(); + NN_LOG_INFO(peerIpAndPort); + EXPECT_EQ("127.0.0.1:9911", peerIpAndPort); + + result = ep->PostSend(1, req); + EXPECT_EQ(SS_OK, result); + + UBSHcomNetTransOpInfo innerOpInfo(0, 0, 0, NTH_TWO_SIDE); + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SS_OK, result); + + MOCKER_CPP(&SockWorker::PostSend).defaults().will(returnObjectList(413, 400, 413, 400)); + result = ep->PostSend(1, req); + EXPECT_EQ(SS_ERROR, result); + + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SS_ERROR, result); + + result = ep->WaitCompletion(); + EXPECT_EQ(NN_INVALID_OPERATION, result); + + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(2, respCtx); + EXPECT_EQ(NN_INVALID_OPERATION, result); + + result = ep->ReceiveRaw(2, respCtx); + EXPECT_EQ(NN_INVALID_OPERATION, result); + + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, PostSendRawRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + CreateServerDriver(serverDriver, 9912, RequestReceived, false); + CreateClientDriver(clientDriver, 9912, RequestReceived, false); + + clientDriver->Connect("hello world", ep, 0); + ep->DefaultTimeout(1); + std::string value = "sock ping pong client"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SS_OK, result); + + MOCKER_CPP(&SockWorker::PostSend).defaults().will(returnObjectList(413, 400)); + result = ep->PostSendRaw(req, 0); + EXPECT_EQ(SS_ERROR, result); + + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, PostSendRawSglRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + CreateServerDriver(serverDriver, 9913, RequestReceived, false); + CreateClientDriver(clientDriver, 9913, RequestReceived, false); + + clientDriver->Connect("hello server", ep, 0); + ep->DefaultTimeout(1); + sem_init(&sem, 0, 0); + + TestRegMrInfo clientMrInfo[NN_NO4]; + std::vector mrs; + bool res = RegSglMem(clientDriver, clientMrInfo, mrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = clientMrInfo[i].lAddress; + iov[i].lKey = clientMrInfo[i].lKey; + iov[i].size = clientMrInfo[i].size; + } + UBSHcomNetTransSglRequest req(iov, NN_NO4, 0); + + result = ep->PostSendRawSgl(req, 1); + EXPECT_EQ(SS_OK, result); + sem_wait(&sem); + + MOCKER_CPP(&SockWorker::PostSendRawSgl).defaults().will(returnObjectList(413, 400)); + result = ep->PostSendRawSgl(req, 0); + EXPECT_EQ(SS_ERROR, result); + + SockEpDestroyMem(clientDriver, mrs); + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, PostReadWriteRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + CreateServerDriver(serverDriver, 9914, RequestReceivedSglServer, false); + CreateClientDriver(clientDriver, 9914, RequestReceivedSglClient, false); + + clientDriver->Connect("hello server", ep, 0); + + ep->DefaultTimeout(1); + sem_init(&sem, 0, 0); + + TestRegMrInfo clientMrInfo[NN_NO4]; + bool res; + std::vector clientMrs; + res = RegSglMem(clientDriver, clientMrInfo, clientMrs); + EXPECT_TRUE(res); + std::vector serverMrs; + res = RegSglMem(serverDriver, localMrInfo, serverMrs); + EXPECT_TRUE(res); + + std::string msg = "Transfer MrInfo of the server to the client."; + UBSHcomNetTransRequest rsp((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, rsp); + EXPECT_EQ(SS_OK, result); + sem_wait(&sem); + + UBSHcomNetTransRequest req; + req.lAddress = clientMrInfo[0].lAddress; + req.rAddress = localMrInfo[0].lAddress; + req.lKey = clientMrInfo[0].lKey; + req.rKey = localMrInfo[0].lKey; + req.size = localMrInfo[0].size; + + NN_LOG_INFO("++++++++++++rAddress = " << req.rAddress << ", value = " << + *(reinterpret_cast((void *)req.rAddress))); + NN_LOG_INFO("++++++++++++lAddress = " << req.lAddress << ", value = " << + *(reinterpret_cast((void *)req.lAddress))); + result = ep->PostRead(req); + sem_wait(&sem); + + for (uint16_t i = 0; i < 1; i++) { + uint64_t *readValue = reinterpret_cast((void *)(clientMrInfo[i].lAddress)); + NN_LOG_INFO("value[" << i << "]=" << *readValue); + } + EXPECT_EQ(SS_OK, result); + + result = ep->PostWrite(req); + sem_wait(&sem); + + MOCKER_CPP(&SockWorker::PostRead, SResult(SockWorker::*)(Sock *, SockTransHeader &, + const UBSHcomNetTransRequest &)) + .defaults() + .will(returnObjectList(413, 400)); + result = ep->PostRead(req); + EXPECT_EQ(SS_ERROR, result); + + MOCKER_CPP(&SockWorker::PostWrite, SResult(SockWorker::*)(Sock *, SockTransHeader &, + const UBSHcomNetTransRequest &)) + .defaults() + .will(returnObjectList(413, 400)); + result = ep->PostWrite(req); + EXPECT_EQ(SS_ERROR, result); + + SockEpDestroyMem(clientDriver, clientMrs); + SockEpDestroyMem(serverDriver, serverMrs); + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, PostReadWriteSglRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + CreateServerDriver(serverDriver, 9915, RequestReceivedSglServer, false); + CreateClientDriver(clientDriver, 9915, RequestReceivedSglClient, false); + + clientDriver->Connect("hello server", ep, 0); + ep->DefaultTimeout(1); + sem_init(&sem, 0, 0); + + TestRegMrInfo clientMrInfo[NN_NO4]; + bool res; + std::vector clientMrs; + res = RegSglMem(clientDriver, clientMrInfo, clientMrs); + EXPECT_TRUE(res); + std::vector serverMrs; + res = RegSglMem(serverDriver, localMrInfo, serverMrs); + EXPECT_TRUE(res); + + std::string msg = "Transfer MrInfo of the server to the client."; + UBSHcomNetTransRequest rsp((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, rsp); + EXPECT_EQ(SS_OK, result); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = clientMrInfo[i].lAddress; + iov[i].rAddress = localMrInfo[i].lAddress; + iov[i].lKey = clientMrInfo[i].lKey; + iov[i].rKey = localMrInfo[i].lKey; + iov[i].size = localMrInfo[i].size; + } + UBSHcomNetTransSglRequest reqRead(iov, NN_NO4, 0); + result = ep->PostRead(reqRead); + sem_wait(&sem); + + for (uint16_t i = 0; i < NN_NO4; i++) { + uint64_t *readValue = reinterpret_cast((void *)(clientMrInfo[i].lAddress)); + uint64_t value = *readValue; + NN_LOG_INFO("value[" << i << "]=" << *readValue); + *readValue = ++value; + } + EXPECT_EQ(SS_OK, result); + + UBSHcomNetTransSglRequest reqWrite(iov, NN_NO4, 0); + result = ep->PostWrite(reqWrite); + sem_wait(&sem); + + MOCKER_CPP(&SockWorker::PostRead, SResult(SockWorker::*)(Sock *, SockTransHeader &, + const UBSHcomNetTransSglRequest &)) + .defaults() + .will(returnObjectList(413, 400)); + result = ep->PostRead(reqRead); + EXPECT_EQ(SS_ERROR, result); + + MOCKER_CPP(&SockWorker::PostWrite, SResult(SockWorker::*)(Sock *, SockTransHeader &, + const UBSHcomNetTransSglRequest &)) + .defaults() + .will(returnObjectList(413, 400)); + result = ep->PostWrite(reqRead); + EXPECT_EQ(SS_ERROR, result); + + SockEpDestroyMem(clientDriver, clientMrs); + SockEpDestroyMem(serverDriver, serverMrs); + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, SyncPostSendRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + bool res; + res = CreateServerDriver(serverDriver, 9916, RequestReceivedServer, false); + EXPECT_TRUE(res); + + res = CreateClientDriverSync(clientDriver, 9916); + EXPECT_TRUE(res); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(1); + auto peerIpAndPort = ep->PeerIpAndPort(); + NN_LOG_INFO(peerIpAndPort); + EXPECT_EQ("127.0.0.1:9916", peerIpAndPort); + + std::string msg = "Hello server, this is a message"; + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, req); + EXPECT_EQ(SS_OK, result); + + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(2, respCtx); + std::string resp((char *)respCtx.Message()->Data(), respCtx.Header().dataLength); + NN_LOG_INFO("server response received - " << respCtx.Header().opCode << ", dataLen " << + respCtx.Header().dataLength); + EXPECT_EQ(SS_OK, result); + + UBSHcomNetTransOpInfo innerOpInfo(0, 0, 0, NTH_TWO_SIDE); + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SS_OK, result); + + MOCKER_CPP(&Sock::PostSend, SResult(Sock::*)(SockTransHeader &, const UBSHcomNetTransRequest &)) + .defaults() + .will(returnObjectList(413, 400, 413, 400)); + result = ep->PostSend(3, req); + EXPECT_EQ(SS_ERROR, result); + + result = ep->PostSend(1, req, innerOpInfo); + EXPECT_EQ(SS_ERROR, result); + + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, SyncReceiveRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + bool res; + res = CreateServerDriver(serverDriver, 9917, RequestReceived, false); + EXPECT_TRUE(res); + + res = CreateClientDriverSync(clientDriver, 9917); + EXPECT_TRUE(res); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + + UBSHcomNetResponseContext respCtx {}; + + MOCKER_CPP(&Sock::PostReceiveHeader).defaults().will(returnObjectList(400, 0, 0)); + result = ep->Receive(4, respCtx); + EXPECT_EQ(SS_ERROR, result); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed).defaults().will(returnObjectList(false, true)); + result = ep->Receive(6, respCtx); + EXPECT_EQ(NN_MALLOC_FAILED, result); + + MOCKER_CPP(&Sock::PostReceiveBody).defaults().will(returnValue(400)); + result = ep->Receive(4, respCtx); + EXPECT_EQ(SS_ERROR, result); + + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, SyncPostSendRawRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + bool res; + res = CreateServerDriver(serverDriver, 9918, RequestReceivedServer, false); + EXPECT_TRUE(res); + res = CreateClientDriverSync(clientDriver, 9918); + EXPECT_TRUE(res); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(1); + std::string msg = "Hello server, this is a message"; + UBSHcomNetTransRequest req((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SS_OK, result); + + UBSHcomNetResponseContext respCtx {}; + result = ep->ReceiveRaw(-1, respCtx); + std::string resp((char *)respCtx.Message()->Data(), respCtx.Header().dataLength); + NN_LOG_INFO("server response received - " << respCtx.Header().opCode << ", dataLen " << + respCtx.Header().dataLength); + EXPECT_EQ(SS_OK, result); + + MOCKER_CPP(&Sock::PostSend, SResult(Sock::*)(SockTransHeader &, const UBSHcomNetTransRequest &)) + .defaults() + .will(returnObjectList(413, 400)); + result = ep->PostSendRaw(req, 1); + EXPECT_EQ(SS_ERROR, result); + + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, SyncReceiveRawRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + bool res; + res = CreateServerDriver(serverDriver, 9919, RequestReceivedServer, false); + EXPECT_TRUE(res); + res = CreateClientDriverSync(clientDriver, 9919); + EXPECT_TRUE(res); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + + std::string msg = "Hello server, this is a message"; + UBSHcomNetResponseContext respCtx {}; + + MOCKER_CPP(&Sock::PostReceiveHeader).defaults().will(returnObjectList(400, 0, 0)); + result = ep->ReceiveRaw(4, respCtx); + EXPECT_EQ(SS_ERROR, result); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed).defaults().will(returnObjectList(false, true)); + result = ep->ReceiveRaw(6, respCtx); + EXPECT_EQ(NN_MALLOC_FAILED, result); + + MOCKER_CPP(&Sock::PostReceiveBody).defaults().will(returnValue(400)); + result = ep->ReceiveRaw(4, respCtx); + EXPECT_EQ(SS_ERROR, result); + + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, SyncPostSendRawSglRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + bool createRes; + createRes = CreateServerDriver(serverDriver, 9920, RequestReceivedServer, false); + EXPECT_TRUE(createRes); + createRes = CreateClientDriverSync(clientDriver, 9920); + EXPECT_TRUE(createRes); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + ep->DefaultTimeout(1); + sem_init(&sem, 0, 0); + + TestRegMrInfo clientMrInfo[NN_NO4]; + std::vector clientMrs; + bool res = RegSglMem(clientDriver, clientMrInfo, clientMrs); + EXPECT_TRUE(res); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = clientMrInfo[i].lAddress; + iov[i].lKey = clientMrInfo[i].lKey; + iov[i].size = clientMrInfo[i].size; + } + + UBSHcomNetTransSglRequest req(iov, NN_NO4, 0); + UBSHcomNetResponseContext respCtx {}; + + result = ep->PostSendRawSgl(req, 0); + EXPECT_EQ(SS_OK, result); + result = ep->ReceiveRawSgl(respCtx); + EXPECT_EQ(SS_OK, result); + + MOCKER_CPP(&Sock::PostSendSgl, SResult(Sock::*)(SockTransHeader &, const UBSHcomNetTransSglRequest &)) + .defaults() + .will(returnObjectList(413, 400)); + result = ep->PostSendRawSgl(req, 0); + EXPECT_EQ(SS_ERROR, result); + + SockEpDestroyMem(clientDriver, clientMrs); + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, SyncPostReadWriteRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + bool createRes; + createRes = CreateServerDriver(serverDriver, 9921, RequestReceivedSglServer, false); + EXPECT_TRUE(createRes); + createRes = CreateClientDriverSync(clientDriver, 9921); + EXPECT_TRUE(createRes); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + + TestRegMrInfo clientMrInfo[NN_NO4]; + bool res; + std::vector clientMrs; + res = RegSglMem(clientDriver, clientMrInfo, clientMrs); + EXPECT_TRUE(res); + std::vector serverMrs; + res = RegSglMem(serverDriver, localMrInfo, serverMrs); + EXPECT_TRUE(res); + + std::string msg = "Transfer MrInfo of the server to the client."; + UBSHcomNetTransRequest rsp((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, rsp); + EXPECT_EQ(SS_OK, result); + + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(respCtx); + EXPECT_EQ(SS_OK, result); + + memcpy(remoteMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + UBSHcomNetTransRequest req; + req.lAddress = clientMrInfo[0].lAddress; + req.rAddress = localMrInfo[0].lAddress; + req.lKey = clientMrInfo[0].lKey; + req.rKey = localMrInfo[0].lKey; + req.size = localMrInfo[0].size; + + result = ep->PostRead(req); + EXPECT_EQ(SS_OK, result); + + result = ep->WaitCompletion(); + EXPECT_EQ(SS_OK, result); + + for (uint16_t i = 0; i < 1; i++) { + uint64_t *readValue = reinterpret_cast((void *)(clientMrInfo[i].lAddress)); + NN_LOG_INFO("value[" << i << "]=" << *readValue); + } + + result = ep->PostWrite(req); + EXPECT_EQ(SS_OK, result); + + result = ep->WaitCompletion(); + EXPECT_EQ(SS_OK, result); + + MOCKER_CPP(&Sock::PostRead, SResult(Sock::*)(SockOpContextInfo *)).defaults().will(returnObjectList(401)); + result = ep->PostRead(req); + EXPECT_EQ(SS_PARAM_INVALID, result); + + MOCKER_CPP(&Sock::PostWrite, SResult(Sock::*)(SockOpContextInfo *)).defaults().will(returnObjectList(401)); + result = ep->PostWrite(req); + EXPECT_EQ(SS_PARAM_INVALID, result); + + SockEpDestroyMem(clientDriver, clientMrs); + SockEpDestroyMem(serverDriver, serverMrs); + CloseDriver(clientDriver, serverDriver); +} + +TEST_F(TestNetSockEndpoint, SyncPostReadWriteSglRetry) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + bool createRes; + createRes = CreateServerDriver(serverDriver, 9922, RequestReceivedSglServer, false); + EXPECT_TRUE(createRes); + createRes = CreateClientDriverSync(clientDriver, 9922); + EXPECT_TRUE(createRes); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + + TestRegMrInfo clientMrInfo[NN_NO4]; + bool res; + std::vector clientMrs; + res = RegSglMem(clientDriver, clientMrInfo, clientMrs); + EXPECT_TRUE(res); + std::vector serverMrs; + res = RegSglMem(serverDriver, localMrInfo, serverMrs); + EXPECT_TRUE(res); + + std::string msg = "Transfer MrInfo of the server to the client."; + UBSHcomNetTransRequest rsp((void *)(const_cast(msg.c_str())), msg.length(), 0); + result = ep->PostSend(1, rsp); + EXPECT_EQ(SS_OK, result); + + UBSHcomNetResponseContext respCtx {}; + result = ep->Receive(respCtx); + EXPECT_EQ(SS_OK, result); + + memcpy(localMrInfo, respCtx.Message()->Data(), respCtx.Message()->DataLen()); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = clientMrInfo[i].lAddress; + iov[i].rAddress = localMrInfo[i].lAddress; + iov[i].lKey = clientMrInfo[i].lKey; + iov[i].rKey = localMrInfo[i].lKey; + iov[i].size = localMrInfo[i].size; + NN_LOG_INFO("idx:" << i << " lKey:" << iov[i].lKey << " lAddress:" << iov[i].lAddress << " rKey:" << + iov[i].rKey << " rAddress:" << iov[i].rAddress << " size:" << iov[i].size); + } + UBSHcomNetTransSglRequest reqRead(iov, NN_NO4, 0); + result = ep->PostRead(reqRead); + EXPECT_EQ(SS_OK, result); + + result = ep->WaitCompletion(); + EXPECT_EQ(SS_OK, result); + + UBSHcomNetTransSglRequest reqWrite(iov, NN_NO4, 0); + result = ep->PostWrite(reqWrite); + EXPECT_EQ(SS_OK, result); + + result = ep->WaitCompletion(); + EXPECT_EQ(SS_OK, result); + + MOCKER_CPP(&Sock::PostReadSgl, SResult(Sock::*)(SockOpContextInfo *)).defaults().will(returnObjectList(401)); + result = ep->PostRead(reqRead); + EXPECT_EQ(SS_PARAM_INVALID, result); + + MOCKER_CPP(&Sock::PostWriteSgl, SResult(Sock::*)(SockOpContextInfo *)).defaults().will(returnObjectList(401)); + result = ep->PostWrite(reqWrite); + EXPECT_EQ(SS_PARAM_INVALID, result); + + SockEpDestroyMem(clientDriver, clientMrs); + SockEpDestroyMem(serverDriver, serverMrs); + CloseDriver(clientDriver, serverDriver); +} + +int SockValidateTlsCert() +{ + char *buffer; + if ((buffer = getcwd(NULL, 0)) == NULL) { + NN_LOG_ERROR("Cet path for TLS cert failed"); + return -1; + } + + std::string currentPath = buffer; + certPath1 = currentPath + "/../test/opensslcrt/normalCert1"; + + if (!CanonicalPath(certPath1)) { + NN_LOG_ERROR("TLS cert path check failed " << certPath1); + return -1; + } + + return 0; +} + +TEST_F(TestNetSockEndpoint, EpEncrypt) +{ + SockValidateTlsCert(); + UBSHcomNetEndpointPtr ep = nullptr; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + CreateServerDriver(serverDriver, 9923, RequestReceivedSglServer, true); + CreateClientDriver(clientDriver, 9923, RequestReceivedSglClient, true); + + auto ret = clientDriver->Connect("hello server", ep, 0); + ASSERT_EQ(ret, NN_OK); + std::string value = "EpEncrypt"; + uint64_t encryptLen = ep->EstimatedEncryptLen(value.length()); + void *cipher = malloc(encryptLen); + ep->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + uint64_t decryptLen = ep->EstimatedDecryptLen(encryptLen); + void *rawData = malloc(decryptLen); + ep->Decrypt(cipher, encryptLen, rawData, decryptLen); + auto result = memcmp(value.c_str(), reinterpret_cast(rawData), value.length()); + EXPECT_EQ(0, result); + ep->Close(); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + encryptLen = ep->EstimatedEncryptLen(value.length()); + cipher = malloc(encryptLen); + ep->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + decryptLen = ep->EstimatedDecryptLen(encryptLen); + rawData = malloc(decryptLen); + ep->Decrypt(cipher, encryptLen, rawData, decryptLen); + result = memcmp(value.c_str(), reinterpret_cast(rawData), value.length()); + EXPECT_EQ(0, result); + ep->Close(); + + clientDriver = nullptr; + serverDriver = nullptr; + + CreateServerDriver(serverDriver, 9924, RequestReceivedSglServer, false); + CreateClientDriver(clientDriver, 9924, RequestReceivedSglClient, false); + + clientDriver->Connect("hello server", ep, 0); + value = "EpEncrypt"; + encryptLen = ep->EstimatedEncryptLen(0); + EXPECT_EQ(0, encryptLen); + encryptLen = ep->EstimatedEncryptLen(value.length()); + EXPECT_EQ(0, encryptLen); + cipher = malloc(encryptLen); + auto res = ep->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + EXPECT_EQ(NN_ERROR, res); + decryptLen = ep->EstimatedDecryptLen(encryptLen); + EXPECT_EQ(0, decryptLen); + rawData = malloc(decryptLen); + res = ep->Decrypt(cipher, encryptLen, rawData, decryptLen); + EXPECT_EQ(NN_ERROR, res); + ep->Close(); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + value = "EpEncrypt"; + encryptLen = ep->EstimatedEncryptLen(0); + EXPECT_EQ(0, encryptLen); + encryptLen = ep->EstimatedEncryptLen(value.length()); + EXPECT_EQ(0, encryptLen); + cipher = malloc(encryptLen); + res = ep->Encrypt(value.c_str(), value.length(), cipher, encryptLen); + EXPECT_EQ(NN_ERROR, res); + decryptLen = ep->EstimatedDecryptLen(encryptLen); + EXPECT_EQ(0, decryptLen); + rawData = malloc(decryptLen); + res = ep->Decrypt(cipher, encryptLen, rawData, decryptLen); + EXPECT_EQ(NN_ERROR, res); + ep->Close(); +} + +TEST_F(TestNetSockEndpoint, SyncPostSendTimeout) +{ + UBSHcomNetEndpointPtr ep = nullptr; + NResult result; + + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetDriver *serverDriver = nullptr; + + bool res; + uint32_t segSize = 3 * 1024 * 1024; + res = CreateServerDriver(serverDriver, 9925, RequestReceivedServer, false, segSize, 1); + EXPECT_TRUE(res); + res = CreateClientDriverSync(clientDriver, 9925, segSize, 1); + EXPECT_TRUE(res); + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + UBSHcomNetResponseContext respCtx{}; + static char data[20] = "sock_pp_client"; + UBSHcomNetTransRequest req((void *)(data), sizeof(data), 0); + UBSHcomEpOptions epOptions{}; + + epOptions.sendTimeout = -1; + ep->SetEpOption(epOptions); + result = ep->PostSend(1, req); + EXPECT_EQ(SS_OK, result); + + result = ep->Receive(-1, respCtx); + EXPECT_EQ(SS_OK, result); + + static char data1[2 * 1024 * 1024] = "sock_pp_client"; + UBSHcomNetTransRequest req1((void *)(data1), sizeof(data1), 0); + + epOptions.sendTimeout = 0; + ep->SetEpOption(epOptions); + result = ep->PostSend(1, req1); + EXPECT_EQ(SS_TIMEOUT, result); + + clientDriver->Connect("hello server", ep, NET_EP_SELF_POLLING); + result = ep->Receive(0, respCtx); + EXPECT_EQ(SS_TIMEOUT, result); + + CloseDriver(clientDriver, serverDriver); +} \ No newline at end of file diff --git a/test/llt/testcase/transport/sock/test_net_sock_endpoint.h b/test/llt/testcase/transport/sock/test_net_sock_endpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..0eae9b1e7eacbd3292b30e5f78baf91616760671 --- /dev/null +++ b/test/llt/testcase/transport/sock/test_net_sock_endpoint.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_TEST_NET_SOCK_ENDPOINT_H +#define HCOM_TEST_NET_SOCK_ENDPOINT_H +#include +#include + +class TestNetSockEndpoint : public testing::Test { +public: + TestNetSockEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); +}; +#endif // HCOM_TEST_NET_SOCK_ENDPOINT_H diff --git a/test/llt/testcase/transport/sock/test_sock.cpp b/test/llt/testcase/transport/sock/test_sock.cpp new file mode 100644 index 0000000000000000000000000000000000000000..658fb5c8728dfbc137efa6cd2d653135e3c607a9 --- /dev/null +++ b/test/llt/testcase/transport/sock/test_sock.cpp @@ -0,0 +1,378 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include "test_sock.hpp" +#include "hcom.h" +#include "ut_helper.h" + +using namespace ock::hcom; +TestCaseSock::TestCaseSock() {} + +void TestCaseSock::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} + +void TestCaseSock::TearDown() +{ + GlobalMockObject::verify(); +} + +static std::string ipSeg = IP_SEG; +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t lKey = 0; + uint32_t size = 0; +} __attribute__((packed)); + +#define SOCK_CHECK_RESULT_TRUE(result) \ + EXPECT_EQ(true, (result)); \ + if (!(result)) { \ + return; \ + } + +static UBSHcomNetDriver *sockServerDriver = nullptr; +static UBSHcomNetEndpointPtr sockServerEp = nullptr; +constexpr uint64_t SOCK_PORT = 9925; + +static TestRegMrInfo localMrInfo[4]; + +int SockServerNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + sockServerEp = newEP; + return 0; +} + +void SockServerEndPointBroken(const UBSHcomNetEndpointPtr &sockServerEp1) +{ + NN_LOG_INFO("end point " << sockServerEp1->Id()); + if (sockServerEp != nullptr) { + sockServerEp.Set(nullptr); + } +} + +int SockServerRequestReceived(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + + int result = 0; + UBSHcomNetTransRequest rsp((void *)(localMrInfo), sizeof(localMrInfo), 0); + if ((result = sockServerEp->PostSend(1, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + + NN_LOG_INFO("request rsp Mr info"); + for (uint16_t i = 0; i < 4; i++) { + NN_LOG_TRACE_INFO("idx:" << i << " key:" << localMrInfo[i].lKey << " address:" << localMrInfo[i].lAddress << + " size" << localMrInfo[i].size); + } + return 0; +} + +int SockServerRequestPosted(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("request posted"); + return 0; +} + +int SockServerOneSideDone(const UBSHcomNetRequestContext &ctx) +{ + NN_LOG_INFO("one side done"); + return 0; +} + + +bool SockServerCreateDriver() +{ + if (sockServerDriver != nullptr) { + NN_LOG_ERROR("sockServerDriver already created"); + return false; + } + + sockServerDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::UDS, "sockServer", true); + if (sockServerDriver == nullptr) { + NN_LOG_ERROR("failed to create sockServerDriver already created"); + return false; + } + + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.enableTls = false; + options.SetNetDeviceIpMask(ipSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + sockServerDriver->RegisterNewEPHandler( + std::bind(&SockServerNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + sockServerDriver->RegisterEPBrokenHandler(std::bind(&SockServerEndPointBroken, std::placeholders::_1)); + sockServerDriver->RegisterNewReqHandler(std::bind(&SockServerRequestReceived, std::placeholders::_1)); + sockServerDriver->RegisterReqPostedHandler(std::bind(&SockServerRequestPosted, std::placeholders::_1)); + sockServerDriver->RegisterOneSideDoneHandler(std::bind(&SockServerOneSideDone, std::placeholders::_1)); + + sockServerDriver->OobIpAndPort(BASE_IP, SOCK_PORT); + + int result = 0; + if ((result = sockServerDriver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize sockServerDriver " << result); + return false; + } + NN_LOG_ERROR("sockServerDriver initialized"); + + if ((result = sockServerDriver->Start()) != 0) { + NN_LOG_ERROR("failed to start sockServerDriver " << result); + return false; + } + NN_LOG_ERROR("sockServerDriver started"); + + return true; +} + +bool SockServerRegSglMem() +{ + for (uint16_t i = 0; i < 4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = sockServerDriver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + localMrInfo[i].lAddress = mr->GetAddress(); + localMrInfo[i].lKey = mr->GetLKey(); + localMrInfo[i].size = NN_NO16; + memset(reinterpret_cast(localMrInfo[i].lAddress), 0, NN_NO16); + } + + return true; +} + +// client +static UBSHcomNetDriver *sockClientDriver = nullptr; +static UBSHcomNetEndpointPtr sockClientEp = nullptr; + +static TestRegMrInfo ClientLocalMrInfo[NN_NO4]; +static TestRegMrInfo remoteMrInfo[NN_NO4]; +static sem_t sem; + +void SockClientEndPointBroken(const UBSHcomNetEndpointPtr &sockClientEp1) +{ + NN_LOG_INFO("end point " << sockClientEp1->Id() << " broken"); + if (sockClientEp != nullptr) { + sockClientEp.Set(nullptr); + } +} + +int SockClientRequestReceived(const UBSHcomNetRequestContext &ctx) +{ + memcpy(remoteMrInfo, ctx.Message()->Data(), ctx.Message()->DataLen()); + NN_LOG_INFO("get remote Mr info"); + for (uint16_t i = 0; i < NN_NO4; i++) { + NN_LOG_TRACE_INFO("idx:" << i << " key:" << remoteMrInfo[i].lKey << " address:" << remoteMrInfo[i].lAddress << + " size" << remoteMrInfo[i].size); + } + + sem_post(&sem); + return 0; +} + +int SockClientRequestPosted(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +int SockClientOneSideDone(const UBSHcomNetRequestContext &ctx) +{ + sem_post(&sem); + return 0; +} + +bool SockClientCreateDriver() +{ + if (sockClientDriver != nullptr) { + NN_LOG_ERROR("sockClientDriver already created"); + return false; + } + + sockClientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::UDS, "sockClient", false); + if (sockClientDriver == nullptr) { + NN_LOG_ERROR("failed to create sockClientDriver already created"); + return false; + } + + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.enableTls = false; + options.SetNetDeviceIpMask(ipSeg); + NN_LOG_INFO("set ip mask " << options.netDeviceIpMask); + + sockClientDriver->RegisterEPBrokenHandler(std::bind(&SockClientEndPointBroken, std::placeholders::_1)); + sockClientDriver->RegisterNewReqHandler(std::bind(&SockClientRequestReceived, std::placeholders::_1)); + sockClientDriver->RegisterReqPostedHandler(std::bind(&SockClientRequestPosted, std::placeholders::_1)); + sockClientDriver->RegisterOneSideDoneHandler(std::bind(&SockClientOneSideDone, std::placeholders::_1)); + + sockClientDriver->OobIpAndPort(BASE_IP, SOCK_PORT); + + int result = 0; + if ((result = sockClientDriver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize sockClientDriver " << result); + return false; + } + NN_LOG_ERROR("sockClientDriver initialized"); + + if ((result = sockClientDriver->Start()) != 0) { + NN_LOG_ERROR("failed to start sockClientDriver " << result); + return false; + } + NN_LOG_ERROR("sockClientDriver started"); + + return true; +} + +bool SockClientConnect() +{ + if (sockClientDriver == nullptr) { + NN_LOG_ERROR("sockClientDriver is null"); + return false; + } + + int result = 0; + if ((result = sockClientDriver->Connect("hello world", sockClientEp, 0)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + + sem_init(&sem, 0, 0); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + if ((result = sockClientEp->PostSend(1, req)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + return false; + } + + sem_wait(&sem); + return true; +} + +void SockSendSingleRequest(UBSHcomNetTransSgeIov *iov, uint64_t index) +{ + int result = 0; + + UBSHcomNetTransSglRequest reqRead(iov, NN_NO4, 0); + result = sockClientEp->PostRead(reqRead); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_INFO("failed to read data from server"); + return; + } + sem_wait(&sem); + for (uint16_t i = 0; i < NN_NO4; i++) { + uint64_t *readValue = reinterpret_cast((void *)(ClientLocalMrInfo[i].lAddress)); + uint64_t value = *readValue; + NN_LOG_TRACE_INFO("value[" << i << "]=" << *readValue); + EXPECT_EQ(value, index); + *readValue = ++value; + } + + UBSHcomNetTransSglRequest reqWrite(iov, NN_NO4, 0); + result = sockClientEp->PostWrite(reqWrite); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_INFO("failed to read data from server"); + return; + } + sem_wait(&sem); + + UBSHcomNetTransRequest buffReq(iov[0].lAddress, iov[0].rAddress, iov[0].lKey, iov[0].rKey, iov[0].size, 0); + result = sockClientEp->PostRead(buffReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_INFO("failed to read data from server"); + return; + } + sem_wait(&sem); + uint64_t *readBuff = reinterpret_cast((void *)(iov[0].lAddress)); + uint64_t readValue = *readBuff; + EXPECT_EQ(readValue, index + 1); + + result = sockClientEp->PostWrite(buffReq); + EXPECT_EQ(result, 0); + if (result != 0) { + NN_LOG_INFO("failed to read data from server"); + return; + } + sem_wait(&sem); +} + +void SockSendRequest() +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = ClientLocalMrInfo[i].lAddress; + iov[i].rAddress = remoteMrInfo[i].lAddress; + iov[i].lKey = ClientLocalMrInfo[i].lKey; + iov[i].rKey = remoteMrInfo[i].lKey; + iov[i].size = NN_NO8; + } + sem_init(&sem, 0, 0); + for (int i = 0; i < 4; i++) { + SockSendSingleRequest(iov, i); + } +} + +bool SockClientRegSglMem() +{ + for (uint16_t i = 0; i < NN_NO4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = sockClientDriver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return false; + } + ClientLocalMrInfo[i].lAddress = mr->GetAddress(); + ClientLocalMrInfo[i].lKey = mr->GetLKey(); + ClientLocalMrInfo[i].size = NN_NO16; + memset(reinterpret_cast(ClientLocalMrInfo[i].lAddress), 0, NN_NO16); + } + + return true; +} + +static void CloseDriver(UBSHcomNetDriver *&driver) +{ + std::string name = driver->Name(); + if (driver->IsStarted()) { + driver->Stop(); + driver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(name); +} + +TEST_F(TestCaseSock, SOCK_CASE_TCP_READWRITE) +{ + bool result = SockServerCreateDriver(); + SOCK_CHECK_RESULT_TRUE(result); + + result = SockServerRegSglMem(); + SOCK_CHECK_RESULT_TRUE(result); + + result = SockClientCreateDriver(); + SOCK_CHECK_RESULT_TRUE(result); + result = SockClientConnect(); + SOCK_CHECK_RESULT_TRUE(result); + result = SockClientRegSglMem(); + SOCK_CHECK_RESULT_TRUE(result); + SockSendRequest(); + CloseDriver(sockClientDriver); + CloseDriver(sockServerDriver); +} diff --git a/test/llt/testcase/transport/sock/test_sock.hpp b/test/llt/testcase/transport/sock/test_sock.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1d814214588d49d322d1b019a9011eae9dde78f8 --- /dev/null +++ b/test/llt/testcase/transport/sock/test_sock.hpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef _TEST_SOCK_HPP_ +#define _TEST_SOCK_HPP_ +#include + +class TestCaseSock : public testing::Test { +public: + TestCaseSock(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; + +#endif diff --git a/test/llt/testcase/transport/sock/test_sock_wrapper.cpp b/test/llt/testcase/transport/sock/test_sock_wrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..270880c03d763ff759ec5cdc59d603cd9419f132 --- /dev/null +++ b/test/llt/testcase/transport/sock/test_sock_wrapper.cpp @@ -0,0 +1,866 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "test_sock_wrapper.h" +#include "hcom.h" +#include "net_sock_common.h" +#include "net_oob_ssl.h" + +using namespace ock::hcom; + +TestSockWrapper::TestSockWrapper() {} + +#define BASE_IP "127.0.0.1" +#define IP_SEG "127.0.0.0/16" +static char sendTemp[] = "hello world"; +static UBSHcomNetEndpointPtr sockServerEp = nullptr; + +using TestRegMrInfo = struct _reg_sgl_info_test_ { + uintptr_t lAddress = 0; + uint32_t lKey = 0; + uint32_t size = 0; +} __attribute__((packed)); +static TestRegMrInfo localMrInfo[4]; +static TestRegMrInfo remoteMrInfo[4]; +static TestRegMrInfo serverMrInfo[4]; + +static UBSHcomNetTransSgeIov iovPtr[NN_NO4]; +static uint32_t iovCnt = NN_NO4; +static UBSHcomNetMemoryRegionPtr mr = nullptr; +static UBSHcomNetTransSgeIov iov[NN_NO4]; +sem_t sem_sock; + +static int NewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload << " ep id " << newEP->Id()); + sockServerEp = newEP; + return 0; +} + +static void EndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + NN_LOG_INFO("end point " << ep->Id()); + if (sockServerEp != nullptr) { + sockServerEp.Set(nullptr); + } +} + +static int RequestReceived(const UBSHcomNetRequestContext &ctx) // 0 +{ + return 0; +} + +static int RequestReceivedWithSend(const UBSHcomNetRequestContext &ctx) // 1 +{ + static char data[100] = {}; + + UBSHcomNetTransRequest req((void *)(data), sizeof(data), 0); + sockServerEp->PostSend(1, req); + return 0; +} + +static int RequestReceivedSglClient(const UBSHcomNetRequestContext &ctx) // 2 +{ + memcpy(remoteMrInfo, ctx.Message()->Data(), ctx.Message()->DataLen()); + NN_LOG_INFO("get remote Mr info"); + for (uint16_t i = 0; i < NN_NO4; i++) { + NN_LOG_INFO("idx:" << i << " key:" << remoteMrInfo[i].lKey << " address:" << remoteMrInfo[i].lAddress << + " size" << remoteMrInfo[i].size); + } + sem_post(&sem_sock); + return 0; +} + +static int RequestReceivedSglServer(const UBSHcomNetRequestContext &ctx) // 3 +{ + NN_LOG_INFO("request received - " << ctx.Header().opCode << ", dataLen " << ctx.Header().dataLength); + + int result = 0; + UBSHcomNetTransRequest rsp((void *)(serverMrInfo), sizeof(serverMrInfo), 0); + if ((result = sockServerEp->PostSend(1, rsp)) != 0) { + NN_LOG_ERROR("failed to post message to data to server, result " << result); + return result; + } + + NN_LOG_INFO("request rsp Mr info"); + for (uint16_t i = 0; i < NN_NO4; i++) { + NN_LOG_INFO("idx:" << i << " key:" << serverMrInfo[i].lKey << " address:" << serverMrInfo[i].lAddress << + " size" << serverMrInfo[i].size); + } + return 0; +} +static int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +static int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + sem_post(&sem_sock); + return 0; +} + +static void CreateServerMR(UBSHcomNetDriver *driver, std::vector &mrs) +{ + for (uint16_t i = 0; i < NN_NO4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return; + } + serverMrInfo[i].lAddress = mr->GetAddress(); + serverMrInfo[i].lKey = mr->GetLKey(); + serverMrInfo[i].size = NN_NO16; + mrs.push_back(mr); + memset(reinterpret_cast(serverMrInfo[i].lAddress), 0, NN_NO16); + } +} + +static void CreateClientMR(UBSHcomNetDriver *driver, std::vector &mrs) +{ + for (uint16_t i = 0; i < NN_NO4; i++) { + UBSHcomNetMemoryRegionPtr mr; + auto result = driver->CreateMemoryRegion(NN_NO16, mr); + if (result != NN_OK) { + NN_LOG_ERROR("reg mr failed"); + return; + } + localMrInfo[i].lAddress = mr->GetAddress(); + localMrInfo[i].lKey = mr->GetLKey(); + localMrInfo[i].size = NN_NO16; + mrs.push_back(mr); + memset(reinterpret_cast(localMrInfo[i].lAddress), 0, NN_NO16); + } +} + +static void SockWrapperDestoryMem(UBSHcomNetDriver *driver, std::vector &mrs) +{ + while (!mrs.empty()) { + driver->DestroyMemoryRegion(mrs.back()); + mrs.pop_back(); + } +} + + +static void SetCB(UBSHcomNetDriver *driver, bool isServer, uint8_t reqHandlerMode) +{ + if (isServer) { + driver->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + } + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + switch (reqHandlerMode) { + case 0: + driver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + break; + case 1: + driver->RegisterNewReqHandler(std::bind(&RequestReceivedWithSend, std::placeholders::_1)); + break; + case 2: + driver->RegisterNewReqHandler(std::bind(&RequestReceivedSglClient, std::placeholders::_1)); + break; + case 3: + driver->RegisterNewReqHandler(std::bind(&RequestReceivedSglServer, std::placeholders::_1)); + break; + } + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); +} + +static int TestNewConnectionHandler(ock::hcom::OOBTCPConnection &conn1) +{ + char *sendBuff = sendTemp; + NResult ret = conn1.Send(sendBuff, strlen(sendBuff)); + EXPECT_EQ(ret, ock::hcom::NN_OK); + return ret; +} + +static void SetDriverOptions(UBSHcomNetDriverOptions &sockOptions) +{ + sockOptions.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + sockOptions.SetNetDeviceIpMask(IP_SEG); + sockOptions.pollingBatchSize = 16; + sockOptions.enableTls = false; + sockOptions.SetWorkerGroups("1"); + sockOptions.SetWorkerGroupsCpuSet("10-10"); +} + +static int SockConnect(uint16_t port) +{ + OOBTCPConnection *conn = nullptr; + OOBTCPServer oobServer(BASE_IP, port); + oobServer.SetNewConnCB(std::bind(&TestNewConnectionHandler, std::placeholders::_1)); + UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + oobServer.Start(); + ock::hcom::OOBTCPClientPtr client = new (std::nothrow) OOBTCPClient(BASE_IP, port); + + client->Connect(BASE_IP, port, conn); + auto connFd = conn->TransferFd(); + return connFd; +} + +static void CloseDriver(UBSHcomNetDriver *&driver) +{ + std::string name = driver->Name(); + if (driver->IsStarted()) { + driver->Stop(); + driver->UnInitialize(); + } + UBSHcomNetDriver::DestroyInstance(name); +} +void TestSockWrapper::SetUp() +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); +} +void TestSockWrapper::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestSockWrapper, SockInitializeSuccess) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-1", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result; + result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_OK, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockInitializeFailedWithInvalidType) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_UDS_TCP, "sock-2", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result; + result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_PARAM_INVALID, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockInitializeTwice) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-3", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result; + result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_OK, result); + result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_OK, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockInitializeFailedWithInvalidFd) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-4", newSockId, -1, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_PARAM_INVALID, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockInitializeFailedWithInvalidReceiveBuf) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-5", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + MOCKER(setsockopt).defaults().will(returnValue(-1)); + SResult result; + result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_TCP_SET_OPTION_FAILED, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockInitializeFailedWithInvalidSendBuf) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-6", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + MOCKER(setsockopt).defaults().will(returnValue(0)).then(returnValue(-1)); + SResult result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_TCP_SET_OPTION_FAILED, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockInitializeFailedWithUDS) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_UDS, "sock-7", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result; + result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(NN_OK, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockInitializeFailedWithExpand) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + option.receiveBufSizeKB = 0; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-8", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + sockWorkerOptions.keepaliveIdleTime = -1; + sockWorkerOptions.keepaliveProbeInterval = -1; + sockWorkerOptions.keepaliveProbeTimes = -1; + SResult result; + result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_TCP_SET_OPTION_FAILED, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockInitializeFailedWithNoDelay) +{ + auto connFd = SockConnect(9981); + SockOptions option{}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-9", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions{}; + MOCKER(setsockopt) + .defaults() + .will(returnValue(0)) + .then(returnValue(0)) + .then(returnValue(0)) + .then(returnValue(0)) + .then(returnValue(-1)); + SResult result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_TCP_SET_OPTION_FAILED, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockSendSuccess) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-10", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + sockWorkerOptions.tcpEnableNoDelay = false; + SResult result; + sock->Initialize(sockWorkerOptions); + std::string payload = "hello world"; + void *tmpBuf = const_cast(payload.c_str()); + result = sock->Send(tmpBuf, payload.length()); + EXPECT_EQ(NN_OK, result); + char receiveBuf[payload.length() + 1]; + bzero(receiveBuf, payload.length() + 1); + void *buff = reinterpret_cast(receiveBuf); + result = sock->Receive(buff, payload.length()); + std::string receivePayload = reinterpret_cast(receiveBuf); + EXPECT_EQ(SS_OK, result); + EXPECT_EQ(payload, receivePayload); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockSendFailedWithInvalidFdBuf) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-11", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result; + result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_OK, result); + void *tmpBuf1 = nullptr; + result = sock->Send(tmpBuf1, 0); + EXPECT_EQ(SS_PARAM_INVALID, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockReceiveFailedWithInvalidFdBuf) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-12", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result; + sock->Initialize(sockWorkerOptions); + std::string payload = "hello world"; + void *tmpBuf = const_cast(payload.c_str()); + sock->Send(tmpBuf, payload.length()); + void *receiveBuf1 = nullptr; + result = sock->Receive(receiveBuf1, payload.length()); + EXPECT_EQ(SS_PARAM_INVALID, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockReceiveFailedWithInvalidSize2) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-14", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + sock->Initialize(sockWorkerOptions); + std::string payload = "hello world"; + void *tmpBuf = const_cast(payload.c_str()); + sock->Send(tmpBuf, payload.length()); + auto receiveBuf = memalign(NN_NO1024, payload.length()); + MOCKER(::recv).defaults().will(returnValue(payload.length() - 1)); + SResult result = sock->Receive(receiveBuf, payload.length()); + EXPECT_EQ(SS_SOCK_DATA_SIZE_UN_MATCHED, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostSendSglFailed) +{ + auto connFd = SockConnect(9982); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-15", newSockId, connFd, option); + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result; + sock->Initialize(sockWorkerOptions); + UBSHcomNetTransSglRequest req(iovPtr, 0, 0); + + UBSHcomNetTransHeader header {}; + header.immData = 1; + header.seqNo = 1; + header.flags = NTH_TWO_SIDE_SGL; + header.dataLength = 0; + for (uint16_t i = 0; i < req.iovCount; i++) { + header.dataLength += req.iov[i].size; + } + + result = sock->PostSendSgl(header, req); + EXPECT_EQ(SS_OK, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostSendInvalidAddress) +{ + auto connFd = SockConnect(9983); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-16", newSockId, connFd, option); + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result; + sock->Initialize(sockWorkerOptions); + std::string payload = "hello world"; + static char data[1023] = {}; + UBSHcomNetTransRequest req(0, sizeof(data), 0); + UBSHcomNetTransHeader header {}; + header.opCode = 1; + header.seqNo = 1; + header.flags = NTH_TWO_SIDE; + header.dataLength = req.size; + result = sock->PostSend(header, req); + EXPECT_EQ(SS_SOCK_SEND_FAILED, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostReceiveHeaderNormalReceive) +{ + static char data[1023] = {}; + UBSHcomNetTransRequest req(0, sizeof(data), 0); + SockTransHeader header {}; + header.dataLength = req.size; + auto connFd = SockConnect(9983); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "test", newSockId, connFd, option); + MOCKER_CPP(setsockopt).stubs().will(returnValue(0)); + MOCKER_CPP(::recv).stubs().will(returnValue(sizeof(SockTransHeader))); + MOCKER_CPP(NetFunc::ValidateHeader).stubs().will(returnValue(0)); + + SResult result = sock->PostReceiveHeader(header, 1); + EXPECT_EQ(SS_OK, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostReceiveHeaderMultipleReceives) +{ + static char data[1023] = {}; + UBSHcomNetTransRequest req(0, sizeof(data), 0); + SockTransHeader header {}; + header.dataLength = req.size; + auto connFd = SockConnect(9983); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "test", newSockId, connFd, option); + MOCKER_CPP(setsockopt).stubs().will(returnValue(0)); + MOCKER_CPP(::recv).stubs() + .will(returnValue(sizeof(SockTransHeader) / NN_NO2)); + MOCKER_CPP(NetFunc::ValidateHeader).stubs().will(returnValue(0)); + + SResult result = sock->PostReceiveHeader(header, 1); + EXPECT_EQ(SS_OK, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostReceiveBodyNormalReceive) +{ + void *buff = malloc(NN_NO1024); + uint32_t dataLength = NN_NO1024; + bool isOneSide = true; + auto connFd = SockConnect(9983); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "test", newSockId, connFd, option); + sock->mEnableTls = false; + MOCKER_CPP(::recv).stubs().will(returnValue(dataLength)); + + SResult result = sock->PostReceiveBody(buff, dataLength, isOneSide); + EXPECT_EQ(SS_OK, result); + free(buff); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostReceiveBodyMultipleReceives) +{ + void *buff = malloc(NN_NO1024); + uint32_t dataLength = NN_NO1024; + bool isOneSide = true; + auto connFd = SockConnect(9983); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "test", newSockId, connFd, option); + sock->mEnableTls = false; + MOCKER_CPP(::recv).stubs().will(returnValue(dataLength / NN_NO2)); + + SResult result = sock->PostReceiveBody(buff, dataLength, isOneSide); + EXPECT_EQ(SS_OK, result); + free(buff); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostReceiveBodyTlsNormalReceive) +{ + void *buff = malloc(NN_NO1024); + uint32_t dataLength = NN_NO1024; + bool isOneSide = false; + auto connFd = SockConnect(9983); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "test", newSockId, connFd, option); + sock->mEnableTls = true; + MOCKER_CPP(HcomSsl::SslRead).stubs().will(returnValue(static_cast(dataLength))); + + SResult result = sock->PostReceiveBody(buff, dataLength, isOneSide); + EXPECT_EQ(SS_OK, result); + free(buff); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostReceiveBodyTlsMultipleReceives) +{ + void *buff = malloc(NN_NO1024); + uint32_t dataLength = NN_NO1024; + bool isOneSide = false; + auto connFd = SockConnect(9983); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "test", newSockId, connFd, option); + sock->mEnableTls = true; + MOCKER_CPP(HcomSsl::SslRead).stubs().will(returnValue(static_cast(dataLength / NN_NO2))); + + SResult result = sock->PostReceiveBody(buff, dataLength, isOneSide); + EXPECT_EQ(SS_OK, result); + free(buff); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostReceiveHeaderSuccess) +{ + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + + UBSHcomNetDriverOptions sockOptions {}; + SetDriverOptions(sockOptions); + NResult result; + serverDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "sock-server-9000", true); + SetCB(serverDriver, true, 1); + serverDriver->OobIpAndPort(BASE_IP, 9990); + serverDriver->Initialize(sockOptions); + serverDriver->Start(); + + clientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "sock-client-9000", false); + clientDriver->OobIpAndPort(BASE_IP, 9990); + sockOptions.dontStartWorkers = true; + clientDriver->Initialize(sockOptions); + clientDriver->Start(); + result = clientDriver->Connect("hello world", clientEp, NET_EP_SELF_POLLING); + EXPECT_EQ(SS_OK, result); + + static char data[100] = {}; + UBSHcomNetResponseContext respCtx {}; + UBSHcomNetTransRequest req((void *)(data), sizeof(data), 0); + result = clientEp->PostSend(1, req); + EXPECT_EQ(SS_OK, result); + result = clientEp->Receive(2, respCtx); + EXPECT_EQ(SS_OK, result); + + CloseDriver(serverDriver); + CloseDriver(clientDriver); +} + +TEST_F(TestSockWrapper, SockPostWriteSglSuccess) +{ + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + + UBSHcomNetDriverOptions sockOptions {}; + SetDriverOptions(sockOptions); + NResult result; + + sem_init(&sem_sock, 0, 0); + serverDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "sock-server-9992", true); + SetCB(serverDriver, true, 3); + serverDriver->OobIpAndPort(BASE_IP, 9992); + serverDriver->Initialize(sockOptions); + result = serverDriver->Start(); + EXPECT_EQ(result, SS_OK); + std::vector serverMrs; + CreateServerMR(serverDriver, serverMrs); + + clientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "sock-client-9992", false); + SetCB(clientDriver, false, 2); + clientDriver->OobIpAndPort(BASE_IP, 9992); + clientDriver->Initialize(sockOptions); + clientDriver->Start(); + clientDriver->Connect("hello world", clientEp, 0); + std::vector clientMrs; + CreateClientMR(clientDriver, clientMrs); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + if ((result = clientEp->PostSend(1, req)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + } + sem_wait(&sem_sock); + + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = localMrInfo[i].lAddress; + iov[i].rAddress = remoteMrInfo[i].lAddress; + iov[i].lKey = localMrInfo[i].lKey; + iov[i].rKey = remoteMrInfo[i].lKey; + iov[i].size = NN_NO16; + } + UBSHcomNetTransSglRequest reqRead(iov, NN_NO4, 0); + result = clientEp->PostRead(reqRead); + sem_wait(&sem_sock); + EXPECT_EQ(SS_OK, result); + UBSHcomNetTransSglRequest reqWrite(iov, NN_NO4, 0); + result = clientEp->PostWrite(reqWrite); + sem_wait(&sem_sock); + EXPECT_EQ(SS_OK, result); + + SockWrapperDestoryMem(serverDriver, serverMrs); + SockWrapperDestoryMem(clientDriver, clientMrs); + CloseDriver(serverDriver); + CloseDriver(clientDriver); +} + +TEST_F(TestSockWrapper, SockPostWriteSglFailedRead) +{ + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + + NResult result; + UBSHcomNetDriverOptions sockOptions {}; + SetDriverOptions(sockOptions); + sem_init(&sem_sock, 0, 0); + + serverDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "sock-server-9993", true); + SetCB(serverDriver, true, 3); + serverDriver->OobIpAndPort(BASE_IP, 9993); + serverDriver->Initialize(sockOptions); + serverDriver->Start(); + std::vector serverMrs; + CreateServerMR(serverDriver, serverMrs); + + clientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "sock-client-9993", false); + SetCB(clientDriver, false, 2); + clientDriver->OobIpAndPort(BASE_IP, 9993); + clientDriver->Initialize(sockOptions); + clientDriver->Start(); + clientDriver->Connect("hello world", clientEp, 0); + UBSHcomEpOptions epOptions {}; + epOptions.tcpBlockingIo = true; + clientEp->SetEpOption(epOptions); + std::vector clientMrs; + CreateClientMR(clientDriver, clientMrs); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + if ((result = clientEp->PostSend(1, req)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + } + sem_wait(&sem_sock); + + ssize_t res = -1; + MOCKER(writev).defaults().will(returnValue(res)); + + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = localMrInfo[i].lAddress; + iov[i].rAddress = remoteMrInfo[i].lAddress; + iov[i].lKey = localMrInfo[i].lKey; + iov[i].rKey = remoteMrInfo[i].lKey; + iov[i].size = NN_NO16; + } + UBSHcomNetTransSglRequest reqRead(iov, NN_NO4, 0); + result = clientEp->PostRead(reqRead); + EXPECT_EQ(SS_TIMEOUT, result); + + SockWrapperDestoryMem(serverDriver, serverMrs); + SockWrapperDestoryMem(clientDriver, clientMrs); + CloseDriver(serverDriver); + CloseDriver(clientDriver); +} + +TEST_F(TestSockWrapper, SockInitializeFailedWithInvalidAliveTime) +{ + auto connFd = SockConnect(9984); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-17", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + sockWorkerOptions.keepaliveIdleTime = -1; + sockWorkerOptions.keepaliveProbeInterval = -1; + sockWorkerOptions.keepaliveProbeTimes = -1; + SResult result; + result = sock->Initialize(sockWorkerOptions); + EXPECT_EQ(SS_TCP_SET_OPTION_FAILED, result); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestSockWrapper, SockPostWriteSglFailedWrite) +{ + UBSHcomNetDriver *serverDriver = nullptr; + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + + NResult result; + UBSHcomNetDriverOptions sockOptions {}; + SetDriverOptions(sockOptions); + + sem_init(&sem_sock, 0, 0); + serverDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "sock-server-9994", true); + SetCB(serverDriver, true, 3); + serverDriver->OobIpAndPort(BASE_IP, 9994); + serverDriver->Initialize(sockOptions); + serverDriver->Start(); + std::vector serverMrs; + CreateServerMR(serverDriver, serverMrs); + + clientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "sock-client-9994", false); + SetCB(clientDriver, false, 2); + clientDriver->OobIpAndPort(BASE_IP, 9994); + clientDriver->Initialize(sockOptions); + clientDriver->Start(); + clientDriver->Connect("hello world", clientEp, 0); + UBSHcomEpOptions epOptions {}; + epOptions.tcpBlockingIo = true; + clientEp->SetEpOption(epOptions); + std::vector clientMrs; + CreateClientMR(clientDriver, clientMrs); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + if ((result = clientEp->PostSend(1, req)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + } + sem_wait(&sem_sock); + + for (uint16_t i = 0; i < NN_NO4; i++) { + iov[i].lAddress = localMrInfo[i].lAddress; + iov[i].rAddress = remoteMrInfo[i].lAddress; + iov[i].lKey = localMrInfo[i].lKey; + iov[i].rKey = remoteMrInfo[i].lKey; + iov[i].size = NN_NO16; + } + UBSHcomNetTransSglRequest reqRead(iov, NN_NO4, 0); + result = clientEp->PostRead(reqRead); + sem_wait(&sem_sock); + EXPECT_EQ(SS_OK, result); + + SockWrapperDestoryMem(serverDriver, serverMrs); + SockWrapperDestoryMem(clientDriver, clientMrs); + CloseDriver(serverDriver); + CloseDriver(clientDriver); +} + +TEST_F(TestSockWrapper, SockSendFail) +{ + auto connFd = SockConnect(9981); + SockOptions option {}; + uint64_t newSockId = NetUuid::GenerateUuid(); + auto sock = new (std::nothrow) Sock(SockType::SOCK_TCP, "sock-18", newSockId, connFd, option); + + NetLocalAutoDecreasePtr autoDecSock(sock); + SockWorkerOptions sockWorkerOptions {}; + SResult result; + sock->Initialize(sockWorkerOptions); + std::string payload = "hello world"; + void *tmpBuf = const_cast(payload.c_str()); + ssize_t res = -1; + MOCKER(::send).defaults().will(returnValue(res)); + result = sock->SendRealConnHeader(connFd, tmpBuf, payload.length()); + EXPECT_EQ(SS_SOCK_SEND_FAILED, result); + NetFunc::NN_SafeCloseFd(connFd); +} diff --git a/test/llt/testcase/transport/sock/test_sock_wrapper.h b/test/llt/testcase/transport/sock/test_sock_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..6a97c293505fdb599f97b4ccf8f447d789483154 --- /dev/null +++ b/test/llt/testcase/transport/sock/test_sock_wrapper.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HCOM_TEST_SOCK_WRAPPER_H +#define HCOM_TEST_SOCK_WRAPPER_H +#include +#include + +class TestSockWrapper : public testing::Test { +public: + TestSockWrapper(); + virtual void SetUp(void); + virtual void TearDown(void); + +protected: +}; +#endif // HCOM_TEST_SOCK_WRAPPER_H \ No newline at end of file diff --git a/test/llt/testcase/transport/test_load_balance.cpp b/test/llt/testcase/transport/test_load_balance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..025067bff39c9c4bd2de52ce6e800aadbfbc4762 --- /dev/null +++ b/test/llt/testcase/transport/test_load_balance.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "test_load_balance.h" +#include "ut_helper.h" +#include "hcom_def.h" +#include "mockcpp/mockcpp.hpp" + +using namespace ock::hcom; +TestLoadBalance::TestLoadBalance() {} +uint16_t TestLoadBalance::basePort = 8899; +void TestLoadBalance::SetUp() +{ + MOCK_VERSION +} + +void TestLoadBalance::TearDown() {} + +TEST_F(TestLoadBalance, OK) +{ + sem_t sem; + bool result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *server = nullptr, *client = nullptr; + std::unordered_map semMap; + Handlers handlers {}; + handlers.receivedHandler = [&](const UBSHcomNetRequestContext &ctx) -> int { + sem_post(&sem); + return 0; + }; + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.mrSendReceiveSegSize = 1024; + options.mrSendReceiveSegCount = 8192; + options.lbPolicy = ock::hcom::NET_HASH_IP_PORT; + strcpy(options.workerGroups, "1,3,3"); + strcpy(options.workerGroupsCpuSet, "10-10,11-13,na"); + options.SetNetDeviceIpMask(IP_SEG); + result = UTHelper::ServerCreateDriver(server, handlers, options, ++basePort); + ASSERT_EQ(result, true); + result = UTHelper::ClientCreateDriver(client, handlers, options, basePort); + ASSERT_EQ(result, true); + result = UTHelper::ClientConnect(client, ep, 0); + ASSERT_EQ(result, true); + result = UTHelper::ClientSend(ep, &sem); + ASSERT_EQ(result, true); +} + +TEST_F(TestLoadBalance, WrongGroups) +{ + sem_t sem; + bool result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *server = nullptr, *client = nullptr; + std::unordered_map semMap; + Handlers handlers {}; + handlers.receivedHandler = [&](const UBSHcomNetRequestContext &ctx) -> int { + sem_post(&sem); + return 0; + }; + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.mrSendReceiveSegSize = 1024; + options.mrSendReceiveSegCount = 8192; + options.lbPolicy = ock::hcom::NET_HASH_IP_PORT; + strcpy(options.workerGroups, "1,3"); + strcpy(options.workerGroupsCpuSet, "10-11,12-20,12-20,12-20,1-20"); + options.SetNetDeviceIpMask(IP_SEG); + result = UTHelper::ServerCreateDriver(server, handlers, options, ++basePort); + EXPECT_EQ(result, false); + result = UTHelper::ClientCreateDriver(client, handlers, options, basePort); + EXPECT_EQ(result, false); +} + +TEST_F(TestLoadBalance, WrongPolicy) +{ + sem_t sem; + bool result; + UBSHcomNetEndpointPtr ep = nullptr; + UBSHcomNetDriver *server = nullptr, *client = nullptr; + std::unordered_map semMap; + Handlers handlers {}; + handlers.receivedHandler = [&](const UBSHcomNetRequestContext &ctx) -> int { + sem_post(&sem); + return 0; + }; + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.mrSendReceiveSegSize = 1024; + options.mrSendReceiveSegCount = 8192; + options.lbPolicy = (UBSHcomNetDriverLBPolicy)3; + options.SetNetDeviceIpMask(IP_SEG); + // wrong policy check will fail and result is false + result = UTHelper::ClientCreateDriver(client, handlers, options, ++basePort); + ASSERT_EQ(result, false); + result = UTHelper::ServerCreateDriver(server, handlers, options, basePort); + ASSERT_EQ(result, false); + result = UTHelper::ClientConnect(client, ep, 0); + EXPECT_EQ(result, false); +} \ No newline at end of file diff --git a/test/llt/testcase/transport/test_load_balance.h b/test/llt/testcase/transport/test_load_balance.h new file mode 100644 index 0000000000000000000000000000000000000000..cc99456e1f30145d58f864a7c9aa5fbd6e615c02 --- /dev/null +++ b/test/llt/testcase/transport/test_load_balance.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_LOAD_BALANCE_H +#define HCOM_TEST_LOAD_BALANCE_H + +#include + +class TestLoadBalance : public testing::Test { +public: + TestLoadBalance(); + virtual void SetUp(void); + virtual void TearDown(void); + static uint16_t basePort; +}; + +#endif // HCOM_TEST_LOAD_BALANCE_H diff --git a/test/llt/testcase/transport/test_memory_region.cpp b/test/llt/testcase/transport/test_memory_region.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c097014127bdeca4ca823c480a7e2dbfa8fdfdba --- /dev/null +++ b/test/llt/testcase/transport/test_memory_region.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "test_memory_region.h" +#include "ut_helper.h" +#include "hcom_def.h" +#include "transport/net_memory_region.h" +#include "mockcpp/mockcpp.hpp" + +using namespace ock::hcom; +TestMemoryRegion::TestMemoryRegion() {} +void TestMemoryRegion::SetUp() +{ + MOCK_VERSION +} + +void TestMemoryRegion::TearDown() {} + +TEST_F(TestMemoryRegion, OK) +{ + NResult result; + NormalMemoryRegion *mr; + result = NormalMemoryRegion::Create("mr", NN_NO4096, mr); + EXPECT_EQ(result, NN_OK); + EXPECT_EQ(mr->Size(), NN_NO4096); + result = mr->Initialize(); + EXPECT_EQ(result, NN_OK); + mr->UnInitialize(); + + void *extMem = malloc(sizeof(NN_NO4096)); + result = NormalMemoryRegion::Create("mr1", (uintptr_t)extMem, NN_NO4096, mr); + EXPECT_EQ(result, NN_OK); + EXPECT_EQ(mr->Size(), NN_NO4096); + EXPECT_EQ(mr->GetAddress(), (uintptr_t)extMem); + + NormalMemoryRegionFixedBuffer *fixedBuffer; + result = NormalMemoryRegionFixedBuffer::Create("mr", NN_NO4096, NN_NO8, fixedBuffer); + EXPECT_EQ(result, NN_OK); + result = fixedBuffer->Initialize(); + EXPECT_EQ(result, NN_OK); + EXPECT_EQ(fixedBuffer->Size(), NN_NO4096 * NN_NO8); + EXPECT_EQ(fixedBuffer->GetFreeBufferCount(), NN_NO8); + uintptr_t item; + auto ret = fixedBuffer->GetFreeBuffer(item); + EXPECT_EQ(ret, true); + EXPECT_EQ(fixedBuffer->GetFreeBufferCount(), NN_NO8 - 1); + uintptr_t items[7]; + uintptr_t *itemsPtr = &items[0]; + ret = fixedBuffer->GetFreeBufferN(itemsPtr, 7); + EXPECT_EQ(ret, true); + EXPECT_EQ(fixedBuffer->GetFreeBufferCount(), 0); + ret = fixedBuffer->GetFreeBuffer(item); + EXPECT_EQ(ret, false); + fixedBuffer->UnInitialize(); +} + +TEST_F(TestMemoryRegion, Fail) +{ + NResult result; + NormalMemoryRegion *mr = nullptr; + result = NormalMemoryRegion::Create("mr", 0, mr); + EXPECT_NE(result, NN_OK); + + NormalMemoryRegion *mr1 = (NormalMemoryRegion *)malloc(sizeof(NormalMemoryRegion)); + result = NormalMemoryRegion::Create("mr1", 0, NN_NO4096, mr1); + EXPECT_NE(result, NN_OK); + result = NormalMemoryRegion::Create("mr1", (uintptr_t)mr1, 0, mr1); + EXPECT_NE(result, NN_OK); +} \ No newline at end of file diff --git a/test/llt/testcase/transport/test_memory_region.h b/test/llt/testcase/transport/test_memory_region.h new file mode 100644 index 0000000000000000000000000000000000000000..8de510a9ae9dfa7f7432b95d1b04832d484fe2ef --- /dev/null +++ b/test/llt/testcase/transport/test_memory_region.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_MEMORY_REGION_H +#define HCOM_TEST_MEMORY_REGION_H + +#include + +class TestMemoryRegion : public testing::Test { +public: + TestMemoryRegion(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +#endif // HCOM_TEST_MEMORY_REGION_H diff --git a/test/llt/testcase/transport/test_net_oob.cpp b/test/llt/testcase/transport/test_net_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3785fed845aca468b711df5739250d7b300e2dc --- /dev/null +++ b/test/llt/testcase/transport/test_net_oob.cpp @@ -0,0 +1,607 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "hcom.h" +#include "test_net_oob.h" +#include "transport/rdma/rdma_common.h" +#include +#include +#include + +namespace ock { +namespace hcom { + +#define BASE_IP "127.0.0.1" +#define IP_SEG "127.0.0.0/16" +char sendTemp[] = "SendSuccess"; + +int testNewConnectionHandler(ock::hcom::OOBTCPConnection &conn) +{ + char *sendBuff = sendTemp; + NResult ret = conn.Send(sendBuff, strlen(sendBuff)); + EXPECT_EQ(ret, ock::hcom::NN_OK); + return ret; +} + +int testNewConnectionHandlerEmpty(ock::hcom::OOBTCPConnection &conn) +{ + NResult ret = conn.Send(nullptr, 0); + EXPECT_EQ(ret, ock::hcom::NN_PARAM_INVALID); + return ret; +} + +int testNewConnectionHandlerFailure(ock::hcom::OOBTCPConnection &conn) +{ + return -1; +} + +std::string GetFilPrefixEnv() +{ + char path[255]; + getcwd(path, 255); + std::string pathStr = path; + std::string envString = "HCOM_FILE_PATH_PREFIX=" + pathStr; + return envString; +} + +TEST_F(TestNetOob, ConnectSuccess) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = new (std::nothrow) ock::hcom::OOBTCPClient(BASE_IP, 9444); + NResult ret1 = client->Connect(BASE_IP, 9444, conn); + EXPECT_EQ(ret1, ock::hcom::NN_OK); +} + +TEST_F(TestNetOob, ConnectSuccessWithFailToAcceptSocket) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(::accept).defaults().will(returnValue(-1)); + MOCKER(::recv).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OK); +} + +TEST_F(TestNetOob, SendSuccess) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = new (std::nothrow) ock::hcom::OOBTCPClient(BASE_IP, 9444); + NResult ret1 = client->Connect(BASE_IP, 9444, conn); + ASSERT_EQ(ret1, ock::hcom::NN_OK); + char revTemp[1024]; + void *revBuff = (void *)revTemp; + NResult ret3 = conn->Receive(revBuff, strlen(sendTemp)); + EXPECT_EQ(ret3, ock::hcom::NN_OK); +} + +TEST_F(TestNetOob, SendFailureWithFailToHandshakeWithClient) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandlerFailure, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + ASSERT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = new (std::nothrow) ock::hcom::OOBTCPClient(BASE_IP, 9444); + NResult ret1 = client->Connect(BASE_IP, 9444, conn); + EXPECT_EQ(ret1, ock::hcom::NN_OK); + auto connFd = conn->TransferFd(); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestNetOob, SendFailureWithEmptyContent) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandlerEmpty, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + ASSERT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = new (std::nothrow) ock::hcom::OOBTCPClient(BASE_IP, 9444); + NResult ret1 = client->Connect(BASE_IP, 9444, conn); + EXPECT_EQ(ret1, ock::hcom::NN_OK); + auto connFd = conn->TransferFd(); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestNetOob, ReceiveFailureWithEmptyContent) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + ASSERT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = new (std::nothrow) ock::hcom::OOBTCPClient(BASE_IP, 9444); + NResult ret1 = client->Connect(BASE_IP, 9444, conn); + ASSERT_EQ(ret1, ock::hcom::NN_OK); + void *revBuff = nullptr; + NResult ret3 = conn->Receive(revBuff, strlen(sendTemp)); + EXPECT_EQ(ret3, ock::hcom::NN_PARAM_INVALID); + auto connFd = conn->TransferFd(); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestNetOob, ReceiveFailureWithUnmatchedSize) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + ASSERT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = new (std::nothrow) ock::hcom::OOBTCPClient(BASE_IP, 9444); + NResult ret1 = client->Connect(BASE_IP, 9444, conn); + ASSERT_EQ(ret1, ock::hcom::NN_OK); + char revTemp[1024]; + void *revBuff = (void *)revTemp; + NResult ret3 = conn->Receive(revBuff, strlen(sendTemp) + 1); + EXPECT_EQ(ret3, ock::hcom::NN_OOB_CONN_RECEIVE_ERROR); + auto connFd = conn->TransferFd(); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestNetOob, ConnectFailure) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + ASSERT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = new (std::nothrow) ock::hcom::OOBTCPClient(BASE_IP, 9444); + MOCKER(::connect).defaults().will(returnValue(-1)); + NResult ret1 = client->Connect(BASE_IP, 9444, conn); + EXPECT_EQ(ret1, ock::hcom::NN_OOB_CLIENT_SOCKET_ERROR); +} + +TEST_F(TestNetOob, ConnectFailureWithFailToCreateSocketInClient) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + ASSERT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = new (std::nothrow) ock::hcom::OOBTCPClient(BASE_IP, 9444); + MOCKER(::socket).defaults().will(returnValue(-1)); + NResult ret1 = client->Connect(BASE_IP, 9444, conn); + EXPECT_EQ(ret1, ock::hcom::NN_OOB_CLIENT_SOCKET_ERROR); +} + +TEST_F(TestNetOob, StartSuccess) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OK); +} + +TEST_F(TestNetOob, StartSuccessWithFailToSetThreadNameOfOobServer) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(pthread_setname_np).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OK); +} + +TEST_F(TestNetOob, StartFailureWithFailedToSetNewConnectionCallBack) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OOB_CONN_CB_NOT_SET); +} + +TEST_F(TestNetOob, StartFailureWithFailedToSetLoadBalancer) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_INVALID_PARAM); +} + +TEST_F(TestNetOob, StartFailureWithInvalidOobType) +{ + ock::hcom::OOBTCPServer oobServer("", 0); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_INVALID_PARAM); +} + +TEST_F(TestNetOob, StartFailureWithFailedToCreateListenSocket) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(::socket).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OOB_LISTEN_SOCKET_ERROR); +} + +TEST_F(TestNetOob, StartFailureWithFailedToSetOption) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(::setsockopt).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OOB_LISTEN_SOCKET_ERROR); +} + +TEST_F(TestNetOob, StartFailureWithFailedToBind) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(::bind).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OOB_LISTEN_SOCKET_ERROR); +} + +TEST_F(TestNetOob, StartFailureWithFailedToListen) +{ + ock::hcom::OOBTCPServer oobServer(BASE_IP, 9444); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(::listen).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OOB_LISTEN_SOCKET_ERROR); +} + +TEST_F(TestNetOob, StartForUdsFailureWithEmptyFilePath) +{ + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, "", 640); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_INVALID_PARAM); +} + +TEST_F(TestNetOob, StartForUdsFailureWithLongFileName) +{ + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, + "ThisFileNameIs107InLengthThisFileNameIs107InLengthThisFileNameIs107InLengthThisFileNameIs107InLengthXXXXXXX", + 0); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_INVALID_PARAM); +} + +TEST_F(TestNetOob, StartForUdsSuccessWithAbstractPath) +{ + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, "server.socket", 0); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OK); +} + +TEST_F(TestNetOob, StartForUdsFailureWithInvalidPath) +{ + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, "/xxx/xxx/fake.socket", 640); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_INVALID_PARAM); +} + +TEST_F(TestNetOob, StartForUdsSuccess) +{ + std::string envString = GetFilPrefixEnv(); + ::putenv(const_cast(envString.c_str())); + + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, testFile, 640); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OK); +} + +TEST_F(TestNetOob, StartForUdsFailureWithSlashAbstractPath) +{ + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, "/fake.socket", 0); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_INVALID_PARAM); +} + +TEST_F(TestNetOob, StartForUdsSuccessWithFailToSetThreadNameOfOobServer) +{ + std::string envString = GetFilPrefixEnv(); + ::putenv(const_cast(envString.c_str())); + + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, testFile, 640); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(pthread_setname_np).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OK); +} + +TEST_F(TestNetOob, StartForUdsFailureWithFailToUnlinkFile) +{ + std::ofstream file(testFile); + file.close(); + + std::string envString = GetFilPrefixEnv(); + ::putenv(const_cast(envString.c_str())); + + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, testFile, 640); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(unlink).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_INVALID_PARAM); +} + +TEST_F(TestNetOob, StartForUdsFailureWithFailToCreateListenSocket) +{ + std::string envString = GetFilPrefixEnv(); + ::putenv(const_cast(envString.c_str())); + + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, testFile, 640); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(::socket).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OOB_LISTEN_SOCKET_ERROR); +} + +TEST_F(TestNetOob, StartForUdsFailureWithFailToBind) +{ + std::string envString = GetFilPrefixEnv(); + ::putenv(const_cast(envString.c_str())); + + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, testFile, 640); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(::bind).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OOB_LISTEN_SOCKET_ERROR); +} + +TEST_F(TestNetOob, StartForUdsFailureWithFailToListen) +{ + std::string envString = GetFilPrefixEnv(); + ::putenv(const_cast(envString.c_str())); + + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, testFile, 640); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + MOCKER(::listen).defaults().will(returnValue(-1)); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OOB_LISTEN_SOCKET_ERROR); +} + +TEST_F(TestNetOob, ConnectForUdsSuccessWithAbstractPath) +{ + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, "server.socket", 0); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + EXPECT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = + new (std::nothrow) ock::hcom::OOBTCPClient(ock::hcom::NET_OOB_UDS, "client.socket", 0); + NResult ret1 = client->Connect("server.socket", conn); + EXPECT_EQ(ret1, ock::hcom::NN_OK); + auto connFd = conn->TransferFd(); + NetFunc::NN_SafeCloseFd(connFd); +} + +TEST_F(TestNetOob, ConnectForUdsFailureWithEmptyPath) +{ + ock::hcom::OOBTCPServer oobServer(ock::hcom::NET_OOB_UDS, "server.socket", 0); + oobServer.SetNewConnCB(std::bind(&testNewConnectionHandler, std::placeholders::_1)); + ock::hcom::UBSHcomNetDriverOptions mOptions; + auto *lb = new (std::nothrow) ock::hcom::NetWorkerLB("mName", mOptions.lbPolicy, UINT16_MAX); + oobServer.SetWorkerLb(lb); + NResult ret = oobServer.Start(); + ASSERT_EQ(ret, ock::hcom::NN_OK); + ock::hcom::OOBTCPConnection *conn = nullptr; + ock::hcom::OOBTCPClientPtr client = + new (std::nothrow) ock::hcom::OOBTCPClient(ock::hcom::NET_OOB_UDS, "client.socket", 0); + NResult ret1 = client->Connect("", conn); + EXPECT_EQ(ret1, ock::hcom::NN_OOB_CLIENT_SOCKET_ERROR); +} + +/* ep id UT */ +static UBSHcomNetEndpointPtr serverEp = nullptr; +static std::string udsName = "server-conn-epId-"; +static std::atomic_uint64_t driverIndex(0); +static sem_t sem; + +static int NewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + serverEp = newEP; + sem_post(&sem); + return 0; +} +static void EndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + NN_LOG_INFO("end point " << ep->Id() << " broken"); +} +static bool CreateDriver(UBSHcomNetDriver *&driver, uint16_t port, bool isServer, NetDriverOobType oobType) +{ + std::string driverName = ""; + if (isServer) { + driverName = "server-epId-"; + } else { + driverName = "client-epId-"; + } + + if (port > 0) { + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, driverName + std::to_string(port), isServer); + driver->OobIpAndPort(BASE_IP, 10000); + } else { + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::SHM, + driverName + std::to_string(driverIndex.fetch_add(1)), isServer); + } + + if (oobType == NET_OOB_UDS && isServer) { + UBSHcomNetOobUDSListenerOptions listenOpt; + listenOpt.Name(udsName); + listenOpt.perm = 0; + driver->AddOobUdsOptions(listenOpt); + } + + UBSHcomNetDriverOptions options {}; + options.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + options.oobType = oobType; + options.SetNetDeviceIpMask(IP_SEG); + options.pollingBatchSize = 16; + options.SetWorkerGroups("1"); + options.SetWorkerGroupsCpuSet("12-12"); + options.enableTls = false; + + if (isServer) { + driver->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + } + driver->Initialize(options); + driver->Start(); + return true; +} + +static void CloseDriver(UBSHcomNetDriver *&driver) +{ + if (driver->IsStarted()) { + driver->Stop(); + driver->UnInitialize(); + } +} + +TEST_F(TestNetOob, ConnectShmEpId) +{ + MOCKER_CPP(&UBSHcomNetDriver::ValidateHandlesCheck).stubs().will(returnValue(static_cast(SER_OK))); + int count = 40; + UBSHcomNetDriver *serverDriver = nullptr; + CreateDriver(serverDriver, 0, true, NET_OOB_UDS); + + std::unordered_set set; + std::mutex locker; + std::vector ths; + std::atomic_uint16_t cnt(0); + + for (int i = 0; i < count; ++i) { + std::thread th([&]() { + auto index = cnt.fetch_add(1); + UBSHcomNetDriver *clientDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + CreateDriver(clientDriver, 0, false, NET_OOB_UDS); + locker.lock(); + if (index & 1) { + clientDriver->Connect(udsName, 0, "hello world", clientEp, 0); + } else { + clientDriver->Connect(udsName, 0, "hello world", clientEp, NET_EP_SELF_POLLING); + } + if (serverEp->Id() == clientEp->Id()) { + set.insert(clientEp->Id()); + } + locker.unlock(); + CloseDriver(clientDriver); + }); + ths.push_back(std::move(th)); + } + for (int i = 0; i < count; ++i) { + ths[i].join(); + } + EXPECT_EQ(set.size(), count); + CloseDriver(serverDriver); + if (serverEp.Get() != nullptr) { + serverEp.Set(nullptr); + } +} + +TEST_F(TestNetOob, TestClientTcpConnect) +{ + std::string oobIp = "255.255.255.255"; + OOBTCPClientPtr client = new (std::nothrow) OOBTCPClient(NET_OOB_TCP, oobIp, NN_NO8192); + OOBTCPConnection *conn = nullptr; + int result = client->Connect(conn); + EXPECT_EQ(NN_INVALID_IP, result); +} +} +} \ No newline at end of file diff --git a/test/llt/testcase/transport/test_net_oob.h b/test/llt/testcase/transport/test_net_oob.h new file mode 100644 index 0000000000000000000000000000000000000000..a48dd575cec02aad2a3800533d46fbeebc909498 --- /dev/null +++ b/test/llt/testcase/transport/test_net_oob.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_NET_OOB_H +#define HCOM_TEST_NET_OOB_H + +#include +#include + +#include "gtest/gtest.h" +#include "mockcpp/mokc.h" +#include "transport/net_oob.h" + +class TestNetOob : public testing ::Test { +public: + TestNetOob() {}; + + ~TestNetOob() {}; + + void SetUp() + { + char path[255]; + getcwd(path, 255); + std::string pathStr = path; + testFile = pathStr + "/test.socket"; + + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); + }; + + void TearDown() + { + GlobalMockObject::verify(); + }; + +protected: + std::string testFile; +}; + +#endif // HCOM_TEST_NET_OOB_H diff --git a/test/llt/testcase/transport/test_secure.cpp b/test/llt/testcase/transport/test_secure.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05c39a8b6b2d9259765f200ec5eb5bd14212bdcd --- /dev/null +++ b/test/llt/testcase/transport/test_secure.cpp @@ -0,0 +1,878 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "test_secure.h" +#include "hcom.h" +#include "net_sock_common.h" + +using namespace ock::hcom; + +#define BASE_IP "127.0.0.1" +#define IP_SEG "127.0.0.0/16" +int ipPort = 6550; +static UBSHcomNetEndpointPtr serverEp; + +static int NewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload); + serverEp = newEP; + return 0; +} +static void EndPointBroken(const UBSHcomNetEndpointPtr &ep) {} +static int RequestReceived(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} +static int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} +static int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +// one way, provider registered, return valid +static int SecInfoProviderValidOne(uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, + uint32_t &outLen, bool &needAutoFree) +{ + const char *kToken = "6G5NXCPJZB-" + "eyJsaWNlbnNlSWQiOiI2RzVOWENQSlpCIiwibGljZW5zZWVOYW1lIjoic2lnbnVwIHNjb290ZXIiLCJhc3NpZ25lZU5hbWUiOiIiLCJhc3NpZ2" + "5lZUVtYWlsIjoiIiwibGljZW5zZVJlc3RyaWN0aW9uIjoiIiwiY2hlY2tDb25jdXJyZW50VXNlIjpmYWxzZSwicHJvZHVjdHMiOlt7ImNvZGUi" + "OiJQU0kiLCJmYWxsYmFja0RhdGUiOiIyMDI1LTA4LTAxIiwicGFpZFVwVG8iOiIyMDI1LTA4LTAxIiwiZXh0ZW5kZWQiOnRydWV9LHsiY29kZS" + "I6IlBEQiIsImZhbGxiYWNrRGF0ZSI6IjIwMjUtMDgtMDEiLCJwYWlkVXBUbyI6IjIwMjUtMDgtMDEiLCJleHRlbmRlZCI6dHJ1ZX0seyJjb2Rl" + "IjoiSUkiLCJmYWxsYmFja0RhdGUiOiIyMDI1LTA4LTAxIiwicGFpZFVwVG8iOiIyMDI1LTA4LTAxIiwiZXh0ZW5kZWQiOmZhbHNlfSx7ImNvZG" + "UiOiJQUEMiLCJmYWxsYmFja0RhdGUiOiIyMDI1LTA4LTAxIiwicGFpZFVwVG8iOiIyMDI1LTA4LTAxIiwiZXh0ZW5kZWQiOnRydWV9LHsiY29k" + "ZSI6IlBHTyIsImZhbGxiYWNrRGF0ZSI6IjIwMjUtMDgtMDEiLCJwYWlkVXBUbyI6IjIwMjUtMDgtMDEiLCJleHRlbmRlZCI6dHJ1ZX0seyJjb2" + "RlIjoiUFNXIiwiZmFsbGJhY2tEYXRlIjoiMjAyNS0wOC0wMSIsInBhaWRVcFRvIjoiMjAyNS0wOC0wMSIsImV4dGVuZGVkIjp0cnVlfSx7ImNv" + "ZGUiOiJQV1MiLCJmYWxsYmFja0RhdGUiOiIyMDI1LTA4LTAxIiwicGFpZFVwVG8iOiIyMDI1LTA4LTAxIiwiZXh0ZW5kZWQiOnRydWV9LHsiY2" + "9kZSI6IlBQUyIsImZhbGxiYWNrRGF0ZSI6IjIwMjUtMDgtMDEiLCJwYWlkVXBUbyI6IjIwMjUtMDgtMDEiLCJleHRlbmRlZCI6dHJ1ZX0seyJj" + "b2RlIjoiUFJCIiwiZmFsbGJhY2tEYXRlIjoiMjAyNS0wOC0wMSIsInBhaWRVcFRvIjoiMjAyNS0wOC0wMSIsImV4dGVuZGVkIjp0cnVlfSx7Im" + "NvZGUiOiJQQ1dNUCIsImZhbGxiYWNrRGF0ZSI6IjIwMjUtMDgtMDEiLCJwYWlkVXBUbyI6IjIwMjUtMDgtMDEiLCJleHRlbmRlZCI6dHJ1ZX1d" + "LCJtZXRhZGF0YSI6IjAxMjAyMjA5MDJQU0FOMDAwMDA1IiwiaGFzaCI6IlRSSUFMOi0xMDc4MzkwNTY4IiwiZ3JhY2VQZXJpb2REYXlzIjo3LC" + "JhdXRvUHJvbG9uZ2F0ZWQiOmZhbHNlLCJpc0F1dG9Qcm9sb25nYXRlZCI6ZmFsc2V9-SnRVlQQR1/" + "9nxZ2AXsQ0seYwU5OjaiUMXrnQIIdNRvykzqQ0Q+" + "vjXlmO7iAUwhwlsyfoMrLuvmLYwoD7fV8Mpz9Gs2gsTR8DfSHuAdvZlFENlIuFoIqyO8BneM9paD0yLxiqxy/" + "WWuOqW6c1v9ubbfdT6z9UnzSUjPKlsjXfq9J2gcDALrv9E0RPTOZqKfnsg7PF0wNQ0/d00dy1k3zI+zJyTRpDxkCaGgijlY/LZ/wqd/" + "kRfcbQuRzdJ/JXa3nj26rACqykKXaBH5thuvkTyySOpZwZMJVJyW7B7ro/" + "hkFCljZug3K+bTw5VwySzJtDcQ9tDYuu0zSAeXrcv2qrOg==-" + "MIIETDCCAjSgAwIBAgIBDTANBgkqhkiG9w0BAQsFADAYMRYwFAYDVQQDDA1KZXRQcm9maWxlIENBMB4XDTIwMTAxOTA5MDU1M1oXDTIyMTAyMT" + "A5MDU1M1owHzEdMBsGA1UEAwwUcHJvZDJ5LWZyb20tMjAyMDEwMTkwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCUlaUFc1wf+" + "CfY9wzFWEL2euKQ5nswqb57V8QZG7d7RoR6rwYUIXseTOAFq210oMEe++LCjzKDuqwDfsyhgDNTgZBPAaC4vUU2oy+XR+" + "Fq8nBixWIsH668HeOnRK6RRhsr0rJzRB95aZ3EAPzBuQ2qPaNGm17pAX0Rd6MPRgjp75IWwI9eA6aMEdPQEVN7uyOtM5zSsjoj79Lbu1fjShOn" + "QZuJcsV8tqnayeFkNzv2LTOlofU/Tbx502Ro073gGjoeRzNvrynAP03pL486P3KCAyiNPhDs2z8/COMrxRlZW5mfzo0xsK0dQGNH3UoG/" + "9RVwHG4eS8LFpMTR9oetHZBAgMBAAGjgZkwgZYwCQYDVR0TBAIwADAdBgNVHQ4EFgQUJNoRIpb1hUHAk0foMSNM9MCEAv8wSAYDVR0jBEEwP4A" + "Uo562SGdCEjZBvW3gubSgUouX8bOhHKQaMBgxFjAUBgNVBAMMDUpldFByb2ZpbGUgQ0GCCQDSbLGDsoN54TATBgNVHSUEDDAKBggrBgEFBQcDA" + "TALBgNVHQ8EBAMCBaAwDQYJKoZIhvcNAQELBQADggIBABqRoNGxAQct9dQUFK8xqhiZaYPd30TlmCmSAaGJ0eBpvkVeqA2jGYhAQRqFiAlFC63" + "JKvWvRZO1iRuWCEfUMkdqQ9VQPXziE/" + "BlsOIgrL6RlJfuFcEZ8TK3syIfIGQZNCxYhLLUuet2HE6LJYPQ5c0jH4kDooRpcVZ4rBxNwddpctUO2te9UU5/" + "FjhioZQsPvd92qOTsV+8Cyl2fvNhNKD1Uu9ff5AkVIQn4JU23ozdB/R5oUlebwaTE6WZNBs+TA/qPj+5/" + "we9NH71WRB0hqUoLI2AKKyiPw++FtN4Su1vsdDlrAzDj9ILjpjJKA1ImuVcG329/" + "WTYIKysZ1CWK3zATg9BeCUPAV1pQy8ToXOq+RSYen6winZ2OO93eyHv2Iw5kbn1dqfBw1BuTE29V2FJKicJSu8iEOpfoafwJISXmz1wnnWL3V/" + "0NxTulfWsXugOoLfv0ZIBP1xH9kmf22jjQ2JiHhQZP7ZDsreRrOeIQ/" + "c4yR8IQvMLfC0WKQqrHu5ZzXTH4NO3CwGWSlTY74kE91zXB5mwWAx1jig+UXYc2w4RkVhy0//lOmVya/" + "PEepuuTTI4+UJwC7qbVlh5zfhj8oTNUXgN0AOc+Q0/WFPl1aw5VV/VrO8FCoB15lFVlpKaQ1Yh+DVU8ke+rt9Th0BCHXe0uZOEmH0nOnH/" + "0onD"; + flag = 1; + output = const_cast(kToken); + outLen = strlen(kToken); + type = ock::hcom::NET_SEC_VALID_ONE_WAY; + needAutoFree = false; + NN_LOG_INFO("client auth info " << output << " len:" << outLen << " flag:" << flag << " sec type:" << + UBSHcomNetDriverSecTypeToString(type)); + return 0; +} + +// two way, provider registered, return valid +static int SecInfoProviderValidTwo(uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, + uint32_t &outLen, bool &needAutoFree) +{ + const char *kToken = "clientservertoken"; + flag = 1; + output = const_cast(kToken); + outLen = strlen(kToken); + type = ock::hcom::NET_SEC_VALID_TWO_WAY; + needAutoFree = false; + NN_LOG_INFO("client auth info " << output << " len:" << outLen << " flag:" << flag << " sec type:" << + UBSHcomNetDriverSecTypeToString(type)); + return 0; +} + +// token is empty string +static int SecInfoProviderValid(uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, + uint32_t &outLen, bool &needAutoFree) +{ + const char *kToken = ""; + flag = 1; + output = const_cast(kToken); + outLen = strlen(kToken); + type = ock::hcom::NET_SEC_VALID_ONE_WAY; + needAutoFree = false; + NN_LOG_INFO("client auth info " << output << " len:" << outLen << " flag:" << flag << " sec type:" << + UBSHcomNetDriverSecTypeToString(type)); + return 0; +} + +// provider not registered, return valid +static int ProviderValid(uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, + uint32_t &outLen, bool &needAutoFree) +{ + NN_LOG_WARN("client provider is not registered, but return valid"); + return 0; +} + +// provider not registered, return invalid +static int SecInfoProviderInvalid(uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, + uint32_t &outLen, bool &needAutoFree) +{ + NN_LOG_ERROR("invalid sec info"); + return -1; +} + +// validator register, return valid +static int AuthValidatorValid(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen) +{ + if (input != nullptr) { + NN_LOG_INFO("client auth validate flag:" << flag); + } else { + NN_LOG_INFO("client auth validate flag:" << flag << " input:" << input << " input Len:" << inputLen); + } + return 0; +} + +// validator not register, return valid +static int ValidatorValid(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen) +{ + NN_LOG_WARN("server validator is not registered, but return valid"); + return 0; +} + +// validator register, return invalid +static int AuthValidatorInvalid(uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen) +{ + NN_LOG_ERROR("Client authentication failed"); + return -1; +} + +static NResult SendSingleRequest(UBSHcomNetEndpointPtr clientEp) +{ + std::string value = "hello world"; + NResult result = NN_OK; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + if ((result = clientEp->PostSend(1, req)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + return result; + } + return NN_OK; +} + +static void SetCB(UBSHcomNetDriver *&driver, uint16_t port, bool isServer, + const UBSHcomNetDriverEndpointSecInfoProvider &SecInfoProvider, + const UBSHcomNetDriverEndpointSecInfoValidator &SecInfoValidator) +{ + if (isServer) { + driver->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + } + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + if (SecInfoProvider == nullptr) { + driver->RegisterEndpointSecInfoProvider(nullptr); + } else { + driver->RegisterEndpointSecInfoProvider(std::bind(SecInfoProvider, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)); + } + + if (SecInfoValidator == nullptr) { + driver->RegisterEndpointSecInfoValidator(nullptr); + } else { + driver->RegisterEndpointSecInfoValidator(std::bind(SecInfoValidator, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); + } + + driver->OobIpAndPort(BASE_IP, port); +} + +static void SetDriverOptions(UBSHcomNetDriverOptions &sockOptions) +{ + sockOptions.mode = UBSHcomNetDriverWorkingMode::NET_EVENT_POLLING; + sockOptions.SetNetDeviceIpMask(IP_SEG); + sockOptions.pollingBatchSize = 16; + sockOptions.SetWorkerGroups("1"); + sockOptions.SetWorkerGroupsCpuSet("10-10"); + sockOptions.enableTls = false; +} + +static void CloseDriver(UBSHcomNetDriver *&driver) +{ + if (driver->IsStarted()) { + driver->Stop(); + driver->UnInitialize(); + } +} + +// client : (provider, invalid) | server : (/, Y/N) | failed +TEST_F(TestSecure, OneWayCase1) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + /* client is registered, return invalid */ + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_ONE_WAY; + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_1", false); + SetCB(cDriver, ipPort, false, SecInfoProviderInvalid, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + + /* 1-1 server validator is registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_1", true); + SetCB(sDriver, ipPort, true, nullptr, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(sDriver); + + /* server validator is not registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_1", true); + SetCB(sDriver, ipPort, true, nullptr, ValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(sDriver); + + /* 1-3 server validator is registered, return invalid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_1", true); + SetCB(sDriver, ipPort, true, nullptr, AuthValidatorInvalid); + sDriver->Initialize(options); + sDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(sDriver); + + /* 1-2 server validator is not registered, set nullptr */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_1", true); + SetCB(sDriver, ipPort, true, nullptr, nullptr); + sDriver->Initialize(options); + sDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(sDriver); + + CloseDriver(cDriver); +} + +// client : (provider, valid) | server : (/, Y/N) | pass +TEST_F(TestSecure, OneWayCase2) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + /* client provider is registered, return valid */ + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_ONE_WAY; + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_2", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidOne, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + + /* 2-1 server validator is not registered, set nullptr */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_2", true); + SetCB(sDriver, ipPort, true, nullptr, nullptr); + sDriver->Initialize(options); + sDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OK); + CloseDriver(sDriver); + + /* 2-2 server validator is not registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_2", true); + SetCB(sDriver, ipPort, true, nullptr, ValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OK); + CloseDriver(sDriver); + + CloseDriver(cDriver); +} + +// client : (provider, valid) | server : (validator, invalid) | failed +TEST_F(TestSecure, OneWayCase3) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + /* client provider is registered, return valid */ + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_ONE_WAY; + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_3", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidOne, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + + /* 3-1 server validator is registered, return invalid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_3", true); + SetCB(sDriver, ipPort, true, nullptr, AuthValidatorInvalid); + sDriver->Initialize(options); + sDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + + CloseDriver(sDriver); + CloseDriver(cDriver); +} + +// client : (provider, valid) | server: (validator, valid) | pass +TEST_F(TestSecure, OneWayCase4) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + + /* client provider is registered, return valid */ + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_ONE_WAY; + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_4", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidOne, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + + /* 4-1 server validator is registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_4", true); + SetCB(sDriver, ipPort, true, nullptr, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OK); + result = SendSingleRequest(clientEp); + EXPECT_EQ(result, NN_OK); + + CloseDriver(sDriver); + CloseDriver(cDriver); +} + +// client : (/, valid) | server : (validator, valid) | failed +TEST_F(TestSecure, OneWayCase5) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + /* server validator is registered, return valid */ + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_ONE_WAY; + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_5", true); + SetCB(sDriver, ipPort, true, nullptr, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 5-1 client provider is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_5", false); + SetCB(cDriver, ipPort, false, ProviderValid, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 5-2 client provider is nullptr */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_5", false); + SetCB(cDriver, ipPort, false, nullptr, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + + CloseDriver(cDriver); + CloseDriver(sDriver); +} + +// client : (/, valid) | server : (validator, invalid) | failed +TEST_F(TestSecure, OneWayCase6) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_ONE_WAY; + /* server validator is registered, but return invalid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_6", true); + SetCB(sDriver, ipPort, true, nullptr, AuthValidatorInvalid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 6-1 client provider is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_6", false); + SetCB(cDriver, ipPort, false, ProviderValid, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 6-2 client provider is nullptr */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_6", false); + SetCB(cDriver, ipPort, false, nullptr, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + CloseDriver(sDriver); +} + +// client : (/, valid/invalid) | server : (/, valid/invalid) | pass +TEST_F(TestSecure, OneWayCase7) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_ONE_WAY; + /* server validator is not registered, but return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_7", true); + SetCB(sDriver, ipPort, true, nullptr, ValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 7-1 client provider is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_7", false); + SetCB(cDriver, ipPort, false, ProviderValid, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 7-2 client provider is nullptr */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_7", false); + SetCB(cDriver, ipPort, false, nullptr, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + CloseDriver(sDriver); + + /* server validator is nullptr */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_one_case_7", true); + SetCB(sDriver, ipPort, true, nullptr, nullptr); + sDriver->Initialize(options); + sDriver->Start(); + + /* 7-3 client provider is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_7", false); + SetCB(cDriver, ipPort, false, ProviderValid, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 7-4 client provider is nullptr */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_one_case_7", false); + SetCB(cDriver, ipPort, false, nullptr, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + CloseDriver(sDriver); +} + +// client : (provider, Y)(/, Y/N) | server : (provider, N)(validator, Y) | failed +TEST_F(TestSecure, TwoWayCase8) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_TWO_WAY; + /* server provider is registered, but return invalid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_two_case_8", true); + SetCB(sDriver, ipPort, true, SecInfoProviderInvalid, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 8-1 client validator is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_8", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, ValidatorValid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 8-2 client validator is not registered, set nullptr */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_8", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + CloseDriver(sDriver); +} + +// client : (provider, Y)(validator, Y/N) | server : (provider, N)(validator, Y) | failed,failed +TEST_F(TestSecure, TwoWayCase9) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_TWO_WAY; + /* server provider is registered, but return invalid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_two_case_9", true); + SetCB(sDriver, ipPort, true, SecInfoProviderInvalid, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 9-1 client validator is registered, return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_9", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, AuthValidatorValid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 9-2 client validator is registered, return invalid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_9", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, AuthValidatorInvalid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + CloseDriver(sDriver); +} + +// client : (provider, Y)(validator, Y/N) | server : (provider, Y)(validator, Y) | pass,failed +TEST_F(TestSecure, TwoWayCase10) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_TWO_WAY; + /* server provider is registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_two_case_10", true); + SetCB(sDriver, ipPort, true, SecInfoProviderValidTwo, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 10-1 client validator is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_10", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, ValidatorValid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OK); + CloseDriver(cDriver); + + /* 10-2 client validator is not registered, set nullptr */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_10", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + CloseDriver(sDriver); +} + +// client : (provider, Y)(validator, N) | server : (provider, Y)(validator, Y) | failed +TEST_F(TestSecure, TwoWayCase11) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_TWO_WAY; + /* server provider is registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_two_case_11", true); + SetCB(sDriver, ipPort, true, SecInfoProviderValidTwo, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 11-1 client validator is not registered, set nullptr */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_11", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, AuthValidatorInvalid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + CloseDriver(sDriver); +} + +// client : (provider, Y)(validator, Y) | server : (provider, Y)(validator, Y) | pass +TEST_F(TestSecure, TwoWayCase12) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_TWO_WAY; + /* server provider is registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_two_case_12", true); + SetCB(sDriver, ipPort, true, SecInfoProviderValidTwo, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 12-1 client validator is not registered, set nullptr */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_12", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, AuthValidatorValid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OK); + CloseDriver(cDriver); + CloseDriver(sDriver); +} + +// client : (provider, Y)(/, Y/N) | server : (/, Y)(validator, Y) | failed +TEST_F(TestSecure, TwoWayCase13) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_TWO_WAY; + /* server provider is not registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_two_case_13", true); + SetCB(sDriver, ipPort, true, ProviderValid, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 13-1 client validator is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_13", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, ValidatorValid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 13-2 client validator is not registered, but return invalid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_13", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + CloseDriver(sDriver); +} + +// client : (provider, Y)(/, Y/N) | server : (/, N)(validator, Y) | failed +TEST_F(TestSecure, TwoWayCase14) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_TWO_WAY; + /* server provider is not registered, set nullptr */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_two_case_14", true); + SetCB(sDriver, ipPort, true, nullptr, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 14-1 client validator is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_14", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, ValidatorValid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 14-2 client validator is not registered, but return invalid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_14", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + CloseDriver(sDriver); +} + +// client : (provider, Y)(validator, Y/N) | server : (/, Y)(validator, Y) | failed +TEST_F(TestSecure, TwoWayCase15) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_TWO_WAY; + /* server provider is not registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_two_case_15", true); + SetCB(sDriver, ipPort, true, ProviderValid, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 15-1 client validator is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_15", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, AuthValidatorValid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 15-2 client validator is not registered, but return invalid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_15", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, AuthValidatorInvalid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + CloseDriver(sDriver); +} + +// client : (provider, Y)(validator, Y/N) | server : (/, N)(validator, Y) | failed +TEST_F(TestSecure, TwoWayCase16) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + options.secType = ock::hcom::NET_SEC_VALID_TWO_WAY; + /* server provider is not registered, set nullptr */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_two_case_16", true); + SetCB(sDriver, ipPort, true, nullptr, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + /* 16-1 client validator is not registered, but return valid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_16", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, AuthValidatorValid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + /* 16-2 client validator is not registered, but return invalid */ + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_two_case_16", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValidTwo, AuthValidatorInvalid); + cDriver->Initialize(options); + cDriver->Start(); + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OOB_SEC_PROCESS_ERROR); + CloseDriver(cDriver); + + CloseDriver(sDriver); +} + +// client : (provider, valid) | server: (validator, valid) | pass +TEST_F(TestSecure, TokenEmptyString) +{ + UBSHcomNetDriver *sDriver = nullptr; + UBSHcomNetDriver *cDriver = nullptr; + UBSHcomNetEndpointPtr clientEp = nullptr; + NResult result = NN_OK; + ipPort++; + + /* client provider is registered, return valid */ + UBSHcomNetDriverOptions options; + SetDriverOptions(options); + cDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "client_empty_string", false); + SetCB(cDriver, ipPort, false, SecInfoProviderValid, nullptr); + cDriver->Initialize(options); + cDriver->Start(); + + /* server validator is registered, return valid */ + sDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "server_empty_string", true); + SetCB(sDriver, ipPort, true, nullptr, AuthValidatorValid); + sDriver->Initialize(options); + sDriver->Start(); + + result = cDriver->Connect("hello world", clientEp, 0); + EXPECT_EQ(result, NN_OK); + + CloseDriver(sDriver); + CloseDriver(cDriver); +} \ No newline at end of file diff --git a/test/llt/testcase/transport/test_secure.h b/test/llt/testcase/transport/test_secure.h new file mode 100644 index 0000000000000000000000000000000000000000..c9bc697c8400994a42c29f63cd22acc51395a9e8 --- /dev/null +++ b/test/llt/testcase/transport/test_secure.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_TEST_SECURE_H +#define HCOM_TEST_SECURE_H + +#include +#include + +class TestSecure : public testing::Test { +public: + TestSecure() =default; + ~TestSecure() = default; + void SetUp() + { + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); + } + void TearDown() {} +}; +#endif //HCOM_TEST_SECURE_H diff --git a/test/llt/testcase/ut_helper.cpp b/test/llt/testcase/ut_helper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a2cc9191f489afb5d301072e863c5d6c00e2fa43 --- /dev/null +++ b/test/llt/testcase/ut_helper.cpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include "ut_helper.h" + +bool UTHelper::ServerCreateDriver(UBSHcomNetDriver *&serverDriver, Handlers &handlers, UBSHcomNetDriverOptions &options, + uint16_t port) +{ + serverDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, + "rdmaServer" + std::to_string(port), true); + if (serverDriver == nullptr) { + NN_LOG_ERROR("failed to create serverDriver already created"); + return false; + } + + serverDriver->RegisterNewEPHandler(handlers.newEpHandler); + serverDriver->RegisterEPBrokenHandler(handlers.epBrokenHandler); + serverDriver->RegisterNewReqHandler(handlers.receivedHandler); + serverDriver->RegisterReqPostedHandler(handlers.sentHandler); + serverDriver->RegisterOneSideDoneHandler(handlers.oneSideDoneHandler); + + serverDriver->OobIpAndPort(BASE_IP, port); + options.enableTls = false; + int result = 0; + if ((result = serverDriver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize serverDriver " << result); + return false; + } + NN_LOG_INFO("serverDriver initialized"); + + if ((result = serverDriver->Start()) != 0) { + NN_LOG_ERROR("failed to start serverDriver " << result); + return false; + } + NN_LOG_INFO("serverDriver started"); + UBSHcomNetMemoryRegionPtr mr; + if (serverDriver->CreateMemoryRegion(NN_NO8192 * 16, mr) != 0) { + NN_LOG_ERROR("failed to create server CreateMemoryRegion " << result); + return false; + } + return true; +} + +bool UTHelper::ClientCreateDriver(UBSHcomNetDriver *&clientDriver, Handlers &handlers, UBSHcomNetDriverOptions &options, + uint16_t port) +{ + auto name = "rdmaClient" + std::to_string(port); + clientDriver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::RDMA, name, false); + if (clientDriver == nullptr) { + NN_LOG_ERROR("failed to create clientDriver already created"); + return false; + } + + clientDriver->RegisterEPBrokenHandler(handlers.epBrokenHandler); + clientDriver->RegisterNewReqHandler(handlers.receivedHandler); + clientDriver->RegisterReqPostedHandler(handlers.sentHandler); + clientDriver->RegisterOneSideDoneHandler(handlers.oneSideDoneHandler); + + clientDriver->OobIpAndPort(BASE_IP, port); + options.enableTls = false; + int result = 0; + if ((result = clientDriver->Initialize(options)) != 0) { + NN_LOG_ERROR("failed to initialize clientDriver " << result); + return false; + } + NN_LOG_INFO("clientDriver initialized"); + + if ((result = clientDriver->Start()) != 0) { + NN_LOG_ERROR("failed to start clientDriver " << result); + return false; + } + NN_LOG_INFO("clientDriver started"); + UBSHcomNetMemoryRegionPtr mr; + if (clientDriver->CreateMemoryRegion(NN_NO8192 * 16, mr) != 0) { + NN_LOG_ERROR("failed to create client CreateMemoryRegion " << result); + return false; + } + return true; +} + +bool UTHelper::ClientConnect(UBSHcomNetDriver *clientDriver, UBSHcomNetEndpointPtr &clientEp, uint16_t grpNo, + uint16_t clientNo) +{ + setenv("HCOM_CONNECTION_RETRY_TIMES", "1", 1); + if (clientDriver == nullptr) { + NN_LOG_ERROR("clientDriver is null"); + return false; + } + + int result = 0; + if ((result = clientDriver->Connect("hello world", clientEp, 0, grpNo, clientNo)) != 0) { + NN_LOG_ERROR("failed to connect to server, result " << result); + return false; + } + return true; +} + +bool UTHelper::ClientSend(UBSHcomNetEndpointPtr &clientEp, sem_t *sem) +{ + int result = 0; + sem_init(sem, 0, 0); + + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + + if ((result = clientEp->PostSend(1, req)) != 0) { + NN_LOG_INFO("failed to post message to data to server"); + return false; + } + sem_wait(sem); + return true; +} + +static uint16_t basePort = 6900; +static int nameSeed = 0; + +NResult UTHelper::GetDriverStateMask(UBSHcomNetDriver *&driver, uint16_t stateMask, bool isServer, + UBSHcomNetDriverProtocol protocol) +{ + driver = UBSHcomNetDriver::Instance(protocol, std::to_string(nameSeed++), isServer); + driver->OobIpAndPort(BASE_IP, isServer ? ++basePort : basePort); + + return ForwardDriverStateMask(driver, stateMask); +} + +NResult UTHelper::GetDriver(UBSHcomNetDriver *&driver, DRIVER_STATE state, bool isServer, + UBSHcomNetDriverProtocol protocol) +{ + driver = UBSHcomNetDriver::Instance(protocol, std::to_string(nameSeed++), isServer); + driver->OobIpAndPort(BASE_IP, isServer ? ++basePort : basePort); + + return ForwardDriverState(driver, state); +} + +NResult UTHelper::ForwardDriverState(UBSHcomNetDriver *&driver, DRIVER_STATE state) +{ + NResult result = NN_OK; + if (state >= DRIVER_STATE_INIT) { + UBSHcomNetDriverOptions options; + options.mode = ock::hcom::NET_EVENT_POLLING; + + options.SetNetDeviceIpMask(IP_SEG); + options.enableTls = false; + result = driver->Initialize(options); + if (result != NN_OK) { + return result; + } + } + if (state >= DRIVER_STATE_START) { + Handlers handlers; + driver->RegisterNewEPHandler(handlers.newEpHandler); + driver->RegisterEPBrokenHandler(handlers.epBrokenHandler); + driver->RegisterNewReqHandler(handlers.receivedHandler); + driver->RegisterReqPostedHandler(handlers.sentHandler); + driver->RegisterOneSideDoneHandler(handlers.oneSideDoneHandler); + result = driver->Start(); + if (result != NN_OK) { + return result; + } + } + if (state >= DRIVER_STATE_STOP) + driver->Stop(); + if (state >= DRIVER_STATE_UNINIT) + driver->UnInitialize(); + return NN_OK; +} + +NResult UTHelper::ForwardDriverStateMask(UBSHcomNetDriver *&driver, uint16_t state) +{ + NResult result = NN_OK; + if (state & DRIVER_STATE_INIT) { + UBSHcomNetDriverOptions options; + options.mode = ock::hcom::NET_EVENT_POLLING; + options.SetNetDeviceIpMask(IP_SEG); + options.enableTls = false; + result = driver->Initialize(options); + if (result != NN_OK) { + return result; + } + } + if (state & DRIVER_STATE_START) { + Handlers handlers; + driver->RegisterNewEPHandler(handlers.newEpHandler); + driver->RegisterEPBrokenHandler(handlers.epBrokenHandler); + driver->RegisterNewReqHandler(handlers.receivedHandler); + driver->RegisterReqPostedHandler(handlers.sentHandler); + driver->RegisterOneSideDoneHandler(handlers.oneSideDoneHandler); + result = driver->Start(); + if (result != NN_OK) { + return result; + } + } + if (state & DRIVER_STATE_STOP) + driver->Stop(); + if (state & DRIVER_STATE_UNINIT) + driver->UnInitialize(); + return NN_OK; +} \ No newline at end of file diff --git a/test/llt/testcase/ut_helper.h b/test/llt/testcase/ut_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..9ab0c54d6d8c558dc427608d00d0ce14cba8b259 --- /dev/null +++ b/test/llt/testcase/ut_helper.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_UT_HELPER_H +#define HCOM_UT_HELPER_H + +#ifdef RDMA_BUILD_ENABLED +#include "transport/rdma/rdma_common.h" +#endif +#include "common/net_util.h" +#include "hcom.h" +#include "hcom_def.h" + +using namespace ock::hcom; + +#ifdef MOCK_VERBS +#define BASE_IP "127.0.0.1" +#define IP_SEG "127.0.0.0/16" +#ifdef RDMA_BUILD_ENABLED +#define MOCK_VERSION MOCKER(ReadRoCEVersionFromFile).stubs().will(returnValue(0)); +#else +#define MOCK_VERSION +#endif +#else +#define BASE_IP "192.168.100.204" +#define IP_SEG "192.168.100.0/24" +#define MOCK_VERSION +#endif + +using Handlers = struct hdlrs { + UBSHcomNetDriverNewEndPointHandler newEpHandler = [](const std::string &ipPort, const UBSHcomNetEndpointPtr &, + const std::string &payload) { return 0; }; + UBSHcomNetDriverEndpointBrokenHandler epBrokenHandler = [](const UBSHcomNetEndpointPtr &) { return 0; }; + UBSHcomNetDriverSentHandler sentHandler = [](const UBSHcomNetRequestContext &) { return 0; }; + UBSHcomNetDriverOneSideDoneHandler oneSideDoneHandler = [](const UBSHcomNetRequestContext &) { return 0; }; + UBSHcomNetDriverReceivedHandler receivedHandler = [](const UBSHcomNetRequestContext &) { return 0; }; +}; + +struct DummyObj { + int tag = -1; + explicit DummyObj(int _tag) : tag(_tag) {} + DummyObj() = default; + DEFINE_RDMA_REF_COUNT_VARIABLE; + DEFINE_RDMA_REF_COUNT_FUNCTIONS + ~DummyObj() {} +}; + +using OBJ_LIFE_CYCLE = enum _o_l_c_ { + NONE = 0, + INIT, + DEINIT +}; + +using DRIVER_STATE = enum _d_s_ { + DRIVER_STATE_NONE = 0, + DRIVER_STATE_INIT = 1 << 0, + DRIVER_STATE_START = 1 << 1, + DRIVER_STATE_STOP = 1 << 2, + DRIVER_STATE_UNINIT = 1 << 3 +}; + +struct NoisyObj { + OBJ_LIFE_CYCLE &state; + explicit NoisyObj(OBJ_LIFE_CYCLE &_state) : state(_state) + { + state = INIT; + } + DEFINE_RDMA_REF_COUNT_VARIABLE; + DEFINE_RDMA_REF_COUNT_FUNCTIONS + ~NoisyObj() + { + state = DEINIT; + } +}; + +#define UT_CHECK_RESULT_TRUE(result) ASSERT_EQ(true, (result)); + + +#define UT_CHECK_RESULT_FALSE(result) ASSERT_EQ(false, (result)); + +#define UT_CHECK_RESULT_OK(result) ASSERT_EQ(NN_OK, (result)); + + +#define UT_CHECK_RESULT_NOK(result) ASSERT_NE(NN_OK, result); + + +#define UT_CHECK_RESULT_NOT_NULL(result) ASSERT_NE(nullptr, result); + + +class UTHelper { +public: + static bool ServerCreateDriver(UBSHcomNetDriver *&serverDriver, Handlers &handlers, + UBSHcomNetDriverOptions &options, uint16_t port); + static bool ClientCreateDriver(UBSHcomNetDriver *&clientDriver, Handlers &handlers, + UBSHcomNetDriverOptions &options, uint16_t port); + static bool ClientConnect(UBSHcomNetDriver *clientDriver, UBSHcomNetEndpointPtr &clientEp, uint16_t srvNo = 0, + uint16_t clientNo = 0); + static bool ClientSend(UBSHcomNetEndpointPtr &clientEp, sem_t *sem); + static NResult GetDriver(UBSHcomNetDriver *&driver, DRIVER_STATE state, bool isServer, + UBSHcomNetDriverProtocol protocol = UBSHcomNetDriverProtocol::RDMA); + static NResult GetDriverStateMask(UBSHcomNetDriver *&driver, uint16_t stateMask, bool isServer, + UBSHcomNetDriverProtocol protocol = UBSHcomNetDriverProtocol::RDMA); + static NResult ForwardDriverState(UBSHcomNetDriver *&driver, DRIVER_STATE state); + static NResult ForwardDriverStateMask(UBSHcomNetDriver *&driver, uint16_t state); +}; + +#endif // HCOM_UT_HELPER_H diff --git a/test/opensslcrt/abnormalCertChain/CA/cacert.pem b/test/opensslcrt/abnormalCertChain/CA/cacert.pem new file mode 100644 index 0000000000000000000000000000000000000000..8e52b7a68dd73f59d2eaf00709e3d409fd2d1226 --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/CA/cacert.pem @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID0TCCArmgAwIBAgIURA/lUFrZwOfJXaaLF9ol7wkhAxIwDQYJKoZIhvcNAQEL +BQAweDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQHDAJOSjEPMA0G +A1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxDzANBgNVBAMMBlJvb3RDQTEfMB0G +CSqGSIb3DQEJARYQcm9vdGNhQHdvcmxkLmNvbTAeFw0yMjEwMTIwOTM3MjFaFw0z +MjEwMDkwOTM3MjFaMHgxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJKUzELMAkGA1UE +BwwCTkoxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MQ8wDQYDVQQDDAZS +b290Q0ExHzAdBgkqhkiG9w0BCQEWEHJvb3RjYUB3b3JsZC5jb20wggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCgvTn1zACCZ4uny3BW8Utilwoztkhb/XM+ +ZI/trCZg1smsnuCNJHyIJVoFz4PoxXESCueTD0UwIcrftvzQPJZzVZEOY/ND4ZFq +Bj4TbCybSNuFIIAXn2yL1x5oLGz5wuEr7XClqUECVZPTyDv2ozg7+L6NRXNnQ3DQ +jL3QqEaH0M0hA/4FX7ySXrSC2BFX5LZzv8cjKla+3jqJUbUxokxEWMfYVNU+JjwV +vS4ieVqIsWcfd+FhqFCpvWj92PpJB5Lk6GkuGi026lgYutK30Gx133QLQjrNRRLG +4dM0KMevpSM1Ug7dXy60dIjlJkFjekEf7umCGFNILf3UKQTd1irnAgMBAAGjUzBR +MB0GA1UdDgQWBBRXGByy5N4UbiP0NcTKWpj2yDsJ3zAfBgNVHSMEGDAWgBRXGByy +5N4UbiP0NcTKWpj2yDsJ3zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4IBAQAaLNCcpUpKjxz3DKPU632uqljaliu5mCm9OAQRDLccfZifIHacyxx45rm7 +8TOSgQnSS41UtsvTMn63LXqWCi+P0K6LVZUDAiQzcehGsKYyj7yG8UiQev3qUXuU +HRD+n92AVKYz4ABkKqepgVqE1G0CQeFISU/czTv4r3+Qv5IGXWuZ9D9fNu2Al8DG +Zo4FdcxsBESksDP07Gjj3zMRgV0uhsoSP7gEc2zrucfSIZasEuIgPGL07vMc4Ybm +EaY/cZcuSj+o012wmP6YMQw4t74Fxqe41l4lY/yPtgUO9LDx9ygiDzEKaMCTea8M +33Oqg/tD4bzToCPlpeh7qhpUfLpP +-----END CERTIFICATE----- diff --git a/test/opensslcrt/abnormalCertChain/CA/rootca.key b/test/opensslcrt/abnormalCertChain/CA/rootca.key new file mode 100644 index 0000000000000000000000000000000000000000..243405ce091dbafb55cbc6d84a9a8f9fb506a969 --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/CA/rootca.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAoL059cwAgmeLp8twVvFLYpcKM7ZIW/1zPmSP7awmYNbJrJ7g +jSR8iCVaBc+D6MVxEgrnkw9FMCHK37b80DyWc1WRDmPzQ+GRagY+E2wsm0jbhSCA +F59si9ceaCxs+cLhK+1wpalBAlWT08g79qM4O/i+jUVzZ0Nw0Iy90KhGh9DNIQP+ +BV+8kl60gtgRV+S2c7/HIypWvt46iVG1MaJMRFjH2FTVPiY8Fb0uInlaiLFnH3fh +YahQqb1o/dj6SQeS5OhpLhotNupYGLrSt9Bsdd90C0I6zUUSxuHTNCjHr6UjNVIO +3V8utHSI5SZBY3pBH+7pghhTSC391CkE3dYq5wIDAQABAoIBADKXPC3btnFUy8TV +KBeFPJfcOA7MmXuyitohZpeErlOeZr1ZCA4EZNmo/+uCQ984fX0TR42mqb0bdbHx +8yJLX4MPdGdWGBPOZCk9q74LNwLs7IK7FvXYbJ6a52wcR3RY3OwpgGHzoo1sh+mJ +RS48cw+VG8x1BnyC4ngRRBDvVbubAYMGceUmlYtFgIRWqZrBhKvXgzpB2zv+5jxR +NnXDqkQRYl9fuAaG3ajN/qpiMZZlcKfZskjfRvrcGcHi8bD+xGutxXSPFq3qPUT+ +tJ1pexOv2xUx0A27441LoPW1FzSyOj2+iPuZpjTSseM1AIxk3lvXwJU/9CAXYBiS +RxA7YQECgYEA1jqeAfoWtEM1XOl3K1KsxfCqcPFsMUxglhcPz1n6Fv0zgSXHr/Sl +HDPmE7faiqNZoyy81rG7+xjpZkSa6LfQ4gCzwjN5soed7Bi7zeRM5QImta2NMpV+ +x6OXdlQQj3QDeBivmuilL47ckENYCk/t/Tu/jCpuPCYC/syuqvq1g2cCgYEAwBSh +LFcJG4xFtClnrPMljbqlDeQi4S02HYl10+8JpVbjEiPmLTbzkDSAUfm4UR+ZblJX +CGvd/TLDy1iju0/m1X+Ddb2lMUMAeihuZjaWxbpaX7UU0QUNR8jcvxQrkcFqc/hy +dpw3aVD+74+WcYoN+YqG9KGp/7N7Lkpnxslk7IECgYEAksozfpNIf1gV9oYam9rY +fAD+KMmkItt8yxseQCwdCyeP5QxoGY7+m6aMHjK6Uoi/YOnEsy+x6MoXE3Yq1w8s +1883XPg8iTIX6bDA7sFiVwD0WUSEHYcGCfF0VSYg+sq5nc78dJ64oS+4vjkG2HoQ +TpZkF7zzL8+z+bdyb8G+Ij0CgYAYdggYd3UHdxOhX+x+D/DmXbCLVlRCzNkpZcoF +lVlrHueH9d5oP6lA4g69YcnhOt71N7MxtVrt1bsteDpRrlk9MyHwqpgQ7/FtnRyC +E82bnKHJsmvWOoh4bdH+23i49SKzZh5dkINV/CSbKXQFPYmOD+Aj4zqc/6RePsd8 +f0VFAQKBgEqkOim1lmubatkYFy/Gm/3sr4pn/ENkPig7nRZDbSqzSOV50/rU40O2 ++QA5mhUyzQuG8S8zPhXWBnJt1T+O57uDRjEtVw9zrgZw/MddXEHILVjK0qtIR1Hq +opEzRRcl+keFb2N5YXfOd+qsnBl0UklTCOo/E12PwhaW6nX6HujD +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/abnormalCertChain/CA/secondca.crt b/test/opensslcrt/abnormalCertChain/CA/secondca.crt new file mode 100644 index 0000000000000000000000000000000000000000..1de5943ff08d32b36bd333221eb5f15cd0911433 --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/CA/secondca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIBATANBgkqhkiG9w0BAQsFADB4MQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxCzAJBgNVBAcMAk5KMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNV +BAsMA0RldjEPMA0GA1UEAwwGUm9vdENBMR8wHQYJKoZIhvcNAQkBFhByb290Y2FA +d29ybGQuY29tMB4XDTIyMTAxMjA5NDIwOFoXDTMyMTAwOTA5NDIwOFowbzELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsM +A0RldjERMA8GA1UEAwwIU2Vjb25kQ0ExITAfBgkqhkiG9w0BCQEWEnNlY29uZGNh +QHdvcmxkLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANcpo7lG +5ttiS8yJtzxGVlO7lA0RqODL1xO84ES+jE3CPLCVo83l95E+ZHQZZWLyAr/DrGCm +rP93+ODINGkMvLZlEV+vYcxzMZeRwctITygaSaxV6bON9m2F3k1iQZS/gKkL9L1i +/F3nkOAEuCYMgwHfPy+/E97H9ECRpwt2mzld40UQZj1nQY1nc8l5F10p5GSiNp5w +98N/Fnk9ZCESUKtPviK541ljvs+tvgSvuKtJJIbdhbn+xW5O9hQE79u9iumo1zBS +jZKMMMdRqTIaqt/0Ofqu39g/OsfeLCfGr+NF9suTaeJEMFPMzOhSdgy9zf+fykVJ +GuXxfedLzbgP06kCAwEAAaNTMFEwHQYDVR0OBBYEFMeDrbiKqXhHPcdp/CNpS9Vn +/MTfMB8GA1UdIwQYMBaAFFcYHLLk3hRuI/Q1xMpamPbIOwnfMA8GA1UdEwEB/wQF +MAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGDOfMIV1AC5bdbpnhSDhpzkBjh3OXqj +PUplhCm+VSFUcSE5vTe4EnUsCmpIAYkAA/2fnr6Kb9GFJeuoWlWpo4vRW9g8qvRD +B7AoV/ojafPb4V2epZLaHpJQcxyqplGFKqPKDAXZlXLXUbWLAutq8q7GLL3MlqDF +nRDHBafXEHJOBQG7HyTDt20uoIM0+10ehJW0LlO2UD/SQ5ZEkYwB5fNk4rJMgtmb ++Nibo2x2ZdMd0rakv5CGCjf13aMczngYcbNjnUFjiLofHTEcoYmDVTqjd4W416Fj +2R+WfdalPEctoRCjDMi7dZLxQOY7/IesZbXJkVMtbP8lrVb+9Xl7lj0= +-----END CERTIFICATE----- diff --git a/test/opensslcrt/abnormalCertChain/CA/secondca.csr b/test/opensslcrt/abnormalCertChain/CA/secondca.csr new file mode 100644 index 0000000000000000000000000000000000000000..012c84a927e1af42c1e6379275991e853c7d7ee9 --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/CA/secondca.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICwTCCAakCAQAwfDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQH +DAJOSjEPMA0GA1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxETAPBgNVBAMMCFNl +Y29uZENBMSEwHwYJKoZIhvcNAQkBFhJzZWNvbmRjYUB3b3JsZC5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDXKaO5RubbYkvMibc8RlZTu5QNEajg +y9cTvOBEvoxNwjywlaPN5feRPmR0GWVi8gK/w6xgpqz/d/jgyDRpDLy2ZRFfr2HM +czGXkcHLSE8oGkmsVemzjfZthd5NYkGUv4CpC/S9Yvxd55DgBLgmDIMB3z8vvxPe +x/RAkacLdps5XeNFEGY9Z0GNZ3PJeRddKeRkojaecPfDfxZ5PWQhElCrT74iueNZ +Y77Prb4Er7irSSSG3YW5/sVuTvYUBO/bvYrpqNcwUo2SjDDHUakyGqrf9Dn6rt/Y +PzrH3iwnxq/jRfbLk2niRDBTzMzoUnYMvc3/n8pFSRrl8X3nS824D9OpAgMBAAGg +ADANBgkqhkiG9w0BAQsFAAOCAQEALBW59/ZzFd97b5jzqamQnkKH2fN/kk7+vfu8 +0FiHN2liCnaHAa7+zlxch8XZY/LdQWdBcTtOMQgTz8dEuHsQaAxT4dLTTm9rs70w +QoxoGLy7okbvGKyhxzJM6BHJVDzaq2AXMtB1BlI+9DFBmwxbpDQyqtc0XaABBkV5 +GahaE2WAP1t3LM+JDOdJ+5VLSNIhneJrFR465HmHaVSVe1ivD3tk7394DwcLPOE4 +qsuIP3nQIzMFKyzyaMbaKXNd1mU3SfitOQuckm7oycexVgtd1oU6kZs0PF8q6e1Z +tcAxZqbYllNZx1cHyg/lu59+gPOvH3FB6LWiloL3usrHLm9pgQ== +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/abnormalCertChain/CA/secondca.key b/test/opensslcrt/abnormalCertChain/CA/secondca.key new file mode 100644 index 0000000000000000000000000000000000000000..803c11daf385b85f69e000111c71a08968d0bced --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/CA/secondca.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEA1ymjuUbm22JLzIm3PEZWU7uUDRGo4MvXE7zgRL6MTcI8sJWj +zeX3kT5kdBllYvICv8OsYKas/3f44Mg0aQy8tmURX69hzHMxl5HBy0hPKBpJrFXp +s432bYXeTWJBlL+AqQv0vWL8XeeQ4AS4JgyDAd8/L78T3sf0QJGnC3abOV3jRRBm +PWdBjWdzyXkXXSnkZKI2nnD3w38WeT1kIRJQq0++IrnjWWO+z62+BK+4q0kkht2F +uf7Fbk72FATv272K6ajXMFKNkowwx1GpMhqq3/Q5+q7f2D86x94sJ8av40X2y5Np +4kQwU8zM6FJ2DL3N/5/KRUka5fF950vNuA/TqQIDAQABAoIBAQDL1t8NQGaloNI+ +zJmTuO9AFI2GdBySG4t/X4j4l61EXagxgxLUlfGc4Ic6lnS+8Jg6JJ7CUiXDQV2/ +VuyQOUjvY4C6LeVxVBC/j48Rj0eurnjtk9b8DJpR2Glq1pNa4LJ7dKBAa+666A8Q +rGfpZCEZPO8XxOaGQNjd8x9WdN9J0DQ1oSsvRC8xLyuKK1jrxrIaXMK1XZRgaNMC +atr3fCo+yChyuzsSKHdscq+2MA6xnp3dc8Q+nwhx1xFSdfah+NCqGfmVYJRDgyCT +GuZCkyi8R1+jkGs4dO/cBTTRlkiWHHFLJo+/c8vB4Do1JkA93PoSYa6G4UAUJA3N +B1e16TihAoGBAOttoHES6CpCiZdOtTTk//zyWxrG8CjnHRCj8g39QHnAYyzlz88h +RYuyfrTPviGuE3+3CMWsjvPag1E20NpyrsJtx3OBhcDG6fj6l7j1vuTngl+mlM9K ++n36BcClTAnY3JZf51FErmUYkTVh9RcQ4iycV7cl65OouNwGVc90WCpPAoGBAOn2 +r3OIC7bal0kxrI0bT65gkbvg3q1sA1cpcu4HOZJUK7RjdYzsd7+lCsSPfVyp9yN1 +iw3Bn2umEkkogxy7wShT27Ng2LNyHRrYhVIuEee/+SSZgMDa+0cTOdKPqRVITffb +weM4Z3gVrc27EtWuc9vGUkNxoceJ3MKHpesDKTyHAoGAMSoFlVdzcE/Q1+4x3Uft +RW9/IwpkYMZSxYTXKaC3dDV/AINFcGXsVg4Cc9PmSrZFkCgzBsTQXZBGWBFwcA3+ +/M9cFXz455ciiUIbqR54rOjDyyHIdbmcse4igWaDiJLnDegdMFV9bdNBj7pTKmv2 +L4a+spqSpZVYdWpFRTtwpfUCgYEA6b4F7bOWmHlsubiB/nuxsLJUBtMTRUlrUPJd +G0dmkjW7cD4Jm+BHhtTZnCULBr/b47Y0VWsC3aaOED8ENnmx8ZtOHLj95tF0GHUH +RWI3i0Q1IgamJobgklK36xCRyWxyUNVhsKOSY9usx6RFnevrXj+VwkHNci/euQ6S +ieeflBMCgYEApVqc4NqEuS4i//GGdQ9fXPx9eXBOxnqjwXL4SUdhwSruBYPIDavD +Vb66fdGQU7wAeVQczeIZEHDFuLkePvbwPeQmdDgJBN88lB1TjxvwhYvOBz0cJbl0 +NN8bkkMUsjZ93sZaVKgKYf7Q2JEwBRdagmWktcx9omJYPs+GAxIdVnE= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/abnormalCertChain/client/cert.pem b/test/opensslcrt/abnormalCertChain/client/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..8b0df4f25fe076fa11ef29e4a77b55019842e0b8 --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/client/cert.pem @@ -0,0 +1,70 @@ +-----BEGIN CERTIFICATE----- +MIID1jCCAr6gAwIBAgIBAzANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDY0NVoXDTMyMTAwOTA5NDY0NVowcTELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEXMBUG +A1UEAwwOdGVzdGNsaWVudC5jb20xHTAbBgkqhkiG9w0BCQEWDnRlc3RAd29ybGQu +Y29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5evH/YQgDHryEt/6 +OV1ALabNSIGgZjLe0Rjouo684vPgROD6qLnfDKZO2llfkupWOtyQJvk29giiNa20 +gtnEinhr6xemJ8PfYxXoTQ0pDCJQQ+bozuJtjQ6D5+IdECe+rbecxEixIMJiUIWL +7Pu5wDrqHKI95WJr3XDil86+jtXuMB6U4Do3RP7KW5MFaoyn9mrQIKUr2Q6ssQ4E +UQaYBa5KyUfHv21HspK9x1T7uLWGZ7PeU/DVfJQdLk9YOldyIVwrtAC4kla/KSVW +R9lMpSm3kLq84Z1DCky0wnizQ0eG+GbSI8hJAFn3bwfBmu99RDitlz5GYY//mD5z +eDXpHwIDAQABo3sweTAJBgNVHRMEAjAAMCwGCWCGSAGG+EIBDQQfFh1PcGVuU1NM +IEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTAdBgNVHQ4EFgQUpij+s+IWcrTecYoQvnlc +kJoxfAswHwYDVR0jBBgwFoAUx4OtuIqpeEc9x2n8I2lL1Wf8xN8wDQYJKoZIhvcN +AQELBQADggEBAGtB6+WR4yZt8DuYJ7SQ4gsEBAo6BEC/lX6Nv4CeSRKBz8RDGuN5 +zAfJtC9k9c2l03tAYQQnwkYiqsEQoN6ojz0JckmQb14y3dj7lClgmMCBy6BPjkDL +85dz/4NFR2K4VxJ8owYaNjBCQFp1oXRq+wcluWseejhoJIJd49H14gkNcVg2j1ep +cRdK4kZL3g/b6ub11NUvt+2RfrQQUCp43FpK86fLd1dEnecyFCgFOyH+larPUPYt +7ftZWkLKih/T9kO1UHmaXb4uZ0VZg5+t6Ut/jeD3d8QnHMftL/TKOR+tihwlEZmW +TyObahHmz0SMWQfUIW3hfNDuQzPybKnycC0= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDvTCCAqWgAwIBAgIBBDANBgkqhkiG9w0BAQsFADB4MQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxCzAJBgNVBAcMAk5KMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNV +BAsMA0RldjEPMA0GA1UEAwwGUm9vdENBMR8wHQYJKoZIhvcNAQkBFhByb290Y2FA +d29ybGQuY29tMB4XDTIzMTAxODA2NTYxOFoXDTIzMTAxOTA2NTYxOFowdzELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsM +A0RldjETMBEGA1UEAwwKU2Vjb25kQ0FFWDEnMCUGCSqGSIb3DQEJARYYc2Vjb25k +Y2FleHBpcmVAd29ybGQuY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAuX/+vVcic13kBJy4ZmhSxIaQRaA2ElVS5WTZ4PG8BsMRAhtcooz/SKF+4NWo +8bbI/Djf9T5349zHDrYkVwJW4jTwzkZjI/aIEKwFqgotbH2K3a6AcY9mHvnOvfsY +1xPC0a/438y1SkmJRTsfMKdUYKcCJ4/NiQrVZ5Z9lKb9x/Zo9Rhm7+PJL6MH3Mdy +vUDMLRLsUyDZTH4ofoafPbXFLtq4eTzzdKxgXrixT+4M2aHbPS8BDsDxndwzse6G +vXwZ0LLGJ65LqlVwHbBZDtZeWoNw+q9Bhhxk0Cf/qd9Ck2hMwj/YV/Ihiz9BFcTt ++rZhVkumWkAuZYkD9xkLdzNvewIDAQABo1MwUTAdBgNVHQ4EFgQUwMtjuRuv+cY1 +Ea8kWrf8MrQdlvAwHwYDVR0jBBgwFoAUVxgcsuTeFG4j9DXEylqY9sg7Cd8wDwYD +VR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAmMTLv3E672GVkCtioYjG +oh+taqXMyskc9vCONK0CuEWStgSDPdzCZNz8uoX+dCPmgYxeju+98c1cG9Rjg2Fv +xPqWi/TKes9fy3VBYkZFSU72tu1fXBXhoIFO34Pmq98wsZmzNgN6bfT8f9tDcif1 +UUmHkO6L4L9ZP4GwqgHCzTKWuzbbvNnWpFKXfbzyhESGP6H18RdtWSWg8i1zKeVI +d26GnNHNLxAi92rKQaCAN5OqsmhmBGehdfNdmkIdJ/Kxk7LYbhn+3KLARUIwnS4x +VCN10ITB8V7692SD9pF7BZQSkW5K7GakfI5L5RBxIej4awVCzMrDD4tl9aQriEKw +FQ== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIID0TCCArmgAwIBAgIURA/lUFrZwOfJXaaLF9ol7wkhAxIwDQYJKoZIhvcNAQEL +BQAweDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQHDAJOSjEPMA0G +A1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxDzANBgNVBAMMBlJvb3RDQTEfMB0G +CSqGSIb3DQEJARYQcm9vdGNhQHdvcmxkLmNvbTAeFw0yMjEwMTIwOTM3MjFaFw0z +MjEwMDkwOTM3MjFaMHgxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJKUzELMAkGA1UE +BwwCTkoxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MQ8wDQYDVQQDDAZS +b290Q0ExHzAdBgkqhkiG9w0BCQEWEHJvb3RjYUB3b3JsZC5jb20wggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCgvTn1zACCZ4uny3BW8Utilwoztkhb/XM+ +ZI/trCZg1smsnuCNJHyIJVoFz4PoxXESCueTD0UwIcrftvzQPJZzVZEOY/ND4ZFq +Bj4TbCybSNuFIIAXn2yL1x5oLGz5wuEr7XClqUECVZPTyDv2ozg7+L6NRXNnQ3DQ +jL3QqEaH0M0hA/4FX7ySXrSC2BFX5LZzv8cjKla+3jqJUbUxokxEWMfYVNU+JjwV +vS4ieVqIsWcfd+FhqFCpvWj92PpJB5Lk6GkuGi026lgYutK30Gx133QLQjrNRRLG +4dM0KMevpSM1Ug7dXy60dIjlJkFjekEf7umCGFNILf3UKQTd1irnAgMBAAGjUzBR +MB0GA1UdDgQWBBRXGByy5N4UbiP0NcTKWpj2yDsJ3zAfBgNVHSMEGDAWgBRXGByy +5N4UbiP0NcTKWpj2yDsJ3zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4IBAQAaLNCcpUpKjxz3DKPU632uqljaliu5mCm9OAQRDLccfZifIHacyxx45rm7 +8TOSgQnSS41UtsvTMn63LXqWCi+P0K6LVZUDAiQzcehGsKYyj7yG8UiQev3qUXuU +HRD+n92AVKYz4ABkKqepgVqE1G0CQeFISU/czTv4r3+Qv5IGXWuZ9D9fNu2Al8DG +Zo4FdcxsBESksDP07Gjj3zMRgV0uhsoSP7gEc2zrucfSIZasEuIgPGL07vMc4Ybm +EaY/cZcuSj+o012wmP6YMQw4t74Fxqe41l4lY/yPtgUO9LDx9ygiDzEKaMCTea8M +33Oqg/tD4bzToCPlpeh7qhpUfLpP +-----END CERTIFICATE----- + diff --git a/test/opensslcrt/abnormalCertChain/client/client.csr b/test/opensslcrt/abnormalCertChain/client/client.csr new file mode 100644 index 0000000000000000000000000000000000000000..ff6829e03e04ae02ef15b910f4f3b44e4748f0d1 --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/client/client.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICwzCCAasCAQAwfjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQH +DAJOSjEPMA0GA1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxFzAVBgNVBAMMDnRl +c3RjbGllbnQuY29tMR0wGwYJKoZIhvcNAQkBFg50ZXN0QHdvcmxkLmNvbTCCASIw +DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOXrx/2EIAx68hLf+jldQC2mzUiB +oGYy3tEY6LqOvOLz4ETg+qi53wymTtpZX5LqVjrckCb5NvYIojWttILZxIp4a+sX +pifD32MV6E0NKQwiUEPm6M7ibY0Og+fiHRAnvq23nMRIsSDCYlCFi+z7ucA66hyi +PeVia91w4pfOvo7V7jAelOA6N0T+yluTBWqMp/Zq0CClK9kOrLEOBFEGmAWuSslH +x79tR7KSvcdU+7i1hmez3lPw1XyUHS5PWDpXciFcK7QAuJJWvyklVkfZTKUpt5C6 +vOGdQwpMtMJ4s0NHhvhm0iPISQBZ928HwZrvfUQ4rZc+RmGP/5g+c3g16R8CAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4IBAQCsGnU+byNkW1GA8XNiya/xcwFhBrgZ8ytk +neqb0s1VdcTpM0F3OhnaYsBbAuWvuuXnKX2fUGD1TPiB/cuiOULxMkR2FQvA6l6g +pOUTRP1sHsUbAB1nLEZ6ZNmYISLY5e52265SFGw8moQ29TEXjg1Fpsgpiol5Xc9O +pf8ojhCgKmCgUgMd48N11BPzTB+8vwiBG7c+pJr3RU+g/FjtGXTE2g7GqWDAvcaC +4F7L+PBtlocjI4K2ci1yEsFYXoa0ZmOZjLvTOS7A3hDFM1ga1pjTh1hejGwC+d1u +uaFXa2XJa89jimpZH609gAv+6KlGXB5SVgfwVG9DjKMFmJJYVmpV +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/abnormalCertChain/client/key.pem b/test/opensslcrt/abnormalCertChain/client/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..d0e8e05724fdfadc4ea309d388bac03c07e18029 --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/client/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA5evH/YQgDHryEt/6OV1ALabNSIGgZjLe0Rjouo684vPgROD6 +qLnfDKZO2llfkupWOtyQJvk29giiNa20gtnEinhr6xemJ8PfYxXoTQ0pDCJQQ+bo +zuJtjQ6D5+IdECe+rbecxEixIMJiUIWL7Pu5wDrqHKI95WJr3XDil86+jtXuMB6U +4Do3RP7KW5MFaoyn9mrQIKUr2Q6ssQ4EUQaYBa5KyUfHv21HspK9x1T7uLWGZ7Pe +U/DVfJQdLk9YOldyIVwrtAC4kla/KSVWR9lMpSm3kLq84Z1DCky0wnizQ0eG+GbS +I8hJAFn3bwfBmu99RDitlz5GYY//mD5zeDXpHwIDAQABAoIBAQCnLeM0Nm8rQ/Zi +vRNvxJtW1nNr5j1gMlsLxUXr6L/1cgi/bKs2JjjGNOMfJ180L0pV8Gysugc5rJtt +1olrn7amTNuDjKWXQnhazuIjrI8NMKIWTX84dzHbIBPPdv1U8uFV5S2LF6QbwtvD +2uccgQjWesAh4+KHuSHfWSaZ5Y1vw0MGNCRkyVx13ovDM0gjT694aI6y0Qd/mXm4 +oAvTZCU7tJ/45nSDiWqSQdOJJdQjV0nOoKj0Fpcr7mF11QYALjXmC04YfO5DYaCd +O4wQxZYGzSWyYoHLFonlc5uL1M/ZYP6n2HPvHCeyvJY6q+kTA3OmhnOIKwoCpObb +FV3JETmhAoGBAP4Kji7Ig2bNK8oYNm6sUFJR1tnLSHIkSmNCbpJ4vQE0RBIU1/US +NrrUBAWm6AppFZPsUsLc42O96GOv/P4tOioKsFUJCi/w59A3Xt04tapHBTEfUiJy +bOLIC1WYEhsZ4u1e1u6hfvyAuJVYLqkiPVsRLp2QV6PUP1cSyNY3mHEPAoGBAOex +nZmytTkL/na/aQNKPEFQaBvQJQUKw7tWcwZyjgm4Ar03ojYw04TkuP0UHw+M0R+M +rjoeuGrnOktjwlMlmt8rMRxO8WeVr1tA5U6658tzt2Rprq1INkvVeLoBtGP29FqM +LI2YG4uKou43LU+jgPrkLkEt0bhb/mAdhHrR4ObxAoGAdgKZQgpLYDn3GY5d2tOZ +DGSQFeRk5wEMvUdi7g/AXQrWhD/CgknPusI6jBWYvR1LtMeXOoY561+Q0J40PC7u +UhFdEGN+o/6Y8RSHsORjH5KWStdt5Cqbgk3DViOqZYSE8heYaIoE3288T8QDCPaq +4d79dJxU2foC4oQLX9e7rOkCgYADaV0dt0Dt3xxXGUhtkPlEKO/vgOgao+bv6jz1 +Wlh3EiuQJ7KOw7dJnKiQqWwvqW4m3cZu+qbShCcalxR0bvhR0uv9M7hgQxb67AC0 +YRIqr8CCjP/Sc17BTRpi+sVyN1+vuaKqTxQQwPDXOx7CrnCmwRdhRFBzO3+KYMTj +nhWGsQKBgFjQkaxPXGpNaMOGoEzNir/4gvGsxPICUEACfA/HAAaTZm23Ly5+ThGO +Pvg4hqY1C1T+ksn0pGIlkGRi3Tq/QnX4NyxouskQHGswDq6wVGYHu00HizmHgYhx +XXuYhcynXchaHnpHNVxuq85cw5LWs3rPwkH50Kq4jnJ7wGhQLVB9 +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/abnormalCertChain/server/cert.pem b/test/opensslcrt/abnormalCertChain/server/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..12664c29bfda75fce75584e7bd8f2da99a396085 --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/server/cert.pem @@ -0,0 +1,71 @@ +-----BEGIN CERTIFICATE----- +MIID1DCCArygAwIBAgIBAjANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDQxM1oXDTMyMTAwOTA5NDQxM1owbzELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEVMBMG +A1UEAwwMdGVzdHNlcnQuY29tMR0wGwYJKoZIhvcNAQkBFg50ZXN0QHdvcmxkLmNv +bTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKfbo1HR8Hxv3qMqNGh1 +ufPv+Hb2UA8uTzyMPv1gG+331X2sQUqbv15jEuTLcjxoOYuCOmiFQPHJpj+I75cP +RrIuscOj8BUnWJlpejaPaICS2XOwzMG5Ae5rUkPoLyVcjhlu6Y3A4kFYaKvH8+58 +XwI1FM8Obswe+XWnkZIj4C5OChYIj71Zh2IS2zpPQjlnYkAA3FJmK7LKocMubjZO +SlhvIwzbBI9069IgEOOrg1xKrlm1gjbmfGFlXMYi9zLEm0SO8dhU5cWLY8QQKYkd +uEF9hurwe+fspGRvTPhbT7IR8yW3xw7kah7TUDs5b3jM6Me9U0M7frb23s7wT31p +2mMCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0EHxYdT3BlblNTTCBH +ZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFC7exNc3BjDBOciBzo5qVY+d +qPKIMB8GA1UdIwQYMBaAFMeDrbiKqXhHPcdp/CNpS9Vn/MTfMA0GCSqGSIb3DQEB +CwUAA4IBAQBq2XwbLsO8cImd1dOu05V6U0zW+/QM9WjP94hVUJiyQ+yexHNLKcvP +x+2G44n4/qUr80tvgHr1/ok9yQzaL+CKQx/hLYs45IGElLcsPU2tV0QZFOlcUgmu +mZODjuTUZzH5oZMcVuPm1IiAV7TFhn7Z0q4I+tticZnPLsdW+bkPQQ2ZrJodmbBb +85vWb92qi9kWUlBPaLx7ntvP7FMOWQYPyWpQE2ipV3vIUqXUxmXBjanHjOxIjMUE +DpRpJkrpPUrznj/1/jYpUQR7ohEsRSbNzKgJzKIgm2QvT9C7cK40bsp9eX3VSEYU +f1JHV5LnmGU3073GUBamF+OOlxWyT47A +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIBATANBgkqhkiG9w0BAQsFADB4MQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxCzAJBgNVBAcMAk5KMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNV +BAsMA0RldjEPMA0GA1UEAwwGUm9vdENBMR8wHQYJKoZIhvcNAQkBFhByb290Y2FA +d29ybGQuY29tMB4XDTIyMTAxMjA5NDIwOFoXDTMyMTAwOTA5NDIwOFowbzELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsM +A0RldjERMA8GA1UEAwwIU2Vjb25kQ0ExITAfBgkqhkiG9w0BCQEWEnNlY29uZGNh +QHdvcmxkLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANcpo7lG +5ttiS8yJtzxGVlO7lA0RqODL1xO84ES+jE3CPLCVo83l95E+ZHQZZWLyAr/DrGCm +rP93+ODINGkMvLZlEV+vYcxzMZeRwctITygaSaxV6bON9m2F3k1iQZS/gKkL9L1i +/F3nkOAEuCYMgwHfPy+/E97H9ECRpwt2mzld40UQZj1nQY1nc8l5F10p5GSiNp5w +98N/Fnk9ZCESUKtPviK541ljvs+tvgSvuKtJJIbdhbn+xW5O9hQE79u9iumo1zBS +jZKMMMdRqTIaqt/0Ofqu39g/OsfeLCfGr+NF9suTaeJEMFPMzOhSdgy9zf+fykVJ +GuXxfedLzbgP06kCAwEAAaNTMFEwHQYDVR0OBBYEFMeDrbiKqXhHPcdp/CNpS9Vn +/MTfMB8GA1UdIwQYMBaAFFcYHLLk3hRuI/Q1xMpamPbIOwnfMA8GA1UdEwEB/wQF +MAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGDOfMIV1AC5bdbpnhSDhpzkBjh3OXqj +PUplhCm+VSFUcSE5vTe4EnUsCmpIAYkAA/2fnr6Kb9GFJeuoWlWpo4vRW9g8qvRD +B7AoV/ojafPb4V2epZLaHpJQcxyqplGFKqPKDAXZlXLXUbWLAutq8q7GLL3MlqDF +nRDHBafXEHJOBQG7HyTDt20uoIM0+10ehJW0LlO2UD/SQ5ZEkYwB5fNk4rJMgtmb ++Nibo2x2ZdMd0rakv5CGCjf13aMczngYcbNjnUFjiLofHTEcoYmDVTqjd4W416Fj +2R+WfdalPEctoRCjDMi7dZLxQOY7/IesZbXJkVMtbP8lrVb+9Xl7lj0= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIID0TCCArmgAwIBAgIURA/lUFrZwOfJXaaLF9ol7wkhAxIwDQYJKoZIhvcNAQEL +BQAweDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQHDAJOSjEPMA0G +A1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxDzANBgNVBAMMBlJvb3RDQTEfMB0G +CSqGSIb3DQEJARYQcm9vdGNhQHdvcmxkLmNvbTAeFw0yMjEwMTIwOTM3MjFaFw0z +MjEwMDkwOTM3MjFaMHgxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJKUzELMAkGA1UE +BwwCTkoxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MQ8wDQYDVQQDDAZS +b290Q0ExHzAdBgkqhkiG9w0BCQEWEHJvb3RjYUB3b3JsZC5jb20wggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCgvTn1zACCZ4uny3BW8Utilwoztkhb/XM+ +ZI/trCZg1smsnuCNJHyIJVoFz4PoxXESCueTD0UwIcrftvzQPJZzVZEOY/ND4ZFq +Bj4TbCybSNuFIIAXn2yL1x5oLGz5wuEr7XClqUECVZPTyDv2ozg7+L6NRXNnQ3DQ +jL3QqEaH0M0hA/4FX7ySXrSC2BFX5LZzv8cjKla+3jqJUbUxokxEWMfYVNU+JjwV +vS4ieVqIsWcfd+FhqFCpvWj92PpJB5Lk6GkuGi026lgYutK30Gx133QLQjrNRRLG +4dM0KMevpSM1Ug7dXy60dIjlJkFjekEf7umCGFNILf3UKQTd1irnAgMBAAGjUzBR +MB0GA1UdDgQWBBRXGByy5N4UbiP0NcTKWpj2yDsJ3zAfBgNVHSMEGDAWgBRXGByy +5N4UbiP0NcTKWpj2yDsJ3zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4IBAQAaLNCcpUpKjxz3DKPU632uqljaliu5mCm9OAQRDLccfZifIHacyxx45rm7 +8TOSgQnSS41UtsvTMn63LXqWCi+P0K6LVZUDAiQzcehGsKYyj7yG8UiQev3qUXuU +HRD+n92AVKYz4ABkKqepgVqE1G0CQeFISU/czTv4r3+Qv5IGXWuZ9D9fNu2Al8DG +Zo4FdcxsBESksDP07Gjj3zMRgV0uhsoSP7gEc2zrucfSIZasEuIgPGL07vMc4Ybm +EaY/cZcuSj+o012wmP6YMQw4t74Fxqe41l4lY/yPtgUO9LDx9ygiDzEKaMCTea8M +33Oqg/tD4bzToCPlpeh7qhpUfLpP +-----END CERTIFICATE----- + + + diff --git a/test/opensslcrt/abnormalCertChain/server/key.pem b/test/opensslcrt/abnormalCertChain/server/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..71ec9e7e81db54ff9436879ed610102db7e881fa --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/server/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAp9ujUdHwfG/eoyo0aHW58+/4dvZQDy5PPIw+/WAb7ffVfaxB +Spu/XmMS5MtyPGg5i4I6aIVA8cmmP4jvlw9Gsi6xw6PwFSdYmWl6No9ogJLZc7DM +wbkB7mtSQ+gvJVyOGW7pjcDiQVhoq8fz7nxfAjUUzw5uzB75daeRkiPgLk4KFgiP +vVmHYhLbOk9COWdiQADcUmYrssqhwy5uNk5KWG8jDNsEj3Tr0iAQ46uDXEquWbWC +NuZ8YWVcxiL3MsSbRI7x2FTlxYtjxBApiR24QX2G6vB75+ykZG9M+FtPshHzJbfH +DuRqHtNQOzlveMzox71TQzt+tvbezvBPfWnaYwIDAQABAoIBAQCPFoHAK5Au40YM +HNwT99cOBI/vCMTyS+2rlXnUj2r/jfZlbMMzkFSvZxEiC/NTXx0+uUKE+qKD+ftH +yblDMfh3x6otNcBgp+u0yt8tR04z2/qVzi6dLNJipQW5cWFPHfjb4VoiRjwYq/5+ +ALMFputufEVCw/Da+8R28OL8iqx9ixNN/bfxktbP5NqGkcEUun5wE+/0KnFzylcV +XyKIzFFun8U6P0oupjv+aoB12SET/X75Xgh3Elx5/Q9FedHRXGjFcam6NJhOG79s +G6IkyciKP7Pf1SYFX6b5ps2bo3oPKWVvJLGMlxSuxJpqmoTQOzochoZF8uak+hkQ +eHl07+TRAoGBANBEVtTyMkRv6dGgmRCPsqolhxsOkHHd5+YOThg5d3EDWRcsUidR +VzwU9/GMrF3grV7FuQk2GYqMRDjFg87+/0O0H6vANxXjNc6sv+dblOikbAJoLksr +245OeCgd383K5+f00pOHvd8KfrxwQPqarpvW2JORXHludgZRgx8FGulfAoGBAM5U +YSKDw+d/D6+PHazakLucZNc7hUUd8Gvh8EYdUxObX+Dr8UIjV/1qqYZzhPUNDpn9 +//cYq8ndAQ7gCTANESJc//ikqDBXHJ3AoMwpNX86ozNUyt6+eoQx2tizgkEMzndo +BVt7A3RjWjn/+bJJX6mIsjg8salQrMW1bOyECHl9AoGBAJwdfhFl87RFR7oxbktx +y/Wq59mqUzBnrPtQYc3a1ePLJK8wM+zxFjkdZraUQmikkJDoGcoD2aV3e3Qq6qDx +mJtBnDP8g85OYPkpmTht9/NjvOsY+Qq0N4I24+7+ZdM3dBr19BtOt09H6LSMWMkB +xj1fET2cyvrjiGk4FNfd1cx1AoGBAL+xdW1zrhbt3czl0lQ93Cnx615sVi0Y273f +dDQwGnck67c0fjlMTPuMlWPs/6IMN3yql50itrgdNFZ1nxOdkEW00bxYfkorJNML +nFkSEDncaLPQG4tGvN0E1KZwYJu/IjOd2Rxc9aC0jadFQt95e/8umSXWfdkostwc +6s3y/UyhAoGAJAG6S/k1Xg0TbVxeZvMlnmfEc7WKh1QCnsH9Va/VnPXlIlTWmHu/ +xBtBb0kC0IF5zWgHhSADz3SHDsYZ87SEqoQxx3BKzd+EU52RF2x74GQDY49mQlUR +zMnxYP3TUSEfkTuSpT8ZwuxQ7KSk5nREDzKIVfcxgFw/cmrW/eChT7w= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/abnormalCertChain/server/server.csr b/test/opensslcrt/abnormalCertChain/server/server.csr new file mode 100644 index 0000000000000000000000000000000000000000..64037cc9b7f13aa23a7aec95eafe0b396a33d471 --- /dev/null +++ b/test/opensslcrt/abnormalCertChain/server/server.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICwTCCAakCAQAwfDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQH +DAJOSjEPMA0GA1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxFTATBgNVBAMMDHRl +c3RzZXJ0LmNvbTEdMBsGCSqGSIb3DQEJARYOdGVzdEB3b3JsZC5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCn26NR0fB8b96jKjRodbnz7/h29lAP +Lk88jD79YBvt99V9rEFKm79eYxLky3I8aDmLgjpohUDxyaY/iO+XD0ayLrHDo/AV +J1iZaXo2j2iAktlzsMzBuQHua1JD6C8lXI4ZbumNwOJBWGirx/PufF8CNRTPDm7M +Hvl1p5GSI+AuTgoWCI+9WYdiEts6T0I5Z2JAANxSZiuyyqHDLm42TkpYbyMM2wSP +dOvSIBDjq4NcSq5ZtYI25nxhZVzGIvcyxJtEjvHYVOXFi2PEECmJHbhBfYbq8Hvn +7KRkb0z4W0+yEfMlt8cO5Goe01A7OW94zOjHvVNDO3629t7O8E99adpjAgMBAAGg +ADANBgkqhkiG9w0BAQsFAAOCAQEAC2AVrLOyTNhAdoVMzdqlXNLBmuoKSJdJdePF +uM3jkNkgfV77opTDFVL2nYxTLddfUpYq8xMpqK2shXWz5nrjn+XbqVqDyP5F6oVl +Rp0EiTKPolvr6+qREnquF7AKRn6qZkSst3/QbdFJrIZ6FjfReFxR+8d+MkhdKcUL +hX0FD8/njwO6twXWqBADZrV8rCsfuIER8+nCVCo827J7ZPNtvli31aFEi1QXo1Em +9Azvn6EULZyLUdvgu5hANyXNRa0yTY+QGZ37lTHTAuogr7PwCW2PTr5AVLU3oaDA +KJT1hWaJgpSKRI9nFut7BVaGjrRQHNH+HTN12mduJFaIIDt2Qg== +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/crlRevokedCert/CA/ca.crl b/test/opensslcrt/crlRevokedCert/CA/ca.crl new file mode 100644 index 0000000000000000000000000000000000000000..ebe976aed211be7be9ce6ad4443a4d11fe3583e3 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/CA/ca.crl @@ -0,0 +1,13 @@ +-----BEGIN X509 CRL----- +MIICATCB6gIBATANBgkqhkiG9w0BAQsFADAXMRUwEwYDVQQDEwxjcmxSZXZva2Vk +Q0EXDTI0MTIwNjAzMzQwMFoXDTM0MTIwNjAzMzQwMFowbjA1AggKiSuy5iJGshcN +MjQxMjA2MDMzNDIzWjAaMBgGA1UdGAQRGA8yMDI0MTIwNjAzMzQwMFowNQIIbYHq +/AAt2m0XDTI0MTIwNjAzMzQyOFowGjAYBgNVHRgEERgPMjAyNDEyMDYwMzM0MDBa +oC8wLTAfBgNVHSMEGDAWgBSwBN0pJZOxSzO1w60tJY5g2hRNxTAKBgNVHRQEAwIB +ATANBgkqhkiG9w0BAQsFAAOCAQEAoQPvdcxxbNzqWGVF5yvigUawx1xcalDpMNAu +XSy2ZIEbMCN/Tq3CwwZZViMsdcnnSLBGu2VeuoY9WlDAKkWCMVbLAx3h5ZFu5rQq +8KtS5SE2KudCvsJcVIlex53d4rIeOKHifALXqEE5mXpQjJV4Nhx0QeloWzA0k44G +UVbCdS9pOPfuScROT4KRMF8y7qY0sTHf1YOu8AGSIBroatqvWmdvezR7Le73zbMe +D9JqhLr4Eu8oXf3KpRGDRqdCQkqAPgOGZguYbik8y1LlMcaxD7ZMaPAHi50hpzUX +ZpAKKnElP/STJppR8bKUp4UfxjtCVk63iQzdNf/YBNSrmFx2xA== +-----END X509 CRL----- diff --git a/test/opensslcrt/crlRevokedCert/CA/ca.csr b/test/opensslcrt/crlRevokedCert/CA/ca.csr new file mode 100644 index 0000000000000000000000000000000000000000..77b5f58c0ba02c3c1ce0b79c319f469c5b9be769 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/CA/ca.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsTCCARoCAQAwcTELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxCzAJBgNVBAMMAkNBMR8w +HQYJKoZIhvcNAQkBFhB5b3VyZW1haWxAcXEuY29tMIGfMA0GCSqGSIb3DQEBAQUA +A4GNADCBiQKBgQDRgBAq+XqNt6FIP4wjozb7VmZGUwtpRKRbcWYWEzJa2ERH9mxE +1hZYfxrevj0goWXVZH+Cxwi4eAzghdOAx9Y3cGR+1Jl3pd4GT8xANjZ4D5Yo2Zbu +anwWk21C2MfBcJOfwxNgVVqFaU23fRnemq+LtHAOnnUNFXUMx16D6R+HKwIDAQAB +oAAwDQYJKoZIhvcNAQELBQADgYEAZPfQjs5uqtOv6OASv5n1T7nWwjLUX/Nqhn9a +WJ3ut00k0yc7NmFPOlWW/FfUI4uyKkRLo7GtXPG0wG4Pn+Pe2UvM0LLXg0Sg47zL +bPV75ampPYD3KkfuQM+GV1dwHLz8kgl2OC23Jc1viZ4Z/ZsJQzuO2D1Wgqm40Dc6 +oVmSbvs= +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/crlRevokedCert/CA/cacert.pem b/test/opensslcrt/crlRevokedCert/CA/cacert.pem new file mode 100644 index 0000000000000000000000000000000000000000..8c3470d84c7abfeac46a8e4af261c9621fcb3001 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/CA/cacert.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDNjCCAh6gAwIBAgIIXsZPXhXOcIcwDQYJKoZIhvcNAQELBQAwFzEVMBMGA1UE +AxMMY3JsUmV2b2tlZENBMB4XDTI0MTIwNjAzMjcwMFoXDTM0MTIwNjAzMjcwMFow +FzEVMBMGA1UEAxMMY3JsUmV2b2tlZENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAqUSZF7HGpEaA+vAB1u7s7LUbAmJZ2FZmEDaUimbXmPM8S5Ipabxd +f33YWR8V8W5ZdVWOBV84ginwYGobiyFt8wNoFbE4EO6GjQ/bBS6PJaSyJHBvDHoD +NKgmm1EP8vMsmUHfw25BpAVKn1O1oxLnVDmFEZNUzP9xwRskgWODcJZBop4Qfw// +ByvfYXsQNG4uXQZjAMhi9k5t7OlSePdkUhWFcZ6WSphgSVssumGQGhDqzyTTeJVY +Dq+BybLLMlCE7JaNmoI+UlD829HMrlsFTgluc7xUAzOTJIDPG6Bbj8bXySjBeQSw ++D/ziNTCJp3B6HCYkZOJJnS+NNP59L0F9QIDAQABo4GFMIGCMA8GA1UdEwEB/wQF +MAMBAf8wHQYDVR0OBBYEFLAE3Sklk7FLM7XDrS0ljmDaFE3FMDAGA1UdIwQpMCeh +G6QZMBcxFTATBgNVBAMTDGNybFJldm9rZWRDQYIIXsZPXhXOcIcwHgYJYIZIAYb4 +QgENBBEWD3hjYSBjZXJ0aWZpY2F0ZTANBgkqhkiG9w0BAQsFAAOCAQEAYJRmR07m +G0RYrwCC/xjzWnrpj5uTrEp5C1EO9H0pBEpBshm2Y+nFMMd8mO240rR0FOXr5t+3 +EG4Nr76XpzIwH2nt0hcR03+0VanAHtM4QqMifRQgD4Q1OgKCGB1qYWY+d5fY8zBj +VeKsW7FwRraL+MSjqdB0VxM6wCHc1xyJpy8JBO48UprqOXP4A5ZJwfSXICa4cKQ7 +Y9OIjT/9/CyIo/8ErgwnZBywiIY6suS/Q523XBrg+BdYyEOoJL1+kNZyAnFpv3pn +eRq00uA6JR24aU8hkyJY6//Uq8M5OuE5d2wtrExsd2WxoTq0pnLBibzM8HNSdG4Q +t0RSmlMCHE/5Zw== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/crlRevokedCert/CA/cacert.srl b/test/opensslcrt/crlRevokedCert/CA/cacert.srl new file mode 100644 index 0000000000000000000000000000000000000000..581cf6858fcecb3a24c7226df85f09a7a83e8751 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/CA/cacert.srl @@ -0,0 +1 @@ +4D6E3E9110C5EBC3BF0FCE8D5C34A714B72F963D diff --git a/test/opensslcrt/crlRevokedCert/CA/cakey.pem b/test/opensslcrt/crlRevokedCert/CA/cakey.pem new file mode 100644 index 0000000000000000000000000000000000000000..3b414f03b037103bfff5af6f8d3559ebdc8341e9 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/CA/cakey.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAqUSZF7HGpEaA+vAB1u7s7LUbAmJZ2FZmEDaUimbXmPM8S5Ip +abxdf33YWR8V8W5ZdVWOBV84ginwYGobiyFt8wNoFbE4EO6GjQ/bBS6PJaSyJHBv +DHoDNKgmm1EP8vMsmUHfw25BpAVKn1O1oxLnVDmFEZNUzP9xwRskgWODcJZBop4Q +fw//ByvfYXsQNG4uXQZjAMhi9k5t7OlSePdkUhWFcZ6WSphgSVssumGQGhDqzyTT +eJVYDq+BybLLMlCE7JaNmoI+UlD829HMrlsFTgluc7xUAzOTJIDPG6Bbj8bXySjB +eQSw+D/ziNTCJp3B6HCYkZOJJnS+NNP59L0F9QIDAQABAoIBAALgDKMZ5eKJdzD5 +uR9rTIv3WG1RlVwbdCsVDIuX3Xjx2DEibYU6V6mpYFxXLugqKzBpZR9lheJq/T03 +KVXlmfiVad9pmUtrMKpVK4sqofBFHEQcnjv9qbj9wVqyd7r1PFczhsBvk8yTOrh6 +03WFc8CfkclvkZcdjkEIZRFrk/TF0GTR5AgEjRKQwjdoX+VuX3fOfX/+8SYl9Vus +cF7zZ2I9Xh3srIY4DtX0nSzDhrHwj7lZ9gE5BiGIICG6ofSmcFxY3sz3HkcaFpr0 +kz+29BUwWjRn8waqkCbodqgiCl8YxzT7AkAzV9t381AwlcrjUGBcwZd3ygIL/CNG +bblO4ysCgYEA2gpeQh1HdQtYNJ9f97Azmn7Ub/ZWzQz+pxn4EJJthJup7qTSkEaP +s/ReN1xSOymSe6pjDhCJJEzpNXVkSL+LfcTHEBCWKTxVKhde64B+8npsE5TNPm6/ +RwyUhpZqSUg6+hB+DWAONEd3NR4uEOloCfGTq+s8XiNsy+wZB/NgemMCgYEAxryI +3uDaTsHl6tB9lQmtiKxibyz32gaftFB3ui5W4E3C6PpgyiJz0vq42+LJEhOrNO7r +R6wLfyPDQS03s1lum6WdIIGttL7oOomA4H0AzKvL4a+KzW0xI9szwSn36gl0TpdJ +6tSLDCIebEl2BqknXf27PsLmZGz/Q4y4BZBGgccCgYEAwfCY+Ku8ZJJrqZrdLJ4w +uEn5wYDWolrdo1qI+IyWF26yTw+SLzxkE8fXidx3VCJdJdxfcSIqKfyzLpa1VbPy +ajL+cP6f4hvFEMRq7ISF5j4YJa6khNTTFGpPLgtajMZPlx/WTyVSXRP4J2opxWcm +4DldhXyeXImb+yKM/TJwbUsCgYBVYAUEil/+WXnmCC8K6Z9tCXE4BX9sYFAXbEIi +72lO7tj+sSEHdQmA50im3aQ3ac+w67qTjatasQheZYh6Ob80ik8X+d2fuiuViiCd +T9YaFut9RqZAXdowZyiyoM7p8anPYVrV0Pzy921OY7iaX+yTij74VU4YcmKuqDMd +4L8JXQKBgQCoZ8iIQVjzLlwoYwK3FTSxub6s89E1I4Ej28wtmuE8cZpGqpZ1p2Af +BxuCq5EwzB9uwomy4sVnlI7eL+1RPBLTd4dz3hMzINNzs4E0S9358oJnbaVgJolx +9jTw37UmogGQFRK/SI6El4Uey/8Sg8NGca4V647DyT5hHXlJjd3w4A== +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/crlRevokedCert/CA/crl.pem b/test/opensslcrt/crlRevokedCert/CA/crl.pem new file mode 100644 index 0000000000000000000000000000000000000000..7d1eb005222b656feb51bf9170120d63f037ef41 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/CA/crl.pem @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCBygIBATANBgkqhkiG9w0BAQsFADBxMQswCQYDVQQGEwJDTjELMAkGA1UE +CAwCR0QxCzAJBgNVBAcMAlNaMQwwCgYDVQQKDANDT00xDDAKBgNVBAsMA05TUDEL +MAkGA1UEAwwCQ0ExHzAdBgkqhkiG9w0BCQEWEHlvdXJlbWFpbEBxcS5jb20XDTIy +MTAxMTA4MzQyN1oXDTIyMTExMDA4MzQyN1owFTATAgIAmRcNMjIxMDExMDgzMjM2 +WqAOMAwwCgYDVR0UBAMCAQEwDQYJKoZIhvcNAQELBQADgYEAcrqdK6qqVnX/01hg +eIzyVV8dudN8lojKtzw7EZ9ukxmcAvihTV+58Le0UhfBX2Bt9iTmrsekzwgTbuoM +/vqdMej+jRkQ0phjyn5oIRMbFTQDZmqID3mcc2dMYL0jdmbonIV1NgaAAQqpoapx +dGc9hJWvHOSrWsQ0b1bbIt4tOXU= +-----END X509 CRL----- diff --git a/test/opensslcrt/crlRevokedCert/client/cert.pem b/test/opensslcrt/crlRevokedCert/client/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..40650c392eb179083beb6a3c5c1d963c9059a74f --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/client/cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDADCCAeigAwIBAgIICokrsuYiRrIwDQYJKoZIhvcNAQELBQAwFzEVMBMGA1UE +AxMMY3JsUmV2b2tlZENBMB4XDTI0MTIwNjAzMjkwMFoXDTM0MTIwNjAzMjcwMFow +FzEVMBMGA1UEAxMMY3JsUmV2b2tlZENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAuusB0DBPmC2jVjE/xXE1DYLLBJKvE14Ad2XKx/+sIURu2qV2bNio +Oj7DOfOJPSGmCnexo941rSSAP2+Uay+vrT3ePY4jahYbscR440PVx0SQRWaZJDTB +n2mD3NlcSEWkxNqxMdEyOLy96hoPJjnqOe3pBqMHqIvy2ZlasHl/b4ALvaN+v0C+ +AX7lKmwD85TEmFora1L6vVtJ6WQbFDJHs+SYKM9V8y0F3mGVrUAj3Bh94nSiwrip +hDE9g1ODt3/GLsNzhScPAYpyp5UQxy1YCUP3L9+QuJLLV5MQAbol5I9DFpLSEHEb +NJodgrdD7VVbZfgJVdW7az97WLFOQIS+ewIDAQABo1AwTjAMBgNVHRMBAf8EAjAA +MB0GA1UdDgQWBBRK8HIpowCVXrT26anmVra7VuhqGzAfBgNVHSMEGDAWgBSwBN0p +JZOxSzO1w60tJY5g2hRNxTANBgkqhkiG9w0BAQsFAAOCAQEAcjovMcD2dw3Oabk2 +qwov+8Fy7calla3tc09GdxXksqcNeLj+DxmdeNAj7Ei/0p/e/Y0gEKhWMUeVBUB1 +hmapuiYWPi1WID+ZcZA1a/jHpRWTARnr8cheZ6IZKAPG6i1hV5m4JPQ//y9mgpKQ +uvS1MWnqJkLEL1fR+jpuM55l6PKJWxRkRPce2XTdV0SPmIirYfL9gWmi7S1HRxmu +QKo/CRLON8JtE0G4m54K+3Dv9zKbC/up01IEDstoCdxMTNawoDZbEkLV4SnI7Odq +SPoVF/q0xIN8SLr+xvrkQnnHq30juEwaoiB7NR1joMm7Nk7hWAmZMlAfI5W5hq/V +GthzWw== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/crlRevokedCert/client/client.csr b/test/opensslcrt/crlRevokedCert/client/client.csr new file mode 100644 index 0000000000000000000000000000000000000000..178f04b32d7fb0477a6b4d4da7c2ac551736fa90 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/client/client.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsjCCARsCAQAwcjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxDDAKBgNVBAMMA0NsaTEf +MB0GCSqGSIb3DQEJARYQeW91cmVtYWlsQHFxLmNvbTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEAoBRHl+VNL2AM/f9aFJfhNlPlPPlXpIr8Ba9Uu3C1i4gfbtg1 +DZNKUUz9vw83Ruit2d+wm1mwdZV74T//HNCSmWZQRcFSUX/X8tf7vse3xS4cL6K2 +LQ04PbN1GBMf81SMtlVAzsMDwi8WbSHpi6RY04Ufo5Z+U2IgGqM2RdJgub0CAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4GBAHnis+sXD1GxQGCI/NipXqFY/kB+lErQDE/c +fBXhA4eccbx3/6Q8AVmQKCADYnpbAnNxoFXpitZaOGDtqJwIYMolYfvqDJvVW7Dd +2f0GkbeFiwo2U3jH0wXswzG0IXu8B+frBgeESUlRGsthMkMKX/w1WaV+bq12XtBY +DM75APZn +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/crlRevokedCert/client/key.pem b/test/opensslcrt/crlRevokedCert/client/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..8671fd19c1f1137c65438aebca72b8d175409fc8 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/client/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAuusB0DBPmC2jVjE/xXE1DYLLBJKvE14Ad2XKx/+sIURu2qV2 +bNioOj7DOfOJPSGmCnexo941rSSAP2+Uay+vrT3ePY4jahYbscR440PVx0SQRWaZ +JDTBn2mD3NlcSEWkxNqxMdEyOLy96hoPJjnqOe3pBqMHqIvy2ZlasHl/b4ALvaN+ +v0C+AX7lKmwD85TEmFora1L6vVtJ6WQbFDJHs+SYKM9V8y0F3mGVrUAj3Bh94nSi +wriphDE9g1ODt3/GLsNzhScPAYpyp5UQxy1YCUP3L9+QuJLLV5MQAbol5I9DFpLS +EHEbNJodgrdD7VVbZfgJVdW7az97WLFOQIS+ewIDAQABAoIBAC+VPs9sVP7U82Uw +QV82a/6GAr+lxbrk7bJ3b9FNFVkcrimsAINB/kVhj9pTVGcDB11Xzhl8qmZqydkc +QwisSYkNHGPIKNzLRv6QmQOl06iYHDAGp8qPQZEez8orbTIaxJC2aXBYpHPMHH6Z +ZqzhBm770Rv26a6u/cIGJ4+QRnAwFS5FWV0Hi/UDEeShrdP2WpiJMpMdvOmgADfx +xFiPZ2UskgCn7JjJOxlnaL7OyXWtsPXcMLQ22ppjFcO0s3JmQspYI0R3HzCWcbiN +6n71/nhBGgNxyG7H9sHWr9g/g+TbWzaPCwKMkXP4iLTCo7SyExTrEj0QQHqSSdeH +v3ys4LECgYEA8Gc3utfB0Pix4gGJGrZ7/pl7azEFivtKw/rczyWW88nqgDfK74pH +A/x7cDB5f6yKelBjsN8g095lbBNduqXIfTVfNMnYz0BfRXJebFRPkpgO66gSdigg +bkoQ9wMYiWNrHlHwgecBJW0A6hdEMplpO5yMqX7f4fyeVSsfRdap6pECgYEAxwt4 +ZEd6H5AbtUSW+rylKMzvQodH81D2J/KQd9nF0zfre5uz7/B2TRdnuIXUJeihULD1 +xM7wQpMWnNELmGUYRF6FR7D3FQwvlVncl6c8u1WEV2CtAztg3sNLnToV8Ym4+/Xw +jb7gUfYP0Z31UexL8EuxIMl87xvlP9Y4cZfzpksCgYEApvr60bPlctIbNGFeqwhP +rfNLr5O4RuXu85t+3f89D46TbIXUVnfUhedMkguJiO3+raDC7eq98YvLJLt64hce +4CE/RPpitun+gs1C2rnhH3CxBGyUji+m/xRBHmGvHFbK45+Ej6Y1vJovBWotu3ul +hXtilKSEu4JtM7klGqxVD5ECgYBaGkwmgOC8lvykAuehm3HcVWDdqeWdNIIptTlG +t9NAeSjlLKe4rw6ueTDfq91MR9F2eJTYwaCrUnBDDrKECweAM5V7zZUb9wc3QZWF +JRRt7RlagI3xT5jbXmbFRr0WOPT/QebiE4zyLUQd7LZXnnMF1Q3avLsMdoFBc5Tb +fyCxnwKBgQDeR/EU4pAFEyT2Y3Usqwpcq+BpENXke0bghl4O13g9GiQjhvz/jph1 +OxkTSVe9IZDUTk8r5F9FJNEEYBIdujBDSeRirIk4Z2T8j9FhgqhKYO+i6bwgv2E6 +BDw8QQPF1xx066vhcHDVNbAVEiNw3sREvi4hrOtYtrYRrnu+9E+orA== +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/crlRevokedCert/demoCA/crlnumber b/test/opensslcrt/crlRevokedCert/demoCA/crlnumber new file mode 100644 index 0000000000000000000000000000000000000000..64969239d5f72d674bbedc24eb0a155a59d0e607 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/demoCA/crlnumber @@ -0,0 +1 @@ +04 diff --git a/test/opensslcrt/crlRevokedCert/demoCA/crlnumber.old b/test/opensslcrt/crlRevokedCert/demoCA/crlnumber.old new file mode 100644 index 0000000000000000000000000000000000000000..75016ea3625245b1aac79cc5586c3f33ce8b7c78 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/demoCA/crlnumber.old @@ -0,0 +1 @@ +03 diff --git a/test/opensslcrt/crlRevokedCert/demoCA/index.txt b/test/opensslcrt/crlRevokedCert/demoCA/index.txt new file mode 100644 index 0000000000000000000000000000000000000000..60bccd610dfc0e96acc64c132f3816b4c78c283e --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/demoCA/index.txt @@ -0,0 +1,2 @@ +R 321007083201Z 221010084405Z 4D6E3E9110C5EBC3BF0FCE8D5C34A714B72F963D unknown /C=CN/ST=GD/L=SZ/O=COM/OU=NSP/CN=Cli/emailAddress=youremail@qq.com +R 321007083151Z 221011113659Z 4D6E3E9110C5EBC3BF0FCE8D5C34A714B72F963C unknown /C=CN/ST=GD/L=SZ/O=COM/OU=NSP/CN=SER/emailAddress=youremail@qq.com diff --git a/test/opensslcrt/crlRevokedCert/demoCA/index.txt.attr b/test/opensslcrt/crlRevokedCert/demoCA/index.txt.attr new file mode 100644 index 0000000000000000000000000000000000000000..8f7e63a3475ce82ed03dba035f5c01a42ca38c65 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/demoCA/index.txt.attr @@ -0,0 +1 @@ +unique_subject = yes diff --git a/test/opensslcrt/crlRevokedCert/demoCA/index.txt.attr.old b/test/opensslcrt/crlRevokedCert/demoCA/index.txt.attr.old new file mode 100644 index 0000000000000000000000000000000000000000..8f7e63a3475ce82ed03dba035f5c01a42ca38c65 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/demoCA/index.txt.attr.old @@ -0,0 +1 @@ +unique_subject = yes diff --git a/test/opensslcrt/crlRevokedCert/demoCA/index.txt.old b/test/opensslcrt/crlRevokedCert/demoCA/index.txt.old new file mode 100644 index 0000000000000000000000000000000000000000..21c93253eccfbb77ec555b4bd0f2d00a2f2e95e3 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/demoCA/index.txt.old @@ -0,0 +1 @@ +R 321007083201Z 221010084405Z 4D6E3E9110C5EBC3BF0FCE8D5C34A714B72F963D unknown /C=CN/ST=GD/L=SZ/O=COM/OU=NSP/CN=Cli/emailAddress=youremail@qq.com diff --git a/test/opensslcrt/crlRevokedCert/server/cert.pem b/test/opensslcrt/crlRevokedCert/server/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..956ef5158f072a005e908593ba676943dd700d8a --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/server/cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDADCCAeigAwIBAgIIbYHq/AAt2m0wDQYJKoZIhvcNAQELBQAwFzEVMBMGA1UE +AxMMY3JsUmV2b2tlZENBMB4XDTI0MTIwNjAzMzIwMFoXDTM0MTIwNjAzMjcwMFow +FzEVMBMGA1UEAxMMY3JsUmV2b2tlZENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAyHQ4O9J4n2mew+H4MHNRCBmsjlE54VQ+BgpoEavf99boNREJ8vWt +rCadtUrNzg35XylkCUzn1YzkfOC8arhXbg3UeyvFqg0nu3WXjyZ9IgNkcc4bzdIx +YpKVzDr6XrgHQ1YPy38osomfEkJL5BUx6b/e+JEOnfiBlt//uZUhg/9QvgnYEywa +ZSK+cw3bpUVZms27VHpigLLzhnm+RST58M8196szW5D/rk+Uq2RNLKwlVMwnkPka +K/aFZTH6Fj/jlcf/E/o6JVhuy6uV1eeFlJclYTkTs4A6jg7CpckHb8OyL7inuDVN +jvQHEvORUR11hNLBOW6/AcvQAjBKzY/6LwIDAQABo1AwTjAMBgNVHRMBAf8EAjAA +MB0GA1UdDgQWBBRxawkWYSBiWfGW4DFxT2qNfbQdgTAfBgNVHSMEGDAWgBSwBN0p +JZOxSzO1w60tJY5g2hRNxTANBgkqhkiG9w0BAQsFAAOCAQEAV6C2pq1jpKgfN634 +rM/guKw4hmHCxZa/3y2XG9wtwTHSmTPNbgZw8S8fq7qPGciHMRd1/2op4sK4s/d8 +HiJsc1028fz5tADV8BVNhWcM9W2JtQtFKPO/VGaPjFfZzGv1P+MkhT3mv653rj7B +6YbHC9V6/ewFjash7D38bQ1xQYkpgfI64FZWrjJluYZEYzpDkan4MqcbeBCKgUAq +Gh0YMd80dEEzv++MsZhysxWhs0wrM0NAyHoKnb5FxaPcTweBvutl+W0j5WBFrsxx +KtmllV3wfiWF+6+iJcrbMsqOAXFqnPvl9It86ylRZ5v2SSZlzbVtN7QywQ2uNLsz +zos25g== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/crlRevokedCert/server/key.pem b/test/opensslcrt/crlRevokedCert/server/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..2c9f5fe96c79a6e3a35ec7b8e3329f5c6dfcf735 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/server/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAyHQ4O9J4n2mew+H4MHNRCBmsjlE54VQ+BgpoEavf99boNREJ +8vWtrCadtUrNzg35XylkCUzn1YzkfOC8arhXbg3UeyvFqg0nu3WXjyZ9IgNkcc4b +zdIxYpKVzDr6XrgHQ1YPy38osomfEkJL5BUx6b/e+JEOnfiBlt//uZUhg/9QvgnY +EywaZSK+cw3bpUVZms27VHpigLLzhnm+RST58M8196szW5D/rk+Uq2RNLKwlVMwn +kPkaK/aFZTH6Fj/jlcf/E/o6JVhuy6uV1eeFlJclYTkTs4A6jg7CpckHb8OyL7in +uDVNjvQHEvORUR11hNLBOW6/AcvQAjBKzY/6LwIDAQABAoIBAD+9kQoYval3NPh1 +uO1zJc+tFk007/JCTuytz3TbEf9ls3X4YpEC5URKDTkSPcopcGP9reU1Mr0Zrag1 +8EQBNS/2LtwxyYtu3nUCCf0HkJ/Js11q3DWBApolIYQNNwL6gri7O4+Kcj7/evkb +4rT2tmXtULH6skqyNEbtGIyDKuKlM12qJD+Gsa3NXbWynDARpGYPuvkfRdzOczUR +deradJb+HH0IYM5Fqml1g/pfhIWzQADQCjrH2bRTdV+ZAC08teALQFI3atnCNE6g +X9WRDpYdezl54q3yEX9aRYrwPQu0rZqZbR1uhUzR4qc30HRanzNOXy+07ouDVlkW +CDRz5NkCgYEA7osGbcoxzkQCKGQo5L57uo1Ugii6sXazrkswGYJBnlesWFGWg95c +qbFTZrAgBLYatP5/1fTB1wA3S2yambUOHinu5jXfe5PtXjCP1G/xI4/QrQcDZJrl +P6VRoElUSTczWz7n9tUlPwmXOImwq9LQK0VVPJQZhfoefdU5Exp4krkCgYEA1x+d +wcppolp4W5iatPaGpCXsiUXcwa1gFugzR7QP/NMwi3ClPNMyAXhjd6I/hi7qjaDe +rJclyCvpK0DyoKRz4/BL5v5UXpm1oVFZ0pwCQo5okHRveuCapJs3C6g10p6swwcD +PqAtyDN5AYCzEzz/TueYRHcX1a/gQPEOa05uoCcCgYATwraMn8RSdvXKzMlKcbEk +OhL3GVKl4tRtqtLYOh8fc9nWEyQp1mDMueDTz+FHIqLBvZdvbPAl1GHAQMXp18LM +ouKkdXAG97EhLVqs2X0bWg2KV9tjGSXGiPZjFdACpKz/cQ2cN/n25998EoPtJ/CQ +tRSjiQ04OQkPmXs55uZ4GQKBgQC0LnsIEe2gWn5s2cjawZHMcydHYwigAAN3rE0n +RZ4OeSUz5cnAxHQh3yAQ2Ai07x7uOUvI62Bt5LmLzP5rptKanG42r9ci1UPYNjcx +vBH4hSb3t06YcP+V15unW+CY5OZ3A6yoC6nNAa/cnltfRbvh7wEOd2GzTXhbbIxT +PIJUCQKBgFWEsb3npVGQqgXvaEd2dvvtRfJ5J1MWWYVoJQhwaMJSTYRRgO8T/z/x +D3GwZ4BnYxr6dX/LJratDOxT6J7GE4Z68S1i3tyxEuL7Fh9EQ6x1KgksNYsiq/Vm +MmB/MPGdMmtBveMo+hFiDtrcOysqqhce0sIdvczebqN5M1SoEUX1 +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/crlRevokedCert/server/server.csr b/test/opensslcrt/crlRevokedCert/server/server.csr new file mode 100644 index 0000000000000000000000000000000000000000..2f975e3b0ddee9079e63a41ef0f3d6e6a817ef01 --- /dev/null +++ b/test/opensslcrt/crlRevokedCert/server/server.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsjCCARsCAQAwcjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxDDAKBgNVBAMMA1NFUjEf +MB0GCSqGSIb3DQEJARYQeW91cmVtYWlsQHFxLmNvbTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEAvyjgp/jHDcgapAgOd9PpBZ4DBCYg93lQDWljJfBBwhJazPio +I8WdcO6eTIRNEByi7bF4QLKLYnitKWM1OAkGWB9nCVBcwRjyGk3luXZWmi5n44aD +ngKf+v9ivn4TDboVNmd5STDN0VhH9VJL4hkgQd4G92vs+ursEnXwr7m8V5ECAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4GBAI8TUA5qIswYxS61IXe013agnG1T7KP9+8px +n7q4COPwOz2W3o2WqzY/T895OQDgMKCZIpoKbLTkTImtiqe3FNXPbxfk45/H2VD1 +8Zk7Gy+rnXZjmbIHNvNDKN+JfqQcoPVomy/Zsnp7fElReeIqfJdKT36eF8IIC5P5 +UswdTrwC +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/expiredCert/CA/ca.csr b/test/opensslcrt/expiredCert/CA/ca.csr new file mode 100644 index 0000000000000000000000000000000000000000..b64effae9b89a5e70b378ec0c30cd9f7b62b9881 --- /dev/null +++ b/test/opensslcrt/expiredCert/CA/ca.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsTCCARoCAQAwcTELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxCzAJBgNVBAMMAkNBMR8w +HQYJKoZIhvcNAQkBFhB5b3VyZW1haWxAcXEuY29tMIGfMA0GCSqGSIb3DQEBAQUA +A4GNADCBiQKBgQDb6iEU/m0GsmkxY4zHpodIqa59ew3NdwRdEjQ5M6fgI1dLWfQT +aYmOiZt0zia610ujxiSyLdoO4n6cQO57801AuzBy4L66WUHf+1wvHSvDG0kUOlfU +HukOeg+cF/hO6fzQPle/If0ZZF84DVZBBohwoUYRsMZs+2vgj7l4+nDocwIDAQAB +oAAwDQYJKoZIhvcNAQELBQADgYEAGLx2JYBYujabQZOy4rr1zxZ38Cgjjn1CTHyV +T2gO3iHrtsKmBnwveDQdUnrY22rAF0BCmioOMnImzEtY5p6J9gp3wmkBlLreqD+X +Fy+XBkKCKJcV/761dvOFV/Ar/K6QCsLQVeV6hAwTIUqfdsIN97aog90sTy16YovG +x+QAwO8= +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/expiredCert/CA/cacert.pem b/test/opensslcrt/expiredCert/CA/cacert.pem new file mode 100644 index 0000000000000000000000000000000000000000..63dd79d09b75330db3ba48f18afadf888556193b --- /dev/null +++ b/test/opensslcrt/expiredCert/CA/cacert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDLDCCAhSgAwIBAgIIJG5G7f8Y0l4wDQYJKoZIhvcNAQELBQAwFDESMBAGA1UE +AxMJZXhwaXJlZENBMB4XDTE0MTIwNTAzMTMwMFoXDTI0MTIwNTAzMTMwMFowFDES +MBAGA1UEAxMJZXhwaXJlZENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAs3dSFAq3G2n2upBQuJlDyUJUk9N6OxCLZrCsbnyBVYHavxSGqDXHu+u4m019 +cX9S5Y3MoV+OGFEvdZIWWjjYdgidVb+kQGJH6QUJhfLTKFS6wHoq04YOy2uIBbPL +TdIcvYcJqynBlbvGkJ5kV0PUAXi35sCqZG2qF4xA32U1tj3+UN8lJCa64eksZiYR +iBodZd2yryahELBE6XXiv1r2Cl8dDU266nX37iyoQBoLaEEdOvNr8QrXLUgT76ps ++ZIZoRNicL90GDOWsZEjbBIiCafLU+hr2sIHhe5kTVWIKD1M+C8pS/iX6NVYUpdA +Km4ldNEXoL+IetKdlTM3/4FCVQIDAQABo4GBMH8wDwYDVR0TAQH/BAUwAwEB/zAd +BgNVHQ4EFgQUXlW0PLV6+JzYFjJUqAVn4pqZA0swLQYDVR0jBCYwJKEYpBYwFDES +MBAGA1UEAxMJZXhwaXJlZENBgggkbkbt/xjSXjAeBglghkgBhvhCAQ0EERYPeGNh +IGNlcnRpZmljYXRlMA0GCSqGSIb3DQEBCwUAA4IBAQAbLlpifZhya7HNI+MpvT5+ +/136AGflP/ZkKvilQAgoIZDkw0ceko3RfmbUSd9Mo9A+O6ExTLZnHtwgiBxZ1N1i ++8e+4rxwcrcQl5ErMWAEBuFV4igrUTkP/n0gxZCroDehcVabrf0/FpF4i2bv2Wny +kCkw/ukvZhda0pem1ErImtxpzfOVVweIzxl0t3MXDpTGHPQKubCL67A8khjHdnL3 +kzmTnZBQlaj9WQQdw0Pv3IfT3NHrTngF4WwSijahPW+HmHJ6lSoRC9jpOARgAhF4 +G3+ABjkDlbML9orATqny9lUyP3VAhI99sMl3il92w51J1ZtViBYav6VIAtHCNC1j +-----END CERTIFICATE----- diff --git a/test/opensslcrt/expiredCert/CA/cacert.srl b/test/opensslcrt/expiredCert/CA/cacert.srl new file mode 100644 index 0000000000000000000000000000000000000000..8c83420fc7fd75ce3d731ef5cba41300acb155b1 --- /dev/null +++ b/test/opensslcrt/expiredCert/CA/cacert.srl @@ -0,0 +1 @@ +6E3D22A8A1E23DCBABC443BE17CF99467D580D36 diff --git a/test/opensslcrt/expiredCert/CA/cakey.pem b/test/opensslcrt/expiredCert/CA/cakey.pem new file mode 100644 index 0000000000000000000000000000000000000000..499d278ac5d3ec72b65000d6dd2bde6c5f86d86e --- /dev/null +++ b/test/opensslcrt/expiredCert/CA/cakey.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAs3dSFAq3G2n2upBQuJlDyUJUk9N6OxCLZrCsbnyBVYHavxSG +qDXHu+u4m019cX9S5Y3MoV+OGFEvdZIWWjjYdgidVb+kQGJH6QUJhfLTKFS6wHoq +04YOy2uIBbPLTdIcvYcJqynBlbvGkJ5kV0PUAXi35sCqZG2qF4xA32U1tj3+UN8l +JCa64eksZiYRiBodZd2yryahELBE6XXiv1r2Cl8dDU266nX37iyoQBoLaEEdOvNr +8QrXLUgT76ps+ZIZoRNicL90GDOWsZEjbBIiCafLU+hr2sIHhe5kTVWIKD1M+C8p +S/iX6NVYUpdAKm4ldNEXoL+IetKdlTM3/4FCVQIDAQABAoIBAAkKY1IF+dQ8Sd+l +6kxo+3NIRN/0g4q1qjwS/YC3A2yQtSRdOZeq1o620MQmtWwcrKBKkiAf+EBtLeLy +PW+FP7xEcf6SP/PhWjuDaN0nTsVQQrR0WF4dpExRigGK3Yc5pNUprWOYCnvglh+4 +vEJxdNzacrMsX05eADdIS11QPnzZZrTFGpLYuv7QkEpyrbHqwpF2ZKPgMzQUR5ja +4Ii+0yNxJtNhGxHyjCWYu8tu1R/H3vGCgYcwN/5MkCTSXkK9+jEoIbw7Nry8RGP2 +6lMzIB6t06q40Wllf7WVFbjVCXWcrgBjH9x4CtEKiX6hQ9c+I4elGmBnv8xCblbt +hXqM25kCgYEA8ghlYIO23hySjT3AxxUTdCzHEMsj6EX5pIVq9dRBeLLa6djb3VGQ +UZSbD+tQoTuTt/MQbL7jRMrMkjEk07wypfFst/NZIeWGrVMpeNgn/cfEkX9riqTq +UAGsryC7kb/Qh70ib7X+10frwuj+CvRqRNToiXQzt33B1pBsnyXpizkCgYEAvdKc +6C1kXbUiRbl+MLRLXwr3cH6hObXcuz+yV39YimgOQa1TbXrJ3o6XWRu2TXSkwhN6 +d3f2YgQlzQwBV/C/uDKLIJxN53OmM2xaD56gMxH6GZaFBjURVETWIFGJarq9BKEe +PkKeU776WxtjweyT2khuIi67wrhVDWY+mEaZA/0CgYEAvwYrZSJv2Rv5+unrVfVn +dkDFGJcMDw3ze4sloUJWLjoTl49l8GaongrI26ag67BG5V887mc2npvlG9kXtNmL +q/dBFGpHRf6O7gt95vfLkHvzw9YqfbGtkVnQ/iO6hs06e3emTz74FNeZ1yrnHqRL +n2ne4mXJT6A8teklIfLdQdkCgYEAokfYVLXZxDMCpdE7DLBWb16Uw9u+QfQ0TPDb +qun36/uhlIx1cncy9c25njiO1hEOeczlOhWY0ZryJiZJT8FGZlbvkmWORH0ebYAI +TqoyEvcfdZi1COV9dymSVfbfIBudVMPZcyjI6peuuQzlb/itPkQFw7toUcP9EoOC +p/RTBJECgYEA0158GrdEaBQzJMQjDFwJ02cUXLrGZe13dVOqLnrf5tt5dPnfI733 +H3gtqXglJByuSCVEDyDP5SZkz+/UsFCkFk7ZRhz0jz7cLT8/+gQP/vQzyxUkikRE +1fKWOmYI+v1myopA5Vmoqg19FeMwjdSd9MxGG+F+PzvTRWm8hqBxNf8= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/expiredCert/client/cert.pem b/test/opensslcrt/expiredCert/client/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..73699d9740be4bb2df95d7f487239fb35fcfa417 --- /dev/null +++ b/test/opensslcrt/expiredCert/client/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIIENOftxjknyUwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UE +AxMJZXhwaXJlZENBMB4XDTE0MTIwNTAzMTQwMFoXDTI0MTIwNTAzMTMwMFowFDES +MBAGA1UEAxMJZXhwaXJlZENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAwIAbmttZ5Kz69+Ngh3Rxi53hC6zecOyLsfjKHUPCwjkT8yJ3iycTVC2DUfaG +0KK2rXjylfiDlR4U6gxMmEsG6aSSfw7fs9FCHZXcPqyXFB+GmO9XWNQ3VsARfOXx +/HdAutgVB2WTShM30Hhej5zytaibzx/4is9YDDmO/0GNjA+DmyzYAPo/PKNC7ABE +mq9MUyIzkTlH8++PHOZTDy4zpoy3MARoKxf/7XDk+2IOFRcyMtNGIC+u5RKGgKIz +CF9LUW0dAu1nKw0JKmJ+XZlHo+/TWmZDqAbrdLgXtrvpVy73TyfHv8tOMGhp6gWu +LOBERpwIJ95aVMz0eaBv4168KwIDAQABo1AwTjAMBgNVHRMBAf8EAjAAMB0GA1Ud +DgQWBBT6u50QH6/zw1wkBPUyompO8WYURzAfBgNVHSMEGDAWgBReVbQ8tXr4nNgW +MlSoBWfimpkDSzANBgkqhkiG9w0BAQsFAAOCAQEAn0AXIhpCWorNH3zDeMmvpSb8 +IHH/kSfkp6ym/QPepsXaSrXZnhg4ZdZGRdWHCtVokr0fhOG2jaqIhNNpzm1jSzhB +xzOtQ8jg17KZNgN2t6HPF9iHFChm49CuqIvPBiBNleBUPzpyrCMuXZB5oBErx1Xy +wMQOXedo5qoE6IVVuwicwjn5S9pF9P3j2Ks9JijSqx0HOzhnVegO0rpsW7iGeatU +4d7hilGybnhRpkQONNktbl1+xVEvoPTukHYEnudyARHYqR35CghX9CM4XumQJgm6 +LBaXAt/UzyXsO3QrQYr+5JbA57TGsKLwaqWsLPy+WQjcZhS9j8AnPLdbtU38kQ== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/expiredCert/client/client.csr b/test/opensslcrt/expiredCert/client/client.csr new file mode 100644 index 0000000000000000000000000000000000000000..c3142c6a8124b52bd57a8f5f7ca3758489b6fa09 --- /dev/null +++ b/test/opensslcrt/expiredCert/client/client.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsjCCARsCAQAwcjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxDDAKBgNVBAMMA0NsaTEf +MB0GCSqGSIb3DQEJARYQeW91cmVtYWlsQHFxLmNvbTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEAxoL/9c0v2WnaZ9Km8/4/r5faQMQ6z8ecAN/imXUjEbMBJWbx +Z6AiQBiG2txCx1jUUYOSBLXm24wtVEMFfcEm40vdMT4pTcmHZ2oYgJO3NS6Y/f0y +KEzw8iQ6SbT4IExs0AmzNCnH0kTVNUNoA0vf2lQzK+NhWU7sQ9TXg4dg0IMCAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4GBAGQ5tpuAkhbmntWmc3Nc93qjgTyxLxTEz6yf +nStiPsOw1WXT/SKvgT/4KK54bekf3+wWFXPLtwazg0RvfaKIKHs03F/oy9ESwhVU +PkaLVIMy8v40UHYK9uM4SdyZGDOp2XDTXAFdIF2BDHwy4ckqos2SIJrZvLh1AoBX +MdjfR1LH +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/expiredCert/client/key.pem b/test/opensslcrt/expiredCert/client/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..682c5a7b44b9fe4b9878e58a7a4d7dfeb25b28c2 --- /dev/null +++ b/test/opensslcrt/expiredCert/client/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAwIAbmttZ5Kz69+Ngh3Rxi53hC6zecOyLsfjKHUPCwjkT8yJ3 +iycTVC2DUfaG0KK2rXjylfiDlR4U6gxMmEsG6aSSfw7fs9FCHZXcPqyXFB+GmO9X +WNQ3VsARfOXx/HdAutgVB2WTShM30Hhej5zytaibzx/4is9YDDmO/0GNjA+DmyzY +APo/PKNC7ABEmq9MUyIzkTlH8++PHOZTDy4zpoy3MARoKxf/7XDk+2IOFRcyMtNG +IC+u5RKGgKIzCF9LUW0dAu1nKw0JKmJ+XZlHo+/TWmZDqAbrdLgXtrvpVy73TyfH +v8tOMGhp6gWuLOBERpwIJ95aVMz0eaBv4168KwIDAQABAoIBAAj1bVN1pedYjY1/ +zXxAVRB0VobekrOarp2nwHBAOQ4k7eLA1eNp/CJMw/HKkVegWvQTzuZf3G/SKJMD +44Sq6TVybUbwgES0FmmeEdPw4E8VcIQpj10Y80JcSfQQF4TyX9bSO5hOh95Iqq5P +C8ePPFRr9mRbTovXPnfDLf5/JIFCGszoR7H2dtIUrpUHdtt3cgrVAY0BJdak6ryt +JFy86AcZ7SN26wpkaNU5qDp8sXpt5uphoLbxpYU0reByO9c7lQIaFeecQeG5uems +VNgguZSUS/FWoPDfkLq5iwRaLGF0K6xfIjM/u02CE3TVlIaLJPZIx3jg1uoqGpe6 +TljqeiECgYEA9qBFsTELPTgBDG+ZmLG1uRoazZRh8q4JCPKrZPNrIxAAyxVbHI8A +ez+81Bd9xYgEcvnshAc+FLKcQgisCr6yTn8Hva/KybM1z6/aRGlnLmgw4bHEmcxq +7uOGEairOTABQSjtnSejwlKo+DAKRWQFiIV9nb9/V8JgzsjXMQ04uakCgYEAx9Eu +RwXOYCWT42pEb5EnccsnmNndpDBZkPFVMsbhuHTPDLjJ/MTuHqUZ1yGXEQ9kN+Vs +yx1HLS2+sikNzYP5g/uZlDqc3cNqI9+8jttjJAkPTzE6kY8ExSVmlplL4mCLHq+k +4gr7hLYu1AVvoNenky2VdIIyX9Khd0n5lduec7MCgYEAxj1IJ8jbVucYeK2QEhvu +jY49MzhtjwtzebzOJkQ/vxxS9usApLER2v/9waHsWAYgRWc9RVcpAKwdTDr6R1zH +qK8VQtT2NTEXNQ0ObmUnKjdX62LgPBwDbGh40OL6VuqOZ5kFfohpan7VEUJUTzi5 +9eYPdeiC7MDy3eS0jNoq2ZkCgYAGX0tkhHDibBBdTkREcpKZdGsc2sXQuKICl0+q +QOFt6nu34iS+5ODbJVS/oZiZuK3vgmeHyrU8YZsVP23rQewxI1LwgTYDdHnsDvSK +ccClo4xTcDR38+GpD9pHrzfWTlHhdqSBeOwwfUbdCBdZP8deUDPV9Vj56VOw5DEL +cGeLNQKBgH3S5IvBmbv2tNbgcln+Qlsqvkw4lbT3cnmrpOFcvx52ZB6kt1TVC46c +DOu/Bi2pl26DyGf65S+X7WARp6ZxleZCFPxEELh9MXgcEWTUrG7si4kQakeihBIO +GiLaOxnIbGZsGPWl5xpiVQ59P035pYMTcL/jUOWPLb6gaBy9TCQq +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/expiredCert/server/cert.pem b/test/opensslcrt/expiredCert/server/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..b0854c8c84195baeb4f75567a7db4f37d0be39b4 --- /dev/null +++ b/test/opensslcrt/expiredCert/server/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIIFnXaZZ5c/DkwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UE +AxMJZXhwaXJlZENBMB4XDTE0MTIwNTAzMTgwMFoXDTI0MTIwNTAzMTMwMFowFDES +MBAGA1UEAxMJZXhwaXJlZENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAlrsVjvPYZ0WqorIUHf8P8qRFeda5xqgVBLuMOUgzyTDCik9Pu2yK4z57ny0+ +KDRLV5hRT/Lfvs1+Q2kdskaPlBiUdjzbFuUVBt4sw0eJ7pdAsy6T0ZzeGKi/wG3Q +IFimkYzKN92B8PVP0j3hE3ixAN3yr8MUKtYDcXLqGfKv9Qn1e9bf2PiRrgWQDrqo +M7ieZcP8uzlJmTROYy5JiPeqEkLGkzNUDxszjXCSW+gjX/dE4b2rDKHIlNEt+TGV +qPXOIooVpjcmnp9UslpRZhlU6QmxN4InJg+3KOpPcAqkYsWOOWyVTimdUg5EUcMG +TK4FuDypTMK6G2jFcjx750+2tQIDAQABo1AwTjAMBgNVHRMBAf8EAjAAMB0GA1Ud +DgQWBBSExGX00Zp01xoll2pr5k3DkUpC/jAfBgNVHSMEGDAWgBReVbQ8tXr4nNgW +MlSoBWfimpkDSzANBgkqhkiG9w0BAQsFAAOCAQEABYzDubSKucCyJPqGJGska+G7 +5bupUUAvCR7NixoeVAZp3DIG5YwNAFAEwB46qzziAoycwf0sS6BEI9neoX/XmeO0 +RtolM5wO7lgEPwsxRVC42T2LgCNrbjPx1xsyt9RI55WoeTvyiz5P1Bonoxj6MmbC +ylN21ER73YQsPidIJFrj0r82hcCYEudPPvFQxd+ZhJiHSlobXXaLwkUX8VWayp3M +neCiQR1la3kk5RJ7v7gEeNktV0jiocm9B2hwMvEkepy1XC1q9Qv9iu28VkCTliJH +YlioEw2cwhqoqAHyDtKTVVeGFxNmNW4NAL3/52YOcgK8SsjqEaUV/NUqnnUWzQ== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/expiredCert/server/key.pem b/test/opensslcrt/expiredCert/server/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..62123dd4dff5b7ee31ad911c0368b8653f22d9b9 --- /dev/null +++ b/test/opensslcrt/expiredCert/server/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEoQIBAAKCAQEAlrsVjvPYZ0WqorIUHf8P8qRFeda5xqgVBLuMOUgzyTDCik9P +u2yK4z57ny0+KDRLV5hRT/Lfvs1+Q2kdskaPlBiUdjzbFuUVBt4sw0eJ7pdAsy6T +0ZzeGKi/wG3QIFimkYzKN92B8PVP0j3hE3ixAN3yr8MUKtYDcXLqGfKv9Qn1e9bf +2PiRrgWQDrqoM7ieZcP8uzlJmTROYy5JiPeqEkLGkzNUDxszjXCSW+gjX/dE4b2r +DKHIlNEt+TGVqPXOIooVpjcmnp9UslpRZhlU6QmxN4InJg+3KOpPcAqkYsWOOWyV +TimdUg5EUcMGTK4FuDypTMK6G2jFcjx750+2tQIDAQABAoH/JwNwD5/GuhwFijYk +OpP76S49we74bd1xTv++97mU7GMoizUVqF2et11lLUygt3p2lAPJE2+G0kPM4Jea +MyK4II08JEeqUP/goylHosxUpFoppO0x4NIkEJu/Be7SfKlCOt6xwncihIQtUPUA +BxnREaWIskJKdJi0cGunigdFSBzqSRbffnL8pKz3JdK4HBAiMZLW0f5hWMWJK3dG +M+NQuuBwlh8SQ7S0vh+Qe8mwimGMmsipNmeA4Wo+80HfnxRnS1zEkUBsuDo2598l +n0mFWxch4KYArewUkqXB26Cm1VHoqfotHMFXsFxgBXvkH0QwdC1b8ycAWoXfBcXc +5EzpAoGBAMvTuPuNmJungtaGzqMIgDgFafPpwrN0mpNOM3zV/lfiK5+cqr84Pbmd +Mzu2TxqBc96lw3/rvKZIh/kMeVuzcp3zUGG7PeADQJTZfKbQ2Fhemor5I6hL8+Nu +fXIV0yfUqeCch6PxTppp+AHGPojuK6C82zx0zpB7Af4ELNcqxxztAoGBAL1QGOaR +3MiHnEwwdTdrCXEVIs9XDnwD8lYxHLbXZN8Sb75BGVNSXt4NOzqsppihzEQOAKeX +H6A4jb8HaIZo6VMZW1Uap9RWLxM2sfoQ53yvkOo5zqKW1Kh0aOKGmTIvC/Q/jv8N +1SeGb0J2mSds9/j7xInHqTo883yJ0LmK5o/pAoGBAI1nctvBXfAOhCyEFXLxgNJc +nybCM3sAGAS9qeafJvadR5pRu/sw71GIB3UTg0lmKZZ69WgXiSbrBrn2t3KwyYFe +vZMvrTttxi39vAaWuKCF8T0cnmoxVx+fFZRCI91sIfZbYZaQ3/EsNww0Fko0wTug +CZClkHCXhchN6TXUzZH5AoGAXGYisS7SNuHRjHI+U90fT51ETzSrciYu+pif7jH1 +HteNyKtXZA6ZFQIaPYoVCfw2iaTX9vRQ0E+qB2njP28nKpL+u+v1rDKgaV1Rwr7L +bVT3gwrR6xN5GfsvUhjl3tONnxoCfkRPHCqGVUcze0W5RkID7EeSNkWXhdpyEI1V ++TkCgYBSTRfMipMoifMERGyuwuc0SNsXW+VE8+KgMsL5zh3/2AQ3M+DNghOg260D +4RZCzYbcNbwwLQD/Lv+3r9t4yQpC5Tp+fHmN17VhTSZtv0Fl+0QjgrygEhgXYQHn ++ZwLnciTDAreYaasogD6C/mBtjLLpXp2C0dKaPF10HdE4OPS4A== +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/expiredCert/server/server.csr b/test/opensslcrt/expiredCert/server/server.csr new file mode 100644 index 0000000000000000000000000000000000000000..489db0a65ee52c2f6b2df4a5aed4dc0b974aae6f --- /dev/null +++ b/test/opensslcrt/expiredCert/server/server.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsjCCARsCAQAwcjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxDDAKBgNVBAMMA1NFUjEf +MB0GCSqGSIb3DQEJARYQeW91cmVtYWlsQHFxLmNvbTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEArkUtnn2ceKOOtXQwxi+x/KTW0ZSKBQP48sPaKpBXLQLN6WdT +dKCz96QG2kzrJ5ujtZYrNyekkEirXCRcoGFRD/6soYOeEOho3EUTwhNbdB1+fGSj +SiICxv4TYcSBgbYwPKjeBRCmri/aTC5/hNo6+aZIPrd8ByJ21doClORkhm0CAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4GBAFi38a1K0vL6/CBkGT5auHD05j0AVbwQ9GR0 +I6ZZIAe1laqpNBGXHQwzutab0XoErjGAVmwxkdq5L+w8t7uEt2weQv+YN6xfWqc+ +3wP/h//gjU2NF9zK9SpNcCG0e4E0+3jfCy3PKZ6lGgTCTweO0EiHYqyFdBnJwdyf +UNzQVvIt +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/multiLevelCert/CA/cacert.pem b/test/opensslcrt/multiLevelCert/CA/cacert.pem new file mode 100644 index 0000000000000000000000000000000000000000..8e52b7a68dd73f59d2eaf00709e3d409fd2d1226 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/CA/cacert.pem @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID0TCCArmgAwIBAgIURA/lUFrZwOfJXaaLF9ol7wkhAxIwDQYJKoZIhvcNAQEL +BQAweDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQHDAJOSjEPMA0G +A1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxDzANBgNVBAMMBlJvb3RDQTEfMB0G +CSqGSIb3DQEJARYQcm9vdGNhQHdvcmxkLmNvbTAeFw0yMjEwMTIwOTM3MjFaFw0z +MjEwMDkwOTM3MjFaMHgxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJKUzELMAkGA1UE +BwwCTkoxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MQ8wDQYDVQQDDAZS +b290Q0ExHzAdBgkqhkiG9w0BCQEWEHJvb3RjYUB3b3JsZC5jb20wggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCgvTn1zACCZ4uny3BW8Utilwoztkhb/XM+ +ZI/trCZg1smsnuCNJHyIJVoFz4PoxXESCueTD0UwIcrftvzQPJZzVZEOY/ND4ZFq +Bj4TbCybSNuFIIAXn2yL1x5oLGz5wuEr7XClqUECVZPTyDv2ozg7+L6NRXNnQ3DQ +jL3QqEaH0M0hA/4FX7ySXrSC2BFX5LZzv8cjKla+3jqJUbUxokxEWMfYVNU+JjwV +vS4ieVqIsWcfd+FhqFCpvWj92PpJB5Lk6GkuGi026lgYutK30Gx133QLQjrNRRLG +4dM0KMevpSM1Ug7dXy60dIjlJkFjekEf7umCGFNILf3UKQTd1irnAgMBAAGjUzBR +MB0GA1UdDgQWBBRXGByy5N4UbiP0NcTKWpj2yDsJ3zAfBgNVHSMEGDAWgBRXGByy +5N4UbiP0NcTKWpj2yDsJ3zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4IBAQAaLNCcpUpKjxz3DKPU632uqljaliu5mCm9OAQRDLccfZifIHacyxx45rm7 +8TOSgQnSS41UtsvTMn63LXqWCi+P0K6LVZUDAiQzcehGsKYyj7yG8UiQev3qUXuU +HRD+n92AVKYz4ABkKqepgVqE1G0CQeFISU/czTv4r3+Qv5IGXWuZ9D9fNu2Al8DG +Zo4FdcxsBESksDP07Gjj3zMRgV0uhsoSP7gEc2zrucfSIZasEuIgPGL07vMc4Ybm +EaY/cZcuSj+o012wmP6YMQw4t74Fxqe41l4lY/yPtgUO9LDx9ygiDzEKaMCTea8M +33Oqg/tD4bzToCPlpeh7qhpUfLpP +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/CA/rootca.crt b/test/opensslcrt/multiLevelCert/CA/rootca.crt new file mode 100644 index 0000000000000000000000000000000000000000..8e52b7a68dd73f59d2eaf00709e3d409fd2d1226 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/CA/rootca.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID0TCCArmgAwIBAgIURA/lUFrZwOfJXaaLF9ol7wkhAxIwDQYJKoZIhvcNAQEL +BQAweDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQHDAJOSjEPMA0G +A1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxDzANBgNVBAMMBlJvb3RDQTEfMB0G +CSqGSIb3DQEJARYQcm9vdGNhQHdvcmxkLmNvbTAeFw0yMjEwMTIwOTM3MjFaFw0z +MjEwMDkwOTM3MjFaMHgxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJKUzELMAkGA1UE +BwwCTkoxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MQ8wDQYDVQQDDAZS +b290Q0ExHzAdBgkqhkiG9w0BCQEWEHJvb3RjYUB3b3JsZC5jb20wggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCgvTn1zACCZ4uny3BW8Utilwoztkhb/XM+ +ZI/trCZg1smsnuCNJHyIJVoFz4PoxXESCueTD0UwIcrftvzQPJZzVZEOY/ND4ZFq +Bj4TbCybSNuFIIAXn2yL1x5oLGz5wuEr7XClqUECVZPTyDv2ozg7+L6NRXNnQ3DQ +jL3QqEaH0M0hA/4FX7ySXrSC2BFX5LZzv8cjKla+3jqJUbUxokxEWMfYVNU+JjwV +vS4ieVqIsWcfd+FhqFCpvWj92PpJB5Lk6GkuGi026lgYutK30Gx133QLQjrNRRLG +4dM0KMevpSM1Ug7dXy60dIjlJkFjekEf7umCGFNILf3UKQTd1irnAgMBAAGjUzBR +MB0GA1UdDgQWBBRXGByy5N4UbiP0NcTKWpj2yDsJ3zAfBgNVHSMEGDAWgBRXGByy +5N4UbiP0NcTKWpj2yDsJ3zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4IBAQAaLNCcpUpKjxz3DKPU632uqljaliu5mCm9OAQRDLccfZifIHacyxx45rm7 +8TOSgQnSS41UtsvTMn63LXqWCi+P0K6LVZUDAiQzcehGsKYyj7yG8UiQev3qUXuU +HRD+n92AVKYz4ABkKqepgVqE1G0CQeFISU/czTv4r3+Qv5IGXWuZ9D9fNu2Al8DG +Zo4FdcxsBESksDP07Gjj3zMRgV0uhsoSP7gEc2zrucfSIZasEuIgPGL07vMc4Ybm +EaY/cZcuSj+o012wmP6YMQw4t74Fxqe41l4lY/yPtgUO9LDx9ygiDzEKaMCTea8M +33Oqg/tD4bzToCPlpeh7qhpUfLpP +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/CA/rootca.key b/test/opensslcrt/multiLevelCert/CA/rootca.key new file mode 100644 index 0000000000000000000000000000000000000000..243405ce091dbafb55cbc6d84a9a8f9fb506a969 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/CA/rootca.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAoL059cwAgmeLp8twVvFLYpcKM7ZIW/1zPmSP7awmYNbJrJ7g +jSR8iCVaBc+D6MVxEgrnkw9FMCHK37b80DyWc1WRDmPzQ+GRagY+E2wsm0jbhSCA +F59si9ceaCxs+cLhK+1wpalBAlWT08g79qM4O/i+jUVzZ0Nw0Iy90KhGh9DNIQP+ +BV+8kl60gtgRV+S2c7/HIypWvt46iVG1MaJMRFjH2FTVPiY8Fb0uInlaiLFnH3fh +YahQqb1o/dj6SQeS5OhpLhotNupYGLrSt9Bsdd90C0I6zUUSxuHTNCjHr6UjNVIO +3V8utHSI5SZBY3pBH+7pghhTSC391CkE3dYq5wIDAQABAoIBADKXPC3btnFUy8TV +KBeFPJfcOA7MmXuyitohZpeErlOeZr1ZCA4EZNmo/+uCQ984fX0TR42mqb0bdbHx +8yJLX4MPdGdWGBPOZCk9q74LNwLs7IK7FvXYbJ6a52wcR3RY3OwpgGHzoo1sh+mJ +RS48cw+VG8x1BnyC4ngRRBDvVbubAYMGceUmlYtFgIRWqZrBhKvXgzpB2zv+5jxR +NnXDqkQRYl9fuAaG3ajN/qpiMZZlcKfZskjfRvrcGcHi8bD+xGutxXSPFq3qPUT+ +tJ1pexOv2xUx0A27441LoPW1FzSyOj2+iPuZpjTSseM1AIxk3lvXwJU/9CAXYBiS +RxA7YQECgYEA1jqeAfoWtEM1XOl3K1KsxfCqcPFsMUxglhcPz1n6Fv0zgSXHr/Sl +HDPmE7faiqNZoyy81rG7+xjpZkSa6LfQ4gCzwjN5soed7Bi7zeRM5QImta2NMpV+ +x6OXdlQQj3QDeBivmuilL47ckENYCk/t/Tu/jCpuPCYC/syuqvq1g2cCgYEAwBSh +LFcJG4xFtClnrPMljbqlDeQi4S02HYl10+8JpVbjEiPmLTbzkDSAUfm4UR+ZblJX +CGvd/TLDy1iju0/m1X+Ddb2lMUMAeihuZjaWxbpaX7UU0QUNR8jcvxQrkcFqc/hy +dpw3aVD+74+WcYoN+YqG9KGp/7N7Lkpnxslk7IECgYEAksozfpNIf1gV9oYam9rY +fAD+KMmkItt8yxseQCwdCyeP5QxoGY7+m6aMHjK6Uoi/YOnEsy+x6MoXE3Yq1w8s +1883XPg8iTIX6bDA7sFiVwD0WUSEHYcGCfF0VSYg+sq5nc78dJ64oS+4vjkG2HoQ +TpZkF7zzL8+z+bdyb8G+Ij0CgYAYdggYd3UHdxOhX+x+D/DmXbCLVlRCzNkpZcoF +lVlrHueH9d5oP6lA4g69YcnhOt71N7MxtVrt1bsteDpRrlk9MyHwqpgQ7/FtnRyC +E82bnKHJsmvWOoh4bdH+23i49SKzZh5dkINV/CSbKXQFPYmOD+Aj4zqc/6RePsd8 +f0VFAQKBgEqkOim1lmubatkYFy/Gm/3sr4pn/ENkPig7nRZDbSqzSOV50/rU40O2 ++QA5mhUyzQuG8S8zPhXWBnJt1T+O57uDRjEtVw9zrgZw/MddXEHILVjK0qtIR1Hq +opEzRRcl+keFb2N5YXfOd+qsnBl0UklTCOo/E12PwhaW6nX6HujD +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/multiLevelCert/CA/secondca.crt b/test/opensslcrt/multiLevelCert/CA/secondca.crt new file mode 100644 index 0000000000000000000000000000000000000000..866099c21c401d7bcaea279480e11a288068eb06 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/CA/secondca.crt @@ -0,0 +1,79 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 1 (0x1) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CN, ST=JS, L=NJ, O=Huawei, OU=Dev, CN=RootCA/emailAddress=rootca@world.com + Validity + Not Before: Oct 12 09:42:08 2022 GMT + Not After : Oct 9 09:42:08 2032 GMT + Subject: C=CN, ST=JS, O=Huawei, OU=Dev, CN=SecondCA/emailAddress=secondca@world.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:d7:29:a3:b9:46:e6:db:62:4b:cc:89:b7:3c:46: + 56:53:bb:94:0d:11:a8:e0:cb:d7:13:bc:e0:44:be: + 8c:4d:c2:3c:b0:95:a3:cd:e5:f7:91:3e:64:74:19: + 65:62:f2:02:bf:c3:ac:60:a6:ac:ff:77:f8:e0:c8: + 34:69:0c:bc:b6:65:11:5f:af:61:cc:73:31:97:91: + c1:cb:48:4f:28:1a:49:ac:55:e9:b3:8d:f6:6d:85: + de:4d:62:41:94:bf:80:a9:0b:f4:bd:62:fc:5d:e7: + 90:e0:04:b8:26:0c:83:01:df:3f:2f:bf:13:de:c7: + f4:40:91:a7:0b:76:9b:39:5d:e3:45:10:66:3d:67: + 41:8d:67:73:c9:79:17:5d:29:e4:64:a2:36:9e:70: + f7:c3:7f:16:79:3d:64:21:12:50:ab:4f:be:22:b9: + e3:59:63:be:cf:ad:be:04:af:b8:ab:49:24:86:dd: + 85:b9:fe:c5:6e:4e:f6:14:04:ef:db:bd:8a:e9:a8: + d7:30:52:8d:92:8c:30:c7:51:a9:32:1a:aa:df:f4: + 39:fa:ae:df:d8:3f:3a:c7:de:2c:27:c6:af:e3:45: + f6:cb:93:69:e2:44:30:53:cc:cc:e8:52:76:0c:bd: + cd:ff:9f:ca:45:49:1a:e5:f1:7d:e7:4b:cd:b8:0f: + d3:a9 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + C7:83:AD:B8:8A:A9:78:47:3D:C7:69:FC:23:69:4B:D5:67:FC:C4:DF + X509v3 Authority Key Identifier: + keyid:57:18:1C:B2:E4:DE:14:6E:23:F4:35:C4:CA:5A:98:F6:C8:3B:09:DF + + X509v3 Basic Constraints: critical + CA:TRUE + Signature Algorithm: sha256WithRSAEncryption + 60:ce:7c:c2:15:d4:00:b9:6d:d6:e9:9e:14:83:86:9c:e4:06: + 38:77:39:7a:a3:3d:4a:65:84:29:be:55:21:54:71:21:39:bd: + 37:b8:12:75:2c:0a:6a:48:01:89:00:03:fd:9f:9e:be:8a:6f: + d1:85:25:eb:a8:5a:55:a9:a3:8b:d1:5b:d8:3c:aa:f4:43:07: + b0:28:57:fa:23:69:f3:db:e1:5d:9e:a5:92:da:1e:92:50:73: + 1c:aa:a6:51:85:2a:a3:ca:0c:05:d9:95:72:d7:51:b5:8b:02: + eb:6a:f2:ae:c6:2c:bd:cc:96:a0:c5:9d:10:c7:05:a7:d7:10: + 72:4e:05:01:bb:1f:24:c3:b7:6d:2e:a0:83:34:fb:5d:1e:84: + 95:b4:2e:53:b6:50:3f:d2:43:96:44:91:8c:01:e5:f3:64:e2: + b2:4c:82:d9:9b:f8:d8:9b:a3:6c:76:65:d3:1d:d2:b6:a4:bf: + 90:86:0a:37:f5:dd:a3:1c:ce:78:18:71:b3:63:9d:41:63:88: + ba:1f:1d:31:1c:a1:89:83:55:3a:a3:77:85:b8:d7:a1:63:d9: + 1f:96:7d:d6:a5:3c:47:2d:a1:10:a3:0c:c8:bb:75:92:f1:40: + e6:3b:fc:87:ac:65:b5:c9:91:53:2d:6c:ff:25:ad:56:fe:f5: + 79:7b:96:3d +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIBATANBgkqhkiG9w0BAQsFADB4MQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxCzAJBgNVBAcMAk5KMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNV +BAsMA0RldjEPMA0GA1UEAwwGUm9vdENBMR8wHQYJKoZIhvcNAQkBFhByb290Y2FA +d29ybGQuY29tMB4XDTIyMTAxMjA5NDIwOFoXDTMyMTAwOTA5NDIwOFowbzELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsM +A0RldjERMA8GA1UEAwwIU2Vjb25kQ0ExITAfBgkqhkiG9w0BCQEWEnNlY29uZGNh +QHdvcmxkLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANcpo7lG +5ttiS8yJtzxGVlO7lA0RqODL1xO84ES+jE3CPLCVo83l95E+ZHQZZWLyAr/DrGCm +rP93+ODINGkMvLZlEV+vYcxzMZeRwctITygaSaxV6bON9m2F3k1iQZS/gKkL9L1i +/F3nkOAEuCYMgwHfPy+/E97H9ECRpwt2mzld40UQZj1nQY1nc8l5F10p5GSiNp5w +98N/Fnk9ZCESUKtPviK541ljvs+tvgSvuKtJJIbdhbn+xW5O9hQE79u9iumo1zBS +jZKMMMdRqTIaqt/0Ofqu39g/OsfeLCfGr+NF9suTaeJEMFPMzOhSdgy9zf+fykVJ +GuXxfedLzbgP06kCAwEAAaNTMFEwHQYDVR0OBBYEFMeDrbiKqXhHPcdp/CNpS9Vn +/MTfMB8GA1UdIwQYMBaAFFcYHLLk3hRuI/Q1xMpamPbIOwnfMA8GA1UdEwEB/wQF +MAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGDOfMIV1AC5bdbpnhSDhpzkBjh3OXqj +PUplhCm+VSFUcSE5vTe4EnUsCmpIAYkAA/2fnr6Kb9GFJeuoWlWpo4vRW9g8qvRD +B7AoV/ojafPb4V2epZLaHpJQcxyqplGFKqPKDAXZlXLXUbWLAutq8q7GLL3MlqDF +nRDHBafXEHJOBQG7HyTDt20uoIM0+10ehJW0LlO2UD/SQ5ZEkYwB5fNk4rJMgtmb ++Nibo2x2ZdMd0rakv5CGCjf13aMczngYcbNjnUFjiLofHTEcoYmDVTqjd4W416Fj +2R+WfdalPEctoRCjDMi7dZLxQOY7/IesZbXJkVMtbP8lrVb+9Xl7lj0= +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/CA/secondca.csr b/test/opensslcrt/multiLevelCert/CA/secondca.csr new file mode 100644 index 0000000000000000000000000000000000000000..012c84a927e1af42c1e6379275991e853c7d7ee9 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/CA/secondca.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICwTCCAakCAQAwfDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQH +DAJOSjEPMA0GA1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxETAPBgNVBAMMCFNl +Y29uZENBMSEwHwYJKoZIhvcNAQkBFhJzZWNvbmRjYUB3b3JsZC5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDXKaO5RubbYkvMibc8RlZTu5QNEajg +y9cTvOBEvoxNwjywlaPN5feRPmR0GWVi8gK/w6xgpqz/d/jgyDRpDLy2ZRFfr2HM +czGXkcHLSE8oGkmsVemzjfZthd5NYkGUv4CpC/S9Yvxd55DgBLgmDIMB3z8vvxPe +x/RAkacLdps5XeNFEGY9Z0GNZ3PJeRddKeRkojaecPfDfxZ5PWQhElCrT74iueNZ +Y77Prb4Er7irSSSG3YW5/sVuTvYUBO/bvYrpqNcwUo2SjDDHUakyGqrf9Dn6rt/Y +PzrH3iwnxq/jRfbLk2niRDBTzMzoUnYMvc3/n8pFSRrl8X3nS824D9OpAgMBAAGg +ADANBgkqhkiG9w0BAQsFAAOCAQEALBW59/ZzFd97b5jzqamQnkKH2fN/kk7+vfu8 +0FiHN2liCnaHAa7+zlxch8XZY/LdQWdBcTtOMQgTz8dEuHsQaAxT4dLTTm9rs70w +QoxoGLy7okbvGKyhxzJM6BHJVDzaq2AXMtB1BlI+9DFBmwxbpDQyqtc0XaABBkV5 +GahaE2WAP1t3LM+JDOdJ+5VLSNIhneJrFR465HmHaVSVe1ivD3tk7394DwcLPOE4 +qsuIP3nQIzMFKyzyaMbaKXNd1mU3SfitOQuckm7oycexVgtd1oU6kZs0PF8q6e1Z +tcAxZqbYllNZx1cHyg/lu59+gPOvH3FB6LWiloL3usrHLm9pgQ== +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/multiLevelCert/CA/secondca.key b/test/opensslcrt/multiLevelCert/CA/secondca.key new file mode 100644 index 0000000000000000000000000000000000000000..803c11daf385b85f69e000111c71a08968d0bced --- /dev/null +++ b/test/opensslcrt/multiLevelCert/CA/secondca.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEA1ymjuUbm22JLzIm3PEZWU7uUDRGo4MvXE7zgRL6MTcI8sJWj +zeX3kT5kdBllYvICv8OsYKas/3f44Mg0aQy8tmURX69hzHMxl5HBy0hPKBpJrFXp +s432bYXeTWJBlL+AqQv0vWL8XeeQ4AS4JgyDAd8/L78T3sf0QJGnC3abOV3jRRBm +PWdBjWdzyXkXXSnkZKI2nnD3w38WeT1kIRJQq0++IrnjWWO+z62+BK+4q0kkht2F +uf7Fbk72FATv272K6ajXMFKNkowwx1GpMhqq3/Q5+q7f2D86x94sJ8av40X2y5Np +4kQwU8zM6FJ2DL3N/5/KRUka5fF950vNuA/TqQIDAQABAoIBAQDL1t8NQGaloNI+ +zJmTuO9AFI2GdBySG4t/X4j4l61EXagxgxLUlfGc4Ic6lnS+8Jg6JJ7CUiXDQV2/ +VuyQOUjvY4C6LeVxVBC/j48Rj0eurnjtk9b8DJpR2Glq1pNa4LJ7dKBAa+666A8Q +rGfpZCEZPO8XxOaGQNjd8x9WdN9J0DQ1oSsvRC8xLyuKK1jrxrIaXMK1XZRgaNMC +atr3fCo+yChyuzsSKHdscq+2MA6xnp3dc8Q+nwhx1xFSdfah+NCqGfmVYJRDgyCT +GuZCkyi8R1+jkGs4dO/cBTTRlkiWHHFLJo+/c8vB4Do1JkA93PoSYa6G4UAUJA3N +B1e16TihAoGBAOttoHES6CpCiZdOtTTk//zyWxrG8CjnHRCj8g39QHnAYyzlz88h +RYuyfrTPviGuE3+3CMWsjvPag1E20NpyrsJtx3OBhcDG6fj6l7j1vuTngl+mlM9K ++n36BcClTAnY3JZf51FErmUYkTVh9RcQ4iycV7cl65OouNwGVc90WCpPAoGBAOn2 +r3OIC7bal0kxrI0bT65gkbvg3q1sA1cpcu4HOZJUK7RjdYzsd7+lCsSPfVyp9yN1 +iw3Bn2umEkkogxy7wShT27Ng2LNyHRrYhVIuEee/+SSZgMDa+0cTOdKPqRVITffb +weM4Z3gVrc27EtWuc9vGUkNxoceJ3MKHpesDKTyHAoGAMSoFlVdzcE/Q1+4x3Uft +RW9/IwpkYMZSxYTXKaC3dDV/AINFcGXsVg4Cc9PmSrZFkCgzBsTQXZBGWBFwcA3+ +/M9cFXz455ciiUIbqR54rOjDyyHIdbmcse4igWaDiJLnDegdMFV9bdNBj7pTKmv2 +L4a+spqSpZVYdWpFRTtwpfUCgYEA6b4F7bOWmHlsubiB/nuxsLJUBtMTRUlrUPJd +G0dmkjW7cD4Jm+BHhtTZnCULBr/b47Y0VWsC3aaOED8ENnmx8ZtOHLj95tF0GHUH +RWI3i0Q1IgamJobgklK36xCRyWxyUNVhsKOSY9usx6RFnevrXj+VwkHNci/euQ6S +ieeflBMCgYEApVqc4NqEuS4i//GGdQ9fXPx9eXBOxnqjwXL4SUdhwSruBYPIDavD +Vb66fdGQU7wAeVQczeIZEHDFuLkePvbwPeQmdDgJBN88lB1TjxvwhYvOBz0cJbl0 +NN8bkkMUsjZ93sZaVKgKYf7Q2JEwBRdagmWktcx9omJYPs+GAxIdVnE= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/multiLevelCert/client/cert.pem b/test/opensslcrt/multiLevelCert/client/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..29106673ef0a58abba477f078d16d0e6bc23da38 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/client/cert.pem @@ -0,0 +1,82 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 3 (0x3) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CN, ST=JS, O=Huawei, OU=Dev, CN=SecondCA/emailAddress=secondca@world.com + Validity + Not Before: Oct 12 09:46:45 2022 GMT + Not After : Oct 9 09:46:45 2032 GMT + Subject: C=CN, ST=JS, O=Huawei, OU=Dev, CN=testclient.com/emailAddress=test@world.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:e5:eb:c7:fd:84:20:0c:7a:f2:12:df:fa:39:5d: + 40:2d:a6:cd:48:81:a0:66:32:de:d1:18:e8:ba:8e: + bc:e2:f3:e0:44:e0:fa:a8:b9:df:0c:a6:4e:da:59: + 5f:92:ea:56:3a:dc:90:26:f9:36:f6:08:a2:35:ad: + b4:82:d9:c4:8a:78:6b:eb:17:a6:27:c3:df:63:15: + e8:4d:0d:29:0c:22:50:43:e6:e8:ce:e2:6d:8d:0e: + 83:e7:e2:1d:10:27:be:ad:b7:9c:c4:48:b1:20:c2: + 62:50:85:8b:ec:fb:b9:c0:3a:ea:1c:a2:3d:e5:62: + 6b:dd:70:e2:97:ce:be:8e:d5:ee:30:1e:94:e0:3a: + 37:44:fe:ca:5b:93:05:6a:8c:a7:f6:6a:d0:20:a5: + 2b:d9:0e:ac:b1:0e:04:51:06:98:05:ae:4a:c9:47: + c7:bf:6d:47:b2:92:bd:c7:54:fb:b8:b5:86:67:b3: + de:53:f0:d5:7c:94:1d:2e:4f:58:3a:57:72:21:5c: + 2b:b4:00:b8:92:56:bf:29:25:56:47:d9:4c:a5:29: + b7:90:ba:bc:e1:9d:43:0a:4c:b4:c2:78:b3:43:47: + 86:f8:66:d2:23:c8:49:00:59:f7:6f:07:c1:9a:ef: + 7d:44:38:ad:97:3e:46:61:8f:ff:98:3e:73:78:35: + e9:1f + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + A6:28:FE:B3:E2:16:72:B4:DE:71:8A:10:BE:79:5C:90:9A:31:7C:0B + X509v3 Authority Key Identifier: + keyid:C7:83:AD:B8:8A:A9:78:47:3D:C7:69:FC:23:69:4B:D5:67:FC:C4:DF + + Signature Algorithm: sha256WithRSAEncryption + 6b:41:eb:e5:91:e3:26:6d:f0:3b:98:27:b4:90:e2:0b:04:04: + 0a:3a:04:40:bf:95:7e:8d:bf:80:9e:49:12:81:cf:c4:43:1a: + e3:79:cc:07:c9:b4:2f:64:f5:cd:a5:d3:7b:40:61:04:27:c2: + 46:22:aa:c1:10:a0:de:a8:8f:3d:09:72:49:90:6f:5e:32:dd: + d8:fb:94:29:60:98:c0:81:cb:a0:4f:8e:40:cb:f3:97:73:ff: + 83:45:47:62:b8:57:12:7c:a3:06:1a:36:30:42:40:5a:75:a1: + 74:6a:fb:07:25:b9:6b:1e:7a:38:68:24:82:5d:e3:d1:f5:e2: + 09:0d:71:58:36:8f:57:a9:71:17:4a:e2:46:4b:de:0f:db:ea: + e6:f5:d4:d5:2f:b7:ed:91:7e:b4:10:50:2a:78:dc:5a:4a:f3: + a7:cb:77:57:44:9d:e7:32:14:28:05:3b:21:fe:95:aa:cf:50: + f6:2d:ed:fb:59:5a:42:ca:8a:1f:d3:f6:43:b5:50:79:9a:5d: + be:2e:67:45:59:83:9f:ad:e9:4b:7f:8d:e0:f7:77:c4:27:1c: + c7:ed:2f:f4:ca:39:1f:ad:8a:1c:25:11:99:96:4f:23:9b:6a: + 11:e6:cf:44:8c:59:07:d4:21:6d:e1:7c:d0:ee:43:33:f2:6c: + a9:f2:70:2d +-----BEGIN CERTIFICATE----- +MIID1jCCAr6gAwIBAgIBAzANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDY0NVoXDTMyMTAwOTA5NDY0NVowcTELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEXMBUG +A1UEAwwOdGVzdGNsaWVudC5jb20xHTAbBgkqhkiG9w0BCQEWDnRlc3RAd29ybGQu +Y29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5evH/YQgDHryEt/6 +OV1ALabNSIGgZjLe0Rjouo684vPgROD6qLnfDKZO2llfkupWOtyQJvk29giiNa20 +gtnEinhr6xemJ8PfYxXoTQ0pDCJQQ+bozuJtjQ6D5+IdECe+rbecxEixIMJiUIWL +7Pu5wDrqHKI95WJr3XDil86+jtXuMB6U4Do3RP7KW5MFaoyn9mrQIKUr2Q6ssQ4E +UQaYBa5KyUfHv21HspK9x1T7uLWGZ7PeU/DVfJQdLk9YOldyIVwrtAC4kla/KSVW +R9lMpSm3kLq84Z1DCky0wnizQ0eG+GbSI8hJAFn3bwfBmu99RDitlz5GYY//mD5z +eDXpHwIDAQABo3sweTAJBgNVHRMEAjAAMCwGCWCGSAGG+EIBDQQfFh1PcGVuU1NM +IEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTAdBgNVHQ4EFgQUpij+s+IWcrTecYoQvnlc +kJoxfAswHwYDVR0jBBgwFoAUx4OtuIqpeEc9x2n8I2lL1Wf8xN8wDQYJKoZIhvcN +AQELBQADggEBAGtB6+WR4yZt8DuYJ7SQ4gsEBAo6BEC/lX6Nv4CeSRKBz8RDGuN5 +zAfJtC9k9c2l03tAYQQnwkYiqsEQoN6ojz0JckmQb14y3dj7lClgmMCBy6BPjkDL +85dz/4NFR2K4VxJ8owYaNjBCQFp1oXRq+wcluWseejhoJIJd49H14gkNcVg2j1ep +cRdK4kZL3g/b6ub11NUvt+2RfrQQUCp43FpK86fLd1dEnecyFCgFOyH+larPUPYt +7ftZWkLKih/T9kO1UHmaXb4uZ0VZg5+t6Ut/jeD3d8QnHMftL/TKOR+tihwlEZmW +TyObahHmz0SMWQfUIW3hfNDuQzPybKnycC0= +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/client/client.crt b/test/opensslcrt/multiLevelCert/client/client.crt new file mode 100644 index 0000000000000000000000000000000000000000..29106673ef0a58abba477f078d16d0e6bc23da38 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/client/client.crt @@ -0,0 +1,82 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 3 (0x3) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CN, ST=JS, O=Huawei, OU=Dev, CN=SecondCA/emailAddress=secondca@world.com + Validity + Not Before: Oct 12 09:46:45 2022 GMT + Not After : Oct 9 09:46:45 2032 GMT + Subject: C=CN, ST=JS, O=Huawei, OU=Dev, CN=testclient.com/emailAddress=test@world.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:e5:eb:c7:fd:84:20:0c:7a:f2:12:df:fa:39:5d: + 40:2d:a6:cd:48:81:a0:66:32:de:d1:18:e8:ba:8e: + bc:e2:f3:e0:44:e0:fa:a8:b9:df:0c:a6:4e:da:59: + 5f:92:ea:56:3a:dc:90:26:f9:36:f6:08:a2:35:ad: + b4:82:d9:c4:8a:78:6b:eb:17:a6:27:c3:df:63:15: + e8:4d:0d:29:0c:22:50:43:e6:e8:ce:e2:6d:8d:0e: + 83:e7:e2:1d:10:27:be:ad:b7:9c:c4:48:b1:20:c2: + 62:50:85:8b:ec:fb:b9:c0:3a:ea:1c:a2:3d:e5:62: + 6b:dd:70:e2:97:ce:be:8e:d5:ee:30:1e:94:e0:3a: + 37:44:fe:ca:5b:93:05:6a:8c:a7:f6:6a:d0:20:a5: + 2b:d9:0e:ac:b1:0e:04:51:06:98:05:ae:4a:c9:47: + c7:bf:6d:47:b2:92:bd:c7:54:fb:b8:b5:86:67:b3: + de:53:f0:d5:7c:94:1d:2e:4f:58:3a:57:72:21:5c: + 2b:b4:00:b8:92:56:bf:29:25:56:47:d9:4c:a5:29: + b7:90:ba:bc:e1:9d:43:0a:4c:b4:c2:78:b3:43:47: + 86:f8:66:d2:23:c8:49:00:59:f7:6f:07:c1:9a:ef: + 7d:44:38:ad:97:3e:46:61:8f:ff:98:3e:73:78:35: + e9:1f + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + A6:28:FE:B3:E2:16:72:B4:DE:71:8A:10:BE:79:5C:90:9A:31:7C:0B + X509v3 Authority Key Identifier: + keyid:C7:83:AD:B8:8A:A9:78:47:3D:C7:69:FC:23:69:4B:D5:67:FC:C4:DF + + Signature Algorithm: sha256WithRSAEncryption + 6b:41:eb:e5:91:e3:26:6d:f0:3b:98:27:b4:90:e2:0b:04:04: + 0a:3a:04:40:bf:95:7e:8d:bf:80:9e:49:12:81:cf:c4:43:1a: + e3:79:cc:07:c9:b4:2f:64:f5:cd:a5:d3:7b:40:61:04:27:c2: + 46:22:aa:c1:10:a0:de:a8:8f:3d:09:72:49:90:6f:5e:32:dd: + d8:fb:94:29:60:98:c0:81:cb:a0:4f:8e:40:cb:f3:97:73:ff: + 83:45:47:62:b8:57:12:7c:a3:06:1a:36:30:42:40:5a:75:a1: + 74:6a:fb:07:25:b9:6b:1e:7a:38:68:24:82:5d:e3:d1:f5:e2: + 09:0d:71:58:36:8f:57:a9:71:17:4a:e2:46:4b:de:0f:db:ea: + e6:f5:d4:d5:2f:b7:ed:91:7e:b4:10:50:2a:78:dc:5a:4a:f3: + a7:cb:77:57:44:9d:e7:32:14:28:05:3b:21:fe:95:aa:cf:50: + f6:2d:ed:fb:59:5a:42:ca:8a:1f:d3:f6:43:b5:50:79:9a:5d: + be:2e:67:45:59:83:9f:ad:e9:4b:7f:8d:e0:f7:77:c4:27:1c: + c7:ed:2f:f4:ca:39:1f:ad:8a:1c:25:11:99:96:4f:23:9b:6a: + 11:e6:cf:44:8c:59:07:d4:21:6d:e1:7c:d0:ee:43:33:f2:6c: + a9:f2:70:2d +-----BEGIN CERTIFICATE----- +MIID1jCCAr6gAwIBAgIBAzANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDY0NVoXDTMyMTAwOTA5NDY0NVowcTELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEXMBUG +A1UEAwwOdGVzdGNsaWVudC5jb20xHTAbBgkqhkiG9w0BCQEWDnRlc3RAd29ybGQu +Y29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5evH/YQgDHryEt/6 +OV1ALabNSIGgZjLe0Rjouo684vPgROD6qLnfDKZO2llfkupWOtyQJvk29giiNa20 +gtnEinhr6xemJ8PfYxXoTQ0pDCJQQ+bozuJtjQ6D5+IdECe+rbecxEixIMJiUIWL +7Pu5wDrqHKI95WJr3XDil86+jtXuMB6U4Do3RP7KW5MFaoyn9mrQIKUr2Q6ssQ4E +UQaYBa5KyUfHv21HspK9x1T7uLWGZ7PeU/DVfJQdLk9YOldyIVwrtAC4kla/KSVW +R9lMpSm3kLq84Z1DCky0wnizQ0eG+GbSI8hJAFn3bwfBmu99RDitlz5GYY//mD5z +eDXpHwIDAQABo3sweTAJBgNVHRMEAjAAMCwGCWCGSAGG+EIBDQQfFh1PcGVuU1NM +IEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTAdBgNVHQ4EFgQUpij+s+IWcrTecYoQvnlc +kJoxfAswHwYDVR0jBBgwFoAUx4OtuIqpeEc9x2n8I2lL1Wf8xN8wDQYJKoZIhvcN +AQELBQADggEBAGtB6+WR4yZt8DuYJ7SQ4gsEBAo6BEC/lX6Nv4CeSRKBz8RDGuN5 +zAfJtC9k9c2l03tAYQQnwkYiqsEQoN6ojz0JckmQb14y3dj7lClgmMCBy6BPjkDL +85dz/4NFR2K4VxJ8owYaNjBCQFp1oXRq+wcluWseejhoJIJd49H14gkNcVg2j1ep +cRdK4kZL3g/b6ub11NUvt+2RfrQQUCp43FpK86fLd1dEnecyFCgFOyH+larPUPYt +7ftZWkLKih/T9kO1UHmaXb4uZ0VZg5+t6Ut/jeD3d8QnHMftL/TKOR+tihwlEZmW +TyObahHmz0SMWQfUIW3hfNDuQzPybKnycC0= +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/client/client.csr b/test/opensslcrt/multiLevelCert/client/client.csr new file mode 100644 index 0000000000000000000000000000000000000000..ff6829e03e04ae02ef15b910f4f3b44e4748f0d1 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/client/client.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICwzCCAasCAQAwfjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQH +DAJOSjEPMA0GA1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxFzAVBgNVBAMMDnRl +c3RjbGllbnQuY29tMR0wGwYJKoZIhvcNAQkBFg50ZXN0QHdvcmxkLmNvbTCCASIw +DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOXrx/2EIAx68hLf+jldQC2mzUiB +oGYy3tEY6LqOvOLz4ETg+qi53wymTtpZX5LqVjrckCb5NvYIojWttILZxIp4a+sX +pifD32MV6E0NKQwiUEPm6M7ibY0Og+fiHRAnvq23nMRIsSDCYlCFi+z7ucA66hyi +PeVia91w4pfOvo7V7jAelOA6N0T+yluTBWqMp/Zq0CClK9kOrLEOBFEGmAWuSslH +x79tR7KSvcdU+7i1hmez3lPw1XyUHS5PWDpXciFcK7QAuJJWvyklVkfZTKUpt5C6 +vOGdQwpMtMJ4s0NHhvhm0iPISQBZ928HwZrvfUQ4rZc+RmGP/5g+c3g16R8CAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4IBAQCsGnU+byNkW1GA8XNiya/xcwFhBrgZ8ytk +neqb0s1VdcTpM0F3OhnaYsBbAuWvuuXnKX2fUGD1TPiB/cuiOULxMkR2FQvA6l6g +pOUTRP1sHsUbAB1nLEZ6ZNmYISLY5e52265SFGw8moQ29TEXjg1Fpsgpiol5Xc9O +pf8ojhCgKmCgUgMd48N11BPzTB+8vwiBG7c+pJr3RU+g/FjtGXTE2g7GqWDAvcaC +4F7L+PBtlocjI4K2ci1yEsFYXoa0ZmOZjLvTOS7A3hDFM1ga1pjTh1hejGwC+d1u +uaFXa2XJa89jimpZH609gAv+6KlGXB5SVgfwVG9DjKMFmJJYVmpV +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/multiLevelCert/client/client.key b/test/opensslcrt/multiLevelCert/client/client.key new file mode 100644 index 0000000000000000000000000000000000000000..d0e8e05724fdfadc4ea309d388bac03c07e18029 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/client/client.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA5evH/YQgDHryEt/6OV1ALabNSIGgZjLe0Rjouo684vPgROD6 +qLnfDKZO2llfkupWOtyQJvk29giiNa20gtnEinhr6xemJ8PfYxXoTQ0pDCJQQ+bo +zuJtjQ6D5+IdECe+rbecxEixIMJiUIWL7Pu5wDrqHKI95WJr3XDil86+jtXuMB6U +4Do3RP7KW5MFaoyn9mrQIKUr2Q6ssQ4EUQaYBa5KyUfHv21HspK9x1T7uLWGZ7Pe +U/DVfJQdLk9YOldyIVwrtAC4kla/KSVWR9lMpSm3kLq84Z1DCky0wnizQ0eG+GbS +I8hJAFn3bwfBmu99RDitlz5GYY//mD5zeDXpHwIDAQABAoIBAQCnLeM0Nm8rQ/Zi +vRNvxJtW1nNr5j1gMlsLxUXr6L/1cgi/bKs2JjjGNOMfJ180L0pV8Gysugc5rJtt +1olrn7amTNuDjKWXQnhazuIjrI8NMKIWTX84dzHbIBPPdv1U8uFV5S2LF6QbwtvD +2uccgQjWesAh4+KHuSHfWSaZ5Y1vw0MGNCRkyVx13ovDM0gjT694aI6y0Qd/mXm4 +oAvTZCU7tJ/45nSDiWqSQdOJJdQjV0nOoKj0Fpcr7mF11QYALjXmC04YfO5DYaCd +O4wQxZYGzSWyYoHLFonlc5uL1M/ZYP6n2HPvHCeyvJY6q+kTA3OmhnOIKwoCpObb +FV3JETmhAoGBAP4Kji7Ig2bNK8oYNm6sUFJR1tnLSHIkSmNCbpJ4vQE0RBIU1/US +NrrUBAWm6AppFZPsUsLc42O96GOv/P4tOioKsFUJCi/w59A3Xt04tapHBTEfUiJy +bOLIC1WYEhsZ4u1e1u6hfvyAuJVYLqkiPVsRLp2QV6PUP1cSyNY3mHEPAoGBAOex +nZmytTkL/na/aQNKPEFQaBvQJQUKw7tWcwZyjgm4Ar03ojYw04TkuP0UHw+M0R+M +rjoeuGrnOktjwlMlmt8rMRxO8WeVr1tA5U6658tzt2Rprq1INkvVeLoBtGP29FqM +LI2YG4uKou43LU+jgPrkLkEt0bhb/mAdhHrR4ObxAoGAdgKZQgpLYDn3GY5d2tOZ +DGSQFeRk5wEMvUdi7g/AXQrWhD/CgknPusI6jBWYvR1LtMeXOoY561+Q0J40PC7u +UhFdEGN+o/6Y8RSHsORjH5KWStdt5Cqbgk3DViOqZYSE8heYaIoE3288T8QDCPaq +4d79dJxU2foC4oQLX9e7rOkCgYADaV0dt0Dt3xxXGUhtkPlEKO/vgOgao+bv6jz1 +Wlh3EiuQJ7KOw7dJnKiQqWwvqW4m3cZu+qbShCcalxR0bvhR0uv9M7hgQxb67AC0 +YRIqr8CCjP/Sc17BTRpi+sVyN1+vuaKqTxQQwPDXOx7CrnCmwRdhRFBzO3+KYMTj +nhWGsQKBgFjQkaxPXGpNaMOGoEzNir/4gvGsxPICUEACfA/HAAaTZm23Ly5+ThGO +Pvg4hqY1C1T+ksn0pGIlkGRi3Tq/QnX4NyxouskQHGswDq6wVGYHu00HizmHgYhx +XXuYhcynXchaHnpHNVxuq85cw5LWs3rPwkH50Kq4jnJ7wGhQLVB9 +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/multiLevelCert/client/key.pem b/test/opensslcrt/multiLevelCert/client/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..d0e8e05724fdfadc4ea309d388bac03c07e18029 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/client/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA5evH/YQgDHryEt/6OV1ALabNSIGgZjLe0Rjouo684vPgROD6 +qLnfDKZO2llfkupWOtyQJvk29giiNa20gtnEinhr6xemJ8PfYxXoTQ0pDCJQQ+bo +zuJtjQ6D5+IdECe+rbecxEixIMJiUIWL7Pu5wDrqHKI95WJr3XDil86+jtXuMB6U +4Do3RP7KW5MFaoyn9mrQIKUr2Q6ssQ4EUQaYBa5KyUfHv21HspK9x1T7uLWGZ7Pe +U/DVfJQdLk9YOldyIVwrtAC4kla/KSVWR9lMpSm3kLq84Z1DCky0wnizQ0eG+GbS +I8hJAFn3bwfBmu99RDitlz5GYY//mD5zeDXpHwIDAQABAoIBAQCnLeM0Nm8rQ/Zi +vRNvxJtW1nNr5j1gMlsLxUXr6L/1cgi/bKs2JjjGNOMfJ180L0pV8Gysugc5rJtt +1olrn7amTNuDjKWXQnhazuIjrI8NMKIWTX84dzHbIBPPdv1U8uFV5S2LF6QbwtvD +2uccgQjWesAh4+KHuSHfWSaZ5Y1vw0MGNCRkyVx13ovDM0gjT694aI6y0Qd/mXm4 +oAvTZCU7tJ/45nSDiWqSQdOJJdQjV0nOoKj0Fpcr7mF11QYALjXmC04YfO5DYaCd +O4wQxZYGzSWyYoHLFonlc5uL1M/ZYP6n2HPvHCeyvJY6q+kTA3OmhnOIKwoCpObb +FV3JETmhAoGBAP4Kji7Ig2bNK8oYNm6sUFJR1tnLSHIkSmNCbpJ4vQE0RBIU1/US +NrrUBAWm6AppFZPsUsLc42O96GOv/P4tOioKsFUJCi/w59A3Xt04tapHBTEfUiJy +bOLIC1WYEhsZ4u1e1u6hfvyAuJVYLqkiPVsRLp2QV6PUP1cSyNY3mHEPAoGBAOex +nZmytTkL/na/aQNKPEFQaBvQJQUKw7tWcwZyjgm4Ar03ojYw04TkuP0UHw+M0R+M +rjoeuGrnOktjwlMlmt8rMRxO8WeVr1tA5U6658tzt2Rprq1INkvVeLoBtGP29FqM +LI2YG4uKou43LU+jgPrkLkEt0bhb/mAdhHrR4ObxAoGAdgKZQgpLYDn3GY5d2tOZ +DGSQFeRk5wEMvUdi7g/AXQrWhD/CgknPusI6jBWYvR1LtMeXOoY561+Q0J40PC7u +UhFdEGN+o/6Y8RSHsORjH5KWStdt5Cqbgk3DViOqZYSE8heYaIoE3288T8QDCPaq +4d79dJxU2foC4oQLX9e7rOkCgYADaV0dt0Dt3xxXGUhtkPlEKO/vgOgao+bv6jz1 +Wlh3EiuQJ7KOw7dJnKiQqWwvqW4m3cZu+qbShCcalxR0bvhR0uv9M7hgQxb67AC0 +YRIqr8CCjP/Sc17BTRpi+sVyN1+vuaKqTxQQwPDXOx7CrnCmwRdhRFBzO3+KYMTj +nhWGsQKBgFjQkaxPXGpNaMOGoEzNir/4gvGsxPICUEACfA/HAAaTZm23Ly5+ThGO +Pvg4hqY1C1T+ksn0pGIlkGRi3Tq/QnX4NyxouskQHGswDq6wVGYHu00HizmHgYhx +XXuYhcynXchaHnpHNVxuq85cw5LWs3rPwkH50Kq4jnJ7wGhQLVB9 +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/multiLevelCert/demoCA/crlnumber b/test/opensslcrt/multiLevelCert/demoCA/crlnumber new file mode 100644 index 0000000000000000000000000000000000000000..4daddb72ffc0402845066df20e4480037235ff6b --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/crlnumber @@ -0,0 +1 @@ +00 diff --git a/test/opensslcrt/multiLevelCert/demoCA/index.txt b/test/opensslcrt/multiLevelCert/demoCA/index.txt new file mode 100644 index 0000000000000000000000000000000000000000..3b9c3faa1f9537d9c661f71824a2a0b3897dbe17 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/index.txt @@ -0,0 +1,3 @@ +V 321009094208Z 01 unknown /C=CN/ST=JS/O=Huawei/OU=Dev/CN=SecondCA/emailAddress=secondca@world.com +V 321009094413Z 02 unknown /C=CN/ST=JS/O=Huawei/OU=Dev/CN=testsert.com/emailAddress=test@world.com +V 321009094645Z 03 unknown /C=CN/ST=JS/O=Huawei/OU=Dev/CN=testclient.com/emailAddress=test@world.com diff --git a/test/opensslcrt/multiLevelCert/demoCA/index.txt.attr b/test/opensslcrt/multiLevelCert/demoCA/index.txt.attr new file mode 100644 index 0000000000000000000000000000000000000000..8f7e63a3475ce82ed03dba035f5c01a42ca38c65 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/index.txt.attr @@ -0,0 +1 @@ +unique_subject = yes diff --git a/test/opensslcrt/multiLevelCert/demoCA/index.txt.attr.old b/test/opensslcrt/multiLevelCert/demoCA/index.txt.attr.old new file mode 100644 index 0000000000000000000000000000000000000000..8f7e63a3475ce82ed03dba035f5c01a42ca38c65 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/index.txt.attr.old @@ -0,0 +1 @@ +unique_subject = yes diff --git a/test/opensslcrt/multiLevelCert/demoCA/index.txt.old b/test/opensslcrt/multiLevelCert/demoCA/index.txt.old new file mode 100644 index 0000000000000000000000000000000000000000..c698fcf327c721c2a866cfeba300935ab468067e --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/index.txt.old @@ -0,0 +1,2 @@ +V 321009094208Z 01 unknown /C=CN/ST=JS/O=Huawei/OU=Dev/CN=SecondCA/emailAddress=secondca@world.com +V 321009094413Z 02 unknown /C=CN/ST=JS/O=Huawei/OU=Dev/CN=testsert.com/emailAddress=test@world.com diff --git a/test/opensslcrt/multiLevelCert/demoCA/newcerts/01.pem b/test/opensslcrt/multiLevelCert/demoCA/newcerts/01.pem new file mode 100644 index 0000000000000000000000000000000000000000..866099c21c401d7bcaea279480e11a288068eb06 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/newcerts/01.pem @@ -0,0 +1,79 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 1 (0x1) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CN, ST=JS, L=NJ, O=Huawei, OU=Dev, CN=RootCA/emailAddress=rootca@world.com + Validity + Not Before: Oct 12 09:42:08 2022 GMT + Not After : Oct 9 09:42:08 2032 GMT + Subject: C=CN, ST=JS, O=Huawei, OU=Dev, CN=SecondCA/emailAddress=secondca@world.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:d7:29:a3:b9:46:e6:db:62:4b:cc:89:b7:3c:46: + 56:53:bb:94:0d:11:a8:e0:cb:d7:13:bc:e0:44:be: + 8c:4d:c2:3c:b0:95:a3:cd:e5:f7:91:3e:64:74:19: + 65:62:f2:02:bf:c3:ac:60:a6:ac:ff:77:f8:e0:c8: + 34:69:0c:bc:b6:65:11:5f:af:61:cc:73:31:97:91: + c1:cb:48:4f:28:1a:49:ac:55:e9:b3:8d:f6:6d:85: + de:4d:62:41:94:bf:80:a9:0b:f4:bd:62:fc:5d:e7: + 90:e0:04:b8:26:0c:83:01:df:3f:2f:bf:13:de:c7: + f4:40:91:a7:0b:76:9b:39:5d:e3:45:10:66:3d:67: + 41:8d:67:73:c9:79:17:5d:29:e4:64:a2:36:9e:70: + f7:c3:7f:16:79:3d:64:21:12:50:ab:4f:be:22:b9: + e3:59:63:be:cf:ad:be:04:af:b8:ab:49:24:86:dd: + 85:b9:fe:c5:6e:4e:f6:14:04:ef:db:bd:8a:e9:a8: + d7:30:52:8d:92:8c:30:c7:51:a9:32:1a:aa:df:f4: + 39:fa:ae:df:d8:3f:3a:c7:de:2c:27:c6:af:e3:45: + f6:cb:93:69:e2:44:30:53:cc:cc:e8:52:76:0c:bd: + cd:ff:9f:ca:45:49:1a:e5:f1:7d:e7:4b:cd:b8:0f: + d3:a9 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + C7:83:AD:B8:8A:A9:78:47:3D:C7:69:FC:23:69:4B:D5:67:FC:C4:DF + X509v3 Authority Key Identifier: + keyid:57:18:1C:B2:E4:DE:14:6E:23:F4:35:C4:CA:5A:98:F6:C8:3B:09:DF + + X509v3 Basic Constraints: critical + CA:TRUE + Signature Algorithm: sha256WithRSAEncryption + 60:ce:7c:c2:15:d4:00:b9:6d:d6:e9:9e:14:83:86:9c:e4:06: + 38:77:39:7a:a3:3d:4a:65:84:29:be:55:21:54:71:21:39:bd: + 37:b8:12:75:2c:0a:6a:48:01:89:00:03:fd:9f:9e:be:8a:6f: + d1:85:25:eb:a8:5a:55:a9:a3:8b:d1:5b:d8:3c:aa:f4:43:07: + b0:28:57:fa:23:69:f3:db:e1:5d:9e:a5:92:da:1e:92:50:73: + 1c:aa:a6:51:85:2a:a3:ca:0c:05:d9:95:72:d7:51:b5:8b:02: + eb:6a:f2:ae:c6:2c:bd:cc:96:a0:c5:9d:10:c7:05:a7:d7:10: + 72:4e:05:01:bb:1f:24:c3:b7:6d:2e:a0:83:34:fb:5d:1e:84: + 95:b4:2e:53:b6:50:3f:d2:43:96:44:91:8c:01:e5:f3:64:e2: + b2:4c:82:d9:9b:f8:d8:9b:a3:6c:76:65:d3:1d:d2:b6:a4:bf: + 90:86:0a:37:f5:dd:a3:1c:ce:78:18:71:b3:63:9d:41:63:88: + ba:1f:1d:31:1c:a1:89:83:55:3a:a3:77:85:b8:d7:a1:63:d9: + 1f:96:7d:d6:a5:3c:47:2d:a1:10:a3:0c:c8:bb:75:92:f1:40: + e6:3b:fc:87:ac:65:b5:c9:91:53:2d:6c:ff:25:ad:56:fe:f5: + 79:7b:96:3d +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIBATANBgkqhkiG9w0BAQsFADB4MQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxCzAJBgNVBAcMAk5KMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNV +BAsMA0RldjEPMA0GA1UEAwwGUm9vdENBMR8wHQYJKoZIhvcNAQkBFhByb290Y2FA +d29ybGQuY29tMB4XDTIyMTAxMjA5NDIwOFoXDTMyMTAwOTA5NDIwOFowbzELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsM +A0RldjERMA8GA1UEAwwIU2Vjb25kQ0ExITAfBgkqhkiG9w0BCQEWEnNlY29uZGNh +QHdvcmxkLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANcpo7lG +5ttiS8yJtzxGVlO7lA0RqODL1xO84ES+jE3CPLCVo83l95E+ZHQZZWLyAr/DrGCm +rP93+ODINGkMvLZlEV+vYcxzMZeRwctITygaSaxV6bON9m2F3k1iQZS/gKkL9L1i +/F3nkOAEuCYMgwHfPy+/E97H9ECRpwt2mzld40UQZj1nQY1nc8l5F10p5GSiNp5w +98N/Fnk9ZCESUKtPviK541ljvs+tvgSvuKtJJIbdhbn+xW5O9hQE79u9iumo1zBS +jZKMMMdRqTIaqt/0Ofqu39g/OsfeLCfGr+NF9suTaeJEMFPMzOhSdgy9zf+fykVJ +GuXxfedLzbgP06kCAwEAAaNTMFEwHQYDVR0OBBYEFMeDrbiKqXhHPcdp/CNpS9Vn +/MTfMB8GA1UdIwQYMBaAFFcYHLLk3hRuI/Q1xMpamPbIOwnfMA8GA1UdEwEB/wQF +MAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGDOfMIV1AC5bdbpnhSDhpzkBjh3OXqj +PUplhCm+VSFUcSE5vTe4EnUsCmpIAYkAA/2fnr6Kb9GFJeuoWlWpo4vRW9g8qvRD +B7AoV/ojafPb4V2epZLaHpJQcxyqplGFKqPKDAXZlXLXUbWLAutq8q7GLL3MlqDF +nRDHBafXEHJOBQG7HyTDt20uoIM0+10ehJW0LlO2UD/SQ5ZEkYwB5fNk4rJMgtmb ++Nibo2x2ZdMd0rakv5CGCjf13aMczngYcbNjnUFjiLofHTEcoYmDVTqjd4W416Fj +2R+WfdalPEctoRCjDMi7dZLxQOY7/IesZbXJkVMtbP8lrVb+9Xl7lj0= +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/demoCA/newcerts/02.pem b/test/opensslcrt/multiLevelCert/demoCA/newcerts/02.pem new file mode 100644 index 0000000000000000000000000000000000000000..070e5de126289aeb9c0ca5c37e3ca0495fe172eb --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/newcerts/02.pem @@ -0,0 +1,82 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 2 (0x2) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CN, ST=JS, O=Huawei, OU=Dev, CN=SecondCA/emailAddress=secondca@world.com + Validity + Not Before: Oct 12 09:44:13 2022 GMT + Not After : Oct 9 09:44:13 2032 GMT + Subject: C=CN, ST=JS, O=Huawei, OU=Dev, CN=testsert.com/emailAddress=test@world.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:a7:db:a3:51:d1:f0:7c:6f:de:a3:2a:34:68:75: + b9:f3:ef:f8:76:f6:50:0f:2e:4f:3c:8c:3e:fd:60: + 1b:ed:f7:d5:7d:ac:41:4a:9b:bf:5e:63:12:e4:cb: + 72:3c:68:39:8b:82:3a:68:85:40:f1:c9:a6:3f:88: + ef:97:0f:46:b2:2e:b1:c3:a3:f0:15:27:58:99:69: + 7a:36:8f:68:80:92:d9:73:b0:cc:c1:b9:01:ee:6b: + 52:43:e8:2f:25:5c:8e:19:6e:e9:8d:c0:e2:41:58: + 68:ab:c7:f3:ee:7c:5f:02:35:14:cf:0e:6e:cc:1e: + f9:75:a7:91:92:23:e0:2e:4e:0a:16:08:8f:bd:59: + 87:62:12:db:3a:4f:42:39:67:62:40:00:dc:52:66: + 2b:b2:ca:a1:c3:2e:6e:36:4e:4a:58:6f:23:0c:db: + 04:8f:74:eb:d2:20:10:e3:ab:83:5c:4a:ae:59:b5: + 82:36:e6:7c:61:65:5c:c6:22:f7:32:c4:9b:44:8e: + f1:d8:54:e5:c5:8b:63:c4:10:29:89:1d:b8:41:7d: + 86:ea:f0:7b:e7:ec:a4:64:6f:4c:f8:5b:4f:b2:11: + f3:25:b7:c7:0e:e4:6a:1e:d3:50:3b:39:6f:78:cc: + e8:c7:bd:53:43:3b:7e:b6:f6:de:ce:f0:4f:7d:69: + da:63 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + 2E:DE:C4:D7:37:06:30:C1:39:C8:81:CE:8E:6A:55:8F:9D:A8:F2:88 + X509v3 Authority Key Identifier: + keyid:C7:83:AD:B8:8A:A9:78:47:3D:C7:69:FC:23:69:4B:D5:67:FC:C4:DF + + Signature Algorithm: sha256WithRSAEncryption + 6a:d9:7c:1b:2e:c3:bc:70:89:9d:d5:d3:ae:d3:95:7a:53:4c: + d6:fb:f4:0c:f5:68:cf:f7:88:55:50:98:b2:43:ec:9e:c4:73: + 4b:29:cb:cf:c7:ed:86:e3:89:f8:fe:a5:2b:f3:4b:6f:80:7a: + f5:fe:89:3d:c9:0c:da:2f:e0:8a:43:1f:e1:2d:8b:38:e4:81: + 84:94:b7:2c:3d:4d:ad:57:44:19:14:e9:5c:52:09:ae:99:93: + 83:8e:e4:d4:67:31:f9:a1:93:1c:56:e3:e6:d4:88:80:57:b4: + c5:86:7e:d9:d2:ae:08:fa:db:62:71:99:cf:2e:c7:56:f9:b9: + 0f:41:0d:99:ac:9a:1d:99:b0:5b:f3:9b:d6:6f:dd:aa:8b:d9: + 16:52:50:4f:68:bc:7b:9e:db:cf:ec:53:0e:59:06:0f:c9:6a: + 50:13:68:a9:57:7b:c8:52:a5:d4:c6:65:c1:8d:a9:c7:8c:ec: + 48:8c:c5:04:0e:94:69:26:4a:e9:3d:4a:f3:9e:3f:f5:fe:36: + 29:51:04:7b:a2:11:2c:45:26:cd:cc:a8:09:cc:a2:20:9b:64: + 2f:4f:d0:bb:70:ae:34:6e:ca:7d:79:7d:d5:48:46:14:7f:52: + 47:57:92:e7:98:65:37:d3:bd:c6:50:16:a6:17:e3:8e:97:15: + b2:4f:8e:c0 +-----BEGIN CERTIFICATE----- +MIID1DCCArygAwIBAgIBAjANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDQxM1oXDTMyMTAwOTA5NDQxM1owbzELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEVMBMG +A1UEAwwMdGVzdHNlcnQuY29tMR0wGwYJKoZIhvcNAQkBFg50ZXN0QHdvcmxkLmNv +bTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKfbo1HR8Hxv3qMqNGh1 +ufPv+Hb2UA8uTzyMPv1gG+331X2sQUqbv15jEuTLcjxoOYuCOmiFQPHJpj+I75cP +RrIuscOj8BUnWJlpejaPaICS2XOwzMG5Ae5rUkPoLyVcjhlu6Y3A4kFYaKvH8+58 +XwI1FM8Obswe+XWnkZIj4C5OChYIj71Zh2IS2zpPQjlnYkAA3FJmK7LKocMubjZO +SlhvIwzbBI9069IgEOOrg1xKrlm1gjbmfGFlXMYi9zLEm0SO8dhU5cWLY8QQKYkd +uEF9hurwe+fspGRvTPhbT7IR8yW3xw7kah7TUDs5b3jM6Me9U0M7frb23s7wT31p +2mMCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0EHxYdT3BlblNTTCBH +ZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFC7exNc3BjDBOciBzo5qVY+d +qPKIMB8GA1UdIwQYMBaAFMeDrbiKqXhHPcdp/CNpS9Vn/MTfMA0GCSqGSIb3DQEB +CwUAA4IBAQBq2XwbLsO8cImd1dOu05V6U0zW+/QM9WjP94hVUJiyQ+yexHNLKcvP +x+2G44n4/qUr80tvgHr1/ok9yQzaL+CKQx/hLYs45IGElLcsPU2tV0QZFOlcUgmu +mZODjuTUZzH5oZMcVuPm1IiAV7TFhn7Z0q4I+tticZnPLsdW+bkPQQ2ZrJodmbBb +85vWb92qi9kWUlBPaLx7ntvP7FMOWQYPyWpQE2ipV3vIUqXUxmXBjanHjOxIjMUE +DpRpJkrpPUrznj/1/jYpUQR7ohEsRSbNzKgJzKIgm2QvT9C7cK40bsp9eX3VSEYU +f1JHV5LnmGU3073GUBamF+OOlxWyT47A +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/demoCA/newcerts/03.pem b/test/opensslcrt/multiLevelCert/demoCA/newcerts/03.pem new file mode 100644 index 0000000000000000000000000000000000000000..29106673ef0a58abba477f078d16d0e6bc23da38 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/newcerts/03.pem @@ -0,0 +1,82 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 3 (0x3) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CN, ST=JS, O=Huawei, OU=Dev, CN=SecondCA/emailAddress=secondca@world.com + Validity + Not Before: Oct 12 09:46:45 2022 GMT + Not After : Oct 9 09:46:45 2032 GMT + Subject: C=CN, ST=JS, O=Huawei, OU=Dev, CN=testclient.com/emailAddress=test@world.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:e5:eb:c7:fd:84:20:0c:7a:f2:12:df:fa:39:5d: + 40:2d:a6:cd:48:81:a0:66:32:de:d1:18:e8:ba:8e: + bc:e2:f3:e0:44:e0:fa:a8:b9:df:0c:a6:4e:da:59: + 5f:92:ea:56:3a:dc:90:26:f9:36:f6:08:a2:35:ad: + b4:82:d9:c4:8a:78:6b:eb:17:a6:27:c3:df:63:15: + e8:4d:0d:29:0c:22:50:43:e6:e8:ce:e2:6d:8d:0e: + 83:e7:e2:1d:10:27:be:ad:b7:9c:c4:48:b1:20:c2: + 62:50:85:8b:ec:fb:b9:c0:3a:ea:1c:a2:3d:e5:62: + 6b:dd:70:e2:97:ce:be:8e:d5:ee:30:1e:94:e0:3a: + 37:44:fe:ca:5b:93:05:6a:8c:a7:f6:6a:d0:20:a5: + 2b:d9:0e:ac:b1:0e:04:51:06:98:05:ae:4a:c9:47: + c7:bf:6d:47:b2:92:bd:c7:54:fb:b8:b5:86:67:b3: + de:53:f0:d5:7c:94:1d:2e:4f:58:3a:57:72:21:5c: + 2b:b4:00:b8:92:56:bf:29:25:56:47:d9:4c:a5:29: + b7:90:ba:bc:e1:9d:43:0a:4c:b4:c2:78:b3:43:47: + 86:f8:66:d2:23:c8:49:00:59:f7:6f:07:c1:9a:ef: + 7d:44:38:ad:97:3e:46:61:8f:ff:98:3e:73:78:35: + e9:1f + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + A6:28:FE:B3:E2:16:72:B4:DE:71:8A:10:BE:79:5C:90:9A:31:7C:0B + X509v3 Authority Key Identifier: + keyid:C7:83:AD:B8:8A:A9:78:47:3D:C7:69:FC:23:69:4B:D5:67:FC:C4:DF + + Signature Algorithm: sha256WithRSAEncryption + 6b:41:eb:e5:91:e3:26:6d:f0:3b:98:27:b4:90:e2:0b:04:04: + 0a:3a:04:40:bf:95:7e:8d:bf:80:9e:49:12:81:cf:c4:43:1a: + e3:79:cc:07:c9:b4:2f:64:f5:cd:a5:d3:7b:40:61:04:27:c2: + 46:22:aa:c1:10:a0:de:a8:8f:3d:09:72:49:90:6f:5e:32:dd: + d8:fb:94:29:60:98:c0:81:cb:a0:4f:8e:40:cb:f3:97:73:ff: + 83:45:47:62:b8:57:12:7c:a3:06:1a:36:30:42:40:5a:75:a1: + 74:6a:fb:07:25:b9:6b:1e:7a:38:68:24:82:5d:e3:d1:f5:e2: + 09:0d:71:58:36:8f:57:a9:71:17:4a:e2:46:4b:de:0f:db:ea: + e6:f5:d4:d5:2f:b7:ed:91:7e:b4:10:50:2a:78:dc:5a:4a:f3: + a7:cb:77:57:44:9d:e7:32:14:28:05:3b:21:fe:95:aa:cf:50: + f6:2d:ed:fb:59:5a:42:ca:8a:1f:d3:f6:43:b5:50:79:9a:5d: + be:2e:67:45:59:83:9f:ad:e9:4b:7f:8d:e0:f7:77:c4:27:1c: + c7:ed:2f:f4:ca:39:1f:ad:8a:1c:25:11:99:96:4f:23:9b:6a: + 11:e6:cf:44:8c:59:07:d4:21:6d:e1:7c:d0:ee:43:33:f2:6c: + a9:f2:70:2d +-----BEGIN CERTIFICATE----- +MIID1jCCAr6gAwIBAgIBAzANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDY0NVoXDTMyMTAwOTA5NDY0NVowcTELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEXMBUG +A1UEAwwOdGVzdGNsaWVudC5jb20xHTAbBgkqhkiG9w0BCQEWDnRlc3RAd29ybGQu +Y29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5evH/YQgDHryEt/6 +OV1ALabNSIGgZjLe0Rjouo684vPgROD6qLnfDKZO2llfkupWOtyQJvk29giiNa20 +gtnEinhr6xemJ8PfYxXoTQ0pDCJQQ+bozuJtjQ6D5+IdECe+rbecxEixIMJiUIWL +7Pu5wDrqHKI95WJr3XDil86+jtXuMB6U4Do3RP7KW5MFaoyn9mrQIKUr2Q6ssQ4E +UQaYBa5KyUfHv21HspK9x1T7uLWGZ7PeU/DVfJQdLk9YOldyIVwrtAC4kla/KSVW +R9lMpSm3kLq84Z1DCky0wnizQ0eG+GbSI8hJAFn3bwfBmu99RDitlz5GYY//mD5z +eDXpHwIDAQABo3sweTAJBgNVHRMEAjAAMCwGCWCGSAGG+EIBDQQfFh1PcGVuU1NM +IEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTAdBgNVHQ4EFgQUpij+s+IWcrTecYoQvnlc +kJoxfAswHwYDVR0jBBgwFoAUx4OtuIqpeEc9x2n8I2lL1Wf8xN8wDQYJKoZIhvcN +AQELBQADggEBAGtB6+WR4yZt8DuYJ7SQ4gsEBAo6BEC/lX6Nv4CeSRKBz8RDGuN5 +zAfJtC9k9c2l03tAYQQnwkYiqsEQoN6ojz0JckmQb14y3dj7lClgmMCBy6BPjkDL +85dz/4NFR2K4VxJ8owYaNjBCQFp1oXRq+wcluWseejhoJIJd49H14gkNcVg2j1ep +cRdK4kZL3g/b6ub11NUvt+2RfrQQUCp43FpK86fLd1dEnecyFCgFOyH+larPUPYt +7ftZWkLKih/T9kO1UHmaXb4uZ0VZg5+t6Ut/jeD3d8QnHMftL/TKOR+tihwlEZmW +TyObahHmz0SMWQfUIW3hfNDuQzPybKnycC0= +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/demoCA/serial b/test/opensslcrt/multiLevelCert/demoCA/serial new file mode 100644 index 0000000000000000000000000000000000000000..64969239d5f72d674bbedc24eb0a155a59d0e607 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/serial @@ -0,0 +1 @@ +04 diff --git a/test/opensslcrt/multiLevelCert/demoCA/serial.old b/test/opensslcrt/multiLevelCert/demoCA/serial.old new file mode 100644 index 0000000000000000000000000000000000000000..75016ea3625245b1aac79cc5586c3f33ce8b7c78 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/demoCA/serial.old @@ -0,0 +1 @@ +03 diff --git a/test/opensslcrt/multiLevelCert/openssl.cnf b/test/opensslcrt/multiLevelCert/openssl.cnf new file mode 100644 index 0000000000000000000000000000000000000000..fb721d6eb14432876b937e45a6d44ff20d80ba49 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/openssl.cnf @@ -0,0 +1,350 @@ +# +# OpenSSL example configuration file. +# This is mostly being used for generation of certificate requests. +# + +# Note that you can include other files from the main configuration +# file using the .include directive. +#.include filename + +# This definition stops the following lines choking if HOME isn't +# defined. +HOME = . + +# Extra OBJECT IDENTIFIER info: +#oid_file = $ENV::HOME/.oid +oid_section = new_oids + +# To use this configuration file with the "-extfile" option of the +# "openssl x509" utility, name here the section containing the +# X.509v3 extensions to use: +# extensions = +# (Alternatively, use a configuration file that has only +# X.509v3 extensions in its main [= default] section.) + +[ new_oids ] + +# We can add new OIDs in here for use by 'ca', 'req' and 'ts'. +# Add a simple OID like this: +# testoid1=1.2.3.4 +# Or use config file substitution like this: +# testoid2=${testoid1}.5.6 + +# Policies used by the TSA examples. +tsa_policy1 = 1.2.3.4.1 +tsa_policy2 = 1.2.3.4.5.6 +tsa_policy3 = 1.2.3.4.5.7 + +#################################################################### +[ ca ] +default_ca = CA_default # The default ca section + +#################################################################### +[ CA_default ] + +dir = ./CA # Where everything is kept +certs = $dir/certs # Where the issued certs are kept +crl_dir = $dir/crl # Where the issued crl are kept +database = $dir/index.txt # database index file. +#unique_subject = no # Set to 'no' to allow creation of + # several certs with same subject. +new_certs_dir = $dir/newcerts # default place for new certs. + +certificate = $dir/cacert.pem # The CA certificate +serial = $dir/serial # The current serial number +crlnumber = $dir/crlnumber # the current crl number + # must be commented out to leave a V1 CRL +crl = $dir/crl.pem # The current CRL +private_key = $dir/private/cakey.pem# The private key + +x509_extensions = usr_cert # The extensions to add to the cert + +# Comment out the following two lines for the "traditional" +# (and highly broken) format. +name_opt = ca_default # Subject Name options +cert_opt = ca_default # Certificate field options + +# Extension copying option: use with caution. +# copy_extensions = copy + +# Extensions to add to a CRL. Note: Netscape communicator chokes on V2 CRLs +# so this is commented out by default to leave a V1 CRL. +# crlnumber must also be commented out to leave a V1 CRL. +# crl_extensions = crl_ext + +default_days = 365 # how long to certify for +default_crl_days= 30 # how long before next CRL +default_md = default # use public key default MD +preserve = no # keep passed DN ordering + +# A few difference way of specifying how similar the request should look +# For type CA, the listed attributes must be the same, and the optional +# and supplied fields are just that :-) +policy = policy_match + +# For the CA policy +[ policy_match ] +countryName = match +stateOrProvinceName = match +organizationName = match +organizationalUnitName = optional +commonName = supplied +emailAddress = optional + +# For the 'anything' policy +# At this point in time, you must list all acceptable 'object' +# types. +[ policy_anything ] +countryName = optional +stateOrProvinceName = optional +localityName = optional +organizationName = optional +organizationalUnitName = optional +commonName = supplied +emailAddress = optional + +#################################################################### +[ req ] +default_bits = 2048 +default_keyfile = privkey.pem +distinguished_name = req_distinguished_name +attributes = req_attributes +x509_extensions = v3_ca # The extensions to add to the self signed cert + +# Passwords for private keys if not present they will be prompted for +# input_password = secret +# output_password = secret + +# This sets a mask for permitted string types. There are several options. +# default: PrintableString, T61String, BMPString. +# pkix : PrintableString, BMPString (PKIX recommendation before 2004) +# utf8only: only UTF8Strings (PKIX recommendation after 2004). +# nombstr : PrintableString, T61String (no BMPStrings or UTF8Strings). +# MASK:XXXX a literal mask value. +# WARNING: ancient versions of Netscape crash on BMPStrings or UTF8Strings. +string_mask = utf8only + +# req_extensions = v3_req # The extensions to add to a certificate request + +[ req_distinguished_name ] +countryName = Country Name (2 letter code) +countryName_default = AU +countryName_min = 2 +countryName_max = 2 + +stateOrProvinceName = State or Province Name (full name) +stateOrProvinceName_default = Some-State + +localityName = Locality Name (eg, city) + +0.organizationName = Organization Name (eg, company) +0.organizationName_default = Internet Widgits Pty Ltd + +# we can do this but it is not needed normally :-) +#1.organizationName = Second Organization Name (eg, company) +#1.organizationName_default = World Wide Web Pty Ltd + +organizationalUnitName = Organizational Unit Name (eg, section) +#organizationalUnitName_default = + +commonName = Common Name (e.g. server FQDN or YOUR name) +commonName_max = 64 + +emailAddress = Email Address +emailAddress_max = 64 + +# SET-ex3 = SET extension number 3 + +[ req_attributes ] +challengePassword = A challenge password +challengePassword_min = 4 +challengePassword_max = 20 + +unstructuredName = An optional company name + +[ usr_cert ] + +# These extensions are added when 'ca' signs a request. + +# This goes against PKIX guidelines but some CAs do it and some software +# requires this to avoid interpreting an end user certificate as a CA. + +basicConstraints=CA:FALSE + +# Here are some examples of the usage of nsCertType. If it is omitted +# the certificate can be used for anything *except* object signing. + +# This is OK for an SSL server. +# nsCertType = server + +# For an object signing certificate this would be used. +# nsCertType = objsign + +# For normal client use this is typical +# nsCertType = client, email + +# and for everything including object signing: +# nsCertType = client, email, objsign + +# This is typical in keyUsage for a client certificate. +# keyUsage = nonRepudiation, digitalSignature, keyEncipherment + +# This will be displayed in Netscape's comment listbox. +nsComment = "OpenSSL Generated Certificate" + +# PKIX recommendations harmless if included in all certificates. +subjectKeyIdentifier=hash +authorityKeyIdentifier=keyid,issuer + +# This stuff is for subjectAltName and issuerAltname. +# Import the email address. +# subjectAltName=email:copy +# An alternative to produce certificates that aren't +# deprecated according to PKIX. +# subjectAltName=email:move + +# Copy subject details +# issuerAltName=issuer:copy + +#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem +#nsBaseUrl +#nsRevocationUrl +#nsRenewalUrl +#nsCaPolicyUrl +#nsSslServerName + +# This is required for TSA certificates. +# extendedKeyUsage = critical,timeStamping + +[ v3_req ] + +# Extensions to add to a certificate request + +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment + +[ v3_ca ] + + +# Extensions for a typical CA + + +# PKIX recommendation. + +subjectKeyIdentifier=hash + +authorityKeyIdentifier=keyid:always,issuer + +basicConstraints = critical,CA:true + +# Key usage: this is typical for a CA certificate. However since it will +# prevent it being used as an test self-signed certificate it is best +# left out by default. +# keyUsage = cRLSign, keyCertSign + +# Some might want this also +# nsCertType = sslCA, emailCA + +# Include email address in subject alt name: another PKIX recommendation +# subjectAltName=email:copy +# Copy issuer details +# issuerAltName=issuer:copy + +# DER hex encoding of an extension: beware experts only! +# obj=DER:02:03 +# Where 'obj' is a standard or added object +# You can even override a supported extension: +# basicConstraints= critical, DER:30:03:01:01:FF + +[ crl_ext ] + +# CRL extensions. +# Only issuerAltName and authorityKeyIdentifier make any sense in a CRL. + +# issuerAltName=issuer:copy +authorityKeyIdentifier=keyid:always + +[ proxy_cert_ext ] +# These extensions should be added when creating a proxy certificate + +# This goes against PKIX guidelines but some CAs do it and some software +# requires this to avoid interpreting an end user certificate as a CA. + +basicConstraints=CA:FALSE + +# Here are some examples of the usage of nsCertType. If it is omitted +# the certificate can be used for anything *except* object signing. + +# This is OK for an SSL server. +# nsCertType = server + +# For an object signing certificate this would be used. +# nsCertType = objsign + +# For normal client use this is typical +# nsCertType = client, email + +# and for everything including object signing: +# nsCertType = client, email, objsign + +# This is typical in keyUsage for a client certificate. +# keyUsage = nonRepudiation, digitalSignature, keyEncipherment + +# This will be displayed in Netscape's comment listbox. +nsComment = "OpenSSL Generated Certificate" + +# PKIX recommendations harmless if included in all certificates. +subjectKeyIdentifier=hash +authorityKeyIdentifier=keyid,issuer + +# This stuff is for subjectAltName and issuerAltname. +# Import the email address. +# subjectAltName=email:copy +# An alternative to produce certificates that aren't +# deprecated according to PKIX. +# subjectAltName=email:move + +# Copy subject details +# issuerAltName=issuer:copy + +#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem +#nsBaseUrl +#nsRevocationUrl +#nsRenewalUrl +#nsCaPolicyUrl +#nsSslServerName + +# This really needs to be in place for it to be a proxy certificate. +proxyCertInfo=critical,language:id-ppl-anyLanguage,pathlen:3,policy:foo + +#################################################################### +[ tsa ] + +default_tsa = tsa_config1 # the default TSA section + +[ tsa_config1 ] + +# These are used by the TSA reply generation only. +dir = ./demoCA # TSA root directory +serial = $dir/tsaserial # The current serial number (mandatory) +crypto_device = builtin # OpenSSL engine to use for signing +signer_cert = $dir/tsacert.pem # The TSA signing certificate + # (optional) +certs = $dir/cacert.pem # Certificate chain to include in reply + # (optional) +signer_key = $dir/private/tsakey.pem # The TSA private key (optional) +signer_digest = sha256 # Signing digest to use. (Optional) +default_policy = tsa_policy1 # Policy if request did not specify it + # (optional) +other_policies = tsa_policy2, tsa_policy3 # acceptable policies (optional) +digests = sha1, sha256, sha384, sha512 # Acceptable message digests (mandatory) +accuracy = secs:1, millisecs:500, microsecs:100 # (optional) +clock_precision_digits = 0 # number of digits after dot. (optional) +ordering = yes # Is ordering defined for timestamps? + # (optional, default: no) +tsa_name = yes # Must the TSA name be included in the reply? + # (optional, default: no) +ess_cert_id_chain = no # Must the ESS cert id chain be included? + # (optional, default: no) +ess_cert_id_alg = sha1 # algorithm to compute certificate + # identifier (optional, default: sha1) diff --git a/test/opensslcrt/multiLevelCert/server/cert.pem b/test/opensslcrt/multiLevelCert/server/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..070e5de126289aeb9c0ca5c37e3ca0495fe172eb --- /dev/null +++ b/test/opensslcrt/multiLevelCert/server/cert.pem @@ -0,0 +1,82 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 2 (0x2) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CN, ST=JS, O=Huawei, OU=Dev, CN=SecondCA/emailAddress=secondca@world.com + Validity + Not Before: Oct 12 09:44:13 2022 GMT + Not After : Oct 9 09:44:13 2032 GMT + Subject: C=CN, ST=JS, O=Huawei, OU=Dev, CN=testsert.com/emailAddress=test@world.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:a7:db:a3:51:d1:f0:7c:6f:de:a3:2a:34:68:75: + b9:f3:ef:f8:76:f6:50:0f:2e:4f:3c:8c:3e:fd:60: + 1b:ed:f7:d5:7d:ac:41:4a:9b:bf:5e:63:12:e4:cb: + 72:3c:68:39:8b:82:3a:68:85:40:f1:c9:a6:3f:88: + ef:97:0f:46:b2:2e:b1:c3:a3:f0:15:27:58:99:69: + 7a:36:8f:68:80:92:d9:73:b0:cc:c1:b9:01:ee:6b: + 52:43:e8:2f:25:5c:8e:19:6e:e9:8d:c0:e2:41:58: + 68:ab:c7:f3:ee:7c:5f:02:35:14:cf:0e:6e:cc:1e: + f9:75:a7:91:92:23:e0:2e:4e:0a:16:08:8f:bd:59: + 87:62:12:db:3a:4f:42:39:67:62:40:00:dc:52:66: + 2b:b2:ca:a1:c3:2e:6e:36:4e:4a:58:6f:23:0c:db: + 04:8f:74:eb:d2:20:10:e3:ab:83:5c:4a:ae:59:b5: + 82:36:e6:7c:61:65:5c:c6:22:f7:32:c4:9b:44:8e: + f1:d8:54:e5:c5:8b:63:c4:10:29:89:1d:b8:41:7d: + 86:ea:f0:7b:e7:ec:a4:64:6f:4c:f8:5b:4f:b2:11: + f3:25:b7:c7:0e:e4:6a:1e:d3:50:3b:39:6f:78:cc: + e8:c7:bd:53:43:3b:7e:b6:f6:de:ce:f0:4f:7d:69: + da:63 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + 2E:DE:C4:D7:37:06:30:C1:39:C8:81:CE:8E:6A:55:8F:9D:A8:F2:88 + X509v3 Authority Key Identifier: + keyid:C7:83:AD:B8:8A:A9:78:47:3D:C7:69:FC:23:69:4B:D5:67:FC:C4:DF + + Signature Algorithm: sha256WithRSAEncryption + 6a:d9:7c:1b:2e:c3:bc:70:89:9d:d5:d3:ae:d3:95:7a:53:4c: + d6:fb:f4:0c:f5:68:cf:f7:88:55:50:98:b2:43:ec:9e:c4:73: + 4b:29:cb:cf:c7:ed:86:e3:89:f8:fe:a5:2b:f3:4b:6f:80:7a: + f5:fe:89:3d:c9:0c:da:2f:e0:8a:43:1f:e1:2d:8b:38:e4:81: + 84:94:b7:2c:3d:4d:ad:57:44:19:14:e9:5c:52:09:ae:99:93: + 83:8e:e4:d4:67:31:f9:a1:93:1c:56:e3:e6:d4:88:80:57:b4: + c5:86:7e:d9:d2:ae:08:fa:db:62:71:99:cf:2e:c7:56:f9:b9: + 0f:41:0d:99:ac:9a:1d:99:b0:5b:f3:9b:d6:6f:dd:aa:8b:d9: + 16:52:50:4f:68:bc:7b:9e:db:cf:ec:53:0e:59:06:0f:c9:6a: + 50:13:68:a9:57:7b:c8:52:a5:d4:c6:65:c1:8d:a9:c7:8c:ec: + 48:8c:c5:04:0e:94:69:26:4a:e9:3d:4a:f3:9e:3f:f5:fe:36: + 29:51:04:7b:a2:11:2c:45:26:cd:cc:a8:09:cc:a2:20:9b:64: + 2f:4f:d0:bb:70:ae:34:6e:ca:7d:79:7d:d5:48:46:14:7f:52: + 47:57:92:e7:98:65:37:d3:bd:c6:50:16:a6:17:e3:8e:97:15: + b2:4f:8e:c0 +-----BEGIN CERTIFICATE----- +MIID1DCCArygAwIBAgIBAjANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDQxM1oXDTMyMTAwOTA5NDQxM1owbzELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEVMBMG +A1UEAwwMdGVzdHNlcnQuY29tMR0wGwYJKoZIhvcNAQkBFg50ZXN0QHdvcmxkLmNv +bTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKfbo1HR8Hxv3qMqNGh1 +ufPv+Hb2UA8uTzyMPv1gG+331X2sQUqbv15jEuTLcjxoOYuCOmiFQPHJpj+I75cP +RrIuscOj8BUnWJlpejaPaICS2XOwzMG5Ae5rUkPoLyVcjhlu6Y3A4kFYaKvH8+58 +XwI1FM8Obswe+XWnkZIj4C5OChYIj71Zh2IS2zpPQjlnYkAA3FJmK7LKocMubjZO +SlhvIwzbBI9069IgEOOrg1xKrlm1gjbmfGFlXMYi9zLEm0SO8dhU5cWLY8QQKYkd +uEF9hurwe+fspGRvTPhbT7IR8yW3xw7kah7TUDs5b3jM6Me9U0M7frb23s7wT31p +2mMCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0EHxYdT3BlblNTTCBH +ZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFC7exNc3BjDBOciBzo5qVY+d +qPKIMB8GA1UdIwQYMBaAFMeDrbiKqXhHPcdp/CNpS9Vn/MTfMA0GCSqGSIb3DQEB +CwUAA4IBAQBq2XwbLsO8cImd1dOu05V6U0zW+/QM9WjP94hVUJiyQ+yexHNLKcvP +x+2G44n4/qUr80tvgHr1/ok9yQzaL+CKQx/hLYs45IGElLcsPU2tV0QZFOlcUgmu +mZODjuTUZzH5oZMcVuPm1IiAV7TFhn7Z0q4I+tticZnPLsdW+bkPQQ2ZrJodmbBb +85vWb92qi9kWUlBPaLx7ntvP7FMOWQYPyWpQE2ipV3vIUqXUxmXBjanHjOxIjMUE +DpRpJkrpPUrznj/1/jYpUQR7ohEsRSbNzKgJzKIgm2QvT9C7cK40bsp9eX3VSEYU +f1JHV5LnmGU3073GUBamF+OOlxWyT47A +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/server/key.pem b/test/opensslcrt/multiLevelCert/server/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..71ec9e7e81db54ff9436879ed610102db7e881fa --- /dev/null +++ b/test/opensslcrt/multiLevelCert/server/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAp9ujUdHwfG/eoyo0aHW58+/4dvZQDy5PPIw+/WAb7ffVfaxB +Spu/XmMS5MtyPGg5i4I6aIVA8cmmP4jvlw9Gsi6xw6PwFSdYmWl6No9ogJLZc7DM +wbkB7mtSQ+gvJVyOGW7pjcDiQVhoq8fz7nxfAjUUzw5uzB75daeRkiPgLk4KFgiP +vVmHYhLbOk9COWdiQADcUmYrssqhwy5uNk5KWG8jDNsEj3Tr0iAQ46uDXEquWbWC +NuZ8YWVcxiL3MsSbRI7x2FTlxYtjxBApiR24QX2G6vB75+ykZG9M+FtPshHzJbfH +DuRqHtNQOzlveMzox71TQzt+tvbezvBPfWnaYwIDAQABAoIBAQCPFoHAK5Au40YM +HNwT99cOBI/vCMTyS+2rlXnUj2r/jfZlbMMzkFSvZxEiC/NTXx0+uUKE+qKD+ftH +yblDMfh3x6otNcBgp+u0yt8tR04z2/qVzi6dLNJipQW5cWFPHfjb4VoiRjwYq/5+ +ALMFputufEVCw/Da+8R28OL8iqx9ixNN/bfxktbP5NqGkcEUun5wE+/0KnFzylcV +XyKIzFFun8U6P0oupjv+aoB12SET/X75Xgh3Elx5/Q9FedHRXGjFcam6NJhOG79s +G6IkyciKP7Pf1SYFX6b5ps2bo3oPKWVvJLGMlxSuxJpqmoTQOzochoZF8uak+hkQ +eHl07+TRAoGBANBEVtTyMkRv6dGgmRCPsqolhxsOkHHd5+YOThg5d3EDWRcsUidR +VzwU9/GMrF3grV7FuQk2GYqMRDjFg87+/0O0H6vANxXjNc6sv+dblOikbAJoLksr +245OeCgd383K5+f00pOHvd8KfrxwQPqarpvW2JORXHludgZRgx8FGulfAoGBAM5U +YSKDw+d/D6+PHazakLucZNc7hUUd8Gvh8EYdUxObX+Dr8UIjV/1qqYZzhPUNDpn9 +//cYq8ndAQ7gCTANESJc//ikqDBXHJ3AoMwpNX86ozNUyt6+eoQx2tizgkEMzndo +BVt7A3RjWjn/+bJJX6mIsjg8salQrMW1bOyECHl9AoGBAJwdfhFl87RFR7oxbktx +y/Wq59mqUzBnrPtQYc3a1ePLJK8wM+zxFjkdZraUQmikkJDoGcoD2aV3e3Qq6qDx +mJtBnDP8g85OYPkpmTht9/NjvOsY+Qq0N4I24+7+ZdM3dBr19BtOt09H6LSMWMkB +xj1fET2cyvrjiGk4FNfd1cx1AoGBAL+xdW1zrhbt3czl0lQ93Cnx615sVi0Y273f +dDQwGnck67c0fjlMTPuMlWPs/6IMN3yql50itrgdNFZ1nxOdkEW00bxYfkorJNML +nFkSEDncaLPQG4tGvN0E1KZwYJu/IjOd2Rxc9aC0jadFQt95e/8umSXWfdkostwc +6s3y/UyhAoGAJAG6S/k1Xg0TbVxeZvMlnmfEc7WKh1QCnsH9Va/VnPXlIlTWmHu/ +xBtBb0kC0IF5zWgHhSADz3SHDsYZ87SEqoQxx3BKzd+EU52RF2x74GQDY49mQlUR +zMnxYP3TUSEfkTuSpT8ZwuxQ7KSk5nREDzKIVfcxgFw/cmrW/eChT7w= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/multiLevelCert/server/server.crt b/test/opensslcrt/multiLevelCert/server/server.crt new file mode 100644 index 0000000000000000000000000000000000000000..070e5de126289aeb9c0ca5c37e3ca0495fe172eb --- /dev/null +++ b/test/opensslcrt/multiLevelCert/server/server.crt @@ -0,0 +1,82 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 2 (0x2) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CN, ST=JS, O=Huawei, OU=Dev, CN=SecondCA/emailAddress=secondca@world.com + Validity + Not Before: Oct 12 09:44:13 2022 GMT + Not After : Oct 9 09:44:13 2032 GMT + Subject: C=CN, ST=JS, O=Huawei, OU=Dev, CN=testsert.com/emailAddress=test@world.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:a7:db:a3:51:d1:f0:7c:6f:de:a3:2a:34:68:75: + b9:f3:ef:f8:76:f6:50:0f:2e:4f:3c:8c:3e:fd:60: + 1b:ed:f7:d5:7d:ac:41:4a:9b:bf:5e:63:12:e4:cb: + 72:3c:68:39:8b:82:3a:68:85:40:f1:c9:a6:3f:88: + ef:97:0f:46:b2:2e:b1:c3:a3:f0:15:27:58:99:69: + 7a:36:8f:68:80:92:d9:73:b0:cc:c1:b9:01:ee:6b: + 52:43:e8:2f:25:5c:8e:19:6e:e9:8d:c0:e2:41:58: + 68:ab:c7:f3:ee:7c:5f:02:35:14:cf:0e:6e:cc:1e: + f9:75:a7:91:92:23:e0:2e:4e:0a:16:08:8f:bd:59: + 87:62:12:db:3a:4f:42:39:67:62:40:00:dc:52:66: + 2b:b2:ca:a1:c3:2e:6e:36:4e:4a:58:6f:23:0c:db: + 04:8f:74:eb:d2:20:10:e3:ab:83:5c:4a:ae:59:b5: + 82:36:e6:7c:61:65:5c:c6:22:f7:32:c4:9b:44:8e: + f1:d8:54:e5:c5:8b:63:c4:10:29:89:1d:b8:41:7d: + 86:ea:f0:7b:e7:ec:a4:64:6f:4c:f8:5b:4f:b2:11: + f3:25:b7:c7:0e:e4:6a:1e:d3:50:3b:39:6f:78:cc: + e8:c7:bd:53:43:3b:7e:b6:f6:de:ce:f0:4f:7d:69: + da:63 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + 2E:DE:C4:D7:37:06:30:C1:39:C8:81:CE:8E:6A:55:8F:9D:A8:F2:88 + X509v3 Authority Key Identifier: + keyid:C7:83:AD:B8:8A:A9:78:47:3D:C7:69:FC:23:69:4B:D5:67:FC:C4:DF + + Signature Algorithm: sha256WithRSAEncryption + 6a:d9:7c:1b:2e:c3:bc:70:89:9d:d5:d3:ae:d3:95:7a:53:4c: + d6:fb:f4:0c:f5:68:cf:f7:88:55:50:98:b2:43:ec:9e:c4:73: + 4b:29:cb:cf:c7:ed:86:e3:89:f8:fe:a5:2b:f3:4b:6f:80:7a: + f5:fe:89:3d:c9:0c:da:2f:e0:8a:43:1f:e1:2d:8b:38:e4:81: + 84:94:b7:2c:3d:4d:ad:57:44:19:14:e9:5c:52:09:ae:99:93: + 83:8e:e4:d4:67:31:f9:a1:93:1c:56:e3:e6:d4:88:80:57:b4: + c5:86:7e:d9:d2:ae:08:fa:db:62:71:99:cf:2e:c7:56:f9:b9: + 0f:41:0d:99:ac:9a:1d:99:b0:5b:f3:9b:d6:6f:dd:aa:8b:d9: + 16:52:50:4f:68:bc:7b:9e:db:cf:ec:53:0e:59:06:0f:c9:6a: + 50:13:68:a9:57:7b:c8:52:a5:d4:c6:65:c1:8d:a9:c7:8c:ec: + 48:8c:c5:04:0e:94:69:26:4a:e9:3d:4a:f3:9e:3f:f5:fe:36: + 29:51:04:7b:a2:11:2c:45:26:cd:cc:a8:09:cc:a2:20:9b:64: + 2f:4f:d0:bb:70:ae:34:6e:ca:7d:79:7d:d5:48:46:14:7f:52: + 47:57:92:e7:98:65:37:d3:bd:c6:50:16:a6:17:e3:8e:97:15: + b2:4f:8e:c0 +-----BEGIN CERTIFICATE----- +MIID1DCCArygAwIBAgIBAjANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDQxM1oXDTMyMTAwOTA5NDQxM1owbzELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEVMBMG +A1UEAwwMdGVzdHNlcnQuY29tMR0wGwYJKoZIhvcNAQkBFg50ZXN0QHdvcmxkLmNv +bTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKfbo1HR8Hxv3qMqNGh1 +ufPv+Hb2UA8uTzyMPv1gG+331X2sQUqbv15jEuTLcjxoOYuCOmiFQPHJpj+I75cP +RrIuscOj8BUnWJlpejaPaICS2XOwzMG5Ae5rUkPoLyVcjhlu6Y3A4kFYaKvH8+58 +XwI1FM8Obswe+XWnkZIj4C5OChYIj71Zh2IS2zpPQjlnYkAA3FJmK7LKocMubjZO +SlhvIwzbBI9069IgEOOrg1xKrlm1gjbmfGFlXMYi9zLEm0SO8dhU5cWLY8QQKYkd +uEF9hurwe+fspGRvTPhbT7IR8yW3xw7kah7TUDs5b3jM6Me9U0M7frb23s7wT31p +2mMCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0EHxYdT3BlblNTTCBH +ZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFC7exNc3BjDBOciBzo5qVY+d +qPKIMB8GA1UdIwQYMBaAFMeDrbiKqXhHPcdp/CNpS9Vn/MTfMA0GCSqGSIb3DQEB +CwUAA4IBAQBq2XwbLsO8cImd1dOu05V6U0zW+/QM9WjP94hVUJiyQ+yexHNLKcvP +x+2G44n4/qUr80tvgHr1/ok9yQzaL+CKQx/hLYs45IGElLcsPU2tV0QZFOlcUgmu +mZODjuTUZzH5oZMcVuPm1IiAV7TFhn7Z0q4I+tticZnPLsdW+bkPQQ2ZrJodmbBb +85vWb92qi9kWUlBPaLx7ntvP7FMOWQYPyWpQE2ipV3vIUqXUxmXBjanHjOxIjMUE +DpRpJkrpPUrznj/1/jYpUQR7ohEsRSbNzKgJzKIgm2QvT9C7cK40bsp9eX3VSEYU +f1JHV5LnmGU3073GUBamF+OOlxWyT47A +-----END CERTIFICATE----- diff --git a/test/opensslcrt/multiLevelCert/server/server.csr b/test/opensslcrt/multiLevelCert/server/server.csr new file mode 100644 index 0000000000000000000000000000000000000000..64037cc9b7f13aa23a7aec95eafe0b396a33d471 --- /dev/null +++ b/test/opensslcrt/multiLevelCert/server/server.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICwTCCAakCAQAwfDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQH +DAJOSjEPMA0GA1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxFTATBgNVBAMMDHRl +c3RzZXJ0LmNvbTEdMBsGCSqGSIb3DQEJARYOdGVzdEB3b3JsZC5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCn26NR0fB8b96jKjRodbnz7/h29lAP +Lk88jD79YBvt99V9rEFKm79eYxLky3I8aDmLgjpohUDxyaY/iO+XD0ayLrHDo/AV +J1iZaXo2j2iAktlzsMzBuQHua1JD6C8lXI4ZbumNwOJBWGirx/PufF8CNRTPDm7M +Hvl1p5GSI+AuTgoWCI+9WYdiEts6T0I5Z2JAANxSZiuyyqHDLm42TkpYbyMM2wSP +dOvSIBDjq4NcSq5ZtYI25nxhZVzGIvcyxJtEjvHYVOXFi2PEECmJHbhBfYbq8Hvn +7KRkb0z4W0+yEfMlt8cO5Goe01A7OW94zOjHvVNDO3629t7O8E99adpjAgMBAAGg +ADANBgkqhkiG9w0BAQsFAAOCAQEAC2AVrLOyTNhAdoVMzdqlXNLBmuoKSJdJdePF +uM3jkNkgfV77opTDFVL2nYxTLddfUpYq8xMpqK2shXWz5nrjn+XbqVqDyP5F6oVl +Rp0EiTKPolvr6+qREnquF7AKRn6qZkSst3/QbdFJrIZ6FjfReFxR+8d+MkhdKcUL +hX0FD8/njwO6twXWqBADZrV8rCsfuIER8+nCVCo827J7ZPNtvli31aFEi1QXo1Em +9Azvn6EULZyLUdvgu5hANyXNRa0yTY+QGZ37lTHTAuogr7PwCW2PTr5AVLU3oaDA +KJT1hWaJgpSKRI9nFut7BVaGjrRQHNH+HTN12mduJFaIIDt2Qg== +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/multiLevelCert/server/server.key b/test/opensslcrt/multiLevelCert/server/server.key new file mode 100644 index 0000000000000000000000000000000000000000..71ec9e7e81db54ff9436879ed610102db7e881fa --- /dev/null +++ b/test/opensslcrt/multiLevelCert/server/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAp9ujUdHwfG/eoyo0aHW58+/4dvZQDy5PPIw+/WAb7ffVfaxB +Spu/XmMS5MtyPGg5i4I6aIVA8cmmP4jvlw9Gsi6xw6PwFSdYmWl6No9ogJLZc7DM +wbkB7mtSQ+gvJVyOGW7pjcDiQVhoq8fz7nxfAjUUzw5uzB75daeRkiPgLk4KFgiP +vVmHYhLbOk9COWdiQADcUmYrssqhwy5uNk5KWG8jDNsEj3Tr0iAQ46uDXEquWbWC +NuZ8YWVcxiL3MsSbRI7x2FTlxYtjxBApiR24QX2G6vB75+ykZG9M+FtPshHzJbfH +DuRqHtNQOzlveMzox71TQzt+tvbezvBPfWnaYwIDAQABAoIBAQCPFoHAK5Au40YM +HNwT99cOBI/vCMTyS+2rlXnUj2r/jfZlbMMzkFSvZxEiC/NTXx0+uUKE+qKD+ftH +yblDMfh3x6otNcBgp+u0yt8tR04z2/qVzi6dLNJipQW5cWFPHfjb4VoiRjwYq/5+ +ALMFputufEVCw/Da+8R28OL8iqx9ixNN/bfxktbP5NqGkcEUun5wE+/0KnFzylcV +XyKIzFFun8U6P0oupjv+aoB12SET/X75Xgh3Elx5/Q9FedHRXGjFcam6NJhOG79s +G6IkyciKP7Pf1SYFX6b5ps2bo3oPKWVvJLGMlxSuxJpqmoTQOzochoZF8uak+hkQ +eHl07+TRAoGBANBEVtTyMkRv6dGgmRCPsqolhxsOkHHd5+YOThg5d3EDWRcsUidR +VzwU9/GMrF3grV7FuQk2GYqMRDjFg87+/0O0H6vANxXjNc6sv+dblOikbAJoLksr +245OeCgd383K5+f00pOHvd8KfrxwQPqarpvW2JORXHludgZRgx8FGulfAoGBAM5U +YSKDw+d/D6+PHazakLucZNc7hUUd8Gvh8EYdUxObX+Dr8UIjV/1qqYZzhPUNDpn9 +//cYq8ndAQ7gCTANESJc//ikqDBXHJ3AoMwpNX86ozNUyt6+eoQx2tizgkEMzndo +BVt7A3RjWjn/+bJJX6mIsjg8salQrMW1bOyECHl9AoGBAJwdfhFl87RFR7oxbktx +y/Wq59mqUzBnrPtQYc3a1ePLJK8wM+zxFjkdZraUQmikkJDoGcoD2aV3e3Qq6qDx +mJtBnDP8g85OYPkpmTht9/NjvOsY+Qq0N4I24+7+ZdM3dBr19BtOt09H6LSMWMkB +xj1fET2cyvrjiGk4FNfd1cx1AoGBAL+xdW1zrhbt3czl0lQ93Cnx615sVi0Y273f +dDQwGnck67c0fjlMTPuMlWPs/6IMN3yql50itrgdNFZ1nxOdkEW00bxYfkorJNML +nFkSEDncaLPQG4tGvN0E1KZwYJu/IjOd2Rxc9aC0jadFQt95e/8umSXWfdkostwc +6s3y/UyhAoGAJAG6S/k1Xg0TbVxeZvMlnmfEc7WKh1QCnsH9Va/VnPXlIlTWmHu/ +xBtBb0kC0IF5zWgHhSADz3SHDsYZ87SEqoQxx3BKzd+EU52RF2x74GQDY49mQlUR +zMnxYP3TUSEfkTuSpT8ZwuxQ7KSk5nREDzKIVfcxgFw/cmrW/eChT7w= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCert1/CA/ca.csr b/test/opensslcrt/normalCert1/CA/ca.csr new file mode 100644 index 0000000000000000000000000000000000000000..d1daee988dbb89be209fd79b626f0e491ec4886e --- /dev/null +++ b/test/opensslcrt/normalCert1/CA/ca.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsTCCARoCAQAwcTELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxCzAJBgNVBAMMAkNBMR8w +HQYJKoZIhvcNAQkBFhB5b3VyZW1haWxAcXEuY29tMIGfMA0GCSqGSIb3DQEBAQUA +A4GNADCBiQKBgQDFR8aCTgT27M5MDDi5cM/BlJ+kGVMCPGlNCbE7/50pCapaLQVh +q+NNl0InhuEb/zmmg7CVHboWiqdBiOcn/inoSZAGnZbstTX7LrhjAISu4wVQYpL2 +k7SMxecZx6M7XUeamasnZrOqgWVD/6wweTPka2fIy5OzU/kuMXCHF3jCMwIDAQAB +oAAwDQYJKoZIhvcNAQELBQADgYEAwMS2H9xmphiRpxO7YrOcHF4QmXIIYh+ibDCM +G3y5254jkNCOYMz18aHFb9LpQq7+8quxMlEsn7G+1G+YJnUIDr4c8TxaRP+LRzZW +mcJsiEx+bbH2uhVpyCY5HXSUFMvOiam+h3fULVcqtlKU2DQzW+8TBk8gMFop+4fh +FIM1HeY= +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/normalCert1/CA/cacert.pem b/test/opensslcrt/normalCert1/CA/cacert.pem new file mode 100644 index 0000000000000000000000000000000000000000..4aca16799f7b73d250cd4a4c402727be41302ca1 --- /dev/null +++ b/test/opensslcrt/normalCert1/CA/cacert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCzCCAfOgAwIBAgIIKu+kkFNKgPowDQYJKoZIhvcNAQELBQAwFDESMBAGA1UE +AxMJbm9ybWFsMUNBMB4XDTI0MTIwNjAyMDQwMFoXDTM0MTIwNjAyMDQwMFowFDES +MBAGA1UEAxMJbm9ybWFsMUNBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEA0wG9EmrYsfE3e5XDD/qf9a63ZZH0x9oYFsiCcTdCGn/I0ydXxLwrgtQGUcBz +r9KMI7vhaoX1IPtY48f2A9ov/+5gZ+2VaxXZ2Fs0/nFjpO3ED7mTp77oojQX82xR +NGT4pZXHViLz55/RT3bFjCyKBYQ6z9IOn6yxUVFKCsxjWAlG+oz0N6Jq1GyEHY80 +b8MSwGW1e1hDxvxlZEbCHGH165e1fYM5goxkgkf55TNEZdTgieom2v+vF86dqTkM +mfIX4dJ58Qg4Bmx/MpT75eS4+qHkEEBsnXQ0XZC7QgibPySoM2MLwx/WCa+RW0Xi +Oe24kgOfRlq8tG4NfEotatc5PQIDAQABo2EwXzAPBgNVHRMBAf8EBTADAQH/MB0G +A1UdDgQWBBQaUG/8sfx7Yiz+f647bbVMQ1AuujAtBgNVHSMEJjAkoRikFjAUMRIw +EAYDVQQDEwlub3JtYWwxQ0GCCCrvpJBTSoD6MA0GCSqGSIb3DQEBCwUAA4IBAQCU +DKQh8KKJLLYkhkqVJeEDHJWEGJDkEBQmNpVJI/NIgvPGG26M/gd6RxVExkM53kcY +KgNUZYg5c1YBOPKcpe0HkdeV3OeWJv0RlIYhfIeeE4Iuc1L6Ut8QCkyw+KVubXsX +mC58y+RkTY1F4B/X6LAeu2CZYE29nOQYhfBbQAOJ8HQVuqSA7VG+YWJVByJXnNgs +TpD56UlXUi59MXyVXp4PtnLuiH5Svby1gmL/0ssIpWFYHK5OfnoVzyicGDHVzmvx +yNilngyOai0qYSeBceIZPtABY001wJs2TRHIE34oVn08V8RU2srkHJQ0YhgFCoUw +v8pjkPTtEheJN5Yzlbc+ +-----END CERTIFICATE----- diff --git a/test/opensslcrt/normalCert1/CA/cacert.srl b/test/opensslcrt/normalCert1/CA/cacert.srl new file mode 100644 index 0000000000000000000000000000000000000000..2be4f375a8272acfb7d9ee3cc019e312dc51d193 --- /dev/null +++ b/test/opensslcrt/normalCert1/CA/cacert.srl @@ -0,0 +1 @@ +5075711B813D68F7309883F4E543B8C264FDC9C7 diff --git a/test/opensslcrt/normalCert1/CA/cakey.pem b/test/opensslcrt/normalCert1/CA/cakey.pem new file mode 100644 index 0000000000000000000000000000000000000000..6566151a0cbd3073005b04da9ba3a5e72ec546b0 --- /dev/null +++ b/test/opensslcrt/normalCert1/CA/cakey.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA0wG9EmrYsfE3e5XDD/qf9a63ZZH0x9oYFsiCcTdCGn/I0ydX +xLwrgtQGUcBzr9KMI7vhaoX1IPtY48f2A9ov/+5gZ+2VaxXZ2Fs0/nFjpO3ED7mT +p77oojQX82xRNGT4pZXHViLz55/RT3bFjCyKBYQ6z9IOn6yxUVFKCsxjWAlG+oz0 +N6Jq1GyEHY80b8MSwGW1e1hDxvxlZEbCHGH165e1fYM5goxkgkf55TNEZdTgieom +2v+vF86dqTkMmfIX4dJ58Qg4Bmx/MpT75eS4+qHkEEBsnXQ0XZC7QgibPySoM2ML +wx/WCa+RW0XiOe24kgOfRlq8tG4NfEotatc5PQIDAQABAoIBABdLvFehsz3mD0WW +Mbx8Y5eoy/erRxafdgRh68DbJHGvhPkHAXIngjhF91KkSEFzoeQ6FHeBxZ/4EAbJ +51fiiWUcqMkgmlwBxBTWcQHce+9UDs7borkk4yMxVCmcXlTGBwuk/yZ74mjCpMW8 +q6+/pRdZ03JpUxL+6m67ZkzN5vPrkvejFTTksDliHs+TWkYZomr11W45LsIDIMQ+ +prWhAZvqWYy0PDiRUs0ue6QtQIEswCbEueFELdlRB7QXkz7R2qB0MZo3gRd03aFz +PB/CkPQ+X01XBAejOvcokN5AQy1XC6FgVjdVbuIj8Tfgs1Rqi5MDswsmfdbcxW4V +YZFl7+MCgYEA6msnFzZkJXT37QmliAkFh+yb7nFYwsM4twet1MO+M7IAmwN8n4M0 +q+/t0kQdhu/w1tousQzYDk4kmN4hCan//HO8ZaxdEO1iDrjOdjPkAPfx6lJf3Bed +UGUpe6O9BqGJZuSyj4TZjDtjF1v8Xx7UWCRZQDvu+oasjfFV4j/paycCgYEA5m7P +ZJqSLFINLpcnOWWO8dqOBc2ZN8TX1qXJPd0d2biFfCBpvfgkGsWS7hgKPBwaKM2v +y/5yqtYNQxPT1nwRr8owitDYQLKLEHU3UIkPZPXHiaofkLG8+F8YitPN9mNBUT/U +my2RxCav2XGx4Qnfec/tRl4hoq1RisOEhPPtxvsCgYEA3gR6U3vKQceUgMXgJwUU +XfX0gVKM1Hl9H6yAMMDrRZ7S//2/bHwhyK5GuhyVMpXRrkuaaUlW38WW18mZ1MNT +lVAwIMOsqTmK45KYyE7BJUGvt1QpQhSDg/8r2NPtVXhs4Cy+CTuzpyMf6KHQVm6m +gox6k2GwJ0qh5xWpV81cT4UCgYBXgg1Qb0LDggVvhAdpTKAUSKNWoNVm0GIHKb5a +t/X2EJTgpo6BjhJn+E/sC0UXvrRQeowgM+jn55Hxvz3bXhJ5Z5c3oEf/ic412c2/ +z80A0jWMmgaStGzHBZYUYor83oSXjl328D9C9k4hjYO3qkArykKZTnYHiTpGMCnL +M7dAsQKBgGTcxwF/FtbOlKR9ugLGJHAz7irI8liqUd3S9bLYC9nr0/6PH/MTJyGN +kou0VqDgPDbq4JwZkvffNa4UcQrme+00+K+06T+os4szOWGuc+sznBEvdwJYHAGa +beNAzimdAL9S49OUm6GEusC6EbOeS+zp0h38RrOWogrTRVM7XNen +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCert1/client/cert.pem b/test/opensslcrt/normalCert1/client/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..bae8c70a0ede04fa6fb9a172098bf828e32f88e9 --- /dev/null +++ b/test/opensslcrt/normalCert1/client/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIISA1JYsFZ4hswDQYJKoZIhvcNAQELBQAwFDESMBAGA1UE +AxMJbm9ybWFsMUNBMB4XDTI0MTIwNjAyMDcwMFoXDTM0MTIwNjAyMDQwMFowFDES +MBAGA1UEAxMJbm9ybWFsMUNBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAr8oU4ojxSYTQduO2FJr1aZOxF4FT2dgkebLgpJ1FEtoSeNKeEl3Du3Uv25og +gM/bkqD4HVbNuasZB5QWu2vfVyvJ//1KUaE4rjO8atnLTot7PwtuAaWfZnu/3/0S +Q+fB7f+KqBer8wDgsiCC/gIt9a/DN7D89d7uG5AnXXCS5BISB245MOqH81tuIx0j +AAaXi1jbVChLQ3AZAKVE2K9cRC0Y3YoHBXcpmrWilbkvPqy5SKda2f6Uc+k60f2T +3tkBmwY9j9kBD3q2uLLqw8kFMnDWSqryYJ8If9FD3CVWdVsGO8uN5eSgjsJRAweW +MySy3glsh0J/PL3sQ/J4HQL03wIDAQABo1AwTjAMBgNVHRMBAf8EAjAAMB0GA1Ud +DgQWBBSy9AeEQv0G8p1JufQFgFIgKCPdhjAfBgNVHSMEGDAWgBQaUG/8sfx7Yiz+ +f647bbVMQ1AuujANBgkqhkiG9w0BAQsFAAOCAQEAb7N5WB91Hih+QdiwXZMpYCZ6 +SXo7ayrSgJzOeKWa6V1mUixHW0lGX0V0kH2fzpHckg8Al1rvtfmO5JeAnZOGfJlJ +Hl/2hVGDZ60NijVt1VIyUP3xYApm3OvBpWGJGCXLFvBuQwmHqhB/Eb0mp4/dAuzT +tmAXydjteHjY6607kmwbU0VWDyClV5WFD1QSQPk8h/GhnMga6TsH7rfmGKgPh5xS +ahPcdAMTs2+TaLZFPAn0podOQBo9hmkzqXZSeRuAgskEJr1ZuTK5IC5UVJevx5Xq +izffuKPY4drwNJWRaOx9fmHXqQ6QiTGarok8/lkGyH6KehfDs7x9/WPfP91qSA== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/normalCert1/client/client.csr b/test/opensslcrt/normalCert1/client/client.csr new file mode 100644 index 0000000000000000000000000000000000000000..2b2fa04943e644ca150d975c0017e8e2e9f10bc1 --- /dev/null +++ b/test/opensslcrt/normalCert1/client/client.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsjCCARsCAQAwcjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxDDAKBgNVBAMMA0NsaTEf +MB0GCSqGSIb3DQEJARYQeW91cmVtYWlsQHFxLmNvbTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEAnoE0d6pWQuDAj3Ec1jMCjU6aN3vSmo6emIqWGPn5EYRMY+gp +NGNNgopKEkd/2Hr8eo/NVpjzzB2eDAwWlzwvqbghrMDskJEUblnJ6Xat/jAm2a6w +3kCR/ZJ/USGKmbGwDU25cx5NBkMzyIxydqCxkmCYkkgr4OODML9FFxVSaP8CAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4GBAG76gRwIr6oXzTsWkK67uaU/PzlnuUYVQ86M +maA4cY5J9X/t7hXOFcSpcKcSrVCgVydKV407yshjiiq4teYGOsx6fizPyyT+p28O +7mbssL03JtgVwJfJSXrUe7lQUuHFTeB99uwMfzYE9nqbhY5jJLknEpOoza3JbKFG +S6Ub/LL0 +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/normalCert1/client/key.pem b/test/opensslcrt/normalCert1/client/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..b9e26ef8e7cdc965ef46ad7e71c2ae6e30c776a5 --- /dev/null +++ b/test/opensslcrt/normalCert1/client/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAr8oU4ojxSYTQduO2FJr1aZOxF4FT2dgkebLgpJ1FEtoSeNKe +El3Du3Uv25oggM/bkqD4HVbNuasZB5QWu2vfVyvJ//1KUaE4rjO8atnLTot7Pwtu +AaWfZnu/3/0SQ+fB7f+KqBer8wDgsiCC/gIt9a/DN7D89d7uG5AnXXCS5BISB245 +MOqH81tuIx0jAAaXi1jbVChLQ3AZAKVE2K9cRC0Y3YoHBXcpmrWilbkvPqy5SKda +2f6Uc+k60f2T3tkBmwY9j9kBD3q2uLLqw8kFMnDWSqryYJ8If9FD3CVWdVsGO8uN +5eSgjsJRAweWMySy3glsh0J/PL3sQ/J4HQL03wIDAQABAoIBABza7rYHb+v3KYJK +e0nt7G4pzdOjk/sEosAjSlgFxMqFEgEpOW9JsDu7ebacxHcV0jTDHYHbYiZMXUMZ +iOhiRaVOd82rFcxJID+1ba6bFOdODVYE70gamAwU23h4rHXfEL3wyMwyI98zUVmL +7VSlUPFXqj4gXErpDsFm1pSySjpg2ELU/JUbuJ0H9NdkxnZE+DPtlRnDZDe5Ipu3 +k6wUBMnOkhdjtJPQCTunBSVWViLOKUCOw4aodX4aZbj3n23ZwXBWsElTrmI6DUQC +Xz02FcdiWmNSZjSMAzKpZsbw9AuPlNzAIdJp1G8E26FdJKywZPjDbnyQhDODU+rT +3t1D2UECgYEA7F4WlFmQuZzj2bVQ7xaci9vNIWn4I6GBdeqWZ6clO1wiOU8c2fKu +RMnjrnC+BXfUn63MPbChYP2MQD/EWvk8f0z575cM8OawqSW8eORV0RpBQH4OW3qN +/fcTv2u99hJvkRZli9uFQNADCQWANtS+etNQVPgqY/kHgUWjNMIGnP8CgYEAvmPp +qQb4IkqeVxZ61BaUHYs4igr8CC2aqFU8EayPu2B0YCoEFuHcAndaczY6+bYdoCDE +4rK+kRcRHDgo2zqJ1FSxPAlYQAZyoo6+TURjm9iY4mC0kntO4TTfl5ReEnyU8AST +GwGluLWyiMzUZwcF55y9WIJKk9nd6ysNkST/SCECgYANeW+gFVsCucbjakqq1ocm +EemntfrHTTWsGytnzUd6Es6ApdnnMsZsdXXdQ+ARP7uHPskd2yvX85xDLV9sMDka +qLC2z82VDhGUEuqjmaCqwuxlWMpIvLvLdoJRrzqtBHAHvVccme8GutTxdrbQaWBS +cClsLpl8CX48CFQ9dbIa0QKBgDQxU6bgoDlWV13sbbOilOwcdF1zcXAUgGdRJmx5 +79UgNgb+vaxAMn4ClisrXLOD908kbJTxB0jjF++yfZDL1Wj95rBcp3K10cs6cnTg +IKZilFbKx/W4FAGs6va0160dtf6uYl1u8C1Ysh7KeBfeVwINd+LBpQwvF6UWYlu9 +1rXBAoGAV6hvewLU1OwTxRzwvHa/1OR9vrRi9CMUPBNkGPNs8+5j4WuqMrBfp7YW +Zdd4mY1C7/SaCLoM410MEhYfJ5GWahlGKfySVh1lSccb24HbM1WJvDd54d86PlZH +1zzt8i9m1AFM6fH7Yf+oczJkxAUCx7o3DAWXgRxbJeLgI0Thqws= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCert1/server/cert.pem b/test/opensslcrt/normalCert1/server/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..e41cdce015c24ecee0e362334325f4a99db740bd --- /dev/null +++ b/test/opensslcrt/normalCert1/server/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIIUz0GJG4axRMwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UE +AxMJbm9ybWFsMUNBMB4XDTI0MTIwNjAyMDgwMFoXDTM0MTIwNjAyMDQwMFowFDES +MBAGA1UEAxMJbm9ybWFsMUNBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAqqBDFR14ajk5kzlbgijk918CL5+tRtH7oYO1sodfCORnFQ5xu4EH2RwkGxVx +R1XIeECySzhenWMUsXuIYBo609Qh2CHtDBpIlDXZ4nKXzZdfJG9jgPHmxYiphnJq +MHlG6bUc50OYngzGrRy3Y4AL59BEzFnBFhe8IqnXouIsOBibEqxDv59kKOe5EkIO +5A7wNQ2d65QnJA9AYfGuZqUgKhJxv3Diem9iv2S2L3A4ZB2J0MGErdIirXZeDpMM +Dx3Q0ogFriwktAAwyDmVlmcobwDrBGQ1VNON2ndV9+BaMjR9Lf9sifSXWadDlyRl +FRLwq13LId3ZJeuvT6qTz/w8PQIDAQABo1AwTjAMBgNVHRMBAf8EAjAAMB0GA1Ud +DgQWBBSkWsdyW73iEEfdTJJz1Jsv99E9iDAfBgNVHSMEGDAWgBQaUG/8sfx7Yiz+ +f647bbVMQ1AuujANBgkqhkiG9w0BAQsFAAOCAQEAgCajECLbbQS1FePUOhlaHP97 +BKx62f6lHrt+NhTsMMXW4hv/zt0yy9CBFGC3CsJ/bFJ19bFvQU9FWeJY25A5WEbe +0Ro8yOVJQ5aU8Cbs2dcDpK6T5MHM3ymzr3VJoMTIpMnEYZgLm9gMdhCkShXpklm9 +mYdeEysxtwdJLOFgjLV/qpRpUP8sX52f0ZgYrdqjV6/pTyWNep+SHqWb2BaTm+fR +gkOgBovOX/TCBAu1pt866qGO39nSBln+rWWCDkM8haff6vF9VQ3OrHpe+jMWq3va +JD6dWDVr/Cum5G9ncBrtyGCz4dWjr8xQlEYyihrIiBZN/DRHddj/cGEXkNz8Yg== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/normalCert1/server/key.pem b/test/opensslcrt/normalCert1/server/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..0082ec2f1d5519885e708dd46e96e748d0254897 --- /dev/null +++ b/test/opensslcrt/normalCert1/server/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAqqBDFR14ajk5kzlbgijk918CL5+tRtH7oYO1sodfCORnFQ5x +u4EH2RwkGxVxR1XIeECySzhenWMUsXuIYBo609Qh2CHtDBpIlDXZ4nKXzZdfJG9j +gPHmxYiphnJqMHlG6bUc50OYngzGrRy3Y4AL59BEzFnBFhe8IqnXouIsOBibEqxD +v59kKOe5EkIO5A7wNQ2d65QnJA9AYfGuZqUgKhJxv3Diem9iv2S2L3A4ZB2J0MGE +rdIirXZeDpMMDx3Q0ogFriwktAAwyDmVlmcobwDrBGQ1VNON2ndV9+BaMjR9Lf9s +ifSXWadDlyRlFRLwq13LId3ZJeuvT6qTz/w8PQIDAQABAoIBAANHdDKWL+e+tMwP +CKKuPYRBh625oF9hjrfya44ESv/FP/PoJKhpWOjDwS/iEBzCWclx2NtYOH1G17F3 +gc3zAxlqFMJuGHuoj6eFh2vyLtetJo9xMpTaUUiPkU57qW3IDuRT7jvFkUvMTA8s +o5uebNwmJQRWXz8Y8q0p3XKMbTQgwNYewisAoPkw6Ra3T1yrtus0vtqVQBmg7G5U +flMDNjrelaDDb76whPbK4S9iZgeJpnOVevox+y1WRSJQsbSIcrZavhnaJakMUMcq +OXouzYfK2RFKS7Cx05w/dUwhWngeRWuzLI8MH0UFZzGoWvtkNd5GWofzvs3H/xe9 +Ere5UGECgYEA7FdKsRzdLZzl8JusLxtYnZaZDvlmwjiWsyaHgWQqrtNXeVjarVLk +dUtkMxATdfOruS3y8OjooDDCNEGsvmrQGFBqV+Y7XZvsOPvvzGTC3PNPwu2lGG9G +HuHxi7+Ood+sdWZQq2kGfBLWiHygN1NMgvZJTPW7XW6knUAMvvrSRHECgYEAuNGe +SZ013sb5fo3g9gtdR1q857Ont3dOQzIUkvDjrf50Iul4U9IFJ4VPez5Xfg6KaTur +MsfeJ0Mxed2Z+g7HN68/baJ9u5IJOxRXegt5d13wmcg3V6OxwRIebOhoxO4TDI/4 +r86TL+C/AQLTuYvflFQvDp/vG/owbuage4sSKo0CgYEAl3CP0dGXAEVLKdP2jvDM +5Z28UdYJvGYaWo1Twtt2ZjPSF1WSOgGllmtKt1WTwr7yyGPjCe+UMCFOL8HteM5k +rU/J9Kz4WVEyGomE7Mmb+4yFDXjNk0yp99v7tPp34M28ajW3dz6DAznm8eo5l3nu +yvQZoBDNkF73aDPEdxOtbGECgYAEjqaoigfaBm4AgN8bMgyKxeIClJt+GqMDZSi8 +ttPmZm/WrIsbBgK6hr7++vqNUS54idAe4G7J6Q4/QweY8uRWqyuVOHyBM4imvEmr +6LyGBPr5z8YNkyu+5wN5DAIJWRV/Kc2oCB/4/kG3dKpj2N8aeTjv02HB/tlS4rzy +TuuTwQKBgCbc8JvgJ4B4FmKMtdGZtE5F5z+8/VYTQ8lK6tORPzRZujY7jY32kE4p +cW4XiJy1+TAQl2uekL/8hlkTAi3pRJaUTjqcwrIciwYMxV4pIdiSUsUTLc0tcaMk +SaVisaKpWm0LjyilYF90oJie5kUssXxlh7Nq4O36ACQJslU85+g1 +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCert1/server/server.csr b/test/opensslcrt/normalCert1/server/server.csr new file mode 100644 index 0000000000000000000000000000000000000000..22dd202ade1cb34d6e778e7e4b7890fb9c6857ae --- /dev/null +++ b/test/opensslcrt/normalCert1/server/server.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsjCCARsCAQAwcjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxDDAKBgNVBAMMA1NFUjEf +MB0GCSqGSIb3DQEJARYQeW91cmVtYWlsQHFxLmNvbTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEAzRL036OUQAx74qhLXCpiAoQjoGEkFdj2MDIN2FaU9i2qgDvW +yeRidWDK4gnjq8jItRcJFDdnQV7at6Qg5IyiKChIzEMDVx6muTet1+k0UNWt6OtB +iMzj3FQ3glaAoGNUKykJTHq+rosrTWpcuahH7acEca+yNoG4E3UIdk25x9cCAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4GBAAaFTb861L7Glp3Br8vZBhDHR+S6TztMaUXg +KvdWkvzvXAJdhW8Kh3FrHBn3Xi10moSrRIvGMPG4K5/JIg4lbhBeM/F1HuimaBps +rQzFgel3QUV277HLgfy37NWQEq5hYwAvv6vyNNJ3xAj6yiekQVc3CXjceZghfBtK +KP7fXj0M +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/normalCert2/CA/ca.csr b/test/opensslcrt/normalCert2/CA/ca.csr new file mode 100644 index 0000000000000000000000000000000000000000..b630c2f7735e27e8e5f0ab141142cae3ea070e98 --- /dev/null +++ b/test/opensslcrt/normalCert2/CA/ca.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICsjCCAZoCAQAwbTELMAkGA1UEBhMCQ04xEzARBgNVBAgMCm15cHJvdmluY2Ux +DzANBgNVBAcMBm15Y2l0eTEXMBUGA1UECgwObXlvcmdhbml6YXRpb24xEDAOBgNV +BAsMB215Z3JvdXAxDTALBgNVBAMMBG15Q0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQCuxxFSBSgW359L2ANVQn8d1qEznEtax3j5NGQ2U3V+mV4ghK7b +10GYT4Ak2Txt59CRwXjggZt/GS1/9aEXg6I8zXnU/+MPCsG5hIg1qWjl2Pje8Nzv +3SlR8kyX9TgYpGJq1zrgHoAFWPPytnxK0qTuxJ+MOQc+6czPY4NwEH3t/5Vv9hza +Th4BSIRXS82vVzzN0mhrPypFImc+qDK/Vy2NhDeH6w49c5RCVmDA1dnVn53kq9Ne +uhRpfKZsaizGSPVMzwM8lOkDngajVryTXCcsFOWLyS5ow1n9X+qlz6b/WBt5UeNA +eO8hZHcB8tFyAawyZ1pAXE5SwKdCXDKAUEZRAgMBAAGgADANBgkqhkiG9w0BAQsF +AAOCAQEARM6tB38FaegEA6vDXlHr14r7Ka200wu1JHDrFCwhktg7poQM5UlEdORu +trnC8XpAArzmY2/McVkc2f/IqRmpBlL02MfdxgGBac3b8KloDYeac/5kRa8VXpBz +z3GSzuWhII3t9+DYCjdz7opu5BTLWm8Llv6iUcsx1vyji+i00kOxz7WDDqxPVybs +aQJAhO3tms9f7nyaItN0uzkx0sMLlxQoKNTy/r0p88BwEcmP7iZ8id4I7TAJOLnk +E8Gh+ubSh156fgJVVWjt00d3ByPmKkvVpMLedjEwIOKO6uQFTOgCyG1RW8w9CgRk +ONgo2f+5ZPRzZENaowm+fLNJ8fZYSw== +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/normalCert2/CA/cacert.pem b/test/opensslcrt/normalCert2/CA/cacert.pem new file mode 100644 index 0000000000000000000000000000000000000000..e7dcd37e9e492c80b2071ea36addb7b745d2b95f --- /dev/null +++ b/test/opensslcrt/normalCert2/CA/cacert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCzCCAfOgAwIBAgIIIpF8oMqMJMowDQYJKoZIhvcNAQELBQAwFDESMBAGA1UE +AxMJbm9ybWFsMkNBMB4XDTI0MTIwNjAzMDQwMFoXDTM0MTIwNjAzMDQwMFowFDES +MBAGA1UEAxMJbm9ybWFsMkNBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEArLeYi4I9e8ZF1r0k9hw9c/UZjxTLvwAQly/cxMD6tqSWo8XchBX83OGqDels +a+55awoNpfj8xpvQPHvqOKyCKICO3w04FZkd1+Ihni30xDdoe3OugbzJQ6oty32M +IIs7yLgrItb10rWrcVTeo1aXWKCuW5c7DfVcImaCrDsqbYkKdyxlyo0I98BvTaGx +AF3NXduFTrsPrZ4qQ/fHNPrmeU5NhPXq/KjjnBtoTEm+HyUbTMKeSOxxTs0A9Uzj +owDFCgwSbqyz+rA4GKBtJciF5nX/F34ZrWuoBKQ0eSHMiIayuOSpo28ew4680xN/ +5x7PYq+q9Y4/HhksWDpdQ/UpaQIDAQABo2EwXzAPBgNVHRMBAf8EBTADAQH/MB0G +A1UdDgQWBBQ8cJmO3qKHdQQUmO80sbAnQZMudTAtBgNVHSMEJjAkoRikFjAUMRIw +EAYDVQQDEwlub3JtYWwyQ0GCCCKRfKDKjCTKMA0GCSqGSIb3DQEBCwUAA4IBAQAZ +3bdEDSuE1D64CnNpv6BYxsMqidVcO7DARyzGkQbfL3k6GkWBvYOqAA+3b/ajPvjc +/+YRsHcPj//ofMCRMslsIXMWgFX+Ew1fcSZmMKwaYC/JAnAl4GBa/J+tt8iCY5Td +Abyp3LCPhDtqO+o0ov7izAK9jqupEIsCCNmbusKOLDEoxbPj4/B9R2Hdl7PCzgJS +lqnvHn4MFx3fBxarUXc9rJVDsCnXZvGV21oefj6ZNzjtR+nSGRFu7iq6SwUE2rc0 +KBKYnSLpGM+BT0qFocKlBACGKDUiaOpIwHA40HM9gdCbGIGFmCQkSz/94b9ggZjO +CZsunDpLohZnCAWOoGmH +-----END CERTIFICATE----- diff --git a/test/opensslcrt/normalCert2/CA/cakey.pem b/test/opensslcrt/normalCert2/CA/cakey.pem new file mode 100644 index 0000000000000000000000000000000000000000..6f54c1393434db56756a468b0cc61035fac75137 --- /dev/null +++ b/test/opensslcrt/normalCert2/CA/cakey.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEArLeYi4I9e8ZF1r0k9hw9c/UZjxTLvwAQly/cxMD6tqSWo8Xc +hBX83OGqDelsa+55awoNpfj8xpvQPHvqOKyCKICO3w04FZkd1+Ihni30xDdoe3Ou +gbzJQ6oty32MIIs7yLgrItb10rWrcVTeo1aXWKCuW5c7DfVcImaCrDsqbYkKdyxl +yo0I98BvTaGxAF3NXduFTrsPrZ4qQ/fHNPrmeU5NhPXq/KjjnBtoTEm+HyUbTMKe +SOxxTs0A9UzjowDFCgwSbqyz+rA4GKBtJciF5nX/F34ZrWuoBKQ0eSHMiIayuOSp +o28ew4680xN/5x7PYq+q9Y4/HhksWDpdQ/UpaQIDAQABAoIBAAOv26LGwfc9cUo3 +OW3DhpBOICCMyhimsLP5gpX0pKV3d/vBAocTTRWN/6pSXE0lhxQA3++2E/dNZjGU +V5ikHhctP1+FkbfsT9XG7v7Iblwt7p4MByKf3BAEtvuMD4y2TC1puQoCWrcHx8Qg +kqwb8hkjPPWZkdonbbMSSyHQFYTHCiBg5QFetqFtaly1Jn1MPNylcMFTEZD3Eez7 +zyvm8fgwg806s7zKQ75D5wxgYrT+O3HD2YiMOv88W4oZbdOWMa+G6om1Tgw9ccnX +Rwe7EYqxzpaMOxmve8Fy/iDukNAEqdu463CqAdpXHbrAiR/sD2R8bUCUNUa8CuIe +LKNLtQ0CgYEA4r1scyHOUpeYyekbkDYl1BsfxxFjG9/ovh/9g9Gm6aOKYRta+9/f +MZj0dPJJ1Shxx7OjZwgs96KC8mxkvIVd7ObAZvLCgW30eN+olbyqKBE5TF9LzypP +m+A2TkKzhhCbgH0oPEJAFn2uTxjhfyDxaGhYnF3ER0Ve4tD+SzFrwI8CgYEAwwF6 +KMZwwm1SUhxVO/kdC5SoFwggWBMgrFwFNtCM65wJGDGEzHAMBblKNbmfdcuGx8vB +eLCYw+F5Twu/ud6AZIpj/s/vXaKaEm4FxajHw28/wde0GAXWzBvf+lrmumA551IZ +54I5AIYcb5TROZl+l16mhYizHwQ4b491Y/cGgocCgYAuiksUZr/+2/cUmrB5rWOe +YeLn7X22XwNa4x4aTJJCKrtgq0jFonIhsA+dYY4eHqscHpdPsrHoModGU1FqEMXy +tnqPpwydnIAtv1rPQWZ7yu1D69uXrkug5yi+b5qBzi7c9NOpar/U3f9FZQoEGjLz +B1gI4IZFSsvrKpLHxGElqQKBgQCIXhIVTaAw9s4oiZO5Wt3K6Wwanx4vOUESB0/G +I1VItldNyeM6K+jcX9TbeTFf6LIaKgzj32sVICn68xZiXoyewoBnAvY7Hqhr1vDG +Zt87diK55aQVYI/zMFMS+lp7Xyte7nUbBl/iU+ayyPj/NK3oINKGcy//sq3cpg/L +lkXBJQKBgQCtJwzB+hF1NKyaX3QqtQZj62rG04hzL87gKnMnbImx4WQ3smGEZZ8Y +CnRUBwraf22V+g0Hi7Y8nxXmgMbCtEWtKEKbEL8V3iM56ZtFbYAfatL8hPznnpzg +DhFOCi9viPBSc/LNiKdifaHNCthyAnLBe/GwGRevU1QD8F+hNYmAoQ== +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCert2/client/cert.pem b/test/opensslcrt/normalCert2/client/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..f7e762530716e04701317523f19ca664facb3b21 --- /dev/null +++ b/test/opensslcrt/normalCert2/client/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIIfSJxdH9/JuAwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UE +AxMJbm9ybWFsMkNBMB4XDTI0MTIwNjAzMDYwMFoXDTM0MTIwNjAzMDQwMFowFDES +MBAGA1UEAxMJbm9ybWFsMkNBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEA7spWMDzza1zTGSfkDd8gXDeLtgoihzkw7PQm09Pv5iZ9PXLAwbrZPIoRZdJB +d9kZK1qFUkMwxZtrMUcEk54vM3N1ZEibf6olfHaFtPO9ebdZzJJUQGO4On/GXqK0 +PzIZ5bjNAKfocHAE56GtmHCaUoEPO26XTQ87f9Qc3nkeNENPiEY7VuhMTjSJ9Kre +DhM3skyOLg2yPeAqgLMyom5CUqfYpIneH2/fkCQKUv6EZem6JYMIZXL5+ibNoAeT +AfZ8zzZryl20aN3vcc4hBatZ92wBEmBNTbrQG4aRWKmSkkoSu9XKuOpopX6k4agN +mqL35lqgv6+vueUBYz6g2VdkuwIDAQABo1AwTjAMBgNVHRMBAf8EAjAAMB0GA1Ud +DgQWBBS1XkSXnyyuG5rLla9AER13tY/KNjAfBgNVHSMEGDAWgBQ8cJmO3qKHdQQU +mO80sbAnQZMudTANBgkqhkiG9w0BAQsFAAOCAQEANKEAeBtHIHxwfGWFFZ7HpX44 +VLaj2smgYUspsNxaTwgRnoppGHrl+cdVgm8aIdxePClflEmGF5ziE6vK5BwoD9Gp +3WfMdPvEr7GgjZ9V47fS2sRD8D9VAxFiFVQ8e4CspSWatrAYXCXAOtFhXiXScGfE +VPENzGA6YkRti1XazRlUes6AWB1ZsmZQ4siMtrlPXpVBEWjfCy4rRTe+mNOI3Lm+ +zJzC4MgXcuTLiHqEonVFtlADsdm7+LxNllr++3Qr94ITG0TuStfe/xB+lSvmA+k2 +O+lNyiTjmyOzr5jIwOduXxGAUp/CffB+NJtgjmEqYY0I/NWsSKHtSD7NZBPh0Q== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/normalCert2/client/client.csr b/test/opensslcrt/normalCert2/client/client.csr new file mode 100644 index 0000000000000000000000000000000000000000..01617d863148900787a089f679cba7b51787d4f7 --- /dev/null +++ b/test/opensslcrt/normalCert2/client/client.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICtjCCAZ4CAQAwcTELMAkGA1UEBhMCQ04xEzARBgNVBAgMCm15cHJvdmluY2Ux +DzANBgNVBAcMBm15Y2l0eTEXMBUGA1UECgwObXlvcmdhbml6YXRpb24xEDAOBgNV +BAsMB215Z3JvdXAxETAPBgNVBAMMCG15Q2xpZW50MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEArWqjM1jTxY6/7VSzZu7fV9d9wn1CdGpm0LWxF75BEa4J +kZDTpuXUwJOIZFPSNB3ZXwaW3v5o10tCB5pCu+X1xLTVXdAmg3ThOAdRg0OH/q3V +lEkFjGaW7dqqCj6wGXfKwzIZowpBFHGVeLsiYJBB9fOm9HvhxPKQ1vjLalLCGhTX +GLclY1NuEPjOK9Gao/jphoOM7YID7yl4MX+vbhu8r+fmSiJm65FcErOlav8D47TE ++D6gkwAOw0Q256TdSI0Vuj7LdDTx4+9cUNi2dK9ZC4bHwyKOkwMnX0T3OnuMbW47 +F9kWA6kc1uKEak7NjGdMEV/OLoQkLMKJ5MkAeFBfjwIDAQABoAAwDQYJKoZIhvcN +AQELBQADggEBAFswSBUD/UxT7JDRb7IorEAm46F9xhkIgXqLcTHopPjHSGskBn4N +0RTzqZ5yv/UTbUGuY9speA/TtZRiGkb3r6KL7z5/OpiUajP+QguC3HU314xpAtuu +RIToMbONvB0WUrRTUF1xkiNew6R6ZW2Hr9siptlMDNI4wEIpqrOt/y6m0sqewMmR +5o1Fd4PTa5Dfm+at4S73upT2cVmAMU3BxUNdi6Nnec2/FiYP8+W/NGllLjVyechH +OwTCgarKmGYyk7Mto6WBgJL/xRfO1/azfJWq4EuZyWW69TgPqTLVL6HuvMcYBqIk +dbCZKruAp4BHnFRn3FsOAVU3XNbQW+vUgJo= +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/normalCert2/client/key.pem b/test/opensslcrt/normalCert2/client/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..f50395b04af6dd105f3bf56bca47660e6008ad33 --- /dev/null +++ b/test/opensslcrt/normalCert2/client/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA7spWMDzza1zTGSfkDd8gXDeLtgoihzkw7PQm09Pv5iZ9PXLA +wbrZPIoRZdJBd9kZK1qFUkMwxZtrMUcEk54vM3N1ZEibf6olfHaFtPO9ebdZzJJU +QGO4On/GXqK0PzIZ5bjNAKfocHAE56GtmHCaUoEPO26XTQ87f9Qc3nkeNENPiEY7 +VuhMTjSJ9KreDhM3skyOLg2yPeAqgLMyom5CUqfYpIneH2/fkCQKUv6EZem6JYMI +ZXL5+ibNoAeTAfZ8zzZryl20aN3vcc4hBatZ92wBEmBNTbrQG4aRWKmSkkoSu9XK +uOpopX6k4agNmqL35lqgv6+vueUBYz6g2VdkuwIDAQABAoIBAAikgiIU0qqcHXfX +ncVywfUUfZYFH+LNCxxIae9YyGIIGphSwu9AGAS86oxI4922mdab3B9yFx3P1j1e +WKgPHZ47G5CZoCcqvViX4Zb/C2iQXkutMCd0dIKJhWB/ew6efZl26tSPPzZyl3ra +uzG9QZii2y5Hfmpnlru1raXPtGwnwZ9jKY9BKHuL9sW+v1ny48r1mWbmYl6A6pdb +HPJ9iBTbVQPjD/rfx4m30JvwVwlogn5mwK4qu1l2bNCwd05CzHj+4L+hg1Xhx404 +Z2iYfxUOhXTTZj1naHPWSAObcd9tbZWYh6EjoKcP7TUkXqY9DZAVVBJ8aNaSFDHA +icuNN2ECgYEA95rKvPYmp2iBVuzohhg6DBUAu5RHmXiyOMcVIKZIVZ2vge/3E0i7 +aNhkiwSM1aXVRqstmuL1VcgS8tL196COREHbN3CteOyccZnmrOI0S47uH9ZTKUd5 +KhI/CV11dfgt4M9KhyM9G/Is4cHp9GsLFCdbS4Lkw/kepFwiMPkKArUCgYEA9uMJ +a/xO1uYVYJs81Dc5rxWXVOvSDCevVkhTyNqjtFPmb9QUUVIysUSNmaq54zq7h+Kr +3hXD2Gkk4II+77M0HxqX4utcKWtMj6ax2saovKYNXaK0T4ouBYYtpSIdgOD8I8h8 +BavGJSALzBMNjeVfW6MObQBNXkOyo2nkXU2hP68CgYA6Mpvshj7XomykLdMJkbAz +HsypSlT58bCSP9jrb6oY1ZNnCywb0ECciCA5vLwNi4jxr1Js0WOw4fHyOTuvFIHb +8U7BXS6o+ZRb7zyU7iY0jEfUG+A4F6CXDuraHKl7LPY7nP17fgDjDK0X8yg2VuKn +vghrWzzCA0/8edtff2mm8QKBgQCDCjR5BX1PIOZd1Weu9/T9Eke1gkmFvBabhF48 +MJhNYW7hEj8Kfsd7QXirBGwPyJ+62zv/76t4lS5Gn7sJ/SMAE1xkzCmprlL/uCmU +hncPUk+r0G1F/oC42+QkTvw76K7Ly2DQncDY+a5pTf+WvhUkDWLdxhhKnUIW7GCS +iV657wKBgQD1+rWh/nNomGRR4R9yDs2z12u85zOZYpWXDGicg+dLLjfo1VYuHKYp +KADEqxxTaU39NEjGtzNfgNdStvzrfSHcIUr+mJEzSLRYTC58N1kBZsNoWdlPIqjI +goPA/iEqKJw5MekAwk7QbI9nniYUVg+gYd4Nb5GfTGEeLvqGwlQjPw== +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCert2/server/ca.srl b/test/opensslcrt/normalCert2/server/ca.srl new file mode 100644 index 0000000000000000000000000000000000000000..6143b0567fb1ca0c9da2c2c875cc3cf40139160a --- /dev/null +++ b/test/opensslcrt/normalCert2/server/ca.srl @@ -0,0 +1 @@ +57C9ED9AEA46550117888FCEA14CA541B98251D3 diff --git a/test/opensslcrt/normalCert2/server/cert.pem b/test/opensslcrt/normalCert2/server/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..ba0126e6811f8d1898ad7bcc121d3a6fa65d3c21 --- /dev/null +++ b/test/opensslcrt/normalCert2/server/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIIXCs1gv7boSQwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UE +AxMJbm9ybWFsMkNBMB4XDTI0MTIwNjAzMDgwMFoXDTM0MTIwNjAzMDQwMFowFDES +MBAGA1UEAxMJbm9ybWFsMkNBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAshLbAYIyaToULFcYj4TB16jJ7ysITicoFPHj6DDUqhE/UzzPk0ddEnPW2tvZ +W31RWBmnxl/IsdHXUVYcSga19q9mUryb+zjV23t7+LYSxCUpqqtkYQXKWgSoMsJb +MSjUepSFwZ0aoF1AyBVVUDePpL9K8qsDra/v7wYfQTagDqALwqvxI0NUP3h2vMJ+ +/vcV/zwLjos9PCoW2Oe7bSgqfT7gF1EURRxE5CMVG9yJio09FfilPl4NnHjYuv98 +KMgFpFyLtUcMwn/xYpZeqOUCD3LZbQYwYXJF5SuQ9SwW+/EAN2PUenSBg2sQO/e1 +GIJ7XtONcpluXrApFct9SrbnPQIDAQABo1AwTjAMBgNVHRMBAf8EAjAAMB0GA1Ud +DgQWBBQBLG6XoAxOd6F3XyNdgFPeVj3JxDAfBgNVHSMEGDAWgBQ8cJmO3qKHdQQU +mO80sbAnQZMudTANBgkqhkiG9w0BAQsFAAOCAQEAfacDpBe0N4/qZ2zhExHjcL5R +pPkuchTh2OIppcCr024BEtbTSXPtDu/ncFcS+y6m7gDYmNUOUHwtTd+jt3/xjEcP +TmF1o3eGwUM8wRkzxSiYSWreGW4v8E2qmkNVdKC0A3Xs+5WXWRXJvlTTor5MPbq2 +8uG6QrEscILAmKmygZMh+Q9DW21SLz1X8dr/UdDtuZ50CZxGby1JlJh0ivsbwJ/X +NA/4ZVRht/R8g+MDP8mQl8HGjDoLWwtG1KlFEk0LEkUDdyPNpizncArbvlRBcjwc +1k+7oEACKyDzdpwOb8WCZc3n0DdSn2+sasw5+z4VIl6p7AxRh7lRaMPqYID7eg== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/normalCert2/server/key.pem b/test/opensslcrt/normalCert2/server/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..c63927e8d5c72f23aa988739f6a22a3f15d87048 --- /dev/null +++ b/test/opensslcrt/normalCert2/server/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAshLbAYIyaToULFcYj4TB16jJ7ysITicoFPHj6DDUqhE/UzzP +k0ddEnPW2tvZW31RWBmnxl/IsdHXUVYcSga19q9mUryb+zjV23t7+LYSxCUpqqtk +YQXKWgSoMsJbMSjUepSFwZ0aoF1AyBVVUDePpL9K8qsDra/v7wYfQTagDqALwqvx +I0NUP3h2vMJ+/vcV/zwLjos9PCoW2Oe7bSgqfT7gF1EURRxE5CMVG9yJio09Ffil +Pl4NnHjYuv98KMgFpFyLtUcMwn/xYpZeqOUCD3LZbQYwYXJF5SuQ9SwW+/EAN2PU +enSBg2sQO/e1GIJ7XtONcpluXrApFct9SrbnPQIDAQABAoIBAAfwKj4neSrSsXqG +iWxAbUm5VwMn5m9hdHp2jAPaRWitaybBQ7mbiU0cx9uqyEY5UGp+fkmsFGzE33Xb +QliUcwEjsYqHqOZqkrH9j7m3GCARTgmzKT7f/LTFl1/n5RvMZ6htPYczgy/Z6Fzj +ApIYX/lmo6css3XHLYl4uuTz0+/C5Jh8HtOr6NgTFm+ECjyfzBkmEQewkzmMTE6r +zQlvGpvCgyYSAinB6yaFAf21NEc3c90vAxW2JJfBtwcRnJKZBSlsV5d5Ly+EdLgA +6FBtsTLu/GdG6iAIfZMPMxGAQLtCfqNxBX6eFV2HMc9sI1B+QvQnbPgcROWmGwRo +p12/w3ECgYEA6HQob2+eGww5SxgOObqYUIKmt7EYPxH9zaKxz72cVbIEtJkn9z58 +mZ/bnVD1Vlp9M4KY7asVuKrDgr05jOtgoUruKofgltgiAvfgqvVevVCxON9Efy61 +1ImSkn+YpOZxEjp9dG+2Ko8PzD70HSlOSHMiJIkihGpgCs37dtNBs9UCgYEAxByM +PfDrUrNn9agUdkeQyVjRP46exQWLEQCS6uuaBeBg3Wkd1TyYJCcKvfZZPyfZcG/a +Mez96RXIi4GNAke8Uqdf1RRFaV4FaHvUYxl6HwLTYsET5Ob5HVHkFBXcJ7AvGna6 +6wx9PCGQ/p829ywvYlV36o1Cxo9/r0z3siKJYckCgYAA6yyZvCQw0SyMymyL99vQ +PaPxB+cjoKuTG3h+bK5ofEqeeBh7VqjhZotUpNRDYhoqkPKPxeRRFYOuSZcnomqA +WK52RpExp2sC/f3KTyvvrZj1s5QuuH2JeH22zIBK7oo12ztRsXtT0brrQLhwQRCJ +IsNYx532TgFRKZ07y/vvfQKBgQCYihwLkfl4ulOCoNpJpiDYfF1GMokMduNUtj2Q +vPFw4LZ3gevSkO9GL5gLajLBDR30dwwhCVX2lxsGaB5mR6YUJFxGQR48dwV939Oz +wX1b4bbKeaQfyAi85ZWiid5UgJYi33dEnNJMk09UvkLiv1ucAR7lMjtedv+6OwPv +ay2xAQKBgQCD7a4ORp/xIXDkpVuaiNkPSK72/kgpTBJsBQjvcgS2BZmt5d0apu2i +hi6LJd5DZudJ/DPMKd3axpEW548IXWuOZZwhUG4+oZM0QY5m9ZSPQhMMRaAnIuNg +mx9ndUofO/m6pyKOtp+vBsuE3JL3sP/1ZJ7EBLEPQ0lb2+F/qHjoWQ== +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCert2/server/server.csr b/test/opensslcrt/normalCert2/server/server.csr new file mode 100644 index 0000000000000000000000000000000000000000..966f22a084830e0551de859d0aef86c15d863229 --- /dev/null +++ b/test/opensslcrt/normalCert2/server/server.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICtjCCAZ4CAQAwcTELMAkGA1UEBhMCQ04xEzARBgNVBAgMCm15cHJvdmluY2Ux +DzANBgNVBAcMBm15Y2l0eTEXMBUGA1UECgwObXlvcmdhbml6YXRpb24xEDAOBgNV +BAsMB215Z3JvdXAxETAPBgNVBAMMCG15U2VydmVyMIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEAul7pVfomKoiOrnZStx9ua3f4PgDMtK0abkKS6As0AVBX +4j6Yen6LaN6n+v8J46YX8ffwLvwSKRsUpCFnnpNpXvLR0zyJZaHLftg6hpOMKFt/ +mk6YnRB+Fbaqfl82GyDJD1QBT4lpJR3tURZzoEwrAC4dY7SU4y6QfqPzd4dsIWdO +SmNbFQCQSAovcotjsfEgHSfe969TlYt+uv3AEaeDKSYxwVfeV7B8SsrkRh8zDh7k +ERuRYfyQZet19GpCW/AUjQMsRQ+qRHteNaNLETgYI3zLRmjujjl+y8YdpXWIYNPk +kdVDGoy0EZfFL6/ncKPaGPwozFE4Rhzd5rRVlnDtgQIDAQABoAAwDQYJKoZIhvcN +AQELBQADggEBACH4BWHpq3R8g6nOkT8cMqQlboG/pz/3OvDrHS927wbJzfJDSko7 +/5dzurJc59PPCVsjQnWtrDLYxr9v2Qs/dHJwAGacYqY/rZAHVSRHoAfaKQkTIgw7 +AO3lbus4KCuks0NCv0TxolD762SbvzojMltD2FIQmpYQ/bvCSHSYol2mMvBdmBvA +B6eWuU8jy69kEvplKVNs2JMWRVW+d2u/fVi93LAhMg/zipofoVgBAS4lWKC5fMCD ++THWYDIpPirj0cuDjXTfbpr8S8r1HmPkFx4Q1CmOoL65ynctVlS9qJQVfVqFfu/n +KTb4s5r0vwB+mp4xn5MceiPLNQmsSNS87F8= +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/normalCertChain/CA/cacert.pem b/test/opensslcrt/normalCertChain/CA/cacert.pem new file mode 100644 index 0000000000000000000000000000000000000000..8e52b7a68dd73f59d2eaf00709e3d409fd2d1226 --- /dev/null +++ b/test/opensslcrt/normalCertChain/CA/cacert.pem @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID0TCCArmgAwIBAgIURA/lUFrZwOfJXaaLF9ol7wkhAxIwDQYJKoZIhvcNAQEL +BQAweDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQHDAJOSjEPMA0G +A1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxDzANBgNVBAMMBlJvb3RDQTEfMB0G +CSqGSIb3DQEJARYQcm9vdGNhQHdvcmxkLmNvbTAeFw0yMjEwMTIwOTM3MjFaFw0z +MjEwMDkwOTM3MjFaMHgxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJKUzELMAkGA1UE +BwwCTkoxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MQ8wDQYDVQQDDAZS +b290Q0ExHzAdBgkqhkiG9w0BCQEWEHJvb3RjYUB3b3JsZC5jb20wggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCgvTn1zACCZ4uny3BW8Utilwoztkhb/XM+ +ZI/trCZg1smsnuCNJHyIJVoFz4PoxXESCueTD0UwIcrftvzQPJZzVZEOY/ND4ZFq +Bj4TbCybSNuFIIAXn2yL1x5oLGz5wuEr7XClqUECVZPTyDv2ozg7+L6NRXNnQ3DQ +jL3QqEaH0M0hA/4FX7ySXrSC2BFX5LZzv8cjKla+3jqJUbUxokxEWMfYVNU+JjwV +vS4ieVqIsWcfd+FhqFCpvWj92PpJB5Lk6GkuGi026lgYutK30Gx133QLQjrNRRLG +4dM0KMevpSM1Ug7dXy60dIjlJkFjekEf7umCGFNILf3UKQTd1irnAgMBAAGjUzBR +MB0GA1UdDgQWBBRXGByy5N4UbiP0NcTKWpj2yDsJ3zAfBgNVHSMEGDAWgBRXGByy +5N4UbiP0NcTKWpj2yDsJ3zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4IBAQAaLNCcpUpKjxz3DKPU632uqljaliu5mCm9OAQRDLccfZifIHacyxx45rm7 +8TOSgQnSS41UtsvTMn63LXqWCi+P0K6LVZUDAiQzcehGsKYyj7yG8UiQev3qUXuU +HRD+n92AVKYz4ABkKqepgVqE1G0CQeFISU/czTv4r3+Qv5IGXWuZ9D9fNu2Al8DG +Zo4FdcxsBESksDP07Gjj3zMRgV0uhsoSP7gEc2zrucfSIZasEuIgPGL07vMc4Ybm +EaY/cZcuSj+o012wmP6YMQw4t74Fxqe41l4lY/yPtgUO9LDx9ygiDzEKaMCTea8M +33Oqg/tD4bzToCPlpeh7qhpUfLpP +-----END CERTIFICATE----- diff --git a/test/opensslcrt/normalCertChain/CA/rootca.key b/test/opensslcrt/normalCertChain/CA/rootca.key new file mode 100644 index 0000000000000000000000000000000000000000..243405ce091dbafb55cbc6d84a9a8f9fb506a969 --- /dev/null +++ b/test/opensslcrt/normalCertChain/CA/rootca.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAoL059cwAgmeLp8twVvFLYpcKM7ZIW/1zPmSP7awmYNbJrJ7g +jSR8iCVaBc+D6MVxEgrnkw9FMCHK37b80DyWc1WRDmPzQ+GRagY+E2wsm0jbhSCA +F59si9ceaCxs+cLhK+1wpalBAlWT08g79qM4O/i+jUVzZ0Nw0Iy90KhGh9DNIQP+ +BV+8kl60gtgRV+S2c7/HIypWvt46iVG1MaJMRFjH2FTVPiY8Fb0uInlaiLFnH3fh +YahQqb1o/dj6SQeS5OhpLhotNupYGLrSt9Bsdd90C0I6zUUSxuHTNCjHr6UjNVIO +3V8utHSI5SZBY3pBH+7pghhTSC391CkE3dYq5wIDAQABAoIBADKXPC3btnFUy8TV +KBeFPJfcOA7MmXuyitohZpeErlOeZr1ZCA4EZNmo/+uCQ984fX0TR42mqb0bdbHx +8yJLX4MPdGdWGBPOZCk9q74LNwLs7IK7FvXYbJ6a52wcR3RY3OwpgGHzoo1sh+mJ +RS48cw+VG8x1BnyC4ngRRBDvVbubAYMGceUmlYtFgIRWqZrBhKvXgzpB2zv+5jxR +NnXDqkQRYl9fuAaG3ajN/qpiMZZlcKfZskjfRvrcGcHi8bD+xGutxXSPFq3qPUT+ +tJ1pexOv2xUx0A27441LoPW1FzSyOj2+iPuZpjTSseM1AIxk3lvXwJU/9CAXYBiS +RxA7YQECgYEA1jqeAfoWtEM1XOl3K1KsxfCqcPFsMUxglhcPz1n6Fv0zgSXHr/Sl +HDPmE7faiqNZoyy81rG7+xjpZkSa6LfQ4gCzwjN5soed7Bi7zeRM5QImta2NMpV+ +x6OXdlQQj3QDeBivmuilL47ckENYCk/t/Tu/jCpuPCYC/syuqvq1g2cCgYEAwBSh +LFcJG4xFtClnrPMljbqlDeQi4S02HYl10+8JpVbjEiPmLTbzkDSAUfm4UR+ZblJX +CGvd/TLDy1iju0/m1X+Ddb2lMUMAeihuZjaWxbpaX7UU0QUNR8jcvxQrkcFqc/hy +dpw3aVD+74+WcYoN+YqG9KGp/7N7Lkpnxslk7IECgYEAksozfpNIf1gV9oYam9rY +fAD+KMmkItt8yxseQCwdCyeP5QxoGY7+m6aMHjK6Uoi/YOnEsy+x6MoXE3Yq1w8s +1883XPg8iTIX6bDA7sFiVwD0WUSEHYcGCfF0VSYg+sq5nc78dJ64oS+4vjkG2HoQ +TpZkF7zzL8+z+bdyb8G+Ij0CgYAYdggYd3UHdxOhX+x+D/DmXbCLVlRCzNkpZcoF +lVlrHueH9d5oP6lA4g69YcnhOt71N7MxtVrt1bsteDpRrlk9MyHwqpgQ7/FtnRyC +E82bnKHJsmvWOoh4bdH+23i49SKzZh5dkINV/CSbKXQFPYmOD+Aj4zqc/6RePsd8 +f0VFAQKBgEqkOim1lmubatkYFy/Gm/3sr4pn/ENkPig7nRZDbSqzSOV50/rU40O2 ++QA5mhUyzQuG8S8zPhXWBnJt1T+O57uDRjEtVw9zrgZw/MddXEHILVjK0qtIR1Hq +opEzRRcl+keFb2N5YXfOd+qsnBl0UklTCOo/E12PwhaW6nX6HujD +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCertChain/CA/secondca.crt b/test/opensslcrt/normalCertChain/CA/secondca.crt new file mode 100644 index 0000000000000000000000000000000000000000..1de5943ff08d32b36bd333221eb5f15cd0911433 --- /dev/null +++ b/test/opensslcrt/normalCertChain/CA/secondca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIBATANBgkqhkiG9w0BAQsFADB4MQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxCzAJBgNVBAcMAk5KMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNV +BAsMA0RldjEPMA0GA1UEAwwGUm9vdENBMR8wHQYJKoZIhvcNAQkBFhByb290Y2FA +d29ybGQuY29tMB4XDTIyMTAxMjA5NDIwOFoXDTMyMTAwOTA5NDIwOFowbzELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsM +A0RldjERMA8GA1UEAwwIU2Vjb25kQ0ExITAfBgkqhkiG9w0BCQEWEnNlY29uZGNh +QHdvcmxkLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANcpo7lG +5ttiS8yJtzxGVlO7lA0RqODL1xO84ES+jE3CPLCVo83l95E+ZHQZZWLyAr/DrGCm +rP93+ODINGkMvLZlEV+vYcxzMZeRwctITygaSaxV6bON9m2F3k1iQZS/gKkL9L1i +/F3nkOAEuCYMgwHfPy+/E97H9ECRpwt2mzld40UQZj1nQY1nc8l5F10p5GSiNp5w +98N/Fnk9ZCESUKtPviK541ljvs+tvgSvuKtJJIbdhbn+xW5O9hQE79u9iumo1zBS +jZKMMMdRqTIaqt/0Ofqu39g/OsfeLCfGr+NF9suTaeJEMFPMzOhSdgy9zf+fykVJ +GuXxfedLzbgP06kCAwEAAaNTMFEwHQYDVR0OBBYEFMeDrbiKqXhHPcdp/CNpS9Vn +/MTfMB8GA1UdIwQYMBaAFFcYHLLk3hRuI/Q1xMpamPbIOwnfMA8GA1UdEwEB/wQF +MAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGDOfMIV1AC5bdbpnhSDhpzkBjh3OXqj +PUplhCm+VSFUcSE5vTe4EnUsCmpIAYkAA/2fnr6Kb9GFJeuoWlWpo4vRW9g8qvRD +B7AoV/ojafPb4V2epZLaHpJQcxyqplGFKqPKDAXZlXLXUbWLAutq8q7GLL3MlqDF +nRDHBafXEHJOBQG7HyTDt20uoIM0+10ehJW0LlO2UD/SQ5ZEkYwB5fNk4rJMgtmb ++Nibo2x2ZdMd0rakv5CGCjf13aMczngYcbNjnUFjiLofHTEcoYmDVTqjd4W416Fj +2R+WfdalPEctoRCjDMi7dZLxQOY7/IesZbXJkVMtbP8lrVb+9Xl7lj0= +-----END CERTIFICATE----- diff --git a/test/opensslcrt/normalCertChain/CA/secondca.csr b/test/opensslcrt/normalCertChain/CA/secondca.csr new file mode 100644 index 0000000000000000000000000000000000000000..012c84a927e1af42c1e6379275991e853c7d7ee9 --- /dev/null +++ b/test/opensslcrt/normalCertChain/CA/secondca.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICwTCCAakCAQAwfDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQH +DAJOSjEPMA0GA1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxETAPBgNVBAMMCFNl +Y29uZENBMSEwHwYJKoZIhvcNAQkBFhJzZWNvbmRjYUB3b3JsZC5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDXKaO5RubbYkvMibc8RlZTu5QNEajg +y9cTvOBEvoxNwjywlaPN5feRPmR0GWVi8gK/w6xgpqz/d/jgyDRpDLy2ZRFfr2HM +czGXkcHLSE8oGkmsVemzjfZthd5NYkGUv4CpC/S9Yvxd55DgBLgmDIMB3z8vvxPe +x/RAkacLdps5XeNFEGY9Z0GNZ3PJeRddKeRkojaecPfDfxZ5PWQhElCrT74iueNZ +Y77Prb4Er7irSSSG3YW5/sVuTvYUBO/bvYrpqNcwUo2SjDDHUakyGqrf9Dn6rt/Y +PzrH3iwnxq/jRfbLk2niRDBTzMzoUnYMvc3/n8pFSRrl8X3nS824D9OpAgMBAAGg +ADANBgkqhkiG9w0BAQsFAAOCAQEALBW59/ZzFd97b5jzqamQnkKH2fN/kk7+vfu8 +0FiHN2liCnaHAa7+zlxch8XZY/LdQWdBcTtOMQgTz8dEuHsQaAxT4dLTTm9rs70w +QoxoGLy7okbvGKyhxzJM6BHJVDzaq2AXMtB1BlI+9DFBmwxbpDQyqtc0XaABBkV5 +GahaE2WAP1t3LM+JDOdJ+5VLSNIhneJrFR465HmHaVSVe1ivD3tk7394DwcLPOE4 +qsuIP3nQIzMFKyzyaMbaKXNd1mU3SfitOQuckm7oycexVgtd1oU6kZs0PF8q6e1Z +tcAxZqbYllNZx1cHyg/lu59+gPOvH3FB6LWiloL3usrHLm9pgQ== +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/normalCertChain/CA/secondca.key b/test/opensslcrt/normalCertChain/CA/secondca.key new file mode 100644 index 0000000000000000000000000000000000000000..803c11daf385b85f69e000111c71a08968d0bced --- /dev/null +++ b/test/opensslcrt/normalCertChain/CA/secondca.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEA1ymjuUbm22JLzIm3PEZWU7uUDRGo4MvXE7zgRL6MTcI8sJWj +zeX3kT5kdBllYvICv8OsYKas/3f44Mg0aQy8tmURX69hzHMxl5HBy0hPKBpJrFXp +s432bYXeTWJBlL+AqQv0vWL8XeeQ4AS4JgyDAd8/L78T3sf0QJGnC3abOV3jRRBm +PWdBjWdzyXkXXSnkZKI2nnD3w38WeT1kIRJQq0++IrnjWWO+z62+BK+4q0kkht2F +uf7Fbk72FATv272K6ajXMFKNkowwx1GpMhqq3/Q5+q7f2D86x94sJ8av40X2y5Np +4kQwU8zM6FJ2DL3N/5/KRUka5fF950vNuA/TqQIDAQABAoIBAQDL1t8NQGaloNI+ +zJmTuO9AFI2GdBySG4t/X4j4l61EXagxgxLUlfGc4Ic6lnS+8Jg6JJ7CUiXDQV2/ +VuyQOUjvY4C6LeVxVBC/j48Rj0eurnjtk9b8DJpR2Glq1pNa4LJ7dKBAa+666A8Q +rGfpZCEZPO8XxOaGQNjd8x9WdN9J0DQ1oSsvRC8xLyuKK1jrxrIaXMK1XZRgaNMC +atr3fCo+yChyuzsSKHdscq+2MA6xnp3dc8Q+nwhx1xFSdfah+NCqGfmVYJRDgyCT +GuZCkyi8R1+jkGs4dO/cBTTRlkiWHHFLJo+/c8vB4Do1JkA93PoSYa6G4UAUJA3N +B1e16TihAoGBAOttoHES6CpCiZdOtTTk//zyWxrG8CjnHRCj8g39QHnAYyzlz88h +RYuyfrTPviGuE3+3CMWsjvPag1E20NpyrsJtx3OBhcDG6fj6l7j1vuTngl+mlM9K ++n36BcClTAnY3JZf51FErmUYkTVh9RcQ4iycV7cl65OouNwGVc90WCpPAoGBAOn2 +r3OIC7bal0kxrI0bT65gkbvg3q1sA1cpcu4HOZJUK7RjdYzsd7+lCsSPfVyp9yN1 +iw3Bn2umEkkogxy7wShT27Ng2LNyHRrYhVIuEee/+SSZgMDa+0cTOdKPqRVITffb +weM4Z3gVrc27EtWuc9vGUkNxoceJ3MKHpesDKTyHAoGAMSoFlVdzcE/Q1+4x3Uft +RW9/IwpkYMZSxYTXKaC3dDV/AINFcGXsVg4Cc9PmSrZFkCgzBsTQXZBGWBFwcA3+ +/M9cFXz455ciiUIbqR54rOjDyyHIdbmcse4igWaDiJLnDegdMFV9bdNBj7pTKmv2 +L4a+spqSpZVYdWpFRTtwpfUCgYEA6b4F7bOWmHlsubiB/nuxsLJUBtMTRUlrUPJd +G0dmkjW7cD4Jm+BHhtTZnCULBr/b47Y0VWsC3aaOED8ENnmx8ZtOHLj95tF0GHUH +RWI3i0Q1IgamJobgklK36xCRyWxyUNVhsKOSY9usx6RFnevrXj+VwkHNci/euQ6S +ieeflBMCgYEApVqc4NqEuS4i//GGdQ9fXPx9eXBOxnqjwXL4SUdhwSruBYPIDavD +Vb66fdGQU7wAeVQczeIZEHDFuLkePvbwPeQmdDgJBN88lB1TjxvwhYvOBz0cJbl0 +NN8bkkMUsjZ93sZaVKgKYf7Q2JEwBRdagmWktcx9omJYPs+GAxIdVnE= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCertChain/client/cert.pem b/test/opensslcrt/normalCertChain/client/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..8ac18d1c638b4709fb4eeea16a4f403d32b13198 --- /dev/null +++ b/test/opensslcrt/normalCertChain/client/cert.pem @@ -0,0 +1,92 @@ +-----BEGIN CERTIFICATE----- +MIID1jCCAr6gAwIBAgIBAzANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDY0NVoXDTMyMTAwOTA5NDY0NVowcTELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEXMBUG +A1UEAwwOdGVzdGNsaWVudC5jb20xHTAbBgkqhkiG9w0BCQEWDnRlc3RAd29ybGQu +Y29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5evH/YQgDHryEt/6 +OV1ALabNSIGgZjLe0Rjouo684vPgROD6qLnfDKZO2llfkupWOtyQJvk29giiNa20 +gtnEinhr6xemJ8PfYxXoTQ0pDCJQQ+bozuJtjQ6D5+IdECe+rbecxEixIMJiUIWL +7Pu5wDrqHKI95WJr3XDil86+jtXuMB6U4Do3RP7KW5MFaoyn9mrQIKUr2Q6ssQ4E +UQaYBa5KyUfHv21HspK9x1T7uLWGZ7PeU/DVfJQdLk9YOldyIVwrtAC4kla/KSVW +R9lMpSm3kLq84Z1DCky0wnizQ0eG+GbSI8hJAFn3bwfBmu99RDitlz5GYY//mD5z +eDXpHwIDAQABo3sweTAJBgNVHRMEAjAAMCwGCWCGSAGG+EIBDQQfFh1PcGVuU1NM +IEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTAdBgNVHQ4EFgQUpij+s+IWcrTecYoQvnlc +kJoxfAswHwYDVR0jBBgwFoAUx4OtuIqpeEc9x2n8I2lL1Wf8xN8wDQYJKoZIhvcN +AQELBQADggEBAGtB6+WR4yZt8DuYJ7SQ4gsEBAo6BEC/lX6Nv4CeSRKBz8RDGuN5 +zAfJtC9k9c2l03tAYQQnwkYiqsEQoN6ojz0JckmQb14y3dj7lClgmMCBy6BPjkDL +85dz/4NFR2K4VxJ8owYaNjBCQFp1oXRq+wcluWseejhoJIJd49H14gkNcVg2j1ep +cRdK4kZL3g/b6ub11NUvt+2RfrQQUCp43FpK86fLd1dEnecyFCgFOyH+larPUPYt +7ftZWkLKih/T9kO1UHmaXb4uZ0VZg5+t6Ut/jeD3d8QnHMftL/TKOR+tihwlEZmW +TyObahHmz0SMWQfUIW3hfNDuQzPybKnycC0= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIBATANBgkqhkiG9w0BAQsFADB4MQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxCzAJBgNVBAcMAk5KMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNV +BAsMA0RldjEPMA0GA1UEAwwGUm9vdENBMR8wHQYJKoZIhvcNAQkBFhByb290Y2FA +d29ybGQuY29tMB4XDTIyMTAxMjA5NDIwOFoXDTMyMTAwOTA5NDIwOFowbzELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsM +A0RldjERMA8GA1UEAwwIU2Vjb25kQ0ExITAfBgkqhkiG9w0BCQEWEnNlY29uZGNh +QHdvcmxkLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANcpo7lG +5ttiS8yJtzxGVlO7lA0RqODL1xO84ES+jE3CPLCVo83l95E+ZHQZZWLyAr/DrGCm +rP93+ODINGkMvLZlEV+vYcxzMZeRwctITygaSaxV6bON9m2F3k1iQZS/gKkL9L1i +/F3nkOAEuCYMgwHfPy+/E97H9ECRpwt2mzld40UQZj1nQY1nc8l5F10p5GSiNp5w +98N/Fnk9ZCESUKtPviK541ljvs+tvgSvuKtJJIbdhbn+xW5O9hQE79u9iumo1zBS +jZKMMMdRqTIaqt/0Ofqu39g/OsfeLCfGr+NF9suTaeJEMFPMzOhSdgy9zf+fykVJ +GuXxfedLzbgP06kCAwEAAaNTMFEwHQYDVR0OBBYEFMeDrbiKqXhHPcdp/CNpS9Vn +/MTfMB8GA1UdIwQYMBaAFFcYHLLk3hRuI/Q1xMpamPbIOwnfMA8GA1UdEwEB/wQF +MAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGDOfMIV1AC5bdbpnhSDhpzkBjh3OXqj +PUplhCm+VSFUcSE5vTe4EnUsCmpIAYkAA/2fnr6Kb9GFJeuoWlWpo4vRW9g8qvRD +B7AoV/ojafPb4V2epZLaHpJQcxyqplGFKqPKDAXZlXLXUbWLAutq8q7GLL3MlqDF +nRDHBafXEHJOBQG7HyTDt20uoIM0+10ehJW0LlO2UD/SQ5ZEkYwB5fNk4rJMgtmb ++Nibo2x2ZdMd0rakv5CGCjf13aMczngYcbNjnUFjiLofHTEcoYmDVTqjd4W416Fj +2R+WfdalPEctoRCjDMi7dZLxQOY7/IesZbXJkVMtbP8lrVb+9Xl7lj0= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDvTCCAqWgAwIBAgIBBDANBgkqhkiG9w0BAQsFADB4MQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxCzAJBgNVBAcMAk5KMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNV +BAsMA0RldjEPMA0GA1UEAwwGUm9vdENBMR8wHQYJKoZIhvcNAQkBFhByb290Y2FA +d29ybGQuY29tMB4XDTIzMTAxODA2NTYxOFoXDTIzMTAxOTA2NTYxOFowdzELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsM +A0RldjETMBEGA1UEAwwKU2Vjb25kQ0FFWDEnMCUGCSqGSIb3DQEJARYYc2Vjb25k +Y2FleHBpcmVAd29ybGQuY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAuX/+vVcic13kBJy4ZmhSxIaQRaA2ElVS5WTZ4PG8BsMRAhtcooz/SKF+4NWo +8bbI/Djf9T5349zHDrYkVwJW4jTwzkZjI/aIEKwFqgotbH2K3a6AcY9mHvnOvfsY +1xPC0a/438y1SkmJRTsfMKdUYKcCJ4/NiQrVZ5Z9lKb9x/Zo9Rhm7+PJL6MH3Mdy +vUDMLRLsUyDZTH4ofoafPbXFLtq4eTzzdKxgXrixT+4M2aHbPS8BDsDxndwzse6G +vXwZ0LLGJ65LqlVwHbBZDtZeWoNw+q9Bhhxk0Cf/qd9Ck2hMwj/YV/Ihiz9BFcTt ++rZhVkumWkAuZYkD9xkLdzNvewIDAQABo1MwUTAdBgNVHQ4EFgQUwMtjuRuv+cY1 +Ea8kWrf8MrQdlvAwHwYDVR0jBBgwFoAUVxgcsuTeFG4j9DXEylqY9sg7Cd8wDwYD +VR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAmMTLv3E672GVkCtioYjG +oh+taqXMyskc9vCONK0CuEWStgSDPdzCZNz8uoX+dCPmgYxeju+98c1cG9Rjg2Fv +xPqWi/TKes9fy3VBYkZFSU72tu1fXBXhoIFO34Pmq98wsZmzNgN6bfT8f9tDcif1 +UUmHkO6L4L9ZP4GwqgHCzTKWuzbbvNnWpFKXfbzyhESGP6H18RdtWSWg8i1zKeVI +d26GnNHNLxAi92rKQaCAN5OqsmhmBGehdfNdmkIdJ/Kxk7LYbhn+3KLARUIwnS4x +VCN10ITB8V7692SD9pF7BZQSkW5K7GakfI5L5RBxIej4awVCzMrDD4tl9aQriEKw +FQ== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIID0TCCArmgAwIBAgIURA/lUFrZwOfJXaaLF9ol7wkhAxIwDQYJKoZIhvcNAQEL +BQAweDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQHDAJOSjEPMA0G +A1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxDzANBgNVBAMMBlJvb3RDQTEfMB0G +CSqGSIb3DQEJARYQcm9vdGNhQHdvcmxkLmNvbTAeFw0yMjEwMTIwOTM3MjFaFw0z +MjEwMDkwOTM3MjFaMHgxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJKUzELMAkGA1UE +BwwCTkoxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MQ8wDQYDVQQDDAZS +b290Q0ExHzAdBgkqhkiG9w0BCQEWEHJvb3RjYUB3b3JsZC5jb20wggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCgvTn1zACCZ4uny3BW8Utilwoztkhb/XM+ +ZI/trCZg1smsnuCNJHyIJVoFz4PoxXESCueTD0UwIcrftvzQPJZzVZEOY/ND4ZFq +Bj4TbCybSNuFIIAXn2yL1x5oLGz5wuEr7XClqUECVZPTyDv2ozg7+L6NRXNnQ3DQ +jL3QqEaH0M0hA/4FX7ySXrSC2BFX5LZzv8cjKla+3jqJUbUxokxEWMfYVNU+JjwV +vS4ieVqIsWcfd+FhqFCpvWj92PpJB5Lk6GkuGi026lgYutK30Gx133QLQjrNRRLG +4dM0KMevpSM1Ug7dXy60dIjlJkFjekEf7umCGFNILf3UKQTd1irnAgMBAAGjUzBR +MB0GA1UdDgQWBBRXGByy5N4UbiP0NcTKWpj2yDsJ3zAfBgNVHSMEGDAWgBRXGByy +5N4UbiP0NcTKWpj2yDsJ3zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4IBAQAaLNCcpUpKjxz3DKPU632uqljaliu5mCm9OAQRDLccfZifIHacyxx45rm7 +8TOSgQnSS41UtsvTMn63LXqWCi+P0K6LVZUDAiQzcehGsKYyj7yG8UiQev3qUXuU +HRD+n92AVKYz4ABkKqepgVqE1G0CQeFISU/czTv4r3+Qv5IGXWuZ9D9fNu2Al8DG +Zo4FdcxsBESksDP07Gjj3zMRgV0uhsoSP7gEc2zrucfSIZasEuIgPGL07vMc4Ybm +EaY/cZcuSj+o012wmP6YMQw4t74Fxqe41l4lY/yPtgUO9LDx9ygiDzEKaMCTea8M +33Oqg/tD4bzToCPlpeh7qhpUfLpP +-----END CERTIFICATE----- + diff --git a/test/opensslcrt/normalCertChain/client/client.csr b/test/opensslcrt/normalCertChain/client/client.csr new file mode 100644 index 0000000000000000000000000000000000000000..ff6829e03e04ae02ef15b910f4f3b44e4748f0d1 --- /dev/null +++ b/test/opensslcrt/normalCertChain/client/client.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICwzCCAasCAQAwfjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQH +DAJOSjEPMA0GA1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxFzAVBgNVBAMMDnRl +c3RjbGllbnQuY29tMR0wGwYJKoZIhvcNAQkBFg50ZXN0QHdvcmxkLmNvbTCCASIw +DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOXrx/2EIAx68hLf+jldQC2mzUiB +oGYy3tEY6LqOvOLz4ETg+qi53wymTtpZX5LqVjrckCb5NvYIojWttILZxIp4a+sX +pifD32MV6E0NKQwiUEPm6M7ibY0Og+fiHRAnvq23nMRIsSDCYlCFi+z7ucA66hyi +PeVia91w4pfOvo7V7jAelOA6N0T+yluTBWqMp/Zq0CClK9kOrLEOBFEGmAWuSslH +x79tR7KSvcdU+7i1hmez3lPw1XyUHS5PWDpXciFcK7QAuJJWvyklVkfZTKUpt5C6 +vOGdQwpMtMJ4s0NHhvhm0iPISQBZ928HwZrvfUQ4rZc+RmGP/5g+c3g16R8CAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4IBAQCsGnU+byNkW1GA8XNiya/xcwFhBrgZ8ytk +neqb0s1VdcTpM0F3OhnaYsBbAuWvuuXnKX2fUGD1TPiB/cuiOULxMkR2FQvA6l6g +pOUTRP1sHsUbAB1nLEZ6ZNmYISLY5e52265SFGw8moQ29TEXjg1Fpsgpiol5Xc9O +pf8ojhCgKmCgUgMd48N11BPzTB+8vwiBG7c+pJr3RU+g/FjtGXTE2g7GqWDAvcaC +4F7L+PBtlocjI4K2ci1yEsFYXoa0ZmOZjLvTOS7A3hDFM1ga1pjTh1hejGwC+d1u +uaFXa2XJa89jimpZH609gAv+6KlGXB5SVgfwVG9DjKMFmJJYVmpV +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/normalCertChain/client/key.pem b/test/opensslcrt/normalCertChain/client/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..d0e8e05724fdfadc4ea309d388bac03c07e18029 --- /dev/null +++ b/test/opensslcrt/normalCertChain/client/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA5evH/YQgDHryEt/6OV1ALabNSIGgZjLe0Rjouo684vPgROD6 +qLnfDKZO2llfkupWOtyQJvk29giiNa20gtnEinhr6xemJ8PfYxXoTQ0pDCJQQ+bo +zuJtjQ6D5+IdECe+rbecxEixIMJiUIWL7Pu5wDrqHKI95WJr3XDil86+jtXuMB6U +4Do3RP7KW5MFaoyn9mrQIKUr2Q6ssQ4EUQaYBa5KyUfHv21HspK9x1T7uLWGZ7Pe +U/DVfJQdLk9YOldyIVwrtAC4kla/KSVWR9lMpSm3kLq84Z1DCky0wnizQ0eG+GbS +I8hJAFn3bwfBmu99RDitlz5GYY//mD5zeDXpHwIDAQABAoIBAQCnLeM0Nm8rQ/Zi +vRNvxJtW1nNr5j1gMlsLxUXr6L/1cgi/bKs2JjjGNOMfJ180L0pV8Gysugc5rJtt +1olrn7amTNuDjKWXQnhazuIjrI8NMKIWTX84dzHbIBPPdv1U8uFV5S2LF6QbwtvD +2uccgQjWesAh4+KHuSHfWSaZ5Y1vw0MGNCRkyVx13ovDM0gjT694aI6y0Qd/mXm4 +oAvTZCU7tJ/45nSDiWqSQdOJJdQjV0nOoKj0Fpcr7mF11QYALjXmC04YfO5DYaCd +O4wQxZYGzSWyYoHLFonlc5uL1M/ZYP6n2HPvHCeyvJY6q+kTA3OmhnOIKwoCpObb +FV3JETmhAoGBAP4Kji7Ig2bNK8oYNm6sUFJR1tnLSHIkSmNCbpJ4vQE0RBIU1/US +NrrUBAWm6AppFZPsUsLc42O96GOv/P4tOioKsFUJCi/w59A3Xt04tapHBTEfUiJy +bOLIC1WYEhsZ4u1e1u6hfvyAuJVYLqkiPVsRLp2QV6PUP1cSyNY3mHEPAoGBAOex +nZmytTkL/na/aQNKPEFQaBvQJQUKw7tWcwZyjgm4Ar03ojYw04TkuP0UHw+M0R+M +rjoeuGrnOktjwlMlmt8rMRxO8WeVr1tA5U6658tzt2Rprq1INkvVeLoBtGP29FqM +LI2YG4uKou43LU+jgPrkLkEt0bhb/mAdhHrR4ObxAoGAdgKZQgpLYDn3GY5d2tOZ +DGSQFeRk5wEMvUdi7g/AXQrWhD/CgknPusI6jBWYvR1LtMeXOoY561+Q0J40PC7u +UhFdEGN+o/6Y8RSHsORjH5KWStdt5Cqbgk3DViOqZYSE8heYaIoE3288T8QDCPaq +4d79dJxU2foC4oQLX9e7rOkCgYADaV0dt0Dt3xxXGUhtkPlEKO/vgOgao+bv6jz1 +Wlh3EiuQJ7KOw7dJnKiQqWwvqW4m3cZu+qbShCcalxR0bvhR0uv9M7hgQxb67AC0 +YRIqr8CCjP/Sc17BTRpi+sVyN1+vuaKqTxQQwPDXOx7CrnCmwRdhRFBzO3+KYMTj +nhWGsQKBgFjQkaxPXGpNaMOGoEzNir/4gvGsxPICUEACfA/HAAaTZm23Ly5+ThGO +Pvg4hqY1C1T+ksn0pGIlkGRi3Tq/QnX4NyxouskQHGswDq6wVGYHu00HizmHgYhx +XXuYhcynXchaHnpHNVxuq85cw5LWs3rPwkH50Kq4jnJ7wGhQLVB9 +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCertChain/server/cert.pem b/test/opensslcrt/normalCertChain/server/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..12664c29bfda75fce75584e7bd8f2da99a396085 --- /dev/null +++ b/test/opensslcrt/normalCertChain/server/cert.pem @@ -0,0 +1,71 @@ +-----BEGIN CERTIFICATE----- +MIID1DCCArygAwIBAgIBAjANBgkqhkiG9w0BAQsFADBvMQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MREwDwYD +VQQDDAhTZWNvbmRDQTEhMB8GCSqGSIb3DQEJARYSc2Vjb25kY2FAd29ybGQuY29t +MB4XDTIyMTAxMjA5NDQxM1oXDTMyMTAwOTA5NDQxM1owbzELMAkGA1UEBhMCQ04x +CzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsMA0RldjEVMBMG +A1UEAwwMdGVzdHNlcnQuY29tMR0wGwYJKoZIhvcNAQkBFg50ZXN0QHdvcmxkLmNv +bTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKfbo1HR8Hxv3qMqNGh1 +ufPv+Hb2UA8uTzyMPv1gG+331X2sQUqbv15jEuTLcjxoOYuCOmiFQPHJpj+I75cP +RrIuscOj8BUnWJlpejaPaICS2XOwzMG5Ae5rUkPoLyVcjhlu6Y3A4kFYaKvH8+58 +XwI1FM8Obswe+XWnkZIj4C5OChYIj71Zh2IS2zpPQjlnYkAA3FJmK7LKocMubjZO +SlhvIwzbBI9069IgEOOrg1xKrlm1gjbmfGFlXMYi9zLEm0SO8dhU5cWLY8QQKYkd +uEF9hurwe+fspGRvTPhbT7IR8yW3xw7kah7TUDs5b3jM6Me9U0M7frb23s7wT31p +2mMCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0EHxYdT3BlblNTTCBH +ZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFC7exNc3BjDBOciBzo5qVY+d +qPKIMB8GA1UdIwQYMBaAFMeDrbiKqXhHPcdp/CNpS9Vn/MTfMA0GCSqGSIb3DQEB +CwUAA4IBAQBq2XwbLsO8cImd1dOu05V6U0zW+/QM9WjP94hVUJiyQ+yexHNLKcvP +x+2G44n4/qUr80tvgHr1/ok9yQzaL+CKQx/hLYs45IGElLcsPU2tV0QZFOlcUgmu +mZODjuTUZzH5oZMcVuPm1IiAV7TFhn7Z0q4I+tticZnPLsdW+bkPQQ2ZrJodmbBb +85vWb92qi9kWUlBPaLx7ntvP7FMOWQYPyWpQE2ipV3vIUqXUxmXBjanHjOxIjMUE +DpRpJkrpPUrznj/1/jYpUQR7ohEsRSbNzKgJzKIgm2QvT9C7cK40bsp9eX3VSEYU +f1JHV5LnmGU3073GUBamF+OOlxWyT47A +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIBATANBgkqhkiG9w0BAQsFADB4MQswCQYDVQQGEwJDTjEL +MAkGA1UECAwCSlMxCzAJBgNVBAcMAk5KMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNV +BAsMA0RldjEPMA0GA1UEAwwGUm9vdENBMR8wHQYJKoZIhvcNAQkBFhByb290Y2FA +d29ybGQuY29tMB4XDTIyMTAxMjA5NDIwOFoXDTMyMTAwOTA5NDIwOFowbzELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAkpTMQ8wDQYDVQQKDAZIdWF3ZWkxDDAKBgNVBAsM +A0RldjERMA8GA1UEAwwIU2Vjb25kQ0ExITAfBgkqhkiG9w0BCQEWEnNlY29uZGNh +QHdvcmxkLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANcpo7lG +5ttiS8yJtzxGVlO7lA0RqODL1xO84ES+jE3CPLCVo83l95E+ZHQZZWLyAr/DrGCm +rP93+ODINGkMvLZlEV+vYcxzMZeRwctITygaSaxV6bON9m2F3k1iQZS/gKkL9L1i +/F3nkOAEuCYMgwHfPy+/E97H9ECRpwt2mzld40UQZj1nQY1nc8l5F10p5GSiNp5w +98N/Fnk9ZCESUKtPviK541ljvs+tvgSvuKtJJIbdhbn+xW5O9hQE79u9iumo1zBS +jZKMMMdRqTIaqt/0Ofqu39g/OsfeLCfGr+NF9suTaeJEMFPMzOhSdgy9zf+fykVJ +GuXxfedLzbgP06kCAwEAAaNTMFEwHQYDVR0OBBYEFMeDrbiKqXhHPcdp/CNpS9Vn +/MTfMB8GA1UdIwQYMBaAFFcYHLLk3hRuI/Q1xMpamPbIOwnfMA8GA1UdEwEB/wQF +MAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGDOfMIV1AC5bdbpnhSDhpzkBjh3OXqj +PUplhCm+VSFUcSE5vTe4EnUsCmpIAYkAA/2fnr6Kb9GFJeuoWlWpo4vRW9g8qvRD +B7AoV/ojafPb4V2epZLaHpJQcxyqplGFKqPKDAXZlXLXUbWLAutq8q7GLL3MlqDF +nRDHBafXEHJOBQG7HyTDt20uoIM0+10ehJW0LlO2UD/SQ5ZEkYwB5fNk4rJMgtmb ++Nibo2x2ZdMd0rakv5CGCjf13aMczngYcbNjnUFjiLofHTEcoYmDVTqjd4W416Fj +2R+WfdalPEctoRCjDMi7dZLxQOY7/IesZbXJkVMtbP8lrVb+9Xl7lj0= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIID0TCCArmgAwIBAgIURA/lUFrZwOfJXaaLF9ol7wkhAxIwDQYJKoZIhvcNAQEL +BQAweDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQHDAJOSjEPMA0G +A1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxDzANBgNVBAMMBlJvb3RDQTEfMB0G +CSqGSIb3DQEJARYQcm9vdGNhQHdvcmxkLmNvbTAeFw0yMjEwMTIwOTM3MjFaFw0z +MjEwMDkwOTM3MjFaMHgxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJKUzELMAkGA1UE +BwwCTkoxDzANBgNVBAoMBkh1YXdlaTEMMAoGA1UECwwDRGV2MQ8wDQYDVQQDDAZS +b290Q0ExHzAdBgkqhkiG9w0BCQEWEHJvb3RjYUB3b3JsZC5jb20wggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCgvTn1zACCZ4uny3BW8Utilwoztkhb/XM+ +ZI/trCZg1smsnuCNJHyIJVoFz4PoxXESCueTD0UwIcrftvzQPJZzVZEOY/ND4ZFq +Bj4TbCybSNuFIIAXn2yL1x5oLGz5wuEr7XClqUECVZPTyDv2ozg7+L6NRXNnQ3DQ +jL3QqEaH0M0hA/4FX7ySXrSC2BFX5LZzv8cjKla+3jqJUbUxokxEWMfYVNU+JjwV +vS4ieVqIsWcfd+FhqFCpvWj92PpJB5Lk6GkuGi026lgYutK30Gx133QLQjrNRRLG +4dM0KMevpSM1Ug7dXy60dIjlJkFjekEf7umCGFNILf3UKQTd1irnAgMBAAGjUzBR +MB0GA1UdDgQWBBRXGByy5N4UbiP0NcTKWpj2yDsJ3zAfBgNVHSMEGDAWgBRXGByy +5N4UbiP0NcTKWpj2yDsJ3zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4IBAQAaLNCcpUpKjxz3DKPU632uqljaliu5mCm9OAQRDLccfZifIHacyxx45rm7 +8TOSgQnSS41UtsvTMn63LXqWCi+P0K6LVZUDAiQzcehGsKYyj7yG8UiQev3qUXuU +HRD+n92AVKYz4ABkKqepgVqE1G0CQeFISU/czTv4r3+Qv5IGXWuZ9D9fNu2Al8DG +Zo4FdcxsBESksDP07Gjj3zMRgV0uhsoSP7gEc2zrucfSIZasEuIgPGL07vMc4Ybm +EaY/cZcuSj+o012wmP6YMQw4t74Fxqe41l4lY/yPtgUO9LDx9ygiDzEKaMCTea8M +33Oqg/tD4bzToCPlpeh7qhpUfLpP +-----END CERTIFICATE----- + + + diff --git a/test/opensslcrt/normalCertChain/server/key.pem b/test/opensslcrt/normalCertChain/server/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..71ec9e7e81db54ff9436879ed610102db7e881fa --- /dev/null +++ b/test/opensslcrt/normalCertChain/server/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAp9ujUdHwfG/eoyo0aHW58+/4dvZQDy5PPIw+/WAb7ffVfaxB +Spu/XmMS5MtyPGg5i4I6aIVA8cmmP4jvlw9Gsi6xw6PwFSdYmWl6No9ogJLZc7DM +wbkB7mtSQ+gvJVyOGW7pjcDiQVhoq8fz7nxfAjUUzw5uzB75daeRkiPgLk4KFgiP +vVmHYhLbOk9COWdiQADcUmYrssqhwy5uNk5KWG8jDNsEj3Tr0iAQ46uDXEquWbWC +NuZ8YWVcxiL3MsSbRI7x2FTlxYtjxBApiR24QX2G6vB75+ykZG9M+FtPshHzJbfH +DuRqHtNQOzlveMzox71TQzt+tvbezvBPfWnaYwIDAQABAoIBAQCPFoHAK5Au40YM +HNwT99cOBI/vCMTyS+2rlXnUj2r/jfZlbMMzkFSvZxEiC/NTXx0+uUKE+qKD+ftH +yblDMfh3x6otNcBgp+u0yt8tR04z2/qVzi6dLNJipQW5cWFPHfjb4VoiRjwYq/5+ +ALMFputufEVCw/Da+8R28OL8iqx9ixNN/bfxktbP5NqGkcEUun5wE+/0KnFzylcV +XyKIzFFun8U6P0oupjv+aoB12SET/X75Xgh3Elx5/Q9FedHRXGjFcam6NJhOG79s +G6IkyciKP7Pf1SYFX6b5ps2bo3oPKWVvJLGMlxSuxJpqmoTQOzochoZF8uak+hkQ +eHl07+TRAoGBANBEVtTyMkRv6dGgmRCPsqolhxsOkHHd5+YOThg5d3EDWRcsUidR +VzwU9/GMrF3grV7FuQk2GYqMRDjFg87+/0O0H6vANxXjNc6sv+dblOikbAJoLksr +245OeCgd383K5+f00pOHvd8KfrxwQPqarpvW2JORXHludgZRgx8FGulfAoGBAM5U +YSKDw+d/D6+PHazakLucZNc7hUUd8Gvh8EYdUxObX+Dr8UIjV/1qqYZzhPUNDpn9 +//cYq8ndAQ7gCTANESJc//ikqDBXHJ3AoMwpNX86ozNUyt6+eoQx2tizgkEMzndo +BVt7A3RjWjn/+bJJX6mIsjg8salQrMW1bOyECHl9AoGBAJwdfhFl87RFR7oxbktx +y/Wq59mqUzBnrPtQYc3a1ePLJK8wM+zxFjkdZraUQmikkJDoGcoD2aV3e3Qq6qDx +mJtBnDP8g85OYPkpmTht9/NjvOsY+Qq0N4I24+7+ZdM3dBr19BtOt09H6LSMWMkB +xj1fET2cyvrjiGk4FNfd1cx1AoGBAL+xdW1zrhbt3czl0lQ93Cnx615sVi0Y273f +dDQwGnck67c0fjlMTPuMlWPs/6IMN3yql50itrgdNFZ1nxOdkEW00bxYfkorJNML +nFkSEDncaLPQG4tGvN0E1KZwYJu/IjOd2Rxc9aC0jadFQt95e/8umSXWfdkostwc +6s3y/UyhAoGAJAG6S/k1Xg0TbVxeZvMlnmfEc7WKh1QCnsH9Va/VnPXlIlTWmHu/ +xBtBb0kC0IF5zWgHhSADz3SHDsYZ87SEqoQxx3BKzd+EU52RF2x74GQDY49mQlUR +zMnxYP3TUSEfkTuSpT8ZwuxQ7KSk5nREDzKIVfcxgFw/cmrW/eChT7w= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/normalCertChain/server/server.csr b/test/opensslcrt/normalCertChain/server/server.csr new file mode 100644 index 0000000000000000000000000000000000000000..64037cc9b7f13aa23a7aec95eafe0b396a33d471 --- /dev/null +++ b/test/opensslcrt/normalCertChain/server/server.csr @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICwTCCAakCAQAwfDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkpTMQswCQYDVQQH +DAJOSjEPMA0GA1UECgwGSHVhd2VpMQwwCgYDVQQLDANEZXYxFTATBgNVBAMMDHRl +c3RzZXJ0LmNvbTEdMBsGCSqGSIb3DQEJARYOdGVzdEB3b3JsZC5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCn26NR0fB8b96jKjRodbnz7/h29lAP +Lk88jD79YBvt99V9rEFKm79eYxLky3I8aDmLgjpohUDxyaY/iO+XD0ayLrHDo/AV +J1iZaXo2j2iAktlzsMzBuQHua1JD6C8lXI4ZbumNwOJBWGirx/PufF8CNRTPDm7M +Hvl1p5GSI+AuTgoWCI+9WYdiEts6T0I5Z2JAANxSZiuyyqHDLm42TkpYbyMM2wSP +dOvSIBDjq4NcSq5ZtYI25nxhZVzGIvcyxJtEjvHYVOXFi2PEECmJHbhBfYbq8Hvn +7KRkb0z4W0+yEfMlt8cO5Goe01A7OW94zOjHvVNDO3629t7O8E99adpjAgMBAAGg +ADANBgkqhkiG9w0BAQsFAAOCAQEAC2AVrLOyTNhAdoVMzdqlXNLBmuoKSJdJdePF +uM3jkNkgfV77opTDFVL2nYxTLddfUpYq8xMpqK2shXWz5nrjn+XbqVqDyP5F6oVl +Rp0EiTKPolvr6+qREnquF7AKRn6qZkSst3/QbdFJrIZ6FjfReFxR+8d+MkhdKcUL +hX0FD8/njwO6twXWqBADZrV8rCsfuIER8+nCVCo827J7ZPNtvli31aFEi1QXo1Em +9Azvn6EULZyLUdvgu5hANyXNRa0yTY+QGZ37lTHTAuogr7PwCW2PTr5AVLU3oaDA +KJT1hWaJgpSKRI9nFut7BVaGjrRQHNH+HTN12mduJFaIIDt2Qg== +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/serExpCertCliNoCheck/CA/ca.csr b/test/opensslcrt/serExpCertCliNoCheck/CA/ca.csr new file mode 100644 index 0000000000000000000000000000000000000000..d1daee988dbb89be209fd79b626f0e491ec4886e --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/CA/ca.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsTCCARoCAQAwcTELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxCzAJBgNVBAMMAkNBMR8w +HQYJKoZIhvcNAQkBFhB5b3VyZW1haWxAcXEuY29tMIGfMA0GCSqGSIb3DQEBAQUA +A4GNADCBiQKBgQDFR8aCTgT27M5MDDi5cM/BlJ+kGVMCPGlNCbE7/50pCapaLQVh +q+NNl0InhuEb/zmmg7CVHboWiqdBiOcn/inoSZAGnZbstTX7LrhjAISu4wVQYpL2 +k7SMxecZx6M7XUeamasnZrOqgWVD/6wweTPka2fIy5OzU/kuMXCHF3jCMwIDAQAB +oAAwDQYJKoZIhvcNAQELBQADgYEAwMS2H9xmphiRpxO7YrOcHF4QmXIIYh+ibDCM +G3y5254jkNCOYMz18aHFb9LpQq7+8quxMlEsn7G+1G+YJnUIDr4c8TxaRP+LRzZW +mcJsiEx+bbH2uhVpyCY5HXSUFMvOiam+h3fULVcqtlKU2DQzW+8TBk8gMFop+4fh +FIM1HeY= +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/serExpCertCliNoCheck/CA/cacert.pem b/test/opensslcrt/serExpCertCliNoCheck/CA/cacert.pem new file mode 100644 index 0000000000000000000000000000000000000000..abb595145226b5346e6d1b182a1febf8f7a9f515 --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/CA/cacert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICZDCCAc0CFHQEAvrAjyEw3R4xI3BMlSeRXPD9MA0GCSqGSIb3DQEBCwUAMHEx +CzAJBgNVBAYTAkNOMQswCQYDVQQIDAJHRDELMAkGA1UEBwwCU1oxDDAKBgNVBAoM +A0NPTTEMMAoGA1UECwwDTlNQMQswCQYDVQQDDAJDQTEfMB0GCSqGSIb3DQEJARYQ +eW91cmVtYWlsQHFxLmNvbTAeFw0yMjEwMTAwNjQ0NTBaFw0zMjEwMDcwNjQ0NTBa +MHExCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJHRDELMAkGA1UEBwwCU1oxDDAKBgNV +BAoMA0NPTTEMMAoGA1UECwwDTlNQMQswCQYDVQQDDAJDQTEfMB0GCSqGSIb3DQEJ +ARYQeW91cmVtYWlsQHFxLmNvbTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA +xUfGgk4E9uzOTAw4uXDPwZSfpBlTAjxpTQmxO/+dKQmqWi0FYavjTZdCJ4bhG/85 +poOwlR26FoqnQYjnJ/4p6EmQBp2W7LU1+y64YwCEruMFUGKS9pO0jMXnGcejO11H +mpmrJ2azqoFlQ/+sMHkz5GtnyMuTs1P5LjFwhxd4wjMCAwEAATANBgkqhkiG9w0B +AQsFAAOBgQARkRz98PEu58ReSL5VSBDu41a+03zm9uj5AflnJfGFRTtdjpJINlgG +yA2t9d85oPrdUtVrkTACMQTQxMvn7TdmjeLdRokqymyMYOAlUY9A0xyXSLLo8S6G +7wrOPOUqiaBVOFahoo6FN0pItETgsPfa5iqTjUqpBifZ21xtBY7jqA== +-----END CERTIFICATE----- diff --git a/test/opensslcrt/serExpCertCliNoCheck/CA/cacert.srl b/test/opensslcrt/serExpCertCliNoCheck/CA/cacert.srl new file mode 100644 index 0000000000000000000000000000000000000000..7c8154c99bccbe1bd01b9ab008cdc58309702c80 --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/CA/cacert.srl @@ -0,0 +1 @@ +5075711B813D68F7309883F4E543B8C264FDC9C8 diff --git a/test/opensslcrt/serExpCertCliNoCheck/CA/cakey.pem b/test/opensslcrt/serExpCertCliNoCheck/CA/cakey.pem new file mode 100644 index 0000000000000000000000000000000000000000..a63424c81a85f9557b463eb7748eb2d1ccd38416 --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/CA/cakey.pem @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQDFR8aCTgT27M5MDDi5cM/BlJ+kGVMCPGlNCbE7/50pCapaLQVh +q+NNl0InhuEb/zmmg7CVHboWiqdBiOcn/inoSZAGnZbstTX7LrhjAISu4wVQYpL2 +k7SMxecZx6M7XUeamasnZrOqgWVD/6wweTPka2fIy5OzU/kuMXCHF3jCMwIDAQAB +AoGABel8vXLxGyVFmWnUWVpUH40Aq75Gio6c6T4dHZsvbodnn4Qx4RdxnGWoCd80 +583iQGc5534YkYxeLsyXgM9RGiNp93G+2FCmIqP7DKb2ewdVKragb6j/rmOUEqV7 +BlQhSFg7j/oJWA+wnNisSxtlayRZXAwDFk9tEaZxzg5QGmECQQD84vTa8AAcaODg +9tYAwg3jlfN361v224NSiQGoK4vuzoXvcmz4kzax3jPlXYpyoGKD2uz6ZJl9L6zz +WeiGXCKvAkEAx7WPc5/gq8XlXjhYE6+ZdkkJACRwMYdCYMxWfJPdbHAnwjeofqOR +ZRPEVZi2wB7oYw2W0MaFcCe8OPIOpjUJvQJBAM3nOTi+a0i2eEuTj0GEv2xL3rYI +c5O2veFI3omAE2q4y0LgZyuqKcF/S7/4fs7AGhaD/aoOmQ7d77MgxHwJrs8CQFHx +urHb2SlCY1Ic1m8Fb3eskiffy7VW5bKoRZiKZ6GWB4pWoimMXh1WVgo1Rk9gn0P9 +kHOL8bbKZx4KnF4whxUCQHVsGTDxAZQA31ZEG0QVJDAiRGpG5Adp/Yz48F0JuLzh +DUa9LVvkoV22x0N/207aQEYExmxfM44yestJElz1CSI= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/serExpCertCliNoCheck/client/cert.pem b/test/opensslcrt/serExpCertCliNoCheck/client/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..8e2a0d41d5bddec2cf35d6801bcccfed72448ca6 --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/client/cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICZTCCAc4CFFB1cRuBPWj3MJiD9OVDuMJk/cnHMA0GCSqGSIb3DQEBCwUAMHEx +CzAJBgNVBAYTAkNOMQswCQYDVQQIDAJHRDELMAkGA1UEBwwCU1oxDDAKBgNVBAoM +A0NPTTEMMAoGA1UECwwDTlNQMQswCQYDVQQDDAJDQTEfMB0GCSqGSIb3DQEJARYQ +eW91cmVtYWlsQHFxLmNvbTAeFw0yMjEwMTAwNjQ1MjBaFw0zMjEwMDcwNjQ1MjBa +MHIxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJHRDELMAkGA1UEBwwCU1oxDDAKBgNV +BAoMA0NPTTEMMAoGA1UECwwDTlNQMQwwCgYDVQQDDANDbGkxHzAdBgkqhkiG9w0B +CQEWEHlvdXJlbWFpbEBxcS5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGB +AJ6BNHeqVkLgwI9xHNYzAo1Omjd70pqOnpiKlhj5+RGETGPoKTRjTYKKShJHf9h6 +/HqPzVaY88wdngwMFpc8L6m4IazA7JCRFG5Zyel2rf4wJtmusN5Akf2Sf1Ehipmx +sA1NuXMeTQZDM8iMcnagsZJgmJJIK+DjgzC/RRcVUmj/AgMBAAEwDQYJKoZIhvcN +AQELBQADgYEAslmntRuNWFLzZ7gt8GlkOCiHeLG/YuaBxsx3UQ3iXxS0ZCXtd98Y +XDx2kwvsKlUQRIUQpyh5MPi2Pa5NEQHKgjNxuoi/Vfzzez7Fmjf13FP9881uRXl7 +/R9tGJh8nxS7/hRyrk8mvi5/DrsEVxGIMOdtOCzFayh1sNlDcgyDI9g= +-----END CERTIFICATE----- diff --git a/test/opensslcrt/serExpCertCliNoCheck/client/client.csr b/test/opensslcrt/serExpCertCliNoCheck/client/client.csr new file mode 100644 index 0000000000000000000000000000000000000000..2b2fa04943e644ca150d975c0017e8e2e9f10bc1 --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/client/client.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsjCCARsCAQAwcjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxDDAKBgNVBAMMA0NsaTEf +MB0GCSqGSIb3DQEJARYQeW91cmVtYWlsQHFxLmNvbTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEAnoE0d6pWQuDAj3Ec1jMCjU6aN3vSmo6emIqWGPn5EYRMY+gp +NGNNgopKEkd/2Hr8eo/NVpjzzB2eDAwWlzwvqbghrMDskJEUblnJ6Xat/jAm2a6w +3kCR/ZJ/USGKmbGwDU25cx5NBkMzyIxydqCxkmCYkkgr4OODML9FFxVSaP8CAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4GBAG76gRwIr6oXzTsWkK67uaU/PzlnuUYVQ86M +maA4cY5J9X/t7hXOFcSpcKcSrVCgVydKV407yshjiiq4teYGOsx6fizPyyT+p28O +7mbssL03JtgVwJfJSXrUe7lQUuHFTeB99uwMfzYE9nqbhY5jJLknEpOoza3JbKFG +S6Ub/LL0 +-----END CERTIFICATE REQUEST----- diff --git a/test/opensslcrt/serExpCertCliNoCheck/client/key.pem b/test/opensslcrt/serExpCertCliNoCheck/client/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..a5b35b73169aa670c5de17b2cbd2249d47e829c3 --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/client/key.pem @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCegTR3qlZC4MCPcRzWMwKNTpo3e9Kajp6YipYY+fkRhExj6Ck0 +Y02CikoSR3/Yevx6j81WmPPMHZ4MDBaXPC+puCGswOyQkRRuWcnpdq3+MCbZrrDe +QJH9kn9RIYqZsbANTblzHk0GQzPIjHJ2oLGSYJiSSCvg44Mwv0UXFVJo/wIDAQAB +AoGABjVbX8CNRmtVP4kXrxAqbmyAv+GauzEQ5zqubGYGKnh7IYKrcoDdPVq7UgH4 +8PrIPui5C0ZMbldOpHwoul3CJBzk/uomkc9NEwZichqmbhS2cSehkBqh1WZexZqn +K+yDEtUMb1hbn9KuiH1SOdjLk25PhIM2v/nci+xBN8bf5UECQQDTI1IPSemcNlNM +48SlfCuA3QwJFW8MMhyeT0ci6HHbZ9ihoFFDhHJJsqiewteOVJR0sUEpgCpt64o7 +aHQTxzafAkEAwC7wxNWd5Ow8A61Zqb/DA2fvJeJ3dmZ1INkt2qd7vUAxeAONSBZy +b4vHSRCKIHt8YxU1hamzAfTRN9sJt0eRoQJBAKtcMU+jR2y3UqhG++TrgEtYHFqL +ANO/ICcEZNHaIf5WrCWRfiL0LeXOkLg5nsUvtEV5T0+la8pGrqxEvGEJADUCQGan +LBgXelVTEeNs7t3K/iGNnKIBy4nExH+dQe7vUxsNdN8EFq0QwGNwqCwQ15buHszW +AVa6BFHhMPfC2fe6FMECQGt2g0Z7tyYFvSKy3fgb4+txM6tyes8y6wA0GEMLCqlV +/vSjH3oSNLahwEq+rIhVyZRISS+q9fH1SKdYbZZUc0w= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/serExpCertCliNoCheck/server/cert.pem b/test/opensslcrt/serExpCertCliNoCheck/server/cert.pem new file mode 100644 index 0000000000000000000000000000000000000000..56541b2837abca99e661240ec8268aef7cd91ae7 --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/server/cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICZTCCAc4CFFB1cRuBPWj3MJiD9OVDuMJk/cnIMA0GCSqGSIb3DQEBCwUAMHEx +CzAJBgNVBAYTAkNOMQswCQYDVQQIDAJHRDELMAkGA1UEBwwCU1oxDDAKBgNVBAoM +A0NPTTEMMAoGA1UECwwDTlNQMQswCQYDVQQDDAJDQTEfMB0GCSqGSIb3DQEJARYQ +eW91cmVtYWlsQHFxLmNvbTAeFw0yMjEwMDcxNjAxMDBaFw0yMjEwMDgxNjAxMDBa +MHIxCzAJBgNVBAYTAkNOMQswCQYDVQQIDAJHRDELMAkGA1UEBwwCU1oxDDAKBgNV +BAoMA0NPTTEMMAoGA1UECwwDTlNQMQwwCgYDVQQDDANTRVIxHzAdBgkqhkiG9w0B +CQEWEHlvdXJlbWFpbEBxcS5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGB +AJPsvIFYux0IlCRZOk1b3vwR5GZZrPqOZPEV2F0HmvHzFLjtf1hRzfC0XOjMu1XB +vJT6QaiF5FfkO84bB04p12PfN1AldYwJtpDqO4n//pFfon3XaXsYQldcI6qBxcN/ +1tiB23kaEU9vdjepbSbAsEO5nCfOKqTq2vgmf9MAmFdDAgMBAAEwDQYJKoZIhvcN +AQELBQADgYEAlZj31gXePMy5FvyA0ZYzV7fPEHKu3AcPr7CAygfp5YQd7tVvf/jc +kJ3kF+rlM+/wH5S+JNSnJbjhNuHyvlRZNl0pXu5l717gyFMD97DycbSksxlMquOg +IfRQDaSZ7FoAB5jtjdo1aYtureqjFOtAwDsuJaE0/nQ8S8m23k9EPZM= +-----END CERTIFICATE----- diff --git a/test/opensslcrt/serExpCertCliNoCheck/server/key.pem b/test/opensslcrt/serExpCertCliNoCheck/server/key.pem new file mode 100644 index 0000000000000000000000000000000000000000..7149b4de98ca46ce527e7c8491eefc738bd59a5c --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/server/key.pem @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCT7LyBWLsdCJQkWTpNW978EeRmWaz6jmTxFdhdB5rx8xS47X9Y +Uc3wtFzozLtVwbyU+kGoheRX5DvOGwdOKddj3zdQJXWMCbaQ6juJ//6RX6J912l7 +GEJXXCOqgcXDf9bYgdt5GhFPb3Y3qW0mwLBDuZwnziqk6tr4Jn/TAJhXQwIDAQAB +AoGAAUJswHUu8qpWCJEx6+KyXVcRqMVusZtwxJS5COG9sf2t2X08LlZT+I6wk9La +bXp+zo3q7TJmpVDuKW0VfNbiXJYDoB5rYibHrvOtwznkLeZ7a70VWtJg6JnzaRbD +5ybk6iYstXkP2h59/M/gaSJ+PGLm5gnTca8Oa4mIxY88AqECQQDCuhyHg27a4lib +qoDAsauzbEwK/Sflhh3/Jis0T1v4JoOTZa4kLoCJXQnFpg8UCNXTC9THoLaFwKVP +kYU/KW4bAkEAwniNhHt8iNNs5o3eR9TKwottNXD0z8MdcSvQ4pH4I6TXpKlWiQF2 +p6e4gQmcTm8N87zafg8z5m53kCMFQ+yt+QJATfKmHPDV3TSh954m7/uvGaGORw/T +eqNXOp2SydvCmvD/1m7raHxc45+km8O/YWuv/E1OHaMNrTjSc6lyxzfUZQJADBF1 +HsqXANq2AFy/dY+0AXyri0x2NMpz8kj4zoamcnfRVguYLipkFaVn4sIY7BkiMYT/ +viDtZJhoXQ5/TibK6QJBAJhHYvioy/TCWHaob0JoeuzpRE+F8vnnq/Gj2WYzD1PB +FBSl4ufFZ/1fVi7wNvu1Tw2vEKfvkG64zkl/Ogjhhl4= +-----END RSA PRIVATE KEY----- diff --git a/test/opensslcrt/serExpCertCliNoCheck/server/server.csr b/test/opensslcrt/serExpCertCliNoCheck/server/server.csr new file mode 100644 index 0000000000000000000000000000000000000000..f9beacbb19fe9da45616119d0844c4c8065b5b45 --- /dev/null +++ b/test/opensslcrt/serExpCertCliNoCheck/server/server.csr @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBsjCCARsCAQAwcjELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAkdEMQswCQYDVQQH +DAJTWjEMMAoGA1UECgwDQ09NMQwwCgYDVQQLDANOU1AxDDAKBgNVBAMMA1NFUjEf +MB0GCSqGSIb3DQEJARYQeW91cmVtYWlsQHFxLmNvbTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEAk+y8gVi7HQiUJFk6TVve/BHkZlms+o5k8RXYXQea8fMUuO1/ +WFHN8LRc6My7VcG8lPpBqIXkV+Q7zhsHTinXY983UCV1jAm2kOo7if/+kV+ifddp +exhCV1wjqoHFw3/W2IHbeRoRT292N6ltJsCwQ7mcJ84qpOra+CZ/0wCYV0MCAwEA +AaAAMA0GCSqGSIb3DQEBCwUAA4GBAA7ENSgF3BhQCf3mnhUdXPrGmVaUiXtf4mv+ +A6uiEXVxvmmK7JOOY5k2kgfYo550m8e3D2by7bwisyAzmrrfkVA9YGjfEnre/ebG +53/mykM7+vs8oT3xM9xQEGOOquF9/4u+yJBJxWWpEyOY6zEQvqoFVTK+N8pfmReN +nRDrJM/B +-----END CERTIFICATE REQUEST----- diff --git a/test/stub/CMakeLists.txt b/test/stub/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5e914ad81c07c85f6e5f6f00de0498df27e4d27f --- /dev/null +++ b/test/stub/CMakeLists.txt @@ -0,0 +1,7 @@ +file(GLOB_RECURSE FAKE_IBV *.cpp *.h) + +add_library(fake_ibv_static STATIC ${FAKE_IBV}) +target_link_libraries(fake_ibv_static + -Wl,--start-group + pthread dl rt + -Wl,--end-group) \ No newline at end of file diff --git a/test/stub/fake_ibv.cpp b/test/stub/fake_ibv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..af31bdc3c66c847d1655dd530bc392ee62acbdb2 --- /dev/null +++ b/test/stub/fake_ibv.cpp @@ -0,0 +1,1108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "fake_ibv.h" +#include +#include +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define FAKE_FILENAME (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__) +#define FAKE_LOG(fmt, ...) \ + do { \ + char log[512UL]; \ + int len = 0; \ + len += sprintf(log, "[%s %s:%d]", __FUNCTION__, FAKE_FILENAME, __LINE__); \ + len += sprintf(log + len, fmt, ##__VA_ARGS__); \ + sprintf(log + len, "\n"); \ + printf(log); \ + } while (0) + +#define IBV_ERROR (-1) +#define FAKE_NULL_FD (-1) +#define FAKE_NULL_DWORD 0XFFFFFFFF +#ifndef container_of +#define container_of(ptr, type, member) ((type *)(void *)((char *)(ptr) - offsetof(type, member))) +#endif + +struct ibv_device g_ibvdevice[FAKE_IBV_DEVICE_NUM] = { + { + {NULL, NULL}, + IBV_NODE_RNIC, + IBV_TRANSPORT_IB, + "hrn0_0", + "uverbs0", + "/tmp/hrn0_0", + "/tmp/uverbs0" + }, + { + {NULL, NULL}, + IBV_NODE_RNIC, + IBV_TRANSPORT_IB, + "hrn1_0", + "uverbs1", + "/tmp/hrn1_0", + "/tmp/uverbs1" + } +}; + +static uint64_t readBuff[100]; +fake_lock_list_t g_f_qp_list; +uint32_t g_fake_qp_num_gen = 0; +uint32_t g_keyId = 0; +/* ********************************************************************* + 功能描述 : 获取key号 +********************************************************************** */ +uint32_t fake_get_key(void) +{ + return (uint32_t)__sync_add_and_fetch(&g_keyId, 1); +} +/* ********************************************************************* + 功能描述 : 获取qp链表 +********************************************************************** */ +fake_lock_list_t *fake_get_qp_list(void) +{ + return &g_f_qp_list; +} + +/* ********************************************************************* + 功能描述 : 获取qp号 +********************************************************************** */ +uint32_t fake_get_qp_num(void) +{ + return (uint32_t)__sync_add_and_fetch(&g_fake_qp_num_gen, 1); +} + +/* ********************************************************************* + 功能描述 : 销毁srq +********************************************************************** */ +int ibv_destroy_srq(struct ibv_srq *srq) +{ + if (srq != NULL) { + free(srq); + } + + return 0; +} + +/* ********************************************************************* + 功能描述 : 初始化recv wr mgr + +********************************************************************** */ +void fake_recv_wr_mgr_init(fake_recv_wr_mgr_t *recv_wr_mgr) +{ + memset(recv_wr_mgr, 0x00, sizeof(fake_recv_wr_mgr_t)); + pthread_mutex_init(&recv_wr_mgr->wrLock, NULL); +} + +/* ********************************************************************* + 功能描述 : 创建srq +********************************************************************** */ +struct ibv_srq *ibv_create_srq(struct ibv_pd *pd, struct ibv_srq_init_attr *attr) +{ + fake_srq_t *fsrq = (fake_srq_t *)malloc(sizeof(fake_srq_t)); + if (fsrq == NULL) { + return NULL; + } + fsrq->srq.pd = pd; + fsrq->srq.srq_context = attr->srq_context; + fsrq->srq.context = pd->context; + fake_recv_wr_mgr_init(&fsrq->recv_wr_mgr); + return (struct ibv_srq *)fsrq; +} + +/* ********************************************************************* + 功能描述 : recv wr mgr 产生item +********************************************************************** */ +fake_recv_wr_item_t *fake_recv_wr_mgr_produce_item(fake_recv_wr_mgr_t *recv_wr_mgr) +{ + pthread_mutex_lock(&recv_wr_mgr->wrLock); + if (((recv_wr_mgr->producer + 1) % FAKE_RECV_WR_DEPTH) == recv_wr_mgr->comsuer) { + FAKE_LOG("Fake: mgr(%p) recv wr queue is empty.", recv_wr_mgr); + pthread_mutex_unlock(&recv_wr_mgr->wrLock); + return NULL; + } + + uint32_t idx = recv_wr_mgr->producer; + recv_wr_mgr->producer = (recv_wr_mgr->producer + 1) % FAKE_RECV_WR_DEPTH; + pthread_mutex_unlock(&recv_wr_mgr->wrLock); + + return &recv_wr_mgr->item[idx]; +} + +/* ********************************************************************* + 功能描述 : 从qp中获取一个recv wr +********************************************************************** */ +fake_recv_wr_item_t *fake_recv_wr_mgr_comsume_item(fake_recv_wr_mgr_t *recv_wr_mgr) +{ + pthread_mutex_lock(&recv_wr_mgr->wrLock); + if (recv_wr_mgr->comsuer == recv_wr_mgr->producer) { + FAKE_LOG("Fake: mgr(%p) recv wr queue is full.", recv_wr_mgr); + pthread_mutex_unlock(&recv_wr_mgr->wrLock); + return NULL; + } + + uint32_t idx = recv_wr_mgr->comsuer++; + recv_wr_mgr->comsuer %= FAKE_RECV_WR_DEPTH; + pthread_mutex_unlock(&recv_wr_mgr->wrLock); + + return &recv_wr_mgr->item[idx]; +} + +/* ********************************************************************* + 功能描述 : qp链表初始化 +********************************************************************** */ +void fake_ibv_init(void) +{ + fake_lock_list_init(&g_f_qp_list); +} + +/* ********************************************************************* + 功能描述 : fork init桩 +********************************************************************** */ +int ibv_fork_init(void) +{ + return 0; +} + +/* ********************************************************************* + 功能描述 : 申请pd +********************************************************************** */ +struct ibv_pd *ibv_alloc_pd(struct ibv_context *context) +{ + struct ibv_pd *pd = (struct ibv_pd *)malloc(sizeof(struct ibv_pd)); + if (pd == NULL) { + return NULL; + } + + pd->context = context; + + return pd; +} + +/* ********************************************************************* + 功能描述 : 释放pd +********************************************************************** */ +int ibv_dealloc_pd(struct ibv_pd *pd) +{ + free(pd); + return 0; +} + +/* ********************************************************************* + 功能描述 : 创建CC +********************************************************************** */ +struct ibv_comp_channel *ibv_create_comp_channel(struct ibv_context *context) +{ + fake_cc_t *fcc = (fake_cc_t *)malloc(sizeof(fake_cc_t)); + if (fcc == NULL) { + return NULL; + } + + fcc->cc.context = context; + fcc->cc.fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + fcc->cc.refcnt = 0; + + fake_lock_list_init(&fcc->cq_list); + + return &fcc->cc; +} + +/* ********************************************************************* + 功能描述 : 释放CC +********************************************************************** */ +int ibv_destroy_comp_channel(struct ibv_comp_channel *channel) +{ + fake_cc_t *fcc = (fake_cc_t *)channel; + + if (fcc->cc.fd != FAKE_NULL_FD) { + close(fcc->cc.fd); + fcc->cc.fd = FAKE_NULL_FD; + } + + free(fcc); + + return 0; +} + +/* ********************************************************************* + 功能描述 : 发送一个完成事件 +********************************************************************** */ +void fake_send_event_on_cc(struct ibv_comp_channel *cmc) +{ + eventfd_t val = 1; // 任意数据 + + (void)eventfd_write(cmc->fd, val); +} + +/* ********************************************************************* + 功能描述 : 创建一个CQ +********************************************************************** */ +struct ibv_cq *ibv_create_cq(struct ibv_context *context, int cqe, void *cq_context, struct ibv_comp_channel *channel, + int comp_vector) +{ + fake_cq_t *fcq = (fake_cq_t *)malloc(sizeof(fake_cq_t)); + if (fcq == NULL) { + return NULL; + } + if (channel == nullptr) { + auto cl = ibv_create_comp_channel(context); + channel = cl; + } + struct ibv_cq *cq = &fcq->cq; + cq->context = context; + cq->cqe = cqe; + cq->cq_context = cq_context; + cq->channel = channel; + + pthread_mutex_init(&fcq->cqLock, NULL); + fcq->wcq.producer = 0; + fcq->wcq.comsumer = 0; + fcq->wcq.buff = (struct ibv_wc *)malloc(cqe * sizeof(struct ibv_wc)); + if (fcq->wcq.buff == NULL) { + free(fcq); + return NULL; + } + + UNREF_PARAM(comp_vector); + + FAKE_LOG("Create cq(%p) num(%d) success.", cq, cqe); + + fake_cc_t *fcc = (fake_cc_t *)channel; + fake_lock_list_add_tail(&fcq->entry, &fcc->cq_list); + + return cq; +} + +/* ********************************************************************* + 功能描述 : 生成一个cqe +********************************************************************** */ +struct ibv_wc *fake_fetch_cqe(fake_cq_t *fcq) +{ + if (((fcq->wcq.producer + 1) % fcq->cq.cqe) == fcq->wcq.comsumer) { + return NULL; + } + + uint32_t idx = fcq->wcq.producer; + fcq->wcq.producer = (fcq->wcq.producer + 1) % fcq->cq.cqe; + + return &fcq->wcq.buff[idx]; +} + +/* ********************************************************************* + 功能描述 : 产生CQE +********************************************************************** */ +int fake_prodce_cqe(fake_cq_t *fcq, struct ibv_wc *wc) +{ + int ret = IBV_ERROR; + pthread_mutex_lock(&fcq->cqLock); + struct ibv_wc *next_wc = fake_fetch_cqe(fcq); + if (next_wc != NULL) { + *next_wc = *wc; + ret = 0; + } else { + FAKE_LOG("Fake, fetch cqe failed, producer(%u) comsumer(%u).", fcq->wcq.producer, fcq->wcq.comsumer); + } + pthread_mutex_unlock(&fcq->cqLock); + return ret; +} + +/* ********************************************************************* + 功能描述 : 消费一个cqe +********************************************************************** */ +struct ibv_wc *fake_comsume_cqe(fake_cq_t *fcq) +{ + if (fcq->wcq.comsumer == fcq->wcq.producer) { + return NULL; + } + + uint32_t idx = fcq->wcq.comsumer++; + fcq->wcq.comsumer %= fcq->cq.cqe; + + return &fcq->wcq.buff[idx]; +} + +/* ********************************************************************* + 功能描述 : 创建一个CQ +********************************************************************** */ +int ibv_destroy_cq(struct ibv_cq *cq) +{ + fake_cq_t *fcq = (fake_cq_t *)(void *)cq; + + fake_cc_t *fcc = (fake_cc_t *)cq->channel; + fake_lock_list_del_node(&fcq->entry, &fcc->cq_list); + + if (fcq->wcq.buff != NULL) { + free(fcq->wcq.buff); + fcq->wcq.buff = NULL; + } + + free(fcq); + + return 0; +} + +/* ********************************************************************* + 功能描述 : 创建一个QP +********************************************************************** */ +struct ibv_qp *ibv_create_qp(struct ibv_pd *pd, struct ibv_qp_init_attr *qp_init_attr) +{ + fake_qp_t *fqp = (fake_qp_t *)malloc(sizeof(fake_qp_t)); + if (fqp == NULL) { + return NULL; + } + + struct ibv_qp *qp = &(fqp->qp); + + qp->context = pd->context; + qp->qp_context = qp_init_attr->qp_context; + qp->pd = pd; + qp->send_cq = qp_init_attr->send_cq; + qp->recv_cq = qp_init_attr->recv_cq; + qp->srq = qp_init_attr->srq; + qp->state = IBV_QPS_RESET; + qp->qp_type = qp_init_attr->qp_type; + qp->qp_num = fake_get_qp_num(); + fqp->dest_qp_num = FAKE_NULL_DWORD; + fake_recv_wr_mgr_init(&fqp->recv_wr_mgr); + + fake_lock_list_add_tail(&fqp->entry, &g_f_qp_list); + + return qp; +} + +/* ********************************************************************* + 功能描述 : 释放一个QP +********************************************************************** */ +int ibv_destroy_qp(struct ibv_qp *qp) +{ + fake_qp_t *fqp = (fake_qp_t *)(void *)qp; + fake_lock_list_del_node(&fqp->entry, &g_f_qp_list); + free(fqp); + + return 0; +} + +/* ********************************************************************* + 功能描述 : 获取事件名字 +********************************************************************** */ +const char *ibv_event_type_str(enum ibv_event_type event) +{ + static const char *event_type_str[] = { + [IBV_EVENT_CQ_ERR] = "CQ error", + [IBV_EVENT_QP_FATAL] = "local work queue catastrophic error", + [IBV_EVENT_QP_REQ_ERR] = "invalid request local work queue error", + [IBV_EVENT_QP_ACCESS_ERR] = "local access violation work queue error", + [IBV_EVENT_COMM_EST] = "communication established", + [IBV_EVENT_SQ_DRAINED] = "send queue drained", + [IBV_EVENT_PATH_MIG] = "path migrated", + [IBV_EVENT_PATH_MIG_ERR] = "path migration request error", + [IBV_EVENT_DEVICE_FATAL] = "local catastrophic error", + [IBV_EVENT_PORT_ACTIVE] = "port active", + [IBV_EVENT_PORT_ERR] = "port error", + [IBV_EVENT_LID_CHANGE] = "LID change", + [IBV_EVENT_PKEY_CHANGE] = "P_Key change", + [IBV_EVENT_SM_CHANGE] = "SM change", + [IBV_EVENT_SRQ_ERR] = "SRQ catastrophic error", + [IBV_EVENT_SRQ_LIMIT_REACHED] = "SRQ limit reached", + [IBV_EVENT_QP_LAST_WQE_REACHED] = "last WQE reached", + [IBV_EVENT_CLIENT_REREGISTER] = "client reregistration", + [IBV_EVENT_GID_CHANGE] = "GID table change" + }; + + if (event < IBV_EVENT_CQ_ERR || event > IBV_EVENT_GID_CHANGE) { + return "unknown"; + } + + return event_type_str[event]; +} + +int ibv_get_async_event(struct ibv_context *context, struct ibv_async_event *event) +{ + (void)memset(event, 0, sizeof(struct ibv_async_event)); + FAKE_LOG("fake_ibv: call fake function ibv_get_async_event."); + UNREF_PARAM(context); + + /* no block mode return -1 will be ingnored */ + return -1; +} + +void ibv_ack_async_event(struct ibv_async_event *event) +{ + FAKE_LOG("fake_ibv: call ibv_ack_async_event."); + UNREF_PARAM(event); + return; +} + +/* ********************************************************************* + 功能描述 : 获取cc上有数据的cq +********************************************************************** */ +int ibv_get_cq_event(struct ibv_comp_channel *channel, struct ibv_cq **cq, void **cq_context) +{ + list_head_t *node = NULL; + list_head_t *next = NULL; + fake_cc_t *fcc = (fake_cc_t *)channel; + + UNREF_PARAM(cq_context); + + *cq = NULL; + { + pthread_mutex_lock(&fcc->cq_list.listLock); + list_for_each_safe(node, next, &fcc->cq_list.list.list_head) + { + fake_cq_t *fcq = list_entry(node, fake_cq_t, entry); + if (fcq->wcq.comsumer != fcq->wcq.producer) { + *cq = &fcq->cq; + break; + } + } + pthread_mutex_unlock(&fcc->cq_list.listLock); + } + + if (*cq != NULL) { + fake_send_event_on_cc(channel); + } + + read(channel->fd, readBuff, sizeof(uint64_t) * 100); + + return 0; +} + +/* ********************************************************************* + 功能描述 : ack events +********************************************************************** */ +void ibv_ack_cq_events(struct ibv_cq *cq, unsigned int nevents) +{ + UNREF_PARAM(cq); + UNREF_PARAM(nevents); + return; +} + +struct ibv_device **ibv_get_device_list(int *device_num) +{ + struct ibv_device **l; + int i; + + l = (struct ibv_device **)calloc(FAKE_IBV_DEVICE_NUM + 1, sizeof(struct ibv_device *)); + if (!l) { + FAKE_LOG("fake_ibv: create device list fail."); + return NULL; + } + + for (i = 0; i < FAKE_IBV_DEVICE_NUM; ++i) { + l[i] = &g_ibvdevice[i]; + } + + *device_num = FAKE_IBV_DEVICE_NUM; + + return l; +} + +void ibv_free_device_list(struct ibv_device **list) +{ + free(list); + + return; +} + +/* ********************************************************************* + 功能描述 : 将所有的recv wr全部放入cq中 +********************************************************************** */ +void fake_flash_all_recv_wr(fake_qp_t *fqp) +{ + uint32_t total = 0; + fake_cq_t *fcq = (fake_cq_t *)fqp->qp.recv_cq; + fake_recv_wr_item_t *item = fake_recv_wr_mgr_comsume_item(&fqp->recv_wr_mgr); + struct ibv_wc wc = {}; + + while (item != NULL) { + wc.wr_id = item->wr.wr_id; + wc.status = IBV_WC_WR_FLUSH_ERR; + wc.opcode = IBV_WC_RECV; + wc.qp_num = fqp->qp.qp_num; + if (fake_prodce_cqe(fcq, &wc) != 0) { + break; + } + + item = fake_recv_wr_mgr_comsume_item(&fqp->recv_wr_mgr); + + total++; + } + + fake_send_event_on_cc(fcq->cq.channel); + + FAKE_LOG("Recv_wr_mgr(%p) Qp(%p) cq(%p) flash to wr(%d) to cq(%u-%u).", &fqp->recv_wr_mgr, fqp, fcq, total, + fcq->wcq.comsumer, fcq->wcq.producer); +} + +/* ********************************************************************* + 功能描述 : 改变qp +********************************************************************** */ +int ibv_modify_qp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask) +{ + fake_qp_t *fqp = (fake_qp_t *)qp; + + qp->state = attr->qp_state; + + if (attr->qp_state == IBV_QPS_RTR) { + fqp->dest_qp_num = attr->dest_qp_num; + } + + if (attr->qp_state == IBV_QPS_ERR) { + fake_flash_all_recv_wr((fake_qp_t *)qp); + } + + FAKE_LOG("Modify qp(%p, %u) to %d.", qp, qp->qp_num, attr->qp_state); + + UNREF_PARAM(attr_mask); + + return 0; +} + +/* ********************************************************************* + 功能描述 : post recv +********************************************************************** */ +int fake_post_recv(struct ibv_qp *qp, struct ibv_recv_wr *wr, struct ibv_recv_wr **bad_wr) +{ + fake_qp_t *fqp = (fake_qp_t *)(void *)qp; + + UNREF_PARAM(bad_wr); + if (qp->state == IBV_QPS_ERR) { + return IBV_ERROR; + } + + fake_recv_wr_item_t *item = fake_recv_wr_mgr_produce_item(&fqp->recv_wr_mgr); + if (item == NULL) { + return IBV_ERROR; + } + + item->wr = *wr; + memcpy(item->sg_list, wr->sg_list, wr->num_sge * sizeof(struct ibv_sge)); + + return 0; +} + +int fake_post_srq_recv(struct ibv_srq *srq, struct ibv_recv_wr *recv_wr, struct ibv_recv_wr **bad_recv_wr) +{ + UNREF_PARAM(bad_recv_wr); + + fake_srq_t *fsrq = (fake_srq_t *)(void *)srq; + fake_recv_wr_item_t *item = fake_recv_wr_mgr_produce_item(&fsrq->recv_wr_mgr); + if (item == NULL) { + return IBV_ERROR; + } + + item->wr = *recv_wr; + memcpy(item->sg_list, recv_wr->sg_list, recv_wr->num_sge * sizeof(struct ibv_sge)); + + return 0; +} + + +/* ********************************************************************* + 功能描述 : 查找对端的qp +********************************************************************** */ +fake_qp_t *fake_find_peer_qp(fake_qp_t *my_qp) +{ + list_head_t *node = NULL; + list_head_t *next = NULL; + fake_qp_t *fqp = NULL; + fake_qp_t *find = NULL; + + pthread_mutex_lock(&g_f_qp_list.listLock); + list_for_each_safe(node, next, &g_f_qp_list.list.list_head) + { + fqp = list_entry(node, fake_qp_t, entry); + if (fqp->dest_qp_num == my_qp->qp.qp_num) { + find = fqp; + break; + } + } + pthread_mutex_unlock(&g_f_qp_list.listLock); + + return find; +} + +/* ********************************************************************* + 功能描述 : wr转换成wc的opcode +********************************************************************** */ +enum ibv_wc_opcode fake_wr_to_wc_opcode(enum ibv_wr_opcode opcode) +{ + switch (opcode) { + case IBV_WR_RDMA_WRITE: + return IBV_WC_RDMA_WRITE; + + case IBV_WR_SEND: + return IBV_WC_SEND; + + case IBV_WR_RDMA_READ: + return IBV_WC_RDMA_READ; + + default: + return IBV_WC_RECV; + } +} + +/* ********************************************************************* + 功能描述 : 通知发送完成 +********************************************************************** */ +void fake_create_send_wc(fake_qp_t *fqp, struct ibv_send_wr *wr, uint64_t size) +{ + if (wr == NULL) { + /* dcc combo write 分2次post,前一次post没有singled标记,不产生完成事件 */ + return; + } + + fake_cq_t *fcq = (fake_cq_t *)fqp->qp.send_cq; + + struct ibv_wc wc = {}; + wc.wr_id = wr->wr_id; + wc.status = IBV_WC_SUCCESS; + wc.opcode = fake_wr_to_wc_opcode(wr->opcode); + wc.qp_num = fqp->qp.qp_num; + wc.byte_len = size; + wc.imm_data = wr->imm_data; + + if (fake_prodce_cqe(fcq, &wc) == 0) { + fake_send_event_on_cc(fcq->cq.channel); + } +} + +/* ********************************************************************* + 功能描述 : 通知接收完成 +********************************************************************** */ +void fake_create_recv_wc(fake_cq_t *fcq, uint32_t qp_num, fake_recv_wr_item_t *item, uint64_t size, uint32_t immData) +{ + struct ibv_wc wc = {}; + wc.wr_id = item->wr.wr_id; + wc.status = IBV_WC_SUCCESS; + wc.opcode = IBV_WC_RECV; + wc.byte_len = size; + wc.qp_num = qp_num; + wc.imm_data = immData; + + if (fake_prodce_cqe(fcq, &wc) == 0) { + fake_send_event_on_cc(fcq->cq.channel); + } +} + +/* ********************************************************************* + 功能描述 : post send处理 +********************************************************************** */ +int fake_ibv_post_send(fake_qp_t *my_qp, struct ibv_send_wr *wr) +{ + fake_qp_t *peer_qp = fake_find_peer_qp(my_qp); + if (peer_qp == NULL) { + return IBV_ERROR; + } + fake_recv_wr_mgr_t *recv_wr_mgr; + if (peer_qp->qp.srq != NULL) { + fake_srq_t *fsrq = (fake_srq_t *)peer_qp->qp.srq; + recv_wr_mgr = &(fsrq->recv_wr_mgr); + } else { + recv_wr_mgr = &peer_qp->recv_wr_mgr; + } + + fake_recv_wr_item_t *item = fake_recv_wr_mgr_comsume_item(recv_wr_mgr); + if (item == NULL) { + FAKE_LOG("Item is null."); + return IBV_ERROR; + } + + uint64_t size = 0; + for (int i = 0; i < wr->num_sge; i++) { + if (item->sg_list[i].length == 0) { + // 兼容一个wqe接受sgl的场景,直接拷贝在上一个 + memcpy(reinterpret_cast( + static_cast(item->sg_list[i - 1].addr + wr->sg_list[i - 1].length)), + reinterpret_cast(static_cast(wr->sg_list[i].addr)), + wr->sg_list[i].length); + size += wr->sg_list[i].length; + continue; + } + memcpy(reinterpret_cast( + static_cast(item->sg_list[i].addr)), + reinterpret_cast(static_cast(wr->sg_list[i].addr)), + wr->sg_list[i].length); + size += wr->sg_list[i].length; + } + + if (wr->send_flags == IBV_SEND_SIGNALED) { + fake_create_send_wc(my_qp, wr, size); + } + + fake_cq_t *fcq = (fake_cq_t *)peer_qp->qp.recv_cq; + fake_create_recv_wc(fcq, peer_qp->qp.qp_num, item, size, wr->imm_data); + + return 0; +} + +/* ********************************************************************* + 功能描述 : post read处理 +********************************************************************** */ +int fake_post_read(fake_qp_t *my_qp, struct ibv_send_wr *wr) +{ + uint64_t size = 0; + struct ibv_send_wr *wr_temp = wr; + struct ibv_send_wr *finish_wr = NULL; + + while (wr_temp != NULL) { + int i; + int remote_addr_offset = 0; + for (i = 0; i < wr_temp->num_sge; i++) { + remote_addr_offset += ((i == 0) ? 0 : wr_temp->sg_list[i - 1].length); + char *src = (char *)(uintptr_t)wr_temp->wr.rdma.remote_addr + remote_addr_offset; + if (src == 0x0 || wr_temp->wr.rdma.rkey == 0) { + FAKE_LOG("Illegal rkey %d or rAddress %p", wr_temp->wr.rdma.rkey, src); + return IBV_ERROR; + } + memcpy((void *)(uintptr_t)wr_temp->sg_list[i].addr, (char *)(uintptr_t)src, wr_temp->sg_list[i].length); + size += wr_temp->sg_list[i].length; + } + + if (wr_temp->send_flags == IBV_SEND_SIGNALED) { + finish_wr = wr_temp; + break; + } + wr_temp = wr_temp->next; + } + + if (finish_wr == NULL) { + FAKE_LOG("Wr isn't set any send signal flag."); + return IBV_ERROR; + } + + if (finish_wr->next != NULL) { + FAKE_LOG("Wr must be the last one, if it is been set send signal flag."); + return IBV_ERROR; + } + + fake_create_send_wc(my_qp, finish_wr, size); + + return 0; +} + +/* ********************************************************************* + 功能描述 : post write处理 +********************************************************************** */ +int fake_post_write(fake_qp_t *my_qp, struct ibv_send_wr *wr) +{ + struct ibv_send_wr *wr_temp = wr; + struct ibv_send_wr *finish_wr = NULL; + uint64_t size = 0; + + while (wr_temp != NULL) { + int i; + int remote_addr_offset = 0; + for (i = 0; i < wr_temp->num_sge; i++) { + remote_addr_offset += ((i == 0) ? 0 : wr_temp->sg_list[i - 1].length); + char *src = (char *)(uintptr_t)wr_temp->wr.rdma.remote_addr + remote_addr_offset; + if (src == 0x0 || wr_temp->wr.rdma.rkey == 0) { + FAKE_LOG("Illegal rkey %d or rAddress %p", wr_temp->wr.rdma.rkey, src); + return IBV_ERROR; + } + memcpy(src, (void *)(uintptr_t)wr_temp->sg_list[i].addr, wr_temp->sg_list[i].length); + size += wr_temp->sg_list[i].length; + } + + if (wr_temp->send_flags == IBV_SEND_SIGNALED) { + finish_wr = wr_temp; + break; + } + wr_temp = wr_temp->next; + } + + /* dcc combo write 分2次post,前一次post没有singled标记,不要求每次post一定有singled标记 */ + if (finish_wr != NULL && finish_wr->next != NULL) { + FAKE_LOG("Wr must be the last one, if it is been set send signal flag."); + return IBV_ERROR; + } + + fake_create_send_wc(my_qp, finish_wr, size); + + return 0; +} + +/* ********************************************************************* + 功能描述 : post send +********************************************************************** */ +int fake_post_send(struct ibv_qp *qp, struct ibv_send_wr *wr, struct ibv_send_wr **bad_wr) +{ + int ret; + fake_qp_t *my_qp = (fake_qp_t *)qp; + + UNREF_PARAM(bad_wr); + + switch (wr->opcode) { + case IBV_WR_SEND: + case IBV_WR_SEND_WITH_IMM: + ret = fake_ibv_post_send(my_qp, wr); + break; + + case IBV_WR_RDMA_READ: + ret = fake_post_read(my_qp, wr); + break; + + case IBV_WR_RDMA_WRITE: + ret = fake_post_write(my_qp, wr); + break; + + default: + ret = IBV_ERROR; + FAKE_LOG("Fake don't support opcode(%d).", wr->opcode); + break; + } + + return ret; +} + +/* ********************************************************************* + 功能描述 : 模拟进行post cq,将已经完成的事件全部放到用户的wc中 +********************************************************************** */ +int fake_poll_cq(struct ibv_cq *cq, int num_entries, struct ibv_wc *wc) +{ + int i; + fake_cq_t *fcq = (fake_cq_t *)(void *)cq; + + pthread_mutex_lock(&fcq->cqLock); + for (i = 0; i < num_entries; i++) { + struct ibv_wc *my_wc = fake_comsume_cqe(fcq); + if (my_wc == NULL) { + break; + } + + (void)memcpy(&wc[i], my_wc, sizeof(struct ibv_wc)); + } + pthread_mutex_unlock(&fcq->cqLock); + + return i; +} + +int fake_req_notify_cq(struct ibv_cq *cq, int solicited_only) +{ + if (cq == NULL) { + return 0; + } + + fake_cq_t *fcq = (fake_cq_t *)cq; + + if (fcq->wcq.comsumer != fcq->wcq.producer) { + fake_send_event_on_cc(fcq->cq.channel); + } + + UNREF_PARAM(solicited_only); + + return 0; +} + +typedef struct { + struct verbs_context v_ctx; +} ibv_verb_ctx_all_t; + +struct ibv_context *ibv_open_device(struct ibv_device *device) +{ + ibv_verb_ctx_all_t *ctx = (ibv_verb_ctx_all_t *)calloc(1, sizeof(ibv_verb_ctx_all_t)); + if (ctx == NULL) { + FAKE_LOG("fake_ibv: create verbs_context fail."); + return NULL; + } + + struct verbs_context *v_ctx = &ctx->v_ctx; + struct ibv_context *context = &v_ctx->context; + + context->abi_compat = ((uint8_t *)nullptr) - 1; /* verbs 扩展接口判断 */ + v_ctx->sz = sizeof(struct verbs_context); + v_ctx->create_qp_ex = nullptr; + + context->device = device; + context->cmd_fd = FAKE_NULL_FD; + context->async_fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + context->num_comp_vectors = 0; + + context->ops.poll_cq = fake_poll_cq; + context->ops.post_send = fake_post_send; + context->ops.post_recv = fake_post_recv; + context->ops.req_notify_cq = fake_req_notify_cq; + context->ops.post_srq_recv = fake_post_srq_recv; + + return context; +} + +int ibv_close_device(struct ibv_context *context) +{ + if (context->cmd_fd != FAKE_NULL_FD) { + close(context->cmd_fd); + context->cmd_fd = FAKE_NULL_FD; + } + + if (context->async_fd != FAKE_NULL_FD) { + close(context->async_fd); + context->async_fd = FAKE_NULL_FD; + } + + struct verbs_context *v_ctx = container_of(context, struct verbs_context, context); + ibv_verb_ctx_all_t *ctx = container_of(v_ctx, ibv_verb_ctx_all_t, v_ctx); + free(ctx); + + return 0; +} + +int ibv_query_device(struct ibv_context *context, struct ibv_device_attr *device_attr) +{ + (void)memset(device_attr, 0, sizeof(struct ibv_device_attr)); + + /* 当前xnet只用了,gid和portcnt,故只初始化这两个字段,可根据代码需要自行增加 */ + if (snprintf(device_attr->fw_ver, sizeof(device_attr->fw_ver) - 1, "xnet_fake_0.1") == -1) { + return IBV_ERROR; + } + + device_attr->node_guid = 1; + device_attr->phys_port_cnt = 1; + + UNREF_PARAM(context); + return 0; +} + +#ifdef ibv_reg_mr +#undef ibv_reg_mr +#endif +struct ibv_mr *ibv_reg_mr(struct ibv_pd *pd, void *address, size_t length, int access) +{ + struct ibv_mr *mr = (struct ibv_mr *)calloc(1, sizeof(struct ibv_mr)); + if (!mr) { + FAKE_LOG("fake_ibv: ibv_reg_mr fail."); + return NULL; + } + + mr->context = pd->context; + mr->pd = pd; + mr->addr = address; + mr->length = length; + mr->lkey = fake_get_key(); + mr->rkey = fake_get_key(); + mr->handle = 0; + + UNREF_PARAM(access); + + return mr; +} + +struct ibv_mr *ibv_reg_mr_iova2(struct ibv_pd *pd, void *address, size_t length, uint64_t iova, unsigned int access) +{ + struct ibv_mr *mr = (struct ibv_mr *)calloc(1, sizeof(struct ibv_mr)); + if (!mr) { + FAKE_LOG("fake_ibv: ibv_reg_mr fail."); + return NULL; + } + + mr->context = pd->context; + mr->pd = pd; + mr->addr = address; + mr->length = length; + mr->lkey = fake_get_key(); + mr->rkey = fake_get_key(); + mr->handle = 0; + + UNREF_PARAM(access); + + return mr; +} + +int ibv_dereg_mr(struct ibv_mr *mr) +{ + free(mr); + return 0; +} + +struct ibv_mr *ibv_reg_umm_page_mr(struct ibv_pd *pd, void *addr, void *knl_addr, size_t length, int access_flag) +{ + UNREF_PARAM(knl_addr); + + return ibv_reg_mr(pd, addr, length, access_flag); +} + +int ibv_dereg_umm_page_mr(struct ibv_mr *mr) +{ + free(mr); + return 0; +} + + +int ibv_query_gid(struct ibv_context *context, uint8_t port_num, int index, union ibv_gid *gid) +{ + (void)memset(gid, 0, sizeof(union ibv_gid)); + auto devI6Address = reinterpret_cast(gid->raw); + + devI6Address->s6_addr32[2UL] = htonl(0x0000ffff); + devI6Address->s6_addr32[3UL] = inet_addr("127.0.0.1"); + + UNREF_PARAM(context); + UNREF_PARAM(port_num); + UNREF_PARAM(index); + return 0; +} + +/* ******************************************************************** + * 下面几个函数无实际流程调用,打空桩 + * ******************************************************************** */ +int ibv_query_qp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask, struct ibv_qp_init_attr *init_attr) +{ + FAKE_LOG("fake_ibv: call fake function ibv_query_qp."); + UNREF_PARAM(qp); + UNREF_PARAM(attr); + UNREF_PARAM(attr_mask); + UNREF_PARAM(init_attr); + return 0; +} + +int ibv_check_qp(struct ibv_context *context, pid_t pid) +{ + FAKE_LOG("fake_ibv: call fake function ibv_check_qp."); + + UNREF_PARAM(context); + UNREF_PARAM(pid); + return 0; +} + +void ibv_dfx_qp_wr_dump(struct ibv_qp *qp) +{ + FAKE_LOG("fake_ibv: call fake function dfx_qp_wr_dump."); + UNREF_PARAM(qp); +} + +void ibv_dfx_port_traffic_dump(struct ibv_context *context) +{ + FAKE_LOG("fake_ibv: call fake function port_traffic_dump."); + UNREF_PARAM(context); +} + +void ibv_dfx_counter(struct ibv_qp *qp) +{ + FAKE_LOG("fake_ibv: call fake function dfx_counter."); + UNREF_PARAM(qp); +} + +const char *ibv_port_state_str(enum ibv_port_state port_state) +{ + return "OK"; +} + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/test/stub/fake_ibv.h b/test/stub/fake_ibv.h new file mode 100644 index 0000000000000000000000000000000000000000..a897e2d0a8dc91a2c35068a2f8afa5272591ad48 --- /dev/null +++ b/test/stub/fake_ibv.h @@ -0,0 +1,263 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef __FAKE_IBV_H__ +#define __FAKE_IBV_H__ + +#include +#include +#include +#include +#include + +#include "hcom_log.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct list_head { + struct list_head *next, *prev; +}; +typedef struct list_head list_head_t; + +#define FAKE_IBV_DEVICE_NUM 2 +#define FAKE_RECV_WR_SGE_NUM 8 +#define FAKE_RECV_WR_DEPTH 1024 + +typedef struct dsw_list_s { + list_head_t list_head; + int node_num; +} fake_list_t; + +typedef struct { + fake_list_t list; + pthread_mutex_t listLock; +} fake_lock_list_t; + +#define INIT_LIST_HEAD(ptr) \ + { \ + (ptr)->next = (ptr); \ + (ptr)->prev = (ptr); \ + } + +#define list_for_each_safe(pos, n, head) \ + for ((pos) = (head)->next, (n) = (pos)->next; (pos) != (head); (pos) = (n), (n) = (pos)->next) + +#define list_entry(ptr, type, member) ((type *)(void *)((char *)(ptr) - offsetof(type, member))) + +static inline void __list_add(struct list_head *newnode, struct list_head *prevnode, struct list_head *nextnode) +{ + nextnode->prev = newnode; + newnode->next = nextnode; + newnode->prev = prevnode; + prevnode->next = newnode; +} + +static inline void __list_del(struct list_head *prev, struct list_head *next) +{ + next->prev = prev; + prev->next = next; +} + +static inline void list_add_tail(struct list_head *new_head, struct list_head *head) +{ + __list_add(new_head, head->prev, head); +} + +static inline void list_del_node(list_head_t *node, fake_list_t *list) +{ + __list_del(node->prev, node->next); + INIT_LIST_HEAD(node); +} + +static inline void fake_lock_list_init(fake_lock_list_t *list) +{ + pthread_mutex_init(&list->listLock, NULL); + INIT_LIST_HEAD(&(list->list.list_head)); + list->list.node_num = 0; +} + +static inline void fake_lock_list_add_tail(list_head_t *node, fake_lock_list_t *lock_list) +{ + pthread_mutex_lock(&lock_list->listLock); + lock_list->list.node_num += 1; + list_add_tail(node, &(lock_list->list.list_head)); + pthread_mutex_unlock(&lock_list->listLock); +} + +static inline void fake_lock_list_del_node(list_head_t *node, fake_lock_list_t *lock_list) +{ + pthread_mutex_lock(&lock_list->listLock); + list_del_node(node, &(lock_list->list)); + lock_list->list.node_num -= 1; + pthread_mutex_unlock(&lock_list->listLock); +} + +typedef struct { + struct ibv_recv_wr wr; + struct ibv_sge sg_list[FAKE_RECV_WR_SGE_NUM]; +} fake_recv_wr_item_t; + +typedef struct { + uint32_t producer; + uint32_t comsuer; + fake_recv_wr_item_t item[FAKE_RECV_WR_DEPTH]; + pthread_mutex_t wrLock; +} fake_recv_wr_mgr_t; + +/* 结构体定义 */ +typedef struct fake_qp_s { + struct ibv_qp qp; /* 对外呈现的qp结构 */ + struct ibv_recv_wr head; /* 用来挂载外部注册的recv,外部post的时候申请结构挂到链表上 */ + list_head_t entry; + uint32_t dest_qp_num; + fake_recv_wr_mgr_t recv_wr_mgr; +} fake_qp_t; + +typedef struct fake_srq_s { + struct ibv_srq srq; + fake_recv_wr_mgr_t recv_wr_mgr; +} fake_srq_t; + +typedef struct cqe_queue_s { + uint32_t producer; + uint32_t comsumer; + struct ibv_wc *buff; +} cqe_queue_t; + +typedef struct fake_cq_s { + struct ibv_cq cq; + pthread_mutex_t cqLock; + cqe_queue_t wcq; + list_head_t entry; +} fake_cq_t; + +#define MAX_CQ_ON_CC 128 + +typedef struct { + struct ibv_comp_channel cc; + fake_lock_list_t cq_list; +} fake_cc_t; + + +struct ibv_mr *ibv_reg_umm_page_mr(struct ibv_pd *pd, void *addr, void *knl_addr, size_t length, int access_flag); +int ibv_dereg_umm_page_mr(struct ibv_mr *mr); +int ibv_check_qp(struct ibv_context *context, pid_t pid); +void ibv_dfx_qp_wr_dump(struct ibv_qp *qp); +void ibv_dfx_port_traffic_dump(struct ibv_context *context); +void ibv_dfx_counter(struct ibv_qp *qp); + +int fake_prodce_cqe(fake_cq_t *fcq, struct ibv_wc *wc); +void fake_create_recv_wc(fake_cq_t *fcq, uint32_t qp_num, fake_recv_wr_item_t *item, uint64_t size, uint32_t immData); +void fake_recv_wr_mgr_init(fake_recv_wr_mgr_t *recv_wr_mgr); +__attribute__((constructor)) void fake_ibv_init(void); +void fake_send_event_on_cc(struct ibv_comp_channel *cmc); +fake_recv_wr_item_t *fake_recv_wr_mgr_comsume_item(fake_recv_wr_mgr_t *recv_wr_mgr); +void fake_create_send_wc(fake_qp_t *fqp, struct ibv_send_wr *wr, uint64_t size); +fake_lock_list_t *fake_get_qp_list(void); +uint32_t fake_get_qp_num(void); +void fake_flash_all_recv_wr(fake_qp_t *fqp); + +#define UNREF_PARAM(x) ((void)(x)) + +#ifdef ibv_query_port +#undef ibv_query_port +#endif +static int fake_ibv_query_port(struct ibv_context *context, uint8_t port_num, + struct _compat_ibv_port_attr *port_attr_in) +{ + struct ibv_port_attr *port_attr = (struct ibv_port_attr *)port_attr_in; + port_attr->state = IBV_PORT_ACTIVE; + port_attr->active_mtu = IBV_MTU_4096; + port_attr->gid_tbl_len = 1; /* gid table长度不知道为多少,先设置为1,不行再改 */ + port_attr->lid = 0; /* lid值调试的时候再定 */ + port_attr->active_speed = 0; /* 速率也乱填一个值,应该不会影响LLT功能 */ + port_attr->link_layer = IBV_LINK_LAYER_ETHERNET; + + UNREF_PARAM(context); + UNREF_PARAM(port_num); + + return 0; +} + +#define fake_ibv_query_port(context, port_num, port_attr) fake_ibv_query_port(context, port_num, port_attr) +#ifdef __cplusplus +} +#endif + +const int BT_SIZE = 40u; +const int THREAD_MAX_NAME_LEN = 16u; +const int STR_SIZE = 512u; + +inline std::string DemangleFuncName(const char *str) +{ + size_t size = 0; + int status = 0; + std::string tmpStr; + tmpStr.resize(STR_SIZE); + + if (str == nullptr) { + std::string emptyStr = "empty"; + return emptyStr; + } + + if (1 == sscanf(str, "%*[^(]%*[^_]%255[^)+]", &tmpStr[0])) { + char *tmp = abi::__cxa_demangle(&tmpStr[0], nullptr, &size, &status); + if (tmp) { + std::string result(tmp); + free(tmp); + return result; + } + } + + if (1 == sscanf(str, "%255s", &tmpStr[0])) { + return tmpStr; + } + + return str; +} + +inline void NetBacktrace(uint64_t id) +{ + void **list = nullptr; + char **stacks = nullptr; + int size; + list = new (std::nothrow) void *[BT_SIZE]; + if (list == nullptr) { + printf("Failed to alloc memory for list"); + return; + } + size = backtrace(list, BT_SIZE); + stacks = backtrace_symbols(list, size); + if (stacks != nullptr) { + char *thName = new (std::nothrow) char[THREAD_MAX_NAME_LEN]; + if (thName == nullptr) { + printf("Failed to alloc memory for thName"); + free(stacks); + delete[] list; + return; + } + if (pthread_getname_np(pthread_self(), thName, THREAD_MAX_NAME_LEN) != 0) { + printf("Failed to get the thread name for %lu", pthread_self()); + } else { + printf("%s id: %lu backtrace:", thName, id); + } + delete[] thName; + for (int i = 0; i < size; i++) { + printf("Id(%d) :[%s]\n", i, DemangleFuncName(stacks[i]).c_str()); + } + free(stacks); + } + delete[] list; +} + +#endif /* __FAKE_IBV_H__ */ diff --git a/test/stub/hcom_securec.cpp b/test/stub/hcom_securec.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95476e1c924ecc83b3b43b466afd2c71692b2c77 --- /dev/null +++ b/test/stub/hcom_securec.cpp @@ -0,0 +1,153 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include + +#include "hcom_securec.h" + +namespace ock { +namespace hcom { +#ifndef SECUREC_MEM_MAX_LEN +#define SECUREC_MEM_MAX_LEN 0x7fffffffUL +#endif + +#ifndef SECUREC_STRING_MAX_LEN +#define SECUREC_STRING_MAX_LEN 0x7fffffffUL +#endif + +#define SECUREC_LIKELY(x) __builtin_expect(!!(x), 1) + +#define SECUREC_MEMORY_NO_OVERLAP(dest, src, count) \ + (((src) < (dest) && ((const char *)(src) + (count)) <= (char *)(dest)) || \ + ((dest) < (src) && ((char *)(dest) + (count)) <= (const char *)(src))) + +#define SECUREC_MEMORY_IS_OVERLAP(dest, src, count) \ + (((src) < (dest) && ((const char *)(src) + (count)) > (char *)(dest)) || \ + ((dest) < (src) && ((char *)(dest) + (count)) > (const char *)(src))) + +#define SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count) \ + (SECUREC_LIKELY((count) <= (destMax) && (dest) != nullptr && (src) != nullptr && \ + (destMax) <= SECUREC_MEM_MAX_LEN && (count) > 0 && SECUREC_MEMORY_NO_OVERLAP((dest), (src), (count)))) + +#define SECUREC_STRCPY_PARAM_OK(strDest, destMax, strSrc) \ + ((destMax) > 0 && (destMax) <= SECUREC_STRING_MAX_LEN && (strDest) != nullptr && (strSrc) != nullptr && \ + (strDest) != (strSrc)) + +#define SECUREC_CALC_STR_LEN(str, maxLen, outLen) \ + do { \ + *(outLen) = strnlen((str), (maxLen)); \ + } while (0) + +#define SECUREC_STRCPY_OPT(dest, src, lenWithTerm) \ + do { \ + memcpy((dest), (src), (lenWithTerm)); \ + } while (0) + +typedef enum { + SEC_EOK = 0, + SEC_EINVAL = 22, + SEC_ERANGE = 34, + SEC_EINVAL_AND_RESET = 150, + SEC_ERANGE_AND_RESET = 162, + SEC_EOVERLAP_AND_RESET = 182, +} MEMCPY_S_CODE; + +inline int SecMemcpyError(void *dest, size_t destMax, const void *src, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_MEM_MAX_LEN) { + return SEC_ERANGE; + } + if (dest == nullptr || src == nullptr) { + if (dest != nullptr) { + (void)memset(dest, 0, destMax); + return SEC_EINVAL_AND_RESET; + } + return SEC_EINVAL; + } + if (count > destMax) { + (void)memset(dest, 0, destMax); + return SEC_ERANGE_AND_RESET; + } + if (SECUREC_MEMORY_IS_OVERLAP(dest, src, count)) { + (void)memset(dest, 0, destMax); + return SEC_EOVERLAP_AND_RESET; + } + /* Count is 0 or dest equal src also ret EOK */ + return SEC_EOK; +} + +int memcpy_s(void *dest, size_t destMax, const void *src, size_t count) +{ + if (SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count)) { + memcpy(dest, src, count); + return SEC_EOK; + } + /* Meet some runtime violation, return error code */ + return SecMemcpyError(dest, destMax, src, count); +} + +inline int CheckSrcRange(char *strDest, size_t destMax, const char *strSrc) +{ + size_t tmpDestMax = destMax; + const char *tmpSrc = strSrc; + /* Use destMax as boundary checker and destMax must be greater than zero */ + while (*tmpSrc != '\0' && tmpDestMax > 0) { + ++tmpSrc; + --tmpDestMax; + } + if (tmpDestMax == 0) { + strDest[0] = '\0'; + return SEC_ERANGE_AND_RESET; + } + return SEC_EOK; +} + +int strcpy_error(char *strDest, size_t destMax, const char *strSrc) +{ + if (destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { + return SEC_ERANGE; + } + if (strDest == nullptr || strSrc == nullptr) { + if (strDest != nullptr) { + strDest[0] = '\0'; + return SEC_EINVAL_AND_RESET; + } + return SEC_EINVAL; + } + return CheckSrcRange(strDest, destMax, strSrc); +} + +int strcpy_s(char *strDest, size_t destMax, const char *strSrc) +{ + if (SECUREC_STRCPY_PARAM_OK(strDest, destMax, strSrc)) { + size_t srcStrLen; + SECUREC_CALC_STR_LEN(strSrc, destMax, &srcStrLen); + ++srcStrLen; /* The length include '\0' */ + + if (srcStrLen <= destMax) { + /* Use mem overlap check include '\0' */ + if (SECUREC_MEMORY_NO_OVERLAP(strDest, strSrc, srcStrLen)) { + /* Performance optimization srcStrLen include '\0' */ + SECUREC_STRCPY_OPT(strDest, strSrc, srcStrLen); + return SEC_EOK; + } else { + strDest[0] = '\0'; + return SEC_EOVERLAP_AND_RESET; + } + } + } + return strcpy_error(strDest, destMax, strSrc); +} +} +} \ No newline at end of file diff --git a/test/stub/hcom_securec.h b/test/stub/hcom_securec.h new file mode 100644 index 0000000000000000000000000000000000000000..5e111e0a9979b7b31df8d7a9c1f698d4150501a3 --- /dev/null +++ b/test/stub/hcom_securec.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifndef HCOM_HCOM_SECUREC_H +#define HCOM_HCOM_SECUREC_H + +#include +#include +#include + +namespace ock { +namespace hcom { + +int memcpy_s(void *dest, size_t destMax, const void *src, size_t count); +int strcpy_error(char *strDest, size_t destMax, const char *strSrc); +int strcpy_s(char *strDest, size_t destMax, const char *strSrc); + +} +} +#endif // HCOM_HCOM_SECUREC_H diff --git a/test/tools/hcom_tracer/CMakeLists.txt b/test/tools/hcom_tracer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..294671fe33d8b82a10307b4a3fc5f22a0973d181 --- /dev/null +++ b/test/tools/hcom_tracer/CMakeLists.txt @@ -0,0 +1,36 @@ +CMAKE_MINIMUM_REQUIRED(VERSION 3.12.1) +PROJECT(htracer) + +SET(CMAKE_CXX_STANDARD 14) +SET(CMAKE_BUILD_TYPE "Release") +add_compile_options( + -Wall + -fstack-protector-strong +) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +add_link_options( + -pie + -Wl,-z,relro,-z,now,-z,noexecstack + -s +) + +set(HCOM_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/../../../dist/hcom/include") +set(HCOM_LIB_DIR "${CMAKE_SOURCE_DIR}/../../../dist/hcom/lib") +set(SECUREC_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/../../../dist/hcom_3rdparty/huawei_secure_c/include") +set(SECUREC_LIB_DIR "${CMAKE_SOURCE_DIR}/../../../dist/hcom_3rdparty/huawei_secure_c/lib") + +include_directories("${HCOM_INCLUDE_DIR}" + "${SECUREC_INCLUDE_DIR}") +link_directories("${HCOM_LIB_DIR}" + "${SECUREC_LIB_DIR}") + +set(CMAKE_SKIP_BUILD_RPATH TRUE) + +FILE(GLOB_RECURSE HTRACER_CLI_SRC + "${CMAKE_SOURCE_DIR}/src/*") +ADD_EXECUTABLE(htracer_cli ${HTRACER_CLI_SRC}) +TARGET_LINK_LIBRARIES(htracer_cli PUBLIC pthread boundscheck) +TARGET_INCLUDE_DIRECTORIES(htracer_cli PUBLIC . message common rpc) +install(TARGETS htracer_cli DESTINATION ${TARGET_INSTALL_BIN}/ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE) diff --git a/test/tools/hcom_tracer/src/cmd_handler.cpp b/test/tools/hcom_tracer/src/cmd_handler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13315cd60770f643bcd8ad03280feb74fc11417d --- /dev/null +++ b/test/tools/hcom_tracer/src/cmd_handler.cpp @@ -0,0 +1,229 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include +#include +#include "cmd_handler.h" +#include "cmd_helper.h" +#include "htracer_client.h" +#include "htracer_utils.h" +#include "hcom/hcom_num_def.h" + +using namespace ock::hcom; + +static CmdHelper cmdHelper; + +void HTracerCliHelper::Initialize() +{ + std::map> cmdHandlers = { { "show", std::make_shared() }, + { "reset", std::make_shared() }, + { "conf", std::make_shared() } }; + cmdHelper.UpdateHost(); + cmdHandlers.swap(mCmdHandlers); +} + +SerCode ShowCmdHandler::Handle(std::vector cmds) +{ + /* ! + * -s 1 -i 1 -c 10000 -d /tmp/local + */ + if (HTracerUtils::ExistCmdOption(cmds, "-h")) { + std::cout << "Help info:" << std::endl; + std::cout << HelpInfo() << std::endl; + return SER_OK; + } + + /* + * interval time + */ + uint32_t interval = ParseUintOption(cmds, "-i", 1); + + /* + * latency quantile + */ + double quantile = ParseDoubleOption(cmds, "-tp", -1.0, 0, NN_NO100); + + /* + * round count + */ + uint32_t count = ParseUintOption(cmds, "-n", 1); + + /* + * dump result to file + */ + std::ofstream dump; + auto dumpPath = HTracerUtils::GetCmdOption(cmds, "-d"); + if (!dumpPath.empty()) { + dump.open(dumpPath, std::ios::out | std::ios::app); + } + + std::ostream &out = dump.is_open() ? dump : std::cout; + for (uint32_t i = 0; i < count; ++i) { + out << "Round:" << (i + 1) << " " << HTracerUtils::CurrentTime() << std::endl; + ProcessTraceData(out, quantile); + if (count != 1) { + sleep(interval); + } + } + + return SER_OK; +} + +uint32_t ShowCmdHandler::ParseUintOption(const std::vector &cmds, const std::string &opt, + uint32_t defaultValue) +{ + auto param = HTracerUtils::GetCmdOption(cmds, opt); + return param.empty() ? defaultValue : std::atol(param.c_str()); +} + +double ShowCmdHandler::ParseDoubleOption(const std::vector &cmds, const std::string &opt, + double defaultValue, double min, double max) +{ + auto param = HTracerUtils::GetCmdOption(cmds, opt); + if (param.empty()){ + return defaultValue; + } + double val = std::atof(param.c_str()); + return (val > min && val < max) ? val : defaultValue; +} + +void ShowCmdHandler::ProcessTraceData(std::ostream &out, double quantile) +{ + std::map sumTraceInfoMap; + cmdHelper.UpdateHost(quantile); + auto hostInfo = cmdHelper.GetHostInfo(); + auto &processes = hostInfo.GetAllProcesses(); + for (const auto &process : processes) { + auto &traceInfos = process.second->GetAllTraceInfos(); + for (const auto &traceInfo : traceInfos) { + auto iter = sumTraceInfoMap.find(traceInfo.name); + if (iter == sumTraceInfoMap.end()) { + sumTraceInfoMap.insert(std::make_pair(traceInfo.name, traceInfo)); + } else { + iter->second += traceInfo; + } + } + } + + out << TTraceInfo::HeaderString() << std::endl; + for (const auto &traceInfo : sumTraceInfoMap) { + out << "\t" << traceInfo.second.ToString() << std::endl; + } +} + +std::string ShowCmdHandler::HelpInfo() +{ + std::stringstream ss; + ss << "\t -i print interval. "<< std::endl + << "\t -n number of times. "<< std::endl + << "\t -d dump trace point information. -d /opt/dump.text " << std::endl + << "\t -tp show percentile of latency, need to use \'conf -p\' to enable it first!" << std::endl; + return ss.str(); +} + +SerCode HTracerCliHelper::HandleCmd(std::string cmd) +{ + auto cmds = HTracerUtils::StrSplit(cmd, ' '); + if (cmds.empty()) { + std::cout << "Invalid command!" << std::endl << + "\tshow : show trace information" << std::endl << + "\treset : clear invalid host and reset trace" << std::endl << + "\tconf : config trace" << std::endl << + "\tquit : quit trace" << std::endl << + "\tcommand -h : show help information for command. e.g. show -h" << std::endl; + return SER_ERROR; + } + + std::string cmdType = cmds[0]; + cmds.erase(cmds.begin()); + auto cmdHandlerIt = mCmdHandlers.find(cmdType); + if (cmdHandlerIt == mCmdHandlers.end()) { + std::cout << "Invalid command!" << std::endl << + "\tshow : show trace information" << std::endl << + "\treset : clear invalid host and reset trace" << std::endl << + "\tconf : config trace" << std::endl << + "\tquit : quit trace" << std::endl << + "\tcommand -h : show help information for command. e.g. show -h" << std::endl; + return SER_ERROR; + } + return cmdHandlerIt->second->Handle(cmds); +} + +SerCode ResetCmdHandler::Handle(std::vector cmds) +{ + if (!cmds.empty()) { + std::cout << "Invalid param" << std::endl; + return SER_ERROR; + } + + cmdHelper.ResetTraceInfo(); + + return SER_OK; +} + +std::string ResetCmdHandler::HelpInfo() +{ + return std::string(""); +} + +SerCode ConfCmdHandler::Handle(std::vector cmds) +{ + if (HTracerUtils::ExistCmdOption(cmds, "-h")) { + std::cout << "Help info:" << std::endl; + std::cout << HelpInfo() << std::endl; + return SER_OK; + } + /* enable trace */ + bool enable = true; + auto enableParam = HTracerUtils::GetCmdOption(cmds, "-t"); + if (!enableParam.empty()) { + enable = std::stoi(enableParam); + } + + /* enable tp */ + bool enableTp = false; + auto enableTpParam = HTracerUtils::GetCmdOption(cmds, "-p"); + if (!enableTpParam.empty()) { + enableTp = std::atoi(enableTpParam.c_str()); + } + + /* enable log */ + bool enableLog = false; + auto enableLogParam = HTracerUtils::GetCmdOption(cmds, "-o"); + if (!enableLogParam.empty()) { + enableLog = std::atoi(enableLogParam.c_str()); + } + + /* log path */ + auto logPath = HTracerUtils::GetCmdOption(cmds, "-d"); + if (logPath.size() > NN_NO260) { + std::cout << "Invalid log path param" << std::endl; + logPath = ""; + } + + HandlerConfPara confPara(enable, enableTp, enableLog, logPath); + cmdHelper.EnableTrace(confPara); + return SER_OK; +} + +std::string ConfCmdHandler::HelpInfo() +{ + std::stringstream ss; + ss << "\t -t 1:enable or 0:disable trace, default 1. " << std::endl; + ss << "\t -p 1:enable tp or 0:disable tp, default 0. " << std::endl; + ss << "\t -o 1:enable log or 0:disable log, default 0. " << std::endl; + ss << "\t -d set dump log path, default path /tmp/htrace/log. " << std::endl; + ss << "\t To successfully change the log path, set \"conf -o 0\" first. " << std::endl; + return ss.str(); +} \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/cmd_handler.h b/test/tools/hcom_tracer/src/cmd_handler.h new file mode 100644 index 0000000000000000000000000000000000000000..7fe705d66d0cce5d914340e9f11939219198bf1b --- /dev/null +++ b/test/tools/hcom_tracer/src/cmd_handler.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef CMD_HANDLER +#define CMD_HANDLER + +#include +#include +#include +#include +#include +#include "hcom/hcom_err.h" + +using namespace ock::hcom; + +class CmdHandler { +public: + virtual SerCode Handle(std::vector cmds) = 0; + virtual std::string HelpInfo() = 0; +}; + +class ShowCmdHandler : public CmdHandler { +public: + SerCode Handle(std::vector cmds) override; + + std::string HelpInfo() override; + + uint32_t ParseUintOption(const std::vector &cmds, const std::string &opt, uint32_t defaultValue); + + double ParseDoubleOption(const std::vector &cmds, const std::string &opt, double defaultValue, + double min, double max); + + void ProcessTraceData(std::ostream &out, double quantile); +}; + +class ResetCmdHandler : public CmdHandler { +public: + SerCode Handle(std::vector cmds) override; + + std::string HelpInfo() override; +}; + +class ConfCmdHandler : public CmdHandler { +public: + SerCode Handle(std::vector cmds) override; + + std::string HelpInfo() override; +}; + +class HTracerCliHelper { +public: + void Initialize(); + + SerCode HandleCmd(std::string cmd); + +private: + std::map> mCmdHandlers; +}; +#endif // CMD_HANDLER \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/cmd_helper.cpp b/test/tools/hcom_tracer/src/cmd_helper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2fe51110803fe36d3b950ead876512da41c978a --- /dev/null +++ b/test/tools/hcom_tracer/src/cmd_helper.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include "cmd_helper.h" +#include "htracer_utils.h" +#include "htracer_client.h" + +#define INVALID_SERVICE_ID (0xFFFF) + +void ProcessInfo::UpdateLastTime() +{ + mLastUpdateTime = HTracerUtils::CurrentTime(); +} + +std::string ProcessInfo::ToString() +{ + std::stringstream ss; + ss << StateToString() << std::endl; + ss << TTraceInfo::HeaderString() << std::endl; + time_t rawTime; + time(&rawTime); + for (uint32_t i = 0; i < mTraceInfo.size(); ++i) { + ss << "\t" << mTraceInfo[i].ToString() << std::endl; + } + return ss.str(); +} + +void ProcessInfo::SetTraceInfo(std::vector traceInfo) +{ + traceInfo.swap(mTraceInfo); + UpdateLastTime(); +} + +std::string ProcessInfo::StateToString() +{ + std::stringstream ss; + auto processName = "localhost@" + std::to_string(mPid); + if (mActive) { + ss << " " << processName << "(status: active) "; + } else { + ss << " " << processName << "(status: inactive, out of date from " << mLastUpdateTime << ") "; + } + return ss.str(); +} + +void CmdHelper::UpdateHost(double quantile) +{ + std::lock_guard lock(mMutex); + mHostInfo.UpdateHostInfo(INVALID_SERVICE_ID, quantile); +} + +void CmdHelper::EnableTrace(const HandlerConfPara &confPara) +{ + std::lock_guard lock(mMutex); + mClient.EnableTrace(confPara); +} + +HostInfo CmdHelper::GetHostInfo() +{ + return mHostInfo; +} + +void CmdHelper::ResetTraceInfo() +{ + mHostInfo.ResetTraceInfo(); +} + +std::shared_ptr HostInfo::GetOrCreateProcess(pid_t pid) +{ + auto processIt = mAllProcesses.find(pid); + if (processIt != mAllProcesses.end()) { + return processIt->second; + } + + auto process = std::make_shared(pid); + if (process == nullptr) { + LOG_ERR("failed to malloc process, pid: " << pid); + return nullptr; + } + mAllProcesses[pid] = process; + return process; +} + +SerCode HostInfo::UpdateHostInfo(uint16_t serviceId, double quantile) +{ + HTracerClient client; + pid_t pid = -1; + std::vector traceInfos; + SerCode query_ret = client.Query(serviceId, quantile, traceInfos, pid); + if (query_ret == SER_OK) { + auto process = GetOrCreateProcess(pid); + process->SetTraceInfo(traceInfos); + process->SetActive(true); + mActiveProcess = process; + } else { + for (const auto &processIt : mAllProcesses) { + processIt.second->SetActive(false); + } + mActiveProcess = nullptr; + } + return SER_OK; +} + +SerCode HostInfo::ResetTraceInfo() +{ + HTracerClient client; + client.Reset(); + return SER_OK; +} + +std::string HostInfo::ToString(bool detail) +{ + std::stringstream ss; + if (detail) { + for (const auto &processIt : mAllProcesses) { + auto process = processIt.second; + ss << process->ToString(); + } + } else { + if (IsActive()) { + ss << " " + << "localHost " + << "(status: active) "; + } else { + ss << " " + << "localHost " + << "(status: inactive) "; + } + } + return ss.str(); +} \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/cmd_helper.h b/test/tools/hcom_tracer/src/cmd_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..c612efa753f9410dc64e739338a8e6bf4ef15baf --- /dev/null +++ b/test/tools/hcom_tracer/src/cmd_helper.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef CMD_HELPER +#define CMD_HELPER + +#include +#include +#include +#include +#include +#include "htracer_msg.h" +#include "hcom/hcom_err.h" +#include "htracer_client.h" + +class ProcessInfo { +public: + explicit ProcessInfo(pid_t pid) : mPid(pid) {} + + void SetTraceInfo(std::vector traceInfo); + + void SetActive(bool active) + { + mActive = active; + } + + bool IsActive() + { + return mActive; + } + + const std::vector &GetAllTraceInfos() + { + return mTraceInfo; + } + + std::string ToString(); + + std::string StateToString(); + +private: + void UpdateLastTime(); + +private: + bool mActive = true; + pid_t mPid; + std::string mLastUpdateTime = "--:--:--"; + std::vector mTraceInfo; +}; + +class HostInfo { +public: + SerCode UpdateHostInfo(uint16_t serviceId, double quantile = -1.0); + SerCode ResetTraceInfo(); + + void Inactive() + { + mActiveProcess = nullptr; + } + + bool IsActive() + { + return mActiveProcess != nullptr; + } + + std::string ToString(bool detail = false); + + const std::map> &GetAllProcesses() + { + return mAllProcesses; + } + +private: + std::shared_ptr GetOrCreateProcess(pid_t pid); + +private: + std::shared_ptr mActiveProcess = nullptr; + std::map> mAllProcesses; +}; + +class CmdHelper { +public: + void UpdateHost(double quantile = -1.0); + + void EnableTrace(const HandlerConfPara &confPara); + + HostInfo GetHostInfo(); + + void ResetTraceInfo(); + +private: + std::shared_ptr InsertHost(const std::string host); + + std::shared_ptr GetHostInfo(const std::string host); + +private: + HostInfo mHostInfo; + std::mutex mMutex; + HTracerClient mClient; +}; + +#endif // CMD_HELPER \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/htracer_cli.cpp b/test/tools/hcom_tracer/src/htracer_cli.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c87db17fe5d36a7f899f99592dffd055e0bba993 --- /dev/null +++ b/test/tools/hcom_tracer/src/htracer_cli.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include "hcom/hcom_num_def.h" +#include "hcom/hcom_err.h" +#include "rpc_client.h" +#include "htracer_utils.h" +#include "cmd_handler.h" + +using namespace ock::hcom; + +/* ! + * support + * 1. config service trace information support + * 1.1 trace level + * 1.2 trace on/off + * 2. show trace information by service id support + * 3. cross-node query support + */ + +void InvalidParamPrint() +{ + std::cout << "Invalid parameters!"<> "; + getline(std::cin, cmd); + if (cmd == "quit") { + return SER_OK; + } + cliHelper.HandleCmd(cmd); + std::cout << "Execution Done." << std::endl; + } + return SER_OK; +} \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/htracer_client.cpp b/test/tools/hcom_tracer/src/htracer_client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7ed93a67b08dd16548576072f9226ac08aeb52c --- /dev/null +++ b/test/tools/hcom_tracer/src/htracer_client.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "htracer_client.h" +#include "rpc_client.h" +#include "htracer_msg.h" + +HTracerClient::HTracerClient() +{ + mQueryRecvBuffer = new char[mQueryRecvBufferSize]; + if (mQueryRecvBuffer == nullptr) { + LOG_ERR("failed to malloc query recv buffer"); + return; + } +} + +HTracerClient::~HTracerClient() +{ + if (mQueryRecvBuffer != nullptr) { + delete[] mQueryRecvBuffer; + } +} + +SerCode HTracerClient::Query(uint16_t serviceId, double quantile, std::vector &traceInfos, pid_t &pid) +{ + QueryTraceInfoRequest queryRequest; + queryRequest.serviceId = serviceId; + queryRequest.quantile = quantile; + Message request(&queryRequest, sizeof(queryRequest)); + Message response(mQueryRecvBuffer, mQueryRecvBufferSize); + if (mClient.SyncCall(request, response) != SER_OK) { + return SER_ERROR; + } + + auto queryResponse = reinterpret_cast(mQueryRecvBuffer); + for (uint32_t i = 0; i < queryResponse->traceInfoNum; ++i) { + auto &traceInfo = queryResponse->traceInfo[i]; + traceInfos.push_back(traceInfo); + } + pid = queryResponse->pid; + return SER_OK; +} + +SerCode HTracerClient::Reset() +{ + ResetTraceInfoRequest resetRequest; + ResetTraceInfoResponse resetResponse; + Message request(&resetRequest, sizeof(resetRequest)); + Message response(&resetResponse, sizeof(resetResponse)); + if (mClient.SyncCall(request, response) != SER_OK) { + return SER_ERROR; + } + return SER_OK; +} + +SerCode HTracerClient::EnableTrace(const HandlerConfPara &confPara) +{ + EnableTraceRequest enableRequest; + enableRequest.enable = confPara.enable; + enableRequest.enableTp = confPara.enableTp; + enableRequest.enableLog = confPara.enableLog; + // HandlerConfPara and EnableTraceRequest logPath is same length. + if (strcpy_s(enableRequest.logPath, sizeof(enableRequest.logPath), confPara.logPath) != 0) { + return SER_ERROR; + } + EnableTraceResponse enableResponse; + Message request(&enableRequest, sizeof(enableRequest)); + Message response(&enableResponse, sizeof(enableResponse)); + if (mClient.SyncCall(request, response) != SER_OK) { + return SER_ERROR; + } + return SER_OK; +} \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/htracer_client.h b/test/tools/hcom_tracer/src/htracer_client.h new file mode 100644 index 0000000000000000000000000000000000000000..b6825e9e410aeb42dc5290964c31680172f093e6 --- /dev/null +++ b/test/tools/hcom_tracer/src/htracer_client.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACE_CLIENT_H +#define HTRACE_CLIENT_H + +#include +#include +#include +#include "hcom/hcom_err.h" +#include "htracer_msg.h" +#include "rpc_client.h" + +#define MAX_SERVICE_NUM (256) +#define MAX_INNER_ID_NUM (2) + +class HTracerClient { +public: + HTracerClient(); + + ~HTracerClient(); + + SerCode Query(uint16_t serviceId, double quantile, std::vector &traceInfos, pid_t &pid); + + SerCode Reset(); + + SerCode EnableTrace(const HandlerConfPara &confPara); + +private: + RpcClient mClient; + char *mQueryRecvBuffer = nullptr; + const static uint32_t mQueryRecvBufferSize = + sizeof(QueryTraceInfoResponse) + sizeof(TTraceInfo) * MAX_INNER_ID_NUM * MAX_SERVICE_NUM; +}; + + +#endif // HTRACE_CLIENT_H \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/htracer_log.h b/test/tools/hcom_tracer/src/htracer_log.h new file mode 100644 index 0000000000000000000000000000000000000000..68a674f4dfdbcaa7f4c6feb78284eacfef2d3b8f --- /dev/null +++ b/test/tools/hcom_tracer/src/htracer_log.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACER_LOG_H +#define HTRACER_LOG_H + +#include +#include +#include +#include +#include + +using ExternalLog = void (*)(int level, const char *msg); + +class Logger { +public: + Logger() {} + + static std::shared_ptr Instance() + { + static auto logger = std::make_shared(); + return logger; + } + + inline void SetExternalLogFunction(ExternalLog func) + { + mLogFunc = func; + } + + inline void Log(int32_t level, const std::ostringstream &oss) + { + if (mLogFunc != nullptr) { + mLogFunc(level, oss.str().c_str()); + return; + } + static const char *levelName[] = {"DEBUG", "INFO", "WARN", "ERROR"}; + struct timeval tv {}; + char strTime[24]; + + gettimeofday(&tv, nullptr); + strftime(strTime, sizeof strTime, "%Y-%m-%d %H:%M:%S.", localtime(&tv.tv_sec)); + + std::cout << "[" << strTime << tv.tv_usec << "]" + << "[" << levelName[level] << "]" << oss.str() << std::endl; + } + +private: + ExternalLog mLogFunc = nullptr; +}; + +#ifndef __LOG_FILENAME__ +#define __LOG_FILENAME__ (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__) +#endif + +#ifndef LOG +#define LOG(level, message) \ + do { \ + std::ostringstream oss; \ + oss << "[ " << __LOG_FILENAME__ << ":" << __LINE__ << " ][" << __FUNCTION__ << "]" << (message); \ + Logger::Instance()->Log(level, oss); \ + } while (0) +#endif + +#ifdef LOG_ENABLED +#ifndef LOG_ERR +#define LOG_ERR(message) LOG(3, message) +#endif + +#ifndef LOG_WARN +#define LOG_WARN(message) LOG(2, message) +#endif + +#ifndef LOG_INFO +#define LOG_INFO(message) LOG(1, message) +#endif + +#ifndef LOG_DEBUG +#define LOG_DEBUG(message) LOG(0, message) +#endif +#else +#define LOG_ERR(message) +#define LOG_WARN(message) +#define LOG_INFO(message) +#define LOG_DEBUG(message) +#endif +#endif // HTRACER_LOG_H \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/htracer_msg.h b/test/tools/hcom_tracer/src/htracer_msg.h new file mode 100644 index 0000000000000000000000000000000000000000..f9f63e63feaafbd6569277e06b0c1b381ea4b49a --- /dev/null +++ b/test/tools/hcom_tracer/src/htracer_msg.h @@ -0,0 +1,290 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACE_MSG_H +#define HTRACE_MSG_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include "rpc_msg.h" +#include "htracer_log.h" +#include "hcom/hcom_num_def.h" +#include "hcom/hcom_err.h" +#include "securec.h" + +#define TRACE_INFO_MAX_LEN 63 + +using namespace ock::hcom; + +constexpr uint32_t LOG_PATH_LENGTH = 260; + +enum MessageOpcode { + TRACE_OP_PING = 0, + TRACE_OP_QUERY = 1, + TRACE_OP_ENABLE_TRACE = 2, + TRACE_OP_RESET = 3 +}; + +struct HandlerConfPara { + bool enable; // enable trace + bool enableTp; // enable tp + bool enableLog; // enable log + char reserved[1]; + char logPath[LOG_PATH_LENGTH]; + HandlerConfPara(bool enable1, bool enable2, bool enable3, const std::string &path) + : enable(enable1), enableTp(enable2), enableLog(enable3) + { + if (path.size() < sizeof(logPath)) { + strcpy_s(logPath, path.length() + 1, path.c_str()); + } + } +}; + +struct TTraceInfo { + char name[TRACE_INFO_MAX_LEN + 1] = {0}; + uint64_t begin = 0; + uint64_t goodEnd = 0; + uint64_t badEnd = 0; + uint64_t min = UINT64_MAX; + uint64_t max = 0; + uint64_t total = 0; + double latencyQuentile = 0.0; + + explicit TTraceInfo(const char *name) + { + strncpy_s(this->name, TRACE_INFO_MAX_LEN + 1, name, TRACE_INFO_MAX_LEN); + } + + void operator += (const TTraceInfo &other) + { + begin += other.begin; + goodEnd += other.goodEnd; + badEnd += other.badEnd; + if (min >= other.min) { + min = other.min; + } + if (max <= other.max) { + max = other.max; + } + total += other.total; + } + + enum TracePointTimeUnit { + NANO_SECOND, + MICRO_SECOND, + MILLI_SECOND, + SECOND, + TP_TIME_UNIT + }; + + std::string ToString(TracePointTimeUnit unit = MICRO_SECOND) const + { + static uint64_t TIME_UNIT_STEP[TP_TIME_UNIT] = { + 1, + NN_NO1000, + NN_NO1000000, + NN_NO1000000000 + }; + + static std::string TIME_UNIT_NAME[TP_TIME_UNIT] = { + "ns", + "us", + "ms", + "s" + }; + std::string str; + std::ostringstream os(str); + os.flags(std::ios::fixed); + os.precision(NN_NO3); + auto unitStep = TIME_UNIT_STEP[unit]; + auto unitName = TIME_UNIT_NAME[unit]; + os << "[" << std::left << std::setw(NN_NO50) << name << "]" + << "\t" << std::left << std::setw(NN_NO15) << begin << "\t" << std::left << std::setw(NN_NO15) << goodEnd << + "\t" << std::left << std::setw(NN_NO15) << badEnd << "\t" << std::left << std::setw(NN_NO15) << + ((begin > goodEnd - badEnd) ? (begin - goodEnd - badEnd) : 0) << "\t" << std::left << std::setw(NN_NO15) << + (min == UINT64_MAX ? 0 : ((double)min / unitStep)) << "\t" << std::left << std::setw(NN_NO15) << + (double)max / unitStep << "\t" << std::left << std::setw(NN_NO15) << + (goodEnd == 0 ? 0 : (double)total / goodEnd / unitStep) << "\t" << std::left << std::setw(NN_NO15) << + (double)total / unitStep << "\t" << std::left << std::setw(NN_NO15) << + (latencyQuentile > 0 ? std::to_string(latencyQuentile) : "OFF"); + return os.str(); + } + + static std::string HeaderString() + { + std::stringstream ss; + ss << "\t[" << std::left << std::setw(NN_NO50) << "TP_NAME" + << "]" + << "\t" << std::left << std::setw(NN_NO15) << "TOTAL" + << "\t" << std::left << std::setw(NN_NO15) << "SUCCESS" + << "\t" << std::left << std::setw(NN_NO15) << "FAILURE" + << "\t" << std::left << std::setw(NN_NO15) << "UNFINISHED" + << "\t" << std::left << std::setw(NN_NO15) << "MIN(us)" + << "\t" << std::left << std::setw(NN_NO15) << "MAX(us)" + << "\t" << std::left << std::setw(NN_NO15) << "AVG(us)" + << "\t" << std::left << std::setw(NN_NO15) << "TOTAL(us)" + << "\t" << std::left << std::setw(NN_NO15) << "TPX(us)"; + return ss.str(); + } +}; + +struct PingRequest : public MessageHeader { + PingRequest() : MessageHeader(TRACE_OP_PING) {} +}; + +struct PingResponse : public MessageHeader { + PingResponse() : MessageHeader(TRACE_OP_PING) {} + pid_t pid; + + static SerCode BuildMessage(Message &message) + { + uint32_t bodySize = sizeof(pid_t); + uint32_t messageSize = sizeof(MessageHeader) + bodySize; + void *messageData = malloc(messageSize); + auto pingResponse = reinterpret_cast(messageData); + if (messageData == nullptr) { + LOG_ERR("failed to malloc message data, size:" << messageSize); + return SER_ERROR; + } + pingResponse->version = VERSION; + pingResponse->magicCode = MAGIC_CODE; + pingResponse->crc = 0; + pingResponse->opcode = TRACE_OP_QUERY; + pingResponse->bodySize = bodySize; + pingResponse->pid = getpid(); + message.SetData(messageData); + message.SetSize(messageSize); + return SER_OK; + } +}; + +struct ResetTraceInfoRequest : public MessageHeader { + ResetTraceInfoRequest() : MessageHeader(TRACE_OP_RESET) {} +}; + +struct ResetTraceInfoResponse : public MessageHeader { + ResetTraceInfoResponse() : MessageHeader(TRACE_OP_RESET) {} + + static SerCode BuildMessage(Message &message) + { + uint32_t messageSize = sizeof(ResetTraceInfoResponse); + void *messageData = malloc(messageSize); + if (messageData == nullptr) { + LOG_ERR("failed to malloc message data, size:" << messageSize); + return SER_ERROR; + } + bzero(messageData, messageSize); + + // fill message header. + auto queryResponse = reinterpret_cast(messageData); + queryResponse->version = VERSION; + queryResponse->magicCode = MAGIC_CODE; + queryResponse->crc = 0; + queryResponse->opcode = TRACE_OP_RESET; + queryResponse->bodySize = 0; + + message.SetData(messageData); + message.SetSize(messageSize); + + return SER_OK; + } +}; + +struct EnableTraceRequest : public MessageHeader { + bool enable; + bool enableTp; + bool enableLog; + char reserved[1]; + char logPath[LOG_PATH_LENGTH]; + EnableTraceRequest() : MessageHeader(TRACE_OP_ENABLE_TRACE) {} +}; + +struct EnableTraceResponse : public MessageHeader { + EnableTraceResponse() : MessageHeader(TRACE_OP_ENABLE_TRACE) {} + + static SerCode BuildMessage(Message &message) + { + uint32_t messageSize = sizeof(EnableTraceResponse); + void *messageData = malloc(messageSize); + if (messageData == nullptr) { + LOG_ERR("failed to malloc message data, size:" << messageSize); + return SER_ERROR; + } + bzero(messageData, messageSize); + + // fill message header. + auto queryResponse = reinterpret_cast(messageData); + queryResponse->version = VERSION; + queryResponse->magicCode = MAGIC_CODE; + queryResponse->crc = 0; + queryResponse->opcode = TRACE_OP_ENABLE_TRACE; + queryResponse->bodySize = 0; + + message.SetData(messageData); + message.SetSize(messageSize); + + return SER_OK; + } +}; + +struct QueryTraceInfoRequest : public MessageHeader { + uint16_t serviceId; + double quantile; + QueryTraceInfoRequest() : MessageHeader(TRACE_OP_QUERY) {} +}; + +struct QueryTraceInfoResponse : public MessageHeader { + uint32_t traceInfoNum = 0; + pid_t pid; + TTraceInfo traceInfo[0]; + + QueryTraceInfoResponse() : MessageHeader(TRACE_OP_QUERY) {} + + static SerCode BuildMessage(const std::vector &tTranceInfos, Message &message) + { + uint32_t bodySize = sizeof(uint32_t) + sizeof(pid_t) + sizeof(TTraceInfo) * tTranceInfos.size(); + uint32_t messageSize = sizeof(MessageHeader) + bodySize; + void *messageData = malloc(messageSize); + if (messageData == nullptr) { + LOG_ERR("failed to malloc message data, size:" << messageSize); + return SER_ERROR; + } + bzero(messageData, messageSize); + + // fill message header. + auto queryResponse = reinterpret_cast(messageData); + queryResponse->version = VERSION; + queryResponse->magicCode = MAGIC_CODE; + queryResponse->crc = 0; + queryResponse->opcode = TRACE_OP_QUERY; + queryResponse->bodySize = bodySize; + queryResponse->pid = getpid(); + queryResponse->traceInfoNum = tTranceInfos.size(); + + // file message body. + int i = 0; + for (const auto &info : tTranceInfos) { + queryResponse->traceInfo[i++] = info; + } + message.SetData(messageData); + message.SetSize(messageSize); + return SER_OK; + } +}; + +#endif // HTRACE_MSG_H \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/htracer_utils.h b/test/tools/hcom_tracer/src/htracer_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..ed7c0f450c3b1e21e0b4626dd1957beb51911338 --- /dev/null +++ b/test/tools/hcom_tracer/src/htracer_utils.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef HTRACE_UTILS_H +#define HTRACE_UTILS_H + +#include "securec.h" +#include "hcom/hcom_num_def.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ock { +namespace hcom { +class HTracerUtils { +public: + static std::vector StrSplit(const std::string &str, char delim) + { + std::vector res; + std::string::size_type start = 0; + while (true) { + auto pos = str.find(delim, start); + if (pos == std::string::npos) { + res.push_back(str.substr(start)); + break; + } + res.push_back(str.substr(start, pos - start)); + start = pos + 1; + } + return res; + } + + static std::string GetCmdOption(char **begin, char **end, const std::string &option) + { + char **itr = std::find(begin, end, option); + if (itr != end && ++itr != end) { + return *itr; + } + return ""; + } + + static bool ExistCmdOption(std::vector cmds, const std::string &option) + { + auto itr = std::find(cmds.begin(), cmds.end(), option); + return itr != cmds.end(); + } + + static std::string GetCmdOption(std::vector cmds, const std::string &option) + { + auto itr = std::find(cmds.begin(), cmds.end(), option); + if (itr != cmds.end() && ++itr != cmds.end()) { + return *itr; + } + return ""; + } + + static std::string CurrentTime() + { + time_t rawTime; + time(&rawTime); + auto tmInfo = localtime(&rawTime); + std::stringstream ss; + ss << std::setfill('0') << std::setw(NN_NO4) << std::right << (NN_NO1900 + tmInfo->tm_year) << "-" << + std::setfill('0') << std::setw(NN_NO2) << std::right << (NN_NO1 + tmInfo->tm_mon) << "-" << + std::setfill('0') << std::setw(NN_NO2) << std::right << tmInfo->tm_mday << " " << std::setfill('0') << + std::setw(NN_NO2) << std::right << tmInfo->tm_hour << ":" << std::setfill('0') << std::setw(NN_NO2) << + std::right << tmInfo->tm_min << ":" << std::setfill('0') << std::setw(NN_NO2) << std::right << + tmInfo->tm_sec; + return ss.str(); + } +}; +} +} +#endif // HTRACE_UTILS_H \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/rpc_client.cpp b/test/tools/hcom_tracer/src/rpc_client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48b175950931e1080bf78945830a94445d7e1baf --- /dev/null +++ b/test/tools/hcom_tracer/src/rpc_client.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "rpc_client.h" +#include "htracer_log.h" +#include "htracer_utils.h" + +std::string RpcClient::serverName = "udx_server"; + +SerCode RpcClient::SyncCall(const Message &request, Message &response) +{ + if (!MessageValidator::Validate(request)) { + LOG_ERR("request message is invalidate"); + return SER_ERROR; + } + + int32_t sockFd = Connect(); + if (sockFd == -1) { + LOG_ERR("failed to connect, please check server is available"); + return SER_ERROR; + } + + if (::send(sockFd, request.GetData(), request.GetSize(), 0) == -1) { + LOG_ERR("failed to send message"); + ::close(sockFd); + return SER_ERROR; + } + + if (::recv(sockFd, response.GetData(), response.GetSize(), 0) == -1) { + LOG_ERR("failed to receive message"); + ::close(sockFd); + return SER_ERROR; + } + + if (!MessageValidator::Validate(response)) { + LOG_ERR("response message is invalidate"); + ::close(sockFd); + return SER_ERROR; + } + ::close(sockFd); + return SER_OK; +} + +int32_t RpcClient::Connect() +{ + std::string abstractSockName(1, '\0'); + abstractSockName += RpcClient::serverName; + + struct sockaddr_un un; + auto ret = memset_s(&un, sizeof(un), 0, sizeof(un)); + if (ret != 0) { + LOG_ERR("failed to memset_s sockaddr un"); + return -1; + } + un.sun_family = AF_UNIX; + ret = memcpy_s(un.sun_path, abstractSockName.length() + 1, abstractSockName.c_str(), abstractSockName.length() + 1); + if (ret != 0) { + LOG_ERR("failed to memcpy_s to sun_path"); + return -1; + } + + int32_t sockFd = socket(AF_UNIX, SOCK_STREAM, 0); + if (sockFd == -1) { + LOG_ERR("failed to create connection socket"); + return -1; + } + + if (connect(sockFd, reinterpret_cast(&un), sizeof(un)) < 0) { + LOG_ERR("connect failed"); + close(sockFd); + return -1; + } + + return sockFd; +} \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/rpc_client.h b/test/tools/hcom_tracer/src/rpc_client.h new file mode 100644 index 0000000000000000000000000000000000000000..c1766b73c215e216e4bb099cd8c185a6e21def4f --- /dev/null +++ b/test/tools/hcom_tracer/src/rpc_client.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef RPC_CLIENT_H +#define RPC_CLIENT_H + +#include +#include "hcom/hcom_err.h" +#include "rpc_msg.h" + +using namespace ock::hcom; + +class RpcClient { +public: + SerCode SyncCall(const Message &request, Message &response); + static std::string serverName; + +private: + int32_t Connect(); +}; + +#endif // RPC_CLIENT_H \ No newline at end of file diff --git a/test/tools/hcom_tracer/src/rpc_msg.h b/test/tools/hcom_tracer/src/rpc_msg.h new file mode 100644 index 0000000000000000000000000000000000000000..91f30bc95202ef20eb8d183dff79ae9be7d7710d --- /dev/null +++ b/test/tools/hcom_tracer/src/rpc_msg.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef RPC_MSG_H +#define RPC_MSG_H +#include + +#define VERSION 1 +#define MAGIC_CODE 0xABABABAB +#define INVALID_OPCODE 0xFFFFFFFF + +struct MessageHeader { + uint32_t version = VERSION; + uint32_t magicCode = MAGIC_CODE; + uint32_t crc = 0; + uint32_t opcode = INVALID_OPCODE; + uint32_t bodySize = 0; + uint32_t reserved = 0; + explicit MessageHeader(uint32_t opcode) : opcode(opcode) {} +}; + +class Message { +public: + Message(void *data, uint32_t dataSize) : mData(data), mSize(dataSize) {} + Message() : Message(nullptr, 0) {} + + void *GetData() const + { + return mData; + } + + void SetData(void *data) + { + mData = data; + } + + uint32_t GetSize() const + { + return mSize; + } + + void SetSize(uint32_t size) + { + mSize = size; + } + + const MessageHeader *GetHeader() const + { + if (mData == nullptr) { + return nullptr; + } + return reinterpret_cast(mData); + } + +private: + void *mData = nullptr; + uint32_t mSize = 0; +}; + +class MessageValidator { +public: + static bool Validate(const Message &message) + { + void *messageData = message.GetData(); + uint32_t messageSize = message.GetSize(); + if (messageData == nullptr || messageSize == 0) { + return false; + } + + MessageHeader *header = reinterpret_cast(messageData); + if (header->version != VERSION || header->magicCode != MAGIC_CODE || + header->bodySize + sizeof(MessageHeader) > messageSize) { + return false; + } + return true; + } +}; +#endif // RPC_MSG_H \ No newline at end of file diff --git a/test/tools/perf_test/CMakeLists.txt b/test/tools/perf_test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6c356ebed7d809974f4a7cdca7e3c9b2346418e6 --- /dev/null +++ b/test/tools/perf_test/CMakeLists.txt @@ -0,0 +1,60 @@ +# +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# + +cmake_minimum_required(VERSION 3.14.1) +project(HCOM_PERF_TEST C CXX) +set(CMAKE_SKIP_RPATH TRUE) + +if (NOT EXISTS ${HCOM_INCLUDE_DIR}) + set(HCOM_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/../../dist/hcom/include") + message(INFO "-- HCOM_INCLUDE_DIR is empty, use default value(${HCOM_INCLUDE_DIR})") +endif() + +if (NOT EXISTS ${HCOM_LIB_DIR}) + set(HCOM_LIB_DIR "${CMAKE_SOURCE_DIR}/../../dist/hcom/lib") + message(INFO "-- HCOM_LIB_DIR is empty, use default value(${HCOM_LIB_DIR})") +endif() + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE release) +endif() +message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") + +include_directories("${HCOM_INCLUDE_DIR}") +link_directories("${HCOM_LIB_DIR}") +include_directories("${CMAKE_SOURCE_DIR}") + +if (${CMAKE_BUILD_TYPE} MATCHES "release") + set(CXX_FLAGS + -Wno-format-overflow + -O3 + -Wno-unused-function + -fPIE -pie + -fstack-protector-strong + -Wl,-z,relro,-z,now,-z,noexecstack + -s + ) +else () + set(CXX_FLAGS + -Wall + -Wno-format-overflow + -Wno-unused-function + -fPIE -pie + -fstack-protector-strong + -Wl,-z,relro,-z,now,-z,noexecstack + -s + ) +endif () + +file(GLOB_RECURSE PERF_TEST_SOURCES + "${CMAKE_SOURCE_DIR}/hcom_perf_test.cpp" + "${CMAKE_SOURCE_DIR}/test_case/service_v2/*.cpp" + "${CMAKE_SOURCE_DIR}/test_case/transport/*.cpp" + "${CMAKE_SOURCE_DIR}/report/*.cpp" + "${CMAKE_SOURCE_DIR}/common/*.cpp") + +set(HCOM_PERF_TOOL hcom_perf) +add_executable(${HCOM_PERF_TOOL} ${PERF_TEST_SOURCES}) +target_link_libraries(${HCOM_PERF_TOOL} hcom_static pthread rt dl) +target_link_libraries(${HCOM_PERF_TOOL} hcom_static pthread boundscheck) diff --git a/test/tools/perf_test/README.md b/test/tools/perf_test/README.md new file mode 100644 index 0000000000000000000000000000000000000000..102d94d709001c0657e71d8b41f2e63288e845fa --- /dev/null +++ b/test/tools/perf_test/README.md @@ -0,0 +1,89 @@ +## 1 简介 + +`hcom_perf`是一个性能测试工具,类似 [Infiniband Verbs Performance Tests](https://github.com/linux-rdma/perftest/tree/master)。 + +## 2 编译 + +```shell +# 下载hcom源码(包含hcom_perf工具) +$ git clone --recurse-submodules + +# 编译hcom(hcom_perf编译依赖libhcomstatic.a) +$ cd hcom +$ bash build.sh + +# 编译hcom_perf,注意测试性能要求hcom_perf和hcom都用release版本 +$ cd test/tools/perf_test +$ mkdir build && cd build + +# 头文件默认使用是/output/hcom/include, 如果需要更改, 可通过环境变量HCOM_INCLUDE_DIR设置。 +# 库文件默认使用是/output/hcom/lib, 如果需要更改, 可通过环境变量HCOM_LIB_DIR设置。 +# 例如, cmake -DHCOM_INCLUDE_DIR=/usr/include/hcom -DHCOM_LIB_DIR=/usr/local/lib/hcom .. +$ cmake -DCMAKE_BUILD_TYPE=release .. +$ make -j8 +``` + +## 3 使用 + +```shell +#rdma +# server端(假设server端IP为10.10.1.63) +./hcom_perf --case transport_send_lat --role server --protocol rdma -i 10.10.1.63 -n 1000 --all + +# client端 +./hcom_perf --case transport_send_lat --role client --protocol rdma -i 10.10.1.63 -n 1000 --all + +#service UBC +# server端(假设server端IP为10.10.1.63) +./hcom_perf --case service_send_lat --role server --protocol UBC -i 10.10.1.63 -n 1000 --all + +# client端 +./hcom_perf --case service_send_lat --role client --protocol UBC -i 10.10.1.63 -n 1000 --all + +#server 端输入 “q” 退出 + +# 如果有需要,可以执行如下命令,调整HCOM打印级别 +export HCOM_SET_LOG_LEVEL=3 +``` + + + +> 注意事项: +> +> 1. 性能测试时,`hcom_perf`工具和`hcom`都要使用release版本。 +> 2. 默认不绑核,绑定网卡所在的`numa`中的某个核效率最高。 +> 3. 参考`ibv`或者`urma`的`perf`场景,本工具也仅支持单线程。 +> 4. 调整`hcom`日志打印级别(`export HCOM_SET_LOG_LEVEL=3`),可以精简打印。 +> 5. 通过环境变量`HCOM_PERF_TEST_LOG_LEVEL`可以调整`hcom_perf`的打印级别,`0/1/2/3`依次对应`debug/info/warning/error`日志级别。 + + + +## 4 开发指南 + +> 注意: +> +> ​ 本章节仅针对`hcom_perf`开发人员。 + +开发过程中,要注意如下事项: + +1. 关键路径(发送、接收)逻辑要尽可能简单和高效。 +2. 传输层或者服务层创建、初始化及启动`hcom`实例有较多公共逻辑,建议将公共代码放到*_helper辅助类中,避免重复代码。 + + + +遗留如下问题待处理: + +1. 带宽的统计方式及打印,需要参考`ibv`或`urma`做实现和验证。 +2. 仅实现和验证了“传输层-RDMA协议-双边发送时延”,其它场景需要进一步开发和验证。 +3. 后续可考虑移除role参数,内部通过其它方式判断client或者server。如完成初始建链后,client可以向sever发送一条消息(包大小、测试次数等),以实现简化server的参数输入。 +4. 使用`ibv`测试工具时,client和server均会打印结果。`hcom—perf`当前仅client端打印的耗时情况,如果需要server端也统计和打印结果,每次测试的开始和结束,需要client和server额外交互控制信息。 + + + +## 参考资料 + +1. [Infiniband Verbs Performance Tests](https://github.com/linux-rdma/perftest/tree/master) +2. [Infiniband Verbs Performance Test 时延统计](https://github.com/linux-rdma/perftest/blob/master/src/perftest_parameters.c#L4118) +3. [URMA Performance Tests](https://codehub-y.huawei.com/nStack/nStack/Cloud/UBus/urma/files?ref=master&filePath=code%2Ftools%2Furma_perftest) +4. [URMA perftest运行与开发](https://wiki.huawei.com/domains/13898/wiki/30209/WIKI202412045316527) + diff --git a/test/tools/perf_test/common/perf_test_common.h b/test/tools/perf_test/common/perf_test_common.h new file mode 100644 index 0000000000000000000000000000000000000000..700149cbc3417065ee2b9526dd54a180c5509524 --- /dev/null +++ b/test/tools/perf_test/common/perf_test_common.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_COMMON_H +#define HCOM_PERF_TEST_COMMON_H + +#include "common/perf_test_config.h" + +namespace hcom { +namespace perftest { +constexpr uint64_t MESSAGE_SIZE_BASE = 2; +constexpr uint64_t UB_MAX_SIZE = 65536; + +class PerfTestContext { +public: + uint64_t tposted[MAX_ITERATIONS] = {0}; + uint64_t cnt = 0; + uint64_t mIterations = 0; + uint32_t mSize = 0; + uint64_t totrcnt = 0; +}; + +class MrInfo { +public: + uintptr_t lAddress = 0; + uint64_t lKey = 0; + uint32_t size = 0; +}; + +} +} + +#endif diff --git a/test/tools/perf_test/common/perf_test_config.cpp b/test/tools/perf_test/common/perf_test_config.cpp new file mode 100644 index 0000000000000000000000000000000000000000..00336dc4047cb26683359f42ba26947fb33ee902 --- /dev/null +++ b/test/tools/perf_test/common/perf_test_config.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include + +#include "common/perf_test_logger.h" +#include "common/perf_test_utils.h" +#include "common/perf_test_config.h" + +namespace hcom { +namespace perftest { +static const std::vector> gPerfTestType = { + { "TRANSPORT_SEND_LAT", PERF_TEST_TYPE::TRANSPORT_SEND_LAT }, + { "TRANSPORT_SEND_BW", PERF_TEST_TYPE::TRANSPORT_SEND_BW }, + { "TRANSPORT_READ_LAT", PERF_TEST_TYPE::TRANSPORT_READ_LAT }, + { "TRANSPORT_READ_BW", PERF_TEST_TYPE::TRANSPORT_READ_BW }, + { "TRANSPORT_WRITE_LAT", PERF_TEST_TYPE::TRANSPORT_WRITE_LAT }, + { "TRANSPORT_WRITE_BW", PERF_TEST_TYPE::TRANSPORT_WRITE_BW }, + { "SERVICE_SEND_LAT", PERF_TEST_TYPE::SERVICE_SEND_LAT }, + { "SERVICE_SEND_BW", PERF_TEST_TYPE::SERVICE_SEND_BW }, + { "SERVICE_READ_LAT", PERF_TEST_TYPE::SERVICE_READ_LAT }, + { "SERVICE_READ_BW", PERF_TEST_TYPE::SERVICE_READ_BW }, + { "SERVICE_WRITE_LAT", PERF_TEST_TYPE::SERVICE_WRITE_LAT }, + { "SERVICE_WRITE_BW", PERF_TEST_TYPE::SERVICE_WRITE_BW }, +}; + +bool PerfTestConfig::SetType(const std::string &cmd) +{ + for (const auto &item : gPerfTestType) { + if (PerfTestUtils::IsStringCaseInsensitiveEqual(cmd, item.first)) { + mType = item.second; + return true; + } + } + LOG_ERROR("Get perftest type for cmd(" << cmd << ") failed!"); + mType = PERF_TEST_TYPE::UNKNOWN; + return false; +} + +static const std::vector> gPerfTestProtocol = { + { "RDMA", ock::hcom::UBSHcomNetDriverProtocol::RDMA }, { "TCP", ock::hcom::UBSHcomNetDriverProtocol::TCP }, + { "SHM", ock::hcom::UBSHcomNetDriverProtocol::SHM }, { "UBC", ock::hcom::UBSHcomNetDriverProtocol::UBC }, +}; + +bool PerfTestConfig::SetProtocol(const std::string &cmd) +{ + for (const auto &item : gPerfTestProtocol) { + if (PerfTestUtils::IsStringCaseInsensitiveEqual(cmd, item.first)) { + mProtocol = item.second; + return true; + } + } + LOG_ERROR("Get perftest protocol for cmd(" << cmd << ") failed!"); + mProtocol = ock::hcom::UBSHcomNetDriverProtocol::UNKNOWN; + return false; +} + +PERF_TEST_TYPE PerfTestConfig::GetType() const +{ + return mType; +} + +static void HelpInfo(const char *argv0) +{ + std::cout << "[example]" << std::endl; + std::cout << "hcom_perf --case transport_send_lat --role server --protocol rdma -i 10.10.1.63 -n 1000 -d 0 --all"; + std::cout << std::endl; + std::cout << "hcom_perf --case transport_send_lat --role client --protocol rdma -i 10.10.1.63 -n 1000 -d 0 --all"; + std::cout << std::endl; +} + +void PerfTestConfig::Print() +{ + LOG_DEBUG("mIsServer = " << mIsServer); + LOG_DEBUG("mIterations = " << mIterations); + LOG_DEBUG("mSize = " << mSize); + LOG_DEBUG("oobIp = " << mOobIp); + LOG_DEBUG("oobPort = " << mOobPort); + LOG_DEBUG("ipSeg = " << mIpMask); + LOG_DEBUG("protocol = " << mProtocol); + LOG_DEBUG("mIsTestAllSize = " << mIsTestAllSize); + LOG_DEBUG("cpuId = " << mCpuId); +} + +PerfTestConfig::PerfTestConfig() +{ + mIsServer = true; + mIterations = 1000; + mSize = 1024; + mOobIp = ""; + mOobPort = 8850; + mIpMask = "192.168.100.0/24"; + mProtocol = ock::hcom::RDMA; + mCpuId = -1; + mIsTestAllSize = false; +} + +bool PerfTestConfig::SetIsServer(const std::string &role) +{ + if (PerfTestUtils::IsStringCaseInsensitiveEqual(role, "server")) { + mIsServer = true; + } else { + mIsServer = false; + } + return true; +} + +bool PerfTestConfig::SelfCheck() +{ + if (GetIterations() >= MAX_ITERATIONS) { + LOG_WARN("Input Iteration(=" << GetIterations() << ") is larger than MAX_ITERATIONS(=" << MAX_ITERATIONS << ")" + << "iters is set to MAX_ITERATIONS."); + SetIterations(MAX_ITERATIONS); + } + + if (GetSize() >= MAX_MESSAGE_SIZE) { + LOG_WARN("Input size(=" << GetSize() << ") is larger than MAX_MESSAGE_SIZE(=" << MAX_MESSAGE_SIZE << ")" + << "size is set to MAX_MESSAGE_SIZE."); + SetIterations(MAX_ITERATIONS); + } + + // 如果需要测试所有尺寸,则按照最大尺寸准备缓冲区 + if (GetIsTestAllSize()) { + SetSize(MAX_MESSAGE_SIZE); + } + return true; +} + +bool PerfTestConfig::ParseArgs(int argc, char *argv[]) +{ + struct option options[] = { + {"case", required_argument, nullptr, 'C'}, + {"role", required_argument, nullptr, 'R'}, + {"protocol", required_argument, nullptr, 'P'}, + {"ip", required_argument, nullptr, 'i'}, + {"port", optional_argument, nullptr, 'p'}, + {"ipMask", optional_argument, nullptr, 'm'}, + {"all", no_argument, nullptr, 'a'}, + {"size", optional_argument, nullptr, 's'}, + {"iters", optional_argument, nullptr, 'n'}, + {"coreId", optional_argument, nullptr, 'c'}, + {"help", no_argument, nullptr, 'h'}, + {nullptr, 0, nullptr, 0}, + }; + + int ret = 0; + int index = 0; + char inputChar[] = "C:R:P:i:p:m:a:s:n:c:h"; + while ((ret = getopt_long(argc, argv, inputChar, options, &index)) != -1) { + switch (ret) { + case 'C': + if (!SetType(optarg)) { + return false; + } + break; + case 'R': + if (!SetIsServer(optarg)) { + return false; + } + break; + case 'P': + if (!SetProtocol(optarg)) { + return false; + } + break; + case 'i': + mOobIp = optarg; + mIpMask = mOobIp + "/16"; + break; + case 'm': + // 后续根据需要,扩展支持更多网段 + break; + case 'p': + mOobPort = (uint16_t)strtoul(optarg, nullptr, 0); + break; + case 'n': + SetIterations(static_cast(strtoul(optarg, nullptr, 0))); + break; + case 's': + SetSize(static_cast(strtoul(optarg, nullptr, 0))); + break; + case 'c': + mCpuId = strtoul(optarg, nullptr, 0); + break; + case 'h': + HelpInfo(argv[0]); + break; + case 'a': + SetIsTestAllSize(true); + break; + default: + break; + } + } + return SelfCheck(); +} +} +} \ No newline at end of file diff --git a/test/tools/perf_test/common/perf_test_config.h b/test/tools/perf_test/common/perf_test_config.h new file mode 100644 index 0000000000000000000000000000000000000000..46496808046a7beb72e33c30dd3382a0f592495f --- /dev/null +++ b/test/tools/perf_test/common/perf_test_config.h @@ -0,0 +1,140 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_CONFIG_H +#define HCOM_PERF_TEST_CONFIG_H + +#include +#include +#include + +#include "hcom/hcom.h" + +namespace hcom { +namespace perftest { +constexpr uint32_t MAX_MESSAGE_SIZE = 1048576; // 1MB = 2^20B +// option中的mrSendReceiveSegSize要配置为 MAX_MESSAGE_SIZE + HCOM_HEADER_SIZE +// sizeof(UBSHcomNetTransHeader) = 32 +constexpr uint32_t HCOM_HEADER_SIZE = 1024; // 需要内存对齐 +constexpr uint32_t MAX_ITERATIONS = 200000; + +enum class PERF_TEST_TYPE { + // 固定使用偶数枚举值代表时延,奇数枚举值代表带宽 + TRANSPORT_SEND_LAT = 0, + TRANSPORT_SEND_BW = 1, + TRANSPORT_READ_LAT = 2, + TRANSPORT_READ_BW = 3, + TRANSPORT_WRITE_LAT = 4, + TRANSPORT_WRITE_BW = 5, + + SERVICE_SEND_LAT = 100, + SERVICE_SEND_BW = 101, + SERVICE_READ_LAT = 102, + SERVICE_READ_BW = 103, + SERVICE_WRITE_LAT = 104, + SERVICE_WRITE_BW = 105, + + UNKNOWN = 0xFFFF +}; + +class PerfTestConfig { +public: + PerfTestConfig(); + bool ParseArgs(int argc, char *argv[]); + + std::string GetOobIp() const + { + return mOobIp; + } + + uint16_t GetOobPort() const + { + return mOobPort; + } + + bool SetProtocol(const std::string &cmd); + ock::hcom::UBSHcomNetDriverProtocol GetProtocol() const + { + return mProtocol; + } + void Print(); + + bool SetType(const std::string &cmd); + PERF_TEST_TYPE GetType() const; + + bool SetIsServer(const std::string &cmd); + bool GetIsServer() const + { + return mIsServer; + } + + void SetIterations(uint64_t iters) + { + mIterations = iters; + } + + uint64_t GetIterations() const + { + return mIterations; + } + + void SetSize(uint32_t size) + { + mSize = size; + } + + uint32_t GetSize() const + { + return mSize; + } + + void SetIsTestAllSize(bool flag) + { + mIsTestAllSize = flag; + } + + bool GetIsTestAllSize() const + { + return mIsTestAllSize; + } + + bool GetIsBwNoPeak() + { + // 不确定该参数对bw结果的影响,固定返回false + return false; + } + + void SetCpuId(int16_t cpuId) + { + mCpuId = cpuId; + } + + int32_t GetCpuId() const + { + return mCpuId; + } + + std::string GetIpMask() const + { + return mIpMask; + } + +private: + // 检查参数有效性 + bool SelfCheck(); + +private: + PERF_TEST_TYPE mType; // 测试类型 + ock::hcom::UBSHcomNetDriverProtocol mProtocol; // 底层通信协议类型 + bool mIsServer; // 是否为server + std::string mOobIp; // server OOB Ip + uint16_t mOobPort; // server OOB Port + std::string mIpMask; // OOB网段,client server相同 + int16_t mCpuId; // 亲和性配置,-1为不绑核 + uint64_t mIterations; // 测试次数 + uint32_t mSize; // 测试的最大包大小 + bool mIsTestAllSize = false; // 是否测试所有大小的包 +}; +} +} +#endif \ No newline at end of file diff --git a/test/tools/perf_test/common/perf_test_logger.cpp b/test/tools/perf_test/common/perf_test_logger.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72dbc884e1e41336e79f8757cda2cc427f9304b1 --- /dev/null +++ b/test/tools/perf_test/common/perf_test_logger.cpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include "common/perf_test_logger.h" + +namespace hcom { +namespace perftest { + +PerfTestLogger *PerfTestLogger::gLogger = nullptr; +std::mutex PerfTestLogger::gMutex; +int PerfTestLogger::logLevel = PERF_TEST_NO1; + +PerfTestLogger *PerfTestLogger::Instance() +{ + if (gLogger == nullptr) { + std::lock_guard lock(gMutex); + if (gLogger == nullptr) { + gLogger = new (std::nothrow) PerfTestLogger(); + if (gLogger == nullptr) { + std::cout << "Failed to new PerfTestLogger, probably out of memory" << std::endl; + } + SetLogLevel(); + } + } + + return gLogger; +} + +void PerfTestLogger::SetLogLevel() +{ + /* set one of 0,1,2,3 */ + char *envSize = ::getenv("HCOM_PERF_TEST_LOG_LEVEL"); + if (envSize != nullptr) { + long value = 0; + if (!SetStrStol(envSize, value)) { + std::cout << "Invalid setting 'HCOM_PERF_TEST_LOG_LEVEL', should set one of 0,1,2,3 " << std::endl; + return; + } + logLevel = value; + } +} + +void PerfTestLogger::SetLogLevel(int level) +{ + if (level >= static_cast(PERF_TEST_NO0) && level <= static_cast(PERF_TEST_NO3)) { + logLevel = level; + } +} + +bool PerfTestLogger::SetStrStol(const std::string &str, long &value) +{ + char *remain = nullptr; + errno = 0; + value = std::strtol(str.c_str(), &remain, 10); // 10 is decimal digits + if (remain == nullptr || strlen(remain) > 0 || value < PERF_TEST_NO0 || value > PERF_TEST_NO3 || errno == ERANGE) { + return false; + } else if (value == 0 && str != "0") { + return false; + } + + return true; +} + +void PerfTestLogger::Log(int level, const std::ostringstream &oss) const +{ + struct timeval tv {}; + char strTime[24]; + + int ret = gettimeofday(&tv, nullptr); + if (ret != 0) { + std::cout << "Fail to get the current system time, " << ret << "." << std::endl; + } + time_t timeStamp = tv.tv_sec; + struct tm localTime {}; + struct tm *resultTime = localtime_r(&timeStamp, &localTime); + if ((resultTime != nullptr) && + (strftime(strTime, sizeof strTime, "%Y-%m-%d %H:%M:%S.", resultTime) != PERF_TEST_NO0)) { + std::cout << strTime << tv.tv_usec << " " << level << " " << oss.str().c_str() << std::endl; + } else { + std::cout << "Invalid time trace " << tv.tv_usec << " " << level << " " << oss.str().c_str() << std::endl; + } +} + +} +} diff --git a/test/tools/perf_test/common/perf_test_logger.h b/test/tools/perf_test/common/perf_test_logger.h new file mode 100644 index 0000000000000000000000000000000000000000..a8eceaf89bbcfdec069b5cee5673c9a47ac15455 --- /dev/null +++ b/test/tools/perf_test/common/perf_test_logger.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef PERF_TEST_LOG_H +#define PERF_TEST_LOG_H + +#include +#include +#include +#include +#include +#include + +namespace hcom { +namespace perftest { + +#define PERF_TEST_NO0 0 +#define PERF_TEST_NO1 1 +#define PERF_TEST_NO2 2 +#define PERF_TEST_NO3 3 + +class PerfTestLogger { +public: + static PerfTestLogger *Instance(); + + static void SetLogLevel(); + + static void SetLogLevel(int level); + + static bool SetStrStol(const std::string &str, long &value); + + void Log(int level, const std::ostringstream &oss) const; + + PerfTestLogger(const PerfTestLogger &) = delete; + PerfTestLogger &operator = (const PerfTestLogger &) = delete; + PerfTestLogger(PerfTestLogger &&) = delete; + PerfTestLogger &operator = (PerfTestLogger &&) = delete; + + ~PerfTestLogger() {} + + inline int GetLogLevel() + { + return logLevel; + } + +private: + PerfTestLogger() = default; + +private: + static PerfTestLogger *gLogger; + static std::mutex gMutex; + static int logLevel; +}; + +// macro for log +#ifndef PERF_TEST_LOG_FILENAME +#define PERF_TEST_LOG_FILENAME (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__) +#endif + +#define PERTTEST_LOG(level, args) \ + do { \ + if ((level) >= (PerfTestLogger::Instance()->GetLogLevel())) { \ + std::ostringstream oss; \ + oss << "[perf_test " << PERF_TEST_LOG_FILENAME << ":" << __LINE__ << "] " << args; \ + PerfTestLogger::Instance()->Log(level, oss); \ + } \ + } while (0) + +#define LOG_DEBUG(args) PERTTEST_LOG(0, args) +#define LOG_INFO(args) PERTTEST_LOG(1, args) +#define LOG_WARN(args) PERTTEST_LOG(2, args) +#define LOG_ERROR(args) PERTTEST_LOG(3, args) + +} +} + +#endif diff --git a/test/tools/perf_test/common/perf_test_utils.cpp b/test/tools/perf_test/common/perf_test_utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afdb0f7219fe79e472c0268999bf9a36f6838319 --- /dev/null +++ b/test/tools/perf_test/common/perf_test_utils.cpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include "common/perf_test_utils.h" + +namespace hcom { +namespace perftest { + + +bool PerfTestUtils::IsStringCaseInsensitiveEqual(const std::string& left, const std::string& right) +{ + if (left.size() != right.size()) { + return false; + } + + for (std::size_t i = 0; i < left.size(); ++i) { + if (std::tolower(left[i]) != std::tolower(right[i])) { + return false; + } + } + return true; +} + +} +} \ No newline at end of file diff --git a/test/tools/perf_test/common/perf_test_utils.h b/test/tools/perf_test/common/perf_test_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..c3008e0d9278359f322c087f6f8aba1668f788bf --- /dev/null +++ b/test/tools/perf_test/common/perf_test_utils.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef PERF_TEST_UTILS_H +#define PERF_TEST_UTILS_H + +#include + +namespace hcom { +namespace perftest { + +class PerfTestUtils { +public: + static bool IsStringCaseInsensitiveEqual(const std::string& left, const std::string& right); +}; + +} +} +#endif \ No newline at end of file diff --git a/test/tools/perf_test/docs/.gitkeep b/test/tools/perf_test/docs/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/tools/perf_test/hcom_perf_test.cpp b/test/tools/perf_test/hcom_perf_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7fb85083f34e8b22508ded16dd9e65bb38afaaaa --- /dev/null +++ b/test/tools/perf_test/hcom_perf_test.cpp @@ -0,0 +1,165 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +#include +#include + +#include "hcom/hcom.h" + +#include "common/perf_test_config.h" +#include "common/perf_test_common.h" +#include "common/perf_test_logger.h" +#include "test_case/perf_test_base.h" +#include "test_case/perf_test_factory.h" +#include "report/perf_test_report_base.h" +#include "report/perf_test_report_factory.h" + +namespace hcom { +namespace perftest { + +static bool IsUbcProtocol(ock::hcom::UBSHcomNetDriverProtocol &protocol) +{ + if (protocol == ock::hcom::UBSHcomNetDriverProtocol::UBC) { + return true; + } + return false; +} + +static bool IsSendType(PERF_TEST_TYPE &type) +{ + if ((type == PERF_TEST_TYPE::TRANSPORT_SEND_BW || type == PERF_TEST_TYPE::SERVICE_SEND_BW || + type == PERF_TEST_TYPE::TRANSPORT_SEND_LAT || type == PERF_TEST_TYPE::SERVICE_SEND_LAT)) { + return true; + } + return false; +} + +static bool RunAllSizeTest(PerfTestBase *pTest, const PerfTestConfig &cfg, PerfTestReportBase *report) +{ + uint64_t size = MESSAGE_SIZE_BASE; + ock::hcom::UBSHcomNetDriverProtocol protocol = cfg.GetProtocol(); + PERF_TEST_TYPE type = cfg.GetType(); + while (size <= cfg.GetSize()) { + PerfTestContext ctx; + ctx.mIterations = cfg.GetIterations(); + ctx.mSize = size; + if (!pTest->RunTest(&ctx)) { + LOG_ERROR("run test failed"); + return false; + } + report->PrintReportElement(&ctx); + sleep(1); + size *= 2; + if ((size >= UB_MAX_SIZE) && IsUbcProtocol(protocol) && IsSendType(type)) { + return true; + } + } + return true; +} + +static void ServerRunAllSizeTest(PerfTestBase *pTest, const PerfTestConfig &cfg, PerfTestReportBase *report) +{ + uint64_t size = MESSAGE_SIZE_BASE; + ock::hcom::UBSHcomNetDriverProtocol protocol = cfg.GetProtocol(); + PERF_TEST_TYPE type = cfg.GetType(); + while (size <= cfg.GetSize()) { + PerfTestContext ctx; + ctx.mIterations = cfg.GetIterations(); + ctx.mSize = size; + if (!pTest->RunTest(&ctx)) { + LOG_ERROR("run test failed"); + break; + } + size *= 2; + if ((size >= UB_MAX_SIZE) && IsUbcProtocol(protocol) && IsSendType(type)) { + break; + } + } + pTest->UnInitialize(); + return; +} + +void RunTest(const PerfTestConfig &cfg, PerfTestReportBase *report) +{ + PerfTestBase *pTest = PerfTestFactory::GetInstance().CreatePerfTest(cfg.GetType(), cfg); + if (pTest == nullptr) { + LOG_ERROR("create perf test failed!"); + return; + } + + if (!pTest->Initialize()) { + LOG_ERROR("instance create and start failed!"); + return; + } + PERF_TEST_TYPE type = cfg.GetType(); + // server死循环等待,input 'q'停止server进程 + // 如期望server也输出结果,则每个iteration开始和结束,需增加交互(client通知server) + if (cfg.GetIsServer()) { + if (type == PERF_TEST_TYPE::TRANSPORT_WRITE_LAT || type == PERF_TEST_TYPE::SERVICE_WRITE_LAT) { + if (cfg.GetIsTestAllSize()) { + ServerRunAllSizeTest(pTest, cfg, report); + } else { + PerfTestContext ctx; + ctx.mIterations = cfg.GetIterations(); + ctx.mSize = cfg.GetSize(); + if (!pTest->RunTest(&ctx)) { + LOG_ERROR("run test failed"); + } + } + pTest->UnInitialize(); + return; + } + while (true) { + auto tmpChar = getchar(); + switch (tmpChar) { + case 'q': + pTest->UnInitialize(); + return; + default: + continue; + } + } + } + + report->PrintReportHead(); + if (cfg.GetIsTestAllSize()) { + RunAllSizeTest(pTest, cfg, report); + } else { + PerfTestContext ctx; + ctx.mIterations = cfg.GetIterations(); + ctx.mSize = cfg.GetSize(); + if (!pTest->RunTest(&ctx)) { + LOG_ERROR("run test failed"); + return; + } + report->PrintReportElement(&ctx); + } + + report->PrintReportTail(); + pTest->UnInitialize(); +} + +int main(int argc, char *argv[]) +{ + // Parse parameters and check for conflicts + PerfTestConfig cfg; + if (!cfg.ParseArgs(argc, argv)) { + LOG_ERROR("parse cfg failed"); + return -1; + } + + cfg.Print(); + + PerfTestReportBase *report = PerfTestReportFactory::GetInstance().CreatePerfTestReport(cfg); + if (report == nullptr) { + LOG_ERROR("create perf test report failed!"); + return -1; + } + + RunTest(cfg, report); + + return 0; +} +} +} diff --git a/test/tools/perf_test/report/perf_test_report_base.cpp b/test/tools/perf_test/report/perf_test_report_base.cpp new file mode 100644 index 0000000000000000000000000000000000000000..995356141fb55bf6ba8c866f6b50a64112227dcf --- /dev/null +++ b/test/tools/perf_test/report/perf_test_report_base.cpp @@ -0,0 +1,18 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include + +#include "perf_test_report_base.h" + +namespace hcom { +namespace perftest { + + +void PerfTestReportBase::PrintReportTail() +{ + std::cout << PERF_TEST_RESULT_LINE << std::endl; +} + +} +} \ No newline at end of file diff --git a/test/tools/perf_test/report/perf_test_report_base.h b/test/tools/perf_test/report/perf_test_report_base.h new file mode 100644 index 0000000000000000000000000000000000000000..a25694514398942b036bf1c8d79f2104c997336b --- /dev/null +++ b/test/tools/perf_test/report/perf_test_report_base.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef PERF_TEST_REPORT_H +#define PERF_TEST_REPORT_H + +#include + +#include "common/perf_test_config.h" +#include "common/perf_test_common.h" + + +namespace hcom { +namespace perftest { + +enum class PERF_TEST_REPORT_TYPE { + LATENCY = 0, + BAND_WIDTH = 1 +}; + +#define PERF_TEST_RESULT_LINE "---------------------------------------------------------------------------------------" + +class PerfTestReportBase { +public: + explicit PerfTestReportBase(const PerfTestConfig& cfg) : mCfg(cfg) {}; + // 打印结果的头部 + virtual void PrintReportHead() = 0; + // 打印单条结果项,每个包尺寸调用该接口打印一条 + virtual void PrintReportElement(PerfTestContext *ctx) = 0; + // 打印结果的尾部 + void PrintReportTail(); + +protected: + PerfTestConfig mCfg; +}; + +} +} + +#endif diff --git a/test/tools/perf_test/report/perf_test_report_bw.cpp b/test/tools/perf_test/report/perf_test_report_bw.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f86e8f9e5dc311377a3169ca0c74896a3ae51bfe --- /dev/null +++ b/test/tools/perf_test/report/perf_test_report_bw.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include +#include "report/perf_test_report_factory.h" +#include "common/perf_test_logger.h" +#include "report/perf_test_report_bw.h" + +namespace hcom { +namespace perftest { +constexpr uint32_t MIN_ITERATIONS = 1; // min Iterations num +constexpr double NS_TO_S = 1000000000; +constexpr double TO_M = 1000000; +constexpr double BYTE_TO_MB = 1048576; +constexpr char FILL_CHAR = ' '; +constexpr uint32_t DOUBLE_WIDTH_6 = 6; +constexpr uint32_t OUTPUT_WIDTH_8 = 8; +constexpr uint32_t OUTPUT_WIDTH_13 = 13; +constexpr uint32_t OUTPUT_WIDTH_18 = 18; +constexpr uint32_t OUTPUT_WIDTH_20 = 20; +constexpr uint32_t OUTPUT_PRECISION = 2; + +// 与ib_send_bw打印保持一致 +#define RESULT_BW_HEADER " #bytes #iterations BW peak[MB/sec] BW average[MB/sec] MsgRate[Mpps]" + +void PerfTestReportBw::PrintReportElement(PerfTestContext *ctx) +{ + if (ctx == nullptr) { + LOG_ERROR("ctx is nullptr!"); + return; + } + + uint64_t iters = ctx->mIterations; + if (iters < MIN_ITERATIONS + 1) { + LOG_ERROR("iteration is less than 1!"); + return; + } + + double delta = static_cast(ctx->tposted[iters] - ctx->tposted[0]) / NS_TO_S; + double run_inf_bi_factor = 1; + uint32_t tSize = ctx->mSize * run_inf_bi_factor; + double bw_avg = (double)(tSize * iters) / delta / BYTE_TO_MB; // MB/s + double msgRate = (double)(run_inf_bi_factor * iters) / delta / TO_M; // Mpps + std::stringstream sstream; + // 统一设置左对齐,统一设置填充字符(配合位宽使用,通过修改填充字符方便调整打印格式) + sstream << std::left << std::setfill(FILL_CHAR); + sstream << " " << std::setw(OUTPUT_WIDTH_8) << ctx->mSize; + sstream << " " << std::setw(OUTPUT_WIDTH_13) << ctx->mIterations; + // 固定保留小数点后两位 + sstream << std::fixed << std::setprecision(OUTPUT_PRECISION); + sstream << " " << std::setw(OUTPUT_WIDTH_20) << "NA"; + sstream << " " << std::setw(OUTPUT_WIDTH_20) << bw_avg; + // 固定保留小数点后六位 + sstream << " " << std::setprecision(DOUBLE_WIDTH_6) << std::setw(OUTPUT_WIDTH_18) << msgRate; + std::cout << sstream.str() << std::endl; +} + +void PerfTestReportBw::PrintReportHead() +{ + std::stringstream sstream; + sstream << PERF_TEST_RESULT_LINE << std::endl; + sstream << " BandWidth Test" << std::endl; + sstream << " Cpu id : " << mCfg.GetCpuId() << std::endl; + if (mCfg.GetIsTestAllSize() == false) { + sstream << " Datasize : " << mCfg.GetSize(); + sstream << ", Iterations : " << mCfg.GetIterations() << std::endl; + } + sstream << PERF_TEST_RESULT_LINE << std::endl; + sstream << RESULT_BW_HEADER << std::endl; + std::cout << sstream.str(); +} + +REGIST_PERF_TEST_REPORT_CREATOR(PERF_TEST_REPORT_TYPE::BAND_WIDTH, PerfTestReportBw); +} +} \ No newline at end of file diff --git a/test/tools/perf_test/report/perf_test_report_bw.h b/test/tools/perf_test/report/perf_test_report_bw.h new file mode 100644 index 0000000000000000000000000000000000000000..c11e4fc4902c41c1643cef7f6c346aef88681d99 --- /dev/null +++ b/test/tools/perf_test/report/perf_test_report_bw.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef PERF_TEST_REPORT_BANDWIDTH_H +#define PERF_TEST_REPORT_BANDWIDTH_H + +#include "report/perf_test_report_base.h" + +namespace hcom { +namespace perftest { +class PerfTestReportBw : public PerfTestReportBase { +public: + PerfTestReportBw(const PerfTestConfig &cfg) : PerfTestReportBase(cfg){}; + void PrintReportElement(PerfTestContext *ctx) override; + void PrintReportHead() override; +}; +} +} +#endif \ No newline at end of file diff --git a/test/tools/perf_test/report/perf_test_report_factory.cpp b/test/tools/perf_test/report/perf_test_report_factory.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bba8cfe39c410cf7ed1918700e28c6f7a0c16a89 --- /dev/null +++ b/test/tools/perf_test/report/perf_test_report_factory.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include "report/perf_test_report_factory.h" + +namespace hcom { +namespace perftest { + +PerfTestReportBase* PerfTestReportFactory::CreatePerfTestReport(const PerfTestConfig& cfg) +{ + PERF_TEST_REPORT_TYPE reportType = PERF_TEST_REPORT_TYPE::LATENCY; + if (static_cast(cfg.GetType()) % 2 != 0) { + reportType = PERF_TEST_REPORT_TYPE::BAND_WIDTH; + } + + for (auto it : m_createFuncs) { + if (it.first == static_cast(reportType)) { + return it.second(cfg); + } + } + + LOG_ERROR("Can't find create function for perf test report(type=" << static_cast(reportType) << ")!"); + return nullptr; +} + +} +} diff --git a/test/tools/perf_test/report/perf_test_report_factory.h b/test/tools/perf_test/report/perf_test_report_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..173f71cc2b7303f0a9a10dfe77f36f7ce3d61a7e --- /dev/null +++ b/test/tools/perf_test/report/perf_test_report_factory.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_REPORT_FACTORY_H +#define HCOM_PERF_TEST_REPORT_FACTORY_H + +#include + +#include "common/perf_test_logger.h" +#include "report/perf_test_report_base.h" + +namespace hcom { +namespace perftest { + +using ReportCreateFunc = PerfTestReportBase* (*)(const PerfTestConfig& cfg); + +class PerfTestReportFactory { +public: + ~PerfTestReportFactory() = default; + static PerfTestReportFactory &GetInstance() + { + static PerfTestReportFactory instance; + return instance; + } + + PerfTestReportBase* CreatePerfTestReport(const PerfTestConfig& cfg); + + void RegistCreateFunc(PERF_TEST_REPORT_TYPE type, ReportCreateFunc func) + { + LOG_DEBUG("RegistCreateFunc for PERF_TEST_REPORT_TYPE(" << static_cast(type) << ")"); + m_createFuncs.emplace(static_cast(type), func); + } + +private: + PerfTestReportFactory() = default; + std::map m_createFuncs; +}; + + +class PerfTestReportRegister { +public: + PerfTestReportRegister(PERF_TEST_REPORT_TYPE type, ReportCreateFunc func) + { + PerfTestReportFactory::GetInstance().RegistCreateFunc(type, func); + } +}; + +#define REGIST_PERF_TEST_REPORT_CREATOR(ReportType, ReportClass) \ + static PerfTestReportBase* Create##ReportClass(const PerfTestConfig& cfg) \ + { \ + return new (std::nothrow) ReportClass(cfg); \ + } \ + static PerfTestReportRegister g_register_##ReportClass(ReportType, Create##ReportClass) + +} +} +#endif diff --git a/test/tools/perf_test/report/perf_test_report_lat.cpp b/test/tools/perf_test/report/perf_test_report_lat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0816f50f0fbd95c2c330340211d634a225601d12 --- /dev/null +++ b/test/tools/perf_test/report/perf_test_report_lat.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include +#include + +#include "report/perf_test_report_factory.h" +#include "common/perf_test_logger.h" +#include "report/perf_test_report_lat.h" + +namespace hcom { +namespace perftest { +constexpr uint32_t HALF = 2; +constexpr uint32_t NS_TO_US = 1000; +constexpr uint32_t LAT_MEASURE_TAIL = 2; // Remove the two max value +constexpr double PERF_TEST_ITERS_99 = 0.99; +constexpr double PERF_TEST_ITERS_99_9 = 0.999; +constexpr char FILL_CHAR = ' '; + +constexpr uint32_t OUTPUT_WIDTH_2 = 2; +constexpr uint32_t OUTPUT_WIDTH_7 = 7; +constexpr uint32_t OUTPUT_WIDTH_12 = 12; +constexpr uint32_t OUTPUT_WIDTH_13 = 13; +constexpr uint32_t OUTPUT_WIDTH_14 = 14; +constexpr uint32_t OUTPUT_WIDTH_15 = 15; +constexpr uint32_t OUTPUT_WIDTH_18 = 18; +constexpr uint32_t OUTPUT_WIDTH_22 = 22; +constexpr uint32_t OUTPUT_PRECISION = 2; + +// 与ib_send_lat打印保持一致 +#define RESULT_LAT_HEADER \ + " #bytes #iterations t_min[usec] t_max[usec] t_typical[usec] t_avg[usec] t_stdev[usec]" \ + " 99\% percentile[usec] 99.9\% percentile[usec]" + +static inline double GetMedian(uint64_t num, double *deltaArr) +{ + if ((num - 1) % HALF != 0) { + return (deltaArr[num / HALF] + deltaArr[num / HALF - 1]) / HALF; + } else { + return deltaArr[num / HALF]; + } +} + +bool PerfTestReportLat::isDuplex(const PERF_TEST_TYPE &type) +{ + if (type == PERF_TEST_TYPE::TRANSPORT_SEND_LAT || type == PERF_TEST_TYPE::TRANSPORT_WRITE_LAT || + type == PERF_TEST_TYPE::SERVICE_WRITE_LAT || type == PERF_TEST_TYPE::SERVICE_SEND_LAT) { + return true; + } + return false; +} + +void PerfTestReportLat::PrintReportElement(PerfTestContext *ctx) +{ + if (ctx == nullptr) { + LOG_ERROR("ctx is nullptr!"); + return; + } + + uint64_t iters = ctx->mIterations; + if (iters < LAT_MEASURE_TAIL + 1) { + LOG_ERROR("iteration is less than 3!"); + return; + } + + double *delta = new double[iters]; + if (delta == nullptr) { + LOG_ERROR("Failed to allocate memory for delta!"); + return; + } + + double biSend = 1; + PERF_TEST_TYPE type = mCfg.GetType(); + if (isDuplex(type)) { + biSend = 2; + } + + for (uint64_t i = 0; i < iters; i++) { + // 纳秒(ns)转微秒(us), 单向时延需要RTT/2 + delta[i] = static_cast(ctx->tposted[i + 1] - ctx->tposted[i]) / NS_TO_US / biSend; + } + + std::sort(delta, delta + iters); + iters = iters - LAT_MEASURE_TAIL; // Remove the two largest values + + double median = GetMedian(iters, delta); + double average = 0.0; + for (uint64_t i = 0; i < iters; i++) { + average += delta[i]; + } + average /= iters; + + /* variance lat */ + double stdev_sum = 0; + for (uint64_t i = 0; i < iters; i++) { + stdev_sum += pow(average - delta[i], 2); + } + double stdev = sqrt(stdev_sum / iters); + + uint64_t iters_99 = static_cast(ceil(iters * PERF_TEST_ITERS_99)); + uint64_t iters_99_9 = static_cast(ceil(iters * PERF_TEST_ITERS_99_9)); + + std::stringstream sstream; + // 统一设置左对齐,统一设置填充字符(配合位宽使用,通过修改填充字符方便调整打印格式) + sstream << std::left << std::setfill(FILL_CHAR); + sstream << " " << std::setw(OUTPUT_WIDTH_7) << ctx->mSize; + sstream << " " << std::setw(OUTPUT_WIDTH_13) << ctx->mIterations; + // 固定保留小数点后两位 + sstream << std::fixed << std::setprecision(OUTPUT_PRECISION); + sstream << " " << std::setw(OUTPUT_WIDTH_14) << delta[0]; + sstream << " " << std::setw(OUTPUT_WIDTH_12) << delta[iters]; + sstream << " " << std::setw(OUTPUT_WIDTH_18) << median; + sstream << " " << std::setw(OUTPUT_WIDTH_14) << average; + sstream << " " << std::setw(OUTPUT_WIDTH_15) << stdev; + sstream << " " << std::setw(OUTPUT_WIDTH_22) << delta[iters_99]; + sstream << " " << std::setw(OUTPUT_WIDTH_22) << delta[iters_99_9]; + std::cout << sstream.str() << std::endl; + + delete[] delta; + delta = nullptr; +} + +void PerfTestReportLat::PrintReportHead() +{ + std::stringstream sstream; + sstream << PERF_TEST_RESULT_LINE << std::endl; + sstream << " Latency Test" << std::endl; + sstream << " Cpu id : " << mCfg.GetCpuId() << std::endl; + if (mCfg.GetIsTestAllSize() == false) { + sstream << " Datasize : " << mCfg.GetSize(); + sstream << ", Iterations : " << mCfg.GetIterations() << std::endl; + } + sstream << PERF_TEST_RESULT_LINE << std::endl; + sstream << RESULT_LAT_HEADER << std::endl; + std::cout << sstream.str(); +} + +REGIST_PERF_TEST_REPORT_CREATOR(PERF_TEST_REPORT_TYPE::LATENCY, PerfTestReportLat); +} +} \ No newline at end of file diff --git a/test/tools/perf_test/report/perf_test_report_lat.h b/test/tools/perf_test/report/perf_test_report_lat.h new file mode 100644 index 0000000000000000000000000000000000000000..953c5cc97524ca36a0ec0447ce2e9ca8f20dce1d --- /dev/null +++ b/test/tools/perf_test/report/perf_test_report_lat.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef PERF_TEST_REPORT_LATENCY_H +#define PERF_TEST_REPORT_LATENCY_H + +#include "report/perf_test_report_base.h" + +namespace hcom { +namespace perftest { +class PerfTestReportLat : public PerfTestReportBase { +public: + PerfTestReportLat(const PerfTestConfig &cfg) : PerfTestReportBase(cfg){}; + void PrintReportElement(PerfTestContext *ctx) override; + void PrintReportHead() override; + bool isDuplex(const PERF_TEST_TYPE &type); +}; +} +} + +#endif diff --git a/test/tools/perf_test/test_case/perf_test_base.h b/test/tools/perf_test/test_case/perf_test_base.h new file mode 100644 index 0000000000000000000000000000000000000000..064e4bf9147790ca509e6aa2260e9e5d1e4e430b --- /dev/null +++ b/test/tools/perf_test/test_case/perf_test_base.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_CASE_H +#define HCOM_PERF_TEST_CASE_H + +#include "common/perf_test_common.h" +#include "common/perf_test_config.h" + +namespace hcom { +namespace perftest { + +class PerfTestBase { +public: + explicit PerfTestBase(const PerfTestConfig& cfg) : mCfg(cfg) {}; + virtual ~PerfTestBase() {}; + virtual bool Initialize() = 0; + virtual bool RunTest(PerfTestContext* ctx) = 0; + virtual void UnInitialize() = 0; + +protected: + PerfTestConfig mCfg; +}; + +} +} +#endif \ No newline at end of file diff --git a/test/tools/perf_test/test_case/perf_test_factory.h b/test/tools/perf_test/test_case/perf_test_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..dccd53d58a3e4a0f21e51646b01b5bb8e48bd7c6 --- /dev/null +++ b/test/tools/perf_test/test_case/perf_test_factory.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_FACTORY_H +#define HCOM_PERF_TEST_FACTORY_H + +#include + +#include "common/perf_test_logger.h" +#include "test_case/perf_test_base.h" + +namespace hcom { +namespace perftest { + +using CreateFunc = PerfTestBase* (*)(const PerfTestConfig& cfg); + +class PerfTestFactory { +public: + ~PerfTestFactory() = default; + static PerfTestFactory &GetInstance() + { + static PerfTestFactory instance; + return instance; + } + + PerfTestBase* CreatePerfTest(PERF_TEST_TYPE type, const PerfTestConfig& cfg) + { + for (auto it : m_createFuncs) { + if (it.first == static_cast(type)) { + return it.second(cfg); + } + } + LOG_ERROR("Can't find create function for perf test(type=" << static_cast(type) << ")!"); + return nullptr; + } + + void RegistCreateFunc(PERF_TEST_TYPE type, CreateFunc func) + { + LOG_DEBUG("RegistCreateFunc for perf test(type=" << static_cast(type) << ")"); + m_createFuncs.emplace(static_cast(type), func); + } + +private: + PerfTestFactory() = default; + std::map m_createFuncs; +}; + + +class PerfTestRegister { +public: + PerfTestRegister(PERF_TEST_TYPE type, CreateFunc func) + { + PerfTestFactory::GetInstance().RegistCreateFunc(type, func); + } +}; + +#define REGIST_PERF_TEST_CREATOR(TestType, TestClass) \ + static PerfTestBase *Create##TestClass(const PerfTestConfig& cfg) \ + { \ + return new (std::nothrow) TestClass(cfg); \ + } \ + static PerfTestRegister g_register_##TestClass(TestType, Create##TestClass) + +} +} +#endif diff --git a/test/tools/perf_test/test_case/service_v2/service_helper.cpp b/test/tools/perf_test/test_case/service_v2/service_helper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..607bde9c528735749a0d16c9835d765761e079aa --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_helper.cpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include "test_case/service_v2/service_helper.h" +#include "common/perf_test_logger.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; + +static int NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, const std::string &payload) +{ + return 0; +} + +static void ChannelBroken(const ock::hcom::UBSHcomChannelPtr &ch) +{ + return; +} + +static int RequestReceived(ock::hcom::UBSHcomServiceContext &ctx) +{ + return 0; +} + +static int RequestPosted(const ock::hcom::UBSHcomServiceContext &ctx) +{ + return 0; +} + +static int OneSideDone(const ock::hcom::UBSHcomServiceContext &ctx) +{ + return 0; +} + +ServiceHelper::ServiceHelper(const PerfTestConfig &cfg) +{ + mCfg = cfg; + mService = nullptr; + // 回调函数提供默认空实现,简化测试用例 + mNewChHandler = NewChannel; + mChBrokenHandler = ChannelBroken; + mRecvHandler = RequestReceived; + mSendHandler = RequestPosted; + mOneSideDoneHandler = OneSideDone; +} + +void ServiceHelper::RegisterNewChHandler(const UBSHcomServiceNewChannelHandler &handler) +{ + mNewChHandler = handler; +} + +void ServiceHelper::RegisterChBrokenHandler(const UBSHcomServiceChannelBrokenHandler &handler) +{ + mChBrokenHandler = handler; +} + +void ServiceHelper::RegisterRecvHandler(const UBSHcomServiceRecvHandler &handler) +{ + mRecvHandler = handler; +} + +void ServiceHelper::RegisterSendHandler(const UBSHcomServiceSendHandler &handler) +{ + mSendHandler = handler; +} + +void ServiceHelper::RegisterOneSideDoneHandler(const UBSHcomServiceOneSideDoneHandler &handler) +{ + mOneSideDoneHandler = handler; +} + +bool ServiceHelper::CreateService() +{ + if (mService != nullptr) { + LOG_WARN("UBSHcomNetDriver already created"); + return true; + } + std::string name; + if (mCfg.GetIsServer()) { + name = "ServicePerfTest_server"; + } else { + name = "ServicePerfTest_client"; + } + + UBSHcomServiceOptions options; + options.maxSendRecvDataSize = MAX_MESSAGE_SIZE + HCOM_HEADER_SIZE; + options.workerGroupMode = ock::hcom::NET_BUSY_POLLING; + if (mCfg.GetCpuId() != -1) { + options.workerGroupCpuIdsRange = { mCfg.GetCpuId(), mCfg.GetCpuId() }; + } + + mService = UBSHcomService::Create(mCfg.GetProtocol(), name, options); + if (mService == nullptr) { + LOG_ERROR("Failed to create service"); + return false; + } + + mService->SetDeviceIpMask({ mCfg.GetIpMask() }); + + mService->RegisterRecvHandler(mRecvHandler); + mService->RegisterChannelBrokenHandler(mChBrokenHandler, ock::hcom::UBSHcomChannelBrokenPolicy::BROKEN_ALL); + mService->RegisterSendHandler(mSendHandler); + mService->RegisterOneSideHandler(mOneSideDoneHandler); + + if (mCfg.GetIsServer()) { + mService->Bind("tcp://" + mCfg.GetOobIp() + ":" + std::to_string(mCfg.GetOobPort()), mNewChHandler); + } + + UBSHcomTlsOptions tlsOptions; + tlsOptions.enableTls = false; + mService->SetTlsOptions(tlsOptions); + + int result = 0; + + if ((result = mService->Start()) != 0) { + LOG_ERROR("Failed to start NetService " << result); + return false; + } + + LOG_DEBUG("NetService started"); + return true; +} + +void ServiceHelper::DestroyService() +{ + std::string name; + if (mCfg.GetIsServer()) { + name = "ServicePerfTest_server"; + } else { + name = "ServicePerfTest_client"; + } + + if (mService != nullptr) { + if (!mMrVector.empty()) { + for (auto mr : mMrVector) { + mService->DestroyMemoryRegion(mr); + } + mMrVector.clear(); + } + UBSHcomService::Destroy(name); + mService = nullptr; + } +} + +bool ServiceHelper::CreateMemoryRegion(RegMrInfo &mrInfo) +{ + if (mService == nullptr) { + return false; + } + + UBSHcomRegMemoryRegion mr; + // 按照最大包大小申请内存,以支持同时测试多个不同大小的包 + auto result = mService->RegisterMemoryRegion(MAX_MESSAGE_SIZE, mr); + if (result != 0) { + LOG_ERROR("Create memory region failed"); + return false; + } + + mrInfo.lAddress = mr.GetAddress(); + mr.GetMemoryKey(mrInfo.lKey); + mrInfo.size = MAX_MESSAGE_SIZE; + mMrVector.emplace_back(mr); + return true; +} +} +} diff --git a/test/tools/perf_test/test_case/service_v2/service_helper.h b/test/tools/perf_test/test_case/service_v2/service_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..20456c20b42fc60a0946026ce267c9da8372f491 --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_helper.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_SERVICE_HELPER_H +#define HCOM_SERVICE_HELPER_H + +#include "hcom/hcom_service.h" +#include "hcom/hcom_service_context.h" +#include "hcom/hcom_service_channel.h" +#include "hcom/hcom.h" +#include "common/perf_test_common.h" +#include "common/perf_test_config.h" + +namespace hcom { +namespace perftest { +using UBSHcomServiceNewChannelHandler = + std::function; +using UBSHcomServiceChannelBrokenHandler = std::function; +using UBSHcomServiceRecvHandler = std::function; +using UBSHcomServiceSendHandler = std::function; +using UBSHcomServiceOneSideDoneHandler = std::function; + +class RegMrInfo { +public: + uintptr_t lAddress = 0; + ock::hcom::UBSHcomMemoryKey lKey; + uint32_t size = 0; +}; + +class ServiceHelper { +public: + ServiceHelper(const PerfTestConfig &cfg); + bool CreateMemoryRegion(RegMrInfo &mrInfo); + bool CreateService(); + void DestroyService(); + inline ock::hcom::UBSHcomService *GetNetService() const + { + return mService; + } + void RegisterNewChHandler(const UBSHcomServiceNewChannelHandler &handler); + void RegisterChBrokenHandler(const UBSHcomServiceChannelBrokenHandler &handler); + void RegisterRecvHandler(const UBSHcomServiceRecvHandler &handler); + void RegisterSendHandler(const UBSHcomServiceSendHandler &handler); + void RegisterOneSideDoneHandler(const UBSHcomServiceOneSideDoneHandler &handler); + +private: + PerfTestConfig mCfg; + ock::hcom::UBSHcomService *mService = nullptr; + std::vector mMrVector; + + UBSHcomServiceNewChannelHandler mNewChHandler; + UBSHcomServiceChannelBrokenHandler mChBrokenHandler; + UBSHcomServiceRecvHandler mRecvHandler; + UBSHcomServiceSendHandler mSendHandler; + UBSHcomServiceOneSideDoneHandler mOneSideDoneHandler; +}; +} +} +#endif // HCOM_SERVICE_HELPER_H diff --git a/test/tools/perf_test/test_case/service_v2/service_read_bw_test.cpp b/test/tools/perf_test/test_case/service_v2/service_read_bw_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2cb23af6bb78d233096255e612c8cf25e2375add --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_read_bw_test.cpp @@ -0,0 +1,193 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "test_case/service_v2/service_read_bw_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_SERVICE_READ_BW = 204; + +int ServiceReadBwTest::DoPostRead() +{ + int res = 0; + mCtx->tposted[0] = ock::hcom::MONOTONIC_TIME_NS(); + for (uint64_t i = 0; i < mCtx->mIterations; ++i) { + ock::hcom::Callback *newCallback = ock::hcom::UBSHcomNewCallback( + [this](ock::hcom::UBSHcomServiceContext &context) { + PerfTestContext *testCtx = this->GetPerfTestContext(); + testCtx->totrcnt++; + if (testCtx->totrcnt == testCtx->mIterations) { + testCtx->tposted[testCtx->mIterations] = MONOTONIC_TIME_NS(); + sem_post(&this->mSem); + } + }, + std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + return -1; + } + res = mCh->Get(mReq, newCallback); + if (res != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("failed to send to server"); + return res; + } + } + return res; +} + +int ServiceReadBwTest::NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, + const std::string &payload) +{ + mCh = ch; + LOG_DEBUG("New connection from " << ipPort << " !"); + return 0; +} + +int ServiceReadBwTest::RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx) +{ + int result = 0; + if (mCfg.GetIsServer()) { + // server + UBSHcomRequest req(&mPostMrInfo, sizeof(mPostMrInfo), OP_SERVICE_READ_BW); + Callback *newCallback = UBSHcomNewCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + return -1; + } + // post send callback + UBSHcomReplyContext replyCtx; + replyCtx.rspCtx = ctx.RspCtx(); + if ((ctx.Channel()->Reply(replyCtx, req, newCallback)) != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("Failed to post message to data to server"); + return -1; + } + } + return result; +} + +bool ServiceReadBwTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create NetService + UBSHcomServiceNewChannelHandler funcNewChannel = + bind(&ServiceReadBwTest::NewChannel, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); + UBSHcomServiceRecvHandler funcReqReceived = bind(&ServiceReadBwTest::RequestReceived, this, std::placeholders::_1); + + mHelper.RegisterRecvHandler(funcReqReceived); + mHelper.RegisterNewChHandler(funcNewChannel); + if (!mHelper.CreateService()) { + goto ERROR_HANDLE; + } + + if (!RegMemory()) { + LOG_ERROR("register memory failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + if (!ExchangeAddress()) { + LOG_ERROR("client exchange address failed"); + goto ERROR_HANDLE; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyService(); + sem_destroy(&mSem); + return false; +} + +bool ServiceReadBwTest::RegMemory() +{ + if (!mHelper.CreateMemoryRegion(mPostMrInfo)) { + LOG_ERROR("Create memoryRegion failed"); + return false; + } + return true; +} + +bool ServiceReadBwTest::ExchangeAddress() +{ + if (mCh == nullptr) { + LOG_ERROR("Exchange address failed, ch is nullptr!"); + return false; + } + + std::string value = "hello world"; + UBSHcomRequest req((void *)(const_cast(value.c_str())), value.length(), OP_SERVICE_READ_BW); + UBSHcomResponse rsp(&mPeerMrInfo, sizeof(mPeerMrInfo)); + + if ((mCh->Call(req, rsp, nullptr)) != 0) { + LOG_ERROR("Failed to call message to data to server"); + return false; + } + + return true; +} + +void ServiceReadBwTest::UnInitialize() +{ + if (mCh != nullptr) { + mHelper.GetNetService()->Disconnect(mCh); + mCh.Set(nullptr); + } + + mHelper.DestroyService(); + sem_destroy(&mSem); +} + +bool ServiceReadBwTest::Connect() +{ + auto service = mHelper.GetNetService(); + if (service == nullptr) { + LOG_ERROR("Connect failed, net service is nullptr!"); + return false; + } + UBSHcomConnectOptions opt; + int res = service->Connect("tcp://" + mCfg.GetOobIp() + ":" + std::to_string(mCfg.GetOobPort()), mCh, opt); + if (res != 0) { + LOG_ERROR("Connect failed, error code: " << res); + return false; + } + return true; +} + +bool ServiceReadBwTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (!mCfg.GetIsServer()) { + mReq.lAddress = mPostMrInfo.lAddress; + mReq.rAddress = mPeerMrInfo.lAddress; + mReq.lKey = mPostMrInfo.lKey; + mReq.rKey = mPeerMrInfo.lKey; + mReq.size = ctx->mSize; + + DoPostRead(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::SERVICE_READ_BW, ServiceReadBwTest); +} +} diff --git a/test/tools/perf_test/test_case/service_v2/service_read_bw_test.h b/test/tools/perf_test/test_case/service_v2/service_read_bw_test.h new file mode 100644 index 0000000000000000000000000000000000000000..894dc1d299712dd6797fb96a3fe2e262a854fe7e --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_read_bw_test.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_SERVICE_READ_BW_H +#define HCOM_PERF_TEST_SERVICE_READ_BW_H +#include +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/service_v2/service_helper.h" + +namespace hcom { +namespace perftest { +class ServiceReadBwTest : public PerfTestBase { +public: + ServiceReadBwTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + int DoPostRead(); + int NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx); + bool RegMemory(); + bool ExchangeAddress(); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomChannelPtr mCh = nullptr; + ock::hcom::UBSHcomOneSideRequest mReq; + ServiceHelper mHelper; + RegMrInfo mPostMrInfo; + RegMrInfo mPeerMrInfo; + sem_t mSem; +}; +} +} + +#endif diff --git a/test/tools/perf_test/test_case/service_v2/service_read_lat_test.cpp b/test/tools/perf_test/test_case/service_v2/service_read_lat_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1d34153f797bc2d56b117b656c0e8e90f63b6509 --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_read_lat_test.cpp @@ -0,0 +1,164 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "test_case/service_v2/service_read_lat_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_SERVICE_READ_LAT = 203; + +int ServiceReadLatTest::NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, + const std::string &payload) +{ + mCh = ch; + LOG_DEBUG("New connection from " << ipPort << " !"); + return 0; +} + +int ServiceReadLatTest::RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx) +{ + int result = 0; + if (mCfg.GetIsServer()) { + // server + UBSHcomRequest req(&mPostMrInfo, sizeof(mPostMrInfo), OP_SERVICE_READ_LAT); + Callback *newCallback = UBSHcomNewCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + return -1; + } + // post send callback + UBSHcomReplyContext replyCtx; + replyCtx.rspCtx = ctx.RspCtx(); + if ((ctx.Channel()->Reply(replyCtx, req, newCallback)) != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("Failed to post message to data to server"); + return -1; + } + } + return result; +} + +bool ServiceReadLatTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create NetService + UBSHcomServiceNewChannelHandler funcNewChannel = bind(&ServiceReadLatTest::NewChannel, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + UBSHcomServiceRecvHandler funcReqReceived = bind(&ServiceReadLatTest::RequestReceived, this, std::placeholders::_1); + + mHelper.RegisterRecvHandler(funcReqReceived); + mHelper.RegisterNewChHandler(funcNewChannel); + + if (!mHelper.CreateService()) { + goto ERROR_HANDLE; + } + + if (!RegMemory()) { + LOG_ERROR("register memory failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + if (!ExchangeAddress()) { + LOG_ERROR("client exchange address failed"); + goto ERROR_HANDLE; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyService(); + sem_destroy(&mSem); + return false; +} + +bool ServiceReadLatTest::RegMemory() +{ + if (!mHelper.CreateMemoryRegion(mPostMrInfo)) { + LOG_ERROR("Create memoryRegion failed"); + return false; + } + return true; +} + +bool ServiceReadLatTest::ExchangeAddress() +{ + if (mCh == nullptr) { + LOG_ERROR("Exchange address failed, ch is nullptr!"); + return false; + } + + std::string value = "hello world"; + UBSHcomRequest req((void *)(const_cast(value.c_str())), value.length(), OP_SERVICE_READ_LAT); + UBSHcomResponse rsp(&mPeerMrInfo, sizeof(mPeerMrInfo)); + + if ((mCh->Call(req, rsp, nullptr)) != 0) { + LOG_ERROR("Failed to call message to data to server"); + return false; + } + + return true; +} + +void ServiceReadLatTest::UnInitialize() +{ + if (mCh != nullptr) { + mHelper.GetNetService()->Disconnect(mCh); + mCh.Set(nullptr); + } + + mHelper.DestroyService(); + sem_destroy(&mSem); +} + +bool ServiceReadLatTest::Connect() +{ + auto service = mHelper.GetNetService(); + if (service == nullptr) { + LOG_ERROR("Connect failed, net service is nullptr!"); + return false; + } + UBSHcomConnectOptions opt; + int res = service->Connect("tcp://" + mCfg.GetOobIp() + ":" + std::to_string(mCfg.GetOobPort()), mCh, opt); + if (res != 0) { + LOG_ERROR("Connect failed, error code: " << res); + return false; + } + return true; +} + +bool ServiceReadLatTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + + if (!mCfg.GetIsServer()) { + mReq.lAddress = mPostMrInfo.lAddress; + mReq.rAddress = mPeerMrInfo.lAddress; + mReq.lKey = mPostMrInfo.lKey; + mReq.rKey = mPeerMrInfo.lKey; + mReq.size = ctx->mSize; + + DoPostRead(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::SERVICE_READ_LAT, ServiceReadLatTest); +} +} diff --git a/test/tools/perf_test/test_case/service_v2/service_read_lat_test.h b/test/tools/perf_test/test_case/service_v2/service_read_lat_test.h new file mode 100644 index 0000000000000000000000000000000000000000..35dc759cc5da58b2df9a8ed59dd01bc97ebd3093 --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_read_lat_test.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_SERVICE_READ_LAT_H +#define HCOM_PERF_TEST_SERVICE_READ_LAT_H +#include +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/service_v2/service_helper.h" + +namespace hcom { +namespace perftest { +class ServiceReadLatTest : public PerfTestBase { +public: + ServiceReadLatTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + + inline int DoPostRead() + { + rcnt.store(0); + while (mCtx->cnt < mCtx->mIterations) { + ock::hcom::Callback *newCallback = ock::hcom::UBSHcomNewCallback( + [this](ock::hcom::UBSHcomServiceContext &context) { + PerfTestContext *testCtx = this->GetPerfTestContext(); + this->rcnt.fetch_add(1); + if (static_cast(this->rcnt.load()) == testCtx->mIterations) { + testCtx->tposted[testCtx->mIterations] = ock::hcom::MONOTONIC_TIME_NS(); + sem_post(&this->mSem); + } + }, + std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + sem_post(&mSem); + return -1; + } + mCtx->tposted[mCtx->cnt] = ock::hcom::MONOTONIC_TIME_NS(); + int res = mCh->Get(mReq, newCallback); + if (res != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("failed to send to server"); + sem_post(&mSem); + return -1; + } + ++mCtx->cnt; + while (mCtx->cnt != static_cast(rcnt.load())) + ; + } + mCtx->tposted[mCtx->cnt] = ock::hcom::MONOTONIC_TIME_NS(); + LOG_DEBUG("One Iteration Done!"); + sem_post(&mSem); + return 0; + } + + int NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx); + bool RegMemory(); + bool ExchangeAddress(); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomChannelPtr mCh = nullptr; + ock::hcom::UBSHcomOneSideRequest mReq; + volatile std::atomic rcnt{ 0 }; + ServiceHelper mHelper; + RegMrInfo mPostMrInfo; + RegMrInfo mPeerMrInfo; + sem_t mSem; +}; +} +} + +#endif diff --git a/test/tools/perf_test/test_case/service_v2/service_send_bw_test.cpp b/test/tools/perf_test/test_case/service_v2/service_send_bw_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fd162aa68db34fc18bc6e96910ca45fca485109f --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_send_bw_test.cpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "test_case/service_v2/service_send_bw_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_CODE_SEND_BW = 201; + +int ServiceSendBwTest::DoPostSend() +{ + PerfTestContext *ctx = GetPerfTestContext(); + UBSHcomRequest req(mDataAddr, ctx->mSize, OP_CODE_SEND_BW); + ctx->tposted[0] = MONOTONIC_TIME_NS(); + for (uint64_t i = 0; i < ctx->mIterations; ++i) { + Callback *newCallback = UBSHcomNewCallback( + [this](UBSHcomServiceContext &context) { + PerfTestContext *testCtx = this->GetPerfTestContext(); + testCtx->totrcnt++; + if (testCtx->totrcnt == testCtx->mIterations) { + testCtx->tposted[testCtx->mIterations] = MONOTONIC_TIME_NS(); + sem_post(&this->mSem); + } + }, + std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + return -1; + } + int res = mCh->Send(req, newCallback); + if (res != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("failed to send to server"); + return res; + } + } + return 0; +} + +int ServiceSendBwTest::NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, + const std::string &payload) +{ + mCh = ch; + LOG_DEBUG("New connection from " << ipPort << " !"); + return 0; +} + +int ServiceSendBwTest::RequestPosted(const ock::hcom::UBSHcomServiceContext &ctx) +{ + PerfTestContext *testCtx = GetPerfTestContext(); + testCtx->totrcnt++; + if (testCtx->totrcnt == testCtx->mIterations) { + testCtx->tposted[testCtx->mIterations] = MONOTONIC_TIME_NS(); + sem_post(&mSem); + } + return 0; +} + +bool ServiceSendBwTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create NetService + UBSHcomServiceNewChannelHandler funcNewChannel = + bind(&ServiceSendBwTest::NewChannel, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); + UBSHcomServiceSendHandler funcReqPosted = bind(&ServiceSendBwTest::RequestPosted, this, std::placeholders::_1); + + mHelper.RegisterNewChHandler(funcNewChannel); + mHelper.RegisterSendHandler(funcReqPosted); + + if (!mHelper.CreateService()) { + goto ERROR_HANDLE; + } + + // init data buffer + mDataAddr = new (std::nothrow) char[MAX_MESSAGE_SIZE]; + if (mDataAddr == nullptr) { + LOG_ERROR("Create data buffer failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("Client connect failed"); + goto ERROR_HANDLE; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyService(); + if (mDataAddr != nullptr) { + delete[] mDataAddr; + mDataAddr = nullptr; + } + sem_destroy(&mSem); + return false; +} + +void ServiceSendBwTest::UnInitialize() +{ + if (mCh != nullptr) { + mHelper.GetNetService()->Disconnect(mCh); + mCh.Set(nullptr); + } + + mHelper.DestroyService(); + + if (mDataAddr != nullptr) { + delete[] mDataAddr; + mDataAddr = nullptr; + } + sem_destroy(&mSem); +} + +bool ServiceSendBwTest::Connect() +{ + auto service = mHelper.GetNetService(); + if (service == nullptr) { + LOG_ERROR("Connect failed, net service is nullptr!"); + return false; + } + UBSHcomConnectOptions opt; + int res = service->Connect("tcp://" + mCfg.GetOobIp() + ":" + std::to_string(mCfg.GetOobPort()), mCh, opt); + if (res != 0) { + LOG_ERROR("Connect failed, error code: " << res); + return false; + } + return true; +} + +bool ServiceSendBwTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (!mCfg.GetIsServer()) { + DoPostSend(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::SERVICE_SEND_BW, ServiceSendBwTest); +} +} diff --git a/test/tools/perf_test/test_case/service_v2/service_send_bw_test.h b/test/tools/perf_test/test_case/service_v2/service_send_bw_test.h new file mode 100644 index 0000000000000000000000000000000000000000..a0248b595ebb484bd1a0555e01021008a275343b --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_send_bw_test.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_SERVICE_SEND_BW_H +#define HCOM_PERF_TEST_SERVICE_SEND_BW_H +#include +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/service_v2/service_helper.h" + +namespace hcom { +namespace perftest { +class ServiceSendBwTest : public PerfTestBase { +public: + ServiceSendBwTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + int DoPostSend(); + int NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, const std::string &payload); + int RequestPosted(const ock::hcom::UBSHcomServiceContext &ctx); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomChannelPtr mCh = nullptr; + ServiceHelper mHelper; + char *mDataAddr = nullptr; + sem_t mSem; +}; +} +} + +#endif diff --git a/test/tools/perf_test/test_case/service_v2/service_send_lat_test.cpp b/test/tools/perf_test/test_case/service_v2/service_send_lat_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..68f16230883c09f899176a205cce9f4c2bc49212 --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_send_lat_test.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "test_case/service_v2/service_send_lat_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_CODE_SEND_LAT = 200; + +int ServiceSendLatTest::DoPostSend() +{ + PerfTestContext *ctx = GetPerfTestContext(); + int res = 0; + // ctx->tposted[i+1] - ctx->tposted[i] 为一次RTT(Round-Trip Time,往返时间) + if (ctx->cnt < ctx->mIterations) { + ctx->tposted[ctx->cnt] = MONOTONIC_TIME_NS(); + UBSHcomRequest req(mDataAddr, ctx->mSize, OP_CODE_SEND_LAT); + Callback *newCallback = UBSHcomNewCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + return -1; + } + res = mCh->Send(req, newCallback); + if (res != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("Failed to send to server"); + } + ++ctx->cnt; + return 0; + } + + if (ctx->cnt == ctx->mIterations) { + ctx->tposted[ctx->cnt] = MONOTONIC_TIME_NS(); + LOG_DEBUG("One Iteration Done!"); + sem_post(&mSem); + } + return 0; +} + +int ServiceSendLatTest::NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, + const std::string &payload) +{ + mCh = ch; + LOG_DEBUG("New connection from " << ipPort << " !"); + return 0; +} + +int ServiceSendLatTest::RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx) +{ + int res = 0; + if (mCfg.GetIsServer()) { + // server 直接回复相同大小的消息即可 + UBSHcomRequest req(ctx.MessageData(), ctx.MessageDataLen(), OP_CODE_SEND_LAT); + Callback *newCallback = UBSHcomNewCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + return -1; + } + res = mCh->Send(req, newCallback); + if (res != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("UBSHcomResponse meaasge error"); + return -1; + } + } else { + DoPostSend(); + } + return 0; +} + +bool ServiceSendLatTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create NetService + std::function + funcNewChannel = bind(&ServiceSendLatTest::NewChannel, this, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3); + std::function funcReqReceived = + bind(&ServiceSendLatTest::RequestReceived, this, std::placeholders::_1); + mHelper.RegisterNewChHandler(funcNewChannel); + mHelper.RegisterRecvHandler(funcReqReceived); + if (!mHelper.CreateService()) { + goto ERROR_HANDLE; + } + + // init data buffer + mDataAddr = new char[MAX_MESSAGE_SIZE]; + if (mDataAddr == nullptr) { + LOG_ERROR("Create data buffer failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("Client connect failed"); + return false; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyService(); + if (mDataAddr != nullptr) { + delete[] mDataAddr; + mDataAddr = nullptr; + } + sem_destroy(&mSem); + return false; +} + +void ServiceSendLatTest::UnInitialize() +{ + if (mCh != nullptr) { + mHelper.GetNetService()->Disconnect(mCh); + mCh.Set(nullptr); + } + + mHelper.DestroyService(); + + if (mDataAddr != nullptr) { + delete[] mDataAddr; + mDataAddr = nullptr; + } + sem_destroy(&mSem); +} + +bool ServiceSendLatTest::Connect() +{ + auto service = mHelper.GetNetService(); + if (service == nullptr) { + LOG_ERROR("Connect failed, net service is nullptr!"); + return false; + } + UBSHcomConnectOptions opt; + int res = service->Connect("tcp://" + mCfg.GetOobIp() + ":" + std::to_string(mCfg.GetOobPort()), mCh, opt); + if (res != 0) { + LOG_ERROR("Connect failed, error code: " << res); + return false; + } + return true; +} + +bool ServiceSendLatTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (!mCfg.GetIsServer()) { + DoPostSend(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::SERVICE_SEND_LAT, ServiceSendLatTest); +} +} diff --git a/test/tools/perf_test/test_case/service_v2/service_send_lat_test.h b/test/tools/perf_test/test_case/service_v2/service_send_lat_test.h new file mode 100644 index 0000000000000000000000000000000000000000..d9d86abd01e1a355e304bcc9ba1a5a8293f2a016 --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_send_lat_test.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_SERVICE_SEND_LAT_H +#define HCOM_PERF_TEST_SERVICE_SEND_LAT_H +#include +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/service_v2/service_helper.h" + +namespace hcom { +namespace perftest { +class ServiceSendLatTest : public PerfTestBase { +public: + ServiceSendLatTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + int DoPostSend(); + int NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx); + int OneSideDone(const ock::hcom::UBSHcomServiceContext &ctx); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomChannelPtr mCh = nullptr; + ServiceHelper mHelper; + + char *mDataAddr = nullptr; + sem_t mSem; +}; +} +} + +#endif diff --git a/test/tools/perf_test/test_case/service_v2/service_write_bw_test.cpp b/test/tools/perf_test/test_case/service_v2/service_write_bw_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aeeb3b4fee7f3acc02dd8fbd1ad310e9154c2747 --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_write_bw_test.cpp @@ -0,0 +1,193 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "test_case/service_v2/service_write_bw_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_SERVICE_WRITE_BW = 206; + +int ServiceWriteBwTest::DoPostWrite() +{ + mCtx->tposted[0] = ock::hcom::MONOTONIC_TIME_NS(); + for (uint64_t i = 0; i < mCtx->mIterations; ++i) { + ock::hcom::Callback *newCallback = ock::hcom::UBSHcomNewCallback( + [this](ock::hcom::UBSHcomServiceContext &context) { + PerfTestContext *testCtx = this->GetPerfTestContext(); + testCtx->totrcnt++; + if (testCtx->totrcnt == testCtx->mIterations) { + testCtx->tposted[testCtx->mIterations] = MONOTONIC_TIME_NS(); + sem_post(&this->mSem); + } + }, + std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + return -1; + } + int res = mCh->Put(mReq, newCallback); + if (res != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("failed to send to server"); + return res; + } + } + + return 0; +} + +int ServiceWriteBwTest::NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, + const std::string &payload) +{ + mCh = ch; + LOG_DEBUG("New connection from " << ipPort << " !"); + return 0; +} + +int ServiceWriteBwTest::RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx) +{ + int result = 0; + if (mCfg.GetIsServer()) { + // server + UBSHcomRequest req(&mPostMrInfo, sizeof(mPostMrInfo), OP_SERVICE_WRITE_BW); + Callback *newCallback = UBSHcomNewCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + return -1; + } + // post send callback + UBSHcomReplyContext replyCtx; + replyCtx.rspCtx = ctx.RspCtx(); + if ((ctx.Channel()->Reply(replyCtx, req, newCallback)) != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("Failed to post message to data to server"); + return -1; + } + } + return result; +} + +bool ServiceWriteBwTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create NetService + UBSHcomServiceNewChannelHandler funcNewChannel = bind(&ServiceWriteBwTest::NewChannel, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + UBSHcomServiceRecvHandler funcReqReceived = bind(&ServiceWriteBwTest::RequestReceived, this, std::placeholders::_1); + + mHelper.RegisterRecvHandler(funcReqReceived); + mHelper.RegisterNewChHandler(funcNewChannel); + if (!mHelper.CreateService()) { + goto ERROR_HANDLE; + } + + if (!RegMemory()) { + LOG_ERROR("register memory failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + if (!ExchangeAddress()) { + LOG_ERROR("client exchange address failed"); + goto ERROR_HANDLE; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyService(); + sem_destroy(&mSem); + return false; +} + +bool ServiceWriteBwTest::RegMemory() +{ + if (!mHelper.CreateMemoryRegion(mPostMrInfo)) { + LOG_ERROR("Create memoryRegion failed"); + return false; + } + return true; +} + +bool ServiceWriteBwTest::ExchangeAddress() +{ + if (mCh == nullptr) { + LOG_ERROR("Exchange address failed, ch is nullptr!"); + return false; + } + + std::string value = "hello world"; + UBSHcomRequest req((void *)(const_cast(value.c_str())), value.length(), OP_SERVICE_WRITE_BW); + UBSHcomResponse rsp(&mPeerMrInfo, sizeof(mPeerMrInfo)); + + if ((mCh->Call(req, rsp, nullptr)) != 0) { + LOG_ERROR("Failed to call message to data to server"); + return false; + } + + return true; +} + +void ServiceWriteBwTest::UnInitialize() +{ + if (mCh != nullptr) { + mHelper.GetNetService()->Disconnect(mCh); + mCh.Set(nullptr); + } + + mHelper.DestroyService(); + sem_destroy(&mSem); +} + +bool ServiceWriteBwTest::Connect() +{ + auto service = mHelper.GetNetService(); + if (service == nullptr) { + LOG_ERROR("Connect failed, net service is nullptr!"); + return false; + } + UBSHcomConnectOptions opt; + int res = service->Connect("tcp://" + mCfg.GetOobIp() + ":" + std::to_string(mCfg.GetOobPort()), mCh, opt); + if (res != 0) { + LOG_ERROR("Connect failed, error code: " << res); + return false; + } + return true; +} + +bool ServiceWriteBwTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (!mCfg.GetIsServer()) { + mReq.lAddress = mPostMrInfo.lAddress; + mReq.rAddress = mPeerMrInfo.lAddress; + mReq.lKey = mPostMrInfo.lKey; + mReq.rKey = mPeerMrInfo.lKey; + mReq.size = ctx->mSize; + + DoPostWrite(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::SERVICE_WRITE_BW, ServiceWriteBwTest); +} +} diff --git a/test/tools/perf_test/test_case/service_v2/service_write_bw_test.h b/test/tools/perf_test/test_case/service_v2/service_write_bw_test.h new file mode 100644 index 0000000000000000000000000000000000000000..055f43b494011c2a6194d207ba9f9c18389d69c7 --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_write_bw_test.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_SERVICE_WRITE_BW_H +#define HCOM_PERF_TEST_SERVICE_WRITE_BW_H +#include +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/service_v2/service_helper.h" + +namespace hcom { +namespace perftest { +class ServiceWriteBwTest : public PerfTestBase { +public: + ServiceWriteBwTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + int DoPostWrite(); + int NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx); + bool RegMemory(); + bool ExchangeAddress(); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomChannelPtr mCh = nullptr; + ock::hcom::UBSHcomOneSideRequest mReq; + ServiceHelper mHelper; + RegMrInfo mPostMrInfo; + RegMrInfo mPeerMrInfo; + sem_t mSem; +}; +} +} + +#endif diff --git a/test/tools/perf_test/test_case/service_v2/service_write_lat_test.cpp b/test/tools/perf_test/test_case/service_v2/service_write_lat_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..04b7140933685438191910ae75253c651dc41f91 --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_write_lat_test.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "securec.h" + +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "test_case/service_v2/service_write_lat_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_SERVICE_WRITE_LAT = 205; + +int ServiceWriteLatTest::NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, + const std::string &payload) +{ + mCh = ch; + isConnect.store(true); + LOG_DEBUG("New connection from " << ipPort << " !"); + return 0; +} + +int ServiceWriteLatTest::RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx) +{ + int result = 0; + if (mCfg.GetIsServer()) { + // server + if (memcpy_s(&mPeerMrInfo, sizeof(mPeerMrInfo), ctx.MessageData(), ctx.MessageDataLen()) != 0) { + LOG_ERROR("memcpy_s failed"); + return -1; + } + UBSHcomRequest req(&mPollMrInfo, sizeof(mPollMrInfo), OP_SERVICE_WRITE_LAT); + // NetServiceOpInfo sendOpInfo{}; + Callback *newCallback = UBSHcomNewCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + sem_post(&mSem); + return -1; + } + // post send callback + UBSHcomReplyContext replyCtx; + replyCtx.rspCtx = ctx.RspCtx(); + if ((ctx.Channel()->Reply(replyCtx, req, newCallback)) != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("Failed to post message to data to server"); + result = -1; + } + sem_post(&mSem); + } + return result; +} + +void ServiceWriteLatTest::ChannelBroken(const ock::hcom::UBSHcomChannelPtr &ch) +{ + isConnect.store(false); + return; +} + +bool ServiceWriteLatTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create NetService + UBSHcomServiceNewChannelHandler funcNewChannel = bind(&ServiceWriteLatTest::NewChannel, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + UBSHcomServiceRecvHandler funcReqReceived = bind(&ServiceWriteLatTest::RequestReceived, this, std::placeholders::_1); + UBSHcomServiceChannelBrokenHandler funcChBroken = bind(&ServiceWriteLatTest::ChannelBroken, this, + std::placeholders::_1); + + mHelper.RegisterRecvHandler(funcReqReceived); + mHelper.RegisterNewChHandler(funcNewChannel); + mHelper.RegisterChBrokenHandler(funcChBroken); + if (!mHelper.CreateService()) { + goto ERROR_HANDLE; + } + + if (!RegMemory()) { + LOG_ERROR("register memory failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + if (!ExchangeAddress()) { + LOG_ERROR("client exchange address failed"); + goto ERROR_HANDLE; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyService(); + sem_destroy(&mSem); + return false; +} + +bool ServiceWriteLatTest::RegMemory() +{ + if (!mHelper.CreateMemoryRegion(mPostMrInfo)) { + LOG_ERROR("Create memoryRegion failed"); + return false; + } + if (!mHelper.CreateMemoryRegion(mPollMrInfo)) { + LOG_ERROR("Create memoryRegion failed"); + return false; + } + return true; +} + +bool ServiceWriteLatTest::ExchangeAddress() +{ + if (mCh == nullptr) { + LOG_ERROR("Exchange address failed, ch is nullptr!"); + return false; + } + + UBSHcomRequest req(&mPollMrInfo, sizeof(mPollMrInfo), OP_SERVICE_WRITE_LAT); + UBSHcomResponse rsp(&mPeerMrInfo, sizeof(mPeerMrInfo)); + + if ((mCh->Call(req, rsp, nullptr)) != 0) { + LOG_ERROR("Failed to call message to data to server"); + return false; + } + + return true; +} + +void ServiceWriteLatTest::UnInitialize() +{ + if (mCh != nullptr) { + mHelper.GetNetService()->Disconnect(mCh); + mCh.Set(nullptr); + } + + mHelper.DestroyService(); + sem_destroy(&mSem); +} + +bool ServiceWriteLatTest::Connect() +{ + auto service = mHelper.GetNetService(); + if (service == nullptr) { + LOG_ERROR("Connect failed, net service is nullptr!"); + return false; + } + UBSHcomConnectOptions opt; + int res = service->Connect("tcp://" + mCfg.GetOobIp() + ":" + std::to_string(mCfg.GetOobPort()), mCh, opt); + if (res != 0) { + LOG_ERROR("Connect failed, error code: " << res); + return false; + } + return true; +} + +bool ServiceWriteLatTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (mCfg.GetIsServer() && !isConnect.load()) { + // server等到地址交换结束 + sem_wait(&mSem); + } + + mReq.lAddress = mPostMrInfo.lAddress; + mReq.rAddress = mPeerMrInfo.lAddress; + mReq.lKey = mPostMrInfo.lKey; + mReq.rKey = mPeerMrInfo.lKey; + mReq.size = ctx->mSize; + + DoPostWrite(); + + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::SERVICE_WRITE_LAT, ServiceWriteLatTest); +} +} diff --git a/test/tools/perf_test/test_case/service_v2/service_write_lat_test.h b/test/tools/perf_test/test_case/service_v2/service_write_lat_test.h new file mode 100644 index 0000000000000000000000000000000000000000..0533bb3fb087562bd9f37ef4fe33fb1d88121bf9 --- /dev/null +++ b/test/tools/perf_test/test_case/service_v2/service_write_lat_test.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_SERVICE_WRITE_LAT_H +#define HCOM_PERF_TEST_SERVICE_WRITE_LAT_H +#include +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/service_v2/service_helper.h" + +namespace hcom { +namespace perftest { +class ServiceWriteLatTest : public PerfTestBase { +public: + ServiceWriteLatTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + + inline int DoPostWrite() + { + volatile uint64_t *pollData = reinterpret_cast(mPollMrInfo.lAddress); + volatile uint64_t *postData = reinterpret_cast(mPostMrInfo.lAddress); + uint64_t num = 0; + *pollData = num; + *postData = num; + PerfTestContext *ctx = GetPerfTestContext(); + ctx->cnt = 0; + rcnt = 0; + ccnt.store(0); + while (ctx->cnt < ctx->mIterations || rcnt < ctx->mIterations || + static_cast(ccnt.load()) < ctx->mIterations) { + if (rcnt < ctx->mIterations && !(ctx->cnt < 1 && !mCfg.GetIsServer())) { + rcnt++; + while ((*pollData != rcnt) && ctx->cnt < ctx->mIterations) + ; + } + if (ctx->cnt < ctx->mIterations) { + ++ctx->cnt; + ock::hcom::Callback *newCallback = ock::hcom::UBSHcomNewCallback( + [this](ock::hcom::UBSHcomServiceContext &context) { this->ccnt.fetch_add(1); }, std::placeholders::_1); + if (newCallback == nullptr) { + LOG_ERROR("Create callback failed"); + sem_post(&mSem); + return -1; + } + *postData = ctx->cnt; + ctx->tposted[mCtx->cnt - 1] = ock::hcom::MONOTONIC_TIME_NS(); + int res = mCh->Put(mReq, newCallback); + if (res != 0) { + if (newCallback != nullptr) { + delete newCallback; + } + LOG_ERROR("failed to write to server"); + sem_post(&mSem); + return res; + } + } + + while (ctx->cnt != static_cast(ccnt.load())) + ; + } + if (ctx->cnt == ctx->mIterations) { + ctx->tposted[ctx->cnt] = ock::hcom::MONOTONIC_TIME_NS(); + LOG_DEBUG("One Iteration Done!"); + sem_post(&mSem); + } + return 0; + } + + int NewChannel(const std::string &ipPort, const ock::hcom::UBSHcomChannelPtr &ch, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomServiceContext &ctx); + void ChannelBroken(const ock::hcom::UBSHcomChannelPtr &ch); + bool RegMemory(); + bool ExchangeAddress(); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomChannelPtr mCh = nullptr; + ock::hcom::UBSHcomOneSideRequest mReq; + ServiceHelper mHelper; + RegMrInfo mPostMrInfo; + RegMrInfo mPollMrInfo; + RegMrInfo mPeerMrInfo; + uint64_t rcnt = 0; + std::atomic ccnt{ 0 }; + std::atomic isConnect{ false }; + sem_t mSem; +}; +} +} + +#endif diff --git a/test/tools/perf_test/test_case/transport/transport_helper.cpp b/test/tools/perf_test/test_case/transport/transport_helper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40db9ecf54c433cafc53829005003657170f3993 --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_helper.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include "test_case/transport/transport_helper.h" +#include "common/perf_test_logger.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; + +static int NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, const std::string &payload) +{ + return 0; +} + +static void EndPointBroken(const ock::hcom::UBSHcomNetEndpointPtr &ep) +{ + return; +} + +static int RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +static int RequestPosted(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +static int OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +TransportHelper::TransportHelper(const PerfTestConfig &cfg) +{ + mCfg = cfg; + mDriver = nullptr; + // 回调函数提供默认空实现,简化测试用例 + mNewEpHandler = NewEndPoint; + mEpBrokenHandler = EndPointBroken; + mReqRecvHandler = RequestReceived; + mOneSideDoneHandler = OneSideDone; + mReqPostedHandler = RequestPosted; +} + +void TransportHelper::RegisterNewEPHandler(const NewEpHandler &handler) +{ + mNewEpHandler = handler; +} + +void TransportHelper::RegisterEpBrokenHandler(const EpBrokenHandler &handler) +{ + mEpBrokenHandler = handler; +} + +void TransportHelper::RegisterReqRecvHandler(const ReqRecvHandler &handler) +{ + mReqRecvHandler = handler; +} + +void TransportHelper::RegisterOneSideDoneHandler(const OneSideDoneHandler &handler) +{ + mOneSideDoneHandler = handler; +} + +void TransportHelper::RegisterReqPostedHandler(const ReqPostedHandler &handler) +{ + mReqPostedHandler = handler; +} + +bool TransportHelper::CreateNetDriver() +{ + if (mDriver != nullptr) { + LOG_WARN("UBSHcomNetDriver already created"); + return true; + } + + mDriver = UBSHcomNetDriver::Instance(mCfg.GetProtocol(), "PerfTest", mCfg.GetIsServer()); + + UBSHcomNetDriverOptions options{}; + FillNetDriverOption(options); + mDriver->RegisterNewEPHandler(mNewEpHandler); + mDriver->RegisterEPBrokenHandler(mEpBrokenHandler); + mDriver->RegisterReqPostedHandler(mReqPostedHandler); + mDriver->RegisterNewReqHandler(mReqRecvHandler); + mDriver->RegisterOneSideDoneHandler(mOneSideDoneHandler); + + if (mCfg.GetIsServer()) { + mDriver->OobIpAndPort(mCfg.GetOobIp(), mCfg.GetOobPort()); + } + + int result = 0; + if ((result = mDriver->Initialize(options)) != 0) { + LOG_ERROR("failed to initialize driver " << result); + return false; + } + LOG_DEBUG("UBSHcomNetDriver initialized"); + + if ((result = mDriver->Start()) != 0) { + LOG_ERROR("failed to start UBSHcomNetDriver " << result); + return false; + } + + LOG_DEBUG("UBSHcomNetDriver started"); + return true; +} + +void TransportHelper::DestroyNetDriver() +{ + if (mDriver != nullptr) { + if (!mMrVector.empty()) { + for (auto mr : mMrVector) { + mDriver->DestroyMemoryRegion(mr); + } + mMrVector.clear(); + } + mDriver->Stop(); + mDriver->UnInitialize(); + UBSHcomNetDriver::DestroyInstance(mDriver->Name()); + mDriver = nullptr; + } +} + +bool TransportHelper::FillNetDriverOption(ock::hcom::UBSHcomNetDriverOptions &opts) +{ + opts.mode = UBSHcomNetDriverWorkingMode::NET_BUSY_POLLING; + opts.mrSendReceiveSegSize = MAX_MESSAGE_SIZE + HCOM_HEADER_SIZE; + opts.mrSendReceiveSegCount = NN_NO2048; + opts.pollingBatchSize = 16; + PERF_TEST_TYPE type = mCfg.GetType(); + if (type == PERF_TEST_TYPE::TRANSPORT_SEND_BW) { + // 为TRANSPORT_SEND_BW模式时,留一个oneSide wr给心跳,剩余都用于send wr + opts.qpSendQueueSize = 1024; + opts.qpReceiveQueueSize = 1024; + opts.prePostReceiveSizePerQP = 1023; + } + opts.SetNetDeviceIpMask(mCfg.GetIpMask()); + opts.SetWorkerGroups("1"); + opts.enableTls = 0; + if (mCfg.GetCpuId() != -1) { + std::string str = std::to_string(mCfg.GetCpuId()) + "-" + std::to_string(mCfg.GetCpuId()); + opts.SetWorkerGroupsCpuSet(str); + } + return true; +} + +bool TransportHelper::CreateMemoryRegion(MrInfo &mrInfo) +{ + if (mDriver == nullptr) { + return false; + } + + UBSHcomNetMemoryRegionPtr mr; + // 按照最大包大小申请内存,以支持同时测试多个不同大小的包 + auto result = mDriver->CreateMemoryRegion(MAX_MESSAGE_SIZE, mr); + if (result != 0) { + LOG_ERROR("Create memory region failed"); + return false; + } + + mrInfo.lAddress = mr->GetAddress(); + mrInfo.lKey = mr->GetLKey(); + mrInfo.size = MAX_MESSAGE_SIZE; + mMrVector.emplace_back(mr); + LOG_DEBUG("register addr: " << mrInfo.lAddress << ", lKey = " << mrInfo.lKey << ", size = " << MAX_MESSAGE_SIZE); + return true; +} +} +} \ No newline at end of file diff --git a/test/tools/perf_test/test_case/transport/transport_helper.h b/test/tools/perf_test/test_case/transport/transport_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..35cf2abddd3fcf5699f9d8551206eeac10a257b2 --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_helper.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_TRANSPORT_HELPER_H +#define HCOM_PERF_TEST_TRANSPORT_HELPER_H + +#include "hcom/hcom.h" +#include "common/perf_test_common.h" +#include "common/perf_test_config.h" + +namespace hcom { +namespace perftest { +using NewEpHandler = + std::function; +using EpBrokenHandler = std::function; +using ReqRecvHandler = std::function; +using OneSideDoneHandler = std::function; +using ReqPostedHandler = std::function; + +class TransportHelper { +public: + TransportHelper(const PerfTestConfig &cfg); + bool FillNetDriverOption(ock::hcom::UBSHcomNetDriverOptions &opts); + bool CreateMemoryRegion(MrInfo &mrInfo); + + bool CreateNetDriver(); + void DestroyNetDriver(); + inline ock::hcom::UBSHcomNetDriver *GetNetDriver() const + { + return mDriver; + } + + void RegisterNewEPHandler(const NewEpHandler &handler); + void RegisterEpBrokenHandler(const EpBrokenHandler &handler); + void RegisterReqRecvHandler(const ReqRecvHandler &handler); + void RegisterOneSideDoneHandler(const OneSideDoneHandler &handler); + void RegisterReqPostedHandler(const ReqPostedHandler &handler); + +private: + PerfTestConfig mCfg; + ock::hcom::UBSHcomNetDriver *mDriver = nullptr; + std::vector mMrVector; + + NewEpHandler mNewEpHandler; + EpBrokenHandler mEpBrokenHandler; + ReqRecvHandler mReqRecvHandler; + OneSideDoneHandler mOneSideDoneHandler; + ReqPostedHandler mReqPostedHandler; +}; +} +} + +#endif \ No newline at end of file diff --git a/test/tools/perf_test/test_case/transport/transport_read_bw_test.cpp b/test/tools/perf_test/test_case/transport/transport_read_bw_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..abc550b1e8335830884f07b23099791ce1267062 --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_read_bw_test.cpp @@ -0,0 +1,191 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "securec.h" + +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "transport_read_bw_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_CODE_READ_BW = 4; + +int TransportReadBwTest::DoPostRead() +{ + PerfTestContext *ctx = GetPerfTestContext(); + // ctx->tposted[i+1] - ctx->tposted[i] 为一次read完成时间 + ock::hcom::UBSHcomNetTransRequest req; + req.lAddress = mPostMrInfo.lAddress; + req.rAddress = mPeerMrInfo.lAddress; + req.lKey = mPostMrInfo.lKey; + req.rKey = mPeerMrInfo.lKey; + req.size = ctx->mSize; + ctx->tposted[0] = MONOTONIC_TIME_NS(); + for (uint64_t i = 0; i < ctx->mIterations; ++i) { + int res = mEp->PostRead(req); + if (res != 0) { + LOG_ERROR("failed to send to server"); + } + } + return 0; +} + +int TransportReadBwTest::NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, + const std::string &payload) +{ + mEp = ep; + LOG_DEBUG("new connection from " << ipPort << " !"); + return 0; +} + +int TransportReadBwTest::RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + int result = 0; + if (!mCfg.GetIsServer()) { + // client + if (memcpy_s(&mPeerMrInfo, sizeof(mPeerMrInfo), ctx.Message()->Data(), ctx.Message()->DataLen()) != 0) { + LOG_ERROR("memcpy_s failed"); + return -1; + } + sem_post(&mSem); + return result; + } + // server + UBSHcomNetTransRequest rsp((void *)(&mPostMrInfo), sizeof(mPostMrInfo), 0); + if ((result = mEp->PostSend(OP_CODE_READ_BW, rsp)) != 0) { + LOG_ERROR("Failed to post message to data to server, result " << result); + } + return result; +} + +int TransportReadBwTest::OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + PerfTestContext *testCtx = GetPerfTestContext(); + testCtx->totrcnt++; + if (testCtx->totrcnt == testCtx->mIterations) { + testCtx->tposted[testCtx->mIterations] = MONOTONIC_TIME_NS(); + sem_post(&mSem); + } + return 0; +} + +bool TransportReadBwTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create UBSHcomNetDriver + NewEpHandler funcNewEndpoint = bind(&TransportReadBwTest::NewEndPoint, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + ReqRecvHandler funcReqReceived = bind(&TransportReadBwTest::RequestReceived, this, std::placeholders::_1); + OneSideDoneHandler funcOneSide = bind(&TransportReadBwTest::OneSideDone, this, std::placeholders::_1); + mHelper.RegisterNewEPHandler(funcNewEndpoint); + mHelper.RegisterReqRecvHandler(funcReqReceived); + mHelper.RegisterOneSideDoneHandler(funcOneSide); + if (!mHelper.CreateNetDriver()) { + goto ERROR_HANDLE; + } + + // init data buffer + mDataAddr = new char[MAX_MESSAGE_SIZE]; + if (mDataAddr == nullptr) { + LOG_ERROR("create data buffer failed"); + goto ERROR_HANDLE; + } + + if (!RegMemory()) { + LOG_ERROR("register memory failed"); + goto ERROR_HANDLE; + } + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + if (!ExchangeAddress()) { + LOG_ERROR("client exchange address failed"); + goto ERROR_HANDLE; + } + } + return true; + +ERROR_HANDLE: + mHelper.DestroyNetDriver(); + sem_destroy(&mSem); + return false; +} + +bool TransportReadBwTest::RegMemory() +{ + if (!mHelper.CreateMemoryRegion(mPostMrInfo)) { + LOG_ERROR("server create memoryRegion failed"); + return false; + } + return true; +} + +bool TransportReadBwTest::ExchangeAddress() +{ + if (mEp == nullptr) { + LOG_ERROR("Exchange address failed, ep is nullptr!"); + return false; + } + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + if (mEp->PostSend(OP_CODE_READ_BW, req) != 0) { + LOG_ERROR("Failed to exchange address to data to server"); + return false; + } + sem_wait(&mSem); + return true; +} + +void TransportReadBwTest::UnInitialize() +{ + if (mEp != nullptr) { + mEp->Close(); + mEp.Set(nullptr); + } + + mHelper.DestroyNetDriver(); + + if (mDataAddr != nullptr) { + delete[] mDataAddr; + mDataAddr = nullptr; + } + sem_destroy(&mSem); +} + +bool TransportReadBwTest::Connect() +{ + auto driver = mHelper.GetNetDriver(); + if (driver == nullptr) { + LOG_ERROR("connect failed, net driver is nullptr!"); + return false; + } + int res = driver->Connect(mCfg.GetOobIp(), mCfg.GetOobPort(), "xx", mEp, 0); + if (res != 0) { + LOG_ERROR("connect failed, error code: " << res); + return false; + } + return true; +} + +bool TransportReadBwTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (!mCfg.GetIsServer()) { + DoPostRead(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::TRANSPORT_READ_BW, TransportReadBwTest); +} +} \ No newline at end of file diff --git a/test/tools/perf_test/test_case/transport/transport_read_bw_test.h b/test/tools/perf_test/test_case/transport/transport_read_bw_test.h new file mode 100644 index 0000000000000000000000000000000000000000..5027aa220ba5cb20397e53d5e47f1d0f0fb6d5cd --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_read_bw_test.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +#ifndef HCOM_TRANSPORT_READ_BW_TEST_H +#define HCOM_TRANSPORT_READ_BW_TEST_H + +#include + +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/transport/transport_helper.h" + +namespace hcom { +namespace perftest { +class TransportReadBwTest : public PerfTestBase { +public: + explicit TransportReadBwTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + int DoPostRead(); + int NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx); + int RequestPosted(const ock::hcom::UBSHcomNetRequestContext &ctx); + int OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx); + bool RegMemory(); + bool ExchangeAddress(); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomNetEndpointPtr mEp = nullptr; + TransportHelper mHelper; + MrInfo mPostMrInfo; + MrInfo mPeerMrInfo; + char *mDataAddr = nullptr; + sem_t mSem; +}; +} +} + +#endif // HCOM_TRANSPORT_READ_BW_TEST_H diff --git a/test/tools/perf_test/test_case/transport/transport_read_lat_test.cpp b/test/tools/perf_test/test_case/transport/transport_read_lat_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c61fe229bafd9d565d706d33debae08093ae4dfc --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_read_lat_test.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "securec.h" + +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "transport_read_lat_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_CODE_READ_LAT = 2; +int TransportReadLatTest::NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, + const std::string &payload) +{ + mEp = ep; + LOG_DEBUG("new connection from " << ipPort << " !"); + return 0; +} + +int TransportReadLatTest::RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + int result = 0; + if (!mCfg.GetIsServer()) { + // client + if (memcpy_s(&serverMrInfo, sizeof(serverMrInfo), ctx.Message()->Data(), ctx.Message()->DataLen()) != 0) { + LOG_ERROR("memcpy_s failed"); + return -1; + } + sem_post(&mSem); + return result; + } + // server + UBSHcomNetTransRequest rsp((void *)(&serverMrInfo), sizeof(serverMrInfo), 0); + if ((result = mEp->PostSend(OP_CODE_READ_LAT, rsp)) != 0) { + LOG_ERROR("Failed to post message to data to server, result " << result); + } + return result; +} + +int TransportReadLatTest::OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + return DoPostRead(); +} + +bool TransportReadLatTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create UBSHcomNetDriver + NewEpHandler funcNewEndpoint = bind(&TransportReadLatTest::NewEndPoint, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + ReqRecvHandler funcReqReceived = bind(&TransportReadLatTest::RequestReceived, this, std::placeholders::_1); + OneSideDoneHandler funcOneSide = bind(&TransportReadLatTest::OneSideDone, this, std::placeholders::_1); + mHelper.RegisterNewEPHandler(funcNewEndpoint); + mHelper.RegisterReqRecvHandler(funcReqReceived); + mHelper.RegisterOneSideDoneHandler(funcOneSide); + if (!mHelper.CreateNetDriver()) { + goto ERROR_HANDLE; + } + + if (!RegMemory()) { + LOG_ERROR("register memory failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + if (!ExchangeAddress()) { + LOG_ERROR("client exchange address failed"); + goto ERROR_HANDLE; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyNetDriver(); + sem_destroy(&mSem); + return false; +} + +bool TransportReadLatTest::RegMemory() +{ + if (!mCfg.GetIsServer()) { + // client + if (!mHelper.CreateMemoryRegion(clientMrInfo)) { + LOG_ERROR("client create memoryRegion failed"); + return false; + } + } else { + // server + if (!mHelper.CreateMemoryRegion(serverMrInfo)) { + LOG_ERROR("server create memoryRegion failed"); + return false; + } + } + return true; +} + +bool TransportReadLatTest::ExchangeAddress() +{ + if (mEp == nullptr) { + LOG_ERROR("Exchange address failed, ep is nullptr!"); + return false; + } + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + if (mEp->PostSend(OP_CODE_READ_LAT, req) != 0) { + LOG_ERROR("Failed to exchange address to data to server"); + return false; + } + sem_wait(&mSem); + return true; +} + +void TransportReadLatTest::UnInitialize() +{ + if (mEp != nullptr) { + mEp->Close(); + mEp.Set(nullptr); + } + + mHelper.DestroyNetDriver(); + sem_destroy(&mSem); +} + +bool TransportReadLatTest::Connect() +{ + auto driver = mHelper.GetNetDriver(); + if (driver == nullptr) { + LOG_ERROR("connect failed, net driver is nullptr!"); + return false; + } + int res = driver->Connect(mCfg.GetOobIp(), mCfg.GetOobPort(), "read_test", mEp, 0); + if (res != 0) { + LOG_ERROR("connect failed, error code: " << res); + return false; + } + return true; +} + +bool TransportReadLatTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + mReq.lAddress = clientMrInfo.lAddress; + mReq.rAddress = serverMrInfo.lAddress; + mReq.lKey = clientMrInfo.lKey; + mReq.rKey = serverMrInfo.lKey; + mReq.size = ctx->mSize; + if (!mCfg.GetIsServer()) { + DoPostRead(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::TRANSPORT_READ_LAT, TransportReadLatTest); +} +} diff --git a/test/tools/perf_test/test_case/transport/transport_read_lat_test.h b/test/tools/perf_test/test_case/transport/transport_read_lat_test.h new file mode 100644 index 0000000000000000000000000000000000000000..e4930e836c65feaea4dcd5916589eb6f1c8e3596 --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_read_lat_test.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +#ifndef HCOM_TRANSPORT_READ_LAT_TEST_H +#define HCOM_TRANSPORT_READ_LAT_TEST_H + +#include + +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/transport/transport_helper.h" + + +namespace hcom { +namespace perftest { +class TransportReadLatTest : public PerfTestBase { +public: + explicit TransportReadLatTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + + inline int DoPostRead() + { + if (mCtx->cnt < mCtx->mIterations) { + mCtx->tposted[mCtx->cnt] = ock::hcom::MONOTONIC_TIME_NS(); + int res = mEp->PostRead(mReq); + if (res != 0) { + LOG_ERROR("failed to send to server"); + } + ++mCtx->cnt; + return 0; + } + + mCtx->tposted[mCtx->cnt] = ock::hcom::MONOTONIC_TIME_NS(); + LOG_DEBUG("One Iteration Done!"); + sem_post(&mSem); + return 0; + } + + int NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx); + int OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx); + bool RegMemory(); + bool ExchangeAddress(); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomNetEndpointPtr mEp = nullptr; + ock::hcom::UBSHcomNetTransRequest mReq; + TransportHelper mHelper; + MrInfo clientMrInfo; + MrInfo serverMrInfo; + sem_t mSem; +}; +} +} + +#endif // HCOM_TRANSPORT_READ_LAT_TEST_H diff --git a/test/tools/perf_test/test_case/transport/transport_send_bw_test.cpp b/test/tools/perf_test/test_case/transport/transport_send_bw_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ba1b456d6a32612982065d8c1cab6b3e61ea46c --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_send_bw_test.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "transport_send_bw_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_CODE_SEND_BW = 200; + +int TransportSendBwTest::DoPostSend() +{ + PerfTestContext *ctx = GetPerfTestContext(); + UBSHcomNetTransRequest req(mDataAddr, ctx->mSize, 0); + ctx->tposted[0] = MONOTONIC_TIME_NS(); + for (uint64_t i = 0; i < ctx->mIterations; ++i) { + int res = mEp->PostSend(OP_CODE_SEND_BW, req); + if (res != 0) { + LOG_ERROR("failed to send to server"); + } + } + return 0; +} + +int TransportSendBwTest::NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, + const std::string &payload) +{ + mEp = ep; + LOG_DEBUG("new connection from " << ipPort << " !"); + return 0; +} + +int TransportSendBwTest::RequestPosted(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + PerfTestContext *testCtx = GetPerfTestContext(); + testCtx->totrcnt++; + if (testCtx->totrcnt == testCtx->mIterations) { + testCtx->tposted[testCtx->mIterations] = MONOTONIC_TIME_NS(); + sem_post(&mSem); + } + return 0; +} + +bool TransportSendBwTest::Initialize() +{ + sem_init(&mSem, 0, 0); + // create UBSHcomNetDriver + NewEpHandler funcNewEndpoint = bind(&TransportSendBwTest::NewEndPoint, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + ReqPostedHandler funcReqPosted = bind(&TransportSendBwTest::RequestPosted, this, std::placeholders::_1); + mHelper.RegisterNewEPHandler(funcNewEndpoint); + mHelper.RegisterReqPostedHandler(funcReqPosted); + if (!mHelper.CreateNetDriver()) { + goto ERROR_HANDLE; + } + + // init data buffer + mDataAddr = new char[MAX_MESSAGE_SIZE]; + if (mDataAddr == nullptr) { + LOG_ERROR("create data buffer failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyNetDriver(); + if (mDataAddr != nullptr) { + delete[] mDataAddr; + mDataAddr = nullptr; + } + sem_destroy(&mSem); + return false; +} + +void TransportSendBwTest::UnInitialize() +{ + if (mEp != nullptr) { + mEp->Close(); + mEp.Set(nullptr); + } + + mHelper.DestroyNetDriver(); + + if (mDataAddr != nullptr) { + delete[] mDataAddr; + mDataAddr = nullptr; + } + sem_destroy(&mSem); +} + +bool TransportSendBwTest::Connect() +{ + auto driver = mHelper.GetNetDriver(); + if (driver == nullptr) { + LOG_ERROR("connect failed, net driver is nullptr!"); + return false; + } + int res = driver->Connect(mCfg.GetOobIp(), mCfg.GetOobPort(), "xx", mEp, 0); + if (res != 0) { + LOG_ERROR("connect failed, error code: " << res); + return false; + } + return true; +} + +bool TransportSendBwTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (!mCfg.GetIsServer()) { + DoPostSend(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::TRANSPORT_SEND_BW, TransportSendBwTest); +} +} diff --git a/test/tools/perf_test/test_case/transport/transport_send_bw_test.h b/test/tools/perf_test/test_case/transport/transport_send_bw_test.h new file mode 100644 index 0000000000000000000000000000000000000000..6e9a7011e7e3084f5b972e843eadb75214bb9b49 --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_send_bw_test.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_TRANSPORT_SEND_BW_TEST_H +#define HCOM_TRANSPORT_SEND_BW_TEST_H + +#include +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/transport/transport_helper.h" + +namespace hcom { +namespace perftest { +class TransportSendBwTest : public PerfTestBase { +public: + explicit TransportSendBwTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + int DoPostSend(); + int NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, const std::string &payload); + int RequestPosted(const ock::hcom::UBSHcomNetRequestContext &ctx); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomNetEndpointPtr mEp = nullptr; + TransportHelper mHelper; + + char *mDataAddr = nullptr; + sem_t mSem; +}; +} +} +#endif // HCOM_TRANSPORT_SEND_BW_TEST_H diff --git a/test/tools/perf_test/test_case/transport/transport_send_lat_test.cpp b/test/tools/perf_test/test_case/transport/transport_send_lat_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b2d819a54125595d0ef0f1e4883f1d4bd737d7f --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_send_lat_test.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include + +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "test_case/transport/transport_send_lat_test.h" + + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_CODE_SEND_LAT = 200; + +int TransportSendLatTest::DoPostSend() +{ + PerfTestContext *ctx = GetPerfTestContext(); + // ctx->tposted[i+1] - ctx->tposted[i] 为一次RTT(Round-Trip Time,往返时间) + if (ctx->cnt < ctx->mIterations) { + ctx->tposted[ctx->cnt] = MONOTONIC_TIME_NS(); + UBSHcomNetTransRequest req(mDataAddr, ctx->mSize, 0); + int res = mEp->PostSend(OP_CODE_SEND_LAT, req); + if (res != 0) { + LOG_ERROR("failed to send to server"); + } + ++ctx->cnt; + return 0; + } + + if (ctx->cnt == ctx->mIterations) { + ctx->tposted[ctx->cnt] = MONOTONIC_TIME_NS(); + LOG_DEBUG("One Iteration Done!"); + sem_post(&mSem); + } + return 0; +} + +int TransportSendLatTest::NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, + const std::string &payload) +{ + mEp = ep; + LOG_DEBUG("new connection from " << ipPort << " !"); + return 0; +} + +int TransportSendLatTest::RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + if (ctx.Header().opCode == OP_CODE_SEND_LAT) { + if (mCfg.GetIsServer()) { + // server 直接回复相同大小的消息即可 + UBSHcomNetTransRequest req(mDataAddr, ctx.Header().dataLength, 0); + int res = mEp->PostSend(OP_CODE_SEND_LAT, req); + } else { + DoPostSend(); + } + return 0; + } + + LOG_ERROR("receive unexpected opcode(=" << ctx.Header().opCode << ")."); + return -1; +} + +bool TransportSendLatTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create UBSHcomNetDriver + NewEpHandler funcNewEndpoint = bind(&TransportSendLatTest::NewEndPoint, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + ReqRecvHandler funcReqReceived = bind(&TransportSendLatTest::RequestReceived, this, std::placeholders::_1); + mHelper.RegisterNewEPHandler(funcNewEndpoint); + mHelper.RegisterReqRecvHandler(funcReqReceived); + if (!mHelper.CreateNetDriver()) { + goto ERROR_HANDLE; + } + + // init data buffer + mDataAddr = new char[MAX_MESSAGE_SIZE]; + if (mDataAddr == nullptr) { + LOG_ERROR("create data buffer failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyNetDriver(); + if (mDataAddr != nullptr) { + delete[] mDataAddr; + mDataAddr = nullptr; + } + sem_destroy(&mSem); + return false; +} + +void TransportSendLatTest::UnInitialize() +{ + if (mEp != nullptr) { + mEp->Close(); + mEp.Set(nullptr); + } + + mHelper.DestroyNetDriver(); + + if (mDataAddr != nullptr) { + delete[] mDataAddr; + mDataAddr = nullptr; + } + sem_destroy(&mSem); +} + +bool TransportSendLatTest::Connect() +{ + auto driver = mHelper.GetNetDriver(); + if (driver == nullptr) { + LOG_ERROR("connect failed, net driver is nullptr!"); + return false; + } + int res = driver->Connect(mCfg.GetOobIp(), mCfg.GetOobPort(), "xx", mEp, 0); + if (res != 0) { + LOG_ERROR("connect failed, error code: " << res); + return false; + } + return true; +} + +bool TransportSendLatTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (!mCfg.GetIsServer()) { + DoPostSend(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::TRANSPORT_SEND_LAT, TransportSendLatTest); +} +} diff --git a/test/tools/perf_test/test_case/transport/transport_send_lat_test.h b/test/tools/perf_test/test_case/transport/transport_send_lat_test.h new file mode 100644 index 0000000000000000000000000000000000000000..f9ba883ff4f44a18334ce1f07f5f6a55b5e0f5af --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_send_lat_test.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#ifndef HCOM_PERF_TEST_TRANSPORT_SEND_LAT_H +#define HCOM_PERF_TEST_TRANSPORT_SEND_LAT_H + +#include + +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/transport/transport_helper.h" + + +namespace hcom { +namespace perftest { +class TransportSendLatTest : public PerfTestBase { +public: + explicit TransportSendLatTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + int DoPostSend(); + int NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx); + int OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomNetEndpointPtr mEp = nullptr; + TransportHelper mHelper; + + char *mDataAddr = nullptr; + sem_t mSem; +}; +} +} + +#endif diff --git a/test/tools/perf_test/test_case/transport/transport_write_bw_test.cpp b/test/tools/perf_test/test_case/transport/transport_write_bw_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e852dc3d9f81b0cd75566ce33e3f5366bb7615b --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_write_bw_test.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "securec.h" + +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "transport_write_bw_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_CODE_WRITE_BW = 5; + +int TransportWriteBwTest::DoPostWrite() +{ + PerfTestContext *ctx = GetPerfTestContext(); + // ctx->tposted[i+1] - ctx->tposted[i] 为一次write完成时间 + ock::hcom::UBSHcomNetTransRequest req; + req.lAddress = mPostMrInfo.lAddress; + req.rAddress = mPeerMrInfo.lAddress; + req.lKey = mPostMrInfo.lKey; + req.rKey = mPeerMrInfo.lKey; + req.size = ctx->mSize; + ctx->tposted[0] = MONOTONIC_TIME_NS(); + for (uint64_t i = 0; i < ctx->mIterations; ++i) { + int res = mEp->PostWrite(req); + if (res != 0) { + LOG_ERROR("failed to send to server"); + } + } + return 0; +} + +int TransportWriteBwTest::NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, + const std::string &payload) +{ + mEp = ep; + LOG_DEBUG("new connection from " << ipPort << " !"); + return 0; +} + +int TransportWriteBwTest::RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + int result = 0; + if (!mCfg.GetIsServer()) { + // client + if (memcpy_s(&mPeerMrInfo, sizeof(mPeerMrInfo), ctx.Message()->Data(), ctx.Message()->DataLen()) != 0) { + LOG_ERROR("memcpy_s failed"); + return -1; + } + sem_post(&mSem); + return result; + } + // server + UBSHcomNetTransRequest rsp((void *)(&mPostMrInfo), sizeof(mPostMrInfo), 0); + if ((result = mEp->PostSend(OP_CODE_WRITE_BW, rsp)) != 0) { + LOG_ERROR("Failed to post message to data to server, result " << result); + } + return result; +} + +int TransportWriteBwTest::OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + PerfTestContext *testCtx = GetPerfTestContext(); + testCtx->totrcnt++; + if (testCtx->totrcnt == testCtx->mIterations) { + testCtx->tposted[testCtx->mIterations] = MONOTONIC_TIME_NS(); + sem_post(&mSem); + } + return 0; +} + +bool TransportWriteBwTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create UBSHcomNetDriver + NewEpHandler funcNewEndpoint = bind(&TransportWriteBwTest::NewEndPoint, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + ReqRecvHandler funcReqReceived = bind(&TransportWriteBwTest::RequestReceived, this, std::placeholders::_1); + OneSideDoneHandler funcOneSide = bind(&TransportWriteBwTest::OneSideDone, this, std::placeholders::_1); + mHelper.RegisterNewEPHandler(funcNewEndpoint); + mHelper.RegisterReqRecvHandler(funcReqReceived); + mHelper.RegisterOneSideDoneHandler(funcOneSide); + if (!mHelper.CreateNetDriver()) { + goto ERROR_HANDLE; + } + + if (!RegMemory()) { + LOG_ERROR("register memory failed"); + goto ERROR_HANDLE; + } + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + if (!ExchangeAddress()) { + LOG_ERROR("client exchange address failed"); + goto ERROR_HANDLE; + } + } + return true; + +ERROR_HANDLE: + mHelper.DestroyNetDriver(); + sem_destroy(&mSem); + return false; +} + +bool TransportWriteBwTest::RegMemory() +{ + if (!mHelper.CreateMemoryRegion(mPostMrInfo)) { + LOG_ERROR("create memoryRegion failed"); + return false; + } + return true; +} + +bool TransportWriteBwTest::ExchangeAddress() +{ + if (mEp == nullptr) { + LOG_ERROR("Exchange address failed, ep is nullptr!"); + return false; + } + std::string value = "hello world"; + UBSHcomNetTransRequest req((void *)(const_cast(value.c_str())), value.length(), 0); + if (mEp->PostSend(OP_CODE_WRITE_BW, req) != 0) { + LOG_ERROR("Failed to exchange address to data to server"); + return false; + } + sem_wait(&mSem); + return true; +} + +void TransportWriteBwTest::UnInitialize() +{ + if (mEp != nullptr) { + mEp->Close(); + mEp.Set(nullptr); + } + + mHelper.DestroyNetDriver(); + sem_destroy(&mSem); +} + +bool TransportWriteBwTest::Connect() +{ + auto driver = mHelper.GetNetDriver(); + if (driver == nullptr) { + LOG_ERROR("connect failed, net driver is nullptr!"); + return false; + } + int res = driver->Connect(mCfg.GetOobIp(), mCfg.GetOobPort(), "xx", mEp, 0); + if (res != 0) { + LOG_ERROR("connect failed, error code: " << res); + return false; + } + return true; +} + +bool TransportWriteBwTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (!mCfg.GetIsServer()) { + DoPostWrite(); + } + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::TRANSPORT_WRITE_BW, TransportWriteBwTest); +} +} \ No newline at end of file diff --git a/test/tools/perf_test/test_case/transport/transport_write_bw_test.h b/test/tools/perf_test/test_case/transport/transport_write_bw_test.h new file mode 100644 index 0000000000000000000000000000000000000000..f8eb9576d7b85e44599587df2e6b3e478d6cc128 --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_write_bw_test.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +#ifndef HCOM_TRANSPORT_WRITE_BW_TEST_H +#define HCOM_TRANSPORT_WRITE_BW_TEST_H + +#include + +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/transport/transport_helper.h" + +namespace hcom { +namespace perftest { +class TransportWriteBwTest : public PerfTestBase { +public: + explicit TransportWriteBwTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + int DoPostWrite(); + int NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx); + int RequestPosted(const ock::hcom::UBSHcomNetRequestContext &ctx); + int OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx); + bool RegMemory(); + bool ExchangeAddress(); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomNetEndpointPtr mEp = nullptr; + TransportHelper mHelper; + MrInfo mPostMrInfo; + MrInfo mPeerMrInfo; + sem_t mSem; +}; +} +} + +#endif // HCOM_TRANSPORT_WRITE_BW_TEST_H diff --git a/test/tools/perf_test/test_case/transport/transport_write_lat_test.cpp b/test/tools/perf_test/test_case/transport/transport_write_lat_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84cfd92ef83872bab0b99e5b654a8a4b670f1ac2 --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_write_lat_test.cpp @@ -0,0 +1,178 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include +#include "securec.h" + +#include "common/perf_test_logger.h" +#include "test_case/perf_test_factory.h" +#include "transport_write_lat_test.h" + +namespace hcom { +namespace perftest { +using namespace ock::hcom; +constexpr uint16_t OP_CODE_WRITE_LAT = 3; +int TransportWriteLatTest::NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, + const std::string &payload) +{ + mEp = ep; + isConnect.store(true); + LOG_DEBUG("new connection from " << ipPort << " !"); + return 0; +} + +int TransportWriteLatTest::RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + int result = 0; + if (memcpy_s(&mPeerMrInfo, sizeof(mPeerMrInfo), ctx.Message()->Data(), ctx.Message()->DataLen()) != 0) { + LOG_ERROR("memcpy_s failed"); + return -1; + } + if (!mCfg.GetIsServer()) { + // client + sem_post(&mSem); + return result; + } + // server + UBSHcomNetTransRequest rsp((void *)(&mPollMrInfo), sizeof(mPollMrInfo), 0); + if ((result = mEp->PostSend(OP_CODE_WRITE_LAT, rsp)) != 0) { + LOG_ERROR("Failed to post message to data to server, result " << result); + } + sem_post(&mSem); + return result; +} + +int TransportWriteLatTest::OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx) +{ + ccnt.fetch_add(1); + return 0; +} + +int TransportWriteLatTest::EpBroken(const ock::hcom::UBSHcomNetEndpointPtr &ep) +{ + isConnect.store(false); + return 0; +} + +bool TransportWriteLatTest::Initialize() +{ + sem_init(&mSem, 0, 0); + + // create UBSHcomNetDriver + NewEpHandler funcNewEndpoint = bind(&TransportWriteLatTest::NewEndPoint, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + ReqRecvHandler funcReqReceived = bind(&TransportWriteLatTest::RequestReceived, this, std::placeholders::_1); + OneSideDoneHandler funcOneSide = bind(&TransportWriteLatTest::OneSideDone, this, std::placeholders::_1); + EpBrokenHandler funcBrokenEp = bind(&TransportWriteLatTest::EpBroken, this, std::placeholders::_1); + + mHelper.RegisterNewEPHandler(funcNewEndpoint); + mHelper.RegisterReqRecvHandler(funcReqReceived); + mHelper.RegisterOneSideDoneHandler(funcOneSide); + mHelper.RegisterEpBrokenHandler(funcBrokenEp); + if (!mHelper.CreateNetDriver()) { + goto ERROR_HANDLE; + } + + if (!RegMemory()) { + LOG_ERROR("register memory failed"); + goto ERROR_HANDLE; + } + + // client connect to server + if (!mCfg.GetIsServer()) { + if (!Connect()) { + LOG_ERROR("client connect failed"); + goto ERROR_HANDLE; + } + if (!ExchangeAddress()) { + LOG_ERROR("client exchange address failed"); + goto ERROR_HANDLE; + } + } + + return true; + +ERROR_HANDLE: + mHelper.DestroyNetDriver(); + sem_destroy(&mSem); + return false; +} + +bool TransportWriteLatTest::RegMemory() +{ + if (!mHelper.CreateMemoryRegion(mPostMrInfo)) { + LOG_ERROR("Create send memoryRegion failed"); + return false; + } + if (!mHelper.CreateMemoryRegion(mPollMrInfo)) { + LOG_ERROR("Create receive memoryRegion failed"); + return false; + } + return true; +} + +bool TransportWriteLatTest::ExchangeAddress() +{ + if (mEp == nullptr) { + LOG_ERROR("Exchange address failed, ep is nullptr!"); + return false; + } + UBSHcomNetTransRequest req((void *)(&mPollMrInfo), sizeof(mPollMrInfo), 0); + if (mEp->PostSend(OP_CODE_WRITE_LAT, req) != 0) { + LOG_ERROR("Failed to exchange address to data to server"); + return false; + } + sem_wait(&mSem); + return true; +} + +void TransportWriteLatTest::UnInitialize() +{ + if (mEp != nullptr) { + mEp->Close(); + mEp.Set(nullptr); + } + + mHelper.DestroyNetDriver(); + sem_destroy(&mSem); +} + +bool TransportWriteLatTest::Connect() +{ + auto driver = mHelper.GetNetDriver(); + if (driver == nullptr) { + LOG_ERROR("connect failed, net driver is nullptr!"); + return false; + } + int res = driver->Connect(mCfg.GetOobIp(), mCfg.GetOobPort(), "write_test", mEp, 0); + if (res != 0) { + LOG_ERROR("connect failed, error code: " << res); + return false; + } + return true; +} + +bool TransportWriteLatTest::RunTest(PerfTestContext *ctx) +{ + // ctx会记录测试中每个Iteration耗时,故每次使用不同的ctx + SetPerfTestContext(ctx); + if (mCfg.GetIsServer() && !isConnect.load()) { + // server等到地址交换结束 + sem_wait(&mSem); + } + + mReq.lAddress = mPostMrInfo.lAddress; + mReq.rAddress = mPeerMrInfo.lAddress; + mReq.lKey = mPostMrInfo.lKey; + mReq.rKey = mPeerMrInfo.lKey; + mReq.size = ctx->mSize; + + DoPostWrite(); + // 等待测试结束 + sem_wait(&mSem); + return true; +} + +REGIST_PERF_TEST_CREATOR(PERF_TEST_TYPE::TRANSPORT_WRITE_LAT, TransportWriteLatTest); +} +} diff --git a/test/tools/perf_test/test_case/transport/transport_write_lat_test.h b/test/tools/perf_test/test_case/transport/transport_write_lat_test.h new file mode 100644 index 0000000000000000000000000000000000000000..078bac5f5f805ff64183f6d77906772252481ed6 --- /dev/null +++ b/test/tools/perf_test/test_case/transport/transport_write_lat_test.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +#ifndef HCOM_TRANSPORT_READ_LAT_TEST_H +#define HCOM_TRANSPORT_READ_LAT_TEST_H + +#include +#include +#include "hcom/hcom.h" +#include "test_case/perf_test_base.h" +#include "test_case/transport/transport_helper.h" + + +namespace hcom { +namespace perftest { +class TransportWriteLatTest : public PerfTestBase { +public: + explicit TransportWriteLatTest(const PerfTestConfig &cfg) : PerfTestBase(cfg), mHelper(cfg){}; + bool Initialize() override; + void UnInitialize() override; + bool RunTest(PerfTestContext *ctx) override; + +private: + bool Connect(); + + inline int DoPostWrite() + { + volatile uint64_t *pollData = reinterpret_cast(mPollMrInfo.lAddress); + volatile uint64_t *postData = reinterpret_cast(mPostMrInfo.lAddress); + uint64_t num = 0; + *pollData = num; + *postData = num; + PerfTestContext *ctx = GetPerfTestContext(); + ctx->cnt = 0; + rcnt = 0; + ccnt.store(0); + while (ctx->cnt < ctx->mIterations || rcnt < ctx->mIterations || ccnt.load() < ctx->mIterations) { + if (rcnt < ctx->mIterations && !(ctx->cnt < 1 && !mCfg.GetIsServer())) { + rcnt++; + while ((*pollData != rcnt) && ctx->cnt < ctx->mIterations) { + } + } + if (ctx->cnt < ctx->mIterations) { + ++ctx->cnt; + ctx->tposted[ctx->cnt - 1] = ock::hcom::MONOTONIC_TIME_NS(); + *postData = ctx->cnt; + int res = mEp->PostWrite(mReq); + if (res != 0) { + LOG_ERROR("failed to send to server"); + return -1; + } + } + while (ccnt.load() != ctx->cnt) { + } + } + if (ctx->cnt == ctx->mIterations) { + ctx->tposted[ctx->cnt] = ock::hcom::MONOTONIC_TIME_NS(); + LOG_DEBUG("One Iteration Done!"); + sem_post(&mSem); + } + return 0; + } + + int NewEndPoint(const std::string &ipPort, const ock::hcom::UBSHcomNetEndpointPtr &ep, const std::string &payload); + int RequestReceived(const ock::hcom::UBSHcomNetRequestContext &ctx); + int OneSideDone(const ock::hcom::UBSHcomNetRequestContext &ctx); + int EpBroken(const ock::hcom::UBSHcomNetEndpointPtr &ep); + bool RegMemory(); + bool ExchangeAddress(); + +private: + bool SetPerfTestContext(PerfTestContext *ctx) + { + if (ctx == nullptr) { + return false; + } + mCtx = ctx; + return true; + } + + PerfTestContext *GetPerfTestContext() const + { + return mCtx; + } + PerfTestContext *mCtx = nullptr; + +private: + ock::hcom::UBSHcomNetEndpointPtr mEp = nullptr; + TransportHelper mHelper; + ock::hcom::UBSHcomNetTransRequest mReq; + MrInfo mPostMrInfo; + MrInfo mPollMrInfo; + MrInfo mPeerMrInfo; + uint64_t rcnt = 0; + std::atomic ccnt{ 0 }; + std::atomic isConnect{ false }; + sem_t mSem; +}; +} +} + +#endif // HCOM_TRANSPORT_READ_LAT_TEST_H diff --git a/test/unit_test/CMakeLists.txt b/test/unit_test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8415d3f1dae7b773648312c46d0f3218d793552c --- /dev/null +++ b/test/unit_test/CMakeLists.txt @@ -0,0 +1,99 @@ +# +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# +cmake_minimum_required(VERSION 3.14) +project(HCOM_UNIT_TEST C CXX) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 11) + +# define macro that will be used in source code +add_compile_options(-DMOCK_VERBS -DTEST_LLT) + +# collect hcom source files +include_directories(${CMAKE_SOURCE_DIR}/src/) +include_directories(${CMAKE_SOURCE_DIR}/src/api/capi) +include_directories(${CMAKE_SOURCE_DIR}/src/api/java_sdk/jni) +include_directories(${CMAKE_SOURCE_DIR}/src/api/java_sdk/jni/include) +include_directories(${CMAKE_SOURCE_DIR}/src/api/java_sdk/jni/service) +include_directories(${CMAKE_SOURCE_DIR}/src/api/java_sdk/jni/common) +include_directories(${CMAKE_SOURCE_DIR}/src/under_api/verbs) +include_directories(${CMAKE_SOURCE_DIR}/src/service_v2/api/) +include_directories(${CMAKE_SOURCE_DIR}/src/service_v2/) +include_directories(${CMAKE_SOURCE_DIR}/src/api/capi_v2/) +include_directories(${CMAKE_SOURCE_DIR}/src/common/trace) + + +file(GLOB_RECURSE SOURCE_FILES + ${CMAKE_SOURCE_DIR}/src/*.cpp + ${CMAKE_SOURCE_DIR}/src/*.h) + +# remove jni file from hcom source files +file(GLOB_RECURSE HCOM_JNI_SRCS "${HCOM_SRC_DIR}/src/api/java_sdk/*.cpp" + "${HCOM_SRC_DIR}/src/api/java_sdk/*.h") +list(REMOVE_ITEM SOURCE_FILES ${HCOM_JNI_SRCS}) + +# collect hcom unittest files +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_FILES) +aux_source_directory(service_v2/api SOURCE_FILES) +aux_source_directory(service_v2 SOURCE_FILES) +aux_source_directory(capi_v2 SOURCE_FILES) +include_directories(${CMAKE_SOURCE_DIR}/src/api/capi_v2/) +aux_source_directory(common SOURCE_FILES) +aux_source_directory(under_api SOURCE_FILES) +aux_source_directory(transport SOURCE_FILES) +aux_source_directory(transport/rdma SOURCE_FILES) +aux_source_directory(transport/rdma/verbs SOURCE_FILES) +aux_source_directory(transport/shm SOURCE_FILES) +aux_source_directory(transport/sock SOURCE_FILES) +aux_source_directory(transport/ub SOURCE_FILES) +aux_source_directory(transport/common SOURCE_FILES) +aux_source_directory(under_api/openssl SOURCE_FILES) +aux_source_directory(under_api/urma SOURCE_FILES) + +if (BUILD_WITH_HTRACER) +aux_source_directory(common/trace SOURCE_FILES) +endif() + +file(GLOB_RECURSE HCOM_SERVICE_SRCS + "${HCOM_SRC_DIR}/src/service/*" + "${HCOM_SRC_DIR}/src/api/capi/*") + +file(GLOB_RECURSE HCOM_SERVICE_V2_SRCS + "${HCOM_SRC_DIR}/src/service_v2/*" + "${HCOM_SRC_DIR}/src/api/capi_v2/*") + +# include&link gtest +set(GTEST_INSTALL_DIR "${TEST_TOOL_INSTALL_PATH}/googletest") +if (NOT EXISTS ${GTEST_INSTALL_DIR}) + message(ERROR "GTEST_INSTALL_DIR(${GTEST_INSTALL_DIR}) is invalid") +endif() +include_directories(${GTEST_INSTALL_DIR}/include) +link_directories(${GTEST_INSTALL_DIR}/lib64) + +# include&link mockcpp +set(MOCKCPP_INSTALL_DIR "${TEST_TOOL_INSTALL_PATH}/mockcpp") +if (NOT EXISTS ${MOCKCPP_INSTALL_DIR}) + message(ERROR "MOCKCPP_INSTALL_DIR(${MOCKCPP_INSTALL_DIR}) is invalid") +endif() +include_directories(${MOCKCPP_INSTALL_DIR}/include) +link_directories(${MOCKCPP_INSTALL_DIR}/lib) + +# enable gcov +add_compile_options(-ftest-coverage -fprofile-arcs) + +# Ignore access control in class(like private or protected) +add_compile_options(-fno-access-control) + +# enable asan +# add_compile_options(-fsanitize=address) +# add_link_options(-fsanitize=address) + +# build hcom_ut +set(DEPEND_LIBS rt pthread gtest gcov mockcpp dl boundscheck fake_ibv_static) +set(HCOM_UT hcom_ut) +add_executable(${HCOM_UT} ${SOURCE_FILES}) +target_compile_options(${HCOM_UT} PUBLIC -D_GNU_SOURCE) +set_target_properties(${HCOM_UT} PROPERTIES CLEAN_DIRECT_OUTPUT 1) +set_target_properties(${HCOM_UT} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") +target_link_libraries(${HCOM_UT} ${DEPEND_LIBS}) diff --git a/test/unit_test/capi/test_hcom_def_inner_c.cpp b/test/unit_test/capi/test_hcom_def_inner_c.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6758da7f93fed8b8218d1b41c72bda20e899a204 --- /dev/null +++ b/test/unit_test/capi/test_hcom_def_inner_c.cpp @@ -0,0 +1,363 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "hcom_def_inner_c.h" +#include "net_rdma_async_endpoint.h" + +namespace ock { +namespace hcom { +class TestHcomDefInnerC : public testing::Test { +public: + TestHcomDefInnerC(); + virtual void SetUp(void); + virtual void TearDown(void); + EpHdlAdp *testEpHdlAdp = nullptr; + EpOpHdlAdp *testEpOpHdlAdp = nullptr; + OOBSecInfoProviderAdp *testOobSecInfoProviderAdp = nullptr; + OOBSecInfoValidatorAdp *testOobSecInfoValidatorAdp = nullptr; + OOBPskUseSessionAdp *testOobPskUseSessionAdp = nullptr; + OOBPskFindSessionAdp *testOobPskFindSessionAdp = nullptr; + EpIdleHdlAdp *testEpIdleHdlAdp = nullptr; + EpTLSHdlAdp *testEpTLSHdlAdp = nullptr; + ServiceHdlAdp *testServiceHdlAdp = nullptr; + ChannelOpHdlAdp *testChannelOpHdlAdp = nullptr; + ServiceIdleHdlAdp *testServiceIdleHdlAdp = nullptr; +}; + +TestHcomDefInnerC::TestHcomDefInnerC() {} + +void TestHcomDefInnerC::SetUp() +{ + Net_EPHandlerType t = C_EP_NEW; + Net_EPHandler h = nullptr; + Net_RequestHandler handler = nullptr; + uint64_t usrCtx = 0; + Net_SecInfoProvider provider = nullptr; + Net_SecInfoValidator validator = nullptr; + Net_PskUseSessionCb cb = nullptr; + Net_PskFindSessionCb FindCb = nullptr; + Net_IdleHandler idleHandler = nullptr; + ubs_hcom_service_handler_type serviceHandlerType = C_CHANNEL_NEW; + ubs_hcom_service_channel_policy policy = C_CHANNEL_BROKEN_ALL; + ubs_hcom_service_channel_handler channelHandler = nullptr; + ubs_hcom_service_request_handler requestHandler = nullptr; + testEpHdlAdp = new (std::nothrow) EpHdlAdp(t, h, usrCtx); + testEpOpHdlAdp = new (std::nothrow) EpOpHdlAdp(handler, usrCtx); + testOobSecInfoProviderAdp = new (std::nothrow) OOBSecInfoProviderAdp(provider); + testOobSecInfoValidatorAdp = new (std::nothrow) OOBSecInfoValidatorAdp(validator); + testOobPskUseSessionAdp = new (std::nothrow) OOBPskUseSessionAdp(cb); + testOobPskFindSessionAdp = new (std::nothrow) OOBPskFindSessionAdp(FindCb); + testEpIdleHdlAdp = new (std::nothrow) EpIdleHdlAdp(idleHandler, usrCtx); + testEpTLSHdlAdp = new (std::nothrow) EpTLSHdlAdp(); + testServiceHdlAdp = new (std::nothrow) ServiceHdlAdp(serviceHandlerType, policy, channelHandler, usrCtx); + testChannelOpHdlAdp = new (std::nothrow) ChannelOpHdlAdp(requestHandler, usrCtx); + testServiceIdleHdlAdp = new (std::nothrow) ServiceIdleHdlAdp(idleHandler, usrCtx); +} + +void TestHcomDefInnerC::TearDown() +{ + if (testEpHdlAdp != nullptr) { + delete testEpHdlAdp; + testEpHdlAdp = nullptr; + } + + if (testEpOpHdlAdp != nullptr) { + delete testEpOpHdlAdp; + testEpOpHdlAdp = nullptr; + } + + if (testOobSecInfoProviderAdp != nullptr) { + delete testOobSecInfoProviderAdp; + testOobSecInfoProviderAdp = nullptr; + } + + if (testOobSecInfoValidatorAdp != nullptr) { + delete testOobSecInfoValidatorAdp; + testOobSecInfoValidatorAdp = nullptr; + } + + if (testOobPskUseSessionAdp != nullptr) { + delete testOobPskUseSessionAdp; + testOobPskUseSessionAdp = nullptr; + } + + if (testOobPskFindSessionAdp != nullptr) { + delete testOobPskFindSessionAdp; + testOobPskFindSessionAdp = nullptr; + } + + if (testEpIdleHdlAdp != nullptr) { + delete testEpIdleHdlAdp; + testEpIdleHdlAdp = nullptr; + } + + if (testEpTLSHdlAdp != nullptr) { + delete testEpTLSHdlAdp; + testEpTLSHdlAdp = nullptr; + } + + if (testServiceHdlAdp != nullptr) { + delete testServiceHdlAdp; + testServiceHdlAdp = nullptr; + } + + if (testChannelOpHdlAdp != nullptr) { + delete testChannelOpHdlAdp; + testChannelOpHdlAdp = nullptr; + } + + if (testServiceIdleHdlAdp != nullptr) { + delete testServiceIdleHdlAdp; + testServiceIdleHdlAdp = nullptr; + } + + GlobalMockObject::verify(); +} + +TEST_F(TestHcomDefInnerC, EpHdlAdpNewEndPointNullErr) +{ + uint64_t epId = NN_NO8; + UBSHcomNetWorkerIndex workerIndex{}; + uint32_t workerIdx = 5; + uint32_t gIdx = 6; + uint16_t dIdx = 7; + workerIndex.Set(workerIdx, gIdx, dIdx); + UBSHcomNetEndpointPtr newEP = new (std::nothrow) NetAsyncEndpoint(epId, nullptr, nullptr, workerIndex); + std::string ipPort = "1.2.3.4:1234"; + std::string payload = "payload"; + + testEpHdlAdp->mHandler = nullptr; + int ret = testEpHdlAdp->NewEndPoint(ipPort, newEP, payload); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestHcomDefInnerC, EpHdlAdpEndPointBrokenNullErr) +{ + uint64_t epId = NN_NO8; + UBSHcomNetWorkerIndex workerIndex{}; + uint32_t workerIdx = 5; + uint32_t gIdx = 6; + uint16_t dIdx = 7; + workerIndex.Set(workerIdx, gIdx, dIdx); + UBSHcomNetEndpointPtr newEP = new (std::nothrow) NetAsyncEndpoint(epId, nullptr, nullptr, workerIndex); + + testEpHdlAdp->mHandler = nullptr; + EXPECT_NO_FATAL_FAILURE(testEpHdlAdp->EndPointBroken(newEP)); +} + +TEST_F(TestHcomDefInnerC, EpOpHdlAdpRequestedNullErr) +{ + UBSHcomNetRequestContext ctx{}; + testEpOpHdlAdp->mHandler = nullptr; + int ret = testEpOpHdlAdp->Requested(ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestHcomDefInnerC, EpOpHdlAdpRequestedSentMemCpyErr) +{ + UBSHcomNetRequestContext ctx{}; + ctx.mOpType = UBSHcomNetRequestContext::NN_SENT; + testEpOpHdlAdp->mHandler = [](Net_RequestContext *, uint64_t) { return 1; }; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + int ret = testEpOpHdlAdp->Requested(ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestHcomDefInnerC, EpOpHdlAdpRequestedWrittenMemCpyErr) +{ + UBSHcomNetRequestContext ctx{}; + ctx.mOpType = UBSHcomNetRequestContext::NN_WRITTEN; + testEpOpHdlAdp->mHandler = [](Net_RequestContext *, uint64_t) { return 1; }; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + int ret = testEpOpHdlAdp->Requested(ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestHcomDefInnerC, EpOpHdlAdpRequestedSglWrittenMemCpyErr) +{ + UBSHcomNetRequestContext ctx{}; + ctx.mOpType = UBSHcomNetRequestContext::NN_SGL_WRITTEN; + testEpOpHdlAdp->mHandler = [](Net_RequestContext *, uint64_t) { return 1; }; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + int ret = testEpOpHdlAdp->Requested(ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestHcomDefInnerC, OOBSecInfoProviderAdpCreateSecInfoNullErr) +{ + testOobSecInfoProviderAdp->mProvider = nullptr; + UBSHcomNetDriverSecType type = NET_SEC_DISABLED; + int64_t flag = 0; + char *output = nullptr; + uint32_t outLen = 0; + bool needAutoFree = true; + int ret = testOobSecInfoProviderAdp->CreateSecInfo(0, flag, type, output, outLen, needAutoFree); + EXPECT_EQ(ret, -1); +} + +TEST_F(TestHcomDefInnerC, OOBSecInfoProviderAdpCreateSecInfo) +{ + testOobSecInfoProviderAdp->mProvider + = [](uint64_t, int64_t *, Net_DriverSecType *, char **, uint32_t *, int *e) { return 0; }; + UBSHcomNetDriverSecType type = NET_SEC_DISABLED; + int64_t flag = 0; + char *output = nullptr; + uint32_t outLen = 0; + bool needAutoFree = true; + int ret = testOobSecInfoProviderAdp->CreateSecInfo(0, flag, type, output, outLen, needAutoFree); + EXPECT_EQ(ret, 0); + + type = NET_SEC_VALID_ONE_WAY; + ret = testOobSecInfoProviderAdp->CreateSecInfo(0, flag, type, output, outLen, needAutoFree); + EXPECT_EQ(ret, 0); + + type = NET_SEC_VALID_TWO_WAY; + ret = testOobSecInfoProviderAdp->CreateSecInfo(0, flag, type, output, outLen, needAutoFree); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestHcomDefInnerC, OOBSecInfoValidatorAdpSecInfoValidate) +{ + testOobSecInfoValidatorAdp->mValidator = nullptr; + int64_t flag = 0; + char *intput = nullptr; + uint32_t inputLen = 0; + int ret = testOobSecInfoValidatorAdp->SecInfoValidate(0, flag, intput, inputLen); + EXPECT_EQ(ret, -1); + + testOobSecInfoValidatorAdp->mValidator = [](uint64_t, int64_t, const char *, uint32_t) { return 0; }; + ret = testOobSecInfoValidatorAdp->SecInfoValidate(0, flag, intput, inputLen); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestHcomDefInnerC, OOBPskUseSessionAdpUseSession) +{ + testOobPskUseSessionAdp->mCb = nullptr; + void *ssl = nullptr; + const void *md = nullptr; + const unsigned char **id = nullptr; + size_t *idLen = nullptr; + void **session = nullptr; + int ret = testOobPskUseSessionAdp->UseSession(ssl, md, id, idLen, session); + EXPECT_EQ(ret, 0); + + testOobPskUseSessionAdp->mCb = [](void *, const void *, const unsigned char **, size_t *, void **) { return 1; }; + ret = testOobPskUseSessionAdp->UseSession(ssl, md, id, idLen, session); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestHcomDefInnerC, OOBPskFindSessionAdpFindSession) +{ + testOobPskFindSessionAdp->mCb = nullptr; + void *ssl = nullptr; + unsigned char *identity = nullptr; + size_t identityLen = 0; + void **session = nullptr; + int ret = testOobPskFindSessionAdp->FindSession(ssl, identity, identityLen, session); + EXPECT_EQ(ret, 0); + + testOobPskFindSessionAdp->mCb = [](void *, const unsigned char *, size_t, void **) { return 1; }; + ret = testOobPskFindSessionAdp->FindSession(ssl, identity, identityLen, session); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestHcomDefInnerC, EpIdleHdlAdpIdleHandlerNullErr) +{ + testEpIdleHdlAdp->mHandler = nullptr; + UBSHcomNetWorkerIndex index{}; + EXPECT_NO_FATAL_FAILURE(testEpIdleHdlAdp->Idle(index)); +} + +TEST_F(TestHcomDefInnerC, EpTLSHdlAdpTLSCertificationCallbackErr) +{ + testEpTLSHdlAdp->mGetCert = nullptr; + std::string name = "name"; + std::string path = "path"; + bool ret = testEpTLSHdlAdp->UBSHcomTLSCertificationCallback(name, path); + EXPECT_EQ(ret, false); + + testEpTLSHdlAdp->mGetCert = [](const char *, char **) { return 0; }; + ret = testEpTLSHdlAdp->UBSHcomTLSCertificationCallback(name, path); + EXPECT_EQ(ret, false); +} + +TEST_F(TestHcomDefInnerC, EpTLSHdlAdpTLSPrivateKeyCallbackErr) +{ + testEpTLSHdlAdp->mGetPriKey = nullptr; + std::string name = "name"; + std::string path = "path"; + void *keyPass = nullptr; + int len = 0; + UBSHcomTLSEraseKeypass cb; + bool ret = testEpTLSHdlAdp->UBSHcomTLSPrivateKeyCallback(name, path, keyPass, len, cb); + EXPECT_EQ(ret, false); + + testEpTLSHdlAdp->mGetPriKey = [](const char *, char **, char **, Net_TlsKeyPassErase *) { return 0; }; + ret = testEpTLSHdlAdp->UBSHcomTLSPrivateKeyCallback(name, path, keyPass, len, cb); + EXPECT_EQ(ret, false); +} + +TEST_F(TestHcomDefInnerC, EpTLSHdlAdpTLSCaCallbackErr) +{ + testEpTLSHdlAdp->mGetCA = nullptr; + std::string name = "name"; + std::string caPath = "caPath"; + std::string crlPath = "crlPath"; + UBSHcomPeerCertVerifyType peerCertVerifyType = VERIFY_BY_DEFAULT; + UBSHcomTLSCertVerifyCallback cb; + bool ret = testEpTLSHdlAdp->UBSHcomTLSCaCallback(name, caPath, crlPath, peerCertVerifyType, cb); + EXPECT_EQ(ret, false); + + testEpTLSHdlAdp->mGetCA + = [](const char *, char **, char **, Net_PeerCertVerifyType *, Net_TlsCertVerify *) { return 0; }; + ret = testEpTLSHdlAdp->UBSHcomTLSCaCallback(name, caPath, crlPath, peerCertVerifyType, cb); + EXPECT_EQ(ret, false); +} + +TEST_F(TestHcomDefInnerC, ServiceHdlAdpNewChannelErr) +{ + testServiceHdlAdp->mHandler = nullptr; + std::string ipPort = "1.2.3.4:1234"; + NetChannelPtr newCh = nullptr; + std::string payload = "payload"; + int ret = testServiceHdlAdp->NewChannel(ipPort, newCh, payload); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestHcomDefInnerC, ServiceHdlAdpChannelBrokenErr) +{ + testServiceHdlAdp->mHandler = nullptr; + NetChannelPtr ch = nullptr; + EXPECT_NO_FATAL_FAILURE(testServiceHdlAdp->ChannelBroken(ch)); +} + +TEST_F(TestHcomDefInnerC, ChannelOpHdlAdpRequestedErr) +{ + testChannelOpHdlAdp->mHandler = nullptr; + NetServiceContext ctx{}; + int ret = testChannelOpHdlAdp->Requested(ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestHcomDefInnerC, ServiceIdleHdlAdpIdle) +{ + testServiceIdleHdlAdp->mServiceHandler = nullptr; + UBSHcomNetWorkerIndex index{}; + EXPECT_NO_FATAL_FAILURE(testServiceIdleHdlAdp->Idle(index)); + + testServiceIdleHdlAdp->mServiceHandler = [](uint8_t, uint16_t, uint64_t) {}; + EXPECT_NO_FATAL_FAILURE(testServiceIdleHdlAdp->Idle(index)); +} +} +} \ No newline at end of file diff --git a/test/unit_test/capi_v2/test_hcom_c.cpp b/test/unit_test/capi_v2/test_hcom_c.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f31dcf25591495c7d529567d21ac8b819efb6054 --- /dev/null +++ b/test/unit_test/capi_v2/test_hcom_c.cpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "hcom_c.h" +#include "hcom_service_c.h" +#include "service_channel_imp.h" +#include "hcom.h" + +namespace ock { +namespace hcom { +class TestHcomCapi : public testing::Test { +public: + TestHcomCapi(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +TestHcomCapi::TestHcomCapi() {} + +void TestHcomCapi::SetUp() +{ +} + +void TestHcomCapi::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestHcomCapi, TestCopySglInfo) +{ + ubs_hcom_readwrite_request_sgl *src + = static_cast(malloc(sizeof(ubs_hcom_readwrite_request_sgl))); + ASSERT_NE(src, nullptr); + bzero(src, sizeof(ubs_hcom_readwrite_request_sgl)); + EXPECT_EQ(ubs_hcom_ep_post_send_raw_sgl(1, src, 1), static_cast(NN_INVALID_PARAM)); + free(src); +} + +TEST_F(TestHcomCapi, TestSendRecvFds) +{ + EXPECT_EQ(ubs_hcom_channel_send_fds(0, nullptr, 0), static_cast(SER_INVALID_PARAM)); + EXPECT_EQ(ubs_hcom_channel_recv_fds(0, nullptr, 0, 0), static_cast(SER_INVALID_PARAM)); + InnerConnectOptions opt {}; + UBSHcomChannel *ch = new (std::nothrow) HcomChannelImp(0, false, opt); + EXPECT_NE(ch, nullptr); + ubs_hcom_channel channel = reinterpret_cast(ch); + EXPECT_EQ(ubs_hcom_channel_send_fds(channel, nullptr, 0), static_cast(SER_ERROR)); + EXPECT_EQ(ubs_hcom_channel_recv_fds(channel, nullptr, 0, 0), static_cast(SER_ERROR)); + channel = 0; + delete ch; +} + +TEST_F(TestHcomCapi, TestConvertServiceConnectOptionsToInnerOptions) +{ + ubs_hcom_service_options opt {}; + ubs_hcom_service service = 0; + int ret = ubs_hcom_service_create(C_SERVICE_RDMA, "service0", opt, &service); + ASSERT_EQ(ret, 0); + ubs_hcom_service_connect_options connectOpt {}; + connectOpt.mode = C_CLIENT_SELF_POLL_BUSY; + ubs_hcom_channel channel = 0; + EXPECT_NE(ubs_hcom_service_connect(service, "url", &channel, connectOpt), 0); + connectOpt.mode = C_CLIENT_SELF_POLL_EVENT; + EXPECT_NE(ubs_hcom_service_connect(service, "url", &channel, connectOpt), 0); + ubs_hcom_service_destroy(service, "service0"); +} + +TEST_F(TestHcomCapi, TestSetTlsOptions) +{ + ubs_hcom_service_options opt {}; + ubs_hcom_service service = 0; + int ret = ubs_hcom_service_create(C_SERVICE_RDMA, "service0", opt, &service); + ASSERT_EQ(ret, 0); + EXPECT_NO_FATAL_FAILURE(ubs_hcom_service_set_tls_opt(service, false, C_SERVICE_TLS_1_2, C_SERVICE_AES_GCM_128, + nullptr, nullptr, nullptr)); + ubs_hcom_service_destroy(service, "service0"); +} + +TEST_F(TestHcomCapi, TestSetUbcMode) +{ + ubs_hcom_service_options opt {}; + ubs_hcom_service service = 0; + int ret = ubs_hcom_service_create(C_SERVICE_UBC, "service0", opt, &service); + ASSERT_EQ(ret, 0); + ubs_hcom_service_ubc_mode ubcMode = C_SERVICE_HIGHBANDWIDTH; + EXPECT_NO_FATAL_FAILURE(ubs_hcom_service_set_ubcmode(service, ubcMode)); + ubs_hcom_service_destroy(service, "service0"); +} + +void CommonCb(void *arg, ubs_hcom_service_context context) +{ + return; +} + +TEST_F(TestHcomCapi, TestChannelRecv) +{ + EXPECT_EQ(ubs_hcom_channel_recv(0, 0, 0, 0, nullptr), static_cast(SER_INVALID_PARAM)); + + InnerConnectOptions opt{}; + UBSHcomChannel *ch = new (std::nothrow) HcomChannelImp(0, false, opt); + EXPECT_NE(ch, nullptr); + ubs_hcom_channel channel = reinterpret_cast(ch); + EXPECT_EQ(ubs_hcom_channel_recv(channel, 0, 0, 0, nullptr), static_cast(SER_INVALID_PARAM)); + + UBSHcomServiceContext ctx{}; + ubs_hcom_service_context serviceContext = reinterpret_cast(&ctx); + EXPECT_EQ(ubs_hcom_channel_recv(channel, serviceContext, 0, 0, nullptr), static_cast(SER_INVALID_PARAM)); + + uintptr_t address = NN_NO100; + EXPECT_EQ(ubs_hcom_channel_recv(channel, serviceContext, address, 0, nullptr), + static_cast(SER_INVALID_PARAM)); + + MOCKER_CPP_VIRTUAL(*ch, &UBSHcomChannel::Recv).stubs().will(returnValue(static_cast(SER_OK))); + uint32_t size = NN_NO1024; + EXPECT_EQ(ubs_hcom_channel_recv(channel, serviceContext, address, size, nullptr), static_cast(SER_OK)); + + ubs_hcom_channel_callback cb; + cb.cb = CommonCb; + cb.arg = NULL; + EXPECT_EQ(ubs_hcom_channel_recv(channel, serviceContext, address, size, &cb), static_cast(SER_OK)); + channel = 0; + delete ch; +} + +TEST_F(TestHcomCapi, SetTwoSideThreshold) +{ + ubs_hcom_twoside_threshold twoSideThreshold; + twoSideThreshold.splitThreshold = NN_NO8192; + twoSideThreshold.rndvThreshold = NN_NO8192; + EXPECT_EQ(ubs_hcom_channel_set_twoside_threshold(0, twoSideThreshold), static_cast(SER_INVALID_PARAM)); + + InnerConnectOptions opt{}; + UBSHcomChannel *ch = new (std::nothrow) HcomChannelImp(0, false, opt); + EXPECT_NE(ch, nullptr); + ubs_hcom_channel channel = reinterpret_cast(ch); + MOCKER_CPP_VIRTUAL(*ch, &UBSHcomChannel::SetTwoSideThreshold).stubs().will(returnValue(static_cast(SER_OK))); + EXPECT_EQ(ubs_hcom_channel_set_twoside_threshold(channel, twoSideThreshold), static_cast(SER_OK)); + + channel = 0; + delete ch; +} +} +} \ No newline at end of file diff --git a/test/unit_test/common/test_log.cpp b/test/unit_test/common/test_log.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a25ac4b614fa0f474c37d1f127a469d331eb3b3a --- /dev/null +++ b/test/unit_test/common/test_log.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +// Author: zhiwei + +#include +#include +#include + +#include "hcom_log.h" + +using ::testing::StartsWith; + +namespace ock { +namespace hcom { +class TestLog : public testing::Test { +public: + void SetUp() override + { + } + + void TearDown() override + { + GlobalMockObject::verify(); + } +}; + +TEST_F(TestLog, gettimeofday) +{ + MOCKER_CPP(gettimeofday).stubs().will(returnValue(-1)); + + testing::internal::CaptureStdout(); + UBSHcomNetOutLogger::Print(0, "something went wrong"); + std::string out = testing::internal::GetCapturedStdout(); + EXPECT_THAT(out, StartsWith("Fail to get the current system time, -1.")); +} + +TEST_F(TestLog, localtime_r) +{ + MOCKER_CPP(localtime_r).stubs().will(returnValue(static_cast(nullptr))); + + testing::internal::CaptureStdout(); + UBSHcomNetOutLogger::Print(0, "something went wrong"); + std::string out = testing::internal::GetCapturedStdout(); + EXPECT_THAT(out, StartsWith("Invalid time trace")); +} + +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/common/test_net_mem_allocator.cpp b/test/unit_test/common/test_net_mem_allocator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0cc1aed6a83374381c67f741e90ca85ba02c646 --- /dev/null +++ b/test/unit_test/common/test_net_mem_allocator.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include + +#include "net_util.h" +#include "net_addr_size_map.h" +#include "net_mem_allocator_cache.h" + +namespace ock { +namespace hcom { + +class TestNetMemAllocator : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + UBSHcomNetMemoryAllocatorOptions mOptions {}; + void *mAddress = nullptr; +}; + +void TestNetMemAllocator::SetUp() +{ + mAddress = memalign(NN_NO4096, 256UL << NN_NO24); + mOptions.address = reinterpret_cast(mAddress); + mOptions.size = 256UL << NN_NO24; + mOptions.minBlockSize = NN_NO4096; + mOptions.alignedAddress = true; +} + +void TestNetMemAllocator::TearDown() +{ + GlobalMockObject::verify(); + free(mAddress); +} + +TEST_F(TestNetMemAllocator, InitializeSuccess) +{ + NetLocalAutoDecreasePtr alloc(new (std::nothrow) NetMemAllocator()); + int ret = alloc.Get()->Initialize(mOptions.address, mOptions.size, mOptions.minBlockSize, mOptions.alignedAddress); + EXPECT_EQ(ret, 0); + NetLocalAutoDecreasePtr cache(new (std::nothrow) NetAllocatorCache(alloc.Get())); + ret = cache.Get()->Initialize(mOptions); + EXPECT_EQ(ret, 0); + ret = cache.Get()->Initialize(mOptions); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetMemAllocator, InitializeFail_INVALID_PARAM) +{ + NetLocalAutoDecreasePtr alloc(new (std::nothrow) NetMemAllocator()); + int ret = alloc.Get()->Initialize(mOptions.address, mOptions.size, mOptions.minBlockSize, mOptions.alignedAddress); + EXPECT_EQ(ret, 0); + NetLocalAutoDecreasePtr cache(new (std::nothrow) NetAllocatorCache(alloc.Get())); + + mOptions.cacheTierPolicy = TIER_POWER; + mOptions.cacheTierCount = NN_NO32; + ret = cache.Get()->Initialize(mOptions); + EXPECT_EQ(ret, NN_INVALID_PARAM); + mOptions.cacheTierPolicy = TIER_TIMES; + mOptions.cacheTierCount = NN_NO8; + + cache.Get()->mMajorAllocator = nullptr; + ret = cache.Get()->Initialize(mOptions); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + mOptions.cacheBlockCountPerTier = NN_NO0; + ret = cache.Get()->Initialize(mOptions); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + mOptions.cacheTierCount = NN_NO0; + ret = cache.Get()->Initialize(mOptions); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetMemAllocator, InitializeFail_NEW_OBJECT_FAILED) +{ + NetLocalAutoDecreasePtr alloc(new (std::nothrow) NetMemAllocator()); + int ret = alloc.Get()->Initialize(mOptions.address, mOptions.size, mOptions.minBlockSize, mOptions.alignedAddress); + EXPECT_EQ(ret, 0); + NetLocalAutoDecreasePtr cache(new (std::nothrow) NetAllocatorCache(alloc.Get())); + + MOCKER_CPP(&NetAddress2SizeHashMap::Initialize).stubs().will(returnValue(1)); + ret = cache.Get()->Initialize(mOptions); + EXPECT_EQ(ret, NN_NEW_OBJECT_FAILED); +} + +} // namespace hcom +} // namespace ock \ No newline at end of file diff --git a/test/unit_test/common/test_net_mem_pool.cpp b/test/unit_test/common/test_net_mem_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..71c18fb1c5662e5a9d35d0e9e8228d3718b2fca2 --- /dev/null +++ b/test/unit_test/common/test_net_mem_pool.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include + +#include "net_util.h" +#include "net_mem_pool_fixed.h" + +namespace ock { +namespace hcom { + +class TestNetMemPool : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + NetMemPoolFixedOptions options {}; + NetMemPoolFixedPtr globalPool = nullptr; +}; + +void TestNetMemPool::SetUp() {} + +void TestNetMemPool::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestNetMemPool, Fixed) +{ + /* init global pool options */ + options.superBlkSizeMB = NN_NO4; + options.tcExpandBlkCnt = NN_NO64; + options.minBlkSize = NN_NO64; + + NetLocalAutoDecreasePtr netPtr(new (std::nothrow) NetMemPoolFixed("test", options)); + globalPool = netPtr.Get(); + EXPECT_NE(globalPool, nullptr); + int ret = globalPool->Initialize(); + EXPECT_EQ(ret, 0); + globalPool->mFreeCount = 0; + + thread_local NetTCacheFixed tc(globalPool.Get()); + NN_LOG_INFO("mem pool mFreeCount " << globalPool->mFreeCount); + char *pointer = tc.Allocate(); + EXPECT_NE(pointer, nullptr); + NN_LOG_INFO(tc.ToString()); + tc.Free(pointer); +} + +TEST_F(TestNetMemPool, KeyedThreadLocalCache) +{ + NetMemPoolFixedOptions options{}; + options.superBlkSizeMB = NN_NO1; + options.tcExpandBlkCnt = NN_NO8; // 每个扩容时小块个数 + options.minBlkSize = NN_NO64; // 每个小块大小 + + NetLocalAutoDecreasePtr mempool(new (std::nothrow) NetMemPoolFixed("keyed", options)); + mempool.Get()->Initialize(); + + KeyedThreadLocalCache<4> cache; + + // Allocate + // key 超过最大值,更新失败 + EXPECT_NO_THROW(cache.UpdateIf(11, mempool.Get())); + EXPECT_EQ(cache.Allocate(11), nullptr); + + // key 更新成功,使用 0 号 cache 分配内存 + EXPECT_NO_THROW(cache.UpdateIf(0, mempool.Get())); + int *arr[16] = {nullptr}; + for (auto &a : arr) { + a = cache.Allocate(0); + EXPECT_NE(a, nullptr); + } + + // key 超过最大值,alloc 失败 + EXPECT_EQ(cache.Allocate(11), nullptr); + + // key 对应的 thread_local cache 不存在 + EXPECT_EQ(cache.Allocate(1), nullptr); + + // Free + // key 超过最大值,归还失败 + EXPECT_NO_THROW(cache.Free(11, nullptr)); + + // key 对应的 thread_local cache 不存在 + EXPECT_NO_THROW(cache.Free(1, nullptr)); + + // 归还成功,由于一次性归还了 16 个小块,达到了回收至上层 mempool 的要求,当前 cache[0] 本地的 freelist 将归还一半, + // 剩余 8 个小块 + for (auto &a : arr) { + EXPECT_NO_THROW(cache.Free(0, a)); + } + EXPECT_EQ(cache.mTCacheFixeds[0]->mCurrentFree, 8); +} + +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/common/test_net_pgtable.cpp b/test/unit_test/common/test_net_pgtable.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f6b2d5d906cda4bca8030dd92540a23d4f44d95 --- /dev/null +++ b/test/unit_test/common/test_net_pgtable.cpp @@ -0,0 +1,445 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include +#include + +#include "net_pgtable.h" + +constexpr size_t AlignDown(size_t n, size_t alignment) +{ + return n - (n % alignment); +} + +namespace ock { +namespace hcom { +class TestPgTable : public testing::Test { +public: + void SetUp() override {} + + void TearDown() override + { + GlobalMockObject::verify(); + } + +protected: + using SearchResult = std::vector; + PgTable mPgTable { pgdAlloc, pgdFree }; + + SearchResult Search(PgtAddress from, PgtAddress to) + { + NN_LOG_INFO("Begin to search from " + << "[0x" << std::hex << from << ".. 0x" << std::hex << to << "]"); + SearchResult result; + mPgTable.SearchRange(from, to, pgdSearchCb, reinterpret_cast(&result)); + return result; + } + + static PgtRegion *MakeRegion(PgtAddress start, PgtAddress end) + { + PgtRegion r = { start, end }; + return new PgtRegion(r); + } + + static bool IsOverlap(const PgtRegion *region, PgtAddress from, PgtAddress to) + { + NN_LOG_DEBUG("regions" << region << " in the range 0x" << std::hex << region->start << "..0x" << region->end << + " from 0x" << from << " to: 0x" << to); + return std::max(from, region->start) <= std::min(to, region->end); + } + + static uint32_t CountOverlap(const std::vector ®ions, PgtAddress from, PgtAddress to) + { + uint32_t count = 0; + for (const auto &item : regions) { + if (IsOverlap(item, from, to)) { + ++count; + } + } + return count; + } + + void TestSearchRegion(const PgtRegion ®ion) + { + SearchResult result; + + result = Search(region.start, region.end - 1); + ASSERT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion, result.front()); + + result = Search(region.start, region.end); + ASSERT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion, result.front()); + + result = Search(region.start, region.end + 1); + ASSERT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion, result.front()); + } + +private: + static PgtDir *pgdAlloc(const PgTable &pgtable) + { + return new (std::nothrow) PgtDir; + } + + static PgtDir *pgdAllocFailed(const PgTable &pgtable) + { + return nullptr; + } + + static void pgdFree(const PgTable &pgtable, PgtDir *pgdir) + { + delete pgdir; + } + + static void pgdSearchCb(const PgTable &pgtable, PgtRegion ®ion, void *arg) + { + NN_LOG_INFO("find the region push to result " << ®ion << "[0x" << std::hex << region.start << ".. 0x" << + std::hex << region.end << "]"); + SearchResult *result = reinterpret_cast(arg); + result->push_back(®ion); + } +}; + +TEST_F(TestPgTable, BasicSuccess) +{ + PgtRegion region; + + region.start = 0x600800; + region.end = 0x603400; + + NResult status = mPgTable.Insert(region); + EXPECT_EQ(status, NN_OK); + + mPgTable.Dump(); + + EXPECT_EQ(®ion, mPgTable.Lookup(0x600800)); + EXPECT_EQ(®ion, mPgTable.Lookup(0x602020)); + EXPECT_EQ(®ion, mPgTable.Lookup(0x6033ff)); + EXPECT_EQ(nullptr, mPgTable.Lookup(0x603400)); + EXPECT_EQ(nullptr, mPgTable.Lookup(0x0)); + EXPECT_EQ(nullptr, mPgTable.Lookup(std::numeric_limits::max())); + EXPECT_EQ(NN_NO1, mPgTable.mRegionCount); + + status = mPgTable.Remove(region); + EXPECT_EQ(status, NN_OK); + EXPECT_EQ(NN_NO0, mPgTable.mRegionCount); + + status = mPgTable.Insert(region); + EXPECT_EQ(status, NN_OK); + + mPgTable.Dump(); +} + +TEST_F(TestPgTable, InsertPgdAllocFailed) +{ + PgtRegion region; + region.start = 0x600800; + region.end = 0x603400; + + PgTable pgTable { pgdAllocFailed, pgdFree }; + NResult status = pgTable.Insert(region); + EXPECT_EQ(status, NN_ERROR); +} + +TEST_F(TestPgTable, PgtExpandFailed) +{ + PgTable pgTable { pgdAlloc, pgdFree }; + pgTable.mIndexShift = NN_NO64; + bool ret = pgTable.PgtExpand(); + EXPECT_EQ(ret, false); +} + +TEST_F(TestPgTable, InsertAndLookupAdjSuccess) +{ + // [0xc600000, 0xc600400) [0xc600400, 0xc600800) + PgtRegion region1 = { 0xc600000, 0xc600400 }; + PgtRegion region2 = { 0xc600400, 0xc600800 }; + NResult status = mPgTable.Insert(region1); + EXPECT_EQ(status, NN_OK); + + status = mPgTable.Insert(region2); + EXPECT_EQ(status, NN_OK); + + mPgTable.Dump(); + EXPECT_EQ(®ion2, mPgTable.Lookup(0xc600400)); + EXPECT_EQ(®ion1, mPgTable.Lookup(0xc600000)); + + status = mPgTable.Remove(region1); + EXPECT_EQ(status, NN_OK); + + status = mPgTable.Remove(region2); + EXPECT_EQ(status, NN_OK); +} + +TEST_F(TestPgTable, InsertAlreadyExistFailed) +{ + PgtRegion region1 = { 0x4000, 0x6000 }; + NResult ret = mPgTable.Insert(region1); + EXPECT_EQ(ret, NN_OK); + + PgtRegion region2 = { 0x5000, 0x7000 }; + ret = mPgTable.Insert(region2); + EXPECT_EQ(ret, NN_ERROR); + + PgtRegion region3 = { 0x3000, 0x5000 }; + ret = mPgTable.Insert(region3); + EXPECT_EQ(ret, NN_ERROR); + + ret = mPgTable.Remove(region1); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestPgTable, RemoveNonExistFailed) +{ + PgtRegion region1 = { 0x5000, 0x7000 }; + auto ret = mPgTable.Remove(region1); + EXPECT_EQ(ret, NN_ERROR); + + PgtRegion region2 = { 0x6000, 0x8000 }; + ret = mPgTable.Insert(region2); + EXPECT_EQ(ret, NN_OK); + + ret = mPgTable.Remove(region1); + EXPECT_EQ(ret, NN_ERROR); + + region1.start = 0x6000; + region1.end = 0x6000; + ret = mPgTable.Remove(region1); + EXPECT_EQ(ret, NN_ERROR); + + region1 = region2; + ret = mPgTable.Remove(region1); + EXPECT_EQ(ret, NN_ERROR); /* should be pointer-equal */ + + ret = mPgTable.Remove(region2); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestPgTable, SearchLargeRegionSuccess) +{ + PgtRegion region = { 0x3c03cb00, 0x3c03f600 }; + NResult ret = mPgTable.Insert(region); + EXPECT_EQ(ret, NN_OK); + + SearchResult result; + + result = Search(0x36990000, 0x3c810000); + EXPECT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion, result.front()); + + result = Search(region.start - 1, region.start); + EXPECT_EQ(NN_NO1, result.size()); + + result = Search(region.start, region.start + 1); + EXPECT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion, result.front()); + + result = Search(region.end - 1, region.end); + EXPECT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion, result.front()); + + result = Search(region.end, region.end + 1); + EXPECT_EQ(0u, result.size()); + + ret = mPgTable.Remove(region); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestPgTable, SearchNonContigRegionsSuccess) +{ + const size_t regionSize = (1UL << NN_NO28); + + // insert [0x7f6ef0000000 .. 0x7f6f00000000] + auto start = 0x7f6ef0000000; + auto end = start + regionSize; + PgtRegion region1 = { start, end }; + NResult ret = mPgTable.Insert(region1); + EXPECT_EQ(ret, NN_OK); + + // insert [0x7f6f2c021000 .. 0x7f6f3c021000] + start = 0x7f6f2c021000; + end = start + regionSize; + PgtRegion region2 = { start, end }; + ret = mPgTable.Insert(region2); + EXPECT_EQ(ret, NN_OK); + + // insert [0x7f6f42000000 .. 0x7f6f52000000] + start = 0x7f6f42000000; + end = start + regionSize; + PgtRegion region3 = { start, end }; + ret = mPgTable.Insert(region3); + EXPECT_EQ(ret, NN_OK); + + SearchResult result; + + // search the 1st region + TestSearchRegion(region1); + + // search the 2nd region + TestSearchRegion(region2); + + // search the 3rd region + TestSearchRegion(region3); + + ret = mPgTable.Remove(region1); + EXPECT_EQ(ret, NN_OK); + + ret = mPgTable.Remove(region2); + EXPECT_EQ(ret, NN_OK); + + ret = mPgTable.Remove(region3); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestPgTable, SearchAdjRegionsSuccess) +{ + const size_t regionSize = (1UL << NN_NO28); + // insert [0x7f6ef0000000 .. 0x7f6f00000000] + auto start = 0x7f6ef0000000; + auto end = start + regionSize; + PgtRegion region1 = { start, end }; + NResult ret = mPgTable.Insert(region1); + EXPECT_EQ(ret, NN_OK); + + // insert [0x7f6f00000000 .. 0x7f6f10000000] + start = end; + end = start + regionSize; + PgtRegion region2 = { region1.end, 0x7f6f40000000 }; + ret = mPgTable.Insert(region2); + EXPECT_EQ(ret, NN_OK); + + // insert [0x7f6f10000000 .. 0x7f6f20000000] + start = end; + end = start + regionSize; + PgtRegion region3 = { region2.end, 0x7f6f48000000 }; + ret = mPgTable.Insert(region3); + EXPECT_EQ(ret, NN_OK); + + SearchResult result; + + // search the 1st region + result = Search(region1.start, region1.end - 1); + EXPECT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion1, result.front()); + + result = Search(region1.start, region1.end); + EXPECT_EQ(NN_NO2, result.size()); + EXPECT_EQ(®ion1, result.front()); + + result = Search(region1.start, region1.end + 1); + EXPECT_EQ(NN_NO2, result.size()); + EXPECT_EQ(®ion1, result.front()); + + // search the 2nd region + result = Search(region2.start, region2.end - 1); + EXPECT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion2, result.front()); + + result = Search(region2.start, region2.end); + EXPECT_EQ(NN_NO2, result.size()); + EXPECT_EQ(®ion2, result.front()); + + result = Search(region2.start, region2.end + 1); + EXPECT_EQ(NN_NO2, result.size()); + EXPECT_EQ(®ion2, result.front()); + + // search the 3rd region + result = Search(region3.start, region3.end - 1); + EXPECT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion3, result.front()); + + result = Search(region3.start, region3.end); + EXPECT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion3, result.front()); + + result = Search(region3.start, region3.end + 1); + EXPECT_EQ(NN_NO1, result.size()); + EXPECT_EQ(®ion3, result.front()); + + ret = mPgTable.Remove(region1); + EXPECT_EQ(ret, NN_OK); + + ret = mPgTable.Remove(region2); + EXPECT_EQ(ret, NN_OK); + + ret = mPgTable.Remove(region3); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestPgTable, MultiSearchSuccess) +{ + uint32_t ucsRandSeed = ock::hcom::NN_NO1073741824; + std::vector regions; + // Repeat execution 5 times, using different random data each time,to verify the stability of the page table + for (int count = 0; count < NN_NO10; ++count) { + PgtAddress min = std::numeric_limits::max(); + PgtAddress max = 0; + + /* generate random regions */ + uint32_t regionCount = 0; + for (int i = 0; i < NN_NO10; ++i) { + PgtAddress start = (rand_r(&ucsRandSeed) & 0x7fffffff) << NN_NO24; + size_t randomSize = static_cast(rand_r(&ucsRandSeed)); + size_t size = std::min(randomSize, std::numeric_limits::max() - start); + PgtAddress end = start + AlignDown(size, PAGE_ADDR_ALIGN_MIN); + + min = std::min(start, min); + max = std::max(start, max); + auto region = MakeRegion(start, end); + + NN_LOG_INFO("begin to check insert count:" << count << " region index:" << i << + " regions in the range 0x" << std::hex << region->start << "..0x" << region->end << std::dec << + " total num:" << regionCount); + + if (CountOverlap(regions, region->start, region->end) != 0) { + /* Make sure regions do not overlap */ + continue; + } + + regions.push_back(region); + ++regionCount; + } + + /* Insert regions */ + for (const auto &item : regions) { + mPgTable.Insert(*item); + } + + /* Count how many fall in the [1/4, 3/4] range */ + PgtAddress from = ((min * NN_NO90) + (max * NN_NO10)) / NN_NO100; + PgtAddress to = ((min * NN_NO10) + (max * NN_NO90)) / NN_NO100; + uint32_t numInRange = CountOverlap(regions, from, to); + + SearchResult result = Search(from, to); + NN_LOG_INFO("total region num " << regionCount << " found " << result.size() << "/" << numInRange << + " regions in the range 0x" << std::hex << from << "..0x" << to << std::dec); + EXPECT_EQ(numInRange, result.size()); + } +} + +TEST_F(TestPgTable, CleanUpSuccess) +{ + PgtRegion region1 = { 0xc600000, 0xc600400 }; + PgtRegion region2 = { 0xc600400, 0xc600800 }; + PgtRegion region3 = { 0xc600800, 0xc600b00 }; + EXPECT_EQ(mPgTable.Insert(region1), NN_OK); + EXPECT_EQ(mPgTable.Insert(region2), NN_OK); + EXPECT_EQ(mPgTable.Insert(region3), NN_OK); + mPgTable.Dump(); + mPgTable.Cleanup(); + EXPECT_EQ(mPgTable.mRootEntry.IsPresent(), false); +} +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/common/test_net_util.cpp b/test/unit_test/common/test_net_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a52b966c839659e9a1615f7c7af9bdeb75eab06f --- /dev/null +++ b/test/unit_test/common/test_net_util.cpp @@ -0,0 +1,78 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +// Author: zhiwei + +#include + +#include "net_util.h" + +namespace ock { +namespace hcom { + +class TestNetUtil : public testing::Test { +public: + void SetUp() override + { + } + + void TearDown() override + { + } +}; + +TEST_F(TestNetUtil, ScopeExitSimple) +{ + bool flag = false; + auto guard0 = MakeScopeExit([&flag]() { EXPECT_TRUE(flag); }); + auto guard1 = MakeScopeExit([&flag]() { flag = true; }); +} + +TEST_F(TestNetUtil, ScopeExitActive) +{ + auto guard = MakeScopeExit([]() { EXPECT_TRUE(true); }); + EXPECT_TRUE(guard.Active()); +} + +TEST_F(TestNetUtil, ScopeExitDeactivate) +{ + bool flag = true; + auto guard0 = MakeScopeExit([&flag]() { EXPECT_TRUE(flag); }); + auto guard1 = MakeScopeExit([&flag]() { flag = false; }); + + guard1.Deactivate(); + EXPECT_FALSE(guard1.Active()); +} + +TEST_F(TestNetUtil, HexStringToBuffFailed) +{ + uint8_t buf[4]; + + EXPECT_FALSE(HexStringToBuff("112233", sizeof(buf), nullptr)); + EXPECT_FALSE(HexStringToBuff("112233", sizeof(buf), buf)); + EXPECT_FALSE(HexStringToBuff("112233xyz", sizeof(buf), buf)); +} + +TEST_F(TestNetUtil, HexStringToBuffOk) +{ + uint8_t buf[4]; + + EXPECT_TRUE(HexStringToBuff("61626364", sizeof(buf), buf)); + EXPECT_EQ(buf[0], 0x61); + EXPECT_EQ(buf[1], 0x62); + EXPECT_EQ(buf[2], 0x63); + EXPECT_EQ(buf[3], 0x64); +} + +TEST_F(TestNetUtil, BuffToHexStringFailed) +{ + std::string out; + EXPECT_FALSE(BuffToHexString(nullptr, 8, out)); +} + +TEST_F(TestNetUtil, BuffToHexStringOk) +{ + char buf[] = "12345678"; + std::string out = "01234567"; + EXPECT_TRUE(BuffToHexString(buf, sizeof(buf), out)); +} +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/common/trace/test_htracer_manager.cpp b/test/unit_test/common/trace/test_htracer_manager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..098b58e30d598be1e844cb980d8d4a3746e09f29 --- /dev/null +++ b/test/unit_test/common/trace/test_htracer_manager.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +#include +#include +#include + +#include "htracer_manager.h" + +namespace ock { +namespace hcom { +class TestHTracerManager : public testing::Test { +public: + void SetUp() override {} + + void TearDown() override + { + GlobalMockObject::verify(); + } +}; + +TEST_F(TestHTracerManager, TestCreateInstanceSuccess) +{ + TraceManager traceManager {}; + EXPECT_NE(traceManager.CreateInstance(), nullptr); +} + +TEST_F(TestHTracerManager, TestCreateInstanceMemSetFailed) +{ + TraceManager traceManager {}; + MOCKER_CPP(memset_s).stubs().will(returnValue(1)).then(returnValue(-1)); + EXPECT_EQ(traceManager.CreateInstance(), nullptr); +} + +TEST_F(TestHTracerManager, DumpTraceSplitInfoSuccess) +{ + std::string tpName = "test_tracer_manager"; + uint64_t diff = 1; + int32_t retCode = 0; + TraceManager::DumpTraceSplitInfo(tpName, diff, retCode); + ASSERT_EQ(TraceManager::mDumpEnable, false); +} +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/common/trace/test_htracer_rpc_server.cpp b/test/unit_test/common/trace/test_htracer_rpc_server.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84be9b4a4bf752476091f14942055e2c74bc1d4b --- /dev/null +++ b/test/unit_test/common/trace/test_htracer_rpc_server.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +#include +#include +#include + +#include "hcom_num_def.h" +#include "rpc_server.h" +#include "securec.h" + +namespace ock { +namespace hcom { +class TestHTracerRpcServer : public testing::Test { +public: + void SetUp() override {} + + void TearDown() override + { + GlobalMockObject::verify(); + } +}; + +TEST_F(TestHTracerRpcServer, TestRpcServerStartSuccess) +{ + auto FakePort = NN_NO7200; + auto rpcServer = new (std::nothrow) RpcServer(); + EXPECT_NE(rpcServer, nullptr); + EXPECT_EQ(rpcServer->Start(std::to_string(FakePort)), NN_OK); + rpcServer->Stop(); + delete rpcServer; +} + +TEST_F(TestHTracerRpcServer, TestRpcServerMemSetFailed) +{ + auto FakePort = NN_NO7200; + auto rpcServer = new (std::nothrow) RpcServer(); + EXPECT_NE(rpcServer, nullptr); + MOCKER_CPP(memset_s).stubs().will(returnValue(1)).then(returnValue(-1)); + EXPECT_EQ(rpcServer->Start(std::to_string(FakePort)), SER_ERROR); + rpcServer->Stop(); + delete rpcServer; +} + +TEST_F(TestHTracerRpcServer, TestRpcServerMemCpyFailed) +{ + auto FakePort = NN_NO7200; + auto rpcServer = new (std::nothrow) RpcServer(); + EXPECT_NE(rpcServer, nullptr); + MOCKER_CPP(memcpy_s).stubs().will(returnValue(1)).then(returnValue(-1)); + EXPECT_EQ(rpcServer->Start(std::to_string(FakePort)), SER_ERROR); + rpcServer->Stop(); + delete rpcServer; +} +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/common/trace/test_msg.cpp b/test/unit_test/common/trace/test_msg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ca687537ca49f230c0e96a7363a7efeeed71d53f --- /dev/null +++ b/test/unit_test/common/trace/test_msg.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "test_msg.h" + +using namespace std; +using namespace ock::hcom; +void TestMsg::SetUp() const {} +void TestMsg::TearDown() const {} + +void TestMsg::SetUpTestCase() {} +void TestMsg::TearDownTestCase() {} + +TEST_F(TestMsg, test_msg_response_give_nullptr_return_ok) +{ + EXPECT_EQ(HTracerInit("30000"), NN_OK); + + TRACE_DELAY_BEGIN(0); + TRACE_DELAY_END(0, 0); + auto tTranceInfos = TracerServiceHelper::GetTraceInfos(0, 0, TraceManager::IsLatencyQuantileEnable()); + Message response(nullptr, 0); + EXPECT_EQ(QueryTraceInfoResponse::BuildMessage(tTranceInfos, response), NN_OK); + HTracerExit(); +} + +TEST_F(TestMsg, test_msg_response_give_new_return_true) +{ + EXPECT_EQ(HTracerInit("2999"), NN_OK); + QueryTraceInfoRequest *queryRequest = static_cast(malloc(sizeof(QueryTraceInfoRequest))); + queryRequest->serviceId = 1; + Message request(queryRequest, sizeof(QueryTraceInfoRequest)); + + TRACE_DELAY_BEGIN(0); + TRACE_DELAY_END(0, 0); + TRACE_DELAY_BEGIN(1); + TRACE_DELAY_END(1, 0); + TRACE_DELAY_BEGIN(2); + TRACE_DELAY_END(2, 0); + + auto tTranceInfos = TracerServiceHelper::GetTraceInfos(0, 0, TraceManager::IsLatencyQuantileEnable()); + + uint32_t bodySize = sizeof(uint32_t) + sizeof(TTraceInfo) * tTranceInfos.size(); + uint32_t messageSize = sizeof(MessageHeader) + bodySize; + Message queryResponse {}; + + EXPECT_EQ(QueryTraceInfoResponse::BuildMessage(tTranceInfos, queryResponse), NN_OK); + + HTracerExit(); +} + +TEST_F(TestMsg, test_msg_give_request_return_true) +{ + QueryTraceInfoRequest queryRequest; + queryRequest.serviceId = 1; + EXPECT_EQ(queryRequest.serviceId, 1); + EXPECT_EQ(queryRequest.bodySize, 0); + EXPECT_EQ(queryRequest.crc, 0); + EXPECT_EQ(queryRequest.version, VERSION); + EXPECT_EQ(queryRequest.magicCode, MAGIC_CODE); + EXPECT_EQ(queryRequest.opcode, TRACE_OP_QUERY); +} \ No newline at end of file diff --git a/test/unit_test/common/trace/test_msg.h b/test/unit_test/common/trace/test_msg.h new file mode 100644 index 0000000000000000000000000000000000000000..5ed7b9b1da3f6f34b19e5843be505329c243c410 --- /dev/null +++ b/test/unit_test/common/trace/test_msg.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + + +#ifndef DFS_TRACER_TEST_MSG_H +#define DFS_TRACER_TEST_MSG_H +#include +#include "gtest/gtest.h" +#include "htracer_msg.h" +#include "rpc_msg.h" +#include "htracer_service_helper.h" + +class TestMsg : public testing::Test { +public: + TestMsg() {} + ~TestMsg() {} + + // TestCase only enter once + static void SetUpTestCase(); + static void TearDownTestCase(); + + // every TEST_F macro will enter one + void SetUp() const; + void TearDown() const; +}; + + +#endif // DFS_TRACER_TEST_MSG_H diff --git a/test/unit_test/common/trace/test_trace_rpc.cpp b/test/unit_test/common/trace/test_trace_rpc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76e966dff15dc3585d7a961ad4fc0bf9a96b1cfe --- /dev/null +++ b/test/unit_test/common/trace/test_trace_rpc.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + + +#include "test_trace_rpc.h" +#include "htracer_service_helper.h" +#include "rpc_server.h" +#include "rpc_msg.h" + +using namespace std; +using namespace ock::hcom; +void TestRpc::SetUp() const {} +void TestRpc::TearDown() +{ + if (mRpcServer != nullptr) { + delete mRpcServer; + mRpcServer = nullptr; + } +} + +void TestRpc::SetUpTestCase() {} +void TestRpc::TearDownTestCase() {} + +TEST_F(TestRpc, test_one_rpc_sercer_start_return_true_shutdown_return_true) +{ + int port = 12345; + if (mRpcServer == nullptr) { + mRpcServer = new (std::nothrow) RpcServer(); + } + EXPECT_NE(mRpcServer, nullptr); + EXPECT_EQ(mRpcServer->Start(std::to_string(port)), NN_OK); + + mRpcServer->Stop(); +} + +TEST_F(TestRpc, test_one_rpc_sercer_start_18_port_return_true) +{ + if (mRpcServer == nullptr) { + mRpcServer = new (std::nothrow) RpcServer(); + } + EXPECT_NE(mRpcServer, nullptr); + for (int port = 50000; port < 50100; port++) { + EXPECT_EQ(mRpcServer->Start(std::to_string(port)), NN_OK); + } + + mRpcServer->Stop(); +} + +TEST_F(TestRpc, test_msg_validate_0_size) +{ + QueryTraceInfoRequest *queryRequest = static_cast(malloc(sizeof(QueryTraceInfoRequest))); + queryRequest->serviceId = 1; + Message request(queryRequest, 0); + ASSERT_EQ(MessageValidator::Validate(request), false); +} + +TEST_F(TestRpc, test_msg_validate_nullptr_msg) +{ + QueryTraceInfoRequest *queryRequest = static_cast(malloc(sizeof(QueryTraceInfoRequest))); + Message request(nullptr, sizeof(QueryTraceInfoRequest)); + ASSERT_EQ(MessageValidator::Validate(request), false); +} + +TEST_F(TestRpc, test_msg_validate_normal_msg) +{ + QueryTraceInfoRequest *queryRequest = static_cast(malloc(sizeof(QueryTraceInfoRequest))); + queryRequest->serviceId = 1; + queryRequest->version = VERSION; + queryRequest->magicCode = MAGIC_CODE; + queryRequest->bodySize = 0; + + Message request(queryRequest, sizeof(QueryTraceInfoRequest)); + ASSERT_EQ(MessageValidator::Validate(request), true); +} \ No newline at end of file diff --git a/test/unit_test/common/trace/test_trace_rpc.h b/test/unit_test/common/trace/test_trace_rpc.h new file mode 100644 index 0000000000000000000000000000000000000000..53aa16d19872f078c29a91c8f06443987e330559 --- /dev/null +++ b/test/unit_test/common/trace/test_trace_rpc.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef DFS_TRACER_TEST_TRACE_RPC_H +#define DFS_TRACER_TEST_TRACE_RPC_H + +#include +#include "gtest/gtest.h" +#include "htracer_manager.h" +#include "rpc_server.h" +#include "htracer_msg.h" + +class TestRpc : public testing::Test { +public: + TestRpc() {} + ~TestRpc() {} + + // TestCase only enter once + static void SetUpTestCase(); + static void TearDownTestCase(); + + // every TEST_F macro will enter one + void SetUp() const; + void TearDown(); + ock::hcom::RpcServer *mRpcServer = nullptr; +}; + + +#endif // DFS_TRACER_TEST_TRACE_RPC_H diff --git a/test/unit_test/common/trace/test_trace_service.cpp b/test/unit_test/common/trace/test_trace_service.cpp new file mode 100644 index 0000000000000000000000000000000000000000..381b4a0f6c57a03c5287b252ffa3c63262f2e681 --- /dev/null +++ b/test/unit_test/common/trace/test_trace_service.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + + +#include "test_trace_service.h" +#include +#include +#include +#include "htracer.h" + + +namespace ock { +namespace hcom { +void TestService::SetUp() const {} +void TestService::TearDown() const {} + +void TestService::SetUpTestCase() {} +void TestService::TearDownTestCase() {} + +constexpr uint32_t TRACE_ID_0 = TRACE_ID(0, NN_NO1); +constexpr uint32_t TRACE_ID_1 = TRACE_ID(1, NN_NO1); +constexpr uint32_t TRACE_ID_2 = TRACE_ID(2, NN_NO0); + +TEST_F(TestService, test_one_trace_service_start_return_true_shutdown_return_true) +{ + traceService = new HTracerService(); + EXPECT_EQ(traceService->StartUp("1234"), NN_OK); + + traceService->ShutDown(); + + delete traceService; + traceService = nullptr; +} + +TEST_F(TestService, test_one_trace_service_create_port_return_false_when_used_and_return_true_when_new_port) +{ + traceService = new HTracerService(); + EXPECT_EQ(traceService->StartUp("12345"), NN_OK); + EXPECT_EQ(traceService->StartUp("12346"), NN_OK); + + traceService->ShutDown(); + delete traceService; + traceService = nullptr; +} + +TEST_F(TestService, test_get_traceinfo_give_normal_value_id_return_normal_value) +{ + EnableHtrace(true); + EXPECT_EQ(HTracerInit("30000"), NN_OK); + TRACE_DELAY_BEGIN(TRACE_ID_0); + TRACE_DELAY_END(TRACE_ID_0, 0); + TRACE_DELAY_BEGIN(TRACE_ID_2); + TRACE_DELAY_END(TRACE_ID_2, 0); + + auto tTranceInfos = TracerServiceHelper::GetTraceInfos(0, 0, TraceManager::IsLatencyQuantileEnable()); + ASSERT_EQ(tTranceInfos.size(), 1); + for (int i = 0; i < 1; i++) { + EXPECT_EQ(tTranceInfos[i].begin, 1); + EXPECT_NE(tTranceInfos[i].total, 0); + EXPECT_EQ(tTranceInfos[i].goodEnd, 1); + EXPECT_EQ(tTranceInfos[i].badEnd, 0); + } + + auto tTranceInfos1 = TracerServiceHelper::GetTraceInfos(2, 0, TraceManager::IsLatencyQuantileEnable()); + ASSERT_EQ(tTranceInfos1.size(), 1); + for (int i = 0; i < 1; i++) { + EXPECT_EQ(tTranceInfos1[i].begin, 1); + EXPECT_NE(tTranceInfos1[i].total, 0); + EXPECT_EQ(tTranceInfos1[i].goodEnd, 1); + EXPECT_EQ(tTranceInfos1[i].badEnd, 0); + } + HTracerExit(); +} + +TEST_F(TestService, test_get_traceinfo_give_morethan_MAX_SERVICE_NUM_return_empty) +{ + EXPECT_EQ(HTracerInit("30001"), NN_OK); + TRACE_DELAY_BEGIN(TRACE_ID_0); + TRACE_DELAY_END(TRACE_ID_0, 0); + TRACE_DELAY_BEGIN(TRACE_ID_2); + TRACE_DELAY_END(TRACE_ID_2, 0); + auto tTranceInfos = TracerServiceHelper::GetTraceInfos(MAX_SERVICE_NUM + 1, + 0, TraceManager::IsLatencyQuantileEnable()); + + EXPECT_EQ(tTranceInfos.size(), 0); + + HTracerExit(); +} + +TEST_F(TestService, test_get_traceinfo_give_invalid_value_return_all_records) +{ + TracerServiceHelper::ResetTraceInfos(); + EXPECT_EQ(HTracerInit("30002"), NN_OK); + TRACE_DELAY_BEGIN(TRACE_ID_0); + TRACE_DELAY_END(TRACE_ID_0, 0); + TRACE_DELAY_BEGIN(TRACE_ID_2); + TRACE_DELAY_END(TRACE_ID_2, 0); + auto tTranceInfos = TracerServiceHelper::GetTraceInfos(INVALID_SERVICE_ID, + 0, TraceManager::IsLatencyQuantileEnable()); + + EXPECT_EQ(tTranceInfos.size(), 2); + for (int i = 0; i < 2; i++) { + EXPECT_EQ(tTranceInfos[i].begin, 1); + EXPECT_NE(tTranceInfos[i].total, 0); + EXPECT_EQ(tTranceInfos[i].goodEnd, 1); + EXPECT_EQ(tTranceInfos[i].badEnd, 0); + } + HTracerExit(); +} + +TEST_F(TestService, test_sent_response_return_ok) +{ + traceService = new HTracerService(); + EXPECT_EQ(traceService->StartUp("33333"), NN_OK); + Message response(nullptr, 0); +} + + +TEST_F(TestService, test_sent_request_nullptr_return_false) +{ + traceService = new HTracerService(); + EXPECT_EQ(traceService->StartUp("33331"), NN_OK); + Message response(nullptr, 0); + Message request(nullptr, 0); + traceService->ShutDown(); + delete traceService; + traceService = nullptr; +} + +TEST_F(TestService, test_sent_response_nullptr_return_ok) +{ + traceService = new HTracerService(); + EXPECT_EQ(traceService->StartUp("33334"), NN_OK); + Message response(nullptr, 0); + + int32_t recvBufferSize = 1024; + char *recvBuffer = static_cast(malloc(recvBufferSize)); + Message request(recvBuffer, recvBufferSize); + + traceService->ShutDown(); + delete traceService; + traceService = nullptr; +} + +TEST_F(TestService, test_sent_request_opcode_TRACE_OP_MODIFY_return_true) +{ + traceService = new HTracerService(); + EXPECT_EQ(traceService->StartUp("33335"), NN_OK); + Message response(nullptr, 0); + + QueryTraceInfoRequest *queryRequest = + static_cast(malloc(sizeof(QueryTraceInfoRequest))); + queryRequest->serviceId = 1; + queryRequest->opcode = INVALID_OPCODE; + Message request(queryRequest, sizeof(QueryTraceInfoRequest)); + traceService->ShutDown(); + delete traceService; + traceService = nullptr; +} + +TEST_F(TestService, test_sent_request_opcode_TRACE_OP_QUERY_return_true) +{ + traceService = new HTracerService(); + EXPECT_EQ(traceService->StartUp("33336"), NN_OK); + Message response(nullptr, 0); + + QueryTraceInfoRequest *queryRequest = + static_cast(malloc(sizeof(QueryTraceInfoRequest))); + queryRequest->serviceId = 1; + + Message request(queryRequest, sizeof(queryRequest)); + traceService->ShutDown(); + delete traceService; + traceService = nullptr; +} + +TEST_F(TestService, TestTraceManagerBegin) +{ + EXPECT_NO_FATAL_FAILURE(TraceManager::DelayBegin(NN_NO4, "name")); +} + +TEST_F(TestService, TestBuildMessage) +{ + Message msg {}; + ResetTraceInfoResponse::BuildMessage(msg); +} + +TEST_F(TestService, TestCentroidList) +{ + CentroidList lst {2}; + EXPECT_EQ(lst.Insert(1, 2), InsertResultCode::NO_NEED_COMPERSS); + EXPECT_EQ(lst.Insert(1, 2), InsertResultCode::NEED_COMPERSS); + EXPECT_EQ(lst.Insert(-1, 2), InsertResultCode::NO_NEED_COMPERSS); +} + +TEST_F(TestService, TestCentroid_GetMean) +{ + Centroid centroid(1.0, 1); + EXPECT_EQ(centroid.GetMean(), 1.0); +} + +TEST_F(TestService, TestCentroid_GetWeight) +{ + Centroid centroid(1.0, 1); + EXPECT_EQ(centroid.GetWeight(), 1); +} + +TEST_F(TestService, TestCentroidList_GetAndSetCentroids) +{ + CentroidList centroidList(1); + centroidList.Insert(1.0, 1); + std::vector centroids = centroidList.GetAndSetCentroids(); + EXPECT_EQ(centroids.size(), 1); + EXPECT_EQ(centroids[0].GetWeight(), 1); +} + +TEST_F(TestService, TestTdigest) +{ + Tdigest tdigest(100); + for (int i = 1; i <= 100; i++) { + tdigest.Insert(i); + } + tdigest.Merge(); + double p90 = tdigest.Quantile(90); + tdigest.Reset(); +} +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/common/trace/test_trace_service.h b/test/unit_test/common/trace/test_trace_service.h new file mode 100644 index 0000000000000000000000000000000000000000..5debecc9ddc78fd07b5714ba0c45fc77ccf25ce1 --- /dev/null +++ b/test/unit_test/common/trace/test_trace_service.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + + +#ifndef DFS_TRACER_TEST_TRACE_SERVICE_H +#define DFS_TRACER_TEST_TRACE_SERVICE_H + +#include +#include "gtest/gtest.h" +#include "htracer_service.h" +#include "htracer_service_helper.h" +#include "rpc_server.h" + +namespace ock { +namespace hcom { +class TestService : public testing::Test { +public: + TestService() {} + ~TestService() {} + + // TestCase only enter once + static void SetUpTestCase(); + static void TearDownTestCase(); + + // every TEST_F macro will enter one + void SetUp() const; + void TearDown() const; + + ock::hcom::HTracerService *traceService = nullptr; +}; +} +} + +#endif // DFS_TRACER_TEST_TRACE_SERVICE_H diff --git a/test/unit_test/main.cpp b/test/unit_test/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39470198a4c74f6f280f4788cfe4bafd606d0029 --- /dev/null +++ b/test/unit_test/main.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +int main(int argc, char *argv[]) +{ + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + printf("hcom_ut result %d\n", ret); + return ret; +} \ No newline at end of file diff --git a/test/unit_test/service/test_service_ctx_store.cpp b/test/unit_test/service/test_service_ctx_store.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eea98c72ddf85eb4b6451b7e061e573d6e0c643a --- /dev/null +++ b/test/unit_test/service/test_service_ctx_store.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +// Author: zhiwei + +#include + +#include +#include + +#include "net_mem_pool_fixed.h" +#include "service_ctx_store.h" + +namespace ock { +namespace hcom { +class TestServiceCtxStore : public testing::Test { +public: + virtual void SetUp(void) + { + } + + virtual void TearDown(void) + { + GlobalMockObject::verify(); + } +}; + +TEST_F(TestServiceCtxStore, CrossThreadReturn) +{ + NetMemPoolFixedOptions options{}; + options.superBlkSizeMB = NN_NO1; // 1M 一共有 16384 个小块 + options.tcExpandBlkCnt = NN_NO8; // 扩容时分配的小块个数 + options.minBlkSize = NN_NO64; // 每个小块大小为 64 + + NetMemPoolFixed *mempool1 = new (std::nothrow) NetMemPoolFixed("service1", options); + ASSERT_NE(mempool1, nullptr); + mempool1->IncreaseRef(); + mempool1->Initialize(); + + NetMemPoolFixed *mempool2 = new (std::nothrow) NetMemPoolFixed("service2", options); + ASSERT_NE(mempool2, nullptr); + mempool2->IncreaseRef(); + mempool2->Initialize(); + + NetServiceCtxStore *store1 = new (std::nothrow) NetServiceCtxStore(NN_NO1024, mempool1, + UBSHcomNetDriverProtocol::TCP); + NetServiceCtxStore *store2 = new (std::nothrow) NetServiceCtxStore(NN_NO1024, mempool2, + UBSHcomNetDriverProtocol::RDMA); + ASSERT_NE(store1, nullptr); + ASSERT_NE(store2, nullptr); + + // 模拟用户线程 + std::vector ptrs1; + std::vector ptrs2; + std::thread user([&store1, &store2, &ptrs1, &ptrs2]() { + for (int i = 0; i < NN_NO16; ++i) { + int *p1 = store1->GetCtxObj(); + int *p2 = store2->GetCtxObj(); + ptrs1.push_back(p1); + ptrs2.push_back(p2); + } + }); + if (user.joinable()) { + user.join(); + } + + // mempool1 的 1M 内存首地址 + int *start = reinterpret_cast(mempool1->mSuperBlocks[0].buffer); + int *end = reinterpret_cast(reinterpret_cast(start) + NN_NO1048576); + + // 如果是旧实现,ptrs2 中的地址都将是在 [start1, start1 + 1MB) 范围之间,后续令 mempool1 析构就会导致 ptrs2 中的内存 + // 访问失败 + for (auto *p : ptrs2) { + EXPECT_FALSE(p >= start && p < end); + } + + // 令 mempool1 提前析构,不会出现 coredump + delete store1; + mempool1->DecreaseRef(); + EXPECT_NO_FATAL_FAILURE(*ptrs2[0] = 11); + + delete store2; + mempool2->DecreaseRef(); +} + +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/service_v2/test_hcom_service_imp.cpp b/test/unit_test/service_v2/test_hcom_service_imp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4fe2331cf2402e28edcc0b012baa1005129a2510 --- /dev/null +++ b/test/unit_test/service_v2/test_hcom_service_imp.cpp @@ -0,0 +1,1010 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include + +#include "hcom.h" +#include "service_imp.h" +#include "service_channel_imp.h" +#include "net_rdma_driver_oob.h" +#include "net_rdma_async_endpoint.h" + +namespace ock { +namespace hcom { +std::string name = "service1"; +std::string serviceIpInfo = "127.0.0.1"; +std::string serviceUdsPath = "/home/udsPath"; +std::string oobPort = "1234"; + +class TestHcomServiceImp : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + HcomServiceImp *service; +}; + +void TestHcomServiceImp::SetUp() +{ + UBSHcomServiceOptions options{}; + service = new HcomServiceImp(UBSHcomNetDriverProtocol::RDMA, name, options); + service->mOptions.enableMultiRail = false; + service->mEnableMrCache = true; + ASSERT_NE(service, nullptr); +} + +void TestHcomServiceImp::TearDown() +{ + if (service != nullptr) { + delete service; + } + GlobalMockObject::verify(); +} + +int NewChannel(const std::string &ipPort, const UBSHcomChannelPtr &ch, const std::string &payload) +{ + return 0; +} + +TEST_F(TestHcomServiceImp, TestServiceBind) +{ + EXPECT_EQ(service->Bind("tcp://" + serviceIpInfo + ":" + oobPort, NewChannel), static_cast(SER_OK)); + EXPECT_EQ(service->Bind("uds://" + serviceUdsPath, NewChannel), static_cast(SER_OK)); + EXPECT_EQ(service->Bind("abc://" + serviceIpInfo, NewChannel), static_cast(SER_INVALID_PARAM)); + EXPECT_EQ(service->Bind("abc" + serviceIpInfo, NewChannel), static_cast(SER_INVALID_PARAM)); +} + +TEST_F(TestHcomServiceImp, TestServiceAddListener) +{ + EXPECT_NO_FATAL_FAILURE(service->AddListener("tcp://" + serviceIpInfo + ":" + oobPort, 1)); + EXPECT_NO_FATAL_FAILURE(service->AddListener("uds://" + serviceUdsPath, 1)); + EXPECT_NO_FATAL_FAILURE(service->AddListener("abc://" + serviceIpInfo, 1)); + EXPECT_NO_FATAL_FAILURE(service->AddListener("abc" + serviceIpInfo, 1)); +} + +TEST_F(TestHcomServiceImp, TestServiceAddTcpOobListener) +{ + EXPECT_EQ(service->AddTcpOobListener(serviceIpInfo + ":" + oobPort, 1), static_cast(SER_OK)); + EXPECT_EQ(service->AddTcpOobListener(serviceIpInfo + oobPort, 1), static_cast(SER_INVALID_PARAM)); + EXPECT_EQ(service->AddTcpOobListener("127.127.127.127.127:" + oobPort, 1), static_cast(SER_INVALID_PARAM)); + EXPECT_EQ(service->AddTcpOobListener(serviceIpInfo + ":" + oobPort, 1), static_cast(SER_INVALID_PARAM)); +} + +TEST_F(TestHcomServiceImp, TestServiceAddUdsOobListener) +{ + EXPECT_EQ(service->AddUdsOobListener(serviceUdsPath, 1), 0); + EXPECT_EQ(service->AddUdsOobListener(serviceUdsPath, 1), static_cast(SER_INVALID_PARAM)); +} + +TEST_F(TestHcomServiceImp, TestServiceStart) +{ + EXPECT_EQ(service->Start(), static_cast(SER_INVALID_PARAM)); + + service->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, + UBSHcomChannelBrokenPolicy::BROKEN_ALL); + service->RegisterRecvHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterSendHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterOneSideHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterIdleHandler([](const UBSHcomNetWorkerIndex &ctx) {return 0;}); + EXPECT_EQ(service->Start(), static_cast(NN_INVALID_IP)); + + UBSHcomNetDriver *driver = new (std::nothrow) NetDriverRDMAWithOob(name, false, UBSHcomNetDriverProtocol::RDMA); + MOCKER_CPP(&UBSHcomNetDriver::IsStarted).stubs().will(returnValue(true)); + MOCKER_CPP_VIRTUAL(driver, &UBSHcomNetDriver::Initialize).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&UBSHcomNetDriver::Instance).stubs().will(returnValue(driver)); + EXPECT_EQ(service->Start(), static_cast(NN_ERROR)); +} + +TEST_F(TestHcomServiceImp, TestServiceStartSuccess) +{ + service->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, + UBSHcomChannelBrokenPolicy::BROKEN_ALL); + service->RegisterRecvHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterSendHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterOneSideHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterIdleHandler([](const UBSHcomNetWorkerIndex &ctx) {return 0;}); + service->SetDeviceIpMask({serviceIpInfo}); + UBSHcomNetDriver *driver = new (std::nothrow) NetDriverRDMAWithOob(name, false, UBSHcomNetDriverProtocol::RDMA); + MOCKER_CPP(&UBSHcomNetDriver::Instance).stubs().will(returnValue(driver)); + MOCKER_CPP_VIRTUAL(driver, &UBSHcomNetDriver::Initialize).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP_VIRTUAL(driver, &UBSHcomNetDriver::Start).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&UBSHcomNetDriver::IsStarted).stubs().will(returnValue(true)); + EXPECT_EQ(service->Start(), 0); + EXPECT_EQ(service->Start(), 0); +} + +TEST_F(TestHcomServiceImp, TestServiceStartFailed) +{ + service->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, + UBSHcomChannelBrokenPolicy::BROKEN_ALL); + service->RegisterRecvHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterSendHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterOneSideHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + EXPECT_EQ(service->Start(), static_cast(NN_INVALID_IP)); + + MOCKER_CPP(&HcomServiceImp::CreatePeriodicMgr) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomServiceImp::CreateCtxMemPool) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->Start(), static_cast(SER_NEW_OBJECT_FAILED)); + EXPECT_EQ(service->Start(), static_cast(SER_NEW_OBJECT_FAILED)); +} + +SerResult MockGetEnableDevCnt(std::string ipMask, uint16_t &enableDevCount, std::vector &enableIps, + std::string ipGroup) +{ + enableDevCount = 2; + return SER_OK; +} + +TEST_F(TestHcomServiceImp, TestServiceStartMultirail) +{ + service->RegisterChannelBrokenHandler([](const UBSHcomChannelPtr &channel) {}, + UBSHcomChannelBrokenPolicy::BROKEN_ALL); + service->RegisterRecvHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterSendHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + service->RegisterOneSideHandler([](const UBSHcomServiceContext &ctx) {return 0;}); + + UBSHcomNetDriver *driver = new (std::nothrow) NetDriverRDMAWithOob(name, false, UBSHcomNetDriverProtocol::RDMA); + MOCKER_CPP(&UBSHcomNetDriver::IsStarted).stubs().will(returnValue(true)); + MOCKER_CPP_VIRTUAL(driver, &UBSHcomNetDriver::Initialize).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&UBSHcomNetDriver::Instance).stubs().will(returnValue(driver)); + + service->mOptions.startOobSvr = true; + service->mOptions.enableMultiRail = true; + UBSHcomWorkerGroupInfo groupInfo; + groupInfo.cpuIdsRange = {1, 1}; + service->mOptions.workerGroupInfos = {{groupInfo}, {groupInfo}}; + + service->SetDeviceIpGroups({serviceIpInfo, serviceIpInfo}); + MOCKER_CPP_VIRTUAL(driver, &UBSHcomNetDriver::Start).stubs().will(returnValue(static_cast(SER_OK))); + + MOCKER_CPP(&RDMADeviceHelper::GetEnableDeviceCount) + .stubs() + .will(invoke(MockGetEnableDevCnt)); + EXPECT_EQ(service->Bind("tcp://" + serviceIpInfo + ":" + oobPort, NewChannel), 0); + EXPECT_EQ(service->Start(), 0); +} + +TEST_F(TestHcomServiceImp, TestServiceCreateOobListenersFailed) +{ + UBSHcomNetOobListenerOptions option {}; + UBSHcomNetDriverOptions opt {}; + opt.oobType = NetDriverOobType::NET_OOB_TCP; + for (int i = 0; i < NN_NO65536; i++) { + service->mOptions.oobOption[std::to_string(i)] = option; + } + EXPECT_EQ(service->CreateOobListeners(opt), static_cast(SER_INVALID_PARAM)); +} + +TEST_F(TestHcomServiceImp, TestServiceCreateOobUdsListenersFailed) +{ + UBSHcomNetOobUDSListenerOptions option {}; + UBSHcomNetDriverOptions opt {}; + opt.oobType = NetDriverOobType::NET_OOB_UDS; + for (int i = 0; i < NN_NO65536; i++) { + service->mOptions.udsOobOption[std::to_string(i)] = option; + } + EXPECT_EQ(service->CreateOobUdsListeners(opt), static_cast(SER_INVALID_PARAM)); +} + +TEST_F(TestHcomServiceImp, TestServiceCreateOobUdsListeners) +{ + UBSHcomNetOobUDSListenerOptions option {}; + EXPECT_EQ(option.Set(serviceUdsPath, 1), true); + + UBSHcomNetDriverOptions opt {}; + opt.oobType = NetDriverOobType::NET_OOB_UDS; + EXPECT_EQ(service->CreateOobListeners(opt), static_cast(SER_INVALID_PARAM)); + service->mOptions.udsOobOption[serviceUdsPath] = option; + EXPECT_EQ(service->CreateOobListeners(opt), static_cast(SER_OK)); +} + +TEST_F(TestHcomServiceImp, TestServiceDoDestroy) +{ + EXPECT_EQ(service->DoDestroy(name), static_cast(SER_OK)); +} + + +SerResult MockDoConnect(const std::string &serverUrl, SerConnInfo &opt, + const std::string &payLoad, UBSHcomChannelPtr &channel) +{ + channel = new (std::nothrow) HcomChannelImp(opt.channelId, false, opt.options); + return SER_OK; +} + +TEST_F(TestHcomServiceImp, TestServiceConnect) +{ + UBSHcomChannelPtr ch; + UBSHcomConnectOptions opt; + opt.linkCount = NN_NO1; + EXPECT_EQ(service->Connect("tcp://" + serviceIpInfo + ":" + oobPort, ch, opt), static_cast(SER_STOP)); + + MOCKER_CPP(&HcomServiceImp::DoConnect) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(invoke(MockDoConnect)); + service->mStarted = true; + EXPECT_EQ(service->Connect("tcp://" + serviceIpInfo + ":" + oobPort, ch, opt), static_cast(SER_INVALID_PARAM)); + + MOCKER_CPP(&HcomServiceImp::ExchangeTimestamp) + .stubs() + .will(returnValue(static_cast(SER_TIMEOUT))) + .then(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->Connect("tcp://" + serviceIpInfo + ":" + oobPort, ch, opt), static_cast(SER_TIMEOUT)); + EXPECT_EQ(service->Connect("tcp://" + serviceIpInfo + ":" + oobPort, ch, opt), static_cast(SER_OK)); +} + +TEST_F(TestHcomServiceImp, TestExchangeTimestamp) +{ + UBSHcomChannel *ch = nullptr; + EXPECT_EQ(service->ExchangeTimestamp(ch), static_cast(SER_ERROR)); + InnerConnectOptions opt{}; + ch = new (std::nothrow) HcomChannelImp(0, false, opt); + EXPECT_NE(ch, nullptr); + MOCKER_CPP(&HcomChannelImp::SyncCallInner).stubs().will(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->ExchangeTimestamp(ch), static_cast(SER_OK)); + if (ch != nullptr) { + delete ch; + ch = nullptr; + } +} + +TEST_F(TestHcomServiceImp, TestExchangeTimestamp2) +{ + InnerConnectOptions opt{}; + UBSHcomChannel *ch = new (std::nothrow) HcomChannelImp(0, false, opt); + EXPECT_NE(ch, nullptr); + MOCKER_CPP(&HcomChannelImp::SyncCallInner).stubs().will(returnValue(static_cast(SER_ERROR))); + EXPECT_EQ(service->ExchangeTimestamp(ch), static_cast(SER_ERROR)); + if (ch != nullptr) { + delete ch; + ch = nullptr; + } +} + +TEST_F(TestHcomServiceImp, TestServiceExchangeTimeStampHandle) +{ + UBSHcomServiceContext ctx{}; + ctx.mResult = SER_ERROR; + EXPECT_EQ(service->ServiceExchangeTimeStampHandle(ctx), static_cast(SER_ERROR)); + + ctx.mResult = SER_OK; + MOCKER_CPP(&UBSHcomServiceContext::MessageDataLen) + .stubs() + .will(returnValue(static_cast(sizeof(HcomExchangeTimestamp) - NN_NO1))) + .then(returnValue(static_cast(sizeof(HcomExchangeTimestamp)))); + EXPECT_EQ(service->ServiceExchangeTimeStampHandle(ctx), static_cast(SER_INVALID_PARAM)); + + HcomExchangeTimestamp timestamp{}; + timestamp.deltaTimeStamp = NN_NO0; + + HcomExchangeTimestamp timestamp2{}; + timestamp2.deltaTimeStamp = NN_NO1024; + MOCKER_CPP(&UBSHcomServiceContext::MessageData) + .stubs() + .will(returnValue(static_cast(×tamp))) + .then(returnValue(static_cast(×tamp2))); + EXPECT_EQ(service->ServiceExchangeTimeStampHandle(ctx), static_cast(SER_INVALID_PARAM)); + + UBSHcomChannelPtr ch1 = nullptr; + InnerConnectOptions opt{}; + UBSHcomChannelPtr ch2 = new (std::nothrow) HcomChannelImp(0, false, opt); + EXPECT_NE(ch2.Get(), nullptr); + MOCKER_CPP(&UBSHcomServiceContext::Channel).stubs().will(returnValue(ch1)).then(returnValue(ch2)); + EXPECT_EQ(service->ServiceExchangeTimeStampHandle(ctx), static_cast(SER_INVALID_PARAM)); + + MOCKER_CPP(&HcomChannelImp::ReplyInner).stubs().will(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->ServiceExchangeTimeStampHandle(ctx), static_cast(SER_OK)); +} + +SerResult MockDoConnectInner(const std::string &serverUrl, SerConnInfo &opt, + const std::string &payLoad, std::vector &epVector, uint32_t &totalBandWidth) +{ + UBSHcomNetWorkerIndex idx {}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO100, nullptr, nullptr, idx); + epVector.push_back(ep); + return SER_OK; +} + +TEST_F(TestHcomServiceImp, TestServiceDoConnect) +{ + UBSHcomChannelPtr ch; + UBSHcomConnectOptions opt; + opt.linkCount = NN_NO1; + NetMemPoolFixedOptions options = {}; + service->mContextMemPool = new (std::nothrow) NetMemPoolFixed("ServiceContextTimer-test", options); + service->mPeriodicMgr = new (std::nothrow) HcomPeriodicManager(NN_NO1, "mName"); + service->mPgtable = new NetPgTable(HcomServiceImp::pgdAlloc, HcomServiceImp::pgdFree); + SerConnInfo connInfo(0, NetUuid::GenerateUuid(serviceIpInfo), NN_NO1, service->mOptions.chBrokenPolicy, opt); + MOCKER_CPP(&HcomServiceImp::DoConnectInner) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))) + .then(invoke(MockDoConnectInner)); + EXPECT_EQ(service->DoConnect("tcp://" + serviceIpInfo + ":" + oobPort, connInfo, "", ch), + static_cast(SER_INVALID_PARAM)); + EXPECT_EQ(service->DoConnect("tcp://" + serviceIpInfo + ":" + oobPort, connInfo, "", ch), + static_cast(SER_NEW_OBJECT_FAILED)); + EXPECT_EQ(service->DoConnect("tcp://" + serviceIpInfo + ":" + oobPort, connInfo, "", ch), + static_cast(SER_OK)); + service->mPeriodicMgr.Set(nullptr); + service->mContextMemPool.Set(nullptr); + service->mPgtable.Set(nullptr); + ch.Set(nullptr); +} + +TEST_F(TestHcomServiceImp, TestServiceDoConnectInner) +{ + NetDriverPtr driverPtr = new (std::nothrow) NetDriverRDMAWithOob(name, false, RDMA); + service->mDriverPtrs.push_back(driverPtr); + UBSHcomConnectOptions opt; + opt.linkCount = NN_NO1; + SerConnInfo connInfo(0, NetUuid::GenerateUuid(serviceIpInfo), NN_NO1, service->mOptions.chBrokenPolicy, opt); + std::vector epVector; + uint32_t bandWidth = 0; + MOCKER_CPP(&SerConnInfo::Serialize) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->DoConnectInner("tcp://" + serviceIpInfo + ":" + oobPort, connInfo, "", epVector, bandWidth), + static_cast(SER_INVALID_PARAM)); + EXPECT_EQ(service->DoConnectInner("tcp://" + serviceIpInfo + ":" + oobPort, connInfo, "", epVector, bandWidth), + static_cast(NN_NOT_INITIALIZED)); + MOCKER_CPP_VIRTUAL(*(service->mDriverPtrs[0].Get()), &UBSHcomNetDriver::Connect, + SerResult(UBSHcomNetDriver::*)(const std::string &, const std::string &, UBSHcomNetEndpointPtr &, + uint32_t, uint8_t, uint8_t, uint64_t)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->DoConnectInner("tcp://" + serviceIpInfo + ":" + oobPort, connInfo, "", epVector, bandWidth), + static_cast(SER_OK)); +} + +TEST_F(TestHcomServiceImp, TestServiceDoChooseDriver) +{ + NetDriverPtr driverPtr = new (std::nothrow) NetDriverRDMAWithOob(name, false, RDMA); + service->mDriverPtrs.push_back(driverPtr); + int8_t selectDevIndex = 0; + uint8_t selectBandWidth = 0; + UBSHcomNetDriver *driver = nullptr; + EXPECT_NO_FATAL_FAILURE(service->DoChooseDriver(0, 0, selectDevIndex, selectBandWidth, driver)); +} + +TEST_F(TestHcomServiceImp, TestServiceChooseDriver) +{ + NetDriverPtr driverPtr = new (std::nothrow) NetDriverRDMAWithOob(name, false, RDMA); + service->mDriverPtrs.push_back(driverPtr); + UBSHcomNetDriver *driver = nullptr; + OOBTCPConnection conn(NN_NO6); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Receive) + .stubs() + .will(returnValue(static_cast(NN_PARAM_INVALID))) + .then(returnValue(static_cast(NN_OK))); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send) + .stubs() + .will(returnValue(static_cast(NN_PARAM_INVALID))) + .then(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&HcomServiceImp::DoChooseDriver) + .stubs(); + EXPECT_EQ(service->ChooseDriver(conn, driver), static_cast(NN_PARAM_INVALID)); + EXPECT_EQ(service->ChooseDriver(conn, driver), static_cast(SER_ERROR)); + driver = driverPtr.Get(); + EXPECT_EQ(service->ChooseDriver(conn, driver), static_cast(NN_PARAM_INVALID)); + EXPECT_EQ(service->ChooseDriver(conn, driver), static_cast(SER_OK)); + driver = nullptr; +} + +TEST_F(TestHcomServiceImp, TestServiceDisconnect) +{ + EXPECT_NO_FATAL_FAILURE(service->Disconnect(nullptr)); + InnerConnectOptions opt {}; + UBSHcomChannelPtr ch = new (std::nothrow) HcomChannelImp(0, false, opt); + EXPECT_NO_FATAL_FAILURE(service->Disconnect(ch)); +} + +TEST_F(TestHcomServiceImp, TestServiceRegisterMemoryRegion) +{ + UBSHcomRegMemoryRegion mr {}; + EXPECT_EQ(service->RegisterMemoryRegion(NN_NO1024, mr), static_cast(NN_ERROR)); + + NetDriverPtr driverPtr = new (std::nothrow) NetDriverRDMAWithOob(name, false, RDMA); + service->mDriverPtrs.push_back(driverPtr); + MOCKER_CPP_VIRTUAL(*(service->mDriverPtrs[0].Get()), &UBSHcomNetDriver::CreateMemoryRegion, + SerResult(UBSHcomNetDriver::*)(uint64_t, UBSHcomNetMemoryRegionPtr &)) + .stubs() + .will(returnValue(static_cast(NN_ERROR))); + EXPECT_EQ(service->RegisterMemoryRegion(NN_NO1024, mr), static_cast(NN_ERROR)); +} + +NResult MockCreateMemoryRegion(uintptr_t address, uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + mr = new (std::nothrow) RDMAMemoryRegion(name, nullptr, address, size); + return NN_OK; +} + +NResult MockCreateMemoryRegion2(uint64_t size, UBSHcomNetMemoryRegionPtr &mr) +{ + mr = new (std::nothrow) RDMAMemoryRegion(name, nullptr, size); + return NN_OK; +} + +TEST_F(TestHcomServiceImp, TestServiceRegisterMemoryRegion2) +{ + UBSHcomRegMemoryRegion mr {}; + NetDriverPtr driverPtr = new (std::nothrow) NetDriverRDMAWithOob(name, false, RDMA); + service->mDriverPtrs.push_back(driverPtr); + MOCKER_CPP_VIRTUAL(*(service->mDriverPtrs[0].Get()), &UBSHcomNetDriver::CreateMemoryRegion, + SerResult(UBSHcomNetDriver::*)(uint64_t, UBSHcomNetMemoryRegionPtr &)) + .stubs() + .will(invoke(MockCreateMemoryRegion2)); + + MOCKER_CPP(&PgTable::Insert).stubs().will(returnValue(0)).then(returnValue(static_cast(NN_ERROR))); + EXPECT_EQ(service->RegisterMemoryRegion(NN_NO1024, mr), static_cast(NN_OK)); + EXPECT_EQ(service->RegisterMemoryRegion(NN_NO1024, mr), static_cast(NN_ERROR)); +} + +TEST_F(TestHcomServiceImp, TestServiceRegisterMemoryRegion3) +{ + UBSHcomRegMemoryRegion mr {}; + uintptr_t addr = 0; + EXPECT_EQ(service->RegisterMemoryRegion(addr, NN_NO1024, mr), static_cast(NN_ERROR)); + + NetDriverPtr driverPtr = new (std::nothrow) NetDriverRDMAWithOob(name, false, RDMA); + service->mDriverPtrs.push_back(driverPtr); + MOCKER_CPP_VIRTUAL(*(service->mDriverPtrs[0].Get()), &UBSHcomNetDriver::CreateMemoryRegion, + SerResult(UBSHcomNetDriver::*)(uintptr_t, uint64_t, UBSHcomNetMemoryRegionPtr &)) + .stubs() + .will(returnValue(static_cast(NN_ERROR))); + EXPECT_EQ(service->RegisterMemoryRegion(addr, NN_NO1024, mr), static_cast(NN_ERROR)); +} + +TEST_F(TestHcomServiceImp, TestServiceRegisterMemoryRegion4) +{ + UBSHcomRegMemoryRegion mr {}; + uintptr_t addr = 0; + NetDriverPtr driverPtr = new (std::nothrow) NetDriverRDMAWithOob(name, false, RDMA); + service->mDriverPtrs.push_back(driverPtr); + MOCKER_CPP_VIRTUAL(*(service->mDriverPtrs[0].Get()), &UBSHcomNetDriver::CreateMemoryRegion, + SerResult(UBSHcomNetDriver::*)(uintptr_t, uint64_t, UBSHcomNetMemoryRegionPtr &)) + .stubs() + .will(invoke(MockCreateMemoryRegion)); + + MOCKER_CPP(&PgTable::Insert).stubs().will(returnValue(0)).then(returnValue(static_cast(NN_ERROR))); + EXPECT_EQ(service->RegisterMemoryRegion(addr, NN_NO1024, mr), static_cast(NN_OK)); + EXPECT_EQ(service->RegisterMemoryRegion(addr, NN_NO1024, mr), static_cast(NN_ERROR)); +} + +TEST_F(TestHcomServiceImp, TestServiceDestroyMemoryRegion) +{ + UBSHcomRegMemoryRegion mr {}; + EXPECT_NO_FATAL_FAILURE(service->DestroyMemoryRegion(mr)); + mr.mHcomMrs.resize(1); + EXPECT_NO_FATAL_FAILURE(service->DestroyMemoryRegion(mr)); + NetDriverPtr driverPtr = new (std::nothrow) NetDriverRDMAWithOob(name, false, RDMA); + service->mDriverPtrs.push_back(driverPtr); + MOCKER_CPP_VIRTUAL(*(service->mDriverPtrs[0].Get()), &UBSHcomNetDriver::DestroyMemoryRegion) + .stubs(); + + mr.mHcomMrs.resize(0); + UBSHcomMemoryRegionPtr mrPtr = new (std::nothrow) RDMAMemoryRegion(name, nullptr, 0, 0); + PgtRegion *pgtRegion = new PgtRegion(); + mrPtr->mPgRegion = reinterpret_cast(pgtRegion); + mr.mHcomMrs.emplace_back(mrPtr); + EXPECT_NO_FATAL_FAILURE(service->DestroyMemoryRegion(mr)); +} + +TEST_F(TestHcomServiceImp, TestServiceSetOptions) +{ + std::pair cpuIdsPair; + EXPECT_NO_FATAL_FAILURE(service->AddWorkerGroup(0, 0, cpuIdsPair)); + EXPECT_NO_FATAL_FAILURE(service->AddWorkerGroup(0, 0, cpuIdsPair, 0, NN_NO4)); + UBSHcomServiceLBPolicy lbPolicy {}; + EXPECT_NO_FATAL_FAILURE(service->SetConnectLBPolicy(lbPolicy)); + UBSHcomTlsOptions opt {}; + EXPECT_NO_FATAL_FAILURE(service->SetTlsOptions(opt)); + UBSHcomConnSecureOptions secureOpt {}; + EXPECT_NO_FATAL_FAILURE(service->SetConnSecureOpt(secureOpt)); + uint16_t timeOutSec = 0; + EXPECT_NO_FATAL_FAILURE(service->SetTcpUserTimeOutSec(timeOutSec)); +} + +TEST_F(TestHcomServiceImp, TestServiceSetOptions2) +{ + bool tcpSendZCopy = false; + EXPECT_NO_FATAL_FAILURE(service->SetTcpSendZCopy(tcpSendZCopy)); + uint16_t depth = 0; + EXPECT_NO_FATAL_FAILURE(service->SetCompletionQueueDepth(depth)); + uint32_t sqSize = 0; + EXPECT_NO_FATAL_FAILURE(service->SetSendQueueSize(sqSize)); + uint32_t rqSize = 0; + EXPECT_NO_FATAL_FAILURE(service->SetRecvQueueSize(rqSize)); + uint32_t prePostSize = 10; + EXPECT_NO_FATAL_FAILURE(service->SetQueuePrePostSize(prePostSize)); + uint16_t pollSize = 0; + EXPECT_NO_FATAL_FAILURE(service->SetPollingBatchSize(pollSize)); +} + +TEST_F(TestHcomServiceImp, TestServiceSetOptions3) +{ + uint16_t pollTimeout = 0; + EXPECT_NO_FATAL_FAILURE(service->SetEventPollingTimeOutUs(pollTimeout)); + uint32_t threadNum = 0; + EXPECT_NO_FATAL_FAILURE(service->SetTimeOutDetectionThreadNum(threadNum)); + uint32_t maxConnCount = 0; + EXPECT_NO_FATAL_FAILURE(service->SetMaxConnectionCount(maxConnCount)); + UBSHcomHeartBeatOptions opt {}; + EXPECT_NO_FATAL_FAILURE(service->SetHeartBeatOptions(opt)); + UBSHcomMultiRailOptions multiRailOpt {}; + EXPECT_NO_FATAL_FAILURE(service->SetMultiRailOptions(multiRailOpt)); +} + +TEST_F(TestHcomServiceImp, TestServiceGenerateUuid) +{ + std::string uuid; + MOCKER_CPP(&HcomServiceImp::GetIpAddressByIpPort) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + MOCKER(&BuffToHexString) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)) + .then(returnValue(false)) + .then(returnValue(true)); + EXPECT_EQ(service->GenerateUuid(serviceIpInfo, NN_NO1, uuid), static_cast(SER_INVALID_PARAM)); + EXPECT_EQ(service->GenerateUuid(serviceIpInfo, NN_NO1, uuid), static_cast(SER_ERROR)); + EXPECT_EQ(service->GenerateUuid(serviceIpInfo, NN_NO1, uuid), static_cast(SER_OK)); + EXPECT_EQ(service->GenerateUuid(NN_NO123, NN_NO1, uuid), static_cast(SER_ERROR)); + EXPECT_EQ(service->GenerateUuid(NN_NO123, NN_NO1, uuid), static_cast(SER_OK)); +} + +TEST_F(TestHcomServiceImp, TestServiceEmplaceNewEndpoint) +{ + UBSHcomNetWorkerIndex idx {}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO100, nullptr, nullptr, idx); + ConnectingEpInfoPtr epInfo {}; + SerConnInfo connInfo {}; + connInfo.totalLinkCount = 1; + connInfo.options.linkCount = 1; + EXPECT_EQ(service->EmplaceNewEndpoint(ep, epInfo, connInfo, name), static_cast(SER_OK)); + connInfo.index = NN_NO1; + EXPECT_EQ(service->EmplaceNewEndpoint(ep, epInfo, connInfo, name), static_cast(SER_OK)); + connInfo.totalLinkCount = NN_NO17; + EXPECT_EQ(service->EmplaceNewEndpoint(ep, epInfo, connInfo, name), static_cast(SER_INVALID_PARAM)); + + service->mNewEpMap.insert(std::make_pair(name, epInfo)); + MOCKER_CPP(&HcomConnectingEpInfo::Compare) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)); + EXPECT_EQ(service->EmplaceNewEndpoint(ep, epInfo, connInfo, name), static_cast(SER_INVALID_PARAM)); + MOCKER_CPP(&HcomConnectingEpInfo::AddEp) + .stubs() + .will(returnValue(false)); + EXPECT_EQ(service->EmplaceNewEndpoint(ep, epInfo, connInfo, name), + static_cast(SER_EP_BROKEN_DURING_CONNECTING)); +} + +TEST_F(TestHcomServiceImp, TestServiceEmplaceNewEndpointInvalidLinkCount) +{ + UBSHcomNetWorkerIndex idx {}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO100, nullptr, nullptr, idx); + ConnectingEpInfoPtr epInfo {}; + SerConnInfo connInfo {}; + connInfo.totalLinkCount = 1; + connInfo.options.linkCount = 0; + EXPECT_EQ(service->EmplaceNewEndpoint(ep, epInfo, connInfo, name), static_cast(SER_INVALID_PARAM)); + connInfo.options.linkCount = NN_NO20; + EXPECT_EQ(service->EmplaceNewEndpoint(ep, epInfo, connInfo, name), static_cast(SER_INVALID_PARAM)); +} + +TEST_F(TestHcomServiceImp, TestServiceEmplaceNewEndpointInvalidTotalLinkCount) +{ + UBSHcomNetWorkerIndex idx {}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO100, nullptr, nullptr, idx); + ConnectingEpInfoPtr epInfo {}; + SerConnInfo connInfo {}; + connInfo.totalLinkCount = NN_NO0; + connInfo.options.linkCount = 1; + EXPECT_EQ(service->EmplaceNewEndpoint(ep, epInfo, connInfo, name), static_cast(SER_INVALID_PARAM)); + connInfo.totalLinkCount = NN_NO100; + EXPECT_EQ(service->EmplaceNewEndpoint(ep, epInfo, connInfo, name), static_cast(SER_INVALID_PARAM)); +} + +TEST_F(TestHcomServiceImp, TestServiceServiceHandleNewEndPoint) +{ + EXPECT_EQ(service->ServiceHandleNewEndPoint(serviceIpInfo, nullptr, ""), static_cast(SER_INVALID_PARAM)); + UBSHcomNetWorkerIndex idx {}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO100, nullptr, nullptr, idx); + + UBSHcomConnectOptions opt; + SerConnInfo connInfo(0, NetUuid::GenerateUuid(serviceIpInfo), NN_NO1, service->mOptions.chBrokenPolicy, opt); + connInfo.options.linkCount = 1; + MOCKER_CPP(&SerConnInfo::Deserialize) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->ServiceHandleNewEndPoint(serviceIpInfo, ep, ""), static_cast(SER_INVALID_PARAM)); + MOCKER_CPP(&HcomServiceImp::GenerateUuid, + SerResult(HcomServiceImp::*)(const std::string &, uint64_t, std::string &)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->ServiceHandleNewEndPoint(serviceIpInfo, ep, ""), static_cast(SER_INVALID_PARAM)); +} + +TEST_F(TestHcomServiceImp, TestServiceServiceNewChannel) +{ + UBSHcomNetWorkerIndex idx {}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO100, nullptr, nullptr, idx); + std::vector epVector; + UBSHcomConnectOptions opt; + SerConnInfo connInfo(0, NetUuid::GenerateUuid(serviceIpInfo), NN_NO1, service->mOptions.chBrokenPolicy, opt); + connInfo.options.linkCount = 1; + EXPECT_EQ(service->ServiceNewChannel(serviceIpInfo, connInfo, "", epVector), + static_cast(SER_NEW_OBJECT_FAILED)); + MOCKER_CPP(&SerConnInfo::Deserialize) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + NetMemPoolFixedOptions options = {}; + service->mContextMemPool = new (std::nothrow) NetMemPoolFixed("ServiceContextTimer-test", options); + service->mPeriodicMgr = new (std::nothrow) HcomPeriodicManager(NN_NO1, name); + service->mPgtable = new NetPgTable(HcomServiceImp::pgdAlloc, HcomServiceImp::pgdFree); + epVector.push_back(ep); + MOCKER_CPP(&HcomServiceImp::GenerateUuid, + SerResult(HcomServiceImp::*)(const std::string &, uint64_t, std::string &)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->ServiceNewChannel(serviceIpInfo, connInfo, "", epVector), + static_cast(SER_INVALID_PARAM)); + EXPECT_EQ(service->ServiceNewChannel(serviceIpInfo, connInfo, "", epVector), + static_cast(SER_INVALID_PARAM)); + service->mOptions.chNewHandler = [](const std::string &ipPort, const UBSHcomChannelPtr &, + const std::string &payload) { + return SER_OK; + }; + EXPECT_EQ(service->ServiceNewChannel(serviceIpInfo, connInfo, "", epVector), + static_cast(SER_OK)); + service->mPeriodicMgr.Set(nullptr); + service->mContextMemPool.Set(nullptr); +} + +TEST_F(TestHcomServiceImp, TestServiceDelayEraseChannel) +{ + InnerConnectOptions opt {}; + HcomChannelImp *ch = new (std::nothrow) HcomChannelImp(0, false, opt); + ch->mCtxStore = new (std::nothrow) HcomServiceCtxStore(NN_NO2097152, nullptr, UBSHcomNetDriverProtocol::RDMA); + UBSHcomChannelPtr chPtr = ch; + NetMemPoolFixedOptions options = {}; + service->mContextMemPool = new (std::nothrow) NetMemPoolFixed("ServiceContextTimer-test", options); + service->mPeriodicMgr = new (std::nothrow) HcomPeriodicManager(NN_NO1, name); + HcomServiceTimer *timer = new HcomServiceTimer(); + MOCKER_CPP(&HcomServiceCtxStore::GetCtxObj) + .stubs() + .will(returnValue(timer)); + MOCKER_CPP(&HcomServiceCtxStore::PutAndGetSeqNo) + .stubs() + .will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomServiceCtxStore::Return) + .stubs(); + MOCKER_CPP(&HcomPeriodicManager::AddTimer).stubs().will(returnValue(static_cast(SER_OK))); + + EXPECT_EQ(service->DelayEraseChannel(chPtr, 0), static_cast(SER_OK)); + delete timer; + service->mPeriodicMgr.Set(nullptr); + service->mContextMemPool.Set(nullptr); +} + +TEST_F(TestHcomServiceImp, TestServiceEraseChannel) +{ + InnerConnectOptions opt {}; + UBSHcomChannel *ch = new (std::nothrow) HcomChannelImp(0, false, opt); + EXPECT_NO_FATAL_FAILURE(service->EraseChannel(reinterpret_cast(ch))); +} + +TEST_F(TestHcomServiceImp, TestServiceServiceEndPointBroken) +{ + EXPECT_NO_FATAL_FAILURE(service->ServiceEndPointBroken(nullptr)); + UBSHcomNetWorkerIndex idx {}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO100, nullptr, nullptr, idx); + EXPECT_NO_FATAL_FAILURE(service->ServiceEndPointBroken(ep)); +} + +TEST_F(TestHcomServiceImp, TestServiceEndPointBrokenFail2) +{ + MOCKER_CPP(&HcomConnectingEpInfo::AllEPBroken).stubs().will(returnValue(false)).then(returnValue(true)); + UBSHcomNetWorkerIndex workerIndex{}; + workerIndex.Set(NN_NO6, NN_NO4, NN_NO8); + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO8, nullptr, nullptr, workerIndex); + ConnectingEpInfoPtr epInfo = new (std::nothrow) HcomConnectingEpInfo(); + Ep2ChanUpCtx ctx(NN_NO0, reinterpret_cast(epInfo.Get()), NN_NO4); + ep->UpCtx(ctx.wholeUpCtx); + EXPECT_NO_FATAL_FAILURE(service->ServiceEndPointBroken(ep)); + EXPECT_NO_FATAL_FAILURE(service->ServiceEndPointBroken(ep)); + + InnerConnectOptions opt {}; + UBSHcomChannelPtr ch = new (std::nothrow) HcomChannelImp(0, false, opt); + Ep2ChanUpCtx ctx1(NN_NO1, reinterpret_cast(ch.Get()), NN_NO4); + ep->UpCtx(ctx1.wholeUpCtx); + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::AllEpBroken) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)); + EXPECT_NO_FATAL_FAILURE(service->ServiceEndPointBroken(ep)); + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::NeedProcessBroken) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)); + EXPECT_NO_FATAL_FAILURE(service->ServiceEndPointBroken(ep)); + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::SetChannelState).stubs().will(returnValue(true)); + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::ProcessIoInBroken).stubs(); + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::InvokeChannelBrokenCb).stubs(); + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::GetDelayEraseTime) + .stubs() + .will(returnValue(static_cast(0))); + MOCKER_CPP(&HcomServiceImp::DelayEraseChannel).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomServiceImp::EraseChannel).stubs(); + EXPECT_NO_FATAL_FAILURE(service->ServiceEndPointBroken(ep)); +} + +TEST_F(TestHcomServiceImp, TestServiceServiceRequestReceived) +{ + InnerConnectOptions opt {}; + UBSHcomChannelPtr ch = new (std::nothrow) HcomChannelImp(0, false, opt); + Ep2ChanUpCtx ep2ChUpCtx(NN_NO1, reinterpret_cast(ch.Get()), NN_NO0); + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO8, nullptr, nullptr, workerIndex); + ep->UpCtx(ep2ChUpCtx.wholeUpCtx); + UBSHcomNetRequestContext ctx{}; + ctx.mEp = ep; + ctx.mHeader.opCode = NN_NO8192; + EXPECT_EQ(service->ServiceRequestReceived(ctx), static_cast(SER_ERROR)); + Ep2ChanUpCtx ep2ChUpCtx1(NN_NO1, reinterpret_cast(nullptr), NN_NO0); + ep->UpCtx(ep2ChUpCtx.wholeUpCtx); + ctx.mEp = ep; + EXPECT_EQ(service->ServiceRequestReceived(ctx), static_cast(SER_ERROR)); + ctx.mHeader.opCode = NN_NO1; + service->mOptions.recvHandler = [](const UBSHcomServiceContext &ctx) {return 0;}; + EXPECT_EQ(service->ServiceRequestReceived(ctx), static_cast(SER_OK)); + + MOCKER_CPP(&HcomServiceCtxStore::GetSeqNoAndRemove) + .stubs() + .will(returnValue(static_cast(SER_STORE_SEQ_NO_FOUND))); + MOCKER_CPP(&HcomSeqNo::IsResp).stubs().will(returnValue(true)); + HcomServiceCtxStore *store = new (std::nothrow) HcomServiceCtxStore(NN_NO2097152, nullptr, + UBSHcomNetDriverProtocol::RDMA); + ASSERT_NE(store, nullptr); + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::GetCtxStore) + .stubs() + .will(returnValue(store)); + ep->UpCtx(ep2ChUpCtx.wholeUpCtx); + ctx.mEp = ep; + EXPECT_EQ(service->ServiceRequestReceived(ctx), static_cast(SER_ERROR)); + if (store != nullptr) { + delete store; + } +} + +TEST_F(TestHcomServiceImp, TestServiceServiceRequestReceivedSplit) +{ + InnerConnectOptions opt {}; + UBSHcomChannelPtr ch = new (std::nothrow) HcomChannelImp(0, false, opt); + Ep2ChanUpCtx ep2ChUpCtx(NN_NO1, reinterpret_cast(ch.Get()), NN_NO0); + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO8, nullptr, nullptr, workerIndex); + UBSHcomNetRequestContext ctx{}; + ep->UpCtx(ep2ChUpCtx.wholeUpCtx); + ctx.mEp = ep; + ctx.mHeader.opCode = NN_NO1; + ctx.extHeaderType = UBSHcomExtHeaderType::RAW; + service->mOptions.recvHandler = [](const UBSHcomServiceContext &ctx) {return 0;}; + EXPECT_EQ(service->ServiceRequestReceived(ctx), static_cast(SER_OK)); + + ctx.extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + SpliceMessageResultType result = SpliceMessageResultType::INDETERMINATE; + SerResult code = SER_OK; + std::string out = ""; + auto tmp = std::make_tuple(result, code, out); + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::SpliceMessage) + .stubs() + .will(returnValue(tmp)); + EXPECT_EQ(service->ServiceRequestReceived(ctx), static_cast(SER_OK)); +} + +TEST_F(TestHcomServiceImp, TestServiceRunRequestCallback) +{ + UBSHcomNetRequestContext ctx{}; + ctx.mOpType = UBSHcomNetRequestContext::NN_INVALID_OP_TYPE; + UBSHcomServiceContext context{}; + EXPECT_EQ(service->RunRequestCallback(nullptr, ctx, context), false); + + ctx.mOpType = UBSHcomNetRequestContext::NN_SENT; + Callback *newCallback = UBSHcomNewCallback([](UBSHcomServiceContext &context) {}, std::placeholders::_1); + SerTransContext upCtx {}; + upCtx.callback = newCallback; + memcpy_s(ctx.mOriginalReq.upCtxData, NN_NO16, reinterpret_cast(&upCtx), NN_NO16); + EXPECT_EQ(service->RunRequestCallback(nullptr, ctx, context), true); + + upCtx.callback = nullptr; + memcpy_s(ctx.mOriginalReq.upCtxData, NN_NO16, reinterpret_cast(&upCtx), NN_NO16); + InnerConnectOptions opt {}; + UBSHcomChannelPtr ch = new (std::nothrow) HcomChannelImp(0, false, opt); + ASSERT_NE(ch.Get(), nullptr); + HcomServiceCtxStore *store = new (std::nothrow) HcomServiceCtxStore(NN_NO2097152, nullptr, + UBSHcomNetDriverProtocol::RDMA); + ASSERT_NE(store, nullptr); + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::GetCtxStore) + .stubs() + .will(returnValue(store)); + MOCKER_CPP(&HcomServiceCtxStore::GetSeqNoAndRemove) + .stubs() + .will(returnValue(static_cast(SER_STORE_SEQ_NO_FOUND))); + EXPECT_EQ(service->RunRequestCallback(ch.Get(), ctx, context), false); +} + +TEST_F(TestHcomServiceImp, TestServiceServiceRequestPosted) +{ + InnerConnectOptions opt {}; + UBSHcomChannelPtr ch = new (std::nothrow) HcomChannelImp(0, false, opt); + Ep2ChanUpCtx ep2ChUpCtx(NN_NO1, reinterpret_cast(ch.Get()), NN_NO0); + UBSHcomNetWorkerIndex workerIndex{}; + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::GetCallBackType) + .stubs() + .will(returnValue(UBSHcomChannelCallBackType::CHANNEL_GLOBAL_CB)); + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO8, nullptr, nullptr, workerIndex); + ep->UpCtx(ep2ChUpCtx.wholeUpCtx); + UBSHcomNetRequestContext ctx{}; + ctx.mEp = ep; + ctx.mOpType = UBSHcomNetRequestContext::NN_INVALID_OP_TYPE; + EXPECT_EQ(service->ServiceRequestPosted(ctx), static_cast(SER_ERROR)); + service->mOptions.sendHandler = [](const UBSHcomServiceContext &ctx) {return 0;}; + EXPECT_EQ(service->ServiceRequestPosted(ctx), static_cast(SER_OK)); + MOCKER_CPP(&HcomServiceImp::RunRequestCallback) + .stubs() + .will(returnValue(true)); + EXPECT_EQ(service->ServiceRequestPosted(ctx), static_cast(SER_OK)); +} + +TEST_F(TestHcomServiceImp, TestServiceServiceOneSideDone) +{ + InnerConnectOptions opt {}; + UBSHcomChannelPtr ch = new (std::nothrow) HcomChannelImp(0, false, opt); + Ep2ChanUpCtx ep2ChUpCtx(NN_NO1, reinterpret_cast(ch.Get()), NN_NO0); + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO8, nullptr, nullptr, workerIndex); + ep->UpCtx(ep2ChUpCtx.wholeUpCtx); + UBSHcomNetRequestContext ctx{}; + ctx.mEp = ep; + ctx.mOpType = UBSHcomNetRequestContext::NN_INVALID_OP_TYPE; + MOCKER_CPP_VIRTUAL(*(ch.Get()), &UBSHcomChannel::GetCallBackType) + .stubs() + .will(returnValue(UBSHcomChannelCallBackType::CHANNEL_GLOBAL_CB)); + EXPECT_EQ(service->ServiceOneSideDone(ctx), static_cast(SER_ERROR)); + service->mOptions.oneSideDoneHandler = [](const UBSHcomServiceContext &ctx) {return 0;}; + EXPECT_EQ(service->ServiceOneSideDone(ctx), static_cast(SER_OK)); + MOCKER_CPP(&HcomServiceImp::RunRequestCallback) + .stubs() + .will(returnValue(true)); + EXPECT_EQ(service->ServiceOneSideDone(ctx), static_cast(SER_OK)); +} + +TEST_F(TestHcomServiceImp, TestServiceServiceSecInfoProvider) +{ + int64_t flag = 0; + UBSHcomNetDriverSecType type; + char *output = nullptr; + uint32_t outLen = 0; + bool needAutoFree = false; + EXPECT_EQ(service->ServiceSecInfoProvider(0, flag, type, output, outLen, needAutoFree), + static_cast(SER_ERROR)); + service->mOptions.connSecOption.provider = [](uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, + char *&output, uint32_t &outLen, bool &needAutoFree) {return 0;}; + MOCKER_CPP(&ConnectingSecInfo::Initialize).stubs(); + EXPECT_EQ(service->ServiceSecInfoProvider(0, flag, type, output, outLen, needAutoFree), static_cast(SER_OK)); +} + +TEST_F(TestHcomServiceImp, TestServiceServiceSecInfoValidator) +{ + uint64_t ctx = 0; + int64_t flag = 0; + char *input = nullptr; + uint32_t inputLen = 0; + EXPECT_EQ(service->ServiceSecInfoValidator(ctx, flag, input, inputLen), static_cast(SER_ERROR)); + service->mOptions.connSecOption.validator = [](uint64_t ctx, int64_t flag, const char *input, uint32_t inputLen) { + return 0; + }; + MOCKER_CPP(&ConnectingSecInfo::Initialize).stubs(); + EXPECT_EQ(service->ServiceSecInfoValidator(ctx, flag, input, inputLen), static_cast(SER_OK)); +} + +TEST_F(TestHcomServiceImp, TestServiceProtocol) +{ + EXPECT_NO_FATAL_FAILURE(service->Protocol()); +} + +TEST_F(TestHcomServiceImp, TestServiceGetIpAddressByIpPort) +{ + service->mOptions.protocol = SHM; + uint32_t ip; + EXPECT_EQ(service->GetIpAddressByIpPort(serviceIpInfo, ip), static_cast(SER_OK)); + service->mOptions.protocol = RDMA; + EXPECT_EQ(service->GetIpAddressByIpPort(serviceIpInfo, ip), static_cast(SER_INVALID_PARAM)); +} + +TEST_F(TestHcomServiceImp, TestServiceRegisterDriverCb) +{ + service->mOptions.tlsOption.enableTls = true; + EXPECT_NO_FATAL_FAILURE(service->RegisterDriverCb()); +} + +TEST_F(TestHcomServiceImp, TestServiceServicePrivateOpHandle) +{ + UBSHcomServiceContext context{}; + InnerConnectOptions opt {}; + context.mCh = new (std::nothrow) HcomChannelImp(0, false, opt); + EXPECT_EQ(service->ServicePrivateOpHandle(context), static_cast(SER_ERROR)); + context.mCh.Set(nullptr); +} + +TEST_F(TestHcomServiceImp, TestServiceAddTimerCtx) +{ + SerTimerListHeader header {}; + HcomServiceTimer timer {}; + EXPECT_NO_FATAL_FAILURE(header.AddTimerCtx(&timer)); + EXPECT_NO_FATAL_FAILURE(header.RemoveTimerCtx(&timer)); +} + +TEST_F(TestHcomServiceImp, TestServiceSetServiceTransCtx) +{ + SerTransContext ctx {}; + char *ctxData = reinterpret_cast(&ctx); + EXPECT_NO_FATAL_FAILURE(SetServiceTransCtx(ctxData, 1)); + EXPECT_NO_FATAL_FAILURE(SetServiceTransCtx(ctxData, nullptr)); + EXPECT_NO_FATAL_FAILURE(SetServiceTransCtx(ctxData, 0, false)); +} + +TEST_F(TestHcomServiceImp, TestServiceSetMaxSendRecvDataCount) +{ + EXPECT_NO_FATAL_FAILURE(service->SetMaxSendRecvDataCount(1)); +} + +TEST_F(TestHcomServiceImp, TestServiceConnectFailed) +{ + UBSHcomChannelPtr ch; + UBSHcomConnectOptions opt; + opt.linkCount = NN_NO1; + + MOCKER_CPP(&HcomServiceImp::DoConnect) + .stubs() + .will(invoke(MockDoConnect)); + service->mStarted = true; + MOCKER_CPP(&HcomServiceImp::ExchangeTimestamp) + .stubs() + .will(returnValue(static_cast(SER_TIMEOUT))) + .then(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->Connect("tcp://" + serviceIpInfo + ":" + oobPort, ch, opt), static_cast(SER_TIMEOUT)); + + MOCKER_CPP(&HcomServiceImp::EmplaceChannelUuid) + .stubs() + .will(returnValue(static_cast(SER_ERROR))) + .then(returnValue(static_cast(SER_OK))); + EXPECT_EQ(service->Connect("tcp://" + serviceIpInfo + ":" + oobPort, ch, opt), + static_cast(SER_CHANNEL_ID_DUP)); +} +} +} diff --git a/test/unit_test/service_v2/test_hcom_service_v2.cpp b/test/unit_test/service_v2/test_hcom_service_v2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..597097f88486e3239cd5cf9a4935f265dc1a97c3 --- /dev/null +++ b/test/unit_test/service_v2/test_hcom_service_v2.cpp @@ -0,0 +1,398 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include + +#include "hcom.h" +#include "service_channel_imp.h" +#include "service_callback.h" +#include "net_rdma_async_endpoint.h" +#include "under_api/urma/urma_api_wrapper.h" + +namespace ock { +namespace hcom { + +class TestHcomServiceV2 : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); +}; + +void TestHcomServiceV2::SetUp() +{} + +void TestHcomServiceV2::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestHcomServiceV2, TestHcomServiceV2Create) +{ + UBSHcomServiceOptions options{}; + options.maxSendRecvDataSize = 0; + UBSHcomService *service = UBSHcomService::Create(UBSHcomNetDriverProtocol::RDMA, "client1", options); + EXPECT_EQ(service, nullptr); + options.maxSendRecvDataSize = NN_NO1024; + std::string longName(NN_NO64 + 1, 'a'); + service = UBSHcomService::Create(UBSHcomNetDriverProtocol::RDMA, longName, options); + EXPECT_EQ(service, nullptr); + service = UBSHcomService::Create(UBSHcomNetDriverProtocol::RDMA, "client1", options); + EXPECT_NE(service, nullptr); + UBSHcomService *service1 = UBSHcomService::Create(UBSHcomNetDriverProtocol::RDMA, "client1", options); + EXPECT_NE(service1, nullptr); + EXPECT_EQ(service, service1); + + EXPECT_EQ(UBSHcomService::Destroy("client1"), SER_OK); +} + +TEST_F(TestHcomServiceV2, TestHcomServiceV2Destroy) +{ + EXPECT_EQ(UBSHcomService::Destroy("client1"), SER_ERROR); + UBSHcomServiceOptions options{}; + UBSHcomService *service = new HcomServiceImp(UBSHcomNetDriverProtocol::RDMA, "client1", options); + MOCKER_CPP_VIRTUAL(*service, &UBSHcomService::DoDestroy).stubs().will(returnValue(static_cast(SER_ERROR))); + EXPECT_EQ(UBSHcomService::Destroy("client1"), SER_ERROR); + delete service; +} + +TEST_F(TestHcomServiceV2, TestHcomServiceTimer) +{ + HcomServiceTimer *timer = new HcomServiceTimer(); + EXPECT_NE(timer, nullptr); + EXPECT_EQ(timer->SeqNo(), 0); + EXPECT_EQ(timer->Timeout(), 0); + EXPECT_EQ(timer->Callback(), 0); + EXPECT_NO_FATAL_FAILURE(timer->TimeoutDump()); + + UBSHcomServiceContext ctx; + EXPECT_NO_FATAL_FAILURE(timer->RunCallBack(ctx)); + EXPECT_NO_FATAL_FAILURE(timer->DeleteCallBack()); + delete timer; +} + +TEST_F(TestHcomServiceV2, TestHcomServiceTimer2) +{ + HcomServiceTimer *timer = new HcomServiceTimer(); + EXPECT_NE(timer, nullptr); + EXPECT_EQ(timer->IsFinished(), false); + EXPECT_NO_FATAL_FAILURE(timer->MarkFinished()); + EXPECT_NO_FATAL_FAILURE(timer->MarkTimeout()); + EXPECT_EQ(timer->IsTimeOut(), false); + + timer->mTimeout = NN_NO10; + ASSERT_EQ(timer->IsTimeOut(), true); + delete timer; +} + +TEST_F(TestHcomServiceV2, TestHcomServiceTimerFail) +{ + HcomServiceTimer *timer = new HcomServiceTimer(); + ASSERT_NE(timer, nullptr); + EXPECT_NO_FATAL_FAILURE(timer->EraseSeqNo()); + EXPECT_EQ(timer->EraseSeqNoWithRet(), false); + timer->mCtxStore = new (std::nothrow) HcomServiceCtxStore(NN_NO2097152, nullptr, UBSHcomNetDriverProtocol::RDMA); + ASSERT_NE(timer->mCtxStore, nullptr); + MOCKER_CPP(&HcomServiceCtxStore::GetSeqNoAndRemove) + .stubs() + .will(returnValue(static_cast(SER_STORE_SEQ_NO_FOUND))); + EXPECT_NO_FATAL_FAILURE(timer->EraseSeqNo()); + EXPECT_EQ(timer->EraseSeqNoWithRet(), false); + delete timer; +} + +TEST_F(TestHcomServiceV2, TestHcomServiceTimerCompare) +{ + HcomServiceTimer *timer1 = new HcomServiceTimer(); + EXPECT_NE(timer1, nullptr); + HcomServiceTimer *timer2 = new HcomServiceTimer(); + EXPECT_NE(timer2, nullptr); + + HcomServiceTimerCompare compare; + timer1->mTimeout = 2000; + timer2->mTimeout = 1000; + EXPECT_TRUE(compare(timer1, timer2)); + timer1->mTimeout = 2000; + timer2->mTimeout = 2000; + timer1->mSeqNo = 2; + timer2->mSeqNo = 1; + EXPECT_TRUE(compare(timer1, timer2)); + timer1->mTimeout = 1000; + timer2->mTimeout = 2000; + EXPECT_FALSE(compare(timer1, timer2)); + + delete timer2; + delete timer1; +} + +TEST_F(TestHcomServiceV2, TestHexStringToBuff) +{ + std::string input = "1A2B3C4D"; + uint8_t buff[NN_NO4] = {0}; + EXPECT_TRUE(HexStringToBuff(input, NN_NO4, buff)); + EXPECT_EQ(buff[NN_NO0], 0x1A); + EXPECT_EQ(buff[NN_NO1], 0x2B); + EXPECT_EQ(buff[NN_NO2], 0x3C); + EXPECT_EQ(buff[NN_NO3], 0x4D); +} + +TEST_F(TestHcomServiceV2, TestHexStringToBuff2) +{ + uint8_t *buff = nullptr; + EXPECT_FALSE(HexStringToBuff("1A2B3C4D", NN_NO4, buff)); + + uint8_t buff1[NN_NO4] = {0}; + EXPECT_FALSE(HexStringToBuff("1A2B3C", NN_NO4, buff1)); + uint8_t buff2[NN_NO4] = {0}; + EXPECT_FALSE(HexStringToBuff("1A2B3C5", NN_NO4, buff2)); + + uint8_t buff3[NN_NO4] = {0}; + std::string invalidInput = "1G"; // 'G' 不是有效的十六进制字符 + EXPECT_FALSE(HexStringToBuff(invalidInput, NN_NO4, buff3)); +} + +TEST_F(TestHcomServiceV2, TestBuffToHexString) +{ + uint8_t *buff = nullptr; + uint32_t bufferSize = 10; + std::string output; + EXPECT_FALSE(BuffToHexString(buff, bufferSize, output)); + EXPECT_TRUE(output.empty()); +} + +TEST_F(TestHcomServiceV2, TestSerialize) +{ + SerConnInfo connInfo; + std::string payload = "1A2B3C4D"; + std::string out; + EXPECT_EQ(SerConnInfo::Serialize(connInfo, payload, out), SER_OK); +} + +TEST_F(TestHcomServiceV2, TestSerializeFail) +{ + SerConnInfo *connInfo = nullptr; + std::string payload = "TestPayload"; + std::string out; + EXPECT_EQ(SerConnInfo::Serialize(*connInfo, payload, out), SER_ERROR); + EXPECT_TRUE(out.empty()); +} + +TEST_F(TestHcomServiceV2, TestDeserialize) +{ + SerConnInfo connInfo; + std::string payload = "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "0000000000000000DEADBEEF"; // 大于sizeof(SerConnInfo)*2 + std::string userPayload; + + MOCKER_CPP(&SerConnInfo::Validate).stubs().will(returnValue(true)); + EXPECT_EQ(SerConnInfo::Deserialize(payload, connInfo, userPayload), NN_OK); +} + +TEST_F(TestHcomServiceV2, TestDeserializeFail) +{ + SerConnInfo connInfo; + std::string payload1 = "1A2B3C4D"; + std::string userPayload; + EXPECT_EQ(SerConnInfo::Deserialize(payload1, connInfo, userPayload), SER_INVALID_PARAM); + + std::string payload2 = "1A2B"; // 长度不足 sizeof(SerConnInfo) * 2 + EXPECT_EQ(SerConnInfo::Deserialize(payload2, connInfo, userPayload), SER_INVALID_PARAM); + + std::string payload3 = "1A2B3C4DGG"; // 包含无效字符 'GG' + EXPECT_EQ(SerConnInfo::Deserialize(payload3, connInfo, userPayload), SER_INVALID_PARAM); + + std::string payload4 = "00000000FFFFFFFF0000000000000000"; // CRC 校验失败 + EXPECT_EQ(SerConnInfo::Deserialize(payload4, connInfo, userPayload), SER_INVALID_PARAM); +} + +TEST_F(TestHcomServiceV2, TestHcomServiceGlobalObjectInitialize) +{ + EXPECT_EQ(HcomServiceGlobalObject::Initialize(), SER_OK); + EXPECT_TRUE(HcomServiceGlobalObject::gInited); + EXPECT_EQ(HcomServiceGlobalObject::Initialize(), SER_OK); + EXPECT_NE(HcomServiceGlobalObject::gEmptyCallback, nullptr); + HcomServiceGlobalObject::UnInitialize(); + EXPECT_EQ(HcomServiceGlobalObject::gEmptyCallback, nullptr); +} + +TEST_F(TestHcomServiceV2, TestHcomServiceGlobalObjectBuildCtx) +{ + UBSHcomServiceContext ctx; + EXPECT_NO_FATAL_FAILURE(HcomServiceGlobalObject::BuildBrokenCtx(ctx)); +} + +TEST_F(TestHcomServiceV2, TestHcomConnectingEpInfoAllEPBroken) +{ + UBSHcomNetWorkerIndex workerIndex{}; + uint32_t workerIdx = NN_NO4; + uint32_t gIdx = NN_NO6; + uint16_t dIdx = NN_NO8; + workerIndex.Set(workerIdx, gIdx, dIdx); + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(NN_NO100, nullptr, nullptr, workerIndex); + SerConnInfo info{}; + std::string id = "123"; + HcomConnectingEpInfo *epInfo = new (std::nothrow) HcomConnectingEpInfo(id, ep, info); + bool ret = epInfo->AllEPBroken(NN_NO6); + ASSERT_EQ(ret, false); + if (epInfo != nullptr) { + delete epInfo; + epInfo = nullptr; + } +} + +TEST_F(TestHcomServiceV2, TestHcomConnectingEpInfoCompareFail) +{ + HcomConnectingEpInfo *connectChannelInfo = new (std::nothrow) HcomConnectingEpInfo(); + SerConnInfo info; + bool ret; + + // Fail 1 + connectChannelInfo->mConnInfo.version = NN_NO1; + info.version = NN_NO3; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, false); + + // Fail 2 + info.version = NN_NO1; + connectChannelInfo->mConnInfo.channelId = NN_NO1; + info.channelId = NN_NO3; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, false); + + // Fail 3 + info.channelId = NN_NO1; + connectChannelInfo->mConnInfo.policy = UBSHcomChannelBrokenPolicy::BROKEN_ALL; + info.policy = UBSHcomChannelBrokenPolicy::RECONNECT; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, false); + + // Fail 4 + info.policy = UBSHcomChannelBrokenPolicy::BROKEN_ALL; + info.index = NN_NO1; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, false); + + // Fail 5 + info.index = NN_NO0; + connectChannelInfo->mConnInfo.options.linkCount = NN_NO1; + info.options.linkCount = NN_NO3; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, false); + + // Fail 6 + info.options.linkCount = NN_NO1; + connectChannelInfo->mConnInfo.options.cbType = UBSHcomChannelCallBackType::CHANNEL_FUNC_CB; + info.options.cbType = UBSHcomChannelCallBackType::CHANNEL_GLOBAL_CB; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, false); + + // Fail 7 + info.options.cbType = UBSHcomChannelCallBackType::CHANNEL_FUNC_CB; + connectChannelInfo->mConnInfo.options.clientGroupId = NN_NO1; + info.options.clientGroupId = NN_NO3; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, false); + + // Fail 8 + info.options.clientGroupId = NN_NO1; + connectChannelInfo->mConnInfo.options.serverGroupId = NN_NO1; + info.options.serverGroupId = NN_NO3; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, false); + + // Success + info.options.serverGroupId = NN_NO1; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, true); + + if (connectChannelInfo != nullptr) { + delete connectChannelInfo; + connectChannelInfo = nullptr; + } +} + +TEST_F(TestHcomServiceV2, TestHcomConnectingEpInfoCompare) +{ + HcomConnectingEpInfo *connectChannelInfo = new (std::nothrow) HcomConnectingEpInfo(); + SerConnInfo info; + bool ret; + + connectChannelInfo->mConnInfo.version = NN_NO1; + connectChannelInfo->mConnInfo.channelId = NN_NO1; + connectChannelInfo->mConnInfo.policy = UBSHcomChannelBrokenPolicy::BROKEN_ALL; + connectChannelInfo->mConnInfo.options.linkCount = NN_NO1; + connectChannelInfo->mConnInfo.options.cbType = UBSHcomChannelCallBackType::CHANNEL_FUNC_CB; + connectChannelInfo->mConnInfo.options.clientGroupId = NN_NO1; + connectChannelInfo->mConnInfo.options.serverGroupId = NN_NO1; + info.version = NN_NO1; + info.channelId = NN_NO1; + info.policy = UBSHcomChannelBrokenPolicy::BROKEN_ALL; + info.index = NN_NO0; + info.options.linkCount = NN_NO1; + info.options.cbType = UBSHcomChannelCallBackType::CHANNEL_FUNC_CB; + info.options.clientGroupId = NN_NO1; + info.options.serverGroupId = NN_NO1; + ret = connectChannelInfo->Compare(info); + ASSERT_EQ(ret, true); + + if (connectChannelInfo != nullptr) { + delete connectChannelInfo; + connectChannelInfo = nullptr; + } +} + +TEST_F(TestHcomServiceV2, TestMemoryRegion) +{ + UBSHcomRegMemoryRegion region{}; + UBSHcomMemoryKey key{}; + EXPECT_NO_FATAL_FAILURE(region.GetMemoryKey(key)); + EXPECT_EQ(region.GetAddress(), 0); + EXPECT_EQ(region.GetSize(), 0); +} + +TEST_F(TestHcomServiceV2, TestGetServiceTransNeedPostedCall) +{ + SerTransContext ctxData {}; + char *ctx = reinterpret_cast(&ctxData); + EXPECT_EQ(GetServiceTransNeedPostedCall(ctx), true); +} + +TEST_F(TestHcomServiceV2, TestIsNeedInvokeCallback) +{ + UBSHcomRequestContext ctx{}; + MOCKER_CPP(&GetServiceTransNeedPostedCall) + .stubs() + .will(returnValue(true)); + EXPECT_EQ(IsNeedInvokeCallback(ctx), false); + ctx.mResult = 1; + ctx.mOpType = UBSHcomRequestContext::NN_SENT; + EXPECT_EQ(IsNeedInvokeCallback(ctx), true); + ctx.mOpType = UBSHcomRequestContext::NN_SENT_RAW_SGL; + EXPECT_EQ(IsNeedInvokeCallback(ctx), true); + ctx.mOpType = UBSHcomRequestContext::NN_INVALID_OP_TYPE; + EXPECT_EQ(IsNeedInvokeCallback(ctx), true); +} + +TEST_F(TestHcomServiceV2, TestSetTraceIdInner) +{ +#ifdef UB_BUILD_ENABLED + MOCKER(HcomUrma::IsLoaded).stubs().will(returnValue(false)).then(returnValue(true)); + std::string traceId = "This is a test trace id"; + + EXPECT_NO_FATAL_FAILURE(SetTraceIdInner(traceId)); + EXPECT_NO_FATAL_FAILURE(SetTraceIdInner(traceId)); +#endif +} +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/service_v2/test_net_channel_imp.cpp b/test/unit_test/service_v2/test_net_channel_imp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c202571b9a51aa816762c10b0e575256b0723df --- /dev/null +++ b/test/unit_test/service_v2/test_net_channel_imp.cpp @@ -0,0 +1,1263 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include + +#include "hcom.h" +#include "service_channel_imp.h" +#include "net_rdma_async_endpoint.h" +#include "under_api/urma/urma_api_wrapper.h" + +namespace ock { +namespace hcom { +class TestNetChannelImp : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + +private: + UBSHcomService *service = nullptr; + HcomChannelImp *channel = nullptr; + char *data = nullptr; + int32_t dataSize = 1024; + std::vector epVector; + NetMemPoolFixedPtr ctxMemPool = nullptr; + HcomServiceCtxStorePtr mCtxStore = nullptr; + HcomPeriodicManagerPtr mPeriodicMgr = nullptr; + NetPgTablePtr mPgtable = nullptr; + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = nullptr; + NetMemPoolFixedOptions options = {}; + UBSHcomFlowCtrlOptions ctrlOptions{}; +}; + +void TestNetChannelImp::SetUp() +{ + uint64_t id = NN_NO60; + bool selfPoll = true; + InnerConnectOptions connectOptions{}; + channel = new HcomChannelImp(id, selfPoll, connectOptions); + ASSERT_NE(channel, nullptr); + channel->SetChannelTimeOut(0, 0); + + data = new char[dataSize]; + ASSERT_NE(data, nullptr); + + ctxMemPool = new (std::nothrow) NetMemPoolFixed("test", options); + ASSERT_NE(ctxMemPool, nullptr); + mCtxStore = new (std::nothrow) HcomServiceCtxStore(1, ctxMemPool, UBSHcomNetDriverProtocol::RDMA); + ASSERT_NE(mCtxStore, nullptr); + mPgtable = new NetPgTable(HcomServiceImp::pgdAlloc, HcomServiceImp::pgdFree); + ASSERT_NE(mPgtable, nullptr); + + mPeriodicMgr = new (std::nothrow) HcomPeriodicManager(1, "mOptions.name"); + epVector.reserve(1); + workerIndex.Set(NN_NO4, NN_NO6, NN_NO8); + ep = new (std::nothrow) NetAsyncEndpoint(NN_NO100, nullptr, nullptr, workerIndex); + epVector.emplace_back(ep); +} + +void TestNetChannelImp::TearDown() +{ + if (data != nullptr) { + delete[] data; + data = nullptr; + } + + if (channel != nullptr) { + delete channel; + channel = nullptr; + } + + GlobalMockObject::verify(); +} + +TEST_F(TestNetChannelImp, TestSendFail) +{ + UBSHcomRequest req(data, sizeof(data), 0); + MOCKER_CPP(&HcomChannelImp::FlowControl) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Send(req, nullptr), SER_INVALID_PARAM); + + MOCKER_CPP(&HcomChannelImp::SendInner) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_ERROR))); + ASSERT_EQ(channel->Send(req, nullptr), SER_NEW_OBJECT_FAILED); + + ASSERT_EQ(channel->Send(req, nullptr), SER_ERROR); +} + +TEST_F(TestNetChannelImp, TestSendOK) +{ + UBSHcomRequest req(data, sizeof(data), 0); + MOCKER_CPP(&HcomChannelImp::FlowControl).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomChannelImp::SendInner).stubs().will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Send(req, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestSendInner) +{ + UBSHcomRequest req(data, sizeof(data), 0); + ASSERT_EQ(channel->SendInner(req, nullptr), SER_NOT_ESTABLISHED); + + Callback *callback = UBSHcomNewCallback([] + (UBSHcomServiceContext &context) { ASSERT_EQ(context.Result(), 0); }, std::placeholders::_1); + ASSERT_EQ(channel->SendInner(req, callback), SER_INVALID_PARAM); + + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + channel->SetChannelState(CH_ESTABLISHED); + ASSERT_EQ(channel->SendInner(req, nullptr), NN_EP_NOT_ESTABLISHED); +} + +TEST_F(TestNetChannelImp, TestCallFail) +{ + UBSHcomRequest req(data, dataSize, 1); + UBSHcomResponse rsp(data, dataSize); + MOCKER_CPP(&HcomChannelImp::FlowControl) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Call(req, rsp, nullptr), SER_INVALID_PARAM); + + MOCKER_CPP(&HcomChannelImp::CallInner) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_ERROR))); + ASSERT_EQ(channel->Call(req, rsp, nullptr), SER_NEW_OBJECT_FAILED); + + ASSERT_EQ(channel->Call(req, rsp, nullptr), SER_ERROR); +} + +TEST_F(TestNetChannelImp, TestCallOK) +{ + UBSHcomRequest req(data, dataSize, 1); + UBSHcomResponse rsp(data, dataSize); + MOCKER_CPP(&HcomChannelImp::FlowControl).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomChannelImp::CallInner).stubs().will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Call(req, rsp, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestCallInner) +{ + UBSHcomRequest req(data, dataSize, 1); + UBSHcomResponse rsp(data, dataSize); + ASSERT_EQ(channel->CallInner(req, rsp, nullptr), SER_NOT_ESTABLISHED); + + int32_t ret = 0; + sem_t sem; + sem_init(&sem, 0, 0); + Callback *callback = UBSHcomNewCallback( + [&sem, &ret, &rsp](UBSHcomServiceContext &context) { + ASSERT_EQ(context.Result(), 0); + memcpy_s(rsp.address, rsp.size, context.MessageData(), context.MessageDataLen()); + sem_post(&sem); + }, + std::placeholders::_1); + ASSERT_EQ(channel->CallInner(req, rsp, callback), SER_INVALID_PARAM); + + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + channel->SetChannelState(CH_ESTABLISHED); + ASSERT_EQ(channel->CallInner(req, rsp, nullptr), NN_EP_NOT_ESTABLISHED); +} + +TEST_F(TestNetChannelImp, TestReplyFail) +{ + UBSHcomReplyContext ctx; + UBSHcomRequest req(data, dataSize, 0); + MOCKER_CPP(&HcomChannelImp::FlowControl) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Reply(ctx, req, nullptr), SER_INVALID_PARAM); + + MOCKER_CPP(&HcomChannelImp::ReplyInner) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_ERROR))); + ASSERT_EQ(channel->Reply(ctx, req, nullptr), SER_NEW_OBJECT_FAILED); + + ASSERT_EQ(channel->Reply(ctx, req, nullptr), SER_ERROR); +} + +TEST_F(TestNetChannelImp, TestReplyOK) +{ + UBSHcomReplyContext ctx; + UBSHcomRequest req(data, dataSize, 0); + MOCKER_CPP(&HcomChannelImp::FlowControl).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomChannelImp::ReplyInner).stubs().will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Reply(ctx, req, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestReplyInner) +{ + UBSHcomReplyContext ctx; + UBSHcomRequest req(data, dataSize, 0); + ASSERT_EQ(channel->ReplyInner(ctx, req, nullptr), SER_NOT_ESTABLISHED); + + Callback *callback = UBSHcomNewCallback([] + (UBSHcomServiceContext &context) { ASSERT_EQ(context.Result(), 0); }, std::placeholders::_1); + ASSERT_EQ(channel->ReplyInner(ctx, req, callback), SER_NOT_ESTABLISHED); + + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + channel->SetChannelState(CH_ESTABLISHED); + ASSERT_EQ(channel->ReplyInner(ctx, req, nullptr), SER_NEW_OBJECT_FAILED); +} + +TEST_F(TestNetChannelImp, TestPutFail) +{ + UBSHcomOneSideRequest req{}; + req.lAddress = reinterpret_cast(data); + req.size = dataSize; + + MOCKER_CPP(&HcomChannelImp::FlowControl) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Put(req, nullptr), SER_INVALID_PARAM); + + MOCKER_CPP(&HcomChannelImp::OneSideInner) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_ERROR))); + ASSERT_EQ(channel->Put(req, nullptr), SER_NEW_OBJECT_FAILED); + + ASSERT_EQ(channel->Put(req, nullptr), SER_ERROR); +} + +TEST_F(TestNetChannelImp, TestPutOK) +{ + UBSHcomOneSideRequest req{}; + req.lAddress = reinterpret_cast(data); + req.size = dataSize; + + MOCKER_CPP(&HcomChannelImp::FlowControl).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomChannelImp::OneSideInner).stubs().will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Put(req, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestGetFail) +{ + UBSHcomOneSideRequest req{}; + req.lAddress = reinterpret_cast(data); + req.size = dataSize; + + MOCKER_CPP(&HcomChannelImp::FlowControl) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Get(req, nullptr), SER_INVALID_PARAM); + + MOCKER_CPP(&HcomChannelImp::OneSideInner) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_ERROR))); + ASSERT_EQ(channel->Get(req, nullptr), SER_NEW_OBJECT_FAILED); + + ASSERT_EQ(channel->Get(req, nullptr), SER_ERROR); +} + +TEST_F(TestNetChannelImp, TestGetOK) +{ + UBSHcomOneSideRequest req{}; + req.lAddress = reinterpret_cast(data); + req.size = dataSize; + + MOCKER_CPP(&HcomChannelImp::FlowControl).stubs().will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomChannelImp::OneSideInner).stubs().will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Get(req, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestOneSideInner) +{ + UBSHcomOneSideRequest req{}; + req.lAddress = reinterpret_cast(data); + req.size = dataSize; + ASSERT_EQ(channel->OneSideInner(req, nullptr, true), SER_NOT_ESTABLISHED); + + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + channel->SetChannelState(CH_ESTABLISHED); + ASSERT_EQ(channel->OneSideInner(req, nullptr, true), NN_EP_NOT_ESTABLISHED); +} + +TEST_F(TestNetChannelImp, TestInitialize) +{ + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + + ASSERT_EQ(channel->ToString(), "Connect channel id " + std::to_string(NN_NO60) + " with 1 eps :[100]"); + ASSERT_EQ(channel->SetFlowControlConfig(ctrlOptions), SER_OK); + channel->SetChannelTimeOut(1, 1); + channel->UnInitialize(); +} + +TEST_F(TestNetChannelImp, TestInitializeFail) +{ + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), 0, 0), SER_INVALID_PARAM); +} + +TEST_F(TestNetChannelImp, TestOthers) +{ + channel->SetUuid("1"); + ASSERT_EQ(channel->GetUuid(), "1"); + ASSERT_EQ(channel->GetId(), NN_NO60); + ASSERT_EQ(channel->GetTimerList(), 0); + ASSERT_EQ(channel->GetDelayEraseTime(), NN_NO1); + channel->mOptions.brokenPolicy = UBSHcomChannelBrokenPolicy::RECONNECT; + ASSERT_EQ(channel->GetDelayEraseTime(), NN_NO60); +} + +TEST_F(TestNetChannelImp, TestOthers1) +{ + ASSERT_EQ(channel->GetCtxStore(), nullptr); + ASSERT_EQ(channel->GetCallBackType(), UBSHcomChannelCallBackType::CHANNEL_FUNC_CB); +} +TEST_F(TestNetChannelImp, TestNextWorkerPollEp) +{ + UBSHcomNetEndpoint *nextEp = nullptr; + ASSERT_EQ(channel->NextWorkerPollEp(nextEp, 0), SER_NOT_ESTABLISHED); + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + ASSERT_EQ(channel->NextWorkerPollEp(nextEp, 0), SER_OK); +} + +TEST_F(TestNetChannelImp, TestPrepareTimerCtx) +{ + HcomServiceTimer *serviceTimer = new HcomServiceTimer(); + MOCKER_CPP(&HcomServiceCtxStore::GetCtxObj).stubs().will(returnValue(serviceTimer)); + MOCKER_CPP(&HcomServiceCtxStore::PutAndGetSeqNo) + .stubs() + .will(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomServiceCtxStore::Return).stubs(); + MOCKER_CPP(&HcomPeriodicManager::AddTimer).stubs().will(returnValue(static_cast(SER_OK))); + + TimerCtx TimerCtx {}; + ASSERT_EQ(channel->PrepareTimerContext(nullptr, 0, TimerCtx), 0); + delete serviceTimer; +} + +TEST_F(TestNetChannelImp, TestDestroyTimerCtx) +{ + HcomServiceTimer *timer = new HcomServiceTimer(); + TimerCtx TimerCtx {}; + TimerCtx.timer = timer; + TimerCtx.timer->IncreaseRef(); + MOCKER_CPP(&HcomServiceTimer::EraseSeqNoWithRet).stubs().will(returnValue(false)).then(returnValue(true)); + EXPECT_NO_FATAL_FAILURE(channel->DestroyTimerContext(TimerCtx)); + EXPECT_NO_FATAL_FAILURE(channel->DestroyTimerContext(TimerCtx)); + delete timer; +} + +SerResult MockPrepareTimerCtx(Callback *cb, int16_t timeout, TimerCtx &context) +{ + if (cb != nullptr) { + UBSHcomServiceContext ctx {}; + ctx.mResult = SER_OK; + cb->Run(ctx); + } + return SER_OK; +} + +TEST_F(TestNetChannelImp, TestSyncSendInner) +{ + channel->mOptions.selfPoll = false; + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(invoke(MockPrepareTimerCtx)); + + UBSHcomRequest req(data, sizeof(data), 0); + ASSERT_EQ(channel->SyncSendInner(req), SER_NEW_OBJECT_FAILED); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->SyncSendInner(req), SER_OK); + + MOCKER_CPP(&HcomChannelImp::RndvInner).stubs().will(returnValue(static_cast(SER_OK))); + channel->mRndvThreshold = NN_NO10; + ASSERT_EQ(channel->SyncSendInner(req), SER_OK); +} + +TEST_F(TestNetChannelImp, TestAsyncSendInner) +{ + channel->mOptions.selfPoll = false; + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + MOCKER_CPP(&HcomChannelImp::DestroyTimerContext).stubs(); + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_OK))); + UBSHcomRequest req(data, sizeof(data), 0); + ASSERT_EQ(channel->AsyncSendInner(req, nullptr), SER_NEW_OBJECT_FAILED); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->AsyncSendInner(req, nullptr), SER_OK); + + MOCKER_CPP(&HcomChannelImp::RndvInner).stubs().will(returnValue(static_cast(SER_OK))); + channel->mRndvThreshold = NN_NO10; + ASSERT_EQ(channel->AsyncSendInner(req, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestSyncSendWithSelfPoll) +{ + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::WaitCompletion, SerResult(UBSHcomNetEndpoint::*)(int32_t)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + channel->mOptions.selfPoll = true; + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + UBSHcomRequest req(data, sizeof(data), 0); + ASSERT_EQ(channel->SyncSendWithSelfPoll(req), SER_OK); +} + +void MockSyncCallCbForWorkerPoll(UBSHcomServiceContext &context, UBSHcomResponse *rsp, + HcomServiceSelfSyncParam *syncParam) +{ + syncParam->Result(SER_OK); + syncParam->Signal(); +} + +TEST_F(TestNetChannelImp, TestSyncCallInner) +{ + channel->mOptions.selfPoll = false; + UBSHcomResponse rsp{}; + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(invoke(MockPrepareTimerCtx)); + + UBSHcomRequest req(data, sizeof(data), 0); + ASSERT_EQ(channel->SyncCallInner(req, rsp), SER_NEW_OBJECT_FAILED); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->SyncCallInner(req, rsp), SER_INVALID_PARAM); + + MOCKER_CPP(&HcomChannelImp::RndvInner).stubs().will(returnValue(static_cast(SER_OK))); + channel->mRndvThreshold = NN_NO10; + ASSERT_EQ(channel->SyncCallInner(req, rsp), SER_INVALID_PARAM); +} + +TEST_F(TestNetChannelImp, TestSyncCallWithSelfPoll) +{ + channel->mOptions.selfPoll = true; + UBSHcomResponse rsp{}; + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + UBSHcomRequest req(data, sizeof(data), 0); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::WaitCompletion, SerResult(UBSHcomNetEndpoint::*)(int32_t)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::Receive, + SerResult(UBSHcomNetEndpoint::*)(int32_t, UBSHcomNetResponseContext &)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))); + ASSERT_EQ(channel->SyncCallWithSelfPoll(req, rsp), SER_INVALID_PARAM); +} + +TEST_F(TestNetChannelImp, TestAsyncCallInner) +{ + channel->mOptions.selfPoll = false; + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomChannelImp::DestroyTimerContext).stubs(); + UBSHcomRequest req(data, sizeof(data), 0); + ASSERT_EQ(channel->AsyncCallInner(req, nullptr), SER_NEW_OBJECT_FAILED); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->AsyncCallInner(req, nullptr), SER_OK); + + MOCKER_CPP(&HcomChannelImp::RndvInner).stubs().will(returnValue(static_cast(SER_OK))); + channel->mRndvThreshold = NN_NO10; + ASSERT_EQ(channel->AsyncCallInner(req, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestPrepareCallback) +{ + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(invoke(MockPrepareTimerCtx)); + HcomServiceSelfSyncParam syncParam {}; + TimerCtx syncContext {}; + ASSERT_EQ(channel->PrepareCallback(syncParam, syncContext), SER_NEW_OBJECT_FAILED); + ASSERT_EQ(channel->PrepareCallback(syncParam, syncContext), SER_OK); +} + +TEST_F(TestNetChannelImp, TestOneSideSyncWithWorkerPoll) +{ + channel->mOptions.selfPoll = false; + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + UBSHcomOneSideRequest request {}; + request.lAddress = reinterpret_cast(data); + request.size = dataSize; + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext).stubs().will(invoke(MockPrepareTimerCtx)); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostWrite, + SerResult(UBSHcomNetEndpoint::*)(const UBSHcomNetTransRequest &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostRead, + SerResult(UBSHcomNetEndpoint::*)(const UBSHcomNetTransRequest &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->OneSideSyncWithWorkerPoll(request, true), SER_OK); + ASSERT_EQ(channel->OneSideSyncWithWorkerPoll(request, false), SER_OK); +} + +TEST_F(TestNetChannelImp, TestOneSideAsyncWithWorkerPoll) +{ + channel->mOptions.selfPoll = false; + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + UBSHcomOneSideRequest req {}; + req.lAddress = reinterpret_cast(data); + req.size = dataSize; + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext).stubs().will(invoke(MockPrepareTimerCtx)); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostWrite, + SerResult(UBSHcomNetEndpoint::*)(const UBSHcomNetTransRequest &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostRead, + SerResult(UBSHcomNetEndpoint::*)(const UBSHcomNetTransRequest &)) + .stubs() + .will(returnValue(static_cast(SER_OK))); + Callback *callback = UBSHcomNewCallback([] + (UBSHcomServiceContext &context) { ASSERT_EQ(context.Result(), 0); }, std::placeholders::_1); + ASSERT_EQ(channel->OneSideAsyncWithWorkerPoll(req, callback, true), SER_OK); + callback = UBSHcomNewCallback([] + (UBSHcomServiceContext &context) { ASSERT_EQ(context.Result(), 0); }, std::placeholders::_1); + ASSERT_EQ(channel->OneSideAsyncWithWorkerPoll(req, callback, false), SER_OK); + + MOCKER_CPP(&HcomChannelImp::NextWorkerPollEp).stubs().will(returnValue(static_cast(SER_ERROR))); + ASSERT_EQ(channel->OneSideAsyncWithWorkerPoll(req, callback, false), SER_ERROR); +} + +TEST_F(TestNetChannelImp, TestRecvFail) +{ + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + UBSHcomServiceContext context{}; + UBSHcomRequest req(data, sizeof(data), 0); + HcomServiceRndvMessage rndvMessage(NN_NO2, req); + context.mData = reinterpret_cast(&rndvMessage); + + context.mDataLen = sizeof(HcomServiceRndvMessage) - NN_NO1; + uintptr_t address = 0; + uint32_t size = NN_NO16; + ASSERT_EQ(channel->Recv(context, address, size), SER_ERROR); + + context.mDataLen = sizeof(HcomServiceRndvMessage); + ASSERT_EQ(channel->Recv(context, address, size), SER_ERROR); + + address = reinterpret_cast(data); + size = sizeof(data); + + PgtRegion *pgtRegion = nullptr; + PgtRegion pgtRegion2{}; + pgtRegion2.start = reinterpret_cast(data); + pgtRegion2.end = reinterpret_cast(data) + sizeof(data); + MOCKER_CPP(&PgTable::Lookup).stubs().will(returnValue(pgtRegion)).then(returnValue(&pgtRegion2)); + MOCKER_CPP(&HcomServiceRndvMessage::IsTimeout).stubs().will(returnValue(false)); + ASSERT_EQ(channel->Recv(context, address, size), SER_ERROR); + + MOCKER_CPP_VIRTUAL(*channel, &HcomChannelImp::Get, + int32_t(HcomChannelImp::*)(const UBSHcomOneSideRequest &, const Callback *)) + .stubs() + .will(returnValue(static_cast(SER_ERROR))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->Recv(context, address, size), SER_ERROR); + + ASSERT_EQ(channel->Recv(context, address, size), SER_OK); +} + +TEST_F(TestNetChannelImp, TestRndvInnerFail) +{ + UBSHcomTwoSideThreshold threshold{}; + threshold.rndvThreshold = NN_NO1024; + channel->SetTwoSideThreshold(threshold); + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + UBSHcomRequest req(data, sizeof(data), 0); + UBSHcomNetTransOpInfo transOp{}; + + PgtRegion *pgtRegion = nullptr; + PgtRegion pgtRegion2{}; + pgtRegion2.start = reinterpret_cast(data); + pgtRegion2.end = reinterpret_cast(data) + sizeof(data) - NN_NO1; + MOCKER_CPP(&PgTable::Lookup).stubs().will(returnValue(pgtRegion)).then(returnValue(&pgtRegion2)); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &)) + .stubs() + .will(returnValue(static_cast(SER_OK))) + .then(returnValue(static_cast(SER_ERROR))); + + ASSERT_EQ(channel->RndvInner(ep.Get(), req, transOp, true), SER_OK); + + ASSERT_EQ(channel->RndvInner(ep.Get(), req, transOp, false), SER_ERROR); + + threshold.rndvThreshold = UINT32_MAX; + channel->SetTwoSideThreshold(threshold); +} +TEST_F(TestNetChannelImp, TestFlowControl) +{ + ASSERT_EQ(channel->FlowControl(0, 0, 0), SER_OK); + RateLimiter *limiter = new (std::nothrow) RateLimiter; + ASSERT_NE(limiter, nullptr); + channel->mOptions.rateLimit = reinterpret_cast(limiter); + MOCKER_CPP(&RateLimiter::AcquireQuota).stubs().will(returnValue(true)).then(returnValue(false)); + ASSERT_EQ(channel->FlowControl(0, 0, 0), SER_OK); + MOCKER_CPP(&RateLimiter::InvalidateSize).stubs().will(returnValue(true)).then(returnValue(false)); + + ASSERT_EQ(channel->FlowControl(0, 0, 0), SER_INVALID_PARAM); + MOCKER_CPP(&RateLimiter::WaitUntilNextWindow).stubs(); + MOCKER_CPP(&RateLimiter::BuildNextWindow).stubs(); + ASSERT_EQ(channel->FlowControl(0, 0, 0), SER_TIMEOUT); + channel->mOptions.rateLimit = 0; + delete limiter; +} + +TEST_F(TestNetChannelImp, TestAcquireQuotaFalse) +{ + RateLimiter *limiter = new (std::nothrow) RateLimiter; + ASSERT_NE(limiter, nullptr); + limiter->windowPassedByte = UINT64_MAX; + limiter->thresholdByte = UINT64_MAX; + auto ret = limiter->AcquireQuota(NN_NO1024); + ASSERT_EQ(ret, false); + delete limiter; +} + +TEST_F(TestNetChannelImp, TestAcquireQuotaSuccess) +{ + RateLimiter *limiter = new (std::nothrow) RateLimiter; + ASSERT_NE(limiter, nullptr); + limiter->windowPassedByte = NN_NO1024; + limiter->thresholdByte = UINT64_MAX; + auto ret = limiter->AcquireQuota(NN_NO1024); + ASSERT_EQ(ret, true); + delete limiter; +} + +TEST_F(TestNetChannelImp, TestSetFlowControlConfig) +{ + UBSHcomFlowCtrlOptions opt {}; + ASSERT_EQ(channel->SetFlowControlConfig(opt), SER_NOT_ESTABLISHED); + RateLimiter *limiter = new (std::nothrow) RateLimiter; + ASSERT_NE(limiter, nullptr); + channel->mOptions.rateLimit = reinterpret_cast(limiter); + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + ASSERT_EQ(channel->SetFlowControlConfig(opt), SER_OK); + channel->mOptions.rateLimit = 0; + delete limiter; +} + +TEST_F(TestNetChannelImp, TestAllEpBroken) +{ + EpInfo *info = new (std::nothrow) EpInfo; + channel->mEpInfo = info; + ASSERT_EQ(channel->AllEpBroken(), true); + delete info; + channel->mEpInfo = nullptr; + + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + ASSERT_EQ(channel->AllEpBroken(), false); +} + +TEST_F(TestNetChannelImp, TestNeedProcessBroken) +{ + ASSERT_EQ(channel->NeedProcessBroken(), true); +} + +TEST_F(TestNetChannelImp, TestInvokeChannelBrokenCb) +{ + UBSHcomChannelPtr chPtr = channel; + chPtr->IncreaseRef(); + + EXPECT_NO_FATAL_FAILURE(channel->InvokeChannelBrokenCb(chPtr)); + channel->mOptions.brokenHandler = [](const UBSHcomChannelPtr &ch) { + printf("enter cb\n"); + return 0; + }; + EXPECT_NO_FATAL_FAILURE(channel->InvokeChannelBrokenCb(chPtr)); +} + +TEST_F(TestNetChannelImp, TestProcessIoInBroken) +{ + ASSERT_EQ(channel->Initialize(epVector, reinterpret_cast(ctxMemPool.Get()), + reinterpret_cast(mPeriodicMgr.Get()), reinterpret_cast(mPgtable.Get())), + SER_OK); + EXPECT_NO_FATAL_FAILURE(channel->ProcessIoInBroken()); +} + +TEST_F(TestNetChannelImp, TestCalculateOffsetAndSize) +{ + UBSHcomOneSideRequest request {}; + request.size = NN_NO1024; + uint32_t remain = 0; + uint32_t offset = 0; + uint32_t size = 0; + channel->mOptions.enableMultiRail = true; + channel->mOptions.multiRailThresh = NN_NO1; + EXPECT_NO_FATAL_FAILURE(channel->CalculateOffsetAndSize(request, ep.Get(), remain, offset, size)); +} + +TEST_F(TestNetChannelImp, TestGetRemoteUdsIdInfo) +{ + UBSHcomNetUdsIdInfo info {}; + EXPECT_EQ(channel->GetRemoteUdsIdInfo(info), static_cast(SER_ERROR)); + + EpInfo *epInfo = new (std::nothrow) EpInfo; + ASSERT_NE(epInfo, nullptr); + channel->mEpInfo = epInfo; + EXPECT_EQ(channel->GetRemoteUdsIdInfo(info), static_cast(SER_ERROR)); + + channel->mEpInfo->epArr[0] = ep.Get(); + EXPECT_EQ(channel->GetRemoteUdsIdInfo(info), static_cast(NN_EP_NOT_ESTABLISHED)); + + channel->mEpInfo->epArr[0] = nullptr; + channel->mEpInfo = nullptr; + delete epInfo; +} + +TEST_F(TestNetChannelImp, TestSendFds) +{ + EXPECT_EQ(channel->SendFds(nullptr, 0), static_cast(SER_ERROR)); + + EpInfo *epInfo = new (std::nothrow) EpInfo; + ASSERT_NE(epInfo, nullptr); + channel->mEpInfo = epInfo; + EXPECT_EQ(channel->SendFds(nullptr, 0), static_cast(SER_ERROR)); + + channel->mEpInfo->epArr[0] = ep.Get(); + EXPECT_EQ(channel->SendFds(nullptr, 0), static_cast(NN_EXCHANGE_FD_NOT_SUPPORT)); + + channel->mEpInfo->epArr[0] = nullptr; + channel->mEpInfo = nullptr; + delete epInfo; +} + +TEST_F(TestNetChannelImp, TestReceiveFds) +{ + EXPECT_EQ(channel->ReceiveFds(nullptr, 0, 0), static_cast(SER_ERROR)); + + EpInfo *epInfo = new (std::nothrow) EpInfo; + ASSERT_NE(epInfo, nullptr); + channel->mEpInfo = epInfo; + EXPECT_EQ(channel->ReceiveFds(nullptr, 0, 0), static_cast(SER_ERROR)); + + channel->mEpInfo->epArr[0] = ep.Get(); + EXPECT_EQ(channel->ReceiveFds(nullptr, 0, 0), static_cast(NN_EXCHANGE_FD_NOT_SUPPORT)); + + channel->mEpInfo->epArr[0] = nullptr; + channel->mEpInfo = nullptr; + delete epInfo; +} + +void MockReleaseSelfPollEp(uint32_t index) {} + +TEST_F(TestNetChannelImp, TestSyncSendSplitWithWorkerPoll) +{ + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(invoke(MockPrepareTimerCtx)); + MOCKER_CPP(&HcomChannelImp::DestroyTimerContext).stubs(); + + UBSHcomRequest req(data, NN_NO65536, 0); + auto tmpEp = ep.Get(); + ASSERT_EQ(channel->SyncSendSplitWithWorkerPoll(tmpEp, req, 1), SER_NEW_OBJECT_FAILED); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &, + const UBSHcomExtHeaderType, const void *, uint32_t)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->SyncSendSplitWithWorkerPoll(tmpEp, req, 1), SER_INVALID_PARAM); + ASSERT_EQ(channel->SyncSendSplitWithWorkerPoll(tmpEp, req, 1), SER_OK); +} + +TEST_F(TestNetChannelImp, TestSyncSendSplitWithSelfPoll) +{ + MOCKER_CPP(&HcomChannelImp::ReleaseSelfPollEp).stubs().will(invoke(MockReleaseSelfPollEp)); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &, + const UBSHcomExtHeaderType, const void *, uint32_t)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + UBSHcomRequest req(data, NN_NO65536, 0); + auto tmpEp = ep.Get(); + ASSERT_EQ(channel->SyncSendSplitWithSelfPoll(tmpEp, req, 1, 0), SER_INVALID_PARAM); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::WaitCompletion, SerResult(UBSHcomNetEndpoint::*)(int32_t)) + .stubs() + .will(returnValue(static_cast(SER_OK))) + .then(returnValue(static_cast(SER_INVALID_PARAM))); + + ASSERT_EQ(channel->SyncSendSplitWithSelfPoll(tmpEp, req, 1, 0), SER_OK); + ASSERT_EQ(channel->SyncSendSplitWithSelfPoll(tmpEp, req, 1, 0), SER_INVALID_PARAM); +} + +TEST_F(TestNetChannelImp, TestAsyncSendSplitWithWorkerPoll) +{ + MOCKER_CPP(&HcomChannelImp::DestroyTimerContext).stubs(); + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_OK))); + UBSHcomRequest req(data, NN_NO65536, 0); + auto tmpEp = ep.Get(); + ASSERT_EQ(channel->AsyncSendSplitWithWorkerPoll(tmpEp, req, 1, nullptr), SER_NEW_OBJECT_FAILED); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &, + const UBSHcomExtHeaderType, const void *, uint32_t)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->AsyncSendSplitWithWorkerPoll(tmpEp, req, 1, nullptr), SER_INVALID_PARAM); + ASSERT_EQ(channel->AsyncSendSplitWithWorkerPoll(tmpEp, req, 1, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestSyncSpliceMessage) +{ + auto tmpEp = ep.Get(); + std::string acc; + void *data; + uint32_t dataLen; + UBSHcomNetResponseContext ctx; + UBSHcomNetTransHeader mHeader; + mHeader.extHeaderType = UBSHcomExtHeaderType::RAW; + + ctx.mHeader = mHeader; + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::Receive, + SerResult(UBSHcomNetEndpoint::*)(int32_t, UBSHcomNetResponseContext&)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))); + + ASSERT_EQ(SyncSpliceMessage(ctx, tmpEp, 1, acc, data, dataLen), SER_INVALID_PARAM); +} + +// SpliceMessage +std::shared_ptr CreateNRC() +{ + alignas(UBSHcomNetMessage) static char msg_buf[sizeof(UBSHcomNetMessage)]; + auto msg = reinterpret_cast(msg_buf); + + alignas(UBSHcomNetEndpoint) static char ep_buf[sizeof(UBSHcomNetEndpoint)]; + auto p = reinterpret_cast(ep_buf); + // Since we use static memory, no need to free ep_buf. + p->IncreaseRef(); + + auto sp = std::make_shared(); + sp->mMessage = msg; + sp->mEp = p; + return sp; +} + +// payloadLen = 1 +UBSHcomFragmentHeader *GetFragmentHeader(int msgId, int totalLength, int offset) +{ + alignas(UBSHcomFragmentHeader) static char buf[sizeof(UBSHcomFragmentHeader) + 1]; + + auto f = reinterpret_cast(buf); + f->msgId = {0, msgId}; + f->totalLength = totalLength; + f->offset = offset; + + return f; +} + +// Typically the payload pointer refers the GetFragmentHeader::buf. +template void SetNRCPayload(UBSHcomNetRequestContext &ctx, T *payload, + uint32_t sz = sizeof(UBSHcomFragmentHeader) + 1) +{ + ctx.mMessage->mBuf = payload; + ctx.mMessage->mDataLen = sz; +} + +std::shared_ptr GetNetAsyncEndpoint() +{ + UBSHcomNetWorkerIndex idx; + auto sp = std::make_shared(0xdead, nullptr, nullptr, idx); + return sp; +} + +TEST_F(TestNetChannelImp, TestSpliceMessageMsgInvalid) +{ + auto ctx = CreateNRC(); + SetNRCPayload(*ctx, (void *)nullptr, 0); + + SpliceMessageResultType result; + SerResult code; + std::string out; + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::ERROR); + EXPECT_EQ(code, SER_SPLIT_INVALID_MSG); +} + +TEST_F(TestNetChannelImp, TestSpliceMessageFirstFragmentLost) +{ + auto ctx = CreateNRC(); + + // offset = 1, the first fragment (offset=0) is lost. + auto fh = GetFragmentHeader(0x11, 2, 1); + SetNRCPayload(*ctx, fh); + + SpliceMessageResultType result; + SerResult code; + std::string out; + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::ERROR); + EXPECT_EQ(code, SER_ERROR); +} + +namespace internal { +void DoMockThen(mockcpp::MoreStubBuilder<> *builder) +{ +} + +template void DoMockThen(mockcpp::MoreStubBuilder<> *builder, SerResult err, Ts... errs) +{ + builder = &builder->then(returnValue(static_cast(err))); + DoMockThen(builder, errs...); +} +} // namespace internal + +template void MockPrepareTimerContext(SerResult err, Ts... errs) +{ + auto builder = MOCKER_CPP(&HcomChannelImp::PrepareTimerContext).stubs(); + auto *b = &builder.will(returnValue(static_cast(err))); + internal::DoMockThen(b, errs...); +} + +TEST_F(TestNetChannelImp, TestSpliceMessageOffsetError) +{ + auto ctx = CreateNRC(); + MockPrepareTimerContext(SER_OK); + + SpliceMessageResultType result; + SerResult code; + std::string out; + + // the first fragment of msg (id=0x11), with totalLength = 2 + auto first = GetFragmentHeader(0x11, 2, 0); + SetNRCPayload(*ctx, first); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::INDETERMINATE); + EXPECT_EQ(code, SER_OK); + + // the second fragment of msg (id=0x11), but one bit of the offset flipped + auto second = GetFragmentHeader(0x11, 2, 1 + 8); + SetNRCPayload(*ctx, second); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::ERROR); + EXPECT_EQ(code, SER_SPLIT_INVALID_MSG); +} + +TEST_F(TestNetChannelImp, TestSpliceMessageLargePayload) +{ + auto ctx = CreateNRC(); + MockPrepareTimerContext(SER_OK); + + SpliceMessageResultType result; + SerResult code; + std::string out; + + // totalLength = 2, but payload length = 0xffff. + auto first = GetFragmentHeader(0x11, 2, 0); + SetNRCPayload(*ctx, first, 0xffff); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::ERROR); + EXPECT_EQ(code, SER_SPLIT_INVALID_MSG); +} + +template void MockGetSeqNoAndRemove(SerResult err, Ts... errs) +{ + auto builder = MOCKER_CPP(&HcomServiceCtxStore::GetSeqNoAndRemove).stubs(); + auto *b = &builder.will(returnValue(static_cast(err))); + internal::DoMockThen(b, errs...); +} + +TEST_F(TestNetChannelImp, TestSpliceMessageOk) +{ + auto ctx = CreateNRC(); + MockPrepareTimerContext(SER_OK); + MockGetSeqNoAndRemove(SER_OK); + MOCKER_CPP(&HcomServiceTimer::MarkFinished).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&HcomServiceTimer::DecreaseRef).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&HcomServiceTimer::DeleteCallBack).stubs().will(ignoreReturnValue()); + + SpliceMessageResultType result; + SerResult code; + std::string out; + + auto first = GetFragmentHeader(0x11, 2, 0); + SetNRCPayload(*ctx, first); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::INDETERMINATE); + EXPECT_EQ(code, SER_OK); + + auto second = GetFragmentHeader(0x11, 2, 1); + SetNRCPayload(*ctx, second); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::OK); + EXPECT_EQ(code, SER_OK); +} + +TEST_F(TestNetChannelImp, TestSpliceMessageTwo) +{ + auto ctx = CreateNRC(); + MockPrepareTimerContext(SER_OK); + MockGetSeqNoAndRemove(SER_OK); + MOCKER_CPP(&HcomServiceTimer::MarkFinished).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&HcomServiceTimer::DecreaseRef).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&HcomServiceTimer::DeleteCallBack).stubs().will(ignoreReturnValue()); + + SpliceMessageResultType result; + SerResult code; + std::string out; + + // message1: totalLength=2, offset=0, payload=1 + auto m1First = GetFragmentHeader(0x11, 2, 0); + SetNRCPayload(*ctx, m1First); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::INDETERMINATE); + EXPECT_EQ(code, SER_OK); + + // message2: totalLength=2, offset=0, payload=1 + auto m2First = GetFragmentHeader(0x12, 2, 0); + SetNRCPayload(*ctx, m2First); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::INDETERMINATE); + EXPECT_EQ(code, SER_OK); + + // message1: totalLength=2, offset=1, payload=1 + auto m1Second = GetFragmentHeader(0x11, 2, 1); + SetNRCPayload(*ctx, m1Second); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::OK); + EXPECT_EQ(code, SER_OK); + + // message2: totalLength=2, offset=1, payload=1 + auto m2Second = GetFragmentHeader(0x12, 2, 1); + SetNRCPayload(*ctx, m2Second); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, false); + EXPECT_EQ(result, SpliceMessageResultType::OK); + EXPECT_EQ(code, SER_OK); +} + +TEST_F(TestNetChannelImp, TestSpliceRespMessageOk) +{ + auto ctx = CreateNRC(); + MockPrepareTimerContext(SER_OK); + MockGetSeqNoAndRemove(SER_OK); + MOCKER_CPP(&HcomServiceTimer::MarkFinished).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&HcomServiceTimer::DecreaseRef).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&HcomServiceTimer::DeleteCallBack).stubs().will(ignoreReturnValue()); + + SpliceMessageResultType result; + SerResult code; + std::string out; + + auto first = GetFragmentHeader(0x11, 2, 0); + SetNRCPayload(*ctx, first); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, true); + EXPECT_EQ(result, SpliceMessageResultType::INDETERMINATE); + EXPECT_EQ(code, SER_OK); + + auto second = GetFragmentHeader(0x11, 2, 1); + SetNRCPayload(*ctx, second); + std::tie(result, code, out) = channel->SpliceMessage(*ctx, true); + EXPECT_EQ(result, SpliceMessageResultType::OK); + EXPECT_EQ(code, SER_OK); +} + +TEST_F(TestNetChannelImp, TestAsyncReplySplitWithWorkerPoll) +{ + UBSHcomReplyContext ctx; + + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomChannelImp::DestroyTimerContext).stubs(); + auto tmp = ep.Get(); + UBSHcomRequest req(data, NN_NO65536, 0); + ASSERT_EQ(channel->AsyncReplySplitWithWorkerPoll(ctx, tmp, req, 1, nullptr), SER_NEW_OBJECT_FAILED); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &, + const UBSHcomExtHeaderType, const void *, uint32_t)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->AsyncReplySplitWithWorkerPoll(ctx, tmp, req, 1, nullptr), SER_INVALID_PARAM); + ASSERT_EQ(channel->AsyncReplySplitWithWorkerPoll(ctx, tmp, req, 1, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestSyncReplySplitWithWorkerPoll) +{ + UBSHcomReplyContext ctx; + + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomChannelImp::DestroyTimerContext).stubs(); + auto tmp = ep.Get(); + UBSHcomRequest req(data, NN_NO65536, 0); + + ASSERT_EQ(channel->SyncReplySplitWithWorkerPoll(ctx, tmp, req, 1), SER_NEW_OBJECT_FAILED); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &, + const UBSHcomExtHeaderType, const void *, uint32_t)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))); + ASSERT_EQ(channel->SyncReplySplitWithWorkerPoll(ctx, tmp, req, 1), SER_INVALID_PARAM); +} + +TEST_F(TestNetChannelImp, TestSyncCallSplitWithWorkerPoll) +{ + UBSHcomResponse rsp{}; + UBSHcomRequest req(data, NN_NO65536, 0); + auto tmpEp = ep.Get(); + + MOCKER_CPP(&HcomChannelImp::DestroyTimerContext).stubs(); + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->SyncCallSplitWithWorkerPoll(tmpEp, req, 1, rsp), SER_NEW_OBJECT_FAILED); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &, + const UBSHcomExtHeaderType, const void *, uint32_t)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))); + + ASSERT_EQ(channel->SyncCallSplitWithWorkerPoll(tmpEp, req, 1, rsp), SER_INVALID_PARAM); +} + +TEST_F(TestNetChannelImp, TestAsyncCallSplitWithWorkerPoll) +{ + MOCKER_CPP(&HcomChannelImp::PrepareTimerContext) + .stubs() + .will(returnValue(static_cast(SER_NEW_OBJECT_FAILED))) + .then(returnValue(static_cast(SER_OK))); + MOCKER_CPP(&HcomChannelImp::DestroyTimerContext).stubs(); + UBSHcomRequest req(data, NN_NO65536, 0); + auto tmpEp = ep.Get(); + + ASSERT_EQ(channel->AsyncCallSplitWithWorkerPoll(tmpEp, req, 1, nullptr), SER_NEW_OBJECT_FAILED); + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &, + const UBSHcomExtHeaderType, const void *, uint32_t)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + ASSERT_EQ(channel->AsyncCallSplitWithWorkerPoll(tmpEp, req, 1, nullptr), SER_INVALID_PARAM); + ASSERT_EQ(channel->AsyncCallSplitWithWorkerPoll(tmpEp, req, 1, nullptr), SER_OK); +} + +TEST_F(TestNetChannelImp, TestSyncCallSplitWithSelfPoll) +{ + UBSHcomResponse rsp{}; + UBSHcomRequest req(data, NN_NO65536, 0); + auto tmpEp = ep.Get(); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::PostSend, + SerResult(UBSHcomNetEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, const UBSHcomNetTransOpInfo &, + const UBSHcomExtHeaderType, const void *, uint32_t)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))) + .then(returnValue(static_cast(SER_OK))); + + ASSERT_EQ(channel->SyncCallSplitWithSelfPoll(tmpEp, req, 1, 0, rsp), SER_INVALID_PARAM); + + MOCKER_CPP_VIRTUAL(*(ep.Get()), &UBSHcomNetEndpoint::WaitCompletion, SerResult(UBSHcomNetEndpoint::*)(int32_t)) + .stubs() + .will(returnValue(static_cast(SER_INVALID_PARAM))); + ASSERT_EQ(channel->SyncCallSplitWithSelfPoll(tmpEp, req, 1, 0, rsp), SER_INVALID_PARAM); +} + +TEST_F(TestNetChannelImp, TestSetTraceId) +{ +#ifdef build_BUILD_ENABLED + MOCKER(HcomUrma::IsLoaded).stubs().will(returnValue(false)).then(returnValue(true)); + std::string traceId = "This is a test trace id"; + + EXPECT_NO_FATAL_FAILURE(channel->SetTraceId(traceId)); + EXPECT_NO_FATAL_FAILURE(channel->SetTraceId(traceId)); +#endif +} + +} // namespace HCOM +} // namespace OCK diff --git a/test/unit_test/transport/common/test_net_common.cpp b/test/unit_test/transport/common/test_net_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56c224038767272db739fbe38ccfff65c738a474 --- /dev/null +++ b/test/unit_test/transport/common/test_net_common.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "net_common.h" +#include "hcom.h" + +namespace ock { +namespace hcom { + +class TestNetCommon : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); +}; + +void TestNetCommon::SetUp() +{ +} + +void TestNetCommon::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestNetCommon, ParseWorkersGroupsErr) +{ + std::string workerStr; + for (int i = 0; i < (NN_NO128 + 1); ++i) { + if (i > 0) { + workerStr += ","; + } + workerStr += "1"; + } + std::vector workerGroups; + bool ret = NetFunc::NN_ParseWorkersGroups(workerStr, workerGroups); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetCommon, ParseWorkerGroupsCpusErr) +{ + std::string workerGroupCpusStr; + for (int i = 0; i < (NN_NO128 + 1); ++i) { + if (i > 0) { + workerGroupCpusStr += ","; + } + workerGroupCpusStr += "1-1"; + } + std::vector> workerGroupCpus; + bool ret = NetFunc::NN_ParseWorkerGroupsCpus(workerGroupCpusStr, workerGroupCpus); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetCommon, FinalizeWorkerGroupCpusErr) +{ + std::vector workerGroups = {2, 3}; + std::vector> workerGroupCpus = {{1, 3}, {2, 4}}; + std::vector flatWorkersCpus; + bool ret = NetFunc::NN_FinalizeWorkerGroupCpus(workerGroups, workerGroupCpus, true, flatWorkersCpus); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetCommon, ParseWorkersGroupsThreadPriorityErr) +{ + std::string threadPriorityStr = "1,2,3"; + int groupNum = NN_NO4; + std::vector threadPriority; + bool ret = NetFunc::NN_ParseWorkersGroupsThreadPriority(threadPriorityStr, threadPriority, groupNum); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetCommon, VecstrToStr) +{ + std::vector vec{}; + std::string linkStr{}; + std::string result{}; + std::string expect = "testtest2"; + vec.emplace_back("test"); + vec.emplace_back("test2"); + NetFunc::NN_VecStrToStr(vec, linkStr, result); + EXPECT_EQ(expect, result); +} + +TEST_F(TestNetCommon, ConvertIpAndPort) +{ + std::string badUrl = "1.2.3.4"; + std::string badUrl2 = "1.2.3.4:0"; + std::string goodUrl = "1.2.3.4:9981"; + std::string ip{}; + uint16_t port = 0; + EXPECT_EQ(NetFunc::NN_ConvertIpAndPort(badUrl, ip, port), false); + EXPECT_EQ(NetFunc::NN_ConvertIpAndPort(badUrl2, ip, port), false); + EXPECT_EQ(NetFunc::NN_ConvertIpAndPort(goodUrl, ip, port), true); +} + +TEST_F(TestNetCommon, SplitProtoUrl) +{ + std::string badUrl = "127.0.0.1:9981"; + std::string testUrl = "tcp://127.0.0.1:9981"; + std::string testUrl2 = "uds://name"; + std::string testUrl3 = "unknown://name"; + std::string testUrl4 = "ubc://1111:2222:0000:0000:0000:0000:0100:0000:888"; + + NetProtocol protocol; + std::string url{}; + EXPECT_EQ(NetFunc::NN_SplitProtoUrl(badUrl, protocol, url), false); + EXPECT_EQ(NetFunc::NN_SplitProtoUrl(testUrl, protocol, url), true); + EXPECT_EQ(NetFunc::NN_SplitProtoUrl(testUrl2, protocol, url), true); + EXPECT_EQ(NetFunc::NN_SplitProtoUrl(testUrl3, protocol, url), false); + EXPECT_EQ(NetFunc::NN_SplitProtoUrl(testUrl4, protocol, url), true); + EXPECT_EQ(protocol, NetProtocol::NET_UBC); + EXPECT_EQ(url, "1111:2222:0000:0000:0000:0000:0100:0000:888"); +} + +TEST_F(TestNetCommon, ConvertEidAndJettyId) +{ + std::string badUrl1 = "127.0.0.1:8888"; + std::string badUrl2 = "1111:2222:0000:0000:0000:0000:0100:0000"; + std::string badUrl3 = "1111:2222:0000:0000:0000:0000:0100:0000:"; + std::string badUrl4 = "1111:2222:0000:0000:0000:0000:0100:888"; + std::string badUrl5 = "1111:2222:0000:0000:0000:0000:0100:0000:2"; + std::string badUrl6 = "1111:2222:0000:0000:0000:0000:0100:0000:1024"; + std::string badUrl7 = "1111:2222:0000:0000:0000:0000:0100:0000::22"; + + std::string testUrl1 = "1111:2222:3333:0000:0000:0000:0000:0000:4"; + std::string testUrl2 = "1111:2222:3333:0000:0000:0000:0000:0000:1023"; + std::string testUrl3 = "1111:2222:0000:0000:0000:0000:0100:0000:888"; + + NetProtocol protocol; + std::string eid; + uint16_t jettyId = 0; + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(badUrl1, eid, jettyId), false); + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(badUrl2, eid, jettyId), false); + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(badUrl3, eid, jettyId), false); + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(badUrl4, eid, jettyId), false); + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(badUrl5, eid, jettyId), false); + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(badUrl6, eid, jettyId), false); + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(badUrl7, eid, jettyId), false); + + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(testUrl1, eid, jettyId), true); + EXPECT_EQ(eid, "1111:2222:3333:0000:0000:0000:0000:0000"); + EXPECT_EQ(jettyId, 4); + + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(testUrl2, eid, jettyId), true); + EXPECT_EQ(eid, "1111:2222:3333:0000:0000:0000:0000:0000"); + EXPECT_EQ(jettyId, 1023); + + EXPECT_EQ(NetFunc::NN_ConvertEidAndJettyId(testUrl3, eid, jettyId), true); + EXPECT_EQ(eid, "1111:2222:0000:0000:0000:0000:0100:0000"); + EXPECT_EQ(jettyId, 888); +} + +TEST_F(TestNetCommon, TestValidateUrl) +{ + EXPECT_EQ(NetFunc::NN_ValidateUrl("!@#$^"), static_cast(NN_INVALID_PARAM)); + EXPECT_EQ(NetFunc::NN_ValidateUrl(""), static_cast(NN_INVALID_PARAM)); + EXPECT_EQ(NetFunc::NN_ValidateUrl("tcp://127.0.0.1:8888"), static_cast(NN_OK)); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/common/test_net_heartbeat.cpp b/test/unit_test/transport/common/test_net_heartbeat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62a49c30c287fe4e3ce0d09c0e8088783fc26592 --- /dev/null +++ b/test/unit_test/transport/common/test_net_heartbeat.cpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include + +#include "hcom.h" +#include "net_rdma_driver_oob.h" +#include "net_rdma_async_endpoint.h" +#include "rdma_composed_endpoint.h" + +namespace ock { +namespace hcom { +class TestNetHeartbeat : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + UBSHcomNetDriver *mDriver = nullptr; + NetHeartbeat *mHb = nullptr; +}; + +void TestNetHeartbeat::SetUp() +{ + mDriver = new (std::nothrow) NetDriverRDMAWithOob("test_driver", 1, RDMA); + mHb = new NetHeartbeat(mDriver, NN_NO60, NN_NO2); + mDriver->IncreaseRef(); +} + +void TestNetHeartbeat::TearDown() +{ + if (mDriver != nullptr) { + delete mDriver; + mDriver = nullptr; + } + + if (mHb != nullptr) { + delete mHb; + mHb = nullptr; + } + + GlobalMockObject::verify(); +} + +TEST_F(TestNetHeartbeat, NewHeartbeatZero) +{ + NetHeartbeat *hb = new (std::nothrow) NetHeartbeat(mDriver, NN_NO60, 0); + uint32_t interval = 5000; + EXPECT_EQ(hb->mHeartBeatProbeInterval, interval); + delete (hb); + hb = nullptr; +} + +TEST_F(TestNetHeartbeat, NewHeartbeatMax) +{ + uint32_t maxInterval = 1023; + NetHeartbeat *hb = new (std::nothrow) NetHeartbeat(mDriver, NN_NO60, maxInterval); + EXPECT_EQ(hb->mHeartBeatProbeInterval, maxInterval * NN_NO1000 * NN_NO1000); + delete (hb); + hb = nullptr; +} + +TEST_F(TestNetHeartbeat, StartWithMrError) +{ + MOCKER_CPP_VIRTUAL(mDriver, &UBSHcomNetDriver::CreateMemoryRegion) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(2)); + MOCKER_CPP(&NetHeartbeat::RunInHbThread).stubs().then(ignoreReturnValue()); + + mHb->mHBStarted = true; + NResult result = mHb->Start(); + EXPECT_EQ(result, 1); + result = mHb->Start(); + EXPECT_EQ(result, 2); +} + +TEST_F(TestNetHeartbeat, DetectSingleEpHbStateWithInvalidParam) +{ + UBSHcomNetEndpoint *ep = nullptr; + UBSHcomNetTransRequest req {}; + EXPECT_NO_FATAL_FAILURE(mHb->DetectSingleEpHbState(dynamic_cast(ep), + dynamic_cast(mDriver), req, RDMAOpContextInfo::HB_WRITE)); +} + +TEST_F(TestNetHeartbeat, DetectSingleEpHbStateWithBrokenEp) +{ + RDMAAsyncEndPoint *rdmaEp = nullptr; + UBSHcomNetWorkerIndex index {}; + UBSHcomNetEndpoint *ep = new NetAsyncEndpoint(0, rdmaEp, (NetDriverRDMAWithOob*)mDriver, index); + UBSHcomNetTransRequest req {}; + + MOCKER_CPP(&NetAsyncEndpoint::checkTargetHbTime).stubs().will(returnValue(true)); + MOCKER_CPP(&NetAsyncEndpoint::HbCheckStateNormal).stubs().will(returnValue(false)); + MOCKER_CPP(&NetAsyncEndpoint::HbBrokenEp).stubs().will(returnValue(true)); + MOCKER_CPP(&NetDriverRDMAWithOob::ProcessEpError).stubs().will(ignoreReturnValue()); + + EXPECT_NO_FATAL_FAILURE(mHb->DetectSingleEpHbState(dynamic_cast(ep), + dynamic_cast(mDriver), req, RDMAOpContextInfo::HB_WRITE)); +} + +TEST_F(TestNetHeartbeat, DetectSingleEpHbStateWithOUTBrokenEp) +{ + RDMAAsyncEndPoint *rdmaEp = nullptr; + UBSHcomNetWorkerIndex index {}; + UBSHcomNetEndpoint *ep = new NetAsyncEndpoint(0, rdmaEp, (NetDriverRDMAWithOob*)mDriver, index); + UBSHcomNetTransRequest req {}; + + MOCKER_CPP(&NetAsyncEndpoint::checkTargetHbTime).stubs().will(returnValue(true)); + MOCKER_CPP(&NetAsyncEndpoint::HbCheckStateNormal).stubs().will(returnValue(false)); + MOCKER_CPP(&NetAsyncEndpoint::HbBrokenEp).stubs().will(returnValue(false)); + + EXPECT_NO_FATAL_FAILURE(mHb->DetectSingleEpHbState(dynamic_cast(ep), + dynamic_cast(mDriver), req, RDMAOpContextInfo::HB_WRITE)); + EXPECT_EQ(ep->State().Compare(NEP_BROKEN), true); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/common/test_net_oob.cpp b/test/unit_test/transport/common/test_net_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..589068b4da07c2625b8a801878d7d0df1b01ca8b --- /dev/null +++ b/test/unit_test/transport/common/test_net_oob.cpp @@ -0,0 +1,197 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include +#include +#include "hcom.h" +#include "net_oob.h" + +namespace ock { +namespace hcom { + +const std::string SERVER_IP("127.0.0.1"); +constexpr uint16_t SERVER_PORT_ZERO = 0; + + +class TestNetOob : public testing::Test { +public: + static void SetUpTestSuite() {} + static void TearDownTestSuite() {} + virtual void SetUp(void) {} + virtual void TearDown(void) + { + GlobalMockObject::verify(); + } +}; + +// tcp server is already started +TEST_F(TestNetOob, EnableAutoPortSelectionFailed1) +{ + uint16_t minPort = 2000; + uint16_t maxPort = 3000; + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = true; + server.mOobType = NetDriverOobType::NET_OOB_TCP; + NResult ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_ERROR); + + // make sure the server exits correctly + server.mStarted = false; +} + +// tcp server oob is not tcp +TEST_F(TestNetOob, EnableAutoPortSelectionFailed2) +{ + uint16_t minPort = 2000; + uint16_t maxPort = 3000; + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = false; + server.mOobType = NetDriverOobType::NET_OOB_UDS; + NResult ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_ERROR); +} + +// port range error +TEST_F(TestNetOob, EnableAutoPortSelectionFailed3) +{ + uint16_t minPort = 0; + uint16_t maxPort = 3000; + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = false; + server.mOobType = NetDriverOobType::NET_OOB_TCP; + NResult ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_ERROR); + + minPort = 1; + maxPort = 1000; + ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_ERROR); + + minPort = 2000; + maxPort = 1000; + ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_ERROR); + + minPort = 3000; + maxPort = 2000; + ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_ERROR); +} + +// port range error +TEST_F(TestNetOob, EnableAutoPortSelectionSuccess) +{ + uint16_t minPort = 2000; + uint16_t maxPort = 3000; + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = false; + server.mOobType = NetDriverOobType::NET_OOB_TCP; + NResult ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_OK); + + server.mListenPort = 2500; + ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestNetOob, GetListenPortFailed) +{ + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = false; + uint16_t port = 0; + NResult ret = server.GetListenPort(port); + EXPECT_EQ(ret, NN_ERROR); +} + +TEST_F(TestNetOob, GetListenIpFailed) +{ + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = false; + std::string listenIp; + NResult ret = server.GetListenIp(listenIp); + EXPECT_EQ(ret, NN_ERROR); +} + +TEST_F(TestNetOob, GetUdsNameFailed) +{ + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = false; + std::string udsName; + NResult ret = server.GetUdsName(udsName); + EXPECT_EQ(ret, NN_ERROR); + + server.mStarted = true; + ret = server.GetUdsName(udsName); + EXPECT_EQ(ret, NN_ERROR); + + // make sure the server exits correctly + server.mStarted = false; +} + +TEST_F(TestNetOob, BindAndListenAutoSuccess) +{ + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = false; + int socketFD = 0; + int ret = server.CreateAndConfigSocket(socketFD); + EXPECT_EQ(ret, NN_OK); + uint16_t minPort = 2000; + uint16_t maxPort = 3000; + ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_OK); + + ret = server.BindAndListenCommon(socketFD); + EXPECT_EQ(ret, NN_OK); + + NetFunc::NN_SafeCloseFd(socketFD); +} + +// bind always failed +TEST_F(TestNetOob, BindAndListenAutoFailed1) +{ + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = false; + int socketFD = 0; + int ret = server.CreateAndConfigSocket(socketFD); + EXPECT_EQ(ret, NN_OK); + uint16_t minPort = 2000; + uint16_t maxPort = 2003; + ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_OK); + + MOCKER_CPP(::bind).stubs().will(returnValue(int(-1))); + ret = server.BindAndListenAuto(socketFD); + EXPECT_NE(ret, NN_OK); +} + +// listen always failed +TEST_F(TestNetOob, BindAndListenAutoFailed2) +{ + OOBTCPServer server(SERVER_IP, SERVER_PORT_ZERO); + server.mStarted = false; + int socketFD = 0; + int ret = server.CreateAndConfigSocket(socketFD); + EXPECT_EQ(ret, NN_OK); + uint16_t minPort = 2000; + uint16_t maxPort = 2003; + ret = server.EnableAutoPortSelection(minPort, maxPort); + EXPECT_EQ(ret, NN_OK); + + MOCKER_CPP(::listen).stubs().will(returnValue(int(-1))); + ret = server.BindAndListenAuto(socketFD); + EXPECT_NE(ret, NN_OK); +} + +} +} diff --git a/test/unit_test/transport/common/test_net_security.cpp b/test/unit_test/transport/common/test_net_security.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4abe22a8ae7453f5a466d9098ba87cdbf059b150 --- /dev/null +++ b/test/unit_test/transport/common/test_net_security.cpp @@ -0,0 +1,254 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include "net_security_alg.h" +#include "net_security_rand.h" + +namespace ock { +namespace hcom { +class TestNetSecurity : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + NetSecrets mSecrets; + AesGcm128 mAes; + + UBSHcomNetCipherSuite mCipherSuite = AES_GCM_128; + + unsigned char *mRawData; + uint32_t mRawLen = NN_NO16; + + unsigned char *mCipher; + uint32_t mCipherLen = NN_NO60; + + char dummyCtx[1]; + EVP_CIPHER_CTX *ctx; +}; + +void TestNetSecurity::SetUp() +{ + mAes.SetEncryptOptions(mCipherSuite); + mCipher = static_cast(malloc(mCipherLen)); + mRawData = static_cast(malloc(mRawLen)); + ctx = reinterpret_cast(dummyCtx); +} + +void TestNetSecurity::TearDown() +{ + free(mCipher); + free(mRawData); + GlobalMockObject::verify(); +} + +TEST_F(TestNetSecurity, EncryptOpenSSLSuccess) +{ + MOCKER_CPP(&AesGcm128::EncryptInner).stubs().will(returnValue(0)); + bool ret = mAes.Encrypt(mSecrets, mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, true); +} + +TEST_F(TestNetSecurity, EncryptOpenSSLSuccess_AES_CCM_128) +{ + mAes.SetEncryptOptions(AES_CCM_128); + + MOCKER_CPP(&AesGcm128::EncryptInner).stubs().will(returnValue(0)); + bool ret = mAes.Encrypt(mSecrets, mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, true); +} + +TEST_F(TestNetSecurity, EncryptOpenSSLSuccess_CHACHA20_POLY1305) +{ + mAes.SetEncryptOptions(CHACHA20_POLY1305); + + MOCKER_CPP(&AesGcm128::EncryptInner).stubs().will(returnValue(0)); + bool ret = mAes.Encrypt(mSecrets, mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, true); +} + +TEST_F(TestNetSecurity, EncryptWithInvalidParamFail) +{ + const void *keySecrets = nullptr; + MOCKER_CPP(&NetSecrets::GetKeySecret).stubs().will(returnValue(keySecrets)); + mAes.SetEncryptOptions(mCipherSuite); + bool ret = mAes.Encrypt(mSecrets, mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetSecurity, EncryptInnerCtxNewFail) +{ + EVP_CIPHER_CTX *ctx = nullptr; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + NResult ret = mAes.EncryptInner(static_cast(mSecrets.GetKeySecret()), + static_cast(mSecrets.GetAADSecret()), mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestNetSecurity, EncryptInnerSetEncryptInfoFail) +{ + NResult result = NN_ENCRYPT_FAILED; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxFree).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::SetEncryptInfo).stubs().will(returnValue(result)); + NResult ret = mAes.EncryptInner(static_cast(mSecrets.GetKeySecret()), + static_cast(mSecrets.GetAADSecret()), mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestNetSecurity, EncryptInnerUpdateFail) +{ + NResult result = NN_OK; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxFree).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::SetEncryptInfo).stubs().will(returnValue(result)); + MOCKER_CPP(&HcomSsl::EvpEncryptUpdate).stubs().will(returnValue(-1)); + NResult ret = mAes.EncryptInner(static_cast(mSecrets.GetKeySecret()), + static_cast(mSecrets.GetAADSecret()), mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestNetSecurity, EncryptInnerFinalExFail) +{ + NResult result = NN_OK; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxFree).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::SetEncryptInfo).stubs().will(returnValue(result)); + MOCKER_CPP(&HcomSsl::EvpEncryptUpdate).stubs().will(returnValue(1)); + MOCKER_CPP(&HcomSsl::EvpEncryptFinalEx).stubs().will(returnValue(-1)); + NResult ret = mAes.EncryptInner(static_cast(mSecrets.GetKeySecret()), + static_cast(mSecrets.GetAADSecret()), mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestNetSecurity, EncryptInnerCipherCtxCtrlFail) +{ + NResult result = NN_OK; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxFree).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::SetEncryptInfo).stubs().will(returnValue(result)); + MOCKER_CPP(&HcomSsl::EvpEncryptUpdate).stubs().will(returnValue(1)); + MOCKER_CPP(&HcomSsl::EvpEncryptFinalEx).stubs().will(returnValue(1)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxCtrl).stubs().will(returnValue(-1)); + NResult ret = mAes.EncryptInner(static_cast(mSecrets.GetKeySecret()), + static_cast(mSecrets.GetAADSecret()), mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestNetSecurity, EncryptInnerSuccess) +{ + NResult result = NN_OK; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxFree).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::SetEncryptInfo).stubs().will(returnValue(result)); + MOCKER_CPP(&HcomSsl::EvpEncryptUpdate).stubs().will(returnValue(1)); + MOCKER_CPP(&HcomSsl::EvpEncryptFinalEx).stubs().will(returnValue(1)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxCtrl).stubs().will(returnValue(1)); + NResult ret = mAes.EncryptInner(static_cast(mSecrets.GetKeySecret()), + static_cast(mSecrets.GetAADSecret()), mRawData, mRawLen, mCipher, mCipherLen); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestNetSecurity, DecryptOpenSSLSuccess) +{ + MOCKER_CPP(&AesGcm128::DecryptInner).stubs().will(returnValue(0)); + bool ret = mAes.Decrypt(mSecrets, mCipher, mCipherLen, mRawData, mRawLen); + EXPECT_EQ(ret, true); +} + +TEST_F(TestNetSecurity, DecryptOpenSSLSuccess_AES_CCM_128) +{ + mAes.SetEncryptOptions(AES_CCM_128); + + MOCKER_CPP(&AesGcm128::DecryptInner).stubs().will(returnValue(0)); + bool ret = mAes.Decrypt(mSecrets, mCipher, mCipherLen, mRawData, mRawLen); + EXPECT_EQ(ret, true); +} + +TEST_F(TestNetSecurity, DecryptOpenSSLSuccess_CHACHA20_POLY1305) +{ + mAes.SetEncryptOptions(CHACHA20_POLY1305); + + MOCKER_CPP(&AesGcm128::DecryptInner).stubs().will(returnValue(0)); + bool ret = mAes.Decrypt(mSecrets, mCipher, mCipherLen, mRawData, mRawLen); + EXPECT_EQ(ret, true); +} + +TEST_F(TestNetSecurity, DecryptWithInvalidParamFail) +{ + const void *keySecrets = nullptr; + MOCKER_CPP(&NetSecrets::GetKeySecret).stubs().will(returnValue(keySecrets)); + bool ret = mAes.Decrypt(mSecrets, mCipher, mCipherLen, mRawData, mRawLen); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetSecurity, DecryptInnerCtxNewFail) +{ + EVP_CIPHER_CTX *ctx = nullptr; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + NResult ret = mAes.DecryptInner(static_cast(mSecrets.GetKeySecret()), mCipher, mCipherLen, + mRawData, mRawLen); + EXPECT_EQ(ret, NN_DECRYPT_FAILED); +} + +TEST_F(TestNetSecurity, DecryptInnerSetDecryptInfoFail) +{ + NResult result = NN_DECRYPT_FAILED; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxFree).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::SetDecryptInfo).stubs().will(returnValue(result)); + NResult ret = mAes.DecryptInner(static_cast(mSecrets.GetKeySecret()), mCipher, mCipherLen, + mRawData, mRawLen); + EXPECT_EQ(ret, NN_DECRYPT_FAILED); +} + +TEST_F(TestNetSecurity, DecryptInnerUpdateFail) +{ + NResult result = NN_OK; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxFree).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::SetDecryptInfo).stubs().will(returnValue(result)); + MOCKER_CPP(&HcomSsl::EvpDecryptUpdate).stubs().will(returnValue(-1)); + NResult ret = mAes.DecryptInner(static_cast(mSecrets.GetKeySecret()), mCipher, mCipherLen, + mRawData, mRawLen); + EXPECT_EQ(ret, NN_DECRYPT_FAILED); +} + +TEST_F(TestNetSecurity, DecryptInnerCipherCtxCtrlFail) +{ + NResult result = NN_OK; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxFree).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::SetDecryptInfo).stubs().will(returnValue(result)); + MOCKER_CPP(&HcomSsl::EvpDecryptUpdate).stubs().will(returnValue(1)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxCtrl).stubs().will(returnValue(-1)); + NResult ret = mAes.DecryptInner(static_cast(mSecrets.GetKeySecret()), mCipher, mCipherLen, + mRawData, mRawLen); + EXPECT_EQ(ret, NN_DECRYPT_FAILED); +} + +TEST_F(TestNetSecurity, DecryptInnerSuccess) +{ + NResult result = NN_OK; + MOCKER_CPP(&HcomSsl::EvpCipherCtxNew).stubs().will(returnValue(ctx)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxFree).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::SetDecryptInfo).stubs().will(returnValue(result)); + MOCKER_CPP(&HcomSsl::EvpDecryptUpdate).stubs().will(returnValue(1)); + MOCKER_CPP(&HcomSsl::EvpDecryptFinalEx).stubs().will(returnValue(1)); + MOCKER_CPP(&HcomSsl::EvpCipherCtxCtrl).stubs().will(returnValue(1)); + NResult ret = mAes.DecryptInner(static_cast(mSecrets.GetKeySecret()), mCipher, mCipherLen, + mRawData, mRawLen); + EXPECT_EQ(ret, NN_OK); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/rdma/verbs/test_net_rdma_async_endpoint.cpp b/test/unit_test/transport/rdma/verbs/test_net_rdma_async_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..274c425453d89c46f13a51960f29b64ae0436766 --- /dev/null +++ b/test/unit_test/transport/rdma/verbs/test_net_rdma_async_endpoint.cpp @@ -0,0 +1,876 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +//#ifdef RDMA_BUILD_ENABLED +#include +#include +#include "hcom.h" +#include "net_common.h" +#include "rdma_composed_endpoint.h" +#include "net_rdma_driver_oob.h" +#include "net_security_rand.h" +#include "rdma_validation.h" +#include "net_rdma_async_endpoint.h" + +namespace ock { +namespace hcom { + +UBSHcomNetTransHeader mockAsyncMrBuf{}; +class TestNetRdmaAsyncEndpoint : public testing::Test { +public: + TestNetRdmaAsyncEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name; + std::string ip; + uint16_t port; + NetDriverRDMAWithOob *mDriver = nullptr; + RDMAContext *ctx = nullptr; + RDMAWorker *mWorker = nullptr; + RDMACq *cq = nullptr; + RDMAQp *qp = nullptr; + UBSHcomNetWorkerIndex mWorkerIndex; + RDMAAsyncEndPoint *ep = nullptr; + UBSHcomNetTransRequest request; + UBSHcomNetTransSglRequest sglRequest; + UBSHcomNetTransSgeIov *iov = nullptr; + NetAsyncEndpoint *NEP = nullptr; + RDMAMemoryRegionFixedBuffer *Mr = nullptr; + NetHeartbeat *mHeartBeat = nullptr; +}; + +TestNetRdmaAsyncEndpoint::TestNetRdmaAsyncEndpoint() {} + +void TestNetRdmaAsyncEndpoint::SetUp() +{ + bool useDevX = true; + RDMAGId gid; + ctx = new (std::nothrow) RDMAContext(name, useDevX, gid); + ASSERT_NE(ctx, nullptr); + + bool startOobSvr = true; + UBSHcomNetDriverProtocol protocol = RDMA; + mDriver = new (std::nothrow) NetDriverRDMAWithOob(name, startOobSvr, protocol); + mDriver->mStarted = true; + Mr = mDriver->mDriverSendMR = new (std::nothrow) RDMAMemoryRegionFixedBuffer(name, ctx, 1, 1); + ASSERT_NE(mDriver, nullptr); + + NetMemPoolFixedOptions memOptions{}; + memOptions.minBlkSize = sizeof(RDMAOpContextInfo); + + NetMemPoolFixedPtr memPool = new (std::nothrow) NetMemPoolFixed(name, memOptions); + ASSERT_EQ(memPool->Initialize(), NN_OK); + + NetMemPoolFixedPtr sglMemPool = new (std::nothrow) NetMemPoolFixed(name, memOptions); + ASSERT_EQ(sglMemPool->Initialize(), NN_OK); + + RDMAWorkerOptions options; + mWorker = new (std::nothrow) RDMAWorker(name, ctx, options, memPool, sglMemPool); + mWorker->mOpCtxInfoPool.Initialize(mWorker->mOpCtxMemPool); + mWorker->mSglCtxInfoPool.Initialize(mWorker->mSglCtxMemPool); + ASSERT_NE(mWorker, nullptr); + + cq = new (std::nothrow) RDMACq(name, ctx, false, 0); + ASSERT_NE(cq, nullptr); + + uint32_t mid = 0; + QpOptions qpOptions; + qp = new (std::nothrow) RDMAQp(name, mid, ctx, cq, qpOptions); + qp->UpContext1((uintptr_t)mWorker); + ASSERT_NE(qp, nullptr); + + uint64_t id = 0; + ep = new (std::nothrow) RDMAAsyncEndPoint(name, mWorker, qp); + ASSERT_NE(ep, nullptr); + + NEP = new (std::nothrow) NetAsyncEndpoint(id, ep, mDriver, mWorkerIndex); + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mAllowedSize = NN_NO128; + NEP->mSegSize = NN_NO128; + + mHeartBeat = new (std::nothrow) NetHeartbeat(mDriver, NN_NO60, NN_NO2); + + request.lAddress = reinterpret_cast(&mWorkerIndex); + request.size = 1; + + iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + sglRequest = UBSHcomNetTransSglRequest(iov, 1, 1); +} + +void TestNetRdmaAsyncEndpoint::TearDown() +{ + GlobalMockObject::verify(); + if (Mr != nullptr) { + delete Mr; + Mr = nullptr; + } + if (NEP != nullptr) { + delete NEP; + NEP = nullptr; + } + if (ctx != nullptr) { + delete ctx; + ctx = nullptr; + } + if (iov != nullptr) { + delete iov; + iov = nullptr; + } +} + +static bool MockGetFreeBuffer(uintptr_t &mrBufAddress) +{ + mrBufAddress = reinterpret_cast(&mockAsyncMrBuf); + return true; +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendSeqFailed) +{ + name = "NetAsyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendSeq) +{ + name = "NetAsyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&RDMAWorker::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + NEP->mIsNeedEncrypt = 1; + ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendSeqTwo) +{ + name = "NetAsyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&RDMAWorker::PostSend).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendInfoFailed) +{ + name = "NetAsyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + UBSHcomNetTransOpInfo OpInfo{}; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendOpInfo) +{ + name = "NetAsyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&RDMAWorker::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + UBSHcomNetTransOpInfo OpInfo{}; + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + NEP->mIsNeedEncrypt = 1; + ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendOpInfoTwo) +{ + name = "NetAsyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&RDMAWorker::PostSend).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + UBSHcomNetTransOpInfo OpInfo{}; + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendOpInfoWithHeaderRaw) +{ + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::RAW; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendOpInfoWithHeaderNull) +{ + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, nullptr, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendOpInfoWithHeaderValidateFailed) +{ + NEP->mState.Set(NEP_BROKEN); + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + NEP->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendOpInfoWithHeaderBuffer) +{ + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_GET_BUFF_FAILED); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendOpInfoWithHeaderMemcpy) +{ + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + + NEP->mIsNeedEncrypt = 1; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendOpInfoWithHeaderWorkerSend) +{ + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&RDMAWorker::PostSend) + .stubs() + .will(returnValue(static_cast(NN_OK))) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_OK); + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, RR_QP_POST_SEND_FAILED); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendRawSgl) +{ + name = "NetAsyncEndpointRdmaPostSendRawSgl"; + MOCKER_CPP(&RDMAWorker::PostSendSgl).stubs().will(returnValue(1)).then(returnValue(0)); + + MOCKER_CPP(&NetDriverRDMAWithOob::ValidateMemoryRegion, + NResult(NetDriverRDMAWithOob::*)(uint64_t, uintptr_t, uint64_t)) + .stubs() + .will(returnValue(0)); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(static_cast(0))); + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendRawSglFail) +{ + MOCKER_CPP(&NetDriverRDMAWithOob::ValidateMemoryRegion, + NResult(NetDriverRDMAWithOob::*)(uint64_t, uintptr_t, uint64_t)) + .stubs() + .will(returnValue(0)); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(static_cast(0))); + NEP->mIsNeedEncrypt = true; + int ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaSetEpOption) +{ + UBSHcomEpOptions epOptions; + int ret = NEP->SetEpOption(epOptions); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaEstimatedEncryptLen) +{ + int ret = NEP->EstimatedEncryptLen(0); + EXPECT_EQ(ret, 0); + NEP->mIsNeedEncrypt = 0; + ret = NEP->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 0); +} +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaEstimatedEncryptLenTwo) +{ + NEP->mIsNeedEncrypt = 1; + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(1)); + int ret = NEP->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaEncrypt) +{ + uint64_t cipherLen = 0; + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + NEP->mIsNeedEncrypt = 0; + ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaEncryptTwo) +{ + uint64_t cipherLen = 0; + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(true)); + + NEP->mIsNeedEncrypt = 1; + int ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaEstimatedDecryptLen) +{ + NEP->mIsNeedEncrypt = 0; + int ret = NEP->EstimatedDecryptLen(0); + EXPECT_EQ(ret, 0); + + NEP->mIsNeedEncrypt = 1; + MOCKER_CPP(&AesGcm128::GetRawLen).stubs().will(returnValue(1)); + ret = NEP->EstimatedDecryptLen(0); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaDecrypt) +{ + uint64_t rawLen = 0; + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + NEP->mIsNeedEncrypt = 0; + ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaDecryptTwo) +{ + uint64_t rawLen = 0; + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(true)); + + NEP->mIsNeedEncrypt = 1; + int ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaGetRemoteUdsIdInfo) +{ + UBSHcomNetUdsIdInfo verbsIdInfo; + NEP->mState.Set(NEP_NEW); + int ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mDriver->mStartOobSvr = false; + ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaGetRemoteUdsIdInfoTwo) +{ + UBSHcomNetUdsIdInfo verbsIdInfo; + + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mDriver->mStartOobSvr = true; + NEP->mDriver->mOptions.oobType = NET_OOB_TCP; + int ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); + + NEP->mDriver->mOptions.oobType = NET_OOB_UDS; + ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaGetPeerIpPort) +{ + if (NEP->mEp->mQP != nullptr) { + delete NEP->mEp->mQP; + NEP->mEp->mQP = nullptr; + } + int ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + if (NEP->mEp != nullptr) { + delete NEP->mEp; + NEP->mEp = nullptr; + } + ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaGetPeerIpPortTwo) +{ + NEP->mEp->mQP->mPeerIpPort = "0.0.0.0"; + int ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + NEP->mEp->mQP->mPeerIpPort = "0.0.0.0:sss"; + ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaGetPeerIpPortThree) +{ + NEP->mEp->mQP->mPeerIpPort = "0.0.0.0:0"; + int ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + NEP->mEp->mQP->mPeerIpPort = "0.0.0.0:16"; + ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, true); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaGetSendQueueSize) +{ + name = "NetAsyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&RDMAQp::PostSend) + .stubs() + .will(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_OK)); + EXPECT_EQ(NEP->GetSendQueueCount(), 1); + + ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_OK)); + EXPECT_EQ(NEP->GetSendQueueCount(), 2); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendSglInlineOne) +{ + name = "NetAsyncEndpointRdmaPostSendSglInlineOne"; + NEP->mIsNeedEncrypt = true; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + MOCKER_CPP(&RDMAQp::PostSend) + .stubs() + .will(returnValue(static_cast(NN_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + UBSHcomNetTransOpInfo opInfo{}; + int ret = NEP->PostSendSglInline(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendSglInlineTwo) +{ + name = "NetAsyncEndpointRdmaPostSendSglInlineTwo"; + NEP->mIsNeedEncrypt = true; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(true)); + + MOCKER_CPP(&RDMAQp::PostSend) + .stubs() + .will(returnValue(static_cast(NN_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + UBSHcomNetTransOpInfo opInfo{}; + int ret = NEP->PostSendSglInline(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendSglInlineThree) +{ + name = "NetAsyncEndpointRdmaPostSendSglInlineThree"; + NEP->mIsNeedEncrypt = false; + MOCKER_CPP(&RDMAWorker::PostSendSglInline) + .stubs() + .will(returnValue(0)); + + UBSHcomNetTransOpInfo opInfo{}; + int ret = NEP->PostSendSglInline(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendSglInlineFour) +{ + name = "NetAsyncEndpointRdmaPostSendSglInlineFour"; + NEP->mIsNeedEncrypt = false; + MOCKER_CPP(&RDMAWorker::PostSendSglInline) + .stubs() + .will(returnValue(1)); + + UBSHcomNetTransOpInfo opInfo{}; + int ret = NEP->PostSendSglInline(0, request, opInfo); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendSglInlineFive) +{ + name = "NetAsyncEndpointRdmaPostSendSglInlineFive"; + NEP->mIsNeedEncrypt = false; + MOCKER_CPP(&RDMAWorker::PostSendSglInline) + .stubs() + .will(returnValue(static_cast(RR_QP_POST_SEND_WR_FULL))) + .then(returnValue(0)); + + UBSHcomNetTransOpInfo opInfo{}; + int ret = NEP->PostSendSglInline(0, request, opInfo); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetRdmaAsyncEndpoint, NetAsyncEndpointRdmaPostSendSglInlineValidateFail) +{ + NEP->mState.Set(NEP_BROKEN); + NEP->mIsNeedEncrypt = false; + UBSHcomNetTransOpInfo opInfo{}; + int ret = NEP->PostSendSglInline(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + NEP->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetRdmaAsyncEndpoint, QpInitializeFail) +{ + int ret; + + MOCKER_CPP(&RDMAQp::CreateIbvQp).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&RDMAQp::CreateQpMr).stubs().will(returnValue(1)); + MOCKER_CPP(HcomIbv::DestroyQp).stubs().will(returnValue(0)); + + ret = qp->Initialize(); + EXPECT_NE(ret, 0); + + ret = qp->Initialize(); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestCreateFailed) +{ + auto ret = RDMAAsyncEndPoint::Create("", nullptr, ep); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); + + MOCKER_CPP(RDMAWorker::CreateQP).stubs().will(returnValue(static_cast(RR_PARAM_INVALID))); + ret = RDMAAsyncEndPoint::Create("name", mWorker, ep); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); + + QpOptions option {}; + RDMASyncEndpoint *syncEp = nullptr; + ret = RDMASyncEndpoint::Create("name", nullptr, EVENT_POLLING, 0, option, syncEp); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAEndpointFunction) +{ + RDMAAsyncEndPoint asyncEp {"name", nullptr, nullptr}; + RDMAQpExchangeInfo info {}; + EXPECT_EQ(asyncEp.GetExchangeInfo(info), static_cast(RR_QP_NOT_INITIALIZED)); + EXPECT_EQ(asyncEp.ChangeToReady(info), static_cast(RR_EP_NOT_INITIALIZED)); + + asyncEp.mQP = qp; + + MOCKER_CPP(RDMAQp::ReturnBuffer).stubs().will(returnValue(false)); + EXPECT_NO_FATAL_FAILURE(asyncEp.ReturnBuffer(0)); + asyncEp.mQP = nullptr; + + EXPECT_EQ(asyncEp.Initialize(), static_cast(RR_EP_NOT_INITIALIZED)); + asyncEp.mQP = qp; + MOCKER_CPP(RDMAQp::Initialize).stubs().will(returnValue(static_cast(RR_PARAM_INVALID))); + EXPECT_EQ(asyncEp.Initialize(), static_cast(RR_PARAM_INVALID)); + asyncEp.mQP = nullptr; +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMASizeValidateFail) +{ + request.size = NN_NO2; + uint32_t allowedSize = NN_NO1; + AesGcm128 mAes; + EXPECT_EQ(SizeValidate(request, allowedSize, false, mAes), NN_TWO_SIDE_MESSAGE_TOO_LARGE); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAPostSendValidationFail) +{ + UBSHcomNetAtomicState state{NEP_ESTABLISHED}; + request.size = NN_NO2; + uint32_t allowedSize = NN_NO1; + AesGcm128 mAes; + EXPECT_EQ(PostSendValidation(state, 1, mDriver, 1, request, allowedSize, false, mAes), + NN_TWO_SIDE_MESSAGE_TOO_LARGE); + + allowedSize = NN_NO3; + EXPECT_EQ(PostSendValidation(state, 1, mDriver, MAX_OPCODE + 1, request, allowedSize, false, mAes), + NN_INVALID_OPCODE); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAPostSendRawValidationFail) +{ + UBSHcomNetAtomicState state{NEP_BROKEN}; + uint32_t allowedSize = NN_NO1; + AesGcm128 mAes; + EXPECT_EQ(PostSendRawValidation(state, 1, mDriver, 1, request, allowedSize, false, mAes), + NN_EP_NOT_ESTABLISHED); + + state.Set(NEP_ESTABLISHED); + request.size = NN_NO2; + allowedSize = NN_NO1; + EXPECT_EQ(PostSendRawValidation(state, 1, mDriver, 1, request, allowedSize, false, mAes), + NN_TWO_SIDE_MESSAGE_TOO_LARGE); + + allowedSize = NN_NO3; + EXPECT_EQ(PostSendRawValidation(state, 1, mDriver, 0, request, allowedSize, false, mAes), + NN_PARAM_INVALID); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAReadWriteValidationFail) +{ + UBSHcomNetAtomicState state{NEP_BROKEN}; + EXPECT_EQ(ReadWriteValidation(state, 1, mDriver, request), NN_EP_NOT_ESTABLISHED); + + state.Set(NEP_ESTABLISHED); + request.size = NET_SGE_MAX_SIZE + 1; + EXPECT_EQ(ReadWriteValidation(state, 1, mDriver, request), NN_PARAM_INVALID); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMASglValidationFail) +{ + size_t totalSize = 0; + sglRequest.iov[0].size = NET_SGE_MAX_SIZE + 1; + EXPECT_EQ(SglValidation(sglRequest, totalSize, mDriver), NN_PARAM_INVALID); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAPostSendSglValidationFail) +{ + size_t totalSize = 0; + UBSHcomNetAtomicState state{NEP_BROKEN}; + uint32_t allowedSize = NN_NO1; + AesGcm128 mAes; + EXPECT_EQ(PostSendSglValidation(state, 1, mDriver, 1, sglRequest, allowedSize, totalSize, false, mAes), + NN_EP_NOT_ESTABLISHED); + + state.Set(NEP_ESTABLISHED); + EXPECT_EQ(PostSendSglValidation(state, 1, mDriver, 0, sglRequest, allowedSize, totalSize, false, mAes), + NN_PARAM_INVALID); + + allowedSize = 0; + MOCKER_CPP(&NetDriverRDMAWithOob::ValidateMemoryRegion, + NResult(NetDriverRDMAWithOob::*)(uint64_t, uintptr_t, uint64_t)) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(static_cast(1))); + EXPECT_EQ(PostSendSglValidation(state, 1, mDriver, 1, sglRequest, allowedSize, totalSize, true, mAes), + NN_TWO_SIDE_MESSAGE_TOO_LARGE); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAEncryptRawSglSuccess) +{ + UBSHcomNetTransRequest tlsReq {}; + uintptr_t mrBufAddress = 0; + size_t size = 0; + AesGcm128 mAes; + NetSecrets mSecrets; + + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&AesGcm128::Encrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs().will(returnValue(true)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::ReturnBuffer).stubs().will(returnValue(true)); + + EXPECT_EQ(EncryptRawSgl(tlsReq, mrBufAddress, size, mAes, mDriver, sglRequest, mSecrets), NN_OK); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAEncryptRawSglGetBufferFail) +{ + UBSHcomNetTransRequest tlsReq {}; + uintptr_t mrBufAddress = 0; + size_t size = 0; + AesGcm128 mAes; + NetSecrets mSecrets; + + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(returnValue(false)); + + EXPECT_EQ(EncryptRawSgl(tlsReq, mrBufAddress, size, mAes, mDriver, sglRequest, mSecrets), NN_GET_BUFF_FAILED); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAEncryptRawSglMemCpyFail) +{ + UBSHcomNetTransRequest tlsReq {}; + uintptr_t mrBufAddress = 0; + size_t size = 0; + AesGcm128 mAes; + NetSecrets mSecrets; + + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::ReturnBuffer).stubs().will(returnValue(true)); + + EXPECT_EQ(EncryptRawSgl(tlsReq, mrBufAddress, size, mAes, mDriver, sglRequest, mSecrets), NN_INVALID_PARAM); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAEncryptRawSglGetSecondBufferFail) +{ + UBSHcomNetTransRequest tlsReq {}; + uintptr_t mrBufAddress = 0; + size_t size = 0; + AesGcm128 mAes; + NetSecrets mSecrets; + + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(returnValue(true)).then(returnValue(false)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::ReturnBuffer).stubs().will(returnValue(true)); + + EXPECT_EQ(EncryptRawSgl(tlsReq, mrBufAddress, size, mAes, mDriver, sglRequest, mSecrets), NN_GET_BUFF_FAILED); +} + +TEST_F(TestNetRdmaAsyncEndpoint, TestRDMAEncryptRawSglEncryptFail) +{ + UBSHcomNetTransRequest tlsReq {}; + uintptr_t mrBufAddress = 0; + size_t size = 0; + AesGcm128 mAes; + NetSecrets mSecrets; + + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&AesGcm128::Encrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs().will(returnValue(false)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::ReturnBuffer).stubs().will(returnValue(true)); + + EXPECT_EQ(EncryptRawSgl(tlsReq, mrBufAddress, size, mAes, mDriver, sglRequest, mSecrets), NN_ENCRYPT_FAILED); +} +} +} +//#endif diff --git a/test/unit_test/transport/rdma/verbs/test_net_rdma_driver_oob.cpp b/test/unit_test/transport/rdma/verbs/test_net_rdma_driver_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9daf2983dd5911973b7b63198d50877de0e1f9bf --- /dev/null +++ b/test/unit_test/transport/rdma/verbs/test_net_rdma_driver_oob.cpp @@ -0,0 +1,570 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include + +#include "net_monotonic.h" +#include "net_oob_ssl.h" +#include "net_rdma_sync_endpoint.h" +#include "net_rdma_async_endpoint.h" +#include "rdma_mr_dm_buf.h" +#include "rdma_mr_fixed_buf.h" +#include "net_rdma_driver_oob.h" +#include "net_oob_secure.h" + +namespace ock { +namespace hcom { + +class TestNetRdmaDriverOob : public testing::Test { +public: + TestNetRdmaDriverOob(); + virtual void SetUp(void); + virtual void TearDown(void); + const std::string name = "TestNetRdmaDriverOob"; + NetDriverRDMAWithOob *testDriver = nullptr; +}; + +TestNetRdmaDriverOob::TestNetRdmaDriverOob() {} + +void TestNetRdmaDriverOob::SetUp() +{ + bool startOobSvr = true; + UBSHcomNetDriverProtocol protocol = RDMA; + testDriver = new (std::nothrow) NetDriverRDMAWithOob(name, startOobSvr, protocol); + ASSERT_NE(testDriver, nullptr); +} + +void TestNetRdmaDriverOob::TearDown() +{ + if (testDriver != nullptr) { + delete testDriver; + testDriver = nullptr; + } + + GlobalMockObject::verify(); +} + +OOBTCPConnection *newConn = nullptr; +NResult MockConnect(const std::string &ip, uint32_t port, OOBTCPConnection *&conn) +{ + conn = newConn; + conn->SetIpAndPort("xx.xx", 1); + return NN_OK; +} + +NResult MockReceiveTest(void *&buf, uint32_t size) +{ + ConnectHeader *bufHeader = reinterpret_cast(buf); + bufHeader->devIndex = NN_NO4; + return NN_OK; +} + +TEST_F(TestNetRdmaDriverOob, TestConnectMultiRailFail) +{ + std::string oobIp = "127.0.0.1"; + uint16_t oobPort = 1; + testDriver->mOptions.enableMultiRail = true; + testDriver->mEnableTls = false; + std::string payload = "Test"; + UBSHcomNetEndpointPtr outEp; + uint32_t flags = 0; + uint8_t serverGrpNo = 1; + uint8_t clientGrpNo = 1; + uint64_t ctx = 1; + int fd = -1; + testDriver->mInited = true; + testDriver->mStarted = true; + newConn = new (std::nothrow) OOBTCPConnection(fd); + newConn->IncreaseRef(); + + MOCKER_CPP(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*newConn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*newConn, &OOBTCPConnection::Receive).stubs().will(invoke(MockReceiveTest)); + + NResult res = testDriver->Connect(oobIp, oobPort, payload, outEp, flags, serverGrpNo, clientGrpNo, ctx); + EXPECT_EQ(res, NN_ERROR); + + newConn->DecreaseRef(); +} + +TEST_F(TestNetRdmaDriverOob, DestroyEpByPortNum) +{ + UBSHcomNetWorkerIndex index {}; + RDMAAsyncEndPoint *ep1 = (RDMAAsyncEndPoint *)malloc(sizeof(RDMAAsyncEndPoint)); + RDMAQp *qp1 = (RDMAQp *)malloc(sizeof(RDMAQp)); + RDMAContext *context1 = (RDMAContext *)malloc(sizeof(RDMAContext)); + ep1->mQP = qp1; + ep1->mQP->mRDMAContext = context1; + ep1->mQP->mRDMAContext->mPortNumber = 0; + UBSHcomNetEndpointPtr fakeEp1 = new (std::nothrow) NetAsyncEndpoint(0, ep1, testDriver, index); + testDriver->IncreaseRef(); + testDriver->mEndPoints.emplace(fakeEp1->mId, fakeEp1); + + RDMAAsyncEndPoint *ep2 = (RDMAAsyncEndPoint *)malloc(sizeof(RDMAAsyncEndPoint)); + RDMAQp *qp2 = (RDMAQp *)malloc(sizeof(RDMAQp)); + RDMAContext *context2 = (RDMAContext *)malloc(sizeof(RDMAContext)); + ep2->mQP = qp2; + ep2->mQP->mRDMAContext = context2; + ep2->mQP->mRDMAContext->mPortNumber = 1; + UBSHcomNetEndpointPtr fakeEp2 = new (std::nothrow) NetAsyncEndpoint(1, ep2, testDriver, index); + testDriver->IncreaseRef(); + testDriver->mEndPoints.emplace(fakeEp2->mId, fakeEp2); + + MOCKER_CPP(&NetDriverRDMAWithOob::ProcessEpError).stubs().will(ignoreReturnValue()); + EXPECT_NO_FATAL_FAILURE(testDriver->DestroyEpByPortNum(1)); + free(context1); + free(context2); + free(qp1); + free(qp2); + free(ep1); + free(ep2); +} + +TEST_F(TestNetRdmaDriverOob, HandlePortDown) +{ + RDMAWorker *fakeWorker = (RDMAWorker *)malloc(sizeof(RDMAWorker)); + RDMAContext *context = (RDMAContext *)malloc(sizeof(RDMAContext)); + fakeWorker->mRDMAContext = context; + fakeWorker->mRDMAContext->mPortNumber = 1; + testDriver->mWorkers.emplace_back(fakeWorker); + MOCKER_CPP(&RDMAWorker::Stop).stubs().will(returnValue(0)); + EXPECT_NO_FATAL_FAILURE(testDriver->HandlePortDown(1)); + free(context); + free(fakeWorker); +} + +TEST_F(TestNetRdmaDriverOob, HandlePortActive) +{ + RDMAWorker *fakeWorker = (RDMAWorker *)malloc(sizeof(RDMAWorker)); + RDMAContext *context = (RDMAContext *)malloc(sizeof(RDMAContext)); + fakeWorker->mRDMAContext = context; + fakeWorker->mRDMAContext->mPortNumber = 1; + testDriver->mWorkers.emplace_back(fakeWorker); + MOCKER_CPP(&RDMAWorker::Start).stubs().will(returnValue(0)); + EXPECT_NO_FATAL_FAILURE(testDriver->HandlePortActive(1)); + free(context); + free(fakeWorker); +} + +TEST_F(TestNetRdmaDriverOob, DestroyEpInWorker) +{ + UBSHcomNetWorkerIndex index{}; + RDMAWorker *fakeWorker = (RDMAWorker *)malloc(sizeof(RDMAWorker)); + RDMAContext *context1 = (RDMAContext *)malloc(sizeof(RDMAContext)); + context1->mPortNumber = 0; + fakeWorker->mRDMAContext = context1; + testDriver->mWorkers.emplace_back(fakeWorker); + + RDMAAsyncEndPoint *ep1 = (RDMAAsyncEndPoint *)malloc(sizeof(RDMAAsyncEndPoint)); + RDMAQp *qp1 = (RDMAQp *)malloc(sizeof(RDMAQp)); + ep1->mQP = qp1; + ep1->mQP->mRDMAContext = context1; + ep1->mWorker = fakeWorker; + UBSHcomNetEndpointPtr fakeEp1 = new (std::nothrow) NetAsyncEndpoint(0, ep1, testDriver, fakeWorker->mIndex); + testDriver->IncreaseRef(); + testDriver->mEndPoints.emplace(fakeEp1->mId, fakeEp1); + + RDMAAsyncEndPoint *ep2 = (RDMAAsyncEndPoint *)malloc(sizeof(RDMAAsyncEndPoint)); + RDMAQp *qp2 = (RDMAQp *)malloc(sizeof(RDMAQp)); + RDMAContext *context2 = (RDMAContext *)malloc(sizeof(RDMAContext)); + ep2->mQP = qp2; + ep2->mQP->mRDMAContext = context2; + ep2->mQP->mRDMAContext->mPortNumber = 1; + UBSHcomNetEndpointPtr fakeEp2 = new (std::nothrow) NetAsyncEndpoint(1, ep2, testDriver, index); + testDriver->IncreaseRef(); + testDriver->mEndPoints.emplace(fakeEp2->mId, fakeEp2); + + MOCKER_CPP(&NetDriverRDMAWithOob::ProcessEpError).stubs().will(ignoreReturnValue()); + EXPECT_NO_FATAL_FAILURE(testDriver->DestroyEpInWorker(fakeWorker)); + free(context1); + free(context2); + free(qp1); + free(qp2); + free(ep1); + free(ep2); + free(fakeWorker); +} + +TEST_F(TestNetRdmaDriverOob, HandleCqEventParamErr) +{ + ibv_async_event event{}; + ibv_cq *cq = (ibv_cq *)malloc(sizeof(ibv_cq)); + event.element.cq = cq; + event.element.cq->cq_context = nullptr; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleCqEvent(&event)); + + RDMAWorker *fakeWorker = (RDMAWorker *)malloc(sizeof(RDMAWorker)); + event.element.cq->cq_context = (void *)fakeWorker; + MOCKER_CPP(&RDMAWorker::Stop).stubs().will(returnValue(1)); + EXPECT_NO_FATAL_FAILURE(testDriver->HandleCqEvent(&event)); + free(cq); + free(fakeWorker); +} + +TEST_F(TestNetRdmaDriverOob, HandleCqEvent) +{ + RDMAWorker *fakeWorker = (RDMAWorker *)malloc(sizeof(RDMAWorker)); + ibv_async_event event{}; + ibv_cq *cq = (ibv_cq *)malloc(sizeof(ibv_cq)); + event.element.cq = cq; + event.element.cq->cq_context = (void *)fakeWorker; + + MOCKER_CPP(&RDMAWorker::Stop).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverRDMAWithOob::DestroyEpInWorker).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&RDMAWorker::ReInitializeCQ).stubs() + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&RDMAWorker::Start).stubs().will(returnValue(1)); + EXPECT_NO_FATAL_FAILURE(testDriver->HandleCqEvent(&event)); + EXPECT_NO_FATAL_FAILURE(testDriver->HandleCqEvent(&event)); + free(cq); + free(fakeWorker); +} + +TEST_F(TestNetRdmaDriverOob, HandleAsyncEvent) +{ + RDMAContext *ctx = nullptr; + RDMACq *cq = nullptr; + ibv_async_event event {}; + ibv_qp *qp = (ibv_qp *)malloc(sizeof(ibv_qp)); + RDMAQp *rdmaQp = new RDMAQp("rdma qp", 0, ctx, cq); + qp->qp_context = (void *)rdmaQp; + event.element.qp = qp; + char *name = "qp"; + + MOCKER_CPP(&NetDriverRDMAWithOob::HandleCqEvent).expects(once()); + event.event_type = IBV_EVENT_CQ_ERR; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + + MOCKER_CPP(&NetDriverRDMAWithOob::HandlePortDown).expects(once()); + event.event_type = IBV_EVENT_PORT_ERR; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + + MOCKER_CPP(&NetDriverRDMAWithOob::HandlePortActive).expects(once()); + event.event_type = IBV_EVENT_PORT_ACTIVE; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + + MOCKER_CPP(&RDMAContext::UpdateGid).expects(once()).will(returnValue(0)); + event.event_type = IBV_EVENT_GID_CHANGE; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + + event.event_type = IBV_EVENT_QP_FATAL; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_QP_REQ_ERR; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_QP_ACCESS_ERR; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_COMM_EST; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_SQ_DRAINED; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_PATH_MIG; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_PATH_MIG_ERR; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_QP_LAST_WQE_REACHED; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_SRQ_ERR; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_SRQ_LIMIT_REACHED; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_LID_CHANGE; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_PKEY_CHANGE; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_SM_CHANGE; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_CLIENT_REREGISTER; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + event.event_type = IBV_EVENT_DEVICE_FATAL; + EXPECT_NO_FATAL_FAILURE(testDriver->HandleAsyncEvent(&event)); + + delete(rdmaQp); + free(qp); +} + +int MockRequestPostedHandler(const UBSHcomNetRequestContext &) +{ + return 0; +} + +TEST_F(TestNetRdmaDriverOob, SendFinishedCB) +{ + RDMAOpContextInfo ctx {}; + ctx.opType = RDMAOpContextInfo::SEND_RAW_SGL; + ctx.upCtxSize = 1; + RDMAQp *qp = (RDMAQp *)malloc(sizeof(RDMAQp)); + ctx.qp = qp; + UBSHcomNetEndpoint *ep = (UBSHcomNetEndpoint *)malloc((sizeof(UBSHcomNetEndpoint))); + ctx.qp->mUpContext = (uintptr_t)ep; + RDMASgeCtxInfo sgeCtx {}; + RDMASglContextInfo sglCtx {}; + sgeCtx.ctx = &sglCtx; + memcpy_s(ctx.upCtx, sizeof(RDMASgeCtxInfo), &sgeCtx, sizeof(RDMASgeCtxInfo)); + ctx.upCtxSize = sizeof(RDMASgeCtxInfo); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&RDMAQp::ReturnPostSendWr).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&RDMAWorker::ReturnSglContextInfo).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&RDMAWorker::ReturnOpContextInfo).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::ReturnBuffer).stubs().will(returnValue(true)); + testDriver->mRequestPostedHandler = MockRequestPostedHandler; + testDriver->mEnableTls = true; + + int result = testDriver->SendFinishedCB(&ctx); + EXPECT_EQ(result, NN_OK); + free(ep); + free(qp); +} + +TEST_F(TestNetRdmaDriverOob, ProcessErrorSendFinished) +{ + RDMAOpContextInfo ctx {}; + EXPECT_NO_FATAL_FAILURE(testDriver->ProcessErrorSendFinished(&ctx)); + + RDMAQp *qp = (RDMAQp *)malloc(sizeof(RDMAQp)); + qp->mUpContext1 = 1; + ctx.qp = qp; + EXPECT_NO_FATAL_FAILURE(testDriver->ProcessErrorSendFinished(&ctx)); + free(qp); +} + +TEST_F(TestNetRdmaDriverOob, TestRDMAMemoryRegionCreate) +{ + int ret; + std::string name = "mr"; + RDMAContext *ctx = nullptr; + uint64_t size = NN_NO64; + RDMAMemoryRegion *buf = nullptr; + uintptr_t address = 0; + + ret = RDMAMemoryRegion::Create(name, ctx, size, buf); + EXPECT_EQ(ret, RR_PARAM_INVALID); + + ret = RDMAMemoryRegion::Create(name, ctx, address, size, buf); + EXPECT_EQ(ret, RR_PARAM_INVALID); +} + +TEST_F(TestNetRdmaDriverOob, TestRDMAFixBufferCreate) +{ + int ret; + std::string name = "mr"; + RDMAContext *ctx = nullptr; + RDMAMemoryRegionFixedBuffer *mr = nullptr; + + mr = new RDMAMemoryRegionFixedBuffer(name, ctx, 0, 0); + EXPECT_NE(mr->Initialize(), RR_OK); +} + +TEST_F(TestNetRdmaDriverOob, TestDriverInitializeFail) +{ + int ret; + UBSHcomNetDriverOptions option{}; + + testDriver->mInited = false; + option.enableTls = false; + + MOCKER_CPP(&NetDriverRDMA::CreateContext).stubs().will(returnValue(0)); + MOCKER_CPP(&RDMAContext::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::CreateWorkerResource).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::CreateWorkers).stubs().will(returnValue(1)); + + ret = testDriver->Initialize(option); + EXPECT_NE(ret, 0); + + ret = testDriver->Initialize(option); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetRdmaDriverOob, TestDriverInitializeFail2) +{ + int ret; + UBSHcomNetDriverOptions option{}; + + testDriver->mInited = false; + option.enableTls = false; + + MOCKER_CPP(&NetDriverRDMA::CreateContext).stubs().will(returnValue(0)); + MOCKER_CPP(&RDMAContext::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::CreateWorkerResource).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::CreateWorkers).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::CreateClientLB).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::CreateListeners).stubs().will(returnValue(1)); + + ret = testDriver->Initialize(option); + EXPECT_NE(ret, 0); + + ret = testDriver->Initialize(option); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetRdmaDriverOob, CreateWorkerResourceFail) +{ + int ret; + + MOCKER_CPP(&NetDriverRDMA::CreateSendMr).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::CreateOpCtxMemPool).stubs().will(returnValue(1)); + + ret = testDriver->CreateWorkerResource(); + EXPECT_NE(ret, 0); + + ret = testDriver->CreateWorkerResource(); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetRdmaDriverOob, CreateWorkerResourceFail2) +{ + int ret; + + MOCKER_CPP(&NetDriverRDMA::CreateSendMr).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::CreateOpCtxMemPool).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverRDMA::CreateSglCtxMemPool).stubs().will(returnValue(1)); + + ret = testDriver->CreateWorkerResource(); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetRdmaDriverOob, Connect) +{ + int ret; + testDriver->mInited = true; + testDriver->mStarted = true; + std::string badUrl = "unknown://127.0.0.1"; + std::string serverUrl = "tcp://127.0.0.1:9981"; + std::string payload{}; + UBSHcomNetEndpointPtr outEp; + MOCKER_CPP(&NetDriverRDMAWithOob::Connect, + NResult(NetDriverRDMAWithOob::*)(const OOBTCPClientPtr &, const std::string &, UBSHcomNetEndpointPtr &, uint8_t, + uint8_t, uint64_t)).stubs().will(returnValue(1)); + MOCKER_CPP(&NetDriverRDMAWithOob::ConnectSyncEp).stubs().will(returnValue(0)); + ret = testDriver->Connect(badUrl, payload, outEp, 0, 0, 0, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + testDriver->mEnableTls = true; + ret = testDriver->Connect(serverUrl, payload, outEp, 0, 0, 0, 0); + EXPECT_EQ(ret, 1); + + testDriver->mEnableTls = false; + ret = testDriver->Connect(serverUrl, payload, outEp, NET_EP_SELF_POLLING, 0, 0, 0); + EXPECT_EQ(ret, 0); +} + +static ssize_t MockSend(int socket, void const *buf, size_t size, int flags) +{ + return size; +} + +ConnectResp mockResp; +static ssize_t MockRecv(int socket, void *buf, size_t size, int flags) +{ + switch (size) { + case sizeof(ConnectHeader): { + ConnectHeader *tmp = reinterpret_cast(buf); + tmp->magic = 1; + tmp->protocol = UBSHcomNetDriverProtocol::RDMA; + break; + } + case sizeof(uint32_t): { + uint32_t *tmp = reinterpret_cast(buf); + *tmp = 1; + break; + } + case sizeof(ConnRespWithUId): { + ConnRespWithUId *tmp = reinterpret_cast(buf); + tmp->connResp = mockResp; + break; + } + default: + break; + } + + return size; +} + +TEST_F(TestNetRdmaDriverOob, Connect2) +{ + std::string ip("127.0.0.1"); + std::string payload{}; + UBSHcomNetEndpointPtr outEp; + OOBTCPClientPtr client = new (std::nothrow) OOBTCPClient(ip, 1); + ASSERT_NE(client.Get(), nullptr); + client->mOobType = NET_OOB_UDS; + testDriver->mOptions.enableMultiRail = true; + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, int &)).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER(::send).stubs().will(invoke(MockSend)); + EXPECT_EQ(testDriver->ConnectSyncEp(client, payload, outEp, 0, 0, 0), RR_PARAM_INVALID); +} + +TEST_F(TestNetRdmaDriverOob, Connect3) +{ + std::string ip("127.0.0.1"); + std::string payload{}; + UBSHcomNetEndpointPtr outEp; + OOBTCPClientPtr client = new (std::nothrow) OOBTCPClient(ip, 1); + ASSERT_NE(client.Get(), nullptr); + client->mOobType = NET_OOB_UDS; + RDMASyncEndpoint *rep = new (std::nothrow) RDMASyncEndpoint(ip, nullptr, BUSY_POLLING, nullptr, nullptr, 0); + ASSERT_NE(rep, nullptr); + testDriver->mOptions.enableMultiRail = true; + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, int &)).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER(::send).stubs().will(invoke(MockSend)); + MOCKER(RDMASyncEndpoint::Create).stubs() + .with(any(), any(), any(), any(), any(), outBound(rep)) + .will(returnValue(0)); + EXPECT_EQ(testDriver->ConnectSyncEp(client, payload, outEp, 0, 0, 0), RR_EP_NOT_INITIALIZED); +} + +TEST_F(TestNetRdmaDriverOob, Connect4) +{ + std::string ip("127.0.0.1"); + std::string payload{}; + UBSHcomNetEndpointPtr outEp; + OOBTCPClientPtr client = new (std::nothrow) OOBTCPClient(ip, 1); + ASSERT_NE(client.Get(), nullptr); + client->mOobType = NET_OOB_UDS; + RDMASyncEndpoint *rep = new (std::nothrow) RDMASyncEndpoint(ip, nullptr, BUSY_POLLING, nullptr, nullptr, 0); + ASSERT_NE(rep, nullptr); + testDriver->mOptions.enableMultiRail = true; + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, int &)).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER(::send).stubs().will(invoke(MockSend)); + MOCKER(RDMASyncEndpoint::Create).stubs() + .with(any(), any(), any(), any(), any(), outBound(rep)) + .will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*rep, &RDMASyncEndpoint::Initialize).stubs().will(returnValue(0)); + rep->IncreaseRef(); + mockResp = MAGIC_MISMATCH; + EXPECT_EQ(testDriver->ConnectSyncEp(client, payload, outEp, 0, 0, 0), NN_CONNECT_REFUSED); + rep->IncreaseRef(); + mockResp = WORKER_GRPNO_MISMATCH; + EXPECT_EQ(testDriver->ConnectSyncEp(client, payload, outEp, 0, 0, 0), NN_CONNECT_REFUSED); + rep->IncreaseRef(); + mockResp = PROTOCOL_MISMATCH; + EXPECT_EQ(testDriver->ConnectSyncEp(client, payload, outEp, 0, 0, 0), NN_CONNECT_PROTOCOL_MISMATCH); + rep->IncreaseRef(); + mockResp = SERVER_INTERNAL_ERROR; + EXPECT_EQ(testDriver->ConnectSyncEp(client, payload, outEp, 0, 0, 0), NN_ERROR); + rep->IncreaseRef(); + mockResp = VERSION_MISMATCH; + EXPECT_EQ(testDriver->ConnectSyncEp(client, payload, outEp, 0, 0, 0), NN_CONNECT_REFUSED); + mockResp = TLS_VERSION_MISMATCH; + EXPECT_EQ(testDriver->ConnectSyncEp(client, payload, outEp, 0, 0, 0), NN_CONNECT_REFUSED); +} + +} +} \ No newline at end of file diff --git a/test/unit_test/transport/rdma/verbs/test_net_rdma_driver_oob1.cpp b/test/unit_test/transport/rdma/verbs/test_net_rdma_driver_oob1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11b662476a42078981e61248e57bdea58c603972 --- /dev/null +++ b/test/unit_test/transport/rdma/verbs/test_net_rdma_driver_oob1.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include + +#include "net_monotonic.h" +#include "net_oob_ssl.h" +#include "net_rdma_sync_endpoint.h" +#include "net_rdma_async_endpoint.h" +#include "rdma_mr_dm_buf.h" +#include "rdma_mr_fixed_buf.h" +#include "net_rdma_driver_oob.h" +#include "net_oob_secure.h" + +namespace ock { +namespace hcom { + +class TestNetRdmaDriverOob1 : public testing::Test { +public: + TestNetRdmaDriverOob1(); + virtual void SetUp(void); + virtual void TearDown(void); + const std::string name = "TestNetRdmaDriverOob1"; + NetDriverRDMAWithOob *testDriver1 = nullptr; +}; + +TestNetRdmaDriverOob1::TestNetRdmaDriverOob1() {} + +void TestNetRdmaDriverOob1::SetUp() +{ + bool startOobSvr = true; + UBSHcomNetDriverProtocol protocol = RDMA; + testDriver1 = new (std::nothrow) NetDriverRDMAWithOob(name, startOobSvr, protocol); + ASSERT_NE(testDriver1, nullptr); +} + +void TestNetRdmaDriverOob1::TearDown() +{ + if (testDriver1 != nullptr) { + delete testDriver1; + testDriver1 = nullptr; + } + + GlobalMockObject::verify(); +} + +TEST_F(TestNetRdmaDriverOob1, TestDoUnInitialize) +{ + testDriver1->mStarted = true; + EXPECT_NO_FATAL_FAILURE(testDriver1->DoUnInitialize()); +} + +TEST_F(TestNetRdmaDriverOob1, TestNewConnectionCBFailed) +{ + OOBTCPConnection *conn = new (std::nothrow) OOBTCPConnection(-1); + MOCKER_CPP(&OOBSecureProcess::SecProcessInOOBServer).stubs() + .will(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&OOBSecureProcess::SecProcessCompareEpNum, + NResult(uint32_t, uint32_t, const std::string &, const std::vector &)).stubs() + .will(returnValue(static_cast(NN_OOB_SEC_PROCESS_ERROR))) + .then(returnValue(static_cast(NN_OK))); + EXPECT_EQ(testDriver1->NewConnectionCB(*conn), static_cast(NN_OOB_SEC_PROCESS_ERROR)); + + MOCKER_CPP_VIRTUAL(*conn, &OOBTCPConnection::Receive) + .stubs() + .will(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&OOBSecureProcess::SecCheckConnectionHeader).stubs() + .will(returnValue(static_cast(NN_OOB_SEC_PROCESS_ERROR))) + .then(returnValue(static_cast(NN_OK))); + EXPECT_EQ(testDriver1->NewConnectionCB(*conn), static_cast(NN_ERROR)); + + testDriver1->mOptions.enableMultiRail = true; + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs() + .will(returnValue(static_cast(NN_ERROR))); + EXPECT_EQ(testDriver1->NewConnectionCB(*conn), static_cast(NN_ERROR)); +} + +TEST_F(TestNetRdmaDriverOob1, TestConnectFailed) +{ + std::string payload(1025, 'a'); + UBSHcomNetEndpointPtr ep = nullptr; + testDriver1->mInited = true; + testDriver1->mStarted = true; + auto ret = testDriver1->Connect(std::string("127.0.0.1"), 0, payload, ep, 0, 0, 0, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); +} + +TEST_F(TestNetRdmaDriverOob1, TestProcessError) +{ + EXPECT_NO_FATAL_FAILURE(testDriver1->ProcessErrorNewRequest(nullptr)); + EXPECT_NO_FATAL_FAILURE(testDriver1->ProcessErrorOneSideDone(nullptr)); + EXPECT_NO_FATAL_FAILURE(testDriver1->ProcessQPError(nullptr)); + EXPECT_NO_FATAL_FAILURE(testDriver1->SendFinished(nullptr)); + EXPECT_NO_FATAL_FAILURE(testDriver1->OneSideDone(nullptr)); +} + +TEST_F(TestNetRdmaDriverOob1, TestNewRequestError) +{ + EXPECT_EQ(testDriver1->NewRequest(nullptr), static_cast(NN_ERROR)); + UBSHcomNetRequestContext ctx {}; + UBSHcomNetMessage msg {}; + EXPECT_EQ(testDriver1->NewReceivedRequestWithoutCopy(nullptr, ctx, msg, nullptr, nullptr, nullptr), + static_cast(NN_INVALID_PARAM)); +} + +} +} \ No newline at end of file diff --git a/test/unit_test/transport/rdma/verbs/test_net_rdma_sync_endpoint.cpp b/test/unit_test/transport/rdma/verbs/test_net_rdma_sync_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f61b267b70e8f41e417f082960a530f10124d933 --- /dev/null +++ b/test/unit_test/transport/rdma/verbs/test_net_rdma_sync_endpoint.cpp @@ -0,0 +1,736 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +//#ifdef RDMA_BUILD_ENABLED +#include +#include +#include "hcom.h" +#include "net_common.h" +#include "net_rdma_driver_oob.h" +#include "net_security_rand.h" +#include "rdma_validation.h" +#include "rdma_composed_endpoint.h" +#include "net_rdma_sync_endpoint.h" + +namespace ock { +namespace hcom { + +class TestNetRdmaSyncEndpoint : public testing::Test { +public: + TestNetRdmaSyncEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name; + std::string ip; + uint16_t port; + NetDriverRDMAWithOob *mDriver = nullptr; + RDMAContext *ctx = nullptr; + RDMAWorker *mWorker = nullptr; + RDMACq *cq = nullptr; + RDMAQp *qp = nullptr; + UBSHcomNetWorkerIndex mWorkerIndex; + RDMASyncEndpoint *ep = nullptr; + UBSHcomNetTransRequest request; + UBSHcomNetTransSglRequest sglRequest; + UBSHcomNetTransSgeIov *iov = nullptr; + NetSyncEndpoint *NEP = nullptr; + RDMAMemoryRegionFixedBuffer *Mr = nullptr; +}; + +TestNetRdmaSyncEndpoint::TestNetRdmaSyncEndpoint() {} + +void TestNetRdmaSyncEndpoint::SetUp() +{ + bool useDevX = true; + RDMAGId gid; + ctx = new (std::nothrow) RDMAContext(name, useDevX, gid); + ASSERT_NE(ctx, nullptr); + + bool startOobSvr = true; + UBSHcomNetDriverProtocol protocol = RDMA; + mDriver = new (std::nothrow) NetDriverRDMAWithOob(name, startOobSvr, protocol); + mDriver->mStarted = true; + + Mr = mDriver->mDriverSendMR = new (std::nothrow) RDMAMemoryRegionFixedBuffer(name, ctx, 1, 1); + ASSERT_NE(mDriver, nullptr); + + RDMAWorkerOptions options; + NetMemPoolFixedPtr memPool; + NetMemPoolFixedPtr sglMemPool; + mWorker = new (std::nothrow) RDMAWorker(name, ctx, options, memPool, sglMemPool); + ASSERT_NE(mWorker, nullptr); + + cq = new (std::nothrow) RDMACq(name, ctx, false, 0); + ASSERT_NE(cq, nullptr); + + uint32_t mid = 0; + QpOptions qpOptions; + qp = new (std::nothrow) RDMAQp(name, mid, ctx, cq, qpOptions); + ASSERT_NE(qp, nullptr); + + uint32_t rdmaOpCtxPoolSize = NN_NO1; + RDMAPollingMode pollMode = EVENT_POLLING; + ep = new (std::nothrow) RDMASyncEndpoint(name, ctx, pollMode, cq, qp, rdmaOpCtxPoolSize); + ASSERT_NE(ep, nullptr); + + uint64_t id = 0; + NEP = new (std::nothrow) NetSyncEndpoint(id, ep, mDriver, mWorkerIndex); + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mAllowedSize = NN_NO128; + NEP->mSegSize = NN_NO128; + + request.lAddress = reinterpret_cast(&mWorkerIndex); + request.size = 1; + iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + sglRequest = UBSHcomNetTransSglRequest(iov, 1, 1); +} + +void TestNetRdmaSyncEndpoint::TearDown() +{ + if (Mr != nullptr) { + delete Mr; + Mr = nullptr; + } + if (NEP != nullptr) { + delete NEP; + NEP = nullptr; + } + if (iov != nullptr) { + delete iov; + iov = nullptr; + } + if (ctx != nullptr) { + delete ctx; + ctx = nullptr; + } + if (mWorker != nullptr) { + delete mWorker; + mWorker = nullptr; + } + + GlobalMockObject::verify(); +} +UBSHcomNetTransHeader mockSyncMrBuf{}; +static bool MockGetFreeBuffer(uintptr_t &mrBufAddress) +{ + mrBufAddress = reinterpret_cast(&mockSyncMrBuf); + return true; +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaUdsName) +{ + std::string ret = NEP->UdsName(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaGetSendQueueCount) +{ + int ret = NEP->GetSendQueueCount(); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendSeqFailed) +{ + name = "NetSyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendSeq) +{ + name = "NetSyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&RDMASyncEndpoint::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + NEP->mIsNeedEncrypt = 1; + ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendSeqTwo) +{ + name = "NetSyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&RDMASyncEndpoint::PostSend).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendInfoFailed) +{ + name = "NetSyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + UBSHcomNetTransOpInfo OpInfo{}; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendOpInfo) +{ + name = "NetSyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&RDMASyncEndpoint::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + UBSHcomNetTransOpInfo OpInfo{}; + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + NEP->mIsNeedEncrypt = 1; + ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendOpInfoTwo) +{ + name = "NetSyncEndpointRdmaPostSend"; + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&RDMASyncEndpoint::PostSend).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + UBSHcomNetTransOpInfo OpInfo{}; + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendOpInfoValidateFail) +{ + NEP->mState.Set(NEP_BROKEN); + UBSHcomNetTransOpInfo OpInfo{}; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + NEP->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendOpInfoWithHeaderRaw) +{ + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::RAW; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendOpInfoWithHeaderNull) +{ + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, nullptr, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendOpInfoWithHeaderValidateFailed) +{ + NEP->mState.Set(NEP_BROKEN); + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + NEP->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendOpInfoWithHeaderBuffer) +{ + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_GET_BUFF_FAILED); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendOpInfoWithHeaderMemcpy) +{ + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + + NEP->mIsNeedEncrypt = 1; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendOpInfoWithHeaderEpSend) +{ + MOCKER_CPP(&RDMAMemoryRegionFixedBuffer::GetFreeBuffer, bool(RDMAMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&RDMASyncEndpoint::PostSend) + .stubs() + .will(returnValue(static_cast(NN_OK))) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_OK); + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, RR_QP_POST_SEND_FAILED); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendRawSgl) +{ + name = "NetSyncEndpointRdmaPostSendRawSgl"; + MOCKER_CPP(&RDMASyncEndpoint::PostSendSgl).stubs().will(returnValue(1)).then(returnValue(0)); + + MOCKER_CPP(&NetDriverRDMAWithOob::ValidateMemoryRegion, + NResult(NetDriverRDMAWithOob::*)(uint64_t, uintptr_t, uint64_t)) + .stubs() + .will(returnValue(0)); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(static_cast(0))); + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaPostSendRawSglFail) +{ + MOCKER_CPP(&NetDriverRDMAWithOob::ValidateMemoryRegion, + NResult(NetDriverRDMAWithOob::*)(uint64_t, uintptr_t, uint64_t)) + .stubs() + .will(returnValue(0)); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(static_cast(0))); + NEP->mIsNeedEncrypt = true; + int ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); + + NEP->mState.Set(NEP_BROKEN); + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + NEP->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetRdmaSyncEndpoint, ComposedEndpointRdmaPostSendSgl) +{ + name = "ComposedEndpointRdmaPostSendSgl"; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(static_cast(0))); + + int ret = ep->PostSendSgl(sglRequest, request, 0, false); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); + if (ep->mQP != nullptr) { + delete ep->mQP; + ep->mQP = nullptr; + } + ret = ep->PostSendSgl(sglRequest, request, 0, false); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); +} +TEST_F(TestNetRdmaSyncEndpoint, ComposedEndpointRdmaPostSendSglTwo) +{ + name = "ComposedEndpointRdmaPostSendSglTwo"; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + sglRequest.upCtxSize = 1; + int ret = ep->PostSendSgl(sglRequest, request, 0, false); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); +} + +TEST_F(TestNetRdmaSyncEndpoint, ComposedEndpointRdmaPostSendSglThree) +{ + name = "ComposedEndpointRdmaPostSendSglThree"; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + sglRequest.upCtxSize = 0; + MOCKER_CPP(&RDMAQp::PostSend).stubs().will(returnValue(1)); + MOCKER_CPP(&RDMAQp::PostSendSgl).stubs().will(returnValue(0)); + int ret = ep->PostSendSgl(sglRequest, request, 0, true); + EXPECT_EQ(ret, 1); + ret = ep->PostSendSgl(sglRequest, request, 0, false); + EXPECT_EQ(ret, 0); +} +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaSetEpOption) +{ + UBSHcomEpOptions epOptions; + int ret = NEP->SetEpOption(epOptions); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaEstimatedEncryptLen) +{ + int ret = NEP->EstimatedEncryptLen(0); + EXPECT_EQ(ret, 0); + NEP->mIsNeedEncrypt = 0; + ret = NEP->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 0); +} +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaEstimatedEncryptLenTwo) +{ + NEP->mIsNeedEncrypt = 1; + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(1)); + int ret = NEP->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaEncrypt) +{ + uint64_t cipherLen = 0; + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + NEP->mIsNeedEncrypt = 0; + ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaEncryptTwo) +{ + uint64_t cipherLen = 0; + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(true)); + + NEP->mIsNeedEncrypt = 1; + int ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaEstimatedDecryptLen) +{ + NEP->mIsNeedEncrypt = 0; + int ret = NEP->EstimatedDecryptLen(0); + EXPECT_EQ(ret, 0); + + NEP->mIsNeedEncrypt = 1; + MOCKER_CPP(&AesGcm128::GetRawLen).stubs().will(returnValue(1)); + ret = NEP->EstimatedDecryptLen(0); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaDecrypt) +{ + uint64_t rawLen = 0; + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + NEP->mIsNeedEncrypt = 0; + ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaDecryptTwo) +{ + uint64_t rawLen = 0; + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(true)); + + NEP->mIsNeedEncrypt = 1; + int ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaGetRemoteUdsIdInfo) +{ + UBSHcomNetUdsIdInfo verbsIdInfo; + NEP->mState.Set(NEP_NEW); + int ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mDriver->mStartOobSvr = false; + ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaGetRemoteUdsIdInfoTwo) +{ + UBSHcomNetUdsIdInfo verbsIdInfo; + + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mDriver->mStartOobSvr = true; + NEP->mDriver->mOptions.oobType = NET_OOB_TCP; + int ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); + + NEP->mDriver->mOptions.oobType = NET_OOB_UDS; + ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaGetPeerIpPort) +{ + if (NEP->mEp->mQP != nullptr) { + delete NEP->mEp->mQP; + NEP->mEp->mQP = nullptr; + } + int ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + if (NEP->mEp != nullptr) { + delete NEP->mEp; + NEP->mEp = nullptr; + } + ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaGetPeerIpPortTwo) +{ + NEP->mEp->mQP->mPeerIpPort = "0.0.0.0"; + int ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + NEP->mEp->mQP->mPeerIpPort = "0.0.0.0:sss"; + ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetRdmaSyncEndpoint, NetSyncEndpointRdmaGetPeerIpPortThree) +{ + NEP->mEp->mQP->mPeerIpPort = "0.0.0.0:0"; + int ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + NEP->mEp->mQP->mPeerIpPort = "0.0.0.0:16"; + ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, true); +} + +TEST_F(TestNetRdmaSyncEndpoint, SyncReceiveFailWithErrorOpType) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + RDMAOpContextInfo opCtx{}; + opCtx.opType = RDMAOpContextInfo::SEND; + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(&RDMASyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_ERROR); +} + +TEST_F(TestNetRdmaSyncEndpoint, SyncReceiveFailWithOverDataSize) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + RDMAOpContextInfo opCtx{}; + opCtx.opType = RDMAOpContextInfo::RECEIVE; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NET_SGE_MAX_SIZE + NN_NO1; + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(&RDMASyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetRdmaSyncEndpoint, SyncReceiveFailWithErrDataLen) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + RDMAOpContextInfo opCtx{}; + opCtx.opType = RDMAOpContextInfo::RECEIVE; + opCtx.dataSize = NN_NO1024; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NN_NO2048; + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(&RDMASyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetRdmaSyncEndpoint, SyncReceiveFailWithInvalidHeader) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + RDMAOpContextInfo opCtx{}; + opCtx.opType = RDMAOpContextInfo::RECEIVE; + opCtx.dataSize = NN_NO1024; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(&RDMASyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_VALIDATE_HEADER_CRC_INVALID); +} + +TEST_F(TestNetRdmaSyncEndpoint, TestRDMASyncEndpointFunction) +{ + RDMASyncEndpoint syncEp {"name", nullptr, EVENT_POLLING, nullptr, nullptr, 0}; + EXPECT_EQ(syncEp.Initialize(), static_cast(RR_EP_NOT_INITIALIZED)); + syncEp.mQP = qp; + EXPECT_EQ(syncEp.Initialize(), static_cast(RR_EP_NOT_INITIALIZED)); + syncEp.mCq = cq; + + MOCKER_CPP(RDMACq::Initialize).stubs() + .will(returnValue(static_cast(RR_PARAM_INVALID))) + .then(returnValue(static_cast(RR_OK))); + EXPECT_EQ(syncEp.Initialize(), static_cast(RR_PARAM_INVALID)); + MOCKER_CPP(RDMAQp::Initialize).stubs() + .will(returnValue(static_cast(RR_PARAM_INVALID))) + .then(returnValue(static_cast(RR_OK))); + EXPECT_EQ(syncEp.Initialize(), static_cast(RR_PARAM_INVALID)); + EXPECT_EQ(syncEp.Initialize(), static_cast(NN_INVALID_PARAM)); + syncEp.mQP = nullptr; + syncEp.mCq = nullptr; +} + +TEST_F(TestNetRdmaSyncEndpoint, PostReceiveFail) +{ + RDMASyncEndpoint syncEp {"name", nullptr, EVENT_POLLING, nullptr, nullptr, 0}; + EXPECT_EQ(syncEp.PostReceive(0, 0, 0), static_cast(RR_PARAM_INVALID)); + EXPECT_EQ(syncEp.RePostReceive(nullptr), static_cast(RR_PARAM_INVALID)); + RDMASendReadWriteRequest rwReq {}; + EXPECT_EQ(syncEp.PostSend(rwReq), static_cast(RR_PARAM_INVALID)); + EXPECT_EQ(syncEp.PostRead(rwReq), static_cast(RR_PARAM_INVALID)); + EXPECT_EQ(syncEp.PostWrite(rwReq), static_cast(RR_PARAM_INVALID)); +} + +TEST_F(TestNetRdmaSyncEndpoint, PostOneSideSglFail) +{ + RDMASyncEndpoint syncEp {"name", nullptr, EVENT_POLLING, nullptr, nullptr, 0}; + RDMASendSglRWRequest sglRwReq {}; + EXPECT_EQ(syncEp.PostOneSideSgl(sglRwReq), static_cast(RR_PARAM_INVALID)); + + RDMASgeCtxInfo sge {}; + uint64_t ctxArr[NET_SGE_MAX_IOV]; + EXPECT_EQ(syncEp.CreateOneSideCtx(sge, nullptr, 0, ctxArr, false), static_cast(RR_PARAM_INVALID)); + + RDMAOpContextInfo *opCtxInfo = nullptr; + uint32_t immData = 0; + EXPECT_EQ(syncEp.PollingCompletion(opCtxInfo, 0, immData), static_cast(RR_EP_NOT_INITIALIZED)); +} + +TEST_F(TestNetRdmaSyncEndpoint, SyncPollingCompletionContextInfoNull) +{ + RDMAOpContextInfo *opCtx = nullptr; + uint32_t immData = 0; + MOCKER(RDMACq::EventProgressV).stubs().will(returnValue(0)); + RResult ret = ep->PollingCompletion(opCtx, 0, immData); + EXPECT_EQ(ret, RR_CQ_WC_WRONG); +} +} +} +//#endif diff --git a/test/unit_test/transport/rdma/verbs/test_net_rdma_worker.cpp b/test/unit_test/transport/rdma/verbs/test_net_rdma_worker.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29036b0cac7edf8b46f192ab2c737a5ecf2234e6 --- /dev/null +++ b/test/unit_test/transport/rdma/verbs/test_net_rdma_worker.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +//#ifdef RDMA_BUILD_ENABLED +#include +#include +#include "hcom.h" +#include "net_common.h" +#include "net_rdma_driver_oob.h" +#include "net_security_rand.h" +#include "rdma_validation.h" +#include "rdma_composed_endpoint.h" +#include "net_rdma_sync_endpoint.h" + +namespace ock { +namespace hcom { + +class TestNetRdmaWorker : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name; + RDMAContext *ctx = nullptr; + RDMAWorker *mWorker = nullptr; + RDMACq *cq = nullptr; + RDMAQp *qp = nullptr; + UBSHcomNetWorkerIndex mWorkerIndex{}; + UBSHcomNetTransRequest request; + RDMASglContextInfo *sglCtx; + RDMAOpContextInfo *rdmaCtx; + UBSHcomNetTransSglRequest sglRequest; + UBSHcomNetTransSgeIov *iov = nullptr; +}; + +void TestNetRdmaWorker::SetUp() +{ + RDMAGId gid = {}; + ctx = new (std::nothrow) RDMAContext(name, true, gid); + ASSERT_NE(ctx, nullptr); + + RDMAWorkerOptions options{}; + NetMemPoolFixedPtr memPool; + NetMemPoolFixedPtr sglMemPool; + mWorker = new (std::nothrow) RDMAWorker(name, ctx, options, memPool, sglMemPool); + ASSERT_NE(mWorker, nullptr); + + cq = new (std::nothrow) RDMACq(name, ctx, false, 0); + ASSERT_NE(cq, nullptr); + + uint32_t mid = 0; + QpOptions qpOptions = {}; + qp = new (std::nothrow) RDMAQp(name, mid, ctx, cq, qpOptions); + ASSERT_NE(qp, nullptr); + + request.lAddress = reinterpret_cast(&mWorkerIndex); + request.size = 1; + iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + sglRequest = UBSHcomNetTransSglRequest(iov, 1, 0); + sglCtx = new (std::nothrow) RDMASglContextInfo(); + rdmaCtx = new (std::nothrow) RDMAOpContextInfo(); +} + +void TestNetRdmaWorker::TearDown() +{ + if (mWorker != nullptr) { + delete mWorker; + mWorker = nullptr; + } + if (sglCtx != nullptr) { + delete sglCtx; + sglCtx = nullptr; + } + if (rdmaCtx != nullptr) { + delete rdmaCtx; + rdmaCtx = nullptr; + } + if (iov != nullptr) { + delete iov; + iov = nullptr; + } + if (qp != nullptr) { + delete qp; + qp = nullptr; + } + if (ctx != nullptr) { + delete ctx; + ctx = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSgl) +{ + name = "NetSyncEndpointRdmaPostSendSgl"; + RDMAQp *tmpQp = nullptr; + int ret = mWorker->PostSendSgl(tmpQp, sglRequest, request, 0, 0); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); + RDMASglContextInfo *tmpSglCtx = nullptr; + MOCKER_CPP(&RDMASglContextInfoPool::Get).stubs().will(returnValue(tmpSglCtx)); + ret = mWorker->PostSendSgl(qp, sglRequest, request, 0, 0); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSglTwo) +{ + name = "NetSyncEndpointRdmaPostSendSgltwo"; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + MOCKER_CPP(&RDMASglContextInfoPool::Get).stubs().will(returnValue(sglCtx)); + int ret = mWorker->PostSendSgl(qp, sglRequest, request, 0, 0); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); + + sglRequest.upCtxSize = 1; + ret = mWorker->PostSendSgl(qp, sglRequest, request, 0, 0); + EXPECT_EQ(ret, static_cast(RR_PARAM_INVALID)); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSglThree) +{ + name = "NetSyncEndpointRdmaPostSendSglThree"; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + MOCKER_CPP(&RDMASglContextInfoPool::Get).stubs().will(returnValue(sglCtx)); + + RDMAOpContextInfo *tmpCtx = nullptr; + MOCKER_CPP(&RDMAOpContextInfoPool::Get).stubs().will(returnValue(tmpCtx)).then(returnValue(rdmaCtx)); + + MOCKER_CPP(&RDMAQp::GetPostSendWr).stubs().will(returnValue(false)); + + sglRequest.upCtxSize = 0; + int ret = mWorker->PostSendSgl(qp, sglRequest, request, 0, 0); + EXPECT_EQ(ret, static_cast(RR_QP_CTX_FULL)); + MOCKER_CPP(&RDMAOpContextInfoPool::Return).stubs().will(returnValue(0)); + ret = mWorker->PostSendSgl(qp, sglRequest, request, 0, 0); + EXPECT_EQ(ret, static_cast(RR_QP_POST_SEND_WR_FULL)); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSglFour) +{ + name = "NetSyncEndpointRdmaPostSendSglFour"; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + RDMASglContextInfoPool mSglCtxInfoPool; + MOCKER_CPP(&RDMASglContextInfoPool::Get).stubs().will(returnValue(sglCtx)); + + MOCKER_CPP(&RDMAOpContextInfoPool::Get).stubs().will(returnValue(rdmaCtx)); + + MOCKER_CPP(&RDMAQp::GetPostSendWr).stubs().will(returnValue(true)); + + MOCKER_CPP(&RDMAQp::PostSend).stubs().will(returnValue(1)); + + MOCKER_CPP(&RDMAQp::PostSendSgl).stubs().will(returnValue(0)); + + MOCKER_CPP(&RDMAOpContextInfoPool::Return).stubs().will(returnValue(0)); + + MOCKER_CPP(&RDMASglContextInfoPool ::Return).stubs().will(returnValue(0)); + sglRequest.upCtxSize = 0; + int ret = mWorker->PostSendSgl(qp, sglRequest, request, 0, 0); + EXPECT_EQ(ret, 0); + + ret = mWorker->PostSendSgl(qp, sglRequest, request, 0, 1); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSglInlineOne) +{ + name = "RdmaWorkerPostSendSglInlineOne"; + RDMASendSglInlineHeader header; + RDMASendReadWriteRequest req; + uint32_t immData = 0; + RResult ret = mWorker->PostSendSglInline(nullptr, header, req, immData); + EXPECT_EQ(ret, 200); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSglInlineTwo) +{ + name = "RdmaWorkerPostSendSglInlineTwo"; + RDMASendSglInlineHeader header; + RDMASendReadWriteRequest req; + uint32_t immData = 0; + RDMAOpContextInfo *tmpRdmaCtx = nullptr; + MOCKER_CPP(&RDMAOpContextInfoPool::Get).stubs().will(returnValue(tmpRdmaCtx)); + RResult ret = mWorker->PostSendSglInline(qp, header, req, immData); + EXPECT_EQ(ret, 232); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSglInlineThree) +{ + name = "RdmaWorkerPostSendSglInlineThree"; + RDMASendSglInlineHeader header; + RDMASendReadWriteRequest req; + uint32_t immData = 0; + MOCKER_CPP(&RDMAOpContextInfoPool::Get).stubs().will(returnValue(rdmaCtx)); + MOCKER_CPP(&RDMAQp::GetPostSendWr).stubs().will(returnValue(false)); + MOCKER_CPP(&RDMAOpContextInfoPool::Return).stubs().will(returnValue(0)); + RResult ret = mWorker->PostSendSglInline(qp, header, req, immData); + EXPECT_EQ(ret, 230); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSglInlineFive) +{ + name = "RdmaWorkerPostSendSglInlineFive"; + RDMASendSglInlineHeader header; + RDMASendReadWriteRequest req; + req.upCtxSize = 100; + uint32_t immData = 0; + MOCKER_CPP(&RDMAOpContextInfoPool::Get).stubs().will(returnValue(rdmaCtx)); + MOCKER_CPP(&RDMAQp::GetPostSendWr).stubs().will(returnValue(true)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + RResult ret = mWorker->PostSendSglInline(qp, header, req, immData); + EXPECT_EQ(ret, 200); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSglInlineSix) +{ + name = "RdmaWorkerPostSendSglInlineSix"; + RDMASendSglInlineHeader header; + RDMASendReadWriteRequest req; + uint32_t immData = 0; + MOCKER_CPP(&RDMAOpContextInfoPool::Get).stubs().will(returnValue(rdmaCtx)); + MOCKER_CPP(&RDMAQp::GetPostSendWr).stubs().will(returnValue(true)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + MOCKER_CPP(&RDMAQp::PostSendSglInline).stubs().will(returnValue(201)); + MOCKER_CPP(&RDMAQp::ReturnPostSendWr).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&RDMAQp::DecreaseRef).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&RDMAOpContextInfoPool::Return).stubs().will(returnValue(0)); + RResult ret = mWorker->PostSendSglInline(qp, header, req, immData); + EXPECT_EQ(ret, 201); +} + +TEST_F(TestNetRdmaWorker, RdmaWorkerPostSendSglInlineSeven) +{ + name = "RdmaWorkerPostSendSglInlineSeven"; + RDMASendSglInlineHeader header; + RDMASendReadWriteRequest req; + uint32_t immData = 0; + MOCKER_CPP(&RDMAOpContextInfoPool::Get).stubs().will(returnValue(rdmaCtx)); + MOCKER_CPP(&RDMAQp::GetPostSendWr).stubs().will(returnValue(true)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&RDMAQp::PostSendSglInline).stubs().will(returnValue(0)); + RResult ret = mWorker->PostSendSglInline(qp, header, req, immData); + EXPECT_EQ(ret, 0); +} + +} +} +//#endif \ No newline at end of file diff --git a/test/unit_test/transport/rdma/verbs/test_rdma_verbs_wrapper.cpp b/test/unit_test/transport/rdma/verbs/test_rdma_verbs_wrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7fcb4c1b35f5afe3413fe549d1eec36854d4d5f6 --- /dev/null +++ b/test/unit_test/transport/rdma/verbs/test_rdma_verbs_wrapper.cpp @@ -0,0 +1,229 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "rdma_verbs_wrapper_qp.h" +#include "rdma_mr_dm_buf.h" +#include "rdma_mr_fixed_buf.h" + +namespace ock { +namespace hcom { +class TestRdmaVerbsWrapper : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); +}; + +void TestRdmaVerbsWrapper::SetUp() +{ +} + +void TestRdmaVerbsWrapper::TearDown() +{ +} + +TEST_F(TestRdmaVerbsWrapper, DeviceHelperUnInitialize) +{ + RDMADeviceHelper::G_Inited = true; + EXPECT_NO_FATAL_FAILURE(RDMADeviceHelper::UnInitialize()); + EXPECT_EQ(RDMADeviceHelper::G_Inited, false); +} + +inline NResult MockFilterIp(const std::string &ipMask, std::vector &outIps) +{ + outIps.emplace_back("192.168.0.0"); + return 0; +} + +TEST_F(TestRdmaVerbsWrapper, DeviceHelperGetEnableDeviceCountWithEmptyMatchIp) +{ + std::vector enableIps; + uint16_t enableCount = 0; + RResult result = RDMADeviceHelper::GetEnableDeviceCount("12345", enableCount, enableIps, ""); + EXPECT_EQ(result, NN_INVALID_IP); + + MOCKER_CPP(&FilterIp).stubs().will(returnValue(0)).then(invoke(MockFilterIp)); + result = 0; + result = RDMADeviceHelper::GetEnableDeviceCount("192.168.0.0/24", enableCount, enableIps, ""); + EXPECT_EQ(result, NN_INVALID_IP); +} + +RResult MockGetDeviceByIp(const std::string &ip, RDMAGId &gid) +{ + gid.devIndex = 0; + return 0; +} + +TEST_F(TestRdmaVerbsWrapper, DeviceHelperGetEnableDeviceCountWithMatchIp) +{ + std::vector enableIps; + uint16_t enableCount = 0; + MOCKER(RDMADeviceHelper::Initialize).stubs().will(returnValue(205)).then(returnValue(0)); + MOCKER_CPP(&FilterIp).stubs().will(invoke(MockFilterIp)); + MOCKER(RDMADeviceHelper::GetDeviceByIp).stubs().will(invoke(MockGetDeviceByIp)); + + RResult result = RDMADeviceHelper::GetEnableDeviceCount("192.168.0.0/24", enableCount, enableIps, ""); + EXPECT_EQ(result, RR_DEVICE_FAILED_OPEN); + + RDMADeviceSimpleInfo simpleInfo {}; + simpleInfo.active = true; + RDMADeviceHelper::G_RDMADevMap[0] = simpleInfo; + result = RDMADeviceHelper::GetEnableDeviceCount("192.168.0.0/24", enableCount, enableIps, ""); + EXPECT_EQ(enableCount, 1); + EXPECT_EQ(result, RR_OK); +} + +TEST_F(TestRdmaVerbsWrapper, OpResult) +{ + ibv_wc wc {}; + wc.status = IBV_WC_SUCCESS; + EXPECT_EQ(RDMAOpContextInfo::OpResult(wc), RDMAOpContextInfo::OpResultType::SUCCESS); + wc.status = IBV_WC_RETRY_EXC_ERR; + EXPECT_EQ(RDMAOpContextInfo::OpResult(wc), RDMAOpContextInfo::OpResultType::ERR_TIMEOUT); + wc.status = IBV_WC_RNR_RETRY_EXC_ERR; + EXPECT_EQ(RDMAOpContextInfo::OpResult(wc), RDMAOpContextInfo::OpResultType::ERR_TIMEOUT); + wc.status = IBV_WC_WR_FLUSH_ERR; + EXPECT_EQ(RDMAOpContextInfo::OpResult(wc), RDMAOpContextInfo::OpResultType::ERR_CANCELED); + wc.status = IBV_WC_LOC_LEN_ERR; + EXPECT_EQ(RDMAOpContextInfo::OpResult(wc), RDMAOpContextInfo::OpResultType::ERR_IO_ERROR); +} + +TEST_F(TestRdmaVerbsWrapper, GetNResult) +{ + RDMAOpContextInfo::OpResultType type = RDMAOpContextInfo::OpResultType::SUCCESS; + EXPECT_EQ(RDMAOpContextInfo::GetNResult(type), NN_OK); + type = RDMAOpContextInfo::OpResultType::ERR_TIMEOUT; + EXPECT_EQ(RDMAOpContextInfo::GetNResult(type), NN_MSG_TIMEOUT); + type = RDMAOpContextInfo::OpResultType::ERR_CANCELED; + EXPECT_EQ(RDMAOpContextInfo::GetNResult(type), NN_MSG_CANCELED); + type = RDMAOpContextInfo::OpResultType::ERR_EP_BROKEN; + EXPECT_EQ(RDMAOpContextInfo::GetNResult(type), NN_EP_BROKEN); + type = RDMAOpContextInfo::OpResultType::ERR_EP_CLOSE; + EXPECT_EQ(RDMAOpContextInfo::GetNResult(type), NN_EP_CLOSE); + type = RDMAOpContextInfo::OpResultType::ERR_IO_ERROR; + EXPECT_EQ(RDMAOpContextInfo::GetNResult(type), NN_MSG_ERROR); +} + +TEST_F(TestRdmaVerbsWrapper, DoInitialize) +{ + MOCKER_CPP(RDMADeviceHelper::DoUpdate).stubs() + .will(returnValue(static_cast(RR_DEVICE_FAILED_OPEN))); + EXPECT_EQ(RDMADeviceHelper::DoInitialize(), RR_DEVICE_FAILED_OPEN); + EXPECT_NO_FATAL_FAILURE(RDMADeviceHelper::Update()); +} + +TEST_F(TestRdmaVerbsWrapper, DoUpdate) +{ + std::vector outGidVec; + EXPECT_NO_FATAL_FAILURE(RDMADeviceHelper::GetGidVec(nullptr, "name", 0, 0, 0, outGidVec)); + + ibv_context ctx {}; + VerbsAPI::hcomInnerQueryGid = + [](struct ibv_context *context, uint8_t port_num, int index, union ibv_gid *gid) { return 1; }; + EXPECT_NO_FATAL_FAILURE(RDMADeviceHelper::GetGidVec(&ctx, "name", 0, 0, 1, outGidVec)); + + VerbsAPI::hcomInnerQueryGid = + [](struct ibv_context *context, uint8_t port_num, int index, union ibv_gid *gid) { return 0; }; + EXPECT_NO_FATAL_FAILURE(RDMADeviceHelper::GetGidVec(&ctx, "name", 0, 0, 1, outGidVec)); +} + +TEST_F(TestRdmaVerbsWrapper, StrToRoCEVersion) +{ + EXPECT_EQ(RDMADeviceHelper::StrToRoCEVersion("IB/RoCE v1"), RoCE_V1); + EXPECT_EQ(RDMADeviceHelper::StrToRoCEVersion("RoCE v2"), RoCE_V2); + EXPECT_EQ(RDMADeviceHelper::StrToRoCEVersion("5555"), RoCE_V15); +} + +TEST_F(TestRdmaVerbsWrapper, RDMAContextInitializeFail) +{ + RDMAGId gid {}; + RDMAContext ctx {"name", false, gid}; + ibv_context ctx1 {}; + ctx.mContext = &ctx1; + + MOCKER_CPP(RDMAContext::UnInitialize).stubs() + .will(returnValue(static_cast(RR_OK))); + + EXPECT_EQ(ctx.Initialize(), RR_OK); + ctx.mContext = nullptr; +} + +TEST_F(TestRdmaVerbsWrapper, Initialize) +{ + RDMAGId gid {}; + RDMAContext ctx {"name", false, gid}; + MOCKER_CPP(RDMADeviceHelper::Update).stubs() + .will(returnValue(static_cast(RR_DEVICE_FAILED_OPEN))) + .then(returnValue(static_cast(RR_OK))); + EXPECT_NO_FATAL_FAILURE(ctx.UpdateGid("IP")); + + MOCKER_CPP(RDMADeviceHelper::GetDeviceByIp).stubs() + .will(returnValue(static_cast(RR_DEVICE_FAILED_OPEN))) + .then(returnValue(static_cast(RR_OK))); + + EXPECT_NO_FATAL_FAILURE(ctx.UpdateGid("IP")); +} + +TEST_F(TestRdmaVerbsWrapper, RDMACqInitializeFail) +{ + RDMACq cq {"name", nullptr}; + + MOCKER_CPP(RDMACq::UnInitialize).stubs() + .will(returnValue(static_cast(RR_OK))); + ibv_cq ibvCq {}; + EXPECT_EQ(cq.Initialize(), RR_PARAM_INVALID); + cq.mCompletionQueue = &ibvCq; + EXPECT_EQ(cq.Initialize(), RR_OK); + int count = 0; + EXPECT_EQ(cq.ProgressV(nullptr, count), RR_CQ_NOT_INITIALIZED); +} + +TEST_F(TestRdmaVerbsWrapper, RDMAQpCreateFail) +{ + RDMAQp qp {"name", 0, nullptr, nullptr}; + MOCKER_CPP(RDMAQp::UnInitialize).stubs() + .will(returnValue(static_cast(RR_OK))); + EXPECT_EQ(qp.CreateIbvQp(), RR_PARAM_INVALID); + RDMAQpExchangeInfo info {}; + EXPECT_EQ(qp.ChangeToReady(info), RR_QP_CHANGE_STATE_FAILED); + EXPECT_EQ(qp.GetExchangeInfo(info), RR_QP_NOT_INITIALIZED); + + MOCKER_CPP(RDMAMemoryRegionFixedBuffer::Create).stubs() + .will(returnValue(static_cast(RR_PARAM_INVALID))); + EXPECT_EQ(qp.CreateQpMr(), RR_PARAM_INVALID); + + RDMAMemoryRegionFixedBuffer mr {"name", nullptr, 0, 0}; + qp.mQpMr = &mr; + + uintptr_t item = 0; + EXPECT_EQ(qp.GetFreeBuff(item), false); +} + +TEST_F(TestRdmaVerbsWrapper, GetEnableDeviceCount) +{ + int ret; + std::string ipMask(NN_NO1024, 'a'); + uint16_t enableDevCount = 0; + std::vector enableIps{}; + std::string ipGroup{}; + + ret = RDMADeviceHelper::GetEnableDeviceCount(ipMask, enableDevCount, enableIps, ipGroup); + EXPECT_EQ(ret, NN_INVALID_IP); +} + +} +} \ No newline at end of file diff --git a/test/unit_test/transport/rdma/verbs/test_rdma_worker.cpp b/test/unit_test/transport/rdma/verbs/test_rdma_worker.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8bce488d226bd63f72646c5c2716a6597606520 --- /dev/null +++ b/test/unit_test/transport/rdma/verbs/test_rdma_worker.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +//#ifdef RDMA_BUILD_ENABLED +#include +#include +#include +#include +#include "hcom_utils.h" +#include "net_common.h" +#include "rdma_worker.h" + +namespace ock { +namespace hcom { + +class TestRdmaWorker : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name; + RDMAContext *ctx = nullptr; + RDMAWorker *mWorker = nullptr; +}; + +void TestRdmaWorker::SetUp() +{ + RDMAGId gid = {}; + ctx = new (std::nothrow) RDMAContext(name, true, gid); + ASSERT_NE(ctx, nullptr); + + RDMAWorkerOptions options{}; + NetMemPoolFixedPtr memPool; + NetMemPoolFixedPtr sglMemPool; + mWorker = new (std::nothrow) RDMAWorker(name, ctx, options, memPool, sglMemPool); + ASSERT_NE(mWorker, nullptr); +} + +void TestRdmaWorker::TearDown() +{ + if (mWorker != nullptr) { + delete mWorker; + mWorker = nullptr; + } + if (ctx != nullptr) { + delete ctx; + ctx = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestRdmaWorker, RdmaWorkerReInitializeCQ) +{ + mWorker->mInited = false; + int ret = mWorker->ReInitializeCQ(); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestRdmaWorker, RdmaWorkerReInitializeCQTwo) +{ + RDMACq *tmpCQ = new (std::nothrow) RDMACq(name, ctx, false, 0); + mWorker->mRDMACq = tmpCQ; + MOCKER_CPP(&RDMACq::Initialize).stubs().will(returnValue(1)).then(returnValue(0)); + mWorker->mInited = true; + int ret = mWorker->ReInitializeCQ(); + EXPECT_EQ(ret, 1); + ret = mWorker->ReInitializeCQ(); + EXPECT_EQ(ret, 0); + if (tmpCQ != nullptr) { + delete tmpCQ; + tmpCQ = nullptr; + } +} + +TEST_F(TestRdmaWorker, RdmaReadRoCEVersionFromFile) +{ + std::string version = ""; + std::string deviceName = ""; + EXPECT_EQ(ReadRoCEVersionFromFile(deviceName, 0, 0, version), static_cast(RR_PARAM_INVALID)); +} + +} +} +//#endif \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_net_shm_driver_oob.cpp b/test/unit_test/transport/shm/test_net_shm_driver_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba25b8e7f36a9217e8ffd46cfe85314f676e28b0 --- /dev/null +++ b/test/unit_test/transport/shm/test_net_shm_driver_oob.cpp @@ -0,0 +1,793 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include +#include "net_shm_sync_endpoint.h" +#include "net_shm_async_endpoint.h" +#include "shm_composed_endpoint.h" +#include "net_oob_secure.h" +#include "net_oob_ssl.h" +#include "shm_validation.h" +#include "net_shm_driver_oob.h" + +namespace ock { +namespace hcom { +class TestNetShmDriverOob : public testing::Test { +public: + TestNetShmDriverOob(); + virtual void SetUp(void); + virtual void TearDown(void); + + NetDriverShmWithOOB *driver = nullptr; +}; + +TestNetShmDriverOob::TestNetShmDriverOob() {} + +void TestNetShmDriverOob::SetUp() +{ + driver = new (std::nothrow) NetDriverShmWithOOB("ShmDriverClearShmLeftFile", false, SHM); + ASSERT_NE(driver, nullptr); +} + +void TestNetShmDriverOob::TearDown() +{ + if (driver != nullptr) { + delete driver; + driver = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestNetShmDriverOob, ShmDriverClearShmLeftFile) +{ + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmDriverClearShmLeftFile", 0, NN_NO128, NN_NO4, ch); + ch->mUpCtx = 0; + + driver->ProcessEpError(ch); + driver->ClearShmLeftFile(); + EXPECT_EQ(driver->mClearThreadStarted, true); +} + +TEST_F(TestNetShmDriverOob, ShmDriverCreateMemoryRegion) +{ + int ret; + ASSERT_NE(driver, nullptr); + UBSHcomNetMemoryRegionPtr mr{}; + + ret = driver->CreateMemoryRegion(0, 0, mr); + EXPECT_EQ(ret, static_cast(NN_INVALID_OPERATION)); +} + +TEST_F(TestNetShmDriverOob, DriverInitializeInited) +{ + int ret; + UBSHcomNetDriverOptions option{}; + driver->mInited = true; + + ret = driver->Initialize(option); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetShmDriverOob, DriverInitializeValidateCommonOptionsFail) +{ + int ret; + UBSHcomNetDriverOptions option{}; + driver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(1)); + + ret = driver->Initialize(option); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, DriverInitializeOutLoggerInstanceFail) +{ + int ret; + UBSHcomNetDriverOptions option{}; + UBSHcomNetOutLogger *logger = nullptr; + driver->mInited = false; + MOCKER_CPP(&NetDriverShmWithOOB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&UBSHcomNetOutLogger::Instance).stubs().will(returnValue(logger)); + + ret = driver->Initialize(option); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, DriverInitializeLoadSslFail) +{ + int ret; + UBSHcomNetDriverOptions option{}; + + driver->mInited = false; + option.enableTls = true; + MOCKER_CPP(&NetDriverShmWithOOB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(HcomSsl::Load).stubs().will(returnValue(1)); + + ret = driver->Initialize(option); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, DriverInitializeFail) +{ + int ret; + UBSHcomNetDriverOptions option{}; + + driver->mInited = false; + option.enableTls = false; + MOCKER_CPP(&NetDriverShmWithOOB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverShmWithOOB::CreateWorkerResource).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverShmWithOOB::CreateWorkers).stubs().will(returnValue(1)); + + ret = driver->Initialize(option); + EXPECT_NE(ret, 0); + + ret = driver->Initialize(option); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, DriverInitializeFail2) +{ + int ret; + UBSHcomNetDriverOptions option{}; + + driver->mInited = false; + option.enableTls = false; + MOCKER_CPP(&NetDriverShmWithOOB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverShmWithOOB::CreateWorkerResource).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverShmWithOOB::CreateWorkers).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverShmWithOOB::CreateClientLB).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverShmWithOOB::CreateListeners).stubs().will(returnValue(1)); + + ret = driver->Initialize(option); + EXPECT_NE(ret, 0); + + driver->mStartOobSvr = true; + ret = driver->Initialize(option); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, ValidateOptionsFail) +{ + int ret; + MOCKER_CPP(&UBSHcomNetDriver::ValidateAndParseOobPortRange) + .stubs() + .will(returnValue(static_cast(NN_INVALID_PARAM))) + .then(returnValue(0)); + MOCKER_CPP(&UBSHcomNetDriver::ValidateOptionsOobType).stubs().will(returnValue(static_cast(NN_INVALID_PARAM))); + ret = driver->ValidateOptions(); + EXPECT_EQ(ret, NN_INVALID_PARAM); + ret = driver->ValidateOptions(); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetShmDriverOob, CreateWorkersFail) +{ + int ret; + UBSHcomNetDriverOptions option{}; + + MOCKER_CPP(NetFunc::NN_ParseWorkersGroups).stubs().will(returnValue(false)); + ret = driver->CreateWorkers(); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, UnInitializeNotInit) +{ + driver->mInited = false; + EXPECT_NO_FATAL_FAILURE(driver->UnInitialize()); +} + +TEST_F(TestNetShmDriverOob, UnInitializeStarted) +{ + driver->mInited = true; + driver->mStarted = true; + EXPECT_NO_FATAL_FAILURE(driver->UnInitialize()); +} + +TEST_F(TestNetShmDriverOob, CreateWorkerResourceOpCompMemPoolFail) +{ + int ret; + NetMemPoolFixed *testOpCompMemPool = nullptr; + MOCKER_CPP(&NetRef::Get).stubs().will(returnValue(testOpCompMemPool)); + ret = driver->CreateWorkerResource(); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, CreateWorkerResourceMemPoolInitializeFail) +{ + int ret; + MOCKER_CPP(&NetMemPoolFixed::Initialize).stubs().will(returnValue(0)).then(returnValue(1)); + ret = driver->CreateWorkerResource(); + EXPECT_NE(ret, 0); + + ret = driver->CreateWorkerResource(); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, CreateWorkerResourceSglCompMemPoolInitializeFail) +{ + int ret; + MOCKER_CPP(&NetMemPoolFixed::Initialize).stubs().will(returnValue(0)).then(returnValue(0)).then(returnValue(1)); + ret = driver->CreateWorkerResource(); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, StartNotInited) +{ + int ret; + driver->mInited = false; + ret = driver->Start(); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, StartChannelKeeperNull) +{ + int ret; + driver->mInited = true; + driver->mChannelKeeper = nullptr; + ret = driver->Start(); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, CreateMemoryRegionNotInited) +{ + int ret; + UBSHcomNetMemoryRegionPtr mr = nullptr; + driver->mInited = false; + ret = driver->CreateMemoryRegion(1, mr); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, CreateMemoryRegionShmMemoryRegionCreateFail) +{ + int ret; + UBSHcomNetMemoryRegionPtr mr = nullptr; + driver->mInited = true; + MOCKER_CPP(ShmMemoryRegion::Create, NResult(const std::string &, uint64_t, ShmMemoryRegion *&)) + .stubs() + .will(returnValue(1)); + ret = driver->CreateMemoryRegion(1, mr); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, CreateMemoryRegionMemId) +{ + int ret; + UBSHcomNetMemoryRegionPtr mr = nullptr; + ret = driver->CreateMemoryRegion(1, mr, 1); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, MultiRailNewConnectionErr) +{ + int ret; + OOBTCPConnection conn(-1); + ret = driver->MultiRailNewConnection(conn); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, DestroyEndpointNull) +{ + UBSHcomNetEndpointPtr ep = nullptr; + EXPECT_NO_FATAL_FAILURE(driver->DestroyEndpoint(ep)); +} + +TEST_F(TestNetShmDriverOob, DestroyMemoryRegionNull) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + EXPECT_NO_FATAL_FAILURE(driver->DestroyMemoryRegion(mr)); +} + +TEST_F(TestNetShmDriverOob, MapAndRegVaForUBErr) +{ + uint64_t va = 0; + void *ret = driver->MapAndRegVaForUB(1, va); + EXPECT_EQ(ret, nullptr); +} + +TEST_F(TestNetShmDriverOob, UnmapVaForUBErr) +{ + int ret; + uint64_t va = 0; + ret = driver->UnmapVaForUB(va); + EXPECT_NE(ret, 0); +} + +TEST_F(TestNetShmDriverOob, HandleNewRequestChannelStateFail) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_BROKEN); + ch->UpContext(1); + + ShmOpContextInfo ctx{ ch.Get(), 1, 1, ShmOpContextInfo::ShmOpType::SH_RECEIVE, + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR }; + ret = driver->HandleNewRequest(ctx, 0); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetShmDriverOob, HandleNewRequestValidateHeaderWithDataSizeFail) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ch->UpContext(reinterpret_cast(ep.Get())); + + ShmOpContextInfo ctx{ ch.Get(), 1, 1, ShmOpContextInfo::ShmOpType::SH_RECEIVE, + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR }; + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(100)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + ret = driver->HandleNewRequest(ctx, 0); + EXPECT_EQ(ret, NN_ERROR); +} + +TEST_F(TestNetShmDriverOob, HandleNewRequestEpToChildFail) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, nullptr, map); + ASSERT_NE(ep, nullptr); + ch->UpContext(reinterpret_cast(ep)); + + char testData[128] = "Hello, this is a test data."; + ShmOpContextInfo ctx{ ch.Get(), (uintptr_t)testData, 1, ShmOpContextInfo::ShmOpType::SH_RECEIVE, + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR }; + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + ret = driver->HandleNewRequest(ctx, 0); + EXPECT_EQ(ret, NN_PARAM_INVALID); + + ret = driver->HandleNewRequest(ctx, 1); + EXPECT_EQ(ret, NN_PARAM_INVALID); +} + +TEST_F(TestNetShmDriverOob, HandleNewRequestMallocDecryptFail) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mIsNeedEncrypt = true; + ch->UpContext(reinterpret_cast(ep)); + + char testData[128] = "Hello, this is a test data."; + ShmOpContextInfo ctx{ ch.Get(), (uintptr_t)testData, 1, ShmOpContextInfo::ShmOpType::SH_RECEIVE, + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR }; + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)) + .then(returnValue(false)) + .then(returnValue(true)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + ret = driver->HandleNewRequest(ctx, 0); + EXPECT_EQ(ret, NN_MALLOC_FAILED); + + MOCKER_CPP(&AesGcm128::Decrypt).stubs().will(returnValue(false)); + ret = driver->HandleNewRequest(ctx, 0); + EXPECT_EQ(ret, NN_DECRYPT_FAILED); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed).stubs().will(returnValue(false)).then(returnValue(true)); + ret = driver->HandleNewRequest(ctx, 1); + EXPECT_EQ(ret, NN_MALLOC_FAILED); + + ret = driver->HandleNewRequest(ctx, 1); + EXPECT_EQ(ret, NN_DECRYPT_FAILED); +} + +TEST_F(TestNetShmDriverOob, HandleNewRequestValidateDecryptLengthFail) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mIsNeedEncrypt = true; + ch->UpContext(reinterpret_cast(ep)); + + char testData[128] = "Hello, this is a test data."; + ShmOpContextInfo ctx{ ch.Get(), (uintptr_t)testData, 1, ShmOpContextInfo::ShmOpType::SH_RECEIVE, + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR }; + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed).stubs().will(returnValue(true)); + MOCKER_CPP(&AesGcm128::Decrypt).stubs().will(returnValue(true)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + ret = driver->HandleNewRequest(ctx, 0); + EXPECT_EQ(ret, NN_DECRYPT_FAILED); +} + +TEST_F(TestNetShmDriverOob, HandleNewRequestMallocMemcpyFail) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mIsNeedEncrypt = false; + ch->UpContext(reinterpret_cast(ep)); + + char testData[128] = "Hello, this is a test data."; + ShmOpContextInfo ctx{ ch.Get(), (uintptr_t)testData, 1, ShmOpContextInfo::ShmOpType::SH_RECEIVE, + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR }; + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + ret = driver->HandleNewRequest(ctx, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + ret = driver->HandleNewRequest(ctx, 1); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed).stubs().will(returnValue(false)); + ret = driver->HandleNewRequest(ctx, 0); + EXPECT_EQ(ret, NN_MALLOC_FAILED); + + ret = driver->HandleNewRequest(ctx, 1); + EXPECT_EQ(ret, NN_MALLOC_FAILED); +} + +TEST_F(TestNetShmDriverOob, HandleNewRequest) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + ch->UpContext(1); + + ShmOpContextInfo ctx{ ch.Get(), 1, 1, ShmOpContextInfo::ShmOpType::SH_SEND, + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR }; + ret = driver->HandleNewRequest(ctx, 0); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetShmDriverOob, HandleReqPostedFail) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ch->UpContext(reinterpret_cast(ep)); + + UBSHcomNetWorkerIndex indexWorker; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; + ShmWorker *worker = + new (std::nothrow) ShmWorker("shm", indexWorker, options, opMemPool, opCtxMemPool, sglOpMemPool); + ch->UpContext1(reinterpret_cast(worker)); + + ShmOpCompInfo ctx{}; + ctx.channel = ch.Get(); + ctx.opType = ShmOpContextInfo::ShmOpType::SH_RECEIVE; + driver->mRequestPostedHandler = [](const UBSHcomNetRequestContext &ctx) -> int { return SER_ERROR; }; + ret = driver->HandleReqPosted(ctx); + EXPECT_EQ(ret, SER_ERROR); +} + +TEST_F(TestNetShmDriverOob, ProcessEpError) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ch->UpContext(reinterpret_cast(ep)); + + EXPECT_NO_FATAL_FAILURE(driver->ProcessEpError(ch)); +} + +TEST_F(TestNetShmDriverOob, ProcessEpErrorEPBroken) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mEPBrokenProcessed = true; + ch->UpContext(reinterpret_cast(ep)); + + EXPECT_NO_FATAL_FAILURE(driver->ProcessEpError(ch)); +} + +TEST_F(TestNetShmDriverOob, ProcessEpErrorEPState) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + ch->UpContext(reinterpret_cast(ep)); + + EXPECT_NO_FATAL_FAILURE(driver->ProcessEpError(ch)); +} + +TEST_F(TestNetShmDriverOob, ProcessEpErrorTwoSideRemaining) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ch->UpContext(reinterpret_cast(ep)); + + ShmOpCompInfo ctx{}; + ctx.channel = ch.Get(); + ctx.opType = ShmOpContextInfo::ShmOpType::SH_RECEIVE; + ch.Get()->mCompPosted.next = &ctx; + + EXPECT_NO_FATAL_FAILURE(driver->ProcessEpError(ch)); +} + +TEST_F(TestNetShmDriverOob, ProcessEpErrorOneSideRemaining) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch->mState.Set(ShmChannelState::CH_NEW); + + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), nullptr, 0, index, map); + ASSERT_NE(ep, nullptr); + ch->UpContext(reinterpret_cast(ep)); + + ShmOpContextInfo ctx{ ch.Get(), 1, 1, ShmOpContextInfo::ShmOpType::SH_RECEIVE, + ShmOpContextInfo::ShmErrorType::SH_NO_ERROR }; + ch.Get()->mCtxPosted.next = &ctx; + + EXPECT_NO_FATAL_FAILURE(driver->ProcessEpError(ch)); +} + +TEST_F(TestNetShmDriverOob, Connect) +{ + int ret; + std::string payload = "Hello, this is a test data."; + UBSHcomNetEndpointPtr ep; + ret = driver->Connect(payload, ep, 0, 0, 0); + EXPECT_NE(ret, NN_OK); + + driver->mOptions.oobType = NET_OOB_UDS; + ret = driver->Connect(payload, ep, 0, 0, 0); + EXPECT_NE(ret, NN_OK); +} + +TEST_F(TestNetShmDriverOob, Connect2) +{ + int ret; + std::string payload{}; + UBSHcomNetEndpointPtr ep; + std::string serverUrl = "tcp://127.0.0.1:9981"; + std::string serverUrl2 = "uds://name"; + std::string badUrl = "unknown://127.0.0.1:9981"; + driver->mInited = true; + driver->mStarted = false; + ret = driver->Connect(serverUrl, payload, ep, 0, 0, 0, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + ret = driver->Connect(serverUrl2, payload, ep, 0, 0, 0, 0); + EXPECT_EQ(ret, NN_ERROR); + + ret = driver->Connect(badUrl, payload, ep, 0, 0, 0, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetShmDriverOob, ShmDriverHandleNewRequestFail) +{ + int ret; + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("ShmDriverCreateMemoryRegion", false, SHM); + ASSERT_NE(driver, nullptr); + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSend", 0, NN_NO128, NN_NO4, ch); + // mWorker create + UBSHcomNetWorkerIndex indexWorker; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; + ShmWorker *mWorker = new (std::nothrow) + ShmWorker("NetAsyncEndpointShmPostSend", indexWorker, options, opMemPool, opCtxMemPool, sglOpMemPool); + ASSERT_NE(mWorker, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), mWorker, 0, index, map); + ASSERT_NE(ep, nullptr); + ch->UpContext(reinterpret_cast(ep)); + ShmOpContextInfo ctx{}; + ctx.channel = ch.Get(); + ctx.dataAddress = 1; + ctx.dataSize = NN_NO1024; + + uint32_t immData = 0; + MOCKER_CPP(&ShmChannel::UpContext, uint64_t(ShmChannel::*)() const) + .stubs() + .will(returnValue(static_cast(1))) + .then(returnValue(static_cast(0))); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + ret = driver->HandleNewRequest(ctx, immData); + EXPECT_EQ(ret, static_cast(NN_PARAM_INVALID)); + + delete driver; + driver = nullptr; +} + +TEST_F(TestNetShmDriverOob, ShmDriverHandleNewRequestFailTwo) +{ + int ret; + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("ShmDriverCreateMemoryRegion", false, SHM); + ASSERT_NE(driver, nullptr); + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSend", 0, NN_NO128, NN_NO4, ch); + // mWorker create + UBSHcomNetWorkerIndex indexWorker; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; + ShmWorker *mWorker = new (std::nothrow) + ShmWorker("NetAsyncEndpointShmPostSend", indexWorker, options, opMemPool, opCtxMemPool, sglOpMemPool); + ASSERT_NE(mWorker, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), mWorker, 0, index, map); + ASSERT_NE(ep, nullptr); + ch->UpContext(reinterpret_cast(ep)); + ShmOpContextInfo ctx{}; + ctx.channel = ch.Get(); + UBSHcomNetTransHeader header{}; + ctx.dataAddress = reinterpret_cast(&header); + ctx.dataSize = NN_NO1024; + + uint32_t immData = 0; + MOCKER_CPP(&NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(static_cast(NN_PARAM_INVALID))); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + ret = driver->HandleNewRequest(ctx, immData); + EXPECT_EQ(ret, static_cast(NN_PARAM_INVALID)); + + delete driver; + driver = nullptr; +} + +TEST_F(TestNetShmDriverOob, HandleChanelKeeperMsgFail) +{ + int ret; + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("ShmDriverCreateMemoryRegion", false, SHM); + ASSERT_NE(driver, nullptr); + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSend", 0, NN_NO128, NN_NO4, ch); + ShmChKeeperMsgHeader header{}; + header.msgType = GET_MR_FD; + header.dataSize = sizeof(uint32_t); + MOCKER(::recv).defaults().will(returnValue(1)); + ShmHandlePtr shmHandlePtr = nullptr; + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap).stubs().will(returnValue(shmHandlePtr)); + driver->HandleChanelKeeperMsg(header, ch); + EXPECT_EQ(header.msgType, GET_MR_FD); + delete driver; + driver = nullptr; +} + +TEST_F(TestNetShmDriverOob, HandleChanelKeeperMsgFailTwo) +{ + int ret; + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("ShmDriverCreateMemoryRegion", false, SHM); + ASSERT_NE(driver, nullptr); + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSend", 0, NN_NO128, NN_NO4, ch); + ShmChKeeperMsgHeader header{}; + header.msgType = GET_MR_FD; + header.dataSize = sizeof(uint32_t); + MOCKER(::recv).defaults().will(returnValue(1)); + ShmHandlePtr shmHandlePtr = new (std::nothrow) ShmHandle("mName", SHM_F_EVENT_QUEUE_PREFIX, 1, NN_NO128, true); + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap).stubs().will(returnValue(shmHandlePtr)); + MOCKER_CPP(&ShmHandle::Fd).stubs().will(returnValue(0)); + driver->HandleChanelKeeperMsg(header, ch); + EXPECT_EQ(header.msgType, GET_MR_FD); + delete driver; + driver = nullptr; +} + +TEST_F(TestNetShmDriverOob, HandleChanelKeeperMsgFailThree) +{ + int ret; + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("ShmDriverCreateMemoryRegion", false, SHM); + ASSERT_NE(driver, nullptr); + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSend", 0, NN_NO128, NN_NO4, ch); + ShmChKeeperMsgHeader header{}; + header.msgType = GET_MR_FD; + driver->HandleChanelKeeperMsg(header, ch); + + header.dataSize = sizeof(uint32_t); + MOCKER(::recv).defaults().will(returnValue(0)); + driver->HandleChanelKeeperMsg(header, ch); + EXPECT_EQ(header.msgType, GET_MR_FD); + delete driver; + driver = nullptr; +} + +TEST_F(TestNetShmDriverOob, HandleChanelKeeperMsgFailFour) +{ + int ret; + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("ShmDriverCreateMemoryRegion", false, SHM); + ASSERT_NE(driver, nullptr); + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSend", 0, NN_NO128, NN_NO4, ch); + ShmChKeeperMsgHeader header{}; + header.msgType = GET_MR_FD; + header.dataSize = sizeof(uint32_t); + MOCKER(::recv).defaults().will(returnValue(1)); + ShmHandlePtr shmHandlePtr = new (std::nothrow) ShmHandle("mName", SHM_F_EVENT_QUEUE_PREFIX, 1, NN_NO128, true); + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap).stubs().will(returnValue(shmHandlePtr)); + MOCKER_CPP(&ShmHandle::Fd).stubs().will(returnValue(1)); + MOCKER(::send).defaults().will(returnValue(0)).then(returnValue(1)); + driver->HandleChanelKeeperMsg(header, ch); + EXPECT_EQ(header.msgType, GET_MR_FD); + + MOCKER_CPP(&ShmHandleFds::SendMsgFds).defaults().will(returnValue(1)); + driver->HandleChanelKeeperMsg(header, ch); + EXPECT_EQ(header.msgType, GET_MR_FD); + delete driver; + driver = nullptr; +} +} // namespace hcom +} // namespace ock \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_net_shm_endpoint.cpp b/test/unit_test/transport/shm/test_net_shm_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..92cab985857c9e8d3153242382fefadabd179af5 --- /dev/null +++ b/test/unit_test/transport/shm/test_net_shm_endpoint.cpp @@ -0,0 +1,2191 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include "hcom.h" +#include "net_shm_sync_endpoint.h" +#include "net_shm_async_endpoint.h" + +namespace ock { +namespace hcom { +uint8_t mockData[8]; +UBSHcomNetTransHeader mockReq{}; + +class TestNetShmEndpoint : public testing::Test { +public: + TestNetShmEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +TestNetShmEndpoint::TestNetShmEndpoint() {} + +static HResult MockGetPeerDataAddressByOffset(uint64_t offset, uintptr_t &address) +{ + address = reinterpret_cast(&mockReq); + offset = 0; + return 0; +} + +static HResult MockDequeueEvent(int32_t timeout, ShmEvent &opEvent) +{ + opEvent.opType = static_cast(mockData[0]); + opEvent.peerChannelAddress = 0; + return 0; +} +void TestNetShmEndpoint::SetUp() {} + +void TestNetShmEndpoint::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointShmPostSend) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSend", 0, NN_NO128, NN_NO4, ch); + // mWorker create + UBSHcomNetWorkerIndex indexWorker; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; + ShmWorker *mWorker = new (std::nothrow) ShmWorker("NetAsyncEndpointShmPostSend", indexWorker, options, opMemPool, + opCtxMemPool, sglOpMemPool); + ASSERT_NE(mWorker, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), mWorker, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + UBSHcomNetTransRequest request; + request.lAddress = reinterpret_cast(&indexWorker); + request.size = 1; + ep->mAllowedSize = NN_NO128; + + MOCKER_CPP(&ShmWorker::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE))); + MOCKER_CPP(&AesGcm128::Encrypt, + bool (AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + ep->mIsNeedEncrypt = 1; + ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); + + ep->mIsNeedEncrypt = 0; + ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointShmPostSendTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSendTwo", 0, NN_NO128, NN_NO4, ch); + // mWorker create + UBSHcomNetWorkerIndex indexWorker; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; + ShmWorker *mWorker = new (std::nothrow) ShmWorker("NetAsyncEndpointShmPostSendTwo", indexWorker, options, opMemPool, + opCtxMemPool, sglOpMemPool); + ASSERT_NE(mWorker, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), mWorker, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + UBSHcomNetTransRequest request; + request.lAddress = reinterpret_cast(&indexWorker); + request.size = 1; + ep->mAllowedSize = NN_NO128; + + MOCKER_CPP(&ShmWorker::PostSend) + .stubs() + .will(returnValue(0)) + .then(returnValue(static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE))); + + ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, 0); + + ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostSend) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostSend", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostSend", 0, SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + UBSHcomNetTransRequest request; + request.lAddress = reinterpret_cast(&index); + request.size = 1; + ep->mAllowedSize = NN_NO128; + UBSHcomNetTransOpInfo opInfo{}; + + MOCKER_CPP(&AesGcm128::Encrypt, + bool (AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + MOCKER(NetFunc::CalcHeaderCrc32, uint32_t(UBSHcomNetTransHeader *)) + .stubs() + .will(returnValue(static_cast(0))); + MOCKER_CPP(&ShmSyncEndpoint::PostSend) + .stubs() + .will(returnValue(1)); + + ep->mIsNeedEncrypt = true; + ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); + + ep->mIsNeedEncrypt = false; + ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostSendTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostSendTwo", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostSendTwo", 0, SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // shmEp create + ShmMRHandleMap map; + UBSHcomNetWorkerIndex index; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + UBSHcomNetTransRequest request; + request.lAddress = reinterpret_cast(&index); + request.size = 1; + ep->mAllowedSize = NN_NO128; + UBSHcomNetTransOpInfo opInfo{}; + + MOCKER(NetFunc::CalcHeaderCrc32, uint32_t(UBSHcomNetTransHeader *)) + .stubs() + .will(returnValue(static_cast(0))); + MOCKER_CPP(&ShmSyncEndpoint::PostSend) + .stubs() + .will(returnValue(0)) + .then(returnValue(static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE))); + + ep->mIsNeedEncrypt = false; + ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, 0); + ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostSendThree) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostSendThree", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostSendThree", 0, + SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + UBSHcomNetTransRequest request; + request.lAddress = reinterpret_cast(&index); + request.size = 1; + ep->mAllowedSize = NN_NO128; + UBSHcomNetTransOpInfo opInfo{}; + + MOCKER_CPP(&ShmSyncEndpoint::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER(NetFunc::CalcHeaderCrc32, uint32_t(UBSHcomNetTransHeader *)) + .stubs() + .will(returnValue(static_cast(0))); + + ep->mIsNeedEncrypt = false; + + ret = ep->PostSend(0, request, opInfo); + EXPECT_EQ(ret, 1); + + ret = ep->PostSend(0, request, opInfo); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostSendRaw) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostSendRaw", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostSendRaw", 0, SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + UBSHcomNetTransRequest request; + request.lAddress = reinterpret_cast(&index); + request.size = 1; + ep->mAllowedSize = NN_NO128; + ep->mSegSize = NN_NO128; + + MOCKER_CPP(&ShmSyncEndpoint::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&AesGcm128::Encrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + ep->mIsNeedEncrypt = true; + ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); + + ep->mIsNeedEncrypt = false; + ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostSendRawTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostSendRawTwo", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostSendRawTwo", 0, + SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + ep->mAllowedSize = NN_NO128; + ep->mSegSize = NN_NO128; + UBSHcomNetTransRequest request; + request.lAddress = reinterpret_cast(&index); + request.size = 1; + + MOCKER_CPP(&ShmSyncEndpoint::PostSend) + .stubs() + .will(returnValue(0)); + + ep->mIsNeedEncrypt = false; + ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostSendRawThree) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostSendRawThree", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostSendRawThree", 0, + SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + UBSHcomNetTransRequest request; + request.lAddress = reinterpret_cast(&index); + request.size = 1; + ep->mAllowedSize = NN_NO128; + ep->mSegSize = NN_NO128; + + MOCKER_CPP(&ShmChannel::DCGetFreeBuck) + .stubs() + .will(returnValue(1)); + + ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostSendRawSgl) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostSendRawSgl", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostSendRawSgl", 0, + SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // driver create + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("NetSyncEndpointShmPostSendRawSgl", false, + SHM); + ASSERT_NE(driver, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), driver, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // request create + UBSHcomNetTransSglRequest request{}; + UBSHcomNetTransSgeIov iov; + request.iovCount = 1; + request.iov = &iov; + iov.size = 1; + iov.lAddress = reinterpret_cast(&index); + ep->mSegSize = NN_NO128; + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::PostSendRawSgl) + .stubs() + .will(returnValue(1)); + MOCKER_CPP(&AesGcm128::Encrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + ep->mIsNeedEncrypt = true; + ret = ep->PostSendRawSgl(request, 1); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); + + ep->mIsNeedEncrypt = false; + ret = ep->PostSendRawSgl(request, 1); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostSendRawSglTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostSendRawSglTwo", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostSendRawSglTwo", 0, + SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // driver create + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("NetSyncEndpointShmPostSendRawSglTwo", false, + SHM); + ASSERT_NE(driver, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), driver, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // request create + UBSHcomNetTransSglRequest request{}; + UBSHcomNetTransSgeIov iov; + request.iovCount = 1; + request.iov = &iov; + iov.size = 1; + iov.lAddress = reinterpret_cast(&index); + ep->mSegSize = NN_NO128; + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::PostSendRawSgl) + .stubs() + .will(returnValue(0)); + + ep->mIsNeedEncrypt = false; + ret = ep->PostSendRawSgl(request, 1); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostRead) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostRead", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostRead", 0, SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // driver create + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("NetSyncEndpointShmPostRead", false, SHM); + ASSERT_NE(driver, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), driver, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // request create + UBSHcomNetTransRequest request{}; + request.size = 1; + request.lAddress = reinterpret_cast(&index); + ep->mSegSize = NN_NO128; + + UBSHcomNetTransOpInfo opInfo{}; + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::PostRead, HResult(ShmSyncEndpoint::*)(ShmChannel *, const UBSHcomNetTransRequest &, + ShmMRHandleMap &)) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + + ret = ep->PostRead(request); + EXPECT_EQ(ret, 1); + + ret = ep->PostRead(request); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostReadTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmPostReadTwo", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostReadTwo", 0, SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // driver create + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("NetSyncEndpointShmPostReadTwo", false, SHM); + ASSERT_NE(driver, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), driver, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // request create + UBSHcomNetTransSglRequest request{}; + UBSHcomNetTransSgeIov iov; + request.iovCount = 1; + request.iov = &iov; + iov.size = 1; + iov.lAddress = reinterpret_cast(&index); + ep->mSegSize = NN_NO128; + + UBSHcomNetTransOpInfo opInfo{}; + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::PostRead, HResult(ShmSyncEndpoint::*)(ShmChannel *, const UBSHcomNetTransSglRequest &, + ShmMRHandleMap &)) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + + ret = ep->PostRead(request); + EXPECT_EQ(ret, 1); + + ret = ep->PostRead(request); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostWrite) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ret = ShmChannel::CreateAndInit("NetSyncEndpointShmPostWrite", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostWrite", 0, SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // driver create + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("NetSyncEndpointShmPostWrite", false, SHM); + ASSERT_NE(driver, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), driver, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // request create + UBSHcomNetTransRequest request{}; + request.size = 1; + request.lAddress = reinterpret_cast(&index); + ep->mSegSize = NN_NO128; + + UBSHcomNetTransOpInfo opInfo{}; + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::PostWrite, HResult(ShmSyncEndpoint::*)(ShmChannel *, const UBSHcomNetTransRequest &, + ShmMRHandleMap &)) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + + ret = ep->PostWrite(request); + EXPECT_EQ(ret, 1); + + ret = ep->PostWrite(request); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmPostWriteTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ret = ShmChannel::CreateAndInit("NetSyncEndpointShmPostWriteTwo", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointShmPostWriteTwo", 0, SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // driver create + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("NetSyncEndpointShmPostWriteTwo", false, SHM); + ASSERT_NE(driver, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(ch->Id(), ch.Get(), driver, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // request create + UBSHcomNetTransSglRequest request{}; + UBSHcomNetTransSgeIov iov; + request.iovCount = 1; + request.iov = &iov; + iov.size = 1; + iov.lAddress = reinterpret_cast(&index); + ep->mSegSize = NN_NO128; + + UBSHcomNetTransOpInfo opInfo{}; + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::PostWrite, HResult(ShmSyncEndpoint::*)(ShmChannel *, const UBSHcomNetTransSglRequest &, + ShmMRHandleMap &)) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + + ret = ep->PostWrite(request); + EXPECT_EQ(ret, 1); + + ret = ep->PostWrite(request); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointSetEpOption) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + UBSHcomEpOptions epOptions{}; + + ret = ep->SetEpOption(epOptions); + EXPECT_EQ(ret, 0); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointGetSendQueueCount) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + ret = ep->GetSendQueueCount(); + EXPECT_EQ(ret, 0); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointPeerIpAndPort) +{ + std::string ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + ret = ep->PeerIpAndPort(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointPeerIpAndPortTwo) +{ + int ret; + std::string result; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointPeerIpAndPortTwo", 0, NN_NO128, NN_NO4, ch); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + result = ep->PeerIpAndPort(); + EXPECT_EQ(result, ""); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointUdsName) +{ + std::string ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + ret = ep->UdsName(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointUdsNameTwo) +{ + int ret; + std::string result; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointUdsNameTwo", 0, NN_NO128, NN_NO4, ch); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + result = ep->UdsName(); + EXPECT_EQ(result, ""); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointInvalidOperation) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + UBSHcomNetResponseContext ctx{}; + + ret = ep->WaitCompletion(0); + EXPECT_EQ(ret, static_cast(NN_INVALID_OPERATION)); + + ret = ep->Receive(0, ctx); + EXPECT_EQ(ret, static_cast(NN_INVALID_OPERATION)); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointInvalidOperationTwo) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + UBSHcomNetResponseContext ctx{}; + + ret = ep->ReceiveRaw(0, ctx); + EXPECT_EQ(ret, static_cast(NN_INVALID_OPERATION)); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointEnableEncrypt) +{ + // shmSecrets create + NetSecrets shmSecrets{}; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + UBSHcomNetDriverOptions options; + options.cipherSuite = AES_GCM_256; + ep->EnableEncrypt(options); + EXPECT_EQ(ep->mAes.mCipherSuite, AES_GCM_256); + + ep->SetSecrets(shmSecrets); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointEstimatedEncryptLen) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen) + .stubs() + .will(returnValue(1)); + ret = ep->EstimatedEncryptLen(0); + EXPECT_EQ(ret, 0); + + ret = ep->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointEstimatedEncryptLenTwo) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen) + .stubs() + .will(returnValue(1)); + + ep->mIsNeedEncrypt = true; + ret = ep->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointEncrypt) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // Encrypt data + uint8_t encryptData = 0; + uint8_t *cipher = reinterpret_cast(&index); + uint64_t cipherLen; + + MOCKER_CPP(&AesGcm128::Encrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + ret = ep->Encrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + ep->mIsNeedEncrypt = true; + ret = ep->Encrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointEncryptTwo) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // Encrypt data + uint8_t encryptData = 0; + uint8_t *cipher = reinterpret_cast(&index); + uint64_t cipherLen = 0; + + MOCKER_CPP(&AesGcm128::Encrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(true)); + + ep->mIsNeedEncrypt = true; + ret = ep->Encrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_OK)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointEstimatedDecryptLen) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + MOCKER_CPP(&AesGcm128::GetRawLen) + .stubs() + .will(returnValue(1)); + ret = ep->EstimatedDecryptLen(1); + EXPECT_EQ(ret, 0); + + ep->mIsNeedEncrypt = true; + ret = ep->EstimatedDecryptLen(1); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointDecrypt) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // Encrypt data + uint8_t encryptData = 0; + uint8_t *cipher = reinterpret_cast(&index); + uint64_t cipherLen; + + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + ret = ep->Decrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + ep->mIsNeedEncrypt = true; + ret = ep->Decrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointDecryptTwo) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + // Encrypt data + uint8_t encryptData = 0; + uint8_t *cipher = reinterpret_cast(&index); + uint64_t cipherLen = 0; + + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(true)); + ep->mIsNeedEncrypt = true; + ret = ep->Decrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_OK)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointSendFds) +{ + int ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_NEW); + // SendFds data + int fds[1] = {1}; + + ret = ep->SendFds(fds, 0); + EXPECT_EQ(ret, static_cast(NN_PARAM_INVALID)); + + ret = ep->SendFds(fds, 1); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointReceiveFds) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointReceiveFds", 0, NN_NO128, NN_NO4, ch); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), 0, 0, index, map); + ASSERT_NE(ep, nullptr); + // ReceiveFds data + int fds[1] = {1}; + + MOCKER_CPP(&ShmChannel::Close).stubs(); + MOCKER_CPP(&ShmChannel::RemoveUserFds) + .stubs() + .will(returnValue(0)); + + ret = ep->ReceiveFds(fds, 0, 0); + EXPECT_EQ(ret, static_cast(NN_PARAM_INVALID)); + + ret = ep->ReceiveFds(fds, 1, 0); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointReceiveFdsTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointReceiveFdsTwo", 0, NN_NO128, NN_NO4, ch); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(ch->Id(), ch.Get(), 0, 0, index, map); + ASSERT_NE(ep, nullptr); + // ReceiveFds data + int fds[1] = {1}; + + MOCKER_CPP(&ShmChannel::RemoveUserFds) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&ShmChannel::Close).stubs(); + + ep->mState.Set(NEP_ESTABLISHED); + ret = ep->ReceiveFds(fds, 1, 0); + EXPECT_EQ(ret, static_cast(NN_OK)); + ep->Close(); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointGetRemoteUdsIdInfo) +{ + int ret; + // driver create + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("NetAsyncEndpointGetRemoteUdsIdInfo", false, + SHM); + ASSERT_NE(driver, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, driver, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_NEW); + UBSHcomNetUdsIdInfo idInfo{}; + + MOCKER_CPP(&ShmChannel::Close).stubs(); + ret = ep->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + + ep->mState.Set(NEP_ESTABLISHED); + driver->mStartOobSvr = false; + ret = ep->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); + + ep->Close(); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointGetRemoteUdsIdInfoTwo) +{ + int ret; + // driver create + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("NetAsyncEndpointGetRemoteUdsIdInfo", false, + SHM); + ASSERT_NE(driver, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, driver, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_NEW); + UBSHcomNetUdsIdInfo idInfo{}; + + MOCKER_CPP(&ShmChannel::Close).stubs(); + + ep->mState.Set(NEP_ESTABLISHED); + driver->mStartOobSvr = true; + ret = ep->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); + + ep->Close(); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointGetPeerIpPort) +{ + bool ret; + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new (std::nothrow) NetAsyncEndpointShm(0, 0, 0, 0, index, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + + std::string ip("127.0.0.1"); + uint16_t port = 1234; + + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointSetEpOption) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new (std::nothrow) ShmSyncEndpoint("NetSyncEndpointSetEpOption", 0, SHM_EVENT_POLLING); + ASSERT_NE(shmEp, nullptr); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new (std::nothrow) NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ASSERT_NE(ep, nullptr); + ep->mState.Set(NEP_ESTABLISHED); + UBSHcomEpOptions epOptions{}; + + ret = ep->SetEpOption(epOptions); + EXPECT_EQ(ret, 0); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointGetSendQueueCount) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointGetSendQueueCount", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + ret = ep->GetSendQueueCount(); + EXPECT_EQ(ret, 0); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointPeerIpAndPort) +{ + std::string ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointPeerIpAndPort", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + ret = ep->PeerIpAndPort(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointPeerIpAndPortTwo) +{ + int ret; + std::string result; + // mShmCh create + ShmChannelPtr ch; + ret = ShmChannel::CreateAndInit("NetSyncEndpointPeerIpAndPort2", 0, NN_NO128, NN_NO4, ch); + if (NN_UNLIKELY(ret != NN_OK)) { + return; + } + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointPeerIpAndPort2", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + result = ep->PeerIpAndPort(); + EXPECT_EQ(result, ""); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointUdsName) +{ + std::string ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointUdsName", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + ret = ep->UdsName(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointUdsNameTwo) +{ + int ret; + std::string result; + // mShmCh create + ShmChannelPtr ch; + ret = ShmChannel::CreateAndInit("NetSyncEndpointUdsName2", 0, NN_NO128, NN_NO4, ch); + if (NN_UNLIKELY(ret != NN_OK)) { + return; + } + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointUdsName2", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + result = ep->UdsName(); + EXPECT_EQ(result, ""); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointEnableEncrypt) +{ + // shmSecrets create + NetSecrets shmSecrets{}; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointEnableEncrypt", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + UBSHcomNetDriverOptions options; + options.cipherSuite = AES_GCM_256; + ep->EnableEncrypt(options); + EXPECT_EQ(ep->mAes.mCipherSuite, AES_GCM_256); + + ep->SetSecrets(shmSecrets); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointEstimatedEncryptLen) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointEstimatedEncryptLen", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen) + .stubs() + .will(returnValue(1)); + ret = ep->EstimatedEncryptLen(0); + EXPECT_EQ(ret, 0); + + ret = ep->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointEstimatedEncryptLenTwo) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointEstimatedEncryptLen2", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen) + .stubs() + .will(returnValue(1)); + + ep->mIsNeedEncrypt = true; + ret = ep->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointEncrypt) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointEncrypt", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // Encrypt data + uint8_t encryptData = 0; + uint8_t *cipher = reinterpret_cast(&index); + uint64_t cipherLen; + + ret = ep->Encrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointEncryptTwo) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointEncrypt2", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // Encrypt data + uint8_t encryptData = 0; + uint8_t *cipher = reinterpret_cast(&index); + uint64_t cipherLen = 0; + + MOCKER_CPP(&AesGcm128::Encrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)); + + ep->mIsNeedEncrypt = true; + ret = ep->Encrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + ret = ep->Encrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_OK)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointEstimatedDecryptLen) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointEstimatedDecryptLen", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + MOCKER_CPP(&AesGcm128::GetRawLen) + .stubs() + .will(returnValue(1)); + ret = ep->EstimatedDecryptLen(1); + EXPECT_EQ(ret, 0); + + ep->mIsNeedEncrypt = true; + ret = ep->EstimatedDecryptLen(1); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointDecrypt) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointDecrypt", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // Encrypt data + uint8_t encryptData = 0; + uint8_t *cipher = reinterpret_cast(&index); + uint64_t cipherLen; + + ret = ep->Decrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointDecryptTwo) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointDecrypt2", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // Encrypt data + uint8_t encryptData = 0; + uint8_t *cipher = reinterpret_cast(&index); + uint64_t cipherLen = 0; + + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)); + + ep->mIsNeedEncrypt = true; + ret = ep->Decrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + ret = ep->Decrypt(&encryptData, 1, cipher, cipherLen); + EXPECT_EQ(ret, static_cast(NN_OK)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointSendFds) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointSendFds", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + // SendFds data + int fds[1] = {1}; + + ret = ep->SendFds(fds, 0); + EXPECT_EQ(ret, static_cast(NN_PARAM_INVALID)); + + ret = ep->SendFds(fds, 1); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointReceiveFds) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ret = ShmChannel::CreateAndInit("NetSyncEndpointReceiveFds", 0, NN_NO128, NN_NO4, ch); + if (NN_UNLIKELY(ret != NN_OK)) { + return; + } + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointReceiveFds", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_NEW); + // ReceiveFds data + int fds[1] = {1}; + + MOCKER_CPP(&ShmChannel::RemoveUserFds) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&ShmChannel::Close).stubs(); + + ret = ep->ReceiveFds(fds, 0, 0); + EXPECT_EQ(ret, static_cast(NN_PARAM_INVALID)); + + ret = ep->ReceiveFds(fds, 1, 0); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->Close(); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointReceiveFdsTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ret = ShmChannel::CreateAndInit("NetSyncEndpointReceiveFds2", 0, NN_NO128, NN_NO4, ch); + if (NN_UNLIKELY(ret != NN_OK)) { + return; + } + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointReceiveFds2", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_NEW); + // ReceiveFds data + int fds[1] = {1}; + + MOCKER_CPP(&ShmChannel::RemoveUserFds) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&ShmChannel::Close).stubs(); + + ep->mState.Set(NEP_ESTABLISHED); + ret = ep->ReceiveFds(fds, 1, 0); + EXPECT_EQ(ret, static_cast(NN_OK)); + ep->Close(); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointGetRemoteUdsIdInfo) +{ + int ret; + // driver create + NetDriverShmWithOOB *driver = new NetDriverShmWithOOB("NetSyncEndpointGetRemoteUdsIdInfo", false, SHM); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointGetRemoteUdsIdInfo", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, driver, index, shmEp, map); + ep->mState.Set(NEP_NEW); + UBSHcomNetUdsIdInfo idInfo{}; + + ret = ep->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + + ep->mState.Set(NEP_ESTABLISHED); + driver->mStartOobSvr = false; + ret = ep->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointGetRemoteUdsIdInfoTwo) +{ + int ret; + // driver create + NetDriverShmWithOOB *driver = new NetDriverShmWithOOB("NetSyncEndpointGetRemoteUdsIdInfo2", false, SHM); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointGetRemoteUdsIdInfo2", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, driver, index, shmEp, map); + ep->mState.Set(NEP_NEW); + UBSHcomNetUdsIdInfo idInfo{}; + + ep->mState.Set(NEP_ESTABLISHED); + driver->mStartOobSvr = true; + ret = ep->GetRemoteUdsIdInfo(idInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointGetPeerIpPort) +{ + bool ret; + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointGetPeerIpPort", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(0, 0, 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + std::string ip("127.0.0.1"); + uint16_t port = 1234; + + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointReceive) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointReceive", 0, NN_NO128, NN_NO4, ch); + // peerChannel create + ShmChannelPtr peerCh; + ShmChannel::CreateAndInit("NetSyncEndpointReceive", 0, NN_NO128, NN_NO4, peerCh); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointReceive", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + UBSHcomNetResponseContext ctx{}; + int32_t timeout = 0; + + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::Receive) + .stubs() + .will(returnValue(1)); + + ep->mExistDelayEvent = true; + ep->mDelayHandleReceiveEvent.peerChannelAddress = 0; + ret = ep->Receive(0, ctx); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + ep->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(peerCh.Get()); + ep->mDelayHandleReceiveEvent.dataOffset = 0; + ret = ep->Receive(0, ctx); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointWaitCompletion) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointWaitCompletion", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointWaitCompletion", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + MOCKER_CPP(&ShmSyncEndpoint::DequeueEvent) + .stubs() + .will(returnValue(1)); + + ret = ep->WaitCompletion(0); + EXPECT_EQ(ret, 1); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointWaitCompletionTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointWaitCompletion2", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointWaitCompletion2", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + MOCKER_CPP(&ShmSyncEndpoint::DequeueEvent) + .stubs() + .will(invoke(MockDequeueEvent)); + + mockData[0] = static_cast(ShmOpContextInfo::SH_RECEIVE); + ep->mExistDelayEvent = true; + ret = ep->WaitCompletion(0); + EXPECT_EQ(ret, static_cast(SH_ERROR)); + + mockData[0] = static_cast(ShmOpContextInfo::SH_SEND); + ret = ep->WaitCompletion(0); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointWaitCompletionThree) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointWaitCompletion3", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointWaitCompletion3", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + MOCKER_CPP(&ShmSyncEndpoint::DequeueEvent) + .stubs() + .will(invoke(MockDequeueEvent)); + + ep->mExistDelayEvent = true; + mockData[0] = static_cast(ShmOpContextInfo::SH_SGL_WRITE); + ret = ep->WaitCompletion(0); + EXPECT_EQ(ret, 0); + + mockData[0] = static_cast(ShmOpContextInfo::SH_RECEIVE_RAW); + ret = ep->WaitCompletion(0); + EXPECT_EQ(ret, static_cast(SH_ERROR)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointShmPostSendRawSgl) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSendRawSgl", 0, NN_NO128, NN_NO4, ch); + // mWorker create + UBSHcomNetWorkerIndex indexWorker; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; + ShmWorker *mWorker = new ShmWorker("NetAsyncEndpointShmPostSendRawSgl", indexWorker, options, opMemPool, + opCtxMemPool, sglOpMemPool); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new NetAsyncEndpointShm(ch->Id(), ch.Get(), mWorker, 0, index, map); + ep->mState.Set(NEP_ESTABLISHED); + // request create + UBSHcomNetTransSglRequest request{}; + UBSHcomNetTransSgeIov iov; + uint32_t data; + request.iovCount = 1; + request.iov = &iov; + iov.size = 1; + iov.lAddress = reinterpret_cast(&data); + ep->mSegSize = NN_NO128; + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)); + + MOCKER_CPP(&AesGcm128::Encrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + MOCKER_CPP(&MemoryRegionChecker::Validate) + .stubs() + .will(returnValue(0)); + + ep->mIsNeedEncrypt = 1; + ret = ep->PostSendRawSgl(request, 1); + EXPECT_EQ(ret, static_cast(NN_MALLOC_FAILED)); + + ret = ep->PostSendRawSgl(request, 1); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointShmPostSendRawSglTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostSendRawSgl2", 0, NN_NO128, NN_NO4, ch); + // mWorker create + UBSHcomNetWorkerIndex indexWorker; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; + ShmWorker *mWorker = new ShmWorker("NetAsyncEndpointShmPostSendRawSgl2", indexWorker, options, opMemPool, + opCtxMemPool, sglOpMemPool); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new NetAsyncEndpointShm(ch->Id(), ch.Get(), mWorker, 0, index, map); + ep->mState.Set(NEP_ESTABLISHED); + // request create + UBSHcomNetTransSglRequest request{}; + UBSHcomNetTransSgeIov iov; + uint32_t data; + request.iovCount = 1; + request.iov = &iov; + iov.size = 1; + iov.lAddress = reinterpret_cast(&data); + ep->mSegSize = NN_NO128; + + MOCKER_CPP(&MemoryRegionChecker::Validate) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&ShmWorker::PostSendRawSgl) + .stubs() + .will(returnValue(static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE))) + .then(returnValue(static_cast(SH_PEER_FD_ERROR))); + + ep->mIsNeedEncrypt = 0; + ret = ep->PostSendRawSgl(request, 1); + EXPECT_EQ(ret, static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE)); + + ret = ep->PostSendRawSgl(request, 1); + EXPECT_EQ(ret, static_cast(SH_PEER_FD_ERROR)); + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetAsyncEndpointShmPostReadTwo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetAsyncEndpointShmPostRead2", 0, NN_NO128, NN_NO4, ch); + /// driver create + NetDriverShmWithOOB *driver = new NetDriverShmWithOOB("NetAsyncEndpointShmPostRead2", false, SHM); + // mWorker create + UBSHcomNetWorkerIndex indexWorker; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; + ShmWorker *mWorker = new ShmWorker("NetAsyncEndpointShmPostRead2", indexWorker, options, opMemPool, + opCtxMemPool, sglOpMemPool); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetAsyncEndpointShm *ep = new NetAsyncEndpointShm(ch->Id(), ch.Get(), mWorker, driver, index, map); + ep->mState.Set(NEP_ESTABLISHED); + // request create + UBSHcomNetTransRequest request; + uint32_t data; + request.lAddress = reinterpret_cast(&data); + request.size = 1; + ep->mAllowedSize = NN_NO128; + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&ShmWorker::PostRead) + .stubs() + .will(returnValue(static_cast(SH_OP_CTX_FULL))) + .then(returnValue(static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE))); + + ep->mIsNeedEncrypt = 0; + ep->mDefaultTimeout = 1; + ret = ep->PostRead(request); + EXPECT_EQ(ret, static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, SyncReceiveFailWithErrorOpType) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("SyncReceiveFailWithErrorOpType", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("SyncReceiveFailWithErrorOpType", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + ep->mExistDelayEvent = true; + ep->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(returnValue(1)) + .then(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, 1); + + ep->mExistDelayEvent = true; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, SyncReceiveFailWithOverDataSize) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("SyncReceiveFailWithOverDataSize", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("SyncReceiveFailWithOverDataSize", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // param init + // dataLength is over NET_SGE_MAX_SIZE + mockReq.dataLength = NET_SGE_MAX_SIZE + NN_NO1; + int32_t timeout = 0; + UBSHcomNetResponseContext ctx {}; + ep->mExistDelayEvent = true; + ep->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + ep->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(returnValue(1)) + .then(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, 1); + + ep->mExistDelayEvent = true; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + delete (ep); +} + +TEST_F(TestNetShmEndpoint, SyncReceiveFailWithErrorSeqNo) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("SyncReceiveFailWithErrorSeqNo", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("SyncReceiveFailWithErrorSeqNo", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // param init + mockReq.seqNo = 1; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + ep->mExistDelayEvent = true; + ep->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + ep->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + MOCKER(NetFunc::ValidateHeaderCrc32, bool(UBSHcomNetTransHeader *)) + .stubs() + .will(returnValue(false)); + + ep->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_SEQ_NO_NOT_MATCHED)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, SyncReceiveFailWithErrDataLen) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("SyncReceiveFailWithErrDataLen", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("SyncReceiveFailWithErrDataLen", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + + mockReq.seqNo = 0; + mockReq.dataLength = NN_NO2048; + int32_t timeout = 0; + UBSHcomNetResponseContext ctx {}; + ep->mExistDelayEvent = true; + ep->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + // data length in header is not equal to dataSize in event + ep->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset).stubs().will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + MOCKER(NetFunc::ValidateHeaderCrc32, bool(UBSHcomNetTransHeader *)).stubs().will(returnValue(false)); + + ep->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + ep->mExistDelayEvent = true; + + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + delete (ep); +} + +TEST_F(TestNetShmEndpoint, SyncReceiveFailWithInvalidHeader) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("SyncReceiveFailWithInvalidHeader", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("SyncReceiveFailWithInvalidHeader", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // param init + mockReq.seqNo = 0; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + int32_t timeout = 0; + UBSHcomNetResponseContext ctx {}; + ep->mExistDelayEvent = true; + ep->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + ep->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + ep->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset).stubs().will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree).stubs().will(returnValue(0)); + MOCKER(NetFunc::ValidateHeaderCrc32, bool(UBSHcomNetTransHeader *)).stubs().will(returnValue(false)); + + ep->mExistDelayEvent = true; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_VALIDATE_HEADER_CRC_INVALID)); + + delete (ep); +} + +TEST_F(TestNetShmEndpoint, SyncReceiveFailToAllocate) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("SyncReceiveFailToAllocate", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("SyncReceiveFailToAllocate", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // param init + mockReq.seqNo = 0; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + mockReq.opCode = 0; + mockReq.headerCrc = NetFunc::CalcHeaderCrc32(mockReq); + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + ep->mExistDelayEvent = true; + ep->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + ep->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + ep->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER(NetFunc::ValidateHeaderCrc32, bool(UBSHcomNetTransHeader *)) + .stubs() + .will(returnValue(true)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(false)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + ep->mExistDelayEvent = true; + ep->mIsNeedEncrypt = true; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_MALLOC_FAILED)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmReceiveFour) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmReceive4", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointShmReceive4", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // param init + mockReq.seqNo = 0; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + mockReq.opCode = 0; + mockReq.headerCrc = NetFunc::CalcHeaderCrc32(mockReq); + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + ep->mIsNeedEncrypt = true; + ep->mExistDelayEvent = true; + ep->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + ep->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + ep->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER(NetFunc::ValidateHeaderCrc32, bool(UBSHcomNetTransHeader *)) + .stubs() + .will(returnValue(true)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(true)); + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + ep->mExistDelayEvent = true; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_DECRYPT_FAILED)); + + ep->mExistDelayEvent = true; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, 0); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, NetSyncEndpointShmReceiveFailFive) +{ + int ret; + // mShmCh create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetSyncEndpointShmReceive4", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpoint *shmEp = new ShmSyncEndpoint("NetSyncEndpointShmReceive4", 0, SHM_EVENT_POLLING); + // shmEp create + UBSHcomNetWorkerIndex index; + ShmMRHandleMap map; + NetSyncEndpointShm *ep = new NetSyncEndpointShm(ch->Id(), ch.Get(), 0, index, shmEp, map); + ep->mState.Set(NEP_ESTABLISHED); + // param init + mockReq.seqNo = 0; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + mockReq.opCode = 0; + mockReq.headerCrc = NetFunc::CalcHeaderCrc32(mockReq); + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + ep->mIsNeedEncrypt = false; + ep->mExistDelayEvent = true; + ep->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + ep->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + ep->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER(NetFunc::ValidateHeaderCrc32, bool(UBSHcomNetTransHeader *)) + .stubs() + .will(returnValue(true)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)); + MOCKER_CPP(&memcpy_s) + .stubs() + .will(returnValue(0)) + .then(returnValue(1)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + ep->mExistDelayEvent = true; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_MALLOC_FAILED)); + + ep->mExistDelayEvent = true; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + ep->mExistDelayEvent = true; + ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + delete(ep); +} + +TEST_F(TestNetShmEndpoint, ShmWorkerInitializeFail) +{ + UBSHcomNetWorkerIndex indexWorker; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; + ShmWorker *worker = new (std::nothrow) ShmWorker("shm", indexWorker, options, opMemPool, + opCtxMemPool, sglOpMemPool); + + MOCKER_CPP(&ShmWorker::Validate).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&ShmWorker::CreateEventQueue).stubs().will(returnValue(1)); + + EXPECT_NE(worker->Initialize(), 0); + EXPECT_NE(worker->Initialize(), 0); + + delete worker; +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_shm_async_endpoint.cpp b/test/unit_test/transport/shm/test_shm_async_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2e0a05883940cef7e8b18e752131bf4f462f2702 --- /dev/null +++ b/test/unit_test/transport/shm/test_shm_async_endpoint.cpp @@ -0,0 +1,553 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include + +#include "securec.h" +#include "hcom_def.h" +#include "hcom_log.h" +#include "net_shm_sync_endpoint.h" +#include "net_shm_async_endpoint.h" +#include "shm_common.h" +#include "shm_validation.h" + + +namespace ock { +namespace hcom { + +constexpr uint32_t ASYNC_EP_SHM_ALLOWD_SIZE = 256; +constexpr uint32_t REQUEST_SIZE = 128; + +class TestShmAsyncEndpoint : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + static void SetUpTestSuite() {} + static void TearDownTestSuite() {} + + NetAsyncEndpointShm* mShmAsyncEp = nullptr; + UBSHcomNetTransRequest mReq; +}; + +static HResult MockDCGetFreeBuck(uintptr_t &address, uint64_t &offsetToBase, + uint16_t waitPeriodUs = NN_NO100, int32_t timeoutSecond = -1) +{ + static UBSHcomNetTransHeader mockAsyncBuf{}; + address = reinterpret_cast(&mockAsyncBuf); + offsetToBase = 0; + return SH_OK; +} + +void TestShmAsyncEndpoint::SetUp() +{ + // create and configure NetSyncEndpointShm object + UBSHcomNetWorkerIndex workerId; + workerId.wholeIdx = 0; + ShmMRHandleMap tmpShmMRHandleMap; + + ShmChannelPtr shmCh = new ShmChannel("TestShmAsyncEndpoint", 0, 0, 0); + if (shmCh == nullptr) { + NN_LOG_ERROR("new ShmChannel failed"); + return; + } + + mShmAsyncEp = new (std::nothrow) NetAsyncEndpointShm(0, shmCh.Get(), nullptr, nullptr, workerId, tmpShmMRHandleMap); + if (mShmAsyncEp == nullptr) { + NN_LOG_ERROR("new NetSyncEndpointShm failed"); + return; + } + + mShmAsyncEp->mState.Set(NEP_ESTABLISHED); + mShmAsyncEp->mAllowedSize = ASYNC_EP_SHM_ALLOWD_SIZE; + mShmAsyncEp->mIsNeedEncrypt = false; + + // create and config req + static char buffer[REQUEST_SIZE]; + auto ret = memset_s(buffer, REQUEST_SIZE, '\0', REQUEST_SIZE); + ASSERT_EQ(ret, 0); + ret = memset_s(&mReq, sizeof(mReq), '\0', sizeof(mReq)); + ASSERT_EQ(ret, 0); + mReq.lAddress = reinterpret_cast(buffer); + mReq.size = REQUEST_SIZE; +} + +void TestShmAsyncEndpoint::TearDown() +{ + if (mShmAsyncEp != nullptr) { + delete mShmAsyncEp; + mShmAsyncEp = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestShmAsyncEndpoint, PostSendFailWhenValidateStateFail) +{ + mShmAsyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmAsyncEp->PostSend(0, mReq, 0); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmAsyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendFailWhenValidateSizeFail) +{ + mShmAsyncEp->mAllowedSize = NN_NO1; + NResult ret = mShmAsyncEp->PostSend(0, mReq, 0); + EXPECT_EQ(ret, NN_TWO_SIDE_MESSAGE_TOO_LARGE); + mShmAsyncEp->mAllowedSize = ASYNC_EP_SHM_ALLOWD_SIZE; +} + +TEST_F(TestShmAsyncEndpoint, PostSendFailWhenGetFreeBuckFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck) + .stubs() + .will(returnValue(static_cast(SH_NOT_INITIALIZED))); + + NResult ret = mShmAsyncEp->PostSend(0, mReq, 0); + EXPECT_EQ(ret, SH_NOT_INITIALIZED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendFailWhenMemcpyFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + mShmAsyncEp->mIsNeedEncrypt = false; + NResult ret = mShmAsyncEp->PostSend(0, mReq, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestShmAsyncEndpoint, PostSendOpInfoFailWhenValidateStateFail) +{ + mShmAsyncEp->mState.Set(NEP_BROKEN); + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmAsyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmAsyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendOpInfoFailWhenValidateSizeFail) +{ + mShmAsyncEp->mAllowedSize = NN_NO1; + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmAsyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, NN_TWO_SIDE_MESSAGE_TOO_LARGE); + mShmAsyncEp->mAllowedSize = ASYNC_EP_SHM_ALLOWD_SIZE; +} + +TEST_F(TestShmAsyncEndpoint, PostSendOpInfoFailWhenGetFreeBuckFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck) + .stubs() + .will(returnValue(static_cast(SH_NOT_INITIALIZED))); + + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmAsyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, SH_NOT_INITIALIZED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendOpInfoFailWhenEncryptFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen) + .stubs() + .will(returnValue(static_cast(0))); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool (AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + mShmAsyncEp->mIsNeedEncrypt = true; + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmAsyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendOpInfoFailWhenMemcpyFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + mShmAsyncEp->mIsNeedEncrypt = false; + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmAsyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestShmAsyncEndpoint, PostSendOpInfoFailWhenSendFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + MOCKER_CPP(&ShmWorker::PostSend, + HResult (ShmWorker::*)(ShmChannel *, const UBSHcomNetTransRequest&, uint64_t, uint32_t, int32_t)) + .stubs() + .will(returnValue(static_cast(SH_OP_CTX_FULL))) + .then(returnValue(static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE))) + .then(returnValue(1)); + + mShmAsyncEp->mIsNeedEncrypt = false; + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmAsyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, SH_SEND_COMPLETION_CALLBACK_FAILURE); + + ret = mShmAsyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawFailWhenValidateStateFail) +{ + mShmAsyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmAsyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmAsyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawFailWhenValidateSizeFail) +{ + mShmAsyncEp->mSegSize = NN_NO1; + NResult ret = mShmAsyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, NN_TWO_SIDE_MESSAGE_TOO_LARGE); + mShmAsyncEp->mSegSize = ASYNC_EP_SHM_ALLOWD_SIZE; +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawFailWhenGetFreeBuckFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck) + .stubs() + .will(returnValue(static_cast(SH_NOT_INITIALIZED))); + + mShmAsyncEp->mSegSize = ASYNC_EP_SHM_ALLOWD_SIZE; + NResult ret = mShmAsyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, SH_NOT_INITIALIZED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawFailWhenEncryptFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen) + .stubs() + .will(returnValue(static_cast(0))); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool (AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + mShmAsyncEp->mIsNeedEncrypt = true; + mShmAsyncEp->mSegSize = ASYNC_EP_SHM_ALLOWD_SIZE; + NResult ret = mShmAsyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawFailWhenMemcpyFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + mShmAsyncEp->mIsNeedEncrypt = false; + mShmAsyncEp->mSegSize = ASYNC_EP_SHM_ALLOWD_SIZE; + NResult ret = mShmAsyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawFailWhenSendFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + MOCKER_CPP(&ShmWorker::PostSend, + HResult (ShmWorker::*)(ShmChannel *, const UBSHcomNetTransRequest&, uint64_t, uint32_t, int32_t)) + .stubs() + .will(returnValue(static_cast(SH_OP_CTX_FULL))) + .then(returnValue(static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE))) + .then(returnValue(1)); + + mShmAsyncEp->mIsNeedEncrypt = false; + mShmAsyncEp->mSegSize = ASYNC_EP_SHM_ALLOWD_SIZE; + NResult ret = mShmAsyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, SH_SEND_COMPLETION_CALLBACK_FAILURE); + + ret = mShmAsyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawSglFailWhenValidateFail) +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + mShmAsyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmAsyncEp->PostSendRawSgl(sglReq, 1); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmAsyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawSglFailWhenGetFreeBuckFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck) + .stubs() + .will(returnValue(static_cast(SH_NOT_INITIALIZED))); + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + NResult ret = mShmAsyncEp->PostSendRawSgl(sglReq, 1); + EXPECT_EQ(ret, SH_NOT_INITIALIZED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawSglFailWhenEncryptFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(true)); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool (AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + mShmAsyncEp->mIsNeedEncrypt = true; + NResult ret = mShmAsyncEp->PostSendRawSgl(sglReq, 1); + EXPECT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestShmAsyncEndpoint, PostSendRawSglFailWhenMemcpyFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + mShmAsyncEp->mIsNeedEncrypt = false; + NResult ret = mShmAsyncEp->PostSendRawSgl(sglReq, 1); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestShmAsyncEndpoint, PostReadValidateFail) +{ + mShmAsyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmAsyncEp->PostRead(mReq); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmAsyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmAsyncEndpoint, PostReadSglValidateFail) +{ + mShmAsyncEp->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + NResult ret = mShmAsyncEp->PostRead(sglReq); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmAsyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmAsyncEndpoint, PostReadSglFail) +{ + MOCKER_CPP(&ShmWorker::PostReadSgl) + .stubs() + .will(returnValue(static_cast(SH_OP_CTX_FULL))) + .then(returnValue(1)); + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("PostReadSgl", false, SHM); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + mShmAsyncEp->mDriver = driver; + NResult ret = mShmAsyncEp->PostRead(sglReq); + EXPECT_EQ(ret, 1); + + mShmAsyncEp->mDriver = nullptr; + if (driver != nullptr) { + delete driver; + driver = nullptr; + } +} + +TEST_F(TestShmAsyncEndpoint, PostWriteValidateFail) +{ + mShmAsyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmAsyncEp->PostWrite(mReq); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmAsyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmAsyncEndpoint, PostWriteFail) +{ + MOCKER_CPP(&ShmWorker::PostWrite) + .stubs() + .will(returnValue(static_cast(SH_OP_CTX_FULL))) + .then(returnValue(1)); + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("PostWrite", false, SHM); + mShmAsyncEp->mDriver = driver; + NResult ret = mShmAsyncEp->PostWrite(mReq); + EXPECT_EQ(ret, 1); + + mShmAsyncEp->mDriver = nullptr; + if (driver != nullptr) { + delete driver; + driver = nullptr; + } +} + +TEST_F(TestShmAsyncEndpoint, PostWriteSglValidateFail) +{ + mShmAsyncEp->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + NResult ret = mShmAsyncEp->PostWrite(sglReq); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmAsyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmAsyncEndpoint, PostWriteSglFail) +{ + MOCKER_CPP(&ShmWorker::PostWriteSgl) + .stubs() + .will(returnValue(static_cast(SH_OP_CTX_FULL))) + .then(returnValue(1)); + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + + NetDriverShmWithOOB *driver = new (std::nothrow) NetDriverShmWithOOB("PostWriteSgl", false, SHM); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + mShmAsyncEp->mDriver = driver; + NResult ret = mShmAsyncEp->PostWrite(sglReq); + EXPECT_EQ(ret, 1); + + mShmAsyncEp->mDriver = nullptr; + if (driver != nullptr) { + delete driver; + driver = nullptr; + } +} + +TEST_F(TestShmAsyncEndpoint, SendFdsFail) +{ + int fds1[NN_NO4] = {1, 2, 3, 4}; + int fds2[NN_NO4] = {0}; + uint32_t len = NN_NO4; + MOCKER_CPP(::send).stubs().will(returnValue(0)); + EXPECT_EQ(mShmAsyncEp->SendFds(fds1, len), NN_ERROR); + EXPECT_EQ(mShmAsyncEp->SendFds(fds2, len), NN_INVALID_PARAM); +} + +TEST_F(TestShmAsyncEndpoint, DCMarkPeerBuckFree) +{ + ShmChannelPtr shmCh = new ShmChannel("TestShmAsyncEndpoint", 0, 0, 0); + EXPECT_NO_FATAL_FAILURE(shmCh->DCMarkPeerBuckFree(0)); + EXPECT_NO_FATAL_FAILURE(shmCh->DCMarkBuckFree(0)); +} + +TEST_F(TestShmAsyncEndpoint, PostSendValidationOpCodeFail) +{ + UBSHcomNetAtomicState state{NEP_ESTABLISHED}; + uint16_t opCode = MAX_OPCODE; + UBSHcomNetTransRequest req{}; + EXPECT_EQ(PostSendValidation(state, 0, opCode, req), NN_INVALID_OPCODE); +} + +TEST_F(TestShmAsyncEndpoint, PostSendValidationSizeFail) +{ + UBSHcomNetAtomicState state{NEP_ESTABLISHED}; + UBSHcomNetTransRequest req{}; + req.size = 0; + EXPECT_EQ(PostSendValidation(state, 0, 0, req), NN_INVALID_PARAM); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_shm_channel.cpp b/test/unit_test/transport/shm/test_shm_channel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..202e11db17aa9524816910ebdaf5804ab3415dcb --- /dev/null +++ b/test/unit_test/transport/shm/test_shm_channel.cpp @@ -0,0 +1,222 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include "shm_channel.h" + +namespace ock { +namespace hcom { + +class TestShmChannel : public testing::Test { +public: + TestShmChannel(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +TestShmChannel::TestShmChannel() {} + +void TestShmChannel::SetUp() +{ +} + +void TestShmChannel::TearDown() +{ + GlobalMockObject::verify(); +} + +ShmChannel *MockShmChannelGet() +{ + return nullptr; +} +TEST_F(TestShmChannel, ShmChannelCreateAndInit) +{ + int ret; + ShmChannelPtr ch; + + MOCKER_CPP(&ShmChannelPtr::Get) + .stubs() + .will(invoke(MockShmChannelGet)); + + ret = ShmChannel::CreateAndInit("ShmChannelCreateAndInit", 0, NN_NO128, NN_NO4, ch); + EXPECT_EQ(ret, static_cast(SH_NEW_OBJECT_FAILED)); +} + +TEST_F(TestShmChannel, ShmChannelCreateAndInitTwo) +{ + int ret; + ShmChannelPtr ch; + + MOCKER_CPP(&ShmChannel::Initialize) + .stubs() + .will(returnValue(1)); + + ret = ShmChannel::CreateAndInit("ShmChannelCreateAndInit2", 0, NN_NO128, NN_NO4, ch); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmChannel, ShmChannelClose) +{ + int ret; + ShmChannelPtr ch; + ret = ShmChannel::CreateAndInit("ShmChannelClose", 0, NN_NO128, NN_NO4, ch); + + ch->mFd = -1; + ch->Close(); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestShmChannel, ShmChannelAddMrFd) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmChannelAddMrFd", 0, NN_NO128, NN_NO4, ch); + + ShmChannel::gQueueSizeCap = 0; + ch->mFd = -1; + ret = ch->AddMrFd(0); + EXPECT_EQ(ret, static_cast(SH_FDS_QUEUE_FULL)); +} + +TEST_F(TestShmChannel, ShmChannelAddUserFds) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmChannelAddUserFds", 0, NN_NO128, NN_NO4, ch); + // param init + int fds[4]; + + ShmChannel::gQueueSizeCap = 0; + ret = ch->AddUserFds(fds, 1); + EXPECT_EQ(ret, static_cast(SH_FDS_QUEUE_FULL)); + + ShmChannel::gQueueSizeCap = NN_NO2; + ret = ch->AddUserFds(fds, 1); + EXPECT_EQ(ret, static_cast(SH_OK)); +} + +TEST_F(TestShmChannel, ShmChannelRemoveUserFds) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmChannelRemoveUserFds", 0, NN_NO128, NN_NO4, ch); + // param init + int fds[4]; + + ch->mUserFdQueue.push(1); + ret = ch->RemoveUserFds(fds, 1, 1); + EXPECT_EQ(ret, static_cast(SH_OK)); +} + +TEST_F(TestShmChannel, ShmChannelRemoveUserFdsTwo) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmChannelRemoveUserFds2", 0, NN_NO128, NN_NO4, ch); + // param init + int fds[4]; + + MOCKER_CPP(NetMonotonic::TimeUs) + .stubs() + .will(returnValue(static_cast(0))) + .then(returnValue(static_cast(NN_NO2 * NN_NO1000000))); + ch->mUserFdQueue.push(1); + + ret = ch->RemoveUserFds(fds, NN_NO2, 1); + EXPECT_EQ(ret, static_cast(SH_TIME_OUT)); +} + +TEST_F(TestShmChannel, ShmChannelGetCtxPosted) +{ + int ret; + ShmChannelPtr ch; + ret = ShmChannel::CreateAndInit("ShmChannelGetCtxPosted", 0, NN_NO128, NN_NO4, ch); + // param init + ShmOpContextInfo *remaining; + + ch->GetCtxPosted(remaining); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestShmChannel, ShmChannelGetCompPosted) +{ + int ret; + ShmChannelPtr ch; + ret = ShmChannel::CreateAndInit("ShmChannelGetCompPosted", 0, NN_NO128, NN_NO4, ch); + // param init + ShmOpCompInfo *remaining; + + ch->GetCompPosted(remaining); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestShmChannel, ShmChannelGValidateExchangeInfo) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmChannelGValidateExchangeInfo", 0, NN_NO128, NN_NO4, ch); + // param init + ShmConnExchangeInfo info{}; + + info.qCapacity = 0; + ret = ch->ValidateExchangeInfo(info); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); + + info.qCapacity = 1; + info.queueFd = 1; + info.qName[0] = 0; + ret = ch->ValidateExchangeInfo(info); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); +} + +TEST_F(TestShmChannel, ShmChannelGValidateExchangeInfoTwo) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmChannelGValidateExchangeInfoTwo", 0, NN_NO128, NN_NO4, ch); + // param init + ShmConnExchangeInfo info{}; + + info.qCapacity = 1; + info.queueFd = 1; + strncpy_s(info.qName, NN_NO32, "test", 1); + info.dcBuckCount = 0; + ret = ch->ValidateExchangeInfo(info); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); + + info.dcBuckCount = 1; + info.dcBuckSize = 1; + info.dcName[0] = 0; + ret = ch->ValidateExchangeInfo(info); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); +} + +TEST_F(TestShmChannel, ShmChannelGValidateExchangeInfoThree) +{ + int ret; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmChannelGValidateExchangeInfoThree", 0, NN_NO128, NN_NO4, ch); + // param init + ShmConnExchangeInfo info{}; + + info.qCapacity = 1; + info.queueFd = 1; + strncpy_s(info.qName, NN_NO32, "test", 1); + info.dcBuckCount = 1; + info.dcBuckSize = 1; + strncpy_s(info.dcName, NN_NO64, "test", 1); + info.channelId = 0; + ret = ch->ValidateExchangeInfo(info); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_shm_channel_keeper.cpp b/test/unit_test/transport/shm/test_shm_channel_keeper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c55a6cefa850013283e4a59b485c4bce592de4e9 --- /dev/null +++ b/test/unit_test/transport/shm/test_shm_channel_keeper.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include "shm_channel.h" +#include "shm_handle_fds.h" +#include "shm_channel_keeper.h" + +namespace ock { +namespace hcom { + +class TestShmChannelKeeper : public testing::Test { +public: + TestShmChannelKeeper(); + virtual void SetUp(void); + virtual void TearDown(void); + + ShmChannelKeeper *keeper = nullptr; +}; + +TestShmChannelKeeper::TestShmChannelKeeper() +{} + +void TestShmChannelKeeper::SetUp() +{ + keeper = new (std::nothrow) ShmChannelKeeper("channel_keeper", 0); + ASSERT_NE(keeper, nullptr); +} + +void TestShmChannelKeeper::TearDown() +{ + if (keeper != nullptr) { + delete keeper; + keeper = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestShmChannelKeeper, StartMessageHandlerFail) +{ + int ret = keeper->Start(); + EXPECT_EQ(ret, SH_PARAM_INVALID); +} + +TEST_F(TestShmChannelKeeper, Stop) +{ + keeper->Stop(); +} + +TEST_F(TestShmChannelKeeper, AddShmChannelEpollInFail) +{ + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + ch.Get()->mFd = 0; + + int ret = keeper->AddShmChannel(ch); + EXPECT_EQ(ret, SH_CH_ADD_FAILURE_IN_KEEPER); +} + +TEST_F(TestShmChannelKeeper, RemoveShmChannelFail) +{ + int ret = keeper->RemoveShmChannel(0); + EXPECT_EQ(ret, SH_CH_REMOVE_FAILURE_IN_KEEPER); +} + +TEST_F(TestShmChannelKeeper, RemoveShmChannelSuccess) +{ + uint64_t channelId = 123; + ShmChannelPtr channel; + ShmChannel::CreateAndInit("TestShmChannelKeeper", channelId, NN_NO128, NN_NO4, channel); + keeper->mShmChannels[channelId] = channel; + MOCKER_CPP(&epoll_ctl).stubs().will(returnValue(0)); + int ret = keeper->RemoveShmChannel(channelId); + EXPECT_EQ(ret, SH_OK); + channel.Set(nullptr); +} + +TEST_F(TestShmChannelKeeper, ExchangeFdProcessFail) +{ + ShmChKeeperMsgHeader header; + header.msgType = EXCHANGE_USER_FD; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + int ret = keeper->ExchangeFdProcess(header, ch); + EXPECT_EQ(ret, SH_ERROR); + + MOCKER_CPP(ShmHandleFds::ReceiveMsgFds).stubs().will(returnValue(0)); + MOCKER_CPP(ShmChannel::AddUserFds).stubs().will(returnValue(300)); + ret = keeper->ExchangeFdProcess(header, ch); + EXPECT_EQ(ret, SH_ERROR); +} + +TEST_F(TestShmChannelKeeper, ExchangeFdProcess) +{ + ShmChKeeperMsgHeader header; + header.msgType = EXCHANGE_USER_FD; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("NetShmDriverOob", 0, NN_NO128, NN_NO4, ch); + MOCKER_CPP(ShmHandleFds::ReceiveMsgFds).stubs().will(returnValue(0)); + int ret = keeper->ExchangeFdProcess(header, ch); + EXPECT_EQ(ret, SH_OK); +} + +} // namespace hcom +} // namespace ock \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_shm_composed_endpoint.cpp b/test/unit_test/transport/shm/test_shm_composed_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc468067cbe7e8d48a76b4e1328880755282cac1 --- /dev/null +++ b/test/unit_test/transport/shm/test_shm_composed_endpoint.cpp @@ -0,0 +1,590 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include "shm_composed_endpoint.h" +#include "shm_queue.h" +#include "shm_mr_pool.h" + +namespace ock { +namespace hcom { +class TestShmComposedEndpoint : public testing::Test { +public: + TestShmComposedEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +TestShmComposedEndpoint::TestShmComposedEndpoint() {} + +void TestShmComposedEndpoint::SetUp() +{ +} + +void TestShmComposedEndpoint::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointCreate) +{ + int ret; + // ShmChannel create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmSyncEndpointCreate", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + + MOCKER_CPP(&ShmSyncEndpoint::CreateEventQueue) + .stubs() + .will(returnValue(1)); + + ret = ShmSyncEndpoint::Create("ShmSyncEndpointCreate", NN_NO128, SHM_EVENT_POLLING, shmEp); + EXPECT_EQ(ret, 1); +} + + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointCreateFail) +{ + int ret; + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint *nullEp = nullptr; + + MOCKER_CPP(&ShmSyncEndpointPtr::Get) + .stubs() + .will(returnValue(nullEp)); + + ret = ShmSyncEndpoint::Create("ShmSyncEndpointCreate", NN_NO128, SHM_EVENT_POLLING, shmEp); + EXPECT_EQ(ret, static_cast(SH_NEW_OBJECT_FAILED)); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointCreateEventQueueFail) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointPostSend", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + ShmHandle *nullShmHandle = nullptr; + ShmEventQueue *nullEventQueue = nullptr; + + MOCKER_CPP(&ShmEventQueuePtr::Get) + .stubs() + .will(returnValue(nullEventQueue)); + ret = ep->CreateEventQueue(); + EXPECT_EQ(ret, static_cast(SH_NEW_OBJECT_FAILED)); + + MOCKER_CPP(&ShmHandlePtr::Get) + .stubs() + .will(returnValue(nullShmHandle)); + ret = ep->CreateEventQueue(); + EXPECT_EQ(ret, static_cast(SH_NEW_OBJECT_FAILED)); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointPostSend) +{ + int ret; + // ShmChannel create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmSyncEndpointPostSend", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointPostSend", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + + UBSHcomNetTransRequest req{}; + uint64_t offset = 0; + uint32_t immData = 0; + int32_t defaultTimeout = 0; + UBSHcomNetTransHeader header{}; + + MOCKER_CPP(&ShmChannel::EQEventEnqueue) + .stubs() + .will(returnValue(-1)) + .then(returnValue(0)); + MOCKER_CPP(&ShmEventQueue::EnqueueAndNotify) + .stubs() + .will(returnValue(0)) + .then(returnValue(-1)) + .then(returnValue(1)); + + req.upCtxSize = sizeof(ShmOpContextInfo::upCtx) + 1; + ret = ep->PostSend(ch.Get(), req, offset, immData, defaultTimeout); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); + + req.upCtxSize = sizeof(ShmOpContextInfo::upCtx); + req.lAddress = reinterpret_cast(&header); + ret = ep->PostSend(ch.Get(), req, offset, immData, defaultTimeout); + EXPECT_EQ(ret, static_cast(SH_RETRY_FULL)); + + ret = ep->PostSend(ch.Get(), req, offset, immData, defaultTimeout); + EXPECT_EQ(ret, static_cast(SH_OK)); + + ret = ep->PostSend(ch.Get(), req, offset, immData, defaultTimeout); + EXPECT_EQ(ret, static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE)); + + ret = ep->PostSend(ch.Get(), req, offset, immData, defaultTimeout); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointFillSglCtx) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointFillSglCtx", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // request create + ShmSglOpContextInfo sglCtx{}; + UBSHcomNetTransSglRequest request{}; + int data; + UBSHcomNetTransSgeIov iov; + request.iovCount = 1; + request.iov = &iov; + iov.size = 1; + iov.lAddress = reinterpret_cast(&data); + request.upCtxSize = 1; + + ret = ep->FillSglCtx(nullptr, request); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); + + ret = ep->FillSglCtx(&sglCtx, request); + EXPECT_EQ(ret, static_cast(SH_OK)); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + ret = ep->FillSglCtx(&sglCtx, request); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); + + ret = ep->FillSglCtx(&sglCtx, request); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointPostSendRawSgl) +{ + int ret; + // ShmChannel create + ShmChannelPtr ch; + ShmChannel::CreateAndInit("ShmSyncEndpointPostSendRawSgl", 0, NN_NO128, NN_NO4, ch); + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointPostSendRawSgl", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // request create + UBSHcomNetTransRequest req{}; + UBSHcomNetTransSglRequest request{}; + int data; + UBSHcomNetTransHeader header{}; + UBSHcomNetTransSgeIov iov; + request.iovCount = 1; + request.iov = &iov; + iov.size = 1; + iov.lAddress = reinterpret_cast(&data); + request.upCtxSize = 1; + req.lAddress = reinterpret_cast(&header); + + MOCKER_CPP(&ShmChannel::EQEventEnqueue) + .stubs() + .will(returnValue(-1)) + .then(returnValue(0)); + MOCKER_CPP(&ShmEventQueue::EnqueueAndNotify) + .stubs() + .will(returnValue(0)) + .then(returnValue(-1)) + .then(returnValue(1)); + + request.upCtxSize = sizeof(ShmOpContextInfo::upCtx) + 1; + ret = ep->PostSendRawSgl(ch.Get(), req, request, 0, 0, 0); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); + + request.upCtxSize = sizeof(ShmOpContextInfo::upCtx); + ret = ep->PostSendRawSgl(ch.Get(), req, request, 0, 0, 0); + EXPECT_EQ(ret, static_cast(SH_RETRY_FULL)); + + ret = ep->PostSendRawSgl(ch.Get(), req, request, 0, 0, 0); + EXPECT_EQ(ret, static_cast(SH_OK)); + + ret = ep->PostSendRawSgl(ch.Get(), req, request, 0, 0, 0); + EXPECT_EQ(ret, static_cast(SH_SEND_COMPLETION_CALLBACK_FAILURE)); + + ret = ep->PostSendRawSgl(ch.Get(), req, request, 0, 0, 0); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointSendLocalEventForOneSideDone) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointSendLocalEventForOneSideDone", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param create + ShmOpContextInfo ctx{}; + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_READ; + + MOCKER_CPP(&ShmEventQueue::EnqueueAndNotify) + .stubs() + .will(returnValue(0)) + .then(returnValue(-1)) + .then(returnValue(1)); + + ret = ep->SendLocalEventForOneSideDone(&ctx, type); + EXPECT_EQ(ret, static_cast(SH_OK)); + + ret = ep->SendLocalEventForOneSideDone(&ctx, type); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointReceive) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointReceive", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param create + ShmOpContextInfo opCtx{}; + uint32_t immData = 0; + + MOCKER_CPP(&ShmSyncEndpoint::DequeueEvent) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(1)); + + ret = ep->Receive(0, opCtx, immData); + EXPECT_EQ(ret, 1); + + ret = ep->Receive(0, opCtx, immData); + EXPECT_EQ(ret, static_cast(SH_ERROR)); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointDequeueEvent) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointDequeueEvent", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param create + ShmEvent opEvent{}; + + MOCKER_CPP(&ShmEventQueue::DequeueOrWait) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(1)); + + ret = ep->DequeueEvent(0, opEvent); + EXPECT_EQ(ret, 1); + + ret = ep->DequeueEvent(0, opEvent); + EXPECT_EQ(ret, static_cast(SH_OK)); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointPostRead) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointPostRead", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param create + UBSHcomNetTransRequest req{}; + req.upCtxSize = sizeof(ShmOpContextInfo::upCtx) + 1; + ShmMRHandleMap mrHandleMap{}; + + ret = ep->PostRead(0, req, mrHandleMap); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointPostReadTwo) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointPostRead2", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param create + UBSHcomNetTransSglRequest req{}; + req.upCtxSize = sizeof(ShmOpContextInfo::upCtx) + 1; + ShmMRHandleMap mrHandleMap{}; + + ret = ep->PostRead(0, req, mrHandleMap); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointPostWrite) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointPostWrite", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param create + UBSHcomNetTransRequest req{}; + req.upCtxSize = sizeof(ShmOpContextInfo::upCtx) + 1; + ShmMRHandleMap mrHandleMap{}; + + ret = ep->PostWrite(0, req, mrHandleMap); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); +} + +TEST_F(TestShmComposedEndpoint, ShmSyncEndpointPostWriteTwo) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("ShmSyncEndpointPostWrite2", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param create + UBSHcomNetTransSglRequest req{}; + req.upCtxSize = sizeof(ShmOpContextInfo::upCtx) + 1; + ShmMRHandleMap mrHandleMap{}; + + ret = ep->PostWrite(0, req, mrHandleMap); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); +} + +TEST_F(TestShmComposedEndpoint, SyncReadWriteProcess) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("SyncReadWriteProcess", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param init + ShmMRHandleMap mrHandleMap{}; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("SyncReadWriteProcess", 0, NN_NO128, NN_NO4, ch); + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::SH_SEND; + ShmHandlePtr localMrHandle = nullptr; + UBSHcomNetTransRequest req{}; + + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap) + .stubs() + .will(returnValue(localMrHandle)); + + ret = ep->PostReadWrite(ch.Get(), req, mrHandleMap, type); + EXPECT_EQ(ret, static_cast(SH_ERROR)); +} + +TEST_F(TestShmComposedEndpoint, SyncReadWriteProcessTwo) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("SyncReadWriteProcess2", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param init + ShmMRHandleMap mrHandleMap{}; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("SyncReadWriteProcess2", 0, NN_NO128, NN_NO4, ch); + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::SH_SEND; + ShmHandlePtr localMrHandle = new (std::nothrow) ShmHandle("SyncReadWriteProcess2", "", 0, 0, false); + ASSERT_NE(localMrHandle, nullptr); + ShmHandlePtr remoteMrHandle; + UBSHcomNetTransRequest req{}; + + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap) + .stubs() + .will(returnValue(localMrHandle)); + MOCKER_CPP(&ShmMRHandleMap::GetFromRemoteMap) + .stubs() + .will(returnValue(remoteMrHandle)); + MOCKER_CPP(&ShmChannel::GetRemoteMrHandle) + .stubs() + .will(returnValue(1)); + + ret = ep->PostReadWrite(ch.Get(), req, mrHandleMap, type); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmComposedEndpoint, PostReadWriteCopyFail) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("PostReadWriteCopyFail", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param init + ShmMRHandleMap mrHandleMap{}; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("PostReadWriteCopyFail", 0, NN_NO128, NN_NO4, ch); + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_READ; + UBSHcomNetTransRequest req{}; + req.upCtxSize = 1; + ShmHandlePtr localMrHandle = new (std::nothrow) ShmHandle("localMrHandle", "", 0, 0, false); + ASSERT_NE(localMrHandle, nullptr); + ShmHandlePtr remoteMrHandle = new (std::nothrow) ShmHandle("remoteMrHandle", "", 0, 0, false); + ASSERT_NE(remoteMrHandle, nullptr); + + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap) + .stubs() + .will(returnValue(localMrHandle)); + MOCKER_CPP(&ShmMRHandleMap::GetFromRemoteMap) + .stubs() + .will(returnValue(remoteMrHandle)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + + ret = ep->PostReadWrite(ch.Get(), req, mrHandleMap, type); + EXPECT_EQ(ret, static_cast(SH_PARAM_INVALID)); +} + +TEST_F(TestShmComposedEndpoint, PostReadWriteSendLocalEventFail) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("PostReadWriteSendLocalEventFail", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param init + ShmMRHandleMap mrHandleMap{}; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("PostReadWriteSendLocalEventFail", 0, NN_NO128, NN_NO4, ch); + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_READ; + UBSHcomNetTransRequest req{}; + req.upCtxSize = 0; + ShmHandlePtr localMrHandle = new (std::nothrow) ShmHandle("localMrHandle", "", 0, 0, false); + ASSERT_NE(localMrHandle, nullptr); + ShmHandlePtr remoteMrHandle = new (std::nothrow) ShmHandle("remoteMrHandle", "", 0, 0, false); + ASSERT_NE(remoteMrHandle, nullptr); + + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap) + .stubs() + .will(returnValue(localMrHandle)); + MOCKER_CPP(&ShmMRHandleMap::GetFromRemoteMap) + .stubs() + .will(returnValue(remoteMrHandle)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::SendLocalEventForOneSideDone).stubs().will(returnValue(1)); + + ret = ep->PostReadWrite(ch.Get(), req, mrHandleMap, type); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmComposedEndpoint, PostReadWriteSglSyncReadWriteProcessFail) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("PostReadWriteSglSyncReadWriteProcessFail", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param init + ShmMRHandleMap mrHandleMap{}; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("PostReadWriteSglSyncReadWriteProcessFail", 0, NN_NO128, NN_NO4, ch); + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_SGL_READ; + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + ShmHandlePtr localMrHandle = new (std::nothrow) ShmHandle("localMrHandle", "", 0, 0, false); + ASSERT_NE(localMrHandle, nullptr); + ShmHandlePtr remoteMrHandle; + + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap) + .stubs() + .will(returnValue(localMrHandle)); + MOCKER_CPP(&ShmMRHandleMap::GetFromRemoteMap) + .stubs() + .will(returnValue(remoteMrHandle)); + MOCKER_CPP(&ShmChannel::GetRemoteMrHandle) + .stubs() + .will(returnValue(1)); + + ret = ep->PostReadWriteSgl(ch.Get(), sglReq, mrHandleMap, type); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmComposedEndpoint, PostReadWriteSglFillSglCtxFail) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("PostReadWriteSglFillSglCtxFail", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param init + ShmMRHandleMap mrHandleMap{}; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("PostReadWriteSglFillSglCtxFail", 0, NN_NO128, NN_NO4, ch); + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_SGL_WRITE; + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + ShmHandlePtr localMrHandle = new (std::nothrow) ShmHandle("localMrHandle", "", 0, 0, false); + ASSERT_NE(localMrHandle, nullptr); + ShmHandlePtr remoteMrHandle = new (std::nothrow) ShmHandle("remoteMrHandle", "", 0, 0, false); + ASSERT_NE(remoteMrHandle, nullptr); + + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap) + .stubs() + .will(returnValue(localMrHandle)); + MOCKER_CPP(&ShmMRHandleMap::GetFromRemoteMap) + .stubs() + .will(returnValue(remoteMrHandle)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::FillSglCtx).stubs().will(returnValue(1)); + ret = ep->PostReadWriteSgl(ch.Get(), sglReq, mrHandleMap, type); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmComposedEndpoint, PostReadWriteSglSendLocalEventFail) +{ + int ret; + // ShmSyncEndpoint create + ShmSyncEndpointPtr shmEp; + ShmSyncEndpoint::Create("PostReadWriteSglSendLocalEventFail", NN_NO128, SHM_EVENT_POLLING, shmEp); + ShmSyncEndpoint *ep = shmEp.Get(); + // param init + ShmMRHandleMap mrHandleMap{}; + ShmChannelPtr ch; + ShmChannel::CreateAndInit("PostReadWriteSglSendLocalEventFail", 0, NN_NO128, NN_NO4, ch); + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_SGL_READ; + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + ShmHandlePtr localMrHandle = new (std::nothrow) ShmHandle("localMrHandle", "", 0, 0, false); + ASSERT_NE(localMrHandle, nullptr); + ShmHandlePtr remoteMrHandle = new (std::nothrow) ShmHandle("remoteMrHandle", "", 0, 0, false); + ASSERT_NE(remoteMrHandle, nullptr); + + MOCKER_CPP(&ShmMRHandleMap::GetFromLocalMap) + .stubs() + .will(returnValue(localMrHandle)); + MOCKER_CPP(&ShmMRHandleMap::GetFromRemoteMap) + .stubs() + .will(returnValue(remoteMrHandle)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::FillSglCtx).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmSyncEndpoint::SendLocalEventForOneSideDone).stubs().will(returnValue(1)); + ret = ep->PostReadWriteSgl(ch.Get(), sglReq, mrHandleMap, type); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestShmComposedEndpoint, CreateEventQueueFail) +{ + int ret; + ShmSyncEndpoint *ep = new ShmSyncEndpoint("shm", 0, SHM_EVENT_POLLING); + EXPECT_NE(ep->CreateEventQueue(), SH_OK); + + delete ep; +} + +TEST_F(TestShmComposedEndpoint, ShmMemoryRegionCreateFail) +{ + int ret; + ShmMemoryRegion *mr = nullptr; + + ret = ShmMemoryRegion::Create("shmMr", 0, mr); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + ret = ShmMemoryRegion::Create("shmMr", 0, 0, mr); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_shm_data_channel.cpp b/test/unit_test/transport/shm/test_shm_data_channel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c05bd2734b23aa3c6ab93432bd2a3accae073e87 --- /dev/null +++ b/test/unit_test/transport/shm/test_shm_data_channel.cpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include "shm_data_channel.h" + +namespace ock { +namespace hcom { + +class TestShmDataChannel : public testing::Test { +public: + TestShmDataChannel(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +TestShmDataChannel::TestShmDataChannel() +{} + +void TestShmDataChannel::SetUp() +{} + +void TestShmDataChannel::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestShmDataChannel, ValidateOptionsNullNameFail) +{ + int ret; + UBSHcomNetAtomicState state{CH_NEW}; + ShmDataChannelOptions opt(0, NN_NO256, NN_NO16, true); + ShmDataChannel *dc = new (std::nothrow) ShmDataChannel("", opt, &state); + ASSERT_NE(dc, nullptr); + + ret = dc->ValidateOptions(); + EXPECT_EQ(ret, SH_PARAM_INVALID); + + if (dc != nullptr) { + delete dc; + dc = nullptr; + } +} + +TEST_F(TestShmDataChannel, ValidateOptionsNullStateFail) +{ + int ret; + ShmDataChannelOptions opt(0, NN_NO256, NN_NO16, true); + ShmDataChannel *dc = new (std::nothrow) ShmDataChannel("TestShmDataChannel", opt, nullptr); + ASSERT_NE(dc, nullptr); + + ret = dc->ValidateOptions(); + EXPECT_EQ(ret, SH_PARAM_INVALID); + + if (dc != nullptr) { + delete dc; + dc = nullptr; + } +} + +TEST_F(TestShmDataChannel, ValidateOptionsOptionsFail) +{ + int ret; + UBSHcomNetAtomicState state{CH_NEW}; + ShmDataChannelOptions opt(0, 0, 0, true); + ShmDataChannel *dc = new (std::nothrow) ShmDataChannel("TestShmDataChannel", opt, &state); + ASSERT_NE(dc, nullptr); + + ret = dc->ValidateOptions(); + EXPECT_EQ(ret, SH_PARAM_INVALID); + + if (dc != nullptr) { + delete dc; + dc = nullptr; + } +} + +TEST_F(TestShmDataChannel, InitializeInvalidOptionFail) +{ + int ret; + UBSHcomNetAtomicState state{CH_NEW}; + ShmDataChannelOptions opt(0, NN_NO256, NN_NO16, true); + ShmDataChannel *dc = new (std::nothrow) ShmDataChannel("TestShmDataChannel", opt, &state); + ASSERT_NE(dc, nullptr); + + MOCKER_CPP(ShmDataChannel::ValidateOptions).stubs().will(returnValue(300)); + ret = dc->Initialize(); + EXPECT_EQ(ret, SH_ERROR); + + if (dc != nullptr) { + delete dc; + dc = nullptr; + } +} + +TEST_F(TestShmDataChannel, Initialize) +{ + int ret; + UBSHcomNetAtomicState state{CH_NEW}; + ShmDataChannelOptions opt(0, NN_NO256, NN_NO16, true); + ShmDataChannel *dc = new (std::nothrow) ShmDataChannel("TestShmDataChannel", opt, &state); + ASSERT_NE(dc, nullptr); + + ret = dc->Initialize(); + EXPECT_EQ(ret, SH_OK); + ret = dc->Initialize(); // Initialize again + EXPECT_EQ(ret, SH_OK); + + if (dc != nullptr) { + delete dc; + dc = nullptr; + } +} + +} // namespace hcom +} // namespace ock \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_shm_mr_pool.cpp b/test/unit_test/transport/shm/test_shm_mr_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5cefc6827810d5f41630fb881c46de17d9438d82 --- /dev/null +++ b/test/unit_test/transport/shm/test_shm_mr_pool.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include "shm_mr_pool.h" + +namespace ock { +namespace hcom { + +class TestShmMemoryRegion : public testing::Test { +public: + TestShmMemoryRegion(); + virtual void SetUp(void); + virtual void TearDown(void); + + ShmMemoryRegion *mr = nullptr; + ShmMemoryRegion *noExternalMr = nullptr; +}; + +TestShmMemoryRegion::TestShmMemoryRegion() +{} + +void TestShmMemoryRegion::SetUp() +{ + mr = new (std::nothrow) ShmMemoryRegion("shmMr", true, 0, 0); + ASSERT_NE(mr, nullptr); + noExternalMr = new (std::nothrow) ShmMemoryRegion("noExternalshmMr", false, 0, 0); + ASSERT_NE(mr, nullptr); +} + +void TestShmMemoryRegion::TearDown() +{ + if (mr != nullptr) { + delete mr; + mr = nullptr; + } + if (noExternalMr != nullptr) { + delete noExternalMr; + noExternalMr = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestShmMemoryRegion, UnInitializeFail) +{ + mr->UnInitialize(); +} + +TEST_F(TestShmMemoryRegion, CreateAndInitialize) +{ + ShmMemoryRegion *tmpShmMr = nullptr; + int ret = ShmMemoryRegion::Create("tmpShmMr", 1, 1, tmpShmMr); + EXPECT_EQ(ret, NN_OK); + + ret = tmpShmMr->Initialize(); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestShmMemoryRegion, InitializeInvalidParam) +{ + int ret = mr->Initialize(); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestShmMemoryRegion, InitializeHandlerNoInitializedFail) +{ + MOCKER_CPP(ShmHandle::Initialize).stubs().will(returnValue(300)); + int ret = noExternalMr->Initialize(); + EXPECT_EQ(ret, NN_NOT_INITIALIZED); +} + +TEST_F(TestShmMemoryRegion, InitializeHandlerMallocFail) +{ + MOCKER_CPP(ShmHandle::Initialize).stubs().will(returnValue(0)); + uintptr_t mAddress = 0; + MOCKER_CPP(ShmHandle::ShmAddress).stubs().will(returnValue(mAddress)); + int ret = noExternalMr->Initialize(); + EXPECT_EQ(ret, NN_MALLOC_FAILED); +} + +} // namespace hcom +} // namespace ock \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_shm_queue.cpp b/test/unit_test/transport/shm/test_shm_queue.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0b5fadf6105846f346261e40968a07c84a56166 --- /dev/null +++ b/test/unit_test/transport/shm/test_shm_queue.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "shm_worker.h" +#include "shm_handle.h" +#include "shm_queue.h" +#include "shm_common.h" + +namespace ock { +namespace hcom { +class TestShmQueue : public testing::Test { +public: + TestShmQueue(); + virtual void SetUp(void); + virtual void TearDown(void); + std::string name = "TestShmQueue"; + ShmEvent event {}; + ShmEventQueuePtr queue; + ShmQueueMeta *queueMeta; +}; + +TestShmQueue::TestShmQueue() {} + +void TestShmQueue::SetUp() +{ + queueMeta = new (std::nothrow) ShmQueueMeta(); + ASSERT_NE(queueMeta, nullptr); + queue = new (std::nothrow) ShmEventQueue(name, NN_NO2048, nullptr); + ASSERT_NE(queue, nullptr); + queue->mQueueMeta = queueMeta; + queue->mQueueData = new (std::nothrow) ShmEvent[NN_NO2048]; + ASSERT_NE(queue->mQueueData, nullptr); +} + +void TestShmQueue::TearDown() +{ + if (queueMeta != nullptr) { + delete queueMeta; + queue->mQueueMeta = nullptr; + } + if (queue->mQueueData != nullptr) { + delete queue->mQueueData; + queue->mQueueData = nullptr; + } + if (queue != nullptr) { + queue.Set(nullptr); + } + GlobalMockObject::verify(); +} + +TEST_F(TestShmQueue, EnqueueFailed) +{ + queue->mInited = true; + queue->mMaxEnqueueTimeout = 0; + queue->mQueueMeta->prod.tail = 0; + queue->mQueueMeta->prod.head = 1; + queue->mQueueMeta->cons.tail = 0; + queue->mQueueMeta->cons.head = 0; + EXPECT_EQ(queue->Enqueue(event), static_cast(ShmEventQueue::SHM_QUEUE_FULL)); +} + +TEST_F(TestShmQueue, DequeueFailed) +{ + queue->mInited = true; + queue->mMaxEnqueueTimeout = 0; + queue->mQueueMeta->prod.tail = 0; + queue->mQueueMeta->prod.head = 0; + queue->mQueueMeta->cons.tail = 0; + queue->mQueueMeta->cons.head = 0; + queue->mFailedProd = 0; + EXPECT_EQ(queue->Dequeue(event), static_cast(ShmEventQueue::SHM_QUEUE_EMPTY)); +} + +TEST_F(TestShmQueue, CheckState) +{ + queue->mInited = true; + queue->mMaxFailedTime = 0; + queue->mQueueMeta->prod.tail = 0; + queue->mQueueMeta->prod.head = 1; + queue->mTempProdIdx = UINT64_MAX; + queue->mFailedProd = UINT64_MAX; + EXPECT_NO_FATAL_FAILURE(queue->CheckAndMarkProducerState()); + EXPECT_EQ(queue->mTempProdIdx, queue->mQueueMeta->prod.tail); + + EXPECT_NO_FATAL_FAILURE(queue->CheckAndMarkProducerState()); + EXPECT_EQ(queue->mQueueMeta->prod.tail, NN_NO1); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_shm_sync_endpoint.cpp b/test/unit_test/transport/shm/test_shm_sync_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c7866b37a5dddb1d3f286074520189340cb8305f --- /dev/null +++ b/test/unit_test/transport/shm/test_shm_sync_endpoint.cpp @@ -0,0 +1,613 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include + +#include "securec.h" +#include "hcom_def.h" +#include "hcom_log.h" +#include "net_shm_sync_endpoint.h" +#include "net_shm_async_endpoint.h" +#include "shm_common.h" + + +namespace ock { +namespace hcom { + +constexpr uint32_t SYNC_EP_SHM_ALLOWD_SIZE = 256; +constexpr uint32_t REQUEST_SIZE = 128; +static UBSHcomNetTransHeader mockReq{}; + +class TestShmSyncEndpointNew : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + static void SetUpTestSuite() {} + static void TearDownTestSuite() {} + + NetSyncEndpointShm* mShmSyncEp = nullptr; + UBSHcomNetTransRequest mReq; +}; + +static HResult MockDCGetFreeBuck(uintptr_t &address, uint64_t &offsetToBase, + uint16_t waitPeriodUs = NN_NO100, int32_t timeoutSecond = -1) +{ + static char buffer[SYNC_EP_SHM_ALLOWD_SIZE]; + auto ret = memset_s(buffer, SYNC_EP_SHM_ALLOWD_SIZE, '\0', SYNC_EP_SHM_ALLOWD_SIZE); + if (ret != 0) { + NN_LOG_ERROR("MockDCGetFreeBuck memset_s failed"); + return SH_ERROR; + } + address = reinterpret_cast(buffer); + offsetToBase = 0; + return SH_OK; +} + +void TestShmSyncEndpointNew::SetUp() +{ + // create and configure NetSyncEndpointShm object + UBSHcomNetWorkerIndex workerId; + workerId.wholeIdx = 0; + ShmMRHandleMap tmpShmMRHandleMap; + + ShmChannelPtr shmCh = new ShmChannel("TestShmSyncEndpoint", 0, 0, 0); + if (shmCh == nullptr) { + NN_LOG_ERROR("new ShmChannel failed"); + return; + } + + ShmSyncEndpointPtr shmEp = new ShmSyncEndpoint("TestShmSyncEndpoint", 0, SHM_EVENT_POLLING); + if (shmEp == nullptr) { + NN_LOG_ERROR("new ShmSyncEndpoint failed"); + return; + } + + mShmSyncEp = new (std::nothrow) + NetSyncEndpointShm(0, shmCh.Get(), nullptr, workerId, shmEp.Get(), tmpShmMRHandleMap); + if (mShmSyncEp == nullptr) { + NN_LOG_ERROR("new NetSyncEndpointShm failed"); + return; + } + + mShmSyncEp->mState.Set(NEP_ESTABLISHED); + mShmSyncEp->mAllowedSize = SYNC_EP_SHM_ALLOWD_SIZE; + mShmSyncEp->mIsNeedEncrypt = false; + + // create and config req + static char buffer[REQUEST_SIZE]; + auto ret = memset_s(buffer, REQUEST_SIZE, '\0', REQUEST_SIZE); + ASSERT_EQ(ret, 0); + ret = memset_s(&mReq, sizeof(mReq), '\0', sizeof(mReq)); + ASSERT_EQ(ret, 0); + mReq.lAddress = reinterpret_cast(buffer); + mReq.size = REQUEST_SIZE; +} + +void TestShmSyncEndpointNew::TearDown() +{ + if (mShmSyncEp != nullptr) { + delete mShmSyncEp; + mShmSyncEp = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestShmSyncEndpointNew, PostSendFailWhenValidateStateFail) +{ + mShmSyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmSyncEp->PostSend(0, mReq, 0); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmSyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmSyncEndpointNew, PostSendFailWhenMessageTooLarge) +{ + mReq.size = SYNC_EP_SHM_ALLOWD_SIZE + 1; + NResult ret = mShmSyncEp->PostSend(0, mReq, 0); + ASSERT_EQ(ret, NN_TWO_SIDE_MESSAGE_TOO_LARGE); +} + +TEST_F(TestShmSyncEndpointNew, PostSendFailWhenGetFreeBuckFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck) + .stubs() + .will(returnValue(static_cast(SH_NOT_INITIALIZED))); + + NResult ret = mShmSyncEp->PostSend(0, mReq, 0); + ASSERT_EQ(ret, SH_NOT_INITIALIZED); +} + +TEST_F(TestShmSyncEndpointNew, PostSendFailWhenEncryptFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen) + .stubs() + .will(returnValue(static_cast(0))); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool (AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + mShmSyncEp->mIsNeedEncrypt = true; + NResult ret = mShmSyncEp->PostSend(0, mReq, 0); + ASSERT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestShmSyncEndpointNew, PostSendFailWhenMemcpyFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + mShmSyncEp->mIsNeedEncrypt = false; + NResult ret = mShmSyncEp->PostSend(0, mReq, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestShmSyncEndpointNew, PostSendSuccess) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmSyncEndpoint::PostSend, + HResult (ShmSyncEndpoint::*)(ShmChannel *, const UBSHcomNetTransRequest&, uint64_t, uint32_t, int32_t)) + .stubs() + .will(returnValue(static_cast(SH_OK))); + MOCKER(memcpy_s).stubs().will(returnValue(0)); + + NResult ret = mShmSyncEp->PostSend(0, mReq, 0); + ASSERT_EQ(ret, SH_OK); +} + +TEST_F(TestShmSyncEndpointNew, PostSendOpInfoFailWhenValidateStateFail) +{ + mShmSyncEp->mState.Set(NEP_BROKEN); + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmSyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmSyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmSyncEndpointNew, PostSendOpInfoFailWhenValidateSizeFail) +{ + mShmSyncEp->mAllowedSize = NN_NO1; + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmSyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, NN_TWO_SIDE_MESSAGE_TOO_LARGE); + mShmSyncEp->mAllowedSize = SYNC_EP_SHM_ALLOWD_SIZE; +} + +TEST_F(TestShmSyncEndpointNew, PostSendOpInfoFailWhenGetFreeBuckFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck) + .stubs() + .will(returnValue(static_cast(SH_NOT_INITIALIZED))); + + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmSyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, SH_NOT_INITIALIZED); +} + +TEST_F(TestShmSyncEndpointNew, PostSendOpInfoFailWhenEncryptFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen) + .stubs() + .will(returnValue(static_cast(0))); + + MOCKER_CPP(&AesGcm128::Encrypt, + bool (AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + mShmSyncEp->mIsNeedEncrypt = true; + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmSyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, NN_ENCRYPT_FAILED); +} + +TEST_F(TestShmSyncEndpointNew, PostSendOpInfoFailWhenMemcpyFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + mShmSyncEp->mIsNeedEncrypt = false; + UBSHcomNetTransOpInfo opInfo{}; + NResult ret = mShmSyncEp->PostSend(0, mReq, opInfo); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestShmSyncEndpointNew, PostSendRawFailWhenValidateStateFail) +{ + mShmSyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmSyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmSyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmSyncEndpointNew, PostSendRawFailWhenValidateSizeFail) +{ + mShmSyncEp->mSegSize = NN_NO1; + NResult ret = mShmSyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, NN_TWO_SIDE_MESSAGE_TOO_LARGE); + mShmSyncEp->mSegSize = SYNC_EP_SHM_ALLOWD_SIZE; +} + +TEST_F(TestShmSyncEndpointNew, PostSendRawFailWhenMemcpyFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + mShmSyncEp->mIsNeedEncrypt = false; + mShmSyncEp->mSegSize = SYNC_EP_SHM_ALLOWD_SIZE; + NResult ret = mShmSyncEp->PostSendRaw(mReq, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestShmSyncEndpointNew, PostSendRawSglFailWhenValidateFail) +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + mShmSyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmSyncEp->PostSendRawSgl(sglReq, 1); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmSyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmSyncEndpointNew, PostSendRawSglFailWhenGetFreeBuckFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck) + .stubs() + .will(returnValue(static_cast(SH_NOT_INITIALIZED))); + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + NResult ret = mShmSyncEp->PostSendRawSgl(sglReq, 1); + EXPECT_EQ(ret, SH_NOT_INITIALIZED); +} + +TEST_F(TestShmSyncEndpointNew, PostSendRawSglFailWhenEncryptFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(false)) + .then(returnValue(true)); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + mShmSyncEp->mIsNeedEncrypt = true; + NResult ret = mShmSyncEp->PostSendRawSgl(sglReq, 1); + EXPECT_EQ(ret, NN_MALLOC_FAILED); + + ret = mShmSyncEp->PostSendRawSgl(sglReq, 1); + EXPECT_EQ(ret, NN_ERROR); +} + +TEST_F(TestShmSyncEndpointNew, PostSendRawSglFailWhenMemcpyFail) +{ + MOCKER_CPP(&ShmChannel::DCGetFreeBuck, + HResult (ShmChannel::*)(uintptr_t&, uint64_t&, uint16_t, int32_t)) + .stubs() + .will(invoke(MockDCGetFreeBuck)); + + MOCKER_CPP(&ShmChannel::DCMarkBuckFree) + .stubs() + .will(returnValue(static_cast(SH_OK))); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + MOCKER_CPP(&NetDriverShmWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + mShmSyncEp->mIsNeedEncrypt = false; + NResult ret = mShmSyncEp->PostSendRawSgl(sglReq, 1); + EXPECT_EQ(ret, NN_ERROR); +} + +TEST_F(TestShmSyncEndpointNew, PostReadValidateFail) +{ + mShmSyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmSyncEp->PostRead(mReq); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmSyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmSyncEndpointNew, PostReadSglValidateFail) +{ + mShmSyncEp->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + NResult ret = mShmSyncEp->PostRead(sglReq); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmSyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmSyncEndpointNew, PostWriteValidateFail) +{ + mShmSyncEp->mState.Set(NEP_BROKEN); + NResult ret = mShmSyncEp->PostWrite(mReq); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmSyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmSyncEndpointNew, PostWriteSglValidateFail) +{ + mShmSyncEp->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + NResult ret = mShmSyncEp->PostWrite(sglReq); + EXPECT_EQ(ret, NN_EP_NOT_ESTABLISHED); + mShmSyncEp->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestShmSyncEndpointNew, SendFdsFail) +{ + int fds1[NN_NO4] = {1, 2, 3, 4}; + int fds2[NN_NO4] = {0}; + uint32_t len = NN_NO4; + MOCKER_CPP(::send).stubs().will(returnValue(0)); + EXPECT_EQ(mShmSyncEp->SendFds(fds1, len), NN_ERROR); + EXPECT_EQ(mShmSyncEp->SendFds(fds2, len), NN_INVALID_PARAM); +} + +static HResult MockGetPeerDataAddressByOffset(uint64_t offset, uintptr_t &address) +{ + address = reinterpret_cast(&mockReq); + offset = 0; + return 0; +} + +TEST_F(TestShmSyncEndpointNew, ReceiveRawPeerChannelAddressFail) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + mShmSyncEp->mShmEp->mName = "Test"; + mShmSyncEp->mExistDelayEvent = true; + mShmSyncEp->mDelayHandleReceiveEvent.peerChannelAddress = 0; + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), NN_ERROR); +} + +TEST_F(TestShmSyncEndpointNew, ReceiveRawGetPeerDataAddressByOffsetAndOpTypeFail) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + mShmSyncEp->mExistDelayEvent = true; + mShmSyncEp->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(returnValue(1)) + .then(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), 1); + mShmSyncEp->mExistDelayEvent = true; + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), NN_ERROR); +} + +TEST_F(TestShmSyncEndpointNew, ReceiveRawNotExistDelayEvent) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + mShmSyncEp->mExistDelayEvent = false; + MOCKER_CPP(&ShmSyncEndpoint::Receive).stubs().will(returnValue(1)); + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), 1); +} + +TEST_F(TestShmSyncEndpointNew, ReceiveRawSeqNoErr) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + mShmSyncEp->mExistDelayEvent = true; + mShmSyncEp->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + mShmSyncEp->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + mShmSyncEp->mLastSendSeqNo = 1; + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), NN_SEQ_NO_NOT_MATCHED); +} + +TEST_F(TestShmSyncEndpointNew, ReceiveRawEncryptAllocateFail) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + mShmSyncEp->mExistDelayEvent = true; + mShmSyncEp->mIsNeedEncrypt = true; + mShmSyncEp->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + mShmSyncEp->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + mShmSyncEp->mLastSendSeqNo = 0; + mShmSyncEp->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(false)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), NN_MALLOC_FAILED); +} + +TEST_F(TestShmSyncEndpointNew, ReceiveRawDecryptFail) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + mShmSyncEp->mExistDelayEvent = true; + mShmSyncEp->mIsNeedEncrypt = true; + mShmSyncEp->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + mShmSyncEp->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + mShmSyncEp->mLastSendSeqNo = 0; + mShmSyncEp->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(true)); + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), NN_DECRYPT_FAILED); +} + +TEST_F(TestShmSyncEndpointNew, ReceiveRawDecryptSuccess) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + mShmSyncEp->mExistDelayEvent = true; + mShmSyncEp->mIsNeedEncrypt = true; + mShmSyncEp->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + mShmSyncEp->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + mShmSyncEp->mLastSendSeqNo = 0; + mShmSyncEp->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(true)); + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(true)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), NN_OK); +} + +TEST_F(TestShmSyncEndpointNew, ReceiveRawAllocateFail) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + mShmSyncEp->mExistDelayEvent = true; + mShmSyncEp->mIsNeedEncrypt = false; + mShmSyncEp->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + mShmSyncEp->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + mShmSyncEp->mLastSendSeqNo = 0; + mShmSyncEp->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(false)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), NN_MALLOC_FAILED); +} + +TEST_F(TestShmSyncEndpointNew, ReceiveRawMemcpyFail) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + mockReq.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + mShmSyncEp->mExistDelayEvent = true; + mShmSyncEp->mIsNeedEncrypt = false; + mShmSyncEp->mDelayHandleReceiveEvent.peerChannelAddress = reinterpret_cast(&mockReq); + mShmSyncEp->mDelayHandleReceiveEvent.opType = ShmOpContextInfo::SH_RECEIVE; + mShmSyncEp->mLastSendSeqNo = 0; + mShmSyncEp->mDelayHandleReceiveEvent.dataSize = NN_NO1024; + MOCKER_CPP(&ShmChannel::GetPeerDataAddressByOffset) + .stubs() + .will(invoke(MockGetPeerDataAddressByOffset)); + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(true)); + MOCKER_CPP(&memcpy_s) + .stubs() + .will(returnValue(1)); + MOCKER_CPP(&ShmChannel::DCMarkPeerBuckFree) + .stubs() + .will(returnValue(0)); + + EXPECT_EQ(mShmSyncEp->ReceiveRaw(timeout, ctx), NN_INVALID_PARAM); +} + +} +} \ No newline at end of file diff --git a/test/unit_test/transport/shm/test_shm_worker.cpp b/test/unit_test/transport/shm/test_shm_worker.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7e2212a253c4f0e298e2aa9c348415f7f727f4a8 --- /dev/null +++ b/test/unit_test/transport/shm/test_shm_worker.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "shm_worker.h" +#include "shm_handle.h" +#include "shm_queue.h" +#include "shm_common.h" + +namespace ock { +namespace hcom { +class TestShmWorker : public testing::Test { +public: + TestShmWorker(); + virtual void SetUp(void); + virtual void TearDown(void); + ShmWorker *worker = nullptr; + ShmChannel *ch = nullptr; + std::string name = "TestShmWorker"; + UBSHcomNetWorkerIndex index{}; + ShmWorkerOptions options{}; + NetMemPoolFixedPtr opMemPool; + NetMemPoolFixedPtr opCtxMemPool; + NetMemPoolFixedPtr sglOpMemPool; +}; + +TestShmWorker::TestShmWorker() {} + +void TestShmWorker::SetUp() +{ + worker = new (std::nothrow) ShmWorker(name, index, options, opMemPool, opCtxMemPool, sglOpMemPool); + ch = new (std::nothrow) ShmChannel(name, 0, NN_NO128, NN_NO4); +} + +void TestShmWorker::TearDown() +{ + if (worker != nullptr) { + delete worker; + worker = nullptr; + } + + if (ch != nullptr) { + delete ch; + ch = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestShmWorker, InitializeInitedFail) +{ + worker->mInited = true; + HResult res = worker->Initialize(); + EXPECT_EQ(res, SH_OK); +} + +TEST_F(TestShmWorker, InitializeValidateFail) +{ + worker->mInited = false; + MOCKER_CPP(&ShmWorker::Validate).stubs().will(returnValue(1)); + HResult res = worker->Initialize(); + EXPECT_EQ(res, 1); +} + +TEST_F(TestShmWorker, InitializeCreateEventQueueFail) +{ + worker->mInited = false; + MOCKER_CPP(&ShmWorker::Validate).stubs().will(returnValue(0)); + MOCKER_CPP(&ShmWorker::CreateEventQueue).stubs().will(returnValue(1)); + HResult res = worker->Initialize(); + EXPECT_EQ(res, 1); +} + +TEST_F(TestShmWorker, CreateEventQueueAlreadyCreatedFail) +{ + worker->CreateEventQueue(); + HResult res = worker->CreateEventQueue(); + EXPECT_EQ(res, SH_ERROR); + + worker->mEventQueue->DecreaseRef(); +} + +TEST_F(TestShmWorker, CreateEventQueueGetPtrFail) +{ + worker->mEventQueue = nullptr; + ShmEventQueue *eventQueueNullPtr = nullptr; + ShmHandle *handleNullPtr = nullptr; + MOCKER_CPP(&ShmEventQueuePtr::Get).stubs().will(returnValue(eventQueueNullPtr)); + + HResult res = worker->CreateEventQueue(); + EXPECT_EQ(res, SH_NEW_OBJECT_FAILED); + + MOCKER_CPP(&ShmHandlePtr::Get).stubs().will(returnValue(handleNullPtr)); + res = worker->CreateEventQueue(); + EXPECT_EQ(res, SH_NEW_OBJECT_FAILED); +} + +TEST_F(TestShmWorker, RunInThreadBusyPoll) +{ + worker->mOptions.threadPriority = 1; + worker->mOptions.mode = SHM_BUSY_POLLING; + MOCKER_CPP(&setpriority).stubs().will(returnValue(1)); + MOCKER_CPP(&ShmWorker::DoBusyPolling).stubs().will(ignoreReturnValue()); + EXPECT_NO_FATAL_FAILURE(worker->RunInThread(-1)); +} + +TEST_F(TestShmWorker, StartNotInitedErr) +{ + worker->mInited = false; + HResult res = worker->Start(); + EXPECT_EQ(res, SH_ERROR); +} + +TEST_F(TestShmWorker, StartAlreadyStarted) +{ + worker->mInited = true; + worker->mStarted = true; + + HResult res = worker->Start(); + EXPECT_EQ(res, SH_OK); +} + +TEST_F(TestShmWorker, StartNewRequestHandlerNull) +{ + worker->mInited = true; + worker->mStarted = false; + worker->mNewRequestHandler = nullptr; + HResult res = worker->Start(); + EXPECT_EQ(res, SH_PARAM_INVALID); +} + +TEST_F(TestShmWorker, StartSendPostedHandlerNull) +{ + ShmNewReqHandler shmNewReqHandler{}; + worker->mInited = true; + worker->mStarted = false; + worker->RegisterNewReqHandler(shmNewReqHandler); + worker->mSendPostedHandler = nullptr; + HResult res = worker->Start(); + EXPECT_EQ(res, SH_PARAM_INVALID); + worker->mNewRequestHandler = nullptr; +} + +TEST_F(TestShmWorker, StartOneSideDoneHandlerNull) +{ + ShmNewReqHandler shmNewReqHandler{}; + ShmPostedHandler shmPostedHandler{}; + worker->mInited = true; + worker->mStarted = false; + worker->RegisterNewReqHandler(shmNewReqHandler); + worker->RegisterReqPostedHandler(shmPostedHandler); + worker->mOneSideDoneHandler = nullptr; + HResult res = worker->Start(); + EXPECT_EQ(res, SH_PARAM_INVALID); + worker->mNewRequestHandler = nullptr; + worker->mSendPostedHandler = nullptr; +} + +TEST_F(TestShmWorker, FillSglCtxNullErr) +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + + HResult res = worker->FillSglCtx(nullptr, sglReq); + EXPECT_EQ(res, SH_PARAM_INVALID); +} + +TEST_F(TestShmWorker, FillSglCtxCopyErr) +{ + ShmSglOpContextInfo sglCtx{}; + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 1); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + + HResult res = worker->FillSglCtx(&sglCtx, sglReq); + EXPECT_EQ(res, SH_PARAM_INVALID); + + res = worker->FillSglCtx(&sglCtx, sglReq); + EXPECT_EQ(res, SH_PARAM_INVALID); +} + +TEST_F(TestShmWorker, SendLocalEvent) +{ + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_SEND; + worker->mOptions.mode = SHM_BUSY_POLLING; + + MOCKER_CPP(&ShmEventQueue::Enqueue) + .stubs() + .will(returnValue(-1)) + .then(returnValue(1)); + + HResult res = worker->SendLocalEvent(0, ch, type); + EXPECT_EQ(res, 1); + ch = nullptr; +} + +TEST_F(TestShmWorker, SendLocalEventQueueFull) +{ + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_SEND; + worker->mOptions.mode = SHM_BUSY_POLLING; + + MOCKER_CPP(&ShmEventQueue::Enqueue) + .stubs() + .will(returnValue(-1)); + + worker->mDefaultTimeout = 0; + HResult res = worker->SendLocalEvent(0, ch, type); + EXPECT_EQ(res, SH_SEND_COMPLETION_CALLBACK_FAILURE); + ch = nullptr; +} + +TEST_F(TestShmWorker, PostSendRawSglParamErr) +{ + UBSHcomNetTransRequest req{}; + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + sglReq.upCtxSize = sizeof(ShmOpContextInfo::upCtx) + 1; + + HResult res = worker->PostSendRawSgl(ch, req, sglReq, 0, 0, -1); + EXPECT_EQ(res, SH_PARAM_INVALID); +} + +TEST_F(TestShmWorker, PostSendRawSglChBroken) +{ + UBSHcomNetTransRequest req{}; + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + ch->mState.Set(ShmChannelState::CH_BROKEN); + + HResult res = worker->PostSendRawSgl(ch, req, sglReq, 0, 0, -1); + EXPECT_EQ(res, SH_CH_BROKEN); +} + +TEST_F(TestShmWorker, PostSendRawSglRetryFull) +{ + UBSHcomNetTransRequest req{}; + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + MOCKER_CPP(&ShmChannel::EQEventEnqueue).stubs().will(returnValue(-1)); + + HResult res = worker->PostSendRawSgl(ch, req, sglReq, 0, 0, -1); + EXPECT_EQ(res, SH_RETRY_FULL); +} + +TEST_F(TestShmWorker, PostReadWriteParamErr) +{ + UBSHcomNetTransRequest req{}; + ShmMRHandleMap map; + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_READ; + req.upCtxSize = sizeof(ShmOpContextInfo::upCtx) + 1; + + HResult res = worker->PostReadWrite(ch, req, map, type); + EXPECT_EQ(res, SH_PARAM_INVALID); +} + +TEST_F(TestShmWorker, PostReadWriteChBroken) +{ + UBSHcomNetTransRequest req{}; + ShmMRHandleMap map; + ShmOpContextInfo::ShmOpType type = ShmOpContextInfo::ShmOpType::SH_READ; + ch->mState.Set(ShmChannelState::CH_BROKEN); + + HResult res = worker->PostReadWrite(ch, req, map, type); + EXPECT_EQ(res, SH_CH_BROKEN); +} + +TEST_F(TestShmWorker, RegisterIdleHandler) +{ + ShmIdleHandler h{}; + EXPECT_NO_FATAL_FAILURE(worker->RegisterIdleHandler(h)); + worker->mIdleHandler = nullptr; +} + +TEST_F(TestShmWorker, GetFinishTime) +{ + worker->mDefaultTimeout = 1; + EXPECT_NO_FATAL_FAILURE(worker->GetFinishTime()); + worker->mDefaultTimeout = 0; + EXPECT_EQ(worker->GetFinishTime(), 0); +} + +TEST_F(TestShmWorker, NeedRetry) +{ + HResult result = -1; + bool res = worker->NeedRetry(result, ch); + EXPECT_EQ(res, true); + + result = 0; + res = worker->NeedRetry(result, ch); + EXPECT_EQ(res, false); + + ch->mState.Set(ShmChannelState::CH_BROKEN); + res = worker->NeedRetry(result, ch); + EXPECT_EQ(res, false); +} + +TEST_F(TestShmWorker, PostSendParamErr) +{ + UBSHcomNetTransRequest req{}; + req.upCtxSize = sizeof(ShmOpContextInfo::upCtx) + 1; + + HResult res = worker->PostSend(ch, req, 0, 0, -1); + EXPECT_EQ(res, SH_PARAM_INVALID); +} + +TEST_F(TestShmWorker, PostSendChBroken) +{ + UBSHcomNetTransRequest req{}; + ch->mState.Set(ShmChannelState::CH_BROKEN); + + HResult res = worker->PostSend(ch, req, 0, 0, -1); + EXPECT_EQ(res, SH_CH_BROKEN); +} + +TEST_F(TestShmWorker, PostSendRetryFull) +{ + UBSHcomNetTransRequest req{}; + MOCKER_CPP(&ShmChannel::EQEventEnqueue).stubs().will(returnValue(-1)); + + HResult res = worker->PostSend(ch, req, 0, 0, -1); + EXPECT_EQ(res, SH_RETRY_FULL); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/sock/test_net_sock_async_endpoint.cpp b/test/unit_test/transport/sock/test_net_sock_async_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72be3be7fa9ec1fee8c06e934ba82a6d915238a5 --- /dev/null +++ b/test/unit_test/transport/sock/test_net_sock_async_endpoint.cpp @@ -0,0 +1,524 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include "hcom.h" +#include "net_common.h" +#include "net_sock_driver_oob.h" +#include "sock_validation.h" +#include "net_sock_async_endpoint.h" + +namespace ock { +namespace hcom { +class TestNetSockAsyncEndpoint : public testing::Test { +public: + TestNetSockAsyncEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name; + std::string ip; + uint16_t port; + NetDriverSockWithOOB *mDriver = nullptr; + SockWorker *mWorker = nullptr; + Sock *sock = nullptr; + UBSHcomNetWorkerIndex mWorkerIndex; + NetAsyncEndpointSock *ep = nullptr; + UBSHcomNetTransRequest request; +}; + +TestNetSockAsyncEndpoint::TestNetSockAsyncEndpoint() {} + +void TestNetSockAsyncEndpoint::SetUp() +{ + bool startOobSvr = true; + UBSHcomNetDriverProtocol protocol = TCP; + mDriver = new (std::nothrow) NetDriverSockWithOOB(name, startOobSvr, protocol, SOCK_TCP); + mDriver->mStarted = true; + + SockWorkerOptions options; + NetMemPoolFixedPtr memPool; + NetMemPoolFixedPtr sglMemPool; + NetMemPoolFixedPtr headerReqMemPool; + UBSHcomNetWorkerIndex index; + mWorker = new (std::nothrow) SockWorker(SOCK_TCP, name, index, memPool, sglMemPool, headerReqMemPool, options); + ASSERT_NE(mWorker, nullptr); + + uint32_t sockId = NN_NO100; + uint32_t mid = 0; + SockOptions sockOptions; + sock = new (std::nothrow) Sock(SOCK_TCP, name, sockId, -1, sockOptions); + ASSERT_NE(sock, nullptr); + + ep = new (std::nothrow) NetAsyncEndpointSock(sockId, sock, mDriver, index); + ASSERT_NE(ep, nullptr); + + ep->mState.Set(NEP_ESTABLISHED); + ep->mAllowedSize = NN_NO128; + ep->mSegSize = NN_NO128; + + request.lAddress = reinterpret_cast(&mWorkerIndex); + request.size = 1; +} + +void TestNetSockAsyncEndpoint::TearDown() +{ + if (ep != nullptr) { + delete ep; + ep = nullptr; + } + if (mWorker != nullptr) { + delete mWorker; + mWorker = nullptr; + } + + GlobalMockObject::verify(); +} + +static UBSHcomNetTransHeader mockMrBuf{}; +static bool MockGetFreeBuffer(uintptr_t &mrBufAddress) +{ + mrBufAddress = reinterpret_cast(&mockMrBuf); + return true; +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPeerIpAndPortErr) +{ + ep->mSock = nullptr; + std::string ret = ep->PeerIpAndPort(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncUdsNameErr) +{ + std::string ret = ep->UdsName(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncGetRemoteUdsIdInfo) +{ + ep->mState.Set(NEP_BROKEN); + ep->mDriver->mStartOobSvr = true; + ep->mDriver->mOptions.oobType = NET_OOB_UDS; + ep->mRemoteUdsIdInfo.gid = NN_NO1024; + ep->mRemoteUdsIdInfo.pid = NN_NO1024; + ep->mRemoteUdsIdInfo.uid = NN_NO1024; + UBSHcomNetUdsIdInfo sockIdInfo{}; + NResult ret = ep->GetRemoteUdsIdInfo(sockIdInfo); + EXPECT_EQ(ret, NN_OK); + + ep->mRemoteUdsIdInfo.gid = 0; + ep->mRemoteUdsIdInfo.pid = 0; + ep->mRemoteUdsIdInfo.uid = 0; + ret = ep->GetRemoteUdsIdInfo(sockIdInfo); + EXPECT_EQ(ret, NN_ERROR); + + ep->mDriver->mOptions.oobType = NET_OOB_TCP; + ret = ep->GetRemoteUdsIdInfo(sockIdInfo); + EXPECT_EQ(ret, NN_UDS_ID_INFO_NOT_SUPPORT); + + ep->mDriver->mStartOobSvr = false; + ret = ep->GetRemoteUdsIdInfo(sockIdInfo); + EXPECT_EQ(ret, NN_UDS_ID_INFO_NOT_SUPPORT); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncGetPeerIpPort) +{ + std::string ip; + uint16_t port; + ep->mSock->mPeerIpPort = ""; + bool ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + ep->mSock->mPeerIpPort = "1.2.3.4"; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + ep->mSock->mPeerIpPort = "1.2.3.4:test"; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + ep->mSock->mPeerIpPort = "1.2.3.4:0"; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + ep->mSock->mPeerIpPort = "1.2.3.4:16"; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, true); + + ep->mSock = nullptr; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncSetEpOptionReSetErr) +{ + UBSHcomEpOptions options{}; + options.tcpBlockingIo = false; + NResult ret = ep->SetEpOption(options); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncSetEpOptionTimeoutErr) +{ + UBSHcomEpOptions options{}; + options.tcpBlockingIo = true; + options.sendTimeout = NN_NO2; + ep->mDefaultTimeout = NN_NO1; + NResult ret = ep->SetEpOption(options); + EXPECT_EQ(ret, NN_ERROR); + ep->mDefaultTimeout = -1; +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncSetEpOptionSetTimeoutErr) +{ + UBSHcomEpOptions options{}; + options.tcpBlockingIo = true; + MOCKER_CPP(&Sock::SetBlockingIo, SResult(Sock::*)(UBSHcomEpOptions &)).stubs().will(returnValue(1)); + NResult ret = ep->SetEpOption(options); + EXPECT_EQ(ret, NN_ERROR); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncGetSendQueueCount) +{ + MOCKER_CPP(&Sock::GetSendQueueCount).stubs().will(returnValue(1)); + uint32_t ret = ep->GetSendQueueCount(); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendZCopy) +{ + UBSHcomNetTransOpInfo OpInfo{}; + MOCKER_CPP(&SockWorker::PostSend).stubs() + .will(returnValue(static_cast(SS_TCP_RETRY))) + .then(returnValue(1)) + .then(returnValue(0)); + + int ret = ep->PostSendZCopy(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = ep->PostSendZCopy(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendSeqZCopy) +{ + ep->mSendZCopy = true; + MOCKER_CPP(&NetAsyncEndpointSock::PostSendZCopy).stubs().will(returnValue(0)); + int ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendSeqValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + int ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendSeqValidateBuffFail) +{ + request.upCtxSize = sizeof(SockOpContextInfo::upCtx) + 1; + int ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.upCtxSize = 0; +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendSeqGetBuffErr) +{ + ep->mSendZCopy = false; + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::GetFreeBuffer, bool(NormalMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(returnValue(false)); + int ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendSeqCopyErr) +{ + ep->mSendZCopy = false; + NormalMemoryRegionFixedBuffer *Mr = mDriver->mSockDriverSendMR + = new (std::nothrow)NormalMemoryRegionFixedBuffer(name, 1, 1); + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::GetFreeBuffer, bool(NormalMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::ReturnBuffer, bool(NormalMemoryRegionFixedBuffer::*)(uintptr_t)) + .stubs().will(returnValue(true)); + int ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + if (Mr != nullptr) { + delete Mr; + Mr = nullptr; + } +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendOpInfoValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + UBSHcomNetTransOpInfo opInfo{}; + int ret = ep->PostSend(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendOpInfoValidateBuffFail) +{ + request.size = 0; + UBSHcomNetTransOpInfo opInfo{}; + int ret = ep->PostSend(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.size = 1; +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendOpInfoZCopy) +{ + ep->mSendZCopy = true; + UBSHcomNetTransOpInfo opInfo{}; + MOCKER_CPP(&NetAsyncEndpointSock::PostSendZCopy).stubs().will(returnValue(0)); + int ret = ep->PostSend(0, request, opInfo); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendOpInfoGetBuffErr) +{ + ep->mSendZCopy = false; + UBSHcomNetTransOpInfo opInfo{}; + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::GetFreeBuffer, bool(NormalMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(returnValue(false)); + int ret = ep->PostSend(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendOpInfoCopyErr) +{ + ep->mSendZCopy = false; + UBSHcomNetTransOpInfo opInfo{}; + NormalMemoryRegionFixedBuffer *Mr = mDriver->mSockDriverSendMR + = new (std::nothrow)NormalMemoryRegionFixedBuffer(name, 1, 1); + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::GetFreeBuffer, bool(NormalMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::ReturnBuffer, bool(NormalMemoryRegionFixedBuffer::*)(uintptr_t)) + .stubs().will(returnValue(true)); + int ret = ep->PostSend(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + if (Mr != nullptr) { + delete Mr; + Mr = nullptr; + } +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendRawValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + int ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendRawValidateBuffFail) +{ + request.size = 0; + int ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.size = 1; +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendRawSeqZCopy) +{ + ep->mSendZCopy = true; + MOCKER_CPP(&NetAsyncEndpointSock::PostSendZCopy).stubs().will(returnValue(0)); + int ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendRawSeqGetBuffErr) +{ + ep->mSendZCopy = false; + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::GetFreeBuffer, bool(NormalMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(returnValue(false)); + int ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendRawSeqCopyErr) +{ + ep->mSendZCopy = false; + NormalMemoryRegionFixedBuffer *Mr = mDriver->mSockDriverSendMR + = new (std::nothrow)NormalMemoryRegionFixedBuffer(name, 1, 1); + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::GetFreeBuffer, bool(NormalMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::ReturnBuffer, bool(NormalMemoryRegionFixedBuffer::*)(uintptr_t)) + .stubs().will(returnValue(true)); + int ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + + if (Mr != nullptr) { + delete Mr; + Mr = nullptr; + } +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendRawSglValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + int ret = ep->PostSendRawSgl(sglReq, 0); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostSendRawSglValidateBuffFail) +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + sglReq.upCtxSize = sizeof(SockOpContextInfo::upCtx) + 1; + int ret = ep->PostSendRawSgl(sglReq, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostReadValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + int ret = ep->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostReadValidateBuffFail) +{ + request.size = 0; + int ret = ep->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.size = 1; +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostReadOneSideValidateFail) +{ + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(1)); + int ret = ep->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_INVALID_LKEY)); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostReadSglValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + int ret = ep->PostRead(sglReq); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostReadSglOneSideSglValidateFail) +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + sglReq.upCtxSize = sizeof(SockOpContextInfo::upCtx) + 1; + int ret = ep->PostRead(sglReq); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostWriteValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + int ret = ep->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostWriteValidateBuffFail) +{ + request.size = 0; + int ret = ep->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.size = 1; +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostWriteOneSideValidateFail) +{ + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(1)); + int ret = ep->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostWriteSglValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + int ret = ep->PostWrite(sglReq); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncPostWriteSglOneSideSglValidateFail) +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + sglReq.upCtxSize = sizeof(SockOpContextInfo::upCtx) + 1; + int ret = ep->PostWrite(sglReq); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncEnableSendZCopy) +{ + EXPECT_NO_FATAL_FAILURE(ep->EnableSendZCopy()); + ep->mSendZCopy = false; +} + +TEST_F(TestNetSockAsyncEndpoint, AsyncGetFinishTime) +{ + ep->mDefaultTimeout = 0; + uint64_t ret = ep->GetFinishTime(); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetSockAsyncEndpoint, StateValidationFail) +{ + int ret = StateValidation(ep->mState, 0, nullptr, nullptr); + EXPECT_EQ(ret, static_cast(NN_ERROR)); +} + +TEST_F(TestNetSockAsyncEndpoint, TwoSideSglValidationFail) +{ + size_t totalSize = 0; + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + sglReq.iovCount = 0; + int ret = TwoSideSglValidation(sglReq, mDriver, 1, totalSize); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); +} + +TEST_F(TestNetSockAsyncEndpoint, TwoSideSglValidationLkeyFail) +{ + size_t totalSize = 0; + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(1)); + int ret = TwoSideSglValidation(sglReq, mDriver, 1, totalSize); + EXPECT_EQ(ret, static_cast(NN_INVALID_LKEY)); +} + +TEST_F(TestNetSockAsyncEndpoint, OneSideValidationFail) +{ + request.upCtxSize = sizeof(SockOpContextInfo::upCtx) + 1; + int ret = OneSideValidation(request, mDriver); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.upCtxSize = 0; +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/sock/test_net_sock_driver_oob.cpp b/test/unit_test/transport/sock/test_net_sock_driver_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8f24968d24e3e260ac3e7db7e5d767ec3925c6d --- /dev/null +++ b/test/unit_test/transport/sock/test_net_sock_driver_oob.cpp @@ -0,0 +1,406 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include +#include "net_sock_driver_oob.h" +#include "net_sock_async_endpoint.h" +#include "net_oob_secure.h" + +namespace ock { +namespace hcom { +class TestNetSockDriverOob : public testing::Test { +public: + std::string name; + NetDriverSockWithOOB *mDriver = nullptr; + Sock *sock = nullptr; + UBSHcomNetDriverOptions option{}; + + TestNetSockDriverOob(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +TestNetSockDriverOob::TestNetSockDriverOob() {} + +void TestNetSockDriverOob::SetUp() +{ + bool startOobSvr = true; + UBSHcomNetDriverProtocol protocol = TCP; + mDriver = new (std::nothrow) NetDriverSockWithOOB(name, startOobSvr, protocol, SOCK_TCP); + mDriver->mStarted = true; + + SockOptions sockOptions; + sock = new (std::nothrow) Sock(SOCK_TCP, name, NN_NO100, -1, sockOptions); + ASSERT_NE(sock, nullptr); +} + +void TestNetSockDriverOob::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestNetSockDriverOob, InitializeNetOutLoggerInstanceErr) +{ + mDriver->mInited = false; + UBSHcomNetOutLogger *logger = nullptr; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&UBSHcomNetOutLogger::Instance).stubs().will(returnValue(logger)); + EXPECT_EQ(mDriver->Initialize(option), NN_NOT_INITIALIZED); +} + +TEST_F(TestNetSockDriverOob, ValidateOptionsOobTypeErr) +{ + mDriver->mOptions.oobType = NET_OOB_UB; + EXPECT_EQ(mDriver->ValidateOptionsOobType(), NN_INVALID_PARAM); +} + +TEST_F(TestNetSockDriverOob, InitializeLoadOpensslErr) +{ + mDriver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateOptions).stubs().will(returnValue(0)); + option.enableTls = true; + MOCKER_CPP(HcomSsl::Load).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->Initialize(option), NN_NOT_INITIALIZED); + option.enableTls = false; +} + +TEST_F(TestNetSockDriverOob, InitializeCreateWorkerResourceErr) +{ + mDriver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateWorkerResource).stubs().will(returnValue(1)); + MOCKER_CPP(&NetDriverSockWithOOB::UnInitializeInner).stubs().will(ignoreReturnValue()); + EXPECT_EQ(mDriver->Initialize(option), 1); +} + +TEST_F(TestNetSockDriverOob, InitializeCreateClientLBErr) +{ + mDriver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateWorkerResource).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateWorkers).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateClientLB).stubs().will(returnValue(1)); + MOCKER_CPP(&NetDriverSockWithOOB::UnInitializeInner).stubs().will(ignoreReturnValue()); + EXPECT_EQ(mDriver->Initialize(option), 1); +} + +TEST_F(TestNetSockDriverOob, InitializeCreateListenersErr) +{ + mDriver->mInited = false; + mDriver->mStartOobSvr = true; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateWorkerResource).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateWorkers).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateClientLB).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateListeners).stubs().will(returnValue(1)); + MOCKER_CPP(&NetDriverSockWithOOB::UnInitializeInner).stubs().will(ignoreReturnValue()); + EXPECT_EQ(mDriver->Initialize(option), 1); +} + +TEST_F(TestNetSockDriverOob, UnInitializeStartedErr) +{ + mDriver->mInited = false; + mDriver->mStarted = true; + EXPECT_NO_FATAL_FAILURE(mDriver->UnInitialize()); +} + +TEST_F(TestNetSockDriverOob, ValidateOptionsArrErr) +{ + MOCKER_CPP(ValidateArrayOptions).stubs().will(returnValue(false)); + EXPECT_EQ(mDriver->ValidateOptions(), NN_INVALID_PARAM); +} + +TEST_F(TestNetSockDriverOob, ValidateOptionsParamErr) +{ + MOCKER_CPP(ValidateArrayOptions).stubs().will(returnValue(true)); + mDriver->mOptions.tcpSendBufSize = NN_NO10000; + EXPECT_EQ(mDriver->ValidateOptions(), NN_INVALID_PARAM); + + mDriver->mOptions.tcpSendBufSize = NN_NO1024; + mDriver->mOptions.tcpReceiveBufSize = NN_NO10000; + EXPECT_EQ(mDriver->ValidateOptions(), NN_INVALID_PARAM); +} + +TEST_F(TestNetSockDriverOob, ValidateOptionsErrTwo) +{ + MOCKER_CPP(ValidateArrayOptions).stubs().will(returnValue(true)); + mDriver->mSockType = SOCK_UDS; + MOCKER_CPP(&UBSHcomNetDriver::ValidateAndParseOobPortRange).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(mDriver->ValidateOptions(), NN_INVALID_PARAM); + + MOCKER_CPP(&UBSHcomNetDriver::ValidateOptionsOobType).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->ValidateOptions(), NN_INVALID_PARAM); +} + +TEST_F(TestNetSockDriverOob, CreateWorkerResourceOpCtxMemPoolErr) +{ + MOCKER_CPP(&NetDriverSockWithOOB::CreateOpCtxMemPool).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateWorkerResource(), 1); +} + +TEST_F(TestNetSockDriverOob, CreateWorkerResourceSglCtxMemPoolErr) +{ + MOCKER_CPP(&NetDriverSockWithOOB::CreateOpCtxMemPool).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateSglCtxMemPool).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateWorkerResource(), 1); +} + +TEST_F(TestNetSockDriverOob, CreateWorkerResourceHeaderReqMemPoolErr) +{ + mDriver->mOptions.tcpSendZCopy = true; + MOCKER_CPP(&NetDriverSockWithOOB::CreateOpCtxMemPool).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateSglCtxMemPool).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateHeaderReqMemPool).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateWorkerResource(), 1); +} + +TEST_F(TestNetSockDriverOob, CreateWorkerResourceSendMrErr) +{ + mDriver->mOptions.tcpSendZCopy = false; + MOCKER_CPP(&NetDriverSockWithOOB::CreateOpCtxMemPool).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateSglCtxMemPool).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverSockWithOOB::CreateSendMr).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateWorkerResource(), 1); +} + +TEST_F(TestNetSockDriverOob, CreateOpCtxMemPoolErr) +{ + MOCKER_CPP(&NetMemPoolFixed::Initialize).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateOpCtxMemPool(), 1); +} + +TEST_F(TestNetSockDriverOob, CreateOpCtxMemPoolNullErr) +{ + NetMemPoolFixed *testOpCtxMemPool = nullptr; + MOCKER_CPP(&NetRef::Get).stubs().will(returnValue(testOpCtxMemPool)); + EXPECT_EQ(mDriver->CreateOpCtxMemPool(), NN_INVALID_PARAM); +} + +TEST_F(TestNetSockDriverOob, CreateSglCtxMemPoolErr) +{ + MOCKER_CPP(&NetMemPoolFixed::Initialize).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateSglCtxMemPool(), 1); +} + +TEST_F(TestNetSockDriverOob, CreateSglCtxMemPoolNullErr) +{ + NetMemPoolFixed *testSglCtxMemPool = nullptr; + MOCKER_CPP(&NetRef::Get).stubs().will(returnValue(testSglCtxMemPool)); + EXPECT_EQ(mDriver->CreateSglCtxMemPool(), NN_INVALID_PARAM); +} + +TEST_F(TestNetSockDriverOob, CreateHeaderReqMemPoolErr) +{ + MOCKER_CPP(&NetMemPoolFixed::Initialize).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateHeaderReqMemPool(), 1); +} + +TEST_F(TestNetSockDriverOob, CreateHeaderReqMemPoolNullErr) +{ + NetMemPoolFixed *testHeaderReqMemPool = nullptr; + MOCKER_CPP(&NetRef::Get).stubs().will(returnValue(testHeaderReqMemPool)); + EXPECT_EQ(mDriver->CreateHeaderReqMemPool(), NN_INVALID_PARAM); +} + +TEST_F(TestNetSockDriverOob, CreateHeaderReqMemPoolSuccess) +{ + MOCKER_CPP(&NetMemPoolFixed::Initialize).stubs().will(returnValue(0)); + EXPECT_EQ(mDriver->CreateHeaderReqMemPool(), 0); + mDriver->mHeaderReqMemPool.Set(nullptr); +} + +TEST_F(TestNetSockDriverOob, CreateSendMrErr) +{ + MOCKER_CPP(&NormalMemoryRegionFixedBuffer::Create).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateSendMr(), 1); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionErr) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + mDriver->mInited = true; + MOCKER_CPP(NormalMemoryRegion::Create, + NResult(const std::string &, uint64_t, NormalMemoryRegion *&)) + .stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateMemoryRegion(NN_NO8, mr), 1); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionInitErr) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + mDriver->mInited = true; + auto tmpBuf = new (std::nothrow) NormalMemoryRegion(name, false, 0, NN_NO8); + MOCKER(NormalMemoryRegion::Create, NResult(const std::string &, uint64_t, NormalMemoryRegion *&)).stubs() + .with(any(), any(), outBound(tmpBuf)) + .will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*tmpBuf, &NormalMemoryRegion::Initialize).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateMemoryRegion(NN_NO8, mr), 1); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionRegisterErr) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + mDriver->mInited = true; + auto tmpBuf = new (std::nothrow) NormalMemoryRegion(name, false, 0, NN_NO8); + MOCKER(NormalMemoryRegion::Create, NResult(const std::string &, uint64_t, NormalMemoryRegion *&)).stubs() + .with(any(), any(), outBound(tmpBuf)) + .will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*tmpBuf, &NormalMemoryRegion::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(MemoryRegionChecker::Register).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateMemoryRegion(NN_NO8, mr), 1); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionAddressErr) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + uintptr_t address = 1; + mDriver->mInited = true; + MOCKER_CPP(NormalMemoryRegion::Create, + NResult(const std::string &, uintptr_t, uint64_t, NormalMemoryRegion *&)) + .stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateMemoryRegion(address, NN_NO8, mr), 1); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionAddressInitErr) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + uintptr_t address = 1; + mDriver->mInited = true; + auto tmpBuf = new (std::nothrow) NormalMemoryRegion(name, true, address, NN_NO8); + MOCKER(NormalMemoryRegion::Create, NResult(const std::string &, uintptr_t, uint64_t, NormalMemoryRegion *&)).stubs() + .with(any(), any(), any(), outBound(tmpBuf)) + .will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*tmpBuf, &NormalMemoryRegion::Initialize).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateMemoryRegion(address, NN_NO8, mr), 1); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionAddressRegisterErr) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + uintptr_t address = 1; + mDriver->mInited = true; + auto tmpBuf = new (std::nothrow) NormalMemoryRegion(name, true, address, NN_NO8); + MOCKER(NormalMemoryRegion::Create, NResult(const std::string &, uintptr_t, uint64_t, NormalMemoryRegion *&)).stubs() + .with(any(), any(), any(), outBound(tmpBuf)) + .will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*tmpBuf, &NormalMemoryRegion::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(MemoryRegionChecker::Register).stubs().will(returnValue(1)); + EXPECT_EQ(mDriver->CreateMemoryRegion(address, NN_NO8, mr), 1); +} + +TEST_F(TestNetSockDriverOob, CreateMemoryRegionMemidErr) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + unsigned long memid = 1; + EXPECT_EQ(mDriver->CreateMemoryRegion(0, mr, memid), NN_ERROR); +} + +TEST_F(TestNetSockDriverOob, MultiRailNewConnectionErr) +{ + OOBTCPConnection conn(-1); + EXPECT_EQ(mDriver->MultiRailNewConnection(conn), NN_ERROR); +} + +TEST_F(TestNetSockDriverOob, MapAndRegVaForUBErr) +{ + unsigned long memid = 1; + uint64_t va = 0; + EXPECT_EQ(mDriver->MapAndRegVaForUB(memid, va), nullptr); +} + +TEST_F(TestNetSockDriverOob, UnmapVaForUBErr) +{ + uint64_t va = 0; + EXPECT_EQ(mDriver->UnmapVaForUB(va), NN_ERROR); +} + +TEST_F(TestNetSockDriverOob, HandleSockRealConnectWithDupId) +{ + SockOpContextInfo ctx {}; + ctx.sock = sock; + sock->IncreaseRef(); + + UBSHcomNetWorkerIndex index; + NetAsyncEndpointSock *ep = new (std::nothrow) NetAsyncEndpointSock(sock->mId, sock, mDriver, index); + ASSERT_NE(ep, nullptr); + mDriver->mEndPoints.emplace(sock->mId, ep); + ep->IncreaseRef(); + + auto ret = mDriver->HandleSockRealConnect(ctx); + EXPECT_EQ(ret, NN_ERROR); + + mDriver->mEndPoints.erase(sock->mId); + ep->DecreaseRef(); + sock->DecreaseRef(); +} + +TEST_F(TestNetSockDriverOob, HandleSockRealConnectWithOverPayload) +{ + SockOpContextInfo ctx {}; + ctx.sock = sock; + sock->IncreaseRef(); + UBSHcomNetTransHeader mockHeader {}; + mockHeader.dataLength = NN_NO2048; + ctx.header = &mockHeader; + mDriver->mOptions.magic = 0; + mDriver->mEnableTls = false; + + SockWorkerOptions options; + NetMemPoolFixedPtr memPool; + NetMemPoolFixedPtr sglMemPool; + NetMemPoolFixedPtr headerReqMemPool; + UBSHcomNetWorkerIndex index; + SockWorker *worker = + new (std::nothrow) SockWorker(SOCK_TCP, name, index, memPool, sglMemPool, headerReqMemPool, options); + ASSERT_NE(worker, nullptr); + ctx.sock->UpContext1(reinterpret_cast(worker)); + worker->IncreaseRef(); + + auto ret = mDriver->HandleSockRealConnect(ctx); + EXPECT_EQ(ret, NN_EP_CLOSE); + + sock->DecreaseRef(); + worker->DecreaseRef(); +} + +TEST_F(TestNetSockDriverOob, HandleSockErrorWithErrUpCtx) +{ + SockWorkerOptions options; + NetMemPoolFixedPtr memPool; + NetMemPoolFixedPtr sglMemPool; + NetMemPoolFixedPtr headerReqMemPool; + UBSHcomNetWorkerIndex index; + SockWorker *worker = + new (std::nothrow) SockWorker(SOCK_TCP, name, index, memPool, sglMemPool, headerReqMemPool, options); + ASSERT_NE(worker, nullptr); + sock->UpContext1(reinterpret_cast(worker)); + worker->IncreaseRef(); + + auto ret = mDriver->HandleSockError(sock); + EXPECT_EQ(ret, NN_ERROR); + worker->DecreaseRef(); +} + +TEST_F(TestNetSockDriverOob, GetConnRespOther) +{ + SockType t = SOCK_UDS_TCP; + EXPECT_EQ(mDriver->GetConnResp(t), OK); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/sock/test_net_sock_sync_endpoint.cpp b/test/unit_test/transport/sock/test_net_sock_sync_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8832a60377cf886f1fa4f77384bb1969245236c5 --- /dev/null +++ b/test/unit_test/transport/sock/test_net_sock_sync_endpoint.cpp @@ -0,0 +1,572 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include +#include +#include "hcom.h" +#include "net_common.h" +#include "net_sock_driver_oob.h" +#include "net_security_rand.h" +#include "sock_validation.h" +#include "net_sock_sync_endpoint.h" + +namespace ock { +namespace hcom { +class TestNetSockSyncEndpoint : public testing::Test { +public: + TestNetSockSyncEndpoint(); + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name; + std::string ip; + uint16_t port; + NetDriverSockWithOOB *mDriver = nullptr; + SockWorker *mWorker = nullptr; + Sock *sock = nullptr; + UBSHcomNetWorkerIndex mWorkerIndex; + NetSyncEndpointSock *ep = nullptr; + UBSHcomNetTransRequest request; + UBSHcomNetTransSglRequest sglRequest; + UBSHcomNetTransSgeIov *iov = nullptr; +}; + +TestNetSockSyncEndpoint::TestNetSockSyncEndpoint() {} + +void TestNetSockSyncEndpoint::SetUp() +{ + bool startOobSvr = true; + UBSHcomNetDriverProtocol protocol = TCP; + mDriver = new (std::nothrow) NetDriverSockWithOOB(name, startOobSvr, protocol, SOCK_TCP); + mDriver->mStarted = true; + + SockWorkerOptions options; + NetMemPoolFixedPtr memPool; + NetMemPoolFixedPtr sglMemPool; + NetMemPoolFixedPtr headerReqMemPool; + UBSHcomNetWorkerIndex index; + mWorker = new (std::nothrow) SockWorker(SOCK_TCP, name, index, memPool, sglMemPool, headerReqMemPool, options); + ASSERT_NE(mWorker, nullptr); + + uint32_t sockId = NN_NO100; + uint32_t mid = 0; + SockOptions sockOptions; + sock = new (std::nothrow) Sock(SOCK_TCP, name, sockId, -1, sockOptions); + ASSERT_NE(sock, nullptr); + + ep = new (std::nothrow) NetSyncEndpointSock(sockId, sock, mDriver, index); + ASSERT_NE(ep, nullptr); + + ep->mState.Set(NEP_ESTABLISHED); + ep->mAllowedSize = NN_NO128; + ep->mSegSize = NN_NO128; + + request.lAddress = reinterpret_cast(&mWorkerIndex); + request.size = 1; + iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + sglRequest = UBSHcomNetTransSglRequest(iov, 1, 1); +} + +void TestNetSockSyncEndpoint::TearDown() +{ + if (iov != nullptr) { + delete iov; + iov = nullptr; + } + if (ep != nullptr) { + delete ep; + ep = nullptr; + } + if (mWorker != nullptr) { + delete mWorker; + mWorker = nullptr; + } + + GlobalMockObject::verify(); +} + +UBSHcomNetTransHeader mockHeader; +TEST_F(TestNetSockSyncEndpoint, SyncSetEpOptionTimeoutErr) +{ + UBSHcomEpOptions options{}; + options.sendTimeout = NN_NO2; + ep->mDefaultTimeout = NN_NO1; + NResult ret = ep->SetEpOption(options); + EXPECT_EQ(ret, NN_ERROR); + ep->mDefaultTimeout = -1; +} + +TEST_F(TestNetSockSyncEndpoint, SyncSetEpOptionSetTimeoutErr) +{ + UBSHcomEpOptions options{}; + MOCKER_CPP(&Sock::SetBlockingSendTimeout).stubs().will(returnValue(1)); + NResult ret = ep->SetEpOption(options); + EXPECT_EQ(ret, NN_ERROR); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostSendSeqValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + int ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostSendSeqValidateBuffFail) +{ + request.upCtxSize = sizeof(SockOpContextInfo::upCtx) + 1; + int ret = ep->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.upCtxSize = 0; +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostSendOpInfoValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + UBSHcomNetTransOpInfo opInfo{}; + int ret = ep->PostSend(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostSendOpInfoValidateBuffFail) +{ + request.size = 0; + UBSHcomNetTransOpInfo opInfo{}; + int ret = ep->PostSend(0, request, opInfo); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.size = 1; +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostSendRawValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + int ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostSendRawValidateBuffFail) +{ + request.size = 0; + int ret = ep->PostSendRaw(request, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.size = 1; +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostSendRawSglValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + int ret = ep->PostSendRawSgl(sglReq, 0); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostSendRawSglValidateBuffFail) +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + sglReq.upCtxSize = sizeof(SockOpContextInfo::upCtx) + 1; + int ret = ep->PostSendRawSgl(sglReq, 0); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostReadValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + int ret = ep->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostReadValidateBuffFail) +{ + request.size = 0; + int ret = ep->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.size = 1; +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostReadOneSideValidateFail) +{ + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(1)); + int ret = ep->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_INVALID_LKEY)); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostReadCtxInfoErr) +{ + SockOpContextInfo *ctxInfo = nullptr; + MOCKER_CPP(&SockOpContextInfoPool::Get).stubs().will(returnValue(ctxInfo)); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + NResult ret = ep->PostRead(request); + EXPECT_EQ(ret, SS_CTX_FULL); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostReadCtxSglInfoErr) +{ + SockOpContextInfo *ctxInfo = new (std::nothrow) SockOpContextInfo(); + SockSglContextInfo *sglCtxInfo = nullptr; + MOCKER_CPP(&SockOpContextInfoPool::Get).stubs().will(returnValue(ctxInfo)); + MOCKER_CPP(&SockSglContextInfoPool::Get).stubs().will(returnValue(sglCtxInfo)); + MOCKER_CPP(&SockOpContextInfoPool::Return).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + NResult ret = ep->PostRead(request); + EXPECT_EQ(ret, SS_CTX_FULL); + + if (ctxInfo != nullptr) { + delete ctxInfo; + ctxInfo = nullptr; + } +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostReadSglValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + int ret = ep->PostRead(sglReq); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostReadSglOneSideSglValidateFail) +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + sglReq.upCtxSize = sizeof(SockOpContextInfo::upCtx) + 1; + int ret = ep->PostRead(sglReq); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostReadSglCtxInfoErr) +{ + SockOpContextInfo *ctxInfo = nullptr; + MOCKER_CPP(&SockOpContextInfoPool::Get).stubs().will(returnValue(ctxInfo)); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + NResult ret = ep->PostRead(sglRequest); + EXPECT_EQ(ret, SS_CTX_FULL); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostReadSglCtxSglInfoErr) +{ + SockOpContextInfo *ctxInfo = new (std::nothrow) SockOpContextInfo(); + SockSglContextInfo *sglCtxInfo = nullptr; + MOCKER_CPP(&SockOpContextInfoPool::Get).stubs().will(returnValue(ctxInfo)); + MOCKER_CPP(&SockSglContextInfoPool::Get).stubs().will(returnValue(sglCtxInfo)); + MOCKER_CPP(&SockOpContextInfoPool::Return).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + NResult ret = ep->PostRead(sglRequest); + EXPECT_EQ(ret, SS_CTX_FULL); + + if (ctxInfo != nullptr) { + delete ctxInfo; + ctxInfo = nullptr; + } +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostWriteValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + int ret = ep->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostWriteValidateBuffFail) +{ + request.size = 0; + int ret = ep->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); + request.size = 1; +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostWriteOneSideValidateFail) +{ + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(1)); + int ret = ep->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_INVALID_LKEY)); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostWriteCtxInfoErr) +{ + SockOpContextInfo *ctxInfo = nullptr; + MOCKER_CPP(&SockOpContextInfoPool::Get).stubs().will(returnValue(ctxInfo)); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + NResult ret = ep->PostWrite(request); + EXPECT_EQ(ret, SS_CTX_FULL); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostWriteCtxSglInfoErr) +{ + SockOpContextInfo *ctxInfo = new (std::nothrow) SockOpContextInfo(); + SockSglContextInfo *sglCtxInfo = nullptr; + MOCKER_CPP(&SockOpContextInfoPool::Get).stubs().will(returnValue(ctxInfo)); + MOCKER_CPP(&SockSglContextInfoPool::Get).stubs().will(returnValue(sglCtxInfo)); + MOCKER_CPP(&SockOpContextInfoPool::Return).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + NResult ret = ep->PostWrite(request); + EXPECT_EQ(ret, SS_CTX_FULL); + + if (ctxInfo != nullptr) { + delete ctxInfo; + ctxInfo = nullptr; + } +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostWriteSglValidateStateFail) +{ + ep->mState.Set(NEP_BROKEN); + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + int ret = ep->PostWrite(sglReq); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + ep->mState.Set(NEP_ESTABLISHED); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostWriteSglOneSideSglValidateFail) +{ + UBSHcomNetTransSgeIov iov[NN_NO4]; + UBSHcomNetTransSglRequest sglReq(iov, NN_NO4, 0); + sglReq.upCtxSize = sizeof(SockOpContextInfo::upCtx) + 1; + int ret = ep->PostWrite(sglReq); + EXPECT_EQ(ret, static_cast(NN_INVALID_PARAM)); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostWriteSglCtxInfoErr) +{ + SockOpContextInfo *ctxInfo = nullptr; + MOCKER_CPP(&SockOpContextInfoPool::Get).stubs().will(returnValue(ctxInfo)); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + NResult ret = ep->PostWrite(sglRequest); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPostWriteSglCtxSglInfoErr) +{ + SockOpContextInfo *ctxInfo = new (std::nothrow) SockOpContextInfo(); + SockSglContextInfo *sglCtxInfo = nullptr; + MOCKER_CPP(&SockOpContextInfoPool::Get).stubs().will(returnValue(ctxInfo)); + MOCKER_CPP(&SockSglContextInfoPool::Get).stubs().will(returnValue(sglCtxInfo)); + MOCKER_CPP(&SockOpContextInfoPool::Return).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverSockWithOOB::ValidateMemoryRegion).stubs().will(returnValue(0)); + NResult ret = ep->PostWrite(sglRequest); + EXPECT_EQ(ret, SS_PARAM_INVALID); + + if (ctxInfo != nullptr) { + delete ctxInfo; + ctxInfo = nullptr; + } +} + +TEST_F(TestNetSockSyncEndpoint, SyncWaitCompletionFail) +{ + // param init + int32_t timeout = 0; + mockHeader.dataLength = 0; + ep->mLastFlag = NTH_READ; + ep->mRespCtx.mHeader = mockHeader; + MOCKER_CPP(setsockopt).stubs().will(returnValue(0)); + MOCKER_CPP(::recv).stubs().will(returnValue(sizeof(UBSHcomNetTransHeader))); + MOCKER_CPP(close).stubs().will(returnValue(0)); + + NResult ret = ep->WaitCompletion(timeout); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetSockSyncEndpoint, SyncReceiveFailWithErrorDataLen) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx {}; + mockHeader.dataLength = 0; + ep->mRespCtx.mHeader = mockHeader; + MOCKER_CPP(setsockopt).stubs().will(returnValue(0)); + MOCKER_CPP(::recv).stubs().will(returnValue(sizeof(UBSHcomNetTransHeader))); + MOCKER_CPP(close).stubs().will(returnValue(0)); + + NResult ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetSockSyncEndpoint, SyncReceiveFailWithOverDataLen) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx {}; + mockHeader.dataLength = NET_SGE_MAX_SIZE + 1; + ep->mRespCtx.mHeader = mockHeader; + MOCKER_CPP(setsockopt).stubs().will(returnValue(0)); + MOCKER_CPP(::recv).stubs().will(returnValue(sizeof(UBSHcomNetTransHeader))); + MOCKER_CPP(close).stubs().will(returnValue(0)); + + NResult ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetSockSyncEndpoint, SyncReceiveFailWithInvalidCRC) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx {}; + mockHeader.dataLength = NN_NO1024; + ep->mRespCtx.mHeader = mockHeader; + MOCKER_CPP(setsockopt).stubs().will(returnValue(0)); + MOCKER_CPP(::recv).stubs().will(returnValue(sizeof(UBSHcomNetTransHeader))); + MOCKER_CPP(close).stubs().will(returnValue(0)); + + NResult ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_VALIDATE_HEADER_CRC_INVALID); +} + +TEST_F(TestNetSockSyncEndpoint, SyncReceiveFailWithInvalidSeqNo) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx {}; + mockHeader.dataLength = NN_NO1024; + mockHeader.seqNo = 1; + mockHeader.headerCrc = NetFunc::CalcHeaderCrc32(mockHeader); + ep->mRespCtx.mHeader = mockHeader; + MOCKER_CPP(setsockopt).stubs().will(returnValue(0)); + MOCKER_CPP(::recv).stubs().will(returnValue(sizeof(UBSHcomNetTransHeader))); + MOCKER_CPP(close).stubs().will(returnValue(0)); + + NResult ret = ep->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_SEQ_NO_NOT_MATCHED); +} + +TEST_F(TestNetSockSyncEndpoint, SyncGetSendQueueCount) +{ + uint32_t ret = ep->GetSendQueueCount(); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetSockSyncEndpoint, SyncPeerIpAndPortErr) +{ + ep->mSock = nullptr; + std::string ret = ep->PeerIpAndPort(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); +} + +TEST_F(TestNetSockSyncEndpoint, SyncUdsNameErr) +{ + std::string ret = ep->UdsName(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); +} + +TEST_F(TestNetSockSyncEndpoint, SyncGetRemoteUdsIdInfo) +{ + ep->mState.Set(NEP_ESTABLISHED); + ep->mDriver->mStartOobSvr = true; + ep->mDriver->mOptions.oobType = NET_OOB_UDS; + UBSHcomNetUdsIdInfo sockIdInfo{}; + ep->mRemoteUdsIdInfo.gid = NN_NO1024; + ep->mRemoteUdsIdInfo.pid = NN_NO1024; + ep->mRemoteUdsIdInfo.uid = NN_NO1024; + NResult ret = ep->GetRemoteUdsIdInfo(sockIdInfo); + EXPECT_EQ(ret, NN_OK); + + ep->mState.Set(NEP_BROKEN); + ep->mRemoteUdsIdInfo.gid = 0; + ep->mRemoteUdsIdInfo.pid = 0; + ep->mRemoteUdsIdInfo.uid = 0; + ret = ep->GetRemoteUdsIdInfo(sockIdInfo); + EXPECT_EQ(ret, NN_ERROR); + + ep->mDriver->mOptions.oobType = NET_OOB_TCP; + ret = ep->GetRemoteUdsIdInfo(sockIdInfo); + EXPECT_EQ(ret, NN_UDS_ID_INFO_NOT_SUPPORT); + + ep->mDriver->mStartOobSvr = false; + ret = ep->GetRemoteUdsIdInfo(sockIdInfo); + EXPECT_EQ(ret, NN_UDS_ID_INFO_NOT_SUPPORT); +} + +TEST_F(TestNetSockSyncEndpoint, SyncGetPeerIpPort) +{ + std::string ip; + uint16_t port; + ep->mSock->mPeerIpPort = ""; + bool ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + ep->mSock->mPeerIpPort = "1.2.3.4"; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + ep->mSock->mPeerIpPort = "1.2.3.4:test"; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + ep->mSock->mPeerIpPort = "1.2.3.4:0"; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + ep->mSock->mPeerIpPort = "1.2.3.4:16"; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, true); + + ep->mSock = nullptr; + ret = ep->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetSockSyncEndpoint, SyncGetFinishTime) +{ + ep->mDefaultTimeout = 0; + uint64_t ret = ep->GetFinishTime(); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetSockSyncEndpoint, GetRemoteUdsIdInfoFail2) +{ + int ret; + UBSHcomNetUdsIdInfo sockIdInfo{}; + + ep->mState.Set(NEP_ESTABLISHED); + mDriver->mStartOobSvr = true; + mDriver->mOptions.oobType = NET_OOB_TCP; + ret = ep->GetRemoteUdsIdInfo(sockIdInfo); + EXPECT_EQ(ret, NN_UDS_ID_INFO_NOT_SUPPORT); +} + +TEST_F(TestNetSockSyncEndpoint, Connect) +{ + int ret; + std::string payload{}; + UBSHcomNetEndpointPtr ep; + std::string badUrl = "unknown://127.0.0.1:9981"; + std::string serverUrl = "tcp://127.0.0.1:9981"; + mDriver->mInited = true; + mDriver->mStarted = true; + mDriver->mWorkerGroups.emplace_back(std::make_pair(1, 1)); + + MOCKER_CPP(&NetDriverSockWithOOB::Connect, + NResult(NetDriverSockWithOOB::*)(const OOBTCPClientPtr &, const std::string &, UBSHcomNetEndpointPtr &, uint8_t, + uint8_t, uint64_t)).stubs().will(returnValue(1)); + MOCKER_CPP(&NetDriverSockWithOOB::ConnectSyncEp).stubs().will(returnValue(0)); + ret = mDriver->Connect(badUrl, payload, ep, 0, 0, 0, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + mDriver->mEnableTls = true; + ret = mDriver->Connect(serverUrl, payload, ep, 0, 0, 0, 0); + EXPECT_EQ(ret, 1); + + mDriver->mEnableTls = false; + ret = mDriver->Connect(serverUrl, payload, ep, NET_EP_SELF_POLLING, 0, 0, 0); + EXPECT_EQ(ret, 0); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/sock/test_sock_worker.cpp b/test/unit_test/transport/sock/test_sock_worker.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f8edc0b93d926d6adbac715ee4d03305769e944b --- /dev/null +++ b/test/unit_test/transport/sock/test_sock_worker.cpp @@ -0,0 +1,363 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "sock_worker.h" + +namespace ock { +namespace hcom { +class TestSockWorker : public testing::Test { +public: + TestSockWorker(); + virtual void SetUp(void); + virtual void TearDown(void); + SockWorker *mSockWorker = nullptr; + Sock *mSock = nullptr; +}; + +TestSockWorker::TestSockWorker() {} + +void TestSockWorker::SetUp() +{ + SockType mT = SOCK_UDS; + std::string mName = "TestSockWorker"; + UBSHcomNetWorkerIndex mIndex{}; + NetMemPoolFixedPtr mOpCtxMemPool; + NetMemPoolFixedPtr mSglCtxMemPool; + NetMemPoolFixedPtr mHeaderReqMemPool; + SockWorkerOptions mSockWorkerOptions{}; + mSockWorker = new (std::nothrow) + SockWorker(mT, mName, mIndex, mOpCtxMemPool, mSglCtxMemPool, mHeaderReqMemPool, mSockWorkerOptions); + ASSERT_TRUE(mSockWorker != nullptr); + uint64_t mId = 1; + int mFd = -1; + SockOptions mSockOptions{}; + mSock = new (std::nothrow) Sock(mT, mName, mId, mFd, mSockOptions); + ASSERT_TRUE(mSock != nullptr); +} + +void TestSockWorker::TearDown() +{ + if (mSock != nullptr) { + delete mSock; + mSock = nullptr; + } + + if (mSockWorker != nullptr) { + delete mSockWorker; + mSockWorker = nullptr; + } + + GlobalMockObject::verify(); +} + +TEST_F(TestSockWorker, TestInitializeInitedFail) +{ + mSockWorker->mInited = true; + SResult res = mSockWorker->Initialize(); + EXPECT_EQ(res, SS_OK); +} + +TEST_F(TestSockWorker, TestInitializeValidateFail) +{ + MOCKER_CPP(mSockWorker->Validate).stubs().will(returnValue(static_cast(SS_ERROR))); + SResult res = mSockWorker->Initialize(); + EXPECT_EQ(res, static_cast(SS_ERROR)); +} + +TEST_F(TestSockWorker, TestInitializeOpCtxMemPoolFail) +{ + MOCKER_CPP(&OpContextInfoPool::Initialize, + NResult(OpContextInfoPool::*)(const NetMemPoolFixedPtr &, const UBSHcomNetDriverProtocol)) + .stubs() + .will(returnValue(static_cast(SS_ERROR))); + SResult res = mSockWorker->Initialize(); + EXPECT_EQ(res, static_cast(SS_ERROR)); +} + +TEST_F(TestSockWorker, TestInitializeSglCtxMemPoolFail) +{ + MOCKER_CPP(&OpContextInfoPool::Initialize, + NResult(OpContextInfoPool::*)(const NetMemPoolFixedPtr &, const UBSHcomNetDriverProtocol)) + .stubs() + .will(returnValue(static_cast(SS_ERROR))); + SResult res = mSockWorker->Initialize(); + EXPECT_EQ(res, static_cast(SS_ERROR)); +} + +TEST_F(TestSockWorker, TestInitializeHeaderReqMemPoolFail) +{ + mSockWorker->mOptions.tcpSendZCopy = true; + MOCKER_CPP(&OpContextInfoPool::Initialize, + NResult(OpContextInfoPool::*)(const NetMemPoolFixedPtr &, const UBSHcomNetDriverProtocol)) + .stubs() + .will(returnValue(static_cast(SS_ERROR))); + SResult res = mSockWorker->Initialize(); + EXPECT_EQ(res, static_cast(SS_ERROR)); +} + +TEST_F(TestSockWorker, TestStartNotInitializedFail) +{ + mSockWorker->mInited = false; + SResult res = mSockWorker->Start(); + EXPECT_EQ(res, static_cast(SS_ERROR)); +} + +TEST_F(TestSockWorker, TestStartStartedFail) +{ + mSockWorker->mInited = true; + mSockWorker->mStarted = true; + SResult res = mSockWorker->Start(); + EXPECT_EQ(res, static_cast(SS_OK)); +} + +TEST_F(TestSockWorker, TestStartNewRequestHandlerFail) +{ + mSockWorker->mInited = true; + mSockWorker->mStarted = false; + mSockWorker->mNewRequestHandler = nullptr; + SResult res = mSockWorker->Start(); + EXPECT_EQ(res, static_cast(SS_PARAM_INVALID)); +} + +TEST_F(TestSockWorker, TestStartSendPostedHandlerFail) +{ + mSockWorker->mInited = true; + mSockWorker->mStarted = false; + mSockWorker->mSendPostedHandler = nullptr; + SResult res = mSockWorker->Start(); + EXPECT_EQ(res, static_cast(SS_PARAM_INVALID)); +} + +TEST_F(TestSockWorker, TestStartOneSideDoneHandlerFail) +{ + mSockWorker->mInited = true; + mSockWorker->mStarted = false; + mSockWorker->mOneSideDoneHandler = nullptr; + SResult res = mSockWorker->Start(); + EXPECT_EQ(res, static_cast(SS_PARAM_INVALID)); +} + +TEST_F(TestSockWorker, TestPostSendSockAddQueueFail) +{ + SockTransHeader mHeader{}; + UBSHcomNetTransRequest mReq; + MOCKER_CPP(mSock->GetQueueSpace).stubs().will(returnValue(false)); + SResult res = mSockWorker->PostSend(mSock, mHeader, mReq); + EXPECT_EQ(res, static_cast(SS_SOCK_ADD_QUEUE_FAILED)); +} + +TEST_F(TestSockWorker, TestPostSendNoCtxLeftFail) +{ + SockTransHeader mHeader{}; + UBSHcomNetTransRequest mReq; + MOCKER_CPP(mSock->GetQueueSpace).stubs().will(returnValue(true)); + SResult res = mSockWorker->PostSend(mSock, mHeader, mReq); + EXPECT_EQ(res, static_cast(SS_CTX_FULL)); +} + +TEST_F(TestSockWorker, TestPostSendRawSglSockAddQueueFail) +{ + SockTransHeader mHeader{}; + UBSHcomNetTransSglRequest mReq; + MOCKER_CPP(mSock->GetQueueSpace).stubs().will(returnValue(false)); + SResult res = mSockWorker->PostSendRawSgl(mSock, mHeader, mReq); + EXPECT_EQ(res, static_cast(SS_SOCK_ADD_QUEUE_FAILED)); +} + +TEST_F(TestSockWorker, TestPostSendRawSglNoCtxLeftFail) +{ + SockTransHeader mHeader{}; + UBSHcomNetTransSglRequest mReq; + MOCKER_CPP(mSock->GetQueueSpace).stubs().will(returnValue(true)); + SResult res = mSockWorker->PostSendRawSgl(mSock, mHeader, mReq); + EXPECT_EQ(res, static_cast(SS_CTX_FULL)); +} + +TEST_F(TestSockWorker, TestPostReadSockAddQueueFail) +{ + SockTransHeader mHeader{}; + UBSHcomNetTransRequest mReq; + UBSHcomNetTransSglRequest mSglReq; + MOCKER_CPP(mSock->GetQueueSpace).stubs().will(returnValue(false)); + SResult res = mSockWorker->PostRead(mSock, mHeader, mReq); + EXPECT_EQ(res, static_cast(SS_SOCK_ADD_QUEUE_FAILED)); + res = mSockWorker->PostRead(mSock, mHeader, mSglReq); + EXPECT_EQ(res, static_cast(SS_SOCK_ADD_QUEUE_FAILED)); +} + +TEST_F(TestSockWorker, TestPostReadNoCtxLeftFail) +{ + SockTransHeader mHeader{}; + UBSHcomNetTransRequest mReq; + UBSHcomNetTransSglRequest mSglReq; + MOCKER_CPP(mSock->GetQueueSpace).stubs().will(returnValue(true)); + SResult res = mSockWorker->PostRead(mSock, mHeader, mReq); + EXPECT_EQ(res, static_cast(SS_PARAM_INVALID)); + res = mSockWorker->PostRead(mSock, mHeader, mSglReq); + EXPECT_EQ(res, static_cast(SS_CTX_FULL)); +} + +TEST_F(TestSockWorker, TestPostWriteSockAddQueueFail) +{ + SockTransHeader mHeader{}; + UBSHcomNetTransRequest mReq; + UBSHcomNetTransSglRequest mSglReq; + MOCKER_CPP(mSock->GetQueueSpace).stubs().will(returnValue(false)); + SResult res = mSockWorker->PostWrite(mSock, mHeader, mReq); + EXPECT_EQ(res, static_cast(SS_SOCK_ADD_QUEUE_FAILED)); + res = mSockWorker->PostWrite(mSock, mHeader, mSglReq); + EXPECT_EQ(res, static_cast(SS_SOCK_ADD_QUEUE_FAILED)); +} + +TEST_F(TestSockWorker, TestPostWriteNoCtxLeftFail) +{ + SockTransHeader mHeader{}; + UBSHcomNetTransRequest mReq; + UBSHcomNetTransSglRequest mSglReq; + MOCKER_CPP(mSock->GetQueueSpace).stubs().will(returnValue(true)); + SResult res = mSockWorker->PostWrite(mSock, mHeader, mReq); + EXPECT_EQ(res, static_cast(SS_CTX_FULL)); + res = mSockWorker->PostWrite(mSock, mHeader, mSglReq); + EXPECT_EQ(res, static_cast(SS_CTX_FULL)); +} + +TEST_F(TestSockWorker, TestAddToEpollInvalidFdFail) +{ + MOCKER_CPP(&Sock::FD).stubs().will(returnValue(INVALID_FD)); + SResult res = mSockWorker->AddToEpoll(mSock, 1); + EXPECT_EQ(res, static_cast(SS_PARAM_INVALID)); +} + +TEST_F(TestSockWorker, TestAddToEpollEpollFail) +{ + MOCKER_CPP(&Sock::FD).stubs().will(returnValue(1)); + MOCKER_CPP(&epoll_ctl).stubs().will(returnValue(1)); + SResult res = mSockWorker->AddToEpoll(mSock, 1); + EXPECT_EQ(res, static_cast(SS_SOCK_EPOLL_OP_FAILED)); +} + +TEST_F(TestSockWorker, TestModifyInEpollInvalidFdFail) +{ + MOCKER_CPP(&Sock::FD).stubs().will(returnValue(INVALID_FD)); + SResult res = mSockWorker->ModifyInEpoll(mSock, 1); + EXPECT_EQ(res, static_cast(SS_PARAM_INVALID)); +} + +TEST_F(TestSockWorker, TestModifyInEpollEpollFail) +{ + MOCKER_CPP(&Sock::FD).stubs().will(returnValue(1)); + MOCKER_CPP(&epoll_ctl).stubs().will(returnValue(1)); + errno = ENOENT; + SResult res = mSockWorker->ModifyInEpoll(mSock, 1); + EXPECT_EQ(res, static_cast(SS_SOCK_EPOLL_OP_FAILED)); + + errno = 0; + res = mSockWorker->ModifyInEpoll(mSock, 1); + EXPECT_EQ(res, static_cast(SS_SOCK_EPOLL_OP_FAILED)); +} + +TEST_F(TestSockWorker, TestRemoveFromEpollEpollFail) +{ + MOCKER_CPP(&Sock::FD).stubs().will(returnValue(1)); + MOCKER_CPP(&epoll_ctl).stubs().will(returnValue(1)); + errno = ENOENT; + SResult res = mSockWorker->RemoveFromEpoll(mSock); + EXPECT_EQ(res, static_cast(SS_OK)); + + errno = 0; + res = mSockWorker->RemoveFromEpoll(mSock); + EXPECT_EQ(res, static_cast(SS_SOCK_EPOLL_OP_FAILED)); +} + +TEST_F(TestSockWorker, TestTCPInitializeOpCtxMemPoolFail) +{ + mSockWorker->mType = SOCK_TCP; + MOCKER_CPP(&OpContextInfoPool::Initialize, + NResult(OpContextInfoPool::*)(const NetMemPoolFixedPtr &)) + .stubs() + .will(returnValue(static_cast(SS_ERROR))); + SResult res = mSockWorker->Initialize(); + EXPECT_EQ(res, static_cast(SS_ERROR)); +} + +TEST_F(TestSockWorker, TestTCPInitializeSglCtxMemPoolFail) +{ + mSockWorker->mType = SOCK_TCP; + MOCKER_CPP(&OpContextInfoPool::Initialize, + NResult(OpContextInfoPool::*)(const NetMemPoolFixedPtr &)) + .stubs() + .will(returnValue(static_cast(SS_ERROR))); + SResult res = mSockWorker->Initialize(); + EXPECT_EQ(res, static_cast(SS_ERROR)); +} + +TEST_F(TestSockWorker, TestTCPInitializeHeaderReqMemPoolFail) +{ + mSockWorker->mType = SOCK_TCP; + mSockWorker->mOptions.tcpSendZCopy = true; + MOCKER_CPP(&OpContextInfoPool::Initialize, + NResult(OpContextInfoPool::*)(const NetMemPoolFixedPtr &)) + .stubs() + .will(returnValue(static_cast(SS_ERROR))); + SResult res = mSockWorker->Initialize(); + EXPECT_EQ(res, static_cast(SS_ERROR)); +} + +TEST_F(TestSockWorker, CheckIovLen) +{ + uint16_t iovCount = 0; + SockOpContextInfo opCtx{}; + opCtx.dataSize = 0; + EXPECT_EQ(mSockWorker->CheckIovLen(opCtx, iovCount), false); + + opCtx.dataSize = sizeof(UBSHcomNetTransSglRequest::iovCount); + opCtx.dataAddress = reinterpret_cast(&iovCount); + EXPECT_EQ(mSockWorker->CheckIovLen(opCtx, iovCount), false); + + iovCount = 1; + EXPECT_EQ(mSockWorker->CheckIovLen(opCtx, iovCount), false); +} + +TEST_F(TestSockWorker, PostReadAck) +{ + SockOpContextInfo opCtx{}; + opCtx.sock = mSock; + mSock->mQueueVacantSize = NN_NO100; + mSock->UpContext(NN_NO1); + opCtx.dataSize = 0; + EXPECT_EQ(mSockWorker->PostReadAck(opCtx), SS_PARAM_INVALID); +} + +TEST_F(TestSockWorker, PostWriteAck) +{ + SockOpContextInfo opCtx{}; + opCtx.sock = mSock; + mSock->mQueueVacantSize = NN_NO100; + mSock->UpContext(NN_NO1); + opCtx.dataSize = 0; + EXPECT_EQ(mSockWorker->PostWriteAck(opCtx), SS_PARAM_INVALID); +} + +TEST_F(TestSockWorker, PostWriteSglAck) +{ + SockOpContextInfo opCtx{}; + opCtx.sock = mSock; + mSock->UpContext(NN_NO1); + mSock->mQueueVacantSize = NN_NO100; + opCtx.dataSize = 0; + MOCKER_CPP(&SockWorker::CheckIovLen).stubs().will(returnValue(1)); + EXPECT_EQ(mSockWorker->PostWriteSglAck(opCtx), SS_PARAM_INVALID); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/sock/test_sock_wrapper.cpp b/test/unit_test/transport/sock/test_sock_wrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25032c0a9484512fbe5f3af86401ef3caf838601 --- /dev/null +++ b/test/unit_test/transport/sock/test_sock_wrapper.cpp @@ -0,0 +1,556 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include + +#include "sock_wrapper.h" +#include "openssl_api_wrapper.h" + +namespace ock { +namespace hcom { + +class TestSockWrapper : public testing::Test { +public: + TestSockWrapper(); + virtual void SetUp(void); + virtual void TearDown(void); + Sock *mSock = nullptr; + SockOpContextInfo *ctx = nullptr; + SockSglContextInfo *sendCtx = nullptr; + SockHeaderReqInfo *reqInfo = nullptr; +}; + +TestSockWrapper::TestSockWrapper() {} + +void TestSockWrapper::SetUp() +{ + SockType mT = SOCK_UDS; + std::string mName = "TestSockWrapper"; + uint64_t mId = 1; + int mFd = -1; + SockOptions mSockOptions{}; + mSock = new (std::nothrow) Sock(mT, mName, mId, mFd, mSockOptions); + ASSERT_TRUE(mSock != nullptr); + ctx = new (std::nothrow) SockOpContextInfo(); + ASSERT_TRUE(ctx != nullptr); + sendCtx = new (std::nothrow) SockSglContextInfo(); + ASSERT_TRUE(sendCtx != nullptr); + reqInfo = new (std::nothrow) SockHeaderReqInfo(); + ASSERT_TRUE(reqInfo != nullptr); + ctx->sendCtx = sendCtx; + ctx->headerRequest = reqInfo; +} + +void TestSockWrapper::TearDown() +{ + if (mSock != nullptr) { + delete mSock; + mSock = nullptr; + } + + if (sendCtx != nullptr) { + delete sendCtx; + sendCtx = nullptr; + } + + if (reqInfo != nullptr) { + delete reqInfo; + reqInfo = nullptr; + } + + if (ctx != nullptr) { + delete ctx; + ctx = nullptr; + } + + GlobalMockObject::verify(); +} + +TEST_F(TestSockWrapper, TestInitializeValidateOptionsFail) +{ + mSock->mInited = false; + SockWorkerOptions options; + MOCKER_CPP(&Sock::ValidateOptions).stubs().will(returnValue(1)); + SResult ret = mSock->Initialize(options); + EXPECT_EQ(ret, static_cast(NN_NO1)); +} + +TEST_F(TestSockWrapper, TestInitializeMemAllocateFail) +{ + mSock->mInited = false; + SockWorkerOptions options; + MOCKER_CPP(&Sock::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&Sock::SetSockOption).stubs().will(returnValue(0)); + MOCKER_CPP(&SockBuff::ExpandIfNeed).stubs().will(returnValue(false)); + SResult ret = mSock->Initialize(options); + EXPECT_EQ(ret, SS_MEMORY_ALLOCATE_FAILED); +} + +TEST_F(TestSockWrapper, TestValidateOptions) +{ + mSock->mOptions.receiveBufSizeKB = 0; + SResult ret = mSock->ValidateOptions(); + EXPECT_EQ(ret, SS_OK); +} + +TEST_F(TestSockWrapper, TestSetSockOptionSetRecvBufferFail) +{ + mSock->mFd = 1; + SockWorkerOptions workerOptions{}; + workerOptions.sockReceiveBufKB = 1; + MOCKER_CPP(setsockopt).stubs().will(returnValue(-1)); + SResult ret = mSock->SetSockOption(workerOptions); + EXPECT_EQ(ret, SS_TCP_SET_OPTION_FAILED); + mSock->mFd = -1; +} + +TEST_F(TestSockWrapper, TestSetSockOptionSetSendBufferFail) +{ + mSock->mFd = 1; + SockWorkerOptions workerOptions{}; + workerOptions.sockReceiveBufKB = 0; + workerOptions.sockSendBufKB = 1; + MOCKER_CPP(setsockopt).stubs().will(returnValue(-1)); + SResult ret = mSock->SetSockOption(workerOptions); + EXPECT_EQ(ret, SS_TCP_SET_OPTION_FAILED); + mSock->mFd = -1; +} + +TEST_F(TestSockWrapper, TestSetSockOptionGetSendBufferFail) +{ + mSock->mFd = 1; + SockWorkerOptions workerOptions{}; + workerOptions.sockReceiveBufKB = 0; + workerOptions.sockSendBufKB = 0; + MOCKER_CPP(getsockopt).stubs().will(returnValue(-1)).then(returnValue(0)); + SResult ret = mSock->SetSockOption(workerOptions); + EXPECT_EQ(ret, SS_TCP_GET_OPTION_FAILED); + ret = mSock->SetSockOption(workerOptions); + EXPECT_EQ(ret, SS_TCP_GET_OPTION_FAILED); + mSock->mFd = -1; +} + +TEST_F(TestSockWrapper, TestSetBlockingSendTimeoutFail) +{ + MOCKER_CPP(setsockopt).stubs().will(returnValue(-1)); + SResult ret = mSock->SetBlockingSendTimeout(1); + EXPECT_EQ(ret, SS_TCP_SET_OPTION_FAILED); +} + +TEST_F(TestSockWrapper, TestSetBlockingIoGetControlValueFail) +{ + SResult ret = mSock->SetBlockingIo(); + EXPECT_EQ(ret, SS_TCP_SET_OPTION_FAILED); +} + +TEST_F(TestSockWrapper, TestSetNonBlockingIoGetControlValueFail) +{ + SResult ret = mSock->SetNonBlockingIo(); + EXPECT_EQ(ret, SS_TCP_SET_OPTION_FAILED); +} + +TEST_F(TestSockWrapper, TestSetBlockingIoFail) +{ + UBSHcomEpOptions epOptions{}; + MOCKER_CPP(&Sock::SetBlockingIo, SResult(Sock::*)()).stubs().will(returnValue(0)).then(returnValue(-1)); + MOCKER_CPP(&Sock::SetBlockingSendTimeout).stubs().will(returnValue(-1)); + + SResult ret = mSock->SetBlockingIo(epOptions); + EXPECT_EQ(ret, SS_TCP_SET_OPTION_FAILED); + ret = mSock->SetBlockingIo(epOptions); + EXPECT_EQ(ret, SS_TCP_SET_OPTION_FAILED); +} + +TEST_F(TestSockWrapper, TestGetSendQueueCount) +{ + MOCKER_CPP(&NetRingBuffer::Size).stubs().will(returnValue(1)); + uint32_t ret = mSock->GetSendQueueCount(); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestSockWrapper, TestSendRealConnHeaderParamFail) +{ + SResult ret = mSock->SendRealConnHeader(-1, nullptr, 0); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestSockWrapper, TestPostSendParamFail) +{ + SockOpContextInfo *nullCtx = nullptr; + SResult ret = mSock->PostSend(nullCtx); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestSockWrapper, TestPostSend) +{ + mSock->mTcpBlockingMode = true; + mSock->mOptions.sendZCopy = true; + mSock->mCbByWorkerInBlocking = true; + ssize_t size = NN_NO128; + MOCKER_CPP(&writev).stubs().will(returnValue(size)); + MOCKER_CPP(&NetRingBuffer::PushBack).stubs().will(returnValue(true)); + SResult ret = mSock->PostSend(ctx); + EXPECT_EQ(ret, SS_SOCK_SEND_EAGAIN); +} + +TEST_F(TestSockWrapper, TestPostSendSglParamFail) +{ + SockOpContextInfo *nullCtx = nullptr; + SResult ret = mSock->PostSendSgl(nullCtx); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestSockWrapper, TestPostSendSglFail) +{ + mSock->mTcpBlockingMode = true; + ssize_t size = static_cast(sizeof(SockTransHeader)) - 1; + MOCKER_CPP(&writev).stubs().will(returnValue(0)).then(returnValue(size)); + SResult ret = mSock->PostSendSgl(ctx); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = EAGAIN; + ret = mSock->PostSendSgl(ctx); + EXPECT_EQ(ret, SS_SOCK_SEND_FAILED); + + errno = 0; + ret = mSock->PostSendSgl(ctx); + EXPECT_EQ(ret, SS_TIMEOUT); + + mSock->mEnableTls = true; + ctx->opType = SockOpContextInfo::SockOpType::SS_SEND_RAW_SGL; + ret = mSock->PostSendSgl(ctx); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestSockWrapper, TestPostSendSglHeaderFail) +{ + SockTransHeader header{}; + UBSHcomNetTransSglRequest req{}; + ssize_t size = static_cast(sizeof(SockTransHeader)) - 1; + MOCKER_CPP(&writev).stubs().will(returnValue(0)).then(returnValue(size)); + SResult ret = mSock->PostSendSgl(header, req); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = 0; + ret = mSock->PostSendSgl(header, req); + EXPECT_EQ(ret, SS_TIMEOUT); +} + +TEST_F(TestSockWrapper, TestPostSendSglHeaderTlsFail) +{ + SockTransHeader header{}; + UBSHcomNetTransSglRequest req{}; + mSock->mEnableTls = true; + ssize_t size = static_cast(sizeof(SockTransHeader)) - 1; + MOCKER_CPP(&writev).stubs().will(returnValue(0)).then(returnValue(size)); + SResult ret = mSock->PostSendSgl(header, req); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = EAGAIN; + ret = mSock->PostSendSgl(header, req); + EXPECT_EQ(ret, SS_SOCK_SEND_FAILED); + + errno = 0; + ret = mSock->PostSendSgl(header, req); + EXPECT_EQ(ret, SS_TIMEOUT); +} + +TEST_F(TestSockWrapper, TestPostSendSglHeaderSSLSendFail) +{ + SockTransHeader header{}; + UBSHcomNetTransSgeIov *iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + UBSHcomNetTransSglRequest req = UBSHcomNetTransSglRequest(iov, 1, 1); + mSock->mEnableTls = true; + ssize_t size = NN_NO128; + MOCKER_CPP(&writev).stubs().will(returnValue(size)); + MOCKER_CPP(&Sock::SSLSend).stubs() + .will(returnValue(static_cast(SS_OOB_SSL_WRITE_ERROR))) + .then(returnValue(static_cast(SS_TIMEOUT))); + SResult ret = mSock->PostSendSgl(header, req); + EXPECT_EQ(ret, SS_SOCK_SEND_FAILED); + + ret = mSock->PostSendSgl(header, req); + EXPECT_EQ(ret, SS_TIMEOUT); + + if (iov != nullptr) { + delete iov; + iov = nullptr; + } +} + +TEST_F(TestSockWrapper, TestPostSendHeadParamFail) +{ + SockOpContextInfo *nullCtx = nullptr; + SResult ret = mSock->PostSendHead(nullCtx); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestSockWrapper, TestPostSendHeadFail) +{ + mSock->mTcpBlockingMode = true; + ssize_t size = static_cast(sizeof(SockTransHeader)) - 1; + MOCKER_CPP(::send).stubs().will(returnValue(0)).then(returnValue(size)); + SResult ret = mSock->PostSendHead(ctx); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = EAGAIN; + ret = mSock->PostSendHead(ctx); + EXPECT_EQ(ret, SS_SOCK_SEND_FAILED); + + errno = 0; + ret = mSock->PostSendHead(ctx); + EXPECT_EQ(ret, SS_TIMEOUT); +} + +TEST_F(TestSockWrapper, TestPostSendHeadSuccess) +{ + mSock->mTcpBlockingMode = true; + ssize_t size = NN_NO128; + MOCKER_CPP(::send).stubs().will(returnValue(size)); + SResult ret = mSock->PostSendHead(ctx); + EXPECT_EQ(ret, SS_OK); +} + +TEST_F(TestSockWrapper, TestPostReadParamFail) +{ + SockOpContextInfo *nullCtx = nullptr; + SResult ret = mSock->PostRead(nullCtx); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestSockWrapper, TestPostReadFail) +{ + mSock->mTcpBlockingMode = true; + ssize_t size = static_cast(sizeof(SockTransHeader)) - 1; + MOCKER_CPP(&writev).stubs().will(returnValue(0)).then(returnValue(size)); + SResult ret = mSock->PostRead(ctx); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = 0; + ret = mSock->PostRead(ctx); + EXPECT_EQ(ret, SS_TIMEOUT); +} + +TEST_F(TestSockWrapper, TestPostWriteParamFail) +{ + SockOpContextInfo *nullCtx = nullptr; + SResult ret = mSock->PostWrite(nullCtx); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestSockWrapper, TestPostWriteFail) +{ + mSock->mTcpBlockingMode = true; + ssize_t size = static_cast(sizeof(SockTransHeader)) - 1; + MOCKER_CPP(&writev).stubs().will(returnValue(0)).then(returnValue(size)); + SResult ret = mSock->PostWrite(ctx); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = EAGAIN; + ret = mSock->PostWrite(ctx); + EXPECT_EQ(ret, SS_SOCK_SEND_FAILED); + + errno = 0; + ret = mSock->PostWrite(ctx); + EXPECT_EQ(ret, SS_TIMEOUT); +} + +TEST_F(TestSockWrapper, TestPostReceiveHeaderFail) +{ + SockTransHeader header {}; + mSock->mRevTimeoutSecond = -1; + + MOCKER_CPP(setsockopt).stubs().will(returnValue(-1)).then(returnValue(0)); + MOCKER_CPP(::recv).stubs().will(returnValue(0)); + + SResult ret = mSock->PostReceiveHeader(header, 1); + EXPECT_EQ(ret, SS_TCP_SET_OPTION_FAILED); + + ret = mSock->PostReceiveHeader(header, 1); + EXPECT_EQ(ret, SS_SOCK_SEND_FAILED); +} + +TEST_F(TestSockWrapper, TestPostReceiveBodyParamFail) +{ + void *buff = nullptr; + uint32_t dataLength = 0; + bool isOneSide = true; + + SResult ret = mSock->PostReceiveBody(buff, dataLength, isOneSide); + EXPECT_EQ(ret, SS_PARAM_INVALID); + + buff = malloc(NN_NO1024); + ret = mSock->PostReceiveBody(buff, dataLength, isOneSide); + EXPECT_EQ(ret, SS_PARAM_INVALID); + + free(buff); +} + +TEST_F(TestSockWrapper, TestPostReceiveBodyFail) +{ + void *buff = malloc(NN_NO1024); + uint32_t dataLength = NN_NO1024; + bool isOneSide = true; + mSock->mEnableTls = false; + MOCKER_CPP(::recv).stubs().will(returnValue(0)); + + errno = EAGAIN; + SResult ret = mSock->PostReceiveBody(buff, dataLength, isOneSide); + EXPECT_EQ(ret, SS_TIMEOUT); + free(buff); +} + +TEST_F(TestSockWrapper, TestPostReceiveBodyTlsFail) +{ + void *buff = malloc(NN_NO1024); + uint32_t dataLength = NN_NO1024; + bool isOneSide = false; + mSock->mEnableTls = true; + MOCKER_CPP(&Sock::SSLRead).stubs() + .will(returnValue(static_cast(SS_TIMEOUT))) + .then(returnValue(static_cast(SS_SSL_READ_FAILED))); + + SResult ret = mSock->PostReceiveBody(buff, dataLength, isOneSide); + EXPECT_EQ(ret, SS_TIMEOUT); + + ret = mSock->PostReceiveBody(buff, dataLength, isOneSide); + EXPECT_EQ(ret, SS_SSL_READ_FAILED); + + free(buff); +} + +TEST_F(TestSockWrapper, TestPostReadSglParamFail) +{ + SockOpContextInfo *nullCtx = nullptr; + SResult ret = mSock->PostReadSgl(nullCtx); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestSockWrapper, TestPostReadSglFail) +{ + mSock->mTcpBlockingMode = true; + ssize_t size = NN_NO1; + MOCKER_CPP(&writev).stubs().will(returnValue(0)).then(returnValue(size)); + SResult ret = mSock->PostReadSgl(ctx); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = 0; + ret = mSock->PostReadSgl(ctx); + EXPECT_EQ(ret, SS_TIMEOUT); +} + +TEST_F(TestSockWrapper, TestPostReadSglAckParamFail) +{ + SockOpContextInfo *nullCtx = nullptr; + SResult ret = mSock->PostReadSglAck(nullCtx); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestSockWrapper, TestPostReadSglAckFail) +{ + mSock->mTcpBlockingMode = true; + ssize_t size = NN_NO1; + MOCKER_CPP(&writev).stubs().will(returnValue(0)).then(returnValue(size)); + SResult ret = mSock->PostReadSglAck(ctx); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = EAGAIN; + ret = mSock->PostReadSglAck(ctx); + EXPECT_EQ(ret, SS_SOCK_SEND_FAILED); + + errno = 0; + ret = mSock->PostReadSglAck(ctx); + EXPECT_EQ(ret, SS_TIMEOUT); +} + +TEST_F(TestSockWrapper, TestPostReadSglAckSuccess) +{ + mSock->mTcpBlockingMode = true; + ssize_t size = NN_NO128; + MOCKER_CPP(&writev).stubs().will(returnValue(size)); + SResult ret = mSock->PostReadSglAck(ctx); + EXPECT_EQ(ret, SS_OK); +} + +TEST_F(TestSockWrapper, TestPostWriteSglParamFail) +{ + SockOpContextInfo *nullCtx = nullptr; + SResult ret = mSock->PostWriteSgl(nullCtx); + EXPECT_EQ(ret, SS_PARAM_INVALID); +} + +TEST_F(TestSockWrapper, TestPostWriteSglFail) +{ + mSock->mTcpBlockingMode = true; + ssize_t size = NN_NO1; + MOCKER_CPP(&writev).stubs().will(returnValue(0)).then(returnValue(size)); + SResult ret = mSock->PostWriteSgl(ctx); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = 0; + ret = mSock->PostWriteSgl(ctx); + EXPECT_EQ(ret, SS_TIMEOUT); +} + +TEST_F(TestSockWrapper, TestSSLSendOpenssl) +{ + uint32_t readLen = 0; + MOCKER_CPP(HcomSsl::SslWrite).stubs().will(returnValue(0)); + MOCKER_CPP(HcomSsl::SslGetError).stubs().will(returnValue(0)); + + SResult ret = mSock->SSLSend(nullptr, 0, readLen); + EXPECT_EQ(ret, SS_OOB_SSL_WRITE_ERROR); +} + +TEST_F(TestSockWrapper, TestSSLReadOpenssl) +{ + uint32_t readLen = 0; + MOCKER_CPP(HcomSsl::SslRead).stubs().will(returnValue(0)); + MOCKER_CPP(HcomSsl::SslGetError).stubs().will(returnValue(0)); + + SResult ret = mSock->SSLRead(nullptr, 0, readLen); + EXPECT_EQ(ret, SS_SSL_READ_FAILED); +} + +TEST_F(TestSockWrapper, TestPostSendSglSsl) +{ + struct iovec iov[NN_NO5]; + SResult ret = mSock->PostSendSglSsl(ctx, iov); + EXPECT_EQ(ret, SS_SOCK_SEND_FAILED); + + ssize_t size = NN_NO128; + MOCKER_CPP(&writev).stubs() + .will(returnValue(static_cast(0))) + .then(returnValue(sizeof(SockTransHeader) - 1)) + .then(returnValue(size)); + ret = mSock->PostSendSglSsl(ctx, iov); + EXPECT_EQ(ret, SS_TCP_RETRY); + + errno = 0; + ret = mSock->PostSendSglSsl(ctx, iov); + EXPECT_EQ(ret, SS_TIMEOUT); + + MOCKER_CPP(&Sock::SSLSend).stubs() + .will(returnValue(static_cast(SS_OOB_SSL_WRITE_ERROR))) + .then(returnValue(static_cast(SS_TIMEOUT))); + errno = 0; + ctx->sendCtx->iovCount = 1; + ret = mSock->PostSendSglSsl(ctx, iov); + EXPECT_EQ(ret, SS_SOCK_SEND_FAILED); + + ret = mSock->PostSendSglSsl(ctx, iov); + EXPECT_EQ(ret, SS_TIMEOUT); + + ctx->sendBuff = nullptr; +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/test_net_driver.cpp b/test/unit_test/transport/test_net_driver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8421929d54718f722260b09920eba3fb8a1ebecd --- /dev/null +++ b/test/unit_test/transport/test_net_driver.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + +#include +#include + +#include "hcom.h" +#include "net_common.h" + +namespace ock { +namespace hcom { + +class TestNetDriver : public testing::Test { +public: + virtual void SetUp(void) + { + } + + virtual void TearDown(void) + { + GlobalMockObject::verify(); + } +}; + +TEST_F(TestNetDriver, VersionParseFailed) +{ + MOCKER(NetFunc::NN_SplitStr).stubs().will(ignoreReturnValue()); + + UBSHcomNetDriver *driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "test", false); + EXPECT_EQ(driver, nullptr); + + driver = UBSHcomNetDriver::Instance(UBSHcomNetDriverProtocol::TCP, "test", false); + EXPECT_EQ(driver, nullptr); +} +} +} diff --git a/test/unit_test/transport/test_net_driver_options.cpp b/test/unit_test/transport/test_net_driver_options.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12e1a5b885c88db89e2a447938914453172a0a6b --- /dev/null +++ b/test/unit_test/transport/test_net_driver_options.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +// Author: zhiwei + +#include + +#include "hcom.h" + +namespace ock { +namespace hcom { + +class TestNetDriverOptions : public testing::Test { +public: + void SetUp() override + { + } + + void TearDown() override + { + } +}; + +TEST_F(TestNetDriverOptions, SetNetDeviceEidFailed) +{ + UBSHcomNetDriverOptions opts; + EXPECT_FALSE(opts.SetNetDeviceEid("length < 32")); + EXPECT_FALSE(opts.SetNetDeviceEid("0000:0000:0000:0000:0000:ffff:0102:xxyy")); +} + +TEST_F(TestNetDriverOptions, SetNetDeviceEidOk) +{ + UBSHcomNetDriverOptions opts; + EXPECT_TRUE(opts.SetNetDeviceEid("0000:0000:0000:0000:0000:ffff:0102:0304")); + EXPECT_EQ(opts.netDeviceEid[0], 0x00); + EXPECT_EQ(opts.netDeviceEid[9], 0x00); + EXPECT_EQ(opts.netDeviceEid[10], 0xff); + EXPECT_EQ(opts.netDeviceEid[11], 0xff); + EXPECT_EQ(opts.netDeviceEid[12], 0x01); + EXPECT_EQ(opts.netDeviceEid[13], 0x02); + EXPECT_EQ(opts.netDeviceEid[14], 0x03); + EXPECT_EQ(opts.netDeviceEid[15], 0x04); +} + +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/transport/test_net_oob.cpp b/test/unit_test/transport/test_net_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e172aab3899ba0b88f918dd02f5f874caf3e15c0 --- /dev/null +++ b/test/unit_test/transport/test_net_oob.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "rdma_worker.h" +#include "transport/net_delay_release_timer.h" +#include "transport/net_heartbeat.h" +#include "transport/net_load_balance.h" +#include "transport/rdma/rdma_common.h" +#include "transport/rdma/verbs/net_rdma_async_endpoint.h" +#include "net_sock_driver_oob.h" + +namespace ock { +namespace hcom { + +class TestNetOob : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); +}; + +void TestNetOob::SetUp() +{ +} + +void TestNetOob::TearDown() +{ + GlobalMockObject::verify(); +} + +void FakeThread(OOBTCPServer *This) +{ + return; +} + +TEST_F(TestNetOob, OOBTCPServerStart) +{ + ock::hcom::OOBTCPServer oobServer("192.168.100.204", 9444); + + oobServer.mStarted = true; + std::thread tmpThread(&FakeThread, &oobServer); + oobServer.mAcceptThread = std::move(tmpThread); + EXPECT_EQ(oobServer.Start(), 0); + + MOCKER_CPP(OOBTCPServer::CreateAndConfigSocket).stubs().will(returnValue(0)); + MOCKER_CPP(OOBTCPServer::BindAndListenAuto).stubs().will(returnValue(0)); + oobServer.mIsAutoPortSelectionEnabled = true; + EXPECT_EQ(oobServer.CreateAndStartSocket(), 0); +} + +TEST_F(TestNetOob, OOBTCPServerStop) +{ + ock::hcom::OOBTCPServer oobServer("192.168.100.204", 9444); + oobServer.mStarted = true; + std::thread tmpThread(&FakeThread, &oobServer); + oobServer.mAcceptThread = std::move(tmpThread); + oobServer.mOobType = NET_OOB_UDS; + oobServer.mUdsPerm = 600; + EXPECT_EQ(oobServer.Stop(), static_cast(NN_INVALID_PARAM)); + oobServer.mStarted = false; +} + +TEST_F(TestNetOob, OOBTCPServerStop1) +{ + ock::hcom::OOBTCPServer oobServer("192.168.100.204", 9444); + oobServer.mStarted = true; + std::thread tmpThread(&FakeThread, &oobServer); + oobServer.mAcceptThread = std::move(tmpThread); + oobServer.mOobType = NET_OOB_UDS; + oobServer.mUdsPerm = 600; + MOCKER_CPP(CanonicalPath).stubs().will(returnValue(true)); + + EXPECT_EQ(oobServer.Stop(), static_cast(NN_INVALID_PARAM)); + oobServer.mStarted = false; +} + +static ssize_t MockConnSend(int socket, void const *buf, size_t size, int flags) +{ + return -1; +} + +TEST_F(TestNetOob, OOBTCPServerFunc) +{ + ock::hcom::OOBTCPConnection oobConn(-1); + + void *buf = malloc(1); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + errno = ENOMEM; + EXPECT_EQ(oobConn.Send(buf, 1), static_cast(NN_OOB_CONN_SEND_ERROR)); + + if (buf != nullptr) { + free(buf); + buf = nullptr; + } +} + +TEST_F(TestNetOob, ConnectWithFdSocket) +{ + int fd = 0; + int err = 0; + + MOCKER(::socket).stubs().will(returnValue(static_cast(-1))); + + err = OOBTCPClient::ConnectWithFd("127.0.0.1", 2233, fd); + EXPECT_EQ(err, NN_OOB_CLIENT_SOCKET_ERROR); +} + +TEST_F(TestNetOob, ConnectWithFdConnect) +{ + int fd = 0; + int err = 0; + + MOCKER(::sleep).stubs().will(returnValue(static_cast(0))); + MOCKER(::connect).stubs().will(returnValue(static_cast(-1))); + err = OOBTCPClient::ConnectWithFd("127.0.0.1", 2233, fd); + EXPECT_EQ(err, NN_OOB_CLIENT_SOCKET_ERROR); +} + +TEST_F(TestNetOob, ConnectWithFdRecv) +{ + int fd = 0; + int err = 0; + + MOCKER(::sleep).stubs().will(returnValue(static_cast(0))); + MOCKER(::connect).stubs().will(returnValue(static_cast(0))); + MOCKER(::recv) + .stubs() + .will(returnValue(static_cast(-1))) + .then(returnValue(static_cast(4))); + + err = OOBTCPClient::ConnectWithFd("127.0.0.1", 2233, fd); + EXPECT_EQ(err, NN_OK); +} +} // namespace hcom +} // namespace ock diff --git a/test/unit_test/transport/test_net_oob_secure.cpp b/test/unit_test/transport/test_net_oob_secure.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d608607b082ac3c6e688c3af27c1924014e263c7 --- /dev/null +++ b/test/unit_test/transport/test_net_oob_secure.cpp @@ -0,0 +1,250 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "transport/net_oob_secure.h" + +namespace ock { +namespace hcom { + +class TestNetOobSecure : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); +}; + +void TestNetOobSecure::SetUp() +{ +} + +void TestNetOobSecure::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestNetOobSecure, SecProcessCompareEpNum) +{ + OOBSecureProcess proc {}; + std::vector oobServers; + NetOOBServer *server = new (std::nothrow) OOBTCPServer("127.0.0.1", 9980); + server->mStarted = false; + oobServers.push_back(server); + EXPECT_EQ(proc.SecProcessCompareEpNum(0, 0, "0", oobServers), static_cast(NN_OK)); + EXPECT_EQ(proc.SecProcessCompareEpNum("udsName", "IpPort", oobServers), static_cast(NN_OK)); + EXPECT_NO_FATAL_FAILURE(proc.SecProcessAddEpNum(0, 0, "0", oobServers)); + EXPECT_NO_FATAL_FAILURE(proc.SecProcessAddEpNum("udsName", "IpPort", oobServers)); + EXPECT_NO_FATAL_FAILURE(proc.SecProcessDelEpNum(0, 0, "0", oobServers)); + + if (server != nullptr) { + delete server; + server = nullptr; + } +} + +TEST_F(TestNetOobSecure, SecProcessDelEpNum) +{ + OOBSecureProcess proc {}; + std::vector oobServers; + NetOOBServer *server = new (std::nothrow) OOBTCPServer("127.0.0.1", 9980); + server->mStarted = false; + oobServers.push_back(server); + EXPECT_NO_FATAL_FAILURE(proc.SecProcessDelEpNum("udsName", "IpPort", oobServers)); + + server->mStarted = true; + server->mOobType = NET_OOB_UDS; + server->mUdsName = "udsName"; + EXPECT_NO_FATAL_FAILURE(proc.SecProcessDelEpNum("udsName", "IpPort", oobServers)); +} + +TEST_F(TestNetOobSecure, SecProcessInOOBServer) +{ + OOBSecureProcess proc {}; + UBSHcomNetDriverEndpointSecInfoProvider provider = nullptr; + UBSHcomNetDriverEndpointSecInfoValidator validator = nullptr; + OOBTCPConnection *tcpConn = new (std::nothrow) OOBTCPConnection(-1); + MOCKER_CPP(OOBSecureProcess::ValidateSecInfo).stubs() + .will(returnValue(static_cast(NN_OOB_SEC_PROCESS_ERROR))) + .then(returnValue(static_cast(NN_OK))); + MOCKER_CPP_VIRTUAL(*tcpConn, &OOBTCPConnection::Send) + .stubs() + .will(returnValue(static_cast((NN_ERROR)))); + EXPECT_EQ(proc.SecProcessInOOBServer(provider, validator, *tcpConn, "name", + UBSHcomNetDriverSecType::NET_SEC_DISABLED), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + EXPECT_EQ(proc.SecProcessInOOBServer(provider, validator, *tcpConn, "name", + UBSHcomNetDriverSecType::NET_SEC_DISABLED), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + + if (tcpConn != nullptr) { + delete tcpConn; + tcpConn = nullptr; + } +} + +TEST_F(TestNetOobSecure, SecProcessInOOBClient) +{ + OOBSecureProcess proc {}; + UBSHcomNetDriverEndpointSecInfoProvider provider = nullptr; + UBSHcomNetDriverEndpointSecInfoValidator validator = nullptr; + OOBTCPConnection *tcpConn = new (std::nothrow) OOBTCPConnection(-1); + MOCKER_CPP_VIRTUAL(*tcpConn, &OOBTCPConnection::Send) + .stubs() + .will(returnValue(static_cast((NN_ERROR)))) + .then(returnValue(static_cast((NN_OK)))); + EXPECT_EQ(proc.SecProcessInOOBClient(provider, validator, tcpConn, "name", 0, + UBSHcomNetDriverSecType::NET_SEC_DISABLED), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + + MOCKER_CPP_VIRTUAL(*tcpConn, &OOBTCPConnection::Receive) + .stubs() + .will(returnValue(static_cast((NN_ERROR)))); + + EXPECT_EQ(proc.SecProcessInOOBClient(provider, validator, tcpConn, "name", 0, + UBSHcomNetDriverSecType::NET_SEC_DISABLED), + static_cast(NN_ERROR)); + + if (tcpConn != nullptr) { + delete tcpConn; + tcpConn = nullptr; + } +} + +TEST_F(TestNetOobSecure, SendSecInfo) +{ + OOBSecureProcess proc {}; + UBSHcomNetDriverEndpointSecInfoProvider provider = + [](uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, uint32_t &outLen, + bool &needAutoFree) { + outLen = NN_NO2147483646 + 1; + return 0; + }; + UBSHcomNetDriverEndpointSecInfoValidator validator = nullptr; + OOBTCPConnection *tcpConn = new (std::nothrow) OOBTCPConnection(-1); + UBSHcomNetDriverSecType type = UBSHcomNetDriverSecType::NET_SEC_DISABLED; + EXPECT_EQ(proc.SendSecInfo(provider, validator, nullptr, "name", type, 0), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + + EXPECT_EQ(proc.SendSecInfo(provider, validator, tcpConn, "name", type, 0), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + + provider = + [](uint64_t ctx, int64_t &flag, UBSHcomNetDriverSecType &type, char *&output, uint32_t &outLen, + bool &needAutoFree) { + outLen = 0; + type = UBSHcomNetDriverSecType::NET_SEC_VALID_ONE_WAY; + return 0; + }; + MOCKER_CPP_VIRTUAL(*tcpConn, &OOBTCPConnection::Send) + .stubs() + .will(returnValue(static_cast((NN_ERROR)))) + .then(returnValue(static_cast((NN_OK)))) + .then(returnValue(static_cast((NN_ERROR)))); + + EXPECT_EQ(proc.SendSecInfo(provider, validator, tcpConn, "name", type, 0), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + + EXPECT_EQ(proc.SendSecInfo(provider, validator, tcpConn, "name", type, 0), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + + if (tcpConn != nullptr) { + delete tcpConn; + tcpConn = nullptr; + } +} + +NResult FakeReceive(void *&buf, uint32_t size) +{ + ConnSecHeader *header = static_cast(buf); + header->type = 3; + return 0; +} + +NResult FakeReceive1(void *&buf, uint32_t size) +{ + ConnSecHeader *header = static_cast(buf); + header->secInfoLen = NN_NO2147483646 + 1; + header->type = UBSHcomNetDriverSecType::NET_SEC_VALID_TWO_WAY; + return 0; +} + +NResult FakeReceive2(void *&buf, uint32_t size) +{ + ConnSecHeader *header = static_cast(buf); + header->type = UBSHcomNetDriverSecType::NET_SEC_VALID_TWO_WAY; + return 0; +} + +TEST_F(TestNetOobSecure, ValidateSecInfo) +{ + OOBSecureProcess proc {}; + OOBTCPConnection *tcpConn = new (std::nothrow) OOBTCPConnection(-1); + UBSHcomNetDriverSecType type = UBSHcomNetDriverSecType::NET_SEC_DISABLED; + UBSHcomNetDriverEndpointSecInfoProvider provider = nullptr; + UBSHcomNetDriverEndpointSecInfoValidator validator = nullptr; + + MOCKER_CPP_VIRTUAL(*tcpConn, &OOBTCPConnection::Receive) + .stubs() + .will(invoke(FakeReceive)) + .then(returnValue(0)) + .then(invoke(FakeReceive1)) + .then(invoke(FakeReceive2)) + .then(returnValue(1)); + EXPECT_EQ(proc.SendSecInfo(provider, validator, tcpConn, "name", type, 0), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + EXPECT_EQ(proc.SendSecInfo(provider, validator, tcpConn, "name", type, 0), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + EXPECT_EQ(proc.SendSecInfo(provider, validator, tcpConn, "name", type, 0), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + EXPECT_EQ(proc.SendSecInfo(provider, validator, tcpConn, "name", type, 0), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); + EXPECT_EQ(proc.SendSecInfo(provider, validator, tcpConn, "name", type, 0), + static_cast(NN_OOB_SEC_PROCESS_ERROR)); +} + +TEST_F(TestNetOobSecure, SecCheckConnectionHeader) +{ + OOBSecureProcess proc {}; + ConnectHeader header {}; + UBSHcomNetDriverOptions options {}; + bool enableTls = true; + UBSHcomNetDriverProtocol protocol = UBSHcomNetDriverProtocol::TCP; + uint32_t majorVersion = 1; + uint32_t minorVersion = 0; + ConnRespWithUId resp {}; + + options.magic = header.magic; + header.protocol = UBSHcomNetDriverProtocol::RDMA; + EXPECT_EQ(proc.SecCheckConnectionHeader(header, options, enableTls, protocol, majorVersion, minorVersion, resp), + static_cast(NN_ERROR)); + + header.protocol = UBSHcomNetDriverProtocol::TCP; + header.majorVersion = majorVersion + 1; + EXPECT_EQ(proc.SecCheckConnectionHeader(header, options, enableTls, protocol, majorVersion, minorVersion, resp), + static_cast(VERSION_MISMATCH)); + + header.majorVersion = majorVersion; + header.minorVersion = minorVersion + 1; + EXPECT_EQ(proc.SecCheckConnectionHeader(header, options, enableTls, protocol, majorVersion, minorVersion, resp), + static_cast(VERSION_MISMATCH)); + + header.minorVersion = minorVersion; + header.tlsVersion = TLS_1_3 + 1; + EXPECT_EQ(proc.SecCheckConnectionHeader(header, options, enableTls, protocol, majorVersion, minorVersion, resp), + static_cast(NN_ERROR)); +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/test_net_oob_ssl.cpp b/test/unit_test/transport/test_net_oob_ssl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..07bb601cf13555c7778d1e96add3c2728987b547 --- /dev/null +++ b/test/unit_test/transport/test_net_oob_ssl.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "net_security_rand.h" +#include "transport/net_oob_ssl.h" + +namespace ock { +namespace hcom { + +class TestNetOobSsl : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); +}; + +void TestNetOobSsl::SetUp() +{ +} + +void TestNetOobSsl::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestNetOobSsl, RunInThread) +{ + NetDriverOobType type = NetDriverOobType::NET_OOB_TCP; + std::string name = "test"; + UBSHcomTLSPrivateKeyCallback pkCallback = nullptr; + UBSHcomTLSCertificationCallback ccCallback = nullptr; + UBSHcomTLSCaCallback caCallback = nullptr; + OOBSSLServer server {type, name, 0, pkCallback, ccCallback, caCallback}; + server.mOobType = NET_OOB_UDS; + MOCKER_CPP(std::atomic::load).stubs() + .will(returnValue(false)) + .then(returnValue(true)); + server.mNeedStop = true; + NetExecutorServicePtr es = new (std::nothrow) NetExecutorService(0, 0); + server.mEs = es; + EXPECT_NO_FATAL_FAILURE(server.RunInThread()); +} + +TEST_F(TestNetOobSsl, SendSecret) +{ + OOBSSLConnection conn {-1}; + MOCKER_CPP(NetSecrets::Init).stubs() + .will(returnValue(false)) + .then(returnValue(true)); + + EXPECT_EQ(conn.SendSecret(), static_cast(NN_ERROR)); + + MOCKER_CPP(NetSecrets::Serialize).stubs() + .will(returnValue(false)) + .then(returnValue(true)); + EXPECT_EQ(conn.SendSecret(), static_cast(NN_OOB_SSL_INIT_ERROR)); + + EXPECT_EQ(conn.SendSecret(), static_cast(NN_OOB_CONN_SEND_ERROR)); +} + +TEST_F(TestNetOobSsl, RecvSecret) +{ + OOBSSLConnection *conn = new (std::nothrow) OOBSSLConnection (-1); + MOCKER_CPP(NetSecrets::Init).stubs() + .will(returnValue(false)) + .then(returnValue(true)); + + EXPECT_EQ(conn->RecvSecret(), static_cast(NN_ERROR)); + + MOCKER_CPP(NetSecrets::Deserialize).stubs() + .will(returnValue(false)); + + OOBTCPConnection *tcpConn = static_cast(conn); + MOCKER_CPP_VIRTUAL(*tcpConn, &OOBTCPConnection::Receive) + .stubs() + .will(returnValue(static_cast(NN_OK))); + + EXPECT_EQ(conn->RecvSecret(), static_cast(NN_OOB_SSL_INIT_ERROR)); + + if (conn != nullptr) { + delete conn; + conn = nullptr; + } +} + +TEST_F(TestNetOobSsl, SSLClientRecvHandler) +{ + OOBSSLConnection conn {-1}; + EXPECT_EQ(conn.SSLClientRecvHandler(-1), static_cast(NN_ERROR)); +} + +TEST_F(TestNetOobSsl, TlsConnectCbTaskRun) +{ + TlsConnectCbTask task {nullptr, -1, nullptr}; + MOCKER(::send).stubs().will(returnValue(0)).then(returnValue(1)); + EXPECT_NO_FATAL_FAILURE(task.Run()); + EXPECT_NO_FATAL_FAILURE(task.Run()); +} + +TEST_F(TestNetOobSsl, TestSslRand) +{ + EXPECT_EQ(SecurityRandGenerator::SslRand(nullptr, 0), false); + SSLAPI::randPrivBytes = [](unsigned char *buf, int num) { return 0; }; + SSLAPI::randStatus = []() { return 0; }; + SSLAPI::randPoll = []() { return 0; }; + void *out = malloc(1); + EXPECT_EQ(SecurityRandGenerator::SslRand(out, 1), false); + if (out != nullptr) { + free(out); + out = nullptr; + } +} + +TEST_F(TestNetOobSsl, NetSecretsInitSSLRandSecret) +{ + NetSecrets secret {}; + SSLAPI::randPrivBytes = [](unsigned char *buf, int num) { return 1; }; + SSLAPI::randStatus = []() { return 1; }; + SSLAPI::randPoll = []() { return 1; }; + EXPECT_EQ(secret.InitSSLRandSecret(), false); + secret.mKeySecretLen = 1; + EXPECT_EQ(secret.InitSSLRandSecret(), true); + + MOCKER_CPP(SecurityRandGenerator::SslRand).stubs() + .will(returnValue(true)) + .then(returnValue(false)) + .then(returnValue(true)) + .then(returnValue(true)) + .then(returnValue(false)) + .then(returnValue(true)); + + EXPECT_EQ(secret.InitSSLRandSecret(), false); + EXPECT_EQ(secret.InitSSLRandSecret(), false); + EXPECT_EQ(secret.InitSSLRandSecret(), true); +} + +TEST_F(TestNetOobSsl, NetSecretsSerialize) +{ + NetSecrets secret {}; + + EXPECT_EQ(secret.Serialize(nullptr, 0), false); + + size_t len = sizeof(uint8_t); + char *dest = static_cast(malloc(len)); + EXPECT_EQ(secret.Serialize(dest, len), false); + + if (dest != nullptr) { + free(dest); + dest = nullptr; + } +} + +TEST_F(TestNetOobSsl, NetSecretsDeserialize) +{ + NetSecrets secret {}; + + EXPECT_EQ(secret.Deserialize(nullptr, 0), false); + + size_t len = sizeof(uint8_t); + char *dest = static_cast(malloc(len)); + EXPECT_EQ(secret.Deserialize(dest, len), false); + if (dest != nullptr) { + free(dest); + dest = nullptr; + } +} +} +} \ No newline at end of file diff --git a/test/unit_test/transport/test_transport_common.cpp b/test/unit_test/transport/test_transport_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a083e25d88f878e74f7e62571887d78aff2b48aa --- /dev/null +++ b/test/unit_test/transport/test_transport_common.cpp @@ -0,0 +1,251 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "rdma_worker.h" +#include "transport/net_delay_release_timer.h" +#include "transport/net_heartbeat.h" +#include "transport/net_load_balance.h" +#include "transport/rdma/rdma_common.h" +#include "transport/rdma/verbs/net_rdma_async_endpoint.h" +#include "transport/ub/ub_urma_wrapper_jetty.h" +#include "transport/ub/ub_worker.h" +#include "transport/ub/net_ub_endpoint.h" +#include "transport/ub/net_ub_driver_oob.h" +#include "net_sock_driver_oob.h" + +namespace ock { +namespace hcom { + +class TestTransportCommon : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); +}; + +void TestTransportCommon::SetUp() +{ +} + +void TestTransportCommon::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestTransportCommon, NetDelayReleaseTimerStarted) +{ + std::string name = "timer_name"; + NetDelayReleaseTimer *timer = new (std::nothrow) NetDelayReleaseTimer(name, 0); + + ASSERT_NE(timer, nullptr); + timer->mStarted = true; + EXPECT_EQ(timer->Start(), static_cast(NN_OK)); + EXPECT_NO_FATAL_FAILURE(timer->Stop()); + + if (timer != nullptr) { + delete timer; + timer = nullptr; + } +} + +TEST_F(TestTransportCommon, NetDelayReleaseTimerStartedThread) +{ + MOCKER(epoll_create).defaults().will(returnValue(-1)); + + std::string name = "timer_name"; + NetDelayReleaseTimer *timer = new (std::nothrow) NetDelayReleaseTimer(name, 0); + + ASSERT_NE(timer, nullptr); + EXPECT_NO_FATAL_FAILURE(timer->RunDelayReleaseThread()); + + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(0, nullptr, nullptr, workerIndex); + auto epRes = NetDelayReleaseResource(ep, NN_NO1); + epRes.mTimeout = 0; + timer->mDelayReleaseQueue.push(epRes); + EXPECT_NO_FATAL_FAILURE(timer->DequeueDelayRelease()); + + auto epRes1 = NetDelayReleaseResource(ep, NN_NO1); + epRes1.mTimeout = 0; + epRes1.mEp.Set(nullptr); + timer->mDelayReleaseQueue.push(epRes1); + EXPECT_NO_FATAL_FAILURE(timer->DequeueDelayRelease()); + + EXPECT_NO_FATAL_FAILURE(timer->EnqueueDelayRelease(ep)); + if (timer != nullptr) { + delete timer; + timer = nullptr; + } +} + +void FakeRunInHbThread(NetHeartbeat *This) +{ + return; +} + +TEST_F(TestTransportCommon, NetHeartbeatStart) +{ + NetHeartbeat *heartbeat = new (std::nothrow) NetHeartbeat(nullptr, 0, 0); + EXPECT_EQ(heartbeat->Start(), static_cast(NN_INVALID_PARAM)); + + std::thread tmpThread(&FakeRunInHbThread, heartbeat); + heartbeat->mHbThread = std::move(tmpThread); + EXPECT_NO_FATAL_FAILURE(heartbeat->Stop()); + + NetDriverRDMAWithOob *driver = new (std::nothrow) NetDriverRDMAWithOob("name", false, + UBSHcomNetDriverProtocol::UBC); + heartbeat->mDriver = driver; + EXPECT_NO_FATAL_FAILURE(heartbeat->DetectHbState()); + + UBSHcomNetWorkerIndex workerIndex{}; + NetAsyncEndpoint *ep = new (std::nothrow) NetAsyncEndpoint(0, nullptr, nullptr, workerIndex); + + MOCKER_CPP_VIRTUAL(*ep, &NetAsyncEndpoint::PostSend, + NResult(NetAsyncEndpoint::*)(uint16_t, const UBSHcomNetTransRequest &, uint32_t)).stubs() + .will(returnValue(static_cast(NN_OK))) + .then(returnValue(static_cast(NN_INVALID_PARAM))); + EXPECT_EQ(heartbeat->SendTwoSideHeartBeat(ep), static_cast(NN_OK)); + EXPECT_EQ(heartbeat->SendTwoSideHeartBeat(ep), static_cast(NN_INVALID_PARAM)); + if (heartbeat != nullptr) { + delete heartbeat; + heartbeat = nullptr; + } + if (ep != nullptr) { + delete ep; + ep = nullptr; + } + if (driver != nullptr) { + delete driver; + driver = nullptr; + } +} + +TEST_F(TestTransportCommon, NetHeartbeatDetectSingleEpHbState) +{ + NetHeartbeat *heartbeat = new (std::nothrow) NetHeartbeat(nullptr, 0, 0); + UBSHcomNetTransRequest req {}; + EXPECT_NO_FATAL_FAILURE(heartbeat->DetectSingleEpHbState(req, nullptr)); + + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetAsyncEndpoint(0, nullptr, nullptr, workerIndex); + NetDriverSockWithOOB *driver = + new (std::nothrow) NetDriverSockWithOOB("name", false, UBSHcomNetDriverProtocol::TCP, SOCK_TCP); + heartbeat->mDriver = driver; + EXPECT_NO_FATAL_FAILURE(heartbeat->DetectSingleEpHbState(req, nullptr)); + if (heartbeat != nullptr) { + delete heartbeat; + heartbeat = nullptr; + } +} + +TEST_F(TestTransportCommon, NetWorkerLBFunction) +{ + NetWorkerLB *lb = new (std::nothrow) NetWorkerLB("name", UBSHcomNetDriverLBPolicy::NET_ROUND_ROBIN, 64); + std::vector> groups; + groups.push_back(std::make_pair(0, 0)); + EXPECT_EQ(lb->AddWorkerGroups(groups), static_cast(NN_INVALID_PARAM)); + + uint16_t wkrIdx = 0; + EXPECT_EQ(lb->ChooseWorker(2, "127.0.0.1", wkrIdx), false); + NetWorkerGroupLbInfo info {}; + info.wrkCntLimited = 1; + info.wrkCntInGrp = 0; + lb->mWrkGroups.push_back(info); + lb->mWrkGroups[0].wrkLimited.push_back(1); + EXPECT_EQ(lb->ChooseWorker(0, "127.0.0.1", wkrIdx), true); + + lb->mWorkerLimitedCnt = 0; + lb->AddWorkerGroup(0, 1); + if (lb != nullptr) { + delete lb; + lb = nullptr; + } +} + +TEST_F(TestTransportCommon, NetWorkerLBChooseWorkerLimited) +{ + NetWorkerLB *lb = new (std::nothrow) NetWorkerLB("name", UBSHcomNetDriverLBPolicy::NET_HASH_IP_PORT, 64); + uint16_t wkrIdx = 0; + EXPECT_EQ(lb->ChooseWorkerLimited(2, "127.0.0.1", wkrIdx), false); + + NetWorkerGroupLbInfo info {}; + info.wrkCntLimited = 1; + info.wrkCntInGrp = 0; + lb->mWrkGroups.push_back(info); + lb->mWrkGroups[0].wrkLimited.push_back(1); + EXPECT_EQ(lb->ChooseWorkerLimited(0, "127.0.0.1", wkrIdx), true); + if (lb != nullptr) { + delete lb; + lb = nullptr; + } +} + +TEST_F(TestTransportCommon, NormalMemoryRegionInit) +{ + NormalMemoryRegion *region = new (std::nothrow) NormalMemoryRegion("name", false, 0, 0); + region->mInited = true; + EXPECT_EQ(region->Initialize(), static_cast(NN_OK)); + + region->mInited = false; + region->mExternalMemory = true; + region->mBuf = 0; + region->mSize = 0; + EXPECT_EQ(region->Initialize(), static_cast(NN_INVALID_PARAM)); + + region->mExternalMemory = false; + MOCKER(memalign).defaults().will(returnValue(static_cast(nullptr))); + EXPECT_EQ(region->Initialize(), static_cast(NN_MALLOC_FAILED)); + if (region != nullptr) { + delete region; + region = nullptr; + } +} + +TEST_F(TestTransportCommon, NormalMemoryRegionFixedBufferInit) +{ + NormalMemoryRegionFixedBuffer *buffer = new (std::nothrow) NormalMemoryRegionFixedBuffer("name", 0, 0); + EXPECT_EQ(buffer->Initialize(), static_cast(NN_INVALID_PARAM)); + MOCKER_CPP(NetRingBuffer::Initialize).stubs() + .will(returnValue(static_cast(NN_INVALID_PARAM))); + EXPECT_EQ(buffer->Initialize(), static_cast(NN_INVALID_PARAM)); + + if (buffer != nullptr) { + delete buffer; + buffer = nullptr; + } +} + +TEST_F(TestTransportCommon, NormalMemoryRegionGetMemorySeg) +{ + NormalMemoryRegion region("name", false, 0, 0); + + EXPECT_EQ(region.GetMemorySeg(), static_cast(nullptr)); + uint64_t va = 0; + uint64_t vaLen = 0; + uint32_t token = 0; + EXPECT_NO_FATAL_FAILURE(region.GetVa(va, vaLen, token)); +} + +TEST_F(TestTransportCommon, NormalMemoryRegionFixedBufferGetFreeBufferN) +{ + NormalMemoryRegionFixedBuffer buffer("name", 0, 0); + EXPECT_EQ(buffer.GetFreeBufferN(nullptr, 0), false); +} +} +} +//#endif \ No newline at end of file diff --git a/test/unit_test/transport/ub/test_net_driver_ub.cpp b/test/unit_test/transport/ub/test_net_driver_ub.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae100e250c049bcf3d74af4e42b48d5ca8c8498a --- /dev/null +++ b/test/unit_test/transport/ub/test_net_driver_ub.cpp @@ -0,0 +1,767 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED +#include +#include + +#include "hcom_def.h" +#include "net_ub_driver.h" +#include "net_ub_endpoint.h" +#include "openssl_api_wrapper.h" +#include "under_api/obmm/obmm_api_wrapper.h" +#include "under_api/urma/urma_api_wrapper.h" +#include "ub_common.h" +#include "ub_mr_fixed_buf.h" +#include "ub_worker.h" +#include "net_oob_secure.h" + +namespace ock { +namespace hcom { +class TestNetDriverUB : public testing::Test { +public: + TestNetDriverUB(); + virtual void SetUp(void); + virtual void TearDown(void); + std::string mName = "TestNetDriverUB"; + NetDriverUBWithOob *driver = nullptr; + UBSHcomNetDriverOptions option{}; + UBContext *ctx = nullptr; + UBEId eid{}; + urma_context_t mUrmaContext{}; + char mem[NN_NO8]{}; + // worker + UBWorker *worker = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + // qp + UBJetty *qp = nullptr; + // mDriverSendMR + UBMemoryRegionFixedBuffer *mDriverSendMR = nullptr; +}; + +TestNetDriverUB::TestNetDriverUB() {} + +void TestNetDriverUB::SetUp() +{ + driver = new (std::nothrow) NetDriverUBWithOob(mName, true, UBSHcomNetDriverProtocol::UBC); + ASSERT_NE(driver, nullptr); + ctx = new (std::nothrow) UBContext("ubTest", eid); + ASSERT_NE(ctx, nullptr); + ctx->mUrmaContext = &mUrmaContext; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + driver->mContext = ctx; + worker = new (std::nothrow) UBWorker(mName, ctx, workerOptions, memPool, sglMemPool); + ASSERT_NE(worker, nullptr); + qp = new (std::nothrow) UBJetty(mName, 0, nullptr, nullptr); + ASSERT_NE(qp, nullptr); + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); +} + +void TestNetDriverUB::TearDown() +{ + if (ctx != nullptr) { + ctx->mUrmaContext = nullptr; + delete ctx; + ctx = nullptr; + } + if (driver != nullptr) { + driver->mContext = nullptr; + delete driver; + driver = nullptr; + } + if (worker != nullptr) { + worker->mUBContext = nullptr; + delete worker; + worker = nullptr; + } + if (qp != nullptr) { + delete qp; + qp = nullptr; + } + GlobalMockObject::verify(); +} + +std::vector filters{}; + +TEST_F(TestNetDriverUB, InitializeParamErr) +{ + UBSHcomNetOutLogger *trueLogger = UBSHcomNetOutLogger::Instance(); + UBSHcomNetOutLogger *logger = nullptr; + driver->mInited = true; + EXPECT_EQ(driver->Initialize(option), NN_OK); + + driver->mInited = false; + MOCKER(UBSHcomNetOutLogger::Instance) + .stubs() + .will(returnValue(logger)) + .then(returnValue(trueLogger)); + EXPECT_EQ(driver->Initialize(option), NN_NOT_INITIALIZED); +} + +TEST_F(TestNetDriverUB, ValidateOptionsOobTypeErr) +{ + driver->mOptions.oobType = NET_OOB_UB; + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + driver->mOptions.enableTls = true; + EXPECT_EQ(driver->ValidateOptionsOobType(), NN_INVALID_PARAM); +} + +TEST_F(TestNetDriverUB, InitializeOptErr) +{ + driver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidateOptions).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->Initialize(option), 1); + EXPECT_EQ(driver->Initialize(option), 1); +} + +TEST_F(TestNetDriverUB, InitializeLoadOpensslErr) +{ + driver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidateOptions).stubs().will(returnValue(0)); + option.enableTls = true; + MOCKER_CPP(HcomSsl::Load).stubs().will(returnValue(1)); + EXPECT_EQ(driver->Initialize(option), NN_NOT_INITIALIZED); + option.enableTls = false; +} + +TEST_F(TestNetDriverUB, InitializeSizeErr) +{ + driver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidaQpQueueSizeOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&UBContext::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateContext).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUB::UnInitializeInner).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverUB::CreateWorkerResource).stubs().will(returnValue(1)); + + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + option.mrSendReceiveSegSize = OBMM_SIZE; + + EXPECT_EQ(driver->Initialize(option), 1); + EXPECT_EQ(driver->Initialize(option), 1); +} + +TEST_F(TestNetDriverUB, InitializeSizeErrTwo) +{ + driver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&UBContext::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateContext).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::UnInitializeInner).stubs().will(ignoreReturnValue()); + + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + option.mrSendReceiveSegSize = OBMM_SIZE; + + EXPECT_EQ(driver->Initialize(option), 100); +} + +TEST_F(TestNetDriverUB, InitializeCtxErr) +{ + driver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateContext).stubs().will(returnValue(0)); + MOCKER_CPP(&UBContext::Initialize).stubs().will(returnValue(1)); + MOCKER_CPP(&NetDriverUB::UnInitializeInner).stubs().will(ignoreReturnValue()); + + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + option.mrSendReceiveSegSize = OBMM_SIZE; + + EXPECT_EQ(driver->Initialize(option), 1); +} + +TEST_F(TestNetDriverUB, InitializeCreateWorkerErr) +{ + driver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidaQpQueueSizeOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&UBContext::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateContext).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::UnInitializeInner).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverUB::CreateWorkerResource).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateWorkers).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateClientLB).stubs().will(returnValue(1)); + + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + option.mrSendReceiveSegSize = OBMM_SIZE; + + EXPECT_EQ(driver->Initialize(option), 1); + EXPECT_EQ(driver->Initialize(option), 1); +} + +TEST_F(TestNetDriverUB, InitializeDoInitializeErr) +{ + driver->mInited = false; + MOCKER_CPP(&UBSHcomNetDriverOptions::ValidateCommonOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidateOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::ValidaQpQueueSizeOptions).stubs().will(returnValue(0)); + MOCKER_CPP(&UBContext::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateContext).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::UnInitializeInner).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverUB::CreateWorkerResource).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateWorkers).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateClientLB).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*driver, &NetDriverUBWithOob::DoInitialize).stubs() + .will(returnValue(1)).then(returnValue(0)); + + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + option.mrSendReceiveSegSize = OBMM_SIZE; + + EXPECT_EQ(driver->Initialize(option), 1); + EXPECT_EQ(driver->Initialize(option), 0); +} + +TEST_F(TestNetDriverUB, ValidateOptionsParamErr) +{ + MOCKER_CPP(ValidateArrayOptions).stubs().will(returnValue(true)); + MOCKER_CPP(&UBSHcomNetDriver::ValidateAndParseOobPortRange).stubs().will(returnValue(1)).then(returnValue(0)); + + driver->mOptions.prePostReceiveSizePerQP = 0; + driver->mOptions.maxPostSendCountPerQP = NN_NO2048; + EXPECT_EQ(driver->ValidateOptions(), NN_INVALID_PARAM); + + // 65535 大于硬件上限,警告,最终因 ValidateAndParseOobPortRange 错误 + driver->mOptions.prePostReceiveSizePerQP = 65535; + driver->mOptions.maxPostSendCountPerQP = 65535; + EXPECT_EQ(driver->ValidateOptions(), NN_INVALID_PARAM); + + driver->mOptions.prePostReceiveSizePerQP = NN_NO2048; + driver->mOptions.maxPostSendCountPerQP = 0; + EXPECT_EQ(driver->ValidateOptions(), NN_INVALID_PARAM); +} + +TEST_F(TestNetDriverUB, ValidateOptions) +{ + MOCKER_CPP(ValidateArrayOptions).stubs().will(returnValue(true)); + MOCKER_CPP(&UBSHcomNetDriver::ValidateAndParseOobPortRange).stubs().will(returnValue(1)).then(returnValue(0)); + + MOCKER_CPP(&UBSHcomNetDriver::ValidateOptionsOobType).stubs().will(returnValue(1)).then(returnValue(0)); + + driver->mOptions.prePostReceiveSizePerQP = NN_NO2048; + driver->mOptions.maxPostSendCountPerQP = NN_NO2048; + EXPECT_EQ(driver->ValidateOptions(), NN_INVALID_PARAM); + + driver->mOptions.prePostReceiveSizePerQP = NN_NO512; + driver->mOptions.maxPostSendCountPerQP = NN_NO2048; + EXPECT_EQ(driver->ValidateOptions(), NN_INVALID_PARAM); + + EXPECT_EQ(driver->ValidateOptions(), NN_OK); +} + +static void MockSplitStr(const std::string &str, const std::string &separator, std::vector &result) +{ + result = filters; +} + +TEST_F(TestNetDriverUB, CreateContextMultiRailNoDevice) +{ + MOCKER_CPP(UBDeviceHelper::GetEnableDeviceCount).stubs().will(returnValue(static_cast(UB_PARAM_INVALID))); + driver->mProtocol = UBSHcomNetDriverProtocol::TCP; + driver->mContext = nullptr; + driver->mOptions.enableMultiRail = true; + EXPECT_EQ(driver->CreateContext(), NN_ERROR); +} + +UResult MockedGetEnableDeviceCount(std::string ipMask, uint16_t &enableDevCount, std::vector &enableIps, + std::string ipGroup) +{ + enableIps.emplace_back("1.1.1.1"); + return UB_OK; +} + +TEST_F(TestNetDriverUB, CreateContextMultiRailOk) +{ + MOCKER_CPP(UBDeviceHelper::GetEnableDeviceCount).stubs().will(invoke(MockedGetEnableDeviceCount)); + MOCKER_CPP(UBDeviceHelper::GetDeviceByIp).stubs().will(returnValue(static_cast(UB_OK))); + + driver->mContext = nullptr; + driver->mOptions.enableMultiRail = true; + driver->mDevIndex = 0; + driver->mProtocol = UBSHcomNetDriverProtocol::TCP; + EXPECT_EQ(driver->CreateContext(), NN_OK); +} + +TEST_F(TestNetDriverUB, CreateContextParamErr) +{ + EXPECT_EQ(driver->CreateContext(), NN_OK); + + driver->mContext = nullptr; + filters.clear(); + driver->mProtocol = UBSHcomNetDriverProtocol::TCP; + MOCKER(NetFunc::NN_SplitStr).stubs().will(invoke(MockSplitStr)); + EXPECT_EQ(driver->CreateContext(), NN_ERROR); +} + +TEST_F(TestNetDriverUB, CreateContextIPErr) +{ + std::vector matchIps{}; + driver->mContext = nullptr; + filters.clear(); + filters.emplace_back("192.168.1.1"); + driver->mProtocol = UBSHcomNetDriverProtocol::TCP; + MOCKER(NetFunc::NN_SplitStr).stubs().will(invoke(MockSplitStr)); + MOCKER(FilterIp).stubs().with(any(), outBound(matchIps)).will(returnValue(0)); + MOCKER(UBDeviceHelper::Initialize).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->CreateContext(), NN_ERROR); +} + +TEST_F(TestNetDriverUB, CreateContextInitializErr) +{ + std::vector matchIps{}; + matchIps.emplace_back("192.168.1.1"); + driver->mContext = nullptr; + filters.clear(); + filters.emplace_back("192.168.1.1"); + driver->mProtocol = UBSHcomNetDriverProtocol::TCP; + MOCKER(NetFunc::NN_SplitStr).stubs().will(invoke(MockSplitStr)); + MOCKER(FilterIp).stubs().with(any(), outBound(matchIps)).will(returnValue(0)); + MOCKER(UBDeviceHelper::Initialize).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(UBDeviceHelper::GetDeviceByIp).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->CreateContext(), NN_ERROR); + EXPECT_EQ(driver->CreateContext(), NN_ERROR); +} + +TEST_F(TestNetDriverUB, CreateContextCreateErr) +{ + std::vector matchIps{}; + matchIps.emplace_back("192.168.1.1"); + driver->mContext = nullptr; + filters.clear(); + filters.emplace_back("192.168.1.1"); + driver->mProtocol = UBSHcomNetDriverProtocol::TCP; + MOCKER(NetFunc::NN_SplitStr).stubs().will(invoke(MockSplitStr)); + MOCKER(FilterIp).stubs().with(any(), outBound(matchIps)).will(returnValue(0)); + MOCKER(UBDeviceHelper::Initialize).stubs().will(returnValue(0)); + MOCKER(UBDeviceHelper::GetDeviceByIp).stubs().will(returnValue(0)); + MOCKER(UBContext::Create).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->CreateContext(), 1); +} + +TEST_F(TestNetDriverUB, CreateContextSuccess) +{ + std::vector matchIps{}; + matchIps.emplace_back("192.168.1.1"); + driver->mContext = nullptr; + filters.clear(); + filters.emplace_back("192.168.1.1"); + driver->mProtocol = UBSHcomNetDriverProtocol::TCP; + MOCKER(NetFunc::NN_SplitStr).stubs().will(invoke(MockSplitStr)); + MOCKER(FilterIp).stubs().with(any(), outBound(matchIps)).will(returnValue(0)); + MOCKER(UBDeviceHelper::Initialize).stubs().will(returnValue(0)); + MOCKER(UBDeviceHelper::GetDeviceByIp).stubs().will(returnValue(0)); + MOCKER(UBContext::Create).stubs().with(any(), any(), outBound(ctx)).will(returnValue(0)); + + EXPECT_EQ(driver->CreateContext(), 0); +} + +TEST_F(TestNetDriverUB, CreateContextUbcModeMismatch) +{ + driver->mContext = nullptr; + driver->mOptions.SetUbcMode(UBSHcomUbcMode::LowLatency); + MOCKER_CPP(&NetDriverUB::GetDeviceByName).stubs().will(returnValue(1)); + EXPECT_EQ(driver->CreateContext(), NN_ERROR); +} + +TEST_F(TestNetDriverUB, CreateContextUbcModeOk) +{ + MOCKER_CPP(UBDeviceHelper::GetDeviceByEid).stubs().will(returnValue(static_cast(UB_OK))); + MOCKER_CPP(UBDeviceHelper::Initialize).stubs().will(returnValue(static_cast(UB_OK))); + + driver->mContext = nullptr; + driver->mOptions.SetUbcMode(UBSHcomUbcMode::LowLatency); + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + EXPECT_EQ(driver->CreateContext(), NN_ERROR); +} + +TEST_F(TestNetDriverUB, CreateSendMrErr) +{ + MOCKER(UBMemoryRegionFixedBuffer::Create).stubs().will(returnValue(1)); + EXPECT_EQ(driver->CreateSendMr(1), 1); +} + +TEST_F(TestNetDriverUB, CreateSendMrInitializeErr) +{ + mDriverSendMR = new (std::nothrow) UBMemoryRegionFixedBuffer(mName, ctx, 0, 0, 0); + ASSERT_NE(mDriverSendMR, nullptr); + MOCKER(UBMemoryRegionFixedBuffer::Create) + .stubs() + .with(any(), any(), any(), any(), any(), outBound(mDriverSendMR)) + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP_VIRTUAL(*mDriverSendMR, &UBMemoryRegionFixedBuffer::Initialize).stubs().will(returnValue(1)); + EXPECT_EQ(driver->CreateSendMr(1), 1); + EXPECT_EQ(driver->CreateSendMr(1), 1); + + if (mDriverSendMR != nullptr) { + delete mDriverSendMR; + mDriverSendMR = nullptr; + } +} + +TEST_F(TestNetDriverUB, CreateSendMr) +{ + mDriverSendMR = new (std::nothrow) UBMemoryRegionFixedBuffer(mName, ctx, 0, 0, 0); + ASSERT_NE(mDriverSendMR, nullptr); + MOCKER(UBMemoryRegionFixedBuffer::Create) + .stubs() + .with(any(), any(), any(), any(), any(), outBound(mDriverSendMR)) + .will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*mDriverSendMR, &UBMemoryRegionFixedBuffer::Initialize).stubs().will(returnValue(0)); + EXPECT_EQ(driver->CreateSendMr(1), 0); + + if (mDriverSendMR != nullptr) { + delete mDriverSendMR; + mDriverSendMR = nullptr; + } +} + +TEST_F(TestNetDriverUB, CreateOpCtxMemPoolErr) +{ + MOCKER_CPP(&NetMemPoolFixed::Initialize).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_EQ(driver->CreateOpCtxMemPool(), 1); + EXPECT_EQ(driver->CreateOpCtxMemPool(), 0); +} + +TEST_F(TestNetDriverUB, CreateOpCtxMemPoolNullErr) +{ + NetMemPoolFixed *testOpCtxMemPool = nullptr; + MOCKER_CPP(&NetRef::Get).stubs().will(returnValue(testOpCtxMemPool)); + EXPECT_EQ(driver->CreateOpCtxMemPool(), NN_INVALID_PARAM); +} + +TEST_F(TestNetDriverUB, CreateSglCtxMemPoolErr) +{ + MOCKER_CPP(&NetMemPoolFixed::Initialize).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_EQ(driver->CreateSglCtxMemPool(), 1); + EXPECT_EQ(driver->CreateSglCtxMemPool(), 0); +} + +TEST_F(TestNetDriverUB, CreateSglCtxMemPoolNullErr) +{ + NetMemPoolFixed *testSglOpCtxInfoPool = nullptr; + MOCKER_CPP(&NetRef::Get).stubs().will(returnValue(testSglOpCtxInfoPool)); + EXPECT_EQ(driver->CreateSglCtxMemPool(), NN_INVALID_PARAM); +} + +TEST_F(TestNetDriverUB, CreateWorkerResourceMrErr) +{ + MOCKER_CPP(&NetDriverUB::CreateSendMr).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateOpCtxMemPool).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->CreateWorkerResource(), 1); + EXPECT_EQ(driver->CreateWorkerResource(), 1); +} + +TEST_F(TestNetDriverUB, CreateWorkerResourceSuccess) +{ + MOCKER_CPP(&NetDriverUB::CreateSendMr).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateOpCtxMemPool).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUB::CreateSglCtxMemPool).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_EQ(driver->CreateWorkerResource(), 0); + EXPECT_EQ(driver->CreateWorkerResource(), 0); +} + +TEST_F(TestNetDriverUB, DestroyEndpointErr) +{ + UBSHcomNetEndpointPtr ubEp = nullptr; + EXPECT_NO_FATAL_FAILURE(driver->DestroyEndpoint(ubEp)); +} + +TEST_F(TestNetDriverUB, ClearWorkers) +{ + UBWorker *tmpWorker = new (std::nothrow) UBWorker(mName, ctx, workerOptions, memPool, sglMemPool); + tmpWorker->mInited = false; + tmpWorker->IncreaseRef(); + driver->mWorkers.emplace_back(tmpWorker); + EXPECT_NO_FATAL_FAILURE(driver->ClearWorkers()); +} + +TEST_F(TestNetDriverUB, CreateWorkersErr) +{ + MOCKER(NetFunc::NN_ParseWorkersGroups).stubs().will(returnValue(true)); + MOCKER(NetFunc::NN_ParseWorkerGroupsCpus).stubs().will(returnValue(true)); + MOCKER(NetFunc::NN_FinalizeWorkerGroupCpus).stubs().will(returnValue(true)); + MOCKER(NetFunc::NN_ParseWorkersGroupsThreadPriority).stubs().will(returnValue(false)); + EXPECT_EQ(driver->CreateWorkers(), NN_INVALID_PARAM); +} + +TEST_F(TestNetDriverUB, CreateWorkers) +{ + std::vector workerGroups{ 1 }; + std::vector workerThreadPriority{ 1 }; + std::vector flatWorkerCpus{ 1 }; + MOCKER(NetFunc::NN_ParseWorkersGroups).stubs().with(any(), outBound(workerGroups)).will(returnValue(true)); + MOCKER(NetFunc::NN_ParseWorkerGroupsCpus).stubs().will(returnValue(true)); + MOCKER(NetFunc::NN_FinalizeWorkerGroupCpus) + .stubs() + .with(any(), any(), any(), outBound(flatWorkerCpus)) + .will(returnValue(true)); + MOCKER(NetFunc::NN_ParseWorkersGroupsThreadPriority) + .stubs() + .with(any(), outBound(workerThreadPriority), any()) + .will(returnValue(true)); + MOCKER_CPP(&UBWorker::Initialize).stubs().will(returnValue(1)).then(returnValue(0)); + + driver->mOptions.workerThreadPriority = 1; + EXPECT_EQ(driver->CreateWorkers(), NN_NEW_OBJECT_FAILED); + EXPECT_EQ(driver->CreateWorkers(), 0); + + // clear resources + delete driver->mWorkers[0]; + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUB, UnInitializeErr) +{ + driver->mInited = false; + EXPECT_NO_FATAL_FAILURE(driver->UnInitialize()); + driver->mInited = true; + driver->mStarted = true; + EXPECT_NO_FATAL_FAILURE(driver->UnInitialize()); +} + +TEST_F(TestNetDriverUB, UnInitialize) +{ + driver->mInited = true; + driver->mStarted = false; + driver->mContext = nullptr; + EXPECT_NO_FATAL_FAILURE(driver->UnInitialize()); +} + +TEST_F(TestNetDriverUB, UnmapVaForUBErr) +{ + uint64_t va; + EXPECT_EQ(driver->UnmapVaForUB(va), NN_ERROR); +} + +TEST_F(TestNetDriverUB, CreateMemoryRegion) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + + driver->mInited = true; + MOCKER_CPP(&UBMemoryRegion::InitializeForOneSide).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(driver->CreateMemoryRegion(reinterpret_cast(mem), NN_NO8, mr), 1); + EXPECT_EQ(driver->CreateMemoryRegion(reinterpret_cast(mem), NN_NO8, mr), 0); +} + +TEST_F(TestNetDriverUB, CreateMemoryRegionAddressErr) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + driver->mInited = false; + EXPECT_EQ(driver->CreateMemoryRegion(reinterpret_cast(mem), NN_NO8, mr), NN_EP_NOT_INITIALIZED); + + driver->mInited = true; + EXPECT_EQ(driver->CreateMemoryRegion(0, NN_NO8, mr), NN_INVALID_PARAM); +} + +TEST_F(TestNetDriverUB, CreateMemoryRegionInternalParamErr) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + + EXPECT_EQ(driver->CreateMemoryRegion(0, mr), NN_INVALID_PARAM); + driver->mInited = false; + EXPECT_EQ(driver->CreateMemoryRegion(NN_NO8, mr), NN_EP_NOT_INITIALIZED); +} + +TEST_F(TestNetDriverUB, CreateMemoryRegionInternal) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + driver->mInited = true; + MOCKER_CPP(&UBMemoryRegion::InitializeForOneSide).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_EQ(driver->CreateMemoryRegion(NN_NO8, mr), 1); + EXPECT_EQ(driver->CreateMemoryRegion(NN_NO8, mr), 0); +} + +TEST_F(TestNetDriverUB, StartParamErr) +{ + driver->mStarted = true; + EXPECT_EQ(driver->Start(), NN_OK); + + driver->mStarted = false; + driver->mInited = false; + EXPECT_EQ(driver->Start(), NN_ERROR); +} + +TEST_F(TestNetDriverUB, Start) +{ + driver->mStarted = false; + driver->mInited = true; + driver->mOptions.dontStartWorkers = true; + MOCKER_CPP_VIRTUAL(*driver, &NetDriverUBWithOob::DoStart).stubs().will(returnValue(0)); + EXPECT_EQ(driver->Start(), NN_OK); +} + +TEST_F(TestNetDriverUB, Stop) +{ + driver->mStarted = false; + EXPECT_NO_FATAL_FAILURE(driver->Stop()); + + driver->mStarted = true; + MOCKER_CPP_VIRTUAL(*driver, &NetDriverUBWithOob::DoStop).stubs().will(ignoreReturnValue()); + EXPECT_NO_FATAL_FAILURE(driver->Stop()); +} + +TEST_F(TestNetDriverUB, DoInitializeWithoutWorker) +{ + driver->mStartOobSvr = false; + EXPECT_EQ(driver->DoInitialize(), NN_OK); + + driver->mStartOobSvr = true; + MOCKER_CPP(&UBSHcomNetDriver::CreateListeners).stubs().will(returnValue(1)); + EXPECT_EQ(driver->DoInitialize(), NN_ERROR); +} + +TEST_F(TestNetDriverUB, DoUnInitialize) +{ + driver->mStarted = true; + EXPECT_NO_FATAL_FAILURE(driver->DoUnInitialize()); + + driver->mStarted = false; + EXPECT_NO_FATAL_FAILURE(driver->DoUnInitialize()); +} + +TEST_F(TestNetDriverUB, HandleCqEventParamErr) +{ + urma_async_event_t event{}; + urma_jfc_t jfc{}; + event.element.jfc = nullptr; + EXPECT_NO_FATAL_FAILURE(driver->HandleCqEvent(&event)); + + event.element.jfc = &jfc; + jfc.jfc_cfg.user_ctx = 1; + event.element.jfc->jfc_cfg.user_ctx = reinterpret_cast(&worker); + MOCKER_CPP(&UBWorker::Stop).stubs().will(returnValue(1)); + EXPECT_NO_FATAL_FAILURE(driver->HandleCqEvent(&event)); +} + +TEST_F(TestNetDriverUB, HandleCqEvent) +{ + urma_async_event_t event{}; + urma_jfc_t jfc{}; + + event.element.jfc = &jfc; + jfc.jfc_cfg.user_ctx = 1; + event.element.jfc->jfc_cfg.user_ctx = reinterpret_cast(&worker); + MOCKER_CPP(&UBWorker::Stop).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::DestroyEpInWorker).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBWorker::ReInitializeCQ).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBWorker::Start).stubs().will(returnValue(1)); + EXPECT_NO_FATAL_FAILURE(driver->HandleCqEvent(&event)); + EXPECT_NO_FATAL_FAILURE(driver->HandleCqEvent(&event)); +} + +TEST_F(TestNetDriverUB, ParseUrl) +{ + std::string badUrl = "127.0.0.1:9981"; + std::string testUrl = "uds://name"; + std::string testUrl2 = "tcp://127.0.0.1:0"; + std::string testUrl3 = "tcp://127.0.0.1:9981"; + std::string testUrl4 = "ubc://1111:2222:0000:0000:0000:0000:0100:0000:1"; + std::string testUrl5 = "ubc://1111:2222:0000:0000:0000:0000:0100:0000:888"; + NetDriverOobType type; + std::string ip{}; + uint16_t port = 0; + + EXPECT_EQ(driver->ParseUrl(badUrl, type, ip, port), NN_PARAM_INVALID); + EXPECT_EQ(driver->ParseUrl(testUrl, type, ip, port), SER_OK); + EXPECT_EQ(driver->ParseUrl(testUrl2, type, ip, port), NN_PARAM_INVALID); + EXPECT_EQ(driver->ParseUrl(testUrl3, type, ip, port), SER_OK); + EXPECT_EQ(driver->ParseUrl(testUrl4, type, ip, port), NN_PARAM_INVALID); + EXPECT_EQ(driver->ParseUrl(testUrl5, type, ip, port), SER_OK); + EXPECT_EQ(type, NetDriverOobType::NET_OOB_UB); + EXPECT_EQ(ip, "1111:2222:0000:0000:0000:0000:0100:0000"); + EXPECT_EQ(port, 888); +} + +TEST_F(TestNetDriverUB, SetNetDeviceIpMask) +{ + UBSHcomNetDriverOptions opt{}; + std::vector mask{}; + mask.emplace_back("1.2.3.4"); + mask.emplace_back("1.2.3.5"); + EXPECT_EQ(opt.SetNetDeviceIpMask(mask), true); +} + +TEST_F(TestNetDriverUB, SetNetDeviceIpGroup) +{ + UBSHcomNetDriverOptions opt{}; + std::vector ipGroup{}; + ipGroup.emplace_back("1.2.3.4"); + ipGroup.emplace_back("1.2.3.5"); + EXPECT_EQ(opt.SetNetDeviceIpGroup(ipGroup), true); +} + +TEST_F(TestNetDriverUB, SetWorkerGroupsInfo) +{ + UBSHcomNetDriverOptions opt{}; + std::vector workerGroups{}; + std::vector workerGroups2{}; + UBSHcomWorkerGroupInfo info1{}; + UBSHcomWorkerGroupInfo info2{}; + workerGroups.emplace_back(info1); + workerGroups.emplace_back(info2); + EXPECT_EQ(opt.SetWorkerGroupsInfo(workerGroups), true); + EXPECT_EQ(opt.SetWorkerGroupsInfo(workerGroups2), false); +} + +TEST_F(TestNetDriverUB, GetDeviceByEid) +{ + UBEId tmpEid{}; + MOCKER_CPP(UBDeviceHelper::Initialize).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(UBDeviceHelper::GetDeviceByEid).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(UBDeviceHelper::UnInitialize).stubs().then(ignoreReturnValue()); + + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + EXPECT_EQ(driver->GetDeviceByEid(tmpEid), 1); + EXPECT_EQ(driver->GetDeviceByEid(tmpEid), 1); + EXPECT_EQ(driver->GetDeviceByEid(tmpEid), 0); +} + +TEST_F(TestNetDriverUB, GetDeviceByName) +{ + UBEId tmpEid{}; + MOCKER_CPP(UBDeviceHelper::Initialize).stubs().will(returnValue(1)); + EXPECT_EQ(driver->GetDeviceByName(tmpEid), 1); +} + +TEST_F(TestNetDriverUB, DestroyMemoryRegion) +{ + UBMemoryRegion *mr = new (std::nothrow) UBMemoryRegion("name", nullptr, 0, 0, 0); + MOCKER_CPP_VIRTUAL(*mr, &UBMemoryRegion::UnInitialize).stubs(); + UBSHcomNetMemoryRegionPtr mrPtr = mr; + EXPECT_NO_FATAL_FAILURE(driver->DestroyMemoryRegion(mrPtr)); +} + +TEST_F(TestNetDriverUB, GetTseg) +{ + urma_target_seg_t *tseg = nullptr; + EXPECT_EQ(driver->GetTseg(0, tseg), static_cast(UB_PARAM_INVALID)); +} +} +} +#endif diff --git a/test/unit_test/transport/ub/test_net_driver_ub_oob_public_jetty.cpp b/test/unit_test/transport/ub/test_net_driver_ub_oob_public_jetty.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dab09d795af08251a4812dcb74f4eddfbbc9be6b --- /dev/null +++ b/test/unit_test/transport/ub/test_net_driver_ub_oob_public_jetty.cpp @@ -0,0 +1,1226 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include +#include +#include +#include + +#include "net_monotonic.h" +#include "net_oob_ssl.h" +#include "net_ub_endpoint.h" +#include "ub_mr_fixed_buf.h" +#include "ub_worker.h" +#include "net_ub_driver_oob.h" +#include "net_oob_secure.h" +#include "ub_urma_wrapper_public_jetty.h" + +namespace ock { +namespace hcom { +class TestNetDriverUBPublicJetty : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name = "driver-public-jetty"; + NetDriverUBWithOob *driver = nullptr; + UBPublicJetty *jetty = nullptr; + UBJetty *qp = nullptr; +}; + +void TestNetDriverUBPublicJetty::SetUp() +{ + driver = new (std::nothrow) NetDriverUBWithOob(name, false, UBC); + ASSERT_NE(driver, nullptr); + jetty = new (std::nothrow) UBPublicJetty(name, 0, nullptr, nullptr); + ASSERT_NE(jetty, nullptr); + qp = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + ASSERT_NE(qp, nullptr); + qp->StoreExchangeInfo(new UBJettyExchangeInfo); +} + +void TestNetDriverUBPublicJetty::TearDown() +{ + if (driver != nullptr) { + delete driver; + driver = nullptr; + } + if (jetty != nullptr) { + delete jetty; + jetty = nullptr; + } + if (qp != nullptr) { + delete qp; + qp = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestNetDriverUBPublicJetty, CreateUrmaListeners) +{ + UBPublicJetty *publicJetty = nullptr; + EXPECT_EQ(driver->CreateUrmaListeners(publicJetty), NN_INVALID_PARAM); + + UBSHcomNetOobListenerOptions opt{}; + opt.port = NN_NO4; + driver->mOobListenOptions.emplace_back(opt); + MOCKER_CPP(&NetDriverUBWithOob::CreatePublicJetty) + .stubs() + .with(outBound(jetty), any()) + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&NetWorkerLB::AddWorkerGroups).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(driver->CreateUrmaListeners(publicJetty), NN_ERROR); + EXPECT_EQ(driver->CreateUrmaListeners(publicJetty), NN_ERROR); + EXPECT_EQ(driver->CreateUrmaListeners(publicJetty), NN_OK); +} + +TEST_F(TestNetDriverUBPublicJetty, CreateUrmaListenersFailed) +{ + UBPublicJetty *publicJetty = nullptr; + UBSHcomNetOobListenerOptions opt{}; + opt.port = NN_NO2; + driver->mOobListenOptions.emplace_back(opt); + EXPECT_EQ(driver->CreateUrmaListeners(publicJetty), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, ConnectByPublicJetty) +{ + std::string oobIp("1.2.3.4"); + std::string payload("hello"); + UBSHcomNetEndpointPtr outEp = nullptr; + + MOCKER_CPP(&NetDriverUBWithOob::ClientCheckState).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::ConnectSyncEpByPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::ConnectASyncEpByPublicJetty).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->ConnectByPublicJetty(oobIp, NN_NO4, payload, outEp, 0, 0, 0), NN_ERROR); + EXPECT_EQ(driver->ConnectByPublicJetty(oobIp, NN_NO4, payload, outEp, NET_EP_SELF_POLLING, 0, 0), 0); + EXPECT_EQ(driver->ConnectByPublicJetty(oobIp, NN_NO4, payload, outEp, 0, 0, 0), 0); +} + +TEST_F(TestNetDriverUBPublicJetty, ConnectByPublicJettyFail) +{ + std::string oobIp("1.2.3.4"); + std::string payload("hello"); + UBSHcomNetEndpointPtr outEp = nullptr; + MOCKER_CPP(&NetDriverUBWithOob::ClientCheckState).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(driver->ConnectByPublicJetty(oobIp, NN_NO2, payload, outEp, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, ClientCheckState) +{ + std::string payload("hello"); + + driver->mInited.store(false); + EXPECT_EQ(driver->ClientCheckState(payload), NN_NOT_INITIALIZED); + + driver->mInited.store(true); + driver->mStarted = false; + EXPECT_EQ(driver->ClientCheckState(payload), NN_ERROR); + + driver->mStarted = true; + std::string payload2(NN_NO2048, 'a'); + EXPECT_EQ(driver->ClientCheckState(payload2), NN_INVALID_PARAM); + EXPECT_EQ(driver->ClientCheckState(payload), NN_OK); +} + +TEST_F(TestNetDriverUBPublicJetty, CreatePublicJetty) +{ + int err = 1; + UBPublicJetty *publicJetty = nullptr; + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(err)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::InitializePublicJetty).stubs().will(returnValue(err)).then(returnValue(0)); + + EXPECT_EQ(driver->CreatePublicJetty(publicJetty, 0), err); + EXPECT_EQ(driver->CreatePublicJetty(publicJetty, 0), NN_ERROR); + EXPECT_EQ(driver->CreatePublicJetty(publicJetty, 0), NN_OK); + + if (publicJetty != nullptr) { + delete publicJetty; + publicJetty = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, PublicJettyConnect1) +{ + std::string oobIp("1.2.3.4"); + UBPublicJetty *clientPublicJetty = nullptr; + + MOCKER_CPP(&NetDriverUBWithOob::CreatePublicJetty) + .stubs() + .will(returnValue(1)); + + EXPECT_EQ(driver->PublicJettyConnect(oobIp, 1, clientPublicJetty), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, PublicJettyConnect2) +{ + std::string oobIp("1.2.3.4"); + UBPublicJetty *clientPublicJetty = nullptr; + + UBPublicJetty* tmp = new (std::nothrow) UBPublicJetty(name, 0, nullptr, nullptr); + MOCKER_CPP(&NetDriverUBWithOob::CreatePublicJetty) + .stubs() + .with(outBound(tmp), any()) + .will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::StartPublicJetty).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->PublicJettyConnect(oobIp, 1, clientPublicJetty), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, PublicJettyConnect3) +{ + std::string oobIp("1.2.3.4"); + UBPublicJetty *clientPublicJetty = nullptr; + + UBPublicJetty* tmp = new (std::nothrow) UBPublicJetty(name, 0, nullptr, nullptr); + MOCKER_CPP(&NetDriverUBWithOob::CreatePublicJetty) + .stubs() + .with(outBound(tmp), any()) + .will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::StartPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::StrToEid).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->PublicJettyConnect(oobIp, 1, clientPublicJetty), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, PublicJettyConnect4) +{ + std::string oobIp("1.2.3.4"); + UBPublicJetty *clientPublicJetty = nullptr; + + UBPublicJetty* tmp = new (std::nothrow) UBPublicJetty(name, 0, nullptr, nullptr); + MOCKER_CPP(&NetDriverUBWithOob::CreatePublicJetty) + .stubs() + .with(outBound(tmp), any()) + .will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::StartPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::StrToEid).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::ImportPublicJetty).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->PublicJettyConnect(oobIp, 1, clientPublicJetty), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, PublicJettyConnect5) +{ + std::string oobIp("1.2.3.4"); + UBPublicJetty *clientPublicJetty = nullptr; + + UBPublicJetty* tmp = new (std::nothrow) UBPublicJetty(name, 0, nullptr, nullptr); + MOCKER_CPP(&NetDriverUBWithOob::CreatePublicJetty) + .stubs() + .with(outBound(tmp), any()) + .will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::StartPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::StrToEid).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::ImportPublicJetty).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->PublicJettyConnect(oobIp, 1, clientPublicJetty), NN_OK); + delete tmp; +} + +TEST_F(TestNetDriverUBPublicJetty, ClientSelectWorker) +{ + NetWorkerLB *lb = new NetWorkerLB(name, NET_ROUND_ROBIN, 1); + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + UBWorker *outWorker = nullptr; + + driver->mClientLb = lb; + driver->mWorkers.emplace_back(worker); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(false)).then(returnValue(true)); + + EXPECT_EQ(driver->ClientSelectWorker(outWorker, 0, 0), NN_ERROR); + EXPECT_EQ(driver->ClientSelectWorker(outWorker, 0, 0), NN_OK); + if (lb != nullptr) { + delete lb; + lb = nullptr; + } + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ClientSendConnReq) +{ + std::string payload("hello"); + + MOCKER_CPP(&NetDriverUBWithOob::FillExchMsg).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::PollingCompletion).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::GetJettyId).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->ClientSendConnReq(payload, 0, 0, nullptr, qp, nullptr), UB_PARAM_INVALID); + EXPECT_EQ(driver->ClientSendConnReq(payload, 0, 0, jetty, qp, nullptr), NN_ERROR); + EXPECT_EQ(driver->ClientSendConnReq(payload, 0, 0, jetty, qp, nullptr), NN_ERROR); + EXPECT_EQ(driver->ClientSendConnReq(payload, 0, 0, jetty, qp, nullptr), NN_ERROR); + EXPECT_EQ(driver->ClientSendConnReq(payload, 0, 0, jetty, qp, nullptr), NN_OK); +} + +TEST_F(TestNetDriverUBPublicJetty, CheckServerACK) +{ + JettyConnResp exchangeMsg{}; + exchangeMsg.connResp = MAGIC_MISMATCH; + EXPECT_EQ(driver->CheckServerACK(exchangeMsg), NN_CONNECT_REFUSED); + exchangeMsg.connResp = WORKER_GRPNO_MISMATCH; + EXPECT_EQ(driver->CheckServerACK(exchangeMsg), NN_CONNECT_REFUSED); + exchangeMsg.connResp = PROTOCOL_MISMATCH; + EXPECT_EQ(driver->CheckServerACK(exchangeMsg), NN_CONNECT_PROTOCOL_MISMATCH); + exchangeMsg.connResp = SERVER_INTERNAL_ERROR; + EXPECT_EQ(driver->CheckServerACK(exchangeMsg), NN_ERROR); + exchangeMsg.connResp = VERSION_MISMATCH; + EXPECT_EQ(driver->CheckServerACK(exchangeMsg), NN_CONNECT_REFUSED); + exchangeMsg.connResp = TLS_VERSION_MISMATCH; + EXPECT_EQ(driver->CheckServerACK(exchangeMsg), NN_CONNECT_REFUSED); + exchangeMsg.connResp = OK; + EXPECT_EQ(driver->CheckServerACK(exchangeMsg), NN_OK); + exchangeMsg.connResp = OK_PROTOCOL_TCP; + EXPECT_EQ(driver->CheckServerACK(exchangeMsg), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, ClientEstablishConnOnReply) +{ + UBJettyExchangeInfo info{}; + MOCKER_CPP(&UBPublicJetty::Receive).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::CheckServerACK).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::SetBondingInfo).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::ImportPublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::PollingCompletion).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->ClientEstablishConnOnReply(nullptr, qp, info), UB_PARAM_INVALID); + EXPECT_EQ(driver->ClientEstablishConnOnReply(jetty, qp, info), NN_ERROR); + EXPECT_EQ(driver->ClientEstablishConnOnReply(jetty, qp, info), NN_ERROR); + EXPECT_EQ(driver->ClientEstablishConnOnReply(jetty, qp, info), NN_ERROR); + EXPECT_EQ(driver->ClientEstablishConnOnReply(jetty, qp, info), NN_ERROR); + EXPECT_EQ(driver->ClientEstablishConnOnReply(jetty, qp, info), NN_ERROR); + EXPECT_EQ(driver->ClientEstablishConnOnReply(jetty, qp, info), NN_OK); + + MOCKER_CPP(&operator new, void *(*) (size_t, const std::nothrow_t &)) + .stubs() + .will(returnValue(static_cast(nullptr))); + EXPECT_EQ(driver->ClientEstablishConnOnReply(jetty, qp, info), NN_MALLOC_FAILED); +} + +TEST_F(TestNetDriverUBPublicJetty, ClientCreateJetty1) +{ + UBJetty *outQp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + + EXPECT_EQ(driver->ClientCreateJetty(outQp, nullptr), NN_PARAM_INVALID); +} + +TEST_F(TestNetDriverUBPublicJetty, ClientCreateJetty2) +{ + UBJetty *outQp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + + MOCKER_CPP(&UBWorker::CreateQP).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->ClientCreateJetty(outQp, worker), NN_ERROR); + + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ClientCreateJetty3) +{ + UBJetty *outQp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + UBJetty* tmp = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(tmp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->ClientCreateJetty(outQp, worker), NN_ERROR); + + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ClientCreateJetty4) +{ + UBJetty *outQp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + UBJetty* tmp = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(tmp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->ClientCreateJetty(outQp, worker), NN_OK); + + delete tmp; + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, CheckMagicAndProtocol) +{ + JettyConnResp exchangeMsg{}; + JettyConnHeader exchangeInfo{}; + + exchangeInfo.ConnectHeader.magic = NN_NO2; + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::Protocol).stubs().will(returnValue(UBC)); + EXPECT_EQ(driver->CheckMagicAndProtocol(exchangeMsg, &exchangeInfo, jetty), NN_ERROR); + + exchangeInfo.ConnectHeader.magic = NN_NO256; + exchangeInfo.ConnectHeader.protocol = UBC; + EXPECT_EQ(driver->CheckMagicAndProtocol(exchangeMsg, &exchangeInfo, jetty), NN_OK); +} + +TEST_F(TestNetDriverUBPublicJetty, FillExchMsg) +{ + JettyConnHeader *exchangeInfo = (JettyConnHeader *)malloc(sizeof(JettyConnHeader) + NN_NO6); + std::string payload("hello"); + + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::FillBondingMsg).stubs().will(returnValue(0)); + MOCKER_CPP(memcpy_s).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_EQ(driver->FillExchMsg(exchangeInfo, qp, payload, 0, jetty), 1); + EXPECT_EQ(driver->FillExchMsg(exchangeInfo, qp, payload, 0, jetty), NN_ERROR); + EXPECT_EQ(driver->FillExchMsg(exchangeInfo, qp, payload, 0, jetty), NN_OK); + free(exchangeInfo); +} + +TEST_F(TestNetDriverUBPublicJetty, FillExchMsgHeartBeat) +{ + JettyConnHeader *exchangeInfo = (JettyConnHeader *)malloc(sizeof(JettyConnHeader) + NN_NO6); + std::string payload("hello"); + NetHeartbeat *hb = new (std::nothrow) NetHeartbeat(nullptr, 0, 0); + driver->mHeartBeat = hb; + UBEId eid{}; + UBContext *ctx = new (std::nothrow) UBContext("name", eid); + qp->mUBContext = ctx; + qp->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + + MOCKER_CPP(&UBJetty::CreateHBMemoryRegion).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetRemoteHbInfo).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::FillBondingMsg).stubs().will(returnValue(0)); + MOCKER_CPP(memcpy_s).stubs().will(returnValue(0)); + EXPECT_EQ(driver->FillExchMsg(exchangeInfo, qp, payload, 0, jetty), NN_OK); + free(exchangeInfo); + if (hb != nullptr) { + delete hb; + } + driver->mHeartBeat = nullptr; +} + +TEST_F(TestNetDriverUBPublicJetty, FillExchMsgHeartBeatErr) +{ + JettyConnHeader *exchangeInfo = (JettyConnHeader *)malloc(sizeof(JettyConnHeader) + NN_NO6); + std::string payload("hello"); + NetHeartbeat *hb = new (std::nothrow) NetHeartbeat(nullptr, 0, 0); + driver->mHeartBeat = hb; + UBEId eid{}; + UBContext *ctx = new (std::nothrow) UBContext("name", eid); + qp->mUBContext = ctx; + qp->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + + MOCKER_CPP(&UBJetty::CreateHBMemoryRegion).stubs().will(returnValue(0)).then(returnValue(1)); + MOCKER_CPP(&UBPublicJetty::FillBondingMsg).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->FillExchMsg(exchangeInfo, qp, payload, 0, jetty), 1); + EXPECT_EQ(driver->FillExchMsg(exchangeInfo, qp, payload, 0, jetty), 1); + free(exchangeInfo); + if (hb != nullptr) { + delete hb; + } + driver->mHeartBeat = nullptr; +} + +TEST_F(TestNetDriverUBPublicJetty, PrePostReceiveOnConnection) +{ + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + urma_target_seg_t *tseg = nullptr; + + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(false)).then(returnValue(true)); + MOCKER_CPP(&UBWorker::PostReceive).stubs().will(returnValue(1)); + MOCKER_CPP(&UBJetty::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(tseg)); + + EXPECT_EQ(driver->PrePostReceiveOnConnection(qp, nullptr), UB_PARAM_INVALID); + EXPECT_EQ(driver->PrePostReceiveOnConnection(qp, worker), NN_ERROR); + EXPECT_EQ(driver->PrePostReceiveOnConnection(qp, worker), 1); + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ServerSelectWorker) +{ + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + UBWorker *outWorker = nullptr; + JettyConnResp exchangeMsg{}; + + driver->mPublicJetty = jetty; + driver->mWorkers.emplace_back(worker); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(false)).then(returnValue(true)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(false)).then(returnValue(true)); + NetWorkerLBPtr lbOne = nullptr; + NetWorkerLBPtr lb = new NetWorkerLB(name, NET_ROUND_ROBIN, 1); + MOCKER_CPP(&UBPublicJetty::LoadBalancer).stubs().will(returnValue(lbOne)).then(returnValue(lb)); + + EXPECT_EQ(driver->ServerSelectWorker(outWorker, exchangeMsg, 0, jetty), NN_ERROR); + jetty->mWorkerLb = lb; + lb->IncreaseRef(); + EXPECT_EQ(driver->ServerSelectWorker(outWorker, exchangeMsg, 0, jetty), NN_ERROR); + lb->IncreaseRef(); + EXPECT_EQ(driver->ServerSelectWorker(outWorker, exchangeMsg, 0, jetty), NN_ERROR); + lb->IncreaseRef(); + EXPECT_EQ(driver->ServerSelectWorker(outWorker, exchangeMsg, 0, jetty), NN_OK); + + driver->mPublicJetty = nullptr; + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ServerCreateJetty1) +{ + UBJetty *outQp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + JettyConnResp exchangeMsg{}; + JettyConnHeader info{}; + + MOCKER_CPP(&UBWorker::CreateQP).stubs().will(returnValue(1)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(0)); + EXPECT_EQ(driver->ServerCreateJetty(outQp, worker, exchangeMsg, &info, jetty), NN_ERROR); + + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ServerCreateJetty2) +{ + UBJetty *outQp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + JettyConnResp exchangeMsg{}; + JettyConnHeader info{}; + + UBJetty* tmp = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(tmp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(1)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->ServerCreateJetty(outQp, worker, exchangeMsg, &info, jetty), NN_ERROR); + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ServerCreateJetty3) +{ + UBJetty *outQp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + JettyConnResp exchangeMsg{}; + JettyConnHeader info{}; + + UBJetty* tmp = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(tmp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->ServerCreateJetty(outQp, worker, exchangeMsg, &info, jetty), 1); + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ServerCreateJetty4) +{ + UBJetty *outQp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + JettyConnResp exchangeMsg{}; + JettyConnHeader info{}; + + UBJetty* tmp = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(tmp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(0)); + + MOCKER_CPP(&operator new, void *(*) (size_t, const std::nothrow_t &)) + .stubs() + .will(returnValue(static_cast(nullptr))); + EXPECT_EQ(driver->ServerCreateJetty(outQp, worker, exchangeMsg, &info, jetty), NN_MALLOC_FAILED); + + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ServerCreateJetty5) +{ + UBJetty *outQp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + JettyConnResp exchangeMsg{}; + JettyConnHeader info{}; + + UBJetty* tmp = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(tmp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->ServerCreateJetty(outQp, worker, exchangeMsg, &info, jetty), NN_OK); + delete tmp; + if (worker != nullptr) { + delete worker; + worker = nullptr; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ServerReplyMsg) +{ + JettyConnResp exchangeMsg{}; + std::string name = "test-public-jetty"; + UBEId eid{}; + UBContext *ctx = new (std::nothrow) UBContext(name, eid); + ASSERT_NE(ctx, nullptr); + jetty->mUBContext = ctx; + uint32_t jettyId = 0; + + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::PollingCompletion).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::GetJettyId).stubs().will(returnValue(jettyId)); + MOCKER_CPP(&UBPublicJetty::FillBondingMsg).stubs().will(returnValue(0)); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); + + EXPECT_EQ(driver->ServerReplyMsg(nullptr, exchangeMsg, nullptr), UB_PARAM_INVALID); + EXPECT_EQ(driver->ServerReplyMsg(qp, exchangeMsg, jetty), 1); + EXPECT_EQ(driver->ServerReplyMsg(qp, exchangeMsg, jetty), NN_ERROR); + EXPECT_EQ(driver->ServerReplyMsg(qp, exchangeMsg, jetty), NN_ERROR); + EXPECT_EQ(driver->ServerReplyMsg(qp, exchangeMsg, jetty), NN_OK); + jetty->mUBContext = nullptr; + delete ctx; +} + +TEST_F(TestNetDriverUBPublicJetty, ServerReplyMsgHeartBeat) +{ + JettyConnResp exchangeMsg{}; + std::string name = "test-public-jetty"; + UBEId eid{}; + UBContext *ctx = new (std::nothrow) UBContext(name, eid); + ASSERT_NE(ctx, nullptr); + jetty->mUBContext = ctx; + uint32_t jettyId = 0; + NetHeartbeat *hb = new (std::nothrow) NetHeartbeat(nullptr, 0, 0); + ASSERT_NE(hb, nullptr); + driver->mHeartBeat = hb; + + UBEId eid2{}; + UBContext *ctx2 = new (std::nothrow) UBContext("name", eid2); + qp->mUBContext = ctx2; + qp->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + + MOCKER_CPP(&UBJetty::CreateHBMemoryRegion).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetRemoteHbInfo).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBPublicJetty::GetJettyId).stubs().will(returnValue(jettyId)); + MOCKER_CPP(&UBPublicJetty::FillBondingMsg).stubs().will(returnValue(0)); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->ServerReplyMsg(qp, exchangeMsg, jetty), 1); + + if (hb != nullptr) { + delete hb; + } + driver->mHeartBeat = nullptr; + jetty->mUBContext = nullptr; + delete ctx; +} + +TEST_F(TestNetDriverUBPublicJetty, ServerReplyMsgHeartBeatErr) +{ + JettyConnResp exchangeMsg{}; + std::string name = "test-public-jetty"; + UBEId eid{}; + UBContext *ctx = new (std::nothrow) UBContext(name, eid); + ASSERT_NE(ctx, nullptr); + jetty->mUBContext = ctx; + uint32_t jettyId = 0; + NetHeartbeat *hb = new (std::nothrow) NetHeartbeat(nullptr, 0, 0); + ASSERT_NE(hb, nullptr); + driver->mHeartBeat = hb; + UBEId eid2{}; + UBContext *ctx2 = new (std::nothrow) UBContext("name", eid2); + qp->mUBContext = ctx2; + qp->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + + MOCKER_CPP(&UBJetty::CreateHBMemoryRegion).stubs().will(returnValue(0)).then(returnValue(1)); + MOCKER_CPP(&UBJetty::GetRemoteHbInfo).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBPublicJetty::GetJettyId).stubs().will(returnValue(jettyId)); + MOCKER_CPP(&UBPublicJetty::FillBondingMsg).stubs().will(returnValue(0)); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); + + EXPECT_EQ(driver->ServerReplyMsg(qp, exchangeMsg, jetty), 1); + EXPECT_EQ(driver->ServerReplyMsg(qp, exchangeMsg, jetty), 1); + + if (hb != nullptr) { + delete hb; + } + driver->mHeartBeat = nullptr; + jetty->mUBContext = nullptr; + delete ctx; +} + +TEST_F(TestNetDriverUBPublicJetty, ConnectSyncEpByPublicJetty) +{ + std::string oobIp("1.2.3.4"); + std::string payload("hello"); + UBSHcomNetEndpointPtr outEp = nullptr; + + MOCKER_CPP(&NetDriverUBWithOob::PublicJettyConnect).stubs().will(returnValue(1)).then(returnValue(0)); + + MOCKER_CPP(&NetDriverUBWithOob::ClientSyncEpCreateJetty) + .stubs() + .with(outBound(qp), any(), any()) + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::ClientSendConnReq).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::ClientEstablishConnOnReply).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::PrePostReceiveOnSyncEp).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::ClientSyncEpSetInfo).stubs().will(ignoreReturnValue()); + + qp->IncreaseRef(); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + driver->IncreaseRef(); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + driver->IncreaseRef(); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, ClientSyncEpSetInfo) +{ + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBSyncEndpoint(0, nullptr, nullptr, NN_NO4, nullptr, workerIndex); + reinterpret_cast(ep.Get())->mJetty = qp; + UBSHcomNetEndpointPtr outEp = nullptr; + + EXPECT_NO_FATAL_FAILURE(driver->ClientSyncEpSetInfo(ep, qp, outEp)); + reinterpret_cast(ep.Get())->mJetty = nullptr; +} + +TEST_F(TestNetDriverUBPublicJetty, PrePostReceiveOnSyncEp) +{ + urma_target_seg_t *tseg = nullptr; + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBSyncEndpoint(0, nullptr, nullptr, NN_NO4, nullptr, workerIndex); + + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(false)).then(returnValue(true)); + MOCKER_CPP(&UBJetty::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(tseg)); + MOCKER_CPP(&NetUBSyncEndpoint::PostReceive).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->PrePostReceiveOnSyncEp(ep, NN_NO4, qp), NN_ERROR); + EXPECT_EQ(driver->PrePostReceiveOnSyncEp(ep, NN_NO4, qp), 1); +} + +TEST_F(TestNetDriverUBPublicJetty, ClientSyncEpCreateJettyFail) +{ + UBJetty *outQp = nullptr; + UBJfc *outCq = nullptr; + UBJfc *cq = new (std::nothrow) UBJfc(name, nullptr); + ASSERT_NE(cq, nullptr); + UBJetty *qpIn = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + ASSERT_NE(qpIn, nullptr); + MOCKER_CPP(&NetUBSyncEndpoint::CreateResources) + .stubs() + .with(any(), any(), any(), any(), outBound(qpIn), outBound(cq)) + .will(returnValue(1)); + EXPECT_EQ(driver->ClientSyncEpCreateJetty(outQp, outCq, UB_BUSY_POLLING), 1); + delete cq; + delete qpIn; +} + +TEST_F(TestNetDriverUBPublicJetty, ClientSyncEpCreateJettyFailTwo) +{ + UBJetty *outQp = nullptr; + UBJfc *outCq = nullptr; + UBJfc *cq = new (std::nothrow) UBJfc(name, nullptr); + ASSERT_NE(cq, nullptr); + UBJetty *qpIn = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + ASSERT_NE(qpIn, nullptr); + MOCKER_CPP(&NetUBSyncEndpoint::CreateResources) + .stubs() + .with(any(), any(), any(), any(), outBound(qpIn), outBound(cq)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(1)); + EXPECT_EQ(driver->ClientSyncEpCreateJetty(outQp, outCq, UB_BUSY_POLLING), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, ClientSyncEpCreateJettyFailThree) +{ + UBJetty *outQp = nullptr; + UBJfc *outCq = nullptr; + UBJfc *cq = new (std::nothrow) UBJfc(name, nullptr); + ASSERT_NE(cq, nullptr); + UBJetty *qpIn = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + ASSERT_NE(qpIn, nullptr); + MOCKER_CPP(&NetUBSyncEndpoint::CreateResources) + .stubs() + .with(any(), any(), any(), any(), outBound(qpIn), outBound(cq)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(1)); + MOCKER_CPP(&UBJetty::UnInitialize).stubs().will(returnValue(0)); + EXPECT_EQ(driver->ClientSyncEpCreateJetty(outQp, outCq, UB_BUSY_POLLING), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, ClientSyncEpCreateJettySuccess) +{ + UBJetty *outQp = nullptr; + UBJfc *outCq = nullptr; + UBJfc *cq = new (std::nothrow) UBJfc(name, nullptr); + ASSERT_NE(cq, nullptr); + UBJetty *qpIn = new (std::nothrow) UBJetty(name, 0, nullptr, nullptr); + ASSERT_NE(qpIn, nullptr); + MOCKER_CPP(&NetUBSyncEndpoint::CreateResources) + .stubs() + .with(any(), any(), any(), any(), outBound(qpIn), outBound(cq)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + EXPECT_EQ(driver->ClientSyncEpCreateJetty(outQp, outCq, UB_BUSY_POLLING), NN_OK); + delete cq; + delete qpIn; +} + +TEST_F(TestNetDriverUBPublicJetty, ClearJettyResource) +{ + UBOpContextInfo ctxInfo{}; + qp->mCtxPosted.next = &ctxInfo; + MOCKER_CPP(&NetDriverUBWithOob::ProcessErrorNewRequest).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBJetty::Stop).stubs().will(returnValue(0)); + EXPECT_NO_FATAL_FAILURE(driver->ClearJettyResource(qp)); +} + +TEST_F(TestNetDriverUBPublicJetty, ClientCreateEpRecvFail) +{ + UBSHcomNetEndpointPtr outEp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + ASSERT_NE(worker, nullptr); + worker->IncreaseRef(); + UBJettyExchangeInfo info{}; + + MOCKER_CPP(&UBPublicJetty::Receive).stubs().will(returnValue(1)); + worker->IncreaseRef(); + qp->IncreaseRef(); + driver->IncreaseRef(); + // auto free by std::unique_ptr + UBJettyExchangeInfo *exchInfo = new UBJettyExchangeInfo(); + qp->StoreExchangeInfo(exchInfo); + EXPECT_EQ(driver->ClientCreateEp(outEp, 0, qp, worker, info, jetty), NN_ERROR); + if (worker != nullptr) { + delete worker; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ClientCreateEpSendFail) +{ + UBSHcomNetEndpointPtr outEp = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + ASSERT_NE(worker, nullptr); + worker->IncreaseRef(); + UBJettyExchangeInfo info{}; + + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::PollingCompletion).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::Receive).stubs().will(returnValue(1)).then(returnValue(0)); + worker->IncreaseRef(); + qp->IncreaseRef(); + driver->IncreaseRef(); + // auto free by std::unique_ptr + UBJettyExchangeInfo *exchInfo = new UBJettyExchangeInfo(); + qp->StoreExchangeInfo(exchInfo); + EXPECT_EQ(driver->ClientCreateEp(outEp, 0, qp, worker, info, jetty), NN_ERROR); + + worker->IncreaseRef(); + qp->IncreaseRef(); + driver->IncreaseRef(); + // auto free by std::unique_ptr + exchInfo = new UBJettyExchangeInfo(); + qp->StoreExchangeInfo(exchInfo); + EXPECT_EQ(driver->ClientCreateEp(outEp, 0, qp, worker, info, jetty), NN_ERROR); + + worker->IncreaseRef(); + qp->IncreaseRef(); + driver->IncreaseRef(); + // auto free by std::unique_ptr + exchInfo = new UBJettyExchangeInfo(); + qp->StoreExchangeInfo(exchInfo); + EXPECT_EQ(driver->ClientCreateEp(outEp, 0, qp, worker, info, jetty), NN_ERROR); + + worker->IncreaseRef(); + qp->IncreaseRef(); + driver->IncreaseRef(); + // auto free by std::unique_ptr + exchInfo = new UBJettyExchangeInfo(); + qp->StoreExchangeInfo(exchInfo); + EXPECT_EQ(driver->ClientCreateEp(outEp, 0, qp, worker, info, jetty), NN_ERROR); + if (worker != nullptr) { + delete worker; + } +} + +int MockNewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + if (payload == "test-handshake") { + return 1; + } + return 0; +} + +TEST_F(TestNetDriverUBPublicJetty, ServerCreateEpSendFail) +{ + UBJettyExchangeInfo info = qp->GetExchangeInfo(); + JettyConnHeader exchangeInfo{}; + exchangeInfo.payloadLen = 1; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + ASSERT_NE(worker, nullptr); + + driver->RegisterNewEPHandler( + std::bind(&MockNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(1)); + MOCKER_CPP(&NetDriverUBWithOob::ServerHandshake).stubs().will(returnValue(0)); + EXPECT_EQ(driver->ServerCreateEp(info, qp, nullptr, &exchangeInfo, jetty), NN_PARAM_INVALID); + worker->IncreaseRef(); + qp->IncreaseRef(); + driver->IncreaseRef(); + EXPECT_EQ(driver->ServerCreateEp(info, qp, worker, &exchangeInfo, jetty), 0); + if (worker != nullptr) { + delete worker; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ServerCreateEpSuccess) +{ + UBJettyExchangeInfo info = qp->GetExchangeInfo(); + JettyConnHeader exchangeInfo{}; + exchangeInfo.payloadLen = 1; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + UBWorker *worker = new UBWorker(name, 0, workerOptions, memPool, sglMemPool); + ASSERT_NE(worker, nullptr); + + driver->RegisterNewEPHandler( + std::bind(&MockNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::ServerHandshake).stubs().will(returnValue(0)); + worker->IncreaseRef(); + qp->IncreaseRef(); + driver->IncreaseRef(); + EXPECT_EQ(driver->ServerCreateEp(info, qp, worker, &exchangeInfo, jetty), NN_OK); + if (worker != nullptr) { + delete worker; + } +} + +TEST_F(TestNetDriverUBPublicJetty, ServerEstablishCtrlConn) +{ + JettyConnHeader exchangeInfo{}; + MOCKER_CPP(&UBPublicJetty::StartPublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::SetBondingInfo).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::ImportPublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_EQ(driver->ServerEstablishCtrlConn(nullptr, jetty), NN_PARAM_INVALID); + EXPECT_EQ(driver->ServerEstablishCtrlConn(&exchangeInfo, jetty), NN_ERROR); + EXPECT_EQ(driver->ServerEstablishCtrlConn(&exchangeInfo, jetty), NN_ERROR); + EXPECT_EQ(driver->ServerEstablishCtrlConn(&exchangeInfo, jetty), NN_ERROR); + EXPECT_EQ(driver->ServerEstablishCtrlConn(&exchangeInfo, jetty), NN_OK); +} + +TEST_F(TestNetDriverUBPublicJetty, ServerHandshakeAckFail) +{ + std::string payload = "hello"; + std::string eidAndPort = "4245:4944:0000:0000:0000:0000:0100:0000"; + + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBSyncEndpoint(0, nullptr, nullptr, NN_NO4, nullptr, workerIndex); + driver->RegisterNewEPHandler( + std::bind(&MockNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + MOCKER_CPP(&UBPublicJetty::Receive).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_EQ(driver->ServerHandshake(ep, payload, eidAndPort, jetty), NN_ERROR); + EXPECT_EQ(driver->ServerHandshake(ep, payload, eidAndPort, jetty), NN_ERROR); +} + +template UResult MockReceive(void *buf, uint32_t size) +{ + int8_t *data = reinterpret_cast(buf); + *data = Value; + return UB_OK; +} + +TEST_F(TestNetDriverUBPublicJetty, ServerHandshake) +{ + std::string payload = "hello"; + std::string test = "test-handshake"; + std::string eidAndPort = "4245:4944:0000:0000:0000:0000:0100:0000"; + + UBSHcomNetWorkerIndex workerIndex{}; + UBSHcomNetEndpointPtr ep = new (std::nothrow) NetUBSyncEndpoint(0, nullptr, nullptr, NN_NO4, nullptr, workerIndex); + driver->RegisterNewEPHandler( + std::bind(&MockNewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + + MOCKER_CPP(&UBPublicJetty::Receive).stubs().will(invoke(MockReceive<1>)); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::PollingCompletion).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_EQ(driver->ServerHandshake(ep, payload, eidAndPort, jetty), NN_ERROR); + EXPECT_EQ(driver->ServerHandshake(ep, payload, eidAndPort, jetty), NN_ERROR); + EXPECT_EQ(driver->ServerHandshake(ep, payload, eidAndPort, jetty), NN_OK); + EXPECT_EQ(driver->ServerHandshake(ep, test, eidAndPort, jetty), NN_ERROR); +} + +template void *NewExceptFor(size_t sz, const std::nothrow_t &) +{ + if (sz == sizeof(T)) { + return nullptr; + } + + return std::malloc(sz); +} + +TEST_F(TestNetDriverUBPublicJetty, CreateSyncEpMallocFail) +{ + UBSHcomNetEndpointPtr outEp = nullptr; + UBJetty *tmpJetty = new (std::nothrow) UBJetty("name", 0, nullptr, nullptr); + ASSERT_NE(tmpJetty, nullptr); + tmpJetty->IncreaseRef(); + + MOCKER_CPP(&operator new, void *(*)(size_t, const std::nothrow_t &)) + .stubs() + .will(invoke(NewExceptFor)); + + EXPECT_EQ(driver->CreateSyncEp(tmpJetty, nullptr, 0, outEp, nullptr), NN_NEW_OBJECT_FAILED); + if (tmpJetty != nullptr) { + delete tmpJetty; + } +} + +TEST_F(TestNetDriverUBPublicJetty, CreateSyncEpFail) +{ + UBSHcomNetEndpointPtr outEp = nullptr; + UBJetty *tmpJetty = new (std::nothrow) UBJetty("name", 0, nullptr, nullptr); + ASSERT_NE(tmpJetty, nullptr); + tmpJetty->IncreaseRef(); + + UBPublicJetty *pubJetty = new (std::nothrow) UBPublicJetty("pubJetty", 0x1122, nullptr, nullptr); + ASSERT_NE(pubJetty, nullptr); + pubJetty->IncreaseRef(); + + MOCKER_CPP(&NetDriverUBWithOob::PrePostReceiveOnSyncEp) + .stubs() + .will(returnValue(static_cast(NN_ERROR))) + .then(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty) + .stubs() + .will(returnValue(static_cast(UB_PARAM_INVALID))) + .then(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&UBPublicJetty::PollingCompletion) + .stubs() + .will(returnValue(static_cast(UB_CQ_EVENT_GET_TIMOUT))) + .then(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&UBPublicJetty::Receive).stubs() + .will(returnValue(static_cast(NN_ERROR))) + .then(returnValue(static_cast(NN_OK))); + + // PrePostReceiveOnSyncEp 失败 + driver->IncreaseRef(); + EXPECT_EQ(driver->CreateSyncEp(tmpJetty, nullptr, 0, outEp, pubJetty), NN_ERROR); + + // SendByPublicJetty 失败 + driver->IncreaseRef(); + EXPECT_EQ(driver->CreateSyncEp(tmpJetty, nullptr, 0, outEp, pubJetty), NN_ERROR); + + // PollingCompletion 失败 + driver->IncreaseRef(); + EXPECT_EQ(driver->CreateSyncEp(tmpJetty, nullptr, 0, outEp, pubJetty), NN_ERROR); + + // Receive 失败 + driver->IncreaseRef(); + EXPECT_EQ(driver->CreateSyncEp(tmpJetty, nullptr, 0, outEp, pubJetty), NN_ERROR); + + // Receive 成功,因为默认 serverAck = 1, 所以正常 + driver->IncreaseRef(); + EXPECT_EQ(driver->CreateSyncEp(tmpJetty, nullptr, 0, outEp, pubJetty), NN_OK); + + delete pubJetty; + delete tmpJetty; +} + +TEST_F(TestNetDriverUBPublicJetty, CreateSyncEpCheckServerAckFail) +{ + UBSHcomNetEndpointPtr outEp = nullptr; + UBJetty *tmpJetty = new (std::nothrow) UBJetty("name", 0, nullptr, nullptr); + ASSERT_NE(tmpJetty, nullptr); + tmpJetty->IncreaseRef(); + + UBPublicJetty *pubJetty = new (std::nothrow) UBPublicJetty("pubJetty", 0x1122, nullptr, nullptr); + ASSERT_NE(pubJetty, nullptr); + pubJetty->IncreaseRef(); + + MOCKER_CPP(&NetDriverUBWithOob::PrePostReceiveOnSyncEp).stubs().will(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&UBPublicJetty::SendByPublicJetty).stubs().will(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&UBPublicJetty::PollingCompletion).stubs().will(returnValue(static_cast(NN_OK))); + MOCKER_CPP(&UBPublicJetty::Receive).stubs().will(invoke(&MockReceive<-1>)); + + // Receive 成功,但是 serverAck = -1 + driver->IncreaseRef(); + EXPECT_EQ(driver->CreateSyncEp(tmpJetty, nullptr, 0, outEp, pubJetty), NN_ERROR); + + delete pubJetty; + delete tmpJetty; +} + +TEST_F(TestNetDriverUBPublicJetty, ConnectASyncEpByPublicJetty) +{ + std::string oobIp("1.2.3.4"); + std::string payload("hello"); + UBSHcomNetEndpointPtr outEp = nullptr; + MOCKER_CPP(&NetDriverUBWithOob::PublicJettyConnect).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::CreatePublicJetty).stubs().with(outBound(jetty), any(), any()) + .will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::StartPublicJetty).stubs().then(returnValue(1)); + + EXPECT_EQ(driver->ConnectASyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + jetty->IncreaseRef(); + EXPECT_EQ(driver->ConnectASyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + jetty->IncreaseRef(); + EXPECT_EQ(driver->ConnectASyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBPublicJetty, PublicJettyNewConnectionCB) +{ + UBOpContextInfo ctx{}; + JettyConnHeader exchangeInfo{}; + + driver->mPublicJetty = jetty; + ctx.mrMemAddr = 0; + EXPECT_EQ(driver->PublicJettyNewConnectionCB(&ctx), NN_ERROR); + + ctx.mrMemAddr = reinterpret_cast(&exchangeInfo); + MOCKER_CPP(&NetDriverUBWithOob::CreatePublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::ServerEstablishCtrlConn).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::CheckMagicAndProtocol).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->PublicJettyNewConnectionCB(&ctx), NN_ERROR); + EXPECT_EQ(driver->PublicJettyNewConnectionCB(&ctx), NN_ERROR); + EXPECT_EQ(driver->PublicJettyNewConnectionCB(&ctx), NN_ERROR); + driver->mPublicJetty = nullptr; +} + +TEST_F(TestNetDriverUBPublicJetty, ConnectSyncEpByPublicJettyFail) +{ + std::string oobIp("1.2.3.4"); + std::string payload("hello"); + UBSHcomNetEndpointPtr outEp = nullptr; + + MOCKER_CPP(&NetDriverUBWithOob::PublicJettyConnect).stubs().will(returnValue(1)).then(returnValue(0)); + + MOCKER_CPP(&NetDriverUBWithOob::CreatePublicJetty).stubs().with(outBound(jetty), any(), any()) + .will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::StartPublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::ClientEstablishConnOnReply) + .stubs() + .will(returnValue(NN_ERROR)) + .then(returnValue(NN_OK)); + MOCKER_CPP(&NetDriverUBWithOob::CreateSyncEp).stubs().will(returnValue(NN_ERROR)).then(returnValue(NN_OK)); + + jetty->IncreaseRef(); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + jetty->IncreaseRef(); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + jetty->IncreaseRef(); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + + // ClientEstablishConnOnReply 失败 + jetty->IncreaseRef(); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + + // CreateSyncEp 失败 + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); + EXPECT_EQ(driver->ConnectSyncEpByPublicJetty(oobIp, 1, payload, outEp, 0, 0, 0), NN_ERROR); +} + +} +} +#endif diff --git a/test/unit_test/transport/ub/test_net_driver_ub_with_oob.cpp b/test/unit_test/transport/ub/test_net_driver_ub_with_oob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d351c837c6d60abe785c5d44809051ab5624897a --- /dev/null +++ b/test/unit_test/transport/ub/test_net_driver_ub_with_oob.cpp @@ -0,0 +1,1573 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED +#include +#include + +#include +#include + +#include "net_monotonic.h" +#include "net_oob_ssl.h" +#include "net_ub_endpoint.h" +#include "ub_mr_fixed_buf.h" +#include "ub_worker.h" +#include "net_ub_driver_oob.h" +#include "net_oob_secure.h" +#include "ub_urma_wrapper_jetty.h" + +namespace ock { +namespace hcom { + +int NewEndPoint(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("new endpoint from " << ipPort << " payload " << payload << " id " << newEP->Id()); + UBSHcomNetEndpoint *ep = newEP.Get(); + reinterpret_cast(ep)->mDriver = nullptr; + return 0; +} + +int RequestPosted(const UBSHcomNetRequestContext &ctx) +{ + return 0; +} + +int OneSideDone(const UBSHcomNetRequestContext &ctx) +{ + return 1; +} + +void EndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + UBSHcomNetEndpointPtr tmpEp = ep; + tmpEp.Set(nullptr); +} + +int RequestReceived(const UBSHcomNetRequestContext &ctx) +{ + UBSHcomNetMessage *msg = ctx.Message(); + if (msg->mBuf != nullptr) { + free(msg->mBuf); + msg->mBuf = nullptr; + } + return 1; +} + +class TestNetDriverUBWithOob : public testing::Test { +public: + TestNetDriverUBWithOob(); + virtual void SetUp(void); + virtual void TearDown(void); + std::string mName = "TestNetDriverUBWithOob"; + NetDriverUBWithOob *driver = nullptr; + UBSHcomNetDriverOptions option{}; + UBContext *ctx = nullptr; + UBEId eid{}; + urma_context_t mUrmaContext{}; + char mem[NN_NO8]{}; + JettyOptions jettyOptions{}; + // worker + UBWorker *worker = nullptr; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + UBWorkerOptions workerOptions{}; + // qp + UBJetty *qp = nullptr; + // lb + NetWorkerLB *lb = nullptr; + // tSeg + urma_target_seg_t tSeg{}; + // jfc + UBJfc *jfc = nullptr; + // ctxInfo + UBOpContextInfo ctxInfo{}; + // CallbackEp + NetUBAsyncEndpoint *CallbackEp = nullptr; +}; + +TestNetDriverUBWithOob::TestNetDriverUBWithOob() {} + +void TestNetDriverUBWithOob::SetUp() +{ + // create ctx + ctx = new (std::nothrow) UBContext("ubTest", eid); + ASSERT_NE(ctx, nullptr); + ctx->mUrmaContext = &mUrmaContext; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + // create drivver + driver = new (std::nothrow) NetDriverUBWithOob(mName, true, UBSHcomNetDriverProtocol::UBC); + ASSERT_NE(driver, nullptr); + driver->mOptions.enableTls = false; + driver->RegisterNewEPHandler( + std::bind(&NewEndPoint, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + driver->RegisterReqPostedHandler(std::bind(&RequestPosted, std::placeholders::_1)); + driver->RegisterOneSideDoneHandler(std::bind(&OneSideDone, std::placeholders::_1)); + driver->RegisterEPBrokenHandler(std::bind(&EndPointBroken, std::placeholders::_1)); + driver->RegisterNewReqHandler(std::bind(&RequestReceived, std::placeholders::_1)); + driver->mEnableTls = false; + driver->mContext = ctx; + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + driver->mInited = true; + driver->mStarted = true; + driver->IncreaseRef(); + driver->mMajorVersion = 0; + // create worker + worker = new (std::nothrow) UBWorker(mName, ctx, workerOptions, memPool, sglMemPool); + ASSERT_NE(worker, nullptr); + worker->mInited = true; + worker->IncreaseRef(); + // create qp + qp = new (std::nothrow) UBJetty(mName, 0, ctx, nullptr); + ASSERT_NE(qp, nullptr); + qp->IncreaseRef(); + qp->mUpContext1 = reinterpret_cast(worker); + qp->StoreExchangeInfo(new UBJettyExchangeInfo); + // create lb + lb = new (std::nothrow) NetWorkerLB(mName, UBSHcomNetDriverLBPolicy::NET_ROUND_ROBIN, 0); + ASSERT_NE(lb, nullptr); + lb->IncreaseRef(); + // initialize ctxInfo + ctxInfo.ubJetty = qp; + ctxInfo.upCtxSize = 1; + // create CallbackEp + CallbackEp = new (std::nothrow) NetUBAsyncEndpoint(1, qp, nullptr, worker); + ASSERT_NE(CallbackEp, nullptr); + CallbackEp->IncreaseRef(); + qp->mUpContext = reinterpret_cast(CallbackEp); + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); +} + +void TestNetDriverUBWithOob::TearDown() +{ + if (driver != nullptr) { + driver->mContext = nullptr; + delete driver; + driver = nullptr; + } + if (CallbackEp != nullptr) { + CallbackEp->mJetty = nullptr; + delete CallbackEp; + CallbackEp = nullptr; + } + if (worker != nullptr) { + worker->mUBContext = nullptr; + delete worker; + worker = nullptr; + } + if (qp != nullptr) { + delete qp; + qp = nullptr; + } + if (ctx != nullptr) { + ctx->mUrmaContext = nullptr; + delete ctx; + ctx = nullptr; + } + if (lb != nullptr) { + delete lb; + lb = nullptr; + } + if (jfc != nullptr) { + delete jfc; + jfc = nullptr; + } + GlobalMockObject::verify(); +} + +static ssize_t MockRecv(int socket, void *buf, size_t size, int flags) +{ + switch (size) { + case sizeof(ConnectHeader): { + ConnectHeader *tmp = reinterpret_cast(buf); + tmp->magic = 1; + tmp->protocol = UBSHcomNetDriverProtocol::UBC; + break; + } + case sizeof(uint32_t): { + uint32_t *tmp = reinterpret_cast(buf); + *tmp = 1; + break; + } + default: + break; + } + + return size; +} + +static ssize_t MockRecvFakeSize(int socket, void *buf, size_t size, int flags) +{ + switch (size) { + case sizeof(ConnectHeader): { + ConnectHeader *tmp = reinterpret_cast(buf); + tmp->magic = 1; + tmp->protocol = UBSHcomNetDriverProtocol::UBC; + break; + } + case sizeof(uint32_t): { + uint32_t *tmp = reinterpret_cast(buf); + *tmp = NN_NO1000; + break; + } + default: + break; + } + + return size; +} + +static ssize_t MockRecvUBC(int socket, void *buf, size_t size, int flags) +{ + switch (size) { + case sizeof(ConnectHeader): { + ConnectHeader *tmp = reinterpret_cast(buf); + tmp->magic = 1; + tmp->protocol = UBSHcomNetDriverProtocol::UBC; + break; + } + case sizeof(uint32_t): { + uint32_t *tmp = reinterpret_cast(buf); + *tmp = 1; + break; + } + default: + break; + } + + return size; +} + +static ssize_t MockConnSend(int socket, void const *buf, size_t size, int flags) +{ + return size; +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionCBSecErr) +{ + OOBTCPConnection conn(-1); + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs() + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Receive).stubs() + .will(returnValue(1)); + EXPECT_EQ(driver->NewConnectionCB(conn), NN_OOB_SEC_PROCESS_ERROR); + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionCBMagicErr) +{ + OOBTCPConnection conn(-1); + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs() + .will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs() + .will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); + driver->mOptions.magic = 1; + driver->mProtocol = UBSHcomNetDriverProtocol::RDMA; + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionLbErr) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs() + .will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs() + .will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs() + .will(returnValue(false)); + driver->mOptions.magic = 1; + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionWorkerErr) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(false)); + + driver->mOptions.magic = 1; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionQpCreateErr) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().will(returnValue(1)); + + driver->mOptions.magic = 1; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionQpErr) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(1)); + + driver->mOptions.magic = 1; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionHccsSizeCheckErr) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + conn.mIpAndPort = "192.168.1.1:5684"; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecvFakeSize)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + + driver->mOptions.magic = 1; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionUBCHeartBeat) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + conn.mIpAndPort = "192.168.1.1:5684"; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecvUBC)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::CreateHBMemoryRegion).stubs().will(returnValue(0)).then(returnValue(1)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(1)).then(returnValue(1)); + + driver->mOptions.magic = 1; + driver->mProtocol = UBSHcomNetDriverProtocol::UBC; + driver->mWorkers.emplace_back(worker); + driver->mHeartBeat = new (std::nothrow) NetHeartbeat(driver, NN_NO60, NN_NO2); + qp->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + + EXPECT_EQ(driver->NewConnectionCB(conn), 1); + EXPECT_EQ(driver->NewConnectionCB(conn), 1); + driver->mWorkers.clear(); + if (driver->mHeartBeat != nullptr) { + delete driver->mHeartBeat; + driver->mHeartBeat = nullptr; + } +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionExchangeErr) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + conn.mIpAndPort = "192.168.1.1:5684"; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(1)); + + driver->mOptions.magic = 1; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); + EXPECT_EQ(driver->NewConnectionCB(conn), 1); + + // 必须放在最后 MOCK,保证之前通过 std::nothrow new 分配的实例已完成。否则会 + // 遇到 NetLogger 的 this 为空. + MOCKER_CPP(&operator new, void *(*)(size_t, const std::nothrow_t &)) + .stubs() + .will(returnValue(static_cast(nullptr))); + EXPECT_EQ(driver->NewConnectionCB(conn), NN_MALLOC_FAILED); + + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionPostRecvErr) +{ + int err = 1; + OOBTCPConnection conn(-1); + conn.mLb = lb; + conn.mIpAndPort = "192.168.1.1:5684"; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::PostReceive).stubs().will(returnValue(err)); + MOCKER_CPP(&UBJetty::Stop).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + + driver->mOptions.magic = err; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), err); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionGetbufferErr) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + conn.mIpAndPort = "192.168.1.1:5684"; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(false)).then(returnValue(true)); + MOCKER_CPP(&UBWorker::PostReceive).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + + driver->mOptions.magic = 1; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), NN_MALLOC_FAILED); + EXPECT_EQ(driver->NewConnectionCB(conn), 0); + driver->mWorkers.clear(); +} + +template void *NewExceptFor(size_t sz, const std::nothrow_t &) +{ + if (sz == sizeof(T)) { + return nullptr; + } + + return std::malloc(sz); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionAllocEPFail) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + conn.mIpAndPort = "192.168.1.1:5684"; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::PostReceive).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + MOCKER_CPP(&UBJetty::Stop).stubs().will(returnValue(0)); + MOCKER_CPP(&operator new, void *(*)(size_t, const std::nothrow_t &)) + .stubs() + .will(invoke(NewExceptFor)); + driver->mOptions.magic = 1; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), NN_NEW_OBJECT_FAILED); + driver->mWorkers.clear(); +} + +int NewEPFail(const std::string &ipPort, const UBSHcomNetEndpointPtr &newEP, const std::string &payload) +{ + NN_LOG_INFO("Mock user callback fail"); + UBSHcomNetEndpoint *ep = newEP.Get(); + reinterpret_cast(ep)->mDriver = nullptr; + return 1; +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionUsrCbFail) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + conn.mIpAndPort = "192.168.1.1:5684"; + driver->RegisterNewEPHandler( + std::bind(&NewEPFail, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::PostReceive).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + + driver->mOptions.magic = 1; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), NN_ERROR); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, NewConnectionTest) +{ + OOBTCPConnection conn(-1); + conn.mLb = lb; + conn.mIpAndPort = "192.168.1.1:5684"; + MOCKER(OOBSecureProcess::SecProcessInOOBServer).stubs().will(returnValue(0)); + MOCKER_CPP_VIRTUAL(conn, &OOBTCPConnection::Send).stubs().will(returnValue(0)); + MOCKER(::recv).stubs().will(invoke(MockRecv)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(false)).then(returnValue(true)); + MOCKER_CPP(&UBWorker::PostReceive).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + + driver->mOptions.magic = 1; + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->NewConnectionCB(conn), NN_MALLOC_FAILED); + EXPECT_EQ(driver->NewConnectionCB(conn), 0); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, ConnectBranch) +{ + std::string payload{}; + UBSHcomNetEndpointPtr ep = nullptr; + driver->mInited = false; + + driver->mOptions.oobType = NET_OOB_TCP; + EXPECT_EQ(driver->Connect(payload, ep, 0, 0, 0), NN_ERROR); + + driver->mOptions.oobType = NET_OOB_UDS; + EXPECT_EQ(driver->Connect(payload, ep, 0, 0, 0), NN_ERROR); + + driver->mOptions.oobType = NET_OOB_TCP; +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectInitErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload{}; + + driver->mInited = false; + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); + + driver->mInited = true; + driver->mStarted = false; + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectParamErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload(NN_NO2048, 'a'); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectTCPErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs() + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), 1); + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_OOB_SEC_PROCESS_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectTCPRecvErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(returnValue(ssize_t(0))); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_OOB_CONN_RECEIVE_ERROR); +} + +ConnectResp respCode = ConnectResp::OK; +int8_t ready = -1; +static ssize_t MockConnRecv(int socket, void *buf, size_t size, int flags) +{ + switch (size) { + case sizeof(ConnRespWithUId): { + ConnRespWithUId *tmpConnResp = reinterpret_cast(buf); + tmpConnResp->connResp = respCode; + break; + } + case sizeof(uint32_t): { + uint32_t *tmp = reinterpret_cast(buf); + *tmp = 1; + break; + } + case sizeof(int8_t): { + int8_t *tmp = reinterpret_cast(buf); + *tmp = ready; + break; + } + default: + break; + } + + return size; +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectTCPAckErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + + respCode = MAGIC_MISMATCH; + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_CONNECT_REFUSED); + + respCode = WORKER_GRPNO_MISMATCH; + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_CONNECT_REFUSED); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectProtoErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + + respCode = PROTOCOL_MISMATCH; + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_CONNECT_PROTOCOL_MISMATCH); + + respCode = SERVER_INTERNAL_ERROR; + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectElseErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + + respCode = CONN_ACCEPT_QUEUE_FULL; + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectWorkerErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(false)); + driver->mWorkers.emplace_back(worker); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectQpCreateErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + driver->mWorkers.emplace_back(worker); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectQpErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + driver->mWorkers.emplace_back(worker); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectUBCHeartBeatErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + driver->mWorkers.emplace_back(worker); + driver->mHeartBeat = new (std::nothrow) NetHeartbeat(driver, NN_NO60, NN_NO2); + qp->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::CreateHBMemoryRegion).stubs().will(returnValue(0)).then(returnValue(1)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), 1); + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), 1); + driver->mWorkers.clear(); + if (driver->mHeartBeat != nullptr) { + delete driver->mHeartBeat; + driver->mHeartBeat = nullptr; + } +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectExchangeErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + driver->mWorkers.emplace_back(worker); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), 1); + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), 1); + + // 必须放在最后 MOCK,保证之前通过 std::nothrow new 分配的实例已完成。否则会 + // 遇到 NetLogger 的 this 为空. + // + // 并且在正常流程中也会遇到 OOBTCPClient, OOBTCPConnection, UBJetty 等通过 + // std::nothrow 版本的 new 来分配内存,这些必须避免。 + MOCKER_CPP(&operator new, void *(*)(size_t, const std::nothrow_t &)) + .stubs() + .will(invoke(NewExceptFor)); + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_MALLOC_FAILED); + + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectPostrecvErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + driver->mWorkers.emplace_back(worker); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(false)).then(returnValue(true)); + MOCKER_CPP(&UBWorker::PostReceive).stubs().will(returnValue(1)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), 1); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectReadyErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + driver->mWorkers.emplace_back(worker); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::PostReceive).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + MOCKER_CPP(&UBJetty::ReturnBuffer).stubs().will(returnValue(true)); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); + ready = -1; + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_ERROR); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, AsyncConnectSuccess) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + ready = 1; + driver->mWorkers.emplace_back(worker); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&NetWorkerLB::ChooseWorker).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::IsWorkStarted).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::CreateQP).stubs().with(outBound(qp)).will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::PostReceive).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + MOCKER_CPP(&UBJetty::ReturnBuffer).stubs().will(returnValue(true)); + + EXPECT_EQ(driver->Connect(oobIp, oobPort, payload, outEp, 0, 0, 0, 0), NN_OK); + driver->mWorkers.clear(); +} + +TEST_F(TestNetDriverUBWithOob, ConnectSyncEp) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs() + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(1)); + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), 1); + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_OOB_SEC_PROCESS_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, ConnectSyncEpCreateResourcesErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(NetUBSyncEndpoint::CreateResources).stubs() + .with(any(), any(), any(), any(), outBound(qp), outBound(jfc)) + .will(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), 1); + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, ConnectSyncEpQpErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(NetUBSyncEndpoint::CreateResources).stubs() + .with(any(), any(), any(), any(), outBound(qp), outBound(jfc)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, ConnectSyncEpTCPAckErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(NetUBSyncEndpoint::CreateResources).stubs() + .with(any(), any(), any(), any(), outBound(qp), outBound(jfc)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + + respCode = MAGIC_MISMATCH; + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_CONNECT_REFUSED); + + respCode = WORKER_GRPNO_MISMATCH; + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_CONNECT_REFUSED); +} + +TEST_F(TestNetDriverUBWithOob, ConnectSyncEpProtoErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(NetUBSyncEndpoint::CreateResources).stubs() + .with(any(), any(), any(), any(), outBound(qp), outBound(jfc)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + + respCode = PROTOCOL_MISMATCH; + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_CONNECT_PROTOCOL_MISMATCH); + + respCode = SERVER_INTERNAL_ERROR; + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, ConnectSyncEpElseErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(NetUBSyncEndpoint::CreateResources).stubs() + .with(any(), any(), any(), any(), outBound(qp), outBound(jfc)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + + respCode = CONN_ACCEPT_QUEUE_FULL; + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, ConnectSyncExchangeErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(NetUBSyncEndpoint::CreateResources).stubs() + .with(any(), any(), any(), any(), outBound(qp), outBound(jfc)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(1)); + + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), 1); + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), 1); + + // 必须放在最后 MOCK,保证之前通过 std::nothrow new 分配的实例已完成。否则会 + // 遇到 NetLogger 的 this 为空. + // + // 并且在正常流程中也会遇到 OOBTCPClient, OOBTCPConnection, UBJetty 等通过 + // std::nothrow 版本的 new 来分配内存,这些必须避免。 + MOCKER_CPP(&operator new, void *(*)(size_t, const std::nothrow_t &)) + .stubs() + .will(invoke(NewExceptFor)); + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_MALLOC_FAILED); +} + +TEST_F(TestNetDriverUBWithOob, ConnectSyncPostrecvErr) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(NetUBSyncEndpoint::CreateResources).stubs() + .with(any(), any(), any(), any(), outBound(qp), outBound(jfc)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(false)).then(returnValue(true)); + MOCKER_CPP(&NetUBSyncEndpoint::PostReceive).stubs().will(returnValue(1)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_ERROR); + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), 1); +} + +TEST_F(TestNetDriverUBWithOob, ConnectSyncSuccess) +{ + std::string oobIp = "192.168.1.1"; + uint16_t oobPort = 1; + UBSHcomNetEndpointPtr outEp = nullptr; + std::string payload("hello world"); + respCode = ConnectResp::OK; + + MOCKER(OOBTCPClient::ConnectWithFd, NResult(const std::string &, uint32_t, int &)).stubs().will(returnValue(0)); + MOCKER(OOBSecureProcess::SecProcessInOOBClient).stubs().will(returnValue(0)); + MOCKER(NetUBSyncEndpoint::CreateResources).stubs() + .with(any(), any(), any(), any(), outBound(qp), outBound(jfc)) + .will(returnValue(0)); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::Initialize).stubs().will(returnValue(0)); + MOCKER(::send).stubs().will(invoke(MockConnSend)); + MOCKER(::recv).stubs().will(invoke(MockConnRecv)); + MOCKER_CPP(&UBJetty::FillExchangeInfo).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::ChangeToReady).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetFreeBufferN).stubs().will(returnValue(true)); + MOCKER_CPP(&NetUBSyncEndpoint::PostReceive).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetMemorySeg).stubs().will(returnValue(&tSeg)); + + ready = -1; + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_ERROR); + ready = 1; + EXPECT_EQ(driver->ConnectSyncEp(oobIp, oobPort, payload, outEp, 0, 0, 0), NN_OK); +} + +TEST_F(TestNetDriverUBWithOob, ProcessErrorNewRequestParamErr) +{ + EXPECT_NO_FATAL_FAILURE(driver->ProcessErrorNewRequest(nullptr)); +} + +TEST_F(TestNetDriverUBWithOob, ProcessErrorNewRequest) +{ + ctxInfo.opType = UBOpContextInfo::RECEIVE; + MOCKER_CPP(&UBJetty::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::ReturnOpContextInfo).stubs().will(ignoreReturnValue()); + EXPECT_NO_FATAL_FAILURE(driver->ProcessErrorNewRequest(&ctxInfo)); + + ctxInfo.opType = UBOpContextInfo::SEND_RAW; + EXPECT_NO_FATAL_FAILURE(driver->ProcessErrorNewRequest(&ctxInfo)); +} + +TEST_F(TestNetDriverUBWithOob, SendRawSglFinishedCB) +{ + UBSHcomNetRequestContext netCtx{}; + UBSglContextInfo sglCtx{}; + ctxInfo.upCtxSize = static_cast(sizeof(UBSgeCtxInfo)); + auto upCtx = static_cast((void *)&(ctxInfo.upCtx)); + upCtx->ctx = &sglCtx; + + MOCKER_CPP(&UBWorker::ReturnSglContextInfo).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBWorker::ReturnOpContextInfo).stubs().will(ignoreReturnValue()); + EXPECT_EQ(driver->SendRawSglFinishedCB(&ctxInfo, netCtx), NN_OK); +} + +TEST_F(TestNetDriverUBWithOob, SendFinishedCB) +{ + ctxInfo.opType = UBOpContextInfo::SEND; + ctxInfo.upCtxSize = 1; + + MOCKER_CPP(&UBJetty::ReturnPostSendWr).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBMemoryRegionFixedBuffer::ReturnBuffer).stubs().will(returnValue(false)); + MOCKER_CPP(&UBWorker::ReturnOpContextInfo).stubs().will(ignoreReturnValue()); + EXPECT_EQ(driver->SendFinishedCB(&ctxInfo), NN_OK); +} + +TEST_F(TestNetDriverUBWithOob, ProcessErrorSendFinished) +{ + EXPECT_NO_FATAL_FAILURE(driver->ProcessErrorSendFinished(nullptr)); +} + +TEST_F(TestNetDriverUBWithOob, RWSglOneSideDoneCB) +{ + UBSHcomNetRequestContext netCtx{}; + UBSglContextInfo sglCtx{}; + ctxInfo.upCtxSize = static_cast(sizeof(UBSgeCtxInfo)); + auto upCtx = static_cast((void *)&(ctxInfo.upCtx)); + upCtx->ctx = &sglCtx; + sglCtx.iovCount = 1; + + MOCKER_CPP(&UBWorker::ReturnSglContextInfo).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBWorker::ReturnOpContextInfo).stubs().will(ignoreReturnValue()); + EXPECT_EQ(driver->RWSglOneSideDoneCB(&ctxInfo, netCtx), NN_OK); +} + +TEST_F(TestNetDriverUBWithOob, OneSideDoneCB) +{ + ctxInfo.opType = UBOpContextInfo::WRITE; + ctxInfo.upCtxSize = 1; + + MOCKER_CPP(&UBJetty::ReturnOneSideWr).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBWorker::ReturnOpContextInfo).stubs().will(ignoreReturnValue()); + EXPECT_EQ(driver->OneSideDoneCB(&ctxInfo), NN_OK); +} + +TEST_F(TestNetDriverUBWithOob, ProcessErrorOneSideDone) +{ + EXPECT_NO_FATAL_FAILURE(driver->ProcessErrorOneSideDone(nullptr)); +} + +TEST_F(TestNetDriverUBWithOob, ProcessEpError) +{ + UBOpContextInfo remainingOpCtx{}; + UBOpContextInfo nextOpCtx{}; + remainingOpCtx.next = &nextOpCtx; + uint32_t a = 1; + uint32_t b = 0; + remainingOpCtx.opType = UBOpContextInfo::OpType::SEND; + remainingOpCtx.opResultType = UBOpContextInfo::SUCCESS; + nextOpCtx.opType = UBOpContextInfo::OpType::WRITE; + remainingOpCtx.opResultType = UBOpContextInfo::SUCCESS; + CallbackEp->State().Set(NEP_ESTABLISHED); + + MOCKER_CPP(&UBJetty::Stop).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::GetCtxPosted).stubs().with(outBound(&remainingOpCtx)).will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverUBWithOob::ProcessErrorSendFinished).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverUBWithOob::ProcessErrorOneSideDone).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBJetty::GetPostedCount).stubs().will(returnValue(a)).then(returnValue(b)); + + EXPECT_NO_FATAL_FAILURE(driver->ProcessEpError(reinterpret_cast(CallbackEp))); +} + +TEST_F(TestNetDriverUBWithOob, ProcessQPError) +{ + EXPECT_NO_FATAL_FAILURE(driver->ProcessQPError(nullptr)); +} + +TEST_F(TestNetDriverUBWithOob, ProcessTwoSideHeartbeat) +{ + UBSHcomNetRequestContext netCtx{}; + netCtx.mHeader.opCode == HB_SEND_OP; + MOCKER_CPP(&UBJetty::PostSend).stubs().will(returnValue(0)); + EXPECT_NO_FATAL_FAILURE(driver->ProcessTwoSideHeartbeat(&ctxInfo, netCtx)); + + netCtx.mHeader.opCode == HB_RECV_OP; + MOCKER_CPP(&NetUBAsyncEndpoint::HbRecordCount).stubs().will(ignoreReturnValue()); + EXPECT_NO_FATAL_FAILURE(driver->ProcessTwoSideHeartbeat(&ctxInfo, netCtx)); +} + +TEST_F(TestNetDriverUBWithOob, NewRequestParamErr) +{ + EXPECT_EQ(driver->NewRequest(nullptr), NN_ERROR); + + ctxInfo.opResultType = UBOpContextInfo::ERR_TIMEOUT; + MOCKER_CPP(&NetDriverUBWithOob::ProcessQPError).stubs().will(ignoreReturnValue()); + EXPECT_EQ(driver->NewRequest(&ctxInfo), NN_OK); +} + +TEST_F(TestNetDriverUBWithOob, NewRequestRecvRaw) +{ + ctxInfo.opResultType = UBOpContextInfo::SUCCESS; + ctxInfo.opType = UBOpContextInfo::RECEIVE; + // imm_data + int *tmp = reinterpret_cast(ctxInfo.upCtx); + *tmp = 1; + // mrMemAddr free in callback + UBSHcomNetTransHeader *header = (UBSHcomNetTransHeader *)malloc(sizeof(UBSHcomNetTransHeader) + NN_NO8); + ctxInfo.dataSize = NN_NO8; + ctxInfo.mrMemAddr = reinterpret_cast(header); + MOCKER(NetFunc::ValidateHeaderCrc32, bool(UBSHcomNetTransHeader *)).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::RePostReceive).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->NewRequest(&ctxInfo), NN_OK); + free(header); +} + +TEST_F(TestNetDriverUBWithOob, SendFinishedParamErr) +{ + EXPECT_EQ(driver->SendFinished(nullptr), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, SendFinished) +{ + ctxInfo.opResultType = UBOpContextInfo::ERR_TIMEOUT; + MOCKER_CPP(&NetDriverUBWithOob::ProcessQPError).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverUBWithOob::SendFinishedCB).stubs().will(returnValue(0)); + EXPECT_EQ(driver->SendFinished(&ctxInfo), NN_OK); + + ctxInfo.opResultType = UBOpContextInfo::SUCCESS; + EXPECT_EQ(driver->SendFinished(&ctxInfo), 0); +} + +TEST_F(TestNetDriverUBWithOob, OneSideDoneParamErr) +{ + EXPECT_EQ(driver->OneSideDone(nullptr), NN_ERROR); +} + +TEST_F(TestNetDriverUBWithOob, OneSideDone) +{ + ctxInfo.opResultType = UBOpContextInfo::ERR_TIMEOUT; + MOCKER_CPP(&NetDriverUBWithOob::ProcessQPError).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&NetDriverUBWithOob::OneSideDoneCB).stubs().will(returnValue(0)); + EXPECT_EQ(driver->OneSideDone(&ctxInfo), NN_OK); + + ctxInfo.opResultType = UBOpContextInfo::SUCCESS; + EXPECT_EQ(driver->OneSideDone(&ctxInfo), 0); +} + +std::string MockUBDetailName() +{ + std::string str = "test"; + return str; +} + +TEST_F(TestNetDriverUBWithOob, NewRequestGetEpErr) +{ + UBJetty *qp1 = (UBJetty *)malloc(sizeof(UBJetty)); + UBWorker *fakeWorker = (UBWorker *)malloc(sizeof(UBWorker)); + char fakeUpCtx[NN_NO16] = {}; + + ctxInfo.ubJetty = qp1; + ctxInfo.opResultType = UBOpContextInfo::SUCCESS; + memcpy_s(ctxInfo.upCtx, sizeof(ctxInfo.upCtx), fakeUpCtx, sizeof(ctxInfo.upCtx)); + ctxInfo.opType = UBOpContextInfo::RECEIVE; + driver->mOptions.enableTls = true; + + MOCKER_CPP(&UBJetty::GetUpContext1, uintptr_t(UBJetty::*)() const).stubs() + .will(returnValue(reinterpret_cast(fakeWorker))); + MOCKER_CPP(&UBJetty::GetUpContext, uintptr_t(UBJetty::*)() const).stubs() + .will(returnValue(static_cast(0))); + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBWorker::RePostReceive).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->NewRequest(&ctxInfo), NN_ERROR); + free(qp1); + free(fakeWorker); +} + +TEST_F(TestNetDriverUBWithOob, NewReceivedRawRequestGetEpErr) +{ + UBSHcomNetRequestContext netCtx{}; + UBSHcomNetMessage msg{}; + UBJetty *qp1 = (UBJetty *)malloc(sizeof(UBJetty)); + ctxInfo.ubJetty = qp1; + driver->mOptions.enableTls = true; + + MOCKER_CPP(&UBJetty::GetUpContext, uintptr_t(UBJetty::*)() const).stubs() + .will(returnValue(static_cast(0))); + MOCKER_CPP(&UBWorker::RePostReceive).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->NewReceivedRawRequest(&ctxInfo, netCtx, msg, nullptr, 0), NN_ERROR); + free(qp1); +} + +TEST_F(TestNetDriverUBWithOob, NewReceivedRequestGetEpErr) +{ + UBSHcomNetRequestContext netCtx{}; + UBSHcomNetMessage msg{}; + UBJetty *qp1 = (UBJetty *)malloc(sizeof(UBJetty)); + ctxInfo.ubJetty = qp1; + driver->mOptions.enableTls = true; + + MOCKER_CPP(&UBJetty::GetUpContext, uintptr_t(UBJetty::*)() const).stubs() + .will(returnValue(static_cast(0))); + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBWorker::RePostReceive).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->NewReceivedRequest(&ctxInfo, netCtx, msg, nullptr), NN_ERROR); + free(qp1); +} + +TEST_F(TestNetDriverUBWithOob, NewReceivedRequestEnableTlsOnEpEnctypeTrueDecryptFail) +{ + UBSHcomNetRequestContext netCtx{}; + UBSHcomNetMessage msg{}; + UBJetty *qp1 = (UBJetty *)malloc(sizeof(UBJetty)); + ctxInfo.ubJetty = qp1; + ctxInfo.ubJetty->mUpContext1 = reinterpret_cast(worker); + ctxInfo.ubJetty->mUpContext = reinterpret_cast(CallbackEp); + UBSHcomNetTransHeader header{}; + ctxInfo.mrMemAddr = reinterpret_cast(&header); + driver->mOptions.enableTls = true; + CallbackEp->mIsNeedEncrypt = true; + + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::GetRawLen).stubs().will(returnValue(1)); + MOCKER_CPP(&AesGcm128::Decrypt).stubs().will(returnValue(false)); + MOCKER_CPP(&UBWorker::RePostReceive).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->NewReceivedRequest(&ctxInfo, netCtx, msg, nullptr), NN_DECRYPT_FAILED); + free(qp1); +} + +TEST_F(TestNetDriverUBWithOob, NewReceivedRequestEnableTlsOnEpEnctypeTrueDecryptSuccess) +{ + UBSHcomNetRequestContext netCtx{}; + UBSHcomNetMessage msg{}; + UBJetty *qp1 = (UBJetty *)malloc(sizeof(UBJetty)); + ctxInfo.ubJetty = qp1; + ctxInfo.ubJetty->mUpContext1 = reinterpret_cast(worker); + ctxInfo.ubJetty->mUpContext = reinterpret_cast(CallbackEp); + UBSHcomNetTransHeader header{}; + ctxInfo.mrMemAddr = reinterpret_cast(&header); + driver->mOptions.enableTls = true; + CallbackEp->mIsNeedEncrypt = true; + + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::GetRawLen).stubs().will(returnValue(1)); + MOCKER_CPP(&AesGcm128::Decrypt).stubs().will(returnValue(true)); + MOCKER_CPP(&UBWorker::RePostReceive).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->NewReceivedRequest(&ctxInfo, netCtx, msg, nullptr), NN_OK); + free(qp1); +} + +TEST_F(TestNetDriverUBWithOob, NewReceivedRequestEnableTlsOnEpEnctypeFalse) +{ + UBSHcomNetRequestContext netCtx{}; + UBSHcomNetMessage msg{}; + UBJetty *qp1 = (UBJetty *)malloc(sizeof(UBJetty)); + ctxInfo.ubJetty = qp1; + ctxInfo.ubJetty->mUpContext1 = reinterpret_cast(worker); + ctxInfo.ubJetty->mUpContext = reinterpret_cast(CallbackEp); + UBSHcomNetTransHeader header{}; + ctxInfo.mrMemAddr = reinterpret_cast(&header); + driver->mOptions.enableTls = true; + CallbackEp->mIsNeedEncrypt = false; + + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBWorker::RePostReceive).stubs().will(returnValue(0)); + EXPECT_EQ(driver->NewReceivedRequest(&ctxInfo, netCtx, msg, nullptr), NN_INVALID_PARAM); + free(qp1); +} + +TEST_F(TestNetDriverUBWithOob, NewReceivedRequestEnableTlsOff) +{ + UBSHcomNetRequestContext netCtx{}; + UBSHcomNetMessage msg{}; + UBJetty *qp1 = (UBJetty *)malloc(sizeof(UBJetty)); + ctxInfo.ubJetty = qp1; + driver->mOptions.enableTls = false; + + MOCKER_CPP(&UBJetty::GetUpContext, uintptr_t(UBJetty::*)() const).stubs() + .will(returnValue(static_cast(0))); + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize).stubs().will(returnValue(0)); + MOCKER_CPP(&NetDriverUBWithOob::NewReceivedRequestWithoutCopy).stubs().will(returnValue(0)); + + EXPECT_EQ(driver->NewReceivedRequest(&ctxInfo, netCtx, msg, nullptr), NN_OK); + free(qp1); +} + +TEST_F(TestNetDriverUBWithOob, NewReceivedRequestWithoutCopy) +{ + UBSHcomNetRequestContext netCtx{}; + UBJetty *qp1 = (UBJetty *)malloc(sizeof(UBJetty)); + ctxInfo.ubJetty = qp1; + ctxInfo.ubJetty->mUpContext1 = reinterpret_cast(worker); + ctxInfo.ubJetty->mUpContext = reinterpret_cast(CallbackEp); + UBSHcomNetMessage msg; + UBSHcomNetTransHeader header{}; + header.dataLength = 10; + + MOCKER_CPP(&UBWorker::RePostReceive).stubs().will(returnValue(0)); + auto result = driver->NewReceivedRequestWithoutCopy(&ctxInfo, netCtx, msg, worker, nullptr, &header); + EXPECT_EQ(result, NN_OK); + free(qp1); +} + +TEST_F(TestNetDriverUBWithOob, NewRequestOnEncryption) +{ + UBSHcomNetRequestContext netCtx{}; + UBSHcomNetMessage msg{}; + bool messageReady = true; + ctxInfo.ubJetty->mUpContext1 = reinterpret_cast(worker); + ctxInfo.ubJetty->mUpContext = reinterpret_cast(CallbackEp); + UBSHcomNetTransHeader header{}; + ctxInfo.mrMemAddr = reinterpret_cast(&header); + + MOCKER_CPP(&UBWorker::RePostReceive).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::GetRawLen).stubs().will(returnValue(1)); + MOCKER_CPP(&AesGcm128::Decrypt).stubs().will(returnValue(false)).then(returnValue(true)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(driver->NewRequestOnEncryption(nullptr, msg, messageReady, netCtx), NN_INVALID_PARAM); + CallbackEp->mIsNeedEncrypt = false; + EXPECT_EQ(driver->NewRequestOnEncryption(&ctxInfo, msg, messageReady, netCtx), NN_INVALID_PARAM); + CallbackEp->mIsNeedEncrypt = true; + EXPECT_EQ(driver->NewRequestOnEncryption(&ctxInfo, msg, messageReady, netCtx), NN_DECRYPT_FAILED); + EXPECT_EQ(driver->NewRequestOnEncryption(&ctxInfo, msg, messageReady, netCtx), NN_ERROR); + EXPECT_EQ(driver->NewRequestOnEncryption(&ctxInfo, msg, messageReady, netCtx), NN_OK); + if (msg.mBuf != nullptr) { + free(msg.mBuf); + msg.mBuf = nullptr; + } +} + +TEST_F(TestNetDriverUBWithOob, NetOobListenerOptionsSetEid) +{ + UBSHcomNetOobListenerOptions opt{}; + std::string eid = "00000000000000000000ffffc0a80164"; + EXPECT_EQ(opt.SetEid(eid, 0, 0), true); +} + +TEST_F(TestNetDriverUBWithOob, OobEidAndJettyId) +{ + std::string eid = "0000:0000:0000:0000:0000:ffff:c0a8:0164"; + std::string ip = "1.2.3.4"; + + MOCKER_CPP(HexStringToBuff).stubs().will(returnValue(false)).then(returnValue(true)); + driver->mStartOobSvr = true; + EXPECT_NO_FATAL_FAILURE(driver->OobEidAndJettyId(ip, 0)); + EXPECT_NO_FATAL_FAILURE(driver->OobEidAndJettyId(eid, 0)); + EXPECT_NO_FATAL_FAILURE(driver->OobEidAndJettyId(eid, NN_NO64)); + EXPECT_NO_FATAL_FAILURE(driver->OobEidAndJettyId(eid, NN_NO64)); + + driver->mStartOobSvr = false; + EXPECT_NO_FATAL_FAILURE(driver->OobEidAndJettyId(eid, NN_NO64)); +} + +TEST_F(TestNetDriverUBWithOob, connectUrl) +{ + int ret; + driver->mInited = false; + std::string badUrl = "unknown://127.0.0.1"; + std::string serverUrl = "tcp://127.0.0.1:9981"; + std::string payload{}; + UBSHcomNetEndpointPtr outEp; + + ret = driver->Connect(badUrl, payload, outEp, 0, 0, 0, 0); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + ret = driver->Connect(serverUrl, payload, outEp, 0, 0, 0, 0); + EXPECT_EQ(ret, NN_ERROR); +} + +} +} +#endif diff --git a/test/unit_test/transport/ub/test_net_ub_endpoint.cpp b/test/unit_test/transport/ub/test_net_ub_endpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..243eb8b6478eb2452def55848fbba22d2939a50a --- /dev/null +++ b/test/unit_test/transport/ub/test_net_ub_endpoint.cpp @@ -0,0 +1,1801 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include +#include + +#include "hcom.h" +#include "ub_common.h" +#include "ub_worker.h" +#include "net_ub_driver_oob.h" +#include "net_ub_endpoint.h" + +#include "net_monotonic.h" +#include "net_security_alg.h" +#include "hcom_utils.h" +#include "ub_urma_wrapper_jetty.h" + +namespace ock { +namespace hcom { + +class TestNetUBAsyncEndpoint : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name; + NetDriverUBWithOob *mDriver = nullptr; + UBContext *ctx = nullptr; + UBWorker *mWorker = nullptr; + UBJfc *cq = nullptr; + UBJetty *qp = nullptr; + UBSHcomNetWorkerIndex mWorkerIndex; + UBSHcomNetTransRequest request; + UBSHcomNetTransSglRequest sglRequest; + UBSHcomNetTransSgeIov *iov = nullptr; + NetUBAsyncEndpoint *NEP = nullptr; + UBMemoryRegionFixedBuffer *Mr = nullptr; + NetHeartbeat *mHeartBeat = nullptr; +}; + +void TestNetUBAsyncEndpoint::SetUp() +{ + UBEId eid{}; + UBContext::Create("test_net_ub_endpoint", eid, ctx); + ASSERT_NE(ctx, nullptr); + + bool startOobSvr = true; + UBSHcomNetDriverProtocol protocol = UBC; + mDriver = new (std::nothrow) NetDriverUBWithOob(name, startOobSvr, protocol); + mDriver->mStarted = true; + Mr = mDriver->mDriverSendMR = new (std::nothrow) UBMemoryRegionFixedBuffer(name, ctx, 1, 1, 1); + ASSERT_NE(mDriver, nullptr); + + UBWorkerOptions options; + NetMemPoolFixedPtr memPool; + NetMemPoolFixedPtr sglMemPool; + mWorker = new (std::nothrow) UBWorker(name, ctx, options, memPool, sglMemPool); + ASSERT_NE(mWorker, nullptr); + + cq = new (std::nothrow) UBJfc(name, ctx, false, 0); + ASSERT_NE(cq, nullptr); + + JettyOptions jettyOptions; + qp = new (std::nothrow) UBJetty(name, 0, ctx, cq, jettyOptions); + ASSERT_NE(qp, nullptr); + + qp->StoreExchangeInfo(new UBJettyExchangeInfo); + + NEP = new (std::nothrow) NetUBAsyncEndpoint(0, qp, mDriver, mWorker); + ASSERT_NE(NEP, nullptr); + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mAllowedSize = NN_NO128; + NEP->mSegSize = NN_NO128; + + mHeartBeat = new (std::nothrow) NetHeartbeat(mDriver, NN_NO60, NN_NO2); + + request.lAddress = reinterpret_cast(&mWorkerIndex); + request.rAddress = reinterpret_cast(&mWorkerIndex); + request.size = 1; + + iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + ASSERT_NE(iov, nullptr); + iov->lAddress = reinterpret_cast(&mWorkerIndex); + iov->rAddress = reinterpret_cast(&mWorkerIndex); + iov->size = 1; + sglRequest = UBSHcomNetTransSglRequest(iov, 1, 1); +} + +void TestNetUBAsyncEndpoint::TearDown() +{ + qp->mSendJfc = nullptr; + qp->mRecvJfc = nullptr; + qp->mUrmaJetty = nullptr; + qp->mJettyMr = nullptr; + cq->mUBContext = nullptr; + cq->mUrmaJfc = nullptr; + if (cq != nullptr) { + delete cq; + cq = nullptr; + } + if (Mr != nullptr) { + delete Mr; + Mr = nullptr; + } + if (NEP != nullptr) { + delete NEP; + NEP = nullptr; + } + if (iov != nullptr) { + delete iov; + iov = nullptr; + } + GlobalMockObject::verify(); +} +static UBSHcomNetTransHeader mockMrBuf{}; +static bool MockGetFreeBuffer(uintptr_t &mrBufAddress) +{ + mrBufAddress = reinterpret_cast(&mockMrBuf); + return true; +} + +static NResult FakePollingCompletion(UBOpContextInfo *&ctx, int32_t timeout, uint32_t &immData) +{ + static char buf[128]; + static UBOpContextInfo info; + info.dataSize = 1; + info.mrMemAddr = reinterpret_cast(buf); // used by NetUBSyncEndpoint::ReceiveRawHandle + + ctx = &info; + return NN_OK; +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendSeqFailed) +{ + name = "NetUBAsyncEndpointPostSendSeqFailed"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendSeq) +{ + name = "NetUBAsyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(false)); + + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + NEP->mIsNeedEncrypt = 1; + ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendSeqTwo) +{ + name = "NetUBAsyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostRead) +{ + name = "NetUBAsyncEndpointPostRead"; + MOCKER_CPP(&UBWorker::PostRead, UResult(UBWorker::*)(UBJetty *, const UBSendReadWriteRequest &)) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + + int ret = 0; + uint32_t key = sglRequest.iov[0].lKey; + mDriver->mMapTseg.emplace(key, nullptr); + + ret = NEP->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostWrite) +{ + name = "NetUBAsyncEndpointPostWrite"; + MOCKER_CPP(&UBWorker::PostWrite, + UResult(UBWorker::*)(UBJetty *, const UBSendReadWriteRequest &, UBOpContextInfo::OpType)) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + int ret = 0; + uint32_t key = sglRequest.iov[0].lKey; + mDriver->mMapTseg.emplace(key, nullptr); + + ret = NEP->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, PostSendSglInlineEncrypt) +{ + name = "NetUBAsyncEndpointPostSendSglInline"; + NEP->mIsNeedEncrypt = true; + + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(true)); + + UBSHcomNetTransOpInfo OpInfo{}; + auto ret = NEP->PostSendSglInline(0, request, OpInfo); + EXPECT_EQ(ret, NN_OK); +} + +TEST_F(TestNetUBAsyncEndpoint, PostSendSglInlineNotUB) +{ + name = "NetUBAsyncEndpointPostSendSglInline"; + NEP->mIsNeedEncrypt = false; + NEP->mJetty->mUBContext->protocol = UBSHcomNetDriverProtocol::RDMA; + + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(true)); + + UBSHcomNetTransOpInfo OpInfo{}; + auto ret = NEP->PostSendSglInline(0, request, OpInfo); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetUBAsyncEndpoint, PostSendSglInlineUBSuccess) +{ + name = "NetUBAsyncEndpointPostSendSglInline"; + NEP->mIsNeedEncrypt = false; + NEP->mJetty->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + + MOCKER_CPP(&UBWorker::PostSendSglInline) + .stubs() + .will(returnValue(0)); + + UBSHcomNetTransOpInfo OpInfo{}; + auto ret = NEP->PostSendSglInline(0, request, OpInfo); + EXPECT_EQ(ret, 0); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendInfoFailed) +{ + name = "NetUBAsyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + UBSHcomNetTransOpInfo OpInfo{}; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendOpInfo) +{ + name = "NetUBAsyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(false)); + + UBSHcomNetTransOpInfo OpInfo{}; + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + NEP->mIsNeedEncrypt = 1; + ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendOpInfoTwo) +{ + name = "NetUBAsyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + UBSHcomNetTransOpInfo OpInfo{}; + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendOpInfoWithHeaderRaw) +{ + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader; + int ret; + + extHeaderType = UBSHcomExtHeaderType::RAW; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendOpInfoWithHeaderNull) +{ + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, nullptr, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendOpInfoWithHeaderBuffer) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer).stubs().will(returnValue(false)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_GET_BUFF_FAILED); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendOpInfoWithHeaderMemcpy) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer).stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendOpInfoWithHeaderWorkerSend) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer).stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(static_cast(NN_OK))) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_OK); + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, RR_QP_POST_SEND_FAILED); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendAll) +{ + name = "NetUBAsyncEndpointPostSendAll"; + + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&NetDriverUB::ValidateMemoryRegion, NResult(NetDriverUB::*)(uint64_t, uintptr_t, uint64_t)).stubs() + .will(returnValue(0)); + + MOCKER_CPP(&UBWorker::PostSend, NResult(UBWorker::*)(UBJetty *, const UBSendReadWriteRequest &, + urma_target_seg_t *, uint32_t)).stubs().will(returnValue(0)); + + MOCKER_CPP(&UBJetty::GetUpContext1, uintptr_t(UBJetty::*)() const).stubs() + .will(returnValue(reinterpret_cast(mWorker))); + + int ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(UB_OK)); + + MOCKER_CPP(&UBWorker::PostSendSgl, NResult(UBWorker::*)(UBJetty *, const UBSHcomNetTransSglRequest &, + const UBSHcomNetTransRequest &, uint32_t, bool)).stubs().will(returnValue(0)); + + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); + + uint32_t key = sglRequest.iov[0].lKey; + mDriver->mMapTseg.emplace(key, nullptr); + + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(UB_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendRawAllTwo) +{ + MOCKER_CPP(&UBJetty::GetUpContext1, uintptr_t(UBJetty::*)() const).stubs() + .will(returnValue(reinterpret_cast(mWorker))); + + MOCKER_CPP(&UBWorker::PostOneSideSgl) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + + int ret = NEP->PostRead(sglRequest); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); + + uint32_t key = sglRequest.iov[0].lKey; + mDriver->mMapTseg.emplace(key, nullptr); + + ret = NEP->PostRead(sglRequest); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostRead(sglRequest); + EXPECT_EQ(ret, static_cast(UB_OK)); + + ret = NEP->PostWrite(sglRequest); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostWrite(sglRequest); + EXPECT_EQ(ret, static_cast(UB_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendRawGetBufferErr) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendRawCopyErr) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + NEP->mIsNeedEncrypt = false; + int ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); + + NEP->mIsNeedEncrypt = true; + ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendRaw) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + + NEP->mIsNeedEncrypt = false; + int ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPostSendRawSgl) +{ + MOCKER_CPP(&UBWorker::PostSendSgl) + .stubs() + .will(returnValue(static_cast(RR_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + int ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); + + uint32_t key = sglRequest.iov[0].lKey; + mDriver->mMapTseg.emplace(key, nullptr); + NEP->mIsNeedEncrypt = false; + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointSetEpOption) +{ + UBSHcomEpOptions epOptions; + int ret = NEP->SetEpOption(epOptions); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointGetSendQueueCount) +{ + MOCKER_CPP(&UBJetty::GetSendQueueSize).stubs().will(returnValue(1)); + int ret = NEP->GetSendQueueCount(); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointPeerIpAndPort) +{ + NEP->mJetty->mPeerIpPort = "1.2.3.4"; + std::string ret = NEP->PeerIpAndPort(); + EXPECT_EQ(ret, "1.2.3.4"); + + NEP->mJetty = nullptr; + ret = NEP->PeerIpAndPort(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointUdsName) +{ + std::string ret = NEP->UdsName(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointCheckTargetHbTime) +{ + uint64_t currTime = 1; + NEP->mTargetHbTime = 0; + bool ret = NEP->checkTargetHbTime(currTime); + EXPECT_EQ(ret, true); + + NEP->mTargetHbTime = 1; + ret = NEP->checkTargetHbTime(currTime); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointInvalidOperation) +{ + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + int ret = NEP->WaitCompletion(timeout); + EXPECT_EQ(ret, static_cast(NN_INVALID_OPERATION)); + ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_INVALID_OPERATION)); + ret = NEP->ReceiveRaw(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_INVALID_OPERATION)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointHbRecordCount) +{ + EXPECT_NO_FATAL_FAILURE(NEP->HbRecordCount()); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointHbCheckStateNormal) +{ + NEP->mHbCount = 1; + NEP->mHbLastCount = 0; + bool ret = NEP->HbCheckStateNormal(); + EXPECT_EQ(ret, true); + + NEP->mHbLastCount = 1; + ret = NEP->HbCheckStateNormal(); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointSetRemoteHbInfo) +{ + uintptr_t address = 0; + uint32_t key = 0; + uint64_t size = 0; + EXPECT_NO_FATAL_FAILURE(NEP->SetRemoteHbInfo(address, key, size)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointSetHbBrokenEp) +{ + EXPECT_NO_FATAL_FAILURE(NEP->SetHbBrokenEp()); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointHbBrokenEp) +{ + NEP->mHbBrokenEp = false; + bool ret = NEP->HbBrokenEp(); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointEnableEncrypt) +{ + UBSHcomNetDriverOptions options{}; + EXPECT_NO_FATAL_FAILURE(NEP->EnableEncrypt(options)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointSetSecrets) +{ + NetSecrets secrets; + EXPECT_NO_FATAL_FAILURE(NEP->SetSecrets(secrets)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointEstimatedEncryptLen) +{ + int ret = NEP->EstimatedEncryptLen(0); + EXPECT_EQ(ret, 0); + NEP->mIsNeedEncrypt = 0; + ret = NEP->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 0); +} +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointEstimatedEncryptLenTwo) +{ + NEP->mIsNeedEncrypt = 1; + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(1)); + int ret = NEP->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointEncrypt) +{ + uint64_t cipherLen = 0; + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(false)); + int ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + NEP->mIsNeedEncrypt = 0; + ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointEncryptTwo) +{ + uint64_t cipherLen = 0; + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(true)); + + NEP->mIsNeedEncrypt = 1; + int ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointEstimatedDecryptLen) +{ + NEP->mIsNeedEncrypt = 0; + int ret = NEP->EstimatedDecryptLen(0); + EXPECT_EQ(ret, 0); + + NEP->mIsNeedEncrypt = 1; + MOCKER_CPP(&AesGcm128::GetRawLen).stubs().will(returnValue(1)); + ret = NEP->EstimatedDecryptLen(0); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointDecrypt) +{ + uint64_t rawLen = 0; + MOCKER_CPP(&AesGcm128::Decrypt).stubs().will(returnValue(false)); + int ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + NEP->mIsNeedEncrypt = 0; + ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointDecryptTwo) +{ + uint64_t rawLen = 0; + MOCKER_CPP(&AesGcm128::Decrypt).stubs().will(returnValue(true)); + + NEP->mIsNeedEncrypt = 1; + int ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointGetRemoteUdsIdInfo) +{ + UBSHcomNetUdsIdInfo verbsIdInfo; + NEP->mState.Set(NEP_NEW); + int ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mDriver->mStartOobSvr = false; + ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointGetRemoteUdsIdInfoTwo) +{ + UBSHcomNetUdsIdInfo verbsIdInfo; + + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mDriver->mStartOobSvr = true; + NEP->mDriver->mOptions.oobType = NET_OOB_TCP; + int ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); + + NEP->mDriver->mOptions.oobType = NET_OOB_UDS; + ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointGetPeerIpPortErr) +{ + std::string ip; + uint16_t port; + bool ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + NEP->mJetty = nullptr; + ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointGetPeerIpPort) +{ + std::string ip; + uint16_t port; + NEP->mJetty->mPeerIpPort = "1.2.3.4"; + bool ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetUBAsyncEndpoint, NetUBAsyncEndpointClose) +{ + EXPECT_NO_FATAL_FAILURE(NEP->Close()); +} + +class TestNetUBSyncEndpoint : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name; + NetDriverUBWithOob *mDriver = nullptr; + UBContext *ctx = nullptr; + UBJfc *cq = nullptr; + UBJetty *qp = nullptr; + UBSHcomNetWorkerIndex mWorkerIndex; + UBSHcomNetTransSglRequest sglRequest; + UBSHcomNetTransRequest request; + UBSHcomNetTransSgeIov *iov = nullptr; + NetUBSyncEndpoint *NEP = nullptr; + UBMemoryRegionFixedBuffer *Mr = nullptr; +}; + +void TestNetUBSyncEndpoint::SetUp() +{ + UBEId gid; + ctx = new (std::nothrow) UBContext(name, gid); + ASSERT_NE(ctx, nullptr); + + bool startOobSvr = true; + UBSHcomNetDriverProtocol protocol = UBC; + mDriver = new (std::nothrow) NetDriverUBWithOob(name, startOobSvr, protocol); + mDriver->mStarted = true; + Mr = mDriver->mDriverSendMR = new (std::nothrow) UBMemoryRegionFixedBuffer(name, ctx, 1, 1, 1); + ASSERT_NE(mDriver, nullptr); + + cq = new (std::nothrow) UBJfc(name, ctx, false, 0); + ASSERT_NE(cq, nullptr); + + JettyOptions jettyOptions; + qp = new (std::nothrow) UBJetty(name, 0, ctx, cq, jettyOptions); + ASSERT_NE(qp, nullptr); + + NEP = new (std::nothrow) NetUBSyncEndpoint(0, qp, cq, 0, mDriver, mWorkerIndex); + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mAllowedSize = NN_NO128; + NEP->mSegSize = NN_NO128; + + request.lAddress = reinterpret_cast(&mWorkerIndex); + request.rAddress = reinterpret_cast(&mWorkerIndex); + request.size = 1; + + iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + iov->lAddress = reinterpret_cast(&mWorkerIndex); + iov->rAddress = reinterpret_cast(&mWorkerIndex); + iov->size = 1; + sglRequest = UBSHcomNetTransSglRequest(iov, 1, 1); +} + +void TestNetUBSyncEndpoint::TearDown() +{ + GlobalMockObject::verify(); + qp->mSendJfc = nullptr; + qp->mRecvJfc = nullptr; + qp->mUrmaJetty = nullptr; + qp->mJettyMr = nullptr; + cq->mUBContext = nullptr; + cq->mUrmaJfc = nullptr; + + if (cq != nullptr) { + delete cq; + cq = nullptr; + } + if (Mr != nullptr) { + delete Mr; + Mr = nullptr; + } + if (NEP != nullptr) { + delete NEP; + NEP = nullptr; + } + if (iov != nullptr) { + delete iov; + iov = nullptr; + } +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendSeqFailed) +{ + name = "NetUBSyncEndpointPostSendSeqFailed"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendSeq) +{ + name = "NetUBSyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(false)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(UB_QP_NOT_INITIALIZED)); + + NEP->mIsNeedEncrypt = 1; + ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendSeqTwo) +{ + name = "NetUBSyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&NetUBSyncEndpoint::InnerPostSend) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSend(0, request, 0); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostRead) +{ + name = "NetUBSyncEndpointPostRead"; + MOCKER_CPP(&NetUBSyncEndpoint::InnerPostRead, NResult(NetUBSyncEndpoint::*)(const UBSendReadWriteRequest &)).stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + + int ret = NEP->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostRead(request); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostWrite) +{ + name = "NetUBSyncEndpointPostWrite"; + MOCKER_CPP(&NetUBSyncEndpoint::InnerPostWrite, NResult(NetUBSyncEndpoint::*)(const UBSendReadWriteRequest &)) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + + int ret = NEP->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostWrite(request); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendInfoFailed) +{ + name = "NetUBSyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + UBSHcomNetTransOpInfo OpInfo{}; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_GET_BUFF_FAILED)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendOpInfo) +{ + name = "NetUBSyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + + MOCKER_CPP(&UBWorker::PostSend) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(false)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + UBSHcomNetTransOpInfo OpInfo{}; + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(UB_QP_NOT_INITIALIZED)); + + NEP->mIsNeedEncrypt = 1; + ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendOpInfoTwo) +{ + name = "NetUBSyncEndpointPostSend"; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&NetUBSyncEndpoint::InnerPostSend) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + UBSHcomNetTransOpInfo OpInfo{}; + NEP->mIsNeedEncrypt = 0; + int ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSend(0, request, OpInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendOpInfoWithHeaderRaw) +{ + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo opInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::RAW; + ret = NEP->PostSend(0, request, opInfo, extHeaderType, &extHeader, sizeof(extHeaderType)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendOpInfoWithHeaderNull) +{ + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo opInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, opInfo, extHeaderType, nullptr, sizeof(extHeaderType)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendOpInfoWithHeaderBuffer) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer).stubs().will(returnValue(false)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo opInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, opInfo, extHeaderType, &extHeader, sizeof(extHeaderType)); + EXPECT_EQ(ret, NN_GET_BUFF_FAILED); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendOpInfoWithHeaderMemory) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer).stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo opInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, opInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, opInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendOpInfoWithHeaderWorkerSend) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer).stubs().will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&NetUBSyncEndpoint::InnerPostSend) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(static_cast(NN_OK))) + .then(returnValue(static_cast(RR_QP_POST_SEND_FAILED))); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + NEP->mIsNeedEncrypt = 0; + + UBSHcomNetTransOpInfo OpInfo{}; + UBSHcomExtHeaderType extHeaderType; + UBSHcomFragmentHeader extHeader{}; + int ret; + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, NN_OK); + + extHeaderType = UBSHcomExtHeaderType::FRAGMENT; + ret = NEP->PostSend(0, request, OpInfo, extHeaderType, &extHeader, sizeof(extHeader)); + EXPECT_EQ(ret, RR_QP_POST_SEND_FAILED); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendAll) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&UBJetty::PostSend) + .stubs() + .will(returnValue(static_cast(UB_OK))); + MOCKER_CPP(&UBJetty::PostSendSgl) + .stubs() + .will(returnValue(static_cast(UB_OK))); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + + int ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(UB_OK)); + + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); + + uint32_t key = sglRequest.iov[0].lKey; + mDriver->mMapTseg.emplace(key, nullptr); + + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(UB_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendRawAllTwo) +{ + MOCKER_CPP(&NetUBSyncEndpoint::PostOneSideSgl) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)) + .then(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + + int ret = NEP->PostRead(sglRequest); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostRead(sglRequest); + EXPECT_EQ(ret, static_cast(UB_OK)); + + ret = NEP->PostWrite(sglRequest); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostWrite(sglRequest); + EXPECT_EQ(ret, static_cast(UB_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendRawGetBufferErr) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + int ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(UB_MEMORY_ALLOCATE_FAILED)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendRawCopyErr) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + MOCKER_CPP(&AesGcm128::Encrypt, + bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + NEP->mIsNeedEncrypt = false; + int ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); + + NEP->mIsNeedEncrypt = true; + ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(NN_ENCRYPT_FAILED)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendRaw) +{ + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool(UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(invoke(MockGetFreeBuffer)); + MOCKER_CPP(&NetUBSyncEndpoint::InnerPostSend) + .stubs() + .will(returnValue(static_cast(UB_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + + NEP->mIsNeedEncrypt = false; + int ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSendRaw(request, 1); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointInnerPostSendSglNullErr) +{ + NEP->mJetty = nullptr; + UBSendReadWriteRequest tlsReq; + UBSendSglRWRequest sglRWRequest; + int ret = NEP->InnerPostSendSgl(sglRWRequest, tlsReq, 0); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointInnerPostSendCopyErr) +{ + UBSendReadWriteRequest tlsReq; + UBSendSglRWRequest sglRWRequest; + sglRWRequest.upCtxSize = 1; + MOCKER_CPP(&memcpy_s).stubs() + .will(returnValue(0)) + .then(returnValue(1)); + + int ret = NEP->InnerPostSendSgl(sglRWRequest, tlsReq, 0); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); + + ret = NEP->InnerPostSendSgl(sglRWRequest, tlsReq, 0); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostSendRawSgl) +{ + MOCKER_CPP(&NetUBSyncEndpoint::InnerPostSendSgl) + .stubs() + .will(returnValue(static_cast(RR_QP_POST_SEND_WR_FULL))) + .then(returnValue(1)) + .then(returnValue(0)); + + NEP->mIsNeedEncrypt = false; + int ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_NO1)); + + ret = NEP->PostSendRawSgl(sglRequest, 1); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointWaitCompletionErr) +{ + MOCKER_CPP(&NetUBSyncEndpoint::PollingCompletion).stubs().will(returnValue(1)); + int ret = NEP->WaitCompletion(0); + EXPECT_EQ(ret, static_cast(NN_NO1)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointReceiveRaw) +{ + MOCKER_CPP(&NetUBSyncEndpoint::PollingCompletion) + .stubs() + .will(invoke(FakePollingCompletion)); + MOCKER_CPP(&NetUBSyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(static_cast(UB_OK))); + + UBSHcomNetResponseContext resCtx; + int ret = NEP->ReceiveRaw(0, resCtx); + EXPECT_EQ(ret, static_cast(UB_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPollingCompletionCqNull) +{ + UBOpContextInfo *ctx = nullptr; + NEP->mJfc = nullptr; + uint32_t immData = 0; + + int ret = NEP->PollingCompletion(ctx, 0, immData); + EXPECT_EQ(ret, static_cast(UB_EP_NOT_INITIALIZED)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostReceiveJettyNull) +{ + NEP->mJetty = nullptr; + int ret = NEP->PostReceive(0, 0, nullptr); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostReceiveDequeueErr) +{ + MOCKER_CPP(&NetObjPool::Dequeue) + .stubs() + .will(returnValue(false)); + int ret = NEP->PostReceive(0, 0, nullptr); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointRePostReceive) +{ + UBOpContextInfo *ctx = nullptr; + int ret = NEP->RePostReceive(ctx); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointCreateResourcesParamErr) +{ + UBPollingMode pollMode = UB_EVENT_POLLING; + JettyOptions options{}; + int ret = NEP->CreateResources(name, nullptr, pollMode, options, qp, cq); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointCreateResources) +{ + UBPollingMode pollMode = UB_EVENT_POLLING; + JettyOptions options{}; + name = "test"; + int ret = NEP->CreateResources(name, ctx, pollMode, options, qp, cq); + EXPECT_EQ(ret, static_cast(UB_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointInnerPostReadJettyNull) +{ + NEP->mJetty = nullptr; + UBSendReadWriteRequest rwReq{}; + int ret = NEP->InnerPostRead(rwReq); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointInnerPostWriteJettyNull) +{ + NEP->mJetty = nullptr; + UBSendReadWriteRequest rwReq{}; + int ret = NEP->InnerPostWrite(rwReq); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointCreateOneSideCtxParamErr) +{ + UBSgeCtxInfo sgeInfo{}; + uint64_t ctxArr[NET_SGE_MAX_IOV]; + int ret = NEP->CreateOneSideCtx(sgeInfo, nullptr, 0, ctxArr, true); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointCreateOneSideCtx) +{ + UBSgeCtxInfo sgeInfo{}; + uint64_t ctxArr[NET_SGE_MAX_IOV]; + int ret = NEP->CreateOneSideCtx(sgeInfo, iov, 1, ctxArr, true); + EXPECT_EQ(ret, static_cast(UB_OK)); + + NEP->mJetty->DecreaseRef(); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostOneSideSglParamErr) +{ + sglRequest.upCtxSize = 1; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + + int ret = NEP->PostOneSideSgl(sglRequest); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); + + ret = NEP->PostOneSideSgl(sglRequest); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); + + NEP->mJetty = nullptr; + ret = NEP->PostOneSideSgl(sglRequest); + EXPECT_EQ(ret, static_cast(UB_PARAM_INVALID)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPostOneSideSglCreateOneSideCtxErr) +{ + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&NetUBSyncEndpoint::CreateOneSideCtx) + .stubs() + .will(returnValue(1)); + int ret = NEP->PostOneSideSgl(sglRequest); + EXPECT_EQ(ret, static_cast(NN_NO1)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointSetEpOption) +{ + UBSHcomEpOptions epOptions; + int ret = NEP->SetEpOption(epOptions); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointGetSendQueueCount) +{ + int ret = NEP->GetSendQueueCount(); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointPeerIpAndPort) +{ + NEP->mJetty->mPeerIpPort = "1.2.3.4"; + std::string ret = NEP->PeerIpAndPort(); + EXPECT_EQ(ret, "1.2.3.4"); + + NEP->mJetty = nullptr; + ret = NEP->PeerIpAndPort(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointUdsName) +{ + std::string ret = NEP->UdsName(); + EXPECT_EQ(ret, CONST_EMPTY_STRING); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointEnableEncrypt) +{ + UBSHcomNetDriverOptions options{}; + EXPECT_NO_FATAL_FAILURE(NEP->EnableEncrypt(options)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointSetSecrets) +{ + NetSecrets secrets; + EXPECT_NO_FATAL_FAILURE(NEP->SetSecrets(secrets)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointEstimatedEncryptLen) +{ + int ret = NEP->EstimatedEncryptLen(0); + EXPECT_EQ(ret, 0); + NEP->mIsNeedEncrypt = 0; + ret = NEP->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 0); +} +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointEstimatedEncryptLenTwo) +{ + NEP->mIsNeedEncrypt = 1; + MOCKER_CPP(&AesGcm128::EstimatedEncryptLen).stubs().will(returnValue(1)); + int ret = NEP->EstimatedEncryptLen(1); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointEncrypt) +{ + uint64_t cipherLen = 0; + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(false)); + int ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + NEP->mIsNeedEncrypt = 0; + ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointEncryptTwo) +{ + uint64_t cipherLen = 0; + MOCKER_CPP(&AesGcm128::Encrypt).stubs().will(returnValue(true)); + + NEP->mIsNeedEncrypt = 1; + int ret = NEP->Encrypt(reinterpret_cast(0), 0, reinterpret_cast(0), cipherLen); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointEstimatedDecryptLen) +{ + NEP->mIsNeedEncrypt = 0; + int ret = NEP->EstimatedDecryptLen(0); + EXPECT_EQ(ret, 0); + + NEP->mIsNeedEncrypt = 1; + MOCKER_CPP(&AesGcm128::GetRawLen).stubs().will(returnValue(1)); + ret = NEP->EstimatedDecryptLen(0); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointDecrypt) +{ + uint64_t rawLen = 0; + MOCKER_CPP(&AesGcm128::Decrypt).stubs().will(returnValue(false)); + int ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); + + NEP->mIsNeedEncrypt = 0; + ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_ERROR)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointDecryptTwo) +{ + uint64_t rawLen = 0; + MOCKER_CPP(&AesGcm128::Decrypt).stubs().will(returnValue(true)); + + NEP->mIsNeedEncrypt = 1; + int ret = NEP->Decrypt(reinterpret_cast(0), 0, reinterpret_cast(0), rawLen); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointGetRemoteUdsIdInfo) +{ + UBSHcomNetUdsIdInfo verbsIdInfo; + NEP->mState.Set(NEP_NEW); + int ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_EP_NOT_ESTABLISHED)); + + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mDriver->mStartOobSvr = false; + ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointGetRemoteUdsIdInfoTwo) +{ + UBSHcomNetUdsIdInfo verbsIdInfo; + + NEP->mState.Set(NEP_ESTABLISHED); + NEP->mDriver->mStartOobSvr = true; + NEP->mDriver->mOptions.oobType = NET_OOB_TCP; + int ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_UDS_ID_INFO_NOT_SUPPORT)); + + NEP->mDriver->mOptions.oobType = NET_OOB_UDS; + ret = NEP->GetRemoteUdsIdInfo(verbsIdInfo); + EXPECT_EQ(ret, static_cast(NN_OK)); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointGetPeerIpPortErr) +{ + std::string ip; + uint16_t port; + bool ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); + + NEP->mJetty = nullptr; + ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointGetPeerIpPort) +{ + std::string ip; + uint16_t port; + NEP->mJetty->mPeerIpPort = "1.2.3.4"; + bool ret = NEP->GetPeerIpPort(ip, port); + EXPECT_EQ(ret, false); +} + +TEST_F(TestNetUBSyncEndpoint, NetUBSyncEndpointClose) +{ + EXPECT_NO_FATAL_FAILURE(NEP->Close()); +} + +TEST_F(TestNetUBSyncEndpoint, SyncReceiveFailWithErrorOpType) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + UBOpContextInfo opCtx{}; + opCtx.opType = UBOpContextInfo::SEND; + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(&NetUBSyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_ERROR); +} + +TEST_F(TestNetUBSyncEndpoint, SyncReceiveFailWithNullptr) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + UBOpContextInfo opCtx{}; + opCtx.opType = UBOpContextInfo::RECEIVE; + NEP->mDelayHandleReceiveCtx = nullptr; + + MOCKER_CPP(&NetUBSyncEndpoint::PollingCompletion) + .stubs() + .will(returnValue(1)); + int ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, static_cast(NN_NO1)); +} + +TEST_F(TestNetUBSyncEndpoint, SyncReceiveCopyErr) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + UBOpContextInfo opCtx{}; + opCtx.opType = UBOpContextInfo::RECEIVE; + opCtx.dataSize = NN_NO1024; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NN_NO1024; + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize) + .stubs() + .will(returnValue(0)); + + MOCKER_CPP(&AesGcm128::GetRawLen) + .stubs() + .will(returnValue(1)); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(false)); + + MOCKER_CPP(&NetUBSyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_MALLOC_FAILED); + + NEP->mIsNeedEncrypt = true; + ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, UB_CQ_NOT_INITIALIZED); +} + +TEST_F(TestNetUBSyncEndpoint, SyncReceiveMemCopyHeaderErr) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + UBOpContextInfo opCtx{}; + opCtx.opType = UBOpContextInfo::RECEIVE; + opCtx.dataSize = NN_NO1024; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NN_NO1024; + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize) + .stubs() + .will(returnValue(0)); + + MOCKER_CPP(&AesGcm128::GetRawLen) + .stubs() + .will(returnValue(1)); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(true)); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + + MOCKER_CPP(&NetUBSyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NEP->mIsNeedEncrypt = false; + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBSyncEndpoint, SyncReceiveMemCopyAddressErr) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + UBOpContextInfo opCtx{}; + opCtx.opType = UBOpContextInfo::RECEIVE; + opCtx.dataSize = NN_NO1024; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NN_NO1024; + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize) + .stubs() + .will(returnValue(0)); + + MOCKER_CPP(&AesGcm128::GetRawLen) + .stubs() + .will(returnValue(1)); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(true)); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(1)); + + MOCKER_CPP(&NetUBSyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NEP->mIsNeedEncrypt = false; + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBSyncEndpoint, SyncReceiveDecryptErr) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + UBOpContextInfo opCtx{}; + opCtx.opType = UBOpContextInfo::RECEIVE; + opCtx.dataSize = NN_NO1024; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NN_NO1024; + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(NetFunc::ValidateHeaderWithDataSize) + .stubs() + .will(returnValue(0)); + + MOCKER_CPP(&UBSHcomNetMessage::AllocateIfNeed) + .stubs() + .will(returnValue(true)); + + MOCKER_CPP(&AesGcm128::Decrypt, bool(AesGcm128::*)(NetSecrets &, const void *, uint32_t, void *, uint32_t &)) + .stubs() + .will(returnValue(false)); + + MOCKER_CPP(&NetUBSyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NEP->mIsNeedEncrypt = true; + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_DECRYPT_FAILED); +} + +TEST_F(TestNetUBSyncEndpoint, SyncReceiveFailWithOverDataSize) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + UBOpContextInfo opCtx{}; + opCtx.opType = UBOpContextInfo::RECEIVE; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NET_SGE_MAX_SIZE + NN_NO1; + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(&NetUBSyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBSyncEndpoint, SyncReceiveFailWithErrDataLen) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + UBOpContextInfo opCtx{}; + opCtx.opType = UBOpContextInfo::RECEIVE; + opCtx.dataSize = NN_NO1024; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NN_NO2048; + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(&NetUBSyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_INVALID_PARAM); +} + +TEST_F(TestNetUBSyncEndpoint, SyncReceiveFailWithInvalidHeader) +{ + // param init + int32_t timeout = 0; + UBSHcomNetResponseContext ctx{}; + UBOpContextInfo opCtx{}; + opCtx.opType = UBOpContextInfo::RECEIVE; + opCtx.dataSize = NN_NO1024; + UBSHcomNetTransHeader header{}; + header.seqNo = 0; + header.dataLength = NN_NO1024 - sizeof(UBSHcomNetTransHeader); + opCtx.mrMemAddr = reinterpret_cast(&header); + NEP->mDelayHandleReceiveCtx = &opCtx; + + MOCKER_CPP(&NetUBSyncEndpoint::RePostReceive) + .stubs() + .will(returnValue(0)); + + NResult ret = NEP->Receive(timeout, ctx); + EXPECT_EQ(ret, NN_VALIDATE_HEADER_CRC_INVALID); +} +} +} +#endif diff --git a/test/unit_test/transport/ub/test_net_ub_oob_driver.cpp b/test/unit_test/transport/ub/test_net_ub_oob_driver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f8c7d2169ef4a7bf4d51f3aab8128542521f97ea --- /dev/null +++ b/test/unit_test/transport/ub/test_net_ub_oob_driver.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include +#include +#include +#include + +#include "hcom.h" +#include "ub_common.h" +#include "net_ub_driver_oob.h" +#include "net_ub_endpoint.h" + +#include "net_monotonic.h" +#include "net_security_alg.h" +#include "hcom_utils.h" +#include "ub_urma_wrapper_jetty.h" + +namespace ock { +namespace hcom { +class TestNetUBOobDriver : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name; + UBEId ubEid{}; + NetDriverUBWithOob *mDriver = nullptr; + UBContext *ctx = nullptr; + UBWorker *mWorker = nullptr; + UBJfc *cq = nullptr; + UBJetty *qp = nullptr; + UBSHcomNetWorkerIndex mWorkerIndex; + UBSHcomNetTransRequest request; + UBSHcomNetTransSglRequest sglRequest; + UBSHcomNetTransSgeIov *iov = nullptr; + NetUBAsyncEndpoint *NEP = nullptr; + UBMemoryRegionFixedBuffer *Mr = nullptr; +}; + +void TestNetUBOobDriver::SetUp() +{ + bool startOobSvr = false; + UBSHcomNetDriverProtocol protocol = UBC; + mDriver = new (std::nothrow) NetDriverUBWithOob(name, startOobSvr, protocol); + ASSERT_NE(mDriver, nullptr); +} + +void TestNetUBOobDriver::TearDown() +{ + GlobalMockObject::verify(); + if (mDriver != nullptr) { + delete mDriver; + mDriver = nullptr; + } +} + +TEST_F(TestNetUBOobDriver, NetUBDriverRunInUbEventThread) +{ + mDriver->mNeedStopEvent = true; + + UBContext *ubCtx = nullptr; + int result = UBContext::Create(name, ubEid, ubCtx); + ASSERT_EQ(result, 0); + + ubCtx->mUrmaContext = new (std::nothrow) urma_context_t(); + ASSERT_NE(ubCtx->mUrmaContext, nullptr); + + mDriver->mContext = ubCtx; + EXPECT_NO_FATAL_FAILURE(mDriver->RunInUbEventThread()); + + urma_context_t *urmaContext = nullptr; + MOCKER_CPP(&UBContext::GetContext).stubs().will(returnValue(urmaContext)); + EXPECT_NO_FATAL_FAILURE(mDriver->RunInUbEventThread()); + + delete ubCtx->mUrmaContext; + ubCtx->mUrmaContext = nullptr; +} +} +} +#endif \ No newline at end of file diff --git a/test/unit_test/transport/ub/test_ub_mem_region.cpp b/test/unit_test/transport/ub/test_ub_mem_region.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc8592a24f741428eb7443bbbf738be62db617f2 --- /dev/null +++ b/test/unit_test/transport/ub/test_ub_mem_region.cpp @@ -0,0 +1,220 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED +#include +#include +#include + +#include "ub_common.h" +#include "ub_mr_fixed_buf.h" +#include "under_api/urma/urma_api_wrapper.h" +#include "under_api/obmm/obmm_api_wrapper.h" + +namespace ock { +namespace hcom { +class TestUbMemRegion : public testing::Test { +public: + TestUbMemRegion(); + virtual void SetUp(void); + virtual void TearDown(void); + std::string mName = "TestUbMemRegion"; + UBMemoryRegion *MemRegion = nullptr; + UBContext *ctx = nullptr; + UBEId eid{}; + urma_context_t mUrmaContext{}; + urma_target_seg_t mMemSeg{}; + char mem[NN_NO8]{}; +}; + +TestUbMemRegion::TestUbMemRegion() {} + +void TestUbMemRegion::SetUp() +{ + ctx = new (std::nothrow) UBContext("ubTest", eid); + ASSERT_NE(ctx, nullptr); + ctx->mUrmaContext = &mUrmaContext; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + MemRegion = new (std::nothrow) UBMemoryRegion(mName, ctx, reinterpret_cast(mem), NN_NO8); + ASSERT_NE(MemRegion, nullptr); + MemRegion->mMemSeg = &mMemSeg; + mMemSeg.seg.ubva.va = 0; + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); +} + +void TestUbMemRegion::TearDown() +{ + ctx->mUrmaContext = nullptr; + MemRegion->mMemSeg = nullptr; + + if (ctx != nullptr) { + delete ctx; + ctx = nullptr; + } + + if (MemRegion != nullptr) { + delete MemRegion; + MemRegion = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestUbMemRegion, GetProtocol) +{ + EXPECT_EQ(MemRegion->GetProtocol(), UBSHcomNetDriverProtocol::UBC); +} + +TEST_F(TestUbMemRegion, GetVa) +{ + uint64_t va; + uint64_t vaLen; + uint32_t token_id; + + EXPECT_NO_FATAL_FAILURE(MemRegion->GetVa(va, vaLen, token_id)); + EXPECT_EQ(va, 0); +} + +TEST_F(TestUbMemRegion, Create) +{ + UBMemoryRegion *tmpBuf = nullptr; + UBMemoryRegion::Create(mName, ctx, NN_NO8, tmpBuf); + EXPECT_NE(tmpBuf, nullptr); + tmpBuf->mUBContext = nullptr; + if (tmpBuf != nullptr) { + delete tmpBuf; + tmpBuf = nullptr; + } +} + +TEST_F(TestUbMemRegion, CreateExtMem) +{ + UBMemoryRegion *tmpBuf = nullptr; + UBMemoryRegion::Create(mName, ctx, reinterpret_cast(mem), NN_NO8, tmpBuf); + EXPECT_NE(tmpBuf, nullptr); + tmpBuf->mUBContext = nullptr; + if (tmpBuf != nullptr) { + delete tmpBuf; + tmpBuf = nullptr; + } +} + +TEST_F(TestUbMemRegion, InitializeParamErr) +{ + EXPECT_EQ(MemRegion->Initialize(), UB_OK); + MemRegion->mMemSeg = nullptr; + MemRegion->mUBContext = nullptr; + EXPECT_EQ(MemRegion->Initialize(), UB_PARAM_INVALID); +} + +TEST_F(TestUbMemRegion, InitializeExtMem) +{ + urma_target_seg_t tmpMr{}; + urma_target_seg_t *tmpPtr = nullptr; + + MemRegion->mMemSeg = nullptr; + MemRegion->mExternalMemory = true; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(tmpPtr)).then(returnValue(&tmpMr)); + EXPECT_EQ(MemRegion->Initialize(), UB_MR_REG_FAILED); + EXPECT_EQ(MemRegion->Initialize(), UB_OK); +} + +TEST_F(TestUbMemRegion, InitializeFail) +{ + void *tmpPtr = nullptr; + urma_target_seg_t *tmpMr = nullptr; + + MemRegion->mMemSeg = nullptr; + MemRegion->mExternalMemory = false; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(tmpMr)); + MOCKER(memalign).stubs().will(returnValue(tmpPtr)); + EXPECT_EQ(MemRegion->Initialize(), UB_MEMORY_ALLOCATE_FAILED); +} + +TEST_F(TestUbMemRegion, InitializeFailTwo) +{ + urma_target_seg_t *tmpMr = nullptr; + + MemRegion->mMemSeg = nullptr; + MemRegion->mExternalMemory = false; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(tmpMr)); + EXPECT_EQ(MemRegion->Initialize(), UB_MR_REG_FAILED); +} + +TEST_F(TestUbMemRegion, InitializeForOneSideParamErr) +{ + EXPECT_EQ(MemRegion->InitializeForOneSide(), UB_OK); + MemRegion->mMemSeg = nullptr; + MemRegion->mUBContext = nullptr; + EXPECT_EQ(MemRegion->InitializeForOneSide(), UB_PARAM_INVALID); +} + +TEST_F(TestUbMemRegion, InitializeForOneSideExtMem) +{ + urma_target_seg_t tmpMr{}; + urma_target_seg_t *tmpPtr = nullptr; + + MemRegion->mMemSeg = nullptr; + MemRegion->mExternalMemory = true; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(tmpPtr)).then(returnValue(&tmpMr)); + EXPECT_EQ(MemRegion->InitializeForOneSide(), UB_MR_REG_FAILED); + EXPECT_EQ(MemRegion->InitializeForOneSide(), UB_OK); +} + +TEST_F(TestUbMemRegion, InitializeForOneSideFail) +{ + void *tmpPtr = nullptr; + urma_target_seg_t *tmpMr = nullptr; + + MemRegion->mMemSeg = nullptr; + MemRegion->mExternalMemory = false; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(tmpMr)); + MOCKER(memalign).stubs().will(returnValue(tmpPtr)); + EXPECT_EQ(MemRegion->InitializeForOneSide(), UB_MEMORY_ALLOCATE_FAILED); +} + +TEST_F(TestUbMemRegion, InitializeForOneSideFailTwo) +{ + urma_target_seg_t *tmpMr = nullptr; + + MemRegion->mMemSeg = nullptr; + MemRegion->mExternalMemory = false; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(tmpMr)); + EXPECT_EQ(MemRegion->InitializeForOneSide(), UB_MR_REG_FAILED); +} + +TEST_F(TestUbMemRegion, InitializeWithPAParamErr) +{ + MOCKER(HcomObmm::ObmmOpen).stubs().will(returnValue(-1)).then(returnValue(1)); + MOCKER(mmap).stubs().will(returnValue(MAP_FAILED)); + MOCKER_CPP(close).stubs().will(returnValue(0)); + EXPECT_EQ(MemRegion->InitializeWithPA(1), UB_MEMORY_ALLOCATE_FAILED); + EXPECT_EQ(MemRegion->InitializeWithPA(1), UB_MEMORY_ALLOCATE_FAILED); +} + +TEST_F(TestUbMemRegion, InitializeWithPA) +{ + int tmp; + + MOCKER(HcomObmm::ObmmOpen).stubs().will(returnValue(1)); + MOCKER(mmap).stubs().will(returnValue(reinterpret_cast(&tmp))); + EXPECT_EQ(MemRegion->InitializeWithPA(1), UB_OK); +} +} +} +#endif \ No newline at end of file diff --git a/test/unit_test/transport/ub/test_ub_public_jetty.cpp b/test/unit_test/transport/ub/test_ub_public_jetty.cpp new file mode 100644 index 0000000000000000000000000000000000000000..16ebb86393ab16f16cd28f9b3a834e6875032e10 --- /dev/null +++ b/test/unit_test/transport/ub/test_ub_public_jetty.cpp @@ -0,0 +1,486 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include +#include +#include +#include + +#include "ub_urma_wrapper_public_jetty.h" +#include "ub_mr_fixed_buf.h" +#include "ub_fixed_mem_pool.h" +namespace ock { +namespace hcom { + +class TestUBPublicJetty : public testing::Test { +public: + virtual void SetUp(void); + virtual void TearDown(void); + + std::string name = "test-public-jetty"; + UBEId eid{}; + UBPublicJetty *jetty = nullptr; + UBContext *ctx = nullptr; + UBJfc *jfc = nullptr; +}; + +void TestUBPublicJetty::SetUp() +{ + ctx = new (std::nothrow) UBContext(name, eid); + ASSERT_NE(ctx, nullptr); + jfc = new (std::nothrow) UBJfc(name, ctx); + ASSERT_NE(jfc, nullptr); + jetty = new (std::nothrow) UBPublicJetty(name, 0, ctx, jfc); + ASSERT_NE(jetty, nullptr); + MOCKER_CPP(HcomUrma::Uninit).stubs().will(returnValue(0)); +} + +void TestUBPublicJetty::TearDown() +{ + if (jetty != nullptr) { + delete jetty; + jetty = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestUBPublicJetty, ImportPublicJetty) +{ + urma_target_jetty_t targetJetty{}; + urma_eid_t remoteEid{}; + urma_jetty_t tmpJetty{}; + urma_target_jetty_t *out = nullptr; + jetty->mUrmaJetty = &tmpJetty; + MOCKER_CPP(HcomUrma::ImportJetty).stubs().will(returnValue(out)).then(returnValue(&targetJetty)); + EXPECT_EQ(jetty->ImportPublicJetty(remoteEid, 0), UB_QP_IMPORT_FAILED); + EXPECT_EQ(jetty->ImportPublicJetty(remoteEid, 0), UB_OK); + jetty->mUrmaJetty = nullptr; + jetty->mTargetJetty = nullptr; +} + +TEST_F(TestUBPublicJetty, FillJfsCfg) +{ + urma_jfs_cfg_t jfs_cfg{}; + EXPECT_NO_FATAL_FAILURE(jetty->FillJfsCfg(&jfs_cfg)); +} + +TEST_F(TestUBPublicJetty, FillJfrCfg) +{ + urma_jfr_cfg_t jfr_cfg{}; + EXPECT_NO_FATAL_FAILURE(jetty->FillJfrCfg(&jfr_cfg)); +} + +TEST_F(TestUBPublicJetty, CreateUrmaPublicJetty) +{ + urma_jfc_t mUrmaJfc{}; + urma_context_t urmaContext{}; + EXPECT_EQ(jetty->CreateUrmaPublicJetty(0), UB_PARAM_INVALID); + + jfc->mUrmaJfc = &mUrmaJfc; + ctx->mUrmaContext = &urmaContext; + urma_jfr_t *outJfr = nullptr; + urma_jfr_t outJfr2{}; + MOCKER_CPP(HcomUrma::CreateJfr).stubs().will(returnValue(outJfr)).then(returnValue(&outJfr2)); + urma_jetty_t *outJetty = nullptr; + urma_jetty_t outJetty2{}; + urma_status_t res = 0; + MOCKER_CPP(HcomUrma::CreateJetty).stubs().will(returnValue(outJetty)).then(returnValue(&outJetty2)); + MOCKER_CPP(HcomUrma::DeleteJfr).stubs().will(returnValue(res)); + EXPECT_EQ(jetty->CreateUrmaPublicJetty(0), UB_PARAM_INVALID); + EXPECT_EQ(jetty->CreateUrmaPublicJetty(0), UB_QP_CREATE_FAILED); + EXPECT_EQ(jetty->CreateUrmaPublicJetty(0), UB_OK); + + jfc->mUrmaJfc = nullptr; + ctx->mUrmaContext = nullptr; + jetty->mUrmaJetty = nullptr; +} + +TEST_F(TestUBPublicJetty, CreateJettyMr) +{ + urma_target_seg_t tmpMr{}; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(&tmpMr)); + MOCKER(HcomUrma::UnregisterSeg).stubs().will(returnValue(0)); + EXPECT_EQ(jetty->CreateJettyMr(), UB_OK); +} + +TEST_F(TestUBPublicJetty, InitializePublicJetty) +{ + MOCKER_CPP(&UBPublicJetty::CreateJettyMr).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::CreateCtxInfoPool).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::CreateUrmaPublicJetty).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->InitializePublicJetty(0), 1); + EXPECT_EQ(jetty->InitializePublicJetty(0), 1); + EXPECT_EQ(jetty->InitializePublicJetty(0), 1); + EXPECT_EQ(jetty->InitializePublicJetty(0), 0); +} + +TEST_F(TestUBPublicJetty, StartPublicJettyFail) +{ + urma_target_seg_t tmpMr{}; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(&tmpMr)); + MOCKER(HcomUrma::UnregisterSeg).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::CreateUrmaPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBufferN).stubs().will(returnValue(false)); + jetty->InitializePublicJetty(0); + EXPECT_EQ(jetty->StartPublicJetty(), UB_MEMORY_ALLOCATE_FAILED); +} + +TEST_F(TestUBPublicJetty, StartPublicJetty) +{ + urma_target_seg_t tmpMr{}; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(&tmpMr)); + MOCKER(HcomUrma::UnregisterSeg).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::CreateUrmaPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::PostReceive).stubs().will(returnValue(1)).then(returnValue(0)); + jetty->InitializePublicJetty(0); + EXPECT_EQ(jetty->StartPublicJetty(), UB_ERROR); + EXPECT_EQ(jetty->StartPublicJetty(), UB_OK); +} + +NResult MockNewRequestHandler(UBOpContextInfo *ctx) +{ + return NN_OK; +} + +TEST_F(TestUBPublicJetty, NewRequest) +{ + UBOpContextInfo ctx{}; + JettyConnHeader header{}; + urma_target_seg_t *tseg = nullptr; + + MOCKER_CPP(&UBPublicJetty::GetMemorySeg).stubs().will(returnValue(tseg)); + MOCKER_CPP(&UBFixedMemPool::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBMemoryRegionFixedBuffer::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBPublicJetty::PostReceive).stubs().will(returnValue(1)).then(returnValue(0)); + ctx.mrMemAddr = reinterpret_cast(&header); + jetty->SetNewConnCB(MockNewRequestHandler); + EXPECT_EQ(jetty->NewRequest(nullptr), NN_ERROR); + + header.msgType = UrmaConnectMsgType::CONNECT_REQ; + EXPECT_EQ(jetty->NewRequest(&ctx), UB_QP_POST_RECEIVE_FAILED); + + header.msgType = UrmaConnectMsgType::EXCHANGE_MSG; + EXPECT_EQ(jetty->NewRequest(&ctx), UB_OK); +} + +TEST_F(TestUBPublicJetty, SendFinished) +{ + UBOpContextInfo ctx{}; + + MOCKER_CPP(&UBPublicJetty::ReturnBuffer).stubs().will(returnValue(false)).then(returnValue(true)); + MOCKER_CPP(&UBFixedMemPool::ReturnBuffer).stubs().will(returnValue(true)); + + EXPECT_EQ(jetty->SendFinished(&ctx), UB_OK); + EXPECT_EQ(jetty->SendFinished(&ctx), UB_OK); +} + +TEST_F(TestUBPublicJetty, RunInThread) +{ + jetty->mNeedStop = true; + EXPECT_NO_FATAL_FAILURE(jetty->RunInThread()); +} + +TEST_F(TestUBPublicJetty, SendByPublicJettyFail) +{ + urma_jetty_t tmpJetty{}; + uint8_t data; + urma_target_jetty_t targetJetty{}; + jetty->mTargetJetty = &targetJetty; + urma_target_seg_t tmpMr{}; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(&tmpMr)); + MOCKER(HcomUrma::UnregisterSeg).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::CreateUrmaPublicJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer).stubs().will(returnValue(false)); + EXPECT_EQ(jetty->SendByPublicJetty(&data, 1), UB_QP_NOT_INITIALIZED); + + jetty->mUrmaJetty = &tmpJetty; + jetty->InitializePublicJetty(0); + EXPECT_EQ(jetty->SendByPublicJetty(&data, 1), UB_MEMORY_ALLOCATE_FAILED); + jetty->mUrmaJetty = nullptr; + jetty->mTargetJetty = nullptr; +} + +TEST_F(TestUBPublicJetty, SendByPublicJetty) +{ + urma_jetty_t tmpJetty{}; + urma_target_jetty_t targetJetty{}; + jetty->mTargetJetty = &targetJetty; + uint8_t data; + urma_target_seg_t *tseg = nullptr; + urma_target_seg_t tmpMr{}; + MOCKER(HcomUrma::RegisterSeg).stubs().will(returnValue(&tmpMr)); + MOCKER(HcomUrma::UnregisterSeg).stubs().will(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::GetMemorySeg).stubs().will(returnValue(tseg)); + MOCKER_CPP(&UBPublicJetty::CreateUrmaPublicJetty).stubs().will(returnValue(0)); + MOCKER(HcomUrma::UnimportJetty).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::PostJettySendWr, urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)).then(returnValue(0)); + jetty->mUrmaJetty = &tmpJetty; + jetty->InitializePublicJetty(0); + + EXPECT_EQ(jetty->SendByPublicJetty(&data, 1), UB_QP_POST_SEND_FAILED); + EXPECT_EQ(jetty->SendByPublicJetty(&data, 1), UB_OK); + jetty->mUrmaJetty = nullptr; +} + +TEST_F(TestUBPublicJetty, PostReceive) +{ + urma_target_seg_t localSeg{}; + urma_jetty_t tmpJetty{}; + uint8_t data; + urma_target_seg_t *tseg = nullptr; + + MOCKER_CPP(&UBPublicJetty::GetMemorySeg).stubs().will(returnValue(tseg)); + MOCKER_CPP(HcomUrma::PostJettyRecvWr).stubs().will(returnValue(1)).then(returnValue(0)); + + jetty->mUrmaJetty = &tmpJetty; + EXPECT_EQ(jetty->PostReceive(reinterpret_cast(&data), 0, &localSeg, 0), NN_INVALID_PARAM); + EXPECT_EQ(jetty->PostReceive(reinterpret_cast(&data), 1, &localSeg, 0), UB_QP_POST_RECEIVE_FAILED); + EXPECT_EQ(jetty->PostReceive(reinterpret_cast(&data), 1, &localSeg, 0), UB_OK); + jetty->mUrmaJetty = nullptr; +} + +UBOpContextInfo tmpCtxInfo{}; +UResult MockEventProgressV(urma_cr_t *cr, uint32_t &countInOut, int32_t timeoutInMs = NN_NO500) +{ + switch (timeoutInMs) { + case NN_NO1000: + countInOut = 0; + break; + case NN_NO2000: + countInOut = 1; + cr->status = URMA_CR_LOC_LEN_ERR; + break; + case NN_NO10000: + countInOut = 1; + cr->status = URMA_CR_SUCCESS; + cr->user_ctx = 0; + break; + case (-1): + countInOut = 1; + cr->status = URMA_CR_SUCCESS; + cr->user_ctx = reinterpret_cast(&tmpCtxInfo); + break; + default: + break; + } + return UB_OK; +} + +TEST_F(TestUBPublicJetty, PollingCompletion) +{ + urma_jfc_t mUrmaJfc{}; + urma_jetty_t mUrmaJetty{}; + uint32_t pollCount = 0; + + MOCKER_CPP(&UBJfc::ProgressV).stubs().with(any(), outBound(pollCount)) + .will(returnValue(1)); + MOCKER_CPP(&UBPublicJetty::ReturnBuffer).stubs().will(returnValue(false)); + MOCKER_CPP(&UBFixedMemPool::ReturnBuffer).stubs().will(returnValue(true)); + EXPECT_EQ(jetty->PollingCompletion(), UB_EP_NOT_INITIALIZED); + + jetty->mUrmaJetty = &mUrmaJetty; + jfc->mUrmaJfc = &mUrmaJfc; + EXPECT_EQ(jetty->PollingCompletion(), UB_CQ_POLLING_FAILED); + jfc->mUrmaJfc = nullptr; + jetty->mUrmaJetty = nullptr; +} + +TEST_F(TestUBPublicJetty, ReceiveFail) +{ + uint8_t buf; + urma_jfc_t mUrmaJfc{}; + urma_jetty_t mUrmaJetty{}; + uint32_t pollCount = 0; + MOCKER_CPP(&UBJfc::ProgressV).stubs().with(any(), outBound(pollCount)) + .will(returnValue(1)); + + jetty->mUrmaJetty = &mUrmaJetty; + jfc->mUrmaJfc = &mUrmaJfc; + EXPECT_EQ(jetty->Receive(nullptr, 1), UB_PARAM_INVALID); + EXPECT_EQ(jetty->Receive(&buf, 1), UB_ERROR); + jfc->mUrmaJfc = nullptr; + jetty->mUrmaJetty = nullptr; +} + +UResult MockEventProgressV2(urma_cr_t *cr, uint32_t &countInOut, int32_t timeoutInMs = NN_NO500) +{ + countInOut = 1; + cr->completion_len = 1; + switch (timeoutInMs) { + case NN_NO1000: + cr->user_ctx = 0; + break; + case NN_NO2000: + cr->user_ctx = reinterpret_cast(&tmpCtxInfo); + tmpCtxInfo.mrMemAddr = reinterpret_cast(&tmpCtxInfo); + break; + default: + break; + } + return UB_OK; +} + +UResult MockProgressV(urma_cr_t *cr, uint32_t &countInOut) +{ + countInOut = 1; + cr->completion_len = 1; + cr->user_ctx = reinterpret_cast(&tmpCtxInfo); + tmpCtxInfo.mrMemAddr = reinterpret_cast(&tmpCtxInfo); + return UB_OK; +} + +TEST_F(TestUBPublicJetty, Receive) +{ + uint8_t buf; + urma_target_seg_t *tseg = nullptr; + + MOCKER_CPP(&UBJfc::ProgressV).stubs().will(invoke(MockProgressV)); + MOCKER_CPP(&UBPublicJetty::GetMemorySeg).stubs().will(returnValue(tseg)); + MOCKER_CPP(&UBPublicJetty::CheckRecvResult).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBPublicJetty::PostReceive).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBFixedMemPool::ReturnBuffer).stubs().will(returnValue(true)); + MOCKER_CPP(&UBMemoryRegionFixedBuffer::ReturnBuffer).stubs().will(returnValue(true)); + + EXPECT_EQ(jetty->Receive(&buf, 1), UB_ERROR); + EXPECT_EQ(jetty->Receive(&buf, 1), UB_QP_POST_RECEIVE_FAILED); + EXPECT_EQ(jetty->Receive(&buf, 1), UB_OK); +} + +TEST_F(TestUBPublicJetty, FillBondingMsg) +{ + JettyConnHeader exchangeInfo{}; + urma_jetty_t tmpJetty{}; + jetty->mUrmaJetty = &tmpJetty; + MOCKER_CPP(HcomUrma::UserCtl).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->FillBondingMsg(&(exchangeInfo.clientCtrlBondInfo)), UB_ERROR); + EXPECT_EQ(jetty->FillBondingMsg(&(exchangeInfo.clientCtrlBondInfo)), UB_OK); + jetty->mUrmaJetty = nullptr; +} + +TEST_F(TestUBPublicJetty, SetBondingInfo) +{ + JettyConnHeader exchangeInfo{}; + urma_jetty_t tmpJetty{}; + jetty->mUrmaJetty = &tmpJetty; + MOCKER_CPP(HcomUrma::UserCtl).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->SetBondingInfo(&(exchangeInfo.clientCtrlBondInfo)), UB_ERROR); + EXPECT_EQ(jetty->SetBondingInfo(&(exchangeInfo.clientCtrlBondInfo)), UB_OK); + jetty->mUrmaJetty = nullptr; +} + +TEST_F(TestUBPublicJetty, Stop) +{ + urma_jetty_t tmpJetty{}; + urma_target_jetty_t tmpTargetJetty{}; + jetty->mUrmaJetty = &tmpJetty; + jetty->mTargetJetty = &tmpTargetJetty; + MOCKER(HcomUrma::ModifyJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(HcomUrma::ModifyJfr).stubs().will(returnValue(0)); + MOCKER(HcomUrma::UnimportJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(HcomUrma::DeleteJetty).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_NO_FATAL_FAILURE(jetty->Stop()); + EXPECT_NO_FATAL_FAILURE(jetty->Stop()); + EXPECT_NO_FATAL_FAILURE(jetty->Stop()); +} + +TEST_F(TestUBPublicJetty, UBFixedMemPool) +{ + uintptr_t buf = 0; + UBFixedMemPool *mempool = new (std::nothrow) UBFixedMemPool(NN_NO128, NN_NO64); + EXPECT_EQ(mempool->GetFreeBuffer(buf), false); + EXPECT_EQ(mempool->MakeFreeList(), UB_PARAM_INVALID); + mempool->Initialize(); + EXPECT_EQ(mempool->ReturnBuffer(0), false); + delete mempool; +} + +TEST_F(TestUBPublicJetty, ProcessWorkerCompletion) +{ + UBOpContextInfo ctx{}; + MOCKER_CPP(&UBPublicJetty::NewRequest).stubs().will(returnValue(0)); + ctx.opType = UBOpContextInfo::OpType::RECEIVE; + EXPECT_NO_FATAL_FAILURE(jetty->ProcessWorkerCompletion(&ctx)); + + ctx.opType = UBOpContextInfo::OpType::SEND_RAW; + EXPECT_NO_FATAL_FAILURE(jetty->ProcessWorkerCompletion(&ctx)); +} + +TEST_F(TestUBPublicJetty, CheckRecvResult) +{ + urma_cr_t wc{}; + uint32_t size = 0; + UResult result = UB_OK; + uint32_t pollCount = 0; + int32_t timeoutInMs = 1; + + urma_jfc_t mUrmaJfc{}; + urma_jetty_t mUrmaJetty{}; + jetty->mUrmaJetty = &mUrmaJetty; + jfc->mUrmaJfc = &mUrmaJfc; + + EXPECT_EQ(jetty->CheckRecvResult(wc, size, result, pollCount, timeoutInMs), UB_CQ_POLLING_FAILED); + pollCount = 1; + result = UB_ERROR; + EXPECT_EQ(jetty->CheckRecvResult(wc, size, result, pollCount, timeoutInMs), UB_ERROR); + result = UB_OK; + wc.status = URMA_CR_LOC_LEN_ERR; + EXPECT_EQ(jetty->CheckRecvResult(wc, size, result, pollCount, timeoutInMs), UB_CQ_WC_WRONG); + wc.status = URMA_CR_SUCCESS; + wc.completion_len = 1; + EXPECT_EQ(jetty->CheckRecvResult(wc, size, result, pollCount, timeoutInMs), UB_CQ_WC_WRONG); + size = 1; + EXPECT_EQ(jetty->CheckRecvResult(wc, size, result, pollCount, timeoutInMs), UB_OK); + + jfc->mUrmaJfc = nullptr; + jetty->mUrmaJetty = nullptr; +} + +TEST_F(TestUBPublicJetty, ProcessPollingResult) +{ + urma_cr_t wc{}; + + MOCKER_CPP(&UBPublicJetty::NewRequest).stubs().will(returnValue(0)); + EXPECT_NO_FATAL_FAILURE(jetty->ProcessPollingResult(wc)); + + UBOpContextInfo ctxInfo{}; + wc.user_ctx = reinterpret_cast(&ctxInfo); + EXPECT_NO_FATAL_FAILURE(jetty->ProcessPollingResult(wc)); +} + +TEST_F(TestUBPublicJetty, UBThreadPool) +{ + UBThreadPool *threadPool = new (std::nothrow) UBThreadPool(NN_NO3); + ASSERT_NE(threadPool, nullptr); + + threadPool->Stop(); + threadPool->Initialize(); + threadPool->Submit([]() {NN_LOG_INFO("Run a test task");}); + threadPool->Submit([]() { + NN_LOG_INFO("Run a std error task"); + throw std::runtime_error("Run a std error task"); + }); + threadPool->Submit([]() { + NN_LOG_INFO("Run a unknown error task"); + throw NN_NO58; + }); + sleep(NN_NO1); + EXPECT_NO_FATAL_FAILURE(threadPool->Stop()); + if (threadPool != nullptr) { + delete threadPool; + threadPool = nullptr; + } +} +} +} +#endif diff --git a/test/unit_test/transport/ub/test_ub_urma_jetty.cpp b/test/unit_test/transport/ub/test_ub_urma_jetty.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0baa879c3c773d58f20363f7707394074b418536 --- /dev/null +++ b/test/unit_test/transport/ub/test_ub_urma_jetty.cpp @@ -0,0 +1,800 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED + +#include +#include +#include + +#include "net_monotonic.h" +#include "ub_common.h" +#include "ub_mr_fixed_buf.h" +#include "ub_urma_wrapper_jetty.h" +#include "under_api/urma/urma_api_wrapper.h" +#include "under_api/obmm/obmm_api_wrapper.h" + +namespace ock { +namespace hcom { +class TestUbUrmaJetty : public testing::Test { +public: + TestUbUrmaJetty(); + virtual void SetUp(void); + virtual void TearDown(void); + std::string mName = "TestUbUrmaJetty"; + UBJetty *jetty = nullptr; + UBContext *ctx = nullptr; + UBJfc *jfc = nullptr; + urma_jfc_t mUrmaJfc{}; + urma_jetty_t UrmaJetty{}; + UBEId eid{}; + urma_context_t mUrmaContext{}; + UBMemoryRegionFixedBuffer *mJettyMr = nullptr; +}; + +TestUbUrmaJetty::TestUbUrmaJetty() {} + +void TestUbUrmaJetty::SetUp() +{ + ctx = new (std::nothrow) UBContext("ubTest", eid); + ASSERT_NE(ctx, nullptr); + ctx->mUrmaContext = &mUrmaContext; + + mJettyMr = new (std::nothrow) UBMemoryRegionFixedBuffer(mName, ctx, 1, 1, 1); + ASSERT_NE(mJettyMr, nullptr); + + jfc = new (std::nothrow) UBJfc(mName, ctx, false, 0); + ASSERT_NE(jfc, nullptr); + jfc->mUrmaJfc = &mUrmaJfc; + jetty = new (std::nothrow) UBJetty(mName, 0, ctx, jfc); + ASSERT_NE(jetty, nullptr); + jetty->mJettyOptions.ubcMode = UBSHcomUbcMode::Disabled; + jetty->mUrmaJetty = &UrmaJetty; + jetty->mJettyMr = mJettyMr; + jetty->StoreExchangeInfo(new UBJettyExchangeInfo); + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); +} + +void TestUbUrmaJetty::TearDown() +{ + ctx->mUrmaContext = nullptr; + jetty->mSendJfc = nullptr; + jetty->mRecvJfc = nullptr; + jetty->mUBContext = nullptr; + jetty->mUrmaJetty = nullptr; + jetty->mJettyMr = nullptr; + jfc->mUBContext = nullptr; + jfc->mUrmaJfc = nullptr; + + if (mJettyMr != nullptr) { + delete mJettyMr; + mJettyMr = nullptr; + } + + if (ctx != nullptr) { + delete ctx; + ctx = nullptr; + } + + if (jfc != nullptr) { + delete jfc; + jfc = nullptr; + } + + if (jetty != nullptr) { + delete jetty; + jetty = nullptr; + } + GlobalMockObject::verify(); +} + +TEST_F(TestUbUrmaJetty, UnInitializeSuccess) +{ + int tmp; + mJettyMr->IncreaseRef(); + mJettyMr->IncreaseRef(); + ctx->IncreaseRef(); + jfc->IncreaseRef(); + urma_jfr_t mJfr{}; + jetty->mJfr = &mJfr; + + MOCKER_CPP(&Enabled).stubs().will(returnValue(false)); + MOCKER_CPP(&HcomUrma::UnbindJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::UnimportJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::DeleteJetty).stubs().will(returnValue(11)); + MOCKER_CPP(&HcomUrma::DeleteJfr).stubs().will(returnValue(11)); + MOCKER_CPP(&HcomUrma::DeleteJfc).stubs().will(returnValue(11)); + MOCKER_CPP(&UBContext::UnInitialize).stubs().will(returnValue(0)); + EXPECT_EQ(jetty->UnInitialize(), 0); +} + +TEST_F(TestUbUrmaJetty, UnInitializeHB) +{ + int tmp; + mJettyMr->IncreaseRef(); + mJettyMr->IncreaseRef(); + ctx->IncreaseRef(); + jfc->IncreaseRef(); + urma_jfr_t mJfr{}; + jetty->mJfr = &mJfr; + auto localMr = new (std::nothrow) UBMemoryRegion("localMr", nullptr, 0, 0, 0); + auto remoteMr = new (std::nothrow) UBMemoryRegion("remoteMr", nullptr, 0, 0, 0); + jetty->mHBLocalMr = localMr; + jetty->mHBRemoteMr = remoteMr; + + MOCKER_CPP(&Enabled).stubs().will(returnValue(false)); + MOCKER_CPP(&HcomUrma::UnbindJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::UnimportJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::DeleteJetty).stubs().will(returnValue(11)); + MOCKER_CPP(&HcomUrma::DeleteJfr).stubs().will(returnValue(11)); + MOCKER_CPP(&HcomUrma::DeleteJfc).stubs().will(returnValue(11)); + MOCKER_CPP(&UBContext::UnInitialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::DestroyHBMemoryRegion).stubs().will(ignoreReturnValue()); + EXPECT_EQ(jetty->UnInitialize(), 0); + localMr = nullptr; + remoteMr = nullptr; +} + +TEST_F(TestUbUrmaJetty, GetProtocol) +{ + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + EXPECT_EQ(jetty->GetProtocol(), UBSHcomNetDriverProtocol::UBC); +} + +TEST_F(TestUbUrmaJetty, FillJfsCfg) +{ + urma_jfs_cfg_t cfg{}; + + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + EXPECT_NO_THROW(jetty->FillJfsCfg(&cfg)); +} + +TEST_F(TestUbUrmaJetty, FillJfrCfg) +{ + urma_jfr_cfg_t cfg{}; + + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + EXPECT_NO_THROW(jetty->FillJfrCfg(&cfg)); +} + +TEST_F(TestUbUrmaJetty, PostReceiveParamErr) +{ + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->PostReceive(0, 0, nullptr, 0), UB_QP_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaJetty, PostReceive) +{ + MOCKER(HcomUrma::PostJettyRecvWr).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->PostReceive(0, 0, nullptr, 0), UB_QP_POST_RECEIVE_FAILED); + EXPECT_EQ(jetty->PostReceive(0, 0, nullptr, 0), UB_OK); +} + +TEST_F(TestUbUrmaJetty, PostSendSglInlineJettyNull) +{ + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->PostSendSglInline(nullptr, 10, 10), UB_QP_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaJetty, PostSendSglInlineJettyFail) +{ + jetty->mUrmaJetty = &UrmaJetty; + UBSHcomNetTransDataIov iov[1]; + iov[0].address = 0X1234; + iov[0].key = 123; + iov[0].size = 10; + + MOCKER(HcomUrma::PostJettySendWr, urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->PostSendSglInline(iov, 1, 0), UB_QP_POST_SEND_FAILED); + EXPECT_EQ(jetty->PostSendSglInline(iov, 1, 0), UB_OK); +} + +TEST_F(TestUbUrmaJetty, PostSendSglParamErr) +{ + UBSHcomNetTransSgeIov iov{}; + uint32_t iovCount = 1; + + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->PostSendSgl(&iov, iovCount, 0, 0), UB_QP_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaJetty, PostSendSgl) +{ + UBSHcomNetTransSgeIov iov{}; + uint32_t iovCount = 1; + MOCKER(HcomUrma::PostJettySendWr, urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->PostSendSgl(&iov, iovCount, 0, 0), UB_QP_POST_SEND_FAILED); + EXPECT_EQ(jetty->PostSendSgl(&iov, iovCount, 0, 0), UB_OK); + + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + EXPECT_EQ(jetty->PostSendSgl(&iov, iovCount, 0, 0), UB_OK); +} + +TEST_F(TestUbUrmaJetty, PostReadParamErr) +{ + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->PostRead(0, nullptr, 0, nullptr, 0, 0), UB_QP_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaJetty, PostRead) +{ + MOCKER(HcomUrma::PostJettySendWr, urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->PostRead(0, nullptr, 0, nullptr, 0, 0), UB_QP_POST_WRITE_FAILED); + EXPECT_EQ(jetty->PostRead(0, nullptr, 0, nullptr, 0, 0), UB_OK); +} + +TEST_F(TestUbUrmaJetty, UBCPostReadParamErr) +{ + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->PostRead(0, 1, 0, 0, 0, 0), UB_QP_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaJetty, UBCPostRead) +{ + urma_target_seg_t seg{}; + urma_target_seg_t *tmpSeg1 = nullptr; + urma_target_seg_t *tmpSeg2 = &seg; + MOCKER(HcomUrma::ImportSeg).stubs().will(returnValue(tmpSeg1)).then(returnValue(tmpSeg2)); + MOCKER(HcomUrma::PostJettySendWr, urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(HcomUrma::UnimportSeg).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->PostRead(0, 1, 0, 0, 0, 0), UB_QP_POST_READ_FAILED); + EXPECT_EQ(jetty->PostRead(0, 1, 0, 0, 0, 0), UB_QP_POST_READ_FAILED); + EXPECT_EQ(jetty->PostRead(0, 1, 0, 0, 0, 0), UB_OK); +} + +TEST_F(TestUbUrmaJetty, UBCPostReadTseg) +{ + urma_target_seg_t *tmpSeg1 = nullptr; + urma_target_seg_t seg{}; + urma_target_seg_t *tmpSeg2 = &seg; + MOCKER(HcomUrma::UnimportSeg).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(HcomUrma::PostJettySendWr, urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(HcomUrma::ImportSeg).stubs().will(returnValue(tmpSeg1)).then(returnValue(tmpSeg2)); + + EXPECT_EQ(jetty->PostRead(0, static_cast(nullptr), static_cast(0), + static_cast(0), static_cast(0), static_cast(0)), UB_QP_POST_READ_FAILED); + + EXPECT_EQ(jetty->PostRead(0, static_cast(nullptr), static_cast(0), + static_cast(0), static_cast(0), static_cast(0)), UB_QP_POST_READ_FAILED); + + EXPECT_EQ(jetty->PostRead(0, static_cast(nullptr), static_cast(0), + static_cast(0), static_cast(0), static_cast(0)), UB_OK); + + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->PostRead(0, static_cast(nullptr), static_cast(0), + static_cast(0), static_cast(0), static_cast(0)), UB_QP_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaJetty, PostWriteParamErr) +{ + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->PostWrite(0, nullptr, 0, nullptr, 0, 0), UB_QP_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaJetty, PostWrite) +{ + MOCKER(HcomUrma::PostJettySendWr, urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->PostWrite(0, nullptr, 0, nullptr, 0, 0), UB_QP_POST_WRITE_FAILED); + EXPECT_EQ(jetty->PostWrite(0, nullptr, 0, nullptr, 0, 0), UB_OK); +} + +TEST_F(TestUbUrmaJetty, UBCPostWriteParamErr) +{ + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->PostWrite(0, 1, 0, 0, 0, 0), UB_QP_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaJetty, UBCPostWrite) +{ + urma_target_seg_t seg{}; + urma_target_seg_t *tmpSeg1 = nullptr; + urma_target_seg_t *tmpSeg2 = &seg; + MOCKER(HcomUrma::ImportSeg).stubs().will(returnValue(tmpSeg1)).then(returnValue(tmpSeg2)); + MOCKER(HcomUrma::PostJettySendWr, urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(HcomUrma::UnimportSeg).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->PostWrite(0, 1, 0, 0, 0, 0), UB_QP_POST_WRITE_FAILED); + EXPECT_EQ(jetty->PostWrite(0, 1, 0, 0, 0, 0), UB_QP_POST_WRITE_FAILED); + EXPECT_EQ(jetty->PostWrite(0, 1, 0, 0, 0, 0), UB_OK); +} + +TEST_F(TestUbUrmaJetty, UBCPostWriteTseg) +{ + urma_target_seg_t *tmpSeg1 = nullptr; + urma_target_seg_t seg{}; + urma_target_seg_t *tmpSeg2 = &seg; + MOCKER(HcomUrma::UnimportSeg).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(HcomUrma::PostJettySendWr, urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(HcomUrma::ImportSeg).stubs().will(returnValue(tmpSeg1)).then(returnValue(tmpSeg2)); + + EXPECT_EQ(jetty->PostWrite(0, static_cast(nullptr), static_cast(0), + static_cast(0), static_cast(0), static_cast(0)), UB_QP_POST_WRITE_FAILED); + + EXPECT_EQ(jetty->PostWrite(0, static_cast(nullptr), static_cast(0), + static_cast(0), static_cast(0), static_cast(0)), UB_QP_POST_WRITE_FAILED); + + EXPECT_EQ(jetty->PostWrite(0, static_cast(nullptr), static_cast(0), + static_cast(0), static_cast(0), static_cast(0)), UB_OK); + + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->PostWrite(0, static_cast(nullptr), static_cast(0), + static_cast(0), static_cast(0), static_cast(0)), UB_QP_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaJetty, GetId) +{ + EXPECT_NO_FATAL_FAILURE(jetty->GetId()); +} + +TEST_F(TestUbUrmaJetty, SetAndGetUpId) +{ + EXPECT_NO_FATAL_FAILURE(jetty->SetUpId(1)); + EXPECT_NO_FATAL_FAILURE(jetty->GetUpId()); +} + +TEST_F(TestUbUrmaJetty, SetAndGetName) +{ + EXPECT_NO_FATAL_FAILURE(jetty->GetName()); + EXPECT_NO_FATAL_FAILURE(jetty->SetName(mName)); +} + +TEST_F(TestUbUrmaJetty, SetAndGetContext) +{ + EXPECT_NO_FATAL_FAILURE(jetty->SetUpContext(1)); + EXPECT_NO_FATAL_FAILURE(jetty->GetUpContext()); +} + +TEST_F(TestUbUrmaJetty, SetAndGetContextOne) +{ + EXPECT_NO_FATAL_FAILURE(jetty->SetUpContext1(1)); + EXPECT_NO_FATAL_FAILURE(jetty->GetUpContext1()); +} + +TEST_F(TestUbUrmaJetty, GetCtxPosted) +{ + UBOpContextInfo *remaining = nullptr; + EXPECT_NO_FATAL_FAILURE(jetty->GetCtxPosted(remaining)); +} + +TEST_F(TestUbUrmaJetty, GetPostedCount) +{ + jetty->mCtxPostedCount = 1; + EXPECT_EQ(jetty->GetPostedCount(), 1); +} + +TEST_F(TestUbUrmaJetty, GetPostSendWr) +{ + jetty->mPostSendRef = 0; + EXPECT_EQ(jetty->GetPostSendWr(), false); + jetty->mPostSendRef = NN_NO64; + EXPECT_EQ(jetty->GetPostSendWr(), true); +} + +TEST_F(TestUbUrmaJetty, GetOneSideWr) +{ + jetty->mOneSideRef = 0; + EXPECT_EQ(jetty->GetOneSideWr(), false); + jetty->mOneSideRef = NN_NO64; + EXPECT_EQ(jetty->GetOneSideWr(), true); +} + +TEST_F(TestUbUrmaJetty, NewId) +{ + EXPECT_NO_FATAL_FAILURE(UBJetty::NewId()); +} + +TEST_F(TestUbUrmaJetty, PostRegMrSize) +{ + EXPECT_EQ(jetty->PostRegMrSize(), NN_NO1024); +} + +TEST_F(TestUbUrmaJetty, StopParamErr) +{ + jetty->isStarted = false; + EXPECT_EQ(jetty->Stop(), UB_OK); +} + +TEST_F(TestUbUrmaJetty, Stop) +{ + jetty->isStarted = true; + MOCKER(HcomUrma::ModifyJetty).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER(HcomUrma::ModifyJfr).stubs().will(returnValue(0)); + EXPECT_EQ(jetty->Stop(), 1); + EXPECT_EQ(jetty->Stop(), UB_OK); +} + +TEST_F(TestUbUrmaJetty, StopTwo) +{ + jetty->isStarted = true; + + MOCKER(HcomUrma::ModifyJetty).stubs().will(returnValue(1)); + jetty->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + EXPECT_EQ(jetty->Stop(), 1); +} + +TEST_F(TestUbUrmaJetty, StopModifyJfrFail) +{ + jetty->isStarted = true; + + urma_jfr_t tJfr; + jetty->mJfr = &tJfr; + + MOCKER(HcomUrma::ModifyJetty).stubs().will(returnValue(0)); + MOCKER(HcomUrma::ModifyJfr).stubs().will(returnValue(1)); + EXPECT_EQ(jetty->Stop(), 1); + + jetty->mJfr = nullptr; +} + +TEST_F(TestUbUrmaJetty, CreateUrmaJettyParamErr) +{ + jetty->mUBContext = nullptr; + EXPECT_EQ(jetty->CreateUrmaJetty(0, 0, 0), UB_PARAM_INVALID); +} + +TEST_F(TestUbUrmaJetty, CreateUrmaJettyUBC) +{ + urma_jfr_t jfr{}; + urma_jfr_t *tmpJfr = nullptr; + urma_jetty_t *tmpJetty = nullptr; + jetty->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + MOCKER(HcomUrma::CreateJfr).stubs().will(returnValue(tmpJfr)).then(returnValue(&jfr)); + MOCKER(HcomUrma::CreateJetty).stubs().will(returnValue(tmpJetty)); + urma_status_t res = 10; + MOCKER(HcomUrma::DeleteJfr).stubs().will(returnValue(res)); + EXPECT_EQ(jetty->CreateUrmaJetty(0, 0, 0), UB_PARAM_INVALID); + EXPECT_EQ(jetty->CreateUrmaJetty(0, 0, 0), UB_QP_CREATE_FAILED); +} + +TEST_F(TestUbUrmaJetty, GetJettyMrAndKey) +{ + EXPECT_NO_FATAL_FAILURE(jetty->GetJettyMr()); + EXPECT_NO_FATAL_FAILURE(jetty->GetLKey()); +} + +TEST_F(TestUbUrmaJetty, InitializeParamErr) +{ + jetty->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + + MOCKER_CPP(&UBJetty::CreateJettyMr).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBJetty::CreateUrmaJetty).stubs().will(returnValue(1)); + + EXPECT_EQ(jetty->Initialize(0, 0), 1); + EXPECT_EQ(jetty->Initialize(0, 0), 1); +} + +TEST_F(TestUbUrmaJetty, InitializeSuccess) +{ + int tmp; + + jetty->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + + MOCKER_CPP(&UBJetty::CreateJettyMr).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::CreateUrmaJetty).stubs().will(returnValue(1)).then(returnValue(0)); + + EXPECT_EQ(jetty->Initialize(0, 0), 1); + EXPECT_EQ(jetty->Initialize(0, 0), 0); +} + +TEST_F(TestUbUrmaJetty, ChangeToInitAndReceive) +{ + urma_jetty_attr_t attr{}; + UBJettyExchangeInfo exInfo{}; + + EXPECT_NO_FATAL_FAILURE(jetty->ChangeToInit(attr)); + EXPECT_NO_FATAL_FAILURE(jetty->ChangeToReceive(exInfo, attr)); +} + +TEST_F(TestUbUrmaJetty, ChangeToSend) +{ + urma_jetty_attr_t attr{}; + EXPECT_NO_FATAL_FAILURE(jetty->ChangeToSend(attr)); +} + +TEST_F(TestUbUrmaJetty, ChangeToReadyParamErr) +{ + UBJettyExchangeInfo exInfo{}; + jetty->mUrmaJetty = nullptr; + EXPECT_EQ(jetty->ChangeToReady(exInfo), UB_QP_CHANGE_STATE_FAILED); + jetty->mUrmaJetty = &UrmaJetty; + MOCKER_CPP(&UBJetty::SetMaxSendWrConfig).stubs().will(returnValue(1)); + EXPECT_EQ(jetty->ChangeToReady(exInfo), 1); +} + +TEST_F(TestUbUrmaJetty, ChangeToReadyUbcUserCtlFailed) +{ + MOCKER_CPP(&UBJetty::SetMaxSendWrConfig).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::UserCtl).stubs().will(returnValue(static_cast(1))); + + UBJettyExchangeInfo exInfo{}; + + jetty->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + EXPECT_EQ(jetty->ChangeToReady(exInfo), UB_ERROR); +} + +TEST_F(TestUbUrmaJetty, ChangeToReadyUbc) +{ + MOCKER_CPP(&UBJetty::SetMaxSendWrConfig).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::UserCtl).stubs().will(returnValue(static_cast(URMA_SUCCESS))); + MOCKER_CPP(&UBJetty::ImportAndBindJetty).stubs().will(returnValue(0)); + + UBJettyExchangeInfo exInfo{}; + EXPECT_EQ(jetty->ChangeToReady(exInfo), 0); +} + +TEST_F(TestUbUrmaJetty, SetMaxSendWrConfig) +{ + UBJettyExchangeInfo exInfo{}; + + exInfo.maxReceiveWr = 0; + EXPECT_EQ(jetty->SetMaxSendWrConfig(exInfo), UB_QP_RECEIVE_CONFIG_ERR); + + exInfo.maxReceiveWr = JETTY_MAX_RECV_WR; + EXPECT_EQ(jetty->SetMaxSendWrConfig(exInfo), 0); +} + +TEST_F(TestUbUrmaJetty, FillExchangeInfoUbcUserCtlFailed) +{ + MOCKER_CPP(HcomUrma::UserCtl).stubs().will(returnValue(static_cast(URMA_FAIL))); + + UBJettyExchangeInfo exInfo{}; + + jetty->mUBContext->protocol = UBSHcomNetDriverProtocol::UBC; + EXPECT_EQ(jetty->FillExchangeInfo(exInfo), UB_ERROR); +} + +TEST_F(TestUbUrmaJetty, FillExchangeInfoUbc) +{ + MOCKER_CPP(HcomUrma::UserCtl).stubs().will(returnValue(static_cast(URMA_SUCCESS))); + + UBJettyExchangeInfo exInfo{}; + + jetty->mJettyOptions.ubcMode = UBSHcomUbcMode::LowLatency; + EXPECT_EQ(jetty->FillExchangeInfo(exInfo), UB_OK); +} + +TEST_F(TestUbUrmaJetty, StoreExchangeInfo) +{ + EXPECT_NO_FATAL_FAILURE(jetty->StoreExchangeInfo(new UBJettyExchangeInfo)); +} + +TEST_F(TestUbUrmaJetty, ImportAndBindJettyErr) +{ + urma_target_jetty_t *tmpJetty = nullptr; + urma_target_jetty_t tmpJetty2{}; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + MOCKER(HcomUrma::ImportJetty).stubs().will(returnValue(tmpJetty)).then(returnValue(&tmpJetty2)); + MOCKER(HcomUrma::BindJetty).stubs().will(returnValue(1)); + MOCKER(HcomUrma::UnimportJetty).stubs().will(returnValue(0)); + EXPECT_EQ(jetty->ImportAndBindJetty(), UB_QP_IMPORT_FAILED); + EXPECT_EQ(jetty->ImportAndBindJetty(), UB_OK); +} + +TEST_F(TestUbUrmaJetty, ImportAndBindJetty) +{ + urma_target_jetty_t tmpJetty2{}; + ctx->protocol = UBSHcomNetDriverProtocol::UBC; + MOCKER(HcomUrma::ImportJetty).stubs().will(returnValue(&tmpJetty2)); + MOCKER(HcomUrma::BindJetty).stubs().will(returnValue(0)); + + EXPECT_EQ(jetty->ImportAndBindJetty(), 0); +} + +TEST_F(TestUbUrmaJetty, CreatePollingCq) +{ + urma_jfc_t *tmpJfc = nullptr; + MOCKER(HcomUrma::CreateJfc).stubs().will(returnValue(tmpJfc)).then(returnValue(&mUrmaJfc)); + EXPECT_EQ(jfc->CreatePollingCq(), UB_NEW_OBJECT_FAILED); + EXPECT_EQ(jfc->CreatePollingCq(), UB_OK); +} + +TEST_F(TestUbUrmaJetty, CtxInitializeParamErr) +{ + urma_device_t **devList = nullptr; + EXPECT_EQ(ctx->Initialize(), UB_OK); + + MOCKER(HcomUrma::GetDeviceList).stubs().will(returnValue(devList)); + ctx->mUrmaContext = nullptr; + EXPECT_EQ(ctx->Initialize(), UB_DEVICE_FAILED_OPEN); +} + +TEST_F(TestUbUrmaJetty, CreateJettyMrErr) +{ + MOCKER(UBMemoryRegionFixedBuffer::Create).stubs().will(returnValue(1)); + EXPECT_EQ(jetty->CreateJettyMr(), 1); +} + +TEST_F(TestUbUrmaJetty, CreateJettyMr) +{ + UBMemoryRegionFixedBuffer *mJettyMR = new (std::nothrow) UBMemoryRegionFixedBuffer(mName, ctx, 0, 0, 0); + ASSERT_NE(mJettyMR, nullptr); + MOCKER(UBMemoryRegionFixedBuffer::Create) + .stubs() + .with(any(), any(), any(), any(), any(), outBound(mJettyMR)) + .will(returnValue(0)); + MOCKER_CPP_VIRTUAL(*mJettyMR, &UBMemoryRegionFixedBuffer::Initialize) + .stubs() + .will(returnValue(1)) + .then(returnValue(0)); + EXPECT_EQ(jetty->CreateJettyMr(), 1); + + EXPECT_EQ(jetty->CreateJettyMr(), 0); + jetty->mJettyMr->DecreaseRef(); + jetty->mJettyMr->DecreaseRef(); +} + +TEST_F(TestUbUrmaJetty, GetFreeBuff) +{ + uintptr_t item = 0; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBuffer, bool (UBMemoryRegionFixedBuffer::*)(uintptr_t &)) + .stubs() + .will(returnValue(false)); + EXPECT_EQ(jetty->GetFreeBuff(item), false); +} + +TEST_F(TestUbUrmaJetty, GetFreeBufferN) +{ + uintptr_t *items = nullptr; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::GetFreeBufferN).stubs().will(returnValue(false)); + EXPECT_EQ(jetty->GetFreeBufferN(items, 0), false); +} + +TEST_F(TestUbUrmaJetty, ReturnBuffer) +{ + uintptr_t value = 0; + MOCKER_CPP(&UBMemoryRegionFixedBuffer::ReturnBuffer).stubs().will(returnValue(false)); + EXPECT_EQ(jetty->ReturnBuffer(value), false); +} + +TEST_F(TestUbUrmaJetty, PostOneSideSglImportSegFail) +{ + UBSHcomNetTransSgeIov iov[NET_SGE_MAX_IOV]; + uint32_t iovCount = NET_SGE_MAX_IOV; + uint64_t ctx[NET_SGE_MAX_IOV]; + urma_target_seg_t seg{}; + urma_target_seg_t *tmpSeg = nullptr; + MOCKER(HcomUrma::ImportSeg).stubs().will(returnValue(tmpSeg)); + MOCKER(HcomUrma::PostJettySendWr, + urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, uint32_t, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)); + EXPECT_EQ(jetty->PostOneSideSgl(iov, iovCount, ctx, true, NET_SGE_MAX_IOV), UB_QP_POST_READ_FAILED); +} + +TEST_F(TestUbUrmaJetty, PostOneSideSglFail) +{ + UBSHcomNetTransSgeIov iov[NET_SGE_MAX_IOV]; + uint32_t iovCount = NET_SGE_MAX_IOV; + uint64_t ctx[NET_SGE_MAX_IOV]; + urma_target_seg_t seg{}; + urma_target_seg_t *tmpSeg = &seg; + MOCKER(HcomUrma::ImportSeg).stubs().will(returnValue(tmpSeg)); + MOCKER(HcomUrma::PostJettySendWr, + urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, uint32_t, urma_jfs_wr_t **)) + .stubs().will(returnValue(1)); + MOCKER_CPP(HcomUrma::UnimportSeg).stubs().will(returnValue(0)); + EXPECT_EQ(jetty->PostOneSideSgl(iov, iovCount, ctx, true, NET_SGE_MAX_IOV), UB_QP_POST_READ_FAILED); + EXPECT_EQ(jetty->PostOneSideSgl(iov, iovCount, ctx, false, NET_SGE_MAX_IOV), UB_QP_POST_WRITE_FAILED); +} + +TEST_F(TestUbUrmaJetty, PostOneSideSgl) +{ + UBSHcomNetTransSgeIov iov[NET_SGE_MAX_IOV]; + uint32_t iovCount = NET_SGE_MAX_IOV; + uint64_t ctx[NET_SGE_MAX_IOV]; + urma_target_seg_t seg{}; + urma_target_seg_t *tmpSeg = &seg; + MOCKER(HcomUrma::ImportSeg).stubs().will(returnValue(tmpSeg)); + MOCKER(HcomUrma::PostJettySendWr, + urma_status_t(urma_jetty_t *, urma_jfs_wr_t *, uint32_t, urma_jfs_wr_t **)) + .stubs().will(returnValue(0)); + MOCKER(HcomUrma::UnimportSeg).stubs().will(returnValue(1)); + EXPECT_EQ(jetty->PostOneSideSgl(iov, iovCount, ctx, false, NET_SGE_MAX_IOV), UB_OK); +} + +TEST_F(TestUbUrmaJetty, UnInitialize) +{ + urma_target_jetty_t tmpJetty{}; + jetty->mTargetJetty = &tmpJetty; + jetty->mJettyOptions.ubcMode = UBSHcomUbcMode::LowLatency; + jfc->IncreaseRef(); + MOCKER_CPP(&HcomUrma::UnbindJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::UnimportJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::DeleteJetty).stubs().will(returnValue(0)); + MOCKER_CPP(&UBContext::UnInitialize).stubs().will(returnValue(0)); + EXPECT_EQ(jetty->UnInitialize(), 0); + jetty->mTargetJetty = nullptr; +} + +TEST_F(TestUbUrmaJetty, ImportAndBindJettyFail) +{ + urma_target_jetty_t tmpJetty2{}; + jetty->mJettyOptions.ubcMode = UBSHcomUbcMode::LowLatency; + MOCKER(HcomUrma::ImportJetty).stubs().will(returnValue(&tmpJetty2)); + MOCKER(HcomUrma::BindJetty).stubs().will(returnValue(1)); + MOCKER(HcomUrma::UnimportJetty).stubs().will(returnValue(0)); + + EXPECT_EQ(jetty->ImportAndBindJetty(), UB_QP_BIND_FAILED); +} + +TEST_F(TestUbUrmaJetty, CreateHBMemoryRegion) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + EXPECT_EQ(jetty->CreateHBMemoryRegion(0, mr), NN_INVALID_PARAM); + + MOCKER_CPP(UBMemoryRegion::Create, UResult(const std::string &, UBContext *, uint64_t, UBMemoryRegion *&)) + .stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->CreateHBMemoryRegion(1, mr), 1); + + MOCKER_CPP(&UBMemoryRegion::InitializeForOneSide).stubs().will(returnValue(1)).then(returnValue(0)); + EXPECT_EQ(jetty->CreateHBMemoryRegion(1, mr), 1); + EXPECT_EQ(jetty->CreateHBMemoryRegion(1, mr), 0); +} + +TEST_F(TestUbUrmaJetty, DestroyHBMemoryRegion) +{ + UBSHcomNetMemoryRegionPtr mr = nullptr; + EXPECT_NO_FATAL_FAILURE(jetty->DestroyHBMemoryRegion(mr)); +} + +TEST_F(TestUbUrmaJetty, GetNextLocalHBAddress) +{ + auto localMr = new (std::nothrow) UBMemoryRegion("localMr", nullptr, 1, 1, 1); + jetty->mHBLocalMr = localMr; + jetty->mHBLocalMr->IncreaseRef(); + EXPECT_NO_FATAL_FAILURE(jetty->GetNextLocalHBAddress()); + if (localMr != nullptr) { + delete localMr; + localMr = nullptr; + } + jetty->mHBLocalMr.Set(nullptr); +} + +TEST_F(TestUbUrmaJetty, GetLocalHBKey) +{ + auto localMr = new (std::nothrow) UBMemoryRegion("localMr", nullptr, 1, 1, 1); + jetty->mHBLocalMr = localMr; + jetty->mHBLocalMr->IncreaseRef(); + MOCKER_CPP(&UBMemoryRegion::GetLKey).stubs().will(returnValue(1)); + EXPECT_NO_FATAL_FAILURE(jetty->GetLocalHBKey()); + if (localMr != nullptr) { + delete localMr; + localMr = nullptr; + } + jetty->mHBLocalMr.Set(nullptr); +} + +TEST_F(TestUbUrmaJetty, GetRemoteHbInfo) +{ + auto remoteMr = new (std::nothrow) UBMemoryRegion("remoteMr", nullptr, 1, 1, 1); + jetty->mHBRemoteMr = remoteMr; + jetty->mHBRemoteMr->IncreaseRef(); + UBJettyExchangeInfo info{}; + MOCKER_CPP(&UBMemoryRegion::GetLKey).stubs().will(returnValue(1)); + EXPECT_NO_FATAL_FAILURE(jetty->GetRemoteHbInfo(info)); + if (remoteMr != nullptr) { + delete remoteMr; + remoteMr = nullptr; + } + jetty->mHBRemoteMr.Set(nullptr); +} + +} +} +#endif diff --git a/test/unit_test/transport/ub/test_ub_urma_wrapper.cpp b/test/unit_test/transport/ub/test_ub_urma_wrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5fa6bc0cbef3c6509ae67f70a8cefee98f867de1 --- /dev/null +++ b/test/unit_test/transport/ub/test_ub_urma_wrapper.cpp @@ -0,0 +1,626 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED + +#include +#include +#include + +#include "net_monotonic.h" +#include "ub_common.h" +#include "ub_mr_fixed_buf.h" +#include "ub_urma_wrapper_jetty.h" +#include "under_api/urma/urma_api_wrapper.h" +#include "under_api/obmm/obmm_api_wrapper.h" + +namespace ock { +namespace hcom { +urma_device_t **resList = nullptr; + +class TestUbUrmaWrapper : public testing::Test { +public: + TestUbUrmaWrapper(); + virtual void SetUp(void); + virtual void TearDown(void); + std::string name = "test"; + UBEId eid{}; + UBDeviceHelper *mUBDeviceHelper = nullptr; + UBContext *ctx = nullptr; + UBJfc *jfc = nullptr; +}; + +TestUbUrmaWrapper::TestUbUrmaWrapper() {} + +void TestUbUrmaWrapper::SetUp() +{ + mUBDeviceHelper = new (std::nothrow) UBDeviceHelper(); + ctx = new (std::nothrow) UBContext("ctx", eid); + jfc = new (std::nothrow) UBJfc(name, ctx, false, 0); + resList = (urma_device_t **)malloc(sizeof(urma_device_t *)); +} + +void TestUbUrmaWrapper::TearDown() +{ + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); + + jfc->mUBContext = nullptr; + + if (mUBDeviceHelper != nullptr) { + delete mUBDeviceHelper; + mUBDeviceHelper = nullptr; + } + + if (ctx != nullptr) { + delete ctx; + ctx = nullptr; + } + + if (jfc != nullptr) { + delete jfc; + jfc = nullptr; + } + free(resList); + GlobalMockObject::verify(); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperInitialize) +{ + mUBDeviceHelper->G_InitRef = 1; + UResult ret = mUBDeviceHelper->Initialize(); + EXPECT_EQ(ret, UB_OK); + + mUBDeviceHelper->G_InitRef = 0; + MOCKER_CPP(&UBDeviceHelper::DoInitialize).stubs().will(returnValue(0)); + ret = mUBDeviceHelper->Initialize(); + EXPECT_EQ(ret, UB_OK); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperUnInitialize) +{ + MOCKER_CPP(HcomUrma::Uninit).stubs().will(returnValue(0)); + EXPECT_NO_FATAL_FAILURE(mUBDeviceHelper->UnInitialize()); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperDoInitialize) +{ + MOCKER_CPP(&UBDeviceHelper::DoUpdate).stubs().will(returnValue(1)).then(returnValue(0)); + UResult ret = mUBDeviceHelper->DoInitialize(); + EXPECT_EQ(ret, 1); + ret = mUBDeviceHelper->DoInitialize(); + EXPECT_EQ(ret, UB_OK); + mUBDeviceHelper->G_InitRef = 0; +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperDoUpdate) +{ + MOCKER_CPP(HcomUrma::Init).stubs().will(returnValue(0)).then(returnValue(1)); + urma_device_t **devList = nullptr; + MOCKER_CPP(HcomUrma::GetDeviceList).stubs().will(returnValue(devList)); + UResult ret = mUBDeviceHelper->DoUpdate(); + EXPECT_EQ(ret, UB_DEVICE_FAILED_OPEN); + ret = mUBDeviceHelper->DoUpdate(); + EXPECT_EQ(ret, 1); +} + +void MockFreeDeviceList(urma_device_t **device_list) +{ + return; +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperDoUpdateErr) +{ + MOCKER_CPP(&HcomUrma::Init).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::GetDeviceList).stubs().will(returnValue(resList)); + void *devAttr = nullptr; + MOCKER(malloc).stubs().will(returnValue(devAttr)); + MOCKER_CPP(&HcomUrma::FreeDeviceList).stubs().will(invoke(MockFreeDeviceList)); + UResult ret = mUBDeviceHelper->DoUpdate(); + EXPECT_EQ(ret, UB_NEW_OBJECT_FAILED); +} + +urma_device_t **MockGetDeviceList(int *num_devices) +{ + *num_devices = NN_NO8; + return resList; +} + +TEST_F(TestUbUrmaWrapper, UBContextInitErr) +{ + MOCKER_CPP(&HcomUrma::Init).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::GetDeviceList).stubs().will(invoke(MockGetDeviceList)); + urma_context_t tmpCtx{}; + MOCKER_CPP(&HcomUrma::CreateContext).stubs().will(returnValue(&tmpCtx)); + void *devAttr = nullptr; + MOCKER(malloc).stubs().will(returnValue(devAttr)); + MOCKER_CPP(&HcomUrma::DeleteContext).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::FreeDeviceList).stubs().will(invoke(MockFreeDeviceList)); + UResult ret = ctx->Initialize(); + EXPECT_EQ(ret, UB_MEMORY_ALLOCATE_FAILED); +} + +TEST_F(TestUbUrmaWrapper, UBContextInitErrTwo) +{ + MOCKER_CPP(&HcomUrma::Init).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::GetDeviceList).stubs().will(invoke(MockGetDeviceList)); + urma_context_t tmpCtx{}; + MOCKER_CPP(&HcomUrma::CreateContext).stubs().will(returnValue(&tmpCtx)); + MOCKER_CPP(&HcomUrma::QueryDevice).stubs().will(returnValue(1)); + MOCKER_CPP(&HcomUrma::DeleteContext).stubs().will(returnValue(0)); + MOCKER_CPP(&HcomUrma::FreeDeviceList).stubs().will(invoke(MockFreeDeviceList)); + UResult ret = ctx->Initialize(); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetEidVec) +{ + std::string devName = "device"; + uint16_t devIndex = 0; + uint32_t eidCnt = 1; + urma_eid_info_t eidInfoList{}; + eidInfoList.eid.in6.interface_id = 1; + std::vector outGidVec; + uint8_t bandWidth = 1; + EXPECT_NO_FATAL_FAILURE(mUBDeviceHelper->GetEidVec(devName, devIndex, eidCnt, &eidInfoList, outGidVec, bandWidth)); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetDeviceCountInitializeFailed) +{ + uint16_t deviceCount = 0; + std::vector enabledDevices; + MOCKER_CPP(UBDeviceHelper::Initialize).stubs().will(returnValue(1)); + UResult ret = mUBDeviceHelper->GetDeviceCount(deviceCount, enabledDevices); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetDeviceCount) +{ + uint16_t deviceCount = 0; + std::vector enabledDevices; + MOCKER_CPP(UBDeviceHelper::Initialize).stubs().will(returnValue(0)); + UResult ret = mUBDeviceHelper->GetDeviceCount(deviceCount, enabledDevices); + EXPECT_EQ(ret, UB_OK); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetEnableDeviceCountInvalidIPMaskOrNoMatchedIP) +{ + uint16_t enableDevCount = 0; + std::string ipMask = ""; + std::string ipGroup = ""; + std::vector enableIps; + UResult ret = mUBDeviceHelper->GetEnableDeviceCount(ipMask, enableDevCount, enableIps, ipGroup); + EXPECT_EQ(ret, NN_INVALID_IP); + + ipMask = "10.0.0.0/24"; + ret = mUBDeviceHelper->GetEnableDeviceCount(ipMask, enableDevCount, enableIps, ipGroup); + EXPECT_EQ(ret, UB_DEVICE_NO_IP_MATCHED); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetEnableDeviceCountInitializeFailed) +{ + uint16_t enableDevCount = 0; + std::string ipMask = ""; + std::string ipGroup = "192.168.0.1;192.168.0.2"; + std::vector enableIps; + MOCKER_CPP(UBDeviceHelper::Initialize).stubs().will(returnValue(1)); + UResult ret = mUBDeviceHelper->GetEnableDeviceCount(ipMask, enableDevCount, enableIps, ipGroup); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetEnableDeviceCount) +{ + uint16_t enableDevCount = 0; + std::string ipMask = ""; + std::string ipGroup = "192.168.0.1;192.168.0.2"; + std::vector enableIps; + UBDeviceHelper::G_UBDevMap[0].active = true; + UBDeviceHelper::G_UBDevMap[1].active = true; + MOCKER_CPP(UBDeviceHelper::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(UBDeviceHelper::GetDeviceByIp).stubs().will(returnValue(1)).then(returnValue(0)); + UResult ret = mUBDeviceHelper->GetEnableDeviceCount(ipMask, enableDevCount, enableIps, ipGroup); + EXPECT_EQ(ret, 0); + UBDeviceHelper::G_UBDevMap.clear(); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetIfAddressByIp) +{ + std::string ip = "192.168.0.1"; + struct sockaddr_in address; + MOCKER_CPP(&getifaddrs).stubs().will(returnValue(1)); + UResult ret = mUBDeviceHelper->GetIfAddressByIp(ip, address); + EXPECT_EQ(ret, UB_DEVICE_FAILED_GET_IP_ADDRESS); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetDeviceByAddressInitializeFailed) +{ + std::string ip = "192.168.0.1"; + struct sockaddr_in address; + UBEId eid; + MOCKER_CPP(UBDeviceHelper::Initialize).stubs().will(returnValue(1)); + UResult ret = mUBDeviceHelper->GetDeviceByAddress(ip, address, eid); + EXPECT_EQ(ret, 1); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetDeviceByAddressNoDeviceFound) +{ + std::string ip = "192.168.0.1"; + struct sockaddr_in address; + address.sin_addr.s_addr = inet_addr(ip.c_str()); + UBEId eid; + UBEId testEid; + testEid.devIndex = 1; + testEid.eidIndex = 1; + std::string key = std::to_string(testEid.devIndex); + UBDeviceHelper::G_UBDevEidTable[key].push_back(testEid); + MOCKER_CPP(UBDeviceHelper::Initialize).stubs().will(returnValue(0)); + UResult ret = mUBDeviceHelper->GetDeviceByAddress(ip, address, eid); + EXPECT_EQ(ret, UB_DEVICE_NO_IP_TO_GID_MATCHED); + UBDeviceHelper::G_UBDevEidTable.clear(); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetDeviceByEidNotFound) +{ + UBEId eid; + uint8_t user[URMA_EID_SIZE] = {}; + EXPECT_EQ(UBDeviceHelper::GetDeviceByEid(user, eid), UB_DEVICE_NO_IP_TO_GID_MATCHED); +} + +TEST_F(TestUbUrmaWrapper, UBDeviceHelperGetPortNumber) +{ + uint32_t ret = mUBDeviceHelper->GetPortNumber(); + EXPECT_EQ(ret, mUBDeviceHelper->PORT_NUMBER); +} + +TEST_F(TestUbUrmaWrapper, UBJfcCreateEventCqCreatJfceFail) +{ + urma_jfce_t *jfcePtr = nullptr; + MOCKER_CPP(HcomUrma::CreateJfce).stubs().will(returnValue(jfcePtr)); + UResult ret = jfc->CreateEventCq(); + EXPECT_EQ(ret, UB_NEW_OBJECT_FAILED); +} + +TEST_F(TestUbUrmaWrapper, UBJfcCreateEventCqCreatJfcFail) +{ + urma_jfce_t jfce{}; + urma_jfce_t *jfcePtr = &jfce; + urma_jfc_t *jfcPtr = nullptr; + MOCKER_CPP(HcomUrma::CreateJfce).stubs().will(returnValue(jfcePtr)); + MOCKER_CPP(HcomUrma::CreateJfc).stubs().will(returnValue(jfcPtr)); + MOCKER_CPP(HcomUrma::DeleteJfce).stubs().will(returnValue(0)); + UResult ret = jfc->CreateEventCq(); + EXPECT_EQ(ret, UB_NEW_OBJECT_FAILED); +} + +TEST_F(TestUbUrmaWrapper, UBJfcCreateEventCqRearmJfcFail) +{ + urma_jfce_t jfce{}; + urma_jfce_t *jfcePtr = &jfce; + urma_jfc_t testJfc{}; + urma_jfc_t *jfcPtr = &testJfc; + MOCKER_CPP(HcomUrma::CreateJfce).stubs().will(returnValue(jfcePtr)); + MOCKER_CPP(HcomUrma::CreateJfc).stubs().will(returnValue(jfcPtr)); + MOCKER_CPP(HcomUrma::RearmJfc).stubs().will(returnValue(1)); + MOCKER_CPP(HcomUrma::DeleteJfce).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::DeleteJfc).stubs().will(returnValue(0)); + UResult ret = jfc->CreateEventCq(); + EXPECT_EQ(ret, UB_NEW_OBJECT_FAILED); +} + +TEST_F(TestUbUrmaWrapper, UBJfcCreateEventCqSuccess) +{ + urma_jfce_t jfce{}; + urma_jfce_t *jfcePtr = &jfce; + urma_jfc_t testJfc{}; + urma_jfc_t *jfcPtr = &testJfc; + MOCKER_CPP(HcomUrma::CreateJfce).stubs().will(returnValue(jfcePtr)); + MOCKER_CPP(HcomUrma::CreateJfc).stubs().will(returnValue(jfcPtr)); + MOCKER_CPP(HcomUrma::RearmJfc).stubs().will(returnValue(0)); + UResult ret = jfc->CreateEventCq(); + EXPECT_EQ(ret, UB_OK); + + jfc->mUrmaJfc = nullptr; + jfc->mUrmaJfcEvent = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcInitialize) +{ + urma_jfc_t testJfc{}; + jfc->mUrmaJfc = &testJfc; + UResult ret = jfc->Initialize(); + EXPECT_EQ(ret, UB_OK); + jfc->mUrmaJfc = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcInitializeCtxNull) +{ + jfc->mUrmaJfc = nullptr; + jfc->mUBContext = nullptr; + UResult ret = jfc->Initialize(); + EXPECT_EQ(ret, UB_PARAM_INVALID); +} + +TEST_F(TestUbUrmaWrapper, UBJfcInitializeCreateEventCq) +{ + jfc->mUrmaJfc = nullptr; + jfc->mUBContext = ctx; + urma_context_t urmaCtx{}; + jfc->mUBContext->mUrmaContext = &urmaCtx; + jfc->mCreateCompletionChannel = true; + MOCKER_CPP(UBJfc::CreateEventCq).stubs().will(returnValue(0)); + UResult ret = jfc->Initialize(); + EXPECT_EQ(ret, UB_OK); + jfc->mUBContext->mUrmaContext = nullptr; + jfc->mUBContext = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcInitializeCreatePollingCq) +{ + jfc->mUrmaJfc = nullptr; + jfc->mUBContext = ctx; + urma_context_t urmaCtx{}; + jfc->mUBContext->mUrmaContext = &urmaCtx; + jfc->mCreateCompletionChannel = false; + MOCKER_CPP(UBJfc::CreatePollingCq).stubs().will(returnValue(0)); + UResult ret = jfc->Initialize(); + EXPECT_EQ(ret, UB_OK); + jfc->mUBContext->mUrmaContext = nullptr; + jfc->mUBContext = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcUnInitialize) +{ + urma_jfc_t testJfc{}; + urma_jfce_t jfce{}; + jfc->mUrmaJfc = &testJfc; + jfc->mUrmaJfcEvent = &jfce; + jfc->mUBContext = nullptr; + MOCKER_CPP(HcomUrma::DeleteJfc).stubs().will(returnValue(11)); + MOCKER_CPP(HcomUrma::DeleteJfce).stubs().will(returnValue(0)); + UResult ret = jfc->UnInitialize(); + EXPECT_EQ(ret, UB_OK); +} + +TEST_F(TestUbUrmaWrapper, UBJfcProgressVPollJfcFail) +{ + urma_cr_t cr{}; + uint32_t countInOut = 0; + urma_jfc_t testJfc{}; + jfc->mUrmaJfc = &testJfc; + MOCKER_CPP(HcomUrma::PollJfc).stubs().will(returnValue(-1)); + UResult ret = jfc->ProgressV(&cr, countInOut); + EXPECT_EQ(ret, UB_CQ_POLLING_FAILED); + jfc->mUrmaJfc = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcProgressVSuccess) +{ + urma_cr_t cr{}; + uint32_t countInOut = 0; + urma_jfc_t testJfc{}; + jfc->mUrmaJfc = &testJfc; + MOCKER_CPP(HcomUrma::PollJfc).stubs().will(returnValue(0)).then(returnValue(1)); + UResult ret = jfc->ProgressV(&cr, countInOut); + EXPECT_EQ(ret, UB_OK); + jfc->mUrmaJfc = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcEventProgressVUrmaJfcNull) +{ + urma_cr_t cr{}; + uint32_t countInOut = 0; + int32_t timeoutInMs = 0; + jfc->mUrmaJfc = nullptr; + UResult ret = jfc->EventProgressV(&cr, countInOut, timeoutInMs); + EXPECT_EQ(ret, UB_CQ_NOT_INITIALIZED); +} + +TEST_F(TestUbUrmaWrapper, UBJfcEventProgressVRearmJfcFail) +{ + urma_cr_t cr{}; + uint32_t countInOut = 0; + int32_t timeoutInMs = 0; + urma_jfc_t testJfc{}; + jfc->mUrmaJfc = &testJfc; + urma_jfce_t jfce{}; + jfc->mUrmaJfcEvent = &jfce; + MOCKER_CPP(HcomUrma::WaitJfc).stubs().will(returnValue(1)); + MOCKER_CPP(HcomUrma::PollJfc).stubs().will(returnValue(1)); + MOCKER_CPP(HcomUrma::AckJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::RearmJfc).stubs().will(returnValue(1)); + UResult ret = jfc->EventProgressV(&cr, countInOut, timeoutInMs); + EXPECT_EQ(ret, UB_CQ_EVENT_NOTIFY_FAILED); + jfc->mUrmaJfc = nullptr; + jfc->mUrmaJfcEvent = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcEventProgressVWaitJfcTimeOut) +{ + urma_cr_t cr{}; + uint32_t countInOut = 0; + int32_t timeoutInMs = 0; + urma_jfc_t testJfc{}; + jfc->mUrmaJfc = &testJfc; + urma_jfce_t jfce{}; + jfc->mUrmaJfcEvent = &jfce; + MOCKER_CPP(HcomUrma::WaitJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::PollJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::AckJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::RearmJfc).stubs().will(returnValue(0)); + UResult ret = jfc->EventProgressV(&cr, countInOut, timeoutInMs); + EXPECT_EQ(ret, UB_OK); + jfc->mUrmaJfc = nullptr; + jfc->mUrmaJfcEvent = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcEventProgressVWaitJfcFail) +{ + urma_cr_t cr{}; + uint32_t countInOut = 0; + int32_t timeoutInMs = -1; + urma_jfc_t testJfc{}; + jfc->mUrmaJfc = &testJfc; + urma_jfce_t jfce{}; + jfc->mUrmaJfcEvent = &jfce; + MOCKER_CPP(HcomUrma::WaitJfc).stubs().will(returnValue(-1)); + MOCKER_CPP(HcomUrma::PollJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::AckJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::RearmJfc).stubs().will(returnValue(0)); + UResult ret = jfc->EventProgressV(&cr, countInOut, timeoutInMs); + EXPECT_EQ(ret, UB_CQ_EVENT_GET_FAILED); + jfc->mUrmaJfc = nullptr; + jfc->mUrmaJfcEvent = nullptr; +} + +int FakeWaitJfc1(urma_jfce_t *jfce, uint32_t jfc_cnt, int time_out, urma_jfc_t *jfc[]) +{ + jfc[0] = (urma_jfc_t *)0xdeadbabe; + return 1; +} + +TEST_F(TestUbUrmaWrapper, UBJfcEventProgressVAckJfc) +{ + urma_cr_t cr{}; + uint32_t countInOut = 0; + int32_t timeoutInMs = -1; + urma_jfc_t testJfc{}; + jfc->mUrmaJfc = &testJfc; + urma_jfce_t jfce{}; + jfc->mUrmaJfcEvent = &jfce; + MOCKER_CPP(HcomUrma::WaitJfc).stubs().will(invoke(FakeWaitJfc1)); + MOCKER_CPP(HcomUrma::PollJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::AckJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::RearmJfc).stubs().will(returnValue(0)); + UResult ret = jfc->EventProgressV(&cr, countInOut, timeoutInMs); + EXPECT_EQ(ret, UB_OK); + jfc->mUrmaJfc = nullptr; + jfc->mUrmaJfcEvent = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcEventProgressVPollJfcFail) +{ + urma_cr_t cr{}; + uint32_t countInOut = 0; + int32_t timeoutInMs = -1; + urma_jfc_t testJfc{}; + jfc->mUrmaJfc = &testJfc; + urma_jfce_t jfce{}; + jfc->mUrmaJfcEvent = &jfce; + MOCKER_CPP(HcomUrma::RearmJfc).stubs().will(returnValue(0)); + MOCKER_CPP(&poll).stubs().will(returnValue(1)); + MOCKER_CPP(HcomUrma::WaitJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::PollJfc).stubs().will(returnValue(-1)); + UResult ret = jfc->EventProgressV(&cr, countInOut, timeoutInMs); + EXPECT_EQ(ret, UB_CQ_POLLING_FAILED); + jfc->mUrmaJfc = nullptr; + jfc->mUrmaJfcEvent = nullptr; +} + +TEST_F(TestUbUrmaWrapper, UBJfcEventProgressVSuccess) +{ + urma_cr_t cr{}; + uint32_t countInOut = 0; + int32_t timeoutInMs = -1; + urma_jfc_t testJfc{}; + jfc->mUrmaJfc = &testJfc; + urma_jfce_t jfce{}; + jfc->mUrmaJfcEvent = &jfce; + MOCKER_CPP(HcomUrma::RearmJfc).stubs().will(returnValue(0)); + MOCKER_CPP(&poll).stubs().will(returnValue(1)); + MOCKER_CPP(HcomUrma::WaitJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::PollJfc).stubs().will(returnValue(0)); + MOCKER_CPP(HcomUrma::AckJfc).stubs().will(returnValue(0)); + UResult ret = jfc->EventProgressV(&cr, countInOut, timeoutInMs); + EXPECT_EQ(ret, UB_OK); + jfc->mUrmaJfc = nullptr; + jfc->mUrmaJfcEvent = nullptr; +} + +TEST_F(TestUbUrmaWrapper, HasInternalError) +{ + UBOpContextInfo info; + + info.opResultType = UBOpContextInfo::OpResultType::SUCCESS; + EXPECT_FALSE(info.HasInternalError()); + + info.opResultType = UBOpContextInfo::OpResultType::ERR_TIMEOUT; + EXPECT_TRUE(info.HasInternalError()); + + info.opResultType = UBOpContextInfo::OpResultType::ERR_CANCELED; + EXPECT_TRUE(info.HasInternalError()); + + info.opResultType = UBOpContextInfo::OpResultType::ERR_IO_ERROR; + EXPECT_TRUE(info.HasInternalError()); + + info.opResultType = UBOpContextInfo::OpResultType::ERR_EP_BROKEN; + EXPECT_TRUE(info.HasInternalError()); + + info.opResultType = UBOpContextInfo::OpResultType::ERR_EP_CLOSE; + EXPECT_TRUE(info.HasInternalError()); + + info.opResultType = UBOpContextInfo::OpResultType::ERR_ACCESS_ABRT; + EXPECT_FALSE(info.HasInternalError()); + + info.opResultType = UBOpContextInfo::OpResultType::ERR_ACK_TIMEOUT; + EXPECT_FALSE(info.HasInternalError()); + + info.opResultType = UBOpContextInfo::OpResultType::INVALID_MAGIC; + EXPECT_TRUE(info.HasInternalError()); +} + +TEST_F(TestUbUrmaWrapper, OpResult) +{ + urma_cr_t result; + result.status = URMA_CR_SUCCESS; + EXPECT_EQ(UBOpContextInfo::OpResult(result), UBOpContextInfo::OpResultType::SUCCESS); + result.status = URMA_CR_RNR_RETRY_CNT_EXC_ERR; + EXPECT_EQ(UBOpContextInfo::OpResult(result), UBOpContextInfo::OpResultType::ERR_TIMEOUT); + result.status = URMA_CR_WR_FLUSH_ERR; + EXPECT_EQ(UBOpContextInfo::OpResult(result), UBOpContextInfo::OpResultType::ERR_CANCELED); + result.status = URMA_CR_WR_FLUSH_ERR_DONE; + EXPECT_EQ(UBOpContextInfo::OpResult(result), UBOpContextInfo::OpResultType::ERR_CANCELED); + result.status = URMA_CR_REM_OPERATION_ERR; + EXPECT_EQ(UBOpContextInfo::OpResult(result), UBOpContextInfo::OpResultType::ERR_IO_ERROR); + result.status = URMA_CR_REM_ACCESS_ABORT_ERR; + EXPECT_EQ(UBOpContextInfo::OpResult(result), UBOpContextInfo::OpResultType::ERR_ACCESS_ABRT); + result.status = URMA_CR_ACK_TIMEOUT_ERR; + EXPECT_EQ(UBOpContextInfo::OpResult(result), UBOpContextInfo::OpResultType::ERR_ACK_TIMEOUT); +} + +TEST_F(TestUbUrmaWrapper, GetNResult) +{ + UBOpContextInfo context{}; + UBOpContextInfo::OpResultType opResult = UBOpContextInfo::OpResultType::ERR_TIMEOUT; + EXPECT_EQ(UBOpContextInfo::GetNResult(opResult), NN_MSG_TIMEOUT); + opResult = UBOpContextInfo::OpResultType::ERR_CANCELED; + EXPECT_EQ(UBOpContextInfo::GetNResult(opResult), NN_MSG_CANCELED); + opResult = UBOpContextInfo::OpResultType::ERR_EP_BROKEN; + EXPECT_EQ(UBOpContextInfo::GetNResult(opResult), NN_EP_BROKEN); + opResult = UBOpContextInfo::OpResultType::ERR_EP_CLOSE; + EXPECT_EQ(UBOpContextInfo::GetNResult(opResult), NN_EP_CLOSE); + opResult = UBOpContextInfo::OpResultType::ERR_IO_ERROR; + EXPECT_EQ(UBOpContextInfo::GetNResult(opResult), NN_MSG_ERROR); + opResult = UBOpContextInfo::OpResultType::ERR_ACCESS_ABRT; + EXPECT_EQ(UBOpContextInfo::GetNResult(opResult), NN_URMA_ACCESS_ABRT); + opResult = UBOpContextInfo::OpResultType::ERR_ACK_TIMEOUT; + EXPECT_EQ(UBOpContextInfo::GetNResult(opResult), NN_URMA_ACK_TIMEOUT); +} + + +TEST_F(TestUbUrmaWrapper, GetIfAddressByIp) +{ + struct sockaddr_in addr {}; + EXPECT_EQ(UBDeviceHelper::GetIfAddressByIp("127.0.0.1", addr), 0); +} +} +} +#endif diff --git a/test/unit_test/transport/ub/test_ub_worker.cpp b/test/unit_test/transport/ub/test_ub_worker.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc3b015746fad9acf70bbc2afa255ac33451399f --- /dev/null +++ b/test/unit_test/transport/ub/test_ub_worker.cpp @@ -0,0 +1,777 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifdef UB_BUILD_ENABLED + +#include +#include +#include +#include + +#include "hcom_utils.h" +#include "net_common.h" +#include "ub_worker.h" + + +namespace ock { +namespace hcom { +UBOpContextInfo opCtxInfoPool{}; +UBSglContextInfo sglOpCtxInfoPool{}; + +static UBOpContextInfo *MockOpCtxInfoPoolGet() +{ + return &opCtxInfoPool; +} + +static UBSglContextInfo *MockSglOpCtxInfoPoolGet() +{ + return &sglOpCtxInfoPool; +} + +class TestUbWorker : public testing::Test { +public: + TestUbWorker(); + virtual void SetUp(void); + virtual void TearDown(void); + std::string mName = "TestUbWorker"; + UBContext *ctx = nullptr; + UBWorkerOptions options{}; + NetMemPoolFixed *memPool = nullptr; + NetMemPoolFixed *sglMemPool = nullptr; + NetMemPoolFixed *opCtxInfoPool = nullptr; + NetMemPoolFixedOptions memOptions{}; + UBWorker *worker = nullptr; + UBJetty *qp = nullptr; +}; + +TestUbWorker::TestUbWorker() {} + +void TestUbWorker::SetUp() +{ + opCtxInfoPool = new (std::nothrow) NetMemPoolFixed(mName, memOptions); + ASSERT_NE(opCtxInfoPool, nullptr); + worker = new (std::nothrow) UBWorker(mName, ctx, options, memPool, sglMemPool); + ASSERT_NE(worker, nullptr); + worker->mInited = false; + qp = new (std::nothrow) UBJetty(mName, 0, nullptr, nullptr); + ASSERT_NE(qp, nullptr); + + MOCKER_CPP(&UBSglContextInfoPool::Get).stubs().will(invoke(MockSglOpCtxInfoPoolGet)); + MOCKER_CPP(&UBSglContextInfoPool::Return).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBOpContextInfoPool::Get).stubs().will(invoke(MockOpCtxInfoPoolGet)); + MOCKER_CPP(&UBOpContextInfoPool::Return).stubs().will(ignoreReturnValue()); +} + +void TestUbWorker::TearDown() +{ + GlobalMockObject::verify(); + if (worker != nullptr) { + delete worker; + worker = nullptr; + } + if (opCtxInfoPool != nullptr) { + delete opCtxInfoPool; + opCtxInfoPool = nullptr; + } + if (qp != nullptr) { + delete qp; + qp = nullptr; + } +} + +TEST_F(TestUbWorker, ToString) +{ + UBWorkerOptions option{}; + EXPECT_NO_FATAL_FAILURE(option.ToString()); +} + +TEST_F(TestUbWorker, IsWorkStarted) +{ + worker->mProgressThreadStarted = true; + EXPECT_EQ(worker->IsWorkStarted(1), true); +} + +TEST_F(TestUbWorker, SetIndex) +{ + UBSHcomNetWorkerIndex value{}; + worker->SetIndex(value); + EXPECT_EQ(worker->mIndex.wholeIdx, 0); +} + +TEST_F(TestUbWorker, ReturnOpContextInfo) +{ + UBOpContextInfo *ctx = nullptr; + UBSglContextInfo *sglCtx = nullptr; + EXPECT_NO_FATAL_FAILURE(worker->ReturnOpContextInfo(ctx)); + EXPECT_NO_FATAL_FAILURE(worker->ReturnSglContextInfo(sglCtx)); +} + +TEST_F(TestUbWorker, RegisterHandler) +{ + UBNewReqHandler ubNewReqHandler{}; + UBPostedHandler ubPostedHandler{}; + EXPECT_NO_FATAL_FAILURE(worker->RegisterNewRequestHandler(ubNewReqHandler)); + EXPECT_NO_FATAL_FAILURE(worker->RegisterPostedHandler(ubPostedHandler)); + worker->mNewRequestHandler = nullptr; + worker->mSendPostedHandler = nullptr; +} + +TEST_F(TestUbWorker, RegisterOneSideAndIdleHandler) +{ + UBOneSideDoneHandler ubOneSideDoneHandler{}; + UBIdleHandler ubIdleHandler{}; + EXPECT_NO_FATAL_FAILURE(worker->RegisterOneSideDoneHandler(ubOneSideDoneHandler)); + EXPECT_NO_FATAL_FAILURE(worker->RegisterIdleHandler(ubIdleHandler)); + worker->mOneSideDoneHandler = nullptr; + worker->mIdleHandler = nullptr; +} + +TEST_F(TestUbWorker, DetailName) +{ + EXPECT_NO_FATAL_FAILURE(worker->DetailName()); +} + +TEST_F(TestUbWorker, PortNum) +{ + UBEId eid{}; + UBContext *ubCtx = new UBContext("ubTest", eid); + ubCtx->mPortNumber = 1; + worker->mUBContext = ubCtx; + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); + EXPECT_EQ(worker->PortNum(), 1); + delete ubCtx; + worker->mUBContext = nullptr; +} + +TEST_F(TestUbWorker, WorkerTypeToString) +{ + std::string send("sender"); + std::string unknown("unknown worker type"); + EXPECT_EQ(WorkerTypeToString(UBWorkerType::UB_SENDER), send); + EXPECT_EQ(WorkerTypeToString(static_cast(NN_NO4)), unknown); +} + +TEST_F(TestUbWorker, PollingModeToString) +{ + std::string busy("busy_polling"); + std::string unknown("unknown worker mode"); + EXPECT_EQ(PollingModeToString(UBPollingMode::UB_BUSY_POLLING), busy); + EXPECT_EQ(PollingModeToString(static_cast(NN_NO2)), unknown); +} + +TEST_F(TestUbWorker, Initialize) +{ + worker->mInited = true; + EXPECT_EQ(worker->Initialize(), UB_OK); + worker->mInited = false; + + EXPECT_EQ(worker->Initialize(), UB_PARAM_INVALID); +} + +TEST_F(TestUbWorker, InitializeSuccess) +{ + urma_context_t UrmaContext{}; + UBEId eid{}; + UBContext *ubCtx = new UBContext("ubTest", eid); + ubCtx->mUrmaContext = &UrmaContext; + worker->mUBContext = ubCtx; + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBOpContextInfoPool::Initialize, NResult(UBOpContextInfoPool::*)(const NetMemPoolFixedPtr &)) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&UBSglContextInfoPool::Initialize, NResult(UBSglContextInfoPool::*)(const NetMemPoolFixedPtr &)) + .stubs() + .will(returnValue(0)); + EXPECT_EQ(worker->Initialize(), 0); + + ubCtx->mUrmaContext = nullptr; + worker->mUBContext = nullptr; + if (worker->mUBJfc != nullptr) { + delete worker->mUBJfc; + worker->mUBJfc = nullptr; + } + worker->mUBJfc = nullptr; +} + +TEST_F(TestUbWorker, InitializeUBJfcErr) +{ + urma_context_t UrmaContext{}; + UBEId eid{}; + UBContext *ubCtx = new UBContext("ubTest", eid); + ubCtx->mUrmaContext = &UrmaContext; + worker->mUBContext = ubCtx; + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(1)); + MOCKER_CPP(&UBJfc::UnInitialize).stubs().will(returnValue(0)); + + EXPECT_EQ(worker->Initialize(), 1); + + ubCtx->mUrmaContext = nullptr; + worker->mUBContext = nullptr; + if (worker->mUBJfc != nullptr) { + delete worker->mUBJfc; + worker->mUBJfc = nullptr; + } + worker->mUBJfc = nullptr; + delete ubCtx; + ubCtx = nullptr; +} + +TEST_F(TestUbWorker, InitializeSglCtxInfoPoolErr) +{ + urma_context_t UrmaContext{}; + UBEId eid{}; + UBContext *ubCtx = new UBContext("ubTest", eid); + ubCtx->mUrmaContext = &UrmaContext; + worker->mUBContext = ubCtx; + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJfc::UnInitialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBOpContextInfoPool::Initialize, NResult(UBOpContextInfoPool::*)(const NetMemPoolFixedPtr &)) + .stubs() + .will(returnValue(0)); + MOCKER_CPP(&UBSglContextInfoPool::Initialize, NResult(UBSglContextInfoPool::*)(const NetMemPoolFixedPtr &)) + .stubs() + .will(returnValue(1)); + + EXPECT_EQ(worker->Initialize(), 1); + + ubCtx->mUrmaContext = nullptr; + worker->mUBContext = nullptr; + if (worker->mUBJfc != nullptr) { + delete worker->mUBJfc; + worker->mUBJfc = nullptr; + } + delete ubCtx; + ubCtx = nullptr; +} + +TEST_F(TestUbWorker, InitializeOpCtxInfoPoolErr) +{ + urma_context_t UrmaContext{}; + UBEId eid{}; + UBContext *ubCtx = new UBContext("ubTest", eid); + ubCtx->mUrmaContext = &UrmaContext; + worker->mUBContext = ubCtx; + MOCKER_CPP(HcomUrma::Uninit).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBDeviceHelper::UnInitialize).stubs().will(ignoreReturnValue()); + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJfc::UnInitialize).stubs().will(returnValue(0)); + MOCKER_CPP(&UBOpContextInfoPool::Initialize, NResult(UBOpContextInfoPool::*)(const NetMemPoolFixedPtr &)) + .stubs() + .will(returnValue(1)); + + EXPECT_EQ(worker->Initialize(), 1); + + ubCtx->mUrmaContext = nullptr; + worker->mUBContext = nullptr; + if (worker->mUBJfc != nullptr) { + delete worker->mUBJfc; + worker->mUBJfc = nullptr; + } + delete ubCtx; + ubCtx = nullptr; +} + +TEST_F(TestUbWorker, UnInitialize) +{ + EXPECT_EQ(worker->UnInitialize(), UB_OK); + + worker->mInited = true; + MOCKER_CPP(&UBOpContextInfoPool::UnInitialize).stubs().will(returnValue(0)); + EXPECT_EQ(worker->UnInitialize(), UB_OK); +} + +TEST_F(TestUbWorker, ReInitializeCQ) +{ + EXPECT_EQ(worker->ReInitializeCQ(), UB_OK); + + worker->mInited = true; + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(0)); + EXPECT_EQ(worker->ReInitializeCQ(), UB_OK); + if (worker->mUBJfc != nullptr) { + delete worker->mUBJfc; + worker->mUBJfc = nullptr; + } + worker->mUBJfc = nullptr; +} + +TEST_F(TestUbWorker, ReInitializeCQErr) +{ + worker->mInited = true; + MOCKER_CPP(&UBJfc::Initialize).stubs().will(returnValue(1)); + EXPECT_EQ(worker->ReInitializeCQ(), 1); + if (worker->mUBJfc != nullptr) { + delete worker->mUBJfc; + worker->mUBJfc = nullptr; + } + worker->mUBJfc = nullptr; +} + +TEST_F(TestUbWorker, Start) +{ + EXPECT_EQ(worker->Start(), UB_WORKER_NOT_INITIALIZED); + + worker->mInited = true; + worker->mOptions.dontStartWorkers = true; + EXPECT_EQ(worker->Start(), UB_OK); + worker->mOptions.dontStartWorkers = false; +} + +TEST_F(TestUbWorker, StartTypeErr) +{ + worker->mInited = true; + worker->mOptions.workerType = UB_RECEIVER; + EXPECT_EQ(worker->Start(), UB_WORKER_REQUEST_HANDLER_NOT_SET); + + worker->mOptions.workerType = UB_SENDER; + worker->mOptions.dontStartWorkers = false; + EXPECT_EQ(worker->Start(), UB_WORKER_SEND_POSTED_HANDLER_NOT_SET); +} + +TEST_F(TestUbWorker, Stop) +{ + EXPECT_EQ(worker->Stop(), UB_OK); +} + +TEST_F(TestUbWorker, RunInThreadErr) +{ + worker->mOptions.threadPriority = 1; + worker->mOptions.workerMode = static_cast(NN_NO2); + EXPECT_NO_FATAL_FAILURE(worker->RunInThread()); +} + +TEST_F(TestUbWorker, CreateCtxNullErr) +{ + UBWorker *outWorker = nullptr; + EXPECT_EQ(worker->Create(mName, nullptr, options, memPool, sglMemPool, outWorker), UB_PARAM_INVALID); +} + +TEST_F(TestUbWorker, PostReceiveParamErr) +{ + EXPECT_EQ(worker->PostReceive(nullptr, 0, 0, nullptr), UB_PARAM_INVALID); + + MOCKER_CPP(&UBJetty::PostReceive).stubs().will(returnValue(1)); + qp->IncreaseRef(); + EXPECT_EQ(worker->PostReceive(qp, 0, 0, nullptr), 1); +} + +TEST_F(TestUbWorker, PostReceiveCtxFull) +{ + GlobalMockObject::verify(); + UBOpContextInfo *testPool = nullptr; + MOCKER_CPP(&UBOpContextInfoPool::Get).stubs().will(returnValue(testPool)); + EXPECT_EQ(worker->PostReceive(qp, 0, 0, nullptr), UB_QP_CTX_FULL); +} + +TEST_F(TestUbWorker, RePostReceive) +{ + EXPECT_EQ(worker->RePostReceive(nullptr), UB_PARAM_INVALID); + + UBOpContextInfo ctx{}; + ctx.ubJetty = qp; + MOCKER_CPP(&UBJetty::PostReceive).stubs().will(returnValue(1)); + qp->IncreaseRef(); + qp->IncreaseRef(); + EXPECT_EQ(worker->RePostReceive(&ctx), NN_NO200); +} + +TEST_F(TestUbWorker, PostSendParamErr) +{ + UBSendReadWriteRequest req{}; + EXPECT_EQ(worker->PostSend(nullptr, req, nullptr, 0), UB_PARAM_INVALID); + + MOCKER_CPP(&UBJetty::GetPostSendWr).stubs().will(returnValue(false)); + EXPECT_EQ(worker->PostSend(qp, req, nullptr, 0), UB_QP_POST_SEND_WR_FULL); +} + +TEST_F(TestUbWorker, PostSendCtxFull) +{ + GlobalMockObject::verify(); + UBSendReadWriteRequest req{}; + UBOpContextInfo *testPool = nullptr; + MOCKER_CPP(&UBOpContextInfoPool::Get).stubs().will(returnValue(testPool)); + EXPECT_EQ(worker->PostSend(qp, req, nullptr, 0), UB_QP_CTX_FULL); +} + +TEST_F(TestUbWorker, PostSend) +{ + UBSendReadWriteRequest req{}; + + MOCKER_CPP(&UBJetty::GetPostSendWr).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::PostSend).stubs().will(returnValue(1)); + + qp->IncreaseRef(); + EXPECT_EQ(worker->PostSend(qp, req, nullptr, 0), 1); +} + +TEST_F(TestUbWorker, PostSendSglInlineParamErr) +{ + UBSendSglInlineHeader header{}; + UBSendReadWriteRequest req{}; + EXPECT_EQ(worker->PostSendSglInline(nullptr, header, req, 0), 200); +} + +TEST_F(TestUbWorker, PostSendSglInlineCtxNull) +{ + GlobalMockObject::verify(); + UBSendSglInlineHeader header{}; + UBSendReadWriteRequest req{}; + UBOpContextInfo *testPool = nullptr; + MOCKER_CPP(&UBOpContextInfoPool::Get).stubs().will(returnValue(testPool)); + EXPECT_EQ(worker->PostSendSglInline(qp, header, req, 0), UB_QP_CTX_FULL); +} + +TEST_F(TestUbWorker, PostSendSglInlineWrFull) +{ + UBSendSglInlineHeader header{}; + UBSendReadWriteRequest req{}; + + MOCKER_CPP(&UBJetty::GetPostSendWr).stubs().will(returnValue(false)); + + EXPECT_EQ(worker->PostSendSglInline(qp, header, req, 0), UB_QP_POST_SEND_WR_FULL); +} + +TEST_F(TestUbWorker, PostSendSglInlineSuccess) +{ + UBSendSglInlineHeader header{}; + UBSendReadWriteRequest req{}; + + MOCKER_CPP(&UBJetty::GetPostSendWr).stubs().will(returnValue(true)); + + MOCKER_CPP(&UBJetty::PostSendSglInline).stubs().will(returnValue(0)); + + EXPECT_EQ(worker->PostSendSglInline(qp, header, req, 0), 0); +} + +TEST_F(TestUbWorker, PostSendSglParamErr) +{ + UBSHcomNetTransSglRequest req{}; + req.upCtxSize = 1; + UBSHcomNetTransRequest tlsReq{}; + EXPECT_EQ(worker->PostSendSgl(nullptr, req, tlsReq, 0, false), UB_PARAM_INVALID); + + MOCKER_CPP(&UBJetty::GetPostSendWr).stubs().will(returnValue(false)); + EXPECT_EQ(worker->PostSendSgl(qp, req, tlsReq, 0, false), UB_PARAM_INVALID); +} + +TEST_F(TestUbWorker, PostSendSglCtxFull) +{ + GlobalMockObject::verify(); + UBSHcomNetTransSglRequest req{}; + UBSHcomNetTransRequest tlsReq{}; + UBSglContextInfo *testPool = nullptr; + MOCKER_CPP(&UBSglContextInfoPool::Get).stubs().will(returnValue(testPool)); + EXPECT_EQ(worker->PostSendSgl(qp, req, tlsReq, 0, false), UB_PARAM_INVALID); +} + +TEST_F(TestUbWorker, PostSendSgl) +{ + UBSHcomNetTransSglRequest req{}; + UBSHcomNetTransSgeIov iov{}; + req.iov = &iov; + req.iovCount = 1; + req.upCtxSize = 1; + UBSHcomNetTransRequest tlsReq{}; + + MOCKER_CPP(&UBJetty::GetPostSendWr).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::PostSendSgl).stubs().will(returnValue(1)); + qp->IncreaseRef(); + EXPECT_EQ(worker->PostSendSgl(qp, req, tlsReq, 0, false), 1); +} + +TEST_F(TestUbWorker, PostReadParamErr) +{ + UBSendReadWriteRequest req{}; + EXPECT_EQ(worker->PostRead(nullptr, req), UB_PARAM_INVALID); + + MOCKER_CPP(&UBJetty::GetOneSideWr).stubs().will(returnValue(false)); + EXPECT_EQ(worker->PostRead(qp, req), UB_QP_ONE_SIDE_WR_FULL); +} + +TEST_F(TestUbWorker, PostRead) +{ + UBSendReadWriteRequest req{}; + req.upCtxSize = 1; + MOCKER_CPP(&UBJetty::GetOneSideWr).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::GetProtocol).stubs().will(returnValue(UBSHcomNetDriverProtocol::UBC)); + MOCKER_CPP(&UBJetty::PostRead, UResult(UBJetty::*)(uintptr_t, urma_target_seg_t *, + uintptr_t, uint64_t, uint32_t, uint64_t)) + .stubs() + .will(returnValue(1)); + qp->IncreaseRef(); + EXPECT_EQ(worker->PostRead(qp, req), 1); +} + +TEST_F(TestUbWorker, PostReadCtxFull) +{ + GlobalMockObject::verify(); + UBSendReadWriteRequest req{}; + UBOpContextInfo *testPool = nullptr; + MOCKER_CPP(&UBOpContextInfoPool::Get).stubs().will(returnValue(testPool)); + EXPECT_EQ(worker->PostRead(qp, req), UB_QP_CTX_FULL); +} + +TEST_F(TestUbWorker, PostWriteParamErr) +{ + UBSendReadWriteRequest req{}; + EXPECT_EQ(worker->PostWrite(nullptr, req), UB_PARAM_INVALID); + + MOCKER_CPP(&UBJetty::GetOneSideWr).stubs().will(returnValue(false)); + EXPECT_EQ(worker->PostWrite(qp, req), UB_QP_ONE_SIDE_WR_FULL); +} + +TEST_F(TestUbWorker, PostWriteCtxFull) +{ + GlobalMockObject::verify(); + UBSendReadWriteRequest req{}; + UBOpContextInfo *testPool = nullptr; + MOCKER_CPP(&UBOpContextInfoPool::Get).stubs().will(returnValue(testPool)); + EXPECT_EQ(worker->PostWrite(qp, req), UB_QP_CTX_FULL); +} + +TEST_F(TestUbWorker, PostWrite) +{ + UBSendReadWriteRequest req{}; + req.upCtxSize = 1; + MOCKER_CPP(&UBJetty::GetOneSideWr).stubs().will(returnValue(true)); + MOCKER_CPP(&UBJetty::GetProtocol).stubs().will(returnValue(UBSHcomNetDriverProtocol::UBC)); + MOCKER_CPP(&UBJetty::PostWrite, UResult(UBJetty::*)(uintptr_t, urma_target_seg_t *, + uintptr_t, uint64_t, uint32_t, uint64_t)) + .stubs() + .will(returnValue(1)); + qp->IncreaseRef(); + EXPECT_EQ(worker->PostWrite(qp, req), 1); +} + +TEST_F(TestUbWorker, CreateOneSideCtxParamErr) +{ + UBSgeCtxInfo sgeInfo{}; + uint64_t ctxArr[NET_SGE_MAX_IOV]; + EXPECT_EQ(worker->CreateOneSideCtx(sgeInfo, nullptr, 0, ctxArr, true), UB_PARAM_INVALID); +} + +TEST_F(TestUbWorker, CreateOneSideCtxCtxFull) +{ + GlobalMockObject::verify(); + UBSgeCtxInfo sgeInfo{}; + uint64_t ctxArr[NET_SGE_MAX_IOV]; + UBSHcomNetTransSgeIov *iov = nullptr; + iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + UBOpContextInfo *testPool = nullptr; + MOCKER_CPP(&UBOpContextInfoPool::Get).stubs().will(returnValue(testPool)); + + EXPECT_EQ(worker->CreateOneSideCtx(sgeInfo, iov, 1, ctxArr, true), UB_QP_CTX_FULL); + + if (iov != nullptr) { + delete iov; + iov = nullptr; + } +} + +TEST_F(TestUbWorker, CreateOneSideCtxOneSideWrFull) +{ + UBSgeCtxInfo sgeInfo{}; + sgeInfo.ctx = MockSglOpCtxInfoPoolGet(); + sgeInfo.ctx->qp = qp; + uint64_t ctxArr[NET_SGE_MAX_IOV]; + UBSHcomNetTransSgeIov *iov = nullptr; + iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + MOCKER_CPP(&UBJetty::GetOneSideWr).stubs().will(returnValue(false)); + + EXPECT_EQ(worker->CreateOneSideCtx(sgeInfo, iov, 1, ctxArr, true), UB_QP_ONE_SIDE_WR_FULL); + + if (iov != nullptr) { + delete iov; + iov = nullptr; + } +} + +TEST_F(TestUbWorker, CreateOneSideCtx) +{ + UBSgeCtxInfo sgeInfo{}; + sgeInfo.ctx = MockSglOpCtxInfoPoolGet(); + sgeInfo.ctx->qp = qp; + uint64_t ctxArr[NET_SGE_MAX_IOV]; + UBSHcomNetTransSgeIov *iov = nullptr; + iov = new (std::nothrow) UBSHcomNetTransSgeIov(); + MOCKER_CPP(&UBJetty::GetOneSideWr).stubs().will(returnValue(true)); + + EXPECT_EQ(worker->CreateOneSideCtx(sgeInfo, iov, 1, ctxArr, true), UB_OK); + + if (iov != nullptr) { + delete iov; + iov = nullptr; + } +} + +TEST_F(TestUbWorker, PostOneSideSglParamErr) +{ + UBSendSglRWRequest rwReq{}; + rwReq.upCtxSize = 1; + EXPECT_EQ(worker->PostOneSideSgl(nullptr, rwReq, false), UB_PARAM_INVALID); + + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)).then(returnValue(1)); + EXPECT_EQ(worker->PostOneSideSgl(qp, rwReq, false), UB_PARAM_INVALID); + EXPECT_EQ(worker->PostOneSideSgl(qp, rwReq, false), UB_PARAM_INVALID); +} + +TEST_F(TestUbWorker, PostOneSideSgl) +{ + UBSendSglRWRequest rwReq{}; + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&UBWorker::CreateOneSideCtx).stubs().will(returnValue(1)).then(returnValue(0)); + MOCKER_CPP(&UBJetty::GetProtocol) + .stubs() + .will(returnValue(UBSHcomNetDriverProtocol::UBC)); + MOCKER_CPP(&UBJetty::PostOneSideSgl).stubs().will(returnValue(0)); + + EXPECT_EQ(worker->PostOneSideSgl(qp, rwReq, false), NN_NO1); + EXPECT_EQ(worker->PostOneSideSgl(qp, rwReq, false), UB_OK); +} + +TEST_F(TestUbWorker, CreateQPErr) +{ + UBJetty *tmpQp = nullptr; + worker->mInited = false; + EXPECT_EQ(worker->CreateQP(tmpQp), UB_WORKER_NOT_INITIALIZED); +} + +TEST_F(TestUbWorker, CreateQP) +{ + UBJetty *tmpQp = nullptr; + worker->mInited = true; + EXPECT_EQ(worker->CreateQP(tmpQp), UB_OK); + delete tmpQp; +} + +TEST_F(TestUbWorker, CheckJettyLiveness) +{ + EXPECT_FALSE(worker->CheckJettyLiveness(1)); + + worker->AddJettyId(1); + EXPECT_TRUE(worker->CheckJettyLiveness(1)); + + worker->RemoveJettyId(1); + EXPECT_FALSE(worker->CheckJettyLiveness(1)); +} + +TEST_F(TestUbWorker, TestPostSendSgl) +{ + UBSHcomNetTransSglRequest req{}; + req.upCtxSize = 0; + UBSHcomNetTransRequest tlsReq{}; + + MOCKER_CPP(&UBJetty::GetPostSendWr).stubs().will(returnValue(true)); + MOCKER_CPP(&memcpy_s).stubs().will(returnValue(0)); + MOCKER_CPP(&UBJetty::PostSend).stubs().will(returnValue(0)); + EXPECT_EQ(worker->PostSendSgl(qp, req, tlsReq, 0, true), 0); +} + +void MockRemoveOpCtxInfo(UBJetty *This, UBOpContextInfo *ctxInfo) +{ + return; +} + +void MockUpdateTargetHbTime(NetUBAsyncEndpoint *This) +{ + return; +} + +void TestUbEndPointBroken(const UBSHcomNetEndpointPtr &ep) +{ + NN_LOG_INFO("end point broken"); +} + +int TestUbRequestReceived(const UBOpContextInfo *ctx) +{ + return 0; +} + +int TestUbRequestPosted(const UBOpContextInfo *ctx) +{ + return 0; +} + +int TestUbOneSideDone(const UBOpContextInfo *ctx) +{ + return 0; +} + +TEST_F(TestUbWorker, TestProcessPollingResult) +{ + urma_cr_t *wc1 = nullptr; + uint32_t pollCount = NN_NO1; + UBJetty *lastBrokenQp = nullptr; + urma_cr_status_t lastErrorWcStatus = URMA_CR_SUCCESS; + worker->ProcessPollingResult(wc1, pollCount, lastBrokenQp, lastErrorWcStatus); + + auto *wc = static_cast(calloc(NN_NO1, sizeof(urma_cr_t))); + EXPECT_NE(wc, nullptr); + wc[0].status = URMA_CR_WR_FLUSH_ERR_DONE; + worker->ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + + wc[0].status = URMA_CR_ACK_TIMEOUT_ERR; + UBOpContextInfo contextInfo{}; + wc[0].user_ctx = reinterpret_cast(&contextInfo); + contextInfo.ubJetty = new (std::nothrow) UBJetty("testUbJetty", NN_NO0, nullptr, nullptr); + EXPECT_NE(contextInfo.ubJetty, nullptr); + worker->ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + + contextInfo.ubJetty->isStarted = true; + MOCKER_CPP(&UBOpContextInfo::HasInternalError).stubs().will(returnValue(false)); + MOCKER_CPP(&UBJetty::RemoveOpCtxInfo).stubs().will(invoke(MockRemoveOpCtxInfo)); + worker->ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + + contextInfo.opType = UBOpContextInfo::HB_WRITE; + worker->ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + + contextInfo.opType = UBOpContextInfo::SEND; + worker->ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + + lastBrokenQp = contextInfo.ubJetty; + worker->ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + + wc[0].status = URMA_CR_SUCCESS; + NetUBAsyncEndpoint *ep = new (std::nothrow) NetUBAsyncEndpoint(0, nullptr, nullptr, nullptr); + contextInfo.ubJetty->mUpContext = reinterpret_cast(ep); + MOCKER_CPP(&NetUBAsyncEndpoint::UpdateTargetHbTime).stubs().will(invoke(MockUpdateTargetHbTime)); + worker->RegisterPostedHandler(std::bind(&TestUbRequestReceived, std::placeholders::_1)); + worker->RegisterNewRequestHandler(std::bind(&TestUbRequestPosted, std::placeholders::_1)); + worker->RegisterOneSideDoneHandler(std::bind(&TestUbOneSideDone, std::placeholders::_1)); + worker->ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + + contextInfo.opType = UBOpContextInfo::SGL_READ; + worker->ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + + contextInfo.opType = UBOpContextInfo::RECEIVE; + worker->ProcessPollingResult(wc, pollCount, lastBrokenQp, lastErrorWcStatus); + + free(wc); + wc = nullptr; + if (contextInfo.ubJetty != nullptr) { + delete contextInfo.ubJetty; + contextInfo.ubJetty = nullptr; + } + if (ep != nullptr) { + delete ep; + ep = nullptr; + } +} +} +} +#endif diff --git a/test/unit_test/under_api/openssl/test_openssl_api_dl.cpp b/test/unit_test/under_api/openssl/test_openssl_api_dl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..79ce65b3aa8c6577b1c0f48dee92a925e8b10277 --- /dev/null +++ b/test/unit_test/under_api/openssl/test_openssl_api_dl.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include + +#include "openssl_api_dl.h" + +namespace ock { +namespace hcom { +class TestOpenSSLApiDl : public testing::Test { +public: + TestOpenSSLApiDl(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +TestOpenSSLApiDl::TestOpenSSLApiDl() {} + +void TestOpenSSLApiDl::SetUp() {} + +void TestOpenSSLApiDl::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestOpenSSLApiDl, TestLoadOpensslApiSuccess) +{ + int res = SSLAPI::LoadOpensslAPI(); + EXPECT_EQ(res, 0); +} +} +} \ No newline at end of file diff --git a/test/unit_test/under_api/urma/test_urma_api_dl.cpp b/test/unit_test/under_api/urma/test_urma_api_dl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7efa24b23d0187cb7df27c7fe05d0ce3f6ee71e1 --- /dev/null +++ b/test/unit_test/under_api/urma/test_urma_api_dl.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#ifdef UB_BUILD_ENABLED +#include +#include +#include + +#if defined(TEST_LLT) && defined(MOCK_URMA) +#include "fake_urma.h" +#endif +#include "hcom_log.h" +#include "urma_api_dl.h" + +namespace ock { +namespace hcom { + +constexpr uint32_t NN_NO67 = 67; + +class TestUrmaApiDl : public testing::Test { +public: + TestUrmaApiDl(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +TestUrmaApiDl::TestUrmaApiDl() {} + +void TestUrmaApiDl::SetUp() {} + +void TestUrmaApiDl::TearDown() +{ + GlobalMockObject::verify(); +} + +TEST_F(TestUrmaApiDl, TestLoadUrmaApiDlFail) +{ + int apiNum = NN_NO67; + void *ptr = nullptr; + void *ptr1 = malloc(NN_NO64); + UrmaAPI::gLoaded = false; + + for (int i = 0; i < apiNum; i++) { + MOCKER(dlopen).stubs().will(repeat(ptr1, NN_NO2)).then(returnValue(ptr)); + MOCKER(dlsym).stubs().will(repeat(ptr1, i)).then(returnValue(ptr)); + MOCKER(dlclose).stubs().will(returnValue(0)); + int res = UrmaAPI::LoadUrmaAPI(); + EXPECT_EQ(res, -1); + GlobalMockObject::verify(); + } + free(ptr1); +} + +TEST_F(TestUrmaApiDl, TestLoadUrmaApiDlSuccess) +{ + void *ptr1 = malloc(NN_NO64); + UrmaAPI::gLoaded = false; + MOCKER(dlopen).stubs().will(returnValue(ptr1)); + MOCKER(dlsym).stubs().will(returnValue(ptr1)); + int res = UrmaAPI::LoadUrmaAPI(); + EXPECT_EQ(res, 0); + free(ptr1); +} + +} +} +#endif \ No newline at end of file diff --git a/test/unit_test/under_api/verbs/test_verbs_api_dl.cpp b/test/unit_test/under_api/verbs/test_verbs_api_dl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf3aadc0af6b426915424af6baf322b13034d1a3 --- /dev/null +++ b/test/unit_test/under_api/verbs/test_verbs_api_dl.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + + * ubs-hcom is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ +#include +#include +#include + +#include "hcom_log.h" +#include "verbs_api_dl.h" +#include "../../common/net_common.h" +namespace ock { +namespace hcom { + +class TestVerbsApiDl : public testing::Test { +public: + TestVerbsApiDl(); + virtual void SetUp(void); + virtual void TearDown(void); +}; + +TestVerbsApiDl::TestVerbsApiDl() {} + +void TestVerbsApiDl::SetUp() {} + +void TestVerbsApiDl::TearDown() +{ + GlobalMockObject::verify(); +} + +int g_apiNum = NN_NO26; + +TEST_F(TestVerbsApiDl, TestLoadVerbsApiGLoaded) +{ + VerbsAPI::gLoaded = true; + int res = VerbsAPI::LoadVerbsAPI(); + EXPECT_EQ(res, 0); +} + +TEST_F(TestVerbsApiDl, TestLoadVerbsApiFail) +{ + void *ptr = nullptr; + void *ptr1 = malloc(NN_NO64); + + for (int i = 0; i < g_apiNum; i++) { + VerbsAPI::gLoaded = false; + MOCKER(dlopen).stubs().will(repeat(ptr1, NN_NO2)).then(returnValue(ptr)); + MOCKER(dlsym).stubs().will(repeat(ptr1, i)).then(returnValue(ptr)); + int res = VerbsAPI::LoadVerbsAPI(); + EXPECT_EQ(res, -1); + GlobalMockObject::verify(); + } + free(ptr1); +} +} +} \ No newline at end of file