注:个人向李沐课程笔记
所有代码建议自己运行体会
更新日志
- 2023年4月5日:撰写关于python的原地操作(in-place)与拷贝
- 2023年4月6日:撰写一些基础知识
关于python的原地操作(in-place)与拷贝
以下类型以tensor为例,不同数据类型情况可能并不相同
初始化
import torch
t1 = torch.tensor([1, 2, 3, 4])
t2 = torch.tensor([5, 6, 7, 8])
原地操作
print(id(t1))
t1.add_(t2)
print(id(t1))
t1+=t2
print(id(t1))
t1[:]=t2
print(id(t1))
'''
输出结果:
2127245686720
2127245686720
2127245686720
2127245686720
'''
在tensor中,后缀为_的运算函数,如.add_()或.scatter_()等,+=、*=等运算符以及使用[:]=赋值所产生的操作为原地操作
原地操作在原对象所指向的数据内存中直接进行修改,不会发生拷贝,所以操作前后的对象id一致
其中t1[:]=t2表示只对t1的值即所指向的数据内存里存储的数据进行了修改,对t1本身不做修改,此处与t1=t2需做区分
原地操作在对元素较多的矩阵进行运算时,由于不发生析构和构造等操作,可以节省这方面在空间与时间上的开销从而加快运行速度
但是原地操作在pytorch的Autograd自动计算张量的梯度时可能会发生异常,在这种情况下不推荐使用
非原地操作
print('t1 id=',id(t1))
print('t1 data id = ',t1.storage().data_ptr())
t1=t1+t2
print('t1 id=',id(t1))
print('t1 data id = ',t1.storage().data_ptr())
'''
输出结果:
t1 id= 2127322448976
t1 data id = 2127292453376
t1 id= 2127245685440
t1 data id = 2127292454912
'''
除了上一小节外其它大多运算符为非原地操作。如上述代码所示,+运算符不仅改变了t1的头信息,还改变了t1所指向的数据内存,即在赋值的过程中t1+t2所产生的新的结果先存储在一块新的内存区域里,然后原来的t1被析构,再重构一个新的t1,同时将t1+t2的数据内存地址赋给t1
在进行大规模矩阵运算时,非原地操作会使运算在时间上的开销增大,理论上极限空间开销也会变大
浅拷贝
print(id(t2))
t2 = t1[:] # $1
print(id(t2))
print(t1)
print(t2)
t2.add_(1)
print(t1)
print(t2)
t2 = t2+1
print(t1)
print(t2)
'''
输出结果:
2127245687760
2127245685200
tensor([6, 7, 8, 9])
tensor([6, 7, 8, 9])
tensor([ 7, 8, 9, 10])
tensor([ 7, 8, 9, 10])
tensor([ 7, 8, 9, 10])
tensor([ 8, 9, 10, 11])
'''
$1处发生了浅拷贝,原t2所指向的内存被析构,并创建新的头信息内存,同时将t1所指向的数据内存地址赋给了t2,即此时t1和t2拥有同一块数据内存。当t1或t2数据发生原地操作时,对方的数据也会同样地发生变化;相反地,当t1或t2数据发生非原地操作时,则如上一小节所提到的双方数据将产生分歧
发生浅拷贝的方式有很多,但事实上很多方式由于各种原因最后导致结果和深拷贝一致,一般和tensor的结构有关。当要确保两个变量是浅拷贝时,最好先用.storage().data_ptr()判断一下,否则可能会发生异常
此外我们可以使用copy.copy()的方法来进行浅拷贝
import copy
print('t1 data id = ',t1.storage().data_ptr())
print('t2 data id = ',t2.storage().data_ptr())
t2 = copy.copy(t1)
print('t1 data id = ',t1.storage().data_ptr())
print('t2 data id = ',t2.storage().data_ptr())
'''
输出结果:
t1 data id = 2127292454016
t2 data id = 2127292453568
t1 data id = 2127292454016
t2 data id = 2127292454016
'''
tensor结构
参考:[PyTorch:view() 与 reshape() 区别详解] ,内含较为丰富的图片与代码解析
想要深入了解拷贝机制就必须先了解tensor的结构
tensor的结构分为头信息区(Tensor)和存储区 (Storage),头信息区存储了tensor的size、stride、grad_fn、数据索引、存储区地址等信息,存储区即为存储数据的内存区域首地址
tensor在内存中的数据都是以一维数组的形式存储,这里有个概念称为view,和数据库中的view相似,可以理解为头信息区不同但所指向存储区相同,从而使表达出来的数据不同,如下举例:
print('t1 data id = ',t1.storage().data_ptr())
print('t2 data id = ',t2.storage().data_ptr())
print(t1)
print(t2)
t1=t2[1:3]
print('t1 data id = ',t1.storage().data_ptr())
print('t2 data id = ',t2.storage().data_ptr())
print(t1)
print(t2)
'''
输出结果:
t1 data id = 2127292449856
t2 data id = 2127292449536
tensor([1, 2, 3, 4])
tensor([5, 6, 7, 8])
t1 data id = 2127292449536
t2 data id = 2127292449536
tensor([6, 7])
tensor([5, 6, 7, 8])
'''
这里可以看到,虽说t1和t2输出的数据结构不一样,但他们指向的是同一块数据内存,正是因为二者头信息区的索引不同
tensor的连续性
因为有view的存在,所以就有了连续性(contiguous)的概念。一个tensor是连续的当且仅当它的每个维度上的步长都是后一个维度上步长与该维度大小的乘积,且最后一个维度的步长为1。对于一维tensor而言,步长为1时即为连续;对于(2,3)大小的tensor而言,步长为(3,1)时即为连续
对于.reshape()而言,其返回可能是浅拷贝(即view),也可能是深拷贝(即copy),按官方文档的说法是,当tensor满足连续性时.reshape()返回一个view,否则将其强制变为连续并返回一个copy,但同时官方表示不要依赖于这个“规律”…
而鄙人在测试过程中也发现了非连续tensor进行reshape时亦可能返回一个view
p1 = torch.arange(12)
print('p1 stride = ',p1.stride())
print('p1 data id = ',p1.storage().data_ptr())
print(p1.is_contiguous())
p2 = p1.reshape(3,4)
print('p2 stride = ',p2.stride())
print('p2 data id = ',p2.storage().data_ptr())
print(p2.is_contiguous())
p3 = p2[::2,::2]
print('p3 stride = ',p3.stride())
print('p3 data id = ',p3.storage().data_ptr())
print(p3.is_contiguous())
p4 = p3.reshape(2,2)
print('p4 stride = ',p4.stride())
print('p4 data id = ',p4.storage().data_ptr())
print(p4.is_contiguous())
p5 = p2.t()
print('p5 stride = ',p5.stride())
print('p5 data id = ',p5.storage().data_ptr())
print(p5.is_contiguous())
p6 = p5.reshape(6,2)
print('p6 stride = ',p6.stride())
print('p6 data id = ',p6.storage().data_ptr())
print(p6.is_contiguous())
'''
输出结果:
p1 stride = (1,)
p1 data id = 2127293126720
True
p2 stride = (4, 1)
p2 data id = 2127293126720
True
p3 stride = (8, 2)
p3 data id = 2127293126720
False
p4 stride = (8, 2)
p4 data id = 2127293126720
False
p5 stride = (1, 4)
p5 data id = 2127293126720
False
p6 stride = (2, 1)
p6 data id = 2127293124992
True
'''
如上所示,t3改变了步长使之不连续,但其reshape返回的依旧是一个不连续view,同样地p5进行了转置使之不连续,此时返回的是一个连续的copy
深拷贝
本应浅拷贝的操作返回的不一定是view,但本应深拷贝的操作返回的一定是copy,通常可以使用copy.deepcopy()、.clone()、.detach()等,这其中的区别在写完梯度相关内容后补充