torch.argmax#
torch.argmax(input, dim=None, keepdim=False)
参数说明:
input:输入的张量。
dim:指定在哪个维度上寻找最大值,默认为None,表示在整个张量中寻找最大值。
keepdim:是否保持输出张量的维度和输入张量一致,默认为False。
import torch
a = torch.randn(4, 5)
print(a)
tensor([[-1.2369, -1.6170, -1.4609, -1.5570, -0.6633],
[-0.2065, 0.1751, 2.0633, 1.2556, -0.3395],
[-0.8613, 1.0231, 0.3513, 0.5407, -0.1302],
[ 2.1049, 1.1248, -2.3158, -0.8583, -0.5402]])
b=torch.argmax(a, dim=1)
b
tensor([4, 2, 1, 0])
c=torch.argmax(a)
c
tensor(15)
d=torch.argmax(a, dim=0)
d
tensor([3, 3, 1, 1, 2])