Thrill  0.1
sample.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * thrill/api/sample.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2016 Timo Bingmann <[email protected]>
7  * Copyright (C) 2017-2018 Lorenz Hübschle-Schneider <[email protected]>
8  *
9  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
10  ******************************************************************************/
11 
12 #pragma once
13 #ifndef THRILL_API_SAMPLE_HEADER
14 #define THRILL_API_SAMPLE_HEADER
15 
16 #include <thrill/api/dia.hpp>
17 #include <thrill/api/dop_node.hpp>
19 #include <thrill/common/logger.hpp>
22 
23 #include <tlx/math.hpp>
24 #include <tlx/vector_free.hpp>
25 
26 #include <algorithm>
27 #include <cassert>
28 #include <vector>
29 
30 namespace thrill {
31 namespace api {
32 
33 /*!
34  * A DIANode which performs sampling *without* replacement.
35  *
36  * The implementation is an adaptation of Algorithm P from Sanders, Lamm,
37  * Hübschle-Schneider, Schrade, Dachsbacher, ACM TOMS 2017: "Efficient Random
38  * Sampling - Parallel, Vectorized, Cache-Efficient, and Online". The
39  * modification is in how samples are assigned to workers. Instead of doing
40  * log(num_workers) splits to assign samples to ranges of workers, do
41  * O(log(input_size)) splits to assign samples to input ranges. Workers only
42  * compute the ranges which overlap their local input range, and then add up the
43  * ranges that are fully contained in their local input range. This ensures
44  * consistency while requiring only a single prefix-sum and two scalar
45  * broadcasts.
46  *
47  * \ingroup api_layer
48  */
49 template <typename ValueType>
50 class SampleNode final : public DOpNode<ValueType>
51 {
52  static constexpr bool debug = false;
53 
55  using Super::context_;
56 
57 public:
58  template <typename ParentDIA>
59  SampleNode(const ParentDIA& parent, size_t sample_size)
60  : Super(parent.ctx(), "Sample", { parent.id() }, { parent.node() }),
61  local_size_(0), sample_size_(sample_size), local_samples_(0),
62  hyp_(42 /* dummy seed */),
63  sampler_(sample_size, samples_, rng_),
64  parent_stack_empty_(ParentDIA::stack_empty) {
65  auto presample_fn = [this](const ValueType& input) {
66  sampler_.add(input);
67  };
68  auto lop_chain = parent.stack().push(presample_fn).fold();
69  parent.node()->AddChild(this, lop_chain);
70  }
71 
72  void Execute() final {
73  local_timer_.Start();
75  sLOG << "SampleNode::Execute() processing" << local_size_
76  << "elements of which" << samples_.size()
77  << "were presampled, global sample size =" << sample_size_;
78 
79  if (context_.num_workers() == 1) {
82  sLOG << "SampleNode::Execute (alone) => all"
83  << local_samples_ << "samples";
84  return;
85  }
86 
87  // Compute number of input elements left of self and total input size
88  size_t local_rank = local_size_;
89  size_t global_size = context_.net.ExPrefixSumTotal(local_rank);
90 
91  if (global_size <= sample_size_) {
92  // Requested sample is larger than the number of elements,
93  // return everything
94  assert(samples_.size() == local_size_);
96  sLOG << "SampleNode::Execute (underfull)"
97  << local_samples_ << "of" << sample_size_ << "samples";
98  return;
99  }
100 
101  // Determine and broadcast seed
102  size_t seed = 0;
103  if (context_.my_rank() == 0) {
104  seed = std::random_device { } ();
105  }
106  local_timer_.Stop(), comm_timer_.Start();
107  seed = context_.net.Broadcast(seed);
108  comm_timer_.Stop(), local_timer_.Start();
109 
110  // Calculate number of local samples by recursively splitting the range
111  // considered in half and assigning samples there
113  local_rank, local_rank + local_size_,
114  0, global_size, sample_size_, seed);
115 
116  assert(local_samples_ <= local_size_);
117  assert(local_samples_ <= samples_.size());
118 
119  local_timer_.Stop();
120  sLOG << "SampleNode::Execute"
121  << local_samples_ << "of" << sample_size_ << "samples"
122  << "(got" << local_size_ << "=>"
123  << samples_.size() << "elements),"
124  << "communication time:" << comm_timer_.Microseconds() / 1000.0;
125  }
126 
127  void PushData(bool consume) final {
128  // don't start global timer in pushdata!
129  common::StatsTimerStart push_timer;
130 
131  sLOGC(local_samples_ > samples_.size())
132  << "WTF ERROR CAN'T DRAW" << local_samples_ << "FROM"
133  << samples_.size() << "PRESAMPLES";
134 
135  // Most likely, we'll need to draw the requested number of samples from
136  // the presample that we computed in the PreOp
137  if (local_samples_ < samples_.size()) {
138  sLOG << "Drawing" << local_samples_ << "samples locally from"
139  << samples_.size() << "pre-samples";
140  std::vector<ValueType> subsample;
142  subsampler(samples_.begin(), samples_.end(),
143  local_samples_, subsample);
144  samples_.swap(subsample);
145  LOGC(samples_.size() != local_samples_)
146  << "ERROR: SAMPLE SIZE IS WRONG";
147  }
148  push_timer.Stop(); // don't measure PushItem
149  local_timer_ += push_timer;
150 
151  for (const ValueType& v : samples_) {
152  this->PushItem(v);
153  }
154  if (consume)
155  tlx::vector_free(samples_);
156 
157  sLOG << "SampleNode::PushData finished; total local time excl PushData:"
158  << local_timer_.Microseconds() / 1000.0
159  << "ms, communication:" << comm_timer_.Microseconds() / 1000.0
160  << "ms =" << comm_timer_.Microseconds() * 100.0 /
161  (comm_timer_.Microseconds() + local_timer_.Microseconds())
162  << "%";
163  }
164 
165  void Dispose() final {
167  }
168 
169 private:
170  size_t hash_combine(size_t seed, size_t v) {
171  // technically v needs to be hashed...
172  return seed ^ (v + 0x9e3779b9 + (seed << 6) + (seed >> 2));
173  }
174 
175  // ranges are exclusive, like iterator begin / end
176  size_t calc_local_samples(size_t my_begin, size_t my_end,
177  size_t range_begin, size_t range_end,
178  size_t sample_size, size_t seed) {
179  // handle empty ranges and case without any samples
180  if (range_begin >= range_end) return 0;
181  if (my_begin >= my_end) return 0;
182  if (sample_size == 0) return 0;
183 
184  // is the range contained in my part? then all samples are mine
185  if (my_begin <= range_begin && range_end <= my_end) {
186  LOG << "my range [" << my_begin << ", " << my_end
187  << ") is contained in the currently considered range ["
188  << range_begin << ", " << range_end << ") and thus gets all "
189  << sample_size << " samples";
190  return sample_size;
191  }
192 
193  // does my range overlap the global range?
194  if ((range_begin <= my_begin && my_begin < range_end) ||
195  (range_begin < my_end && my_end <= range_end)) {
196 
197  // seed the distribution so that all PEs generate the same values in
198  // the same subtrees, but different values in different subtrees
199  size_t new_seed = hash_combine(hash_combine(seed, range_begin),
200  range_end);
201  hyp_.seed(new_seed);
202 
203  const size_t left_size = (range_end - range_begin) / 2,
204  right_size = (range_end - range_begin) - left_size;
205  const size_t left_samples = hyp_(left_size, right_size, sample_size);
206 
207  LOG << "my range [" << my_begin << ", " << my_end
208  << ") overlaps the currently considered range ["
209  << range_begin << ", " << range_end << "), splitting: "
210  << "left range [" << range_begin << ", " << range_begin + left_size
211  << ") gets " << left_samples << " samples, right range ["
212  << range_begin + left_size << ", " << range_end << ") the remaining "
213  << sample_size - left_samples << " for a total of " << sample_size
214  << " samples";
215 
216  const size_t
217  left_result = calc_local_samples(
218  my_begin, my_end, range_begin, range_begin + left_size,
219  left_samples, seed),
220  right_result = calc_local_samples(
221  my_begin, my_end, range_begin + left_size, range_end,
222  sample_size - left_samples, seed);
223  return left_result + right_result;
224  }
225 
226  // no overlap
227  return 0;
228  }
229 
230  //! local input size, number of samples to draw globally, and locally
232  //! local samples
233  std::vector<ValueType> samples_;
234  //! Hypergeometric distribution to calculate local sample sizes
236  //! Random generator for reservoir sampler
237  std::mt19937_64 rng_ { std::random_device { } () };
238  //! Reservoir sampler for pre-op
240  //! Timers for local work and communication
242  //! Whether the parent stack is empty
244 };
245 
246 template <typename ValueType, typename Stack>
247 auto DIA<ValueType, Stack>::Sample(size_t sample_size) const {
248  assert(IsValid());
249 
251 
252  auto node = tlx::make_counting<SampleNode>(
253  *this, sample_size);
254 
255  return DIA<ValueType>(node);
256 }
257 
258 } // namespace api
259 } // namespace thrill
260 
261 #endif // !THRILL_API_SAMPLE_HEADER
262 
263 /******************************************************************************/
std::mt19937_64 rng_
Random generator for reservoir sampler.
Definition: sample.hpp:237
common::StatsTimerStopped local_timer_
Timers for local work and communication.
Definition: sample.hpp:241
net::FlowControlChannel & net
Definition: context.hpp:446
#define sLOG
Default logging method: output if the local debug variable is true.
Definition: logger.hpp:34
DIA is the interface between the user and the Thrill framework.
Definition: dia.hpp:141
void PushItem(const ValueType &item) const
Method for derived classes to Push a single item to all children.
Definition: dia_node.hpp:147
static constexpr bool debug
Definition: sample.hpp:52
void add(const Type &item)
visit item, maybe add it to the sample.
SampleNode(const ParentDIA &parent, size_t sample_size)
Definition: sample.hpp:59
size_t num_workers() const
Global number of workers in the system.
Definition: context.hpp:251
#define sLOGC(cond)
Explicitly specify the condition for logging.
Definition: logger.hpp:31
size_t local_size_
local input size, number of samples to draw globally, and locally
Definition: sample.hpp:231
void Execute() final
Virtual execution method. Triggers actual computation in sub-classes.
Definition: sample.hpp:72
size_t count() const
number of items seen
size_t calc_local_samples(size_t my_begin, size_t my_end, size_t range_begin, size_t range_end, size_t sample_size, size_t seed)
Definition: sample.hpp:176
auto Sample(size_t sample_size) const
Select up to sample_size items uniformly at random and return a new DIA<T>.
Definition: sample.hpp:247
void PushData(bool consume) final
Virtual method for pushing data. Triggers actual pushing in sub-classes.
Definition: sample.hpp:127
common::ReservoirSamplingFast< ValueType, decltype(rng_)> sampler_
Reservoir sampler for pre-op.
Definition: sample.hpp:239
common::hypergeometric hyp_
Hypergeometric distribution to calculate local sample sizes.
Definition: sample.hpp:235
size_t my_rank() const
Global rank of this worker among all other workers in the system.
Definition: context.hpp:243
unsigned seed
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT ExPrefixSumTotal(T &value, const BinarySumOp &sum_op=BinarySumOp(), const T &initial=T())
Calculates the exclusive prefix sum over all workers, and delivers the total sum as well...
void vector_free(std::vector< Type > &v)
Definition: vector_free.hpp:21
const bool parent_stack_empty_
Whether the parent stack is empty.
Definition: sample.hpp:243
A DOpNode is a typed node representing and distributed operations in Thrill.
Definition: dop_node.hpp:32
static uint_pair min()
return an uint_pair instance containing the smallest value possible
Definition: uint_types.hpp:217
A DIANode which performs sampling without replacement.
Definition: sample.hpp:50
void Dispose() final
Virtual clear method. Triggers actual disposing in sub-classes.
Definition: sample.hpp:165
size_t hash_combine(size_t seed, size_t v)
Definition: sample.hpp:170
std::vector< ValueType > samples_
local samples
Definition: sample.hpp:233
common::StatsTimerStopped comm_timer_
Definition: sample.hpp:241
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT Broadcast(const T &value, size_t origin=0)
Broadcasts a value of a serializable type T from the master (the worker with id 0) to all other worke...
#define LOG
Default logging method: output if the local debug variable is true.
Definition: logger.hpp:24
Context & context_
associated Context
Definition: dia_base.hpp:293
#define LOGC(cond)
Explicitly specify the condition for logging.
Definition: logger.hpp:21