Model Finetuning for Fun and Profit

September 18, 2018 / Data Science, Developers, Machine Learning, Text Data Use Case, Tutorials

In our last blog post, “Effective Transfer Learning for NLP” we walked through the technological advancements that have made model finetuning practical for natural language processing. In this post, we’ll dive into the details and take a look at how you can start using Indico’s python library, finetune, to try out model finetuning on your own tasks.

Example Usage

First, let’s take a brief peek at what training a model in finetune looks like.

from finetune import Classifier
model = Classifier()
model.fit(text, labels)
predictions = model.predict(test_text)

If you’re familiar with the python machine learning library scikit-learn, working with finetune will be a breeze — we’ve designed finetune to mimic’s scikit-learn’s API to make it easy to get up and running.

 

Installation

To start using finetune on your own data, however, you’ll first need a working installation.  We offer 3 main options for installation. If you already have a working copy of CUDA 9.0 and cudnn7, finetune is just a pip install away:

pip install finetune

If you don’t have a local copy of CUDA or your CUDA version is out of date, your easiest installation option is to run the docker container that ships with finetune.

git clone -b master https://github.com/IndicoDataSolutions/finetune 
cd finetune
./docker/build_docker.sh      # builds docker image
./docker/start_docker.sh      # starts finetune docker container
docker exec -it finetune bash # starts bash session in the docker container

Finally, if you’d prefer to install finetune on your host machine and stay up to date with the development branch of finetune, you can install from source:

git clone https://github.com/IndicoDataSolutions/finetune
cd finetune
pip install tensorflow-gpu --upgrade
python3 setup.py develop

Overview + Documentation

Although the base finetune model was trained on a language modeling objective, with a small architectural addition to the end of the model, we can repurpose the language model to solve a variety of different tasks.  The paper on which finetune is based, “Improving Language Understanding by Generative Pre-Training” by Radford, et al, demonstrated that this approach functions well for tasks like classification, entailment, similarity, and multiple-choice question answering.

Transformer model architecture diagram

We’ve also added support for multilabel-classification, sequence labeling tasks like named-entity recognition, and multi-field text inputs and packaged everything up into a scikit-learn style interface. The documentation for these model types is available at finetune.indico.io.

 

Example Usage: Airline Comment Classification

Now that you have a working copy of finetune, let’s test drive the library out on a real world comment classification task. We’ll use pandas throughout these examples — it’s a handy library for data manipulation in python. Below is some brief code for inspecting our dataset.

import pandas as pd

# Dataset available for download at
# http://s3.amazonaws.com/enso-data/AirlineNegativity.csv
df = pd.read_csv("AirlineNegativity.csv")
 
print(df[:10])
print(df.Target.value_counts())

This should give us a peek at the first 10 rows of the dataset we’ll be working with and an overview of the class balance of the data.

                                                Text                  Target
0  @VirginAmerica it's really aggressive to blast...              Bad Flight
1  @VirginAmerica and it's a really big bad thing...              Can't Tell
2  @VirginAmerica seriously would pay $30 a fligh...              Can't Tell
3      @VirginAmerica SFO-PDX schedule is still MIA.             Late Flight
4  @VirginAmerica  I flew from NYC to SFO last we...              Bad Flight
5  @VirginAmerica why are your first fares in May...              Can't Tell
6  @VirginAmerica you guys messed up my seating.....  Customer Service Issue
7  @VirginAmerica status match program.  I applie...  Customer Service Issue
8  @VirginAmerica What happened 2 ur vegan food o...              Can't Tell
9  @VirginAmerica amazing to me that we can't get...              Bad Flight
Customer Service Issue         2910
Late Flight                    1665
Can't Tell                     1190
Cancelled Flight                847
Lost Luggage                    724
Bad Flight                      580
Flight Booking Problems         529
Flight Attendant Complaints     481
longlines                       178
Damaged Luggage                  74
Name: Target, dtype: int64

