Start to work on animations and use tight_layout()

This commit is contained in:
Lucas Verney 2016-03-18 15:15:56 +01:00
parent 5e633743cf
commit e424b049ba
3 changed files with 971 additions and 65 deletions

File diff suppressed because one or more lines are too long

View File

@ -14,6 +14,7 @@ try:
os.environ["DISPLAY"] os.environ["DISPLAY"]
except KeyError: except KeyError:
mpl.use("agg") mpl.use("agg")
import matplotlib.animation as animation
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import palettable import palettable
@ -113,6 +114,9 @@ class Figure():
# Working attributes # Working attributes
self.figure = None self.figure = None
self.axes = None self.axes = None
self.animation = {"type": False,
"args": (), "kwargs": {},
"persist": []}
def __enter__(self): def __enter__(self):
return self return self
@ -143,15 +147,21 @@ class Figure():
if len(args) == 0 and self.savepath is not None: if len(args) == 0 and self.savepath is not None:
args = (self.savepath,) args = (self.savepath,)
figure = self.render() self.render()
figure.savefig(*args, **kwargs) if self.figure is not None:
self.figure.savefig(*args, **kwargs)
else:
raise exc.InvalidFigure("Invalid figure.")
def show(self): def show(self):
""" """
Render and show the :class:`Figure` object. Render and show the :class:`Figure` object.
""" """
figure = self.render() self.render()
figure.show() if self.figure is not None:
self.figure.show()
else:
raise exc.InvalidFigure("Invalid figure.")
def set_grid(self, grid_description=None, def set_grid(self, grid_description=None,
height=None, width=None, ignore_groups=False, height=None, width=None, ignore_groups=False,
@ -306,6 +316,10 @@ class Figure():
the axes labels as well. the axes labels as well.
- ``rotate`` (angle in degrees) rotate the plot by the angle in \ - ``rotate`` (angle in degrees) rotate the plot by the angle in \
degrees. Leave the labels untouched. degrees. Leave the labels untouched.
- ``frame`` to specify a frame on which the plot should appear \
when calling ``animate`` afterwards. Default behavior is \
to increase the frame number between each plots. Frame \
count starts at 0.
.. note:: Note that this API call considers list of tuples as \ .. note:: Note that this API call considers list of tuples as \
list of (x, y) coordinates to plot, contrary to standard \ list of (x, y) coordinates to plot, contrary to standard \
@ -376,9 +390,10 @@ class Figure():
# Add the plot to the correct group # Add the plot to the correct group
if "group" in custom_kwargs: if "group" in custom_kwargs:
self.plots[custom_kwargs["group"]].append(plot_) group_ = custom_kwargs["group"]
else: else:
self.plots[_DEFAULT_GROUP].append(plot_) group_ = _DEFAULT_GROUP
self.plots[group_].append(plot_)
# Automatically set the legend if label is found # Automatically set the legend if label is found
# (only do it if legend is not explicitly suppressed) # (only do it if legend is not explicitly suppressed)
@ -407,6 +422,20 @@ class Figure():
kwargs["logscale"] = "loglog" kwargs["logscale"] = "loglog"
self.plot(*args, **kwargs) self.plot(*args, **kwargs)
def animate(self, *args, **kwargs):
"""
Create an animation.
You can either:
- pass it a function TODO
- use it directly without arguments to create an animation from \
the previously plot commands.
"""
# TODO
self.animation["type"] = "gif"
self.animation["args"] = args
self.animation["kwargs"] = kwargs
def _legend(self, axis, overload_legend=None): def _legend(self, axis, overload_legend=None):
""" """
Helper function to handle ``legend`` attribute. It places the legend \ Helper function to handle ``legend`` attribute. It places the legend \
@ -454,11 +483,18 @@ class Figure():
figure, axis = plt.subplots() figure, axis = plt.subplots()
# Set the palette for the subplot # Set the palette for the subplot
axis.set_prop_cycle( axis.set_prop_cycle(
self._build_cycler_palette(sum([len(i) for i in self.plots]))) self._build_cycler_palette(sum([len(i)
for i in self.plots.values()]))
)
# Set the axis for every subplot # Set the axis for every subplot
for subplot in self.plots: for subplot in self.plots:
axes[subplot] = axis axes[subplot] = axis
return figure, axes
# Set attributes
self.figure = figure
self.axes = axes
# Return
return
elif self.grid is None: elif self.grid is None:
# If no grid is provided, create an auto grid for the figure. # If no grid is provided, create an auto grid for the figure.
self._set_auto_grid() self._set_auto_grid()
@ -590,39 +626,90 @@ class Figure():
if self.figure is None or self.axes is None: if self.figure is None or self.axes is None:
# Create figure if necessary # Create figure if necessary
self._grid() self._grid()
# Add plots
for group_ in self.plots: if self.animation["type"] is False:
# Get the axis corresponding to current group self._render_no_animation()
try: elif self.animation["type"] == "gif":
axis = self.axes[group_] self._render_gif_animation()
except KeyError: elif self.animation["type"] == "animation":
# If not found, plot in the default group # TODO
axis = self.axes[_DEFAULT_GROUP] return None
# Skip this plot if the axis is None else:
if axis is None: return None
continue self.figure.tight_layout(pad=1)
# Plot
for plot_ in self.plots[group_]: def _render_gif_animation(self):
tmp_plots = axis.plot(*(plot_[0]), **(plot_[1])) """
# Handle custom kwargs at plotting time Handle the render of a GIF-like animation, cycling through the plots.
if "logscale" in plot_[2]:
if plot_[2]["logscale"] == "log": :returns: A :mod:`matplotlib` figure.
axis.set_xscale("log") """
elif plot_[2]["logscale"] == "loglog": # Init
axis.set_xscale("log") # TODO
axis.set_yscale("log") self.axes[_DEFAULT_GROUP].set_xlim((-2, 2))
if "orthonormal" in plot_[2] and plot_[2]["orthonormal"]: self.axes[_DEFAULT_GROUP].set_ylim((-2, 2))
axis.set_aspect("equal") line, = self.axes[_DEFAULT_GROUP].plot([], [])
if "xlim" in plot_[2]: # Define an animation function (closure)
axis.set_xlim(*plot_[2]["xlim"]) def animate(i):
if "ylim" in plot_[2]: # TODO
axis.set_ylim(*plot_[2]["ylim"]) x = np.linspace(0, 2, 1000)
# Do not clip line at the axes boundaries to prevent y = np.sin(2 * np.pi * (x - 0.01 * i))
# extremas from being cropped. line.set_data(x, y)
for tmp_plot in tmp_plots: return line,
tmp_plot.set_clip_on(False) # Set default kwargs
# Set ax properties default_args = (self.figure, animate)
self._set_axes_properties(axis, group_) default_kwargs = {
"frames": 200,
"interval": 20,
"blit": True,
}
# Update with overloaded arguments
default_args = default_args + self.animation["args"]
default_kwargs.update(self.animation["kwargs"])
# Keep track of animation object, as it has to persist
self.animation["persist"] = [
animation.FuncAnimation(*default_args, **default_kwargs)]
return self.figure
def _render_no_animation(self):
"""
Handle the render of the figure when no animation is used.
:returns: A :mod:`matplotlib` figure.
"""
# Add plots
for group_ in self.plots:
# Get the axis corresponding to current group
try:
axis = self.axes[group_]
except KeyError:
# If not found, plot in the default group
axis = self.axes[_DEFAULT_GROUP]
# Skip this plot if the axis is None
if axis is None:
continue
# Plot
for plot_ in self.plots[group_]:
tmp_plots = axis.plot(*(plot_[0]), **(plot_[1]))
# Handle custom kwargs at plotting time
if "logscale" in plot_[2]:
if plot_[2]["logscale"] == "log":
axis.set_xscale("log")
elif plot_[2]["logscale"] == "loglog":
axis.set_xscale("log")
axis.set_yscale("log")
if "orthonormal" in plot_[2] and plot_[2]["orthonormal"]:
axis.set_aspect("equal")
if "xlim" in plot_[2]:
axis.set_xlim(*plot_[2]["xlim"])
if "ylim" in plot_[2]:
axis.set_ylim(*plot_[2]["ylim"])
# Do not clip line at the axes boundaries to prevent
# extremas from being cropped.
for tmp_plot in tmp_plots:
tmp_plot.set_clip_on(False)
# Set ax properties
self._set_axes_properties(axis, group_)
return self.figure return self.figure
@ -829,6 +916,12 @@ def _handle_custom_plot_arguments(kwargs):
# Convert angle to radians # Convert angle to radians
custom_kwargs["rotate"] = kwargs["rotate"] * np.pi / 180 custom_kwargs["rotate"] = kwargs["rotate"] * np.pi / 180
del kwargs["rotate"] del kwargs["rotate"]
# Handle "frame" argument
if "frame" in kwargs:
custom_kwargs["frame"] = kwargs["frame"]
del kwargs["frame"]
else:
custom_kwargs["frame"] = 0
return (kwargs, custom_kwargs) return (kwargs, custom_kwargs)

View File

@ -15,3 +15,10 @@ class InvalidParameterError(BaseException):
Exception raised when an invalid parameter is provided. Exception raised when an invalid parameter is provided.
""" """
pass pass
class InvalidFigure(BaseException):
"""
Exception raised when a figure is invalid and cannot be shown.
"""
pass