杰瑞科技汇

Python squeeze函数如何使用?

这个函数的名字很形象,"squeeze" 意为“挤压”,在 NumPy 中,它的作用就是“挤压”掉数组中维度大小为 1 的维度。

Python squeeze函数如何使用?-图1
(图片来源网络,侵删)

核心功能

numpy.squeeze(a, axis=None) 的主要功能是:

从给定的 NumPy 数组 a 中移除所有维度大小为 1 的轴。

  • 输入:一个 NumPy 数组。
  • 输出:一个新数组,其中所有长度为 1 的维度都被移除。
  • 重要:它不会修改原始数组,而是返回一个新的数组。

语法和参数

numpy.squeeze(a, axis=None)

参数详解:

  • a: 输入的 NumPy 数组。
  • axis (可选): 这是一个关键参数。
    • 不提供 axis (或设为 None),那么所有维度大小为 1 的维度都会被移除。
    • 提供了 axis,它必须是一个整数或一个整数元组,函数会尝试移除指定的轴,但前提是该轴的大小必须为 1,如果指定的轴大小不为 1,则会抛出 ValueError 错误。

工作原理与示例

让我们通过一系列例子来理解 squeeze() 是如何工作的。

示例 1:基本情况(移除所有大小为 1 的维度)

这是最常见的用法。

import numpy as np
# 创建一个形状为 (1, 3, 1, 5) 的数组
arr = np.arange(15).reshape(1, 3, 1, 5)
print("原始数组形状:", arr.shape)
# 输出: 原始数组形状: (1, 3, 1, 5)
# 使用 squeeze() 移除所有大小为 1 的维度
squeezed_arr = np.squeeze(arr)
print("挤压后数组形状:", squeezed_arr.shape)
# 输出: 挤压后数组形状: (3, 5)
# 原始数组保持不变
print("原始数组形状:", arr.shape)
# 输出: 原始数组形状: (1, 3, 1, 5)

解释: 原始数组有 4 个维度,其中第 0 个维度(大小为 1)和第 2 个维度(大小为 1)是“可以被挤压”的。squeeze() 移除了这两个维度,最终得到了一个二维 (3, 5) 的数组。


示例 2:axis 参数的使用

当你只想移除特定的某个维度时,axis 参数就派上用场了。

情况 A:指定一个有效的 axis(大小为 1)

import numpy as np
arr = np.arange(15).reshape(1, 3, 1, 5)
print("原始数组形状:", arr.shape) # (1, 3, 1, 5)
# 只想移除第 0 个维度 (axis=0)
squeezed_axis0 = np.squeeze(arr, axis=0)
print("挤压 axis=0 后的形状:", squeezed_axis0.shape)
# 输出: 挤压 axis=0 后的形状: (3, 1, 5)
# 只想移除第 2 个维度 (axis=2)
squeezed_axis2 = np.squeeze(arr, axis=2)
print("挤压 axis=2 后的形状:", squeezed_axis2.shape)
# 输出: 挤压 axis=2 后的形状: (1, 3, 5)

情况 B:指定一个无效的 axis(大小不为 1)

如果你试图移除一个大小不为 1 的维度,NumPy 会报错。

import numpy as np
arr = np.arange(15).reshape(1, 3, 1, 5)
print("原始数组形状:", arr.shape) # (1, 3, 1, 5)
# 尝试移除大小为 3 的维度 (axis=1)
try:
    np.squeeze(arr, axis=1)
except ValueError as e:
    print(f"错误: {e}")
# 输出: 错误: cannot select an axis to squeeze out which has size not equal to one

情况 C:使用元组指定多个 axis

你也可以一次性移除多个特定的维度,只要它们的大小都为 1。

import numpy as np
arr = np.arange(15).reshape(1, 3, 1, 5)
print("原始数组形状:", arr.shape) # (1, 3, 1, 5)
# 同时移除 axis=0 和 axis=2
squeezed_axes = np.squeeze(arr, axis=(0, 2))
print("挤压 axis=(0, 2) 后的形状:", squeezed_axes.shape)
# 输出: 挤压 axis=(0, 2) 后的形状: (3, 5)

