背景

我們知道PyTorch的的代碼主要由C10、ATen、torch三大部分組成的。其中:

1,C10,來自於Core Tensor Library的縮寫。這裡存放的都是最基礎的Tensor庫的代碼,可以運行在服務端和移動端。PyTorch目前正在將代碼從ATen/core目錄下遷移到C10中。C10的代碼有一些特殊性,體現在這裡的代碼除了服務端外還要運行在移動端,因此編譯後的二進位文件大小也很關鍵,因此C10目前存放的都是最核心、精簡的、基礎的Tensor函數和介面。

C10目前最具代表性的一個class就是TensorImpl了,它實現了Tensor的最基礎框架。繼承者和使用者有:

VariableVariable::Impl
SparseTensorImpl
detail::make_tensor<TensorImpl>(storage_impl, CUDATensorId(), false)
Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>

值得一提的是,C10中還使用/修改了來自llvm的SmallVector,在vector元素比較少的時候用以代替std::vector,用以提升性能;

2,ATen,來自於 A TENsor library for C++11的縮寫;PyTorch的C++ tensor library。ATen部分有大量的代碼是來聲明和定義Tensor運算相關的邏輯的,除此之外,PyTorch還使用了aten/src/ATen/gen.py來動態生成一些ATen相關的代碼。Gemfield本文討論的正是這部分

3,Torch,部分代碼仍然在使用以前的快要進入歷史博物館的Torch開源項目,比如具有下面這些文件名格式的文件:

TH* = TorcH
THC* = TorcH Cuda
THCS* = TorcH Cuda Sparse (now defunct)
THCUNN* = TorcH CUda Neural Network (see cunn)
THD* = TorcH Distributed
THNN* = TorcH Neural Network
THS* = TorcH Sparse (now defunct)
THP* = TorcH Python

PyTorch會使用tools/setup_helpers/generate_code.py來動態生成Torch層面相關的一些代碼,這部分動態生成的邏輯將不在本文闡述,你可以關注Gemfield專欄的後續文章。

4,Gemfield在本文將討論ATen部分動態生成的代碼

這也可以看作是gemfield專欄文章Gemfield:PyTorch的編譯系統的一部分。

ATen動態代碼生成的入口

在編譯期,Cmake腳本會在cmake/Codegen.cmake中檢測BUILD_ATEN_MOBILE標誌(該標誌在編譯Android和iOS的時候會設置為True,否則是False),如果該標誌為False的話,才會運行gen.py(也就是說,gen.py生成了ATen模塊的一部分源代碼):這正好說明了,gen.py生成的代碼是為Server端準備的。

SET(GEN_COMMAND
${PYCMD} ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen.py
--source-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen
--install_dir ${CMAKE_BINARY_DIR}/aten/src/ATen
${GEN_ROCM_FLAG}
${cwrap_files}
)

EXECUTE_PROCESS(
COMMAND ${GEN_COMMAND}
--output-dependencies ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_cpp.txt
--install_dir ${CMAKE_BINARY_DIR}/aten/src/ATen
RESULT_VARIABLE RETURN_VALUE
)

展開就是

python3 pytorch/aten/src/ATen/gen.py
--source-path pytorch/aten/src/ATen
--install_dir pytorch/build/aten/src/ATen
pytorch/aten/src/ATen/Declarations.cwrap
pytorch/aten/src/THNN/generic/THNN.h
pytorch/aten/src/THCUNN/generic/THCUNN.h
pytorch/aten/src/ATen/nn.yaml
pytorch/aten/src/ATen/native/native_functions.yaml
--output-dependencies pytorch/aten/src/ATen/generated_cpp.txt

那麼什麼樣的代碼是server端需要而移動端又不能要的呢?(不妨猜測下)

上述gen.py的命令會被執行兩次,第一次是帶output-dependencies參數的,第二次則不帶。區別是,前者僅僅會把要生成的文件的文件名寫入以下三個文件:

