Thrill  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
logistic_regression.cpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * examples/logistic_regression/logistic_regression.cpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2016 Lorenz H├╝bschle-Schneider <[email protected]>
7  *
8  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
9  ******************************************************************************/
10 
11 #include <thrill/api/cache.hpp>
12 #include <thrill/api/dia.hpp>
13 #include <thrill/api/generate.hpp>
14 #include <thrill/api/print.hpp>
16 #include <thrill/common/logger.hpp>
17 #include <thrill/common/string.hpp>
18 #include <tlx/cmdline_parser.hpp>
19 
20 #include <array>
21 #include <random>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26 
27 #include "logistic_regression.hpp"
28 
29 using namespace thrill; // NOLINT
30 using namespace examples::logistic_regression; // NOLINT
31 
32 // Dimensions of the data
33 constexpr size_t dim = 3;
34 using T = double;
35 
36 using Element = std::array<T, dim>;
37 using DataObject = std::pair<bool, Element>;
38 
39 #define LOGM LOGC(debug && ctx.my_rank() == 0)
40 
41 template <typename Input>
42 static auto ReadInputFile(api::Context& ctx, const Input& input_path) {
43  return ReadLines(ctx, input_path)
44  .Map([](const std::string& line) {
45  // parse "value,dim_1,dim_2,...,dim_n" lines
46  char* endptr;
47  DataObject obj;
48  // yikes C stuff, TODO template
49  obj.first = common::from_cstr<T>(line.c_str(), &endptr);
50  die_unless(endptr && *endptr == ',' &&
51  "Could not parse input line");
52 
53  for (size_t i = 0; i < dim; ++i) {
54  T value = common::from_cstr<T>(endptr + 1, &endptr);
55  die_unless(endptr &&
56  ((i + 1 <= dim && *endptr == ',') ||
57  (i + 1 == dim && *endptr == 0)) &&
58  "Could not parse input line");
59  obj.second[i] = value;
60  }
61  return obj;
62  })
63  .Cache();
64 }
65 
66 static auto GenerateInput(api::Context& ctx, size_t size) {
67 
68  std::default_random_engine rng(std::random_device { } ());
69  std::normal_distribution<double> norm_dist(0.0, 1.0);
70  std::lognormal_distribution<double> lognorm_dist(0.0, 1.0);
71 
72  return Generate(
73  ctx, size,
74  [&](size_t index) {
75  bool c = (2 * index < size);
76 
77  // add noise to features
78  Element p;
79  p[0] = index * 0.1 + size / 100.0 * norm_dist(rng);
80  p[1] = index * index * 0.1 + size / 100.0 * norm_dist(rng);
81  p[2] = size - index * 0.1 + size / 100.0 * norm_dist(rng);
82 
83  return DataObject(c, p);
84  })
85  // cache generated data, otherwise random generators are used again.
86  .Cache();
87 }
88 
89 static auto GenerateTestData(api::Context& ctx, size_t size) {
90 
91  std::default_random_engine rng(std::random_device { } ());
92  std::normal_distribution<double> norm_dist(0.0, 1.0);
93  std::lognormal_distribution<double> lognorm_dist(0.0, 1.0);
94 
95  return Generate(
96  ctx, size,
97  [size](size_t index) {
98  bool c = (2 * index < size);
99 
100  // do not add noise to features
101  Element p;
102  p[0] = index * 0.1;
103  p[1] = index * index * 0.1;
104  p[2] = size - index * 0.1;
105 
106  return DataObject(c, p);
107  });
108 }
109 
110 template <typename InputDIA>
112  const InputDIA& input_dia,
113  size_t max_iterations, double gamma, double epsilon) {
114 
115  Element weights;
116  double norm;
117  size_t iterations;
118  std::tie(weights, norm, iterations) =
119  logit_train<T, dim>(input_dia.Keep(), max_iterations, gamma, epsilon);
120 
121  LOGM << "Iterations: " << iterations;
122  LOGM << "Norm: " << norm;
123  LOGM << "Final weights (model):";
124  for (size_t i = 0; i < dim; ++i) {
125  LOGM << "Model[" << i << "] = " << weights[i];
126  }
127  return weights;
128 }
129 
130 template <typename InputDIA>
131 void TestLogit(api::Context& ctx, const std::string& test_file,
132  const InputDIA& input_dia, const Element& weights) {
133  size_t num_trues, true_trues, num_falses, true_falses;
134  std::tie(num_trues, true_trues, num_falses, true_falses)
135  = logit_test<T, dim>(input_dia, weights);
136  LOGM << "Evaluation result for " << test_file << ":";
137  LOGM << "\tTrue: " << true_trues << " of " << num_trues << " correct, "
138  << num_trues - true_trues << " incorrect, "
139  << static_cast<double>(true_trues) / num_trues * 100.0 << "% matched";
140  LOGM << "\tFalse: " << true_falses << " of " << num_falses
141  << " correct, " << num_falses - true_falses << " incorrect, "
142  << static_cast<double>(true_falses) / num_falses * 100.0 << "% matched";
143 }
144 
145 int main(int argc, char* argv[]) {
146  tlx::CmdlineParser clp;
147 
148  std::string training_path;
149  std::vector<std::string> test_path;
150  clp.add_param_string("input", training_path, "training file pattern(s)");
151  clp.add_param_stringlist("test", test_path, "test file pattern(s)");
152 
153  size_t max_iterations = 1000;
154  clp.add_size_t('n', "iterations", max_iterations,
155  "Maximum number of iterations, default: 1000");
156 
157  double gamma = 0.002, epsilon = 0.0001;
158  clp.add_double('g', "gamma", gamma, "Gamma, default: 0.002");
159  clp.add_double('e', "epsilon", epsilon, "Epsilon, default: 0.0001");
160 
161  bool generate = false;
162  clp.add_bool('G', "generate", generate,
163  "Generate some random data to train and classify");
164 
165  if (!clp.process(argc, argv)) {
166  return -1;
167  }
168 
169  clp.print_result();
170 
171  return api::Run(
172  [&](api::Context& ctx) {
173  ctx.enable_consume();
174 
175  Element weights;
176 
177  if (generate) {
178  size_t size = common::from_cstr<size_t>(training_path.c_str());
179  weights = TrainLogit(ctx, GenerateInput(ctx, size),
180  max_iterations, gamma, epsilon);
181 
182  TestLogit(ctx, "generated",
183  GenerateTestData(ctx, size / 10), weights);
184  }
185  else {
186  weights = TrainLogit(ctx, ReadInputFile(ctx, training_path),
187  max_iterations, gamma, epsilon);
188 
189  for (const auto& test_file : test_path) {
190  auto data = ReadInputFile(ctx, test_file);
191  TestLogit(ctx, test_file, data, weights);
192  }
193  }
194  });
195 }
196 
197 /******************************************************************************/
int main(int argc, char *argv[])
auto TrainLogit(api::Context &ctx, const InputDIA &input_dia, size_t max_iterations, double gamma, double epsilon)
void add_size_t(char key, const std::string &longkey, const std::string &keytype, size_t &dest, const std::string &desc)
auto Generate(Context &ctx, size_t size, const GenerateFunction &generate_function)
Generate is a Source-DOp, which creates a DIA of given size using a generator function.
Definition: generate.hpp:85
#define die_unless(X)
Definition: die.hpp:52
double T
static auto ReadInputFile(api::Context &ctx, const Input &input_path)
void add_double(char key, const std::string &longkey, const std::string &keytype, double &dest, const std::string &desc)
int Run(const std::function< void(Context &)> &job_startpoint)
Runs the given job startpoint with a Context instance.
Definition: context.cpp:887
static auto GenerateInput(api::Context &ctx, size_t size)
The Context of a job is a unique instance per worker which holds references to all underlying parts o...
Definition: context.hpp:218
void enable_consume(bool consume=true)
Sets consume-mode flag such that DIA contents may be consumed during PushData().
Definition: context.hpp:385
DIA< std::string > ReadLines(Context &ctx, const std::string &filepath)
ReadLines is a DOp, which reads a file from the file system and creates an ordered DIA according to a...
Definition: read_lines.hpp:452
void print_result(std::ostream &os)
print nicely formatted result of processing
int value
Definition: gen_data.py:41
std::basic_string< char, std::char_traits< char >, Allocator< char > > string
string with Manager tracking
Definition: allocator.hpp:220
Command line parser which automatically fills variables and prints nice usage messages.
void add_param_string(const std::string &name, std::string &dest, const std::string &desc)
add string parameter [name] with description and store to dest
static auto GenerateTestData(api::Context &ctx, size_t size)
std::pair< bool, Element > DataObject
void add_param_stringlist(const std::string &name, std::vector< std::string > &dest, const std::string &desc)
constexpr size_t dim
std::array< T, dim > Element
void TestLogit(api::Context &ctx, const std::string &test_file, const InputDIA &input_dia, const Element &weights)
void add_bool(char key, const std::string &longkey, const std::string &keytype, bool &dest, const std::string &desc)
#define LOGM
bool process(int argc, const char *const *argv, std::ostream &os)