Thrill  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
select.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * examples/select/select.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2016 Lorenz Hübschle-Schneider <[email protected]>
7  *
8  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
9  ******************************************************************************/
10 
11 #pragma once
12 #ifndef THRILL_EXAMPLES_SELECT_SELECT_HEADER
13 #define THRILL_EXAMPLES_SELECT_SELECT_HEADER
14 
16 #include <thrill/api/collapse.hpp>
17 #include <thrill/api/dia.hpp>
18 #include <thrill/api/gather.hpp>
19 #include <thrill/api/size.hpp>
20 #include <thrill/api/sum.hpp>
21 #include <thrill/common/logger.hpp>
22 
23 #include <algorithm>
24 #include <cmath>
25 #include <functional>
26 #include <utility>
27 
28 namespace examples {
29 namespace select {
30 
31 using namespace thrill; // NOLINT
32 
33 static constexpr bool debug = false;
34 
35 static constexpr double delta = 0.1; // 0 < delta < 0.25
36 
37 static constexpr size_t base_case_size = 1024;
38 
39 #define LOGM LOGC(debug && ctx.my_rank() == 0)
40 
41 template <typename ValueType, typename InStack,
42  typename Compare = std::less<ValueType> >
43 std::pair<ValueType, ValueType>
44 PickPivots(const DIA<ValueType, InStack>& data, size_t size, size_t rank,
45  const Compare& compare = Compare()) {
46  api::Context& ctx = data.context();
47 
48  const size_t num_workers(ctx.num_workers());
49  const double size_d = static_cast<double>(size);
50 
51  const double p = 20 * sqrt(static_cast<double>(num_workers)) / size_d;
52 
53  // materialized at worker 0
54  auto sample = data.Keep().BernoulliSample(p).Gather();
55 
56  std::pair<ValueType, ValueType> pivots;
57  if (ctx.my_rank() == 0) {
58  LOG << "got " << sample.size() << " samples (p = " << p << ")";
59  // Sort the samples
60  std::sort(sample.begin(), sample.end(), compare);
61 
62  const double base_pos =
63  static_cast<double>(rank * sample.size()) / size_d;
64  const double offset = pow(size_d, 0.25 + delta);
65 
66  long lower_pos = static_cast<long>(floor(base_pos - offset));
67  long upper_pos = static_cast<long>(floor(base_pos + offset));
68 
69  size_t lower = static_cast<size_t>(std::max(0L, lower_pos));
70  size_t upper = static_cast<size_t>(
71  std::min(upper_pos, static_cast<long>(sample.size() - 1)));
72 
73  assert(0 <= lower && lower < sample.size());
74  assert(0 <= upper && upper < sample.size());
75 
76  LOG << "Selected pivots at positions " << lower << " and " << upper
77  << ": " << sample[lower] << " and " << sample[upper];
78 
79  pivots = std::make_pair(sample[lower], sample[upper]);
80  }
81 
82  pivots = ctx.net.Broadcast(pivots);
83 
84  LOGM << "pivots: " << pivots.first << " and " << pivots.second;
85 
86  return pivots;
87 }
88 
89 template <typename ValueType, typename InStack,
90  typename Compare = std::less<ValueType> >
91 ValueType Select(const DIA<ValueType, InStack>& data, size_t rank,
92  const Compare& compare = Compare()) {
93  api::Context& ctx = data.context();
94  const size_t size = data.Keep().Size();
95 
96  assert(0 <= rank && rank < size);
97 
98  if (size <= base_case_size) {
99  // base case, gather all data at worker with rank 0
100  ValueType result = ValueType();
101  auto elements = data.Gather();
102 
103  if (ctx.my_rank() == 0) {
104  assert(rank < elements.size());
105  std::nth_element(elements.begin(), elements.begin() + rank,
106  elements.end(), compare);
107 
108  result = elements[rank];
109 
110  LOG << "base case: " << size << " elements remaining, result is "
111  << result;
112  }
113 
114  result = ctx.net.Broadcast(result);
115  return result;
116  }
117 
118  ValueType left_pivot, right_pivot;
119  std::tie(left_pivot, right_pivot) = PickPivots(data, size, rank, compare);
120 
121  size_t left_size, middle_size, right_size;
122 
123  using PartSizes = std::pair<size_t, size_t>;
124  std::tie(left_size, middle_size) =
125  data.Keep().Map(
126  [&](const ValueType& elem) -> PartSizes {
127  if (compare(elem, left_pivot))
128  return PartSizes { 1, 0 };
129  else if (!compare(right_pivot, elem))
130  return PartSizes { 0, 1 };
131  else
132  return PartSizes { 0, 0 };
133  })
134  .Sum(
135  [](const PartSizes& a, const PartSizes& b) -> PartSizes {
136  return PartSizes { a.first + b.first, a.second + b.second };
137  },
138  PartSizes { 0, 0 });
139  right_size = size - left_size - middle_size;
140 
141  LOGM << "left_size = " << left_size << ", middle_size = " << middle_size
142  << ", right_size = " << right_size << ", rank = " << rank;
143 
144  if (rank == left_size) {
145  // all the elements strictly smaller than the left pivot are on the left
146  // side -> left_size-th element is the left pivot
147  LOGM << "result is left pivot: " << left_pivot;
148  return left_pivot;
149  }
150  else if (rank == left_size + middle_size - 1) {
151  // only the elements strictly greater than the right pivot are on the
152  // right side, so the result is the right pivot in this case
153  LOGM << "result is right pivot: " << right_pivot;
154  return right_pivot;
155  }
156  else if (rank < left_size) {
157  // recurse on the left partition
158  LOGM << "Recursing left, " << left_size
159  << " elements remaining (rank = " << rank << ")\n";
160 
161  auto left = data.Filter(
162  [&](const ValueType& elem) -> bool {
163  return compare(elem, left_pivot);
164  }).Collapse();
165  return Select(left, rank, compare);
166  }
167  else if (left_size + middle_size <= rank) {
168  // recurse on the right partition
169  LOGM << "Recursing right, " << right_size
170  << " elements remaining (rank = " << rank - left_size - middle_size
171  << ")\n";
172 
173  auto right = data.Filter(
174  [&](const ValueType& elem) -> bool {
175  return compare(right_pivot, elem);
176  }).Collapse();
177  return Select(right, rank - left_size - middle_size, compare);
178  }
179  else {
180  // recurse on the middle partition
181  LOGM << "Recursing middle, " << middle_size
182  << " elements remaining (rank = " << rank - left_size << ")\n";
183 
184  auto middle = data.Filter(
185  [&](const ValueType& elem) -> bool {
186  return !compare(elem, left_pivot) &&
187  !compare(right_pivot, elem);
188  }).Collapse();
189  return Select(middle, rank - left_size, compare);
190  }
191 }
192 
193 } // namespace select
194 } // namespace examples
195 
196 #endif // !THRILL_EXAMPLES_SELECT_SELECT_HEADER
197 
198 /******************************************************************************/
net::FlowControlChannel & net
Definition: context.hpp:443
static uint_pair max()
return an uint_pair instance containing the largest value possible
Definition: uint_types.hpp:226
static constexpr size_t base_case_size
Definition: select.hpp:37
#define LOGM
Definition: select.hpp:39
size_t my_rank() const
Global rank of this worker among all other workers in the system.
Definition: context.hpp:240
static constexpr double delta
Definition: select.hpp:35
The Context of a job is a unique instance per worker which holds references to all underlying parts o...
Definition: context.hpp:218
std::pair< ValueType, ValueType > PickPivots(const DIA< ValueType, InStack > &data, size_t size, size_t rank, const Compare &compare=Compare())
Definition: select.hpp:44
static constexpr bool debug
Definition: select.hpp:33
static uint_pair min()
return an uint_pair instance containing the smallest value possible
Definition: uint_types.hpp:217
ValueType Select(const DIA< ValueType, InStack > &data, size_t rank, const Compare &compare=Compare())
Definition: select.hpp:91
size_t num_workers() const
Global number of workers in the system.
Definition: context.hpp:248
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT Broadcast(const T &value, size_t origin=0)
Broadcasts a value of a serializable type T from the master (the worker with id 0) to all other worke...
#define LOG
Default logging method: output if the local debug variable is true.
Definition: logger.hpp:24