use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; use crate::{ device::Device, dim::{larger_shape, DimDyn, DimTrait}, matrix::{Matrix, Owned, Ref, Repr}, num::Num, }; macro_rules! call_on_self { ($self:ident, $F:ident, $($args:expr),*) => { $self.$F($($args),*) }; } macro_rules! impl_arithmetic_ops { ( $trait:ident, $trait_method:ident, $assign_trait:ident, $assign_trait_method:ident, $scalr:ident, $scalar_assign:ident, $array:ident, $array_assign:ident ) => { // Add for Matrix impl, D: Device> $trait for Matrix { type Output = Matrix, DimDyn, D>; fn $trait_method(self, rhs: T) -> Self::Output { let s = self.to_ref().into_dyn_dim(); let mut owned = Matrix::alloc_like(&self).into_dyn_dim(); { let mut ref_mut = owned.to_ref_mut(); call_on_self!(ref_mut, $scalr, &s, rhs); } owned } } impl, D: Device> $trait for &Matrix { type Output = Matrix, DimDyn, D>; fn $trait_method(self, rhs: T) -> Self::Output { let mut owned = Matrix::alloc_like(self); let s = self.to_ref().into_dyn_dim(); { let mut ref_mut = owned.to_ref_mut(); call_on_self!(ref_mut, $scalr, &s, rhs); } owned } } // Add> for Matrix impl< T: Num, RS: Repr, SS: DimTrait, RO: Repr, SO: DimTrait, D: Device, > $trait> for Matrix { type Output = Matrix, DimDyn, D>; fn $trait_method(self, rhs: Matrix) -> Self::Output { let larger = if self.shape().len() == rhs.shape().len() { DimDyn::from(larger_shape(self.shape(), rhs.shape())) } else if self.shape().len() > rhs.shape().len() { DimDyn::from(self.shape().slice()) } else { DimDyn::from(rhs.shape().slice()) }; let mut owned: Matrix, DimDyn, D> = Matrix::alloc(larger.slice()); { let mut ref_mut = owned.to_ref_mut(); let s = self.to_ref().into_dyn_dim(); let rhs = rhs.to_ref().into_dyn_dim(); call_on_self!(ref_mut, $array, &s, &rhs); } owned } } impl< T: Num, RS: Repr, SS: DimTrait, RO: Repr, SO: DimTrait, D: Device, > $trait<&Matrix> for Matrix { type Output = Matrix, DimDyn, D>; fn $trait_method(self, rhs: &Matrix) -> Self::Output { let larger = if self.shape().len() == rhs.shape().len() { DimDyn::from(larger_shape(self.shape(), rhs.shape())) } else if self.shape().len() > rhs.shape().len() { DimDyn::from(self.shape().slice()) } else { DimDyn::from(rhs.shape().slice()) }; let mut owned: Matrix, DimDyn, D> = Matrix::alloc(larger.slice()); { let mut ref_mut = owned.to_ref_mut(); let s = self.to_ref().into_dyn_dim(); let rhs = rhs.to_ref().into_dyn_dim(); call_on_self!(ref_mut, $array, &s, &rhs); } owned } } impl< T: Num, RS: Repr, SS: DimTrait, RO: Repr, SO: DimTrait, D: Device, > $trait> for &Matrix { type Output = Matrix, DimDyn, D>; fn $trait_method(self, rhs: Matrix) -> Self::Output { let larger = if self.shape().len() == rhs.shape().len() { DimDyn::from(larger_shape(self.shape(), rhs.shape())) } else if self.shape().len() > rhs.shape().len() { DimDyn::from(self.shape().slice()) } else { DimDyn::from(rhs.shape().slice()) }; let mut owned: Matrix, DimDyn, D> = Matrix::alloc(larger.slice()); { let mut ref_mut = owned.to_ref_mut(); let s = self.to_ref().into_dyn_dim(); let rhs = rhs.to_ref().into_dyn_dim(); call_on_self!(ref_mut, $array, &s, &rhs); } owned } } impl< T: Num, RS: Repr, SS: DimTrait, RO: Repr, SO: DimTrait, D: Device, > $trait<&Matrix> for &Matrix { type Output = Matrix, DimDyn, D>; fn $trait_method(self, rhs: &Matrix) -> Self::Output { let larger = if self.shape().len() == rhs.shape().len() { DimDyn::from(larger_shape(self.shape(), rhs.shape())) } else if self.shape().len() > rhs.shape().len() { DimDyn::from(self.shape().slice()) } else { DimDyn::from(rhs.shape().slice()) }; let mut owned: Matrix, DimDyn, D> = Matrix::alloc(larger.slice()); { let mut ref_mut = owned.to_ref_mut(); let s = self.to_ref().into_dyn_dim(); let rhs = rhs.to_ref().into_dyn_dim(); call_on_self!(ref_mut, $array, &s, &rhs); } owned } } // AddAssign for Matrix, S, D> // impl $assign_trait for Matrix, S, D> { impl $assign_trait for Matrix, DimDyn, D> { fn $assign_trait_method(&mut self, rhs: T) { call_on_self!(self, $scalar_assign, rhs); } } // AddAssign for Matrix, S, D> impl $assign_trait for Matrix, S, D> { fn $assign_trait_method(&mut self, rhs: T) { let mut ref_mut = self.to_ref_mut().into_dyn_dim(); call_on_self!(ref_mut, $scalar_assign, rhs); } } // AddAssign> for Matrix, SS, D> impl, SO: DimTrait, D: Device> $assign_trait> for Matrix, SS, D> { fn $assign_trait_method(&mut self, rhs: Matrix) { let mut ref_mut = self.to_ref_mut().into_dyn_dim(); let rhs = rhs.to_ref().into_dyn_dim(); call_on_self!(ref_mut, $array_assign, &rhs); } } // AddAssign> for Matrix, SS, D> // impl, SO: DimTrait, SS: DimTrait, D: Device> impl, SO: DimTrait, D: Device> $assign_trait> for Matrix, DimDyn, D> { fn $assign_trait_method(&mut self, rhs: Matrix) { let rhs = rhs.to_ref().into_dyn_dim(); call_on_self!(self, $array_assign, &rhs); } } }; } impl_arithmetic_ops!( Add, add, AddAssign, add_assign, add_scalar, add_scalar_assign, add_array, add_assign ); impl_arithmetic_ops!( Sub, sub, SubAssign, sub_assign, sub_scalar, sub_scalar_assign, sub_array, sub_assign ); impl_arithmetic_ops!( Mul, mul, MulAssign, mul_assign, mul_scalar, mul_scalar_assign, mul_array, mul_assign ); impl_arithmetic_ops!( Div, div, DivAssign, div_assign, div_scalar, div_scalar_assign, div_array, div_assign );