背景

在使用PyTorch深度學習框架的時候,不管是訓練還是測試,代碼中引入PyTorch的第一句總是:

import torch

在Gemfield前述專欄文章里,我們已經得知,torch/csrc/stub.cpp鏈接libshm.so、libtorch_python.so、libcaffe2_gpu.so生成了_C.cpython-37m-x86_64-linux-gnu.so庫,而像前述方式import torch的時候,按照python規範,會找到torch package目錄下的__init__.py,在這個文件中進一步會調用:

from torch._C import *

其中torch._C就是_C.cpython-37m-x86_64-linux-gnu.so。因為(以Python3為例)按照Python規範,由於默認的引擎都是CPython,而CPython的C/C++擴展是一個共享庫,並且這個共享庫安裝在PYTHONPATH目錄下,並且文件名(不包含後綴)要和module的名字一樣,並且這個共享庫中要實現PyInit_modulename符號來作為import時候的邏輯入口。

對於PyTorch來說這個modulename 是_C,因此我們可以揣測,在torch/csrc/stub.cpp中一定實現了PyInit_C這個函數。是的,PyTorch就是這麼做的,torch/csrc/stub.cpp中的代碼就是下面這樣:

#include <Python.h>

extern PyObject* initModule();
PyMODINIT_FUNC PyInit__C()
{
return initModule();
}

本文將從initModule函數展開,全面闡述PyTorch框架的初始化工作。initModule就是PyTorch初始化時候的第一層調用棧了,因為所有的初始化工作都是在這個函數內完成的,內容比較多,gemfield將其劃分為7部分:

1,torch._C的誕生:

這一步就是產生torch._C類,並在這個python類上面註冊眾多函數:

PyObject* initModule() {
//openmp的設置
THInferNumThreads();
THPUtils_addPyMethodDefs(methods, TorchMethods);
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
THPUtils_addPyMethodDefs(methods, THCUDNN_methods());
THPUtils_addPyMethodDefs(methods, THDPModule_methods());
THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions());

module = Py_InitModule("torch._C", methods.data());
......
}

其中TorchMethods註冊了29個方法,都是THPModule_前綴的函數;DataLoaderMethods註冊了4個方法,都是THPModule_前綴的函數;torch::autograd::python_functions註冊了4個方法;torch::multiprocessing::python_functions註冊了1個方法;THCPModule_methods註冊了37個CUDA相關的函數,前綴都是THCPModule_;THCUDNN_methods註冊了1個方法;THDPModule_methods註冊了28個方法;torch::distributed::c10d::python_functions註冊了1個方法。

總而言之,在這一小步,我們達到了這樣一個里程碑,torch._C符號誕生,並且向torch._C註冊了一百餘個函數,涉及torch、dataloader、autograd、multiprocess、cuda、cudnn、distribute、c10d方面。

2,一些關鍵類型

以下代碼先後初始化了torch._C._PtrWrapper、torch._C.Generator(含5個方法)、FatalError、torch.Size、torch.dtype、torch.iinfo、torch.layout、torch.device:

PyObject* initModule() {
......
THPWrapper_init(module);
THPGenerator_init(module);
THPException_init(module);
THPSize_init(module);
THPDtype_init(module);
THPDTypeInfo_init(module);
THPLayout_init(module);
THPDevice_init(module);
THPVariable_initModule(module);
THPFunction_initModule(module);
THPEngine_initModule(module);
......
}

3,torch._C._TensorBase的誕生

Gemfield將以下三個初始化函數歸為這一小節:

PyObject* initModule() {
......
THPVariable_initModule(module);
THPFunction_initModule(module);
THPEngine_initModule(module);
......
}

為什麼呢?因為地位太顯赫了。

THPVariable_initModule(module) 創建了torch._C._TensorBase,這是一切Tensor的基類,在Gemfield的其它專欄文章里將單獨解釋;

THPFunction_initModule(module)創建了torch._C._FunctionBase,在torch/autograd/function.py中,以下兩個類以torch._C._FunctionBase為基類:

class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin))
class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin)

這個Function繼承體系就構成了DAG的基礎。

THPEngine_initModule(module)創建了torch._C._EngineBase,_EngineBase這個類負責動態圖執行之前的preprocess,_EngineBase會將torch.autograd的backward之類的請求預處理後送給真正的Engine去執行。

4,pybind11綁定

這一小節的初始化內容都是和pybind11相關的:

