Skip to content

数据操作 + 数据预处理

N 维矩阵

  • 标量:0 维
  • 向量:1 维
  • 矩阵:2 维
  • RGB 图片:3 维(宽、高、通道)
  • RGB 图片批量:4 维(就是多张 RGB 图片一起传)
  • 一个视频批量:5 维(多张 RGB 图片 + 时间维度)

创建张量

python
x = torch.tensor([1, 2, 3, 4])

张量:表示一个数值组成的数组,这个数组可能有多个维度。

可以用 Python 的列表来进行类型转化。

访问元素

冒号代表切片,可以访问一列、一行、一个子矩阵

也可以跳着访问 ,比如 [::3, ::2] 代表每三行取一行,每两列取一列

也可以用这种方式修改元素,区域赋值。

数据操作

这里我们使用 Pytorch 框架来使用。

张量

  • 形状
python
x.shape
  • 总数
python
x.numel()
  • 调整形状
python
x = x.reshape(2, 2)

上面的一维向量就会变成矩阵

  • 运算

加减乘除求幂等,元素间运算

python
y = torch.tensor([1, 2, 3, 4])
z = x + y
z = x - y
z = x * y
z = x / y
z = x ** y
torch.exp(x)
  • 对所有元素求和
python
x.sum()

结果一定是一个标量。

是降维的操作

  • 指定求和汇总张量的轴
python
x.shape # x 的形状(维度)
x.sum(axis=0) # 按列求和
x.sum(axis=1) # 按行求和
x.sum(axis=(0, 1)) # 按行按列求和
  • 如果不想丢掉维度,可以使用 keepdim=True

  • 这样和原本的维度一样,就可以使用广播机制了。例如:x/x.sum(axis=1, keepdim=True)

  • 张量间拼接
python
torch.cat((x, y), dim=0) # 第 0 维拼接
torch.cat((x, y), dim=1) # 第 1 维拼接
  • 张量间堆叠
python
torch.stack((x, y), dim=0) # 第 0 维连接
torch.stack((x, y), dim=1) # 第 1 维连接
torch.cat (拼接)torch.stack (堆叠)
功能本质合并现有维度上的数据创建新维度后堆叠数据
输入要求除拼接维度外,其他维度必须相同所有维度必须完全相同
输出维度保持维度数量不变维度数量+1
类比场景书本并排放在书架上(厚度增加)书本叠放在一起(高度增加)
  • cat = Combine Along Thickness (沿厚度合并)

  • stack = Start Totally A Completely new Kdimension (开启新维度)

  • 矩阵乘法

python
A = torch.arange(20).reshape(5, 4)
B = torch.arange(12).reshape(4, 3)
C = torch.matmul(A, B)
  • 广播机制

两个形状不同的数组进行运算时,可以自动扩展形状以匹配。

维度扩展的规则是:如果两个数组的形状在某个维度上的值相同,或者其中一个为 1,那么这两个数组在该维度上是兼容的。

WARNING

虽然很方便,但是要小心维度的增加可能会不如我们所愿。

赋值

python
before = id(x)
x = x + y
id(x) == before

返回结果是 False,证明在赋值时 python 创建了一个新的对象。

原地操作

有时我们为了节省内存,会进行原地操作。

python
Z = torch.zeros_like(x) # 创建形状和 x 一样的全 0 矩阵

Z[:] = x + y  # 原地操作,对 Z 的所有元素进行改写

如果后续不会使用到 x,那么可以直接使用 x[:] = x + yx += y 来节省内存。

python
# 低效方式(产生中间变量)
x = x + y + z  # 创建两个临时张量

# 高效方式(原地操作)
x += y
x += z         # 零额外内存分配

这里我对 x +=y 产生了疑问,因为在我的固有认知里, x += y 的结果和 x = x + y 的结果是一样的,可是上面的结果却截然相反。一种是原地操作,而另一种却是创建了新对象。

WARNING

下方部分内容采用了 Deepseek AI 生成讲解,并做了一定的修改,请读者注意鉴别内容真实性。

这里我拿 Pytorch 的张量和 Python 的列表都进行了实验,结果如下:

