How to get embeddings from an image classification model on Pytorch
One easy way to get feature vectors is by removing the last layer from a pre-trained image classification neural network. By doing that, instead of class probabilities, we get image embeddings that can serve as input to a clustering algorithm.
Let’s take the VGG-16 network as an example. It’s a rather small and fast CNN, therefore appropriate for image sets with a large number of samples. The Python code below loads Pytorch’s pre-trained VGG-16 model and discards its softmax layer.
from torchvision import models
class Vgg16FeatureExtractor:
def __init__(self):
self.model = models.vgg16(weights="DEFAULT")
self.model.classifier = self.model.classifier[:-1]
self.model.eval()
def compute_features(self, img):
img = np.moveaxis(img, 2, 0)
img_tensor = torch.from_numpy(data).unsqueeze(0).float()
features_tensor = self.model(img_tensor)
features = features_tensor.data
features = features.data.numpy()
return features