Now that we have a better idea of what we’re working with, let’s train a model on a subset of the data. We’ll preprocess the dataset to get rid of examples where the topic of the comment isn’t known. Because the dataset is class imbalanced, we’ll specify that we’d like to oversample rare classes, and we’ll start by randomly sampling a subset of the data to train on. After training, we’ll evaluate on our validation data to see how well the model is performing.

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from finetune import Classifier

df = df[df.Target != "Can't Tell"]
train_df, test_df = train_test_split(df, train_size=500)

model = Classifier(oversample=True)
model.fit(train_df.Text.values, train_df.Target.values)
model.save('airline.model')
predictions = model.predict(test_df.Text.values)

print(classification_report(test_df.Target.values, predictions))

Finetune will record training progress in your terminal as you wait.

Epoch 0: 34%|███████████▌ | 27/79 [00:40<01:17, 1.49s/it]

Once training is done, you’ll see a nicely formatted table of precision, recall, and F1-scores thanks to scikit-learn’s classification_report function.

                             precision    recall  f1-score   support

                 Bad Flight       0.41      0.27      0.33       549
           Cancelled Flight       0.75      0.69      0.72       796
     Customer Service Issue       0.72      0.79      0.75      2744
            Damaged Luggage       0.00      0.00      0.00        70
Flight Attendant Complaints       0.43      0.21      0.28       455
    Flight Booking Problems       0.34      0.48      0.40       499
                Late Flight       0.70      0.75      0.72      1534
               Lost Luggage       0.67      0.77      0.72       676
                  longlines       0.04      0.01      0.01       165

                avg / total       0.63      0.65      0.63      7488

Let’s compare this to two baseline models. First up, a cross-validated logistic regression model on top of TFIDF features:

                             precision    recall  f1-score   support

                 Bad Flight       0.37      0.31      0.33       538
           Cancelled Flight       0.73      0.67      0.70       805
     Customer Service Issue       0.63      0.71      0.67      2723
            Damaged Luggage       0.00      0.00      0.00        72
Flight Attendant Complaints       0.24      0.12      0.16       457
    Flight Booking Problems       0.32      0.30      0.31       494
                Late Flight       0.57      0.68      0.62      1556
               Lost Luggage       0.60      0.57      0.59       675
                  longlines       0.17      0.05      0.08       168

                avg / total       0.55      0.57      0.55      7488

And second, a mean of word vector feature representations in conjunction with a cross-validated logistic regression model:

                             precision    recall  f1-score   support

                 Bad Flight       0.22      0.22      0.22       546
           Cancelled Flight       0.48      0.63      0.54       778
     Customer Service Issue       0.66      0.55      0.60      2727
            Damaged Luggage       0.06      0.10      0.08        68
Flight Attendant Complaints       0.21      0.26      0.23       450
    Flight Booking Problems       0.27      0.34      0.30       505
                Late Flight       0.54      0.53      0.53      1553
               Lost Luggage       0.55      0.61      0.58       689
                  longlines       0.09      0.02      0.04       172

                avg / total       0.50      0.49      0.49      7488

Even at 500 examples, the model finetuning approach represents a measurable improvement over common baselines, and we can expect this gap to grow with training data availability. Importantly, minimal configuration was required — we’re striving to make finetune “just work” whenever possible.

Example Usage: Named Entity Recognition

Part of the magic of model finetuning is its versatility and resilience to changes in task type. Let’s take a look at an entirely different natural language processing problem, named entity recognition. For this example, we’ll be working with a subset of the Reuters corpus. Let’s start by taking a peek at this new dataset.

                                               texts                                        annotations
