Peter Gao
January 21, 2021
12 min

You Should Try Active Learning!

The sight of this curve can strike fear into the hearts of machine learning practitioners. Source

Whenever an ML team discusses what they should do to improve their models, there’s inevitably a point at which someone throws up their hands and says, “Well hey, let’s just get more data and retrain the model. Maybe that will help.”

There’s some promise to this idea. Holding the model code fixed and iterating on data can be the most effective way to improve your models. Randomly sampling data to retrain on yields great improvements up to a certain point. But after that point, you start to get less and less improvement to model performance as you add more data.

To improve model performance with a reasonable budget, you can’t just label more data, you also need to label the right data to improve your model! This concept is commonly known as active learning.

At Aquarium, we build tooling for ML teams to help them improve their model performance more efficiently. In this post, we go into why diminishing returns happen, how to counter this with active learning techniques, and how Aquarium helps implement simple but effective active learning workflows.

Diminishing Returns On Data

In the early days of a project, adding more data can be pretty effective. When your dataset is very small, there’s not enough representation of the production domain for a model to learn, so the model quickly overfits to the train set and does poorly on a test set or in production. So naively sampling more data and retraining your model can lead to significant performance improvements.

However, as dataset size increases, the model tends to get very good at handling easy or common cases, but still struggles on difficult or rare cases. At this point, blindly adding more data yields diminishing improvements to model performance.

To illustrate this phenomenon, Allegro AI trained convnet object detectors on two popular datasets, COCO and BDD, and was able to quantify the diminishing returns to model performance from adding more data.

This pattern happens across almost every learning task, but the exact shape of the curve changes depending on the dimensionality of the data, the size of the model, and the difficulty of the task.

As an ML practitioner, it’s extremely useful to do an ablation study to see whether you are starting to see diminishing returns on your task. First, carve out a fixed test set for evaluation. Then train the same model code on different sized subsets of your training set, evaluate the performance of the resulting models on your test set, and then plot the curve of model performance vs dataset size to see if you’re starting to see diminishing improvements to model performance from adding more data.

Now we’ve identified this as a problem, let’s talk about ways to fix it.

Brute Forcing The Curve

One solution is to scale your data acquisition faster than your returns diminish.

In 2017, Google published a paper describing their solution to the problem of diminishing returns: label and train on way more data. While the largest public imagery dataset available in 2017 was Imagenet with 1 million examples, Google built their own dataset with 300 million examples.

To do so, they leveraged an automated labeling algorithm based on a “complex mixture of raw web signals, connections between web-pages and user feedback,” though without much quality checking by human annotators.

They observed that the model performance improves logarithmically with the size of the training dataset, which matches our expectations on diminishing returns. However, they were still able to get ~5% improvement in their model performance by adding 290 million examples.

In 2018, Facebook published a paper detailing how they were able to achieve state-of-the-art performance on Imagenet by training on billions of examples.

They built this dataset by relying on Instagram hashtags to automatically label their dataset. They pretrained a model on the hashtags aligned to Imagenet classes and then fine tuned the model on Imagenet to achieve a significant performance improvement compared to the previous state of the art.

Again, there was a logarithmic relationship between model performance and training dataset size. At the end of the day, the Facebook team achieved a ~20% improvement in model performance by pretraining on billions of examples.

The moral of this story is: if you have petabytes of data to draw weak labels from, hundreds of GPUs to scale model training with, and millions of dollars to spend on research, you can push through diminishing returns with brute force by throwing more data and compute at the problem.

Tackling The Long Tail

Of course, most ML teams don’t have the mind-bending resources that FAANG teams have, but they still have to contend with the problem of diminishing returns.

Which raises the question: why do model improvements diminish as dataset size increases? If our model errors were uniformly distributed, we’d expect that sampling data randomly would continually improve model performance at a more linear rate.

It turns out the vast majority of ML use cases deal with long tailed distributions of data. In these scenarios, randomly sampled datasets tend to draw from frequent or easy scenarios, causing the model to do well on common scenarios and fail on the various edge cases that make up the long tail of the distribution.

