Thrill  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
sample.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * thrill/api/sample.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2016 Timo Bingmann <[email protected]>
7  *
8  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
9  ******************************************************************************/
10 
11 #pragma once
12 #ifndef THRILL_API_SAMPLE_HEADER
13 #define THRILL_API_SAMPLE_HEADER
14 
15 #include <thrill/api/dia.hpp>
16 #include <thrill/api/dop_node.hpp>
18 #include <thrill/common/logger.hpp>
21 
22 #include <tlx/math.hpp>
23 
24 #include <algorithm>
25 #include <cassert>
26 #include <vector>
27 
28 namespace thrill {
29 namespace api {
30 
31 /*!
32  * A DIANode which performs sampling *without* replacement.
33  *
34  * The implementation is an adaptation of Algorithm P from Sanders, Lamm,
35  * Hübschle-Schneider, Schrade, Dachsbacher, ACM TOMS 2017: "Efficient Random
36  * Sampling - Parallel, Vectorized, Cache-Efficient, and Online". The
37  * modification is in how samples are assigned to workers. Instead of doing
38  * log(num_workers) splits to assign samples to ranges of workers, do
39  * O(log(input_size)) splits to assign samples to input ranges. Workers only
40  * compute the ranges which overlap their local input range, and then add up the
41  * ranges that are fully contained in their local input range. This ensures
42  * consistency while requiring only a single prefix-sum and two scalar
43  * broadcasts.
44  *
45  * \ingroup api_layer
46  */
47 template <typename ValueType>
48 class SampleNode final : public DOpNode<ValueType>
49 {
50  static constexpr bool debug = false;
51 
53  using Super::context_;
54 
55 public:
56  template <typename ParentDIA>
57  SampleNode(const ParentDIA& parent, size_t sample_size)
58  : Super(parent.ctx(), "Sample", { parent.id() }, { parent.node() }),
59  local_size_(0), sample_size_(sample_size), local_samples_(0),
60  hyp_(42 /* dummy seed */),
61  sampler_(sample_size, samples_, rng_),
62  parent_stack_empty_(ParentDIA::stack_empty)
63  {
64  auto presample_fn = [this](const ValueType& input) {
65  sampler_.add(input);
66  };
67  auto lop_chain = parent.stack().push(presample_fn).fold();
68  parent.node()->AddChild(this, lop_chain);
69  }
70 
71  void Execute() final {
72  local_timer_.Start();
73  local_size_ = sampler_.count();
74  sLOG << "SampleNode::Execute() processing" << local_size_
75  << "elements of which" << samples_.size()
76  << "were presampled, global sample size =" << sample_size_;
77 
78  if (context_.num_workers() == 1) {
79  sample_size_ = std::min(sample_size_, local_size_);
80  local_samples_ = sample_size_;
81  sLOG << "SampleNode::Execute (alone) => all"
82  << local_samples_ << "samples";
83  return;
84  }
85 
86  // Compute number of input elements left of self and total input size
87  size_t local_rank = local_size_;
88  size_t global_size = context_.net.ExPrefixSumTotal(local_rank);
89 
90  if (global_size <= sample_size_) {
91  // Requested sample is larger than the number of elements,
92  // return everything
93  assert(samples_.size() == local_size_);
94  local_samples_ = local_size_;
95  sLOG << "SampleNode::Execute (underfull)"
96  << local_samples_ << "of" << sample_size_ << "samples";
97  return;
98  }
99 
100  // Determine and broadcast seed
101  size_t seed = 0;
102  if (context_.my_rank() == 0) {
103  seed = std::random_device { } ();
104  }
105  local_timer_.Stop(), comm_timer_.Start();
106  seed = context_.net.Broadcast(seed);
107  comm_timer_.Stop(), local_timer_.Start();
108 
109  // Calculate number of local samples by recursively splitting the range
110  // considered in half and assigning samples there
111  local_samples_ = calc_local_samples(
112  local_rank, local_rank + local_size_,
113  0, global_size, sample_size_, seed);
114 
115  assert(local_samples_ <= local_size_);
116  assert(local_samples_ <= samples_.size());
117 
118  local_timer_.Stop();
119  sLOG << "SampleNode::Execute"
120  << local_samples_ << "of" << sample_size_ << "samples"
121  << "(got" << local_size_ << "=>"
122  << samples_.size() << "elements),"
123  << "communication time:" << comm_timer_.Microseconds() / 1000.0;
124  }
125 
126  void PushData(bool consume) final {
127  // don't start global timer in pushdata!
128  common::StatsTimerStart push_timer;
129 
130  sLOGC(local_samples_ > samples_.size())
131  << "WTF ERROR CAN'T DRAW" << local_samples_ << "FROM"
132  << samples_.size() << "PRESAMPLES";
133 
134  // Most likely, we'll need to draw the requested number of samples from
135  // the presample that we computed in the PreOp
136  if (local_samples_ < samples_.size()) {
137  sLOG << "Drawing" << local_samples_ << "samples locally from"
138  << samples_.size() << "pre-samples";
139  std::vector<ValueType> subsample;
140  common::Sampling<> subsampler(rng_);
141  subsampler(samples_.begin(), samples_.end(),
142  local_samples_, subsample);
143  samples_.swap(subsample);
144  LOGC(samples_.size() != local_samples_)
145  << "ERROR: SAMPLE SIZE IS WRONG";
146  }
147  push_timer.Stop(); // don't measure PushItem
148  local_timer_ += push_timer;
149 
150  for (const ValueType& v : samples_) {
151  this->PushItem(v);
152  }
153  if (consume)
154  std::vector<ValueType>().swap(samples_);
155 
156  sLOG << "SampleNode::PushData finished; total local time excl PushData:"
157  << local_timer_.Microseconds() / 1000.0
158  << "ms, communication:" << comm_timer_.Microseconds() / 1000.0
159  << "ms =" << comm_timer_.Microseconds() * 100.0 /
160  (comm_timer_.Microseconds() + local_timer_.Microseconds())
161  << "%";
162  }
163 
164  void Dispose() final {
165  std::vector<ValueType>().swap(samples_);
166  }
167 
168 private:
169  size_t hash_combine(size_t seed, size_t v) {
170  // technically v needs to be hashed...
171  return seed ^ (v + 0x9e3779b9 + (seed << 6) + (seed >> 2));
172  }
173 
174  // ranges are exclusive, like iterator begin / end
175  size_t calc_local_samples(size_t my_begin, size_t my_end,
176  size_t range_begin, size_t range_end,
177  size_t sample_size, size_t seed) {
178  // handle empty ranges and case without any samples
179  if (range_begin >= range_end) return 0;
180  if (my_begin >= my_end) return 0;
181  if (sample_size == 0) return 0;
182 
183  // is the range contained in my part? then all samples are mine
184  if (my_begin <= range_begin && range_end <= my_end) {
185  LOG << "my range [" << my_begin << ", " << my_end
186  << ") is contained in the currently considered range ["
187  << range_begin << ", " << range_end << ") and thus gets all "
188  << sample_size << " samples";
189  return sample_size;
190  }
191 
192  // does my range overlap the global range?
193  if ((range_begin <= my_begin && my_begin < range_end) ||
194  (range_begin < my_end && my_end <= range_end)) {
195 
196  // seed the distribution so that all PEs generate the same values in
197  // the same subtrees, but different values in different subtrees
198  size_t new_seed = hash_combine(hash_combine(seed, range_begin),
199  range_end);
200  hyp_.seed(new_seed);
201 
202  const size_t left_size = (range_end - range_begin) / 2,
203  right_size = (range_end - range_begin) - left_size;
204  const size_t left_samples = hyp_(left_size, right_size, sample_size);
205 
206  LOG << "my range [" << my_begin << ", " << my_end
207  << ") overlaps the currently considered range ["
208  << range_begin << ", " << range_end << "), splitting: "
209  << "left range [" << range_begin << ", " << range_begin + left_size
210  << ") gets " << left_samples << " samples, right range ["
211  << range_begin + left_size << ", " << range_end << ") the remaining "
212  << sample_size - left_samples << " for a total of " << sample_size
213  << " samples";
214 
215  const size_t
216  left_result = calc_local_samples(
217  my_begin, my_end, range_begin, range_begin + left_size,
218  left_samples, seed),
219  right_result = calc_local_samples(
220  my_begin, my_end, range_begin + left_size, range_end,
221  sample_size - left_samples, seed);
222  return left_result + right_result;
223  }
224 
225  // no overlap
226  return 0;
227  }
228 
229  //! local input size, number of samples to draw globally, and locally
230  size_t local_size_, sample_size_, local_samples_;
231  //! local samples
232  std::vector<ValueType> samples_;
233  //! Hypergeometric distribution to calculate local sample sizes
235  //! Random generator for reservoir sampler
236  std::mt19937 rng_ { std::random_device { } () };
237  //! Reservoir sampler for pre-op
239  //! Timers for local work and communication
241  //! Whether the parent stack is empty
243 };
244 
245 template <typename ValueType, typename Stack>
246 auto DIA<ValueType, Stack>::Sample(size_t sample_size) const {
247  assert(IsValid());
248 
250 
251  auto node = tlx::make_counting<SampleNode>(
252  *this, sample_size);
253 
254  return DIA<ValueType>(node);
255 }
256 
257 } // namespace api
258 } // namespace thrill
259 
260 #endif // !THRILL_API_SAMPLE_HEADER
261 
262 /******************************************************************************/
common::StatsTimerStopped local_timer_
Timers for local work and communication.
Definition: sample.hpp:240
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT ExPrefixSumTotal(T &value, const T &initial=T(), const BinarySumOp &sum_op=BinarySumOp())
Calculates the exclusive prefix sum over all workers, and delivers the total sum as well...
net::FlowControlChannel & net
Definition: context.hpp:443
virtual void Dispose()
Virtual clear method. Triggers actual disposing in sub-classes.
Definition: dia_base.hpp:188
DIA is the interface between the user and the Thrill framework.
Definition: dia.hpp:141
auto Sample(size_t sample_size) const
Select up to sample_size items uniformly at random and return a new DIA<T>.
Definition: sample.hpp:246
size_t my_rank() const
Global rank of this worker among all other workers in the system.
Definition: context.hpp:240
static constexpr bool debug
Definition: sample.hpp:50
void add(const Type &item)
visit item, maybe add it to the sample.
SampleNode(const ParentDIA &parent, size_t sample_size)
Definition: sample.hpp:57
virtual void PushData(bool consume)=0
Virtual method for pushing data. Triggers actual pushing in sub-classes.
#define sLOG
Default logging method: output if the local debug variable is true.
Definition: logger.hpp:184
void PushItem(const ValueType &item) const
Method for derived classes to Push a single item to all children.
Definition: dia_node.hpp:147
void swap(CountingPtr< A, D > &a1, CountingPtr< A, D > &a2) noexcept
#define LOGC(cond)
Explicitly specify the condition for logging.
Definition: logger.hpp:167
virtual void Execute()=0
Virtual execution method. Triggers actual computation in sub-classes.
StatsTimerBaseStarted< true > StatsTimerStart
common::ReservoirSamplingFast< ValueType, decltype(rng_)> sampler_
Reservoir sampler for pre-op.
Definition: sample.hpp:238
common::hypergeometric hyp_
Hypergeometric distribution to calculate local sample sizes.
Definition: sample.hpp:234
size_t count() const
number of items seen
#define sLOGC(cond)
Explicitly specify the condition for logging.
Definition: logger.hpp:179
unsigned seed
const bool parent_stack_empty_
Whether the parent stack is empty.
Definition: sample.hpp:242
A DOpNode is a typed node representing and distributed operations in Thrill.
Definition: dop_node.hpp:32
A DIANode which performs sampling without replacement.
Definition: sample.hpp:48
std::mt19937 rng_
Random generator for reservoir sampler.
Definition: sample.hpp:236
size_t num_workers() const
Global number of workers in the system.
Definition: context.hpp:248
static constexpr const T & min(const T &a, const T &b)
template for constexpr min, because std::min is not good enough.
Definition: functional.hpp:59
std::vector< ValueType > samples_
local samples
Definition: sample.hpp:58
common::StatsTimerStopped comm_timer_
Definition: sample.hpp:240
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT Broadcast(const T &value, size_t origin=0)
Broadcasts a value of a serializable type T from the master (the worker with id 0) to all other worke...
Context & context_
associated Context
Definition: dia_base.hpp:293
#define LOG
Default logging method: output if the local debug variable is true.
Definition: logger.hpp:172