pytorch的维度变换函数¶
约 419 个字 6 行代码 1 张图片 预计阅读时间 2 分钟
维度转换函数¶
torch.unsqueeze(input, dim)
:在指定维度dim
上增加一个新的维度。如果dim
已经存在,则在其前面添加新的维度。torch.squeeze(input, dim=None)
:移除所有长度为1的维度。如果指定了dim
,则只移除该维度。torch.flatten(input, start_dim=0, end_dim=-1)
:将输入张量从start_dim
到end_dim
的所有维度展平。torch.view(input, size)
或input.view(size)
:重新调整张量的形状,不改变数据。torch.reshape(input, shape)
:与view
类似,用于改变张量的形状,但reshape
可以处理更复杂的维度变换,如增加或减少维度。torch.permute(input, dims)
:重新排列输入张量的维度,dims
是一个维度索引的元组。torch.transpose(input, dim0, dim1)
:交换输入张量的两个维度。torch.expand(input, size)
:将输入张量沿指定的维度复制扩展。torch.cat(tensors, dim)
:沿指定维度dim
连接多个张量。torch.stack(tensors, dim)
:沿新的维度dim
堆叠多个张量,与cat
不同的是,stack
会增加一个新的维度。-
torch.reapeat
-
torch.tile
Python
positon_embedding = torch.tile(positon_embedding_table[:seq_len],[token_embedding.shape[0],1,1])
# positon_embedding_table[:seq_len] = positon_embedding_table[:5] 取前5个8维
# [:5] 表示 对 第一维 索引
# positon_embedding_table[:seq_len] = 5,8
# [token_embedding.shape[0],1,1] = [1,1,1]
# positon_embedding = 1,5,8
理解张量¶
假如你有一个篮子,里面装满了各种颜色的小球。每个小球代表一个数字。现在,如果我们想把这些小球按照一定的顺序排列,比如一行或者一列,这就是一个一维数组。如果你把几行这样的小球排列起来,就形成了一个二维数组,就像一个表格一样。如果你再把这些表格堆叠起来,就形成了一个三维数组。在PyTorch中,张量就是一种用来表示这些不同维度数组的数据结构。
2024-11-15 17:55:43 2024-12-19 22:38:58