代码:

import torch

class_num = 10

batch_size = 4

label = torch.LongTensor(batch_size, 1).random_() % class_num

print(label.size())

one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)

print(one_hot)

输出:

torch.Size([4, 1])

tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],

[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],

[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],

[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

注意:

label的形状必须是[n,1]的,也就是必须是二维的,且第二个维度长度为1,如果是一维度的,则需要升维度,代码如下:

import torch

class_num = 10

batch_size = 4

label = torch.LongTensor(batch_size).random_() % class_num

print(label.size())

label = torch.unsqueeze(label,dim=1)

print(label.size())

点赞(0) 打赏

评论列表 共有 0 条评论

暂无评论

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部