#!/usr/bin/env python import argparse import os import re from function import Function from function import read select_re = re.compile('LAPACK_(\w)_SELECT(\d)') def is_scalar(name, cty, f): return ( \ 'c_char' in cty or name in [ 'abnrm', 'abstol', 'amax', 'anorm', 'bbnrm', 'colcnd', 'ihi', 'il', 'ilo', 'info', 'iter', 'iu', 'l', 'liwork', 'lrwork', 'lwork', 'm', 'mm', 'n', 'n_err_bnds', 'nb', 'nrhs', 'rank', 'rcond', 'rowcnd', 'rpvgrw', 'sdim', 'tryrac', 'vu', ] or name in [ 'alpha', ] and ( 'larfg' in f.name ) or name in [ 'dif', ] and not ( 'tgsen' in f.name or 'tgsna' in f.name ) or name in [ 'p', ] and not ( 'tgevc' in f.name ) or name in [ 'q' ] and ( 'lapack_int' in cty ) or name in [ 'vl', 'vr', ] and not ( 'geev' in f.name or 'ggev' in f.name or 'hsein' in f.name or 'tgevc' in f.name or 'tgsna' in f.name or 'trevc' in f.name or 'trsna' in f.name ) or name.startswith('k') and not ( 'lapmr' in f.name or 'lapmt' in f.name ) or name.startswith('inc') or name.startswith('ld') or name.startswith('tol') or name.startswith('vers') ) def translate_name(name): return name.lower() def translate_base_type(cty): cty = cty.replace('floatcmplx', 'lapack_complex_float') cty = cty.replace('doublecmplx', 'lapack_complex_double') cty = cty.replace('lapack_float_return', 'c_float') cty = cty.replace('f32', 'c_float') cty = cty.replace('f64', 'c_double') if 'c_char' in cty: return 'u8' elif 'c_int' in cty: return 'i32' elif 'c_float' in cty: return 'f32' elif 'c_double' in cty: return 'f64' elif 'lapack_complex_float' in cty: return 'c32' elif 'lapack_complex_double' in cty: return 'c64' elif 'size_t' in cty: return 'size_t' assert False, 'cannot translate `{}`'.format(cty) def translate_signature_type(name, cty, f): m = select_re.match(cty) if m is not None: if m.group(1) == 'S': return 'Select{}F32'.format(m.group(2)) elif m.group(1) == 'D': return 'Select{}F64'.format(m.group(2)) elif m.group(1) == 'C': return 'Select{}C32'.format(m.group(2)) elif m.group(1) == 'Z': return 'Select{}C64'.format(m.group(2)) base = translate_base_type(cty) if '*const' in cty: if is_scalar(name, cty, f): return base else: return '&[{}]'.format(base) elif '*mut' in cty: if is_scalar(name, cty, f): return '&mut {}'.format(base) else: return '&mut [{}]'.format(base) return base def translate_body_argument(name, rty): if rty.startswith('Select'): return 'transmute({})'.format(name) if rty == 'u8': return '&({} as c_char)'.format(name) elif rty == '&mut u8': return '{} as *mut _ as *mut _'.format(name) elif rty == 'i32': return '&{}'.format(name) elif rty == '&mut i32': return name elif rty == '&[i32]': return '{}.as_ptr()'.format(name) elif rty == '&mut [i32]': return '{}.as_mut_ptr()'.format(name) elif rty.startswith('f'): return '&{}'.format(name) elif rty.startswith('&mut f'): return name elif rty.startswith('&[f'): return '{}.as_ptr()'.format(name) elif rty.startswith('&mut [f'): return '{}.as_mut_ptr()'.format(name) elif rty.startswith('c'): return '&{} as *const _ as *const _'.format(name) elif rty.startswith('&mut c'): return '{} as *mut _ as *mut _'.format(name) elif rty.startswith('&[c'): return '{}.as_ptr() as *const _'.format(name) elif rty.startswith('&mut [c'): return '{}.as_mut_ptr() as *mut _'.format(name) elif rty == 'size_t': return name assert False, 'cannot translate `{}: {}`'.format(name, rty) def format_signature(f): args = format_signature_arguments(f) if f.ret is None: return 'pub unsafe fn {}({})'.format(f.name, args) else: return 'pub unsafe fn {}({}) -> {}'.format(f.name, args, translate_base_type(f.ret)) def format_signature_arguments(f): s = [] for name, cty in f.args: name = translate_name(name) s.append('{}: {}'.format(name, translate_signature_type(name, cty, f))) return ', '.join(s) def format_body(f): return 'ffi::{}_({})'.format(f.name, format_body_arguments(f)) def format_body_arguments(f): s = [] for name, cty in f.args: name = translate_name(name) rty = translate_signature_type(name, cty, f) s.append(translate_body_argument(name, rty)) return ', '.join(s) def process(code): lines = filter(lambda line: not re.match(r'^\s*//.*', line), code.split('\n')) lines = re.sub(r'\s+', ' ', ''.join(lines)).strip().split(';') lines = filter(lambda line: not re.match(r'^\s*$', line), lines) return [Function.parse(line) for line in lines] def write(functions): for f in functions: if f.name in ['lsame']: continue print('\n#[inline]') print(format_signature(f) + ' {') print(' ' + format_body(f) + '\n}') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--sys', default='pfapack') arguments = parser.parse_args() path = os.path.join(arguments.sys, 'src', 'pfapack.rs') write(process(read(path)))