Source code for fairensics.data.decision_boundary

"""Decision boundary plots for 2D data sets.

The code for the scatter function is adopted from fair-classification:
    https://github.com/mbilalzafar/fair-classification.

Usage and interpretation is explained in the example jupyter notebook.
"""

import matplotlib.pyplot as plt
import numpy as np
from aif360.datasets import BinaryLabelDataset
from sklearn.decomposition import PCA

from ..fairensics_utils import get_unprotected_attributes


[docs]class DecisionBoundary: """Class for plotting decision boundaries against two axes. The data may be down sampled to two dimensions before plotting. The decision boundary plots are generated using a mesh grid and the following procedure: 1. If necessary, the data is down-sampled to two dimensions 2. Min and maximum values for each axis are extracted 3. A mesh grid is created 4. If necessary, the mesh grid is up-sampled again 5. Predictions are made on the mesh grid 6. Predictions are plotted against the maybe down sampled axis TODO: add option to scale data to [0,1] """ _UNPRIVILEGED_GROUP_NEGATIVE_LABEL = ( "Negative label and unprivileged group" ) _UNPRIVILEGED_GROUP_POSITIVE_LABEL = ( "Positive label and unprivileged group" ) _PRIVILEGED_GROUP_NEGATIVE_LABEL = "Negative label and privileged group" _PRIVILEGED_GROUP_POSITIVE_LABEL = "Positive label and privileged group"
[docs] def __init__( self, colors=("k", "c", "m", "b", "g", "r", "y"), downsampler=PCA(n_components=2), ): """ Args: colors: iterator over possible colors for the decision boundaries downsampler: function to down sample data points must implement 'fit_transform' and 'inverse_transform' methods """ self._colors = iter(colors) self._downsampler = downsampler self._downsampled = False
def _maybe_downsample(self, dataset, only_unprotected): """Downsample data to 2D for plotting. If only_unprotected is true the protected features are removed for both down sampling or when the raw data set is returned. Args: dataset (StructuredDataset): aif dataset with features and labels. only_unprotected (bool): protected features are ignored if true. Returns: (np.ndarray): 2D array of maybe down-sampled features. """ if only_unprotected: unprotected_features = get_unprotected_attributes(dataset) if unprotected_features.shape[1] > 2: self._downsampled = True return self._downsampler.fit_transform(unprotected_features) return unprotected_features if dataset.features.shape[1] > 2: self._downsampled = True return self._downsampler.fit_transform(dataset.features) return dataset.features def _maybe_upsample(self, mesh): """Up-samples mesh for prediction, if down-sampler was called earlier. Args: mesh (np.array): 2D mesh grid. Returns: (np.ndarray): Up-sampled mesh grid. """ if self._downsampled: return self._downsampler.inverse_transform(mesh) return mesh # noinspection Duplicates
[docs] def scatter( self, dataset: BinaryLabelDataset, protected_attribute_ind=0, only_unprotected=True, num_to_draw=100, ): """ Scatter plot the points in dataset. Protected and unprotected individuals and positive and negative label are distinguished. Only one protected attribute is considered for plotting. Args: dataset (BinaryLabelDataset): data set to plot. protected_attribute_ind (int): index of the protected attribute to consider. only_unprotected (bool): if true, the classifier only uses the unprotected attributes. num_to_draw (int): number of points to draw. """ x_draw = self._maybe_downsample(dataset, only_unprotected)[ :num_to_draw, : ] y_draw = dataset.labels[:num_to_draw, 0] x_protected_draw = dataset.protected_attributes[ :num_to_draw, protected_attribute_ind ] unprivileged_group = dataset.unprivileged_protected_attributes[ protected_attribute_ind ] unprivileged_mask = x_protected_draw == unprivileged_group privileged_group = dataset.privileged_protected_attributes[ protected_attribute_ind ] privileged_mask = x_protected_draw == privileged_group X_unprivileged = x_draw[unprivileged_mask] X_privileged = x_draw[privileged_mask] y_unprivileged = y_draw[unprivileged_mask] y_privileged = y_draw[privileged_mask] plt.scatter( X_unprivileged[y_unprivileged == dataset.favorable_label][:, 0], X_unprivileged[y_unprivileged == dataset.favorable_label][:, 1], color="green", marker="x", label=self._UNPRIVILEGED_GROUP_POSITIVE_LABEL, ) plt.scatter( X_unprivileged[y_unprivileged == dataset.unfavorable_label][:, 0], X_unprivileged[y_unprivileged == dataset.unfavorable_label][:, 1], color="red", marker="x", label=self._UNPRIVILEGED_GROUP_NEGATIVE_LABEL, ) plt.scatter( X_privileged[y_privileged == dataset.favorable_label][:, 0], X_privileged[y_privileged == dataset.favorable_label][:, 1], color="green", facecolors="none", label=self._PRIVILEGED_GROUP_POSITIVE_LABEL, ) plt.scatter( X_privileged[y_privileged == dataset.unfavorable_label][:, 0], X_privileged[y_privileged == dataset.unfavorable_label][:, 1], color="red", facecolors="none", label=self._PRIVILEGED_GROUP_NEGATIVE_LABEL, )
[docs] def add_boundary( self, dataset: BinaryLabelDataset, clf, label="", only_unprotected=True, num_points=100, cmap=None, ): """Adds decision boundary to the current plot. If the data set is two dimensional, the boundary is directly plotted using a mesh grid. Otherwise, a mesh gird is generated on the down-sampled points and up-sampled again for prediction. Args: dataset (BinaryLabelDataset): the labeled data set. clf (object): the classifier object (must implement a predict function). label (str): the label for the decision boundary. only_unprotected (bool): if true, the classifier only uses the unprotected attributes. num_points (int): number of points in mesh grid. cmap (str): colormap from matplotlib. If provided background of the plot is colored. """ dataset = self._maybe_downsample(dataset, only_unprotected) x1_min, x1_max = dataset[:, 0].min() - 1, dataset[:, 0].max() + 1 x2_min, x2_max = dataset[:, 1].min() - 1, dataset[:, 1].max() + 1 x1_step = (x1_max - x1_min) / num_points x2_step = (x2_max - x2_min) / num_points xx, yy = np.meshgrid( np.arange(x1_min, x1_max, x1_step), np.arange(x2_min, x2_max, x2_step), ) mesh = self._maybe_upsample(np.c_[xx.ravel(), yy.ravel()]) Z = clf.predict(mesh) Z = Z.reshape(xx.shape) CS = plt.contour(xx, yy, Z, colors=next(self._colors)) CS.collections[0].set_label(label) if cmap is not None: plt.contourf(xx, yy, Z, cmap=cmap)
[docs] @staticmethod def show(title="", xlabel="", ylabel=""): """Shows the plot""" plt.legend() plt.title = title plt.xlabel = xlabel plt.ylabel = ylabel plt.show()