Thrill  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
group_by_key.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * thrill/api/group_by_key.hpp
3  *
4  * DIANode for a groupby operation. Performs the actual groupby operation
5  *
6  * Part of Project Thrill - http://project-thrill.org
7  *
8  * Copyright (C) 2015 Huyen Chau Nguyen <[email protected]>
9  * Copyright (C) 2016 Alexander Noe <[email protected]>
10  *
11  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
12  ******************************************************************************/
13 
14 #pragma once
15 #ifndef THRILL_API_GROUP_BY_KEY_HEADER
16 #define THRILL_API_GROUP_BY_KEY_HEADER
17 
18 #include <thrill/api/dia.hpp>
19 #include <thrill/api/dop_node.hpp>
22 #include <thrill/common/logger.hpp>
25 #include <thrill/data/file.hpp>
26 
27 #include <algorithm>
28 #include <deque>
29 #include <functional>
30 #include <type_traits>
31 #include <typeinfo>
32 #include <unordered_map>
33 #include <utility>
34 #include <vector>
35 
36 namespace thrill {
37 namespace api {
38 
39 /*!
40  * \ingroup api_layer
41  */
42 template <typename ValueType,
43  typename KeyExtractor, typename GroupFunction, typename HashFunction,
44  bool UseLocationDetection>
45 class GroupByNode final : public DOpNode<ValueType>
46 {
47 private:
48  static constexpr bool debug = false;
49 
51  using Super::context_;
52 
53  using Key = typename common::FunctionTraits<KeyExtractor>::result_type;
54  using ValueOut = ValueType;
55  using ValueIn =
56  typename common::FunctionTraits<KeyExtractor>::template arg_plain<0>;
57 
58  struct ValueComparator {
59  public:
60  explicit ValueComparator(const GroupByNode& node) : node_(node) { }
61 
62  bool operator () (const ValueIn& a, const ValueIn& b) const {
63  return node_.key_extractor_(a) < node_.key_extractor_(b);
64  }
65 
66  private:
68  };
69 
70  class HashCount
71  {
72  public:
73  using HashType = size_t;
74  using CounterType = uint8_t;
75 
76  size_t hash;
78 
79  static constexpr size_t counter_bits_ = 8 * sizeof(CounterType);
80 
81  HashCount operator + (const HashCount& b) const {
82  assert(hash == b.hash);
84  }
85 
87  assert(hash == b.hash);
89  return *this;
90  }
91 
92  bool operator < (const HashCount& b) const { return hash < b.hash; }
93 
94  //! method to check if this hash count should be broadcasted to all
95  //! workers interested -- for GroupByKey -> always.
96  bool NeedBroadcast() const {
97  return true;
98  }
99 
100  //! Read count from BitReader
101  template <typename BitReader>
102  void ReadBits(BitReader& reader) {
103  count = reader.GetBits(counter_bits_);
104  }
105 
106  //! Write count and dia_mask to BitWriter
107  template <typename BitWriter>
108  void WriteBits(BitWriter& writer) const {
109  writer.PutBits(count, counter_bits_);
110  }
111  };
112 
113 public:
114  /*!
115  * Constructor for a GroupByNode. Sets the DataManager, parent, stack,
116  * key_extractor and reduce_function.
117  */
118  template <typename ParentDIA>
119  GroupByNode(const ParentDIA& parent,
120  const KeyExtractor& key_extractor,
121  const GroupFunction& groupby_function,
122  const HashFunction& hash_function = HashFunction())
123  : Super(parent.ctx(), "GroupByKey", { parent.id() }, { parent.node() }),
124  key_extractor_(key_extractor),
125  groupby_function_(groupby_function),
126  hash_function_(hash_function),
127  location_detection_(parent.ctx(), Super::id()),
128  pre_file_(context_.GetFile(this))
129  {
130  // Hook PreOp
131  auto pre_op_fn = [=](const ValueIn& input) {
132  PreOp(input);
133  };
134  // close the function stack with our pre op and register it at
135  // parent node for output
136  auto lop_chain = parent.stack().push(pre_op_fn).fold();
137  parent.node()->AddChild(this, lop_chain);
138  }
139 
140  void StartPreOp(size_t /* id */) final {
141  emitters_ = stream_->GetWriters();
143  if (UseLocationDetection)
145  }
146 
147  //! Send all elements to their designated PEs
148  void PreOp(const ValueIn& v) {
149  size_t hash = hash_function_(key_extractor_(v));
150  if (UseLocationDetection) {
151  pre_writer_.Put(v);
152  location_detection_.Insert(HashCount { hash, 1 });
153  }
154  else {
155  const size_t recipient = hash % emitters_.size();
156  emitters_[recipient].Put(v);
157  }
158  }
159 
160  void StopPreOp(size_t /* id */) final {
161  pre_writer_.Close();
162  }
163 
164  DIAMemUse PreOpMemUse() final {
165  return DIAMemUse::Max();
166  }
167 
168  DIAMemUse ExecuteMemUse() final {
169  return DIAMemUse::Max();
170  }
171 
172  DIAMemUse PushDataMemUse() final {
173  if (files_.size() <= 1) {
174  // direct push, no merge necessary
175  return 0;
176  }
177  else {
178  // need to perform multiway merging
179  return DIAMemUse::Max();
180  }
181  }
182 
183  void Execute() override {
184  if (UseLocationDetection) {
185  std::unordered_map<size_t, size_t> target_processors;
186  size_t max_hash = location_detection_.Flush(target_processors);
187  auto file_reader = pre_file_.GetConsumeReader();
188  while (file_reader.HasNext()) {
189  ValueIn in = file_reader.template Next<ValueIn>();
190  Key key = key_extractor_(in);
191 
192  size_t hr = hash_function_(key) % max_hash;
193  auto target_processor = target_processors.find(hr);
194  emitters_[target_processor->second].Put(in);
195  }
196  }
197  // data has been pushed during pre-op -> close emitters
198  emitters_.Close();
199 
200  MainOp();
201  }
202 
203  void PushData(bool consume) final {
204  LOG << "sort data";
206  const size_t num_runs = files_.size();
207  if (num_runs == 0) {
208  // nothing to push
209  }
210  else if (num_runs == 1) {
211  // if there's only one run, call user funcs
212  RunUserFunc(files_[0], consume);
213  }
214  else {
215  // otherwise sort all runs using multiway merge
216  size_t merge_degree, prefetch;
217 
218  // merge batches of files if necessary
219  while (std::tie(merge_degree, prefetch) =
221  files_.size() > merge_degree)
222  {
223  sLOG1 << "Partial multi-way-merge of"
224  << merge_degree << "files with prefetch" << prefetch;
225 
226  // create merger for first merge_degree_ Files
227  std::vector<data::File::ConsumeReader> seq;
228  seq.reserve(merge_degree);
229 
230  for (size_t t = 0; t < merge_degree; ++t) {
231  seq.emplace_back(
232  files_[t].GetConsumeReader(/* prefetch */ 0));
233  }
234 
235  StartPrefetch(seq, prefetch);
236 
237  auto puller = core::make_multiway_merge_tree<ValueIn>(
238  seq.begin(), seq.end(), ValueComparator(*this));
239 
240  // create new File for merged items
241  files_.emplace_back(context_.GetFile(this));
242  auto writer = files_.back().GetWriter();
243 
244  while (puller.HasNext()) {
245  writer.Put(puller.Next());
246  }
247  writer.Close();
248 
249  // this clear is important to release references to the files.
250  seq.clear();
251 
252  // remove merged files
253  files_.erase(files_.begin(), files_.begin() + merge_degree);
254  }
255 
256  std::vector<data::File::Reader> seq;
257  seq.reserve(num_runs);
258 
259  for (size_t t = 0; t < num_runs; ++t) {
260  seq.emplace_back(
261  files_[t].GetReader(consume, /* prefetch */ 0));
262  }
263 
264  StartPrefetch(seq, prefetch);
265 
266  LOG << "start multiwaymerge for real";
267  auto puller = core::make_multiway_merge_tree<ValueIn>(
268  seq.begin(), seq.end(), ValueComparator(*this));
269 
270  LOG << "run user func";
271  if (puller.HasNext()) {
272  // create iterator to pass to user_function
273  auto user_iterator = GroupByMultiwayMergeIterator<
274  ValueIn, KeyExtractor, ValueComparator>(
275  puller, key_extractor_);
276 
277  while (user_iterator.HasNextForReal()) {
278  // call user function
279  const ValueOut res = groupby_function_(
280  user_iterator, user_iterator.GetNextKey());
281  // push result to callback functions
282  this->PushItem(res);
283  }
284  }
285  }
286  timer.Stop();
287  LOG << "RESULT"
288  << " name=multiwaymerge"
289  << " time=" << timer
290  << " multiwaymerge=" << (num_runs > 1);
291  }
292 
293  void Dispose() override { }
294 
295 private:
296  KeyExtractor key_extractor_;
297  GroupFunction groupby_function_;
298  HashFunction hash_function_;
299 
301 
304 
305  std::deque<data::File> files_;
307  size_t totalsize_ = 0;
308 
309  //! location detection and associated files
312 
313  void RunUserFunc(data::File& f, bool consume) {
314  auto r = f.GetReader(consume);
315  if (r.HasNext()) {
316  // create iterator to pass to user_function
317  LOG << "get iterator";
318  auto user_iterator = GroupByIterator<
319  ValueIn, KeyExtractor, ValueComparator>(r, key_extractor_);
320  LOG << "start running user func";
321  while (user_iterator.HasNextForReal()) {
322  // call user function
323  const ValueOut res = groupby_function_(user_iterator,
324  user_iterator.GetNextKey());
325  // push result to callback functions
326  this->PushItem(res);
327  }
328  LOG << "finished user func";
329  }
330  }
331 
332  //! Sort and store elements in a file
333  void FlushVectorToFile(std::vector<ValueIn>& v) {
334  // sort run and sort to file
335  std::sort(v.begin(), v.end(), ValueComparator(*this));
336  totalsize_ += v.size();
337 
338  files_.emplace_back(context_.GetFile(this));
339  data::File::Writer w = files_.back().GetWriter();
340  for (const ValueIn& e : v) {
341  w.Put(e);
342  }
343  w.Close();
344  }
345 
346  //! Receive elements from other workers.
347  void MainOp() {
348  LOG << "running group by main op";
349 
350  std::vector<ValueIn> incoming;
351 
353  // get incoming elements
354  auto reader = stream_->GetCatReader(/* consume */ true);
355  while (reader.HasNext()) {
356  // if vector is full save to disk
357  if (mem::memory_exceeded) {
358  FlushVectorToFile(incoming);
359  incoming.clear();
360  }
361  // store incoming element
362  incoming.emplace_back(reader.template Next<ValueIn>());
363  }
364  FlushVectorToFile(incoming);
365  std::vector<ValueIn>().swap(incoming);
366  LOG << "finished receiving elems";
367  stream_.reset();
368 
369  timer.Stop();
370 
371  LOG << "RESULT"
372  << " name=mainop"
373  << " time=" << timer
374  << " number_files=" << files_.size();
375  }
376 };
377 
378 /******************************************************************************/
379 
380 template <typename ValueType, typename Stack>
381 template <typename ValueOut, bool LocationDetectionValue,
382  typename KeyExtractor, typename GroupFunction, typename HashFunction>
385  const KeyExtractor& key_extractor,
386  const GroupFunction& groupby_function,
387  const HashFunction& hash_function) const {
388 
389  static_assert(
390  std::is_same<
391  typename std::decay<typename common::FunctionTraits<KeyExtractor>
392  ::template arg<0> >::type,
393  ValueType>::value,
394  "KeyExtractor has the wrong input type");
395 
397  ValueOut, KeyExtractor, GroupFunction, HashFunction,
398  LocationDetectionValue>;
399 
400  auto node = tlx::make_counting<GroupByNode>(
401  *this, key_extractor, groupby_function, hash_function);
402 
403  return DIA<ValueOut>(node);
404 }
405 
406 template <typename ValueType, typename Stack>
407 template <typename ValueOut, typename KeyExtractor,
408  typename GroupFunction, typename HashFunction>
410  const KeyExtractor& key_extractor,
411  const GroupFunction& groupby_function,
412  const HashFunction& hash_function) const {
413  // forward to other method _without_ location detection
414  return GroupByKey<ValueOut>(
415  NoLocationDetectionTag, key_extractor, groupby_function, hash_function);
416 }
417 
418 template <typename ValueType, typename Stack>
419 template <typename ValueOut, typename KeyExtractor, typename GroupFunction>
421  const KeyExtractor& key_extractor,
422  const GroupFunction& groupby_function) const {
423  // forward to other method _without_ location detection
424  return GroupByKey<ValueOut>(
425  NoLocationDetectionTag, key_extractor, groupby_function,
426  std::hash<typename FunctionTraits<KeyExtractor>::result_type>());
427 }
428 
429 } // namespace api
430 } // namespace thrill
431 
432 #endif // !THRILL_API_GROUP_BY_KEY_HEADER
433 
434 /******************************************************************************/
void StartPrefetch(std::vector< Reader > &readers, size_t prefetch_size)
Take a vector of Readers and prefetch equally from them.
Definition: file.hpp:570
static DIAMemUse Max()
Definition: dia_base.hpp:60
virtual void Dispose()
Virtual clear method. Triggers actual disposing in sub-classes.
Definition: dia_base.hpp:188
DIA is the interface between the user and the Thrill framework.
Definition: dia.hpp:141
core::LocationDetection< HashCount > location_detection_
ValueType_ ValueType
Definition: dia.hpp:152
virtual DIAMemUse ExecuteMemUse()
Amount of RAM used by Execute()
Definition: dia_base.hpp:176
GroupFunction groupby_function_
std::pair< size_t, size_t > MaxMergeDegreePrefetch(size_t num_files)
Definition: block_pool.cpp:703
void WriteBits(BitWriter &writer) const
Write count and dia_mask to BitWriter.
A File is an ordered sequence of Block objects for storing items.
Definition: file.hpp:56
void reset()
release contained pointer, frees object if this is the last reference.
BlockWriter contains a temporary Block object into which a) any serializable item can be stored or b)...
typename common::FunctionTraits< KeyExtractor >::result_type Key
bool operator<(const HashCount &b) const
bool memory_exceeded
memory limit exceeded indicator
void ReadBits(BitReader &reader)
Read count from BitReader.
#define sLOG1
Definition: logger.hpp:38
virtual void PushData(bool consume)=0
Virtual method for pushing data. Triggers actual pushing in sub-classes.
data::File pre_file_
location detection and associated files
void Close()
custom destructor to close writers is a cyclic fashion
Definition: stream_data.cpp:66
const size_t & id() const
return unique id() of DIANode subclass as stored by StatsNode
Definition: dia_base.hpp:213
static constexpr bool debug
virtual DIAMemUse PreOpMemUse()
Amount of RAM used by PreOp after StartPreOp()
Definition: dia_base.hpp:160
An extra class derived from std::vector<> for delivery of the BlockWriters of a Stream.
Definition: stream_data.hpp:56
data::CatStreamPtr stream_
data::CatStreamPtr GetNewCatStream(size_t dia_id)
Definition: context.cpp:1120
void PushItem(const ValueType &item) const
Method for derived classes to Push a single item to all children.
Definition: dia_node.hpp:147
ConsumeReader GetConsumeReader(size_t prefetch_size=File::default_prefetch_size_)
Get consuming BlockReader for beginning of File.
Definition: file.cpp:73
virtual void StopPreOp(size_t)
Virtual method for preparing end of PushData.
Definition: dia_base.hpp:173
void swap(CountingPtr< A, D > &a1, CountingPtr< A, D > &a2) noexcept
virtual void Execute()=0
Virtual execution method. Triggers actual computation in sub-classes.
int value
Definition: gen_data.py:41
HashCount & operator+=(const HashCount &b)
StatsTimerBaseStarted< true > StatsTimerStart
typename common::FunctionTraits< KeyExtractor >::template arg_plain< 0 > ValueIn
data::File::Writer pre_writer_
data::CatStream::Writers emitters_
HashCount operator+(const HashCount &b) const
auto GroupByKey(const KeyExtractor &key_extractor, const GroupByFunction &groupby_function) const
GroupByKey is a DOp, which groups elements of the DIA by its key.
const struct LocationDetectionFlag< false > NoLocationDetectionTag
global const LocationDetectionFlag instance
Definition: dia.hpp:125
virtual void StartPreOp(size_t)
Virtual method for preparing start of PushData.
Definition: dia_base.hpp:163
static IntegerType AddTruncToType(const IntegerType &a, const IntegerType &b)
Definition: math.hpp:31
data::File GetFile(size_t dia_id)
Returns a new File object containing a sequence of local Blocks.
Definition: context.hpp:280
DIAMemUse mem_limit_
Definition: dia_base.hpp:314
void RunUserFunc(data::File &f, bool consume)
void FlushVectorToFile(std::vector< ValueIn > &v)
Sort and store elements in a file.
A DOpNode is a typed node representing and distributed operations in Thrill.
Definition: dop_node.hpp:32
void MainOp()
Receive elements from other workers.
Reader GetReader(bool consume, size_t prefetch_size=File::default_prefetch_size_)
Get BlockReader or a consuming BlockReader for beginning of File.
Definition: file.cpp:78
std::deque< data::File > files_
TLX_ATTRIBUTE_ALWAYS_INLINE BlockWriter & Put(const T &x)
Put appends a complete item, or fails with a FullException.
GroupByNode(const ParentDIA &parent, const KeyExtractor &key_extractor, const GroupFunction &groupby_function, const HashFunction &hash_function=HashFunction())
Constructor for a GroupByNode.
void Close()
Explicitly close the writer.
tag structure for GroupByKey(), and InnerJoin()
Definition: dia.hpp:116
HashCrc32< T > hash
Select a hashing method.
Definition: hash.hpp:262
#define LOG
Default logging method: output if the local debug variable is true.
Definition: logger.hpp:24
ValueComparator(const GroupByNode &node)
static constexpr size_t counter_bits_
bool operator()(const ValueIn &a, const ValueIn &b) const
Writer GetWriter(size_t block_size=default_block_size)
Get BlockWriter.
Definition: file.cpp:63
Context & context_
associated Context
Definition: dia_base.hpp:293
virtual DIAMemUse PushDataMemUse()
Amount of RAM used by PushData()
Definition: dia_base.hpp:182
data::BlockPool & block_pool()
the block manager keeps all data blocks moving through the system.
Definition: context.hpp:321