syntax = "proto3"; option cc_enable_arenas = true; package tensorflow.tensorforest; import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto"; message FertileStats { // Tracks stats for each node. node_to_slot[i] is the FertileSlot for node i. // This may be sized to max_nodes initially, or grow dynamically as needed. repeated FertileSlot node_to_slot = 1; } message GiniStats { // This allows us to quickly track and calculate impurity (classification) // by storing the sum of input weights and the sum of the squares of the // input weights. Weighted gini is then: 1 - (square / sum * sum). // Updates to these numbers are: // old_i = leaf->value(label) // new_i = old_i + incoming_weight // sum -> sum + incoming_weight // square -> square - (old_i ^ 2) + (new_i ^ 2) // total_left_sum -> total_left_sum - old_left_i * old_total_i + // new_left_i * new_total_i float square = 2; } message LeafStat { // The sum of the weights of the training examples that we have seen. // This is here, outside of the leaf_stat oneof, because almost all // types will want it. float weight_sum = 3; // TODO(thomaswc): Move the GiniStats out of LeafStats and into something // that only tracks them for splits. message GiniImpurityClassificationStats { oneof counts { decision_trees.Vector dense_counts = 1; decision_trees.SparseVector sparse_counts = 2; } GiniStats gini = 3; } // This is the info needed for calculating variance for regression. // Variance will still have to be summed over every output, but the // number of outputs in regression problems is almost always 1. message LeastSquaresRegressionStats { decision_trees.Vector mean_output = 1; decision_trees.Vector mean_output_squares = 2; } oneof leaf_stat { GiniImpurityClassificationStats classification = 1; LeastSquaresRegressionStats regression = 2; // TODO(thomaswc): Add in v5's SparseClassStats. } } message FertileSlot { // The statistics for *all* the examples seen at this leaf. LeafStat leaf_stats = 4; repeated SplitCandidate candidates = 1; // The statistics for the examples seen at this leaf after all the // splits have been initialized. If post_init_leaf_stats.weight_sum // is > 0, then all candidates have been initialized. We need to track // both leaf_stats and post_init_leaf_stats because the first is used // to create the decision_tree::Leaf and the second is used to infer // the statistics for the right side of a split (given the leaf side // stats). LeafStat post_init_leaf_stats = 6; int32 node_id = 5; int32 depth = 7; } message SplitCandidate { // proto representing the potential node. decision_trees.BinaryNode split = 1; // Right counts are inferred from FertileSlot.leaf_stats and left. LeafStat left_stats = 4; // Right stats (not full counts) are kept here. LeafStat right_stats = 5; // Fields used when training with a graph runner. string unique_id = 6; } // Proto used for tracking tree paths during inference time. message TreePath { // Nodes are listed in order that they were traversed. i.e. nodes_visited[0] // is the tree's root node. repeated decision_trees.TreeNode nodes_visited = 1; }