################################################################################################# # # Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# """ Definition of CuTe Layouts and functions to manipulate them """ from itertools import chain from typing import Union from .int_tuple import * class LayoutBase: pass def is_layout(x): return isinstance(x, LayoutBase) class Layout(LayoutBase): def __init__(self, _shape, _stride=None): self.shape = _shape if _stride is None: self.stride = prefix_product(self.shape) else: self.stride = _stride # operator == def __eq__(self, other): return self.shape == other.shape and self.stride == other.stride # operator len(L) (len [rank] like tuples) def __len__(self): if is_tuple(self.shape): return len(self.shape) else: return 1 # operator () (map coord to idx) def __call__(self, *args): """ Map a logical coordinate to a linear index (Coord has no Underscore slice operators) OR Slice the layout and return the sublayout (Coord has an Underscore slice op) Follow the same behavior of `Layout::operator(Coord const&)` in cute C++ """ if has_none(args): if len(args) == 1: return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride)) else: return Layout(slice_(args, self.shape), slice_(args, self.stride)) else: if len(args) == 1: return crd2idx(args[0], self.shape, self.stride) else: return crd2idx(args, self.shape, self.stride) # operator [] (get-i like tuples) def __getitem__(self, i): if is_tuple(self.shape): return Layout(self.shape[i], self.stride[i]) else: assert i == 0 return Layout(self.shape, self.stride) # size(layout) Size of the domain def size(self): return product(self.shape) # cosize(layout) Size of the codomain def cosize(self): return self(self.size() - 1) + 1 # print and str def __str__(self): return f"{self.shape}:{self.stride}" # error msgs and representation def __repr__(self): return f"Layout({self.shape},{self.stride})" # Make Layout from a list of layouts (each layout it's own mode in the result) def make_layout(*layouts): if len(layouts) == 1 and not is_layout(layouts[0]): layouts = layouts[0] shape, stride = zip(*((a.shape,a.stride) for a in layouts)) return Layout(shape, stride) # Size of the domain def size(layout): if is_layout(layout): return layout.size() return product(layout) # Size of the codomain def cosize(layout): return layout.cosize() # Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function def coalesce(layout, profile=None): if is_tuple(profile): assert len(layout) >= len(profile) return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))), (layout[i] for i in range(len(profile),len(layout))))) result_shape = [1] result_stride = [0] for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): # skip their shape-1s if shape == 1: continue # replace our shape-1 with anything elif result_shape[-1] == 1: result_shape[-1] = shape result_stride[-1] = stride # merge modes if the shape*stride match elif result_shape[-1] * result_stride[-1] == stride: result_shape[-1] = result_shape[-1] * shape # append a new mode else: result_shape.append(shape) result_stride.append(stride) if len(result_shape) == 1: return Layout(result_shape[0], result_stride[0]) else: return Layout(tuple(result_shape), tuple(result_stride)) # Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them def filter(layout, profile=None): if is_tuple(profile): assert len(layout) >= len(profile) return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))), (layout[i] for i in range(len(profile),len(layout))))) result_shape = [] result_stride = [] for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): # skip their shape-1s and stride-0s if not (shape == 1 or stride == 0): result_shape.append(shape) result_stride.append(stride) if len(result_shape) == 0: return Layout(1,0) else: return coalesce(Layout(tuple(result_shape), tuple(result_stride))) # Layout composition # Use tuples-of-layouts to perform this operation by-mode and None as no-op def composition(layoutA, layoutB): if layoutB is None: return layoutA elif is_int(layoutB): return composition(layoutA, Layout(layoutB)) elif is_tuple(layoutB): assert len(layoutA) >= len(layoutB) return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), (layoutA[i] for i in range(len(layoutB),len(layoutA))))) elif is_tuple(layoutB.shape): return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB) if layoutB.stride == 0: return Layout(layoutB.shape, 0) else: result_shape = [] result_stride = [] rest_shape = layoutB.shape rest_stride = layoutB.stride for (s, d) in zip(flatten(layoutA.shape)[:-1], flatten(layoutA.stride)[:-1]): s1 = shape_div(s, rest_stride) result_shape.append(min(s1,rest_shape)) result_stride.append(rest_stride * d) rest_shape = shape_div(rest_shape, abs(s1)) rest_stride = shape_div(rest_stride, s) result_shape.append(rest_shape) result_stride.append(rest_stride * flatten(layoutA.stride)[-1]) return coalesce(Layout(tuple(result_shape), tuple(result_stride))) # Layout complement def complement(layout, max_idx=1): if is_int(layout): return complement(Layout(layout)) result_shape = [] result_stride = [] current_idx = 1 sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape))) for (stride, shape) in sorted_DS: if stride == 0 or shape == 1: continue in_bound = current_idx <= shape * stride # To support symbolic value which can't be evaluated now assert (type(in_bound) is not bool) or in_bound result_shape.append(stride // current_idx) result_stride.append(current_idx) current_idx = shape * stride result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div result_stride.append(current_idx) return coalesce(Layout(tuple(result_shape), tuple(result_stride))) # Layout right inverse def right_inverse(layout): if layout is None: return None elif is_int(layout): return Layout(layout) result_shape = [] result_stride = [] current_idx = 1 flat_shape = flatten(layout.shape) flat_stride = flatten(layout.stride) sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape))) for (stride,shape,rstride) in sorted_DSA: if shape == 1: continue if current_idx != stride: break result_shape.append(shape) result_stride.append(rstride) current_idx = shape * stride return coalesce(Layout(tuple(result_shape), tuple(result_stride))) # Layout left inverse def left_inverse(layout): if layout is None: return None elif is_int(layout): return Layout(layout) return right_inverse(make_layout(layout, complement(layout))) # Split a layout by the composition of B and the "rest" # Use tuples-of-layouts to perform this operation by-mode and None as no-op def logical_divide(layoutA, layoutB): if layoutB is None: return layoutA elif is_int(layoutB): return logical_divide(layoutA, Layout(layoutB)) elif is_tuple(layoutB): assert len(layoutA) >= len(layoutB) return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), (layoutA[i] for i in range(len(layoutB),len(layoutA))))) return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA)))) # Reproduce a layoutA over a layoutB # Use tuples-of-layouts to perform this operation by-mode and None as no-op def logical_product(layoutA, layoutB): if layoutB is None: return layoutA elif is_int(layoutB): return logical_divide(layoutA, Layout(layoutB)) elif is_tuple(layoutB): assert len(layoutA) >= len(layoutB) return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), (layoutA[i] for i in range(len(layoutB),len(layoutA))))) return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB)); # Gather the modes from a hierarchical logical_divide or logical_product def hier_unzip(splitter, layoutA, layoutB): if layoutB is None: return make_layout(Layout(1,0), layoutA) elif is_tuple(layoutB): assert len(layoutA) >= len(layoutB) # A layout with shape ((A,a),(B,b),(C,c)) split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB))) # Gather to shape ((A,B,C,...),(a,b,c,...,y,z)) return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))), make_layout(chain((split[i][1] for i in range( 0,len(layoutB))), (layoutA[i] for i in range(len(layoutB),len(layoutA)))))) # splitter must return a rank-2 layout return splitter(layoutA, layoutB) # Apply logical divide hierarchically and gather the split modes into two modes def zipped_divide(layoutA, layoutB): return hier_unzip(logical_divide, layoutA, layoutB) # Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode def tiled_divide(layoutA, layoutB): result = zipped_divide(layoutA, layoutB) return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) # Apply logical product hierarchically and gather the split modes into two modes def zipped_product(layoutA, layoutB): return hier_unzip(logical_product, layoutA, layoutB) # Perform logical product hierarchically and gather tiles (B-layouts) into a new mode def tiled_product(layoutA, layoutB): result = zipped_product(layoutA, layoutB) return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) def slice_and_offset(crd: tuple, layout: Layout): return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)), crd2idx(crd, layout.shape, layout.stride))