/* Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ #include #include "gd.h" #include "vw.h" #include "rand48.h" #include "reductions.h" #include #include using namespace LEARNER; using namespace VW::config; #define NORM2 (m + 1) struct update_data { struct OjaNewton* ON; float g; float sketch_cnt; float norm2_x; float* Zx; float* AZx; float* delta; float bdelta; float prediction; }; struct OjaNewton { vw* all; std::shared_ptr _random_state; int m; int epoch_size; float alpha; int cnt; int t; float* ev; float* b; float* D; float** A; float** K; float* zv; float* vv; float* tmp; example** buffer; float* weight_buffer; struct update_data data; float learning_rate_cnt; bool normalize; bool random_init; void initialize_Z(parameters& weights) { uint32_t length = 1 << all->num_bits; if (normalize) // initialize normalization part { for (uint32_t i = 0; i < length; i++) (&(weights.strided_index(i)))[NORM2] = 0.1f; } if (!random_init) { // simple initialization for (int i = 1; i <= m; i++) (&(weights.strided_index(i)))[i] = 1.f; } else { // more complicated initialization: orthgonal basis of a random matrix const double PI2 = 2.f * 3.1415927f; for (uint32_t i = 0; i < length; i++) { weight& w = weights.strided_index(i); float r1, r2; for (int j = 1; j <= m; j++) { // box-muller tranform: https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform // redraw until r1 should be strictly positive do { r1 = _random_state->get_and_update_random(); r2 = _random_state->get_and_update_random(); } while (r1 == 0.f); (&w)[j] = std::sqrt(-2.f * log(r1)) * (float)cos(PI2 * r2); } } } // Gram-Schmidt for (int j = 1; j <= m; j++) { for (int k = 1; k <= j - 1; k++) { double tmp = 0; for (uint32_t i = 0; i < length; i++) tmp += ((double)(&(weights.strided_index(i)))[j]) * (&(weights.strided_index(i)))[k]; for (uint32_t i = 0; i < length; i++) (&(weights.strided_index(i)))[j] -= (float)tmp * (&(weights.strided_index(i)))[k]; } double norm = 0; for (uint32_t i = 0; i < length; i++) norm += ((double)(&(weights.strided_index(i)))[j]) * (&(weights.strided_index(i)))[j]; norm = std::sqrt(norm); for (uint32_t i = 0; i < length; i++) (&(weights.strided_index(i)))[j] /= (float)norm; } } void compute_AZx() { for (int i = 1; i <= m; i++) { data.AZx[i] = 0; for (int j = 1; j <= i; j++) { data.AZx[i] += A[i][j] * data.Zx[j]; } } } void update_eigenvalues() { for (int i = 1; i <= m; i++) { float gamma = fmin(learning_rate_cnt / t, 1.f); float tmp = data.AZx[i] * data.sketch_cnt; if (t == 1) { ev[i] = gamma * tmp * tmp; } else { ev[i] = (1 - gamma) * t * ev[i] / (t - 1) + gamma * t * tmp * tmp; } } } void compute_delta() { data.bdelta = 0; for (int i = 1; i <= m; i++) { float gamma = fmin(learning_rate_cnt / t, 1.f); // if different learning rates are used /*data.delta[i] = gamma * data.AZx[i] * data.sketch_cnt; for (int j = 1; j < i; j++) { data.delta[i] -= A[i][j] * data.delta[j]; } data.delta[i] /= A[i][i];*/ // if a same learning rate is used data.delta[i] = gamma * data.Zx[i] * data.sketch_cnt; data.bdelta += data.delta[i] * b[i]; } } void update_K() { float tmp = data.norm2_x * data.sketch_cnt * data.sketch_cnt; for (int i = 1; i <= m; i++) { for (int j = 1; j <= m; j++) { K[i][j] += data.delta[i] * data.Zx[j] * data.sketch_cnt; K[i][j] += data.delta[j] * data.Zx[i] * data.sketch_cnt; K[i][j] += data.delta[i] * data.delta[j] * tmp; } } } void update_A() { for (int i = 1; i <= m; i++) { for (int j = 1; j < i; j++) { zv[j] = 0; for (int k = 1; k <= i; k++) { zv[j] += A[i][k] * K[k][j]; } } for (int j = 1; j < i; j++) { vv[j] = 0; for (int k = 1; k <= j; k++) { vv[j] += A[j][k] * zv[k]; } } for (int j = 1; j < i; j++) { for (int k = j; k < i; k++) { A[i][j] -= vv[k] * A[k][j]; } } float norm = 0; for (int j = 1; j <= i; j++) { float temp = 0; for (int k = 1; k <= i; k++) { temp += K[j][k] * A[i][k]; } norm += A[i][j] * temp; } norm = sqrtf(norm); for (int j = 1; j <= i; j++) { A[i][j] /= norm; } } } void update_b() { for (int j = 1; j <= m; j++) { float tmp = 0; for (int i = j; i <= m; i++) { tmp += ev[i] * data.AZx[i] * A[i][j] / (alpha * (alpha + ev[i])); } b[j] += tmp * data.g; } } void update_D() { for (int j = 1; j <= m; j++) { float scale = fabs(A[j][j]); for (int i = j + 1; i <= m; i++) scale = fmin(fabs(A[i][j]), scale); if (scale < 1e-10) continue; for (int i = 1; i <= m; i++) { A[i][j] /= scale; K[j][i] *= scale; K[i][j] *= scale; } b[j] /= scale; D[j] *= scale; // printf("D[%d] = %f\n", j, D[j]); } } void check() { double max_norm = 0; for (int i = 1; i <= m; i++) for (int j = i; j <= m; j++) max_norm = fmax(max_norm, fabs(K[i][j])); // printf("|K| = %f\n", max_norm); if (max_norm < 1e7) return; // implicit -> explicit representation // printf("begin conversion: t = %d, norm(K) = %f\n", t, max_norm); // first step: K <- AKA' // K <- AK for (int j = 1; j <= m; j++) { memset(tmp, 0, sizeof(double) * (m + 1)); for (int i = 1; i <= m; i++) { for (int h = 1; h <= m; h++) { tmp[i] += A[i][h] * K[h][j]; } } for (int i = 1; i <= m; i++) K[i][j] = tmp[i]; } // K <- KA' for (int i = 1; i <= m; i++) { memset(tmp, 0, sizeof(double) * (m + 1)); for (int j = 1; j <= m; j++) for (int h = 1; h <= m; h++) tmp[j] += K[i][h] * A[j][h]; for (int j = 1; j <= m; j++) { K[i][j] = tmp[j]; } } // second step: w[0] <- w[0] + (DZ)'b, b <- 0. uint32_t length = 1 << all->num_bits; for (uint32_t i = 0; i < length; i++) { weight& w = all->weights.strided_index(i); for (int j = 1; j <= m; j++) w += (&w)[j] * b[j] * D[j]; } memset(b, 0, sizeof(double) * (m + 1)); // third step: Z <- ADZ, A, D <- Identity // double norm = 0; for (uint32_t i = 0; i < length; ++i) { memset(tmp, 0, sizeof(float) * (m + 1)); weight& w = all->weights.strided_index(i); for (int j = 1; j <= m; j++) { for (int h = 1; h <= m; ++h) tmp[j] += A[j][h] * D[h] * (&w)[h]; } for (int j = 1; j <= m; ++j) { // norm = std::max(norm, fabs(tmp[j])); (&w)[j] = tmp[j]; } } // printf("|Z| = %f\n", norm); for (int i = 1; i <= m; i++) { memset(A[i], 0, sizeof(double) * (m + 1)); D[i] = 1; A[i][i] = 1; } } ~OjaNewton() { free(ev); free(b); free(D); free(buffer); free(weight_buffer); free(zv); free(vv); free(tmp); if (A) { for (int i = 1; i <= m; i++) { free(A[i]); free(K[i]); } } free(A); free(K); free(data.Zx); free(data.AZx); free(data.delta); } }; void keep_example(vw& all, OjaNewton& /* ON */, example& ec) { output_and_account_example(all, ec); } void make_pred(update_data& data, float x, float& wref) { int m = data.ON->m; float* w = &wref; if (data.ON->normalize) { x /= std::sqrt(w[NORM2]); } data.prediction += w[0] * x; for (int i = 1; i <= m; i++) { data.prediction += w[i] * x * data.ON->D[i] * data.ON->b[i]; } } void predict(OjaNewton& ON, base_learner&, example& ec) { ON.data.prediction = 0; GD::foreach_feature(*ON.all, ec, ON.data); ec.partial_prediction = (float)ON.data.prediction; ec.pred.scalar = GD::finalize_prediction(ON.all->sd, ec.partial_prediction); } void update_Z_and_wbar(update_data& data, float x, float& wref) { float* w = &wref; int m = data.ON->m; if (data.ON->normalize) x /= std::sqrt(w[NORM2]); float s = data.sketch_cnt * x; for (int i = 1; i <= m; i++) { w[i] += data.delta[i] * s / data.ON->D[i]; } w[0] -= s * data.bdelta; } void compute_Zx_and_norm(update_data& data, float x, float& wref) { float* w = &wref; int m = data.ON->m; if (data.ON->normalize) x /= std::sqrt(w[NORM2]); for (int i = 1; i <= m; i++) { data.Zx[i] += w[i] * x * data.ON->D[i]; } data.norm2_x += x * x; } void update_wbar_and_Zx(update_data& data, float x, float& wref) { float* w = &wref; int m = data.ON->m; if (data.ON->normalize) x /= std::sqrt(w[NORM2]); float g = data.g * x; for (int i = 1; i <= m; i++) { data.Zx[i] += w[i] * x * data.ON->D[i]; } w[0] -= g / data.ON->alpha; } void update_normalization(update_data& data, float x, float& wref) { float* w = &wref; int m = data.ON->m; w[NORM2] += x * x * data.g * data.g; } void learn(OjaNewton& ON, base_learner& base, example& ec) { assert(ec.in_use); // predict predict(ON, base, ec); update_data& data = ON.data; data.g = ON.all->loss->first_derivative(ON.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.l.simple.weight; data.g /= 2; // for half square loss if (ON.normalize) GD::foreach_feature(*ON.all, ec, data); ON.buffer[ON.cnt] = &ec; ON.weight_buffer[ON.cnt++] = data.g / 2; if (ON.cnt == ON.epoch_size) { for (int k = 0; k < ON.epoch_size; k++, ON.t++) { example& ex = *(ON.buffer[k]); data.sketch_cnt = ON.weight_buffer[k]; data.norm2_x = 0; memset(data.Zx, 0, sizeof(float) * (ON.m + 1)); GD::foreach_feature(*ON.all, ex, data); ON.compute_AZx(); ON.update_eigenvalues(); ON.compute_delta(); ON.update_K(); GD::foreach_feature(*ON.all, ex, data); } ON.update_A(); // ON.update_D(); } memset(data.Zx, 0, sizeof(float) * (ON.m + 1)); GD::foreach_feature(*ON.all, ec, data); ON.compute_AZx(); ON.update_b(); ON.check(); if (ON.cnt == ON.epoch_size) { ON.cnt = 0; for (int k = 0; k < ON.epoch_size; k++) { VW::finish_example(*ON.all, *ON.buffer[k]); } } } void save_load(OjaNewton& ON, io_buf& model_file, bool read, bool text) { vw& all = *ON.all; if (read) { initialize_regressor(all); ON.initialize_Z(all.weights); } if (model_file.files.size() > 0) { bool resume = all.save_resume; std::stringstream msg; msg << ":" << resume << "\n"; bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text); double temp = 0.; if (resume) GD::save_load_online_state(all, model_file, read, text, temp); else GD::save_load_regressor(all, model_file, read, text); } } base_learner* OjaNewton_setup(options_i& options, vw& all) { auto ON = scoped_calloc_or_throw(); bool oja_newton; float alpha_inverse; // These two are the only two boolean options that default to true. For now going to do this hack // as the infrastructure doesn't easily support this possibility at the same time providing the // ease of bool switches elsewhere. It seems that the switch behavior is more critical because // of the positional data argument. std::string normalize = "true"; std::string random_init = "true"; option_group_definition new_options("OjaNewton options"); new_options.add(make_option("OjaNewton", oja_newton).keep().help("Online Newton with Oja's Sketch")) .add(make_option("sketch_size", ON->m).default_value(10).help("size of sketch")) .add(make_option("epoch_size", ON->epoch_size).default_value(1).help("size of epoch")) .add(make_option("alpha", ON->alpha).default_value(1.f).help("mutiplicative constant for indentiy")) .add(make_option("alpha_inverse", alpha_inverse).help("one over alpha, similar to learning rate")) .add(make_option("learning_rate_cnt", ON->learning_rate_cnt) .default_value(2.f) .help("constant for the learning rate 1/t")) .add(make_option("normalize", normalize).help("normalize the features or not")) .add(make_option("random_init", random_init).help("randomize initialization of Oja or not")); options.add_and_parse(new_options); if (!options.was_supplied("OjaNewton")) return nullptr; ON->all = &all; ON->_random_state = all.get_random_state(); ON->normalize = normalize == "true"; ON->random_init = random_init == "true"; if (options.was_supplied("alpha_inverse")) ON->alpha = 1.f / alpha_inverse; ON->cnt = 0; ON->t = 1; ON->ev = calloc_or_throw(ON->m + 1); ON->b = calloc_or_throw(ON->m + 1); ON->D = calloc_or_throw(ON->m + 1); ON->A = calloc_or_throw(ON->m + 1); ON->K = calloc_or_throw(ON->m + 1); for (int i = 1; i <= ON->m; i++) { ON->A[i] = calloc_or_throw(ON->m + 1); ON->K[i] = calloc_or_throw(ON->m + 1); ON->A[i][i] = 1; ON->K[i][i] = 1; ON->D[i] = 1; } ON->buffer = calloc_or_throw(ON->epoch_size); ON->weight_buffer = calloc_or_throw(ON->epoch_size); ON->zv = calloc_or_throw(ON->m + 1); ON->vv = calloc_or_throw(ON->m + 1); ON->tmp = calloc_or_throw(ON->m + 1); ON->data.ON = ON.get(); ON->data.Zx = calloc_or_throw(ON->m + 1); ON->data.AZx = calloc_or_throw(ON->m + 1); ON->data.delta = calloc_or_throw(ON->m + 1); all.weights.stride_shift((uint32_t)ceil(log2(ON->m + 2))); learner& l = init_learner(ON, learn, predict, all.weights.stride()); l.set_save_load(save_load); l.set_finish_example(keep_example); return make_base(l); }