LLMs as classifiers
Feb 9, 2024
Lefteris Loukas, Ilias Stogiannidis, Odysseas Diamantopoulos, Prodromos Malakasiotis, Stavros Vassos | 2023 | Paper
When I’ve heard folks talking about AI strategy recently a common trope has been that things are moving so fast that it is better to hold off product investments until the pace of change slows and the stack starts to stabilise. Instead we should be focussing on the low hanging fruit from the productivity lifts of using chat assistants or GPTs. Another common argument is to focus on collecting data now and finetuning improved base models later.
I think these viewpoints are partly true but perhaps miss something that’s worth reflecting on a bit - how does the advent of highly capable LLMs like gpt-4 alter product discovery and experimentation around problems you might have solved with classical ML approaches? If you use an LLM as a fast and scrappy alternative to say, finetuning, how much worse are your results, and at what point should you shift your approach to reduce your costs and improve your accuracy? And is collecting all that data always that valuable anyway?
If we can get this transition right as we move along the product maturity curve, we can maximise the value that we’re getting from LLMs at each point of the product lifecycle. This gives us all the benefits of starting today while making execution that bit easier and enabling us to keep our costs down.
Making LLMs worth every penny
A recent paper, “Making LLMs Worth Every Penny” walks through this thought process by deep diving into one common problem - classifying support calls when you have a weak or incomplete dataset. If you’re trying to launch a scrappy product experiment, it’s unlikely that you have the time to collect, curate and label a dataset with the requisite thousands of examples to employ a a classical natural language processing (NLP) technique. You could fill the gap manually until you got enough learning (or data) to justify the investment but that just sounds like too much hard work.
To explore this problem, the paper tests the performance of a few different ML and AI approaches with Banking77, an open source dataset with about 10,000 short (~11 words) online banking queries. The dataset is composed of stuff like this: “What do I do if I have still not received my new card?”. Each query has a label, in this case card_arrival
, and there’s 77 different labels (or classes). The name of the game here is to work out what the intent is behind each query - what does the user need?
To work out what our product maturity curve should look like, we can examine each approach the paper uses for classifying, then compare how accurate each method is at labelling support queries.
Finetuned MLM
Finetuning means taking a base language model and then working out the optimal setting of all of its parameters to make it great at a specific task (a parameter is something like a weight or bias in a neural network layer). In this case, the team used all-mpnet-base-v23, a top performing base model. MPNet unifies two language modelling approaches - permuted language modelling (PLM) and masked language modeling (MLM).
In a masked language model (MLM) approach, we blank out certain words in the sentence and then guess what the missing words are (“Let’s _ to the beach!”). The model gets better at this guessing task by learning about the context words around the blank, and it uses this knowledge to get better at inferring meaning in the sentences. In a permuted language modelling (PLM) approach, we try to get the same objective (understanding meaning) a little differently, we take a sentence (“Let’s go to the beach!”), jumble it up into the n! possible word orders in the sentence (“to the beach! Let’s go”, “go to Let’s the beach!” …) and then work out the correct order of the words.
The MPNet model unifies MLM and PLM by taking the strengths of both approaches. It permutes the sentence (“to the beach! Let’s go”), splits it in two (“to the beach!”, “Let’s go”), masks the right most word in the first sentence part (“to the _!” “Let’s go!”), then from these jumbled up and blanked out sentences works out the missing word; using the left part of the sentence for it’s context and the right part to predict the missing word. The model also tries to solve the problem bidirectionally - meaning it starts from both the beginning and the end of the sentence and heads towards the middle. This combination of approaches means the model gets very good at working out how missing words and word order impact meaning. Scale this approach up to 160gb of text and you get a very capable model that’s great for things like semantic search and creating sentence embeddings.
You can take models like this and then fine tune them on a specific task - in our case labelling financial support queries. This fine tuning is basically a second prediction task - given our Banking77 dataset of customer support queries (“What do I do if I have still not received my new card?”), we sample a bunch of these queries to build a validation dataset and then run a series of class prediction tasks (matching the query to a label like card_arrival
), varying all of the parameter settings we can as we run these to find the tunings that give the best results (best in this case is judged using a micro f1 score).
This type of finetuning is fairly intensive. You’ll (probably) get great results but you’ll need a chunky amount of time with a GPU and a large validation dataset to get there. Both of these things are expensive in terms of wallet and effort.
Few shot contrastive learning with SetFit
Built by the team at HuggingFace, SetFit fine-tunes a Sentence Transformer model on a small number of labeled text pairs (normally 8 or 16 examples per class; think the card_arrival
label we discussed before). SetFit uses a contrastive learning approach where positive and negative sentence pairs are constructed by in-class and out-class sampling. This means we pair up a query from our dataset that succinctly describes the problem “I’ve not received my new card” with one that very clearly does not “How do I change my PIN?”. The sentence embedding (this is a vector or array of numbers) created from this training run then gets fed into a classification head - this trains the model on the relationship between the sentence embedding we’ve created and the label for each class of data. At runtime when we get given lots of unseen data (“Erm, my cards still not arrived, what do I do?”), we pass each prompt through the sentence transformer, generate the embedding to capture the meaning of the sentence, then match that embedding to one of our class labels using the classification head. The advantage of this approach is that it’s super fast and cheap to train as the dataset is so small; Hugging Face claim a 28x speed up in training time. Even better, it’s possible to train the model on a CPU if you are patient. There’s a quick start here if you want to get a feel for how simple this is in code.
The downside of all this is that it takes a little bit of time to understand and groom your data into the right format, and you need to do a tiny bit of training before you get started. You also need to host your model somewhere and call it inference time. This is straightforward, but maybe you just don’t want to write any python to build a working prototype to show to your team. That’s cool too.
In-context learning
Given all of this, an alternative approach is in-context learning. This involves prompting an LLM with a context and asking it to complete our classifying task without fine-tuning. The context usually includes a brief task description, some examples (the context), and the instance to be classified. Something a bit like this:
You are an expert assistant in the field of customer service.
Your task is to help workers in the customer service department of a company. Your task is to classify the customer's question in order to help the customer service worker to answer the question.
In order to help the worker, you MUST respond with the number and the name of one of the following classes you know. If you cannot answer the question, respond: "-1 Unknown".
In case you reply with something else, you will be penalized.
The classes are:
0 activate_my_card
1 age_limit
2 apple_pay_or_google_pay
3 atm_support
...
75 wrong_amount_of_cash_received
76 wrong_exchange_rate_for_cash_withdrawal
Here are some examples of questions and their classes:
How do I top-up while traveling? automatic_top_up
How do I set up auto top-up? automatic_top_up
...
It declined my transfer. declined_transfer
If you’re a fan of prompt engineering, note the threats, the capital letters and the specific instructions for handling the error condition. This is straightforward. You can refine your prompt in ChatGPT with no setup required - just a bit of domain knowledge and (in this case) about 240 examples (3 for each of our 77 different labels).
The results
The team found that using GPT-4, 3 examples per label (“3 shot”), sticking the prompt in the system context and choosing carefully selected representative examples gets you a roughly equivalent performance to a 5 shot setfit based approach when working out query intent from the Banking77 dataset. If you’re lazy or time poor, the difference between 3 shot and 1 shot with GPT-4 is about 2% for both micro and macro F1 🤯.
Finetuning leads the field (about 11-12% better for micro and macro F1 than 3-shot GPT-4). Though you can get pretty close to a finetuning level of performance with a 20 shot set-fit approach (about a 3% drop-off in performance but with substantially less data and less training time).
What this means that if you want to quickly start classifying data like these support queries, all you have to to do is come up with a set of labels (just under 80 in this case, and they can even overlap slightly), generate 3 or so good quality representative examples for each label and then adopt something like the prompt given above and you’ll (probably, this is only proven on one dataset and task afterall) get a very good classification performance. That’s around a morning’s worth of focus time if you’re skilled in the domain.
What’s more, if you can write a prompt, you can do it (stick it in the chat history initially and try it out, then get chatgpt to write the tiny bit of boiler plate and give you the instructions you need to move the prompt to the system context, call the OpenAI api and host this as a lambda/cloud function). Persist each response you generate somewhere so you build your example data over time and you’ve set yourself up to quickly get to top decile performance on the task. This feels incredible to me.
Getting the costs down
There’s one catch, the paper shows that making ~3,000 or so queries with GPT-4 costs $740. I assume that that this is based on OpenAI’s October-ish ‘23 pricing (there’s been a price drop since then) but still, eeeeeshh.
The prompt we’re sending to gpt-4 is quite large. At inference time we’re sending a lot of unneeded context alongside our prompt - we can probably take a quick glance at the query we’ve been fed and work out which of our labels and examples are the right fit. We’d also like to use some of this saved context to ramp up the number of examples we’re using for each label.
The team behind the paper say this is the point we should reach for Retrieval Augmented Generation (RAG). If you stick with GPT-4 and use a RAG to select the 20 most similar examples for the prompt context you can get a 4% boost in performance while dropping the cost by 2/3. Switching out GPT-4 for Claude and sticking with the RAG approach decreases your costs further - to about $42 while only losing 2% of performance versus GPT-4. You do all this with the same 3 shot dataset (240 datapoints if we have 80 class labels). Not bad.
For me though, I would be moving to the SetFit approach mentioned earlier, the costs are tiny (inference endpoint hosting on Huggingface starts at about $0.06 an hour, but you could just roll your own and only pay per call) and if you’re persisting your responses as you go along you should have a much richer, labelled dataset by this point. With a tiny bit more effort you get the top decile 20 shot SetFit performance and lower costs still - this feels worth the investment.
The product maturity curve
So we started out by asking a few questions: If you use an LLM as an alternative to fine tuning how much worse are your results, and at what point should you shift your approach to reduce your costs and improve your accuracy? How do we maximise the value that we’re getting from LLMs at each point of the product lifecycle?
The results in the paper walks us through how our approach should shift as we move along the the product maturity curve. Prototype, experiment and validate with an LLM, then once you’re ready to ramp up and you need to increase accuracy and cut costs move to a 10, 15 or 20 shot SetFit model. At scale, move to fine tuning if you need to crank out that last couple of percentage points of performance. You can move along this curve, at least for this use case, with far less data than you think - using the maturing state of your product to capture and improve the quality of your validation data as you go. This helps you execute much faster, today.
One last thought to leave you with. Google’s Gemini Pro launched in the last couple of weeks with a staggering 1 million token context window (10 million in research 🤯). This achievement is made impressive not just by the size of the window but by the recall across that context - when using a Needle in a Haystack validation it can pull back a required fact accurately 99.7% of the time (in Google’s research, it’s maintains this performance in text mode up to a 7 million word context length). The technical report also outlines another similar task to the one we’e just outlined - learning Kalamang.
With only instructional materials (500 pages of linguistic documentation, a dictionary, and ≈ 400 parallel sentences) all provided in context, Gemini 1.5 Pro is capable of learning to translate from English to Kalamang, a language spoken by fewer than 200 speakers in western New Guinea in the east of Indonesian Papua, and therefore almost no online presence. Moreover, we find that the quality of its translations is comparable to that of a person who has learned from the same materials.
This is quite similar to the use case we’ve been working through here - learning a new domain in depth while data poor. So I wonder - if you can dump a prompt containing the fullest, labelled dataset that you might have at any point in time, then pass the model unseen data to label, I wonder how that would change the results? At a guess you’d say that the gap between the in-context approach and finetuning is likely to keep narrowing over a near term (1-2 year) horizon. Interesting times.