Peter Gao
February 22, 2023
-
11 min

A Guide To Few-Shot Learning With Embeddings

Judging from this robot’s exasperated body language, it must not have a lot of training examples to learn from. Generated from Stable Diffusion using the prompt, “A humanoid robot inspecting a plant among a row of plants.”

Motivation

Machine learning practitioners often perform tasks to improve their training datasets, like:

  • Evaluating the frequency of a certain model failure. For example, understanding how often an aerial imagery model confuses green roofs with forests.
  • Collecting more data of a rare / difficult example. For example, finding more examples of a glass bottle that a model commonly confuses with plastic.
  • Splitting a coarse class into more fine-grained subclasses. For example, creating a “fire truck” class from a broader “truck” class.

Sometimes it’s possible to accomplish this by writing a query or one-off script to filter on metadata — for example, querying for examples with more than a certain height or width to fill out a certain subclass that is correlated with size. In other situations, the description of the desired data can’t be captured in code, so the only workable option is to manually look through data with an operations team to find the right datapoints. Neither of these options are great, and there is room to do better.

One way to look at the problem is to realize that metadata scripts or manual inspection are ways of doing human labeling for a category of data. However, to train a good classifier for a modern deep learning model, teams often need thousands to tens of thousands of labeled datapoints, and it’s expensive to label so much data.

However, we can use some prior knowledge to significantly simplify this problem. In the examples above, we observe that it’s often easy to find small quantities of interesting data but hard to find and collect larger sets of data of the same category. As a result, we can apply a technique called few-shot learning, where the goal is to classify a large amount of unlabeled data using a small amount of training examples, perhaps even a single example.

In this post, we are going to walk through an example of doing few-shot learning with neural network embeddings and show how some light processing on embeddings can make common data tasks very easy.

A short disclaimer before we start: I am a cofounder at Aquarium, and I wrote this post because we believe that combining machine learning techniques and interactive user interfaces could make it far more efficient for users to perform tasks on machine learning data. Instead of relying on large operations workforces or skilled ML engineering time, non-technical domain experts with the right tools can efficiently understand and improve machine learning data. We will visualize many of the experiments in this blog post with Aquarium.

The CropAndWeed dataset

Top: an overview of plant species in the dataset. Bottom: a sample frame in the dataset with annotations.

For this post, we will use the CropAndWeed dataset, which contains images containing human-annotated bounding boxes and semantic masks of various crops and weeds in various weather conditions, lighting conditions, soil types, etc. There are 24 broad plant classes that are broken into 74 finer grained classes, of which 16 are crops and 58 are weeds, based on attributes like number of leaves and stage of growth.

Top: a young sugar beet with only two straight leaves. Bottom: an older sugar beet with more leaves that are starting to curl.

We will start with the broad class of sugar beet plants. It turns out that there’s a lot of variation in this class. Some sugar beet labels are blurry or cut off by the side of the image. Younger sugar beets may only have two small straight-looking leaves, while older sugar beets have many more larger “curly” leaves.

The medium-developed sugar beet that we would like to make a subclass around. Notice the 4 semi-curved leaves.

We will start with a single example of a sugar beet plant in an intermediate growth stage, when it has around 3–4 leaves that are starting to curl. We would like to collect more examples of these beets to turn into a finer grained subclass.

Using embeddings from a pre-trained model

Generating neural network embeddings from a pretrained model

To start, we’re going to use something called a neural network embedding. A neural network embedding is a low-dimensional vector representation of a high dimensional datapoint — it represents what a model “thought” about a particular example. Distances between embedding vectors capture similarity between different datapoints, and can capture higher level concepts in the original input. We can compare and manipulate embeddings to do few-shot learning effectively.

For image classification convnets, embeddings can be extracted from the second to last layer, circled in the red oval. Modified from source.

The easiest way to get started with generating embeddings is to run a pretrained model on a set of data. For computer vision use cases, Pytorch and Keras offer a suite of image classification models pretrained on ImageNet.

