In what sense do large models trained differently learn similar representations?


Large models trained from different random seeds — and sometimes even with different widths, architectures, data, or objectives — tend to learn similar internal representations. What is the appropriate metric for assessing this similarity? Can we use it to make a far more robust, precise version of this statement?

When an input is passed into a neural network, the (pre)activation vectors at the hidden layers constitute vector “representations” of that input. Understanding the structure and development of these hidden representations is in some sense the whole challenge of the science of deep learning. One useful question to start with is whether different large models learn essentially similar representations, up to minor variations. This is important because if the answer is “yes,” then we can “study one model to study them all,” but if it’s “no,” then we’ll have to take much more ad-hoc, model-specific approaches.

Of course, the devil’s in the phrase “essentially similar.” What do we mean exactly? It seems very likely that there will be senses in which the answer to this universality question is “yes” and senses in which it’s “no,” and the important game will be less trying to choose one or the other and more trying to tease out which precise versions of the question have the answer “yes” and which have the answer “no.” This will clarify in which senses we can “study one model to study them all” and for which goals we’ll need ad-hoc approaches.

The difficult question here is methodological: how should we quantitatively compare representations across models? To understand the challenge, note first that the very simplest mathematical metrics for similarity, like Euclidean distance, make no sense here: different models are not only trained from different seeds but might have hidden representations of different sizes that live in different spaces! A single representation vector only has meaning in the context of the distribution of representations from the same model, and so any assessment of representational similarity must take place at the level of (or at least somehow involve) a dataset, not just one pair of representations at a time.

So how do we measure representational similarity at the level of datasets? This is a deep and difficult question not because we have no answer but because we have many. There are many reasonable-seeming metrics based on nearest-neighbor structure, linear-algebraic tools like canonical correlation analysis or kernel alignment, trained linear probes, mutual information, model stitching, and more. These different metrics tend to give different answers. How should we compare representations across different large models? Once we have this tool, what can we learn with it?

Discussion