libraries

This commit is contained in:
2024-09-28 22:52:53 -07:00
parent 5cdaf1f76b
commit 4929d1fa66
7378 changed files with 1550978 additions and 14 deletions

View File

@@ -0,0 +1,170 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
import numpy as np
import matplotlib as mpl
from seaborn._marks.base import (
Mark,
Mappable,
MappableBool,
MappableFloat,
MappableColor,
MappableStyle,
resolve_properties,
resolve_color,
document_properties,
)
class AreaBase:
def _plot(self, split_gen, scales, orient):
patches = defaultdict(list)
for keys, data, ax in split_gen():
kws = {}
data = self._standardize_coordinate_parameters(data, orient)
resolved = resolve_properties(self, keys, scales)
verts = self._get_verts(data, orient)
ax.update_datalim(verts)
# TODO should really move this logic into resolve_color
fc = resolve_color(self, keys, "", scales)
if not resolved["fill"]:
fc = mpl.colors.to_rgba(fc, 0)
kws["facecolor"] = fc
kws["edgecolor"] = resolve_color(self, keys, "edge", scales)
kws["linewidth"] = resolved["edgewidth"]
kws["linestyle"] = resolved["edgestyle"]
patches[ax].append(mpl.patches.Polygon(verts, **kws))
for ax, ax_patches in patches.items():
for patch in ax_patches:
self._postprocess_artist(patch, ax, orient)
ax.add_patch(patch)
def _standardize_coordinate_parameters(self, data, orient):
return data
def _postprocess_artist(self, artist, ax, orient):
pass
def _get_verts(self, data, orient):
dv = {"x": "y", "y": "x"}[orient]
data = data.sort_values(orient, kind="mergesort")
verts = np.concatenate([
data[[orient, f"{dv}min"]].to_numpy(),
data[[orient, f"{dv}max"]].to_numpy()[::-1],
])
if orient == "y":
verts = verts[:, ::-1]
return verts
def _legend_artist(self, variables, value, scales):
keys = {v: value for v in variables}
resolved = resolve_properties(self, keys, scales)
fc = resolve_color(self, keys, "", scales)
if not resolved["fill"]:
fc = mpl.colors.to_rgba(fc, 0)
return mpl.patches.Patch(
facecolor=fc,
edgecolor=resolve_color(self, keys, "edge", scales),
linewidth=resolved["edgewidth"],
linestyle=resolved["edgestyle"],
**self.artist_kws,
)
@document_properties
@dataclass
class Area(AreaBase, Mark):
"""
A fill mark drawn from a baseline to data values.
See also
--------
Band : A fill mark representing an interval between values.
Examples
--------
.. include:: ../docstrings/objects.Area.rst
"""
color: MappableColor = Mappable("C0", )
alpha: MappableFloat = Mappable(.2, )
fill: MappableBool = Mappable(True, )
edgecolor: MappableColor = Mappable(depend="color")
edgealpha: MappableFloat = Mappable(1, )
edgewidth: MappableFloat = Mappable(rc="patch.linewidth", )
edgestyle: MappableStyle = Mappable("-", )
# TODO should this be settable / mappable?
baseline: MappableFloat = Mappable(0, grouping=False)
def _standardize_coordinate_parameters(self, data, orient):
dv = {"x": "y", "y": "x"}[orient]
return data.rename(columns={"baseline": f"{dv}min", dv: f"{dv}max"})
def _postprocess_artist(self, artist, ax, orient):
# TODO copying a lot of code from Bar, let's abstract this
# See comments there, I am not going to repeat them too
artist.set_linewidth(artist.get_linewidth() * 2)
linestyle = artist.get_linestyle()
if linestyle[1]:
linestyle = (linestyle[0], tuple(x / 2 for x in linestyle[1]))
artist.set_linestyle(linestyle)
artist.set_clip_path(artist.get_path(), artist.get_transform() + ax.transData)
if self.artist_kws.get("clip_on", True):
artist.set_clip_box(ax.bbox)
val_idx = ["y", "x"].index(orient)
artist.sticky_edges[val_idx][:] = (0, np.inf)
@document_properties
@dataclass
class Band(AreaBase, Mark):
"""
A fill mark representing an interval between values.
See also
--------
Area : A fill mark drawn from a baseline to data values.
Examples
--------
.. include:: ../docstrings/objects.Band.rst
"""
color: MappableColor = Mappable("C0", )
alpha: MappableFloat = Mappable(.2, )
fill: MappableBool = Mappable(True, )
edgecolor: MappableColor = Mappable(depend="color", )
edgealpha: MappableFloat = Mappable(1, )
edgewidth: MappableFloat = Mappable(0, )
edgestyle: MappableFloat = Mappable("-", )
def _standardize_coordinate_parameters(self, data, orient):
# dv = {"x": "y", "y": "x"}[orient]
# TODO assert that all(ymax >= ymin)?
# TODO what if only one exist?
other = {"x": "y", "y": "x"}[orient]
if not set(data.columns) & {f"{other}min", f"{other}max"}:
agg = {f"{other}min": (other, "min"), f"{other}max": (other, "max")}
data = data.groupby(orient).agg(**agg).reset_index()
return data