这个结果和示例 1 中不指定 axis 的结果是一样的,但这里我们明确指出了要移除哪些维度。


示例 3:对没有可挤压维度的数组使用 squeeze()

如果数组的所有维度大小都不为 1,squeeze() 不会做任何事,直接返回原数组的副本。

import numpy as np
arr = np.array([[1, 2, 3], [4, 5, 6]]) # 形状为 (2, 3)
print("原始数组形状:", arr.shape)
squeezed_arr = np.squeeze(arr)
print("挤压后数组形状:", squeezed_arr.shape)
# 输出: 原始数组形状: (2, 3)
# 输出: 挤压后数组形状: (2, 3)

实际应用场景

squeeze() 在数据科学和机器学习中非常实用,尤其是在处理不同来源的数据时,因为它们的维度可能不一致。

场景 1:处理单张图片数据

假设你有一个灰度图像,其数据维度通常是 (height, width),但为了兼容某些模型的输入要求,你可能会将其形状调整为 (1, height, width, 1)

  • 1 是批处理大小(batch size)
  • 1 是通道数(对于灰度图是 1)
import numpy as np
# 模拟一张 28x28 的灰度图
# 形状为 (1, 28, 28, 1) -> (batch, height, width, channels)
image_batch = np.random.rand(1, 28, 28, 1)
print("图像批次形状:", image_batch.shape) # (1, 28, 28, 1)
# 如果某个模型要求输入是 (height, width, channels)
processed_image = np.squeeze(image_batch)
print("处理后图像形状:", processed_image.shape) # (28, 28, 1)
# 如果模型甚至只需要 (height, width)
final_image = np.squeeze(image_batch, axis=(0, 3))
print("最终图像形状:", final_image.shape) # (28, 28)

场景 2:与神经网络模型交互

很多深度学习框架(如 TensorFlow, PyTorch)在计算损失或进行预测时,可能会返回带有冗余维度的张量。

一个分类模型在处理单个样本时,输出可能是 [[0.1, 0.8, 0.1]],形状为 (1, 3),为了得到最终的类别概率 [0.1, 0.8, 0.1],就可以使用 squeeze()

import numpy as np
# 模型对单个样本的预测输出
# 形状为 (1, 3),表示 1 个样本,3 个类别的概率
model_output = np.array([[0.1, 0.8, 0.1]])
print("模型输出形状:", model_output.shape) # (1, 3)
# 移除批处理维度,得到概率向量
probabilities = np.squeeze(model_output)
print("概率向量形状:", probabilities.shape) # (3,)
# 现在可以轻松找到最大概率的类别
predicted_class = np.argmax(probabilities)
print(f"预测的类别是: {predicted_class}") # 预测的类别是: 1

重要注意事项

  1. 不修改原数组squeeze() 总是返回一个新数组,原始数组不会被改变。
  2. 数据视图(View)与副本(Copy):在大多数情况下,squeeze() 返回的是原始数据的视图(view),而不是副本,这意味着修改 squeezed_arr 中的元素也会影响原始 arr 中的元素,反之亦然,因为它们共享内存空间,这对于节省内存非常有用,但在某些特殊情况下(数组在内存中不是连续的),它可能会返回一个副本。
  3. reshape()expand_dims() 的关系
    • squeeze()reshape() 是互逆操作的一种形式。squeeze() 移除大小为 1 的维度,而 reshape() 可以改变维度,但不能凭空创建或移除数据。
    • squeeze()expand_dims() 是完全相反的操作。expand_dims 用于在指定位置增加一个大小为 1 的新维度。

特性 描述
功能 移除数组中大小为 1 的维度。
语法 numpy.squeeze(a, axis=None)
axis=None 移除所有大小为 1 的维度。
axis=N 只移除指定位置 N 且大小为 1 的维度,如果指定维度大小不为 1,则报错。
返回值 一个新的 NumPy 数组。
原数组 保持不变。
主要用途 清理数据维度,使其符合特定函数或模型的输入要求,尤其是在处理批处理数据和图像数据时。

当你看到一个数组形状里有很多 1,而这些 1 又没有实际意义时,squeeze() 就是你最好的工具之一。

分享:
扫描分享到社交APP
上一篇
下一篇