/* Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ #include #include #include "reductions.h" #include "vw.h" using namespace LEARNER; using namespace VW::config; namespace { const float max_multiplier = 1000.f; const size_t baseline_enabled_idx = 1357; // feature index for enabling baseline } // namespace namespace BASELINE { void set_baseline_enabled(example* ec) { auto& fs = ec->feature_space[message_namespace]; for (auto& f : fs) { if (f.index() == baseline_enabled_idx) { f.value() = 1; return; } } // if not found, push new feature fs.push_back(1, baseline_enabled_idx); } void reset_baseline_disabled(example* ec) { auto& fs = ec->feature_space[message_namespace]; for (auto& f : fs) { if (f.index() == baseline_enabled_idx) { f.value() = 0; return; } } } bool baseline_enabled(example* ec) { auto& fs = ec->feature_space[message_namespace]; for (auto& f : fs) { if (f.index() == baseline_enabled_idx) return f.value() == 1; } return false; } } // namespace BASELINE struct baseline { example* ec; vw* all; bool lr_scaling; // whether to scale baseline learning rate based on max label float lr_multiplier; bool global_only; // only use a global constant for the baseline bool global_initialized; bool check_enabled; // only use baseline when the example contains enabled flag ~baseline() { if (ec) VW::dealloc_example(simple_label.delete_label, *ec); free(ec); } }; void init_global(baseline& data) { if (!data.global_only) return; // use a separate global constant data.ec->indices.push_back(constant_namespace); // different index from constant to avoid conflicts data.ec->feature_space[constant_namespace].push_back( 1, ((constant - 17) * data.all->wpp) << data.all->weights.stride_shift()); data.ec->total_sum_feat_sq++; data.ec->num_features++; } template void predict_or_learn(baseline& data, single_learner& base, example& ec) { // no baseline if check_enabled is true and example contains flag if (data.check_enabled && !BASELINE::baseline_enabled(&ec)) { if (is_learn) base.learn(ec); else base.predict(ec); return; } // always do a full prediction, for safety in accurate predictive validation if (data.global_only) { if (!data.global_initialized) { init_global(data); data.global_initialized = true; } VW::copy_example_metadata(/*audit=*/false, data.ec, &ec); base.predict(*data.ec); ec.l.simple.initial = data.ec->pred.scalar; base.predict(ec); } else base.predict(ec); if (is_learn) { const float pred = ec.pred.scalar; // save 'safe' prediction // now learn data.ec->l.simple = ec.l.simple; if (!data.global_only) { // move label & constant features data over to baseline example VW::copy_example_metadata(/*audit=*/false, data.ec, &ec); VW::move_feature_namespace(data.ec, &ec, constant_namespace); } // regress baseline on label if (data.lr_scaling) { float multiplier = data.lr_multiplier; if (multiplier == 0) { multiplier = std::max(0.0001f, std::max(std::abs(data.all->sd->min_label), std::abs(data.all->sd->max_label))); if (multiplier > max_multiplier) multiplier = max_multiplier; } data.all->eta *= multiplier; base.learn(*data.ec); data.all->eta /= multiplier; } else base.learn(*data.ec); // regress residual ec.l.simple.initial = data.ec->pred.scalar; base.learn(ec); if (!data.global_only) { // move feature data back to the original example VW::move_feature_namespace(&ec, data.ec, constant_namespace); } // return the safe prediction ec.pred.scalar = pred; } } float sensitivity(baseline& data, base_learner& base, example& ec) { // no baseline if check_enabled is true and example contains flag if (data.check_enabled && !BASELINE::baseline_enabled(&ec)) return base.sensitivity(ec); if (!data.global_only) THROW("sensitivity for baseline without --global_only not implemented"); // sensitivity of baseline term VW::copy_example_metadata(/*audit=*/false, data.ec, &ec); data.ec->l.simple.label = ec.l.simple.label; data.ec->pred.scalar = ec.pred.scalar; // std::cout << "before base" << std::endl; const float baseline_sens = base.sensitivity(*data.ec); // std::cout << "base sens: " << baseline_sens << std::endl; // sensitivity of residual as_singleline(&base)->predict(*data.ec); ec.l.simple.initial = data.ec->pred.scalar; const float sens = base.sensitivity(ec); // std::cout << " residual sens: " << sens << std::endl; return baseline_sens + sens; } base_learner* baseline_setup(options_i& options, vw& all) { auto data = scoped_calloc_or_throw(); bool baseline_option = false; std::string loss_function; option_group_definition new_options("Baseline options"); new_options .add(make_option("baseline", baseline_option) .keep() .help("Learn an additive baseline (from constant features) and a residual separately in regression.")) .add(make_option("lr_multiplier", data->lr_multiplier).help("learning rate multiplier for baseline model")) .add(make_option("global_only", data->global_only) .keep() .help("use separate example with only global constant for baseline predictions")) .add(make_option("check_enabled", data->check_enabled) .keep() .help("only use baseline when the example contains enabled flag")); options.add_and_parse(new_options); if (!baseline_option) return nullptr; // initialize baseline example data->ec = VW::alloc_examples(simple_label.label_size, 1); data->ec->interactions = &all.interactions; data->ec->in_use = true; data->all = &all; auto loss_function_type = all.loss->getType(); if (loss_function_type != "logistic") data->lr_scaling = true; auto base = as_singleline(setup_base(options, all)); learner& l = init_learner(data, base, predict_or_learn, predict_or_learn); l.set_sensitivity(sensitivity); return make_base(l); }