0  Paxar Corp said it has acquired Thermo-Print G...  [{"start": 0, "end": 10, "label": "Named Entit...
1  Key Tronic corp said it has received contracts...  [{"start": 0, "end": 15, "label": "Named Entit...
2  Canadian Bashaw Leduc Oil and Gas Ltd said it ...  [{"start": 9, "end": 37, "label": "Named Entit...
3  Entourage International Inc said it had a firs...  [{"start": 0, "end": 27, "label": "Named Entit...
4  Digital Communications Associates Inc said its...  [{"start": 0, "end": 37, "label": "Named Entit...
5  Teradyne Inc said Digital Equipment Corp signe...  [{"start": 0, "end": 12, "label": "Named Entit...
6  Home Intensive Care Inc said it has opened a D...  [{"start": 0, "end": 23, "label": "Named Entit...
7  Bache Securities Inc, 80 pct owned by Prudenti...  [{"start": 0, "end": 20, "label": "Named Entit...
8  Roughly half of this years expected 130,000 he...  [{"start": 116, "end": 127, "label": "Named En...
9  Demand for shares in state-owned engine maker ...  [{"start": 46, "end": 61, "label": "Named Enti...

You’ll notice that the target format has changed from our previous dataset. Because we’re annotating named entities, we now need to provide the location in the text where each named entity appears. Targets for SequenceLabeler models in finetune take the format, [{'start': start_char_location, 'end': end_character_location, 'text': phrase, 'label': label}, {...}].

The subset of the Reuters dataset we’re using contains labeled examples. We’ll set aside 30% of these as a test set and train on the remaining 70%.

import json
import pandas as pd
from sklearn.model_selection import train_test_split
from finetune import SequenceLabeler
from finetune.metrics import sequence_labeling_token_precision, sequence_labeling_token_recall

# available for download at https://s3.amazonaws.com/enso-data/Reuters.csv
dataset = pd.read_csv('Reuters.csv')

# annotations were stored as json strings when writing to CSV
dataset['annotations'] = [json.loads(annotation) for annotation in dataset['annotations']]

trainX, testX, trainY, testY = train_test_split(
    dataset.texts.values, 
    dataset.annotations.values, 
    test_size=0.3,
    random_state=42
)

model = SequenceLabeler(batch_size=2, val_size=0.)
model.fit(trainX, trainY)
predictions = model.predict(testX)
print("Precision: {}".format(sequence_labeling_token_precision(testY, predictions)))
print("Recall: {}".format(sequence_labeling_token_recall(testY, predictions)))

With less than 100 labeled data points, finetune exceeds 80% token level precision and recall — a strong indication that the features learned by the base language model are already well suited to tasks like entity extraction.

Precision: {'Named Entity': 0.8113496932515337}                                 
Recall: {'Named Entity': 0.8252730109204368}

Example Usage: Text Generation

One of the more unique and entertaining features of finetune is more of a happy accident rather than an intentional design decision. Because the base model is a language model (i.e. trained to predict the next word in a sequence), we can take any trained finetune model and investigate what the model has learned by asking it to generate text. For this experiment, we’ll be using the MetroLyrics dataset from Kaggle. We’ll be training a model to predict who the writer of the song is, given lyrics.

Let’s take a look at our training code. There are a few more configuration options that we’ll be using in our model. The parameter lr_warmup specifies that the model should begin with a low learning rate and gradually increase the learning rate over time. The parameter lm_loss_coef stand for “Language Model Loss Coefficient”, and indicates that the model should be not only be penalized for failing to predict the correct author, but should also be penalized for failing to predict the next word in the sequence. By learning to better model which word comes next in the sequence, the model will hopefully also better learn to predict the author of each song.

from collections import Counter
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from finetune import Classifier
import pandas as pd
import numpy as np

df = shuffle(pd.read_csv("lyrics.csv"))
df = df.dropna(subset=['genre', 'lyrics', 'artist'])

# Filter down to the top 30 most prolific hip hop artists
df = df[df.genre == "Hip-Hop"]
prolific_artists = dict(Counter(df.artist.values).most_common(30))
df = shuffle(df[df.artist.isin(prolific_artists)])

train_df, test_df = train_test_split(df, test_size=0.1)
test_df.to_csv('test.csv')

model = Classifier(
    oversample=True,
    n_epochs=5,
    lr_warmup=0.1, 
    tensorboard_folder='.lyrics-tensorboard', 
    lm_loss_coef=0.5
)
model.fit(train_df.lyrics.values, train_df.artist.values)
model.save('lyrics.model')
predictions = model.predict(test_df.lyrics.values)

print(classification_report(test_df.artist.values, predictions))

After 5 epochs of training, we have some initial results:

 
                      precision    recall  f1-score   support                   

                2pac       0.78      0.93      0.84        41
             50-cent       0.78      0.86      0.82        42
            ace-hood       0.89      0.53      0.67        15
     andre-nickatina       0.90      0.90      0.90        20
          atmosphere       0.72      0.72      0.72        18
               b-o-b       0.45      0.50      0.48        10
        beastie-boys       0.74      0.87      0.80        23
            big-sean       0.82      0.82      0.82        22
          bizzy-bone       0.89      0.74      0.81        23
     black-eyed-peas       0.92      0.75      0.83        16
bone-thugs-n-harmony       0.84      0.84      0.84        32
             bow-wow       0.79      0.71      0.75        21
        busta-rhymes       0.86      0.86      0.86        28
             cam-ron       0.78      0.84      0.81        25
      chamillionaire       0.88      0.92      0.90        39
         chris-brown       0.62      0.82      0.71        50
              common       0.83      0.79      0.81        19
        cypress-hill       0.90      0.86      0.88        22
        daddy-yankee       1.00      1.00      1.00        18
                 dmx       0.86      0.64      0.73        28
               drake       0.76      0.67      0.72        43
                e-40       0.89      0.81      0.85        21
              eminem       0.77      0.77      0.77        57
               esham       0.91      0.91      0.91        22
            fabolous       0.81      0.63      0.71        27
             fat-joe       0.80      0.94      0.86        17
              future       0.76      0.76      0.76        21
                game       0.84      0.79      0.82        39
           gangstarr       1.00      0.94      0.97        18
    ghostface-killah       0.88      0.88      0.88        26

         avg / total       0.82      0.81      0.81       803

More interesting than the numbers, however, is this side effect of model training. By jointly learning to predict the next word in song lyrics and classify the author, we’ve unintentionally trained our model to generate novel song lyrics. We can demonstrate this by calling model.generate_text().

from finetune import Classifier

model = Classifier.load('lyrics.model')
# controls tradeoff between greedy and random sampling
model.config.lm_temp = 0.7 
print(model.generate_text(max_length=100))

Below are some of the more entertaining (and reasonably clean) lyrics our model generated. I’ve limited the generated lyrics to 100 tokens, performed a small amount of post-processing to cleanup whitespace, and seeded the generated text with the first word, but these samples are otherwise entirely generated by the model.

Most of them riff off of a single consistent theme:

who would of thought 
that you could turn your back on us 
what's the purpose of being free? 
if you're stuck in the cycle of a slave? 
you're a slave for life
but you chose to be a slave to be free 
and i got no doubt 
dreams of freedom are realer than a dream of riches 
but i'm poison with the words that i've spoken 
i'm a slave to this life 
but i'm still free to be

…and some of them even contain elements of rhyme.

who's that? (who's that?) 
me and my brothers have been running this city since it was first recorded 
ya'll y'all n***** know the deal 
lil' mama, you know the deal 
that i signed to get my pro, had to do it for real 
and i'm still working on my self esteem, i'm a star 
gotta keep my friends close, they hug me like i'm a stranger

With only a few thousand examples, we’ve taken a model trained on books and adapted it to produce song lyrics — nonsensical song lyrics, but song lyrics nonetheless.

who told you that you couldn't dance?
oh, all the women around here like you
and all the men know how to party like you
we see your man in the back and we know what to do
we got a party going on so we always go
come on girl, let me introduce you
to the good ol' boys
we call him mack, and he's super good
i hit him with all my love
that's why i'm here tonight

In Conclusion

We’ve highlighted just a few of the ways you can start applying finetune to your own tasks, but finetune also supports many use cases we couldn’t cover in the span of a single blog post.

Check out finetune’s source code and finetune’s documentation for advanced usage and configuration options. Have a use case you’d like to see supported that currently isn’t covered? Raise an issue on the finetune github repo or submit a pull request. Finetune is still in alpha, so check back for frequent updates and new features.

Interested in becoming a regular contributor? Send me a message at madison@indico.io — we’d love to have your help.

Until next time,

–Madison

Don't Miss a Post!

Subscribe to indico's monthly newsletter to receive the latest blog posts and AI industry news.