Bruno Arine

How to get embeddings from an image classification model on Pytorch

One easy way to get feature vectors is by removing the last layer from a pre-trained image classification neural network. By doing that, instead of class probabilities, we get image embeddings that can serve as input to a clustering algorithm.

Let’s take the VGG-16 network as an example. It’s a rather small and fast CNN, therefore appropriate for image sets with a large number of samples. The Python code below loads Pytorch’s pre-trained VGG-16 model and discards its softmax layer.

from torchvision import models

class Vgg16FeatureExtractor:
    def __init__(self):
        self.model = models.vgg16(weights="DEFAULT")
        self.model.classifier = self.model.classifier[:-1]
        self.model.eval()

    def compute_features(self, img):
        img = np.moveaxis(img, 2, 0)
        img_tensor = torch.from_numpy(data).unsqueeze(0).float()

        features_tensor = self.model(img_tensor)
        features = features_tensor.data
        features = features.data.numpy()

        return features