Skip to content

Commit

Permalink
[featrue](pipelineX) check output type in some node (apache#33716)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mryange authored Apr 23, 2024
1 parent d84b35c commit e78db71
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 7 deletions.
9 changes: 7 additions & 2 deletions be/src/pipeline/exec/aggregation_sink_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ AggSinkOperatorX::AggSinkOperatorX(ObjectPool* pool, int operator_id, const TPla
(tnode.__isset.conjuncts && !tnode.conjuncts.empty())),
_partition_exprs(tnode.__isset.distribute_expr_lists ? tnode.distribute_expr_lists[0]
: std::vector<TExpr> {}),
_is_colocate(tnode.agg_node.__isset.is_colocate && tnode.agg_node.is_colocate) {}
_is_colocate(tnode.agg_node.__isset.is_colocate && tnode.agg_node.is_colocate),
_agg_fn_output_row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples) {}

Status AggSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) {
RETURN_IF_ERROR(DataSinkOperatorX<AggSinkLocalState>::init(tnode, state));
Expand Down Expand Up @@ -714,7 +715,11 @@ Status AggSinkOperatorX::prepare(RuntimeState* state) {
alignment_of_next_state * alignment_of_next_state;
}
}

// check output type
if (_needs_finalize) {
RETURN_IF_ERROR(vectorized::AggFnEvaluator::check_agg_fn_output(
_probe_expr_ctxs.size(), _aggregate_evaluators, _agg_fn_output_row_descriptor));
}
return Status::OK();
}

Expand Down
2 changes: 2 additions & 0 deletions be/src/pipeline/exec/aggregation_sink_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ class AggSinkOperatorX final : public DataSinkOperatorX<AggSinkLocalState> {

const std::vector<TExpr> _partition_exprs;
const bool _is_colocate;

RowDescriptor _agg_fn_output_row_descriptor;
};

} // namespace pipeline
Expand Down
22 changes: 19 additions & 3 deletions be/src/pipeline/exec/hashjoin_build_sink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "exprs/bloom_filter_func.h"
#include "pipeline/exec/hashjoin_probe_operator.h"
#include "pipeline/exec/operator.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/exec/join/vhash_join_node.h"
#include "vec/utils/template_helpers.hpp"

Expand Down Expand Up @@ -461,9 +462,24 @@ Status HashJoinBuildSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* st

const std::vector<TEqJoinCondition>& eq_join_conjuncts = tnode.hash_join_node.eq_join_conjuncts;
for (const auto& eq_join_conjunct : eq_join_conjuncts) {
vectorized::VExprContextSPtr ctx;
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.right, ctx));
_build_expr_ctxs.push_back(ctx);
vectorized::VExprContextSPtr build_ctx;
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.right, build_ctx));
{
// for type check
vectorized::VExprContextSPtr probe_ctx;
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.left, probe_ctx));
auto build_side_expr_type = build_ctx->root()->data_type();
auto probe_side_expr_type = probe_ctx->root()->data_type();
if (!vectorized::make_nullable(build_side_expr_type)
->equals(*vectorized::make_nullable(probe_side_expr_type))) {
return Status::InternalError(
"build side type {}, not match probe side type {} , node info "
"{}",
build_side_expr_type->get_name(), probe_side_expr_type->get_name(),
this->debug_string(0));
}
}
_build_expr_ctxs.push_back(build_ctx);

const auto vexpr = _build_expr_ctxs.back()->root();

Expand Down
9 changes: 7 additions & 2 deletions be/src/pipeline/exec/streaming_aggregation_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,8 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id,
_needs_finalize(tnode.agg_node.need_finalize),
_is_merge(false),
_is_first_phase(tnode.agg_node.__isset.is_first_phase && tnode.agg_node.is_first_phase),
_have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()) {}
_have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()),
_agg_fn_output_row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples) {}

Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state) {
RETURN_IF_ERROR(StatefulOperatorX<StreamingAggLocalState>::init(tnode, state));
Expand Down Expand Up @@ -1235,7 +1236,11 @@ Status StreamingAggOperatorX::prepare(RuntimeState* state) {
alignment_of_next_state * alignment_of_next_state;
}
}

// check output type
if (_needs_finalize) {
RETURN_IF_ERROR(vectorized::AggFnEvaluator::check_agg_fn_output(
_probe_expr_ctxs.size(), _aggregate_evaluators, _agg_fn_output_row_descriptor));
}
return Status::OK();
}

Expand Down
1 change: 1 addition & 0 deletions be/src/pipeline/exec/streaming_aggregation_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class StreamingAggOperatorX final : public StatefulOperatorX<StreamingAggLocalSt
bool _can_short_circuit = false;
std::vector<size_t> _make_nullable_keys;
bool _have_conjuncts;
RowDescriptor _agg_fn_output_row_descriptor;
};

} // namespace pipeline
Expand Down
1 change: 1 addition & 0 deletions be/src/pipeline/exec/union_sink_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ Status UnionSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) {

Status UnionSinkOperatorX::prepare(RuntimeState* state) {
RETURN_IF_ERROR(vectorized::VExpr::prepare(_child_expr, state, _child_x->row_desc()));
RETURN_IF_ERROR(vectorized::VExpr::check_expr_output_type(_child_expr, _row_descriptor));
return Status::OK();
}

Expand Down
19 changes: 19 additions & 0 deletions be/src/vec/exprs/vectorized_agg_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,4 +350,23 @@ AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state)
}
}

Status AggFnEvaluator::check_agg_fn_output(int key_size,
const std::vector<vectorized::AggFnEvaluator*>& agg_fn,
const RowDescriptor& output_row_desc) {
auto name_and_types = VectorizedUtils::create_name_and_data_types(output_row_desc);
for (int i = key_size, j = 0; i < name_and_types.size(); i++, j++) {
auto&& [name, column_type] = name_and_types[i];
auto agg_return_type = agg_fn[j]->function()->get_return_type();
if (!column_type->equals(*agg_return_type)) {
if (!column_type->is_nullable() || agg_return_type->is_nullable() ||
!remove_nullable(column_type)->equals(*agg_return_type)) {
return Status::InternalError(
"column_type not match data_types in agg node, column_type={}, "
"data_types={},column name={}",
column_type->get_name(), agg_return_type->get_name(), name);
}
}
}
return Status::OK();
}
} // namespace doris::vectorized
4 changes: 4 additions & 0 deletions be/src/vec/exprs/vectorized_agg_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ class AggFnEvaluator {
bool is_merge() const { return _is_merge; }
const VExprContextSPtrs& input_exprs_ctxs() const { return _input_exprs_ctxs; }

static Status check_agg_fn_output(int key_size,
const std::vector<vectorized::AggFnEvaluator*>& agg_fn,
const RowDescriptor& output_row_desc);

void set_version(const int version) { _function->set_version(version); }

AggFnEvaluator* clone(RuntimeState* state, ObjectPool* pool);
Expand Down

0 comments on commit e78db71

Please sign in to comment.