Better handling of legend
Automatically add a legend to the graph if labelled data is found, except if user explicitly disabled the legend.
This commit is contained in:
parent
b3f2b03e8f
commit
81bb851e10
1599
Examples.ipynb
1599
Examples.ipynb
File diff suppressed because one or more lines are too long
16
README.md
16
README.md
@ -38,6 +38,22 @@ design in the API, or required feature!
|
|||||||
abstracts on top of `matplotlib` and do the actual render only when the
|
abstracts on top of `matplotlib` and do the actual render only when the
|
||||||
figure is to be `show`n. Even after having called the `show` method, you
|
figure is to be `show`n. Even after having called the `show` method, you
|
||||||
can still change everything in your figure!</dd>
|
can still change everything in your figure!</dd>
|
||||||
|
|
||||||
|
<dt>Does not interfere with `matplotlib`</dt>
|
||||||
|
<dd>You can still use the default `matplotlib` if you want, as
|
||||||
|
`matplotlib` state and parameters are not directly affected by this
|
||||||
|
module, contrary to what `seaborn` do when you import it for
|
||||||
|
instance.</dd>
|
||||||
|
|
||||||
|
<dt>Useful aliases</dt>
|
||||||
|
<dd>You think `loc="top left"` is easier to remember than `loc="upper
|
||||||
|
left"` in a `matplotlib.pyplot.legend()` call? No worry, this module
|
||||||
|
aliases it for you! (same for "bottom" with respect to "lower")</dd>
|
||||||
|
|
||||||
|
<dt>Automatic legend</dt>
|
||||||
|
<dd>If any of your plots contains a `label` keyword, a legend will be
|
||||||
|
added automatically on your graph (you can still explicitly tell it not to
|
||||||
|
add a legend by setting the `legend` attribute to `False`).</dd>
|
||||||
</dl>
|
</dl>
|
||||||
|
|
||||||
|
|
||||||
|
32
replot.py
32
replot.py
@ -8,6 +8,12 @@ import numpy as np
|
|||||||
import seaborn
|
import seaborn
|
||||||
|
|
||||||
|
|
||||||
|
# Cancel seaborn modifications of matplotlib, we do not want to interfere in
|
||||||
|
# any ways with matplotlib
|
||||||
|
seaborn.reset_orig()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Remove it, this is interfering with matplotlib
|
||||||
plt.rcParams['figure.figsize'] = (10.0, 8.0) # Larger figures by default
|
plt.rcParams['figure.figsize'] = (10.0, 8.0) # Larger figures by default
|
||||||
plt.rcParams['text.usetex'] = True # Use LaTeX rendering
|
plt.rcParams['text.usetex'] = True # Use LaTeX rendering
|
||||||
|
|
||||||
@ -16,16 +22,17 @@ class Figure():
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
xlabel="", ylabel="", title="", palette="hls",
|
xlabel="", ylabel="", title="", palette="hls",
|
||||||
legend=None):
|
legend=None):
|
||||||
|
# TODO: Constants
|
||||||
self.max_colors = 10
|
self.max_colors = 10
|
||||||
self.default_points_number = 1000
|
self.default_points_number = 1000
|
||||||
self.default_x_interval = np.linspace(-10, 10,
|
self.default_x_interval = np.linspace(-10, 10,
|
||||||
self.default_points_number)
|
self.default_points_number)
|
||||||
|
|
||||||
|
# Set default values for attributes
|
||||||
self.xlabel = xlabel
|
self.xlabel = xlabel
|
||||||
self.ylabel = ylabel
|
self.ylabel = ylabel
|
||||||
self.title = title
|
self.title = title
|
||||||
self.palette = palette
|
self.palette = palette
|
||||||
# TODO: Legend should be automatic if labelled data is found
|
|
||||||
self.legend = legend
|
self.legend = legend
|
||||||
self.plots = []
|
self.plots = []
|
||||||
|
|
||||||
@ -37,8 +44,11 @@ class Figure():
|
|||||||
|
|
||||||
def show(self):
|
def show(self):
|
||||||
"""
|
"""
|
||||||
|
Actually render and show the figure.
|
||||||
"""
|
"""
|
||||||
|
# Tweak matplotlib to use seaborn
|
||||||
seaborn.set()
|
seaborn.set()
|
||||||
|
# Plot using specified color palette
|
||||||
with seaborn.color_palette(self.palette, self.max_colors):
|
with seaborn.color_palette(self.palette, self.max_colors):
|
||||||
# Create figure
|
# Create figure
|
||||||
figure, axes = plt.subplots()
|
figure, axes = plt.subplots()
|
||||||
@ -49,10 +59,12 @@ class Figure():
|
|||||||
axes.set_xlabel(self.xlabel)
|
axes.set_xlabel(self.xlabel)
|
||||||
axes.set_ylabel(self.ylabel)
|
axes.set_ylabel(self.ylabel)
|
||||||
axes.set_title(self.title)
|
axes.set_title(self.title)
|
||||||
if self.legend is not None:
|
if self.legend is not None and self.legend is not False:
|
||||||
self._legend(axes, location=self.legend)
|
self._legend(axes, location=self.legend)
|
||||||
# Draw figure
|
# Draw figure
|
||||||
figure.show()
|
figure.show()
|
||||||
|
# Do not forget to restore matplotlib state, in order not to interfere
|
||||||
|
# with it.
|
||||||
seaborn.reset_orig()
|
seaborn.reset_orig()
|
||||||
|
|
||||||
|
|
||||||
@ -75,10 +87,18 @@ class Figure():
|
|||||||
>>> plot([1, 2, 3], [4, 5, 6])
|
>>> plot([1, 2, 3], [4, 5, 6])
|
||||||
"""
|
"""
|
||||||
if hasattr(data, "__call__"):
|
if hasattr(data, "__call__"):
|
||||||
|
# We want to plot a function
|
||||||
self._plot_function(data, *args, **kwargs)
|
self._plot_function(data, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
# Else, it is a point series, and we just have to store it for
|
||||||
|
# later plotting.
|
||||||
self.plots.append(((data,) + args, kwargs))
|
self.plots.append(((data,) + args, kwargs))
|
||||||
|
|
||||||
|
# Automatically set the legend if label is found
|
||||||
|
# (only do it if legend is not explicitly suppressed)
|
||||||
|
if "label" in kwargs and self.legend is None:
|
||||||
|
self.legend = True
|
||||||
|
|
||||||
def _plot_function(self, data, *args, **kwargs):
|
def _plot_function(self, data, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
@ -87,8 +107,10 @@ class Figure():
|
|||||||
# No interval specified, using default one
|
# No interval specified, using default one
|
||||||
x_values = self.default_x_interval
|
x_values = self.default_x_interval
|
||||||
elif isinstance(args[0], (list, np.ndarray)):
|
elif isinstance(args[0], (list, np.ndarray)):
|
||||||
|
# List of points specified
|
||||||
x_values = args[0]
|
x_values = args[0]
|
||||||
elif isinstance(args[0], tuple):
|
elif isinstance(args[0], tuple):
|
||||||
|
# Interval specified, generate a list of points
|
||||||
x_values = np.linspace(args[0][0], args[0][1],
|
x_values = np.linspace(args[0][0], args[0][1],
|
||||||
self.default_points_number)
|
self.default_points_number)
|
||||||
else:
|
else:
|
||||||
@ -100,10 +122,12 @@ class Figure():
|
|||||||
def _legend(self, axes, location="best"):
|
def _legend(self, axes, location="best"):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
# If there should be a legend, but no location provided, put it at best
|
|
||||||
# location
|
|
||||||
if location is True:
|
if location is True:
|
||||||
|
# If there should be a legend, but no location provided, put it at
|
||||||
|
# best location.
|
||||||
location = "best"
|
location = "best"
|
||||||
|
# Create aliases for "upper" / "top" and "lower" / "bottom"
|
||||||
location.replace("top ", "upper ")
|
location.replace("top ", "upper ")
|
||||||
location.replace("bottom ", "lower ")
|
location.replace("bottom ", "lower ")
|
||||||
|
# Add legend
|
||||||
axes.legend(loc=location)
|
axes.legend(loc=location)
|
||||||
|
Loading…
Reference in New Issue
Block a user