Building on scikit-learn


Cambridge Analytica is always finding innovative ways to maintain our position at the forefront of data analytics. While our data scientists work with the very best implementations of machine learning algorithms and technologies, they often need to go beyond the basic toolkits. In this blog post, we discuss the process of building on the commonly used Python scikit-learn library.  

The scikit-learn package is an industry standard machine learning library in Python that is widely appreciated for both the breadth of machine learning algorithms and data analysis tools it offers, and the simplicity and elegance of its API. Nevertheless, machine learning as a research field is developing at rapid pace, and to work at the cutting edge will often require going beyond the toolkit provided by scikit-learn and other mature machine learning libraries. Here we explore one straightforward way to build on the scikit-learn library by creating our own estimator, whilst preserving complete compatibility with the scikit-learn ecosystem.  

The scikit-learn API

Scikit-learn has an elegant API (well, we think so) that is common across all of its objects. This API is widely imitated by other machine learning libraries both in Python and other programming languages due to its simplicity and transparency. The API can be broken down into three main objects: estimators, predictors and transformers, each of which provides a distinct interface. Classes in scikit-learn will then provide at least one of these interfaces (often two or sometimes three), sometimes via single or multiple inheritance.  

Estimators are objects that implement a fit() method:[, y, **kwargs])

This method encapsulates the learning or training part. The first argument, X, is the training data, and the second, optional argument, y, is the target vector, which is provided in supervised learning settings. Some estimators provide additional keyword arguments, e.g. sample_weight.

Predictors are objects that implement a predict() method:


This method generates a set of predictions from some input data, X. Predictors for classification algorithms also sometimes provide predict_proba() and/or decision_function() methods that quantify, for each sample, how certain the output of predict() is.

Transformers are objects that implement a transform() method:


This method returns a transformation of the input data, X. Classes that additionally provide the fit() method will often provide a fit_transform() method if this is computationally more efficient than the chaining the fit() and transform() methods.

Creating a scikit-learn compatible estimator 

Let’s say you want to use a machine learning algorithm that is not included in the scikit-learn library. You can build it yourself (or use an implementation from elsewhere), but you want to combine the algorithm with other pieces of scikit-learn functionality using the scikit-learn Pipeline say. What you want to do in this case is create a scikit-learn compatible estimator.  

As an example, let’s say we want to build an ordinal classifier. This is a multi-class classifier that takes into account the natural order amongst the class labels. An example application would be in modeling levels of preference as indicated from a survey (e.g. on a scale from 1 to 5 to mean, say, “very poor” to “very good”). One approach to this problem involves partitioning the data according to the various class labels and learning a series of binary classifiers, the predictions of which are combined to provide a ranking.

To build this as a scikit-learn estimator, the minimum to construct the class is as follows:  

from sklearn.base import BaseEstimator, ClassifierMixin

from sklearn.linear_model import LogisticRegression

from sklearn.utils.validation import (check_X_y, check_array,check_is_fitted)

from sklearn.utils.multiclass import unique_labels


class OrdinalClassifier(BaseEstimator, ClassifierMixin):


            def __init__(base_classifier=LogisticRegression()):

                self.base_classifier = base_classifier


            def fit(self, X, y, **kwargs):

                # Check that X and y have the correct shape:

                X, y = check_X_y(X, y)


                # Store the classes seen during the fit:

                self.classes_ = unique_labels(y)


                # Store a list of fitted binary classifiers that

                # are initially cloned from that provided:

                self.classifiers_ = []


                # Fit the various binary classifiers and append

                # to the list:

                for i in range(len(self.classes_) - 1):

                    # The algorithm goes here...


                return self


            def predict(X):

                # Check that fit() has been called:

                check_is_fitted(self, ['classes_', 'estimators_'])


                # Input validation:

                X = check_array(X)


                # Compute the predictions using the fitted classifiers

                # Return the result…


So, what does this class do? Let’s consider the component parts.

Class inheritance and API

All estimators in scikit-learn must inherit from BaseEstimator, which provides the get_params() and set_params() methods. Being a classifier, OrdinalClassifier also inherits methods from ClassifierMixin via the mixin pattern, which provides the score() method appropriate for classifiers. For regression algorithms, one should inherit from RegressorMixin , and for transformer objects one should inherit from TransformerMixin . OrdinalClassifier is an estimator, so the fit() method must be provided, and we also implement the predictor API (we provide a predict() method), since this is a supervised learning algorithm. We could also have provided the predict_proba() and decision_function() methods here. 

The class constructor

It is important that the class constructor (the __init__ method) does not do anything except set the input parameters as class attributes. There should be no logic on top of this that does anything to transform the input. Not even input validation is allowed. This is to ensure that the inherited get_params() and set_params() methods work correctly. The arguments to the constructor should just be hyperparameters describing the model. Training data is passed to the fit() method only, and unlabelled data is passed to the predict() method only. All arguments to the constructor should also have default values, so that it is possible to initialise the class without any arguments.

The class methods

The fit() method should return self to enable chained expressions like the following:

        y_pred = OrdinalClassifier().fit(X_train, y_train).predict(X_test)

Given that this is a supervised learning algorithm, the fit() method takes both X and y as arguments, though even unsupervised learning algorithms should take y as an argument with a default value of None. This is to ensure that unsupervised and supervised learning algorithms can be chained as part of a scikit-learn Pipeline . By convention, attributes that have been estimated or derived from the data have names ending with a trailing underscore, and these are always set as part of the fit() method.

Notice the use of the sklearn.utils module in both the fit() and predict() methods for input validation and data inspection tools. These are provided by scikit-learn as public API for the purpose of implementing custom compatible objects.

Once you have finished building your estimator, you can check that it is compatible using sklearn.utils.estimator_checks.check_estimator() . For a deep dive on this topic, we recommend reading the scikit-learn contributing guide.