38 static constexpr
bool debug =
false;
41 : node_(node), context_(node->context()),
42 verbose_(context_.mem_config().verbose_)
46 template <
typename Lambda>
47 void Targets(
const Lambda& lambda)
const {
48 std::vector<DIABase*> children = node_->children();
49 std::reverse(children.begin(), children.end());
51 while (!children.empty()) {
52 DIABase* child = children.back();
55 if (child->ForwardDataOnly()) {
57 std::vector<DIABase*> sub = child->children();
58 children.insert(children.end(), sub.begin(), sub.end());
69 std::ostringstream oss;
70 std::vector<DIABase*> children = node_->children();
71 std::reverse(children.begin(), children.end());
75 while (!children.empty())
77 DIABase* child = children.back();
80 if (child ==
nullptr) {
83 else if (child->ForwardDataOnly()) {
85 std::vector<DIABase*> sub = child->children();
86 children.push_back(
nullptr);
87 children.insert(children.end(), sub.begin(), sub.end());
93 oss << *child <<
' ' <<
'[';
109 std::vector<size_t> TargetIds()
const {
110 std::vector<size_t> ids;
111 Targets([&ids](DIABase* child) { ids.emplace_back(child->dia_id()); });
115 std::vector<DIABase*> TargetPtrs()
const {
116 std::vector<DIABase*> children;
117 Targets([&children](DIABase* child) { children.emplace_back(child); });
122 sLOG <<
"START (EXECUTE) stage" << *node_ <<
"targets" << TargetsString();
124 if (context_.my_rank() == 0) {
126 <<
"Execute() stage" << *node_;
129 std::vector<size_t> target_ids = TargetIds();
131 logger_ <<
"class" <<
"StageBuilder" <<
"event" <<
"execute-start" 132 <<
"targets" << target_ids;
134 DIAMemUse mem_use = node_->ExecuteMemUse();
135 if (mem_use.is_max())
136 mem_use = context_.mem_limit();
137 node_->set_mem_limit(mem_use);
146 catch (std::exception& e) {
147 LOG1 <<
"StageBuilder: caught exception from Execute()" 148 <<
" of stage " << *node_ <<
" - what(): " << e.what();
154 sLOG <<
"FINISH (EXECUTE) stage" << *node_ <<
"targets" << TargetsString()
155 <<
"took" << timer <<
"ms";
157 logger_ <<
"class" <<
"StageBuilder" <<
"event" <<
"execute-done" 158 <<
"targets" << target_ids <<
"elapsed" << timer;
160 LOG <<
"DIA bytes: " << node_->context().block_pool().total_bytes();
164 sLOG <<
"START (PUSHDATA) stage" << *node_ <<
"targets" << TargetsString();
166 if (context_.my_rank() == 0) {
168 <<
"PushData() stage" << *node_
169 <<
"with targets" << TargetsString();
172 if (context_.consume() && node_->consume_counter() == 0) {
173 sLOG1 <<
"StageBuilder: attempt to PushData from" 174 <<
"stage" << *node_ <<
"to" << TargetsString()
175 <<
"failed, it was already consumed. Add .Keep()";
179 std::vector<size_t> target_ids = TargetIds();
181 logger_ <<
"class" <<
"StageBuilder" <<
"event" <<
"pushdata-start" 182 <<
"targets" << target_ids;
186 std::vector<DIABase*> targets = TargetPtrs();
188 const size_t mem_limit = context_.mem_limit();
189 std::vector<DIABase*> max_mem_nodes;
190 size_t const_mem = 0;
194 DIAMemUse m = node_->PushDataMemUse();
196 max_mem_nodes.emplace_back(node_.get());
199 const_mem += m.limit();
200 node_->set_mem_limit(m.limit());
205 for (DIABase* target : TargetPtrs()) {
206 DIAMemUse m = target->PreOpMemUse();
208 max_mem_nodes.emplace_back(target);
211 const_mem += m.limit();
212 target->set_mem_limit(m.limit());
217 if (const_mem > mem_limit) {
218 LOG1 <<
"StageBuilder: constant memory usage of DIANodes in Stage: " 220 <<
", already exceeds Context's mem_limit: " << mem_limit;
226 if (!max_mem_nodes.empty()) {
227 size_t remaining_mem = mem_limit - const_mem;
228 remaining_mem /= max_mem_nodes.size();
230 if (context_.my_rank() == 0) {
231 LOG <<
"StageBuilder: distribute remaining worker memory " 232 << remaining_mem <<
" to " 233 << max_mem_nodes.size() <<
" DIANodes";
236 for (DIABase* target : max_mem_nodes) {
237 target->set_mem_limit(remaining_mem);
241 const_mem = mem_limit;
252 node_->RunPushData();
254 catch (std::exception& e) {
255 LOG1 <<
"StageBuilder: caught exception from PushData()" 256 <<
" of stage " << *node_ <<
" targets " << TargetsString()
257 <<
" - what(): " << e.what();
260 node_->RemoveAllChildren();
263 sLOG <<
"FINISH (PUSHDATA) stage" << *node_ <<
"targets" << TargetsString()
264 <<
"took" << timer <<
"ms";
266 logger_ <<
"class" <<
"StageBuilder" <<
"event" <<
"pushdata-done" 267 <<
"targets" << target_ids <<
"elapsed" << timer;
269 LOG <<
"DIA bytes: " << node_->context().block_pool().total_bytes();
275 return node_->dia_id() < s.node_->dia_id();
285 common::JsonLogger& logger_ { node_->logger_ };
291 mutable bool cycle_mark_ =
false;
294 mutable bool topo_seen_ =
false;
297 template <
typename T>
307 LOG <<
"Finding Stages:";
312 bfs_stack.push_back(action);
313 stages->insert(Stage(action));
315 while (!bfs_stack.empty()) {
317 bfs_stack.pop_front();
319 const std::vector<DIABasePtr>& parents = curr->parents();
321 for (
size_t i = 0; i < parents.size(); ++i) {
325 if (stages->count(Stage(p)) != 0)
continue;
327 if (!curr->ForwardDataOnly()) {
329 LOG <<
" Stage: " << *p;
335 bfs_stack.push_back(p);
339 if (curr->RequireParentPushData(i)) {
341 LOG <<
" Stage: " << *p;
343 bfs_stack.push_back(p);
353 die_unless(!s.cycle_mark_ &&
"Cycle in toposort of Stages? Impossible.");
354 if (s.topo_seen_)
return;
356 s.cycle_mark_ =
true;
358 for (
DIABase* child : s.node_->children()) {
359 auto it = stages->find(Stage(
DIABasePtr(child)));
362 if (it == stages->end())
continue;
369 s.cycle_mark_ =
false;
370 result->push_back(s);
375 for (
const Stage& s : *stages) {
376 if (s.topo_seen_)
continue;
384 LOG <<
"DIABase::Execute() this=" << *
this;
387 LOG <<
"DIA node " << *
this <<
" was already executed.";
391 if (ForwardDataOnly()) {
408 if (context_.my_rank() == 0) {
409 LOG <<
"Topological order";
410 for (
auto top = toporder.rbegin(); top != toporder.rend(); ++top) {
411 LOG <<
" " << *top->node_;
415 assert(toporder.front().node_.get() ==
this);
417 while (!toporder.empty())
419 Stage& s = toporder.back();
421 if (s.node_->ForwardDataOnly()) {
431 if (s.node_.get() !=
this)
435 if (s.node_.get() !=
this)
The DIABase has not been computed yet.
std::set< T, std::less< T >, mem::Allocator< T > > mm_set
#define sLOG
Default logging method: output if the local debug variable is true.
#define sLOGC(cond)
Explicitly specify the condition for logging.
const char * label() const
return label() of DIANode subclass as stored by StatsNode
The Context of a job is a unique instance per worker which holds references to all underlying parts o...
void malloc_tracker_print_status()
user function which prints current and peak allocation to stderr
The DIABase is the untyped super class of DIANode.
static void FindStages(Context &ctx, const DIABasePtr &action, mm_set< Stage > *stages)
StatsTimerBaseStarted< true > StatsTimerStart
std::basic_string< char, std::char_traits< char >, Allocator< char > > string
string with Manager tracking
static void TopoSortStages(mm_set< Stage > *stages, mem::vector< Stage > *result)
static constexpr bool debug
std::vector< T, Allocator< T > > vector
vector with Manager tracking
static void TopoSortVisit(const Stage &s, mm_set< Stage > *stages, mem::vector< Stage > *result)
size_t my_rank() const
Global rank of this worker among all other workers in the system.
const size_t & dia_id() const
return unique id of DIANode subclass as stored by StatsNode
std::deque< T, Allocator< T > > deque
deque with Manager tracking
tlx::CountingPtr< DIABase > DIABasePtr
bool operator<(const uint_pair &b) const
less-than comparison operator
#define LOG
Default logging method: output if the local debug variable is true.
std::ostream & operator<<(std::ostream &os, const DIABase &d)
make ostream-able.