torch.argmax

torch.argmax#

torch.argmax(input, dim=None, keepdim=False)

参数说明:

  • input:输入的张量。

  • dim:指定在哪个维度上寻找最大值,默认为None,表示在整个张量中寻找最大值。

  • keepdim:是否保持输出张量的维度和输入张量一致,默认为False。

import torch
a = torch.randn(4, 5)
print(a)
tensor([[-1.2369, -1.6170, -1.4609, -1.5570, -0.6633],
        [-0.2065,  0.1751,  2.0633,  1.2556, -0.3395],
        [-0.8613,  1.0231,  0.3513,  0.5407, -0.1302],
        [ 2.1049,  1.1248, -2.3158, -0.8583, -0.5402]])
b=torch.argmax(a, dim=1)
b
tensor([4, 2, 1, 0])
c=torch.argmax(a)
c
tensor(15)
d=torch.argmax(a, dim=0)
d
tensor([3, 3, 1, 1, 2])