View File

@@ -0,0 +1,252 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
import numpy as np
import matplotlib as mpl
from seaborn._marks.base import (
Mark,
Mappable,
MappableBool,
MappableColor,
MappableFloat,
MappableStyle,
resolve_properties,
resolve_color,
document_properties
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
from matplotlib.artist import Artist
from seaborn._core.scales import Scale
class BarBase(Mark):
def _make_patches(self, data, scales, orient):
transform = scales[orient]._matplotlib_scale.get_transform()
forward = transform.transform
reverse = transform.inverted().transform
other = {"x": "y", "y": "x"}[orient]
pos = reverse(forward(data[orient]) - data["width"] / 2)
width = reverse(forward(data[orient]) + data["width"] / 2) - pos
val = (data[other] - data["baseline"]).to_numpy()
base = data["baseline"].to_numpy()
kws = self._resolve_properties(data, scales)
if orient == "x":
kws.update(x=pos, y=base, w=width, h=val)
else:
kws.update(x=base, y=pos, w=val, h=width)
kws.pop("width", None)
kws.pop("baseline", None)
val_dim = {"x": "h", "y": "w"}[orient]
bars, vals = [], []
for i in range(len(data)):
row = {k: v[i] for k, v in kws.items()}
# Skip bars with no value. It's possible we'll want to make this
# an option (i.e so you have an artist for animating or annotating),
# but let's keep things simple for now.
if not np.nan_to_num(row[val_dim]):
continue
bar = mpl.patches.Rectangle(
xy=(row["x"], row["y"]),
width=row["w"],
height=row["h"],
facecolor=row["facecolor"],
edgecolor=row["edgecolor"],
linestyle=row["edgestyle"],
linewidth=row["edgewidth"],
**self.artist_kws,
)
bars.append(bar)
vals.append(row[val_dim])
return bars, vals
def _resolve_properties(self, data, scales):
resolved = resolve_properties(self, data, scales)
resolved["facecolor"] = resolve_color(self, data, "", scales)
resolved["edgecolor"] = resolve_color(self, data, "edge", scales)
fc = resolved["facecolor"]
if isinstance(fc, tuple):
resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"]
else:
fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem?
resolved["facecolor"] = fc
return resolved
def _legend_artist(
self, variables: list[str], value: Any, scales: dict[str, Scale],
) -> Artist:
# TODO return some sensible default?
key = {v: value for v in variables}
key = self._resolve_properties(key, scales)
artist = mpl.patches.Patch(
facecolor=key["facecolor"],
edgecolor=key["edgecolor"],
linewidth=key["edgewidth"],
linestyle=key["edgestyle"],
)
return artist
@document_properties
@dataclass
class Bar(BarBase):
"""
A bar mark drawn between baseline and data values.
See also
--------
Bars : A faster bar mark with defaults more suitable for histograms.
Examples
--------
.. include:: ../docstrings/objects.Bar.rst
"""
color: MappableColor = Mappable("C0", grouping=False)
alpha: MappableFloat = Mappable(.7, grouping=False)
fill: MappableBool = Mappable(True, grouping=False)
edgecolor: MappableColor = Mappable(depend="color", grouping=False)
edgealpha: MappableFloat = Mappable(1, grouping=False)
edgewidth: MappableFloat = Mappable(rc="patch.linewidth", grouping=False)
edgestyle: MappableStyle = Mappable("-", grouping=False)
# pattern: MappableString = Mappable(None) # TODO no Property yet
width: MappableFloat = Mappable(.8, grouping=False)
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
def _plot(self, split_gen, scales, orient):
val_idx = ["y", "x"].index(orient)
for _, data, ax in split_gen():
bars, vals = self._make_patches(data, scales, orient)
for bar in bars:
# Because we are clipping the artist (see below), the edges end up
# looking half as wide as they actually are. I don't love this clumsy
# workaround, which is going to cause surprises if you work with the
# artists directly. We may need to revisit after feedback.
bar.set_linewidth(bar.get_linewidth() * 2)
linestyle = bar.get_linestyle()
if linestyle[1]:
linestyle = (linestyle[0], tuple(x / 2 for x in linestyle[1]))
bar.set_linestyle(linestyle)
# This is a bit of a hack to handle the fact that the edge lines are
# centered on the actual extents of the bar, and overlap when bars are
# stacked or dodged. We may discover that this causes problems and needs
# to be revisited at some point. Also it should be faster to clip with
# a bbox than a path, but I cant't work out how to get the intersection
# with the axes bbox.
bar.set_clip_path(bar.get_path(), bar.get_transform() + ax.transData)
if self.artist_kws.get("clip_on", True):
# It seems the above hack undoes the default axes clipping
bar.set_clip_box(ax.bbox)
bar.sticky_edges[val_idx][:] = (0, np.inf)
ax.add_patch(bar)
# Add a container which is useful for, e.g. Axes.bar_label
orientation = {"x": "vertical", "y": "horizontal"}[orient]
container_kws = dict(datavalues=vals, orientation=orientation)
container = mpl.container.BarContainer(bars, **container_kws)
ax.add_container(container)
@document_properties
@dataclass
class Bars(BarBase):
"""
A faster bar mark with defaults more suitable for histograms.
See also
--------
Bar : A bar mark drawn between baseline and data values.
Examples
--------
.. include:: ../docstrings/objects.Bars.rst
"""
color: MappableColor = Mappable("C0", grouping=False)
alpha: MappableFloat = Mappable(.7, grouping=False)
fill: MappableBool = Mappable(True, grouping=False)
edgecolor: MappableColor = Mappable(rc="patch.edgecolor", grouping=False)
edgealpha: MappableFloat = Mappable(1, grouping=False)
edgewidth: MappableFloat = Mappable(auto=True, grouping=False)
edgestyle: MappableStyle = Mappable("-", grouping=False)
# pattern: MappableString = Mappable(None) # TODO no Property yet
width: MappableFloat = Mappable(1, grouping=False)
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
def _plot(self, split_gen, scales, orient):
ori_idx = ["x", "y"].index(orient)
val_idx = ["y", "x"].index(orient)
patches = defaultdict(list)
for _, data, ax in split_gen():
bars, _ = self._make_patches(data, scales, orient)
patches[ax].extend(bars)
collections = {}
for ax, ax_patches in patches.items():
col = mpl.collections.PatchCollection(ax_patches, match_original=True)
col.sticky_edges[val_idx][:] = (0, np.inf)
ax.add_collection(col, autolim=False)
collections[ax] = col
# Workaround for matplotlib autoscaling bug
# https://github.com/matplotlib/matplotlib/issues/11898
# https://github.com/matplotlib/matplotlib/issues/23129
xys = np.vstack([path.vertices for path in col.get_paths()])
ax.update_datalim(xys)
if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable):
for ax in collections:
ax.autoscale_view()
def get_dimensions(collection):
edges, widths = [], []
for verts in (path.vertices for path in collection.get_paths()):
edges.append(min(verts[:, ori_idx]))
widths.append(np.ptp(verts[:, ori_idx]))
return np.array(edges), np.array(widths)
min_width = np.inf
for ax, col in collections.items():
edges, widths = get_dimensions(col)
points = 72 / ax.figure.dpi * abs(
ax.transData.transform([edges + widths] * 2)
- ax.transData.transform([edges] * 2)
)
min_width = min(min_width, min(points[:, ori_idx]))
linewidth = min(.1 * min_width, mpl.rcParams["patch.linewidth"])
for _, col in collections.items():
col.set_linewidth(linewidth)

