torch.argmax#

torch.argmax(input, dim=None, keepdim=False)

参数说明:

  • input:输入的张量。

  • dim:指定在哪个维度上寻找最大值,默认为None,表示在整个张量中寻找最大值。

  • keepdim:是否保持输出张量的维度和输入张量一致,默认为False。

import torch
a = torch.randn(4, 5)
print(a)
tensor([[-0.8898,  0.5368,  0.7786,  0.6075,  1.0010],
        [-0.3952, -1.5348, -1.9316, -0.0043,  0.4604],
        [ 1.1868, -0.0518,  1.0266, -1.8229, -0.7412],
        [ 0.4409, -0.1297, -0.1610,  0.2229, -1.0035]])
b=torch.argmax(a, dim=1)
b
tensor([4, 4, 0, 0])
c=torch.argmax(a)
c
tensor(10)
d=torch.argmax(a, dim=0)
d
tensor([2, 0, 2, 0, 0])