专栏名称: 机器之心
专业的人工智能媒体和产业服务平台
目录
相关文章推荐
51好读  ›  专栏  ›  机器之心

教程 | PyTorch内部机制解析:如何通过PyTorch实现Tensor

机器之心  · 公众号  · AI  · 2017-08-07 00:00

正文

请到「今天看啥」查看全文



通用构建(第一部分)


我们可以花费大量时间探索 THPTensor 的各个方面,以及它如何与一个新定义 Python 对象相关联。但是我们仍然需要明白 THPTensor_(init)()函数是如何转换成我们在模块初始化中使用的 THPIntTensor_init()函数。我们该如何使用定义「通用」Tensor 的 Tensor.cpp 文件,并使用它来生成所有类型序列的 Python 对象?换句话说,Tensor.cpp 里遍布着如下代码:


return THPTensor_(New)(THTensor_(new)(LIBRARY_STATE_NOARGS));


这说明了我们需要使类型特定的两种情况:


  • 我们的输出代码将调用 THP Tensor_New(...)代替调用 THPTensor_(New)

  • 我们的输出代码将调用 TH Tensor_new(...)代替调用 THTensor_(new)


换句话说,对于所有支持的 Tensor 类型,我们需要「生成」已经完成上述替换的源代码。这是 PyTorch 的「构建」过程的一部分。PyTorch 依赖于配置工具(https://setuptools.readthedocs.io/en/latest/)来构建软件包,我们在顶层目录中定义一个 setup.py 文件来自定义构建过程。


使用配置工具构建扩展模块的一个组件是列出编译中涉及的源文件。但是,我们的 csrc/generic/Tensor.cpp 文件未列出!那么这个文件中的代码最终是如何成为最终产品的一部分呢?


回想前文所述,我们从以上的 generic 目录中调用 THPTensor *函数(如 init)。如果我们来看一下这个目录,会发现一个定义了的 Tensor.cpp 文件。此文件的最后一行很重要:


//generic_include TH torch/csrc/generic/Tensor.cpp


请注意,虽然这个 Tensor.cpp 文件被 setup.py 文件引用,但它被包装在一个叫 Python helper 的名为 split_types 的函数里。这个函数需要输入一个文件,并在该文件内容中寻找「//generic_include」字符串。如果能匹配该字符串,它将会为每个张量类型生成一个具有以下变动的输出文件,:


1. 输出文件重命名为 Tensor .cpp

2. 输出文件小幅修改如下:


// Before:
//generic_include TH torch/csrc/generic/Tensor.cpp

// After:
#define TH_GENERIC_FILE "torch/src/generic/Tensor.cpp"
#include "TH/THGenerateType.h"


引入第二行的头文件有些许弊端,例如,引入了一些额外的上下文中定义的 Tensor.cpp 源代码。让我们看看其中一个头文件:


#ifndef TH_GENERIC_FILE
#error "You must define TH_GENERIC_FILE before including THGenerateFloatType.h"
#endif

#define real float
#define accreal double
#define TH_CONVERT_REAL_TO_ACCREAL(_val) (accreal)(_val)
#define TH_CONVERT_ACCREAL_TO_REAL(_val) (real)(_val)
#define Real Float
#define THInf FLT_MAX
#define TH_REAL_IS_FLOAT
#line 1 TH_GENERIC_FILE
#include TH_GENERIC_FILE
#undef accreal
#undef real
#undef Real
#undef THInf
#undef TH_REAL_IS_FLOAT
#undef TH_CONVERT_REAL_TO_ACCREAL
#undef TH_CONVERT_ACCREAL_TO_REAL

#ifndef THGenerateManyTypes
#undef TH_GENERIC_FILE
#endif


这样做的目的是从通用 Tensor.cpp 文件引入代码,并使用后面的宏定义。例如,我们将 real 定义为一个浮点数,所以泛型 Tensor 实现中的任何代码将指向一个 real 对象,实际上 real 被替换为浮点数。在对应的文件 THGenerateIntType.h 中,同样的宏定义将用 int 替换 real。


这些输出文件从 split_types 返回,并添加到源文件列表中,因此我们可以看到不同的类型的.cpp 代码是如何创建的。


这里需要注意以下几点:第一,split_types 函数不是必需的。我们可以将 Tensor.cpp 中的代码包装在一个文件中,然后为每个类型重复使用。我们将代码分割成单独文件的原因是这样可以加快编译速度。第二,当我们谈论类型替换(例如用浮点数代替 real)时,我们的意思是,C 预处理器将在编译期执行这些替换。并且在预处理之前这些嵌入源代码的宏定义都没有什么弊端。


通用构建(第二部分)


我们现在有所有的 Tensor 类型的源文件,我们需要考虑如何创建相应的头文件声明,以及如何将 THTensor_(方法)和 THPTensor_(方法)转化成 TH Tensor_method 和 THP Tensor_method。例如,csrc/generic/Tensor.h 具有如下声明:


THP_API PyObject * THPTensor_(New)(THTensor *ptr);

我们使用相同的策略在头文件的源文件中生成代码。在 csrc/Tensor.h 中,我们执行以下操作:


#include "generic/Tensor.h"
#include 

#include "generic/Tensor.h"
#include 

从通用的头文件中抽取代码和用相同的宏定义包装每个类型具有同样的效果。唯一的区别就是前者编译后的代码包含在同一个头文件中,而不是分为多个源文件。


最后,我们需要考虑如何「转换」或「替代」函数类型。如果我们查看相同的头文件,我们会看到一堆 #define 语句,其中包括:


#define THPTensor_(NAME)            TH_CONCAT_4(THP,Real,Tensor_,NAME)


这个宏表示,源代码中的任何匹配形如 THPTensor_(NAME)的字符串都应该替换为 THPRealTensor_NAME,其中 Real 参数是从符号 Real 所在的 #define 定义的时候派生的。因为我们的头文件代码和源代码都包含所有上述类型的宏定义,所以在预处理器运行之后,生成的代码就是我们想要的。


TH 库中的代码为 THTensor_(NAME)定义了相同的宏,支持这些功能的转移。如此一来,我们最终就会得到带有专用代码的头文件和源文件。


#### 模块对象和类型方法,我们现在已经看到如何在 THP 中封装 TH 的 Tensor 定义,并生成了 THPFloatTensor_init(...)等 THP 方法。现在我们可以从我们创建的模块中了解上面的代码实际上做了什么。THPTensor_(init)中的关键行是:


# THPTensorBaseStr, THPTensorType are also macros that are specific 
# to each type
PyModule_AddObject(module, THPTensorBaseStr, (PyObject *)&THPTensorType);


该函数将 Tensor 对象注册到扩展模块,因此我们可以在我们的 Python 代码中使用 THPFloatTensor,THPIntTensor 等。


只是单纯的创建 Tensors 不是很有用 - 我们需要能够调用 TH 定义的所有方法。以下是一个在 Tensor 上调用就地(in-place)zero_ 方法的简单例子。


x = torch.FloatTensor(10)
x.zero_()


我们先看看如何向新定义的类型中添加方法。「类型对象」中的有一个字段 tp_methods。此字段包含方法定义数组(PyMethodDefs),用于将方法(及其底层 C / C ++实现)与类型相关联。假设我们想在我们的 PyFloatObject 上定义一个替换该值的新方法。我们可以按照下面的步骤来实现这一想法:


static PyObject * replace(PyFloatObject *self, PyObject *args) {
	double val;
	if (!PyArg_ParseTuple(args, "d", &val))
		return NULL;
	self->ob_fval = val;
	Py_RETURN_NONE
}


Python 版本的等价方法


def replace(self, val):
	self.ob_fval = fal






请到「今天看啥」查看全文