Thrill  0.1
logistic_regression.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * examples/logistic_regression/logistic_regression.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2016 Lorenz HÃ¼bschle-Schneider <[email protected]>
7  * Copyright (C) 2015-2016 Utkarsh Bhardwaj <[email protected]>
8  *
10  ******************************************************************************/
11
12 #pragma once
15
16 #include <thrill/api/cache.hpp>
17 #include <thrill/api/dia.hpp>
18 #include <thrill/api/size.hpp>
19 #include <thrill/api/sum.hpp>
20 #include <thrill/common/logger.hpp>
21
22 #include <algorithm>
23 #include <array>
24 #include <cmath>
25 #include <functional>
26 #include <numeric>
27 #include <utility>
28
29 #define LOGM LOGC(debug && ctx.my_rank() == 0)
30
31 namespace examples {
32 namespace logistic_regression {
33
34 using namespace thrill; // NOLINT
35 static constexpr bool debug = true;
36
37 template <typename T>
38 inline T sigmoid(const T& x) { // can't make it constexpr, exp isn't one
39  return 1.0 / (1.0 + exp(-x));
40 }
41
42 template <typename T, size_t dim>
43 T calc_norm(const std::array<T, dim>& weights,
44  const std::array<T, dim>& new_weights) {
45  T sum = 0.;
46  for (size_t i = 0; i < dim; ++i) {
47  T diff = weights[i] - new_weights[i];
48  sum += (diff * diff);
49  }
50  return std::sqrt(sum);
51 }
52
53 template <typename T, size_t dim>
54 auto gradient(const bool& y, const std::array<T, dim>& x,
55  const std::array<T, dim>& w) {
57  T dot_product = std::inner_product(w.begin(), w.end(), x.begin(), T { 0.0 });
58  T s = sigmoid(dot_product) - y;
59  for (size_t i = 0; i < dim; ++i) {
60  grad[i] = s * x[i];
61  }
63 }
64
65 template <typename T, size_t dim, typename InStack,
66  typename Element = std::array<T, dim> >
67 auto logit_train(const DIA<std::pair<bool, Element>, InStack>& data,
68  size_t max_iterations, double gamma = 0.002,
69  double epsilon = 0.0001) {
70  // weights, initialized to zero
71  Element weights, new_weights;
72  weights[0] = weights[1] = weights[2] = 0;
73  size_t iter = 0;
74  T norm = 0.0;
75
76  while (iter < max_iterations) {
78  data.Keep()
79  .Map([&weights](const std::pair<bool, Element>& elem) -> Element {
81  })
82  .Sum([](const Element& a, const Element& b) -> Element {
83  Element result;
84  std::transform(a.begin(), a.end(), b.begin(),
85  result.begin(), std::plus<T>());
86  return result;
87  });
88
90  new_weights.begin(),
91  [&gamma](const T& a, const T& b) -> T
92  { return a - gamma * b; });
93
94  norm = calc_norm(new_weights, weights);
95  weights = new_weights;
96
97  iter++;
98  if (norm < epsilon) break;
99  }
100
101  return std::make_tuple(weights, norm, iter);
102 }
103
104 template <typename T, size_t dim, typename InStack,
105  typename Element = std::array<T, dim> >
106 auto logit_test(const DIA<std::pair<bool, Element>, InStack>& data,
107  const Element& weights) {
108  size_t expected_true =
109  data.Keep()
110  .Filter([](const std::pair<T, Element>& elem) -> bool {
111  return elem.first;
112  })
113  .Size();
114
115  size_t expected_false = data.Keep().Size() - expected_true;
116
117  using Prediction = std::pair<bool, bool>;
118  auto classification =
119  data.Keep()
120  .Map([&weights](const std::pair<T, Element>& elem) -> Prediction {
121  const Element& coords = elem.second;
122  T predicted_y = std::inner_product(
123  weights.begin(), weights.end(), coords.begin(), T { 0.0 });
124
125  bool prediction = (sigmoid(predicted_y) > 0.5);
126  return Prediction { elem.first, prediction };
127  })
128  .Collapse(); // don't evaluate this twice
129
130  size_t true_trues =
131  classification.Keep()
132  .Filter([](const Prediction& p) { return p.first && p.second; })
133  .Size();
134
135  size_t true_falses =
136  classification
137  .Filter([](const Prediction& p) { return !p.first && !p.second; })
138  .Size();
139
140  return std::make_tuple(expected_true, true_trues,
141  expected_false, true_falses);
142 }
143
144 } // namespace logistic_regression
145 } // namespace examples
146