12 #ifndef THRILL_EXAMPLES_SELECT_SELECT_HEADER 13 #define THRILL_EXAMPLES_SELECT_SELECT_HEADER 33 static constexpr
bool debug =
false;
35 static constexpr
double delta = 0.1;
39 #define LOGM LOGC(debug && ctx.my_rank() == 0) 41 template <
typename ValueType,
typename InStack,
42 typename Compare = std::less<ValueType> >
43 std::pair<ValueType, ValueType>
45 const Compare& compare = Compare()) {
49 const double size_d =
static_cast<double>(size);
51 const double p = 20 * sqrt(static_cast<double>(num_workers)) / size_d;
56 std::pair<ValueType, ValueType> pivots;
58 LOG <<
"got " << sample.size() <<
" samples (p = " << p <<
")";
60 std::sort(sample.begin(), sample.end(), compare);
62 const double base_pos =
63 static_cast<double>(rank * sample.size()) / size_d;
64 const double offset = pow(size_d, 0.25 + delta);
66 long lower_pos =
static_cast<long>(floor(base_pos - offset));
67 long upper_pos =
static_cast<long>(floor(base_pos + offset));
69 size_t lower =
static_cast<size_t>(
std::max(0L, lower_pos));
70 size_t upper =
static_cast<size_t>(
71 std::min(upper_pos, static_cast<long>(sample.size() - 1)));
73 assert(0 <= lower && lower < sample.size());
74 assert(0 <= upper && upper < sample.size());
76 LOG <<
"Selected pivots at positions " << lower <<
" and " << upper
77 <<
": " << sample[lower] <<
" and " << sample[upper];
79 pivots = std::make_pair(sample[lower], sample[upper]);
84 LOGM <<
"pivots: " << pivots.first <<
" and " << pivots.second;
89 template <
typename ValueType,
typename InStack,
90 typename Compare = std::less<ValueType> >
92 const Compare& compare = Compare()) {
94 const size_t size = data.
Keep().
Size();
96 assert(0 <= rank && rank < size);
98 if (size <= base_case_size) {
100 ValueType result = ValueType();
101 auto elements = data.
Gather();
104 assert(rank < elements.size());
105 std::nth_element(elements.begin(), elements.begin() + rank,
106 elements.end(), compare);
108 result = elements[rank];
110 LOG <<
"base case: " << size <<
" elements remaining, result is " 118 ValueType left_pivot, right_pivot;
119 std::tie(left_pivot, right_pivot) =
PickPivots(data, size, rank, compare);
121 size_t left_size, middle_size, right_size;
123 using PartSizes = std::pair<size_t, size_t>;
124 std::tie(left_size, middle_size) =
126 [&](
const ValueType& elem) -> PartSizes {
127 if (compare(elem, left_pivot))
128 return PartSizes { 1, 0 };
129 else if (!compare(right_pivot, elem))
130 return PartSizes { 0, 1 };
132 return PartSizes { 0, 0 };
135 [](
const PartSizes& a,
const PartSizes& b) -> PartSizes {
136 return PartSizes { a.first + b.first, a.second + b.second };
139 right_size = size - left_size - middle_size;
141 LOGM <<
"left_size = " << left_size <<
", middle_size = " << middle_size
142 <<
", right_size = " << right_size <<
", rank = " << rank;
144 if (rank == left_size) {
147 LOGM <<
"result is left pivot: " << left_pivot;
150 else if (rank == left_size + middle_size - 1) {
153 LOGM <<
"result is right pivot: " << right_pivot;
156 else if (rank < left_size) {
158 LOGM <<
"Recursing left, " << left_size
159 <<
" elements remaining (rank = " << rank <<
")\n";
162 [&](
const ValueType& elem) ->
bool {
163 return compare(elem, left_pivot);
165 return Select(left, rank, compare);
167 else if (left_size + middle_size <= rank) {
169 LOGM <<
"Recursing right, " << right_size
170 <<
" elements remaining (rank = " << rank - left_size - middle_size
174 [&](
const ValueType& elem) ->
bool {
175 return compare(right_pivot, elem);
177 return Select(right, rank - left_size - middle_size, compare);
181 LOGM <<
"Recursing middle, " << middle_size
182 <<
" elements remaining (rank = " << rank - left_size <<
")\n";
184 auto middle = data.
Filter(
185 [&](
const ValueType& elem) ->
bool {
186 return !compare(elem, left_pivot) &&
187 !compare(right_pivot, elem);
189 return Select(middle, rank - left_size, compare);
196 #endif // !THRILL_EXAMPLES_SELECT_SELECT_HEADER net::FlowControlChannel & net
DIA is the interface between the user and the Thrill framework.
static uint_pair max()
return an uint_pair instance containing the largest value possible
static constexpr size_t base_case_size
size_t num_workers() const
Global number of workers in the system.
static constexpr double delta
Context & context() const
Return context_ of DIANode, e.g. for creating new LOps and DOps.
The Context of a job is a unique instance per worker which holds references to all underlying parts o...
std::pair< ValueType, ValueType > PickPivots(const DIA< ValueType, InStack > &data, size_t size, size_t rank, const Compare &compare=Compare())
auto BernoulliSample(double p) const
Each item of a DIA is copied into the output DIA with success probability p (an independent Bernoulli...
auto Filter(const FilterFunction &filter_function) const
Each item of a DIA is tested using filter_function : to determine whether it is copied into the outp...
auto Map(const MapFunction &map_function) const
Map applies map_function : to each item of a DIA and delivers a new DIA contains the returned values...
size_t my_rank() const
Global rank of this worker among all other workers in the system.
static constexpr bool debug
static uint_pair min()
return an uint_pair instance containing the smallest value possible
ValueType Select(const DIA< ValueType, InStack > &data, size_t rank, const Compare &compare=Compare())
size_t Size() const
Computes the total size of all elements across all workers.
const DIA & Keep(size_t increase=1) const
Mark the referenced DIANode for keeping, which makes children not consume the data when executing...
std::vector< ValueType > Gather(size_t target_id=0) const
Gather is an Action, which collects all data of the DIA into a vector at the given worker...
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT Broadcast(const T &value, size_t origin=0)
Broadcasts a value of a serializable type T from the master (the worker with id 0) to all other worke...
#define LOG
Default logging method: output if the local debug variable is true.