Thrill  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
group_by_key.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * thrill/api/group_by_key.hpp
3  *
4  * DIANode for a groupby operation. Performs the actual groupby operation
5  *
6  * Part of Project Thrill - http://project-thrill.org
7  *
8  * Copyright (C) 2015 Huyen Chau Nguyen <[email protected]>
9  * Copyright (C) 2016 Alexander Noe <[email protected]>
10  *
11  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
12  ******************************************************************************/
13 
14 #pragma once
15 #ifndef THRILL_API_GROUP_BY_KEY_HEADER
16 #define THRILL_API_GROUP_BY_KEY_HEADER
17 
18 #include <thrill/api/dia.hpp>
19 #include <thrill/api/dop_node.hpp>
22 #include <thrill/common/logger.hpp>
25 #include <thrill/data/file.hpp>
26 
27 #include <algorithm>
28 #include <deque>
29 #include <functional>
30 #include <type_traits>
31 #include <typeinfo>
32 #include <unordered_map>
33 #include <utility>
34 #include <vector>
35 
36 namespace thrill {
37 namespace api {
38 
39 /*!
40  * \ingroup api_layer
41  */
42 template <typename ValueType,
43  typename KeyExtractor, typename GroupFunction, typename HashFunction,
44  bool UseLocationDetection>
45 class GroupByNode final : public DOpNode<ValueType>
46 {
47 private:
48  static constexpr bool debug = false;
49 
51  using Super::context_;
52 
53  using Key = typename common::FunctionTraits<KeyExtractor>::result_type;
54  using ValueOut = ValueType;
55  using ValueIn =
56  typename common::FunctionTraits<KeyExtractor>::template arg_plain<0>;
57 
58  struct ValueComparator {
59  public:
60  explicit ValueComparator(const GroupByNode& node) : node_(node) { }
61 
62  bool operator () (const ValueIn& a, const ValueIn& b) const {
63  return node_.key_extractor_(a) < node_.key_extractor_(b);
64  }
65 
66  private:
68  };
69 
70  class HashCount
71  {
72  public:
73  using HashType = size_t;
74  using CounterType = uint8_t;
75 
76  size_t hash;
78 
79  static constexpr size_t counter_bits_ = 8 * sizeof(CounterType);
80 
81  HashCount operator + (const HashCount& b) const {
82  assert(hash == b.hash);
84  }
85 
87  assert(hash == b.hash);
89  return *this;
90  }
91 
92  bool operator < (const HashCount& b) const { return hash < b.hash; }
93 
94  //! method to check if this hash count should be broadcasted to all
95  //! workers interested -- for GroupByKey -> always.
96  bool NeedBroadcast() const {
97  return true;
98  }
99 
100  //! Read count from BitReader
101  template <typename BitReader>
102  void ReadBits(BitReader& reader) {
103  count = reader.GetBits(counter_bits_);
104  }
105 
106  //! Write count and dia_mask to BitWriter
107  template <typename BitWriter>
108  void WriteBits(BitWriter& writer) const {
109  writer.PutBits(count, counter_bits_);
110  }
111  };
112 
113 public:
114  /*!
115  * Constructor for a GroupByNode. Sets the DataManager, parent, stack,
116  * key_extractor and reduce_function.
117  */
118  template <typename ParentDIA>
119  GroupByNode(const ParentDIA& parent,
120  const KeyExtractor& key_extractor,
121  const GroupFunction& groupby_function,
122  const HashFunction& hash_function = HashFunction())
123  : Super(parent.ctx(), "GroupByKey", { parent.id() }, { parent.node() }),
124  key_extractor_(key_extractor),
125  groupby_function_(groupby_function),
126  hash_function_(hash_function),
127  location_detection_(parent.ctx(), Super::id()),
128  pre_file_(context_.GetFile(this))
129  {
130  // Hook PreOp
131  auto pre_op_fn = [=](const ValueIn& input) {
132  PreOp(input);
133  };
134  // close the function stack with our pre op and register it at
135  // parent node for output
136  auto lop_chain = parent.stack().push(pre_op_fn).fold();
137  parent.node()->AddChild(this, lop_chain);
138  }
139 
140  void StartPreOp(size_t /* id */) final {
141  emitters_ = stream_->GetWriters();
143  if (UseLocationDetection)
145  }
146 
147  //! Send all elements to their designated PEs
148  void PreOp(const ValueIn& v) {
149  size_t hash = hash_function_(key_extractor_(v));
150  if (UseLocationDetection) {
151  pre_writer_.Put(v);
152  location_detection_.Insert(HashCount { hash, 1 });
153  }
154  else {
155  const size_t recipient = hash % emitters_.size();
156  emitters_[recipient].Put(v);
157  }
158  }
159 
160  void StopPreOp(size_t /* id */) final {
161  pre_writer_.Close();
162  }
163 
164  DIAMemUse PreOpMemUse() final {
165  return DIAMemUse::Max();
166  }
167 
168  DIAMemUse ExecuteMemUse() final {
169  return DIAMemUse::Max();
170  }
171 
172  DIAMemUse PushDataMemUse() final {
173  if (files_.size() <= 1) {
174  // direct push, no merge necessary
175  return 0;
176  }
177  else {
178  // need to perform multiway merging
179  return DIAMemUse::Max();
180  }
181  }
182 
183  void Execute() override {
184  if (UseLocationDetection) {
185  std::unordered_map<size_t, size_t> target_processors;
186  size_t max_hash = location_detection_.Flush(target_processors);
187  auto file_reader = pre_file_.GetConsumeReader();
188  while (file_reader.HasNext()) {
189  ValueIn in = file_reader.template Next<ValueIn>();
190  Key key = key_extractor_(in);
191 
192  size_t hr = hash_function_(key) % max_hash;
193  auto target_processor = target_processors.find(hr);
194  emitters_[target_processor->second].Put(in);
195  }
196  }
197  // data has been pushed during pre-op -> close emitters
198  emitters_.Close();
199 
200  MainOp();
201  }
202 
203  void PushData(bool consume) final {
204  LOG << "sort data";
206  const size_t num_runs = files_.size();
207  if (num_runs == 0) {
208  // nothing to push
209  }
210  else if (num_runs == 1) {
211  // if there's only one run, call user funcs
212  RunUserFunc(files_[0], consume);
213  }
214  else {
215  // otherwise sort all runs using multiway merge
216  size_t merge_degree, prefetch;
217 
218  // merge batches of files if necessary
219  while (files_.size() > MaxMergeDegreePrefetch().first)
220  {
221  std::tie(merge_degree, prefetch) = MaxMergeDegreePrefetch();
222 
223  sLOG1 << "Partial multi-way-merge of"
224  << merge_degree << "files with prefetch" << prefetch;
225 
226  // create merger for first merge_degree_ Files
227  std::vector<data::File::ConsumeReader> seq;
228  seq.reserve(merge_degree);
229 
230  for (size_t t = 0; t < merge_degree; ++t)
231  seq.emplace_back(files_[t].GetConsumeReader(0));
232 
233  StartPrefetch(seq, prefetch);
234 
235  auto puller = core::make_multiway_merge_tree<ValueIn>(
236  seq.begin(), seq.end(), ValueComparator(*this));
237 
238  // create new File for merged items
239  files_.emplace_back(context_.GetFile(this));
240  auto writer = files_.back().GetWriter();
241 
242  while (puller.HasNext()) {
243  writer.Put(puller.Next());
244  }
245  writer.Close();
246 
247  // this clear is important to release references to the files.
248  seq.clear();
249 
250  // remove merged files
251  files_.erase(files_.begin(), files_.begin() + merge_degree);
252  }
253 
254  std::vector<data::File::Reader> seq;
255  seq.reserve(num_runs);
256 
257  for (size_t t = 0; t < num_runs; ++t)
258  seq.emplace_back(files_[t].GetReader(consume));
259 
260  LOG << "start multiwaymerge for real";
261  auto puller = core::make_multiway_merge_tree<ValueIn>(
262  seq.begin(), seq.end(), ValueComparator(*this));
263 
264  LOG << "run user func";
265  if (puller.HasNext()) {
266  // create iterator to pass to user_function
267  auto user_iterator = GroupByMultiwayMergeIterator<
268  ValueIn, KeyExtractor, ValueComparator>(
269  puller, key_extractor_);
270 
271  while (user_iterator.HasNextForReal()) {
272  // call user function
273  const ValueOut res = groupby_function_(
274  user_iterator, user_iterator.GetNextKey());
275  // push result to callback functions
276  this->PushItem(res);
277  }
278  }
279  }
280  timer.Stop();
281  LOG << "RESULT"
282  << " name=multiwaymerge"
283  << " time=" << timer
284  << " multiwaymerge=" << (num_runs > 1);
285  }
286 
287  void Dispose() override { }
288 
289 private:
290  KeyExtractor key_extractor_;
291  GroupFunction groupby_function_;
292  HashFunction hash_function_;
293 
295 
298 
299  std::deque<data::File> files_;
301  size_t totalsize_ = 0;
302 
303  //! location detection and associated files
306 
307  void RunUserFunc(data::File& f, bool consume) {
308  auto r = f.GetReader(consume);
309  if (r.HasNext()) {
310  // create iterator to pass to user_function
311  LOG << "get iterator";
312  auto user_iterator = GroupByIterator<
313  ValueIn, KeyExtractor, ValueComparator>(r, key_extractor_);
314  LOG << "start running user func";
315  while (user_iterator.HasNextForReal()) {
316  // call user function
317  const ValueOut res = groupby_function_(user_iterator,
318  user_iterator.GetNextKey());
319  // push result to callback functions
320  this->PushItem(res);
321  }
322  LOG << "finished user func";
323  }
324  }
325 
326  //! Sort and store elements in a file
327  void FlushVectorToFile(std::vector<ValueIn>& v) {
328  // sort run and sort to file
329  std::sort(v.begin(), v.end(), ValueComparator(*this));
330  totalsize_ += v.size();
331 
332  files_.emplace_back(context_.GetFile(this));
333  data::File::Writer w = files_.back().GetWriter();
334  for (const ValueIn& e : v) {
335  w.Put(e);
336  }
337  w.Close();
338  }
339 
340  //! Receive elements from other workers.
341  void MainOp() {
342  LOG << "running group by main op";
343 
344  std::vector<ValueIn> incoming;
345 
347  // get incoming elements
348  auto reader = stream_->GetCatReader(/* consume */ true);
349  while (reader.HasNext()) {
350  // if vector is full save to disk
351  if (mem::memory_exceeded) {
352  FlushVectorToFile(incoming);
353  incoming.clear();
354  }
355  // store incoming element
356  incoming.emplace_back(reader.template Next<ValueIn>());
357  }
358  FlushVectorToFile(incoming);
359  std::vector<ValueIn>().swap(incoming);
360  LOG << "finished receiving elems";
361  stream_.reset();
362 
363  timer.Stop();
364 
365  LOG << "RESULT"
366  << " name=mainop"
367  << " time=" << timer
368  << " number_files=" << files_.size();
369  }
370 
371  //! calculate maximum merging degree from available memory and the number of
372  //! files. additionally calculate the prefetch size of each File.
373  std::pair<size_t, size_t> MaxMergeDegreePrefetch() {
374  size_t avail_blocks = DIABase::mem_limit_ / data::default_block_size;
375  if (files_.size() >= avail_blocks) {
376  // more files than blocks available -> partial merge of avail_blocks
377  // Files with prefetch = 0, which is one read Block per File.
378  return std::make_pair(avail_blocks, 0u);
379  }
380  else {
381  // less files than available Blocks -> split blocks equally among
382  // Files.
383  return std::make_pair(
384  files_.size(),
385  std::min<size_t>(16u, (avail_blocks / files_.size()) - 1));
386  }
387  }
388 };
389 
390 /******************************************************************************/
391 
392 template <typename ValueType, typename Stack>
393 template <typename ValueOut, bool LocationDetectionValue,
394  typename KeyExtractor, typename GroupFunction, typename HashFunction>
397  const KeyExtractor& key_extractor,
398  const GroupFunction& groupby_function,
399  const HashFunction& hash_function) const {
400 
401  static_assert(
402  std::is_same<
403  typename std::decay<typename common::FunctionTraits<KeyExtractor>
404  ::template arg<0> >::type,
405  ValueType>::value,
406  "KeyExtractor has the wrong input type");
407 
409  ValueOut, KeyExtractor, GroupFunction, HashFunction,
410  LocationDetectionValue>;
411 
412  auto node = tlx::make_counting<GroupByNode>(
413  *this, key_extractor, groupby_function, hash_function);
414 
415  return DIA<ValueOut>(node);
416 }
417 
418 template <typename ValueType, typename Stack>
419 template <typename ValueOut, typename KeyExtractor,
420  typename GroupFunction, typename HashFunction>
422  const KeyExtractor& key_extractor,
423  const GroupFunction& groupby_function,
424  const HashFunction& hash_function) const {
425  // forward to other method _without_ location detection
426  return GroupByKey<ValueOut>(
427  NoLocationDetectionTag, key_extractor, groupby_function, hash_function);
428 }
429 
430 template <typename ValueType, typename Stack>
431 template <typename ValueOut, typename KeyExtractor, typename GroupFunction>
433  const KeyExtractor& key_extractor,
434  const GroupFunction& groupby_function) const {
435  // forward to other method _without_ location detection
436  return GroupByKey<ValueOut>(
437  NoLocationDetectionTag, key_extractor, groupby_function,
438  std::hash<typename FunctionTraits<KeyExtractor>::result_type>());
439 }
440 
441 } // namespace api
442 } // namespace thrill
443 
444 #endif // !THRILL_API_GROUP_BY_KEY_HEADER
445 
446 /******************************************************************************/
static DIAMemUse Max()
Definition: dia_base.hpp:60
virtual void Dispose()
Virtual clear method. Triggers actual disposing in sub-classes.
Definition: dia_base.hpp:188
DIA is the interface between the user and the Thrill framework.
Definition: dia.hpp:141
core::LocationDetection< HashCount > location_detection_
ValueType_ ValueType
Definition: dia.hpp:152
virtual DIAMemUse ExecuteMemUse()
Amount of RAM used by Execute()
Definition: dia_base.hpp:176
#define sLOG1
Definition: logger.hpp:188
GroupFunction groupby_function_
void WriteBits(BitWriter &writer) const
Write count and dia_mask to BitWriter.
size_t default_block_size
default size of blocks in File, Channel, BlockQueue, etc.
Definition: byte_block.cpp:24
A File is an ordered sequence of Block objects for storing items.
Definition: file.hpp:56
void reset()
release contained pointer, frees object if this is the last reference.
BlockWriter contains a temporary Block object into which a) any serializable item can be stored or b)...
typename common::FunctionTraits< KeyExtractor >::result_type Key
bool operator<(const HashCount &b) const
bool memory_exceeded
memory limit exceeded indicator
ConsumeReader GetConsumeReader(size_t num_prefetch=File::default_prefetch)
Get consuming BlockReader for beginning of File.
Definition: file.cpp:73
void ReadBits(BitReader &reader)
Read count from BitReader.
virtual void PushData(bool consume)=0
Virtual method for pushing data. Triggers actual pushing in sub-classes.
void StartPrefetch(std::vector< Reader > &readers, size_t prefetch)
Take a vector of Readers and prefetch equally from them.
Definition: file.hpp:562
data::File pre_file_
location detection and associated files
void Close()
custom destructor to close writers is a cyclic fashion
Definition: stream.cpp:60
const size_t & id() const
return unique id() of DIANode subclass as stored by StatsNode
Definition: dia_base.hpp:213
static constexpr bool debug
virtual DIAMemUse PreOpMemUse()
Amount of RAM used by PreOp after StartPreOp()
Definition: dia_base.hpp:160
std::pair< size_t, size_t > MaxMergeDegreePrefetch()
An extra class derived from std::vector<> for delivery of the BlockWriters of a Stream.
Definition: stream_data.hpp:56
data::CatStreamPtr stream_
data::CatStreamPtr GetNewCatStream(size_t dia_id)
Definition: context.cpp:1144
void PushItem(const ValueType &item) const
Method for derived classes to Push a single item to all children.
Definition: dia_node.hpp:147
virtual void StopPreOp(size_t)
Virtual method for preparing end of PushData.
Definition: dia_base.hpp:173
void swap(CountingPtr< A, D > &a1, CountingPtr< A, D > &a2) noexcept
virtual void Execute()=0
Virtual execution method. Triggers actual computation in sub-classes.
int value
Definition: gen_data.py:41
HashCount & operator+=(const HashCount &b)
StatsTimerBaseStarted< true > StatsTimerStart
typename common::FunctionTraits< KeyExtractor >::template arg_plain< 0 > ValueIn
data::File::Writer pre_writer_
data::CatStream::Writers emitters_
HashCount operator+(const HashCount &b) const
auto GroupByKey(const KeyExtractor &key_extractor, const GroupByFunction &groupby_function) const
GroupByKey is a DOp, which groups elements of the DIA by its key.
const struct LocationDetectionFlag< false > NoLocationDetectionTag
global const LocationDetectionFlag instance
Definition: dia.hpp:125
virtual void StartPreOp(size_t)
Virtual method for preparing start of PushData.
Definition: dia_base.hpp:163
static IntegerType AddTruncToType(const IntegerType &a, const IntegerType &b)
Definition: math.hpp:30
data::File GetFile(size_t dia_id)
Returns a new File object containing a sequence of local Blocks.
Definition: context.hpp:280
DIAMemUse mem_limit_
Definition: dia_base.hpp:314
void RunUserFunc(data::File &f, bool consume)
Reader GetReader(bool consume, size_t num_prefetch=File::default_prefetch)
Get BlockReader or a consuming BlockReader for beginning of File.
Definition: file.cpp:78
void FlushVectorToFile(std::vector< ValueIn > &v)
Sort and store elements in a file.
A DOpNode is a typed node representing and distributed operations in Thrill.
Definition: dop_node.hpp:32
void MainOp()
Receive elements from other workers.
std::deque< data::File > files_
TLX_ATTRIBUTE_ALWAYS_INLINE BlockWriter & Put(const T &x)
Put appends a complete item, or fails with a FullException.
GroupByNode(const ParentDIA &parent, const KeyExtractor &key_extractor, const GroupFunction &groupby_function, const HashFunction &hash_function=HashFunction())
Constructor for a GroupByNode.
void Close()
Explicitly close the writer.
tag structure for GroupByKey(), and InnerJoin()
Definition: dia.hpp:116
HashCrc32< T > hash
Select a hashing method.
Definition: hash.hpp:262
ValueComparator(const GroupByNode &node)
static constexpr size_t counter_bits_
bool operator()(const ValueIn &a, const ValueIn &b) const
Writer GetWriter(size_t block_size=default_block_size)
Get BlockWriter.
Definition: file.cpp:63
Context & context_
associated Context
Definition: dia_base.hpp:293
virtual DIAMemUse PushDataMemUse()
Amount of RAM used by PushData()
Definition: dia_base.hpp:182
#define hr
#define LOG
Default logging method: output if the local debug variable is true.
Definition: logger.hpp:172