import numpy as np from ldpc import bp_decoder from bposd.css import css_code import mwpf from scipy.sparse import csr_array import time class first_min_bp_decoder: def __init__(self, H, error_rate): self.H = H self.error_rate = error_rate self.n = H.shape[1] self.m = H.shape[0] def decode(self, syndrome): error = np.zeros(self.n, dtype=np.uint8) weight = np.sum(syndrome) T = 0 while True: T += 1 BP = bp_decoder( self.H, error_rate = self.error_rate, bp_method = "ps", max_iter = T ) error_new = BP.decode(syndrome) weight_new = np.sum((syndrome + self.H@error_new)%2) if weight_new < weight: weight = weight_new error = error_new else: break return error class uf_decoder: def __init__(self, H, error_rate): self.num_error = H.shape[1] vertex_num = H.shape[0] weighted_edges = [] for i in range(self.num_error): weighted_edges.append(mwpf.HyperEdge(np.nonzero(H[:,i])[0], 1)) initializer = mwpf.SolverInitializer(vertex_num, weighted_edges) self.solver = mwpf.SolverSerialUnionFind(initializer) def decode(self, syndrome): syndrome = np.nonzero(syndrome)[0] self.solver.solve(mwpf.SyndromePattern(syndrome)) subgraph = self.solver.subgraph() self.solver.clear() error = np.zeros(self.num_error, dtype=np.uint8) for i in subgraph: error[i] = 1 return error class bp_uf_decoder: def __init__(self, H, error_rate): self.H = H self.n = H.shape[1] self.m = H.shape[0] self.bp_decoder = bp_decoder(H, error_rate = error_rate, bp_method = "ps", max_iter = int(self.n/10)) self.time = 0 def decode(self, syndrome): bp_decoder = self.bp_decoder error = bp_decoder.decode(syndrome) syndrome = np.nonzero((syndrome + self.H@error) % 2)[0] if len(syndrome) == 0: return error weighted_edges = [mwpf.HyperEdge(np.nonzero(self.H[:,i])[0], int(min(np.abs(bp_decoder.log_prob_ratios[i])*1e6, 1e18))) for i in range(self.n)] initializer = mwpf.SolverInitializer(self.m, weighted_edges) solver = mwpf.SolverSerialUnionFind(initializer) solver.solve(mwpf.SyndromePattern(syndrome)) subgraph = solver.subgraph() for i in subgraph: error[i] = (error[i]+1) % 2 return error