// Copyright (c) the JPEG XL Project Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. #ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ #define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ #include #include #include "lib/jxl/fields.h" #include "lib/jxl/modular/modular_image.h" #include "lib/jxl/modular/options.h" namespace jxl { namespace weighted { constexpr static size_t kNumPredictors = 4; constexpr static int64_t kPredExtraBits = 3; constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1; constexpr static size_t kNumProperties = 1; struct Header : public Fields { JXL_FIELDS_NAME(WeightedPredictorHeader) // TODO(janwas): move to cc file, avoid including fields.h. Header() { Bundle::Init(this); } Status VisitFields(Visitor *JXL_RESTRICT visitor) override { if (visitor->AllDefault(*this, &all_default)) { // Overwrite all serialized fields, but not any nonserialized_*. visitor->SetDefault(this); return true; } auto visit_p = [visitor](pixel_type val, pixel_type *p) { uint32_t up = *p; JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up)); *p = up; return Status(true); }; JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C)); JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C)); JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca)); JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb)); JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc)); JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd)); JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce)); JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0])); JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1])); JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2])); JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3])); return true; } bool all_default; pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0; uint32_t w[kNumPredictors] = {}; }; struct State { pixel_type_w prediction[kNumPredictors] = {}; pixel_type_w pred = 0; // *before* removing the added bits. std::vector pred_errors[kNumPredictors]; std::vector error; Header header; // Allows to approximate division by a number from 1 to 64. uint32_t divlookup[64]; constexpr static pixel_type_w AddBits(pixel_type_w x) { return uint64_t(x) << kPredExtraBits; } State(Header header, size_t xsize, size_t ysize) : header(header) { // Extra margin to avoid out-of-bounds writes. // All have space for two rows of data. for (size_t i = 0; i < 4; i++) { pred_errors[i].resize((xsize + 2) * 2); } error.resize((xsize + 2) * 2); // Initialize division lookup table. for (int i = 0; i < 64; i++) { divlookup[i] = (1 << 24) / (i + 1); } } // Approximates 4+(maxweight<<24)/(x+1), avoiding division JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const { int shift = static_cast(FloorLog2Nonzero(x + 1)) - 5; if (shift < 0) shift = 0; return 4 + ((maxweight * divlookup[x >> shift]) >> shift); } // Approximates the weighted average of the input values with the given // weights, avoiding division. Weights must sum to at least 16. JXL_INLINE pixel_type_w WeightedAverage(const pixel_type_w *JXL_RESTRICT p, std::array w) const { uint32_t weight_sum = 0; for (size_t i = 0; i < kNumPredictors; i++) { weight_sum += w[i]; } JXL_DASSERT(weight_sum > 15); uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4. weight_sum = 0; for (size_t i = 0; i < kNumPredictors; i++) { w[i] >>= log_weight - 4; weight_sum += w[i]; } // for rounding. pixel_type_w sum = (weight_sum >> 1) - 1; for (size_t i = 0; i < kNumPredictors; i++) { sum += p[i] * w[i]; } return (sum * divlookup[weight_sum - 1]) >> 24; } template JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize, pixel_type_w N, pixel_type_w W, pixel_type_w NE, pixel_type_w NW, pixel_type_w NN, Properties *properties, size_t offset) { size_t cur_row = y & 1 ? 0 : (xsize + 2); size_t prev_row = y & 1 ? (xsize + 2) : 0; size_t pos_N = prev_row + x; size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N; size_t pos_NW = x > 0 ? pos_N - 1 : pos_N; std::array weights; for (size_t i = 0; i < kNumPredictors; i++) { // pred_errors[pos_N] also contains the error of pixel W. // pred_errors[pos_NW] also contains the error of pixel WW. weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] + pred_errors[i][pos_NW]; weights[i] = ErrorWeight(weights[i], header.w[i]); } N = AddBits(N); W = AddBits(W); NE = AddBits(NE); NW = AddBits(NW); NN = AddBits(NN); pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1]; pixel_type_w teN = error[pos_N]; pixel_type_w teNW = error[pos_NW]; pixel_type_w sumWN = teN + teW; pixel_type_w teNE = error[pos_NE]; if (compute_properties) { pixel_type_w p = teW; if (std::abs(teN) > std::abs(p)) p = teN; if (std::abs(teNW) > std::abs(p)) p = teNW; if (std::abs(teNE) > std::abs(p)) p = teNE; (*properties)[offset++] = p; } prediction[0] = W + NE - N; prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5); prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5); prediction[3] = N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc + (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >> 5); pred = WeightedAverage(prediction, weights); // If all three have the same sign, skip clamping. if (((teN ^ teW) | (teN ^ teNW)) > 0) { return (pred + kPredictionRound) >> kPredExtraBits; } // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N). pixel_type_w mx = std::max(W, std::max(NE, N)); pixel_type_w mn = std::min(W, std::min(NE, N)); pred = std::max(mn, std::min(mx, pred)); return (pred + kPredictionRound) >> kPredExtraBits; } JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y, size_t xsize) { size_t cur_row = y & 1 ? 0 : (xsize + 2); size_t prev_row = y & 1 ? (xsize + 2) : 0; val = AddBits(val); error[cur_row + x] = pred - val; for (size_t i = 0; i < kNumPredictors; i++) { pixel_type_w err = (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits; // For predicting in the next row. pred_errors[i][cur_row + x] = err; // Add the error on this pixel to the error on the NE pixel. This has the // effect of adding the error on this pixel to the E and EE pixels. pred_errors[i][prev_row + x + 1] += err; } } }; // Encoder helper function to set the parameters to some presets. inline void PredictorMode(int i, Header *header) { switch (i) { case 0: // ~ lossless16 predictor header->w[0] = 0xd; header->w[1] = 0xc; header->w[2] = 0xc; header->w[3] = 0xc; header->p1C = 16; header->p2C = 10; header->p3Ca = 7; header->p3Cb = 7; header->p3Cc = 7; header->p3Cd = 0; header->p3Ce = 0; break; case 1: // ~ default lossless8 predictor header->w[0] = 0xd; header->w[1] = 0xc; header->w[2] = 0xc; header->w[3] = 0xb; header->p1C = 8; header->p2C = 8; header->p3Ca = 4; header->p3Cb = 0; header->p3Cc = 3; header->p3Cd = 23; header->p3Ce = 2; break; case 2: // ~ west lossless8 predictor header->w[0] = 0xd; header->w[1] = 0xc; header->w[2] = 0xd; header->w[3] = 0xc; header->p1C = 10; header->p2C = 9; header->p3Ca = 7; header->p3Cb = 0; header->p3Cc = 0; header->p3Cd = 16; header->p3Ce = 9; break; case 3: // ~ north lossless8 predictor header->w[0] = 0xd; header->w[1] = 0xd; header->w[2] = 0xc; header->w[3] = 0xc; header->p1C = 16; header->p2C = 8; header->p3Ca = 0; header->p3Cb = 16; header->p3Cc = 0; header->p3Cd = 23; header->p3Ce = 0; break; case 4: default: // something else, because why not header->w[0] = 0xd; header->w[1] = 0xc; header->w[2] = 0xc; header->w[3] = 0xc; header->p1C = 10; header->p2C = 10; header->p3Ca = 5; header->p3Cb = 5; header->p3Cc = 5; header->p3Cd = 12; header->p3Ce = 4; break; } } } // namespace weighted // Stores a node and its two children at the same time. This significantly // reduces the number of branches needed during decoding. struct FlatDecisionNode { // Property + splitval of the top node. int32_t property0; // -1 if leaf. union { PropertyVal splitval0; Predictor predictor; }; uint32_t childID; // childID is ctx id if leaf. // Property+splitval of the two child nodes. union { PropertyVal splitvals[2]; int32_t multiplier; }; union { int32_t properties[2]; int64_t predictor_offset; }; }; using FlatTree = std::vector; class MATreeLookup { public: explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {} struct LookupResult { uint32_t context; Predictor predictor; int64_t offset; int32_t multiplier; }; JXL_INLINE LookupResult Lookup(const Properties &properties) const { uint32_t pos = 0; while (true) { const FlatDecisionNode &node = nodes_[pos]; if (node.property0 < 0) { return {node.childID, node.predictor, node.predictor_offset, node.multiplier}; } bool p0 = properties[node.property0] <= node.splitval0; uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0]; uint32_t off1 = 2 | (properties[node.properties[1]] <= node.splitvals[1] ? 1 : 0); pos = node.childID + (p0 ? off1 : off0); } } private: const FlatTree &nodes_; }; static constexpr size_t kExtraPropsPerChannel = 4; static constexpr size_t kNumNonrefProperties = kNumStaticProperties + 13 + weighted::kNumProperties; constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties; constexpr size_t kGradientProp = 9; // Clamps gradient to the min/max of n, w (and l, implicitly). static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w, const int32_t l) { const int32_t m = std::min(n, w); const int32_t M = std::max(n, w); // The end result of this operation doesn't overflow or underflow if the // result is between m and M, but the intermediate value may overflow, so we // do the intermediate operations in uint32_t and check later if we had an // overflow or underflow condition comparing m, M and l directly. // grad = M + m - l = n + w - l const int32_t grad = static_cast(static_cast(n) + static_cast(w) - static_cast(l)); // We use two sets of ternary operators to force the evaluation of them in // any case, allowing the compiler to avoid branches and use cmovl/cmovg in // x86. const int32_t grad_clamp_M = (l < m) ? M : grad; return (l > M) ? m : grad_clamp_M; } inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) { pixel_type_w p = a + b - c; pixel_type_w pa = std::abs(p - a); pixel_type_w pb = std::abs(p - b); return pa < pb ? a : b; } inline void PrecomputeReferences(const Channel &ch, size_t y, const Image &image, uint32_t i, Channel *references) { ZeroFillImage(&references->plane); uint32_t offset = 0; size_t num_extra_props = references->w; intptr_t onerow = references->plane.PixelsPerRow(); for (int32_t j = static_cast(i) - 1; j >= 0 && offset < num_extra_props; j--) { if (image.channel[j].w != image.channel[i].w || image.channel[j].h != image.channel[i].h) { continue; } if (image.channel[j].hshift != image.channel[i].hshift) continue; if (image.channel[j].vshift != image.channel[i].vshift) continue; pixel_type *JXL_RESTRICT rp = references->Row(0) + offset; const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y); const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0); for (size_t x = 0; x < ch.w; x++, rp += onerow) { pixel_type_w v = rpp[x]; rp[0] = std::abs(v); rp[1] = v; pixel_type_w vleft = (x ? rpp[x - 1] : 0); pixel_type_w vtop = (y ? rpprev[x] : vleft); pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft); pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft); rp[2] = std::abs(v - vpredicted); rp[3] = v - vpredicted; } offset += kExtraPropsPerChannel; } } struct PredictionResult { int context = 0; pixel_type_w guess = 0; Predictor predictor; int32_t multiplier; }; inline void InitPropsRow( Properties *p, const std::array &static_props, const int y) { for (size_t i = 0; i < kNumStaticProperties; i++) { (*p)[i] = static_props[i]; } (*p)[2] = y; (*p)[9] = 0; // local gradient. } namespace detail { enum PredictorMode { kUseTree = 1, kUseWP = 2, kForceComputeProperties = 4, kAllPredictions = 8, kNoEdgeCases = 16 }; JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left, pixel_type_w top, pixel_type_w toptop, pixel_type_w topleft, pixel_type_w topright, pixel_type_w leftleft, pixel_type_w toprightright, pixel_type_w wp_pred) { switch (p) { case Predictor::Zero: return pixel_type_w{0}; case Predictor::Left: return left; case Predictor::Top: return top; case Predictor::Select: return Select(left, top, topleft); case Predictor::Weighted: return wp_pred; case Predictor::Gradient: return pixel_type_w{ClampedGradient(left, top, topleft)}; case Predictor::TopLeft: return topleft; case Predictor::TopRight: return topright; case Predictor::LeftLeft: return leftleft; case Predictor::Average0: return (left + top) / 2; case Predictor::Average1: return (left + topleft) / 2; case Predictor::Average2: return (topleft + top) / 2; case Predictor::Average3: return (top + topright) / 2; case Predictor::Average4: return (6 * top - 2 * toptop + 7 * left + 1 * leftleft + 1 * toprightright + 3 * topright + 8) / 16; default: return pixel_type_w{0}; } } template JXL_INLINE PredictionResult Predict( Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, const intptr_t onerow, const size_t x, const size_t y, Predictor predictor, const MATreeLookup *lookup, const Channel *references, weighted::State *wp_state, pixel_type_w *predictions) { // We start in position 3 because of 2 static properties + y. size_t offset = 3; constexpr bool compute_properties = mode & kUseTree || mode & kForceComputeProperties; constexpr bool nec = mode & kNoEdgeCases; pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0)); pixel_type_w top = (nec || y ? pp[-onerow] : left); pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left); pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top); pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left); pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top); pixel_type_w toprightright = (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright); if (compute_properties) { // location (*p)[offset++] = x; // neighbors (*p)[offset++] = std::abs(top); (*p)[offset++] = std::abs(left); (*p)[offset++] = top; (*p)[offset++] = left; // local gradient (*p)[offset] = left - (*p)[offset + 1]; offset++; // local gradient (*p)[offset++] = left + top - topleft; // FFV1 context properties (*p)[offset++] = left - topleft; (*p)[offset++] = topleft - top; (*p)[offset++] = top - topright; (*p)[offset++] = top - toptop; (*p)[offset++] = left - leftleft; } pixel_type_w wp_pred = 0; if (mode & kUseWP) { wp_pred = wp_state->Predict( x, y, w, top, left, topright, topleft, toptop, p, offset); } if (!nec && compute_properties) { offset += weighted::kNumProperties; // Extra properties. const pixel_type *JXL_RESTRICT rp = references->Row(x); for (size_t i = 0; i < references->w; i++) { (*p)[offset++] = rp[i]; } } PredictionResult result; if (mode & kUseTree) { MATreeLookup::LookupResult lr = lookup->Lookup(*p); result.context = lr.context; result.guess = lr.offset; result.multiplier = lr.multiplier; predictor = lr.predictor; } if (mode & kAllPredictions) { for (size_t i = 0; i < kNumModularPredictors; i++) { predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft, topright, leftleft, toprightright, wp_pred); } } result.guess += PredictOne(predictor, left, top, toptop, topleft, topright, leftleft, toprightright, wp_pred); result.predictor = predictor; return result; } } // namespace detail inline PredictionResult PredictNoTreeNoWP(size_t w, const pixel_type *JXL_RESTRICT pp, const intptr_t onerow, const int x, const int y, Predictor predictor) { return detail::Predict( /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr); } inline PredictionResult PredictNoTreeWP(size_t w, const pixel_type *JXL_RESTRICT pp, const intptr_t onerow, const int x, const int y, Predictor predictor, weighted::State *wp_state) { return detail::Predict( /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, /*references=*/nullptr, wp_state, /*predictions=*/nullptr); } inline PredictionResult PredictTreeNoWP(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, const intptr_t onerow, const int x, const int y, const MATreeLookup &tree_lookup, const Channel &references) { return detail::Predict( p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, /*wp_state=*/nullptr, /*predictions=*/nullptr); } // Only use for y > 1, x > 1, x < w-2, and empty references JXL_INLINE PredictionResult PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, const intptr_t onerow, const int x, const int y, const MATreeLookup &tree_lookup, const Channel &references) { return detail::Predict( p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, /*wp_state=*/nullptr, /*predictions=*/nullptr); } inline PredictionResult PredictTreeWP(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, const intptr_t onerow, const int x, const int y, const MATreeLookup &tree_lookup, const Channel &references, weighted::State *wp_state) { return detail::Predict( p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, wp_state, /*predictions=*/nullptr); } inline PredictionResult PredictLearn(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, const intptr_t onerow, const int x, const int y, Predictor predictor, const Channel &references, weighted::State *wp_state) { return detail::Predict( p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, wp_state, /*predictions=*/nullptr); } inline void PredictLearnAll(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, const intptr_t onerow, const int x, const int y, const Channel &references, weighted::State *wp_state, pixel_type_w *predictions) { detail::Predict( p, w, pp, onerow, x, y, Predictor::Zero, /*lookup=*/nullptr, &references, wp_state, predictions); } inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp, const intptr_t onerow, const int x, const int y, pixel_type_w *predictions) { detail::Predict( /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero, /*lookup=*/nullptr, /*references=*/nullptr, /*wp_state=*/nullptr, predictions); } } // namespace jxl #endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_