build/aten/src/ATen/generated_cpp.txt-cuda
build/aten/src/ATen/generated_cpp.txt-core
build/aten/src/ATen/generated_cpp.txt

舉例,在generated_cpp.txt-core文件中,有以下內容:

gemfield@ubuntu:~/github/pytorch$ cat ./build/aten/src/ATen/generated_cpp.txt-core
/home/gemfield/github/pytorch/build/aten/src/ATen/core_tmp/Tensor.h;/home/gemfield/github/pytorch/build/aten/src/ATen/core_tmp/TensorMethods.h;/home/gemfield/github/pytorch/build/aten/src/ATen/core_tmp/Type.h;

第二次不帶output-dependencies參數,則動態代碼生成邏輯會真正的去生成那些文件:

build/aten/src/ATen/CPUByteType.cpp
build/aten/src/ATen/CPUByteType.h
......
build/aten/src/ATen/CPUShortType.cpp
build/aten/src/ATen/CPUShortType.h
build/aten/src/ATen/Declarations.yaml
build/aten/src/ATen/ExtensionBackendRegistration.h
build/aten/src/ATen/Functions.h
build/aten/src/ATen/LegacyTHCPUByteDispatcher.cpp
build/aten/src/ATen/LegacyTHCPUByteDispatcher.h
......
build/aten/src/ATen/LegacyTHDispatcher.cpp
build/aten/src/ATen/LegacyTHDispatcher.h
build/aten/src/ATen/LegacyTHFunctions.h
build/aten/src/ATen/MSNPUType.cpp
build/aten/src/ATen/MSNPUType.h
build/aten/src/ATen/NativeFunctions.h
build/aten/src/ATen/RegisterCPU.cpp
build/aten/src/ATen/RegisterCPU.h
......
build/aten/src/ATen/SparseCPUShortType.cpp
build/aten/src/ATen/SparseCPUShortType.h
build/aten/src/ATen/TypeDefault.cpp
build/aten/src/ATen/TypeDefault.h
build/aten/src/ATen/TypeExtendedInterface.h
build/aten/src/ATen/XLAType.cpp
build/aten/src/ATen/XLAType.h

以及CUDA backend的:

build/aten/src/ATen/CUDAByteType.cpp;
build/aten/src/ATen/CUDAByteType.h;
......
build/aten/src/ATen/CUDAShortType.cpp
build/aten/src/ATen/CUDAShortType.h
build/aten/src/ATen/LegacyTHCUDAByteDispatcher.cpp
build/aten/src/ATen/LegacyTHCUDAByteDispatcher.h
......
build/aten/src/ATen/RegisterCUDA.cpp
build/aten/src/ATen/RegisterCUDA.h
......
build/aten/src/ATen/SparseCUDAShortType.cpp
build/aten/src/ATen/SparseCUDAShortType.h

ATen動態代碼如何生成

在cmake的調用過程中,gen.py從命令行上接收的參數除了source-path、install_dir、output-dependencies之外,就是files:

/home/gemfield/github/pytorch/aten/src/ATen/Declarations.cwrap
/home/gemfield/github/pytorch/aten/src/THNN/generic/THNN.h
/home/gemfield/github/pytorch/aten/src/THCUNN/generic/THCUNN.h
/home/gemfield/github/pytorch/aten/src/ATen/nn.yaml
/home/gemfield/github/pytorch/aten/src/ATen/native/native_functions.yaml

這些files被劃分了cwrap_files、nn_files、native_files;

cwrap_files

/home/gemfield/github/pytorch/aten/src/ATen/Declarations.cwrap

cwrap_files是包含了大量的"[["和"]]"括起來的yaml結構體,使用cwrap_parser進行解析。得到的信息主要是_th_前綴的函數,這些都是TH Tensor的函數。這些信息會放入top env,在後續步驟中用於替換各template文件中的佔位符。另外,經過後續腳本的處理,還會以yaml形式寫入torch/share/ATen/Declarations.yaml文件中。

