axis参数是理解数组操作的关键,它决定了函数沿着哪个轴进行计算。很多初学者对此感到困惑,这里进行一个系统的总结。
一、axis的基本概念
核心思想:axis参数指定了操作进行的方向,或者说沿着哪个轴的方向进行压缩。
二维数组(矩阵)的情况:
import numpy as np
arr = np.array([[1, 2, 3],
[4, 5, 6]])
# shape = (2, 3)
axis=0:跨行操作(垂直方向),压缩行维度
axis=1:跨列操作(水平方向),压缩列维度
高维数组的情况:
对于shape为(a, b, c, d)的数组:
axis=0:操作后维度变为(b, c, d)(去掉最外层维度)
axis=1:操作后维度变为(a, c, d)
axis=2:操作后维度变为(a, b, d)
axis=3:操作后维度变为(a, b, c)
二、常见函数的axis用法
1. 聚合函数(求和、均值等)
arr = np.array([[1, 2, 3],
[4, 5, 6]])
# axis=0:跨行(垂直)求和,每列求和
np.sum(arr, axis=0) # 结果:[5, 7, 9] shape: (3,)
# axis=1:跨列(水平)求和,每行求和
np.sum(arr, axis=1) # 结果:[6, 15] shape: (2,)
# 不指定axis:所有元素求和
np.sum(arr) # 结果:21
2. 连接函数(concatenate, stack等)
a = np.array([[1, 2], [3, 4]])
b = np.array([[5, 6], [7, 8]])
# axis=0:垂直堆叠(增加行)
np.concatenate([a, b], axis=0) # shape: (4, 2)
# axis=1:水平堆叠(增加列)
np.concatenate([a, b], axis=1) # shape: (2, 4)
# stack:创建新轴
np.stack([a, b], axis=0) # shape: (2, 2, 2),在最外层加维度
np.stack([a, b], axis=2) # shape: (2, 2, 2),在最内层加维度
3. 统计函数(argmax, sort等)
arr = np.array([[3, 1, 4],
[1, 5, 9]])
# axis=0:每列找最大值的位置
np.argmax(arr, axis=0) # [0, 1, 1]
# axis=1:每行找最大值的位置
np.argmax(arr, axis=1) # [2, 2]
# 排序
np.sort(arr, axis=0) # 每列排序
np.sort(arr, axis=1) # 每行排序
三、实用记忆技巧
1. "压缩轴"思维
口诀:axis=n表示去掉第n个维度(沿着这个方向压缩)
arr = np.ones((3, 4, 5))
# shape: (3, 4, 5)
result = np.sum(arr, axis=1)
# 沿着axis=1(第二个维度)压缩
# 结果shape: (3, 5) # 去掉了中间的4
2. 维度变化规律
原始shape:(d0, d1, d2, ..., dn)
沿axis=k操作后:(d0, d1, ..., d(k-1), d(k+1), ..., dn)
3. 可视化理解
对于二维数组:
axis=1
----->
[1, 2, 3] ↑
[4, 5, 6] | axis=0
↓
axis=0:从上到下操作(跨行)
axis=1:从左到右操作(跨列)
四、常见易错点
1. 维度减少
大多数聚合操作会减少维度:
arr = np.array([[1, 2], [3, 4]]) # shape: (2, 2)
result = np.sum(arr, axis=0) # shape: (2,)
# 从2D变成了1D
2. keepdims参数
如果想保持维度不变:
arr = np.array([[1, 2], [3, 4]])
np.sum(arr, axis=0, keepdims=True) # shape: (1, 2)
np.sum(arr, axis=1, keepdims=True) # shape: (2, 1)
3. 广播中的axis
a = np.array([[1, 2, 3]]) # shape: (1, 3)
b = np.array([[1], [2]]) # shape: (2, 1)
a + b # shape: (2, 3),自动广播
五、实战练习
理解axis最好的方式是多练习:
# 创建一个3D数组
arr_3d = np.arange(24).reshape(2, 3, 4)
# 练习不同的axis操作
print("原始shape:", arr_3d.shape) # (2, 3, 4)
# 沿着不同轴求和
print("axis=0 sum shape:", np.sum(arr_3d, axis=0).shape) # (3, 4)
print("axis=1 sum shape:", np.sum(arr_3d, axis=1).shape) # (2, 4)
print("axis=2 sum shape:", np.sum(arr_3d, axis=2).shape) # (2, 3)
# 沿着多个轴操作
print("axis=(0,1) sum shape:", np.sum(arr_3d, axis=(0, 1)).shape) # (4,)
总结
axis指定操作方向:沿着哪个轴计算
维度变化:操作通常会压缩(减少)指定的轴
从外到内:axis=0是最外层,axis=-1是最内层
多练习:使用不同维度的数组进行实验
记住:当不确定时,创建一个简单的测试数组,亲自尝试不同axis参数的效果,这是最有效的学习方法。