Thrill  0.1
all_reduce.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * thrill/api/all_reduce.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2015 Matthias Stumpp <[email protected]>
7  * Copyright (C) 2015 Sebastian Lamm <[email protected]>
8  *
9  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
10  ******************************************************************************/
11 
12 #pragma once
13 #ifndef THRILL_API_ALL_REDUCE_HEADER
14 #define THRILL_API_ALL_REDUCE_HEADER
15 
17 #include <thrill/api/dia.hpp>
18 
19 #include <type_traits>
20 
21 namespace thrill {
22 namespace api {
23 
24 /*!
25  * \ingroup api_layer
26  */
27 template <typename ValueType, typename ReduceFunction>
28 class AllReduceNode final : public ActionResultNode<ValueType>
29 {
30  static constexpr bool debug = false;
31 
33  using Super::context_;
34 
35 public:
36  template <typename ParentDIA>
37  AllReduceNode(const ParentDIA& parent,
38  const char* label,
39  const ValueType& initial_value = ValueType(),
40  bool with_initial_value = false,
41  const ReduceFunction& reduce_function = ReduceFunction())
42  : Super(parent.ctx(), label, { parent.id() }, { parent.node() }),
43  reduce_function_(reduce_function),
44  sum_(initial_value),
45  // add to initial value if defined and we are first worker
46  first_(!(with_initial_value && parent.ctx().my_rank() == 0)) {
47  // Hook PreOp(s)
48  auto pre_op_fn = [this](const ValueType& input) {
49  PreOp(input);
50  };
51 
52  auto lop_chain = parent.stack().push(pre_op_fn).fold();
53  parent.node()->AddChild(this, lop_chain);
54  }
55 
56  void PreOp(const ValueType& input) {
57  if (TLX_UNLIKELY(first_)) {
58  first_ = false;
59  sum_ = input;
60  }
61  else {
62  sum_ = reduce_function_(sum_, input);
63  }
64  }
65 
66  //! Executes the sum operation.
67  void Execute() final {
68  // process the reduce
70  }
71 
72  //! Returns result of global sum.
73  const ValueType& result() const final {
74  return sum_;
75  }
76 
77 private:
78  //! The sum function which is applied to two values.
79  ReduceFunction reduce_function_;
80  //! Local/global sum to be used in all reduce operation.
81  ValueType sum_;
82  //! indicate that sum_ is the default constructed first value. Worker 0's
83  //! value is already set to initial_value.
84  bool first_;
85 };
86 
87 template <typename ValueType, typename Stack>
88 template <typename ReduceFunction>
90  const ReduceFunction& sum_function) const {
91  assert(IsValid());
92 
94 
95  static_assert(
96  std::is_convertible<
97  ValueType,
99  "ReduceFunction has the wrong input type");
100 
101  static_assert(
102  std::is_convertible<
103  ValueType,
105  "ReduceFunction has the wrong input type");
106 
107  static_assert(
108  std::is_convertible<
110  ValueType>::value,
111  "ReduceFunction has the wrong input type");
112 
113  auto node = tlx::make_counting<AllReduceNode>(
114  *this, "AllReduce", ValueType(), /* with_initial_value */ false,
115  sum_function);
116 
117  node->RunScope();
118 
119  return node->result();
120 }
121 
122 template <typename ValueType, typename Stack>
123 template <typename ReduceFunction>
125  const ReduceFunction& sum_function, const ValueType& initial_value) const {
126  assert(IsValid());
127 
129 
130  static_assert(
131  std::is_convertible<
132  ValueType,
134  "ReduceFunction has the wrong input type");
135 
136  static_assert(
137  std::is_convertible<
138  ValueType,
140  "ReduceFunction has the wrong input type");
141 
142  static_assert(
143  std::is_convertible<
145  ValueType>::value,
146  "ReduceFunction has the wrong input type");
147 
148  auto node = tlx::make_counting<AllReduceNode>(
149  *this, "AllReduce", initial_value, /* with_initial_value */ true,
150  sum_function);
151 
152  node->RunScope();
153 
154  return node->result();
155 }
156 
157 template <typename ValueType, typename Stack>
158 template <typename ReduceFunction>
160  const ReduceFunction& sum_function) const {
161  assert(IsValid());
162 
164 
165  static_assert(
166  std::is_convertible<
167  ValueType,
169  "ReduceFunction has the wrong input type");
170 
171  static_assert(
172  std::is_convertible<
173  ValueType,
175  "ReduceFunction has the wrong input type");
176 
177  static_assert(
178  std::is_convertible<
180  ValueType>::value,
181  "ReduceFunction has the wrong input type");
182 
183  auto node = tlx::make_counting<AllReduceNode>(
184  *this, "AllReduce", ValueType(), /* with_initial_value */ false,
185  sum_function);
186 
187  return Future<ValueType>(node);
188 }
189 
190 template <typename ValueType, typename Stack>
191 template <typename ReduceFunction>
193  const ReduceFunction& sum_function, const ValueType& initial_value) const {
194  assert(IsValid());
195 
197 
198  static_assert(
199  std::is_convertible<
200  ValueType,
202  "ReduceFunction has the wrong input type");
203 
204  static_assert(
205  std::is_convertible<
206  ValueType,
208  "ReduceFunction has the wrong input type");
209 
210  static_assert(
211  std::is_convertible<
213  ValueType>::value,
214  "ReduceFunction has the wrong input type");
215 
216  auto node = tlx::make_counting<AllReduceNode>(
217  *this, "AllReduce", initial_value, /* with_initial_value */ true,
218  sum_function);
219 
220  return Future<ValueType>(node);
221 }
222 
223 } // namespace api
224 } // namespace thrill
225 
226 #endif // !THRILL_API_ALL_REDUCE_HEADER
227 
228 /******************************************************************************/
net::FlowControlChannel & net
Definition: context.hpp:446
static constexpr bool debug
Definition: all_reduce.hpp:30
const char * label() const
return label() of DIANode subclass as stored by StatsNode
Definition: dia_base.hpp:218
#define TLX_UNLIKELY(c)
Definition: likely.hpp:24
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT AllReduce(const T &value, const BinarySumOp &sum_op=BinarySumOp())
Reduces a value of a serializable type T over all workers given a certain reduce function.
void Execute() final
Executes the sum operation.
Definition: all_reduce.hpp:67
ReduceFunction reduce_function_
The sum function which is applied to two values.
Definition: all_reduce.hpp:79
int value
Definition: gen_data.py:41
Future< ValueType > AllReduceFuture(const ReduceFunction &reduce_function) const
AllReduce is an ActionFuture, which computes the reduction sum of all elements globally and delivers ...
Definition: all_reduce.hpp:159
common::FunctionTraits< Function > FunctionTraits
alias for convenience.
Definition: dia.hpp:147
The return type class for all ActionFutures.
Definition: action_node.hpp:83
void PreOp(const ValueType &input)
Definition: all_reduce.hpp:56
const ValueType & result() const final
Returns result of global sum.
Definition: all_reduce.hpp:73
ValueType AllReduce(const ReduceFunction &reduce_function) const
AllReduce is an Action, which computes the reduction sum of all elements globally and delivers the sa...
Definition: all_reduce.hpp:89
AllReduceNode(const ParentDIA &parent, const char *label, const ValueType &initial_value=ValueType(), bool with_initial_value=false, const ReduceFunction &reduce_function=ReduceFunction())
Definition: all_reduce.hpp:37
ValueType sum_
Local/global sum to be used in all reduce operation.
Definition: all_reduce.hpp:81
Context & context_
associated Context
Definition: dia_base.hpp:293