Peter Gao
March 16, 2021
8 min

The Unreasonable Effectiveness Of Neural Network Embeddings

An example of an interactive embedding visualization generated in Aquarium.

Machine learning on unstructured data (like images or audio) is much harder than machine learning on structured data. Fortunately, deep neural networks can produce structured representations known as “embeddings” that are remarkably effective in organizing and wrangling large sets of unstructured data.

At Aquarium, we heavily utilize embeddings to make it easier for our users to work with unstructured data. In this article, we’re going to explore why these embeddings are so effective and how to use these embeddings to speed up common ML workflows.

Structured and Unstructured Data


Structured data can be thought of as “organized” data that can be described with traditional data schemas. Structured data can be explored with spreadsheets or SQL databases. The data can be histogrammed to get a sense of the distribution of the dataset or to plot outliers, and the data can be filtered to interesting slices with SQL-like queries.

Unstructured data can’t be easily organized into nice schemas with rows, columns, and tables. Examples of this type of data include imagery, audio, text, and pointclouds. Unstructured data is very high dimensional: there are hundreds of thousands of possible words in English and an RGB 800x600 image has ~1.5 million inputs. Obviously there is important information in this data, but it’s hard to index, query, or compare this data in a human-interpretable way, since the tried-and-true techniques that work well on structured data don’t work as well here.

Enter neural network embeddings.

An Intro To Embeddings

An embedding is a low-dimensional vector representation that captures relationships in higher dimensional input data. Distances between embedding vectors capture similarity between different datapoints, and can capture essential concepts in the original input.

Embeddings can be generated in various ways, but there’s a lot of excitement around using neural networks to learn embeddings on unstructured data. One of the most famous examples of this technique is in natural language processing. A pretty cool example of how these embeddings capture higher level concepts can be derived from adding and subtracting the embedding vectors extracted from a neural network trained on a large text dataset. After taking the embedding vector for king, subtracting the embedding for man, and adding the embedding vector for woman, you get an embedding that strongly resembles “queen.”


Neural network embeddings are easy to produce across a variety of task types and data types. They are learned during normal supervised training as an intermediate representation that is most helpful for the net to accomplish that supervised task. These embeddings are therefore a learned feature vector that can be generated from imagery, audio, text, and even structured data.

Neural network embeddings are a useful byproduct of the training process, yet most practitioners don’t know about them, let alone take advantage of them!

Getting Started With Embeddings

A visualization of various layer activations in a convolutional neural network trained on MNIST. The activations from the second-to-last layer (second row from the top) can be extracted and used as a learned embedding for the input image. Source

It’s usually simple to extract embeddings since they are available as a byproduct of the normal supervised training process and don’t require any changes to your models. There’s some subtlety for handling some task-specific net architectures, but a good rule of thumb is to grab embeddings from the activations of the layer right before your task-specific prediction layers.

For example, to extract embeddings from a Resnet50 image classification model in Keras, simply run:

from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
model = ResNet50(weights='imagenet', include_top=False)
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
embedding = model.predict(x)

However, embedding vectors are still fairly long — in this case, 2048 elements long. This is obviously much more manageable than millions of pixel values in the original image, but this embedding is hard to understand because the values are not human-interpretable features like the age, height, timestamps, etc. features that you might see with structured data. Instead, embeddings are “neural network interpretable” features that are hard for humans to understand.

A good visualization can be very helpful for building intuition about embeddings. We can use a dimensionality reduction technique like PCA, t-SNE, or UMAP to further reduce the embedding to 2 dimensions, allowing us to create visualizations that mostly preserve patterns in 2D that were present in the original embedding vectors.

Using UMAP is as easy as installing the UMAP library and calling a function on a list of embeddings. Here, we reduce the 2048 dimensional embeddings into 2 dimensions:

import umap
2d_embeddings = umap.UMAP(n_components=2).fit_transform(embeddings)

Now that we have a set of UMAP-ed embeddings, we can visualize them. The easiest way to do this is to use a library like Matplotlib. Here’s an example that trains a model to generate embeddings on Wikipedia books, reduces the embeddings into 2D with UMAP, and then visualizes them in a Matplotlib plot:


Note that there is no inherent semantic meaning to the locations of the 2 dimensions generated by UMAP (ie the x and y dimensions of the plot). What’s more important is to understand what examples get grouped together and to reason about why some are close together and some are far apart.