It’s fairly simple to run the pretrained model on an image / image crop and then grab embeddings from the activations of the second-to-last layer in that model. For example, to extract embeddings from a Resnet50 image classification model in Pytorch, simply run:

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("image.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Run the model to extract embeddings
embeddings = model(batch)

An example of cosine distance used to do similarity comparisons on text embeddings.

Now that we have some embeddings, we can do few-shot learning by treating it as a similarity search problem, where we find examples in our unlabeled data that are similar to our small set of training / query images.

This is implemented by calculating cosine distance between the embeddings of our query images and the embeddings of images in the unlabeled dataset to search within. The lower the cosine distance between the query image and an image in the search set, the more visually / semantically similar they are. This can be easily calculated using the cosine distance function in scikit-learn.

We can then sort the unlabeled dataset by cosine distance to the query examples and then visualize them. In classifier terms, think of these as the images in the unlabeled dataset that are the most likely to belong to the same subclass as the query image.

Top: the query sugar beet. Bottom: the results from running a similarity search on ImageNet embeddings. The quality of the search results are not great.

In the figure above, we see that the similarity search results are not what we are looking for. The results of the embeddings are plants that look similar to an untrained eye, but all of the results are actually different species of plants (denoted by the different color boxes), not sugar beets.

This is because the embedding generation model was trained on ImageNet, which consists of images scraped from the internet, while the agricultural dataset we are using here looks very visually different. As a result, the embeddings we generated are highlighting surface level similarity — similar size, number of leaves, etc. However, they do not capture the types of similarity that we want to capture with our few-shot learning algorithm — similar species, stage of development, number of leaves, etc.

Using embeddings from a fine-tuned model

Luckily, there is a whole field devoted to training models that generate better quality embeddings. This is known as representation learning or metric learning.

To create embeddings that are more suited for this domain, we can fine-tune a model on our current dataset — that is, start with pretrained model weights and then train the model with lower learning rate on the dataset. That way, we can leverage features that the pretrained model has learned on a larger dataset like ImageNet that are general to most images (textures, corners, edges, etc.), but teach the model features that are specific to our data domain (leaf type, shape, etc.).

SimCLR is a popular framework for contrastive learning that can be fine-tuned on datasets from different domains.

We can fine-tune the model using a technique called contrastive learning. This is a methodology that explicitly trains a model to group together similar datapoints and disperse dissimilar datapoints in the embedding space it learns. A popular contrastive learning framework is SimCLR, which applies data augmentation to datapoints (cropping, resizing, and recoloring of a base image) and then optimizes a contrastive loss function that causes the model to learn that these augmented versions of the original image are “similar.”

Here’s some sample code on how to fine-tune a SimCLR model using the official TF2 implementation:

class Model(tf.keras.Model):
  def __init__(self, path):
    super(Model, self).__init__()
    # Load a pretrained SimCLR model.
    self.saved_model = tf.saved_model.load(path)
    # Linear head.
    self.dense_layer = tf.keras.layers.Dense(units=num_classes,
        name="head_supervised_new")
    self.optimizer = 

  def call(self, x):
    with tf.GradientTape() as tape:
      # Use `trainable=False` since we do not wish to update batch norm
      # statistics of the loaded model. If finetuning everything, set this to
      # True.
      outputs = self.saved_model(x['image'], trainable=False)
      logits_t = self.dense_layer(outputs['final_avg_pool'])
      loss_t = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
        labels = tf.one_hot(x['label'], num_classes), logits=logits_t))
      dense_layer_weights = self.dense_layer.trainable_weights
      print('Variables to train:', dense_layer_weights)
      # Note: We only compute gradients wrt the linear head. To finetune all
      # weights use self.trainable_weights instead.
      grads = tape.gradient(loss_t, dense_layer_weights)
      self.optimizer.apply_gradients(zip(grads, dense_layer_weights))
    return loss_t, x["image"], logits_t, x["label"]

model = Model("gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r50_1x_sk0/saved_model/")

# Use tf.function to speed up training. Remove this when debugging intermediate
# model activations.
@tf.function
def train_step(x):
  return model(x)
  
ds = build_dataset(...)
iterator = iter(ds)
for _ in range(num_steps):
  train_step(next(iterator))

We can fine-tune a contrastive model on our dataset to learn better embeddings, run the trained model on the entire dataset to generate new embeddings, and run the same similarity search query as before.

Top: the query sugar beet. Bottom: the results from running a similarity search on fine-tuned SimCLR embeddings. The quality of the search results are much better.

As you can see, the quality of the search results is markedly improved compared to the ImageNet embeddings. Most of the results are indeed sugar beets (as denoted by the same color boxes) at a similar stage in development, when they are smaller with around four leaves.