操作张量行为列表行为本质原因
x = x + y创建新对象创建新列表操作符重载(__add__方法)
x += y原地修改原地修改操作符重载(__iadd__方法)
x[:] = ...原地修改原地修改索引赋值的特殊语法

PyTorch 张量是可变对象,但具体是否原地修改取决于操作符的实现方式:

python
# 案例对比
x = torch.tensor([1,2,3])
y = torch.tensor([4,5,6])

print(id(x))        # 原始地址:0x7f8a1c
x = x + y           # 创建新对象(调用 __add__)
print(id(x))        # 新地址:0x7f8b2d ❌

x += y              # 原地修改(调用 __iadd__)
print(id(x))        # 地址不变:0x7f8b2d ✅

与其他语言的对比

语言数组类型+= 行为设计理念
Pythonlist原地修改可变对象标准行为
Pythontorch.Tensor原地修改深度学习内存优化需求
C++原生数组原地修改直接内存操作
JavaScriptArray原地修改动态数组特性

📚 浅拷贝 vs 深拷贝

类型特点张量对应操作内存影响
浅拷贝创建新对象引用原数据y = x.view(...)共享存储(危险!)
深拷贝完全复制数据y = x.clone()独立存储(安全)

经典案例

python
a = torch.tensor([1,2,3])
b = a          # 浅拷贝(同一对象)
c = a.clone()  # 深拷贝

a[0] = 999
print(b[0])    # 输出 999 ❗
print(c[0])    # 输出 1 ✅

🔧 最佳实践指南

1. 显式控制内存

python
# 推荐方式
x.add_(y)                # 明确表达原地操作
torch.add(x, y, out=x)   # 指定输出位置

# 不推荐方式
x = x + y                # 产生不可控的中间变量

2. 内存优化技巧

python
# 重用预分配内存
buffer = torch.empty_like(x)
buffer[:] = x + y + z    # 单次内存分配

# 链式操作优化
(x + y).add_(z)          # 比 x + y + z 少一次分配

3. 检测工具

python
def is_inplace(op):
    return hasattr(op, '__iadd__')  # 检测是否支持原地操作

print(is_inplace(torch.Tensor))     # 输出 True

🌰 现实类比理解

  • 创建新对象:像复印文件后修改复印件(原文件不变)
  • 原地操作:直接在原文件上批注修改
  • 浅拷贝:同一份文件投向多个显示屏(一处修改处处可见)
  • 深拷贝:复印文件后各自独立保存

操作符重载

  • + 操作符:对应 __add__ 方法,总是返回新对象
  • += 操作符:对应 __iadd__ 方法,优先尝试原地修改

🌍 现实世界类比

想象你有一个 万能遥控器

  • 普通模式:按「+」键只是调高电视音量
  • 空调模式:同一个「+」键变成调高温度
  • 游戏机模式:「+」键又变成了切换武器

操作符重载就像这个遥控器的智能切换—— 同一个操作符(如 +)在不同场景下触发不同行为,具体行为由操作对象的类型决定。

💻 代码世界中的操作符重载

在编程中,操作符重载允许我们 为自定义类型(如类)定义操作符的行为。通过实现特定的魔法方法(如 __add__),我们可以控制 + 等操作符的具体功能。

基础示例(数值 vs 字符串)

python
# 数值加法(数学运算)
print(3 + 5)        # 输出 8 → 执行算术加法

# 字符串拼接(连接操作)
print("Hello" + "World")  # 输出 HelloWorld → 执行连接操作

这里 + 操作符的 实际行为取决于操作数类型,这正是操作符重载的体现。

🔧 自定义类的操作符重载

假设我们创建一个「购物车」类,希望用 + 直接合并商品:

1. 定义类及重载方法

