# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ This is a set of function wrappers that override the default numpy versions. Interoperability functions for pytorch and Faiss: Importing this will allow pytorch Tensors (CPU or GPU) to be used as arguments to Faiss indexes and other functions. Torch GPU tensors can only be used with Faiss GPU indexes. If this is imported with a package that supports Faiss GPU, the necessary stream synchronization with the current pytorch stream will be automatically performed. Numpy ndarrays can continue to be used in the Faiss python interface after importing this file. All arguments must be uniformly either numpy ndarrays or Torch tensors; no mixing is allowed. """ import faiss import torch import contextlib import inspect import sys import numpy as np def swig_ptr_from_UInt8Tensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.uint8 return faiss.cast_integer_to_uint8_ptr( x.storage().data_ptr() + x.storage_offset()) def swig_ptr_from_HalfTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.float16 # no canonical half type in C/C++ return faiss.cast_integer_to_void_ptr( x.storage().data_ptr() + x.storage_offset() * 4) def swig_ptr_from_FloatTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.float32 return faiss.cast_integer_to_float_ptr( x.storage().data_ptr() + x.storage_offset() * 4) def swig_ptr_from_IntTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.int32, 'dtype=%s' % x.dtype return faiss.cast_integer_to_int_ptr( x.storage().data_ptr() + x.storage_offset() * 8) def swig_ptr_from_IndicesTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.int64, 'dtype=%s' % x.dtype return faiss.cast_integer_to_idx_t_ptr( x.storage().data_ptr() + x.storage_offset() * 8) @contextlib.contextmanager def using_stream(res, pytorch_stream=None): """ Creates a scoping object to make Faiss GPU use the same stream as pytorch, based on torch.cuda.current_stream(). Or, a specific pytorch stream can be passed in as a second argument, in which case we will use that stream. """ if pytorch_stream is None: pytorch_stream = torch.cuda.current_stream() # This is the cudaStream_t that we wish to use cuda_stream_s = faiss.cast_integer_to_cudastream_t(pytorch_stream.cuda_stream) # So we can revert GpuResources stream state upon exit prior_dev = torch.cuda.current_device() prior_stream = res.getDefaultStream(torch.cuda.current_device()) res.setDefaultStream(torch.cuda.current_device(), cuda_stream_s) # Do the user work try: yield finally: res.setDefaultStream(prior_dev, prior_stream) def torch_replace_method(the_class, name, replacement, ignore_missing=False, ignore_no_base=False): try: orig_method = getattr(the_class, name) except AttributeError: if ignore_missing: return raise if orig_method.__name__ == 'torch_replacement_' + name: # replacement was done in parent class return # We should already have the numpy replacement methods patched assert ignore_no_base or (orig_method.__name__ == 'replacement_' + name) setattr(the_class, name + '_numpy', orig_method) setattr(the_class, name, replacement) def handle_torch_Index(the_class): def torch_replacement_add(self, x): if type(x) is np.ndarray: # forward to faiss __init__.py base method return self.add_numpy(x) assert type(x) is torch.Tensor n, d = x.shape assert d == self.d x_ptr = swig_ptr_from_FloatTensor(x) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.add_c(n, x_ptr) else: # CPU torch self.add_c(n, x_ptr) def torch_replacement_add_with_ids(self, x, ids): if type(x) is np.ndarray: # forward to faiss __init__.py base method return self.add_with_ids_numpy(x, ids) assert type(x) is torch.Tensor n, d = x.shape assert d == self.d x_ptr = swig_ptr_from_FloatTensor(x) assert type(ids) is torch.Tensor assert ids.shape == (n, ), 'not same number of vectors as ids' ids_ptr = swig_ptr_from_IndicesTensor(ids) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.add_with_ids_c(n, x_ptr, ids_ptr) else: # CPU torch self.add_with_ids_c(n, x_ptr, ids_ptr) def torch_replacement_assign(self, x, k, labels=None): if type(x) is np.ndarray: # forward to faiss __init__.py base method return self.assign_numpy(x, k, labels) assert type(x) is torch.Tensor n, d = x.shape assert d == self.d x_ptr = swig_ptr_from_FloatTensor(x) if labels is None: labels = torch.empty(n, k, device=x.device, dtype=torch.int64) else: assert type(labels) is torch.Tensor assert labels.shape == (n, k) L_ptr = swig_ptr_from_IndicesTensor(labels) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.assign_c(n, x_ptr, L_ptr, k) else: # CPU torch self.assign_c(n, x_ptr, L_ptr, k) return labels def torch_replacement_train(self, x): if type(x) is np.ndarray: # forward to faiss __init__.py base method return self.train_numpy(x) assert type(x) is torch.Tensor n, d = x.shape assert d == self.d x_ptr = swig_ptr_from_FloatTensor(x) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.train_c(n, x_ptr) else: # CPU torch self.train_c(n, x_ptr) def torch_replacement_search(self, x, k, D=None, I=None): if type(x) is np.ndarray: # forward to faiss __init__.py base method return self.search_numpy(x, k, D, I) assert type(x) is torch.Tensor n, d = x.shape assert d == self.d x_ptr = swig_ptr_from_FloatTensor(x) if D is None: D = torch.empty(n, k, device=x.device, dtype=torch.float32) else: assert type(D) is torch.Tensor assert D.shape == (n, k) D_ptr = swig_ptr_from_FloatTensor(D) if I is None: I = torch.empty(n, k, device=x.device, dtype=torch.int64) else: assert type(I) is torch.Tensor assert I.shape == (n, k) I_ptr = swig_ptr_from_IndicesTensor(I) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.search_c(n, x_ptr, k, D_ptr, I_ptr) else: # CPU torch self.search_c(n, x_ptr, k, D_ptr, I_ptr) return D, I def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None): if type(x) is np.ndarray: # Forward to faiss __init__.py base method return self.search_and_reconstruct_numpy(x, k, D, I, R) assert type(x) is torch.Tensor n, d = x.shape assert d == self.d x_ptr = swig_ptr_from_FloatTensor(x) if D is None: D = torch.empty(n, k, device=x.device, dtype=torch.float32) else: assert type(D) is torch.Tensor assert D.shape == (n, k) D_ptr = swig_ptr_from_FloatTensor(D) if I is None: I = torch.empty(n, k, device=x.device, dtype=torch.int64) else: assert type(I) is torch.Tensor assert I.shape == (n, k) I_ptr = swig_ptr_from_IndicesTensor(I) if R is None: R = torch.empty(n, k, d, device=x.device, dtype=torch.float32) else: assert type(R) is torch.Tensor assert R.shape == (n, k, d) R_ptr = swig_ptr_from_FloatTensor(R) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr) else: # CPU torch self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr) return D, I, R def torch_replacement_remove_ids(self, x): # Not yet implemented assert type(x) is not torch.Tensor, 'remove_ids not yet implemented for torch' return self.remove_ids_numpy(x) def torch_replacement_reconstruct(self, key, x=None): # No tensor inputs are required, but with importing this module, we # assume that the default should be torch tensors. If we are passed a # numpy array, however, assume that the user is overriding this default if (x is not None) and (type(x) is np.ndarray): # Forward to faiss __init__.py base method return self.reconstruct_numpy(key, x) # If the index is a CPU index, the default device is CPU, otherwise we # produce a GPU tensor device = torch.device('cpu') if hasattr(self, 'getDevice'): # same device as the index device = torch.device('cuda', self.getDevice()) if x is None: x = torch.empty(self.d, device=device, dtype=torch.float32) else: assert type(x) is torch.Tensor assert x.shape == (self.d, ) x_ptr = swig_ptr_from_FloatTensor(x) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.reconstruct_c(key, x_ptr) else: # CPU torch self.reconstruct_c(key, x_ptr) return x def torch_replacement_reconstruct_n(self, n0, ni, x=None): # No tensor inputs are required, but with importing this module, we # assume that the default should be torch tensors. If we are passed a # numpy array, however, assume that the user is overriding this default if (x is not None) and (type(x) is np.ndarray): # Forward to faiss __init__.py base method return self.reconstruct_n_numpy(n0, ni, x) # If the index is a CPU index, the default device is CPU, otherwise we # produce a GPU tensor device = torch.device('cpu') if hasattr(self, 'getDevice'): # same device as the index device = torch.device('cuda', self.getDevice()) if x is None: x = torch.empty(ni, self.d, device=device, dtype=torch.float32) else: assert type(x) is torch.Tensor assert x.shape == (ni, self.d) x_ptr = swig_ptr_from_FloatTensor(x) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.reconstruct_n_c(n0, ni, x_ptr) else: # CPU torch self.reconstruct_n_c(n0, ni, x_ptr) return x def torch_replacement_update_vectors(self, keys, x): if type(keys) is np.ndarray: # Forward to faiss __init__.py base method return self.update_vectors_numpy(keys, x) assert type(keys) is torch.Tensor (n, ) = keys.shape keys_ptr = swig_ptr_from_IndicesTensor(keys) assert type(x) is torch.Tensor assert x.shape == (n, self.d) x_ptr = swig_ptr_from_FloatTensor(x) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.update_vectors_c(n, keys_ptr, x_ptr) else: # CPU torch self.update_vectors_c(n, keys_ptr, x_ptr) # Until the GPU version is implemented, we do not support pre-allocated # output buffers def torch_replacement_range_search(self, x, thresh): if type(x) is np.ndarray: # Forward to faiss __init__.py base method return self.range_search_numpy(x, thresh) assert type(x) is torch.Tensor n, d = x.shape assert d == self.d x_ptr = swig_ptr_from_FloatTensor(x) assert not x.is_cuda, 'Range search using GPU tensor not yet implemented' assert not hasattr(self, 'getDevice'), 'Range search on GPU index not yet implemented' res = faiss.RangeSearchResult(n) self.range_search_c(n, x_ptr, thresh, res) # get pointers and copy them # FIXME: no rev_swig_ptr equivalent for torch.Tensor, just convert # np to torch # NOTE: torch does not support np.uint64, just np.int64 lims = torch.from_numpy(faiss.rev_swig_ptr(res.lims, n + 1).copy().astype('int64')) nd = int(lims[-1]) D = torch.from_numpy(faiss.rev_swig_ptr(res.distances, nd).copy()) I = torch.from_numpy(faiss.rev_swig_ptr(res.labels, nd).copy()) return lims, D, I def torch_replacement_sa_encode(self, x, codes=None): if type(x) is np.ndarray: # Forward to faiss __init__.py base method return self.sa_encode_numpy(x, codes) assert type(x) is torch.Tensor n, d = x.shape assert d == self.d x_ptr = swig_ptr_from_FloatTensor(x) if codes is None: codes = torch.empty(n, self.sa_code_size(), dtype=torch.uint8) else: assert codes.shape == (n, self.sa_code_size()) codes_ptr = swig_ptr_from_UInt8Tensor(codes) if x.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.sa_encode_c(n, x_ptr, codes_ptr) else: # CPU torch self.sa_encode_c(n, x_ptr, codes_ptr) return codes def torch_replacement_sa_decode(self, codes, x=None): if type(codes) is np.ndarray: # Forward to faiss __init__.py base method return self.sa_decode_numpy(codes, x) assert type(codes) is torch.Tensor n, cs = codes.shape assert cs == self.sa_code_size() codes_ptr = swig_ptr_from_UInt8Tensor(codes) if x is None: x = torch.empty(n, self.d, dtype=torch.float32) else: assert type(x) is torch.Tensor assert x.shape == (n, self.d) x_ptr = swig_ptr_from_FloatTensor(x) if codes.is_cuda: assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed' # On the GPU, use proper stream ordering with using_stream(self.getResources()): self.sa_decode_c(n, codes_ptr, x_ptr) else: # CPU torch self.sa_decode_c(n, codes_ptr, x_ptr) return x torch_replace_method(the_class, 'add', torch_replacement_add) torch_replace_method(the_class, 'add_with_ids', torch_replacement_add_with_ids) torch_replace_method(the_class, 'assign', torch_replacement_assign) torch_replace_method(the_class, 'train', torch_replacement_train) torch_replace_method(the_class, 'search', torch_replacement_search) torch_replace_method(the_class, 'remove_ids', torch_replacement_remove_ids) torch_replace_method(the_class, 'reconstruct', torch_replacement_reconstruct) torch_replace_method(the_class, 'reconstruct_n', torch_replacement_reconstruct_n) torch_replace_method(the_class, 'range_search', torch_replacement_range_search) torch_replace_method(the_class, 'update_vectors', torch_replacement_update_vectors, ignore_missing=True) torch_replace_method(the_class, 'search_and_reconstruct', torch_replacement_search_and_reconstruct, ignore_missing=True) torch_replace_method(the_class, 'sa_encode', torch_replacement_sa_encode) torch_replace_method(the_class, 'sa_decode', torch_replacement_sa_decode) faiss_module = sys.modules['faiss'] # Re-patch anything that inherits from faiss.Index to add the torch bindings for symbol in dir(faiss_module): obj = getattr(faiss_module, symbol) if inspect.isclass(obj): the_class = obj if issubclass(the_class, faiss.Index): handle_torch_Index(the_class) # allows torch tensor usage with bfKnn def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2): if type(xb) is np.ndarray: # Forward to faiss __init__.py base method return faiss.knn_gpu_numpy(res, xq, xb, k, D, I, metric) nb, d = xb.size() if xb.is_contiguous(): xb_row_major = True elif xb.t().is_contiguous(): xb = xb.t() xb_row_major = False else: raise TypeError('matrix should be row or column-major') if xb.dtype == torch.float32: xb_type = faiss.DistanceDataType_F32 xb_ptr = swig_ptr_from_FloatTensor(xb) elif xb.dtype == torch.float16: xb_type = faiss.DistanceDataType_F16 xb_ptr = swig_ptr_from_HalfTensor(xb) else: raise TypeError('xb must be f32 or f16') nq, d2 = xq.size() assert d2 == d if xq.is_contiguous(): xq_row_major = True elif xq.t().is_contiguous(): xq = xq.t() xq_row_major = False else: raise TypeError('matrix should be row or column-major') if xq.dtype == torch.float32: xq_type = faiss.DistanceDataType_F32 xq_ptr = swig_ptr_from_FloatTensor(xq) elif xq.dtype == torch.float16: xq_type = faiss.DistanceDataType_F16 xq_ptr = swig_ptr_from_HalfTensor(xq) else: raise TypeError('xq must be f32 or f16') if D is None: D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) else: assert D.shape == (nq, k) # interface takes void*, we need to check this assert (D.dtype == torch.float32) if I is None: I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) else: assert I.shape == (nq, k) if I.dtype == torch.int64: I_type = faiss.IndicesDataType_I64 I_ptr = swig_ptr_from_IndicesTensor(I) elif I.dtype == I.dtype == torch.int32: I_type = faiss.IndicesDataType_I32 I_ptr = swig_ptr_from_IntTensor(I) else: raise TypeError('I must be i64 or i32') D_ptr = swig_ptr_from_FloatTensor(D) args = faiss.GpuDistanceParams() args.metric = metric args.k = k args.dims = d args.vectors = xb_ptr args.vectorsRowMajor = xb_row_major args.vectorType = xb_type args.numVectors = nb args.queries = xq_ptr args.queriesRowMajor = xq_row_major args.queryType = xq_type args.numQueries = nq args.outDistances = D_ptr args.outIndices = I_ptr args.outIndicesType = I_type with using_stream(res): faiss.bfKnn(res, args) return D, I torch_replace_method(faiss_module, 'knn_gpu', torch_replacement_knn_gpu, True, True) # allows torch tensor usage with bfKnn for all pairwise distances def torch_replacement_pairwise_distance_gpu(res, xq, xb, D=None, metric=faiss.METRIC_L2): if type(xb) is np.ndarray: # Forward to faiss __init__.py base method return faiss.pairwise_distance_gpu_numpy(res, xq, xb, D, metric) nb, d = xb.size() if xb.is_contiguous(): xb_row_major = True elif xb.t().is_contiguous(): xb = xb.t() xb_row_major = False else: raise TypeError('xb matrix should be row or column-major') if xb.dtype == torch.float32: xb_type = faiss.DistanceDataType_F32 xb_ptr = swig_ptr_from_FloatTensor(xb) elif xb.dtype == torch.float16: xb_type = faiss.DistanceDataType_F16 xb_ptr = swig_ptr_from_HalfTensor(xb) else: raise TypeError('xb must be float32 or float16') nq, d2 = xq.size() assert d2 == d if xq.is_contiguous(): xq_row_major = True elif xq.t().is_contiguous(): xq = xq.t() xq_row_major = False else: raise TypeError('xq matrix should be row or column-major') if xq.dtype == torch.float32: xq_type = faiss.DistanceDataType_F32 xq_ptr = swig_ptr_from_FloatTensor(xq) elif xq.dtype == torch.float16: xq_type = faiss.DistanceDataType_F16 xq_ptr = swig_ptr_from_HalfTensor(xq) else: raise TypeError('xq must be float32 or float16') if D is None: D = torch.empty(nq, nb, device=xb.device, dtype=torch.float32) else: assert D.shape == (nq, nb) # interface takes void*, we need to check this assert (D.dtype == torch.float32) D_ptr = swig_ptr_from_FloatTensor(D) args = faiss.GpuDistanceParams() args.metric = metric args.k = -1 # selects all pairwise distance args.dims = d args.vectors = xb_ptr args.vectorsRowMajor = xb_row_major args.vectorType = xb_type args.numVectors = nb args.queries = xq_ptr args.queriesRowMajor = xq_row_major args.queryType = xq_type args.numQueries = nq args.outDistances = D_ptr with using_stream(res): faiss.bfKnn(res, args) return D torch_replace_method(faiss_module, 'pairwise_distance_gpu', torch_replacement_pairwise_distance_gpu, True, True)