""" Generate a Rust file containing the quadrature rules generated by polyquad. """ import os import numpy as np all_rule_files = [] orders = [] npoints = [] for dirpath, dirnames, filenames in os.walk("."): all_rule_files += [os.path.join(dirpath, file) for file in filenames if file.endswith(".txt")] for rule_file in all_rule_files: base = os.path.basename(rule_file) order_str, points_str = os.path.splitext(base)[0].split("-") orders += [int(order_str)] npoints += [int(points_str)] with open("simplex_rule_definitions.rs", "w") as f: f.write("//! Definition of simplex rules.\n") f.write("\n") f.write("#![allow(clippy::excessive_precision)]\n"); f.write("use std::collections::HashMap;\n") f.write("use lazy_static::lazy_static;\n") f.write("\n") f.write("type HM = HashMap, Vec)>;\n") f.write("\n") f.write("lazy_static! {\n") f.write("pub(crate) static ref SIMPLEX_RULE_DEFINITIONS_INTERVAL: HM = {\n") # Add the standard Gauss Legendre rules f.write("let mut m = HM::new();\n") nmax = 100 for n in range(1, nmax + 1): p, w = np.polynomial.legendre.leggauss(n) sorted_indices = np.argsort(p) points = 0.5 * (1.0 + p[sorted_indices]) weights = 0.5 * w[sorted_indices] f.write("m.insert(\n") f.write(str(len(w)) + ", \n") f.write("(" + str(2 * len(w) - 1) + ",vec![") for point in points: f.write(f"{point},") f.write("],\n") f.write("vec![\n") for weight in weights: f.write(f"{weight},") f.write("]));\n") f.write("m\n") f.write("};\n") for cell, file_id in [ ("triangle", "tri"), ("quadrilateral", "quad"), ("tetrahedron", "tet"), ("hexahedron", "hex"), ("prism", "pri"), ("pyramid", "pyr"), ]: f.write(f"pub(crate) static ref SIMPLEX_RULE_DEFINITIONS_{cell.upper()}: HM = {{\n") f.write("let mut m = HM::new();\n") for index, rule_file in enumerate(all_rule_files): if not rule_file.startswith(f"./{file_id}"): continue arr = np.atleast_2d(np.loadtxt(rule_file)) points = arr[:, :-1] weights = arr[:, -1] if rule_file.startswith("./quad"): points = 0.5 * (1.0 + points) weights = weights / 4.0 elif rule_file.startswith("./tri"): points = 0.5 * (1.0 + points) weights = weights / 4.0 elif rule_file.startswith("./hex"): points = 0.5 * (1.0 + points) weights = weights / 8.0 elif rule_file.startswith("./pri"): points = 0.5 * (1.0 + points) weights = weights / 8.0 elif rule_file.startswith("./tet"): points = 0.5 * (1.0 + points) weights = weights / 8.0 elif rule_file.startswith("./pyr"): points = (1.0 + points) @ np.array( [[0.5, 0, 0], [0, 0.5, 0], [-0.25, -0.25, 0.5]], dtype="float64" ) weights = weights / 8.0 else: raise ValueError("Unknown simplex type.") points = points.flatten() weights = weights.flatten() f.write("m.insert(\n") f.write(str(npoints[index]) + ", \n") f.write("(" + str(orders[index]) + ",vec![") for point in points: f.write(f"{point},") f.write("],\n") f.write("vec![\n") for weight in weights: f.write(f"{weight},") f.write("]));\n") f.write("m\n") f.write("};\n") f.write("}") os.system("rustfmt ./simplex_rule_definitions.rs") os.system("cp ./simplex_rule_definitions.rs ../src/") os.system("rm ./simplex_rule_definitions.rs")