/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include #include #include namespace torch::executor { /** * A facade class for loading programs and executing methods within them. */ class Module final { public: /** * Enum to define memory locking behavior. */ enum class MlockConfig { /// Do not use memory locking. NoMlock, /// Use memory locking and handle errors. UseMlock, /// Use memory locking and ignore errors. UseMlockIgnoreErrors, }; /** * Constructs an instance by loading a program from a file with specified * memory locking behavior. * * @param[in] file_path The path to the ExecuTorch program file to load. * @param[in] mlock_config The memory locking configuration to use. */ explicit Module( const std::string& file_path, const MlockConfig mlock_config = MlockConfig::UseMlock, std::unique_ptr event_tracer = nullptr); /** * Constructs an instance with the provided data loader and memory allocator. * * @param[in] data_loader A DataLoader used for loading program data. * @param[in] memory_allocator A MemoryAllocator used for memory management. * @param[in] event_tracer A EventTracer used for tracking and logging events. */ explicit Module( std::unique_ptr data_loader, std::unique_ptr memory_allocator = nullptr, std::unique_ptr event_tracer = nullptr); Module(const Module&) = delete; Module& operator=(const Module&) = delete; Module(Module&&) = default; Module& operator=(Module&&) = default; /** * Loads the program using the specified data loader and memory allocator. * * @param[in] verification The type of verification to do before returning * success. * * @returns An Error to indicate success or failure of the loading process. */ __ET_NODISCARD Error load( const Program::Verification verification = Program::Verification::Minimal); /** * Checks if the program is loaded. * * @returns true if the program is loaded, false otherwise. */ bool is_loaded() const; /** * Get a list of method names available in the loaded program. * Loads the program and method if needed. * * @returns A set of strings containing the names of the methods, or an error * if the program or method failed to load. */ Result> method_names(); /** * Load a specific method from the program and set up memory management if * needed. The loaded method is cached to reuse the next time it's executed. * * @param[in] method_name The name of the method to load. * * @returns An Error to indicate success or failure. */ __ET_NODISCARD Error load_method(const std::string& method_name); /** * Checks if a specific method is loaded. * * @param[in] method_name The name of the method to check. * * @returns true if the method specified by method_name is loaded, false * otherwise. */ bool is_method_loaded(const std::string& method_name) const; /** * Get a method metadata struct by method name. * Loads the program and method if needed. * * @param[in] method_name The name of the method to get the metadata for. * * @returns A method metadata, or an error if the program or method failed to * load. */ Result method_meta(const std::string& method_name); /** * Execute a specific method with the given input and retrieve output. * Loads the program and method before executing if needed. * * @param[in] method_name The name of the method to execute. * @param[in] input A vector of input values to be passed to the method. * * @returns A Result object containing either a vector of output values * from the method or an error to indicate failure. */ __ET_NODISCARD Result> execute( const std::string& method_name, const std::vector& input); /** * Execute a specific method without any input values. * Loads the program and method before executing if needed. * * @param[in] method_name The name of the method to execute. * * @returns A Result object containing either a vector of output values * from the method or an error to indicate failure. */ __ET_NODISCARD Result> execute(const std::string& method_name) { return execute(method_name, {}); } /** * Execute the 'forward' method with the given input and retrieve output. * Loads the program and method before executing if needed. * * @param[in] input A vector of input values for the 'forward' method. * * @returns A Result object containing either a vector of output values * from the 'forward' method or an error to indicate failure. */ __ET_NODISCARD Result> forward(const std::vector& input) { return execute("forward", input); } /** * Execute the 'forward' method without any input values. * Loads the program and method before executing if needed. * * @returns A Result object containing either a vector of output values * from the 'forward' method or an error to indicate failure. */ __ET_NODISCARD Result> forward() { return forward({}); } /** * Retrieves the EventTracer instance being used by the Module. * EventTracer is used for tracking and logging events during the execution * of methods. * * @returns A pointer to the EventTracer instance. Returns nullptr if no * EventTracer is set. */ EventTracer* event_tracer() const { return event_tracer_.get(); } private: struct MethodHolder { std::vector> planned_buffers; std::vector> planned_spans; std::unique_ptr planned_memory; std::unique_ptr memory_manager; std::unique_ptr method; }; private: std::string file_path_; MlockConfig mlock_config_{MlockConfig::NoMlock}; std::unique_ptr data_loader_; std::unique_ptr memory_allocator_; std::unique_ptr event_tracer_; std::unique_ptr program_; std::unordered_map methods_; }; } // namespace torch::executor