A Cautionary Tale
A machine learning team has just finished deploying the first iteration of a deep learning model to production. The model is not perfect, but it’s good enough to work as a V1. After the euphoria of launch fades, now comes the hard work of turning an MVP into a legitimately valuable product, and that means improving the accuracy of the model to be useful — from, say, 80% to 99%.
So the team does a brainstorm to figure out what to do next. At this point, something very strange happens. One ML engineer perks up and says, “We should try stochastic AdamGradProp with momentum! I just read a paper on Arxiv about it.” Another says, “Let’s try collecting more data! If we just randomly sample data and double the dataset size, it’ll definitely get better.” Yet another might suggest, “We should implement the model that’s state-of-the-art for the COCO benchmark.”
After waiting around a week to three months to try these experiments (and burning through hundreds of hours of GPU compute time), the team comes back and finds that the performance is only incrementally better (say, a 0.05% improvement). Perhaps the model performance is even worse than before! What happened?
Problem Solving In Machine Learning Development
For some reason, many ML teams adopt an attitude towards development that can be best described as “throwing things at the wall and seeing what sticks.” It turns out that doesn’t work so well because solutions are meant to solve certain problems. Blindly trying solutions doesn’t actually solve the big problems that are holding the model’s performance back, and can perhaps worsen them. Ideally an ML team can do a guided search where they identify the most impactful list of work to invest in and work their way through them.
How do we brainstorm and prioritize what work to try? Like with so many other topics, we can turn to the teachings of Richard Feynman. The so-called Feynman problem-solving algorithm is:
- Write down the problem.
- Think real hard.
- Write down the solution.
We can apply this problem-solving algorithm to machine learning development:
- Understand the biggest problems in the ML system
- Think about a solution that would solve those biggest problems with the least effort
- Try a solution
- Measure the before-and-after difference in performance to see if the solution worked
ML teams tend to jump straight to brainstorming and implementing solutions without having a firm understanding of why the system is failing. This waterfall-esque development paradigm works well if a software system’s requirements are extremely clear and easy to plan out in advance. However, machine learning system requirements are complex and hard to know up front, so it’s better to develop with an agile workflow with quick iterations. By having a tight loop of debugging to solution experimentation to measuring results, it’s much faster to try out individual hypotheses and have clear attribution of what solution led to what percentage improvement in model performance.
In addition, a lot of previous advice on training models has been focused on improvements to model code or training parameters, known as model-centric machine learning. However, in applied machine learning settings, most model improvement comes from improvements to the data, known as data-centric machine learning. In recent years, there has also been a lot of progress in development of foundation models that have reduced the need for applied practitioners to tweak model code for optimal performance, making it comparatively more impactful to make improvements to the data.
As a disclaimer, I’m a founder at Aquarium, and we see a lot of our customers run into these same problems again and again. This post attempts to detail a modern, efficient approach to improving model performance with a focus on data-centric techniques that we’ve compiled based on our experience with working with a lot of deep learning teams to improve their model performance.
Improving Machine Learning Systems
There’s a lot of ways to debug machine learning systems, but there are some basic techniques that everyone should try.
Look at the literature
It always makes sense to sanity check your overall approach before optimizing it. For example, it doesn’t make much sense to optimize a logistic regression model trained on raw pixels for an image classification task when almost all modern techniques utilize transformers or convnets.
Most ML applications are not solving fundamentally new problems, so it’s better to start with techniques that have demonstrated good performance on similar problems and then iterate on them over time. Without understanding prior art, you may end up in the unenviable position of doing research work in the time and resource constraints of an industry environment. Worse, developing machine learning models from scratch can lead to a number of subtle and hard-to-catch bugs: improper preprocessing, tuning loss functions, messing with model architecture, etc. that can waste a lot of time.
Another example is constructing good evaluation metrics. For example, it’s tempting for a user to simply look at overall accuracy when evaluating a classification model. But if a dataset is heavily imbalanced, say it contains 90% of class A and 10% of class B, then a model can get to 90% accuracy by always predicting class A! Therefore it’s better to measure precision, recall, and F1 score across each class and then average these metrics across classes for an overall model accuracy metric. This problem has already been solved with standard metrics calculation tools like Scikit-learn’s classification report.
However, there’s also diminishing returns to implementing experimental state-of-the-art algorithms vs utilizing more standard techniques with well-tested implementations in the Pytorch or Keras model zoos. As one moves up the state-of-the-art benchmarks, the improvements to accuracy become incremental, higher accuracy comes at a cost to runtime speed, the implementation is not open source, or if it exists the code is of low quality (also known as “research code”). In some cases, one can even get better performance from fine-tuning a standard model with pretrained weights vs training a state-of-the-art model from scratch without pretrained weights. I’ve found it useful to look at Papers With Code for a given machine learning task or similar dataset and then implement a technique from the “most implemented papers’’ section.
Check if your model is overfitting
A basic concept taught in every introductory machine learning class is bias-variance tradeoff. In practical terms, an ML practitioner needs to look at the error metrics that their model has on their training sets and on their test set.
- The model is overfitting (high variance) when it has low error on the training set but high error on the test set.
- The model is underfitting (high bias) when it has high error on the training set and validation set.
If the model is overfitting, the problem is simpler to fix. This often means that the training set is not representative of the domain it is supposed to run in. An ML practitioner can collect more data or use data augmentation to increase the variety of examples in the training dataset to address this. Another possibility is that the dataset is representative but too small, in which case an effective solution is to implement early stopping, where one stops their model training when it is no longer improving its validation error.
If the model is underfitting, there’s a wider set of solutions that you can try. However, a very basic sanity check is to make sure you are training your model to convergence. Many ML engineers stop training too early, when the model may still be productively “learning,” and would significantly improve its performance if trained for longer. This can be counteracted by continuing to train as long as the model’s validation error continues to decrease (bounded by time and compute constraints) vs stopping after a fixed number of iterations.
After ruling out basic issues, we have to do a more detailed diagnosis to implement the right solution.
Inspect high-loss failure cases
Now we need to get into error analysis. The previous section dealt with basic troubleshooting using aggregate accuracy / error metrics. However, it’s helpful to understand the specific instances in the data where models fail to take the proper corrective action.
An easy check is to inspect the failure cases with the most disagreement between the model inferences and the labels, ie “high loss examples.” It is helpful to create a visualization of these datapoints so a human can easily look over a few hundred examples and gain intuition for what’s going on in the data. This is helpful for finding issues like duplicate examples, malformed data, errors in data preprocessing, etc. Or it can expose obvious patterns of failures that can be addressed systematically.
In particular, looking at high loss examples tends to surface many examples where the model is correct and the label is wrong! Bad labels and data introduce a lot of noise in the training process, so fixing these errors greatly improves the model for little effort.
Find patterns of failures
Looking at individual failures works for finding obvious failures. However, past obvious issues, model errors tend to follow “long tail” distributions, where a few categories of issues may only comprise a small portion of the overall failures, and the rest of the distribution consists of a smattering of different categories with relatively few failures each. The correct move is to triage model failures by finding the most egregious categories / patterns of failures and then prioritize those for resolution first.
A simple but inefficient solution is to simply extend the previous technique — instead of only inspecting the top failures, inspect all of the failures and uncover the patterns manually! This is useful for gaining intuition of the major failures in the dataset, but not only is it extremely time consuming, it’s difficult to quantitatively keep track of what failures show up the most unless you manually keep count as you look through.
A more efficient but less reliable way is to find correlations between metadata and model performance. Metadata are essentially information that a model doesn’t necessarily have access to as part of its input but may be available — for example, timestamps captured with an audio / visual sensor input, the customer ID that the data came from, geo data, etc. It’s quite useful to know that the model tends to fail on a certain class or at a certain time of day because then you can focus your efforts on improving model performance on that slice of metadata, and it’s efficient to uncover these patterns because you can easily manipulate metadata using code or SQL queries.
However, metadata are not as reliable because it’s possible to find spurious correlations (since the model may not be using that metadata as an input) or the metadata does not capture an attribute of the data that actually does affect model performance (for example, there may not be metadata distinguishing audio with a British vs American accent or dark vs bright imagery).
A solution that is both efficient and reliable relies on being able to automatically find patterns of failures in the data input itself that doesn’t need a lot of manual inspection. Neural network embeddings provide a way to understand patterns of similarities in the underlying data generated by neural network inferences, which can be done quicker and cheaper than human inspection. Running clustering on this data exposes patterns in the data, then these clusters can be ranked based on how many datapoints are in each cluster and how badly the model performs proportionately so a human only needs to inspect the most severe clusters vs every datapoint in the dataset. For example, if a model makes 500 mistakes in a cluster with 10,000 examples, this is not necessarily cause for concern. However, if the model makes 500 mistakes in a cluster with 1000 examples, that’s probably worth looking at.
Fix patterns of failures
Once you have an idea of the most severe patterns of failures in a dataset, the next step is to plan out and implement a solution that will solve them. However, there is a range of solutions available at this point.
In some cases, the way to address issues can be through one-off code changes to the model or code preprocessing — for example, data augmentation (randomly adjusting image rotation / color parameters), adjusting hyperparameters (prior box tuning for object detectors), or adding additional modeling steps (having an OCR and text model as part of a road sign classification pipeline).
But in the vast majority of circumstances, it’s more effective to add more examples of cases that the model struggles on. The best way to do this is to simply collect and label more real data from production that is similar to that failure — random sampling can sometimes include these failures, but for uncommon cases it’s necessary to do more targeted data collection anyway. If a model struggles to detect occluded stop signs, then adding more images of occluded stop signs to the dataset is the best way to improve performance. Similarly to the error analysis sections, this can be done through metadata querying when possible or through brute-force human inspection of data if not, but the more elegant solution is to use embeddings to do similarity search across unlabeled datasets to do that targeted data collection.
In rare circumstances, it’s not possible to collect this data because it’s too expensive or too dangerous — in these cases, synthetic data can be a solution, though it comes at a higher cost of paying level artists / simulation engineers to generate data that may not necessarily match the conditions of real data.
Addendum: Why Do ML Teams Jump To Solutions?
I’ve worked on ML teams at various well-known tech companies and I’ve found it’s been extremely common for ML engineers to fall into the trap of “try this and see what happens.”
If I were to guess, many ML engineers are former software engineers who recently picked up ML (as opposed to ML PhDs or seasoned industry ML practitioners). A lot of traditional software engineering doesn’t require thinking too hard about what to do, since many of those decisions are made by product managers (what is the right thing to build) and product designers (how the thing should work). Most work is therefore implementation / execution and most creative freedom comes in how to correctly architect the solution. In addition, there’s a lot of established best practices for ML engineering that software engineers may not be exposed to unless they’ve taken grad-level ML courses in school.
When these engineers are given the freedom to experiment for improving an ML model, they quickly end up burrowing into rabbit holes that don’t yield much improvement for the time spent.
The most similar analog to ML development in traditional software engineering is performance optimization — in other words, speeding up code execution. Blindly applying optimizations to improve performance is a generally inefficient approach.
The most successful approaches involve using a profiler to understand what code takes the most time to run, optimizing those pieces (refactoring particularly slow functions, unrolling tight loops, etc.), and then benchmarking the before-and-after difference in performance. This diagnose, experiment, and validate loop is quite similar to the ML improvement workflow!