In the previous episode of Private AI’s ML Speaker Series, Patricia Thaine (CEO of Private AI) sat down with Arvid Frydenlund (PhD candidate at the University of Toronto in the Computer Science Department and Vector Institute) to discuss his latest paper Language Modelling via Learning to Rank, presented at AAAI-2022 and published at the 6th Workshop on Structured Prediction for NLP at ACL 2022.
In this paper, co-authored by Gagandeep Singh from Nuance Communications and Professor Frank Rudzicz, the authors discuss how they trained a language model using multiple targets per time-step via a structured loss and rank-based distillation.
If you missed this last discussion, scroll down to find a recap of Patricia and Arvid’s chat or watch the full session below.
Watch the full session:
PAI: Could start with an overview of your paper for us?
Arvid: Sure. So the main motivation of the work is that we think that current ways that we train language models using cross entropy is poor because it only tries to maximize the probability of getting a single ground truth word correct at a given context. And it doesn’t account for the fact that there may be many other potential words, which could continue a context in a valid way.
So let’s call the words, which could form a valid continuation of a context to the given time step, the branching set of a given time step. We thought that label smoothing might be a way of solving this problem since it would allow us to have multiple targets per time step. However, if you want to implement label smoothing for language modelling, you’d want to do it such that the label smoothing weights would change dynamically per time step over the weights over the branching set.
To do this dynamically, what you would end up with doing is just getting a pre-trained language model to generate these label smoothing weights, which is really just a form of knowledge of a pre-trained language model. So we really want to avoid doing that.
So instead we propose that we can forgo needing label weights at all, so long as we can create an ordering over the labels in the branching set. Then we developed a novel N-gram algorithm to create these branching sets with these rank ordering and then incorporate this information into the training procedure as a structured ranking loss function.
PAI: Can you tell us about what the most common ways to train language models are?
Arvid: Sure. So the way you train a language model is really dependent on the type of language model you are using. And there’s really three different types of language models. There’s N-gram language models, there are neural autoregressive language models, and then there’s really BERT as the third one. So when you train N-gram language models, really, you’re just getting N-gram counts. Then the first type of neuro language model, which is the focus of our paper, are the traditional language models and these are trained using cross entropy where you try to predict the next word at a given time step given the previous words as context.
Then finally there’s BERT and derivatives, which are really bi-directional language representation models. And BERT makes use of a so-called mask language model objective where a certain percentage of the input sequence is masked. And then the objective is to predict the masked out words, right? So you’re still using cross entropy. The input and the target pairs are now different compared to autoregressive language models. And then there might be minor variations on these, right? So you might use label smoothing inconjunction with cross entropy for autoregressive language modeling, which is something that happens. It’s pretty standard in, say, machine translation. And then there’s a bunch of different variations for BERT, variations on the training objective, and variations on the auxiliary costs.
Just to go off on a bit of a tangent, I personally avoid calling BERT a language model, even though everyone does call it a language model, since it can lead to some confusing results. If you think it’s directly comparable to autoregressive language models, which it isn’t.
Source: Page 07 of Language Modelling via Learning to Rank
One of the tables, we have a comparison between BERT and GPT-2. And you look at the perplexity and you think, well, the perplexity is really low for Bert, and accuracy is really high compared to GPT-2. And so you’d think that BERT would be a much stronger teacher compared to GPT-2. The problem is these values are not comparable at all. So it’s actually kind of wrong with us to put these in the same table, which we do for space.
But if you think of them as both being language models, you would look at this comparison table and think that they’re comparable. And then the outcome of the paper would be very surprising to you when you find out that BERT doesn’t act as a good teacher where GPT does. The problem is that BERT really isn’t a language model, at least not comparable to autoregressive language models.
PAI: Is there anything that you would prefer transformer models to be called instead?
Arvid: Not really. Transformer models. I think transformers are okay. I like the name for transformers. It’s fine. I have a problem with BERT being called the language model. I think it leads to some confusion. And the problem is like, really, 80% of it is a language model. It’s just that 20% is somewhat of a big change.
PAI: Can you explain knowledge distillation to us?
Arvid: Sure. So knowledge distillation is a method for training models where you have a pre-trained teacher model and you want to distill the teacher’s performance onto a new student model. So in general, the use case here is that the student model is smaller than the teacher model and hence, is more manageable for as a production model. But because it’s smaller, it would be less capable if you otherwise didn’t help train with using the teacher’s knowledge. So really, knowledge distillation is a kind of training procedure. And the way this is done is that the teacher produces a distribution at every time step, and then the student tries to match that distribution using that KL loss. So now the target for the student is the distribution over the classes, or words, in the case of language modeling.
You can think of this as trying to match the probability of the teacher or logit-matching instead of just trying to match a single ground truth class. However, for our work, we’re not using knowledge distillation for the general use case, we’re using it to solve an annotation problem. Basically, we need multiple potential words for every time step, the branching sets, but we don’t have these annotations. So, we use a pre-trained language model in lieu of a human annotator to generate these branching sets for us..
PAI: What is rank based knowledge distillation?
Arvid: So rank based knowledge distillation is an intuitive extension to traditional knowledge distillation. So instead of matching a distribution, we can convert that distribution to a rank order list from most probable to the least probable words and then we just try to match the order of that list. So in particular for our work, we truncate the list of the top-k best words then we try to match it using the Plackett-Luce rank loss, which is a structured ranking loss.
So just to take it back to the beginning for a little bit – in our original train loss, which was cross entropy at a given time step, we’re trying to predict a singular ground truth word. In traditional KL based knowledge distillation the teacher generates the distribution over the entire vocabulary, which we’re trying to match. However, in rank based knowledge distillation the teacher generates a top-k or list and we’re trying to match the ordering of that list.
PAI: Why is that important? What is it capturing that other methods are not?
Arvid: So this is a really good question and one that we don’t actually fully explore in the paper. It’s important for us because the N-gram algorithm that we use is incompatible with traditional methods. So we need it as a result of trying to use our N-method. But it also does actually have some information that is capturing in terms of the structure of the problem. So really what happened is we have this annotation problem where we’re trying to use a pre-trained language model as a substitute to having humans annotate these branching sets. However, this means that we have to use a language model to train a language model and that’s not really a great constraint that we want.
So then we developed this N-gram algorithm, or actually, we thought that N-gram language models would help us with this problem. The problem is N-gram language models produce a fairly weak distribution. So if we just try to match that distribution using traditional knowledge distillation, we’re going to get really poor students, right? So then we rethought, what’s actually going into the N-gram language model? Then, kind of stripped it to its bare bones and got this novel N-gram algorithm. The problem is this novel N-gram algorithm is a non-probabilistic model. We don’t have probabilities that we can match using traditional knowledge distillation. So we have to use rank based knowledge distillation to actually be able to use that method. The reason why we’re using it is out of necessity for our methods.
However, it is actually capturing things that other methods aren’t doing and we can see this in prior work. The first paper that I know of doing rank based knowledge is an AI stats paper from last year and they apply it to actual ranking tasks so that the teacher is a ranking model and the student is a ranking model. They show that rank based knowledge of the student does better than traditional KL based knowledge to the student when your task is actually a rank based problem. So that makes a lot of sense that you can do that and, in general, ranking is just a form of structured prediction. You have various different kinds of structures and so there’s lots of other papers that do structured based knowledge distillation and where they’re trying to capture this structured information.
So to answer the question, there’s really two reasons why you do this: you either have a structured problem that you’re trying to capture this structured information, or you can apply it to knowledge distillation of the problems where you can’t otherwise for traditional knowledge based distillation.
PAI: That’s very exciting. You also use an N-gram algorithm to create a non-probabilistic teacher. I think you touched upon this quite a bit. But is there anything else you want to add? And can you tell us about how this helps with training and why it’s efficient?
Arvid: Yeah. So really, it’s just getting us our annotations that we lack. Again, you could think of having a human annotate for us, but that would be way too expensive and time consuming. So we’re trying to bootstrap our information using this N-gram algorithm. And again, we can also use a pre-trained language model to do this, but we don’t want to do that because then you have a language model that uses a trained language model.
In terms of the question about why it’s efficient, there’s various ways you can kind of unpack what it means to be efficient. The algorithm that we created is efficient and it only needs to be done once to create these N-gram branching sets before training our student model. Then we can incorporate this information into training our student language model using an efficient GPU implementation of Plackett-Luce.. So it’s not costing us anything in terms of computation to be able to do this. But there’s actually a more interesting idea in terms of efficiency that we don’t really touch upon in the paper, this is that the N-grams are collected from global statistics over the train data set.
In comparison, when you just train using regular cross entropy training, you’re only ever doing a local procedure where you’re using local information to update a time step at a given time using that local information. But when you introduce these ranking information, either from the N-grams or from the pre-trained language models, you’re introducing global information about the training data set into your local training procedure. And so it could be that there’s actually a training efficiency there, but it’s not something we actually explore in the paper, though, we do mention it.
PAI: Okay, awesome. And I guess you also don’t have to spend millions of dollars training these gigantic language models.
Arvid: Yeah, exactly. That’s also kind of like a future work that we kind of talk about. But one of the efficiencies there is you can kind of bootstrap yourself using these N-grams, which is really just a rule based system. So it could be that it helps for cases where you don’t have a lot of data. Where you’re making these extra targets that you wouldn’t otherwise have from this rule based system so you don’t actually have to have an annotator annotate your otherwise small data sets.
PAI: You found that GPT-2 is a better teacher than BERT and also Born-Again models. Can you describe the difference between these three models and why GPT-2 makes for a better teacher?
Arvid: Sure. So the Born-Again models are kind of a strange thing. Basically, they’re the case when your teacher model is the exact same model specification as your student, right? So it’s kind of a self-distillation and the surprising thing here is people have found that this self-distillation actually improves your student performance. So you just train one model, then you can go back and use the same model as the teacher. That’s what the Born-Again models is. GPT-2 is an autoregressive language model, like our student models, that we’re trying to train. BERT again, is not really a language model. Like I said, it’s a bi-directional language representation model. The reason why we chose these three different models is because of a data leakage issue. So GPT-2 and BERT are trained on a bunch of extra data that the Born-Again models and our N-gram models are not trained on. The Born-Again models and the N-gram models are just trained on our actual training data set, where GPT-2 and BERT are using a whole bunch of auxiliary data. So that’s why we want to use the Born-Again model here, because it allows for a better comparison with the N-gram model. But because of this self-distilation, the Born-Again models are not going to be very good models. It’s the exact same model as a student so that’s why you expect them to be much worse than GPT-2. So then we also want to use both GPT-2 and BERT because they are different types of models. BERT uses future information and so we thought that it might produce better ranking ground truths because it’s using that actual future information. This turned out to actually be a problem and why it performed fairly badly.
Table 8: Example top-k from bert-wwm-24, gpt2-774, and N-grams using the Wiki02 training partition at block 554 with ground-truth text ‘Museum . The DVD released ranked No. 1’.
Source: Page 23 of Language Modelling via Learning to Rank
In the appendix, here’s a good example. The ground truth sentence ends in museum and then a period, so the sentence actually stops there. And because BERT sees future information, it knows that that sentence has actually ended and it doesn’t actually try to continue the sentence. It puts us into the end of a sentence or it tries to make it like a new header. This is also the header marker of our dataset where GPT-2 actually just tries to continue the sentence. This is going to be a mismatch between our student models, which are autoregressive, which will try to continue the sentence. It’s a mismatch between the downstream task that we’re actually evaluating our students on. Where GPT-2 actually matches this task and so it doesn’t actually have this problem. There are also other issues with BERT. It also had an orthography issue where it really liked to try to match examples of the orthography instead of matching words that would be similar. So where GPT-2 matches similar words that are synonyms, this tries to match ‘the’ and ‘The’ as like capital ‘The’, even though this would be very ungrammatical in the middle of the sentence. It’s trying to match how the word looks instead of the function of the word, which was an unexpected thing that BERT would do. Again, so like ‘in’ with a capital ‘I’ – ‘Iny’ is actually an Egyptian goddess name that’s just matching because it looks like ‘in’, where GPT-2 actually gets the grammar of the word. And then you can see that N-grams also gets this pick up. So we don’t actually need our pre-trained teacher model to be able to get the second rank, which also matches GPT-2. And same with numbers, they’ll expand it.
PAI: Would you expect GPT-3 as a teacher to give about the same results or better?
Arvid: It would probably do better than GPT-2. Again, that’s going to be probably because it’s a better model in itself and because it’s trained on better data. But this kind of question is not really necessary for our work. In our work, we’re trying to talk about how you could actually distill something, not what’s actually being distilled. So, yes, our students might do better, but it’s just because it’s distilling better information. That’s not actually talking about the process of being distilled, which is what we’re actually experimenting on. But in general, if you actually want to apply our methods to a real application, yes, you’d probably want to pick the best feature, which should probably be GPT-3.
PAI: Can you describe the experiments you ran to come to this conclusion? You already described quite a few of them. But is there anything else that you want to highlight?
Arvid: We have this really big table that’s kind of confusing. And it was a really big table not broken up for paper space reasons, but it’s really trying to show two ideas here.
Source: Page 08 of Language Modelling via Learning to Rank
So kind of going back to the motivation of the paper, we want a better training method for language models and we think we can do this by making use of multiple targets for timestamp. And then we really have two problems.
We have the problem that how are we going to get these multiple targets, which we do using either the pretrained language models or our N-gram models. And then for the second problem, how do we actually incorporate these multiple targets into the train procedure, which we do using the Plackett-Luce rank loss.
So then we have two questions. How well are our pre-trained language models or N-grams doing in comparison to each other i.e. how well are they actually training these branching sets? And how well is our rank based knowledge that uses the rank loss actually doing compared to traditional knowledge distillation and to traditional cross entropy? So the experiment set up is kind of like we’re comparing everything to our cross entropy baseline and basically everything, or at least all the cases that we actually care about, do better than the cross entropy baseline.
Then there’s multiple Plackett-Luce losses here. This is kind of a detail in the paper that you can look up in the paper. In general, you can think of these as being all the same kind of Plackett-Luce, and one Plackett-Luce will do better than the cross entropy. It might not be the case for everyone, but one version of Pluckett-Luce will do better. Then we find that knowledge distillation always does better than cross entropy, either if it’s rank based knowledge distillation or traditional KL based knowledge distillation.
And so then we have the other question: does rank based knowledge distillation do better than KL based knowledge distillation? This is actually a surprising result. We find that in the majority of cases it does do better than our traditional KL based knowledge distillation, but this wasn’t something that we actually expected to find.
What I was expecting to find would be that it would actually do worse. Then I wanted to actually quantify how much it was doing worse so that when I applied it to the N-grams I could say, yes, you’re losing a little bit of information using the rank based knowledge distillation, but it’s not that much. And then we can actually do the comparison between the N-grams and the other teacher models.
So this was a happy accident that we found that actuallyit did better than traditional KL based knowledge distillation, at least for our specific tasks. Now that we know that we can use this rank based knowledge distillation, we can then apply it to the N-grams and show that the N-grams perform better than the Born-Again models, better or comparable to the BERT models, and not as well as GPT-2.
But again, this is really to be expected. The N-gram method is a very naive method. So you’d expect it to do worse than the very strong language model, GPT-2. And so the fact that it actually does as well as it does is the surprising result.
PAI: What do you find the most exciting about the findings you presented?
Arvid: The major finding is that rate base knowledge distillation might be applicable to tasks that aren’t really considered as ranking problems or from teachers that aren’t really models. Our N-gram algorithm isn’t really a model in a traditional sense. It’s a set of rules to generate targets, so the major contribution to this paper is that we’re taking the problem of language modeling and reframing it as a ranking task. And then when we combine this with the N-gram algorithm, what we’re really doing is distilling a rule based system. So we have this rule based inductive bias that we think will help with training, and we use this rule based system to generate these structured targets which help augment or support our original training targets. Then you can incorporate this informative set of rules into the model via the training procedure without actually having to modify the model at all. I think this would actually be a fairly general setup that would maybe be applicable to a lot of problems outside of language modeling. That’s probably what I’m most excited for. I think it is a general framework for incorporating rule based systems to train neural models.
PAI:Do you have any future work for planning on that aspect?
Arvid: Yes. So one of the things I didn’t mention in the paper is that there’s a problem called representation degradation problem, which happens when you train a very skewed multi-class system using cross entropy train. And I think training language models via learning to rank might actually help this issue when it appears in language modeling.
Since I’m talking to you, one of the other future uses that I mentioned is that it could be used for black box knowledge distillation. Say you expose a model via an API that just gives you the top K outputs, right? Well, you can now do knowledge distillation from that – what would otherwise be a black box model. And I imagine that would have some privacy concerns. Not that I know a lot about privacy, but it’s an interesting idea.