torch.argmax#
torch.argmax(input, dim=None, keepdim=False)
参数说明:
input:输入的张量。
dim:指定在哪个维度上寻找最大值,默认为None,表示在整个张量中寻找最大值。
keepdim:是否保持输出张量的维度和输入张量一致,默认为False。
import torch
a = torch.randn(4, 5)
print(a)
tensor([[-0.8898, 0.5368, 0.7786, 0.6075, 1.0010],
[-0.3952, -1.5348, -1.9316, -0.0043, 0.4604],
[ 1.1868, -0.0518, 1.0266, -1.8229, -0.7412],
[ 0.4409, -0.1297, -0.1610, 0.2229, -1.0035]])
b=torch.argmax(a, dim=1)
b
tensor([4, 4, 0, 0])
c=torch.argmax(a)
c
tensor(10)
d=torch.argmax(a, dim=0)
d
tensor([2, 0, 2, 0, 0])