python
class ShoppingCart:
    def __init__(self, items):
        self.items = items.copy()  # 防止浅拷贝问题
    
    # 重载 + 操作符(对应 __add__ 方法)
    def __add__(self, other_cart):
        # 合并两个购物车的商品
        new_items = self.items + other_cart.items
        return ShoppingCart(new_items)  # 返回新对象
    
    # 重载 += 操作符(对应 __iadd__ 方法)
    def __iadd__(self, other_cart):
        # 原地合并,不创建新对象
        self.items.extend(other_cart.items)
        return self  # 必须返回自身引用

    def __repr__(self):
        return f"购物车商品:{self.items}"

2. 使用效果对比

python
# 初始化购物车
cart_a = ShoppingCart(["苹果", "牛奶"])
cart_b = ShoppingCart(["面包", "鸡蛋"])

# 使用 + 操作符(调用 __add__)
combined_cart = cart_a + cart_b
print(combined_cart)  # 输出:购物车商品:['苹果', '牛奶', '面包', '鸡蛋']
print(id(cart_a) == id(combined_cart))  # False → 新对象

# 使用 += 操作符(调用 __iadd__)
cart_a += cart_b
print(cart_a)  # 输出:购物车商品:['苹果', '牛奶', '面包', '鸡蛋']
print(id(cart_a))  # 与原始id相同 → 原地修改

📚 操作符重载核心机制

操作符对应方法典型行为是否需要返回新对象
+__add__合并/相加是 ✅
+=__iadd__原地扩展/累加否 ❌(返回self)
==__eq__比较内容是否相等返回布尔值
[]__getitem__实现索引访问-

设计原则

  1. 保持直观:重载后的行为要符合直觉(如 Vector(1,2) + Vector(3,4) 应得到 Vector(4,6))
  2. 区分创建与修改+ 应产生新对象,+= 应原地修改
  3. 类型安全:处理不同类型操作时的兼容性(如购物车不能直接加数字)

🚀 PyTorch 中的应用场景

在深度学习框架中,操作符重载被大量使用以实现直观的数学表达:

张量运算示例

python
import torch

# 创建张量
a = torch.tensor([1.0, 2.0], requires_grad=True)
b = torch.tensor([3.0, 4.0], requires_grad=True)

# 通过重载的操作符进行自动微分计算
c = a + b       # 等价于 torch.add(a, b)
d = c * 2       # 等价于 torch.mul(c, 2)
loss = d.sum()  # 等价于 torch.sum(d)

loss.backward() # 自动计算梯度
print(a.grad)   # 输出:tensor([2., 2.])

优势体现

  • 代码简洁a + btorch.add(a, b) 更直观
  • 兼容自动微分:重载的操作符能记录计算图
  • 性能优化:底层通过C++实现高效运算

⚠️ 注意事项(新手常见坑)

  1. 不要滥用重载:确保操作符行为符合普遍认知

  2. 深浅拷贝问题

    python
    # 危险示例(浅拷贝)
    class Matrix:
        def __init__(self, data):
            self.data = data  # 直接引用外部列表
        
        def __add__(self, other):
            return Matrix(self.data + other.data)  # 浅拷贝隐患!
    
    # 安全做法:
    def __init__(self, data):
        self.data = data.copy()  # 创建副本
  3. 类型检查:处理不同类型操作时的容错机制

    python
    def __add__(self, other):
        if not isinstance(other, ShoppingCart):
            raise TypeError("只能合并购物车对象!")
        # ...合并逻辑...

🔍 检测重载方法

可以通过 dir() 查看对象支持的重载方法:

python
print(dir(torch.Tensor))  # 会显示 __add__, __iadd__ 等方法

通过理解操作符重载,你就能真正掌握像 PyTorch 这样的框架如何实现既直观又高效的张量运算啦!(๑•̀ㅂ•́)و✧

数据预处理

使用 Pandas 库来进行数据预处理。

处理缺失的数据

可以考虑插值或是删除。

插值

使用 fillna() 方法进行插值。

64 位浮点数对深度学习比较慢,我们之后会使用 32 位浮点数。

如果有字符串,最简单的方式是将字符串转换为数字。

例如老师的视频中,地点列中除了第一个是 某某胡同,其他行都是 NaN,那么我们可以分为两类,分成是 NaN某某胡同 的。分别用 0 和 1 来表示,这样我们的数据就变成了纯数字了。