15 #ifndef THRILL_API_MERGE_HEADER 16 #define THRILL_API_MERGE_HEADER 75 template <
typename ValueType,
typename Comparator,
size_t kNumInputs>
78 static constexpr
bool debug =
false;
87 static_assert(kNumInputs >= 2,
"Merge requires at least two inputs.");
90 template <
typename ParentDIA0,
typename... ParentDIAs>
92 const ParentDIA0& parent0,
const ParentDIAs& ...
parents)
93 :
Super(parent0.ctx(),
"Merge",
94 { parent0.id(), parents.id() ... },
95 { parent0.node(),
parents.node() ... }),
99 std::array<bool, kNumInputs>{
100 { ParentDIA0::stack_empty, (ParentDIAs::stack_empty)... }
103 for (
size_t i = 0; i < kNumInputs; ++i)
106 for (
size_t i = 0; i < kNumInputs; ++i)
121 template <
typename Index,
typename Parent>
126 auto pre_op_fn = [writer](
const ValueType& input) ->
void {
132 auto lop_chain = parent.stack().push(pre_op_fn).fold();
134 parent.node()->AddChild(
merge_node_, lop_chain, Index::index);
143 assert(parent_index < kNumInputs);
147 assert(
files_[parent_index]->num_items() == 0);
148 *
files_[parent_index] = file.Copy();
161 size_t result_count = 0;
162 static constexpr
bool debug =
false;
167 std::vector<data::CatStream::CatReader> readers;
168 readers.reserve(kNumInputs);
170 for (
size_t i = 0; i < kNumInputs; i++)
171 readers.emplace_back(
streams_[i]->GetCatReader(consume));
173 auto puller = core::make_multiway_merge_tree<ValueType>(
176 while (puller.HasNext())
181 sLOG <<
"Merge: result_count" << result_count;
218 std::stringstream oss;
219 for (
const Pivot& elem : data) {
220 oss <<
"(" << elem.value
221 <<
", itie: " << elem.tie_idx
222 <<
", len: " << elem.segment_len <<
") ";
264 size_t result_size_ = 0;
266 size_t iterations_ = 0;
271 LOG1 <<
"RESULT " <<
"operation=" << label <<
" time=" << value
272 <<
" workers=" << p <<
" result_size_=" << result_size_;
282 size_t pivot_selection =
283 ctx.
net.
AllReduce(pivot_selection_timer_.Milliseconds()) / p;
285 ctx.
net.
AllReduce(search_step_timer_.Milliseconds()) / p;
296 PrintToSQLPlotTool(
"merge", p, merge);
297 PrintToSQLPlotTool(
"balance", p, balance);
298 PrintToSQLPlotTool(
"pivot_selection", p, pivot_selection);
299 PrintToSQLPlotTool(
"search_step", p, search_step);
300 PrintToSQLPlotTool(
"file_op", p, file_op);
301 PrintToSQLPlotTool(
"communication", p, comm);
302 PrintToSQLPlotTool(
"scatter", p, scatter);
303 PrintToSQLPlotTool(
"iterations", p, iterations_);
326 const std::vector<ArrayNumInputsSizeT>& left,
327 const std::vector<ArrayNumInputsSizeT>& width,
328 std::vector<Pivot>& out_pivots) {
332 for (
size_t s = 0; s < width.size(); s++) {
336 for (
size_t p = 1; p < width[s].size(); p++) {
337 if (width[s][p] > width[s][mp]) {
345 ValueType pivot_elem = ValueType();
346 size_t pivot_idx = left[s][mp];
348 if (width[s][mp] > 0) {
349 pivot_idx = left[s][mp] + (
context_.
rng_() % width[s][mp]);
350 assert(pivot_idx < files_[mp]->num_items());
352 pivot_elem = files_[mp]->template GetItemAt<ValueType>(pivot_idx);
356 out_pivots[s] =
Pivot {
363 LOG <<
"local pivots: " <<
VToStr(out_pivots);
378 const std::vector<Pivot>& pivots,
379 std::vector<size_t>& global_ranks,
380 std::vector<ArrayNumInputsSizeT>& out_local_ranks,
381 const std::vector<ArrayNumInputsSizeT>& left,
382 const std::vector<ArrayNumInputsSizeT>& width) {
386 for (
size_t s = 0; s < pivots.size(); s++) {
388 for (
size_t i = 0; i < kNumInputs; i++) {
391 size_t idx = files_[i]->GetIndexOf(
392 pivots[s].
value, pivots[s].tie_idx,
393 left[s][i], left[s][i] + width[s][i],
399 out_local_ranks[s][i] = idx;
401 global_ranks[s] = rank;
430 const std::vector<size_t>& global_ranks,
431 const std::vector<ArrayNumInputsSizeT>& local_ranks,
432 const std::vector<size_t>& target_ranks,
433 std::vector<ArrayNumInputsSizeT>& left,
434 std::vector<ArrayNumInputsSizeT>& width) {
436 for (
size_t s = 0; s < width.size(); s++) {
437 for (
size_t p = 0; p < width[s].size(); p++) {
439 if (width[s][p] == 0)
442 size_t local_rank = local_ranks[s][p];
443 size_t old_width = width[s][p];
444 assert(left[s][p] <= local_rank);
446 if (global_ranks[s] < target_ranks[s]) {
447 width[s][p] -= local_rank - left[s][p];
448 left[s][p] = local_rank;
450 else if (global_ranks[s] >= target_ranks[s]) {
451 width[s][p] = local_rank - left[s][p];
470 LOG <<
"splitting to " << p <<
" workers";
473 size_t local_size = 0;
475 for (
size_t i = 0; i < kNumInputs; i++) {
476 local_size += files_[i]->num_items();
481 for (
size_t i = 0; i < kNumInputs; i++) {
482 auto reader = files_[i]->GetKeepReader();
483 if (!reader.HasNext())
continue;
485 ValueType prev = reader.template Next<ValueType>();
486 while (reader.HasNext()) {
487 ValueType next = reader.template Next<ValueType>();
489 die(
"Merge input was not sorted!");
491 prev = std::move(next);
501 LOG <<
"local size: " << local_size;
502 LOG <<
"global size: " << global_size;
506 std::vector<size_t> target_ranks(p - 1);
508 for (
size_t r = 0; r < p - 1; r++) {
509 target_ranks[r] = (global_size / p) * (r + 1);
512 if (r < global_size % p)
513 target_ranks[r] += 1;
517 LOG <<
"target_ranks: " << target_ranks;
525 std::vector<size_t> global_ranks(p - 1);
528 std::vector<ArrayNumInputsSizeT> left(p - 1), width(p - 1);
531 std::vector<Pivot> pivots(p - 1);
532 std::vector<ArrayNumInputsSizeT> local_ranks(p - 1);
536 for (
size_t r = 0; r < p - 1; r++) {
537 for (
size_t q = 0; q < kNumInputs; q++) {
538 width[r][q] = files_[q]->num_items();
542 bool finished =
false;
550 LOG0 <<
"left: " << left;
551 LOG0 <<
"width: " << width;
554 for (
size_t q = 0; q < kNumInputs; q++) {
555 std::ostringstream oss;
556 for (
size_t i = 0; i < p - 1; ++i) {
557 if (i != 0) oss <<
" # ";
558 oss <<
'[' << left[i][q] <<
',' << left[i][q] + width[i][q] <<
')';
560 LOG1 <<
"left/right[" << q <<
"]: " << oss.str();
569 LOG <<
"final pivots: " <<
VToStr(pivots);
575 LOG <<
"global_ranks: " << global_ranks;
576 LOG <<
"local_ranks: " << local_ranks;
578 SearchStep(global_ranks, local_ranks, target_ranks, left, width);
581 for (
size_t q = 0; q < kNumInputs; q++) {
582 std::ostringstream oss;
583 for (
size_t i = 0; i < p - 1; ++i) {
584 if (i != 0) oss <<
" # ";
585 oss <<
'[' << left[i][q] <<
',' << left[i][q] + width[i][q] <<
')';
587 LOG1 <<
"left/right[" << q <<
"]: " << oss.str();
593 for (
size_t i = 0; i < p - 1; i++) {
594 size_t a = global_ranks[i], b = target_ranks[i];
608 LOG <<
"Creating channels";
611 for (
size_t j = 0; j < kNumInputs; j++)
616 LOG <<
"Scattering.";
621 std::vector<size_t> tx_items(p);
622 for (
size_t j = 0; j < kNumInputs; j++) {
624 std::vector<size_t> offsets(p + 1, 0);
626 for (
size_t r = 0; r < p - 1; r++)
627 offsets[r + 1] = local_ranks[r][j];
629 offsets[p] = files_[j]->num_items();
631 LOG <<
"Scatter from file " << j <<
" to other workers: " 634 for (
size_t r = 0; r < p; ++r) {
635 tx_items[r] += offsets[r + 1] - offsets[r];
638 streams_[j]->template ScatterConsume<ValueType>(
639 *files_[j], offsets);
642 LOG <<
"tx_items: " << tx_items;
648 LOG1 <<
"Merge(): total_items: " << tx_items;
673 template <
typename Comparator,
typename FirstDIA,
typename... DIAs>
674 auto Merge(
const Comparator& comparator,
675 const FirstDIA& first_dia,
const DIAs& ... dias) {
677 tlx::vexpand((first_dia.AssertValid(), 0), (dias.AssertValid(), 0) ...);
679 using ValueType =
typename FirstDIA::ValueType;
681 using CompareResult =
682 typename common::FunctionTraits<Comparator>::result_type;
685 ValueType, Comparator, 1 +
sizeof ... (DIAs)>;
691 typename common::FunctionTraits<Comparator>::template arg<0>
693 "Comparator has the wrong input type in argument 0");
698 typename common::FunctionTraits<Comparator>::template arg<1>
700 "Comparator has the wrong input type in argument 1");
708 "Comparator must return bool");
711 tlx::make_counting<MergeNode>(comparator, first_dia, dias...);
716 template <
typename ValueType,
typename Stack>
717 template <
typename Comparator,
typename SecondDIA>
719 const SecondDIA& second_dia,
const Comparator& comparator)
const {
720 return api::Merge(comparator, *
this, second_dia);
730 #endif // !THRILL_API_MERGE_HEADER
Comparator comparator_
Merge comparator.
net::FlowControlChannel & net
#define sLOG
Default logging method: output if the local debug variable is true.
DIA is the interface between the user and the Thrill framework.
size_t prefix_size_
Count of items on all prev workers.
void PushItem(const ValueType &item) const
Method for derived classes to Push a single item to all children.
StatsTimer scatter_timer_
void GetGlobalRanks(const std::vector< Pivot > &pivots, std::vector< size_t > &global_ranks, std::vector< ArrayNumInputsSizeT > &out_local_ranks, const std::vector< ArrayNumInputsSizeT > &left, const std::vector< ArrayNumInputsSizeT > &width)
Calculates the global ranks of the given pivots.
auto Merge(const SecondDIA &second_dia, const Comparator &comparator=Comparator()) const
Merge is a DOp, which merges two sorted DIAs to a single sorted DIA.
static constexpr bool stats_enabled
Set this variable to true to enable generation and output of merge stats.
void SelectPivots(const std::vector< ArrayNumInputsSizeT > &left, const std::vector< ArrayNumInputsSizeT > &width, std::vector< Pivot > &out_pivots)
Selects random global pivots for all splitter searches based on all worker's search ranges...
size_t num_workers() const
Global number of workers in the system.
A File is an ordered sequence of Block objects for storing items.
#define LOG0
Override default output: never or always output log.
void StopPreOp(size_t parent_index) final
Virtual method for preparing end of PushData.
StatsTimer balancing_timer_
A Timer accumulating all time spent while re-balancing the data.
size_t iterations_
The count of search iterations needed for balancing.
MergeNode(const Comparator &comparator, const ParentDIA0 &parent0, const ParentDIAs &... parents)
const std::vector< DIABasePtr > & parents() const
Returns the parents of this DIABase.
const char * label() const
return label() of DIANode subclass as stored by StatsNode
BlockWriter contains a temporary Block object into which a) any serializable item can be stored or b)...
#define die(msg)
Instead of std::terminate(), throw the output the message via an exception.
void SearchStep(const std::vector< size_t > &global_ranks, const std::vector< ArrayNumInputsSizeT > &local_ranks, const std::vector< size_t > &target_ranks, std::vector< ArrayNumInputsSizeT > &left, std::vector< ArrayNumInputsSizeT > &width)
Shrinks the search ranges according to the global ranks of the pivots.
std::default_random_engine rng_
a random generator
void Execute() final
Virtual execution method. Triggers actual computation in sub-classes.
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT AllReduce(const T &value, const BinarySumOp &sum_op=BinarySumOp())
Reduces a value of a serializable type T over all workers given a certain reduce function.
The Context of a job is a unique instance per worker which holds references to all underlying parts o...
Stats stats_
Instance of merge statistics.
static constexpr bool self_verify
const std::array< bool, kNumInputs > parent_stack_empty_
Whether the parent stack is empty.
template for computing the component-wise sum of std::array or std::vector.
StatsTimer pivot_selection_timer_
data::File::Writer writers_[kNumInputs]
Writers to intermediate files.
data::CatStreamPtr GetNewCatStream(size_t dia_id)
Stats holds timers for measuring merge performance, that supports accumulating the output and printin...
auto Merge(const Comparator &comparator, const FirstDIA &first_dia, const DIAs &... dias)
Merge is a DOp, which merges any number of sorted DIAs to a single sorted DIA.
void PrintToSQLPlotTool(const std::string &label, size_t p, size_t value)
T abs_diff(const T &a, const T &b)
absolute difference, which also works for unsigned types
std::basic_string< char, std::char_traits< char >, Allocator< char > > string
string with Manager tracking
StatsTimer merge_timer_
A Timer accumulating all time spent while actually merging.
static std::string VToStr(const std::vector< Pivot > &data)
Logging helper to print vectors of vectors of pivots.
data::FilePtr GetFilePtr(size_t dia_id)
size_t my_rank() const
Global rank of this worker among all other workers in the system.
RegisterParent(MergeNode *merge_node)
void MainOp()
Receives elements from other workers and re-balance them, so each worker has the same amount after me...
std::array< size_t, kNumInputs > ArrayNumInputsSizeT
void operator()(const Index &, Parent &parent)
static constexpr bool debug
static constexpr bool g_debug_mode
debug mode is active, if NDEBUG is false.
A DOpNode is a typed node representing and distributed operations in Thrill.
StatsTimer file_op_timer_
A Timer accumulating all time spent in File operations.
TLX_ATTRIBUTE_ALWAYS_INLINE BlockWriter & Put(const T &x)
Put appends a complete item, or fails with a FullException.
StatsTimer comm_timer_
A Timer accumulating all time spent communicating.
bool OnPreOpFile(const data::File &file, size_t parent_index) final
Receive a whole data::File of ValueType, but only if our stack is empty.
void Close()
Explicitly close the writer.
data::FilePtr files_[kNumInputs]
Files for intermediate storage.
void PushData(bool consume) final
Virtual method for pushing data. Triggers actual pushing in sub-classes.
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT Broadcast(const T &value, size_t origin=0)
Broadcasts a value of a serializable type T from the master (the worker with id 0) to all other worke...
size_t result_size_
The count of all elements processed on this host.
#define LOG
Default logging method: output if the local debug variable is true.
void Dispose() final
Virtual clear method. Triggers actual disposing in sub-classes.
Implementation of Thrill's merge.
data::CatStreamPtr streams_[kNumInputs]
Array of inbound CatStreams.
Context & context_
associated Context
StatsTimer search_step_timer_
A Timer accumulating all time spent in global search steps.