# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import abc import collections from functools import lru_cache from typing import Union import numpy as np from .. import _config from .._imperative_rt.common import CompNode from .._imperative_rt.core2 import ( SymbolVar, Tensor, apply, astype_cpp, broadcast_cpp, dtype_promotion, getitem_cpp, ) from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp from ..ops import builtin from . import amp from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph _ElwMod = builtin.Elemwise.Mode def _elwise_apply(args, mode): op = builtin.Elemwise(mode) (result,) = apply(op, *args) return result def _elwise(*args, mode): return _elwise_apply(args, mode) @lru_cache(maxsize=None) def _get_extentedMatrixMulOp( device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, ): @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) def extentedMatrixMulOp(inputs, f, c): assert len(inputs) == 2 inp1, inp2 = inputs _dim1, _dim2 = dim1, dim2 def build_shape_head(shape, idx=-1): # shape[:idx] return f( builtin.Subtensor(items=[[0, False, True, False, False]]), shape, c(idx, "int32"), ) def build_shape_tail(shape, idx=-1): # shape[idx:] return f( builtin.Subtensor(items=[[0, True, False, False, False]]), shape, c(idx, "int32"), ) remove_row, remove_col = False, False if _dim1 == 1: _dim1 = 2 remove_row = True if _dim2 == 1: _dim2 = 2 remove_col = True if remove_row: inp1 = f(builtin.AddAxis(axis=[0,]), inp1) if remove_col: inp2 = f(builtin.AddAxis(axis=[1,]), inp2) shape1 = f(builtin.GetVarShape(), inp1) shape2 = f(builtin.GetVarShape(), inp2) if _dim1 > 2: inp1 = f( builtin.Reshape(), inp1, f( builtin.Concat(axis=0, comp_node=device), f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), build_shape_tail(shape1), ), ) if _dim2 > 2: inp2 = f( builtin.Reshape(), inp2, f( builtin.Concat(axis=0, comp_node=device), f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), build_shape_tail(shape2), ), ) op = builtin.MatrixMul( transposeA=transpose_a, transposeB=transpose_b, compute_mode=compute_mode, format=format, strategy=strategy.value, ) result = f(op, inp1, inp2) result_shape = f(builtin.GetVarShape(), result) if _dim1 > 2: result = f( builtin.Reshape(), result, f( builtin.Concat(axis=0, comp_node=device), build_shape_head(shape1), build_shape_tail(result_shape), ), ) if _dim2 > 2: result = f( builtin.Reshape(), result, f( builtin.Concat(axis=0, comp_node=device), build_shape_head(shape2), build_shape_tail(result_shape), ), ) maxdim = _dim1 if _dim1 > _dim2 else _dim2 if remove_row: result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) if remove_col: result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) return (result,), (True,) return extentedMatrixMulOp @lru_cache(maxsize=None) def _get_extentedBatchedMatrixMulOp( device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, ): @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) def extentedBatchedMatrixMulOp(inputs, f, c): assert len(inputs) == 2 inp1, inp2 = inputs _dim1, _dim2 = dim1, dim2 def build_shape_head(shape, idx=-2): # shape[:idx] return f( builtin.Subtensor(items=[[0, False, True, False, False]]), shape, c(idx, "int32"), ) def build_shape_tail(shape, idx=-2): # shape[idx:] return f( builtin.Subtensor(items=[[0, True, False, False, False]]), shape, c(idx, "int32"), ) remove_row, remove_col = False, False if _dim1 == 1: _dim1 = 2 remove_row = True if _dim2 == 1: _dim2 = 2 remove_col = True if remove_row: inp1 = f(builtin.AddAxis(axis=[0,]), inp1) if remove_col: inp2 = f(builtin.AddAxis(axis=[1,]), inp2) shape1 = f(builtin.GetVarShape(), inp1) shape2 = f(builtin.GetVarShape(), inp2) maxdim = _dim1 if _dim1 > _dim2 else _dim2 if _dim1 > _dim2: # broadcast shape2 = f( builtin.Concat(axis=0, comp_node=device), build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] shape2, ) inp2 = f(builtin.Broadcast(), inp2, shape2) batch_shape = build_shape_head(shape1) if _dim2 > _dim1: # broadcast shape1 = f( builtin.Concat(axis=0, comp_node=device), build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] shape1, ) inp1 = f(builtin.Broadcast(), inp1, shape1) batch_shape = build_shape_head(shape2) if _dim1 == _dim2: batch_shape = build_shape_head(shape1) # compress inputs to 3d if maxdim > 3: inp1 = f( builtin.Reshape(), inp1, f( builtin.Concat(axis=0, comp_node=device), f(builtin.Reduce(mode="product", axis=0), batch_shape), build_shape_tail(shape1), ), ) inp2 = f( builtin.Reshape(), inp2, f( builtin.Concat(axis=0, comp_node=device), f(builtin.Reduce(mode="product", axis=0), batch_shape), build_shape_tail(shape2), ), ) op = builtin.BatchedMatrixMul( transposeA=transpose_a, transposeB=transpose_b, compute_mode=compute_mode, format=format, strategy=strategy.value, ) result = f(op, inp1, inp2) if maxdim > 3: result = f( builtin.Reshape(), result, f( builtin.Concat(axis=0, comp_node=device), batch_shape, build_shape_tail(f(builtin.GetVarShape(), result)), ), ) if remove_row: result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) if remove_col: result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) return (result,), (True,) return extentedBatchedMatrixMulOp class _Hashable: def __init__(self, value) -> None: self.value = value def __hash__(self) -> int: return hash(str(self.value)) def __eq__(self, o: object) -> bool: if not isinstance(o, _Hashable): return False return self.value == o.value def _matmul( inp1, inp2, transpose_a=False, transpose_b=False, compute_mode="default", format="default", ): if amp._enabled: compute_mode = "float32" inp1, inp2 = cast_tensors(inp1, inp2) else: dtype = dtype_promotion(inp1, inp2) if inp1.dtype != dtype: inp1 = inp1.astype(dtype) if inp2.dtype != dtype: inp2 = inp2.astype(dtype) dim1, dim2 = inp1.ndim, inp2.ndim assert dim1 > 0 and dim2 > 0 maxdim = dim1 if dim1 > dim2 else dim2 compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) Strategy = builtin.ops.MatrixMul.Strategy strategy = Strategy(0) if _config._benchmark_kernel: strategy |= Strategy.PROFILE else: strategy |= Strategy.HEURISTIC if _config._deterministic_kernel: strategy |= Strategy.REPRODUCIBLE if dim1 == 1 and dim2 == 1: # dispatch to Dot (result,) = apply(builtin.Dot(), inp1, inp2) return result elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul extentedMatrixMulOp = _get_extentedMatrixMulOp( inp1.device, inp1.dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy=_Hashable(strategy), ) (result,) = apply(extentedMatrixMulOp(), inp1, inp2) return result else: # dispath to BatchedMatrixMul extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( inp1.device, inp1.dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy=_Hashable(strategy), ) (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) return result def _unary_elwise(mode): def f(self): return _elwise(self, mode=mode) return f def _binary_elwise(mode, rev=False): if not rev: def f(self, value): return _elwise(self, value, mode=mode) else: def f(self, value): return _elwise(value, self, mode=mode) return f def _logical_unary_elwise(mode, rev=False): def f(self): if self.dtype != np.bool_: raise TypeError("{} requires a bool tensor".format(mode)) return _elwise(self, mode=mode) return f def _logical_binary_elwise(mode, rev=False): if not rev: def f(self, value): if self.dtype != np.bool_ or value.dtype != np.bool_: raise TypeError("{} requires 2 bool tensors".format(mode)) return _elwise(self, value, mode=mode) else: def f(self, value): if self.dtype != np.bool_ or value.dtype != np.bool_: raise TypeError("{} requires 2 bool tensors".format(mode)) return _elwise(value, self, mode=mode) return f def _reduce(mode): def f(self, axis=None, keepdims: bool = False): data = self if axis is None: assert not keepdims, "can not set axis=None and keepdims=True" result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) elif isinstance(axis, collections.abc.Iterable): axis = _normalize_axis(self.ndim, axis, reverse=True) for ai in axis: op = builtin.Reduce(mode=mode, axis=ai) (data,) = apply(op, data) if not keepdims: data = squeeze_cpp(data, ai) result = data else: # builtin.Reduce already accept negtive axis op = builtin.Reduce(mode=mode, axis=axis) (result,) = apply(op, data) if not keepdims: result = squeeze_cpp(result, axis) return result return f def _inplace(f): def g(self, value): result = f(self, value) if result is NotImplemented: raise NotImplementedError self._reset(result) return self return g def _todo(*_): raise NotImplementedError def _expand_args(args): if len(args) == 1: if isinstance( args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), ): args = args[0] return args class ArrayMethodMixin(abc.ABC): # enable tensor to be converted to numpy array __array_priority__ = 1001 def __array__(self, dtype=None): if dtype == None: return self.numpy() return self.numpy().astype(dtype) def __array_wrap__(self, array): Wrapper = type(self) return Wrapper(array, dtype=array.dtype, device=self.device) @abc.abstractmethod def _reset(self, other): pass @abc.abstractproperty def dtype(self) -> np.dtype: pass @abc.abstractproperty def shape(self) -> Union[tuple, Tensor]: pass @abc.abstractproperty def _tuple_shape(self) -> tuple: pass @abc.abstractmethod def numpy(self) -> np.ndarray: pass __hash__ = None # due to __eq__ diviates from python convention __lt__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LT).astype("bool") __le__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LEQ).astype("bool") __gt__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LT).astype("bool") __ge__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LEQ).astype("bool") __eq__ = lambda self, value: _elwise(self, value, mode=_ElwMod.EQ).astype("bool") __ne__ = lambda self, value: _elwise( _elwise(self, value, mode=_ElwMod.EQ).astype("bool"), mode=_ElwMod.NOT, ) __neg__ = _unary_elwise(_ElwMod.NEGATE) __pos__ = lambda self: self __abs__ = _unary_elwise(_ElwMod.ABS) __invert__ = _logical_unary_elwise(_ElwMod.NOT) __round__ = _unary_elwise(_ElwMod.ROUND) __trunc__ = _todo __floor__ = _unary_elwise(_ElwMod.FLOOR) __ceil__ = _unary_elwise(_ElwMod.CEIL) __add__ = _binary_elwise(_ElwMod.ADD) __sub__ = _binary_elwise(_ElwMod.SUB) __mul__ = _binary_elwise(_ElwMod.MUL) __matmul__ = lambda self, other: _matmul(self, other) __truediv__ = _binary_elwise(_ElwMod.TRUE_DIV) __floordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV) __mod__ = _binary_elwise(_ElwMod.MOD) # __divmode__ __pow__ = _binary_elwise(_ElwMod.POW) __lshift__ = _binary_elwise(_ElwMod.SHL) __rshift__ = _binary_elwise(_ElwMod.SHR) __and__ = _logical_binary_elwise(_ElwMod.AND) __or__ = _logical_binary_elwise(_ElwMod.OR) __xor__ = _logical_binary_elwise(_ElwMod.XOR) __radd__ = _binary_elwise(_ElwMod.ADD, rev=1) __rsub__ = _binary_elwise(_ElwMod.SUB, rev=1) __rmul__ = _binary_elwise(_ElwMod.MUL, rev=1) __rmatmul__ = lambda self, other: _matmul(other, self) __rtruediv__ = _binary_elwise(_ElwMod.TRUE_DIV, rev=1) __rfloordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV, rev=1) __rmod__ = _binary_elwise(_ElwMod.MOD, rev=1) # __rdivmode__ __rpow__ = _binary_elwise(_ElwMod.POW, rev=1) __rlshift__ = _binary_elwise(_ElwMod.SHL, rev=1) __rrshift__ = _binary_elwise(_ElwMod.SHR, rev=1) __rand__ = _logical_binary_elwise(_ElwMod.AND, rev=1) __ror__ = _logical_binary_elwise(_ElwMod.OR, rev=1) __rxor__ = _logical_binary_elwise(_ElwMod.XOR, rev=1) __iadd__ = _inplace(__add__) __isub__ = _inplace(__sub__) __imul__ = _inplace(__mul__) __imatmul__ = _inplace(__matmul__) __itruediv__ = _inplace(__truediv__) __ifloordiv__ = _inplace(__floordiv__) __imod__ = _inplace(__mod__) __ipow__ = _inplace(__pow__) __ilshift__ = _inplace(__lshift__) __irshift__ = _inplace(__rshift__) __iand__ = _inplace(__and__) __ior__ = _inplace(__or__) __ixor__ = _inplace(__xor__) __index__ = lambda self: self.item().__index__() __bool__ = lambda self: bool(self.item()) __int__ = lambda self: int(self.item()) __float__ = lambda self: float(self.item()) __complex__ = lambda self: complex(self.item()) def __len__(self): shape = self._tuple_shape if shape: return int(shape[0]) raise TypeError("ndim is 0") def __iter__(self): for i in range(len(self)): yield self[i] def __getitem__(self, index): return getitem_cpp(self, index) def __setitem__(self, index, value): if index is not Ellipsis: value = setitem_cpp(self, index, value) self._reset(value) __contains__ = _todo @property def ndim(self): r"""Returns the number of dimensions of self :class:`~.Tensor`.""" shape = self._tuple_shape if shape is None: raise ValueError("unkown ndim") return len(shape) @property def size(self): r"""Returns the size of the self :class:`~.Tensor`. The returned value is a subclass of :class:`tuple`. """ shape = self.shape if shape.__class__ is tuple: return np.prod(self.shape).item() return shape.prod() @property def T(self): r"""alias of :attr:`~.Tensor.transpose`.""" return self.transpose() def item(self, *args): r"""Returns the value of this :class:`~.Tensor` as a standard Python :class:`numbers.Number`. This only works for tensors with one element. For other cases, see :meth:`~.tolist`. """ if not args: if isinstance(self.size, int): assert self.size == 1 return self.numpy().item() return self[args].item() def tolist(self): r"""Returns the tensor as a (nested) list. For scalars, a standard Python number is returned, just like with :meth:`~.item`. Tensors are automatically moved to the CPU first if necessary. This operation is not differentiable. """ return self.numpy().tolist() def astype(self, dtype): r"""Returns a :class:`Tensor` with the same data and number of elements with the specified :attr:`~.Tensor.dtype`. """ return astype_cpp(self, dtype) def reshape(self, *args): r"""See :func:`~.reshape`.""" return reshape_cpp(self, args) # FIXME: remove this method def _broadcast(self, *args): return broadcast_cpp(self, args) def transpose(self, *args): r"""See :func:`~.transpose`.""" return transpose_cpp(self, args) def flatten(self): r"""See :func:`~.flatten`.""" return reshape_cpp(self, (-1,)) def sum(self, axis=None, keepdims: bool = False): r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. If ``axis`` is a list of axises, reduce over all of them. If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Args: axis: the dimension or dimensions to reduce. keepdims: whether the output tensor has ndim retained or not. Returns: output tensor. Examples: .. testcode:: from megengine import tensor a = tensor([False, True, True, False]) b = tensor([1.0, 2.0, 3.0, 4.0]) print(a.sum().numpy()) print(b.sum().numpy()) Outputs: .. testoutput:: 2 10.0 """ return _reduce("sum")(self, axis, keepdims) def prod(self, axis=None, keepdims: bool = False): r"""Returns the product of each row of the input tensor in the given dimension ``axis``. If ``axis`` is a list of axises, reduce over all of them. If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Args: axis: the dimension or dimensions to reduce. keepdims: whether the output tensor has ndim retained or not. Returns: output tensor. Examples: .. testcode:: from megengine import tensor a = tensor([False, True, True, False]) b = tensor([1.0, 2.0, 3.0, 4.0]) print(a.prod().numpy()) print(b.prod().numpy()) Outputs: .. testoutput:: 0 24.0 """ return _reduce("product")(self, axis, keepdims) def min(self, axis=None, keepdims: bool = False): r"""Returns the min value of each row of the input tensor in the given dimension ``axis``. If ``axis`` is a list of axises, reduce over all of them. If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Args: axis: the dimension or dimensions to reduce. keepdims: whether the output tensor has ndim retained or not. Returns: output tensor. Examples: .. testcode:: from megengine import tensor a = tensor([False, True, True, False]) b = tensor([1.0, 2.0, 3.0, 4.0]) print(a.min().numpy()) print(b.min().numpy()) Outputs: .. testoutput:: False 1.0 """ return _reduce("min")(self, axis, keepdims) def max(self, axis=None, keepdims: bool = False): r"""Returns the max value of each row of the input tensor in the given dimension ``axis``. If ``axis`` is a list of axises, reduce over all of them. If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Args: axis: the dimension or dimensions to reduce. keepdims: whether the output tensor has ndim retained or not. Returns: output tensor. Examples: .. testcode:: from megengine import tensor a = tensor([False, True, True, False]) b = tensor([1.0, 2.0, 3.0, 4.0]) print(a.max().numpy()) print(b.max().numpy()) Outputs: .. testoutput:: True 4.0 """ return _reduce("max")(self, axis, keepdims) def mean(self, axis=None, keepdims: bool = False): r"""Returns the mean value of each row of the input tensor in the given dimension ``axis``. If ``axis`` is a list of axises, reduce over all of them. If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Args: axis: the dimension or dimensions to reduce. keepdims: whether the output tensor has ndim retained or not. Returns: output tensor. Examples: .. testcode:: from megengine import tensor a = tensor([False, True, True, False]) b = tensor([1.0, 2.0, 3.0, 4.0]) print(a.mean().numpy()) print(b.mean().numpy()) Outputs: .. testoutput:: 0.5 2.5 """ return _reduce("mean")(self, axis, keepdims)