nn_files

/home/gemfield/github/pytorch/aten/src/THNN/generic/THNN.h
/home/gemfield/github/pytorch/aten/src/THCUNN/generic/THCUNN.h
/home/gemfield/github/pytorch/aten/src/ATen/nn.yaml

使用nn_parse進行解析,得到_thnn_前綴的函數的信息。這些信息會放入top env,在後續步驟中用於替換各template文件中的佔位符。另外,經過後續腳本的處理,還會以yaml形式寫入torch/share/ATen/Declarations.yaml文件中。

native_files

/home/gemfield/github/pytorch/aten/src/ATen/native/native_functions.yaml

使用native_parse進行解析,得到的是現代PyTorch(區別於老舊的Torch)的一些正統函數的信息。這些信息會放入top env,在後續步驟中用於替換template中的佔位符。另外,經過後續腳本的處理,還會以yaml形式寫入torch/share/ATen/Declarations.yaml文件中。

所有解析完的信息會放入python list里(當然含有很多嵌套的list和dict),也就是屢次提到的top env。

解析完legacy的cwrap_files、nn_files和modern的native_files後,我們得到了top env信息,下面就要開始使用各個template文件結合top env來渲染出真正的源文件了。這個渲染過程分10步:

1,生成CPUGenerator.h、CUDAGenerator.h

輸入為模板文件aten/src/ATen/templates/GeneratorDerived.h:

#pragma once
//${generated_comment}
#include <$header>
#include <ATen/core/Generator.h>
namespace at {
class Context;
struct ${name}Generator : public Generator {
CAFFE2_API ${name}Generator(Context * context);
CAFFE2_API virtual ~${name}Generator();
CAFFE2_API virtual ${name}Generator& copy(const Generator& from) override;
CAFFE2_API virtual ${name}Generator& free() override;
CAFFE2_API virtual uint64_t seed() override;
CAFFE2_API virtual uint64_t initialSeed() override;
CAFFE2_API virtual ${name}Generator& manualSeed(uint64_t seed) override;
CAFFE2_API virtual ${name}Generator& manualSeedAll(uint64_t seed) override;
CAFFE2_API virtual void * unsafeGetTH() override;
public:
Context * context;
${th_generator}
};
}

輸出為$INSTALL_DIR/CPUGenerator.h和$INSTALL_DIR/CUDAGenerator.h,會把其中的${name}替換為CPU或CUDA;${th_generator}替換為THGenerator * generator; 或者空;$header替換為TH/TH.h或THC/THC.h。

2,生成build/aten/src/ATen/Declarations.yaml(同時會被拷貝到torch/share/ATen/Declarations.yaml)

這是根據cwrap_files、nn_files、native_files解析出來的信息生成的。Declarations.yaml會在Torch模塊的動態代碼生成過程中被使用。

3,生成Type繼承體系的源文件

首先的一個問題是:Type繼承體系是什麼樣子呢?

Type的整個繼承體系是這樣的,Type類派生出了TypeExtendedInterface類,TypeExtendedInterface類又派生出了TypeDefault類。TypeDefault又派生出了CUDATypeDefault、CPUTypeDefault、UndefinedType等(以後還會擴展)。其中,根據density和scaler type的不同,CUDATypeDefault派生出了下面這些類:

CUDAIntType
CUDAShortType
SparseCUDACharType
CUDADoubleType
CUDAByteType
CUDACharType
SparseCUDAByteType
CUDAFloatType
SparseCUDALongType
CUDALongType
CUDAHalfType
SparseCUDAShortType
SparseCUDADoubleType
SparseCUDAIntType
SparseCUDAFloatType

CPUTypeDefault派生除了下面這些類:

SparseCPUShortType
CPUFloatType
CPUHalfType
CPUDoubleType
CPUByteType
SparseCPUFloatType
SparseCPUIntType
SparseCPUDoubleType
CPUCharType
SparseCPUByteType
CPUIntType
CPULongType
SparseCPULongType
SparseCPUCharType
CPUShortType

