33 constexpr
size_t dim = 3;
39 #define LOGM LOGC(debug && ctx.my_rank() == 0) 41 template <
typename Input>
49 obj.first = common::from_cstr<T>(line.c_str(), &endptr);
51 "Could not parse input line");
53 for (
size_t i = 0; i <
dim; ++i) {
54 T value = common::from_cstr<T>(endptr + 1, &endptr);
56 ((i + 1 <= dim && *endptr ==
',') ||
57 (i + 1 == dim && *endptr == 0)) &&
58 "Could not parse input line");
59 obj.second[i] =
value;
68 std::default_random_engine rng(std::random_device { } ());
69 std::normal_distribution<double> norm_dist(0.0, 1.0);
70 std::lognormal_distribution<double> lognorm_dist(0.0, 1.0);
75 bool c = (2 * index < size);
79 p[0] = index * 0.1 + size / 100.0 * norm_dist(rng);
80 p[1] = index * index * 0.1 + size / 100.0 * norm_dist(rng);
81 p[2] = size - index * 0.1 + size / 100.0 * norm_dist(rng);
91 std::default_random_engine rng(std::random_device { } ());
92 std::normal_distribution<double> norm_dist(0.0, 1.0);
93 std::lognormal_distribution<double> lognorm_dist(0.0, 1.0);
97 [size](
size_t index) {
98 bool c = (2 * index < size);
103 p[1] = index * index * 0.1;
104 p[2] = size - index * 0.1;
110 template <
typename InputDIA>
112 const InputDIA& input_dia,
113 size_t max_iterations,
double gamma,
double epsilon) {
118 std::tie(weights, norm, iterations) =
119 logit_train<T, dim>(input_dia.Keep(), max_iterations, gamma, epsilon);
121 LOGM <<
"Iterations: " << iterations;
122 LOGM <<
"Norm: " << norm;
123 LOGM <<
"Final weights (model):";
124 for (
size_t i = 0; i <
dim; ++i) {
125 LOGM <<
"Model[" << i <<
"] = " << weights[i];
130 template <
typename InputDIA>
132 const InputDIA& input_dia,
const Element& weights) {
133 size_t num_trues, true_trues, num_falses, true_falses;
134 std::tie(num_trues, true_trues, num_falses, true_falses)
135 = logit_test<T, dim>(input_dia, weights);
136 LOGM <<
"Evaluation result for " << test_file <<
":";
137 LOGM <<
"\tTrue: " << true_trues <<
" of " << num_trues <<
" correct, " 138 << num_trues - true_trues <<
" incorrect, " 139 <<
static_cast<double>(true_trues) / num_trues * 100.0 <<
"% matched";
140 LOGM <<
"\tFalse: " << true_falses <<
" of " << num_falses
141 <<
" correct, " << num_falses - true_falses <<
" incorrect, " 142 <<
static_cast<double>(true_falses) / num_falses * 100.0 <<
"% matched";
145 int main(
int argc,
char* argv[]) {
149 std::vector<std::string> test_path;
153 size_t max_iterations = 1000;
154 clp.
add_size_t(
'n',
"iterations", max_iterations,
155 "Maximum number of iterations, default: 1000");
157 double gamma = 0.002, epsilon = 0.0001;
158 clp.
add_double(
'g',
"gamma", gamma,
"Gamma, default: 0.002");
159 clp.
add_double(
'e',
"epsilon", epsilon,
"Epsilon, default: 0.0001");
161 bool generate =
false;
162 clp.
add_bool(
'G',
"generate", generate,
163 "Generate some random data to train and classify");
165 if (!clp.
process(argc, argv)) {
178 size_t size = common::from_cstr<size_t>(training_path.c_str());
180 max_iterations, gamma, epsilon);
187 max_iterations, gamma, epsilon);
189 for (
const auto& test_file : test_path) {
191 TestLogit(ctx, test_file, data, weights);
int main(int argc, char *argv[])
auto TrainLogit(api::Context &ctx, const InputDIA &input_dia, size_t max_iterations, double gamma, double epsilon)
auto Generate(Context &ctx, size_t size, const GenerateFunction &generate_function)
Generate is a Source-DOp, which creates a DIA of given size using a generator function.
static auto ReadInputFile(api::Context &ctx, const Input &input_path)
int Run(const std::function< void(Context &)> &job_startpoint)
Runs the given job startpoint with a Context instance.
void add_size_t(char key, const std::string &longkey, size_t &dest, const std::string &desc)
add size_t option -key, –longkey with description and store to dest
static auto GenerateInput(api::Context &ctx, size_t size)
The Context of a job is a unique instance per worker which holds references to all underlying parts o...
void enable_consume(bool consume=true)
Sets consume-mode flag such that DIA contents may be consumed during PushData().
DIA< std::string > ReadLines(Context &ctx, const std::string &filepath)
ReadLines is a DOp, which reads a file from the file system and creates an ordered DIA according to a...
void print_result(std::ostream &os)
print nicely formatted result of processing
std::basic_string< char, std::char_traits< char >, Allocator< char > > string
string with Manager tracking
Command line parser which automatically fills variables and prints nice usage messages.
void add_param_string(const std::string &name, std::string &dest, const std::string &desc)
add string parameter [name] with description and store to dest
static auto GenerateTestData(api::Context &ctx, size_t size)
void add_double(char key, const std::string &longkey, double &dest, const std::string &desc)
add double option -key, –longkey with description and store to dest
std::pair< bool, Element > DataObject
void add_param_stringlist(const std::string &name, std::vector< std::string > &dest, const std::string &desc)
void add_bool(char key, const std::string &longkey, bool &dest, const std::string &desc)
std::array< T, dim > Element
void TestLogit(api::Context &ctx, const std::string &test_file, const InputDIA &input_dia, const Element &weights)
bool process(int argc, const char *const *argv, std::ostream &os)