Posted by Ancestry Team on December 18, 2017 in TechRoots

Understanding Machine Learning: XGBoost

As the use of machine learning continues to grow in industry, the need to understand, explain and define what machine learning models do seems to be a growing trend. For machine learning classification problems that are not of the deep learning type, it’s hard to find a more popular library than XGBoost. XGBoost can be particularly useful in a commercial setting due to its ability to scale well to large data and the many supported languages. For example, it’s easy to train your models in Python and deploy them in a Java production environment.
While XGBoost can be quite accurate, this accuracy comes with a somewhat decreased visibility into why XGboost is making its decisions. When delivering results directly to customers – especially ones that are as powerful and emotional as what we deliver at Ancestry – this can be a major drawback. It’s very useful to understand why things are happening. Companies that are turning to machine learning to understand their data, also need ways to understand the predictions from the models. This is growing increasingly important. For example, you wouldn’t want credit agencies using machine learning models to predict creditworthiness without being able to understand why the prediction was made.

Another example from Ancestry is if our machine learning model tells us that a marriage record and a birth record are referring to the same person (the task of record linking), but the dates on the records imply a marriage between a very old and a very young person, we might question why the model linked them. In a case like this, it’s extremely valuable to get insights into why the model made that prediction. It may turn out that it’s taking into account the uniqueness of the names and locations and made a correct prediction. It may also turn out, though, that our features didn’t correctly account for age differences between records. In this case, understanding would help us find ways to improve our model.

In this post, we will walk through some techniques to better understand XGBoost predictions. This has allowed us to leverage the power of gradient boosting while still having an understanding of the decisions being made.
To illustrate these techniques, we will be leveraging the Titanic data set. This dataset includes information about each of the passengers as well as whether that passenger survived. Our goal is to not only predict whether a passenger would survive but to also understand why predictions are being made. Even with these data, you can see the importance of understanding the model. Imagine we had data on a recent Cruise ship accident – the goal in building a predictive model would not actually be the predictions, but understanding the predictions to hopefully learn how to maximize survivors in an accident.

In [1]:
import pandas as pd
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import operator
import matplotlib.pyplot as plt
import seaborn as sns
import lime.lime_tabular
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import Imputer
import numpy as np
from sklearn.grid_search import GridSearchCV
%matplotlib inline

The first thing we will do is read in our data. You can find the data on Kaggle. After reading in the data, we will do some very simple cleaning of the data. Namely:

  • Removing name and passenger id
  • Converting categorical variables to dummy variables
  • Imputing and missing data with the median

These cleaning techniques are very simple. The goal of this post is not to discuss data cleaning, but explaining XGBoost, so these are just fast and reasonable cleaning to get to model training.

In [2]:
data = pd.read_csv("./data/titantic/train.csv")
y = data.Survived
X = data.drop(["Survived", "Name", "PassengerId"], 1)
X = pd.get_dummies(X)

Now let’s split into training and testing for our model.

In [4]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

And build a training pipeline with a small amount of hyper-parameter testing

In [5]:
pipeline = Pipeline([('imputer', Imputer(strategy='median')), ('model', XGBClassifier())])
In [6]:
parameters = dict(model__max_depth=[3, 5, 7],
                  model__learning_rate=[.01, .1],
                  model__n_estimators=[100, 500])

cv = GridSearchCV(pipeline, param_grid=parameters), y_train)

Then look at our test results. For simplicity, we will use the same metric as Kaggle: accuracy.

In [7]:
test_predictions = cv.predict(X_test)
print("Test Accuracy: {}".format(accuracy_score(y_test, test_predictions)))
Test Accuracy: 0.8101694915254237
 So, we have achieved a semi-decent accuracy that would place us in around the top 500 out of about 9000 on Kaggle. So clearly, still some room to improve, but we will leave that as an exercise for the reader. 🙂

Let’s move on to actually understanding what our model has learned. A very common method is to use the feature importances provided by XGBoost. The importance of a feature at a high-level is just how much that feature contributed to making the model better. Don’t worry too much about the actual number. Rather, let us use the importances to rank our features and see relative importances.

In [8]:
fi = list(zip(X.columns, cv.best_estimator_.named_steps['model'].feature_importances_))
fi.sort(key = operator.itemgetter(1), reverse=True)
top_10 = fi[:10]
x = [x[0] for x in top_10]
y = [x[1] for x in top_10]
In [9]:
top_10_chart = sns.barplot(x, y)
plt.setp(top_10_chart.get_xticklabels(), rotation=90)
 From the above graph, we can see that Fare and Age are very important features. To dig a bit deeper, we can look at how Fare is distributed across our classes:
In [10]:
sns.barplot(y_train, X_train['Fare'])
<matplotlib.axes._subplots.AxesSubplot at 0x116c35d30>

We can see pretty clearly, that those who survived had a much higher average Fare than those who did not, so it seems reasonable that this feature is so important. Great!

It seems that feature importances are a good way to understand general feature importances, but what if our model makes a specific prediction that we want to understand? For example, maybe we predict that someone with a high Fare will not survive. That goes against our general belief that those with high Fares survived. Though, clearly, not everyone with a high fare survived, we might want to understand what other features led the model to believe this person would not survive.

