torch.argmax#
torch.argmax(input, dim=None, keepdim=False)
参数说明:
input:输入的张量。
dim:指定在哪个维度上寻找最大值,默认为None,表示在整个张量中寻找最大值。
keepdim:是否保持输出张量的维度和输入张量一致,默认为False。
import torch
a = torch.randn(4, 5)
print(a)
tensor([[ 1.3138, 0.8097, 0.6892, -0.6997, 1.1975],
[ 0.9036, 0.9653, 1.2415, -1.0621, -0.6339],
[ 0.7873, -0.7691, 0.8263, 1.8420, 0.2670],
[-0.1668, -1.1357, 1.2815, 0.3972, -0.6037]])
b=torch.argmax(a, dim=1)
b
tensor([0, 2, 3, 2])
c=torch.argmax(a)
c
tensor(13)
d=torch.argmax(a, dim=0)
d
tensor([0, 1, 3, 2, 0])