replot/replot/__init__.py

815 lines
31 KiB
Python

"""
The :mod:`replot` module is a (sane) Python plotting module, abstracting on top
of Matplotlib.
"""
import collections
import math
import os
import shutil
import cycler
import matplotlib as mpl
# Use "agg" backend automatically if no display is available.
try:
os.environ["DISPLAY"]
except KeyError:
mpl.use("agg")
import matplotlib.pyplot as plt
import numpy as np
import palettable
from replot import adaptive_sampling
from replot import exceptions as exc
from replot import grid_parser
from replot import tools
__VERSION__ = "0.0.1"
#############
# CONSTANTS #
#############
_DEFAULT_GROUP = "_"
# Default palette is husl palette with length 10 color cycle
def _default_palette(n):
"""
Default palette is a CubeHelix perceptual rainbow palette with length the
number of plots.
:param n: The number of colors in the palette.
:returns: The palette as a list of colors (as RGB tuples).
"""
return palettable.cubehelix.Cubehelix.make(
start_hue=240., end_hue=-300., min_sat=1., max_sat=2.5,
min_light=0.3, max_light=0.8, gamma=.9, n=n).mpl_colors
class Figure():
"""
The main class from :mod:`replot`, representing a figure. Can be used \
directly or in a ``with`` statement.
"""
def __init__(self,
xlabel="", ylabel="", title="",
xrange=None, yrange=None,
palette=_default_palette,
legend=None, savepath=None, grid=None,
custom_mpl_rc=None):
"""
Build a :class:`Figure` object.
:param xlabel: Label for the X axis (optional).
:param ylabel: Label for the Z axis (optional).
:param title: Title of the figure (optional).
:param xrange: Range of the X axis (optional), as a tuple \
representing the interval.
:param yrange: Range of the Y axis (optional), as a tuple \
representing the interval.
:param palette: Color palette to use (optional). Defaults to a safe \
palette with compatibility with colorblindness and black and \
white printing.
:type palette: Either a list of colors (as RGB tuples) or a function \
to call with number of plots as parameter and which returns a \
list of colors (as RGB tuples). You can also pass a Seaborn \
palette directly, or use a Palettable Palette.mpl_colors.
:param legend: Whether to use a legend or not (optional). Defaults to \
no legend, except if labels are found on provided plots. \
``False`` to disable completely. ``None`` for default \
behavior. A string indicating position (:mod:`matplotlib` \
format) to put a legend. ``True`` is a synonym for ``best`` \
position.
:param savepath: A path to save the image to (optional). If set, \
the image will be saved on exiting a `with` statement.
:param grid: A dict containing the width and height of the grid, and \
a description of the grid as a list of subplots. Each subplot \
is a tuple of \
``((y_position, x_position), symbol, (rowspan, colspan))``. \
No check for grid validity is being made. You can set it to \
``False`` to disable it completely.
:param custom_mpl_rc: An optional dict to overload some \
:mod:`matplotlib` rc params.
.. note:: If you use group plotting, ``xlabel``, ``ylabel``, \
``legend``, ``xrange``, ``yrange`` and ``zrange`` will be \
set uniformly for every subplot. If you wish to set \
different properties for every subplots, you \
should pass a dict for these properties, keys being the \
group symbols and values being the value for each subplot.
"""
# Set default values for attributes
self.xlabel = xlabel
self.ylabel = ylabel
self.title = title
self.xrange = xrange
self.yrange = yrange
self.palette = palette
self.legend = legend
self.plots = collections.defaultdict(list) # keys are groups
self.grid = grid
self.savepath = savepath
self.custom_mpl_rc = None
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, traceback):
# Do not render the figure if an exception was raised
if exception_type is None:
if self.savepath is not None:
self.save()
else:
self.show()
def save(self, *args, **kwargs):
"""
Render and save the figure. Also show the figure if possible.
.. note:: Signature is the same as the one from \
``matplotlib.pyplot.savefig``. You should refer to its \
documentation for lengthy details.
.. note:: Typical use is:
>>> with replot.Figure() as figure: figure.save("SOME_FILENAME")
.. note:: If a ``savepath`` has been provided to the :class:`Figure` \
object, you can call ``figure.save()`` without argument and \
this path will be used.
"""
if len(args) == 0 and self.savepath is not None:
args = (self.savepath,)
figure = self._render()
figure.savefig(*args, **kwargs)
def show(self):
"""
Render and show the :class:`Figure` object.
"""
figure = self._render()
figure.show()
def set_grid(self, grid_description=None,
height=None, width=None, ignore_groups=False,
auto=False):
"""
Apply a grid layout on the figure (subplots). Subplots are based on \
defined groups (see ``group`` keyword to \
``replot.Figure.plot``).
:param grid_description: A list of rows. Each row is a string \
containing the groups to display (can be seen as ASCII art). \
Can be a single string in case of a single row.
:param height: An optional ``height`` for the grid, implies \
``auto=True``.
:param width: An optional ``height`` for the grid, implies \
``auto=True``.
:param ignore_groups: (optional, implies ``auto=True``) By default, \
``set_grid`` will use groups to organize plots in different \
subplots. If you want to put every plot in a different \
subplot, regardless of their groups, you can set this \
to ``True``.
:param auto: Whether the grid should be guessed automatically from \
groups or not (optional).
.. note:: Groups are a single unicode character. If a group does not \
contain any plot, the resulting subplot will simply be empty.
.. note:: Note that if you do not include the default group in the \
grid description, it will not be shown.
>>> with replot.Figure() as fig: fig.set_grid(["AAA",
"BBC"
"DEC"])
"""
# Handle auto gridifying
if auto or height is not None or width is not None or ignore_groups:
self._set_auto_grid(height, width, ignore_groups)
return
# Default parameters
if grid_description is None:
grid_description = []
elif grid_description is False:
# Disable the grid and return
self.grid = False
return
elif isinstance(grid_description, str):
# If a single string is provided, enclose it in a list.
grid_description = [grid_description]
# Check that grid is not empty
if len(grid_description) == 0:
raise exc.InvalidParameterError("Grid cannot be an empty list.")
# Check that all rows have the same number of elements
for row in grid_description:
if len(row) != len(grid_description[0]):
raise exc.InvalidParameterError(
"All rows must have the same number of elements.")
# Parse the ASCII art grid
parsed_grid = grid_parser.parse_ascii(grid_description)
if parsed_grid is None:
# If grid is not valid, raise an exception
raise exc.InvalidParameterError(
"Invalid grid provided. You did not use rectangular areas " +
"for each group.")
# Set the grid
self.grid = parsed_grid
def _set_auto_grid(self, height=None, width=None, ignore_groups=False):
"""
Apply an automatic grid on the figure, trying to fit best to the \
number of plots.
.. note:: This method must be called after all the plots \
have been added to the figure, or the grid will miss some \
groups.
.. note:: The grid will be filled by the groups in lexicographic \
order. Unassigned plots go to the last subplot.
:param height: An optional ``height`` for the grid.
:param width: An optional ``height`` for the grid.
:param ignore_groups: By default, ``set_grid`` will use groups to \
organize plots in different subplots. If you want to put \
every plot in a different subplot, regardless of their \
groups, you can set this to ``True``.
"""
if ignore_groups:
# If we want to ignore groups, we will start by creating a new
# group for every existing plot
existing_plots = []
for group_ in self.plots:
existing_plots.extend(self.plots[group_])
self.plots = collections.defaultdict(list,
{chr(i): [existing_plots[i]]
for i in
range(len(existing_plots))})
# Find the optimal layout
nb_groups = len(self.plots)
if height is None and width is not None:
height = math.ceil(nb_groups / width)
elif width is None and height is not None:
width = math.ceil(nb_groups / height)
else:
height, width = _optimal_grid(nb_groups)
# Apply the layout
groups = sorted([k
for k in self.plots.keys()
if k != _DEFAULT_GROUP and len(self.plots[k]) > 0])
if len(self.plots[_DEFAULT_GROUP]) > 0:
# Handle default group separately
groups.append(_DEFAULT_GROUP)
grid_description = ["".join(batch)
for batch in tools.batch(groups, width)]
self.set_grid(grid_description)
def plot(self, *args, **kwargs):
"""
Plot something on the :class:`Figure` object.
.. note:: This function expects ``args`` and ``kwargs`` to support
every possible case. You can either pass it (see examples):
- A single argument, being a series of points or a function.
- Two series of points representing X values and Y values \
(standard :mod:`matplotlib` behavior).
- Two arguments being a function and a list of points at which \
it should be evaluated (X values).
- Two arguments being a function and an interval represented by \
a tuple of its bounds.
.. note:: ``kwargs`` arguments are passed to \
``matplotlib.pyplot.plot``.
.. note:: You can use some :mod:`replot` specific keyword arguments:
- ``group`` which permits to group plots together, in \
subplots (one unicode character maximum). ``group`` \
keyword will not affect the render unless you state \
:mod:`replot` to use subplots. Note that ``_`` is a \
reserved group name which cannot be used.
.. note:: Note that this API call considers list of tuples as \
list of (x, y) coordinates to plot, contrary to standard \
matplotlib API which considers it is two different plots.
>>> with replot.figure() as fig: fig.plot(np.sin, (-1, 1))
>>> with replot.figure() as fig: fig.plot(np.sin, [-1, -0.9, …, 1])
>>> with replot.figure() as fig: fig.plot([1, 2, 3], [4, 5, 6])
>>> with replot.figure() as fig: fig.plot([1, 2, 3],
[4, 5, 6], linewidth=2.0)
>>> with replot.figure() as fig: fig.plot([1, 2, 3],
[4, 5, 6], group="a")
"""
if len(args) == 0:
raise exc.InvalidParameterError(
"You should pass at least one argument to this function.")
# Extract custom kwargs (the ones from replot but not matplotlib) from
# kwargs
if kwargs is not None:
kwargs, custom_kwargs = _handle_custom_plot_arguments(kwargs)
else:
custom_kwargs = {}
if hasattr(args[0], "__call__"):
# We want to plot a function
plot_ = _plot_function(args[0], *(args[1:]), **kwargs)
else:
# Else, it is a point series, and we just have to store it for
# later plotting.
if hasattr(args[0], "__iter__"):
try:
# If we pass it a list of tuples, consider it as a list of
# (x, y) coordinates contrary to the standard matplotlib
# behavior
x_list, y_list = zip(*args[0])
args = (list(x_list),
list(y_list)) + args[1:]
except (TypeError, StopIteration, AssertionError):
pass
plot_ = (args, kwargs)
# Keep track of the custom kwargs
plot_ += (custom_kwargs,)
# Add the plot to the correct group
if "group" in custom_kwargs:
self.plots[custom_kwargs["group"]].append(plot_)
else:
self.plots[_DEFAULT_GROUP].append(plot_)
# Automatically set the legend if label is found
# (only do it if legend is not explicitly suppressed)
if "label" in kwargs and self.legend is None:
self.legend = True
def logplot(self, *args, **kwargs):
"""
Plot something on the :class:`Figure` object, in log scale.
.. note:: See :func:`replot.Figure.plot` for the full documentation.
>>> with replot.figure() as fig: fig.logplot(np.log, (-1, 1))
"""
kwargs["logscale"] = "log"
self.plot(*args, **kwargs)
def loglogplot(self, *args, **kwargs):
"""
Plot something on the :class:`Figure` object, in log-log scale.
.. note:: See :func:`replot.Figure.plot` for the full documentation.
>>> with replot.figure() as fig: fig.logplot(np.log, (-1, 1))
"""
kwargs["logscale"] = "loglog"
self.plot(*args, **kwargs)
def _legend(self, axis, overload_legend=None):
"""
Helper function to handle ``legend`` attribute. It places the legend \
correctly depending on attributes and required plots.
:param axis: The :mod:`matplotlib` axis to put the legend on.
:param overload_legend: An optional legend specification to use \
instead of the ``legend`` attribute.
"""
if overload_legend is None:
overload_legend = self.legend
# If no legend is required, just pass
if overload_legend is None or overload_legend is False:
return
if overload_legend is True:
# If there should be a legend, but no location provided, put it at
# best location.
location = "best"
else:
location = overload_legend
# Create aliases for "upper" / "top" and "lower" / "bottom"
location.replace("top ", "upper ")
location.replace("bottom ", "lower ")
# Avoid warning if no labels were given for plots
nb_labelled_plots = sum(["label" in plt[1]
for group in self.plots.values()
for plt in group])
if nb_labelled_plots > 0:
# Add legend
axis.legend(loc=location)
def _grid(self):
"""
Create subplots according to the grid description.
:returns: A tuple containing the figure object as first element, and \
a dict mapping the symbols of the groups to matplotlib axes \
as second element.
"""
if self.grid is False:
# If grid is disabled, we plot every group in the same sublot.
axes = {}
figure, axis = plt.subplots()
# Set the palette for the subplot
axis.set_prop_cycle(
self._build_cycler_palette(sum([len(i) for i in self.plots])))
# Set the axis for every subplot
for subplot in self.plots:
axes[subplot] = axis
return figure, axes
elif self.grid is None:
# If no grid is provided, create an auto grid for the figure.
self._set_auto_grid()
# Axes is a dict associating symbols to matplotlib axes
axes = {}
figure = plt.figure()
# Build all the axes
grid_size = (self.grid["height"], self.grid["width"])
for subplot in self.grid["grid"]:
position, symbol, (rowspan, colspan) = subplot
axes[symbol] = plt.subplot2grid(grid_size,
position,
colspan=colspan,
rowspan=rowspan)
# Set the palette for the subplot
axes[symbol].set_prop_cycle(
self._build_cycler_palette(len(self.plots[symbol])))
if _DEFAULT_GROUP not in axes:
# Set the default group axis to None if it is not in the grid
axes[_DEFAULT_GROUP] = None
return figure, axes
def _set_axes_properties(self, axis, group_):
# Set xlabel
if isinstance(self.xlabel, dict):
try:
axis.set_xlabel(self.xlabel[group_])
except KeyError:
# No entry for this axis in the dict, pass it
pass
else:
axis.set_xlabel(self.xlabel)
# Set ylabel
if isinstance(self.ylabel, dict):
try:
axis.set_ylabel(self.ylabel[group_])
except KeyError:
# No entry for this axis in the dict, pass it
pass
else:
axis.set_ylabel(self.ylabel)
# Set title
if isinstance(self.title, dict):
try:
axis.set_ylabel(self.title[group_])
except KeyError:
# No entry for this axis in the dict, pass it
pass
else:
axis.set_title(self.title)
# Set legend
if isinstance(self.legend, dict):
try:
self._legend(axis, overload_legend=self.legend[group_])
except KeyError:
# No entry for this axis in the dict, use default argument
# That is put a legend except if no plots on this axis.
self._legend(axis, overload_legend=True)
else:
self._legend(axis)
# Set xrange
if isinstance(self.xrange, dict):
try:
if self.xrange[group_] is not None:
axis.set_xlim(*self.xrange[group_])
except KeyError:
# No entry for this axis in the dict, pass it
pass
else:
if self.xrange is not None:
axis.set_xlim(*self.xrange)
# Set yrange
if isinstance(self.yrange, dict):
try:
if self.yrange[group_] is not None:
axis.set_ylim(*self.yrange[group_])
except KeyError:
# No entry for this axis in the dict, pass it
pass
else:
if self.yrange is not None:
axis.set_ylim(*self.yrange)
def _build_cycler_palette(self, n):
"""
Build a cycler palette for the selected subplot.
:param n: number of colors in the palette.
:returns: a cycler object for the palette.
"""
if hasattr(self.palette, "__call__"):
return cycler.cycler("color", self.palette(n))
else:
return cycler.cycler("color", self.palette)
def _render(self):
"""
Actually render the figure.
:returns: A :mod:`matplotlib` figure.
"""
figure = None
# Use custom matplotlib context
with mpl_custom_rc_context(rc=self.custom_mpl_rc):
# Create figure
figure, axes = self._grid()
# Add plots
for group_ in self.plots:
# Get the axis corresponding to current group
try:
axis = axes[group_]
except KeyError:
# If not found, plot in the default group
axis = axes[_DEFAULT_GROUP]
# Skip this plot if the axis is None
if axis is None:
continue
# Plot
for plot_ in self.plots[group_]:
tmp_plots = axis.plot(*(plot_[0]), **(plot_[1]))
# Handle custom kwargs at plotting time
if "logscale" in plot_[2]:
if plot_[2]["logscale"] == "log":
axis.set_xscale("log")
elif plot_[2]["logscale"] == "loglog":
axis.set_xscale("log")
axis.set_yscale("log")
# Do not clip line at the axes boundaries to prevent
# extremas from being cropped.
for tmp_plot in tmp_plots:
tmp_plot.set_clip_on(False)
# Set ax properties
self._set_axes_properties(axis, group_)
return figure
def plot(data, **kwargs):
"""
Helper function to make one-liner plots. Typical use case is:
>>> replot.plot([range(10),
(np.sin, (-5, 5)),
(lambda x: np.sin(x) + 4, (-10, 10), {"linewidth": 10}),
(lambda x: np.sin(x) - 4, (-10, 10), {"linewidth": 10}),
([-i for i in range(5)], {"linewidth": 10})],
xlabel="some x label",
ylabel="some y label",
title="A title for the figure",
legend="best",
palette=seaborn.color_palette("husl", 2))
"""
# Init new figure
figure = Figure(**kwargs)
# data is a list of plotting commands
for plot_ in data:
# If we provide a tuple, handle it
if isinstance(plot_, tuple):
args = ()
kwargs = {}
# First case, only two items provided
if len(plot_) == 2:
# Parse args and kwargs according to type of items
if isinstance(plot_[1], tuple):
args = (plot_[1],)
elif isinstance(plot_[1], dict):
kwargs = plot_[1]
# Second case, at least 3 items provided
elif len(plot_) > 2:
# Then, args and kwargs are well defined
args = (plot_[1],)
kwargs = plot_[2]
# Pass the correct argument to plot function
figure.plot(plot_[0], *args, **kwargs)
else:
figure.plot(plot_)
figure.show()
def _mpl_custom_rc_scaling():
"""
Scale the elements of the figure to get a better rendering.
Settings borrowed from
[Seaborn](https://github.com/mwaskom/seaborn/blob/master/seaborn/rcmod.py#L344).
:returns: a :mod:`matplotlib` ``rcParams``-like dict.
"""
rc_params = {
"figure.figsize": np.array([8, 5.5]),
# Set misc font sizes
"font.size": 12,
"axes.labelsize": 11,
"axes.titlesize": 12,
"xtick.labelsize": 10,
"ytick.labelsize": 10,
"legend.fontsize": 10,
# Set misc linewidth
"grid.linewidth": 1,
"lines.linewidth": 1.75,
"patch.linewidth": .3,
"lines.markersize": 7,
"lines.markeredgewidth": 1.75,
# Disable ticks
"xtick.major.width": 0,
"ytick.major.width": 0,
"xtick.minor.width": 0,
"ytick.minor.width": 0,
# Set ticks padding
"xtick.major.pad": 7,
"ytick.major.pad": 7,
}
return rc_params
def _mpl_custom_rc_axes_style():
"""
Set the style of the plot and the axes. Things like set a grid etc.
Settings borrowed from
[Seaborn](https://github.com/mwaskom/seaborn/blob/master/seaborn/rcmod.py#L344).
:returns: a :mod:`matplotlib` ``rcParams``-like dict.
"""
# Use dark gray instead of black for better readability on screen
dark_gray = ".15"
rc_params = {
# Colors
"figure.facecolor": "white",
"text.color": dark_gray,
# Legend
"legend.frameon": False, # No frame around legend
"legend.numpoints": 1,
"legend.scatterpoints": 1,
# Ticks
"xtick.direction": "out",
"ytick.direction": "out",
"xtick.color": dark_gray,
"ytick.color": dark_gray,
"lines.solid_capstyle": "round",
# Axes
"axes.axisbelow": True,
"axes.linewidth": 0,
"axes.labelcolor": dark_gray,
"axes.grid": True,
"axes.facecolor": "EAEAF2",
"axes.edgecolor": "white",
# Grid
"grid.linestyle": "-",
"grid.color": "white",
# Image
"image.cmap": "Greys"
}
return rc_params
def mpl_custom_rc_context(rc=None):
"""
Overload ``matplotlib.rcParams`` to enable advanced features if \
available. In particular, use LaTeX if available.
:param rc: An optional dict to overload some :mod:`matplotlib` rc params.
:returns: A ``matplotlib.rc_context`` object to use in a ``with`` \
statement.
"""
custom_rc = {}
# Add LaTeX in rc if available
if(shutil.which("latex") is not None and
shutil.which("gs") is not None and
shutil.which("dvipng") is not None):
# LateX dependencies are all available
custom_rc["text.usetex"] = True
custom_rc["text.latex.unicode"] = True
# Use LaTeX default font family
# See https://stackoverflow.com/questions/17958485/matplotlib-not-using-latex-font-while-text-usetex-true
custom_rc["font.family"] = "serif"
custom_rc["font.serif"] = "cm"
# Scale everything
custom_rc.update(_mpl_custom_rc_scaling())
# Set axes style
custom_rc.update(_mpl_custom_rc_axes_style())
# Overload if necessary
if rc is not None:
custom_rc.update(rc)
# Return a context object
return plt.rc_context(rc=custom_rc)
def _handle_custom_plot_arguments(kwargs):
"""
This method handles custom keyword arguments from plot in \
:mod:`replot` which are not in :mod:`matplotlib` function.
:param kwargs: A dictionary of keyword arguments to handle.
:return: A tuple of :mod:`matplotlib` compatible keyword arguments \
and of extra :mod:`replot` keyword arguments, both returned \
as ``dict``.
"""
custom_kwargs = {}
# Handle "group" argument
if "group" in kwargs:
if len(kwargs["group"]) > 1:
raise exc.InvalidParameterError(
"Group name cannot be longer than one unicode character.")
elif kwargs["group"] == _DEFAULT_GROUP:
raise exc.InvalidParameterError(
"'%s' is a reserved group name." % (_DEFAULT_GROUP,))
custom_kwargs["group"] = kwargs["group"]
del kwargs["group"]
# Handle "line" argument
if "line" in kwargs:
if not kwargs["line"]: # If should not draw lines, set kwargs for it
kwargs["linestyle"] = "None"
kwargs["marker"] = "x"
del kwargs["line"]
# Handle "logscale" argument
if "logscale" in kwargs:
custom_kwargs["logscale"] = kwargs["logscale"]
del kwargs["logscale"]
return (kwargs, custom_kwargs)
def _plot_function(data, *args, **kwargs):
"""
Helper function to handle plotting of unevaluated functions (trying \
to evaluate it nicely and rendering the plot).
:param data: The function to plot.
:returns: A tuple of ``(args, kwargs)`` representing the plot.
.. seealso:: The documentation of the ``replot.Figure.plot`` method.
.. note:: ``args`` is used to handle the interval or point series on \
which the function should be evaluated. ``kwargs`` are passed \
directly to ``matplotlib.pyplot.plot`.
"""
if len(args) == 0:
# If no interval specified, raise an issue
raise exc.InvalidParameterError(
"You should pass a plotting interval to the plot command.")
elif isinstance(args[0], tuple):
# Interval specified, use it and adaptive plotting
x_values, y_values = adaptive_sampling.sample_function(
data,
args[0],
tol=1e-3)
elif isinstance(args[0], (list, np.ndarray)):
# List of points specified, use them and compute values of the
# function
x_values = args[0]
y_values = [data(i) for i in x_values]
else:
raise exc.InvalidParameterError(
"Second parameter in plot command should be a tuple " +
"specifying plotting interval.")
return ((x_values, y_values) + args[1:], kwargs)
def _optimal_grid(nb_items):
"""
(Naive) attempt to find an optimal grid layout for N elements.
:param nb_items: The number of square elements to put on the grid.
:returns: A tuple ``(height, width)`` containing the number of rows and \
the number of cols of the resulting grid.
>>> _optimal_grid(2)
(1, 2)
>>> _optimal_grid(3)
(1, 3)
>>> _optimal_grid(4)
(2, 2)
"""
# Compute first possibility
height1 = math.floor(math.sqrt(nb_items))
width1 = math.ceil(nb_items / height1)
# Compute second possibility
width2 = math.ceil(math.sqrt(nb_items))
height2 = math.ceil(nb_items / width2)
# Minimize the product of height and width
if height1 * width1 < height2 * width2:
height, width = height1, width1
else:
height, width = height2, width2
return (height, width)