Martin Casado and Matt Bornstein at A16Z analyzed the economic implications of the long tail for ML applications. They pointed out that long tail distributions are very common in machine learning datasets:

And that there’s a variety of things you can do across product, engineering, and operations that can help handle the long tail:

But when you can’t change your expectations for the product and you don’t want to invest inordinate amounts of engineering time, you need to be more sophisticated with your data collection strategy.

At one of my previous jobs, we trained an image detector to detect construction cones on city streets. We randomly sampled images of cones to train on, and we discovered it did very well on orange cones but failed to detect green cones because green cones were much rarer. If we had continued to randomly sample data to label, we’d collect a lot of orange cones and very few green cones, so our model wouldn’t have gotten much better. Instead, we targeted our data collection and labeling on green cones, retrained our model on the new dataset, and the new model did much better on green cones. After that, we realized we weren’t doing well on the rare rainy days in San Francisco, so we collected more images on rainy days, retrained, and the new model performed better. So on and so forth for scooters, school buses, emergency vehicles, etc. Most importantly, we could repeat this procedure to counter the effects of diminishing returns and meaningfully improve our model performance every time we retrained.

The takeaway is: when you start hitting diminishing returns, don’t randomly sample data to label. Instead, collect and label the data that would most improve your model performance. This concept is known as active learning.

By implementing active learning techniques in your pipeline, your training process is significantly more data efficient. You can either get the same model improvement at lower labeling cost or get more model improvement for the same labeling cost!

Active Learning For Fun And Profit

The biggest players in the self driving space have all discovered the value of active learning feedback loops. Why? Because the long tail of driving is brutally long. From left to right: Waymo, Cruise, and Tesla.

Active learning is a fancy name for a somewhat simple concept, but it turns out that there are a lot of ways to implement active learning. Here’s some things to try in order of increasing difficulty:

Balance your data

Even before you’ve trained a model, you can identify rare areas in your dataset that you should collect more examples of.

The most obvious example of this is class imbalance. If your dataset consists 90% of class A and 10% of class B, it’s no surprise that your model might do worse on class B due to having significantly fewer examples of it to train on. There are many ways to address class imbalance, but often the simplest solution is to collect more of class B!

This principle can also apply to other metadata that’s present in the dataset. For example, if most of your dataset consists of images collected during the day and very few examples from nighttime, then your model will probably not do so well at night. So collect more night images!

Set up a feedback loop for error mining

Once you have a model deployed, you can now get a very valuable signal — where is the model doing badly in production? At this point, it’s extremely helpful to have a feedback system where you can identify places where the model isn’t doing well and label those datapoints.

Look at the places where your model is making mistakes in your training / test set. Simply looking over these examples can reveal labeling errors or rare edge cases where you need to collect more data for your model to get better at. Better yet, you can catch these issues before your model gets into production.

Leverage human inspection of production model inferences to find mistakes. Your end users may complain about your model’s errors that they notice, which provides a coarse but important clue of where to focus labeling. You can also sample subsets of your production data and have human labelers mark mistakes. In robotics contexts, this can even take the form of a human watching a robot operate and pressing a button whenever the robot makes a mistake, providing a stream of erroneous data that you can sample, label, and retrain your model on.

Take advantage of domain knowledge to automatically label data and find mistakes. For example, recommendations models (like on e-commerce or social media websites) can measure their own error by keeping track of the difference between what they recommend and what users end up clicking / not clicking. Prediction or forecasting models can keep track of the difference between what they predicted would happen vs what actually happened as time went on. Hedge funds utilize this for stock price forecasting and self driving companies heavily leverage this for object prediction. Other ML products may already have built in mechanisms to deal with failures (human review, cross-referencing, etc.) that can get you very high quality error mining for “free”. Either way, you can then pick the n highest error examples and add them back to your dataset for retraining.

Listen to the model

You can also make models that are able to understand what data they need to improve themselves.