Incorporating user feedback with a “linear probe”

Say we want to do even more specific / fine-grained data collection. Instead of just looking for sugar beets at a similar stage of development, say we want to specifically find:

  • Sugar beets
  • With 4–6 leaves
  • Where all leaves are clearly visible (ie not cut off by the side of the image)
Examples of undesirable results. Top: plants of a different species. Middle: sugar beets that are cut off by the side of an image. Bottom: sugar beets that have less than 4 clearly visible leaves.

However, when we scroll through the search results from the fine-tuned embeddings, not all of the results satisfy these criteria. With this more restrictive definition, then 8 of the top 50 most similar search results do not satisfy our criteria, with examples of the undesirable results visualized above.

While the quality of the similarity search captures many of the salient features of plants in the dataset, plants can be similar in different ways. Two examples may be similar to each other because they have a similar shape, color, texture, plant, number of leaves, etc. A user may be using few-shot learning to find very specific type of data, like the more restrictive criteria we described above. To filter for a specific type of similarity, we need a framework to provide user feedback to the search and refine the search results over time as the user marks search results as “relevant” or “not relevant.”

Again, machine learning comes to the rescue! We can use a linear probe to incorporate positive and negative feedback from the user. A linear probe is a simple linear model that is trained on embeddings (or more generally, any intermediate features in a neural net) to produce classifications.

In machine learning jargon, embeddings are a higher level feature generated by neural networks in which classes / important attributes are linearly separable, and thus easier to distinguish with a simple classifier, even when the lower level input space is not linearly separable.

Most image classifiers have to be trained on thousands to millions of examples in order to learn the correct features to perform their tasks well. However, in the previous step we already trained our contrastive learning model to learn rich features in the form of embeddings. These embeddings are a high-level representation that capture the most important nuances in the dataset.

Training a model on high level embedding features drastically simplifies the learning task vs training on low level raw inputs like pixels. By training on these embeddings, we can train a simple linear probe on only a few datapoints (tens to hundreds of examples) and still get good performance.

We use the Aquarium UI to mark datapoints as relevant (green thumbs up) or not relevant (red thumbs down). Datapoints are not relevant if they are of different classes, have too few leaves, or have leaves that are not fully visible.

For this post, we use Aquarium’s UI flow to label search results as relevant or not-relevant. Using those labels, we can then train a linear probe that takes in an embedding from a datapoint and produces a confidence output of whether that datapoint is relevant or not. We can use this confidence value to filter irrelevant datapoints and rank higher confidence results higher in the search results.

This classifier training can be implemented very simply with the logistic regression classifier from scikit-learn:

from sklearn.linear_model import LogisticRegression

# assume labels are 0 for irrelevant and 1 for relevant examples
embeddings, relevance_labels = load_training_data()

# train on a set of embeddings with attached relevance markers
clf = LogisticRegression(random_state=0).fit(
embeddings, relevance_labels)

# run a prediction on the first example and get the confidence output
confs = clf.predict_proba(X[0, :])

# get confidence of this datapoint being relevant (label == 1)
relevance_score = conf[1]from sklearn.linear_model import LogisticRegression
Top: the query sugar beet. Middle and bottom: the most relevant search results sorted by relevance from the linear probe. All of them match what we are looking for.

After approximately 10 minutes of work, we’ve marked 108 datapoints as relevant and 32 as irrelevant, which gives us enough data to train a linear probe and then sort search results by relevance score from the probe.

As a result, now all 50 of the top search results are indeed sugar beets with 4–6 leaves that are clearly visible! This is a significant improvement when compared to the search results without the linear probe.

Conclusion

A lot of common tasks on machine learning data are difficult to accomplish with traditional code or require large amounts of human time. In this post, we saw how we can do targeted data collection in a far more efficient way by applying machine learning concepts.

Overall, much of applied machine learning still relies on supervised learning, where a human has to define the desired behavior of a model by labeling the data it trains on. It’s hard to remove this human dependency.

However, as we’ve seen in this post, this human supervision can be more efficient than it is today. Instead of relying on a skilled ML engineer or large groups of operations personnel, a single domain expert user with the right interface can efficiently improve machine learning data.

Get in touch

Schedule time to get started with Aquarium