This type of individual-level analysis can be extremely useful for production machine learning systems. Imagine instead we were predicting whether someone should be given a loan? And we know that credit score is an important feature in our model, but then a customer with a high credit score is denied by our model? How do we explain this to the customer? To regulators?

Fortunately for us, there has been some recent work from the University of Washington on explaining the predictions of any classifier. Their method is called LIME and is available on Github. I won’t spend time here explaining their work, but you should definitely take a look at the paper.

Let’s take a look at hooking up LIME to our model. Basically, you first define an explainer that takes in the training data (we need to make sure we pass it the imputed training data set as that is what we trained on):

In [11]:
X_train_imputed = cv.best_estimator_.named_steps['imputer'].transform(X_train)
explainer = lime.lime_tabular.LimeTabularExplainer(X_train_imputed, 
    class_names=["Not Survived", "Survived"],

Then you have to have a function defined that takes in a row of features and returns an array with the probabilities for each class:

In [12]:
model = cv.best_estimator_.named_steps['model']
def xgb_prediction(X_array_in):
    if len(X_array_in.shape) < 2:
        X_array_in = np.expand_dims(X_array_in, 0)
    return model.predict_proba(X_array_in)

Lastly, you pass an example that you want to be explained to the explainer with your function and the number of features and labels you want to be shown:

In [18]:
X_test_imputed = cv.best_estimator_.named_steps['imputer'].transform(X_test)
exp = explainer.explain_instance(X_test_imputed[1], xgb_prediction, num_features=5, top_labels=1)
exp.show_in_notebook(show_table=True, show_all=False)

Here we have an example that has a 76% probability of Not Surviving. We are then shown which features contributed to which class and how important they were. For this example, having Sex = Female suggested surviving. Let’s look at our bar plot for Sex:

In [14]:
sns.barplot(X_train['Sex_female'], y_train)

Okay – that seems pretty reasonable. Being a female did greatly increase your chance of surviving in our training data. So why the “not survived” prediction? Well, it looks like Pclass = 2.0 contributes a lot to not surviving. Let’s take a look:

In [15]:
sns.barplot(X_train['Pclass'], y_train)
<matplotlib.axes._subplots.AxesSubplot at 0x11861fb00>

Alright – it does look like Pclass equal to 2 shows a lower likelihood of surviving, so we are starting to get more understanding around our prediction. Looking at the top 5 features shown by LIME, though, it still does look like maybe this should be a survivor. Let’s take a look at the label:

In [16]:

This person did survive, so our model got it wrong! But thanks to LIME we have some sense of why: it looks like the Pclass might be throwing it off. This can help us go back and hopefully find ways of improving the model.

That’s it! Some pretty easy and effective ways to better understand XGBoost. Hopefully, these methods can help you leverage the power of XGBoost while still being able to understand and explain the model’s predictions.


  1. caith

    Well, that is a giant leap! What about the nothingness in between, or even the rudiments of the game. Just start on the bottom tier, clean up the navigational mess, the tools that do not work (not just improperly, but not at all), and then commence with your edification.

    Priority, priority. A man without feet, has no need for shoes.

  2. Mary D. Taffet

    Interesting to see a post about machine learning here in this blog. I’m trying to learn about deep learning myself, due to its applications in my field — Natural Language Processing/Information Extraction — but haven’t really gotten started on that effort just yet. I’ll have to come back here and read this post fully when I get a chance. Thanks for posting it!

  3. Robyn Sharpe

    Hi, I am confused about why Ancestry wants such predictive past or possibly future data of our families. I can understand perhaps if this predictive model could help narrow a search or aid a genealogist who is employed by a client searching for a missing relative. But I would be concerned that this type of data analysis would work within the computer system & match possible scenarios within my family to my close relatives that I knew in my life, but a part of my family tree held in Ancestry. I am aware that the computer system already narrows searches for me which sometimes closes access to me through the search process, so searching can be made more difficult for me when that happens, so I have to go through the archives looking at the documents directly because for some reason the search would not bring up my relative that I know is there. So while I am thankful that you are informing me of this program capability, I would like to say; my granny would have not liked me searching our families history/stories and so out of respect for my granny I will not put something in my tree that would have embarrassed her or caused great shame. What we see as light hearted now, was different then & had emotional impact to members of my family up to my fathers generation. I am sorry that I have waffled on a bit, but I hope you understand my feedback, thank you.

  4. Stephanie

    I think this blog entry is above the technical level of most ancestry blog readers. It could do with a couple of paragraphs at the beginning explaining the relevance in simple terms.

  5. Patricia

    I hope that integrating tools of this caliber will assist with developing better content, search result and historically relevant timelines, based upon my user and DNA profile. I use/benefit from similar tools in my profession and hope to see some positive changes as the tool learns and new models are implemented.

  6. Sofia Kogkalidou

    There are some days i am trying to order the kit but it cancels my order. I don’t know the reason, all informations are correct. Whats the wrong? Can anyone please fix it.

    • Member Services Social Support Team

      @Sofia: We’re sorry to hear you have experienced this issue. Please feel free to give us a call on 1-800-262-3787 between the hours of 9am to 11pm EST, seven days a week and we would be happy to place your order for you over the phone.

Comments are closed.