View File

@@ -0,0 +1,317 @@
from __future__ import annotations
from dataclasses import dataclass, fields, field
import textwrap
from typing import Any, Callable, Union
from collections.abc import Generator
import numpy as np
import pandas as pd
import matplotlib as mpl
from numpy import ndarray
from pandas import DataFrame
from matplotlib.artist import Artist
from seaborn._core.scales import Scale
from seaborn._core.properties import (
PROPERTIES,
Property,
RGBATuple,
DashPattern,
DashPatternWithOffset,
)
from seaborn._core.exceptions import PlotSpecError
class Mappable:
def __init__(
self,
val: Any = None,
depend: str | None = None,
rc: str | None = None,
auto: bool = False,
grouping: bool = True,
):
"""
Property that can be mapped from data or set directly, with flexible defaults.
Parameters
----------
val : Any
Use this value as the default.
depend : str
Use the value of this feature as the default.
rc : str
Use the value of this rcParam as the default.
auto : bool
The default value will depend on other parameters at compile time.
grouping : bool
If True, use the mapped variable to define groups.
"""
if depend is not None:
assert depend in PROPERTIES
if rc is not None:
assert rc in mpl.rcParams
self._val = val
self._rc = rc
self._depend = depend
self._auto = auto
self._grouping = grouping
def __repr__(self):
"""Nice formatting for when object appears in Mark init signature."""
if self._val is not None:
s = f"<{repr(self._val)}>"
elif self._depend is not None:
s = f"<depend:{self._depend}>"
elif self._rc is not None:
s = f"<rc:{self._rc}>"
elif self._auto:
s = "<auto>"
else:
s = "<undefined>"
return s
@property
def depend(self) -> Any:
"""Return the name of the feature to source a default value from."""
return self._depend
@property
def grouping(self) -> bool:
return self._grouping
@property
def default(self) -> Any:
"""Get the default value for this feature, or access the relevant rcParam."""
if self._val is not None:
return self._val
elif self._rc is not None:
return mpl.rcParams.get(self._rc)
# TODO where is the right place to put this kind of type aliasing?
MappableBool = Union[bool, Mappable]
MappableString = Union[str, Mappable]
MappableFloat = Union[float, Mappable]
MappableColor = Union[str, tuple, Mappable]
MappableStyle = Union[str, DashPattern, DashPatternWithOffset, Mappable]
@dataclass
class Mark:
"""Base class for objects that visually represent data."""
artist_kws: dict = field(default_factory=dict)
@property
def _mappable_props(self):
return {
f.name: getattr(self, f.name) for f in fields(self)
if isinstance(f.default, Mappable)
}
@property
def _grouping_props(self):
# TODO does it make sense to have variation within a Mark's
# properties about whether they are grouping?
return [
f.name for f in fields(self)
if isinstance(f.default, Mappable) and f.default.grouping
]
# TODO make this method private? Would extender every need to call directly?
def _resolve(
self,
data: DataFrame | dict[str, Any],
name: str,
scales: dict[str, Scale] | None = None,
) -> Any:
"""Obtain default, specified, or mapped value for a named feature.
Parameters
----------
data : DataFrame or dict with scalar values
Container with data values for features that will be semantically mapped.
name : string
Identity of the feature / semantic.
scales: dict
Mapping from variable to corresponding scale object.
Returns
-------
value or array of values
Outer return type depends on whether `data` is a dict (implying that
we want a single value) or DataFrame (implying that we want an array
of values with matching length).
"""
feature = self._mappable_props[name]
prop = PROPERTIES.get(name, Property(name))
directly_specified = not isinstance(feature, Mappable)
return_multiple = isinstance(data, pd.DataFrame)
return_array = return_multiple and not name.endswith("style")
# Special case width because it needs to be resolved and added to the dataframe
# during layer prep (so the Move operations use it properly).
# TODO how does width *scaling* work, e.g. for violin width by count?
if name == "width":
directly_specified = directly_specified and name not in data
if directly_specified:
feature = prop.standardize(feature)
if return_multiple:
feature = [feature] * len(data)
if return_array:
feature = np.array(feature)
return feature
if name in data:
if scales is None or name not in scales:
# TODO Might this obviate the identity scale? Just don't add a scale?
feature = data[name]
else:
scale = scales[name]
value = data[name]
try:
feature = scale(value)
except Exception as err:
raise PlotSpecError._during("Scaling operation", name) from err
if return_array:
feature = np.asarray(feature)
return feature
if feature.depend is not None:
# TODO add source_func or similar to transform the source value?
# e.g. set linewidth as a proportion of pointsize?
return self._resolve(data, feature.depend, scales)
default = prop.standardize(feature.default)
if return_multiple:
default = [default] * len(data)
if return_array:
default = np.array(default)
return default
def _infer_orient(self, scales: dict) -> str: # TODO type scales
# TODO The original version of this (in seaborn._base) did more checking.
# Paring that down here for the prototype to see what restrictions make sense.
# TODO rethink this to map from scale type to "DV priority" and use that?
# e.g. Nominal > Discrete > Continuous
x = 0 if "x" not in scales else scales["x"]._priority
y = 0 if "y" not in scales else scales["y"]._priority
if y > x:
return "y"
else:
return "x"
def _plot(
self,
split_generator: Callable[[], Generator],
scales: dict[str, Scale],
orient: str,
) -> None:
"""Main interface for creating a plot."""
raise NotImplementedError()
def _legend_artist(
self, variables: list[str], value: Any, scales: dict[str, Scale],
) -> Artist | None:
return None
def resolve_properties(
mark: Mark, data: DataFrame, scales: dict[str, Scale]
) -> dict[str, Any]:
props = {
name: mark._resolve(data, name, scales) for name in mark._mappable_props
}
return props
def resolve_color(
mark: Mark,
data: DataFrame | dict,
prefix: str = "",
scales: dict[str, Scale] | None = None,
) -> RGBATuple | ndarray:
"""
Obtain a default, specified, or mapped value for a color feature.
This method exists separately to support the relationship between a
color and its corresponding alpha. We want to respect alpha values that
are passed in specified (or mapped) color values but also make use of a
separate `alpha` variable, which can be mapped. This approach may also
be extended to support mapping of specific color channels (i.e.
luminance, chroma) in the future.
Parameters
----------
mark :
Mark with the color property.
data :
Container with data values for features that will be semantically mapped.
prefix :
Support "color", "fillcolor", etc.
"""
color = mark._resolve(data, f"{prefix}color", scales)
if f"{prefix}alpha" in mark._mappable_props:
alpha = mark._resolve(data, f"{prefix}alpha", scales)
else:
alpha = mark._resolve(data, "alpha", scales)
def visible(x, axis=None):
"""Detect "invisible" colors to set alpha appropriately."""
# TODO First clause only needed to handle non-rgba arrays,
# which we are trying to handle upstream
return np.array(x).dtype.kind != "f" or np.isfinite(x).all(axis)
# Second check here catches vectors of strings with identity scale
# It could probably be handled better upstream. This is a tricky problem
if np.ndim(color) < 2 and all(isinstance(x, float) for x in color):
if len(color) == 4:
return mpl.colors.to_rgba(color)
alpha = alpha if visible(color) else np.nan
return mpl.colors.to_rgba(color, alpha)
else:
if np.ndim(color) == 2 and color.shape[1] == 4:
return mpl.colors.to_rgba_array(color)
alpha = np.where(visible(color, axis=1), alpha, np.nan)
return mpl.colors.to_rgba_array(color, alpha)
# TODO should we be implementing fill here too?
# (i.e. set fillalpha to 0 when fill=False)
def document_properties(mark):
properties = [f.name for f in fields(mark) if isinstance(f.default, Mappable)]
text = [
"",
" This mark defines the following properties:",
textwrap.fill(
", ".join([f"|{p}|" for p in properties]),
width=78, initial_indent=" " * 8, subsequent_indent=" " * 8,
),
]
docstring_lines = mark.__doc__.split("\n")
new_docstring = "\n".join([
*docstring_lines[:2],
*text,
*docstring_lines[2:],
])
mark.__doc__ = new_docstring
return mark

