Source code for rankeval.visualization.feature

"""
This package provides support for feature analysis visualizations.
"""

from __future__ import print_function
import six

import matplotlib.pyplot as plt

import numpy as np

try:
    xrange
except NameError:
    # Python3's range is Python2's xrange
    xrange = range


[docs]def plot_feature_importance(feature_perf, max_features=10, sort_by="gain", feature_names=None): """ Shows the most important features as a bar plot. Parameters ---------- feature_perf : xarray.DataArray Feature importance stats of the model to be visualized max_features : int or None Maximul number of features to be visualized. If None is passed, it will show all the features sort_by : 'gain' or 'count' The method to use for selecting the top features to display. 'gain' method selects the top features by importance, 'count' selects the top features by usage (i.e., number of times it has been used by a split node). feature_names : list of string The name of the features to use for plotting. If None, their index is used in place of the name (starting from 1). Returns ------- : matplotlib.figure.Figure The matpotlib Figure """ feature_importance = feature_perf.sel(type='importance').data feature_count = feature_perf.sel(type='count').data.astype(np.uint16) # figure fig, ax1 = plt.subplots(figsize=(16, 5)) ax2 = ax1.twinx() if sort_by == "gain": idx_sorted = np.argsort(feature_importance)[::-1] title_by = "Importance" elif sort_by == "count": idx_sorted = np.argsort(feature_count)[::-1] title_by = "Count" else: raise RuntimeError("Sorting of features for visualization " "not supported!") if isinstance(max_features, six.integer_types): idx_sorted = idx_sorted[:max_features] else: max_features = len(feature_importance) top_features = idx_sorted top_importances = feature_importance[idx_sorted] top_counts = feature_count[idx_sorted] index = np.arange(max_features) bar_width = 0.35 opacity = 0.7 bar1 = ax1.bar(index, top_importances, bar_width, alpha=opacity, color='r', align='center', zorder=5, edgecolor='black') bar2 = ax2.bar(index + bar_width, top_counts, bar_width, alpha=opacity, color='b', align='center', zorder=5, edgecolor='black') ax1.set_title('Top-k Features by %s' % title_by) ax1.set_xlabel("Features") if feature_names is not None: feature_names_f = np.array(["%16s" % f for f in feature_names]) ax1.set_xticks(index + bar_width / 2 + 0.15) ax1.set_xticklabels(feature_names_f[idx_sorted], rotation=45, ha="right") else: ax1.set_xticks(index + bar_width / 2) ax1.set_xticklabels(top_features + 1) ax1.set_xlim(-bar_width/2 - bar_width, max_features - 1 + bar_width*5/2) step_y = np.ceil(top_importances.max() * 10) / 100 align_y_axis(ax1, ax2, step_y, 100, num_ticks=6) ax1.set_ylabel("Importance Gain") ax2.set_ylabel("Usage Count") ax1.grid(False) ax2.grid(False) ax1.yaxis.grid(True, ls='--', zorder=0) ax1.legend((bar1, bar2), ("Importance", "Count"), loc='best', shadow=True, frameon=True, fancybox=True) return fig
[docs]def align_y_axis(ax1, ax2, minresax1, minresax2, num_ticks=7): """ Sets tick marks of twinx axes to line up with num_ticks total tick marks ax1 and ax2 are matplotlib axes Spacing between tick marks will be a factor of minresax1 and minresax2""" ax1ylims = ax1.get_ybound() ax2ylims = ax2.get_ybound() ax1factor = minresax1 * (num_ticks - 1) ax2factor = minresax2 * (num_ticks - 1) ax1.set_yticks(np.linspace(ax1ylims[0], ax1ylims[1]+(ax1factor - (ax1ylims[1]-ax1ylims[0]) % ax1factor) % ax1factor, num_ticks)) ax2.set_yticks(np.linspace(ax2ylims[0], ax2ylims[1]+(ax2factor - (ax2ylims[1]-ax2ylims[0]) % ax2factor) % ax2factor, num_ticks))