Add a save method + bugfixes

* Add a `save` method to save images, and a `savepath` argument.
* Fix a bug with legend, due to the introduction of plot groups.
* Continue work on `gridify` method, still WIP.
This commit is contained in:
Lucas Verney 2016-03-03 12:11:15 +01:00
parent 2067e2008f
commit a91da5e77b
4 changed files with 1102 additions and 66 deletions

File diff suppressed because one or more lines are too long

View File

@ -22,7 +22,7 @@ design in the API, or required feature!
<dd>Ever got tired of having to start any figure with a call to
`matplotlib.pyplot.subplots()`? This module abstracts it using `with`
statement. New figures are defined by a `with` statement, and are `show`n
automatically upon leaving the `with` context.
automatically (or `save`d) upon leaving the `with` context.
<dt>Plot functions</dt>
<dd>Ever got annoyed by the fact that `matplotlib` can only plot point
@ -61,6 +61,9 @@ design in the API, or required feature!
<dt>Handle subplots more easily</dt>
<dd>**TODO**</dd>
<dt>"Gridify"</dt>
<dd>**TODO**</dd>
</dl>

View File

@ -3,6 +3,7 @@ The :mod:`replot` module is a (sane) Python plotting module, abstracting on top
of Matplotlib.
"""
import collections
import math
import shutil
import matplotlib.pyplot as plt
@ -11,10 +12,14 @@ import seaborn.apionly as sns
from replot import adaptive_sampling
from replot import exceptions as exc
from replot import tools
__VERSION__ = "0.0.1"
# Constants
_DEFAULT_GROUP = "_"
class Figure():
"""
@ -24,7 +29,7 @@ class Figure():
def __init__(self,
xlabel="", ylabel="", title="",
palette="hls", max_colors=10,
legend=None):
legend=None, savepath=None):
"""
Build a :class:`Figure` object.
@ -43,6 +48,8 @@ class Figure():
behavior. A string indicating position (:mod:`matplotlib` \
format) to put a legend. ``True`` is a synonym for ``best`` \
position.
:param savepath: A path to save the image to (optional). If set, \
the image will be saved on exiting a `with` statement.
"""
# Set default values for attributes
self.xlabel = xlabel
@ -52,46 +59,46 @@ class Figure():
self.max_colors = max_colors
self.legend = legend
self.plots = collections.defaultdict(list) # keys are groups
self.savepath = savepath
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, traceback):
# Do not render the figure if an exception was raised
if exception_type is None:
# Do not draw the figure if an exception was raised
self.show()
if self.savepath is not None:
self.save()
else:
self.show()
def save(self, *args, **kwargs):
"""
Render and save the figure. Also show the figure if possible.
.. note:: Signature is the same as the one from \
``matplotlib.pyplot.savefig``. You should refer to its \
documentation for lengthy details.
.. note:: Typical use is:
>>> with replot.Figure() as figure: figure.save("SOME_FILENAME")
.. note:: If a ``savepath`` has been provided to the :class:`Figure` \
object, you can call ``figure.save()`` without argument and \
this path will be used.
"""
if len(args) == 0 and self.savepath is not None:
args = (self.savepath,)
figure = self._render()
figure.savefig(*args, **kwargs)
def show(self):
"""
Actually render and show the :class:`Figure` object.
Render and show the :class:`Figure` object.
"""
# Use custom matplotlib context
with mpl_custom_rc_context():
# Tweak matplotlib to use seaborn
sns.set()
# Plot using specified color palette
with sns.color_palette(palette=self.palette,
n_colors=self.max_colors):
# Create figure
figure, axes = plt.subplots()
# Add plots
for group_ in self.plots:
for plot_ in self.plots[group_]:
tmp_plots = axes.plot(*(plot_[0]), **(plot_[1]))
# Do not clip line at the axes boundaries to prevent
# extremas from being cropped.
for tmp_plot in tmp_plots:
tmp_plot.set_clip_on(False)
# Set properties
axes.set_xlabel(self.xlabel)
axes.set_ylabel(self.ylabel)
axes.set_title(self.title)
self._legend(axes)
# Draw figure
figure.show()
# Do not forget to restore matplotlib state, in order not to
# interfere with it.
sns.reset_orig()
figure = self._render()
figure.show()
def apply_grid(self, grid_description):
"""
@ -119,6 +126,38 @@ class Figure():
"All rows must have the same number of elements.")
# TODO: Parse grid
def gridify(self, height=None, width=None):
"""
Apply an automatic grid on the figure, trying to fit best to the \
number of plots.
.. note:: This method must be called after all the groups of plots \
have been added to the figure, or the grid will miss some \
groups.
.. note:: The grid will be filled by the groups in lexicographic \
order. Unassigned plots go to the last subplot.
:param height: An optional ``height`` for the grid.
:param width: An optional ``height`` for the grid.
"""
# Find the optimal layout
nb_groups = len(self.plots)
if height is None and width is not None:
height = math.ceil(nb_groups / width)
elif width is None and height is not None:
width = math.ceil(nb_groups / height)
else:
height, width = _optimal_grid(nb_groups)
# Apply the layout
groups = (
sorted([k for k in self.plots.keys() if k != _DEFAULT_GROUP]) +
[_DEFAULT_GROUP])
grid_description = ["".join(batch)
for batch in tools.batch(groups, width)]
self.apply_grid(grid_description)
def plot(self, *args, **kwargs):
"""
Plot something on the :class:`Figure` object.
@ -142,7 +181,8 @@ class Figure():
- ``group`` which permits to group plots together, in \
subplots (one unicode character maximum). ``group`` \
keyword will not affect the render unless you state \
:mod:`replot` to use subplots.
:mod:`replot` to use subplots. Note that ``_`` is a \
reserved group name which cannot be used.
>>> with replot.figure() as fig: fig.plot(np.sin, (-1, 1))
>>> with replot.figure() as fig: fig.plot(np.sin, [-1, -0.9, , 1])
@ -175,7 +215,7 @@ class Figure():
if "group" in custom_kwargs:
self.plots[custom_kwargs["group"]].append(plot_)
else:
self.plots["default"].append(plot_)
self.plots[_DEFAULT_GROUP].append(plot_)
# Automatically set the legend if label is found
# (only do it if legend is not explicitly suppressed)
@ -203,11 +243,47 @@ class Figure():
location.replace("top ", "upper ")
location.replace("bottom ", "lower ")
# Avoid warning if no labels were given for plots
nb_labelled_plots = sum(["label" in plt[1] for plt in self.plots])
nb_labelled_plots = sum(["label" in plt[1]
for group in self.plots.values()
for plt in group])
if nb_labelled_plots > 0:
# Add legend
axes.legend(loc=location)
def _render(self):
"""
Actually render the figure.
:returns: A :mod:`matplotlib` figure.
"""
figure = None
# Use custom matplotlib context
with mpl_custom_rc_context():
# Tweak matplotlib to use seaborn
sns.set()
# Plot using specified color palette
with sns.color_palette(palette=self.palette,
n_colors=self.max_colors):
# Create figure
figure, axes = plt.subplots()
# Add plots
for group_ in self.plots:
for plot_ in self.plots[group_]:
tmp_plots = axes.plot(*(plot_[0]), **(plot_[1]))
# Do not clip line at the axes boundaries to prevent
# extremas from being cropped.
for tmp_plot in tmp_plots:
tmp_plot.set_clip_on(False)
# Set properties
axes.set_xlabel(self.xlabel)
axes.set_ylabel(self.ylabel)
axes.set_title(self.title)
self._legend(axes)
# Do not forget to restore matplotlib state, in order not to
# interfere with it.
sns.reset_orig()
return figure
def plot(data, **kwargs):
"""
@ -286,6 +362,9 @@ def _handle_custom_plot_arguments(kwargs):
if len(kwargs["group"]) > 1:
raise exc.InvalidParameterError(
"Group name cannot be longer than one unicode character.")
elif kwargs["group"] == _DEFAULT_GROUP:
raise exc.InvalidParameterError(
"'%s' is a reserved group name." % (_DEFAULT_GROUP,))
custom_kwargs["group"] = kwargs["group"]
del kwargs["group"]
return (kwargs, custom_kwargs)
@ -325,3 +404,37 @@ def _plot_function(data, *args, **kwargs):
"Second parameter in plot command should be a tuple " +
"specifying plotting interval.")
return ((x_values, y_values) + args[1:], kwargs)
def _optimal_grid(nb_items):
"""
(Naive) attempt to find an optimal grid layout for N elements.
:param nb_items: The number of square elements to put on the grid.
:returns: A tuple ``(height, width)`` containing the number of rows and \
the number of cols of the resulting grid.
>>> _optimal_grid(2)
(1, 2)
>>> _optimal_grid(3)
(1, 3)
>>> _optimal_grid(4)
(2, 2)
"""
# Compute first possibility
height1 = math.floor(math.sqrt(nb_items))
width1 = math.ceil(nb_items / height1)
# Compute second possibility
width2 = math.ceil(math.sqrt(nb_items))
height2 = math.ceil(nb_items / width2)
# Minimize the product of height and width
if height1 * width1 < height2 * width2:
height, width = height1, width1
else:
height, width = height2, width2
return (height, width)

34
replot/tools.py Normal file
View File

@ -0,0 +1,34 @@
"""
This file contains various utility functions.
"""
from itertools import islice, chain
def batch(iterable, size):
"""
Get items from a sequence a batch at a time.
.. note:
Adapted from
https://code.activestate.com/recipes/303279-getting-items-in-batches/.
.. note:
All batches must be exhausted immediately.
:params iterable: An iterable to get batches from.
:params size: Size of the batches.
:returns: A new batch of the given size at each time.
>>> [list(i) for i in batch([1, 2, 3, 4, 5], 2)]
[[1, 2], [3, 4], [5]]
"""
item = iter(iterable)
while True:
batch_iterator = islice(item, size)
try:
yield chain([next(batch_iterator)], batch_iterator)
except StopIteration:
return