Handle subplots. Close issue https://github.com/Phyks/replot/issues/5
This commit is contained in:
parent
a373f6f106
commit
2ef75513ed
7986
Examples.ipynb
7986
Examples.ipynb
File diff suppressed because one or more lines are too long
@ -19,6 +19,7 @@ import seaborn.apionly as sns
|
|||||||
|
|
||||||
from replot import adaptive_sampling
|
from replot import adaptive_sampling
|
||||||
from replot import exceptions as exc
|
from replot import exceptions as exc
|
||||||
|
from replot import grid_parser
|
||||||
from replot import tools
|
from replot import tools
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ class Figure():
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
xlabel="", ylabel="", title="",
|
xlabel="", ylabel="", title="",
|
||||||
palette="hls", max_colors=10,
|
palette="hls", max_colors=10,
|
||||||
legend=None, savepath=None):
|
legend=None, savepath=None, grid=None):
|
||||||
"""
|
"""
|
||||||
Build a :class:`Figure` object.
|
Build a :class:`Figure` object.
|
||||||
|
|
||||||
@ -57,6 +58,17 @@ class Figure():
|
|||||||
position.
|
position.
|
||||||
:param savepath: A path to save the image to (optional). If set, \
|
:param savepath: A path to save the image to (optional). If set, \
|
||||||
the image will be saved on exiting a `with` statement.
|
the image will be saved on exiting a `with` statement.
|
||||||
|
:param grid: A dict containing the width and height of the grid, and \
|
||||||
|
a description of the grid as a list of subplots. Each subplot \
|
||||||
|
is a tuple of \
|
||||||
|
``((y_position, x_position), symbol, (rowspan, colspan))``. \
|
||||||
|
No check for grid validity is being made.
|
||||||
|
|
||||||
|
.. note:: If you use group plotting, ``xlabel``, ``ylabel`` and \
|
||||||
|
``legend`` will be set uniformly for every subplot. If you \
|
||||||
|
wish to set different properties for every subplots, you \
|
||||||
|
should pass a dict for these properties, keys being the \
|
||||||
|
group symbols and values being the value for each subplot.
|
||||||
"""
|
"""
|
||||||
# Set default values for attributes
|
# Set default values for attributes
|
||||||
self.xlabel = xlabel
|
self.xlabel = xlabel
|
||||||
@ -66,6 +78,7 @@ class Figure():
|
|||||||
self.max_colors = max_colors
|
self.max_colors = max_colors
|
||||||
self.legend = legend
|
self.legend = legend
|
||||||
self.plots = collections.defaultdict(list) # keys are groups
|
self.plots = collections.defaultdict(list) # keys are groups
|
||||||
|
self.grid = grid
|
||||||
self.savepath = savepath
|
self.savepath = savepath
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -119,6 +132,9 @@ class Figure():
|
|||||||
.. note:: Groups are a single unicode character. If a group does not \
|
.. note:: Groups are a single unicode character. If a group does not \
|
||||||
contain any plot, the resulting subplot will simply be empty.
|
contain any plot, the resulting subplot will simply be empty.
|
||||||
|
|
||||||
|
.. note:: Note that if you do not include the default group in the \
|
||||||
|
grid description, it will not be shown.
|
||||||
|
|
||||||
>>> with replot.Figure() as fig: fig.apply_grid(["AAA",
|
>>> with replot.Figure() as fig: fig.apply_grid(["AAA",
|
||||||
"BBC"
|
"BBC"
|
||||||
"DEC"])
|
"DEC"])
|
||||||
@ -131,14 +147,20 @@ class Figure():
|
|||||||
if len(row) != len(grid_description[0]):
|
if len(row) != len(grid_description[0]):
|
||||||
raise exc.InvalidParameterError(
|
raise exc.InvalidParameterError(
|
||||||
"All rows must have the same number of elements.")
|
"All rows must have the same number of elements.")
|
||||||
# TODO: Parse grid
|
# Parse the ASCII art grid
|
||||||
|
self.grid = grid_parser.parse_ascii(grid_description)
|
||||||
|
if self.grid is None:
|
||||||
|
# If grid is not valid, raise an exception
|
||||||
|
raise exc.InvalidParameterError(
|
||||||
|
"Invalid grid provided. You did not use rectangular areas " +
|
||||||
|
"for each group.")
|
||||||
|
|
||||||
def gridify(self, height=None, width=None):
|
def gridify(self, height=None, width=None):
|
||||||
"""
|
"""
|
||||||
Apply an automatic grid on the figure, trying to fit best to the \
|
Apply an automatic grid on the figure, trying to fit best to the \
|
||||||
number of plots.
|
number of plots.
|
||||||
|
|
||||||
.. note:: This method must be called after all the groups of plots \
|
.. note:: This method must be called after all the plots \
|
||||||
have been added to the figure, or the grid will miss some \
|
have been added to the figure, or the grid will miss some \
|
||||||
groups.
|
groups.
|
||||||
|
|
||||||
@ -158,9 +180,12 @@ class Figure():
|
|||||||
height, width = _optimal_grid(nb_groups)
|
height, width = _optimal_grid(nb_groups)
|
||||||
|
|
||||||
# Apply the layout
|
# Apply the layout
|
||||||
groups = (
|
groups = sorted([k
|
||||||
sorted([k for k in self.plots.keys() if k != _DEFAULT_GROUP]) +
|
for k in self.plots.keys()
|
||||||
[_DEFAULT_GROUP])
|
if k != _DEFAULT_GROUP and len(self.plots[k]) > 0])
|
||||||
|
if len(self.plots[_DEFAULT_GROUP]) > 0:
|
||||||
|
# Handle default group separately
|
||||||
|
groups.append(_DEFAULT_GROUP)
|
||||||
grid_description = ["".join(batch)
|
grid_description = ["".join(batch)
|
||||||
for batch in tools.batch(groups, width)]
|
for batch in tools.batch(groups, width)]
|
||||||
self.apply_grid(grid_description)
|
self.apply_grid(grid_description)
|
||||||
@ -229,23 +254,28 @@ class Figure():
|
|||||||
if "label" in kwargs and self.legend is None:
|
if "label" in kwargs and self.legend is None:
|
||||||
self.legend = True
|
self.legend = True
|
||||||
|
|
||||||
def _legend(self, axes):
|
def _legend(self, axe, overload_legend=None):
|
||||||
"""
|
"""
|
||||||
Helper function to handle ``legend`` attribute. It places the legend \
|
Helper function to handle ``legend`` attribute. It places the legend \
|
||||||
correctly depending on attributes and required plots.
|
correctly depending on attributes and required plots.
|
||||||
|
|
||||||
:param axes: The :mod:`matplotlib` axes to put the legend on.
|
:param axe: The :mod:`matplotlib` axe to put the legend on.
|
||||||
|
:param overload_legend: An optional legend specification to use \
|
||||||
|
instead of the ``legend`` attribute.
|
||||||
"""
|
"""
|
||||||
|
if overload_legend is None:
|
||||||
|
overload_legend = self.legend
|
||||||
|
|
||||||
# If no legend is required, just pass
|
# If no legend is required, just pass
|
||||||
if self.legend is None or self.legend is False:
|
if overload_legend is None or overload_legend is False:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.legend is True:
|
if overload_legend is True:
|
||||||
# If there should be a legend, but no location provided, put it at
|
# If there should be a legend, but no location provided, put it at
|
||||||
# best location.
|
# best location.
|
||||||
location = "best"
|
location = "best"
|
||||||
else:
|
else:
|
||||||
location = self.legend
|
location = overload_legend
|
||||||
# Create aliases for "upper" / "top" and "lower" / "bottom"
|
# Create aliases for "upper" / "top" and "lower" / "bottom"
|
||||||
location.replace("top ", "upper ")
|
location.replace("top ", "upper ")
|
||||||
location.replace("bottom ", "lower ")
|
location.replace("bottom ", "lower ")
|
||||||
@ -255,7 +285,76 @@ class Figure():
|
|||||||
for plt in group])
|
for plt in group])
|
||||||
if nb_labelled_plots > 0:
|
if nb_labelled_plots > 0:
|
||||||
# Add legend
|
# Add legend
|
||||||
axes.legend(loc=location)
|
axe.legend(loc=location)
|
||||||
|
|
||||||
|
def _grid(self):
|
||||||
|
"""
|
||||||
|
Create subplots according to the grid description.
|
||||||
|
|
||||||
|
:returns: A tuple containing the figure object as first element, and \
|
||||||
|
a dict mapping the symbols of the groups to matplotlib axes \
|
||||||
|
as second element.
|
||||||
|
"""
|
||||||
|
if self.grid is None:
|
||||||
|
# If no grid is provided, return a figure and a dict containing
|
||||||
|
# only the default group, occupying the whole figure.
|
||||||
|
figure, axes = plt.subplots()
|
||||||
|
return figure, {_DEFAULT_GROUP: axes}
|
||||||
|
else:
|
||||||
|
# Axes is a dict associating symbols to matplotlib axes
|
||||||
|
axes = {}
|
||||||
|
figure = plt.figure()
|
||||||
|
# Build all the axes
|
||||||
|
grid_size = (self.grid["height"], self.grid["width"])
|
||||||
|
for subplot in self.grid["grid"]:
|
||||||
|
position, symbol, (rowspan, colspan) = subplot
|
||||||
|
axes[symbol] = plt.subplot2grid(grid_size,
|
||||||
|
position,
|
||||||
|
colspan=colspan,
|
||||||
|
rowspan=rowspan)
|
||||||
|
if _DEFAULT_GROUP not in axes:
|
||||||
|
# Set the default group axe to None if it is not in the grid
|
||||||
|
axes[_DEFAULT_GROUP] = None
|
||||||
|
return figure, axes
|
||||||
|
|
||||||
|
def _set_axes_properties(self, axe, group_):
|
||||||
|
# Set xlabel
|
||||||
|
if isinstance(self.xlabel, dict):
|
||||||
|
try:
|
||||||
|
axe.set_xlabel(self.xlabel[group_])
|
||||||
|
except KeyError:
|
||||||
|
# No entry for this axe in the dict, pass it
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
axe.set_xlabel(self.xlabel)
|
||||||
|
# Set ylabel
|
||||||
|
if isinstance(self.ylabel, dict):
|
||||||
|
try:
|
||||||
|
axe.set_ylabel(self.ylabel[group_])
|
||||||
|
except KeyError:
|
||||||
|
# No entry for this axe in the dict, pass it
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
axe.set_ylabel(self.ylabel)
|
||||||
|
# Set title
|
||||||
|
if isinstance(self.title, dict):
|
||||||
|
try:
|
||||||
|
axe.set_ylabel(self.title[group_])
|
||||||
|
except KeyError:
|
||||||
|
# No entry for this axe in the dict, pass it
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
axe.set_title(self.title)
|
||||||
|
# Set legend
|
||||||
|
if isinstance(self.legend, dict):
|
||||||
|
try:
|
||||||
|
self._legend(axe, overload_legend=self.legend[group_])
|
||||||
|
except KeyError:
|
||||||
|
# No entry for this axe in the dict, use default argument
|
||||||
|
# That is put a legend except if no plots on this axe.
|
||||||
|
self._legend(axe, overload_legend=True)
|
||||||
|
else:
|
||||||
|
self._legend(axe)
|
||||||
|
|
||||||
def _render(self):
|
def _render(self):
|
||||||
"""
|
"""
|
||||||
@ -272,20 +371,27 @@ class Figure():
|
|||||||
with sns.color_palette(palette=self.palette,
|
with sns.color_palette(palette=self.palette,
|
||||||
n_colors=self.max_colors):
|
n_colors=self.max_colors):
|
||||||
# Create figure
|
# Create figure
|
||||||
figure, axes = plt.subplots()
|
figure, axes = self._grid()
|
||||||
# Add plots
|
# Add plots
|
||||||
for group_ in self.plots:
|
for group_ in self.plots:
|
||||||
|
# Get the axe corresponding to current group
|
||||||
|
try:
|
||||||
|
axe = axes[group_]
|
||||||
|
except KeyError:
|
||||||
|
# If not found, plot in the default group
|
||||||
|
axe = axes[_DEFAULT_GROUP]
|
||||||
|
# Skip this plot if the axe is None
|
||||||
|
if axe is None:
|
||||||
|
continue
|
||||||
|
# Plot
|
||||||
for plot_ in self.plots[group_]:
|
for plot_ in self.plots[group_]:
|
||||||
tmp_plots = axes.plot(*(plot_[0]), **(plot_[1]))
|
tmp_plots = axe.plot(*(plot_[0]), **(plot_[1]))
|
||||||
# Do not clip line at the axes boundaries to prevent
|
# Do not clip line at the axes boundaries to prevent
|
||||||
# extremas from being cropped.
|
# extremas from being cropped.
|
||||||
for tmp_plot in tmp_plots:
|
for tmp_plot in tmp_plots:
|
||||||
tmp_plot.set_clip_on(False)
|
tmp_plot.set_clip_on(False)
|
||||||
# Set properties
|
# Set ax properties
|
||||||
axes.set_xlabel(self.xlabel)
|
self._set_axes_properties(axe, group_)
|
||||||
axes.set_ylabel(self.ylabel)
|
|
||||||
axes.set_title(self.title)
|
|
||||||
self._legend(axes)
|
|
||||||
# Do not forget to restore matplotlib state, in order not to
|
# Do not forget to restore matplotlib state, in order not to
|
||||||
# interfere with it.
|
# interfere with it.
|
||||||
sns.reset_orig()
|
sns.reset_orig()
|
||||||
|
@ -14,7 +14,8 @@ def parse_ascii(M):
|
|||||||
:param M: A list of strings, each string representing a row.
|
:param M: A list of strings, each string representing a row.
|
||||||
:returns: A dict containing the width and height of the grid, and a \
|
:returns: A dict containing the width and height of the grid, and a \
|
||||||
description of the grid as a list of subplots. Each subplot is a \
|
description of the grid as a list of subplots. Each subplot is a \
|
||||||
tuple of ((y_position, x_position), symbol, (rowspan, colspan)). \
|
tuple of \
|
||||||
|
``((y_position, x_position), symbol, (rowspan, colspan))``. \
|
||||||
Returns ``None`` if the matrix could not be parsed.
|
Returns ``None`` if the matrix could not be parsed.
|
||||||
|
|
||||||
.. note:: This function expects ``M`` to represent a valid rectangular \
|
.. note:: This function expects ``M`` to represent a valid rectangular \
|
||||||
|
Loading…
Reference in New Issue
Block a user