• PyTorch中的python_torch_functions_i.cpp檔案生成機制


    前言

    編譯PyTorch後,torch/csrc/autograd/generated/目錄下會有python_torch_functions_0.cpp,python_torch_functions_1.cpppython_torch_functions_2.cpp等檔案,本文便從setup.py依次來探討這些檔案是如何生成的。

    setup.py

    main

    setup.py

    ################################################################################
    # Parameters parsed from environment
    ################################################################################
    
    VERBOSE_SCRIPT = True
    RUN_BUILD_DEPS = True
    
    filtered_args = []
    for i, arg in enumerate(sys.argv):
        # ...
        if arg in ['clean', 'egg_info', 'sdist']:
            RUN_BUILD_DEPS = False
    
    # ...
    
    def main():
        # ...
        if RUN_BUILD_DEPS:
            build_deps()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    RUN_BUILD_DEPS預設為True,如果RUN_BUILD_DEPS為True,則運行build_deps函數,推測是用於建構PyTorch的dependencies。

    build_deps

    setup.py

    from tools.build_pytorch_libs import build_caffe2
    # ...
    
    # all the work we need to do _before_ setup runs
    def build_deps():
        #...
    
        build_caffe2(version=version,
                     cmake_python_library=cmake_python_library,
                     build_python=True,
                     rerun_cmake=RERUN_CMAKE,
                     cmake_only=CMAKE_ONLY,
                     cmake=cmake)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    build_deps函數中最主要的部份便是調用build_caffe2

    tools/build_pytorch_libs.py

    build_caffe2

    tools/build_pytorch_libs.py

    def build_caffe2(
        version: Optional[str],
        cmake_python_library: Optional[str],
        build_python: bool,
        rerun_cmake: bool,
        cmake_only: bool,
        cmake: CMake,
    ) -> None:
        my_env = _create_build_env()
        build_test = not check_negative_env_flag("BUILD_TEST")
        cmake.generate(
            version, cmake_python_library, build_python, build_test, my_env, rerun_cmake
        )
        if cmake_only:
            return
        cmake.build(my_env)
        if build_python:
            caffe2_proto_dir = os.path.join(cmake.build_dir, "caffe2", "proto")
            for proto_file in glob(os.path.join(caffe2_proto_dir, "*.py")):
                if proto_file != os.path.join(caffe2_proto_dir, "__init__.py"):
                    shutil.copy(proto_file, os.path.join("caffe2", "proto"))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    當中有運行caffe2/CMakeLists.txt?

    caffe2/CMakeLists.txt

    caffe2/CMakeLists.txt

    #...
    file(GLOB_RECURSE autograd_python "${TOOLS_PATH}/autograd/*.py")
    file(GLOB_RECURSE autograd_yaml "${TOOLS_PATH}/autograd/*.yaml")
    file(GLOB_RECURSE autograd_templates "${TOOLS_PATH}/autograd/templates/*")
    add_custom_command(
      OUTPUT
      ${TORCH_GENERATED_CODE}
      COMMAND
      "${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py
        --native-functions-path "aten/src/ATen/native/native_functions.yaml"
        --tags-path "aten/src/ATen/native/tags.yaml"
        $<$:--disable-autograd>
        $<$:--selected-op-list-path="${SELECTED_OP_LIST}">
        --force_schema_registration
        --gen_lazy_ts_backend
        ${GEN_PER_OPERATOR_FLAG}
      DEPENDS
        "${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml"
        "${TORCH_ROOT}/aten/src/ATen/native/tags.yaml"
        "${TORCH_ROOT}/aten/src/ATen/native/ts_native_functions.yaml"
        "${TORCH_ROOT}/torch/csrc/lazy/core/shape_inference.h"
        "${TORCH_ROOT}/torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
        "${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
        "${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
        "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
        "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
        "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
        ${autograd_python}
        ${autograd_yaml}
        ${autograd_templates}
        ${torchgen_python}
      WORKING_DIRECTORY "${TORCH_ROOT}")
    #...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    add_custom_command這一段是由COMMAND利用DEPENDS中所列出的檔案生成OUTPUT

    先來看一下DEPENDS當中的${autograd_templates}

    file(GLOB_RECURSE autograd_templates "${TOOLS_PATH}/autograd/templates/*")
    
    • 1

    當中的${TOOLS_PATH}是:

    # Generate files
    set(TOOLS_PATH "${TORCH_ROOT}/tools")
    
    • 1
    • 2

    所以${autograd_templates}指的是tools/autograd/templates/目錄下的所有檔案,其中就包含了tools/autograd/templates/python_torch_functions.cpp

    python_torch_functions.cpp中最核心的一段代碼如下:

    static PyMethodDef torch_functions_shard[] = {
      ${py_method_defs}
    };
    
    • 1
    • 2
    • 3

    其中${py_method_defs}的位置便是為了待會自動生成代碼時預留的空位。注意到torch_functions_shard的型別是PyMethodDef的陣列,詳見PyMethodDef

    接著看COMMAND,它會調用tools/setup_helpers/generate_code.py,由DEPENDS(包含aten/src/ATen/native/native_functions.yamltools/autograd/templates/目錄下的所有檔案)生成OUTPUT,即${TORCH_GENERATED_CODE}

    set(TORCH_GENERATED_CODE
      ${GENERATED_CXX_TORCH}
      ${GENERATED_H_TORCH}
      ${GENERATED_CXX_PYTHON}
      ${GENERATED_H_PYTHON}
      ${GENERATED_TESTING_PYTHON}
      )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    當中的GENERATED_CXX_PYTHON如下:

    set(GENERATED_CXX_PYTHON
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_0.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_1.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_2.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_3.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_4.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_variable_methods.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions_0.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions_1.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions_2.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_nn_functions.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_fft_functions.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_linalg_functions.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_nested_functions.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
      "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
      )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    可以看到當中便包含了python_torch_functions_0.cpp,python_torch_functions_1.cpppython_torch_functions_2.cpp等三個檔案。

    所以這一段command就是調用tools/setup_helpers/generate_code.py,由python_torch_functions.cppnative_functions.yaml生成python_torch_functions_0.cpp, python_torch_functions_1.cpppython_torch_functions_2.cpp

    接著深入generate_code.py,看看python_torch_functions_i.cpp具體是如何生成的。

    tools/setup_helpers/generate_code.py

    main

    tools/setup_helpers/generate_code.py

    def main() -> None:
        parser = argparse.ArgumentParser(description="Autogenerate code")
        parser.add_argument("--native-functions-path")
        parser.add_argument("--tags-path")
        parser.add_argument(
            "--gen-dir",
            type=pathlib.Path,
            default=pathlib.Path("."),
            help="Root directory where to install files. Defaults to the current working directory.",
        )
        parser.add_argument(
            "--install_dir",
            help=(
                "Deprecated. Use --gen-dir instead. The semantics are different, do not change "
                "blindly."
            ),
        )
        parser.add_argument(
            "--subset",
            help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.',
        )
        parser.add_argument(
            "--disable-autograd",
            default=False,
            action="store_true",
            help="It can skip generating autograd related code when the flag is set",
        )
        parser.add_argument(
            "--selected-op-list-path",
            help="Path to the YAML file that contains the list of operators to include for custom build.",
        )
        parser.add_argument(
            "--operators_yaml_path",
            help="Path to the model YAML file that contains the list of operators to include for custom build.",
        )
        parser.add_argument(
            "--force_schema_registration",
            action="store_true",
            help="force it to generate schema-only registrations for ops that are not"
            "listed on --selected-op-list",
        )
        parser.add_argument(
            "--gen_lazy_ts_backend",
            action="store_true",
            help="Enable generation of the torch::lazy TorchScript backend",
        )
        parser.add_argument(
            "--per_operator_headers",
            action="store_true",
            help="Build lazy tensor ts backend with per-operator ATen headers, must match how ATen was built",
        )
        options = parser.parse_args()
    
        generate_code(
            options.gen_dir,
            options.native_functions_path,
            options.tags_path,
            options.install_dir,
            options.subset,
            options.disable_autograd,
            options.force_schema_registration,
            # options.selected_op_list
            operator_selector=get_selector(
                options.selected_op_list_path, options.operators_yaml_path
            ),
        )
    
        if options.gen_lazy_ts_backend:
            aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
            ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
            ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
            ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"
            install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc")
            lazy_install_dir = os.path.join(install_dir, "lazy/generated")
            os.makedirs(lazy_install_dir, exist_ok=True)
    
            assert os.path.isfile(
                ts_backend_yaml
            ), f"Unable to access ts_backend_yaml: {ts_backend_yaml}"
            assert os.path.isfile(
                ts_native_functions
            ), f"Unable to access {ts_native_functions}"
            from torchgen.dest.lazy_ir import GenTSLazyIR
            from torchgen.gen_lazy_tensor import run_gen_lazy_tensor
    
            run_gen_lazy_tensor(
                aten_path=aten_path,
                source_yaml=ts_backend_yaml,
                backend_name="TorchScript",
                output_dir=lazy_install_dir,
                dry_run=False,
                impl_path=ts_native_functions,
                node_base="TsNode",
                node_base_hdr=ts_node_base,
                build_in_tree=True,
                lazy_ir_generator=GenTSLazyIR,
                per_operator_headers=options.per_operator_headers,
                gen_forced_fallback_code=True,
            )
    
    
    if __name__ == "__main__":
        main()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103

    裡面最關鍵的便是generate_code函數。

    • options.gen_dir:預設是’.’
    • options.native_functions_path:從CMakeLists.txt傳入,是為aten/src/ATen/native/native_functions.yaml
    • options.tags_path:從CMakeLists.txt傳入,是為aten/src/ATen/native/tags.yaml

    generate_code

    tools/setup_helpers/generate_code.py

    def generate_code(
        gen_dir: pathlib.Path,
        native_functions_path: Optional[str] = None,
        tags_path: Optional[str] = None,
        install_dir: Optional[str] = None,
        subset: Optional[str] = None,
        disable_autograd: bool = False,
        force_schema_registration: bool = False,
        operator_selector: Any = None,
    ) -> None:
        from torchgen.selective_build.selector import SelectiveBuilder
    
        from tools.autograd.gen_annotated_fn_args import gen_annotated
        from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
    
        # Build ATen based Variable classes
        if install_dir is None:
            install_dir = os.fspath(gen_dir / "torch/csrc")
            python_install_dir = os.fspath(gen_dir / "torch/testing/_internal/generated")
        else:
            python_install_dir = install_dir
        autograd_gen_dir = os.path.join(install_dir, "autograd", "generated")
        for d in (autograd_gen_dir, python_install_dir):
            os.makedirs(d, exist_ok=True)
        autograd_dir = os.fspath(pathlib.Path(__file__).parent.parent / "autograd")
    
        if subset == "pybindings" or not subset:
            gen_autograd_python(
                native_functions_path or NATIVE_FUNCTIONS_PATH,
                tags_path or TAGS_PATH,
                autograd_gen_dir,
                autograd_dir,
            )
    
        if operator_selector is None:
            operator_selector = SelectiveBuilder.get_nop_selector()
    
        if subset == "libtorch" or not subset:
    
            gen_autograd(
                native_functions_path or NATIVE_FUNCTIONS_PATH,
                tags_path or TAGS_PATH,
                autograd_gen_dir,
                autograd_dir,
                disable_autograd=disable_autograd,
                operator_selector=operator_selector,
            )
    
        if subset == "python" or not subset:
            gen_annotated(
                native_functions_path or NATIVE_FUNCTIONS_PATH,
                tags_path or TAGS_PATH,
                python_install_dir,
                autograd_dir,
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55

    首先設定出以下變數:

    • gen_dir:‘.’
    • native_functions_path :aten/src/ATen/native/native_functions.yaml
    • tags_path:aten/src/ATen/native/tags.yaml
    • install_dirgen_dir/torch/csrc → \rarr ‘./torch/csrc’
    • autograd_gen_dirinstall_dir/autograd/generated → \rarr ./torch/csrc/autograd/generated
    • autograd_dir: tools/autograd

    此處關注的是python_torch_functions_i.cpp,所以接著進入gen_autograd_python函數。

    tools/autograd/gen_autograd.py

    gen_autograd_python

    tools/autograd/gen_autograd.py

    def gen_autograd_python(
        native_functions_path: str,
        tags_path: str,
        out: str,
        autograd_dir: str,
    ) -> None:
        differentiability_infos, _ = load_derivatives(
            os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
        )
    
        template_path = os.path.join(autograd_dir, "templates")
    
        # Generate Functions.h/cpp
        gen_autograd_functions_python(out, differentiability_infos, template_path)
    
        # Generate Python bindings
        deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
        gen_python_functions.gen(
            out, native_functions_path, tags_path, deprecated_path, template_path
        )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    首先設定出以下變數:

    • out:./torch/csrc/autograd/generated
    • native_functions_path:aten/src/ATen/native/native_functions.yaml
    • tags_path:aten/src/ATen/native/tags.yaml
    • deprecated_path:tools/autograd/deprecated.yaml
    • template_path:tools/autograd/templates

    其中gen_autograd_functions_python函數生成torch/csrc/autograd/generated資料夾下的python_functionsEverything.cpppython_functions_0.cpppython_functions_4.cpp

    gen_python_functions.gen則生成其它許多檔案,包括我們所關注的torch/csrc/autograd/generated資料夾下的python_torch_functionsEverything.cpppython_torch_functions_0.cpppython_torch_functions_2.cpp

    接著繼續深入gen_python_functions.gen函數。

    tools/autograd/gen_python_functions.py

    gen_python_functions.gen

    tools/autograd/gen_python_functions.py

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
    #
    #                            Main Function
    #
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
    
    
    def gen(
        out: str,
        native_yaml_path: str,
        tags_yaml_path: str,
        deprecated_yaml_path: str,
        template_path: str,
        *,
        symint: bool = True,
    ) -> None:
        fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
        native_functions = parse_native_yaml(
            native_yaml_path, tags_yaml_path
        ).native_functions
        native_functions = list(filter(should_generate_py_binding, native_functions))
    
        methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
        create_python_bindings(
            fm,
            methods,
            is_py_variable_method,
            None,
            "python_variable_methods.cpp",
            method=True,
            symint=symint,
        )
    
        # NOTE: num_shards here must be synced with gatherTorchFunctions in
        #       torch/csrc/autograd/python_torch_functions_manual.cpp
        functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
        create_python_bindings_sharded(
            fm,
            functions,
            is_py_torch_function,
            "torch",
            "python_torch_functions.cpp",
            method=False,
            num_shards=3,
            symint=symint,
        )
    
        create_python_bindings(
            fm,
            functions,
            is_py_nn_function,
            "torch.nn",
            "python_nn_functions.cpp",
            method=False,
            symint=symint,
        )
    
        create_python_bindings(
            fm,
            functions,
            is_py_fft_function,
            "torch.fft",
            "python_fft_functions.cpp",
            method=False,
            symint=symint,
        )
    
        create_python_bindings(
            fm,
            functions,
            is_py_linalg_function,
            "torch.linalg",
            "python_linalg_functions.cpp",
            method=False,
            symint=symint,
        )
    
        create_python_bindings(
            fm,
            functions,
            is_py_nested_function,
            "torch.nested",
            "python_nested_functions.cpp",
            method=False,
        )
    
        create_python_bindings(
            fm,
            functions,
            is_py_sparse_function,
            "torch.sparse",
            "python_sparse_functions.cpp",
            method=False,
            symint=symint,
        )
    
        create_python_bindings(
            fm,
            functions,
            is_py_special_function,
            "torch.special",
            "python_special_functions.cpp",
            method=False,
            symint=symint,
        )
    
        # Currently, we only use `functions` to generate `return_types` bindings.
        # All methods which return namedtuple have function variant at this point.
        # If any method only operator with namedtuple is added in the future,
        # we will have to address that.
        create_python_return_type_bindings(
            fm, functions, lambda fn: True, "python_return_types.cpp"
        )
    
        valid_tags = parse_tags_yaml(tags_yaml_path)
    
        def gen_tags_enum() -> Dict[str, str]:
            return {
                "enum_of_valid_tags": (
                    "".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])
                )
            }
    
        fm.write("python_enum_tag.cpp", gen_tags_enum)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124

    首先創建一個FileManager變數fm,它的前兩個參數如下(詳見FileManager建構子):

    • install_dir:./torch/csrc/autograd/generated
    • template_dir:tools/autograd/templates

    此處解析native_functions.yaml後得到native_functions

        native_functions = parse_native_yaml(
            native_yaml_path, tags_yaml_path
        ).native_functions
        native_functions = list(filter(should_generate_py_binding, native_functions))
    
    • 1
    • 2
    • 3
    • 4

    接著載入它們的函數簽名,得到functions

        # NOTE: num_shards here must be synced with gatherTorchFunctions in
        #       torch/csrc/autograd/python_torch_functions_manual.cpp
        functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
    
    • 1
    • 2
    • 3

    後續將解析出來的函數簽名傳入create_python_bindings_sharded,生成python_torch_functions_0.cpp,python_torch_functions_1.cpppython_torch_functions_2.cpp等三個檔案。

        create_python_bindings_sharded(
            fm,
            functions,
            is_py_torch_function,
            "torch",
            "python_torch_functions.cpp",
            method=False,
            num_shards=3,
            symint=symint,
        )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    繼續深入create_python_bindings_sharded

    create_python_bindings_sharded

    tools/autograd/gen_python_functions.py

    def create_python_bindings_sharded(
        fm: FileManager,
        pairs: Sequence[PythonSignatureNativeFunctionPair],
        pred: Callable[[NativeFunction], bool],
        module: Optional[str],
        filename: str,
        *,
        method: bool,
        num_shards: int,
        symint: bool = True,
    ) -> None:
        """Generates Python bindings to ATen functions"""
        grouped = group_filter_overloads(pairs, pred)
    
        def key_func(
            kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
        ) -> str:
            return kv[0].base
    
        def env_func(
            kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
        ) -> Dict[str, List[str]]:
            name, fn_pairs = kv
            return {
                "ops_headers": [f"#include {name.base}.h>"],
                "py_forwards": list(forward_decls(name, fn_pairs, method=method)),
                "py_methods": [
                    method_impl(name, module, fn_pairs, method=method, symint=symint)
                ],
                "py_method_defs": [method_def(name, module, fn_pairs, method=method)],
            }
    
        fm.write_sharded(
            filename,
            grouped.items(),
            base_env={
                "generated_comment": "@"
                + f"generated from {fm.template_dir_for_comments()}/{filename}",
            },
            key_fn=key_func,
            env_callable=env_func,
            num_shards=num_shards,
            sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
        )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    write_sharded的參數如下:

    • filenamepython_torch_functions.cpp
    • grouped.items():一個Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair],也就是將運算子名稱對應到(函數簽名,函數)對列表的一個字典

    繼續深入write_sharded

    torchgen/utils.py

    write_sharded

    torchgen/utils.py

        def write_sharded(
            self,
            filename: str,
            items: Iterable[T],
            *,
            key_fn: Callable[[T], str],
            env_callable: Callable[[T], Dict[str, List[str]]],
            num_shards: int,
            base_env: Optional[Dict[str, Any]] = None,
            sharded_keys: Set[str],
        ) -> None:
    
            everything: Dict[str, Any] = {"shard_id": "Everything"}
            shards: List[Dict[str, Any]] = [
                {"shard_id": f"_{i}"} for i in range(num_shards)
            ]
            all_shards = [everything] + shards
    
            if base_env is not None:
                for shard in all_shards:
                    shard.update(base_env)
    
            for key in sharded_keys:
                for shard in all_shards:
                    if key in shard:
                        assert isinstance(
                            shard[key], list
                        ), "sharded keys in base_env must be a list"
                        shard[key] = shard[key].copy()
                    else:
                        shard[key] = []
    
            def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
                for k, v in from_.items():
                    assert k in sharded_keys, f"undeclared sharded key {k}"
                    into[k] += v
    
            if self.dry_run:
                # Dry runs don't write any templates, so incomplete environments are fine
                items = ()
    
            for item in items:
                key = key_fn(item)
                sid = string_stable_hash(key) % num_shards
                env = env_callable(item)
    
                merge_env(shards[sid], env)
                merge_env(everything, env)
    
            dot_pos = filename.rfind(".")
            if dot_pos == -1:
                dot_pos = len(filename)
            base_filename = filename[:dot_pos]
            extension = filename[dot_pos:]
    
            for shard in all_shards:
                shard_id = shard["shard_id"]
                self.write_with_template(
                    f"{base_filename}{shard_id}{extension}", filename, lambda: shard
                )
    
            # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
            self.filenames.discard(
                f"{self.install_dir}/{base_filename}Everything{extension}"
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65

    注意到all_shards是由[everything]shards所組成,是一個List[Dict[str, Any]]

    而以下這一段會把env字典裡的內容併入shards[sid]everything,所以會間接更新all_shards

            for item in items:
                key = key_fn(item)
                sid = string_stable_hash(key) % num_shards
                env = env_callable(item)
    
                merge_env(shards[sid], env)
                merge_env(everything, env)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    因此all_shards中的內容除了下面的之外:

    [{'shard_id': 'Everything'}, {'shard_id': '_0'}, {'shard_id': '_1'}, {'shard_id': '_2'}]
    
    • 1

    當中的各字典還包括了env裡的內容。

    最後一段for迴圈遍歷all_shards,一一調用write_with_template,會調用lambda: shard函數由filename生成f"{base_filename}{shard_id}{extension}",也就是由python_torch_functions.cpp生成python_torch_functionsEverything.cpp, python_torch_functions_0.cpp, python_torch_functions_1.cpppython_torch_functions_2.cpp

    write_with_template函數已獨立成篇,詳見PyTorch檔案生成機制中的FileManager.write_with_template

    生成結果

    回顧create_python_bindings_sharded,那裡列出了generated_comment, ops_headers, py_forwards, py_method_defs, py_methods等key。

    python_torch_functions.cpp中各key會被替換成:

    • generated_comment

      // ${generated_comment}
      
      • 1

      會被替換成:

      // @generated from ../tools/autograd/templates/python_torch_functions.cpp
      
      • 1
    • ops_headers

      #ifndef AT_PER_OPERATOR_HEADERS
      #include 
      #else
      $ops_headers
      #endif
      
      • 1
      • 2
      • 3
      • 4
      • 5

      被替換成:

      #ifndef AT_PER_OPERATOR_HEADERS
      #include 
      #else
      #include 
      // ...
      #include 
      #endif
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7

      表示如果沒有定義AT_PER_OPERATOR_HEADERS這個巨集就會include torch/include/ATen/Functions.h(或build/aten/src/ATen/Functions.h),否則include各算子專屬的headers。筆者的環境中沒有定義AT_PER_OPERATOR_HEADERS,所以torch/include/ATen/ops資料夾下只有from_blob.htensor.h兩個檔案。

    • py_forwards

      // generated forward declarations start here
      
      ${py_forwards}
      
      • 1
      • 2
      • 3

      生成結果如:

      static PyObject * THPVariable__cast_Byte(PyObject* self_, PyObject* args, PyObject* kwargs);
      
      • 1

      是為C++與Python介接函數的宣告。

    • py_method_defs

      static PyMethodDef torch_functions_shard[] = {
        ${py_method_defs}
      };
      
      • 1
      • 2
      • 3

      生成結果如:

        {"_cast_Byte", castPyCFunctionWithKeywords(THPVariable__cast_Byte), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      
      • 1

      是為PyMethodDef結構體,詳見撰寫自己的Python C擴展!Python/C API - 模組,型別,Tuple,例外和引用計數

    • py_methods

      // generated methods start here
      
      ${py_methods}
      
      • 1
      • 2
      • 3

      生成結果如下:

      // _cast_Byte
      static PyObject * THPVariable__cast_Byte(PyObject* self_, PyObject* args, PyObject* kwargs)
      {
        HANDLE_TH_ERRORS
        static PythonArgParser parser({
          "_cast_Byte(Tensor input, bool non_blocking=False)",
        }, /*traceable=*/true);
      
        ParsedArgs<2> parsed_args;
        auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
        if(_r.has_torch_function()) {
          return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
        }
        // aten::_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
        
        auto dispatch__cast_Byte = [](const at::Tensor & self, bool non_blocking) -> at::Tensor {
          pybind11::gil_scoped_release no_gil;
          return at::_cast_Byte(self, non_blocking);
        };
        return wrap(dispatch__cast_Byte(_r.tensor(0), _r.toBool(1)));
        Py_RETURN_NONE;
        END_HANDLE_TH_ERRORS
      }
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23

      是為介接函數的定義,詳見撰寫自己的Python C擴展!

  • 相关阅读:
    C++特性——auto关键字、范围for、指针空值nullptr
    洛谷P5904 树形dp,长链剖分优化
    蔡司光学:儿童近视眼镜的匠心之选
    PX4模块设计之二十八:RCInput模块
    python:list和dict的基本操作实例
    如何在 Java 中实现无向环和有向环的检测
    详解:程序部署在服务器上,localhost可以访问Tomcat,但是外网ip无法访问
    数学建模学习笔记(4):模糊综合评价
    数字取证对有效企业事件响应的重要性
    使用 AgileConfig 动态配置 NLog
  • 原文地址:https://blog.csdn.net/keineahnung2345/article/details/132783587