其次的一個問題是:Type繼承體系的作用是什麼呢?

Type類里聲明了426個純虛函數;TypeExtendedInterface除了繼承之外額外聲明了1287個純虛函數;而TypeDefault類繼承了這些函數,並實現了部分函數(主要是一些中繼方面的),正如default這一名字的含義一樣。而TypeDefault沒有實現的部分,將被30個具體的子類分別實現。所有PyTorch中的op函數,都將由Tensor介面dispatch到Type家族(比如使用detail::infer_type函數可以得到一個Tensor對應的Type,繼而使用Type繼承體系中相應的子類的函數調用),再由Type家族轉發到legacy的TH/THC實現抑或是modern的native實現上。

最後一個問題是:這30個具體的類都是Type,那它們之間的區別是什麼呢?

在aten/src/ATen/core/Type.h(該文件是由模板Type.h生成)中定義了Type類,這個類主要就是聲明了一大堆和tensor相關的純虛函數;接著是TypeExtendedInterface.h中定義的TypeExtendedInterface類繼承了Type類,和Type類一樣,TypeExtendedInterface繼續擴展了更多的純虛函數(由模板中的${pure_virtual_extended_type_method_declarations}根據parser得到的信息生成的),這些擴充的純虛函數擁有下面這樣的格式:

virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const = 0;

virtual ${return_type} ${api_name}(${type_method_formals}) const = 0;

接著是TypeDefault.h中定義的TypeDefault類,它繼承了TypeExtendedInterface類。和父類們一樣,TypeDefault又擴充了一堆函數(也是由模板中的${type_method_declarations}生成的,這些擴充的函數擁有下面這樣的格式:

${return_type} ${api_name}(${type_method_formals}) const override;

TypeDefault類派生出了CUDATypeDefault、CPUTypeDefault。由上文得知,CUDATypeDefault和CPUTypeDefault將會分別派生出15個具體的類,這些類所在的30個頭文件就是根據TypeDerived.h、TypeDerived.cpp、SparseTypeDerived.cpp這三個模板文件生成的。

Type繼承體系中頭文件對應的輸入的模板是./aten/src/ATen/templates/TypeDerived.h:

#pragma once

// ${generated_comment}

#include <ATen/CPUTypeDefault.h>
#include <ATen/Context.h>
#include <ATen/CheckGenerator.h>

$extra_cuda_headers

#ifdef _MSC_VER
#ifdef Type
#undef Type
#endif
#endif

namespace at {

struct ${Type} final : public ${DenseBackend}TypeDefault {
explicit ${Type}();
virtual ScalarType scalarType() const override;
virtual caffe2::TypeMeta typeMeta() const override;
virtual Backend backend() const override;
virtual const char * toString() const override;
virtual size_t elementSizeInBytes() const override;
virtual TypeID ID() const override;

// example
// virtual Tensor * add(Tensor & a, Tensor & b) override;
${type_derived_method_declarations}
};

} // namespace at

根據density是dense還是sparse、backend是CPU還是CUDA、scalar_types是 Byte/Char/Long/Double/Float/Int/Short/Half(sparse沒有Half),我們可以從這一個模板文件生成30個不同的頭文件。它們是:

CPUByteType.h
CPUCharType.h
CPUDoubleType.h
CPUFloatType.h
CPUHalfType.h
CPUIntType.h
CPULongType.h
CPUShortType.h
CUDAByteType.h
CUDACharType.h
CUDADoubleType.h
CUDAFloatType.h
CUDAHalfType.h
CUDAIntType.h
CUDALongType.h
CUDAShortType.h
SparseCPUByteType.h
SparseCPUCharType.h
SparseCPUDoubleType.h
SparseCPUFloatType.h
SparseCPUIntType.h
SparseCPULongType.h
SparseCPUShortType.h
SparseCUDAByteType.h
SparseCUDACharType.h
SparseCUDADoubleType.h
SparseCUDAFloatType.h
SparseCUDAIntType.h
SparseCUDALongType.h
SparseCUDAShortType.h

其中,${type_derived_method_declarations}將會被替換為一大堆函數聲明。CUDA比CPU的多很多、Sparse的比dense的少很多。如果backend和density都一樣,則不同的scalar_types之間的頭文件的${type_derived_method_declarations}將是一樣的。

Type繼承體系中cpp實現文件對應的輸入模板文件是aten/src/ATen/templates/TypeDerived.cpp(Dense的)和aten/src/ATen/templates/SparseTypeDerived.cpp(Sparse的)。這兩個模板文件基本一樣:

#define __STDC_FORMAT_MACROS

#include <ATen/${Type}.h>

// ${generated_comment}

$th_headers
$storage_tensor_headers
#include <ATen/${Generator}.h>
#include <c10/core/Allocator.h>
......
${type_derived_method_definitions}
}

