13 #ifndef THRILL_API_ALL_REDUCE_HEADER 14 #define THRILL_API_ALL_REDUCE_HEADER 19 #include <type_traits> 27 template <
typename ValueType,
typename ReduceFunction>
30 static constexpr
bool debug =
false;
36 template <
typename ParentDIA>
39 const ValueType& initial_value = ValueType(),
40 bool with_initial_value =
false,
41 const ReduceFunction& reduce_function = ReduceFunction())
42 :
Super(parent.ctx(), label, { parent.id() }, { parent.node() }),
46 first_(!(with_initial_value && parent.ctx().my_rank() == 0)) {
48 auto pre_op_fn = [
this](
const ValueType& input) {
52 auto lop_chain = parent.stack().push(pre_op_fn).fold();
53 parent.node()->AddChild(
this, lop_chain);
56 void PreOp(
const ValueType& input) {
73 const ValueType&
result() const final {
87 template <
typename ValueType,
typename Stack>
88 template <
typename ReduceFunction>
90 const ReduceFunction& sum_function)
const {
99 "ReduceFunction has the wrong input type");
105 "ReduceFunction has the wrong input type");
111 "ReduceFunction has the wrong input type");
113 auto node = tlx::make_counting<AllReduceNode>(
119 return node->result();
122 template <
typename ValueType,
typename Stack>
123 template <
typename ReduceFunction>
125 const ReduceFunction& sum_function,
const ValueType& initial_value)
const {
134 "ReduceFunction has the wrong input type");
140 "ReduceFunction has the wrong input type");
146 "ReduceFunction has the wrong input type");
148 auto node = tlx::make_counting<AllReduceNode>(
149 *
this,
"AllReduce", initial_value,
true,
154 return node->result();
157 template <
typename ValueType,
typename Stack>
158 template <
typename ReduceFunction>
160 const ReduceFunction& sum_function)
const {
169 "ReduceFunction has the wrong input type");
175 "ReduceFunction has the wrong input type");
181 "ReduceFunction has the wrong input type");
183 auto node = tlx::make_counting<AllReduceNode>(
190 template <
typename ValueType,
typename Stack>
191 template <
typename ReduceFunction>
193 const ReduceFunction& sum_function,
const ValueType& initial_value)
const {
202 "ReduceFunction has the wrong input type");
208 "ReduceFunction has the wrong input type");
214 "ReduceFunction has the wrong input type");
216 auto node = tlx::make_counting<AllReduceNode>(
217 *
this,
"AllReduce", initial_value,
true,
226 #endif // !THRILL_API_ALL_REDUCE_HEADER net::FlowControlChannel & net
static constexpr bool debug
const char * label() const
return label() of DIANode subclass as stored by StatsNode
T TLX_ATTRIBUTE_WARN_UNUSED_RESULT AllReduce(const T &value, const BinarySumOp &sum_op=BinarySumOp())
Reduces a value of a serializable type T over all workers given a certain reduce function.
void Execute() final
Executes the sum operation.
ReduceFunction reduce_function_
The sum function which is applied to two values.
Future< ValueType > AllReduceFuture(const ReduceFunction &reduce_function) const
AllReduce is an ActionFuture, which computes the reduction sum of all elements globally and delivers ...
common::FunctionTraits< Function > FunctionTraits
alias for convenience.
The return type class for all ActionFutures.
void PreOp(const ValueType &input)
const ValueType & result() const final
Returns result of global sum.
ValueType AllReduce(const ReduceFunction &reduce_function) const
AllReduce is an Action, which computes the reduction sum of all elements globally and delivers the sa...
AllReduceNode(const ParentDIA &parent, const char *label, const ValueType &initial_value=ValueType(), bool with_initial_value=false, const ReduceFunction &reduce_function=ReduceFunction())
ValueType sum_
Local/global sum to be used in all reduce operation.
Context & context_
associated Context