From 323f9376edc4a02ebdf6c696c9965a8119fc6ed7 Mon Sep 17 00:00:00 2001 From: gbananda Date: Wed, 1 Mar 2023 14:13:44 +0530 Subject: [PATCH] Handle CTE cyclic dependency issue --- ...rioritizeUtilizationExecutionSchedule.java | 60 ++++++++++++++++++- .../sql/planner/plan/PlanFragmentId.java | 36 +++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionSchedule.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionSchedule.java index 2b3cf8783..157e681c8 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionSchedule.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionSchedule.java @@ -48,6 +48,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; @@ -93,6 +94,8 @@ public class PrioritizeUtilizationExecutionSchedule @GuardedBy("this") private SettableFuture rescheduleFuture = SettableFuture.create(); + private Map> fragmentIdBuildMap = new HashMap<>(); + public static PrioritizeUtilizationExecutionSchedule forStages(Collection stages, DynamicFilterService dynamicFilterService) { PrioritizeUtilizationExecutionSchedule schedule = new PrioritizeUtilizationExecutionSchedule(stages, dynamicFilterService); @@ -117,10 +120,50 @@ public class PrioritizeUtilizationExecutionSchedule fragmentDependency.vertexSet().stream() .filter(fragmentId -> fragmentDependency.inDegreeOf(fragmentId) == 0) .forEach(fragmentsToExecute::add); + prioritizeBuildSideFragments(sortedFragments); fragmentOrdering = Ordering.explicit(sortedFragments); selectForExecution(fragmentsToExecute.build()); } + private void prioritizeBuildSideFragments(List sortedFragments) + { + List tempList = new ArrayList<>(); + tempList.addAll(sortedFragments); + int index = 0; + for (PlanFragmentId fragment : tempList) { + fragment.setIndex(index++); + } + + for (Map.Entry> cteEntryMap : fragmentIdBuildMap.entrySet()) { + Integer commonCTERefNum = cteEntryMap.getKey(); + Set fragmentIds = cteEntryMap.getValue(); + if (fragmentIds.size() > 1) { + // find First element in HashSet + Integer firstCTE = fragmentIds.stream().findFirst().get(); + // find Last element in HashSet + Integer lastCTE = fragmentIds.stream().reduce((one, two) -> two).get(); + replaceFirstAndLastCTE(firstCTE, lastCTE); + } + } + } + + private void replaceFirstAndLastCTE(Integer firstCTE, Integer lastCTE) + { + int indexOfFirstCTE = 0; + int indexOfLastCTE = 0; + for (PlanFragmentId planFragId : sortedFragments) { + if (firstCTE.equals(Integer.valueOf(planFragId.toString()))) { + indexOfFirstCTE = planFragId.getIndex(); + } + else if (lastCTE.equals(Integer.valueOf(planFragId.toString()))) { + indexOfLastCTE = planFragId.getIndex(); + } + } + PlanFragmentId firstPlanFragId = sortedFragments.get(indexOfFirstCTE); + sortedFragments.set(indexOfFirstCTE, sortedFragments.get(indexOfLastCTE)); + sortedFragments.set(indexOfLastCTE, firstPlanFragId); + } + @Override public StagesScheduleResult getStagesToSchedule() { @@ -326,6 +369,10 @@ public class PrioritizeUtilizationExecutionSchedule FragmentSubGraph subGraph = processFragment(fragments.get(planFragmentId)); verify(fragmentSubGraphs.put(planFragmentId, subGraph) == null, "fragment %s was already processed", planFragmentId); + if (fragments.get(planFragmentId).getRoot() instanceof CTEScanNode) { + planFragmentId.setCTEScanNode(true); + planFragmentId.setCommonCTERefNum(((CTEScanNode) fragments.get(planFragmentId).getRoot()).getCommonCTERefNum()); + } sortedFragments.add(planFragmentId); return subGraph; } @@ -402,6 +449,17 @@ public class PrioritizeUtilizationExecutionSchedule FragmentSubGraph probeSubGraph = probe.accept(this, currentFragmentId); FragmentSubGraph buildSubGraph = build.accept(this, currentFragmentId); + for (PlanFragmentId planFragmentId : buildSubGraph.getUpstreamFragments()) { + if (fragments.get(planFragmentId).getRoot() instanceof CTEScanNode) { + Set buildFragmentIdSet = fragmentIdBuildMap.get(planFragmentId.getCommonCTERefNum()); + if (buildFragmentIdSet == null) { + fragmentIdBuildMap.put(planFragmentId.getCommonCTERefNum(), new LinkedHashSet<>()); + buildFragmentIdSet = fragmentIdBuildMap.get(planFragmentId.getCommonCTERefNum()); + } + buildFragmentIdSet.add(Integer.valueOf(planFragmentId.toString())); + } + } + // start probe source stages after all build source stages finish addDependencyEdges(buildSubGraph.getUpstreamFragments(), probeSubGraph.getLazyUpstreamFragments()); @@ -457,6 +515,7 @@ public class PrioritizeUtilizationExecutionSchedule .collect(toImmutableList()); node.getSourceFragmentIds() .forEach(sourceFragmentId -> fragmentTopology.addEdge(sourceFragmentId, currentFragmentId)); + return new FragmentSubGraph( subGraphs.stream() .flatMap(source -> source.getUpstreamFragments().stream()) @@ -517,7 +576,6 @@ public class PrioritizeUtilizationExecutionSchedule List sourceSubGraphs = node.getSources().stream() .map(subPlanNode -> subPlanNode.accept(this, currentFragmentId)) .collect(toImmutableList()); - return new FragmentSubGraph( sourceSubGraphs.stream() .flatMap(source -> source.getUpstreamFragments().stream()) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/PlanFragmentId.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/PlanFragmentId.java index e4775f8e3..937d66f9c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/PlanFragmentId.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/PlanFragmentId.java @@ -32,6 +32,42 @@ public class PlanFragmentId this.id = id; } + private boolean isCTEScanNode; + + public Integer getCommonCTERefNum() + { + return commnonCTERefNum; + } + + public void setCommonCTERefNum(Integer commnonCTERefNum) + { + this.commnonCTERefNum = commnonCTERefNum; + } + + private Integer commnonCTERefNum; + + public boolean isCTEScanNode() + { + return isCTEScanNode; + } + + public void setCTEScanNode(boolean isCTEScanNode) + { + this.isCTEScanNode = isCTEScanNode; + } + + int index; + + public int getIndex() + { + return index; + } + + public void setIndex(int index) + { + this.index = index; + } + @Override @JsonValue public String toString() -- Gitee