Grouped Bar Plot

1 minute read

Published:

I often have a need to plot a grouped bar plot. Matplotlib provides this example, which is helpful, but not quite generalizable enough for my needs, as it only shows how to group 2 categories together. Here is a generalization of that tutorial that was very helpful for me and I hope is helpful for others as well.

import matplotlib.pyplot as plt
import numpy as np

from typing import List, Optional


def grouped_barplot(
    data,
    clabels: List[str],
    xlabels: List[str],
    gap: float = 0.3,
    show_legend: bool = True,
    show_bar_labels: bool = True,
    ax: Optional[plt.Axes] = None,
):
    """
    
    Parameters
    ----------
    data: array-like
        size=(len(clabels), len(xlabels))
    clabels list(str):
    xlabels: list(str)
    gap: float
        Gap between categories
    show_legend: bool
        Show legend. Default = True
    show_bar_labels: bool
        Show data values above each bar. Default = True
    ax: plt.Axes
        If not provided, a new figure will be created.

    Returns
    -------
    ax, all_rects

    """


    if ax is None:
        _, ax = plt.subplots()

    x = np.arange(len(xlabels))  # the label locations
    width = (1 - gap) / len(clabels)  # the width of the bars

    all_rects = []
    for i, (cdata, clabel) in enumerate(zip(data, clabels)):
        rects = ax.bar(x - .5 + gap / 2 + i * width, cdata, width, label=clabel)
        if show_bar_labels:
            ax.bar_label(rects, padding=3)
        all_rects.append(rects)

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_xticks(x, xlabels)
    if show_legend:
        ax.legend()

    return ax, all_rects

Example usage:

grouped_barplot(
    data=[[1,2,3,4], [2,3,4,5], [4,5,6,7]],
    clabels=["there", "are", "categories"],
    xlabels=["x", "labels", "go", "here"],
)

grouped_barplot