# Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import gc import unittest import iree.runtime as rt NONE_CTOR = lambda iface: None class PyModuleInterfaceTest(unittest.TestCase): def setUp(self): self._instance = rt.VmInstance() def testEmptyModuleLifecycle(self): iface = rt.PyModuleInterface("test1", NONE_CTOR) print(iface) self.assertFalse(iface.initialized) m = iface.create() print(iface) self.assertTrue(iface.initialized) print(m) m = None gc.collect() print(iface) self.assertTrue(iface.destroyed) def testEmptyModuleInstance(self): iface = rt.PyModuleInterface("test1", NONE_CTOR) m = iface.create() context = rt.VmContext(self._instance, modules=(m,)) self.assertTrue(iface.initialized) print(context) # Make sure no circular refs and that everything frees. context = None m = None gc.collect() self.assertTrue(iface.destroyed) def testMultiModuleInstance(self): calls = [] def ctor(iface): calls.append(iface) return None iface = rt.PyModuleInterface("test1", ctor) m = iface.create() context1 = rt.VmContext(self._instance, modules=(m,)) self.assertTrue(iface.initialized) context2 = rt.VmContext(self._instance, modules=(m,)) self.assertTrue(iface.initialized) self.assertEqual(2, len(calls)) # Make sure no circular refs and that everything frees. calls = None context1 = None m = None context2 = None gc.collect() self.assertTrue(iface.destroyed) def testVoidFunctionExport(self): messages = [] class Methods: def __init__(self, iface): self.iface = iface self.counter = 0 def say_hello(self): messages.append(f"Hello! Your number is {self.counter}") print(messages[-1]) self.counter += 1 iface = rt.PyModuleInterface("test1", Methods) iface.export("say_hello", "0v", Methods.say_hello) m = iface.create() context = rt.VmContext(self._instance, modules=(m,)) f = m.lookup_function("say_hello") self.assertIsNotNone(f) args = rt.VmVariantList(0) results = rt.VmVariantList(0) # Invoke twice - should produce two messages. context.invoke(f, args, results) context.invoke(f, args, results) self.assertListEqual( messages, [ "Hello! Your number is 0", "Hello! Your number is 1", ], ) # Make sure no circular refs and that everything frees. context = None m = None gc.collect() self.assertTrue(iface.destroyed) def testPythonException(self): messages = [] class Methods: def __init__(self, iface): pass def do_it(self): raise ValueError("This is from Python") iface = rt.PyModuleInterface("test1", Methods) iface.export("do_it", "0v", Methods.do_it) m = iface.create() context = rt.VmContext(self._instance, modules=(m,)) f = m.lookup_function("do_it") self.assertIsNotNone(f) args = rt.VmVariantList(0) results = rt.VmVariantList(0) # We are testing here that the Python level exception is caught and # translated to an IREE status (surfacing as a RuntimeError) vs percolating # through the C call stack. with self.assertRaisesRegex(RuntimeError, "ValueError: This is from Python"): context.invoke(f, args, results) # Make sure no circular refs and that everything frees. context = None m = None gc.collect() self.assertTrue(iface.destroyed) def testPrimitiveArguments(self): values = [] class Methods: def __init__(self, iface): pass def do_it(self, a, b): values.append((a, b)) iface = rt.PyModuleInterface("test1", Methods) iface.export("do_it_i32", "0ii", Methods.do_it) iface.export("do_it_i64", "0II", Methods.do_it) iface.export("do_it_f32", "0ff", Methods.do_it) iface.export("do_it_f64", "0FF", Methods.do_it) m = iface.create() context = rt.VmContext(self._instance, modules=(m,)) args = rt.VmVariantList(2) results = rt.VmVariantList(0) args.push_int(42) args.push_int(43) context.invoke(m.lookup_function("do_it_i32"), args, results) context.invoke(m.lookup_function("do_it_i64"), args, results) args = rt.VmVariantList(2) args.push_float(2.0) args.push_float(4.0) # TODO: Python doesn't have 32bit floats, so we are populating f64 args. # These are coming back as zeros, and I expected something to be # doing a conversion? The same is being done with i64 above but is # working there. context.invoke(m.lookup_function("do_it_f32"), args, results) context.invoke(m.lookup_function("do_it_f64"), args, results) print(values) self.assertEqual(repr(values), "[(42, 43), (42, 43), (0.0, 0.0), (2.0, 4.0)]") # Make sure no circular refs and that everything frees. context = None m = None gc.collect() self.assertTrue(iface.destroyed) def testPrimitiveResults(self): next_results = None class Methods: def __init__(self, iface): pass def do_it(self): return next_results iface = rt.PyModuleInterface("test1", Methods) iface.export("do_it_i32", "0v_ii", Methods.do_it) iface.export("do_it_i64", "0v_II", Methods.do_it) iface.export("do_it_f32", "0v_ff", Methods.do_it) iface.export("do_it_f64", "0v_FF", Methods.do_it) iface.export("do_it_unary_i32", "0v_i", Methods.do_it) m = iface.create() context = rt.VmContext(self._instance, modules=(m,)) args = rt.VmVariantList(0) # i32 results = rt.VmVariantList(2) next_results = (42, 43) context.invoke(m.lookup_function("do_it_i32"), args, results) self.assertEqual(repr(results), "") # i64 results = rt.VmVariantList(2) next_results = (42, 43) context.invoke(m.lookup_function("do_it_i64"), args, results) self.assertEqual(repr(results), "") # f32 results = rt.VmVariantList(2) next_results = (2.0, 4.0) context.invoke(m.lookup_function("do_it_f32"), args, results) self.assertEqual(repr(results), "") # f64 results = rt.VmVariantList(2) next_results = (2.0, 4.0) context.invoke(m.lookup_function("do_it_f64"), args, results) self.assertEqual(repr(results), "") # Unary special case. results = rt.VmVariantList(1) next_results = 42 context.invoke(m.lookup_function("do_it_unary_i32"), args, results) self.assertEqual(repr(results), "") # Make sure no circular refs and that everything frees. context = None m = None gc.collect() self.assertTrue(iface.destroyed) def testRefArguments(self): values = [] class Methods: def __init__(self, iface): pass def do_it(self, a, b): values.append((a.deref(rt.VmVariantList), b.deref(rt.VmVariantList))) iface = rt.PyModuleInterface("test1", Methods) iface.export("do_it", "0rr", Methods.do_it) m = iface.create() context = rt.VmContext(self._instance, modules=(m,)) # These lists just happen to be reference objects we know how to # create. arg0 = rt.VmVariantList(1) arg0.push_int(42) arg1 = rt.VmVariantList(1) arg1.push_int(84) args = rt.VmVariantList(2) args.push_list(arg0) args.push_list(arg1) results = rt.VmVariantList(2) context.invoke(m.lookup_function("do_it"), args, results) print("REF VALUES:", values) self.assertEqual( repr(values), "[(, )]" ) def testRefResults(self): class Methods: def __init__(self, iface): pass def do_it(self): # These lists just happen to be reference objects we know how to # create. r0 = rt.VmVariantList(1) r0.push_int(42) r1 = rt.VmVariantList(1) r1.push_int(84) return r0.ref, r1.ref iface = rt.PyModuleInterface("test1", Methods) iface.export("do_it", "0v_rr", Methods.do_it) m = iface.create() context = rt.VmContext(self._instance, modules=(m,)) args = rt.VmVariantList(0) results = rt.VmVariantList(2) context.invoke(m.lookup_function("do_it"), args, results) print("REF RESULTS:", results) self.assertEqual(repr(results), "") if __name__ == "__main__": unittest.main()