# Copyright 2024 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import asyncio import timeit import unittest from iree.runtime import ( get_device, HalDeviceLoopBridge, ) class HalDeviceLoopBridgeTest(unittest.TestCase): def testBridge(self): loop = asyncio.new_event_loop() bridge = HalDeviceLoopBridge(self.device, loop) sem1 = None sem2 = None report = None async def main(): def done_1(x): report("PYTHON: sem2.signal(1)") sem2.signal(1) def done_2(x): report("PYTHON: sem2.signal(2)") sem2.signal(2) f1 = bridge.on_semaphore(sem1, 1, "Semaphore 1 Signaled") f1.add_done_callback(done_1) f2 = bridge.on_semaphore(sem2, 1, "Semaphore 2 Signaled") f2.add_done_callback(done_2) f2_again = bridge.on_semaphore(sem2, 2, "Semaphore 2 Signaled Again") sem1.signal(1) f1_result = await f1 report("PYTHON: await f1 =", f1_result) f2_result = await f2 report("PYTHON: await f2 =", f2_result) f2_again_result = await f2_again report("PYTHON: await f2_again =", f2_again_result) self.assertEqual(f1_result, "Semaphore 1 Signaled") self.assertEqual(f2_result, "Semaphore 2 Signaled") self.assertEqual(f2_again_result, "Semaphore 2 Signaled Again") report("PYTHON: ASYNC MAIN() COMPLETE") def run_iter(with_report): nonlocal sem1 nonlocal sem2 nonlocal report sem1 = self.device.create_semaphore(0) sem2 = self.device.create_semaphore(0) if with_report: report = lambda *args: print(*args) else: report = lambda *args: None loop.run_until_complete(main()) try: run_iter(True) iter_time = timeit.timeit("run_iter(False)", globals=locals(), number=10) print(f"Time/iter = {iter_time}s") finally: bridge.stop() def setUp(self): super().setUp() self.device = get_device("local-task") self.allocator = self.device.allocator if __name__ == "__main__": unittest.main()