View File

@@ -0,0 +1,200 @@
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import matplotlib as mpl
from seaborn._marks.base import (
Mark,
Mappable,
MappableBool,
MappableFloat,
MappableString,
MappableColor,
MappableStyle,
resolve_properties,
resolve_color,
document_properties,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
from matplotlib.artist import Artist
from seaborn._core.scales import Scale
class DotBase(Mark):
def _resolve_paths(self, data):
paths = []
path_cache = {}
marker = data["marker"]
def get_transformed_path(m):
return m.get_path().transformed(m.get_transform())
if isinstance(marker, mpl.markers.MarkerStyle):
return get_transformed_path(marker)
for m in marker:
if m not in path_cache:
path_cache[m] = get_transformed_path(m)
paths.append(path_cache[m])
return paths
def _resolve_properties(self, data, scales):
resolved = resolve_properties(self, data, scales)
resolved["path"] = self._resolve_paths(resolved)
resolved["size"] = resolved["pointsize"] ** 2
if isinstance(data, dict): # Properties for single dot
filled_marker = resolved["marker"].is_filled()
else:
filled_marker = [m.is_filled() for m in resolved["marker"]]
resolved["fill"] = resolved["fill"] * filled_marker
return resolved
def _plot(self, split_gen, scales, orient):
# TODO Not backcompat with allowed (but nonfunctional) univariate plots
# (That should be solved upstream by defaulting to "" for unset x/y?)
# (Be mindful of xmin/xmax, etc!)
for _, data, ax in split_gen():
offsets = np.column_stack([data["x"], data["y"]])
data = self._resolve_properties(data, scales)
points = mpl.collections.PathCollection(
offsets=offsets,
paths=data["path"],
sizes=data["size"],
facecolors=data["facecolor"],
edgecolors=data["edgecolor"],
linewidths=data["linewidth"],
linestyles=data["edgestyle"],
transOffset=ax.transData,
transform=mpl.transforms.IdentityTransform(),
**self.artist_kws,
)
ax.add_collection(points)
def _legend_artist(
self, variables: list[str], value: Any, scales: dict[str, Scale],
) -> Artist:
key = {v: value for v in variables}
res = self._resolve_properties(key, scales)
return mpl.collections.PathCollection(
paths=[res["path"]],
sizes=[res["size"]],
facecolors=[res["facecolor"]],
edgecolors=[res["edgecolor"]],
linewidths=[res["linewidth"]],
linestyles=[res["edgestyle"]],
transform=mpl.transforms.IdentityTransform(),
**self.artist_kws,
)
@document_properties
@dataclass
class Dot(DotBase):
"""
A mark suitable for dot plots or less-dense scatterplots.
See also
--------
Dots : A dot mark defined by strokes to better handle overplotting.
Examples
--------
.. include:: ../docstrings/objects.Dot.rst
"""
marker: MappableString = Mappable("o", grouping=False)
pointsize: MappableFloat = Mappable(6, grouping=False) # TODO rcParam?
stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam?
color: MappableColor = Mappable("C0", grouping=False)
alpha: MappableFloat = Mappable(1, grouping=False)
fill: MappableBool = Mappable(True, grouping=False)
edgecolor: MappableColor = Mappable(depend="color", grouping=False)
edgealpha: MappableFloat = Mappable(depend="alpha", grouping=False)
edgewidth: MappableFloat = Mappable(.5, grouping=False) # TODO rcParam?
edgestyle: MappableStyle = Mappable("-", grouping=False)
def _resolve_properties(self, data, scales):
resolved = super()._resolve_properties(data, scales)
filled = resolved["fill"]
main_stroke = resolved["stroke"]
edge_stroke = resolved["edgewidth"]
resolved["linewidth"] = np.where(filled, edge_stroke, main_stroke)
main_color = resolve_color(self, data, "", scales)
edge_color = resolve_color(self, data, "edge", scales)
if not np.isscalar(filled):
# Expand dims to use in np.where with rgba arrays
filled = filled[:, None]
resolved["edgecolor"] = np.where(filled, edge_color, main_color)
filled = np.squeeze(filled)
if isinstance(main_color, tuple):
# TODO handle this in resolve_color
main_color = tuple([*main_color[:3], main_color[3] * filled])
else:
main_color = np.c_[main_color[:, :3], main_color[:, 3] * filled]
resolved["facecolor"] = main_color
return resolved
@document_properties
@dataclass
class Dots(DotBase):
"""
A dot mark defined by strokes to better handle overplotting.
See also
--------
Dot : A mark suitable for dot plots or less-dense scatterplots.
Examples
--------
.. include:: ../docstrings/objects.Dots.rst
"""
# TODO retype marker as MappableMarker
marker: MappableString = Mappable(rc="scatter.marker", grouping=False)
pointsize: MappableFloat = Mappable(4, grouping=False) # TODO rcParam?
stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam?
color: MappableColor = Mappable("C0", grouping=False)
alpha: MappableFloat = Mappable(1, grouping=False) # TODO auto alpha?
fill: MappableBool = Mappable(True, grouping=False)
fillcolor: MappableColor = Mappable(depend="color", grouping=False)
fillalpha: MappableFloat = Mappable(.2, grouping=False)
def _resolve_properties(self, data, scales):
resolved = super()._resolve_properties(data, scales)
resolved["linewidth"] = resolved.pop("stroke")
resolved["facecolor"] = resolve_color(self, data, "fill", scales)
resolved["edgecolor"] = resolve_color(self, data, "", scales)
resolved.setdefault("edgestyle", (0, None))
fc = resolved["facecolor"]
if isinstance(fc, tuple):
resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"]
else:
fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem?
resolved["facecolor"] = fc
return resolved

View File

@@ -0,0 +1,285 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar
import numpy as np
import matplotlib as mpl
from seaborn._marks.base import (
Mark,
Mappable,
MappableFloat,
MappableString,
MappableColor,
resolve_properties,
resolve_color,
document_properties,
)
@document_properties
@dataclass
class Path(Mark):
"""
A mark connecting data points in the order they appear.
See also
--------
Line : A mark connecting data points with sorting along the orientation axis.
Paths : A faster but less-flexible mark for drawing many paths.
Examples
--------
.. include:: ../docstrings/objects.Path.rst
"""
color: MappableColor = Mappable("C0")
alpha: MappableFloat = Mappable(1)
linewidth: MappableFloat = Mappable(rc="lines.linewidth")
linestyle: MappableString = Mappable(rc="lines.linestyle")
marker: MappableString = Mappable(rc="lines.marker")
pointsize: MappableFloat = Mappable(rc="lines.markersize")
fillcolor: MappableColor = Mappable(depend="color")
edgecolor: MappableColor = Mappable(depend="color")
edgewidth: MappableFloat = Mappable(rc="lines.markeredgewidth")
_sort: ClassVar[bool] = False
def _plot(self, split_gen, scales, orient):
for keys, data, ax in split_gen(keep_na=not self._sort):
vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)
vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)
if self._sort:
data = data.sort_values(orient, kind="mergesort")
artist_kws = self.artist_kws.copy()
self._handle_capstyle(artist_kws, vals)
line = mpl.lines.Line2D(
data["x"].to_numpy(),
data["y"].to_numpy(),
color=vals["color"],
linewidth=vals["linewidth"],
linestyle=vals["linestyle"],
marker=vals["marker"],
markersize=vals["pointsize"],
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**artist_kws,
)
ax.add_line(line)
def _legend_artist(self, variables, value, scales):
keys = {v: value for v in variables}
vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)
vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)
artist_kws = self.artist_kws.copy()
self._handle_capstyle(artist_kws, vals)
return mpl.lines.Line2D(
[], [],
color=vals["color"],
linewidth=vals["linewidth"],
linestyle=vals["linestyle"],
marker=vals["marker"],
markersize=vals["pointsize"],
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**artist_kws,
)
def _handle_capstyle(self, kws, vals):
# Work around for this matplotlib issue:
# https://github.com/matplotlib/matplotlib/issues/23437
if vals["linestyle"][1] is None:
capstyle = kws.get("solid_capstyle", mpl.rcParams["lines.solid_capstyle"])
kws["dash_capstyle"] = capstyle
@document_properties
@dataclass
class Line(Path):
"""
A mark connecting data points with sorting along the orientation axis.
See also
--------
Path : A mark connecting data points in the order they appear.
Lines : A faster but less-flexible mark for drawing many lines.
Examples
--------
.. include:: ../docstrings/objects.Line.rst
"""
_sort: ClassVar[bool] = True
@document_properties
@dataclass
class Paths(Mark):
"""
A faster but less-flexible mark for drawing many paths.
See also
--------
Path : A mark connecting data points in the order they appear.
Examples
--------
.. include:: ../docstrings/objects.Paths.rst
"""
color: MappableColor = Mappable("C0")
alpha: MappableFloat = Mappable(1)
linewidth: MappableFloat = Mappable(rc="lines.linewidth")
linestyle: MappableString = Mappable(rc="lines.linestyle")
_sort: ClassVar[bool] = False
def __post_init__(self):
# LineCollection artists have a capstyle property but don't source its value
# from the rc, so we do that manually here. Unfortunately, because we add
# only one LineCollection, we have the use the same capstyle for all lines
# even when they are dashed. It's a slight inconsistency, but looks fine IMO.
self.artist_kws.setdefault("capstyle", mpl.rcParams["lines.solid_capstyle"])
def _plot(self, split_gen, scales, orient):
line_data = {}
for keys, data, ax in split_gen(keep_na=not self._sort):
if ax not in line_data:
line_data[ax] = {
"segments": [],
"colors": [],
"linewidths": [],
"linestyles": [],
}
segments = self._setup_segments(data, orient)
line_data[ax]["segments"].extend(segments)
n = len(segments)
vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)
line_data[ax]["colors"].extend([vals["color"]] * n)
line_data[ax]["linewidths"].extend([vals["linewidth"]] * n)
line_data[ax]["linestyles"].extend([vals["linestyle"]] * n)
for ax, ax_data in line_data.items():
lines = mpl.collections.LineCollection(**ax_data, **self.artist_kws)
# Handle datalim update manually
# https://github.com/matplotlib/matplotlib/issues/23129
ax.add_collection(lines, autolim=False)
if ax_data["segments"]:
xy = np.concatenate(ax_data["segments"])
ax.update_datalim(xy)
def _legend_artist(self, variables, value, scales):
key = resolve_properties(self, {v: value for v in variables}, scales)
artist_kws = self.artist_kws.copy()
capstyle = artist_kws.pop("capstyle")
artist_kws["solid_capstyle"] = capstyle
artist_kws["dash_capstyle"] = capstyle
return mpl.lines.Line2D(
[], [],
color=key["color"],
linewidth=key["linewidth"],
linestyle=key["linestyle"],
**artist_kws,
)
def _setup_segments(self, data, orient):
if self._sort:
data = data.sort_values(orient, kind="mergesort")
# Column stack to avoid block consolidation
xy = np.column_stack([data["x"], data["y"]])
return [xy]
@document_properties
@dataclass
class Lines(Paths):
"""
A faster but less-flexible mark for drawing many lines.
See also
--------
Line : A mark connecting data points with sorting along the orientation axis.
Examples
--------
.. include:: ../docstrings/objects.Lines.rst
"""
_sort: ClassVar[bool] = True
@document_properties
@dataclass
class Range(Paths):
"""
An oriented line mark drawn between min/max values.
Examples
--------
.. include:: ../docstrings/objects.Range.rst
"""
def _setup_segments(self, data, orient):
# TODO better checks on what variables we have
# TODO what if only one exist?
val = {"x": "y", "y": "x"}[orient]
if not set(data.columns) & {f"{val}min", f"{val}max"}:
agg = {f"{val}min": (val, "min"), f"{val}max": (val, "max")}
data = data.groupby(orient).agg(**agg).reset_index()
cols = [orient, f"{val}min", f"{val}max"]
data = data[cols].melt(orient, value_name=val)[["x", "y"]]
segments = [d.to_numpy() for _, d in data.groupby(orient)]
return segments
@document_properties
@dataclass
class Dash(Paths):
"""
A line mark drawn as an oriented segment for each datapoint.
Examples
--------
.. include:: ../docstrings/objects.Dash.rst
"""
width: MappableFloat = Mappable(.8, grouping=False)
def _setup_segments(self, data, orient):
ori = ["x", "y"].index(orient)
xys = data[["x", "y"]].to_numpy().astype(float)
segments = np.stack([xys, xys], axis=1)
segments[:, 0, ori] -= data["width"] / 2
segments[:, 1, ori] += data["width"] / 2
return segments