輸出則是30個cpp:

CPUByteType.cpp
CPUCharType.cpp
CPUDoubleType.cpp
CPUFloatType.cpp
CPUHalfType.cpp
CPUIntType.cpp
CPULongType.cpp
CPUShortType.cpp
CUDAByteType.cpp
CUDACharType.cpp
CUDADoubleType.cpp
CUDAFloatType.cpp
CUDAHalfType.cpp
CUDAIntType.cpp
CUDALongType.cpp
CUDAShortType.cpp
SparseCPUByteType.cpp
SparseCPUCharType.cpp
SparseCPUDoubleType.cpp
SparseCPUFloatType.cpp
SparseCPUIntType.cpp
SparseCPULongType.cpp
SparseCPUShortType.cpp
SparseCUDAByteType.cpp
SparseCUDACharType.cpp
SparseCUDADoubleType.cpp
SparseCUDAFloatType.cpp
SparseCUDAIntType.cpp
SparseCUDALongType.cpp
SparseCUDAShortType.cpp

4,生成legacy TH dispatcher

根據模板LegacyTHDispatcherDerived.cpp:

#include "ATen/${Dispatcher}.h"

// ${generated_comment}

namespace at {

${Dispatcher}::${Dispatcher}()
: LegacyTHDispatcher(${Backend}TensorId()) {}

}

和LegacyTHDispatcherDerived.h:

#pragma once

// ${generated_comment}

#include "ATen/LegacyTHDispatcher.h"

namespace at {

struct ${Dispatcher} final : public LegacyTHDispatcher {
explicit ${Dispatcher}();

};

} // namespace at

生成以下的class 類:

LegacyTHCPUByteDispatcher
LegacyTHCPUCharDispatcher
LegacyTHCPUDoubleDispatcher
LegacyTHCPUFloatDispatcher
LegacyTHCPUHalfDispatcher
LegacyTHCPUIntDispatcher
LegacyTHCPULongDispatcher
LegacyTHCPUShortDispatcher
LegacyTHCUDAByteDispatcher
LegacyTHCUDACharDispatcher
LegacyTHCUDADoubleDispatcher
LegacyTHCUDAFloatDispatcher
LegacyTHCUDAHalfDispatcher
LegacyTHCUDAIntDispatcher
LegacyTHCUDALongDispatcher
LegacyTHCUDAShortDispatcher

5,生成Type.h、Tensor.h、TensorMethods.h

根據模板Type.h、Tensor.h、TensorMethods.h生成源文件Type.h、Tensor.h、TensorMethods.h,其中Type.h在上文已經提到過。Tensor.h模板中就是一個 ${tensor_method_declarations}佔位符,該佔位符將由下面的語句結合實際的env來替換:

${return_type} ${api_name}(${method_formals_with_defaults})${const_mark}

而TensorMethods.h模板中就是一個${tensor_method_definitions}佔位符,該佔位符將由下面的語句結合實際的env來替換:

