Thrill  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
group_by_key.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * thrill/api/group_by_key.hpp
3  *
4  * DIANode for a groupby operation. Performs the actual groupby operation
5  *
6  * Part of Project Thrill - http://project-thrill.org
7  *
8  * Copyright (C) 2015 Huyen Chau Nguyen <[email protected]>
9  * Copyright (C) 2016 Alexander Noe <[email protected]>
10  *
11  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
12  ******************************************************************************/
13 
14 #pragma once
15 #ifndef THRILL_API_GROUP_BY_KEY_HEADER
16 #define THRILL_API_GROUP_BY_KEY_HEADER
17 
18 #include <thrill/api/dia.hpp>
19 #include <thrill/api/dop_node.hpp>
22 #include <thrill/common/logger.hpp>
25 #include <thrill/data/file.hpp>
26 
27 #include <algorithm>
28 #include <deque>
29 #include <functional>
30 #include <type_traits>
31 #include <typeinfo>
32 #include <unordered_map>
33 #include <utility>
34 #include <vector>
35 
36 namespace thrill {
37 namespace api {
38 
39 /*!
40  * \ingroup api_layer
41  */
42 template <typename ValueType,
43  typename KeyExtractor, typename GroupFunction, typename HashFunction,
44  bool UseLocationDetection>
45 class GroupByNode final : public DOpNode<ValueType>
46 {
47 private:
48  static constexpr bool debug = false;
49 
51  using Super::context_;
52 
53  using Key = typename common::FunctionTraits<KeyExtractor>::result_type;
54  using ValueOut = ValueType;
55  using ValueIn =
56  typename common::FunctionTraits<KeyExtractor>::template arg_plain<0>;
57 
58  struct ValueComparator {
59  public:
60  explicit ValueComparator(const GroupByNode& node) : node_(node) { }
61 
62  bool operator () (const ValueIn& a, const ValueIn& b) const {
63  return node_.key_extractor_(a) < node_.key_extractor_(b);
64  }
65 
66  private:
68  };
69 
70  class HashCount
71  {
72  public:
73  using HashType = size_t;
74  using CounterType = uint8_t;
75 
76  size_t hash;
78 
79  static constexpr size_t counter_bits_ = 8 * sizeof(CounterType);
80 
81  HashCount operator + (const HashCount& b) const {
82  assert(hash == b.hash);
84  }
85 
87  assert(hash == b.hash);
89  return *this;
90  }
91 
92  bool operator < (const HashCount& b) const { return hash < b.hash; }
93 
94  //! method to check if this hash count should be broadcasted to all
95  //! workers interested -- for GroupByKey -> always.
96  bool NeedBroadcast() const {
97  return true;
98  }
99 
100  //! Read count from BitReader
101  template <typename BitReader>
102  void ReadBits(BitReader& reader) {
103  count = reader.GetBits(counter_bits_);
104  }
105 
106  //! Write count and dia_mask to BitWriter
107  template <typename BitWriter>
108  void WriteBits(BitWriter& writer) const {
109  writer.PutBits(count, counter_bits_);
110  }
111  };
112 
113 public:
114  /*!
115  * Constructor for a GroupByNode. Sets the DataManager, parent, stack,
116  * key_extractor and reduce_function.
117  */
118  template <typename ParentDIA>
119  GroupByNode(const ParentDIA& parent,
120  const KeyExtractor& key_extractor,
121  const GroupFunction& groupby_function,
122  const HashFunction& hash_function = HashFunction())
123  : Super(parent.ctx(), "GroupByKey", { parent.id() }, { parent.node() }),
124  key_extractor_(key_extractor),
125  groupby_function_(groupby_function),
126  hash_function_(hash_function),
127  location_detection_(parent.ctx(), Super::dia_id()),
128  pre_file_(context_.GetFile(this)) {
129  // Hook PreOp
130  auto pre_op_fn = [=](const ValueIn& input) {
131  PreOp(input);
132  };
133  // close the function stack with our pre op and register it at
134  // parent node for output
135  auto lop_chain = parent.stack().push(pre_op_fn).fold();
136  parent.node()->AddChild(this, lop_chain);
137  }
138 
139  void StartPreOp(size_t /* parent_index */) final {
140  emitters_ = stream_->GetWriters();
142  if (UseLocationDetection)
144  }
145 
146  //! Send all elements to their designated PEs
147  void PreOp(const ValueIn& v) {
148  size_t hash = hash_function_(key_extractor_(v));
149  if (UseLocationDetection) {
150  pre_writer_.Put(v);
151  location_detection_.Insert(HashCount { hash, 1 });
152  }
153  else {
154  const size_t recipient = hash % emitters_.size();
155  emitters_[recipient].Put(v);
156  }
157  }
158 
159  void StopPreOp(size_t /* parent_index */) final {
160  pre_writer_.Close();
161  }
162 
163  DIAMemUse PreOpMemUse() final {
164  return DIAMemUse::Max();
165  }
166 
167  DIAMemUse ExecuteMemUse() final {
168  return DIAMemUse::Max();
169  }
170 
171  DIAMemUse PushDataMemUse() final {
172  if (files_.size() <= 1) {
173  // direct push, no merge necessary
174  return 0;
175  }
176  else {
177  // need to perform multiway merging
178  return DIAMemUse::Max();
179  }
180  }
181 
182  void Execute() override {
183  if (UseLocationDetection) {
184  std::unordered_map<size_t, size_t> target_processors;
185  size_t max_hash = location_detection_.Flush(target_processors);
186  auto file_reader = pre_file_.GetConsumeReader();
187  while (file_reader.HasNext()) {
188  ValueIn in = file_reader.template Next<ValueIn>();
189  Key key = key_extractor_(in);
190 
191  size_t hr = hash_function_(key) % max_hash;
192  auto target_processor = target_processors.find(hr);
193  emitters_[target_processor->second].Put(in);
194  }
195  }
196  // data has been pushed during pre-op -> close emitters
197  emitters_.Close();
198 
199  MainOp();
200  }
201 
202  void PushData(bool consume) final {
203  LOG << "sort data";
205  const size_t num_runs = files_.size();
206  if (num_runs == 0) {
207  // nothing to push
208  }
209  else if (num_runs == 1) {
210  // if there's only one run, call user funcs
211  RunUserFunc(files_[0], consume);
212  }
213  else {
214  // otherwise sort all runs using multiway merge
215  size_t merge_degree, prefetch;
216 
217  // merge batches of files if necessary
218  while (std::tie(merge_degree, prefetch) =
220  files_.size() > merge_degree)
221  {
222  sLOG1 << "Partial multi-way-merge of"
223  << merge_degree << "files with prefetch" << prefetch;
224 
225  // create merger for first merge_degree_ Files
226  std::vector<data::File::ConsumeReader> seq;
227  seq.reserve(merge_degree);
228 
229  for (size_t t = 0; t < merge_degree; ++t) {
230  seq.emplace_back(
231  files_[t].GetConsumeReader(/* prefetch */ 0));
232  }
233 
234  StartPrefetch(seq, prefetch);
235 
236  auto puller = core::make_multiway_merge_tree<ValueIn>(
237  seq.begin(), seq.end(), ValueComparator(*this));
238 
239  // create new File for merged items
240  files_.emplace_back(context_.GetFile(this));
241  auto writer = files_.back().GetWriter();
242 
243  while (puller.HasNext()) {
244  writer.Put(puller.Next());
245  }
246  writer.Close();
247 
248  // this clear is important to release references to the files.
249  seq.clear();
250 
251  // remove merged files
252  files_.erase(files_.begin(), files_.begin() + merge_degree);
253  }
254 
255  std::vector<data::File::Reader> seq;
256  seq.reserve(num_runs);
257 
258  for (size_t t = 0; t < num_runs; ++t) {
259  seq.emplace_back(
260  files_[t].GetReader(consume, /* prefetch */ 0));
261  }
262 
263  StartPrefetch(seq, prefetch);
264 
265  LOG << "start multiwaymerge for real";
266  auto puller = core::make_multiway_merge_tree<ValueIn>(
267  seq.begin(), seq.end(), ValueComparator(*this));
268 
269  LOG << "run user func";
270  if (puller.HasNext()) {
271  // create iterator to pass to user_function
272  auto user_iterator = GroupByMultiwayMergeIterator<
273  ValueIn, KeyExtractor, ValueComparator>(
274  puller, key_extractor_);
275 
276  while (user_iterator.HasNextForReal()) {
277  // call user function
278  const ValueOut res = groupby_function_(
279  user_iterator, user_iterator.GetNextKey());
280  // push result to callback functions
281  this->PushItem(res);
282  }
283  }
284  }
285  timer.Stop();
286  LOG << "RESULT"
287  << " name=multiwaymerge"
288  << " time=" << timer
289  << " multiwaymerge=" << (num_runs > 1);
290  }
291 
292  void Dispose() override { }
293 
294 private:
295  KeyExtractor key_extractor_;
296  GroupFunction groupby_function_;
297  HashFunction hash_function_;
298 
300 
303 
304  std::deque<data::File> files_;
306  size_t totalsize_ = 0;
307 
308  //! location detection and associated files
311 
312  void RunUserFunc(data::File& f, bool consume) {
313  auto r = f.GetReader(consume);
314  if (r.HasNext()) {
315  // create iterator to pass to user_function
316  LOG << "get iterator";
317  auto user_iterator = GroupByIterator<
318  ValueIn, KeyExtractor, ValueComparator>(r, key_extractor_);
319  LOG << "start running user func";
320  while (user_iterator.HasNextForReal()) {
321  // call user function
322  const ValueOut res = groupby_function_(user_iterator,
323  user_iterator.GetNextKey());
324  // push result to callback functions
325  this->PushItem(res);
326  }
327  LOG << "finished user func";
328  }
329  }
330 
331  //! Sort and store elements in a file
332  void FlushVectorToFile(std::vector<ValueIn>& v) {
333  // sort run and sort to file
334  std::sort(v.begin(), v.end(), ValueComparator(*this));
335  totalsize_ += v.size();
336 
337  files_.emplace_back(context_.GetFile(this));
338  data::File::Writer w = files_.back().GetWriter();
339  for (const ValueIn& e : v) {
340  w.Put(e);
341  }
342  w.Close();
343  }
344 
345  //! Receive elements from other workers.
346  void MainOp() {
347  LOG << "running group by main op";
348 
349  std::vector<ValueIn> incoming;
350 
352  // get incoming elements
353  auto reader = stream_->GetCatReader(/* consume */ true);
354  while (reader.HasNext()) {
355  // if vector is full save to disk
356  if (mem::memory_exceeded) {
357  FlushVectorToFile(incoming);
358  incoming.clear();
359  }
360  // store incoming element
361  incoming.emplace_back(reader.template Next<ValueIn>());
362  }
363  FlushVectorToFile(incoming);
364  std::vector<ValueIn>().swap(incoming);
365  LOG << "finished receiving elems";
366  stream_.reset();
367 
368  timer.Stop();
369 
370  LOG << "RESULT"
371  << " name=mainop"
372  << " time=" << timer
373  << " number_files=" << files_.size();
374  }
375 };
376 
377 /******************************************************************************/
378 
379 template <typename ValueType, typename Stack>
380 template <typename ValueOut, bool LocationDetectionValue,
381  typename KeyExtractor, typename GroupFunction, typename HashFunction>
384  const KeyExtractor& key_extractor,
385  const GroupFunction& groupby_function,
386  const HashFunction& hash_function) const {
387 
388  static_assert(
389  std::is_same<
390  typename std::decay<typename common::FunctionTraits<KeyExtractor>
391  ::template arg<0> >::type,
392  ValueType>::value,
393  "KeyExtractor has the wrong input type");
394 
396  ValueOut, KeyExtractor, GroupFunction, HashFunction,
397  LocationDetectionValue>;
398 
399  auto node = tlx::make_counting<GroupByNode>(
400  *this, key_extractor, groupby_function, hash_function);
401 
402  return DIA<ValueOut>(node);
403 }
404 
405 template <typename ValueType, typename Stack>
406 template <typename ValueOut, typename KeyExtractor,
407  typename GroupFunction, typename HashFunction>
409  const KeyExtractor& key_extractor,
410  const GroupFunction& groupby_function,
411  const HashFunction& hash_function) const {
412  // forward to other method _without_ location detection
413  return GroupByKey<ValueOut>(
414  NoLocationDetectionTag, key_extractor, groupby_function, hash_function);
415 }
416 
417 template <typename ValueType, typename Stack>
418 template <typename ValueOut, typename KeyExtractor, typename GroupFunction>
420  const KeyExtractor& key_extractor,
421  const GroupFunction& groupby_function) const {
422  // forward to other method _without_ location detection
423  return GroupByKey<ValueOut>(
424  NoLocationDetectionTag, key_extractor, groupby_function,
425  std::hash<typename FunctionTraits<KeyExtractor>::result_type>());
426 }
427 
428 } // namespace api
429 } // namespace thrill
430 
431 #endif // !THRILL_API_GROUP_BY_KEY_HEADER
432 
433 /******************************************************************************/
void StartPrefetch(std::vector< Reader > &readers, size_t prefetch_size)
Take a vector of Readers and prefetch equally from them.
Definition: file.hpp:570
static DIAMemUse Max()
Definition: dia_base.hpp:60
virtual void Dispose()
Virtual clear method. Triggers actual disposing in sub-classes.
Definition: dia_base.hpp:188
DIA is the interface between the user and the Thrill framework.
Definition: dia.hpp:141
core::LocationDetection< HashCount > location_detection_
ValueType_ ValueType
Definition: dia.hpp:152
virtual DIAMemUse ExecuteMemUse()
Amount of RAM used by Execute()
Definition: dia_base.hpp:176
GroupFunction groupby_function_
std::pair< size_t, size_t > MaxMergeDegreePrefetch(size_t num_files)
Definition: block_pool.cpp:703
void WriteBits(BitWriter &writer) const
Write count and dia_mask to BitWriter.
A File is an ordered sequence of Block objects for storing items.
Definition: file.hpp:56
void reset()
release contained pointer, frees object if this is the last reference.
BlockWriter contains a temporary Block object into which a) any serializable item can be stored or b)...
typename common::FunctionTraits< KeyExtractor >::result_type Key
bool operator<(const HashCount &b) const
bool memory_exceeded
memory limit exceeded indicator
void ReadBits(BitReader &reader)
Read count from BitReader.
#define sLOG1
Definition: logger.hpp:38
virtual void PushData(bool consume)=0
Virtual method for pushing data. Triggers actual pushing in sub-classes.
data::File pre_file_
location detection and associated files
void Close()
custom destructor to close writers is a cyclic fashion
Definition: stream_data.cpp:66
static constexpr bool debug
virtual DIAMemUse PreOpMemUse()
Amount of RAM used by PreOp after StartPreOp()
Definition: dia_base.hpp:160
An extra class derived from std::vector<> for delivery of the BlockWriters of a Stream.
Definition: stream_data.hpp:56
data::CatStreamPtr stream_
data::CatStreamPtr GetNewCatStream(size_t dia_id)
Definition: context.cpp:1147
void PushItem(const ValueType &item) const
Method for derived classes to Push a single item to all children.
Definition: dia_node.hpp:147
ConsumeReader GetConsumeReader(size_t prefetch_size=File::default_prefetch_size_)
Get consuming BlockReader for beginning of File.
Definition: file.cpp:73
virtual void StopPreOp(size_t)
Virtual method for preparing end of PushData.
Definition: dia_base.hpp:173
void swap(CountingPtr< A, D > &a1, CountingPtr< A, D > &a2) noexcept
virtual void Execute()=0
Virtual execution method. Triggers actual computation in sub-classes.
int value
Definition: gen_data.py:41
HashCount & operator+=(const HashCount &b)
StatsTimerBaseStarted< true > StatsTimerStart
typename common::FunctionTraits< KeyExtractor >::template arg_plain< 0 > ValueIn
data::File::Writer pre_writer_
data::CatStream::Writers emitters_
HashCount operator+(const HashCount &b) const
auto GroupByKey(const KeyExtractor &key_extractor, const GroupByFunction &groupby_function) const
GroupByKey is a DOp, which groups elements of the DIA by its key.
const struct LocationDetectionFlag< false > NoLocationDetectionTag
global const LocationDetectionFlag instance
Definition: dia.hpp:125
virtual void StartPreOp(size_t)
Virtual method for preparing start of PushData.
Definition: dia_base.hpp:163
static IntegerType AddTruncToType(const IntegerType &a, const IntegerType &b)
Definition: math.hpp:31
data::File GetFile(size_t dia_id)
Returns a new File object containing a sequence of local Blocks.
Definition: context.hpp:283
DIAMemUse mem_limit_
Definition: dia_base.hpp:314
void RunUserFunc(data::File &f, bool consume)
void FlushVectorToFile(std::vector< ValueIn > &v)
Sort and store elements in a file.
A DOpNode is a typed node representing and distributed operations in Thrill.
Definition: dop_node.hpp:32
void MainOp()
Receive elements from other workers.
Reader GetReader(bool consume, size_t prefetch_size=File::default_prefetch_size_)
Get BlockReader or a consuming BlockReader for beginning of File.
Definition: file.cpp:78
std::deque< data::File > files_
TLX_ATTRIBUTE_ALWAYS_INLINE BlockWriter & Put(const T &x)
Put appends a complete item, or fails with a FullException.
GroupByNode(const ParentDIA &parent, const KeyExtractor &key_extractor, const GroupFunction &groupby_function, const HashFunction &hash_function=HashFunction())
Constructor for a GroupByNode.
void Close()
Explicitly close the writer.
tag structure for GroupByKey(), and InnerJoin()
Definition: dia.hpp:116
HashCrc32< T > hash
Select a hashing method.
Definition: hash.hpp:262
#define LOG
Default logging method: output if the local debug variable is true.
Definition: logger.hpp:24
ValueComparator(const GroupByNode &node)
static constexpr size_t counter_bits_
bool operator()(const ValueIn &a, const ValueIn &b) const
Writer GetWriter(size_t block_size=default_block_size)
Get BlockWriter.
Definition: file.cpp:63
Context & context_
associated Context
Definition: dia_base.hpp:293
virtual DIAMemUse PushDataMemUse()
Amount of RAM used by PushData()
Definition: dia_base.hpp:182
data::BlockPool & block_pool()
the block manager keeps all data blocks moving through the system.
Definition: context.hpp:324
const size_t & dia_id() const
return unique id of DIANode subclass as stored by StatsNode
Definition: dia_base.hpp:213