How is contrastive learning done with one model, exactly?
I agree only one is used in inference, but two are needed for training (otherwise how do you calculate a meaningful loss function?). Notice in the original CLIP paper, there's an image encoder and a text encoder, even though only the text encoder is used during inference. [0]
There are 2 submodules in our model — a contrastive submodule and a diffusion prior submodule, but they still form 1 model because they are trained end-to-end. In the final architecture that we picked there is a common backbone that maps from fMRIs to an intermediate space. Then there is an MLP projector that produces the retrieval embeddings and a diffusion prior that produces the stable diffusion embeddings.
Both the prior and MLP projector makes use of the same intermediate space, and the backbone + projector + prior are all trained end-to-end (the contrastive loss on the projector output and mse loss on prior outputs are simply added together).
We found that this works better than first training a contrastive model then freezing it and training a diffusion prior on its outputs (similar to CLIP + DALLE-2). That is, the retrieval objective improves reconstruction and the reconstruction objective slightly improves retrieval.
I agree only one is used in inference, but two are needed for training (otherwise how do you calculate a meaningful loss function?). Notice in the original CLIP paper, there's an image encoder and a text encoder, even though only the text encoder is used during inference. [0]
[0] https://arxiv.org/pdf/2103.00020.pdf