inline ${return_type} Tensor::${api_name}(${method_formals})${const_mark} {
return type().${api_name}(${method_actuals});
}

可見Tensor.h負責聲明Tensor類,而TensorMethods.h負責inline實現Tensor類的一些方法。這樣一來,Tensor類的函數就通過tensor的種類和dispatcher的中繼,進而調用到了Type繼承體系的代碼。

6,生成TypeExtendedInterface.h、TypeDefault.h、TypeDefault.cpp

根據模板TypeExtendedInterface.h、TypeDefault.h、TypeDefault.cpp生成源文件TypeExtendedInterface.h、TypeDefault.h、TypeDefault.cpp。上文已經提到過了。

7,生成LegacyTHDispatcher.h、LegacyTHDispatcher.cpp

根據模板LegacyTHDispatcher.h和LegacyTHDispatcher.cpp生成,對於這兩者來說,模板里也沒啥要替換的。

#pragma once

// ${generated_comment}

#include <c10/core/TensorTypeIdRegistration.h>

namespace at {

struct CAFFE2_API LegacyTHDispatcher {
explicit LegacyTHDispatcher(TensorTypeId type_id)
: type_id_(type_id) {}

virtual ~LegacyTHDispatcher() {}

protected:
TensorTypeId type_id_;
};

} // namespace th

8,生成RegisterCPU.h、RegisterCPU.cpp、RegisterCUDA.h、RegisterCUDA.cpp

其中,register_cpu_types、register_cuda_types會被hooks調用,完成Type繼承體系的初始化和註冊:

#cpu
void register_cpu_types(Context * context) {
context->registerType(Backend::CPU, ScalarType::Byte, new CPUByteType());
context->registerType(Backend::CPU, ScalarType::Char, new CPUCharType());
context->registerType(Backend::CPU, ScalarType::Double, new CPUDoubleType());
context->registerType(Backend::CPU, ScalarType::Float, new CPUFloatType());
context->registerType(Backend::CPU, ScalarType::Int, new CPUIntType());
context->registerType(Backend::CPU, ScalarType::Long, new CPULongType());
context->registerType(Backend::CPU, ScalarType::Short, new CPUShortType());
context->registerType(Backend::CPU, ScalarType::Half, new CPUHalfType());
context->registerType(Backend::SparseCPU, ScalarType::Byte, new SparseCPUByteType());
context->registerType(Backend::SparseCPU, ScalarType::Char, new SparseCPUCharType());
context->registerType(Backend::SparseCPU, ScalarType::Double, new SparseCPUDoubleType());
context->registerType(Backend::SparseCPU, ScalarType::Float, new SparseCPUFloatType());
context->registerType(Backend::SparseCPU, ScalarType::Int, new SparseCPUIntType());
context->registerType(Backend::SparseCPU, ScalarType::Long, new SparseCPULongType());
context->registerType(Backend::SparseCPU, ScalarType::Short, new SparseCPUShortType());
context->registerType(Backend::MSNPU, ScalarType::Byte, new MSNPUType());
context->registerType(Backend::MSNPU, ScalarType::Char, new MSNPUType());
context->registerType(Backend::MSNPU, ScalarType::Double, new MSNPUType());
context->registerType(Backend::MSNPU, ScalarType::Float, new MSNPUType());
context->registerType(Backend::MSNPU, ScalarType::Int, new MSNPUType());
context->registerType(Backend::MSNPU, ScalarType::Long, new MSNPUType());
context->registerType(Backend::MSNPU, ScalarType::Short, new MSNPUType());
context->registerType(Backend::MSNPU, ScalarType::Half, new MSNPUType());
context->registerType(Backend::XLA, ScalarType::Byte, new XLAType());
context->registerType(Backend::XLA, ScalarType::Char, new XLAType());
context->registerType(Backend::XLA, ScalarType::Double, new XLAType());
context->registerType(Backend::XLA, ScalarType::Float, new XLAType());
context->registerType(Backend::XLA, ScalarType::Int, new XLAType());
context->registerType(Backend::XLA, ScalarType::Long, new XLAType());
context->registerType(Backend::XLA, ScalarType::Short, new XLAType());
context->registerType(Backend::XLA, ScalarType::Half, new XLAType());
context->registerType(Backend::Undefined, ScalarType::Undefined, new UndefinedType());
}
#cuda
void register_cuda_types(Context * context) {
context->registerType(Backend::CUDA, ScalarType::Byte, new CUDAByteType());
context->registerType(Backend::CUDA, ScalarType::Char, new CUDACharType());
context->registerType(Backend::CUDA, ScalarType::Double, new CUDADoubleType());
context->registerType(Backend::CUDA, ScalarType::Float, new CUDAFloatType());
context->registerType(Backend::CUDA, ScalarType::Int, new CUDAIntType());
context->registerType(Backend::CUDA, ScalarType::Long, new CUDALongType());
context->registerType(Backend::CUDA, ScalarType::Short, new CUDAShortType());
context->registerType(Backend::CUDA, ScalarType::Half, new CUDAHalfType());
context->registerType(Backend::SparseCUDA, ScalarType::Byte, new SparseCUDAByteType());
context->registerType(Backend::SparseCUDA, ScalarType::Char, new SparseCUDACharType());
context->registerType(Backend::SparseCUDA, ScalarType::Double, new SparseCUDADoubleType());
context->registerType(Backend::SparseCUDA, ScalarType::Float, new SparseCUDAFloatType());
context->registerType(Backend::SparseCUDA, ScalarType::Int, new SparseCUDAIntType());
context->registerType(Backend::SparseCUDA, ScalarType::Long, new SparseCUDALongType());
context->registerType(Backend::SparseCUDA, ScalarType::Short, new SparseCUDAShortType());
}

