Thrill  0.1
k-means_step4.cpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * examples/tutorial/k-means_step4.cpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2016 Timo Bingmann <[email protected]>
7  *
8  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
9  ******************************************************************************/
10 
11 //! \example examples/tutorial/k-means_step4.cpp
12 //!
13 //! This example is part of the k-means tutorial. See \ref kmeans_tutorial_step4
14 
16 #include <thrill/api/cache.hpp>
17 #include <thrill/api/generate.hpp>
18 #include <thrill/api/print.hpp>
20 #include <thrill/api/sample.hpp>
21 
22 #include <ostream>
23 #include <random>
24 #include <vector>
25 
26 //! A 2-dimensional point with double precision
27 struct Point {
28  //! point coordinates
29  double x, y;
30 
31  double DistanceSquare(const Point& b) const {
32  return (x - b.x) * (x - b.x) + (y - b.y) * (y - b.y);
33  }
34  Point operator + (const Point& b) const {
35  return Point { x + b.x, y + b.y };
36  }
37  Point operator / (double s) const {
38  return Point { x / s, y / s };
39  }
40 };
41 
42 //! make ostream-able for Print()
43 std::ostream& operator << (std::ostream& os, const Point& p) {
44  return os << '(' << p.x << ',' << p.y << ')';
45 }
46 
47 //! Assignment of a point to a cluster.
48 struct ClosestCenter {
49  size_t cluster_id;
50  Point point;
51  size_t count;
52 };
53 //! make ostream-able for Print()
54 std::ostream& operator << (std::ostream& os, const ClosestCenter& cc) {
55  return os << '(' << cc.cluster_id
56  << ':' << cc.point << ':' << cc.count << ')';
57 }
58 
59 //! our main processing method
61 
62  std::default_random_engine rng(std::random_device { } ());
63  std::uniform_real_distribution<double> dist(0.0, 1000.0);
64 
65  // generate 100 random points using uniform distribution
66  auto points =
67  Generate(
68  ctx, /* size */ 100,
69  [&](const size_t&) {
70  return Point { dist(rng), dist(rng) };
71  })
72  .Cache();
73 
74  // print out the points
75  points.Print("points");
76 
77  //! [step4 iteration loop]
78  // pick some initial random cluster centers
79  thrill::DIA<Point> centers = points.Sample(/* num_clusters */ 10);
80 
81  for (size_t iter = 0; iter < /* iterations */ 10; ++iter)
82  {
83  // collect centers in a local vector on each worker
84  std::vector<Point> local_centers = centers.AllGather();
85 
86  auto new_centers =
87  points
88  // calculate the closest center for each point
89  .Map(
90  [local_centers](const Point& p) {
91  double min_dist = p.DistanceSquare(local_centers[0]);
92  size_t cluster_id = 0;
93 
94  for (size_t i = 1; i < local_centers.size(); ++i) {
95  double dist = p.DistanceSquare(local_centers[i]);
96  if (dist < min_dist)
97  min_dist = dist, cluster_id = i;
98  }
99  return ClosestCenter { cluster_id, p, /* count */ 1 };
100  })
101  // new centers as the mean of all points associated with it
102  .ReduceByKey(
103  // key extractor: the cluster id
104  [](const ClosestCenter& cc) { return cc.cluster_id; },
105  // reduction: add points and the counter
106  [](const ClosestCenter& a, const ClosestCenter& b) {
107  return ClosestCenter {
108  a.cluster_id, a.point + b.point, a.count + b.count
109  };
110  })
111  .Map([](const ClosestCenter& cc) {
112  return cc.point / cc.count;
113  });
114 
115  new_centers.Print("new_centers");
116 
117  // Collapse() is needed to fold lambda chain to DIA<Points>
118  centers = new_centers.Collapse();
119  }
120 
121  centers.Print("final centers");
122 
123  return centers.AllGather();
124  //! [step4 iteration loop]
125 }
126 
127 int main() {
128  // launch Thrill program: the lambda function will be run on each worker.
129  return thrill::Run(
130  [&](thrill::Context& ctx) { Process(ctx); });
131 }
132 
133 /******************************************************************************/
DIA is the interface between the user and the Thrill framework.
Definition: dia.hpp:141
std::ostream & operator<<(std::ostream &os, const Point &p)
make ostream-able for Print()
auto Process(thrill::Context &ctx)
our main processing method
std::vector< ValueType > AllGather() const
Returns the whole DIA in an std::vector on each worker.
Definition: all_gather.hpp:114
auto Generate(Context &ctx, size_t size, const GenerateFunction &generate_function)
Generate is a Source-DOp, which creates a DIA of given size using a generator function.
Definition: generate.hpp:87
int Run(const std::function< void(Context &)> &job_startpoint)
Runs the given job startpoint with a Context instance.
Definition: context.cpp:947
thrill::common::Vector< D, double > Point
Compile-Time Fixed-Dimensional Points.
Definition: k-means.hpp:39
The Context of a job is a unique instance per worker which holds references to all underlying parts o...
Definition: context.hpp:221
auto Sample(size_t sample_size) const
Select up to sample_size items uniformly at random and return a new DIA<T>.
Definition: sample.hpp:247
void Print(const std::string &name=std::string()) const
Print is an Action, which collects all data of the DIA at the worker 0 and prints using ostream seria...
Definition: print.hpp:50
list x
Definition: gen_data.py:39
int main()
DIA< ValueType > Collapse() const
Create a CollapseNode which is mainly used to collapse the LOp chain into a DIA<T> with an empty stac...
Definition: collapse.hpp:159