Thrill
0.1
|
From step 1 of this tutorial we have a DIA<Point> point
containing random points. In this step, we will pick 10 random point as initial centers. Then we will iterate over all points to determine the closest cluster centers.
Selecting 10 random cluster centers is easy, since Thrill provides a Sample() DIA operation, which selects a fixed number of items uniformly at random.
Next we want to classify all points using the centers. This can be done in Thrill using a Map(), which takes a point and outputs its associated cluster id.
For the classification, this Map() operation requires the cluster centers. The easiest way to accomplish this is to use C++ lambda captures to copy them into the lambda context.
However, first we have to collect the center points on all workers, since they are currently stored in a DIA, which cannot be accessed directly. This broadcast of the centers is called AllGather() in Thrill (after the corresponding MPI collective), which delivers the same std::vector<Point>
on each worker thread.
The final classification Map() reads as follows:
To perform the actual classification loop over all centers, we need to calculate the distance of the point to each center. With high-flying object-oriented spirits, we decided to add a .DistanceSquare()
method to our Point. For the classification the squared distance is sufficient, since square root is a monotone increasing function.
The question of what kind of items the classification Map() shall return is still open. For step 2 in the tutorial, we decided it is best to just output the result of the classification to check that it is working correctly.
So we have to create a struct which contains the point and its resulting cluster id. This struct should have an operator <<
such that it can be printed easily.
The actual calculation loop is already shown above. It simply iterates over all centers and picks the closest.
See the complete example code examples/tutorial/k-means_step2.cpp
The output of our program so far is something like the following:
[... as before ...] points[99]: (721.08,599.95) points --- End DIA.Print() --- size=100 [host 0 worker 0 000005] PushData() stage Cache.2 with targets [Sample.4] [host 0 worker 0 000006] Execute() stage Sample.4 [host 0 worker 0 000007] PushData() stage Sample.4 with targets [AllGather.5] [host 0 worker 0 000008] Execute() stage AllGather.5 [host 0 worker 0 000009] PushData() stage Cache.2 with targets [Print.7] [host 0 worker 0 000010] Execute() stage Print.7 closest --- Begin DIA.Print() --- size=100 closest[0]: (2:(173.567,374.421)) closest[1]: (6:(827.163,471.481)) closest[2]: (3:(796.444,955.701)) [... more closest ...] closest[97]: (7:(532.274,41.1314)) closest[98]: (5:(474.302,201.813)) closest[99]: (3:(619.357,889.185)) closest --- End DIA.Print() --- size=100 [...]