Add a save
method + bugfixes
* Add a `save` method to save images, and a `savepath` argument. * Fix a bug with legend, due to the introduction of plot groups. * Continue work on `gridify` method, still WIP.
This commit is contained in:
parent
2067e2008f
commit
a91da5e77b
948
Examples.ipynb
948
Examples.ipynb
File diff suppressed because one or more lines are too long
@ -22,7 +22,7 @@ design in the API, or required feature!
|
|||||||
<dd>Ever got tired of having to start any figure with a call to
|
<dd>Ever got tired of having to start any figure with a call to
|
||||||
`matplotlib.pyplot.subplots()`? This module abstracts it using `with`
|
`matplotlib.pyplot.subplots()`? This module abstracts it using `with`
|
||||||
statement. New figures are defined by a `with` statement, and are `show`n
|
statement. New figures are defined by a `with` statement, and are `show`n
|
||||||
automatically upon leaving the `with` context.
|
automatically (or `save`d) upon leaving the `with` context.
|
||||||
|
|
||||||
<dt>Plot functions</dt>
|
<dt>Plot functions</dt>
|
||||||
<dd>Ever got annoyed by the fact that `matplotlib` can only plot point
|
<dd>Ever got annoyed by the fact that `matplotlib` can only plot point
|
||||||
@ -61,6 +61,9 @@ design in the API, or required feature!
|
|||||||
|
|
||||||
<dt>Handle subplots more easily</dt>
|
<dt>Handle subplots more easily</dt>
|
||||||
<dd>**TODO**</dd>
|
<dd>**TODO**</dd>
|
||||||
|
|
||||||
|
<dt>"Gridify"</dt>
|
||||||
|
<dd>**TODO**</dd>
|
||||||
</dl>
|
</dl>
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ The :mod:`replot` module is a (sane) Python plotting module, abstracting on top
|
|||||||
of Matplotlib.
|
of Matplotlib.
|
||||||
"""
|
"""
|
||||||
import collections
|
import collections
|
||||||
|
import math
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -11,10 +12,14 @@ import seaborn.apionly as sns
|
|||||||
|
|
||||||
from replot import adaptive_sampling
|
from replot import adaptive_sampling
|
||||||
from replot import exceptions as exc
|
from replot import exceptions as exc
|
||||||
|
from replot import tools
|
||||||
|
|
||||||
|
|
||||||
__VERSION__ = "0.0.1"
|
__VERSION__ = "0.0.1"
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
_DEFAULT_GROUP = "_"
|
||||||
|
|
||||||
|
|
||||||
class Figure():
|
class Figure():
|
||||||
"""
|
"""
|
||||||
@ -24,7 +29,7 @@ class Figure():
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
xlabel="", ylabel="", title="",
|
xlabel="", ylabel="", title="",
|
||||||
palette="hls", max_colors=10,
|
palette="hls", max_colors=10,
|
||||||
legend=None):
|
legend=None, savepath=None):
|
||||||
"""
|
"""
|
||||||
Build a :class:`Figure` object.
|
Build a :class:`Figure` object.
|
||||||
|
|
||||||
@ -43,6 +48,8 @@ class Figure():
|
|||||||
behavior. A string indicating position (:mod:`matplotlib` \
|
behavior. A string indicating position (:mod:`matplotlib` \
|
||||||
format) to put a legend. ``True`` is a synonym for ``best`` \
|
format) to put a legend. ``True`` is a synonym for ``best`` \
|
||||||
position.
|
position.
|
||||||
|
:param savepath: A path to save the image to (optional). If set, \
|
||||||
|
the image will be saved on exiting a `with` statement.
|
||||||
"""
|
"""
|
||||||
# Set default values for attributes
|
# Set default values for attributes
|
||||||
self.xlabel = xlabel
|
self.xlabel = xlabel
|
||||||
@ -52,46 +59,46 @@ class Figure():
|
|||||||
self.max_colors = max_colors
|
self.max_colors = max_colors
|
||||||
self.legend = legend
|
self.legend = legend
|
||||||
self.plots = collections.defaultdict(list) # keys are groups
|
self.plots = collections.defaultdict(list) # keys are groups
|
||||||
|
self.savepath = savepath
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exception_type, exception_value, traceback):
|
def __exit__(self, exception_type, exception_value, traceback):
|
||||||
|
# Do not render the figure if an exception was raised
|
||||||
if exception_type is None:
|
if exception_type is None:
|
||||||
# Do not draw the figure if an exception was raised
|
if self.savepath is not None:
|
||||||
self.show()
|
self.save()
|
||||||
|
else:
|
||||||
|
self.show()
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Render and save the figure. Also show the figure if possible.
|
||||||
|
|
||||||
|
.. note:: Signature is the same as the one from \
|
||||||
|
``matplotlib.pyplot.savefig``. You should refer to its \
|
||||||
|
documentation for lengthy details.
|
||||||
|
|
||||||
|
.. note:: Typical use is:
|
||||||
|
>>> with replot.Figure() as figure: figure.save("SOME_FILENAME")
|
||||||
|
|
||||||
|
.. note:: If a ``savepath`` has been provided to the :class:`Figure` \
|
||||||
|
object, you can call ``figure.save()`` without argument and \
|
||||||
|
this path will be used.
|
||||||
|
"""
|
||||||
|
if len(args) == 0 and self.savepath is not None:
|
||||||
|
args = (self.savepath,)
|
||||||
|
|
||||||
|
figure = self._render()
|
||||||
|
figure.savefig(*args, **kwargs)
|
||||||
|
|
||||||
def show(self):
|
def show(self):
|
||||||
"""
|
"""
|
||||||
Actually render and show the :class:`Figure` object.
|
Render and show the :class:`Figure` object.
|
||||||
"""
|
"""
|
||||||
# Use custom matplotlib context
|
figure = self._render()
|
||||||
with mpl_custom_rc_context():
|
figure.show()
|
||||||
# Tweak matplotlib to use seaborn
|
|
||||||
sns.set()
|
|
||||||
# Plot using specified color palette
|
|
||||||
with sns.color_palette(palette=self.palette,
|
|
||||||
n_colors=self.max_colors):
|
|
||||||
# Create figure
|
|
||||||
figure, axes = plt.subplots()
|
|
||||||
# Add plots
|
|
||||||
for group_ in self.plots:
|
|
||||||
for plot_ in self.plots[group_]:
|
|
||||||
tmp_plots = axes.plot(*(plot_[0]), **(plot_[1]))
|
|
||||||
# Do not clip line at the axes boundaries to prevent
|
|
||||||
# extremas from being cropped.
|
|
||||||
for tmp_plot in tmp_plots:
|
|
||||||
tmp_plot.set_clip_on(False)
|
|
||||||
# Set properties
|
|
||||||
axes.set_xlabel(self.xlabel)
|
|
||||||
axes.set_ylabel(self.ylabel)
|
|
||||||
axes.set_title(self.title)
|
|
||||||
self._legend(axes)
|
|
||||||
# Draw figure
|
|
||||||
figure.show()
|
|
||||||
# Do not forget to restore matplotlib state, in order not to
|
|
||||||
# interfere with it.
|
|
||||||
sns.reset_orig()
|
|
||||||
|
|
||||||
def apply_grid(self, grid_description):
|
def apply_grid(self, grid_description):
|
||||||
"""
|
"""
|
||||||
@ -119,6 +126,38 @@ class Figure():
|
|||||||
"All rows must have the same number of elements.")
|
"All rows must have the same number of elements.")
|
||||||
# TODO: Parse grid
|
# TODO: Parse grid
|
||||||
|
|
||||||
|
def gridify(self, height=None, width=None):
|
||||||
|
"""
|
||||||
|
Apply an automatic grid on the figure, trying to fit best to the \
|
||||||
|
number of plots.
|
||||||
|
|
||||||
|
.. note:: This method must be called after all the groups of plots \
|
||||||
|
have been added to the figure, or the grid will miss some \
|
||||||
|
groups.
|
||||||
|
|
||||||
|
.. note:: The grid will be filled by the groups in lexicographic \
|
||||||
|
order. Unassigned plots go to the last subplot.
|
||||||
|
|
||||||
|
:param height: An optional ``height`` for the grid.
|
||||||
|
:param width: An optional ``height`` for the grid.
|
||||||
|
"""
|
||||||
|
# Find the optimal layout
|
||||||
|
nb_groups = len(self.plots)
|
||||||
|
if height is None and width is not None:
|
||||||
|
height = math.ceil(nb_groups / width)
|
||||||
|
elif width is None and height is not None:
|
||||||
|
width = math.ceil(nb_groups / height)
|
||||||
|
else:
|
||||||
|
height, width = _optimal_grid(nb_groups)
|
||||||
|
|
||||||
|
# Apply the layout
|
||||||
|
groups = (
|
||||||
|
sorted([k for k in self.plots.keys() if k != _DEFAULT_GROUP]) +
|
||||||
|
[_DEFAULT_GROUP])
|
||||||
|
grid_description = ["".join(batch)
|
||||||
|
for batch in tools.batch(groups, width)]
|
||||||
|
self.apply_grid(grid_description)
|
||||||
|
|
||||||
def plot(self, *args, **kwargs):
|
def plot(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Plot something on the :class:`Figure` object.
|
Plot something on the :class:`Figure` object.
|
||||||
@ -142,7 +181,8 @@ class Figure():
|
|||||||
- ``group`` which permits to group plots together, in \
|
- ``group`` which permits to group plots together, in \
|
||||||
subplots (one unicode character maximum). ``group`` \
|
subplots (one unicode character maximum). ``group`` \
|
||||||
keyword will not affect the render unless you state \
|
keyword will not affect the render unless you state \
|
||||||
:mod:`replot` to use subplots.
|
:mod:`replot` to use subplots. Note that ``_`` is a \
|
||||||
|
reserved group name which cannot be used.
|
||||||
|
|
||||||
>>> with replot.figure() as fig: fig.plot(np.sin, (-1, 1))
|
>>> with replot.figure() as fig: fig.plot(np.sin, (-1, 1))
|
||||||
>>> with replot.figure() as fig: fig.plot(np.sin, [-1, -0.9, …, 1])
|
>>> with replot.figure() as fig: fig.plot(np.sin, [-1, -0.9, …, 1])
|
||||||
@ -175,7 +215,7 @@ class Figure():
|
|||||||
if "group" in custom_kwargs:
|
if "group" in custom_kwargs:
|
||||||
self.plots[custom_kwargs["group"]].append(plot_)
|
self.plots[custom_kwargs["group"]].append(plot_)
|
||||||
else:
|
else:
|
||||||
self.plots["default"].append(plot_)
|
self.plots[_DEFAULT_GROUP].append(plot_)
|
||||||
|
|
||||||
# Automatically set the legend if label is found
|
# Automatically set the legend if label is found
|
||||||
# (only do it if legend is not explicitly suppressed)
|
# (only do it if legend is not explicitly suppressed)
|
||||||
@ -203,11 +243,47 @@ class Figure():
|
|||||||
location.replace("top ", "upper ")
|
location.replace("top ", "upper ")
|
||||||
location.replace("bottom ", "lower ")
|
location.replace("bottom ", "lower ")
|
||||||
# Avoid warning if no labels were given for plots
|
# Avoid warning if no labels were given for plots
|
||||||
nb_labelled_plots = sum(["label" in plt[1] for plt in self.plots])
|
nb_labelled_plots = sum(["label" in plt[1]
|
||||||
|
for group in self.plots.values()
|
||||||
|
for plt in group])
|
||||||
if nb_labelled_plots > 0:
|
if nb_labelled_plots > 0:
|
||||||
# Add legend
|
# Add legend
|
||||||
axes.legend(loc=location)
|
axes.legend(loc=location)
|
||||||
|
|
||||||
|
def _render(self):
|
||||||
|
"""
|
||||||
|
Actually render the figure.
|
||||||
|
|
||||||
|
:returns: A :mod:`matplotlib` figure.
|
||||||
|
"""
|
||||||
|
figure = None
|
||||||
|
# Use custom matplotlib context
|
||||||
|
with mpl_custom_rc_context():
|
||||||
|
# Tweak matplotlib to use seaborn
|
||||||
|
sns.set()
|
||||||
|
# Plot using specified color palette
|
||||||
|
with sns.color_palette(palette=self.palette,
|
||||||
|
n_colors=self.max_colors):
|
||||||
|
# Create figure
|
||||||
|
figure, axes = plt.subplots()
|
||||||
|
# Add plots
|
||||||
|
for group_ in self.plots:
|
||||||
|
for plot_ in self.plots[group_]:
|
||||||
|
tmp_plots = axes.plot(*(plot_[0]), **(plot_[1]))
|
||||||
|
# Do not clip line at the axes boundaries to prevent
|
||||||
|
# extremas from being cropped.
|
||||||
|
for tmp_plot in tmp_plots:
|
||||||
|
tmp_plot.set_clip_on(False)
|
||||||
|
# Set properties
|
||||||
|
axes.set_xlabel(self.xlabel)
|
||||||
|
axes.set_ylabel(self.ylabel)
|
||||||
|
axes.set_title(self.title)
|
||||||
|
self._legend(axes)
|
||||||
|
# Do not forget to restore matplotlib state, in order not to
|
||||||
|
# interfere with it.
|
||||||
|
sns.reset_orig()
|
||||||
|
return figure
|
||||||
|
|
||||||
|
|
||||||
def plot(data, **kwargs):
|
def plot(data, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -286,6 +362,9 @@ def _handle_custom_plot_arguments(kwargs):
|
|||||||
if len(kwargs["group"]) > 1:
|
if len(kwargs["group"]) > 1:
|
||||||
raise exc.InvalidParameterError(
|
raise exc.InvalidParameterError(
|
||||||
"Group name cannot be longer than one unicode character.")
|
"Group name cannot be longer than one unicode character.")
|
||||||
|
elif kwargs["group"] == _DEFAULT_GROUP:
|
||||||
|
raise exc.InvalidParameterError(
|
||||||
|
"'%s' is a reserved group name." % (_DEFAULT_GROUP,))
|
||||||
custom_kwargs["group"] = kwargs["group"]
|
custom_kwargs["group"] = kwargs["group"]
|
||||||
del kwargs["group"]
|
del kwargs["group"]
|
||||||
return (kwargs, custom_kwargs)
|
return (kwargs, custom_kwargs)
|
||||||
@ -325,3 +404,37 @@ def _plot_function(data, *args, **kwargs):
|
|||||||
"Second parameter in plot command should be a tuple " +
|
"Second parameter in plot command should be a tuple " +
|
||||||
"specifying plotting interval.")
|
"specifying plotting interval.")
|
||||||
return ((x_values, y_values) + args[1:], kwargs)
|
return ((x_values, y_values) + args[1:], kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _optimal_grid(nb_items):
|
||||||
|
"""
|
||||||
|
(Naive) attempt to find an optimal grid layout for N elements.
|
||||||
|
|
||||||
|
:param nb_items: The number of square elements to put on the grid.
|
||||||
|
:returns: A tuple ``(height, width)`` containing the number of rows and \
|
||||||
|
the number of cols of the resulting grid.
|
||||||
|
|
||||||
|
>>> _optimal_grid(2)
|
||||||
|
(1, 2)
|
||||||
|
|
||||||
|
>>> _optimal_grid(3)
|
||||||
|
(1, 3)
|
||||||
|
|
||||||
|
>>> _optimal_grid(4)
|
||||||
|
(2, 2)
|
||||||
|
"""
|
||||||
|
# Compute first possibility
|
||||||
|
height1 = math.floor(math.sqrt(nb_items))
|
||||||
|
width1 = math.ceil(nb_items / height1)
|
||||||
|
|
||||||
|
# Compute second possibility
|
||||||
|
width2 = math.ceil(math.sqrt(nb_items))
|
||||||
|
height2 = math.ceil(nb_items / width2)
|
||||||
|
|
||||||
|
# Minimize the product of height and width
|
||||||
|
if height1 * width1 < height2 * width2:
|
||||||
|
height, width = height1, width1
|
||||||
|
else:
|
||||||
|
height, width = height2, width2
|
||||||
|
|
||||||
|
return (height, width)
|
||||||
|
34
replot/tools.py
Normal file
34
replot/tools.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
"""
|
||||||
|
This file contains various utility functions.
|
||||||
|
"""
|
||||||
|
from itertools import islice, chain
|
||||||
|
|
||||||
|
|
||||||
|
def batch(iterable, size):
|
||||||
|
"""
|
||||||
|
Get items from a sequence a batch at a time.
|
||||||
|
|
||||||
|
.. note:
|
||||||
|
|
||||||
|
Adapted from
|
||||||
|
https://code.activestate.com/recipes/303279-getting-items-in-batches/.
|
||||||
|
|
||||||
|
|
||||||
|
.. note:
|
||||||
|
|
||||||
|
All batches must be exhausted immediately.
|
||||||
|
|
||||||
|
:params iterable: An iterable to get batches from.
|
||||||
|
:params size: Size of the batches.
|
||||||
|
:returns: A new batch of the given size at each time.
|
||||||
|
|
||||||
|
>>> [list(i) for i in batch([1, 2, 3, 4, 5], 2)]
|
||||||
|
[[1, 2], [3, 4], [5]]
|
||||||
|
"""
|
||||||
|
item = iter(iterable)
|
||||||
|
while True:
|
||||||
|
batch_iterator = islice(item, size)
|
||||||
|
try:
|
||||||
|
yield chain([next(batch_iterator)], batch_iterator)
|
||||||
|
except StopIteration:
|
||||||
|
return
|
Loading…
Reference in New Issue
Block a user