Add a save
method + bugfixes
* Add a `save` method to save images, and a `savepath` argument. * Fix a bug with legend, due to the introduction of plot groups. * Continue work on `gridify` method, still WIP.
This commit is contained in:
parent
2067e2008f
commit
a91da5e77b
948
Examples.ipynb
948
Examples.ipynb
File diff suppressed because one or more lines are too long
@ -22,7 +22,7 @@ design in the API, or required feature!
|
||||
<dd>Ever got tired of having to start any figure with a call to
|
||||
`matplotlib.pyplot.subplots()`? This module abstracts it using `with`
|
||||
statement. New figures are defined by a `with` statement, and are `show`n
|
||||
automatically upon leaving the `with` context.
|
||||
automatically (or `save`d) upon leaving the `with` context.
|
||||
|
||||
<dt>Plot functions</dt>
|
||||
<dd>Ever got annoyed by the fact that `matplotlib` can only plot point
|
||||
@ -61,6 +61,9 @@ design in the API, or required feature!
|
||||
|
||||
<dt>Handle subplots more easily</dt>
|
||||
<dd>**TODO**</dd>
|
||||
|
||||
<dt>"Gridify"</dt>
|
||||
<dd>**TODO**</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@ The :mod:`replot` module is a (sane) Python plotting module, abstracting on top
|
||||
of Matplotlib.
|
||||
"""
|
||||
import collections
|
||||
import math
|
||||
import shutil
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -11,10 +12,14 @@ import seaborn.apionly as sns
|
||||
|
||||
from replot import adaptive_sampling
|
||||
from replot import exceptions as exc
|
||||
from replot import tools
|
||||
|
||||
|
||||
__VERSION__ = "0.0.1"
|
||||
|
||||
# Constants
|
||||
_DEFAULT_GROUP = "_"
|
||||
|
||||
|
||||
class Figure():
|
||||
"""
|
||||
@ -24,7 +29,7 @@ class Figure():
|
||||
def __init__(self,
|
||||
xlabel="", ylabel="", title="",
|
||||
palette="hls", max_colors=10,
|
||||
legend=None):
|
||||
legend=None, savepath=None):
|
||||
"""
|
||||
Build a :class:`Figure` object.
|
||||
|
||||
@ -43,6 +48,8 @@ class Figure():
|
||||
behavior. A string indicating position (:mod:`matplotlib` \
|
||||
format) to put a legend. ``True`` is a synonym for ``best`` \
|
||||
position.
|
||||
:param savepath: A path to save the image to (optional). If set, \
|
||||
the image will be saved on exiting a `with` statement.
|
||||
"""
|
||||
# Set default values for attributes
|
||||
self.xlabel = xlabel
|
||||
@ -52,46 +59,46 @@ class Figure():
|
||||
self.max_colors = max_colors
|
||||
self.legend = legend
|
||||
self.plots = collections.defaultdict(list) # keys are groups
|
||||
self.savepath = savepath
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
# Do not render the figure if an exception was raised
|
||||
if exception_type is None:
|
||||
# Do not draw the figure if an exception was raised
|
||||
self.show()
|
||||
if self.savepath is not None:
|
||||
self.save()
|
||||
else:
|
||||
self.show()
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
"""
|
||||
Render and save the figure. Also show the figure if possible.
|
||||
|
||||
.. note:: Signature is the same as the one from \
|
||||
``matplotlib.pyplot.savefig``. You should refer to its \
|
||||
documentation for lengthy details.
|
||||
|
||||
.. note:: Typical use is:
|
||||
>>> with replot.Figure() as figure: figure.save("SOME_FILENAME")
|
||||
|
||||
.. note:: If a ``savepath`` has been provided to the :class:`Figure` \
|
||||
object, you can call ``figure.save()`` without argument and \
|
||||
this path will be used.
|
||||
"""
|
||||
if len(args) == 0 and self.savepath is not None:
|
||||
args = (self.savepath,)
|
||||
|
||||
figure = self._render()
|
||||
figure.savefig(*args, **kwargs)
|
||||
|
||||
def show(self):
|
||||
"""
|
||||
Actually render and show the :class:`Figure` object.
|
||||
Render and show the :class:`Figure` object.
|
||||
"""
|
||||
# Use custom matplotlib context
|
||||
with mpl_custom_rc_context():
|
||||
# Tweak matplotlib to use seaborn
|
||||
sns.set()
|
||||
# Plot using specified color palette
|
||||
with sns.color_palette(palette=self.palette,
|
||||
n_colors=self.max_colors):
|
||||
# Create figure
|
||||
figure, axes = plt.subplots()
|
||||
# Add plots
|
||||
for group_ in self.plots:
|
||||
for plot_ in self.plots[group_]:
|
||||
tmp_plots = axes.plot(*(plot_[0]), **(plot_[1]))
|
||||
# Do not clip line at the axes boundaries to prevent
|
||||
# extremas from being cropped.
|
||||
for tmp_plot in tmp_plots:
|
||||
tmp_plot.set_clip_on(False)
|
||||
# Set properties
|
||||
axes.set_xlabel(self.xlabel)
|
||||
axes.set_ylabel(self.ylabel)
|
||||
axes.set_title(self.title)
|
||||
self._legend(axes)
|
||||
# Draw figure
|
||||
figure.show()
|
||||
# Do not forget to restore matplotlib state, in order not to
|
||||
# interfere with it.
|
||||
sns.reset_orig()
|
||||
figure = self._render()
|
||||
figure.show()
|
||||
|
||||
def apply_grid(self, grid_description):
|
||||
"""
|
||||
@ -119,6 +126,38 @@ class Figure():
|
||||
"All rows must have the same number of elements.")
|
||||
# TODO: Parse grid
|
||||
|
||||
def gridify(self, height=None, width=None):
|
||||
"""
|
||||
Apply an automatic grid on the figure, trying to fit best to the \
|
||||
number of plots.
|
||||
|
||||
.. note:: This method must be called after all the groups of plots \
|
||||
have been added to the figure, or the grid will miss some \
|
||||
groups.
|
||||
|
||||
.. note:: The grid will be filled by the groups in lexicographic \
|
||||
order. Unassigned plots go to the last subplot.
|
||||
|
||||
:param height: An optional ``height`` for the grid.
|
||||
:param width: An optional ``height`` for the grid.
|
||||
"""
|
||||
# Find the optimal layout
|
||||
nb_groups = len(self.plots)
|
||||
if height is None and width is not None:
|
||||
height = math.ceil(nb_groups / width)
|
||||
elif width is None and height is not None:
|
||||
width = math.ceil(nb_groups / height)
|
||||
else:
|
||||
height, width = _optimal_grid(nb_groups)
|
||||
|
||||
# Apply the layout
|
||||
groups = (
|
||||
sorted([k for k in self.plots.keys() if k != _DEFAULT_GROUP]) +
|
||||
[_DEFAULT_GROUP])
|
||||
grid_description = ["".join(batch)
|
||||
for batch in tools.batch(groups, width)]
|
||||
self.apply_grid(grid_description)
|
||||
|
||||
def plot(self, *args, **kwargs):
|
||||
"""
|
||||
Plot something on the :class:`Figure` object.
|
||||
@ -142,7 +181,8 @@ class Figure():
|
||||
- ``group`` which permits to group plots together, in \
|
||||
subplots (one unicode character maximum). ``group`` \
|
||||
keyword will not affect the render unless you state \
|
||||
:mod:`replot` to use subplots.
|
||||
:mod:`replot` to use subplots. Note that ``_`` is a \
|
||||
reserved group name which cannot be used.
|
||||
|
||||
>>> with replot.figure() as fig: fig.plot(np.sin, (-1, 1))
|
||||
>>> with replot.figure() as fig: fig.plot(np.sin, [-1, -0.9, …, 1])
|
||||
@ -175,7 +215,7 @@ class Figure():
|
||||
if "group" in custom_kwargs:
|
||||
self.plots[custom_kwargs["group"]].append(plot_)
|
||||
else:
|
||||
self.plots["default"].append(plot_)
|
||||
self.plots[_DEFAULT_GROUP].append(plot_)
|
||||
|
||||
# Automatically set the legend if label is found
|
||||
# (only do it if legend is not explicitly suppressed)
|
||||
@ -203,11 +243,47 @@ class Figure():
|
||||
location.replace("top ", "upper ")
|
||||
location.replace("bottom ", "lower ")
|
||||
# Avoid warning if no labels were given for plots
|
||||
nb_labelled_plots = sum(["label" in plt[1] for plt in self.plots])
|
||||
nb_labelled_plots = sum(["label" in plt[1]
|
||||
for group in self.plots.values()
|
||||
for plt in group])
|
||||
if nb_labelled_plots > 0:
|
||||
# Add legend
|
||||
axes.legend(loc=location)
|
||||
|
||||
def _render(self):
|
||||
"""
|
||||
Actually render the figure.
|
||||
|
||||
:returns: A :mod:`matplotlib` figure.
|
||||
"""
|
||||
figure = None
|
||||
# Use custom matplotlib context
|
||||
with mpl_custom_rc_context():
|
||||
# Tweak matplotlib to use seaborn
|
||||
sns.set()
|
||||
# Plot using specified color palette
|
||||
with sns.color_palette(palette=self.palette,
|
||||
n_colors=self.max_colors):
|
||||
# Create figure
|
||||
figure, axes = plt.subplots()
|
||||
# Add plots
|
||||
for group_ in self.plots:
|
||||
for plot_ in self.plots[group_]:
|
||||
tmp_plots = axes.plot(*(plot_[0]), **(plot_[1]))
|
||||
# Do not clip line at the axes boundaries to prevent
|
||||
# extremas from being cropped.
|
||||
for tmp_plot in tmp_plots:
|
||||
tmp_plot.set_clip_on(False)
|
||||
# Set properties
|
||||
axes.set_xlabel(self.xlabel)
|
||||
axes.set_ylabel(self.ylabel)
|
||||
axes.set_title(self.title)
|
||||
self._legend(axes)
|
||||
# Do not forget to restore matplotlib state, in order not to
|
||||
# interfere with it.
|
||||
sns.reset_orig()
|
||||
return figure
|
||||
|
||||
|
||||
def plot(data, **kwargs):
|
||||
"""
|
||||
@ -286,6 +362,9 @@ def _handle_custom_plot_arguments(kwargs):
|
||||
if len(kwargs["group"]) > 1:
|
||||
raise exc.InvalidParameterError(
|
||||
"Group name cannot be longer than one unicode character.")
|
||||
elif kwargs["group"] == _DEFAULT_GROUP:
|
||||
raise exc.InvalidParameterError(
|
||||
"'%s' is a reserved group name." % (_DEFAULT_GROUP,))
|
||||
custom_kwargs["group"] = kwargs["group"]
|
||||
del kwargs["group"]
|
||||
return (kwargs, custom_kwargs)
|
||||
@ -325,3 +404,37 @@ def _plot_function(data, *args, **kwargs):
|
||||
"Second parameter in plot command should be a tuple " +
|
||||
"specifying plotting interval.")
|
||||
return ((x_values, y_values) + args[1:], kwargs)
|
||||
|
||||
|
||||
def _optimal_grid(nb_items):
|
||||
"""
|
||||
(Naive) attempt to find an optimal grid layout for N elements.
|
||||
|
||||
:param nb_items: The number of square elements to put on the grid.
|
||||
:returns: A tuple ``(height, width)`` containing the number of rows and \
|
||||
the number of cols of the resulting grid.
|
||||
|
||||
>>> _optimal_grid(2)
|
||||
(1, 2)
|
||||
|
||||
>>> _optimal_grid(3)
|
||||
(1, 3)
|
||||
|
||||
>>> _optimal_grid(4)
|
||||
(2, 2)
|
||||
"""
|
||||
# Compute first possibility
|
||||
height1 = math.floor(math.sqrt(nb_items))
|
||||
width1 = math.ceil(nb_items / height1)
|
||||
|
||||
# Compute second possibility
|
||||
width2 = math.ceil(math.sqrt(nb_items))
|
||||
height2 = math.ceil(nb_items / width2)
|
||||
|
||||
# Minimize the product of height and width
|
||||
if height1 * width1 < height2 * width2:
|
||||
height, width = height1, width1
|
||||
else:
|
||||
height, width = height2, width2
|
||||
|
||||
return (height, width)
|
||||
|
34
replot/tools.py
Normal file
34
replot/tools.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""
|
||||
This file contains various utility functions.
|
||||
"""
|
||||
from itertools import islice, chain
|
||||
|
||||
|
||||
def batch(iterable, size):
|
||||
"""
|
||||
Get items from a sequence a batch at a time.
|
||||
|
||||
.. note:
|
||||
|
||||
Adapted from
|
||||
https://code.activestate.com/recipes/303279-getting-items-in-batches/.
|
||||
|
||||
|
||||
.. note:
|
||||
|
||||
All batches must be exhausted immediately.
|
||||
|
||||
:params iterable: An iterable to get batches from.
|
||||
:params size: Size of the batches.
|
||||
:returns: A new batch of the given size at each time.
|
||||
|
||||
>>> [list(i) for i in batch([1, 2, 3, 4, 5], 2)]
|
||||
[[1, 2], [3, 4], [5]]
|
||||
"""
|
||||
item = iter(iterable)
|
||||
while True:
|
||||
batch_iterator = islice(item, size)
|
||||
try:
|
||||
yield chain([next(batch_iterator)], batch_iterator)
|
||||
except StopIteration:
|
||||
return
|
Loading…
Reference in New Issue
Block a user