This commit is contained in:
Lucas Verney 2016-03-04 16:10:46 +01:00
parent 30edf021b6
commit cdb85bc23f

View File

@ -298,12 +298,12 @@ class Figure():
if "label" in kwargs and self.legend is None: if "label" in kwargs and self.legend is None:
self.legend = True self.legend = True
def _legend(self, axe, overload_legend=None): def _legend(self, axis, overload_legend=None):
""" """
Helper function to handle ``legend`` attribute. It places the legend \ Helper function to handle ``legend`` attribute. It places the legend \
correctly depending on attributes and required plots. correctly depending on attributes and required plots.
:param axe: The :mod:`matplotlib` axe to put the legend on. :param axis: The :mod:`matplotlib` axis to put the legend on.
:param overload_legend: An optional legend specification to use \ :param overload_legend: An optional legend specification to use \
instead of the ``legend`` attribute. instead of the ``legend`` attribute.
""" """
@ -329,7 +329,7 @@ class Figure():
for plt in group]) for plt in group])
if nb_labelled_plots > 0: if nb_labelled_plots > 0:
# Add legend # Add legend
axe.legend(loc=location) axis.legend(loc=location)
def _grid(self): def _grid(self):
""" """
@ -342,9 +342,9 @@ class Figure():
if self.grid is False: if self.grid is False:
# If grid is disabled, we plot every group in the same sublot. # If grid is disabled, we plot every group in the same sublot.
axes = {} axes = {}
figure, axe = plt.subplots() figure, axis = plt.subplots()
for subplot in self.plots: for subplot in self.plots:
axes[subplot] = axe axes[subplot] = axis
return figure, axes return figure, axes
elif self.grid is None: elif self.grid is None:
# If no grid is provided, create an auto grid for the figure. # If no grid is provided, create an auto grid for the figure.
@ -362,48 +362,48 @@ class Figure():
colspan=colspan, colspan=colspan,
rowspan=rowspan) rowspan=rowspan)
if _DEFAULT_GROUP not in axes: if _DEFAULT_GROUP not in axes:
# Set the default group axe to None if it is not in the grid # Set the default group axis to None if it is not in the grid
axes[_DEFAULT_GROUP] = None axes[_DEFAULT_GROUP] = None
return figure, axes return figure, axes
def _set_axes_properties(self, axe, group_): def _set_axes_properties(self, axis, group_):
# Set xlabel # Set xlabel
if isinstance(self.xlabel, dict): if isinstance(self.xlabel, dict):
try: try:
axe.set_xlabel(self.xlabel[group_]) axis.set_xlabel(self.xlabel[group_])
except KeyError: except KeyError:
# No entry for this axe in the dict, pass it # No entry for this axis in the dict, pass it
pass pass
else: else:
axe.set_xlabel(self.xlabel) axis.set_xlabel(self.xlabel)
# Set ylabel # Set ylabel
if isinstance(self.ylabel, dict): if isinstance(self.ylabel, dict):
try: try:
axe.set_ylabel(self.ylabel[group_]) axis.set_ylabel(self.ylabel[group_])
except KeyError: except KeyError:
# No entry for this axe in the dict, pass it # No entry for this axis in the dict, pass it
pass pass
else: else:
axe.set_ylabel(self.ylabel) axis.set_ylabel(self.ylabel)
# Set title # Set title
if isinstance(self.title, dict): if isinstance(self.title, dict):
try: try:
axe.set_ylabel(self.title[group_]) axis.set_ylabel(self.title[group_])
except KeyError: except KeyError:
# No entry for this axe in the dict, pass it # No entry for this axis in the dict, pass it
pass pass
else: else:
axe.set_title(self.title) axis.set_title(self.title)
# Set legend # Set legend
if isinstance(self.legend, dict): if isinstance(self.legend, dict):
try: try:
self._legend(axe, overload_legend=self.legend[group_]) self._legend(axis, overload_legend=self.legend[group_])
except KeyError: except KeyError:
# No entry for this axe in the dict, use default argument # No entry for this axis in the dict, use default argument
# That is put a legend except if no plots on this axe. # That is put a legend except if no plots on this axis.
self._legend(axe, overload_legend=True) self._legend(axis, overload_legend=True)
else: else:
self._legend(axe) self._legend(axis)
def _render(self): def _render(self):
""" """
@ -423,24 +423,24 @@ class Figure():
figure, axes = self._grid() figure, axes = self._grid()
# Add plots # Add plots
for group_ in self.plots: for group_ in self.plots:
# Get the axe corresponding to current group # Get the axis corresponding to current group
try: try:
axe = axes[group_] axis = axes[group_]
except KeyError: except KeyError:
# If not found, plot in the default group # If not found, plot in the default group
axe = axes[_DEFAULT_GROUP] axis = axes[_DEFAULT_GROUP]
# Skip this plot if the axe is None # Skip this plot if the axis is None
if axe is None: if axis is None:
continue continue
# Plot # Plot
for plot_ in self.plots[group_]: for plot_ in self.plots[group_]:
tmp_plots = axe.plot(*(plot_[0]), **(plot_[1])) tmp_plots = axis.plot(*(plot_[0]), **(plot_[1]))
# Do not clip line at the axes boundaries to prevent # Do not clip line at the axes boundaries to prevent
# extremas from being cropped. # extremas from being cropped.
for tmp_plot in tmp_plots: for tmp_plot in tmp_plots:
tmp_plot.set_clip_on(False) tmp_plot.set_clip_on(False)
# Set ax properties # Set ax properties
self._set_axes_properties(axe, group_) self._set_axes_properties(axis, group_)
# Do not forget to restore matplotlib state, in order not to # Do not forget to restore matplotlib state, in order not to
# interfere with it. # interfere with it.
sns.reset_orig() sns.reset_orig()