from ..core._imperative_rt.core2 import Const from ..jit.tracing import is_tracing small_tensor_cache = {} def _get_scalar_tensor_with_value(value, dtype=None, device=None): global small_tensor_cache if is_tracing(): ret = Const(value, dtype, device, None) else: cache_key = (value, dtype, device) if cache_key not in small_tensor_cache: ret = Const(value, dtype, device, None) small_tensor_cache[cache_key] = ret else: ret = small_tensor_cache[cache_key] return ret def get_scalar_zero(dtype=None, device=None): return _get_scalar_tensor_with_value(0, dtype, device) def get_scalar_zero_point_five(dtype=None, device=None): return _get_scalar_tensor_with_value(0.5, dtype, device) def get_scalar_one(dtype=None, device=None): return _get_scalar_tensor_with_value(1, dtype, device) def get_scalar_two(dtype=None, device=None): return _get_scalar_tensor_with_value(2, dtype, device)