ResNet50 모델을 이용해 돼지 이미지를 분류
일반적으로 PyTorch에서 image classification을 하기 위해서는 ( 대략적으로 zero-mean, unit variance ) torch.vision.transforms 를 이용해 이미지 변환을 먼저 하는게 대부분이다.
하지만 우리는 원본(unnormalized) 이미지에다가 perturbation을 하는 것이 목적이기 때문에 다른 방식으로 접근해야하고 PyTorch layers 에다가 transformation을 해줘야 한다.
먼저 이미지를 224*224 크기로 resize 해준다. (대부분의 ImageNet 이미지 및 pretrained 된 classifier가 받는 기본 크기)
from PIL import Image
from torchvision import transforms
# read the image, resize to 224 and convert to PyTorch Tensor
pig_img = Image.open("pig.jpg")
preprocess = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
])
pig_tensor = preprocess(pig_img)[None,:,:,:]
# plot image (note that numpy using HWC whereas Pytorch user CHW, so we need to convert)
plt.imshow(pig_tensor[0].numpy().transpose(1,2,0))
필요한 transform을 거친 후 pre-trained 된 ResNet50 모델을 가져와서 이미지에 적용해봅시다.
(한 가지 주의할 점은 PyTorch standards를 준수하기 위해 모듈에 대한 모든 input들은 batch_sizenum_channelsheight*width 형태로 들어가야 한다는 점)
import torch
import torch.nn as nn
from torchvision.models import resnet50
# simple Module to normalize an image
class Normalize(nn.Module):
def __init__(self, mean, std):
super(Normalize, self).__init__()
self.mean = torch.Tensor(mean)
self.std = torch.Tensor(std)
def forward(self, x):
return (x - self.mean.type_as(x)[None,:,None,None]) / self.std.type_as(x)[None,:,None,None]
# values are standard normalization for ImageNet images,
# from <https://github.com/pytorch/examples/blob/master/imagenet/main.py>
norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# load pre-trained ResNet50, and put into evaluation mode (necessary to e.g. turn off batchnorm)
model = resnet50(pretrained=True)
model.eval();
# form predictions
pred = model(norm(pig_tensor))
예측값은 1000개의 imagenet class에 대한 class logit 정보가 담긴 1000 dimension의 vector를 포함하고 있다. (이를 probability vector로 변환하고 싶다면 이 벡터값에다가 softmax를 적용하면 된다.) 가장 높은 likelihood class를 찾기 위해 단순히 최대값의 인덱스를 가져오고 imagenet class에서 해당하는 label을 찾아보기로 했다.
import json
with open("imagenet_class_index.json") as f:
imagenet_classes = {int(i):x[1] for i,x in json.load(f).items()}
print(imagenet_classes[pred.max(dim=1)[1].item()])
hog
ImageNet에서는 pig==hog이다. 결과가 잘 나왔음을 알 수 있다.