View File

@@ -0,0 +1,76 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
import numpy as np
import matplotlib as mpl
from matplotlib.transforms import ScaledTranslation
from seaborn._marks.base import (
Mark,
Mappable,
MappableFloat,
MappableString,
MappableColor,
resolve_properties,
resolve_color,
document_properties,
)
@document_properties
@dataclass
class Text(Mark):
"""
A textual mark to annotate or represent data values.
Examples
--------
.. include:: ../docstrings/objects.Text.rst
"""
text: MappableString = Mappable("")
color: MappableColor = Mappable("k")
alpha: MappableFloat = Mappable(1)
fontsize: MappableFloat = Mappable(rc="font.size")
halign: MappableString = Mappable("center")
valign: MappableString = Mappable("center_baseline")
offset: MappableFloat = Mappable(4)
def _plot(self, split_gen, scales, orient):
ax_data = defaultdict(list)
for keys, data, ax in split_gen():
vals = resolve_properties(self, keys, scales)
color = resolve_color(self, keys, "", scales)
halign = vals["halign"]
valign = vals["valign"]
fontsize = vals["fontsize"]
offset = vals["offset"] / 72
offset_trans = ScaledTranslation(
{"right": -offset, "left": +offset}.get(halign, 0),
{"top": -offset, "bottom": +offset, "baseline": +offset}.get(valign, 0),
ax.figure.dpi_scale_trans,
)
for row in data.to_dict("records"):
artist = mpl.text.Text(
x=row["x"],
y=row["y"],
text=str(row.get("text", vals["text"])),
color=color,
fontsize=fontsize,
horizontalalignment=halign,
verticalalignment=valign,
transform=ax.transData + offset_trans,
**self.artist_kws,
)
ax.add_artist(artist)
ax_data[ax].append([row["x"], row["y"]])
for ax, ax_vals in ax_data.items():
ax.update_datalim(np.array(ax_vals))