PyObject* initModule() {
......
// NOTE: We need to be able to access OperatorExportTypes from ONNX for use in
// the export side of JIT, so this ONNX init needs to appear before the JIT
// init.
torch::onnx::initONNXBindings(module);
torch::jit::initJITBindings(module);
torch::autograd::initNNFunctions(module);
torch::autograd::init_legacy_variable(module);
torch::python::init_bindings(module);
torch::cuda::initModule(module);
......
}

initONNXBindings是ONNX的python binding:torch._C._onnx.TensorProtoDataType和torch._C._onnx.OperatorExportTypes:

>>> dir(torch._C._onnx.TensorProtoDataType)
[BOOL, COMPLEX128, COMPLEX64, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, STRING, UINT16, UINT32, UINT64, UINT8, UNDEFINED, __class__, __delattr__, __dir__, __doc__, __eq__, __format__, __ge__, __getattribute__, __getstate__, __gt__, __hash__, __init__, __int__, __le__, __lt__, __members__, __module__, __ne__, __new__, __reduce__, __reduce_ex__, __repr__, __setattr__, __setstate__, __sizeof__, __str__, __subclasshook__, name]
>>> dir(torch._C._onnx.OperatorExportTypes)
[ONNX, ONNX_ATEN, ONNX_ATEN_FALLBACK, RAW, __class__, __delattr__, __dir__, __doc__, __eq__, __format__, __ge__, __getattribute__, __getstate__, __gt__, __hash__, __init__, __int__, __le__, __lt__, __members__, __module__, __ne__, __new__, __reduce__, __reduce_ex__, __repr__, __setattr__, __setstate__, __sizeof__, __str__, __subclasshook__, name]

initJITBindings則是通過pybind11往torch._C上註冊了一堆和JIT相關的C++函數/對象;

initNNFunctions初始化了一個torch._C._nn 對象,並註冊了一些nn相關的函數:

>>> dir(torch._C._nn)
[__doc__, __loader__, __name__, __package__, __spec__, _parse_to, adaptive_avg_pool2d, adaptive_avg_pool3d, adaptive_max_pool2d, adaptive_max_pool3d, avg_pool2d, avg_pool3d, binary_cross_entropy, elu, elu_,
fractional_max_pool2d, glu, hardtanh, hardtanh_, l1_loss, leaky_relu, leaky_relu_, log_sigmoid, max_pool2d_with_indices, max_pool3d_with_indices, max_unpool2d, max_unpool3d, mse_loss, multi_margin_loss,
multilabel_margin_loss, nll_loss, nll_loss2d, reflection_pad1d, reflection_pad2d, replication_pad1d, replication_pad2d, replication_pad3d, rrelu_with_noise, rrelu_with_noise_, smooth_l1_loss, soft_margin_loss,
softplus, softshrink, thnn_conv2d, thnn_conv3d, thnn_conv_depthwise2d, thnn_conv_dilated2d, thnn_conv_dilated3d, thnn_conv_transpose2d, thnn_conv_transpose3d, upsample_bilinear2d, upsample_linear1d, upsample_nearest1d,
upsample_nearest2d, upsample_nearest3d, upsample_trilinear3d]

init_legacy_variable註冊了torch._C._LegacyVariableBase:

>>> dir(torch._C._LegacyVariableBase)
[__class__, __delattr__, __dir__, __doc__, __eq__, __format__,
__ge__, __getattribute__, __gt__, __hash__, __init__, __le__,
__lt__, __ne__, __new__, __reduce__, __reduce_ex__, __repr__,
__setattr__, __sizeof__, __str__, __subclasshook__]

_LegacyVariableBase類會派生出Variable類(該類的_execution_engine會初始化為torch._C._EngineBase):

class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase))

init_bindings是通過pybind11往torch._C上註冊一些函數,torch::cuda::initModule類似,也是通過pybind11往torch._C上註冊一些函數,只不過內容是和cuda相關的。

5,在torch._C上註冊StorageBase類

PyObject* initModule() {
......
THPDoubleStorage_init(module);
THPFloatStorage_init(module);
THPHalfStorage_init(module);
THPLongStorage_init(module);
THPIntStorage_init(module);
THPShortStorage_init(module);
THPCharStorage_init(module);
THPByteStorage_init(module);
THCPDoubleStorage_init(module);
THCPFloatStorage_init(module);
THCPHalfStorage_init(module);
THCPLongStorage_init(module);
THCPIntStorage_init(module);
THCPShortStorage_init(module);
THCPCharStorage_init(module);
THCPByteStorage_init(module);
THCPStream_init(module);
......
}

