这个函数的名字很形象,"squeeze" 意为“挤压”,在 NumPy 中,它的作用就是“挤压”掉数组中维度大小为 1 的维度。

核心功能
numpy.squeeze(a, axis=None) 的主要功能是:
从给定的 NumPy 数组 a 中移除所有维度大小为 1 的轴。
- 输入:一个 NumPy 数组。
- 输出:一个新数组,其中所有长度为 1 的维度都被移除。
- 重要:它不会修改原始数组,而是返回一个新的数组。
语法和参数
numpy.squeeze(a, axis=None)
参数详解:
a: 输入的 NumPy 数组。axis(可选): 这是一个关键参数。- 不提供
axis(或设为None),那么所有维度大小为 1 的维度都会被移除。 - 提供了
axis,它必须是一个整数或一个整数元组,函数会尝试移除指定的轴,但前提是该轴的大小必须为 1,如果指定的轴大小不为 1,则会抛出ValueError错误。
- 不提供
工作原理与示例
让我们通过一系列例子来理解 squeeze() 是如何工作的。
示例 1:基本情况(移除所有大小为 1 的维度)
这是最常见的用法。
import numpy as np
# 创建一个形状为 (1, 3, 1, 5) 的数组
arr = np.arange(15).reshape(1, 3, 1, 5)
print("原始数组形状:", arr.shape)
# 输出: 原始数组形状: (1, 3, 1, 5)
# 使用 squeeze() 移除所有大小为 1 的维度
squeezed_arr = np.squeeze(arr)
print("挤压后数组形状:", squeezed_arr.shape)
# 输出: 挤压后数组形状: (3, 5)
# 原始数组保持不变
print("原始数组形状:", arr.shape)
# 输出: 原始数组形状: (1, 3, 1, 5)
解释:
原始数组有 4 个维度,其中第 0 个维度(大小为 1)和第 2 个维度(大小为 1)是“可以被挤压”的。squeeze() 移除了这两个维度,最终得到了一个二维 (3, 5) 的数组。
示例 2:axis 参数的使用
当你只想移除特定的某个维度时,axis 参数就派上用场了。
情况 A:指定一个有效的 axis(大小为 1)
import numpy as np
arr = np.arange(15).reshape(1, 3, 1, 5)
print("原始数组形状:", arr.shape) # (1, 3, 1, 5)
# 只想移除第 0 个维度 (axis=0)
squeezed_axis0 = np.squeeze(arr, axis=0)
print("挤压 axis=0 后的形状:", squeezed_axis0.shape)
# 输出: 挤压 axis=0 后的形状: (3, 1, 5)
# 只想移除第 2 个维度 (axis=2)
squeezed_axis2 = np.squeeze(arr, axis=2)
print("挤压 axis=2 后的形状:", squeezed_axis2.shape)
# 输出: 挤压 axis=2 后的形状: (1, 3, 5)
情况 B:指定一个无效的 axis(大小不为 1)
如果你试图移除一个大小不为 1 的维度,NumPy 会报错。
import numpy as np
arr = np.arange(15).reshape(1, 3, 1, 5)
print("原始数组形状:", arr.shape) # (1, 3, 1, 5)
# 尝试移除大小为 3 的维度 (axis=1)
try:
np.squeeze(arr, axis=1)
except ValueError as e:
print(f"错误: {e}")
# 输出: 错误: cannot select an axis to squeeze out which has size not equal to one
情况 C:使用元组指定多个 axis
你也可以一次性移除多个特定的维度,只要它们的大小都为 1。
import numpy as np
arr = np.arange(15).reshape(1, 3, 1, 5)
print("原始数组形状:", arr.shape) # (1, 3, 1, 5)
# 同时移除 axis=0 和 axis=2
squeezed_axes = np.squeeze(arr, axis=(0, 2))
print("挤压 axis=(0, 2) 后的形状:", squeezed_axes.shape)
# 输出: 挤压 axis=(0, 2) 后的形状: (3, 5)
这个结果和示例 1 中不指定 axis 的结果是一样的,但这里我们明确指出了要移除哪些维度。
示例 3:对没有可挤压维度的数组使用 squeeze()
如果数组的所有维度大小都不为 1,squeeze() 不会做任何事,直接返回原数组的副本。
import numpy as np
arr = np.array([[1, 2, 3], [4, 5, 6]]) # 形状为 (2, 3)
print("原始数组形状:", arr.shape)
squeezed_arr = np.squeeze(arr)
print("挤压后数组形状:", squeezed_arr.shape)
# 输出: 原始数组形状: (2, 3)
# 输出: 挤压后数组形状: (2, 3)
实际应用场景
squeeze() 在数据科学和机器学习中非常实用,尤其是在处理不同来源的数据时,因为它们的维度可能不一致。
场景 1:处理单张图片数据
假设你有一个灰度图像,其数据维度通常是 (height, width),但为了兼容某些模型的输入要求,你可能会将其形状调整为 (1, height, width, 1),
1是批处理大小(batch size)1是通道数(对于灰度图是 1)
import numpy as np
# 模拟一张 28x28 的灰度图
# 形状为 (1, 28, 28, 1) -> (batch, height, width, channels)
image_batch = np.random.rand(1, 28, 28, 1)
print("图像批次形状:", image_batch.shape) # (1, 28, 28, 1)
# 如果某个模型要求输入是 (height, width, channels)
processed_image = np.squeeze(image_batch)
print("处理后图像形状:", processed_image.shape) # (28, 28, 1)
# 如果模型甚至只需要 (height, width)
final_image = np.squeeze(image_batch, axis=(0, 3))
print("最终图像形状:", final_image.shape) # (28, 28)
场景 2:与神经网络模型交互
很多深度学习框架(如 TensorFlow, PyTorch)在计算损失或进行预测时,可能会返回带有冗余维度的张量。
一个分类模型在处理单个样本时,输出可能是 [[0.1, 0.8, 0.1]],形状为 (1, 3),为了得到最终的类别概率 [0.1, 0.8, 0.1],就可以使用 squeeze()。
import numpy as np
# 模型对单个样本的预测输出
# 形状为 (1, 3),表示 1 个样本,3 个类别的概率
model_output = np.array([[0.1, 0.8, 0.1]])
print("模型输出形状:", model_output.shape) # (1, 3)
# 移除批处理维度,得到概率向量
probabilities = np.squeeze(model_output)
print("概率向量形状:", probabilities.shape) # (3,)
# 现在可以轻松找到最大概率的类别
predicted_class = np.argmax(probabilities)
print(f"预测的类别是: {predicted_class}") # 预测的类别是: 1
重要注意事项
- 不修改原数组:
squeeze()总是返回一个新数组,原始数组不会被改变。 - 数据视图(View)与副本(Copy):在大多数情况下,
squeeze()返回的是原始数据的视图(view),而不是副本,这意味着修改squeezed_arr中的元素也会影响原始arr中的元素,反之亦然,因为它们共享内存空间,这对于节省内存非常有用,但在某些特殊情况下(数组在内存中不是连续的),它可能会返回一个副本。 - 与
reshape()和expand_dims()的关系:squeeze()和reshape()是互逆操作的一种形式。squeeze()移除大小为 1 的维度,而reshape()可以改变维度,但不能凭空创建或移除数据。squeeze()和expand_dims()是完全相反的操作。expand_dims用于在指定位置增加一个大小为 1 的新维度。
| 特性 | 描述 |
|---|---|
| 功能 | 移除数组中大小为 1 的维度。 |
| 语法 | numpy.squeeze(a, axis=None) |
axis=None |
移除所有大小为 1 的维度。 |
axis=N |
只移除指定位置 N 且大小为 1 的维度,如果指定维度大小不为 1,则报错。 |
| 返回值 | 一个新的 NumPy 数组。 |
| 原数组 | 保持不变。 |
| 主要用途 | 清理数据维度,使其符合特定函数或模型的输入要求,尤其是在处理批处理数据和图像数据时。 |
当你看到一个数组形状里有很多 1,而这些 1 又没有实际意义时,squeeze() 就是你最好的工具之一。
