torch.argmax#
torch.argmax(input, dim=None, keepdim=False)
参数说明:
input:输入的张量。
dim:指定在哪个维度上寻找最大值,默认为None,表示在整个张量中寻找最大值。
keepdim:是否保持输出张量的维度和输入张量一致,默认为False。
import torch
a = torch.randn(4, 5)
print(a)
tensor([[ 1.3138,  0.8097,  0.6892, -0.6997,  1.1975],
        [ 0.9036,  0.9653,  1.2415, -1.0621, -0.6339],
        [ 0.7873, -0.7691,  0.8263,  1.8420,  0.2670],
        [-0.1668, -1.1357,  1.2815,  0.3972, -0.6037]])
b=torch.argmax(a, dim=1)
b
tensor([0, 2, 3, 2])
c=torch.argmax(a)
c
tensor(13)
d=torch.argmax(a, dim=0)
d
tensor([0, 1, 3, 2, 0])