diff --git a/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.cc b/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.cc index c2ada7cda1dd2a16bc3cf84f2183011de2b6171d..307d602d0c19ebf7f6d02363e8dc09468eb1c8f4 100644 --- a/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.cc +++ b/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.cc @@ -42,6 +42,8 @@ void InternalPyboostRunner::GetOrCreateKernel(const TensorList &inputs, hash_map_[key] = internal_op_; } + internal_inputs_shape_.clear(); + internal_outputs_shape_.clear(); internal_inputs_shape_.resize(inputs.size()); internal_outputs_shape_.resize(outputs.size()); TransInternalShapes(&internal_inputs_shape_, inputs, true); @@ -149,7 +151,7 @@ void InternalPyboostRunner::TransInternalShapes( bool is_input) { for (size_t i = 0; i < tensorlist.size(); i++) { if (!tensorlist[i].is_defined()) { - shapelist->at(i) = mindspore::internal::ShapeInfo{0}; + shapelist->at(i) = mindspore::internal::ShapeInfo{}; continue; } diff --git a/ccsrc/ops/ms_kernels_internal/reshape_and_cache/reshape_and_cache.cc b/ccsrc/ops/ms_kernels_internal/reshape_and_cache/reshape_and_cache.cc index b6bfc7fa6984c4d30e4d226f6a7a60ecb128d144..4eb88fecd3b89153865b0cb4e43b643f96698f61 100644 --- a/ccsrc/ops/ms_kernels_internal/reshape_and_cache/reshape_and_cache.cc +++ b/ccsrc/ops/ms_kernels_internal/reshape_and_cache/reshape_and_cache.cc @@ -59,7 +59,7 @@ constexpr size_t kInputHeadNumIndex = 5; constexpr size_t kOutputIndex = 0; class CustomReshapeAndCache : public InternalKernelMod { public: - CustomReshapeAndCache() : InternalKernelMod() {} + CustomReshapeAndCache() : InternalKernelMod(), skip_execution_(false) {} ~CustomReshapeAndCache() = default; void InitKernelInputsOutputsIndex() override { @@ -69,15 +69,64 @@ public: kernel_outputs_index_ = {kOutputIndex}; } + int Resize(const std::vector &inputs, + const std::vector &outputs) override { + // Check if any input has shape containing 0 + for (const auto &input : inputs) { + if (input == nullptr) + continue; + auto shape = input->GetShapeVector(); + for (const auto &dim : shape) { + if (dim == 0) { + MS_LOG(INFO) << "ReshapeAndCache: Skipping execution due to zero " + "dimension in input shape: " + << shape; + skip_execution_ = true; + return KernelMod::Resize(inputs, outputs); // Skip execution + } + } + } + + skip_execution_ = false; + // Call base class implementation + return InternalKernelMod::Resize(inputs, outputs); + } + + bool Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, + void *stream_ptr) override { + // Skip execution if flag is set + if (skip_execution_) { + return true; // Skip execution, return success + } + + // Call base class implementation + return InternalKernelMod::Launch(inputs, workspace, outputs, stream_ptr); + } + protected: internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, const internal::OutputsImmutableInfoList &outputs, const std::vector &ms_inputs, const std::vector &ms_outputs) override { - return internal::CreateReshapeAndCacheOp( - inputs, outputs, internal::kInternalReshapeAndCacheOpName); + internal::ReshapeAndCacheParam param; + auto head_num = ms_inputs.at(internal::kIndex5); + if (head_num->dtype_id() == TypeId::kNumberTypeInt64) { + param.head_num = + static_cast(head_num->GetValue().value()); + } else { + MS_LOG(EXCEPTION) + << "ReshapeAndCache [head_num]'s dtype wrong, expect int64, but got: " + << head_num->dtype_id(); + } + return internal::CreateAsdReshapeAndCacheOp( + inputs, outputs, param, internal::kInternalAsdReshapeAndCacheOpName); } + +private: + bool skip_execution_; // Flag to skip execution when shape contains 0 }; } // namespace ms_custom_ops @@ -102,15 +151,17 @@ protected: internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, const internal::OutputsImmutableInfoList &outputs) override { - return internal::CreateReshapeAndCacheOp( - inputs, outputs, internal::kInternalReshapeAndCacheOpName); + internal::ReshapeAndCacheParam param; + param.head_num = this->head_num_; + return internal::CreateAsdReshapeAndCacheOp( + inputs, outputs, param, internal::kInternalAsdReshapeAndCacheOpName); } private: int32_t head_num_{0}; }; MS_KERNELS_INTERNAL_NAME_REG(ReshapeAndCache, - internal::kInternalReshapeAndCacheOpName); + internal::kInternalAsdReshapeAndCacheOpName); } // namespace ms::pynative namespace ms_custom_ops { diff --git a/tests/st/test_custom_reshape_and_cache.py b/tests/st/test_custom_reshape_and_cache.py index e1545c81288c52c0704c9952366b128dc4de0256..bfad5e31f70626d32401d8381651d539e4b71a5e 100644 --- a/tests/st/test_custom_reshape_and_cache.py +++ b/tests/st/test_custom_reshape_and_cache.py @@ -37,7 +37,8 @@ SLOT_SIZE = 64 BATCH_SIZE = 13 SEQ_LEN = 3 NUM_HEADS = 16 -HEAD_DIM = 32 +K_HEAD_DIM = 32 +V_HEAD_DIM = 32 class CacheFormat(Enum): @@ -59,7 +60,6 @@ class ReshapeAndCacheAll(nn.Cell): def __init__(self): super().__init__() - @jit def construct(self, key, value, key_cache, value_cache, slot_map, head_num=0): return ms_custom_ops.reshape_and_cache( key, value, key_cache, value_cache, slot_map, head_num) @@ -137,6 +137,25 @@ class TestConfig: context.set_context(jit_config=self.jit_config) +class DimensionTestHelper: + """Helper class for testing different dimension combinations""" + + @staticmethod + def run_with_dimensions(k_head_dim: int, v_head_dim: int, test_func): + """Run test with specified dimensions and restore original values""" + global K_HEAD_DIM, V_HEAD_DIM + original_k_head_dim = K_HEAD_DIM + original_v_head_dim = V_HEAD_DIM + + try: + K_HEAD_DIM = k_head_dim + V_HEAD_DIM = v_head_dim + test_func() + finally: + K_HEAD_DIM = original_k_head_dim + V_HEAD_DIM = original_v_head_dim + + # =============================== # test nd format # =============================== @@ -157,40 +176,57 @@ class TestDataGenerator: return np.random.choice(np.arange(num_tokens), num_tokens, replace=False).astype(np.int32) @staticmethod - def get_update_shape(kv_dim: int) -> Tuple[Tuple[int, ...], int]: - """Get update shape and number of tokens based on dimension""" + def get_update_shapes(kv_dim: int, k_head_dim=None, v_head_dim=None) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: + """Get update shapes for key and value, and number of tokens based on dimension""" + # Use provided dimensions or fall back to global constants + actual_k_head_dim = k_head_dim if k_head_dim is not None else K_HEAD_DIM + actual_v_head_dim = v_head_dim if v_head_dim is not None else V_HEAD_DIM + if kv_dim == 2: - update_shape = (BATCH_SIZE * SEQ_LEN, NUM_HEADS * HEAD_DIM) - num_tokens = update_shape[0] + key_update_shape = (BATCH_SIZE * SEQ_LEN, NUM_HEADS * actual_k_head_dim) + value_update_shape = (BATCH_SIZE * SEQ_LEN, NUM_HEADS * actual_v_head_dim) + num_tokens = key_update_shape[0] elif kv_dim == 3: - update_shape = (BATCH_SIZE, SEQ_LEN, NUM_HEADS * HEAD_DIM) - num_tokens = update_shape[0] * update_shape[1] + key_update_shape = (BATCH_SIZE, SEQ_LEN, NUM_HEADS * actual_k_head_dim) + value_update_shape = (BATCH_SIZE, SEQ_LEN, NUM_HEADS * actual_v_head_dim) + num_tokens = key_update_shape[0] * key_update_shape[1] else: raise ValueError(f"Key's dim should be 2 or 3, but got {kv_dim}") - return update_shape, num_tokens + return key_update_shape, value_update_shape, num_tokens + + @staticmethod + def get_update_shape(kv_dim: int, is_key: bool = True, k_head_dim=None, v_head_dim=None) -> Tuple[Tuple[int, ...], int]: + """Legacy method for backward compatibility""" + key_shape, value_shape, num_tokens = TestDataGenerator.get_update_shapes(kv_dim, k_head_dim, v_head_dim) + return (key_shape if is_key else value_shape), num_tokens class NDDataGenerator(TestDataGenerator): """Data generator for ND format""" @staticmethod - def create_inputs(dtype: np.dtype, kv_dim: int) -> Tuple[np.ndarray, ...]: + def create_inputs(dtype: np.dtype, kv_dim: int, k_head_dim=None, v_head_dim=None) -> Tuple[np.ndarray, ...]: """Create ND format inputs""" - cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS, HEAD_DIM) - update_shape, num_tokens = TestDataGenerator.get_update_shape(kv_dim) + # Use provided dimensions or fall back to global constants + actual_k_head_dim = k_head_dim if k_head_dim is not None else K_HEAD_DIM + actual_v_head_dim = v_head_dim if v_head_dim is not None else V_HEAD_DIM + + key_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS, actual_k_head_dim) + value_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS, actual_v_head_dim) + key_update_shape, value_update_shape, num_tokens = TestDataGenerator.get_update_shapes(kv_dim, k_head_dim, v_head_dim) - key_update = TestDataGenerator.create_random_data(update_shape, dtype) - value_update = TestDataGenerator.create_random_data(update_shape, dtype) - key_cache = TestDataGenerator.create_random_data(cache_shape, dtype) - value_cache = TestDataGenerator.create_random_data(cache_shape, dtype) + key_update = TestDataGenerator.create_random_data(key_update_shape, dtype) + value_update = TestDataGenerator.create_random_data(value_update_shape, dtype) + key_cache = TestDataGenerator.create_random_data(key_cache_shape, dtype) + value_cache = TestDataGenerator.create_random_data(value_cache_shape, dtype) slot_map = TestDataGenerator.create_slot_map(num_tokens) return key_update, value_update, key_cache, value_cache, slot_map -def create_nd_inputs(dtype=np.float16, kv_dim=3): +def create_nd_inputs(dtype=np.float16, kv_dim=3, k_head_dim=None, v_head_dim=None): """Legacy function for backward compatibility""" - return NDDataGenerator.create_inputs(dtype, kv_dim) + return NDDataGenerator.create_inputs(dtype, kv_dim, k_head_dim, v_head_dim) class InferenceEngine: @@ -206,10 +242,14 @@ class InferenceEngine: key_cache_ans = key_cache.copy() value_cache_ans = value_cache.copy() - head = key_cache.shape[2] - head_dim = key_cache.shape[3] - key_tmp = key_tmp.reshape(-1, head, head_dim) - value_tmp = value_tmp.reshape(-1, head, head_dim) + # Use different dimensions for key and value + key_head = key_cache.shape[2] + key_head_dim = key_cache.shape[3] + value_head = value_cache.shape[2] + value_head_dim = value_cache.shape[3] + + key_tmp = key_tmp.reshape(-1, key_head, key_head_dim) + value_tmp = value_tmp.reshape(-1, value_head, value_head_dim) for i, slot in enumerate(slot_map): slot_idx = slot // key_cache.shape[1] @@ -229,8 +269,9 @@ class InferenceEngine: key_cache_ans = key_cache.copy() value_cache_ans = value_cache.copy() + # Use different dimensions for key and value key_tmp = key_tmp.reshape(-1, key_cache.shape[2]) - value_tmp = value_tmp.reshape(-1, key_cache.shape[2]) + value_tmp = value_tmp.reshape(-1, value_cache.shape[2]) for i, slot in enumerate(slot_map): slot_idx = slot // key_cache.shape[1] @@ -294,7 +335,8 @@ def test_reshape_and_cache_nd_key(np_dtype, kv_dim, run_mode): Description: Test ND format with key only. Expectation: Assert that results are consistent with numpy. """ - test_config = TestConfig(device_target="Ascend", mode=run_mode) + test_config = TestConfig(device_target="Ascend", mode=run_mode, + jit_config={"jit_level": "O0"}) test_config.apply() net = ReshapeAndCacheKey() @@ -310,3 +352,349 @@ def test_reshape_and_cache_nd_key(np_dtype, kv_dim, run_mode): # Run test _ = net(ms_k, key_cache=ms_k_cache, slot_map=ms_slot_map) TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('np_dtype', [np.float16, np.int8, bfloat16]) +@pytest.mark.parametrize('kv_dim', [2, 3]) +@pytest.mark.parametrize('k_head_dim', [32, 64, 128]) +@pytest.mark.parametrize('v_head_dim', [32, 64, 128]) +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reshape_and_cache_nd_key_value_different_dimensions(np_dtype, kv_dim, k_head_dim, v_head_dim, run_mode): + """ + Feature: Test ReshapeAndCache. + Description: Test ND format with different K_HEAD_DIM and V_HEAD_DIM combinations. + Expectation: Assert that results are consistent with numpy. + """ + def run_test(): + test_config = TestConfig(device_target="Ascend", mode=run_mode) + test_config.apply() + + net = ReshapeAndCacheAll() + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( + np_dtype, kv_dim, k_head_dim, v_head_dim) + np_k_cache_out, np_v_cache_out = nd_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Run test + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map) + TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) + TestResultVerifier.verify_results(ms_v_cache, np_v_cache_out, np_dtype) + + DimensionTestHelper.run_with_dimensions(k_head_dim, v_head_dim, run_test) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('kv_dim', [2, 3]) +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reshape_and_cache_nz_different_key_value_dimensions(kv_dim, run_mode): + """ + Feature: Test ReshapeAndCache. + Description: Test NZ format with significantly different K_HEAD_DIM and V_HEAD_DIM. + Expectation: Assert that results are consistent with numpy. + """ + def run_test(): + # Setup context + jit_config = {"jit_level": "O0"} + test_config = TestConfig(device_target="Ascend", mode=run_mode, jit_config=jit_config) + test_config.apply() + + net = ReshapeAndCacheAll() + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nz_inputs( + np.float16, np.float16, kv_dim) + + # Verify that key and value have different shapes + assert np_k.shape != np_v.shape, f"Key and value should have different shapes: {np_k.shape} vs {np_v.shape}" + assert np_k_cache.shape != np_v_cache.shape, f"Key and value cache should have different shapes: {np_k_cache.shape} vs {np_v_cache.shape}" + + np_k_cache_out, np_v_cache_out = nz_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Create MindSpore inputs with appropriate format + if run_mode == context.GRAPH_MODE: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="FRACTAL_NZ") + else: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="") + acl_format = 29 + ms_k_cache = ops.auto_generate.format_cast(ms_k_cache, acl_format) + ms_v_cache = ops.auto_generate.format_cast(ms_v_cache, acl_format) + + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, head_num=NUM_HEADS) + + # Extract and verify results + ms_k_cache_np = ms_k_cache.asnumpy() + ms_v_cache_np = ms_v_cache.asnumpy() + + # Extract cached slots for verification + ms_k_output = get_nd_cached_slots(ms_k_cache_np, np_slot_map) + golden_k_output = get_nd_cached_slots(np_k_cache_out, np_slot_map) + + ms_v_output = get_nd_cached_slots(ms_v_cache_np, np_slot_map) + golden_v_output = get_nd_cached_slots(np_v_cache_out, np_slot_map) + + # Verify results + assert np.allclose(ms_k_output, golden_k_output, 0.001, 0.001) + assert np.allclose(ms_v_output, golden_v_output, 0.001, 0.001) + + # Test with very different dimensions + DimensionTestHelper.run_with_dimensions(96, 16, run_test) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('kv_dim', [2, 3]) +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reshape_and_cache_different_key_value_dimensions(kv_dim, run_mode): + """ + Feature: Test ReshapeAndCache. + Description: Test with significantly different K_HEAD_DIM and V_HEAD_DIM. + Expectation: Assert that results are consistent with numpy. + """ + def run_test(): + test_config = TestConfig(device_target="Ascend", mode=run_mode) + test_config.apply() + + net = ReshapeAndCacheAll() + + # Test with very different dimensions + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( + np.float16, kv_dim) + + # Verify that key and value have different shapes + assert np_k.shape != np_v.shape, f"Key and value should have different shapes: {np_k.shape} vs {np_v.shape}" + assert np_k_cache.shape != np_v_cache.shape, f"Key and value cache should have different shapes: {np_k_cache.shape} vs {np_v_cache.shape}" + + np_k_cache_out, np_v_cache_out = nd_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Run test + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map) + TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np.float16) + TestResultVerifier.verify_results(ms_v_cache, np_v_cache_out, np.float16) + + # Test with very different dimensions + DimensionTestHelper.run_with_dimensions(128, 32, run_test) + + +# =============================== +# test nz format +# =============================== +class NZDataGenerator(TestDataGenerator): + """Data generator for NZ format""" + + @staticmethod + def create_inputs(k_dtype: np.dtype, v_dtype: np.dtype, kv_dim: int, k_head_dim=None, v_head_dim=None) -> Tuple[np.ndarray, ...]: + """Create NZ format inputs""" + # Use provided dimensions or fall back to global constants + actual_k_head_dim = k_head_dim if k_head_dim is not None else K_HEAD_DIM + actual_v_head_dim = v_head_dim if v_head_dim is not None else V_HEAD_DIM + + k_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS * actual_k_head_dim) + v_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS * actual_v_head_dim) + key_update_shape, value_update_shape, num_tokens = TestDataGenerator.get_update_shapes(kv_dim, k_head_dim, v_head_dim) + + key_update = TestDataGenerator.create_random_data(key_update_shape, k_dtype) + value_update = TestDataGenerator.create_random_data(value_update_shape, v_dtype) + key_cache = np.zeros(k_cache_shape, dtype=k_dtype) + value_cache = np.zeros(v_cache_shape, dtype=v_dtype) + slot_map = TestDataGenerator.create_slot_map(num_tokens) + + return key_update, value_update, key_cache, value_cache, slot_map + + +def create_nz_inputs(k_dtype=np.float16, v_dtype=np.float16, kv_dim=3, k_head_dim=None, v_head_dim=None): + """Legacy function for backward compatibility""" + return NZDataGenerator.create_inputs(k_dtype, v_dtype, kv_dim, k_head_dim, v_head_dim) + + +def get_nz_cached_slots(cache, slot_map): + ans = [] + tmp = [] + + num_slots = cache.shape[0] + slot_size = cache.shape[1] + hidden_size = cache.shape[2] + + if cache.dtype == np.int8: + cache_shape = (num_slots, hidden_size // 32, slot_size, 32) + else: + cache_shape = (num_slots, hidden_size // 16, slot_size, 16) + cache = cache.reshape(cache_shape) + for i, slot in enumerate(slot_map): + if slot < 0: + continue + slot_idx = slot // slot_size + slot_offset = slot % slot_size + for j in range(cache.shape[1]): + tmp.append(cache[slot_idx][j][slot_offset]) + ans.append(np.concatenate(tmp, axis=0)) + ans = np.concatenate(ans) + return ans + + +def get_nd_cached_slots(cache, slot_map): + ans = [] + for slot in slot_map: + if slot < 0: + continue + slot_idx = slot // SLOT_SIZE + slot_offset = slot % SLOT_SIZE + ans.append(cache[slot_idx][slot_offset]) + ans = np.concatenate(ans) + return ans + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('kv_dim', [2, 3]) +@pytest.mark.parametrize('k_dtype', [np.float16, bfloat16, np.int8]) +@pytest.mark.parametrize('v_dtype', [np.float16, bfloat16]) +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reshape_and_cache_nz(k_dtype, v_dtype, kv_dim, run_mode): + """ + Feature: Test ReshapeAndCache. + Description: Test NZ format with key and value. + Expectation: Assert that results are consistent with numpy. + """ + # Skip invalid combinations + if (k_dtype == np.float16 and v_dtype != np.float16) or \ + (k_dtype == bfloat16 and v_dtype != bfloat16): + pytest.skip(f"Invalid combo: {k_dtype} -> {v_dtype}") + + # Setup context + jit_config = {"jit_level": "O0"} + test_config = TestConfig(device_target="Ascend", mode=run_mode, jit_config=jit_config) + test_config.apply() + + net = ReshapeAndCacheAll() + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nz_inputs( + k_dtype, v_dtype, kv_dim) + np_k_cache_out, np_v_cache_out = nz_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Create MindSpore inputs with appropriate format + if run_mode == context.GRAPH_MODE: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="FRACTAL_NZ") + else: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="") + acl_format = 29 + ms_k_cache = ops.auto_generate.format_cast(ms_k_cache, acl_format) + ms_v_cache = ops.auto_generate.format_cast(ms_v_cache, acl_format) + + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, head_num=NUM_HEADS) + + # Extract and verify results + ms_k_cache_np = ms_k_cache.asnumpy() + ms_v_cache_np = ms_v_cache.asnumpy() + + # Handle bfloat16 conversion + if k_dtype == bfloat16: + ms_k_cache_np = ms_k_cache_np.astype(np.float32) + np_k_cache_out = np_k_cache_out.astype(np.float32) + + if v_dtype == bfloat16: + ms_v_cache_np = ms_v_cache_np.astype(np.float32) + np_v_cache_out = np_v_cache_out.astype(np.float32) + + # Extract cached slots for verification + ms_k_output = get_nd_cached_slots(ms_k_cache_np, np_slot_map) + golden_k_output = get_nd_cached_slots(np_k_cache_out, np_slot_map) + + ms_v_output = get_nd_cached_slots(ms_v_cache_np, np_slot_map) + golden_v_output = get_nd_cached_slots(np_v_cache_out, np_slot_map) + + # Verify results + assert np.allclose(ms_k_output, golden_k_output, 0.001, 0.001) + assert np.allclose(ms_v_output, golden_v_output, 0.001, 0.001) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('kv_dim', [2, 3]) +@pytest.mark.parametrize('k_dtype', [np.float16, bfloat16, np.int8]) +@pytest.mark.parametrize('v_dtype', [np.float16, bfloat16]) +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('k_head_dim', [32, 64, 128]) +@pytest.mark.parametrize('v_head_dim', [32, 64, 128]) +def test_reshape_and_cache_nz_different_dimensions(k_dtype, v_dtype, kv_dim, run_mode, k_head_dim, v_head_dim): + """ + Feature: Test ReshapeAndCache. + Description: Test NZ format with different K_HEAD_DIM and V_HEAD_DIM combinations. + Expectation: Assert that results are consistent with numpy. + """ + # Skip invalid combinations + if (k_dtype == np.float16 and v_dtype != np.float16) or \ + (k_dtype == bfloat16 and v_dtype != bfloat16): + pytest.skip(f"Invalid combo: {k_dtype} -> {v_dtype}") + + def run_test(): + # Setup context + jit_config = {"jit_level": "O0"} + test_config = TestConfig(device_target="Ascend", mode=run_mode, jit_config=jit_config) + test_config.apply() + + net = ReshapeAndCacheAll() + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nz_inputs( + k_dtype, v_dtype, kv_dim, k_head_dim, v_head_dim) + np_k_cache_out, np_v_cache_out = nz_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Create MindSpore inputs with appropriate format + if run_mode == context.GRAPH_MODE: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="FRACTAL_NZ") + else: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="") + acl_format = 29 + ms_k_cache = ops.auto_generate.format_cast(ms_k_cache, acl_format) + ms_v_cache = ops.auto_generate.format_cast(ms_v_cache, acl_format) + + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, head_num=NUM_HEADS) + + # Extract and verify results + ms_k_cache_np = ms_k_cache.asnumpy() + ms_v_cache_np = ms_v_cache.asnumpy() + + # Handle bfloat16 conversion + if k_dtype == bfloat16: + ms_k_cache_np = ms_k_cache_np.astype(np.float32) + np_k_cache_out = np_k_cache_out.astype(np.float32) + + if v_dtype == bfloat16: + ms_v_cache_np = ms_v_cache_np.astype(np.float32) + np_v_cache_out = np_v_cache_out.astype(np.float32) + + # Extract cached slots for verification + ms_k_output = get_nd_cached_slots(ms_k_cache_np, np_slot_map) + golden_k_output = get_nd_cached_slots(np_k_cache_out, np_slot_map) + + ms_v_output = get_nd_cached_slots(ms_v_cache_np, np_slot_map) + golden_v_output = get_nd_cached_slots(np_v_cache_out, np_slot_map) + + # Verify results + assert np.allclose(ms_k_output, golden_k_output, 0.001, 0.001) + assert np.allclose(ms_v_output, golden_v_output, 0.001, 0.001) + + DimensionTestHelper.run_with_dimensions(k_head_dim, v_head_dim, run_test)