Thrill  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
all_reduce.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * thrill/api/all_reduce.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2015 Matthias Stumpp <[email protected]>
7  * Copyright (C) 2015 Sebastian Lamm <[email protected]>
8  *
9  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
10  ******************************************************************************/
11 
12 #pragma once
13 #ifndef THRILL_API_ALL_REDUCE_HEADER
14 #define THRILL_API_ALL_REDUCE_HEADER
15 
17 #include <thrill/api/dia.hpp>
18 
19 #include <type_traits>
20 
21 namespace thrill {
22 namespace api {
23 
24 /*!
25  * \ingroup api_layer
26  */
27 template <typename ValueType, typename ReduceFunction>
28 class AllReduceNode final : public ActionResultNode<ValueType>
29 {
30  static constexpr bool debug = false;
31 
33  using Super::context_;
34 
35 public:
36  template <typename ParentDIA>
37  AllReduceNode(const ParentDIA& parent,
38  const char* label,
39  const ValueType& initial_value,
40  const ReduceFunction& reduce_function = ReduceFunction())
41  : Super(parent.ctx(), label, { parent.id() }, { parent.node() }),
42  reduce_function_(reduce_function),
43  sum_(initial_value),
44  first_(parent.ctx().my_rank() != 0)
45  {
46  // Hook PreOp(s)
47  auto pre_op_fn = [this](const ValueType& input) {
48  PreOp(input);
49  };
50 
51  auto lop_chain = parent.stack().push(pre_op_fn).fold();
52  parent.node()->AddChild(this, lop_chain);
53  }
54 
55  void PreOp(const ValueType& input) {
56  if (TLX_UNLIKELY(first_)) {
57  first_ = false;
58  sum_ = input;
59  }
60  else {
61  sum_ = reduce_function_(sum_, input);
62  }
63  }
64 
65  //! Executes the sum operation.
66  void Execute() final {
67  // process the reduce
68  sum_ = context_.net.AllReduce(sum_, reduce_function_);
69  }
70 
71  //! Returns result of global sum.
72  const ValueType& result() const final {
73  return sum_;
74  }
75 
76 private:
77  //! The sum function which is applied to two values.
78  ReduceFunction reduce_function_;
79  //! Local/global sum to be used in all reduce operation.
80  ValueType sum_;
81  //! indicate that sum_ is the default constructed first value. Worker 0's
82  //! value is already set to initial_value.
83  bool first_;
84 };
85 
86 template <typename ValueType, typename Stack>
87 template <typename ReduceFunction>
89  const ReduceFunction& sum_function, const ValueType& initial_value) const {
90  assert(IsValid());
91 
93 
94  static_assert(
95  std::is_convertible<
96  ValueType,
98  "ReduceFunction has the wrong input type");
99 
100  static_assert(
101  std::is_convertible<
102  ValueType,
104  "ReduceFunction has the wrong input type");
105 
106  static_assert(
107  std::is_convertible<
109  ValueType>::value,
110  "ReduceFunction has the wrong input type");
111 
112  auto node = tlx::make_counting<AllReduceNode>(
113  *this, "AllReduce", initial_value, sum_function);
114 
115  node->RunScope();
116 
117  return node->result();
118 }
119 
120 template <typename ValueType, typename Stack>
121 template <typename ReduceFunction>
123  const ReduceFunction& sum_function, const ValueType& initial_value) const {
124  assert(IsValid());
125 
127 
128  static_assert(
129  std::is_convertible<
130  ValueType,
132  "ReduceFunction has the wrong input type");
133 
134  static_assert(
135  std::is_convertible<
136  ValueType,
138  "ReduceFunction has the wrong input type");
139 
140  static_assert(
141  std::is_convertible<
143  ValueType>::value,
144  "ReduceFunction has the wrong input type");
145 
146  auto node = tlx::make_counting<AllReduceNode>(
147  *this, "AllReduce", initial_value, sum_function);
148 
149  return Future<ValueType>(node);
150 }
151 
152 } // namespace api
153 } // namespace thrill
154 
155 #endif // !THRILL_API_ALL_REDUCE_HEADER
156 
157 /******************************************************************************/
net::FlowControlChannel & net
Definition: context.hpp:443
ValueType_ ValueType
Definition: dia.hpp:152
Future< ValueType > AllReduceFuture(const ReduceFunction &reduce_function, const ValueType &initial_value=ValueType()) const
AllReduce is an ActionFuture, which computes the reduction sum of all elements globally and delivers ...
Definition: all_reduce.hpp:122
virtual const ValueType & result() const =0
virtual method to return result via an ActionFuture
static constexpr bool debug
Definition: all_reduce.hpp:30
const char * label() const
return label() of DIANode subclass as stored by StatsNode
Definition: dia_base.hpp:218
#define TLX_UNLIKELY(c)
Definition: likely.hpp:24
ValueType AllReduce(const ReduceFunction &reduce_function, const ValueType &initial_value=ValueType()) const
AllReduce is an Action, which computes the reduction sum of all elements globally and delivers the sa...
Definition: all_reduce.hpp:88
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT AllReduce(const T &value, const BinarySumOp &sum_op=BinarySumOp())
Reduces a value of a serializable type T over all workers given a certain reduce function.
AllReduceNode(const ParentDIA &parent, const char *label, const ValueType &initial_value, const ReduceFunction &reduce_function=ReduceFunction())
Definition: all_reduce.hpp:37
virtual void Execute()=0
Virtual execution method. Triggers actual computation in sub-classes.
int value
Definition: gen_data.py:41
common::FunctionTraits< Function > FunctionTraits
alias for convenience.
Definition: dia.hpp:147
The return type class for all ActionFutures.
Definition: action_node.hpp:83
ValueType sum_
Local/global sum to be used in all reduce operation.
Definition: all_reduce.hpp:41
Context & context_
associated Context
Definition: dia_base.hpp:293