Thrill
0.1
|
In step 2 of this tutorial we constructed a DIA containing each point and its closest center id. The next step of Lloyd's algorithm is to calculate the mean (average) over all points associated with the current cluster center. The mean points will be the new centers in the next iteration.
How to calculate the mean point? In Thrill this can be done using a ReduceByKey() reduction, which is similar to MapReduce's reduce method. The idea is to sum all points associated with a center, and then divide by their number, which results in the average point value.
To accomplish this in the code, we extend the ClosestCenter
class with another field: count
. The reduction will add the field point
and the count
field, which together represent an average point.
The count field is initialized with 1
in the result of Map() calculation of the closest center.
And then we use ReduceByKey() for the reduction. The ReduceByKey() DIA operation requires two lambda methods: a key extractor and a reduction function. The key extractor is simply the cluster_id
field, since this is by which we want the set of point associations to be grouped.
The reduction function "adds" two ClosestCenter
structs by adding the point
and count
fields, but keeping the cluster_id
constant.
In step 3's code we added a .Print() for debugging the reduction.
After the reduction, we must divide by the total points aggregated to deliver the new average centers. This is a simple application of Map() which takes a ClosestCenter
and returns a Point
.
The resulting new_centers
is a DIA<Point>
containing the next iteration's cluster centers.
We implicitly used some vector operators on Point
inside the reduction: plus and scalar division. Again, with high-flying object-oriented spirits, we extended the Point
class with the appropriate operators, which make the functions above very readable.
See the complete example code examples/tutorial/k-means_step3.cpp
The output of our program so far is something like the following:
[... as before ...] [host 0 worker 0 000011] PushData() stage Cache.2 with targets [ReduceByKey.8] [host 0 worker 0 000012] Execute() stage ReduceByKey.8 [host 0 worker 0 000013] PushData() stage ReduceByKey.8 with targets [Print.9] [host 0 worker 0 000014] Execute() stage Print.9 reduced_centers --- Begin DIA.Print() --- size=10 reduced_centers[0]: (0:(63.1991,406.621):2) reduced_centers[1]: (1:(4152.89,999.313):5) reduced_centers[2]: (7:(6394.55,2896.39):7) reduced_centers[3]: (5:(17904.5,18761.1):24) reduced_centers[4]: (4:(5667.67,3197.76):9) reduced_centers[5]: (9:(1842.71,6814.27):8) reduced_centers[6]: (2:(3620.52,9055.77):16) reduced_centers[7]: (8:(1473.69,2804.14):11) reduced_centers[8]: (3:(7621.35,1216.32):13) reduced_centers[9]: (6:(903.419,125.585):5) reduced_centers --- End DIA.Print() --- size=10 [host 0 worker 0 000015] PushData() stage ReduceByKey.8 with targets [Print.11] [host 0 worker 0 000016] Execute() stage Print.11 new_centers --- Begin DIA.Print() --- size=10 new_centers[0]: (31.5996,203.311) new_centers[1]: (830.578,199.863) new_centers[2]: (913.507,413.77) new_centers[3]: (746.021,781.712) new_centers[4]: (629.741,355.307) new_centers[5]: (230.339,851.784) new_centers[6]: (226.283,565.986) new_centers[7]: (133.972,254.921) new_centers[8]: (586.258,93.5627) new_centers[9]: (180.684,25.1171) new_centers --- End DIA.Print() --- size=10 [...]