Context類中實現的registerType如下所示:

void registerType(Backend b, ScalarType s, Type* t) {
globalLegacyTypeDispatch().registerType(b, s,
LegacyTypeDispatch::TypeUniquePtr{t, LegacyTypeDeleter([](Type* p) { delete p; }) });
}

其是通過LegacyTypeDispatch類里的2維數組type_registry完成的mapping。

9,生成Functions.h、LegacyTHFunctions.h

Functions.h是legacy的TH代碼的函數聲明,在未來,這部分將要重寫為modern的native函數。

10,生成NativeFunctions.h

根據模板NativeFunctions.h生成源文件NativeFunctions.h,其中佔位符 ${native_function_declarations}將由下面的內容替換(真正的內容來自於native_functions.yaml):

CAFFE2_API ${return_type} ${native_type_method_dispatch}(${formals_with_defaults});

這個文件聲明了所有的native函數。ATen的native函數才是PyTorch的函數應該有的樣子,區別於歷史上老舊的TH/THC cwrap函數。目前PyTorch這部分的遷移工作還未結束,有些老舊的來自Torch歷史博物館的TH/THC函數仍然在使用中。 Native 函數都會被聲明在native_functions.yaml中,然後實現在aten/src/ATen/native/目錄下的cpp文件中。

native函數可以通過C++和Python介面調用。在C++中,可以通過Tensor類的方法和at namespace中的函數來進行訪問;在Python中,可以通過Variable類或者torch._C._FunctionBase來訪問。

總結

在本文中,gemfield描述了Python中ATen的代碼動態生成。ATen中動態生成的代碼主要有:

1,Type繼承體系,包含頭文件和源文件;Type繼承體系/家族是聯繫Tensor op 與 legacy的TH/現代的native函數 之間的紐帶;

2,Declarations.yaml,會被Torch模塊動態生成代碼調用;

3,生成Tensor類;

4,生成Type家族註冊初始化的代碼;

5,生成legacy的TH/THC的函數的聲明;

6,生成modern PyTorch的native 函數的聲明。

推薦閱讀:

相关文章