This is where the limitations of a simple plot become apparent. It’s hard to explore what datapoints belong to which cluster, which makes it difficult to build intuition why certain data are distributed the way they are in the embedding space. What we really want is an interactive visualization that allows us to see the UMAP distribution but also allows us to zoom into individual clusters and explore what type of data they contain.

In another example, we trained a model to do pet classification on the Oxford-IIIT Pet Dataset, then extracted embeddings from the model and used Aquarium to generate an interactive UMAP visualization. We then were able to explore the distribution of the dataset to find interesting clusters / outliers in the dataset.

Here’s what the UMAP looks like for the entire dataset.

Here’s the UMAP filtered down to just images labeled as Shiba Inu dogs.
While the largest cluster consists of the most common red-furred Shiba Inus…
There are also Shiba Inus with white coats.
As well as Shiba Inus with black coats.
Some images that are labeled as Shiba Inus that are actually Australian Shepherds!

Not only do we see that the embeddings work pretty well on image data, but it also shows variation present in certain label classes. We see that there are Shiba Inus with red fur, as well as some that have white or black coats. But When we examine some of the outliers in the Shiba Inus, we also notice some pictures of dogs that are clearly not Shiba Inus, likely the result of labeling mistakes. Try it out for yourself here!

Since we extracted the embeddings from a supervised classification model, we can actually compare the inferences from our model to the labels on the dataset to find where the model disagrees with the labels. We can then color the embedding pointcloud by agreement / disagreement and look for clusters of similar disagreements that might indicate a specific edge case that the model performs poorly on.

By comparing model inferences to labels, we can color the UMAP cloud by model agreement / disagreement.
Filtering the cloud down to only Shiba Inu labels.
We see that the model seems to confuse Shiba Inus with white coats for Samoyeds!

As we can see here, our model confuses Shiba Inus that have white coats with Samoyeds, another type of dog with a white coat. We should probably go collect and label some more Shiba Inus with white coats of this for the model to retrain on so it does better on these Shibas in the future. Try it for yourself here!

ML Workflows With Embeddings

We can see that embeddings make for a great visualization and debugging tool. However, we can actually build workflows around embeddings to efficiently improve ML model performance.

Embeddings are especially powerful in workflows for finding and fixing model errors. As we saw previously with Shiba Inus, interactive embedding visualizations allow users to find patterns of model errors, while running clustering algorithms on the embeddings can surface those patterns automatically.

These patterns of model failures are often edge cases where a model needs more data of that scenario in order to improve its performance. This is a very common problem that surfaces across different tasks and problem domains. As a result, ML practitioners often do targeted data collection on these interesting scenarios to label and retrain on in order to produce a model that does better on these edge cases.

Examples of variations in stop signs that Tesla’s stop sign detector must handle. Source

In most domains, ML teams can store far more unlabeled data than they can feasibly label. The naive solution to this problem is to sift through large amounts of data manually to find more examples of a specific edge case (the white Shiba Inus). This can take the form of scrolling through grids of visualizations or spreadsheets of data URLs in order to find the “needle” of interesting data they’d like to collect within the “haystack” of the mostly irrelevant data they’ve already stored.

Targeted data mining can be thought of as a search and retrieval problem that can be solved with neural network embeddings. Embedding-based search systems work by comparing a “query” embedding to a large corpus of embeddings to search through. The search system returns the embeddings from the search corpus that are most similar to the query embedding. This similarity can be defined by, for example, having the lowest cosine distance between the query and result embedding vectors.

The ML team at Waymo utilizes embedding search to mine examples from unlabeled data that are similar to difficult edge cases / model failure cases. They can then label and retrain on these datapoints and produce a model that has better performance. This is an example of using unsupervised learning to improve a supervised learning system!

Using a visual search system to mine more cactuses. Source


Neural network embeddings are a powerful byproduct of normal supervised training that allow ML practitioners to more easily work with unstructured datasets. It’s pretty easy to get started with generating and visualizing your own embeddings since they don’t require any changes to your model training process. By using these embeddings properly, one can speed up or automate a lot of work in typical model improvement workflows.

At Aquarium, we invest heavily into using embeddings for dataset exploration, data mining, and more. ML models are only as good as the data they’re trained on, and our goal is to make it easy for non-technical users to ship better models faster by improving their datasets. If you’re interested in trying Aquarium for yourself, let us know!

Get in touch

Schedule time to get started with Aquarium