From 9a93109ac7fcd687cc5ada0357455ca52f6e1499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Tue, 16 Sep 2025 07:19:20 +0800 Subject: [PATCH 1/4] =?UTF-8?q?1.=E6=96=B0=E5=A2=9E=E6=8E=A7=E5=88=B6?= =?UTF-8?q?=E5=BC=80=E5=85=B3=20=E7=94=A8=E4=BA=8E=E6=8E=A7=E5=88=B6spark?= =?UTF-8?q?=E7=9A=84=E8=BE=93=E5=87=BA=E6=98=AF=E5=90=A6=E5=88=86=E7=A6=BB?= =?UTF-8?q?=E5=88=B0stdout=E5=92=8Cstderr=202.=E4=BF=AE=E6=AD=A3=E5=AF=B9?= =?UTF-8?q?=E4=BA=8E--database=E7=9A=84=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/config/common_config.ini | 2 ++ omniadvisor/src/common/constant.py | 1 + .../omniadvisor/interface/hijack_recommend.py | 13 +++++++-- .../service/spark_service/spark_cmd_parser.py | 4 ++- .../service/spark_service/spark_executor.py | 11 +++++-- .../service/spark_service/spark_run.py | 1 + omniadvisor/src/omniadvisor/utils/utils.py | 10 +++++-- .../interface/test_hijack_recommend.py | 12 ++++---- .../spark_service/test_spark_cmd_parser.py | 4 +-- .../spark_service/test_spark_executor.py | 29 ++++++++++--------- .../tests/omniadvisor/utils/test_utils.py | 23 ++++++++------- 11 files changed, 71 insertions(+), 39 deletions(-) diff --git a/omniadvisor/config/common_config.ini b/omniadvisor/config/common_config.ini index cb390b5fe..053d72060 100755 --- a/omniadvisor/config/common_config.ini +++ b/omniadvisor/config/common_config.ini @@ -29,6 +29,8 @@ spark.fetch.trace.timeout=30 spark.fetch.trace.interval=5 # Spark任务执行的超时时间,对比基线的比例 spark.exec.timeout.ratio=10.0 +# 控制spark的输出结果中stdout和stderr是否要分开 +spark.output.merge.switch=False [webpage] # admin 页面显示的时区 diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index 049e73fc3..4a1461077 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -145,6 +145,7 @@ class OmniAdvisorConf: spark_fetch_trace_timeout = _common_config.getint('spark', 'spark.fetch.trace.timeout') spark_fetch_trace_interval = _common_config.getint('spark', 'spark.fetch.trace.interval') spark_exec_timeout_ratio = _common_config.getfloat('spark', 'spark.exec.timeout.ratio') + spark_output_merge_switch = _common_config.getboolean('spark', 'spark.output.merge.switch') # webpage页 admin_timezone = _common_config.get('webpage', 'admin.timezone') # django secret key diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index 3c8274f27..cd0ab8256 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -245,6 +245,14 @@ def _any_modify_keywords_in_sql(sql: str) -> bool: return bool(pattern.search(sql)) +def _spark_output_print(output): + if OA_CONF.spark_output_merge_switch: + print(output.stdout, end="", flush=True) + else: + print(output.stdout, end="", flush=True) + print(output.stderr, end="", flush=True, file=sys.stderr) + + def hijack_recommend(argv: list) -> None: """ 对用户的任务进行劫持,使能参数并下发执行任务 @@ -252,6 +260,7 @@ def hijack_recommend(argv: list) -> None: :param argv: Spark任务的执行命令 :return: """ + # 非SUBMIT动作(指kill任务/查询状态/查询版本)的提交直接回退到原生spark-submit脚本执行 不被特性所劫持 SparkCMDParser.validate_submit_arguments(argv) @@ -281,14 +290,14 @@ def hijack_recommend(argv: list) -> None: if exam_record.status == OA_CONF.ExecStatus.success: # 打印结果输出 global_logger.info("Spark execute success, going to print Spark output.") - print(output, end="", flush=True) + _spark_output_print(output) # 若执行失败 则判断是否需要拉起安全机制 else: if exec_config != user_config: raise RuntimeError("Spark execute failed, ready to activate security protection mechanism.") else: global_logger.warning("Spark execute failed in user config, going to print Spark output.") - print(output, end="", flush=True) + _spark_output_print(output) def main(): diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py index d2bbbd6af..0143aa4dd 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py @@ -87,7 +87,9 @@ class SparkCMDParser: _parser.add_argument('-e', '--e', type=str, help='SQL statement to execute.') _parser.add_argument('-f', type=str, help='File containing SQL script.') _parser.add_argument('-i', help='Initialization SQL file') - _parser.add_argument('-d', '--database') + _parser.add_argument('-d', '--define') + _parser.add_argument('--database') + @classmethod def validate_submit_arguments(cls, args: List) -> None: diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py index 7bbcdb679..650bebd94 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py @@ -76,11 +76,18 @@ class SparkExecutor: :raises RuntimeError: 当无法从输出中匹配到 Application ID 时抛出。 :return: application_id, total_time_taken """ + # 对stdout和stderr进行合并 + spark_output_str = "" + # 拼接字符串之前确保输入流非空有数据 + if spark_output.stdout: + spark_output_str += str(spark_output.stdout) + if spark_output.stderr: + spark_output_str += str(spark_output.stderr) # 解析Application id: # 该条匹配模式的文本来自Spark3.3.1源码中yarn/Client.scala文件在Line224所打印的日志 # 若Spark版本迭代时对该日志内容做出修改 则存在application id匹配失败的风险 pattern = r'Client: Submitting application (application_\d+_\d+) to ResourceManager|Application Id: (.*)\n' - search_obj = re.search(pattern, spark_output) + search_obj = re.search(pattern, spark_output_str) if search_obj: application_id = search_obj.group(1) or search_obj.group(2) if search_obj.group(1): @@ -92,7 +99,7 @@ class SparkExecutor: # 解析Time Taken,只解析行首匹配字符串,避免Spark debug log中重复匹配 pattern = r'^Time taken: (.*) seconds' - match_objs = re.findall(pattern, spark_output, re.MULTILINE) + match_objs = re.findall(pattern, spark_output_str, re.MULTILINE) if match_objs: time_taken_list = [float(obj) for obj in match_objs] total_time_taken = sum(time_taken_list) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 3ac03196d..8fc170148 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -57,6 +57,7 @@ def spark_run(load: Load, config: dict, wait_for_trace: bool = True, # 执行spark任务,并获得执行结果 global_logger.debug('Spark exec cmd is: %s, timeout is: %d', submit_cmd, timeout) exec_result = SparkExecutor.submit_spark_task(cmd_fields=submit_cmd, timeout=timeout) + if exec_result.exitcode == 0: global_logger.info('Spark Load %d execute success', load.id) exam_record_status = OA_CONF.ExecStatus.success diff --git a/omniadvisor/src/omniadvisor/utils/utils.py b/omniadvisor/src/omniadvisor/utils/utils.py index e822a09da..ccc4e0494 100644 --- a/omniadvisor/src/omniadvisor/utils/utils.py +++ b/omniadvisor/src/omniadvisor/utils/utils.py @@ -3,16 +3,18 @@ import os import subprocess import uuid from typing import Tuple, List, Dict - +from common.constant import OA_CONF from common.constant import OmniAdvisorConf from common.exceptions import UnknownEncodingError from omniadvisor.utils.logger import global_logger + def run_cmd(cmd_fields: list) -> Tuple[int, str]: """ 在shell终端执行命令,并返回执行结果。 + :param merged_results: 布尔型变量 默认为真 用于控制是否需要将标准输出和标准错误合并 :param cmd_fields: 需要执行的命令字段List :return: 返回一个包含退出状态码和命令输出的元组。 第一个元素是整数类型的退出状态码(0表示成功), @@ -21,12 +23,14 @@ def run_cmd(cmd_fields: list) -> Tuple[int, str]: global_logger.debug(f"Executor system command: {cmd_fields}") kwargs = { 'stdout': subprocess.PIPE, - 'stderr': subprocess.STDOUT, + 'stderr': subprocess.PIPE, 'shell': False, 'text': True } + if OA_CONF.spark_output_merge_switch: + kwargs['stderr'] = subprocess.STDOUT result = subprocess.run(cmd_fields, **kwargs) - return result.returncode, result.stdout + return result.returncode, result def save_trace_data(data: List[Dict[str, str]], data_dir): diff --git a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py index a078b85f5..268552fec 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py +++ b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py @@ -32,7 +32,7 @@ class TestHijackRecommend: mock_parse_cmd.return_value = (exec_attr, user_config) mock_create_or_update_load.return_value = load mock_get_config.return_value = exec_config - mock_spark_run.return_value = (MagicMock(status="success"), "job output") + mock_spark_run.return_value = (MagicMock(status="success"), MagicMock(stdout="job output", stderr=None)) hijack_recommend(argv) mock_spark_run.assert_called_once_with(load=load, config=exec_config, wait_for_trace=False, exec_in_isolation=False) @@ -58,8 +58,8 @@ class TestHijackRecommend: mock_get_config.return_value = exec_config mock_spark_run.side_effect = [ - (MagicMock(status="fail"), "fail output"), - (MagicMock(status="success"), "safe output") + (MagicMock(status="fail"), MagicMock(stdout=None, stderr="fail output")), + (MagicMock(status="success"), MagicMock(stdout="Safe output", stderr=None)) ] mock_process = MagicMock() mock_multiprocess.return_value = mock_process @@ -87,7 +87,8 @@ class TestHijackRecommend: mock_parse_cmd.return_value = (exec_attr, user_config) mock_create_or_update_load.return_value = load mock_get_config.return_value = exec_config - mock_spark_run.return_value = (MagicMock(status="fail"), "fail output") + mock_spark_output = MagicMock(stdout="job output", stderr=None) + mock_spark_run.return_value = (MagicMock(status="fail"), mock_spark_output) hijack_recommend(argv) # 不进入安全机制的情况call_count的值为1 @@ -113,7 +114,8 @@ class TestHijackRecommend: mock_parse_cmd.return_value = (exec_attr, user_config) mock_create_or_update_load.return_value = load mock_get_config.return_value = exec_config - mock_spark_run.return_value = (MagicMock(status=OA_CONF.ExecStatus.success), "job output") + mock_spark_run.return_value = (MagicMock(status=OA_CONF.ExecStatus.success), + MagicMock(stdout="job output", stderr=None)) mock_process = MagicMock() mock_multiprocess.return_value = mock_process diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_cmd_parser.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_cmd_parser.py index 9de13b04e..36544a236 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_cmd_parser.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_cmd_parser.py @@ -96,7 +96,7 @@ class TestSparkCMDParser: assert exec_attr["e"] == "SELECT 1" assert exec_attr["f"] == "script.sql" assert exec_attr["i"] == "init.sql" - assert exec_attr["database"] == "testdb" + assert exec_attr["define"] == "testdb" assert exec_attr[_CMD_PRIMARY_RESOURCE] == ["example.jar"] # 场景 5: 空 argv 抛出异常 @@ -201,7 +201,7 @@ class TestSparkCMDParser: --usage-error --help -i init_script.sql - -d my_database + --database my_database --e "SELECT * FROM table" -f query.sql my_app.jar diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py index b4a1eea84..8173ab77e 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py @@ -1,6 +1,6 @@ import pytest from common.constant import OA_CONF -from unittest.mock import patch +from unittest.mock import patch, MagicMock from datetime import datetime from omniadvisor.service.spark_service.spark_executor import SparkExecutor @@ -24,11 +24,12 @@ class TestSparkExecutor: self.exitcode = 0 self.application_id = 'application_123456_303209' self.time_taken_list = [23.42, 25.83] - self.output = (f'Client: Submitting application {self.application_id} to ResourceManager\n' - f'XXXXXX\n' - f'Time taken: {self.time_taken_list[0]} seconds\n' - f'Time taken: {self.time_taken_list[1]} seconds\n' - f'Spark Log [INFO]: Time taken: {self.time_taken_list[1]} seconds\n') + self.output = MagicMock(stdout=None, + stderr=f'Client: Submitting application {self.application_id} to ResourceManager\n' + f'XXXXXX\n' + f'Time taken: {self.time_taken_list[0]} seconds\n' + f'Time taken: {self.time_taken_list[1]} seconds\n' + f'Spark Log [INFO]: Time taken: {self.time_taken_list[1]} seconds\n') @patch('omniadvisor.service.spark_service.spark_executor.run_cmd') def test_submit_spark_task(self, mock_run_cmd): @@ -60,22 +61,22 @@ class TestSparkExecutor: assert total_time_taken == OA_CONF.exec_fail_return_runtime # 解析spark output缺失application id - output = (f'Time taken: {self.time_taken_list[0]} seconds\n' - f'Time taken: {self.time_taken_list[1]} seconds\n' - f'Spark Log [INFO]: Time taken: {self.time_taken_list[1]} seconds\n') + self.output.stderr = (f'Time taken: {self.time_taken_list[0]} seconds\n' + f'Time taken: {self.time_taken_list[1]} seconds\n' + f'Spark Log [INFO]: Time taken: {self.time_taken_list[1]} seconds\n') # 验证结果 with pytest.raises(RuntimeError): - _, _ = self.spark_executor._parser_spark_output(output) + _, _ = self.spark_executor._parser_spark_output(self.output) # 解析spark output缺失行首Time taken - output = f'Spark Log [INFO]: Time taken: {self.time_taken_list[1]} seconds\n' + self.output.stderr = f'Spark Log [INFO]: Time taken: {self.time_taken_list[1]} seconds\n' # 验证结果 with pytest.raises(RuntimeError): - _, _ = self.spark_executor._parser_spark_output(output) + _, _ = self.spark_executor._parser_spark_output(self.output) # 解析spark output为空 - output = '' + self.output.stderr = '' # 验证结果 with pytest.raises(RuntimeError): - _, _ = self.spark_executor._parser_spark_output(output) + _, _ = self.spark_executor._parser_spark_output(self.output) diff --git a/omniadvisor/tests/omniadvisor/utils/test_utils.py b/omniadvisor/tests/omniadvisor/utils/test_utils.py index 7a5801753..3034a413a 100644 --- a/omniadvisor/tests/omniadvisor/utils/test_utils.py +++ b/omniadvisor/tests/omniadvisor/utils/test_utils.py @@ -1,5 +1,5 @@ import subprocess -from unittest.mock import patch,mock_open +from unittest.mock import patch, mock_open, MagicMock from omniadvisor.utils.utils import run_cmd, save_trace_data, float_format # 假设你的函数定义在 'your_module' 中 @@ -7,27 +7,30 @@ class TestRunCmd: @patch('subprocess.run') def test_run_cmd_success(self, mock_run): # 配置mock的返回值 - mock_run.return_value.returncode = 0, + mock_run.return_value.returncode = 0 mock_run.return_value.stdout = "Success output" - - exitcode, output = run_cmd('echo hello') + mock_run.return_value.stderr = None + exitcode, output = run_cmd(['echo', 'hello']) # 验证结果 - assert exitcode[0] == 0 - assert output == "Success output" + assert exitcode == 0 + assert output.stdout == "Success output" + assert output.stderr is None mock_run.assert_called_once() @patch('subprocess.run') def test_run_cmd_failure(self, mock_run): # 配置mock的返回值 - mock_run.return_value.returncode = 1, - mock_run.return_value.stdout = "Error output" + mock_run.return_value.returncode = 1 + mock_run.return_value.stdout = None + mock_run.return_value.stderr = "Error output" exitcode, output = run_cmd('false') # 验证结果 - assert exitcode[0] == 1 - assert output == "Error output" + assert exitcode == 1 + assert output.stdout is None + assert output.stderr == "Error output" class TestTraceDataSaver: -- Gitee From 7b9ebcf2f36ea362ad7f5a78a81380c8194d8ce1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Tue, 16 Sep 2025 19:10:01 +0800 Subject: [PATCH 2/4] cleancode --- omniadvisor/src/omniadvisor/utils/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/omniadvisor/src/omniadvisor/utils/utils.py b/omniadvisor/src/omniadvisor/utils/utils.py index ccc4e0494..760cc5604 100644 --- a/omniadvisor/src/omniadvisor/utils/utils.py +++ b/omniadvisor/src/omniadvisor/utils/utils.py @@ -9,12 +9,10 @@ from common.exceptions import UnknownEncodingError from omniadvisor.utils.logger import global_logger - def run_cmd(cmd_fields: list) -> Tuple[int, str]: """ 在shell终端执行命令,并返回执行结果。 - :param merged_results: 布尔型变量 默认为真 用于控制是否需要将标准输出和标准错误合并 :param cmd_fields: 需要执行的命令字段List :return: 返回一个包含退出状态码和命令输出的元组。 第一个元素是整数类型的退出状态码(0表示成功), -- Gitee From 145a6c69b847a4c83ee8d7a1568e6029db9947ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Tue, 16 Sep 2025 20:21:40 +0800 Subject: [PATCH 3/4] =?UTF-8?q?web=E6=BC=8F=E6=B4=9E=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/src/server/engine/settings.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/omniadvisor/src/server/engine/settings.py b/omniadvisor/src/server/engine/settings.py index 465ca8af1..105eab26b 100644 --- a/omniadvisor/src/server/engine/settings.py +++ b/omniadvisor/src/server/engine/settings.py @@ -55,6 +55,8 @@ CSP_FONT_SRC = ("'self'", "fonts.gstatic.com") CSP_IMG_SRC = ("'self'", "data:") CSP_OBJECT_SRC = ("'none'",) CSP_FRAME_ANCESTORS = ("'none'",) +# CSRF Token存储在session中,而不是单独的cookie +CSRF_USE_SESSIONS = True ALLOWED_HOSTS = ['*'] -- Gitee From 649e5ff9dfdc6d5483699a42ee12b73cc896d845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Tue, 16 Sep 2025 20:24:04 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=9D=E7=95=99html?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omniadvisor/compile.py b/omniadvisor/compile.py index fdea3742d..ebd8f4609 100644 --- a/omniadvisor/compile.py +++ b/omniadvisor/compile.py @@ -28,7 +28,7 @@ OUTPUT_ROOT_DIR = os.path.join(PROJECT_DIR, COMPILE_DEST_ROOT_NAME) OUTPUT_SRC_DIR = os.path.join(OUTPUT_ROOT_DIR, COMPILE_FROM_NAME) # 原样保留的文件 -RETAIN_FILES = ['*.so', '*.json', '*.cfg', '*.pyc'] +RETAIN_FILES = ['*.so', '*.json', '*.cfg', '*.pyc', '*.html'] # 这个是产生的临时的编译目录 TEMP_BUILD_DIR = os.path.join(OUTPUT_SRC_DIR, 'build') -- Gitee