Thrill  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
group_by_key.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * thrill/api/group_by_key.hpp
3  *
4  * DIANode for a groupby operation. Performs the actual groupby operation
5  *
6  * Part of Project Thrill - http://project-thrill.org
7  *
8  * Copyright (C) 2015 Huyen Chau Nguyen <[email protected]>
9  * Copyright (C) 2016 Alexander Noe <[email protected]>
10  *
11  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
12  ******************************************************************************/
13 
14 #pragma once
15 #ifndef THRILL_API_GROUP_BY_KEY_HEADER
16 #define THRILL_API_GROUP_BY_KEY_HEADER
17 
18 #include <thrill/api/dia.hpp>
19 #include <thrill/api/dop_node.hpp>
22 #include <thrill/common/logger.hpp>
25 #include <thrill/data/file.hpp>
26 
27 #include <tlx/vector_free.hpp>
28 
29 #include <algorithm>
30 #include <deque>
31 #include <functional>
32 #include <type_traits>
33 #include <typeinfo>
34 #include <unordered_map>
35 #include <utility>
36 #include <vector>
37 
38 namespace thrill {
39 namespace api {
40 
41 /*!
42  * \ingroup api_layer
43  */
44 template <typename ValueType,
45  typename KeyExtractor, typename GroupFunction, typename HashFunction,
46  bool UseLocationDetection>
47 class GroupByNode final : public DOpNode<ValueType>
48 {
49 private:
50  static constexpr bool debug = false;
51 
53  using Super::context_;
54 
55  using Key = typename common::FunctionTraits<KeyExtractor>::result_type;
56  using ValueOut = ValueType;
57  using ValueIn =
58  typename common::FunctionTraits<KeyExtractor>::template arg_plain<0>;
59 
60  struct ValueComparator {
61  public:
62  explicit ValueComparator(const GroupByNode& node) : node_(node) { }
63 
64  bool operator () (const ValueIn& a, const ValueIn& b) const {
65  return node_.key_extractor_(a) < node_.key_extractor_(b);
66  }
67 
68  private:
70  };
71 
72  class HashCount
73  {
74  public:
75  using HashType = size_t;
76  using CounterType = uint8_t;
77 
78  size_t hash;
80 
81  static constexpr size_t counter_bits_ = 8 * sizeof(CounterType);
82 
83  HashCount operator + (const HashCount& b) const {
84  assert(hash == b.hash);
86  }
87 
89  assert(hash == b.hash);
91  return *this;
92  }
93 
94  bool operator < (const HashCount& b) const { return hash < b.hash; }
95 
96  //! method to check if this hash count should be broadcasted to all
97  //! workers interested -- for GroupByKey -> always.
98  bool NeedBroadcast() const {
99  return true;
100  }
101 
102  //! Read count from BitReader
103  template <typename BitReader>
104  void ReadBits(BitReader& reader) {
105  count = reader.GetBits(counter_bits_);
106  }
107 
108  //! Write count and dia_mask to BitWriter
109  template <typename BitWriter>
110  void WriteBits(BitWriter& writer) const {
111  writer.PutBits(count, counter_bits_);
112  }
113  };
114 
115 public:
116  /*!
117  * Constructor for a GroupByNode. Sets the DataManager, parent, stack,
118  * key_extractor and reduce_function.
119  */
120  template <typename ParentDIA>
121  GroupByNode(const ParentDIA& parent,
122  const KeyExtractor& key_extractor,
123  const GroupFunction& groupby_function,
124  const HashFunction& hash_function = HashFunction())
125  : Super(parent.ctx(), "GroupByKey", { parent.id() }, { parent.node() }),
126  key_extractor_(key_extractor),
127  groupby_function_(groupby_function),
128  hash_function_(hash_function),
129  location_detection_(parent.ctx(), Super::dia_id()),
130  pre_file_(context_.GetFile(this)) {
131  // Hook PreOp
132  auto pre_op_fn = [=](const ValueIn& input) {
133  PreOp(input);
134  };
135  // close the function stack with our pre op and register it at
136  // parent node for output
137  auto lop_chain = parent.stack().push(pre_op_fn).fold();
138  parent.node()->AddChild(this, lop_chain);
139  }
140 
141  void StartPreOp(size_t /* parent_index */) final {
142  emitters_ = stream_->GetWriters();
144  if (UseLocationDetection)
146  }
147 
148  //! Send all elements to their designated PEs
149  void PreOp(const ValueIn& v) {
150  size_t hash = hash_function_(key_extractor_(v));
151  if (UseLocationDetection) {
152  pre_writer_.Put(v);
153  location_detection_.Insert(HashCount { hash, 1 });
154  }
155  else {
156  const size_t recipient = hash % emitters_.size();
157  emitters_[recipient].Put(v);
158  }
159  }
160 
161  void StopPreOp(size_t /* parent_index */) final {
162  pre_writer_.Close();
163  }
164 
165  DIAMemUse PreOpMemUse() final {
166  return DIAMemUse::Max();
167  }
168 
169  DIAMemUse ExecuteMemUse() final {
170  return DIAMemUse::Max();
171  }
172 
173  DIAMemUse PushDataMemUse() final {
174  if (files_.size() <= 1) {
175  // direct push, no merge necessary
176  return 0;
177  }
178  else {
179  // need to perform multiway merging
180  return DIAMemUse::Max();
181  }
182  }
183 
184  void Execute() override {
185  if (UseLocationDetection) {
186  std::unordered_map<size_t, size_t> target_processors;
187  size_t max_hash = location_detection_.Flush(target_processors);
188  auto file_reader = pre_file_.GetConsumeReader();
189  while (file_reader.HasNext()) {
190  ValueIn in = file_reader.template Next<ValueIn>();
191  Key key = key_extractor_(in);
192 
193  size_t hr = hash_function_(key) % max_hash;
194  auto target_processor = target_processors.find(hr);
195  emitters_[target_processor->second].Put(in);
196  }
197  }
198  // data has been pushed during pre-op -> close emitters
199  emitters_.Close();
200 
201  MainOp();
202  }
203 
204  void PushData(bool consume) final {
205  LOG << "sort data";
207  const size_t num_runs = files_.size();
208  if (num_runs == 0) {
209  // nothing to push
210  }
211  else if (num_runs == 1) {
212  // if there's only one run, call user funcs
213  RunUserFunc(files_[0], consume);
214  }
215  else {
216  // otherwise sort all runs using multiway merge
217  size_t merge_degree, prefetch;
218 
219  // merge batches of files if necessary
220  while (std::tie(merge_degree, prefetch) =
222  files_.size() > merge_degree)
223  {
224  sLOG1 << "Partial multi-way-merge of"
225  << merge_degree << "files with prefetch" << prefetch;
226 
227  // create merger for first merge_degree_ Files
228  std::vector<data::File::ConsumeReader> seq;
229  seq.reserve(merge_degree);
230 
231  for (size_t t = 0; t < merge_degree; ++t) {
232  seq.emplace_back(
233  files_[t].GetConsumeReader(/* prefetch */ 0));
234  }
235 
236  StartPrefetch(seq, prefetch);
237 
238  auto puller = core::make_multiway_merge_tree<ValueIn>(
239  seq.begin(), seq.end(), ValueComparator(*this));
240 
241  // create new File for merged items
242  files_.emplace_back(context_.GetFile(this));
243  auto writer = files_.back().GetWriter();
244 
245  while (puller.HasNext()) {
246  writer.Put(puller.Next());
247  }
248  writer.Close();
249 
250  // this clear is important to release references to the files.
251  seq.clear();
252 
253  // remove merged files
254  files_.erase(files_.begin(), files_.begin() + merge_degree);
255  }
256 
257  std::vector<data::File::Reader> seq;
258  seq.reserve(num_runs);
259 
260  for (size_t t = 0; t < num_runs; ++t) {
261  seq.emplace_back(
262  files_[t].GetReader(consume, /* prefetch */ 0));
263  }
264 
265  StartPrefetch(seq, prefetch);
266 
267  LOG << "start multiwaymerge for real";
268  auto puller = core::make_multiway_merge_tree<ValueIn>(
269  seq.begin(), seq.end(), ValueComparator(*this));
270 
271  LOG << "run user func";
272  if (puller.HasNext()) {
273  // create iterator to pass to user_function
274  auto user_iterator = GroupByMultiwayMergeIterator<
275  ValueIn, KeyExtractor, ValueComparator>(
276  puller, key_extractor_);
277 
278  while (user_iterator.HasNextForReal()) {
279  // call user function
280  const ValueOut res = groupby_function_(
281  user_iterator, user_iterator.GetNextKey());
282  // push result to callback functions
283  this->PushItem(res);
284  }
285  }
286  }
287  timer.Stop();
288  LOG << "RESULT"
289  << " name=multiwaymerge"
290  << " time=" << timer
291  << " multiwaymerge=" << (num_runs > 1);
292  }
293 
294  void Dispose() override { }
295 
296 private:
297  KeyExtractor key_extractor_;
298  GroupFunction groupby_function_;
299  HashFunction hash_function_;
300 
302 
305 
306  std::deque<data::File> files_;
308  size_t totalsize_ = 0;
309 
310  //! location detection and associated files
313 
314  void RunUserFunc(data::File& f, bool consume) {
315  auto r = f.GetReader(consume);
316  if (r.HasNext()) {
317  // create iterator to pass to user_function
318  LOG << "get iterator";
319  auto user_iterator = GroupByIterator<
320  ValueIn, KeyExtractor, ValueComparator>(r, key_extractor_);
321  LOG << "start running user func";
322  while (user_iterator.HasNextForReal()) {
323  // call user function
324  const ValueOut res = groupby_function_(user_iterator,
325  user_iterator.GetNextKey());
326  // push result to callback functions
327  this->PushItem(res);
328  }
329  LOG << "finished user func";
330  }
331  }
332 
333  //! Sort and store elements in a file
334  void FlushVectorToFile(std::vector<ValueIn>& v) {
335  // sort run and sort to file
336  std::sort(v.begin(), v.end(), ValueComparator(*this));
337  totalsize_ += v.size();
338 
339  files_.emplace_back(context_.GetFile(this));
340  data::File::Writer w = files_.back().GetWriter();
341  for (const ValueIn& e : v) {
342  w.Put(e);
343  }
344  w.Close();
345  }
346 
347  //! Receive elements from other workers.
348  void MainOp() {
349  LOG << "running group by main op";
350 
351  std::vector<ValueIn> incoming;
352 
354  // get incoming elements
355  auto reader = stream_->GetCatReader(/* consume */ true);
356  while (reader.HasNext()) {
357  // if vector is full save to disk
358  if (mem::memory_exceeded) {
359  FlushVectorToFile(incoming);
360  incoming.clear();
361  }
362  // store incoming element
363  incoming.emplace_back(reader.template Next<ValueIn>());
364  }
365  FlushVectorToFile(incoming);
366  tlx::vector_free(incoming);
367  LOG << "finished receiving elems";
368  stream_.reset();
369 
370  timer.Stop();
371 
372  LOG << "RESULT"
373  << " name=mainop"
374  << " time=" << timer
375  << " number_files=" << files_.size();
376  }
377 };
378 
379 /******************************************************************************/
380 
381 template <typename ValueType, typename Stack>
382 template <typename ValueOut, bool LocationDetectionValue,
383  typename KeyExtractor, typename GroupFunction, typename HashFunction>
386  const KeyExtractor& key_extractor,
387  const GroupFunction& groupby_function,
388  const HashFunction& hash_function) const {
389 
390  static_assert(
391  std::is_same<
392  typename std::decay<typename common::FunctionTraits<KeyExtractor>
393  ::template arg<0> >::type,
394  ValueType>::value,
395  "KeyExtractor has the wrong input type");
396 
398  ValueOut, KeyExtractor, GroupFunction, HashFunction,
399  LocationDetectionValue>;
400 
401  auto node = tlx::make_counting<GroupByNode>(
402  *this, key_extractor, groupby_function, hash_function);
403 
404  return DIA<ValueOut>(node);
405 }
406 
407 template <typename ValueType, typename Stack>
408 template <typename ValueOut, typename KeyExtractor,
409  typename GroupFunction, typename HashFunction>
411  const KeyExtractor& key_extractor,
412  const GroupFunction& groupby_function,
413  const HashFunction& hash_function) const {
414  // forward to other method _without_ location detection
415  return GroupByKey<ValueOut>(
416  NoLocationDetectionTag, key_extractor, groupby_function, hash_function);
417 }
418 
419 template <typename ValueType, typename Stack>
420 template <typename ValueOut, typename KeyExtractor, typename GroupFunction>
422  const KeyExtractor& key_extractor,
423  const GroupFunction& groupby_function) const {
424  // forward to other method _without_ location detection
425  return GroupByKey<ValueOut>(
426  NoLocationDetectionTag, key_extractor, groupby_function,
427  std::hash<typename FunctionTraits<KeyExtractor>::result_type>());
428 }
429 
430 } // namespace api
431 } // namespace thrill
432 
433 #endif // !THRILL_API_GROUP_BY_KEY_HEADER
434 
435 /******************************************************************************/
void StartPrefetch(std::vector< Reader > &readers, size_t prefetch_size)
Take a vector of Readers and prefetch equally from them.
Definition: file.hpp:585
static DIAMemUse Max()
Definition: dia_base.hpp:60
virtual void Dispose()
Virtual clear method. Triggers actual disposing in sub-classes.
Definition: dia_base.hpp:188
DIA is the interface between the user and the Thrill framework.
Definition: dia.hpp:141
core::LocationDetection< HashCount > location_detection_
ValueType_ ValueType
Definition: dia.hpp:152
virtual DIAMemUse ExecuteMemUse()
Amount of RAM used by Execute()
Definition: dia_base.hpp:176
GroupFunction groupby_function_
std::pair< size_t, size_t > MaxMergeDegreePrefetch(size_t num_files)
Definition: block_pool.cpp:703
void WriteBits(BitWriter &writer) const
Write count and dia_mask to BitWriter.
A File is an ordered sequence of Block objects for storing items.
Definition: file.hpp:56
void reset()
release contained pointer, frees object if this is the last reference.
BlockWriter contains a temporary Block object into which a) any serializable item can be stored or b)...
typename common::FunctionTraits< KeyExtractor >::result_type Key
bool operator<(const HashCount &b) const
bool memory_exceeded
memory limit exceeded indicator
void ReadBits(BitReader &reader)
Read count from BitReader.
#define sLOG1
Definition: logger.hpp:38
virtual void PushData(bool consume)=0
Virtual method for pushing data. Triggers actual pushing in sub-classes.
data::File pre_file_
location detection and associated files
void Close()
custom destructor to close writers is a cyclic fashion
Definition: stream_data.cpp:66
static constexpr bool debug
virtual DIAMemUse PreOpMemUse()
Amount of RAM used by PreOp after StartPreOp()
Definition: dia_base.hpp:160
An extra class derived from std::vector<> for delivery of the BlockWriters of a Stream.
Definition: stream_data.hpp:56
data::CatStreamPtr stream_
data::CatStreamPtr GetNewCatStream(size_t dia_id)
Definition: context.cpp:1151
void PushItem(const ValueType &item) const
Method for derived classes to Push a single item to all children.
Definition: dia_node.hpp:147
ConsumeReader GetConsumeReader(size_t prefetch_size=File::default_prefetch_size_)
Get consuming BlockReader for beginning of File.
Definition: file.cpp:73
virtual void StopPreOp(size_t)
Virtual method for preparing end of PushData.
Definition: dia_base.hpp:173
virtual void Execute()=0
Virtual execution method. Triggers actual computation in sub-classes.
int value
Definition: gen_data.py:41
HashCount & operator+=(const HashCount &b)
StatsTimerBaseStarted< true > StatsTimerStart
typename common::FunctionTraits< KeyExtractor >::template arg_plain< 0 > ValueIn
data::File::Writer pre_writer_
data::CatStream::Writers emitters_
HashCount operator+(const HashCount &b) const
auto GroupByKey(const KeyExtractor &key_extractor, const GroupByFunction &groupby_function) const
GroupByKey is a DOp, which groups elements of the DIA by its key.
const struct LocationDetectionFlag< false > NoLocationDetectionTag
global const LocationDetectionFlag instance
Definition: dia.hpp:125
virtual void StartPreOp(size_t)
Virtual method for preparing start of PushData.
Definition: dia_base.hpp:163
static IntegerType AddTruncToType(const IntegerType &a, const IntegerType &b)
Definition: math.hpp:31
data::File GetFile(size_t dia_id)
Returns a new File object containing a sequence of local Blocks.
Definition: context.hpp:283
DIAMemUse mem_limit_
Definition: dia_base.hpp:314
void RunUserFunc(data::File &f, bool consume)
void vector_free(std::vector< Type > &v)
Definition: vector_free.hpp:21
void FlushVectorToFile(std::vector< ValueIn > &v)
Sort and store elements in a file.
A DOpNode is a typed node representing and distributed operations in Thrill.
Definition: dop_node.hpp:32
void MainOp()
Receive elements from other workers.
Reader GetReader(bool consume, size_t prefetch_size=File::default_prefetch_size_)
Get BlockReader or a consuming BlockReader for beginning of File.
Definition: file.cpp:78
std::deque< data::File > files_
TLX_ATTRIBUTE_ALWAYS_INLINE BlockWriter & Put(const T &x)
Put appends a complete item, or fails with a FullException.
GroupByNode(const ParentDIA &parent, const KeyExtractor &key_extractor, const GroupFunction &groupby_function, const HashFunction &hash_function=HashFunction())
Constructor for a GroupByNode.
void Close()
Explicitly close the writer.
tag structure for GroupByKey(), and InnerJoin()
Definition: dia.hpp:116
HashCrc32< T > hash
Select a hashing method.
Definition: hash.hpp:262
#define LOG
Default logging method: output if the local debug variable is true.
Definition: logger.hpp:24
ValueComparator(const GroupByNode &node)
static constexpr size_t counter_bits_
bool operator()(const ValueIn &a, const ValueIn &b) const
Writer GetWriter(size_t block_size=default_block_size)
Get BlockWriter.
Definition: file.cpp:63
Context & context_
associated Context
Definition: dia_base.hpp:293
virtual DIAMemUse PushDataMemUse()
Amount of RAM used by PushData()
Definition: dia_base.hpp:182
data::BlockPool & block_pool()
the block manager keeps all data blocks moving through the system.
Definition: context.hpp:324
const size_t & dia_id() const
return unique id of DIANode subclass as stored by StatsNode
Definition: dia_base.hpp:213