Thrill  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
all_reduce.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * thrill/api/all_reduce.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2015 Matthias Stumpp <[email protected]>
7  * Copyright (C) 2015 Sebastian Lamm <[email protected]>
8  *
9  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
10  ******************************************************************************/
11 
12 #pragma once
13 #ifndef THRILL_API_ALL_REDUCE_HEADER
14 #define THRILL_API_ALL_REDUCE_HEADER
15 
17 #include <thrill/api/dia.hpp>
18 
19 #include <type_traits>
20 
21 namespace thrill {
22 namespace api {
23 
24 /*!
25  * \ingroup api_layer
26  */
27 template <typename ValueType, typename ReduceFunction>
28 class AllReduceNode final : public ActionResultNode<ValueType>
29 {
30  static constexpr bool debug = false;
31 
33  using Super::context_;
34 
35 public:
36  template <typename ParentDIA>
37  AllReduceNode(const ParentDIA& parent,
38  const char* label,
39  const ValueType& initial_value,
40  const ReduceFunction& reduce_function = ReduceFunction())
41  : Super(parent.ctx(), label, { parent.id() }, { parent.node() }),
42  reduce_function_(reduce_function),
43  sum_(initial_value),
44  first_(parent.ctx().my_rank() != 0) {
45  // Hook PreOp(s)
46  auto pre_op_fn = [this](const ValueType& input) {
47  PreOp(input);
48  };
49 
50  auto lop_chain = parent.stack().push(pre_op_fn).fold();
51  parent.node()->AddChild(this, lop_chain);
52  }
53 
54  void PreOp(const ValueType& input) {
55  if (TLX_UNLIKELY(first_)) {
56  first_ = false;
57  sum_ = input;
58  }
59  else {
60  sum_ = reduce_function_(sum_, input);
61  }
62  }
63 
64  //! Executes the sum operation.
65  void Execute() final {
66  // process the reduce
67  sum_ = context_.net.AllReduce(sum_, reduce_function_);
68  }
69 
70  //! Returns result of global sum.
71  const ValueType& result() const final {
72  return sum_;
73  }
74 
75 private:
76  //! The sum function which is applied to two values.
77  ReduceFunction reduce_function_;
78  //! Local/global sum to be used in all reduce operation.
79  ValueType sum_;
80  //! indicate that sum_ is the default constructed first value. Worker 0's
81  //! value is already set to initial_value.
82  bool first_;
83 };
84 
85 template <typename ValueType, typename Stack>
86 template <typename ReduceFunction>
88  const ReduceFunction& sum_function, const ValueType& initial_value) const {
89  assert(IsValid());
90 
92 
93  static_assert(
94  std::is_convertible<
95  ValueType,
97  "ReduceFunction has the wrong input type");
98 
99  static_assert(
100  std::is_convertible<
101  ValueType,
103  "ReduceFunction has the wrong input type");
104 
105  static_assert(
106  std::is_convertible<
108  ValueType>::value,
109  "ReduceFunction has the wrong input type");
110 
111  auto node = tlx::make_counting<AllReduceNode>(
112  *this, "AllReduce", initial_value, sum_function);
113 
114  node->RunScope();
115 
116  return node->result();
117 }
118 
119 template <typename ValueType, typename Stack>
120 template <typename ReduceFunction>
122  const ReduceFunction& sum_function, const ValueType& initial_value) const {
123  assert(IsValid());
124 
126 
127  static_assert(
128  std::is_convertible<
129  ValueType,
131  "ReduceFunction has the wrong input type");
132 
133  static_assert(
134  std::is_convertible<
135  ValueType,
137  "ReduceFunction has the wrong input type");
138 
139  static_assert(
140  std::is_convertible<
142  ValueType>::value,
143  "ReduceFunction has the wrong input type");
144 
145  auto node = tlx::make_counting<AllReduceNode>(
146  *this, "AllReduce", initial_value, sum_function);
147 
148  return Future<ValueType>(node);
149 }
150 
151 } // namespace api
152 } // namespace thrill
153 
154 #endif // !THRILL_API_ALL_REDUCE_HEADER
155 
156 /******************************************************************************/
net::FlowControlChannel & net
Definition: context.hpp:446
ValueType_ ValueType
Definition: dia.hpp:152
Future< ValueType > AllReduceFuture(const ReduceFunction &reduce_function, const ValueType &initial_value=ValueType()) const
AllReduce is an ActionFuture, which computes the reduction sum of all elements globally and delivers ...
Definition: all_reduce.hpp:121
virtual const ValueType & result() const =0
virtual method to return result via an ActionFuture
static constexpr bool debug
Definition: all_reduce.hpp:30
const char * label() const
return label() of DIANode subclass as stored by StatsNode
Definition: dia_base.hpp:218
#define TLX_UNLIKELY(c)
Definition: likely.hpp:24
ValueType AllReduce(const ReduceFunction &reduce_function, const ValueType &initial_value=ValueType()) const
AllReduce is an Action, which computes the reduction sum of all elements globally and delivers the sa...
Definition: all_reduce.hpp:87
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT AllReduce(const T &value, const BinarySumOp &sum_op=BinarySumOp())
Reduces a value of a serializable type T over all workers given a certain reduce function.
AllReduceNode(const ParentDIA &parent, const char *label, const ValueType &initial_value, const ReduceFunction &reduce_function=ReduceFunction())
Definition: all_reduce.hpp:37
virtual void Execute()=0
Virtual execution method. Triggers actual computation in sub-classes.
int value
Definition: gen_data.py:41
common::FunctionTraits< Function > FunctionTraits
alias for convenience.
Definition: dia.hpp:147
The return type class for all ActionFutures.
Definition: action_node.hpp:83
ValueType sum_
Local/global sum to be used in all reduce operation.
Definition: all_reduce.hpp:41
Context & context_
associated Context
Definition: dia_base.hpp:293