這些初始化工作主要就是往torch._C上註冊了以下類:

CudaByteStorageBase
CudaCharStorageBase
CudaDoubleStorageBase
CudaFloatStorageBase
CudaHalfStorageBase
CudaIntStorageBase
CudaLongStorageBase
CudaShortStorageBase

ByteStorageBase
CharStorageBase
DoubleStorageBase
FloatStorageBase
HalfStorageBase
IntStorageBase
LongStorageBase
ShortStorageBase

比如以FloatStorageBase為例的話,我們可以這樣查看它註冊的方法:

>>> dir(torch._C.FloatStorageBase)
[__class__, __delattr__, __delitem__, __dir__, __doc__, __eq__, __format__, __ge__, __getattribute__, __getitem__, __gt__, __hash__, __init__, __le__, __len__, __lt__,
__ne__, __new__, __reduce__, __reduce_ex__, __repr__, __setattr__, __setitem__, __sizeof__, __str__, __subclasshook__, _cdata, _expired, _free_weak_ref,
_get_shared_fd, _new_shared_fd, _new_shared_filename, _new_using_fd, _new_using_filename, _new_with_file, _new_with_weak_ptr, _set_cdata, _set_from_file, _share_fd_,
_share_filename_, _shared_decref, _shared_incref, _weak_ref, _write_file, copy_, data_ptr, element_size, fill_, from_buffer, from_file, is_pinned, is_shared, new,
resize_, size]

這些類會在python體系中被繼承:

class FloatStorage(_C.FloatStorageBase, _StorageBase)

另外注意下這塊代碼使用了一些宏來複用不同storage的代碼,如下所示:

aten/src/TH/THGenerateLongType.h:10:#define Real Long
aten/src/TH/THGenerateHalfType.h:10:#define Real Half
aten/src/TH/THGenerateIntType.h:10:#define Real Int
aten/src/TH/THGenerateFloatType.h:9:#define Real Float
aten/src/TH/THGenerateShortType.h:10:#define Real Short
aten/src/TH/THGenerateCharType.h:8:#define Real Char
aten/src/TH/THGenerateByteType.h:8:#define Real Byte
aten/src/TH/THGenerateDoubleType.h:9:#define Real Double
aten/src/THC/THCGenerateIntType.h:7:#define Real Int
aten/src/THC/THCGenerateLongType.h:7:#define Real Long
aten/src/THC/THCGenerateCharType.h:7:#define Real Char
aten/src/THC/THCGenerateFloatType.h:9:#define Real Float
aten/src/THC/THCGenerateDoubleType.h:7:#define Real Double
aten/src/THC/THCGenerateHalfType.h:9:#define Real Half
aten/src/THC/THCGenerateShortType.h:7:#define Real Short
aten/src/THC/THCGenerateByteType.h:7:#define Real Byte

6,ATen的初始化

本小節會進行ATen的global context的初始化,然後使用at::globalContext().defaultGenerator(at::kCPU)進行generator的初始化。

另外,PyTorch會根據編譯環境和用戶配置,然後向torch._C上註冊一些flag。這些flag有has_cudnn、has_mkl、has_lapack、_GLIBCXX_USE_CXX11_ABI:

PyObject* initModule() {
......
PyObject *has_cudnn = Py_True;
set_module_attr("has_cudnn", has_cudnn);

at::init();
py::reinterpret_borrow<py::module>(module).def("_demangle", &c10::demangle);
::c10::Warning::set_warning_handler(&warning_handler);

set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False);
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False);
set_module_attr("_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False);
auto& defaultGenerator = at::globalContext().defaultGenerator(at::kCPU);
THPDefaultGenerator = (THPGenerator*)THPGenerator_NewWithGenerator(defaultGenerator);
set_module_attr("default_generator", (PyObject*)THPDefaultGenerator, /* incref= */ false);

7,torch._C._THNN和torch._C._THCUNN的初始化

PyTorch在這一小節里註冊了torch._C._THNN和torch._C._THCUNN類:

PyObject* initModule() {
......
torch::nn::init__THNN(module);
torch::nn::init__THCUNN(module);
......
}

這兩個類都擁有數量巨大的op函數,一個是CPU版的,一個是CUDA版的。

總結

在PyTorch的初始化階段,(python)torch模塊先後初始化產生torch._C、torch._C._TensorBase、pybind11綁定、torch._C.*StorageBase、torch._C._THNN、torch._C._THCUNN,並進行了ATen context的初始化。


推薦閱讀:
相关文章