diff --git a/cpp/core/api/common/state/ValueState.h b/cpp/core/api/common/state/ValueState.h index 180ae9ef1c7da36e79aaaea10221164e829cfe3c..c1d504e8945b6f907f4ef55d00a5424e791a9dad 100644 --- a/cpp/core/api/common/state/ValueState.h +++ b/cpp/core/api/common/state/ValueState.h @@ -21,6 +21,7 @@ public: ~ValueState() override = default; virtual T value() = 0; virtual void update(const T &value, bool copyKey = false) = 0; + virtual void updateByBatch(std::unordered_map& pendingUpdates) {}; }; using DataStreamValueState = ValueState; diff --git a/cpp/runtime/state/rocksdb/RocksdbMapState.h b/cpp/runtime/state/rocksdb/RocksdbMapState.h index e916ab1db4c5d87e22d4506fe463d7ec2cea2dba..264946e5e4fd500acb3602da2724cb2ebab60f51 100644 --- a/cpp/runtime/state/rocksdb/RocksdbMapState.h +++ b/cpp/runtime/state/rocksdb/RocksdbMapState.h @@ -53,6 +53,7 @@ public: void put(const UK &userKey, const UV &userValue) override; void putByBatch(const K &key,const std::unordered_map &dataToAdd); void putByBatch(std::unordered_map> &dataToAdd); + void putByBatch(std::vector>>> &dataToAdd); void remove(const UK &userKey) override; void removeByBatch(std::unordered_map> &dataToRemove); @@ -86,6 +87,7 @@ public: void createTable(ROCKSDB_NAMESPACE::DB* db, std::string cfName, std::unordered_map> *kvStateInformation); + std::shared_ptr getRawBytes(UK& uk); private: RocksdbMapStateTable *stateTable; @@ -150,6 +152,12 @@ Object* RocksdbMapState::Get(Object* userKey) } } +template +std::shared_ptr RocksdbMapState::getRawBytes(UK& uk) +{ + return stateTable->getRawBytes(currentNamespace,uk); +} + template void RocksdbMapState::GetByBatch(std::unordered_map> &dataToGet,std::unordered_map,UV> &result) @@ -192,6 +200,12 @@ void RocksdbMapState::putByBatch(std::unordered_mapputByBatch(dataToAdd); } +template +void RocksdbMapState::putByBatch(std::vector>>> &dataToAdd) +{ + stateTable->putByBatch(dataToAdd); +} + template void RocksdbMapState::remove(const UK &userKey) { diff --git a/cpp/runtime/state/rocksdb/RocksdbMapStateTable.h b/cpp/runtime/state/rocksdb/RocksdbMapStateTable.h index e41efe79739016359649b515d1c5b106b96f3053..c6071dfb1840719a6ff1bb1846d8e60f1aa2bec0 100644 --- a/cpp/runtime/state/rocksdb/RocksdbMapStateTable.h +++ b/cpp/runtime/state/rocksdb/RocksdbMapStateTable.h @@ -38,6 +38,8 @@ #include "utils/VectorBatchDeserializationUtils.h" #include "utils/VectorBatchSerializationUtils.h" #include "state/RocksDbKvStateInfo.h" +#include "runtime/state/DefaultConfigurableOptionsFactory.h" + /* S is the value used in the State, * like RowData* for HeapValueState, @@ -126,6 +128,23 @@ public: } }; + std::shared_ptr getRawBytes(const N &nameSpace,UK& uk) + { + // outputSerializer free need after Get called + DataOutputSerializer outputSerializer; + OutputBufferStatus outputBufferStatus; + outputSerializer.setBackendBuffer(&outputBufferStatus); + + ROCKSDB_NAMESPACE::Slice sliceKey = serializerKeyAndUserKey(outputSerializer,uk); + + ROCKSDB_NAMESPACE::PinnableSlice pinSlice; + ROCKSDB_NAMESPACE::Status s = rocksDb->Get(readOptions, table, sliceKey, &pinSlice); + if (!s.ok() || pinSlice.size() == 0) { + return nullptr; + } + return std::make_shared(pinSlice.data(), pinSlice.size()); + }; + void GetByBatch(std::unordered_map> &dataToGet,std::unordered_map,UV> &result) { @@ -278,8 +297,32 @@ public: putBatch.Put(table,sliceKey,sliceValue); } } + writeOptions.memtable_insert_hint_per_batch = true; auto ret = rocksDb->Write(writeOptions, &putBatch); + } + + void putByBatch(std::vector>>>& dataToAdd) + { + ROCKSDB_NAMESPACE::WriteBatch putBatch; + for (auto& item : dataToAdd) { + K key = std::get<0>(*item); + UK ukey = std::get<1>(*item); + std::shared_ptr strPtr = std::get<2>(*item); + // outputSerializer free need after Put called + DataOutputSerializer outputSerializer; + OutputBufferStatus outputBufferStatus; + outputSerializer.setBackendBuffer(&outputBufferStatus); + ROCKSDB_NAMESPACE::Slice sliceKey = serializerKeyAndUserKey(outputSerializer,ukey); + // valueOutputSerializer free need after Put called + DataOutputSerializer valueOutputSerializer; + OutputBufferStatus valueOutputBufferStatus; + valueOutputSerializer.setBackendBuffer(&valueOutputBufferStatus); + ROCKSDB_NAMESPACE::Slice sliceValue = serializerValue(valueOutputSerializer, strPtr->data()); + putBatch.Put(table, sliceKey, sliceValue); + } + writeOptions.memtable_insert_hint_per_batch = true; + auto ret = rocksDb->Write(writeOptions, &putBatch); } void remove(const N &nameSpace, const UK &userKey) @@ -833,6 +876,22 @@ protected: outputSerializer.length()); } + ROCKSDB_NAMESPACE::Slice serializerKey(DataOutputSerializer &outputSerializer) + { + auto currentKey = keyContext->getCurrentKey(); + + // 序列化key, userKey + + if constexpr (std::is_pointer_v) { + getKeySerializer()->serialize(currentKey, outputSerializer); + } else { + getKeySerializer()->serialize(¤tKey, outputSerializer); + } + + return ROCKSDB_NAMESPACE::Slice(reinterpret_cast(outputSerializer.getData()), + outputSerializer.length()); + } + ROCKSDB_NAMESPACE::Slice serializerValue(DataOutputSerializer &valueOutputSerializer, UV userValue) { // value序列化 diff --git a/cpp/runtime/state/rocksdb/RocksdbStateTable.h b/cpp/runtime/state/rocksdb/RocksdbStateTable.h index a462b053271001951dfe6d9483524d81fc6db03a..90bea17a5d70f48b625fb52855d77a0f7c897507 100644 --- a/cpp/runtime/state/rocksdb/RocksdbStateTable.h +++ b/cpp/runtime/state/rocksdb/RocksdbStateTable.h @@ -169,6 +169,45 @@ public: } } + void putByBatch(N &nameSpace, std::unordered_map& pendingUpdates) + { + // 存入 + ROCKSDB_NAMESPACE::WriteBatch putBatch; + for (auto& entry : pendingUpdates) { + RowData* key = entry.first; + S state = entry.second; + + keyContext->setCurrentKey(key); + LOG("RocksDB put"); + DataOutputSerializer outputSerializer; + OutputBufferStatus outputBufferStatus; + outputSerializer.setBackendBuffer(&outputBufferStatus); + ROCKSDB_NAMESPACE::Slice sliceKey = GetKeyNameSpaceSlice(outputSerializer, nameSpace); + + // value序列化 + TypeSerializer *vSerializer = getStateSerializer(); + DataOutputSerializer valueOutputSerializer; + OutputBufferStatus valueOutputBufferStatus; + valueOutputSerializer.setBackendBuffer(&valueOutputBufferStatus); + + S tmpS = state; + + if constexpr (std::is_pointer_v) { + vSerializer->serialize(tmpS, valueOutputSerializer); + } else { + vSerializer->serialize(&tmpS, valueOutputSerializer); + } + + ROCKSDB_NAMESPACE::Slice sliceValue(reinterpret_cast(valueOutputSerializer.getData()), + valueOutputSerializer.length()); + putBatch.Put(table,sliceKey,sliceValue); + } + auto s3 = rocksDb->Write(writeOptions, &putBatch); + + if (s3.ok()) { + } + }; + void put(N &nameSpace, const S &state) { // 存入 diff --git a/cpp/runtime/state/rocksdb/RocksdbValueState.h b/cpp/runtime/state/rocksdb/RocksdbValueState.h index 4aa54b8ef2a7b3322d6ff2d801437047a23f8568..963577408bd4a04f63f6cb7cf95c8fe3c5d45b3f 100644 --- a/cpp/runtime/state/rocksdb/RocksdbValueState.h +++ b/cpp/runtime/state/rocksdb/RocksdbValueState.h @@ -77,6 +77,7 @@ public: void addVectorBatch(omnistream::VectorBatch *vectorBatch) override; omnistream::VectorBatch *getVectorBatch(int batchId) override; long getVectorBatchesSize() override; + void updateByBatch(std::unordered_map& pendingUpdates); private: RocksdbStateTable *stateTable; @@ -184,4 +185,11 @@ void RocksdbValueState::clear() stateTable->clear(currentNamespace); } +template +void RocksdbValueState::updateByBatch(std::unordered_map& pendingUpdates) +{ + stateTable->putByBatch(currentNamespace,pendingUpdates); +} + + #endif // OMNISTREAM_ROCKSDBVALUESTATE_H diff --git a/cpp/table/runtime/dataview/StateMapView.h b/cpp/table/runtime/dataview/StateMapView.h index b18a71878bbd02f28f8a2b2d5384fb44b8aa515a..bbb29c4e2e2de7b0a578c765675e4bffe8637aba 100644 --- a/cpp/table/runtime/dataview/StateMapView.h +++ b/cpp/table/runtime/dataview/StateMapView.h @@ -10,10 +10,16 @@ */ #ifndef FLINK_TNEL_STATEMAPVIEW_H #define FLINK_TNEL_STATEMAPVIEW_H + +#include #include "MapView.h" #include "StateDataView.h" #include "core/api/common/state/ValueState.h" #include "core/api/common/state/MapState.h" +#include "../runtime/state/rocksdb/RocksdbMapState.h" +using json = nlohmann::json; + + template class StateMapView : public MapView, public StateDataView { @@ -36,6 +42,39 @@ public: void put(const std::optional& key, const EV& value) override { key == std::nullopt ? getNullState()->update(value) : getMapState()->put(*key, value); }; void remove(const std::optional& key) { key == std::nullopt ? getNullState()->clear() : getMapState()->remove(*key); }; void contains(const std::optional& key) { return key == std::nullopt ? getNullState()->value() != nullptr : getMapState()->contains(*key); }; + emhash7::HashMap *entries() + { + return getMapState()->entries(); + }; + void putByBatch(std::vector>>> & batchData) + { + auto rocksDBMap = dynamic_cast *>(getMapState()); + if (rocksDBMap) { + rocksDBMap->putByBatch(batchData); + } + } + + std::shared_ptr getInnerMap(EK& ek) + { + auto rocksDBMap = dynamic_cast *>(getMapState()); + if (rocksDBMap) { + std::shared_ptr rawString= rocksDBMap->getRawBytes(ek); + try { + if (rawString == nullptr || rawString->empty()) { + return nullptr; + } + return std::make_shared(json::parse(*rawString)); + } catch (const json::parse_error& e) { + LOG("parse json error............"); + } + } + return nullptr; + } + + void cleanup() + { + getMapState()->clearEntriesCache(); + } protected: virtual ValueState *getNullState() = 0; virtual MapState *getMapState() = 0; diff --git a/cpp/table/runtime/generated/AggsHandleFunction.h b/cpp/table/runtime/generated/AggsHandleFunction.h index 7e740e3269b6205d96bc41a4fdf453c55e214ad5..2bdc78b380d9d987c60ffbe149cd35066c087799 100644 --- a/cpp/table/runtime/generated/AggsHandleFunction.h +++ b/cpp/table/runtime/generated/AggsHandleFunction.h @@ -35,6 +35,11 @@ public: virtual void getAccumulators(BinaryRowData* accumulators) = 0; virtual void cleanup() = 0; virtual void close() = 0; + virtual void setCurrentGroupKey(RowData* key) {}; + virtual void setBackend(int backend) {this->backend = backend;} + virtual void updateInnerState() {}; +protected: + int backend=0; // 0: memory, 1: bss, 2: rocksdb }; #endif // FLINK_TNEL_AGGS_HANDLE_FUNCTION_H \ No newline at end of file diff --git a/cpp/table/runtime/generated/function/CountDistinctFunction.cpp b/cpp/table/runtime/generated/function/CountDistinctFunction.cpp index 6e5899ba22bad2c72f364b2bab86fd9c968d5491..1bc347ac27cee87fccc602b9e5722e6635b0ddf8 100644 --- a/cpp/table/runtime/generated/function/CountDistinctFunction.cpp +++ b/cpp/table/runtime/generated/function/CountDistinctFunction.cpp @@ -8,10 +8,11 @@ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. * See the Mulan PSL v2 for more details. */ - +#include #include "CountDistinctFunction.h" #include "typeutils/InternalSerializers.h" #include "runtime/dataview/PerKeyStateDataViewStore.h" +using json = nlohmann::json; bool CountDistinctFunction::equaliser(BinaryRowData *r1, BinaryRowData *r2) { @@ -105,12 +106,86 @@ void CountDistinctFunction::accumulate(RowData *accInput) LOG("Accumulate. Count: " << aggCount << " countIsNull: " << valueIsNull) } -void CountDistinctFunction::accumulate(omnistream::VectorBatch *input, const std::vector &indices) +void CountDistinctFunction::accumulate(omnistream::VectorBatch* input, const std::vector& indices) { + if (backend == 2) { + this->accumulateInRocksDB(input, indices); + } else { + auto columnData = input->Get(aggIdx); + const bool hasFilterCol = hasFilter; + const auto filterData = hasFilterCol + ? reinterpret_cast*>(input->Get(filterIndex)) + : nullptr; + for (int rowIndex : indices) { + bool shouldDoAccumulate = true; + if (hasFilterCol) { + bool isFilterNull = filterData->IsNull(rowIndex); + shouldDoAccumulate = !isFilterNull && filterData->GetValue(rowIndex); + } + if (!shouldDoAccumulate) continue; + bool isFieldNull = columnData->IsNull(rowIndex); + long fieldValue; + switch (typeId) { + case DataTypeId::OMNI_INT: { + fieldValue = isFieldNull + ? -1L + : dynamic_cast*>(columnData)->GetValue(rowIndex); + break; + } + case DataTypeId::OMNI_LONG: { + fieldValue = isFieldNull + ? -1L + : dynamic_cast*>(columnData)->GetValue(rowIndex); + break; + } + default: + LOG("Data type is not supported."); + throw std::runtime_error("Data type is not supported."); + } + std::optional distinctKey = isFieldNull ? std::nullopt : std::optional{fieldValue}; + std::optional value = distinctMapView->get(distinctKey); + long trueValue = value.has_value() ? value.value() : 0L; + uint64_t uValue = static_cast(trueValue); + long existed = uValue & (1 << 0); + if (existed == 0) { + uValue = uValue | (1 << 0); + trueValue = static_cast(uValue); + if (!isFieldNull) { + if (!valueIsNull) { + aggCount++; + } + else { + aggCount = 1L; + valueIsNull = false; + } + } + distinctMapView->put(distinctKey, trueValue); + } + } + LOG("Accumulate. Count: " << aggCount << " valueIsNull: " << valueIsNull); + } +} + + +void CountDistinctFunction::accumulateInRocksDB(omnistream::VectorBatch* input, const std::vector& indices) +{ + std::shared_ptr jsonData = distinctMapView->getInnerMap(stateKey); + std::unordered_map distinctCache; + bool needUpdate = false; + + if (jsonData != nullptr) { + for (auto& item : jsonData->items()) { + long key = std::stol(item.key()); // convert JSON key string → long + long value = item.value().get(); // convert JSON number → long + distinctCache.emplace(key, value); + } + } + auto columnData = input->Get(aggIdx); const bool hasFilterCol = hasFilter; const auto filterData = hasFilterCol - ? reinterpret_cast *>(input->Get(filterIndex)) : nullptr; + ? reinterpret_cast*>(input->Get(filterIndex)) + : nullptr; for (int rowIndex : indices) { bool shouldDoAccumulate = true; if (hasFilterCol) { @@ -121,43 +196,53 @@ void CountDistinctFunction::accumulate(omnistream::VectorBatch *input, const std bool isFieldNull = columnData->IsNull(rowIndex); long fieldValue; switch (typeId) { - case DataTypeId::OMNI_INT: { - fieldValue = isFieldNull - ? -1L : dynamic_cast *>(columnData)->GetValue(rowIndex); - break; - } - case DataTypeId::OMNI_LONG: { - fieldValue = isFieldNull - ? -1L : dynamic_cast *>(columnData)->GetValue(rowIndex); - break; - } - default: - LOG("Data type is not supported."); - throw std::runtime_error("Data type is not supported."); + case DataTypeId::OMNI_INT: { + fieldValue = isFieldNull + ? -1L + : dynamic_cast*>(columnData)->GetValue(rowIndex); + break; + } + case DataTypeId::OMNI_LONG: { + fieldValue = isFieldNull + ? -1L + : dynamic_cast*>(columnData)->GetValue(rowIndex); + break; + } + default: + LOG("Data type is not supported."); + throw std::runtime_error("Data type is not supported."); } std::optional distinctKey = isFieldNull ? std::nullopt : std::optional{fieldValue}; - std::optional value = distinctMapView->get(distinctKey); - long trueValue = value.has_value() ? value.value() : 0L; - uint64_t uValue = static_cast(trueValue); - long existed = uValue & (1 << 0); - if (existed == 0) { - uValue = uValue | (1 << 0); - trueValue = static_cast(uValue); - if (!isFieldNull) { - if (!valueIsNull) { - aggCount++; - } else { - aggCount = 1L; - valueIsNull = false; - } + + if (distinctKey.has_value()) { + if (auto it = distinctCache.find(distinctKey.value()); it != distinctCache.end()) { + continue; } - distinctMapView->put(distinctKey, trueValue); + distinctCache.emplace(distinctKey.value(), 1L); + needUpdate = true; } + + if (!isFieldNull) { + if (!valueIsNull) { + aggCount++; + } + else { + aggCount = 1L; + valueIsNull = false; + } + } + } + if (needUpdate) { + json needUpdateKeysJson = distinctCache; + auto dumpedPtr = std::make_shared(needUpdateKeysJson.dump()); + + keyAndValuesTuples.push_back(std::make_shared>> + (std::make_tuple(this->currentGroupKey,stateKey, dumpedPtr))); } + LOG("Accumulate. Count: " << aggCount << " valueIsNull: " << valueIsNull); } - void CountDistinctFunction::retract(RowData *retractInput) { } @@ -219,3 +304,16 @@ void CountDistinctFunction::getValue(BinaryRowData *newAggValue) } LOG("Get value: " << aggCount) } + +void CountDistinctFunction::setCurrentGroupKey(RowData* key) +{ + this->currentGroupKey = key; +} + +void CountDistinctFunction::updateInnerState() +{ + this->distinctMapView->putByBatch(keyAndValuesTuples); + keyAndValuesTuples.clear(); + this->distinctMapView->cleanup(); + +} diff --git a/cpp/table/runtime/generated/function/CountDistinctFunction.h b/cpp/table/runtime/generated/function/CountDistinctFunction.h index 966e7509988d89ab60327cb88f844727927151f0..6354290498be28c40f7012f7fd48aa5664343b8e 100644 --- a/cpp/table/runtime/generated/function/CountDistinctFunction.h +++ b/cpp/table/runtime/generated/function/CountDistinctFunction.h @@ -25,6 +25,7 @@ public: { hasFilter = filterIndex != -1; typeId = LogicalType::flinkTypeToOmniTypeId(inputType); + stateKey = (aggIdx << 16) | (filterIndex == -1 ? 0 : filterIndex); } void setWindowSize(int windowSize) override {}; @@ -42,6 +43,10 @@ public: void getValue(BinaryRowData *aggValue) override; void cleanup() override {}; void close() override {}; + void setCurrentGroupKey(RowData* key) override; + void accumulateInRocksDB(omnistream::VectorBatch *input, const std::vector &indices); + void updateInnerState(); + private: long aggCount; @@ -55,6 +60,11 @@ private: omniruntime::type::DataTypeId typeId; StateDataViewStore *store; KeyedStateMapViewWithKeysNullable *distinctMapView; + RowData * currentGroupKey; + // std::unordered_map> groupKeyToDistinctSetMap; + // std::unordered_map> groupKeyToDistinctSetMap; + std::vector>>> keyAndValuesTuples; + long stateKey; }; diff --git a/cpp/table/runtime/operators/aggregate/GroupAggFunction.cpp b/cpp/table/runtime/operators/aggregate/GroupAggFunction.cpp index 9dab0ed94ecd4e40966d64401732b19593da72b7..d336ae09c13315452f95ab8d8dd32ee9d71f6420 100644 --- a/cpp/table/runtime/operators/aggregate/GroupAggFunction.cpp +++ b/cpp/table/runtime/operators/aggregate/GroupAggFunction.cpp @@ -122,6 +122,11 @@ void GroupAggFunction::open(const Configuration& parameters) // This kind of specific template type should all be solved by an if-else based on stateDescription accState = static_cast *>(getRuntimeContext())->getState(accDesc); + + if (dynamic_cast *>(accState)) { + this->backend=2; + } + int accStartingIndex = 0; int aggValueIndex = 0; InitAggFunctions(accStartingIndex, aggValueIndex); @@ -237,7 +242,7 @@ void GroupAggFunction::processElement(RowData* input, Context* ctx, TimestampedC binRowAcc->setNullAt(i); } // Flink don't do update here, it updates it in if (!recordCounter->recordCountIsZero(accumulators)){} line 146 - static_cast *>(accState)->update(accumulators); + // static_cast *>(accState)->update(accumulators); } else { firstRow = false; } @@ -341,8 +346,7 @@ void GroupAggFunction::processBatch(omnistream::VectorBatch *input, KeyedProcess for (auto& pair : keyToRowIndices) { bool isEqual = true; RowData* currentKey = pair.first; - RowData* copyKey = currentKey->copy(); - ctx.setCurrentKey(copyKey); + ctx.setCurrentKey(currentKey); std::vector& groupInfo = pair.second; RowData* accumulators = accState->value(); bool firstRow = accumulators == nullptr; @@ -355,12 +359,25 @@ void GroupAggFunction::processBatch(omnistream::VectorBatch *input, KeyedProcess } for (auto& func : functions) { func->setAccumulators(accumulators); + func->setCurrentGroupKey(currentKey); + func->setBackend(backend); } processBatchColumnar(input, groupInfo, accumulators); LOG("functions loop aggregateCallsCount end") AssembleResultForBatch(accumulators, isEqual, firstRow, currentKey, resultKeys, resultValues, resultRowKinds); - delete copyKey; } + + if (backend == 2) { + UpdateAccumulatorsInRocksDB(pendingUpdates); + for (auto& pair : pendingUpdates) { + delete pair.second; + } + pendingUpdates.clear(); + for (auto& func : functions) { + func->updateInnerState(); + } + } + ClearEnv(input, resultKeys, resultValues, resultRowKinds, out, keyToRowIndices); LOG("GroupAggFunction processBatch end") } @@ -514,7 +531,7 @@ bool GroupAggFunction::FirstRowAccumulate(std::vector& groupInfo, RowDa func->createAccumulators(dynamic_cast(accumulators)); } // Flink don't do update here, it updates it in if (!recordCounter->recordCountIsZero(accumulators)){} - static_cast *>(accState)->update(accumulators); + // static_cast *>(accState)->update(accumulators); return true; } @@ -545,7 +562,11 @@ void GroupAggFunction::AssembleResultForBatch(RowData* accumulators, bool isEqua std::vector& resultRowKinds) { if (!recordCounter->recordCountIsZero(accumulators)) { - static_cast *>(accState)->update(accumulators); + if (backend == 2) { + pendingUpdates.emplace(currentKey, accumulators); + }else { + accState->update(accumulators); + } // Flink update accumulators in state here. But since we directly take the RowData* and updates in getAccumulator, the value in statebackend is already updated! if (!firstRow) { if (EndAssemble(isEqual)) { @@ -631,3 +652,8 @@ bool GroupAggFunction::EndAssemble(bool isEqual) } return false; } + +void GroupAggFunction::UpdateAccumulatorsInRocksDB(std::unordered_map& pendingUpdates) +{ + accState->updateByBatch(pendingUpdates); +} diff --git a/cpp/table/runtime/operators/aggregate/GroupAggFunction.h b/cpp/table/runtime/operators/aggregate/GroupAggFunction.h index b3c104c572afcb7ca9ee8a3ec124d8b7bfa3e6c7..d346828752d857b361f53d91150db635fb9e3d28 100644 --- a/cpp/table/runtime/operators/aggregate/GroupAggFunction.h +++ b/cpp/table/runtime/operators/aggregate/GroupAggFunction.h @@ -80,6 +80,7 @@ public: void FillRowIndices(omnistream::VectorBatch *input, std::unordered_map>& keyToRowIndices, int rowCount); bool EndAssemble(bool isEqual); + void UpdateAccumulatorsInRocksDB(std::unordered_map& pendingUpdates); private: std::vector accTypes; @@ -117,6 +118,9 @@ private: std::vector handleInputTypes(); std::map handleDistinctInfo(); void deleteRowData(vector &rowVector); + int backend=0; //0: memory, 1: bss, 2: rocksdb + //rocksdb update container + std::unordered_map pendingUpdates; }; #endif // FLINK_TNEL_GROUP_AGG_FUNCTION_H diff --git a/cpp/table/utils/VectorBatchDeserializationUtils.h b/cpp/table/utils/VectorBatchDeserializationUtils.h index 6a9e876d8a3a101a2138ec347c0dba44580aa85e..e83fe839025736d44908e6778e9e244b30269c3c 100644 --- a/cpp/table/utils/VectorBatchDeserializationUtils.h +++ b/cpp/table/utils/VectorBatchDeserializationUtils.h @@ -179,7 +179,7 @@ public: int32_t size) { auto nullData = UnsafeBaseVector::GetNulls(baseVector); - auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(size); + auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(size*8); memcpy_s(nullData, sizeof(bool) * size, buffer, sizeof(bool) * size); buffer += nullByteSize; } @@ -267,7 +267,7 @@ public: memcpy_s(&stringBodySize, sizeof(int32_t), buffer, sizeof(int32_t)); buffer += sizeof(int32_t); - auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(rowCnt); + auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(rowCnt*8); std::shared_ptr> nullsBuffer = std::make_shared>(nullByteSize); memcpy(nullsBuffer->GetBuffer(), buffer, diff --git a/cpp/table/utils/VectorBatchSerializationUtils.cpp b/cpp/table/utils/VectorBatchSerializationUtils.cpp index fafa5ea7ea94e944aca520303e1c0350b3366be3..4951f91c7e9da1930d1fa8ee55169c5e826a2c7e 100644 --- a/cpp/table/utils/VectorBatchSerializationUtils.cpp +++ b/cpp/table/utils/VectorBatchSerializationUtils.cpp @@ -122,7 +122,7 @@ int32_t omnistream::VectorBatchSerializationUtils::calculateVectorSerializableSi int32_t vectorHeaderSize = sizeof(int32_t) + sizeof(int8_t) + sizeof(int8_t); - auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(size); + auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(size*8); totalSize += nullByteSize + vectorHeaderSize; int32_t dataSize = 0; @@ -207,7 +207,7 @@ void omnistream::VectorBatchSerializationUtils::serializePrimitiveVector(BaseVec int32_t rowCount = baseVector->GetSize(); serializeVectorBatchHeader(baseVector, buffer, bufferSize); auto nullData = UnsafeBaseVector::GetNulls(baseVector); - auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(rowCount); + auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(rowCount*8); auto ret = memcpy_s(buffer, bufferSize, nullData, sizeof(bool) * rowCount); if (ret != EOK) { @@ -274,7 +274,7 @@ void omnistream::VectorBatchSerializationUtils::serializeCharVector(BaseVector * // nullData auto nullData = UnsafeBaseVector::GetNulls(baseVector); - auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(rowCount); + auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(rowCount*8); ret = memcpy_s(buffer, bufferSize, nullData, nullByteSize); if (ret != EOK) { throw std::runtime_error("memcpy_s failed"); diff --git a/cpp/table/utils/VectorBatchSerializationUtils.h b/cpp/table/utils/VectorBatchSerializationUtils.h index 46e7a20fca634348e7c3fbaabe343d0507500db4..e9082871eedab467fe88d040936bf277711528dd 100644 --- a/cpp/table/utils/VectorBatchSerializationUtils.h +++ b/cpp/table/utils/VectorBatchSerializationUtils.h @@ -203,7 +203,7 @@ public: { // nullData auto nullData = UnsafeBaseVector::GetNulls(baseVector); - auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(rowCount); + auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(rowCount *8); size_t len = nullByteSize; auto ret =