Start by taking advantage of the confidence outputs from the modelsampling based on low confidence examples, margin sampling, or entropy. This can be simple to implement, but in my experience, it doesn’t work super well. The model can simply be miscalibrated in its confidences, leading to cases where the model is both high confidence and wrong or cases where the model is overconfident or underconfident.

Next, make models that explicitly understand when they are uncertain. Models can be uncertain because of lack of information / noisy inputs (aleatoric uncertainty) or due to lack of previous data of an edge case (epistemic uncertainty). Bayesian neural networks are trained to allow users to quantify aleatoric or epistemic uncertainties for each datapoint, allowing users to find and label examples with high epistemic uncertainty. Here is a guide on how to implement a Bayesian neural network using dropout layers, and there are a number of papers that detail how active learning techniques that use these uncertainties can significantly improve data efficiency.

An active learning policy on a Bayesian CNN can produce more model performance gains with less data. Source

Listening to your model is a very powerful strategy but is not commonly used in industry. Implementing these techniques requires deep changes to the model that require strong theoretical knowledge and practical ML engineering skills to implement, whereas most companies start with off-the-shelf models that are easier to use but do not have these attributes.

Active Learning With Aquarium

We build Aquarium, which is an ML data management system that makes it easier to do active learning.

We take a very data-focused approach to active learning that is quite similar to Tesla’s edge case engine. Our tooling is focused around understanding the performance of the model in relation to the dataset, and then we provide ways for users to edit or collect more data to address model failures.

Aquarium provides a variety of basic health checks on your dataset, like plotting distributions of metadata to find areas with class or metadata imbalance. You can then use this information to collect more examples of underrepresented classes / metadata for your model to train on.

Aquarium also allows you to implement an active learning feedback loop based on examining what difficult edge cases your model is failing on and collecting more of that data. Aquarium’s active learning workflow looks roughly like:

  • Upload your latest labeled dataset and corresponding model inferences to Aquarium.
  • Find patterns of edge cases in your data where your model performs poorly.
  • Find more examples of these edge cases in your production (unlabeled) data stream, label them, and then add them to your training dataset.
  • Train your model on the updated training dataset and measure the change in performance of your new model vs your old model.
  • Repeat!

Aquarium speeds up the active learning workflow by using analysis of neural network embeddings. Neural network embeddings are a representation of what their network “thought” about the data. The embeddings for a datapoint (generated by either our users’ neural networks or by our stable of pretrained nets) encode input data into a relatively short vector of floats. One can identify outliers and group together examples in a dataset by analyzing the distances between these embeddings.

By doing similarity search on these embeddings, companies can implement content search on complex data like imagery or audio. In the self driving field, many ML teams have used embedding analysis to speed up active learning workflows.

For example, a lot of practitioners find model error patterns by resorting to scanning through thousands of datapoints manually to uncover patterns of failures. By reducing the high dimensional embedding space to a visualization in two dimensions, Aquarium can cluster similar datapoints together and colors them by whether the model performed correctly / incorrectly on them. This helps users quickly find patterns of model failures without having to manually scan through the entire dataset.

It can also be challenging to collect more data of a rare edge case even once it is identified. The default option for many teams is to manually trawl through massive amounts of data to find more examples of a rare scenario to label (ie, green cones). However, using embedding search, it’s possible to find more examples of rare edge cases automatically, essentially treating active learning as a content retrieval problem.

Aquarium offers similar functionality for finding more examples of rare edge cases. Once a user has found examples of an “issue” they’d like to collect more of, Aquarium can trawl through vast amounts of unlabeled data and find more datapoints that are similar / “close together” in the embedding space to the issue. A user can then send these datapoints to a labeling provider, add the labeled data to their dataset, and retrain their model that performs better.

Our platform is not the most algorithmically complex way of doing active learning, but it implements a fairly general technique has been proven to work in a number of practical applications. By wrapping this workflow into a nice web interface, Aquarium’s goal is to allow nontechnical users to constantly improve the dataset and model performance without needing an ML engineer to be actively involved.

Check out our site if you want to learn more. If you’d like to try Aquarium out for your own dataset, let us know!

Get in touch

Schedule time to get started with Aquarium