libraries
This commit is contained in:
940
.venv/lib/python3.12/site-packages/seaborn/regression.py
Normal file
940
.venv/lib/python3.12/site-packages/seaborn/regression.py
Normal file
@@ -0,0 +1,940 @@
|
||||
"""Plotting functions for linear models (broadly construed)."""
|
||||
import copy
|
||||
from textwrap import dedent
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
import statsmodels
|
||||
assert statsmodels
|
||||
_has_statsmodels = True
|
||||
except ImportError:
|
||||
_has_statsmodels = False
|
||||
|
||||
from . import utils
|
||||
from . import algorithms as algo
|
||||
from .axisgrid import FacetGrid, _facet_docs
|
||||
|
||||
|
||||
__all__ = ["lmplot", "regplot", "residplot"]
|
||||
|
||||
|
||||
class _LinearPlotter:
|
||||
"""Base class for plotting relational data in tidy format.
|
||||
|
||||
To get anything useful done you'll have to inherit from this, but setup
|
||||
code that can be abstracted out should be put here.
|
||||
|
||||
"""
|
||||
def establish_variables(self, data, **kws):
|
||||
"""Extract variables from data or use directly."""
|
||||
self.data = data
|
||||
|
||||
# Validate the inputs
|
||||
any_strings = any([isinstance(v, str) for v in kws.values()])
|
||||
if any_strings and data is None:
|
||||
raise ValueError("Must pass `data` if using named variables.")
|
||||
|
||||
# Set the variables
|
||||
for var, val in kws.items():
|
||||
if isinstance(val, str):
|
||||
vector = data[val]
|
||||
elif isinstance(val, list):
|
||||
vector = np.asarray(val)
|
||||
else:
|
||||
vector = val
|
||||
if vector is not None and vector.shape != (1,):
|
||||
vector = np.squeeze(vector)
|
||||
if np.ndim(vector) > 1:
|
||||
err = "regplot inputs must be 1d"
|
||||
raise ValueError(err)
|
||||
setattr(self, var, vector)
|
||||
|
||||
def dropna(self, *vars):
|
||||
"""Remove observations with missing data."""
|
||||
vals = [getattr(self, var) for var in vars]
|
||||
vals = [v for v in vals if v is not None]
|
||||
not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1)
|
||||
for var in vars:
|
||||
val = getattr(self, var)
|
||||
if val is not None:
|
||||
setattr(self, var, val[not_na])
|
||||
|
||||
def plot(self, ax):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _RegressionPlotter(_LinearPlotter):
|
||||
"""Plotter for numeric independent variables with regression model.
|
||||
|
||||
This does the computations and drawing for the `regplot` function, and
|
||||
is thus also used indirectly by `lmplot`.
|
||||
"""
|
||||
def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
|
||||
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
|
||||
units=None, seed=None, order=1, logistic=False, lowess=False,
|
||||
robust=False, logx=False, x_partial=None, y_partial=None,
|
||||
truncate=False, dropna=True, x_jitter=None, y_jitter=None,
|
||||
color=None, label=None):
|
||||
|
||||
# Set member attributes
|
||||
self.x_estimator = x_estimator
|
||||
self.ci = ci
|
||||
self.x_ci = ci if x_ci == "ci" else x_ci
|
||||
self.n_boot = n_boot
|
||||
self.seed = seed
|
||||
self.scatter = scatter
|
||||
self.fit_reg = fit_reg
|
||||
self.order = order
|
||||
self.logistic = logistic
|
||||
self.lowess = lowess
|
||||
self.robust = robust
|
||||
self.logx = logx
|
||||
self.truncate = truncate
|
||||
self.x_jitter = x_jitter
|
||||
self.y_jitter = y_jitter
|
||||
self.color = color
|
||||
self.label = label
|
||||
|
||||
# Validate the regression options:
|
||||
if sum((order > 1, logistic, robust, lowess, logx)) > 1:
|
||||
raise ValueError("Mutually exclusive regression options.")
|
||||
|
||||
# Extract the data vals from the arguments or passed dataframe
|
||||
self.establish_variables(data, x=x, y=y, units=units,
|
||||
x_partial=x_partial, y_partial=y_partial)
|
||||
|
||||
# Drop null observations
|
||||
if dropna:
|
||||
self.dropna("x", "y", "units", "x_partial", "y_partial")
|
||||
|
||||
# Regress nuisance variables out of the data
|
||||
if self.x_partial is not None:
|
||||
self.x = self.regress_out(self.x, self.x_partial)
|
||||
if self.y_partial is not None:
|
||||
self.y = self.regress_out(self.y, self.y_partial)
|
||||
|
||||
# Possibly bin the predictor variable, which implies a point estimate
|
||||
if x_bins is not None:
|
||||
self.x_estimator = np.mean if x_estimator is None else x_estimator
|
||||
x_discrete, x_bins = self.bin_predictor(x_bins)
|
||||
self.x_discrete = x_discrete
|
||||
else:
|
||||
self.x_discrete = self.x
|
||||
|
||||
# Disable regression in case of singleton inputs
|
||||
if len(self.x) <= 1:
|
||||
self.fit_reg = False
|
||||
|
||||
# Save the range of the x variable for the grid later
|
||||
if self.fit_reg:
|
||||
self.x_range = self.x.min(), self.x.max()
|
||||
|
||||
@property
|
||||
def scatter_data(self):
|
||||
"""Data where each observation is a point."""
|
||||
x_j = self.x_jitter
|
||||
if x_j is None:
|
||||
x = self.x
|
||||
else:
|
||||
x = self.x + np.random.uniform(-x_j, x_j, len(self.x))
|
||||
|
||||
y_j = self.y_jitter
|
||||
if y_j is None:
|
||||
y = self.y
|
||||
else:
|
||||
y = self.y + np.random.uniform(-y_j, y_j, len(self.y))
|
||||
|
||||
return x, y
|
||||
|
||||
@property
|
||||
def estimate_data(self):
|
||||
"""Data with a point estimate and CI for each discrete x value."""
|
||||
x, y = self.x_discrete, self.y
|
||||
vals = sorted(np.unique(x))
|
||||
points, cis = [], []
|
||||
|
||||
for val in vals:
|
||||
|
||||
# Get the point estimate of the y variable
|
||||
_y = y[x == val]
|
||||
est = self.x_estimator(_y)
|
||||
points.append(est)
|
||||
|
||||
# Compute the confidence interval for this estimate
|
||||
if self.x_ci is None:
|
||||
cis.append(None)
|
||||
else:
|
||||
units = None
|
||||
if self.x_ci == "sd":
|
||||
sd = np.std(_y)
|
||||
_ci = est - sd, est + sd
|
||||
else:
|
||||
if self.units is not None:
|
||||
units = self.units[x == val]
|
||||
boots = algo.bootstrap(_y,
|
||||
func=self.x_estimator,
|
||||
n_boot=self.n_boot,
|
||||
units=units,
|
||||
seed=self.seed)
|
||||
_ci = utils.ci(boots, self.x_ci)
|
||||
cis.append(_ci)
|
||||
|
||||
return vals, points, cis
|
||||
|
||||
def _check_statsmodels(self):
|
||||
"""Check whether statsmodels is installed if any boolean options require it."""
|
||||
options = "logistic", "robust", "lowess"
|
||||
err = "`{}=True` requires statsmodels, an optional dependency, to be installed."
|
||||
for option in options:
|
||||
if getattr(self, option) and not _has_statsmodels:
|
||||
raise RuntimeError(err.format(option))
|
||||
|
||||
def fit_regression(self, ax=None, x_range=None, grid=None):
|
||||
"""Fit the regression model."""
|
||||
self._check_statsmodels()
|
||||
|
||||
# Create the grid for the regression
|
||||
if grid is None:
|
||||
if self.truncate:
|
||||
x_min, x_max = self.x_range
|
||||
else:
|
||||
if ax is None:
|
||||
x_min, x_max = x_range
|
||||
else:
|
||||
x_min, x_max = ax.get_xlim()
|
||||
grid = np.linspace(x_min, x_max, 100)
|
||||
ci = self.ci
|
||||
|
||||
# Fit the regression
|
||||
if self.order > 1:
|
||||
yhat, yhat_boots = self.fit_poly(grid, self.order)
|
||||
elif self.logistic:
|
||||
from statsmodels.genmod.generalized_linear_model import GLM
|
||||
from statsmodels.genmod.families import Binomial
|
||||
yhat, yhat_boots = self.fit_statsmodels(grid, GLM,
|
||||
family=Binomial())
|
||||
elif self.lowess:
|
||||
ci = None
|
||||
grid, yhat = self.fit_lowess()
|
||||
elif self.robust:
|
||||
from statsmodels.robust.robust_linear_model import RLM
|
||||
yhat, yhat_boots = self.fit_statsmodels(grid, RLM)
|
||||
elif self.logx:
|
||||
yhat, yhat_boots = self.fit_logx(grid)
|
||||
else:
|
||||
yhat, yhat_boots = self.fit_fast(grid)
|
||||
|
||||
# Compute the confidence interval at each grid point
|
||||
if ci is None:
|
||||
err_bands = None
|
||||
else:
|
||||
err_bands = utils.ci(yhat_boots, ci, axis=0)
|
||||
|
||||
return grid, yhat, err_bands
|
||||
|
||||
def fit_fast(self, grid):
|
||||
"""Low-level regression and prediction using linear algebra."""
|
||||
def reg_func(_x, _y):
|
||||
return np.linalg.pinv(_x).dot(_y)
|
||||
|
||||
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
|
||||
grid = np.c_[np.ones(len(grid)), grid]
|
||||
yhat = grid.dot(reg_func(X, y))
|
||||
if self.ci is None:
|
||||
return yhat, None
|
||||
|
||||
beta_boots = algo.bootstrap(X, y,
|
||||
func=reg_func,
|
||||
n_boot=self.n_boot,
|
||||
units=self.units,
|
||||
seed=self.seed).T
|
||||
yhat_boots = grid.dot(beta_boots).T
|
||||
return yhat, yhat_boots
|
||||
|
||||
def fit_poly(self, grid, order):
|
||||
"""Regression using numpy polyfit for higher-order trends."""
|
||||
def reg_func(_x, _y):
|
||||
return np.polyval(np.polyfit(_x, _y, order), grid)
|
||||
|
||||
x, y = self.x, self.y
|
||||
yhat = reg_func(x, y)
|
||||
if self.ci is None:
|
||||
return yhat, None
|
||||
|
||||
yhat_boots = algo.bootstrap(x, y,
|
||||
func=reg_func,
|
||||
n_boot=self.n_boot,
|
||||
units=self.units,
|
||||
seed=self.seed)
|
||||
return yhat, yhat_boots
|
||||
|
||||
def fit_statsmodels(self, grid, model, **kwargs):
|
||||
"""More general regression function using statsmodels objects."""
|
||||
import statsmodels.tools.sm_exceptions as sme
|
||||
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
|
||||
grid = np.c_[np.ones(len(grid)), grid]
|
||||
|
||||
def reg_func(_x, _y):
|
||||
err_classes = (sme.PerfectSeparationError,)
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
if hasattr(sme, "PerfectSeparationWarning"):
|
||||
# statsmodels>=0.14.0
|
||||
warnings.simplefilter("error", sme.PerfectSeparationWarning)
|
||||
err_classes = (*err_classes, sme.PerfectSeparationWarning)
|
||||
yhat = model(_y, _x, **kwargs).fit().predict(grid)
|
||||
except err_classes:
|
||||
yhat = np.empty(len(grid))
|
||||
yhat.fill(np.nan)
|
||||
return yhat
|
||||
|
||||
yhat = reg_func(X, y)
|
||||
if self.ci is None:
|
||||
return yhat, None
|
||||
|
||||
yhat_boots = algo.bootstrap(X, y,
|
||||
func=reg_func,
|
||||
n_boot=self.n_boot,
|
||||
units=self.units,
|
||||
seed=self.seed)
|
||||
return yhat, yhat_boots
|
||||
|
||||
def fit_lowess(self):
|
||||
"""Fit a locally-weighted regression, which returns its own grid."""
|
||||
from statsmodels.nonparametric.smoothers_lowess import lowess
|
||||
grid, yhat = lowess(self.y, self.x).T
|
||||
return grid, yhat
|
||||
|
||||
def fit_logx(self, grid):
|
||||
"""Fit the model in log-space."""
|
||||
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
|
||||
grid = np.c_[np.ones(len(grid)), np.log(grid)]
|
||||
|
||||
def reg_func(_x, _y):
|
||||
_x = np.c_[_x[:, 0], np.log(_x[:, 1])]
|
||||
return np.linalg.pinv(_x).dot(_y)
|
||||
|
||||
yhat = grid.dot(reg_func(X, y))
|
||||
if self.ci is None:
|
||||
return yhat, None
|
||||
|
||||
beta_boots = algo.bootstrap(X, y,
|
||||
func=reg_func,
|
||||
n_boot=self.n_boot,
|
||||
units=self.units,
|
||||
seed=self.seed).T
|
||||
yhat_boots = grid.dot(beta_boots).T
|
||||
return yhat, yhat_boots
|
||||
|
||||
def bin_predictor(self, bins):
|
||||
"""Discretize a predictor by assigning value to closest bin."""
|
||||
x = np.asarray(self.x)
|
||||
if np.isscalar(bins):
|
||||
percentiles = np.linspace(0, 100, bins + 2)[1:-1]
|
||||
bins = np.percentile(x, percentiles)
|
||||
else:
|
||||
bins = np.ravel(bins)
|
||||
|
||||
dist = np.abs(np.subtract.outer(x, bins))
|
||||
x_binned = bins[np.argmin(dist, axis=1)].ravel()
|
||||
|
||||
return x_binned, bins
|
||||
|
||||
def regress_out(self, a, b):
|
||||
"""Regress b from a keeping a's original mean."""
|
||||
a_mean = a.mean()
|
||||
a = a - a_mean
|
||||
b = b - b.mean()
|
||||
b = np.c_[b]
|
||||
a_prime = a - b.dot(np.linalg.pinv(b).dot(a))
|
||||
return np.asarray(a_prime + a_mean).reshape(a.shape)
|
||||
|
||||
def plot(self, ax, scatter_kws, line_kws):
|
||||
"""Draw the full plot."""
|
||||
# Insert the plot label into the correct set of keyword arguments
|
||||
if self.scatter:
|
||||
scatter_kws["label"] = self.label
|
||||
else:
|
||||
line_kws["label"] = self.label
|
||||
|
||||
# Use the current color cycle state as a default
|
||||
if self.color is None:
|
||||
lines, = ax.plot([], [])
|
||||
color = lines.get_color()
|
||||
lines.remove()
|
||||
else:
|
||||
color = self.color
|
||||
|
||||
# Ensure that color is hex to avoid matplotlib weirdness
|
||||
color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color))
|
||||
|
||||
# Let color in keyword arguments override overall plot color
|
||||
scatter_kws.setdefault("color", color)
|
||||
line_kws.setdefault("color", color)
|
||||
|
||||
# Draw the constituent plots
|
||||
if self.scatter:
|
||||
self.scatterplot(ax, scatter_kws)
|
||||
|
||||
if self.fit_reg:
|
||||
self.lineplot(ax, line_kws)
|
||||
|
||||
# Label the axes
|
||||
if hasattr(self.x, "name"):
|
||||
ax.set_xlabel(self.x.name)
|
||||
if hasattr(self.y, "name"):
|
||||
ax.set_ylabel(self.y.name)
|
||||
|
||||
def scatterplot(self, ax, kws):
|
||||
"""Draw the data."""
|
||||
# Treat the line-based markers specially, explicitly setting larger
|
||||
# linewidth than is provided by the seaborn style defaults.
|
||||
# This would ideally be handled better in matplotlib (i.e., distinguish
|
||||
# between edgewidth for solid glyphs and linewidth for line glyphs
|
||||
# but this should do for now.
|
||||
line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"]
|
||||
if self.x_estimator is None:
|
||||
if "marker" in kws and kws["marker"] in line_markers:
|
||||
lw = mpl.rcParams["lines.linewidth"]
|
||||
else:
|
||||
lw = mpl.rcParams["lines.markeredgewidth"]
|
||||
kws.setdefault("linewidths", lw)
|
||||
|
||||
if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
|
||||
kws.setdefault("alpha", .8)
|
||||
|
||||
x, y = self.scatter_data
|
||||
ax.scatter(x, y, **kws)
|
||||
else:
|
||||
# TODO abstraction
|
||||
ci_kws = {"color": kws["color"]}
|
||||
if "alpha" in kws:
|
||||
ci_kws["alpha"] = kws["alpha"]
|
||||
ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75
|
||||
kws.setdefault("s", 50)
|
||||
|
||||
xs, ys, cis = self.estimate_data
|
||||
if [ci for ci in cis if ci is not None]:
|
||||
for x, ci in zip(xs, cis):
|
||||
ax.plot([x, x], ci, **ci_kws)
|
||||
ax.scatter(xs, ys, **kws)
|
||||
|
||||
def lineplot(self, ax, kws):
|
||||
"""Draw the model."""
|
||||
# Fit the regression model
|
||||
grid, yhat, err_bands = self.fit_regression(ax)
|
||||
edges = grid[0], grid[-1]
|
||||
|
||||
# Get set default aesthetics
|
||||
fill_color = kws["color"]
|
||||
lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5)
|
||||
kws.setdefault("linewidth", lw)
|
||||
|
||||
# Draw the regression line and confidence interval
|
||||
line, = ax.plot(grid, yhat, **kws)
|
||||
if not self.truncate:
|
||||
line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
|
||||
if err_bands is not None:
|
||||
ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)
|
||||
|
||||
|
||||
_regression_docs = dict(
|
||||
|
||||
model_api=dedent("""\
|
||||
There are a number of mutually exclusive options for estimating the
|
||||
regression model. See the :ref:`tutorial <regression_tutorial>` for more
|
||||
information.\
|
||||
"""),
|
||||
regplot_vs_lmplot=dedent("""\
|
||||
The :func:`regplot` and :func:`lmplot` functions are closely related, but
|
||||
the former is an axes-level function while the latter is a figure-level
|
||||
function that combines :func:`regplot` and :class:`FacetGrid`.\
|
||||
"""),
|
||||
x_estimator=dedent("""\
|
||||
x_estimator : callable that maps vector -> scalar, optional
|
||||
Apply this function to each unique value of ``x`` and plot the
|
||||
resulting estimate. This is useful when ``x`` is a discrete variable.
|
||||
If ``x_ci`` is given, this estimate will be bootstrapped and a
|
||||
confidence interval will be drawn.\
|
||||
"""),
|
||||
x_bins=dedent("""\
|
||||
x_bins : int or vector, optional
|
||||
Bin the ``x`` variable into discrete bins and then estimate the central
|
||||
tendency and a confidence interval. This binning only influences how
|
||||
the scatterplot is drawn; the regression is still fit to the original
|
||||
data. This parameter is interpreted either as the number of
|
||||
evenly-sized (not necessary spaced) bins or the positions of the bin
|
||||
centers. When this parameter is used, it implies that the default of
|
||||
``x_estimator`` is ``numpy.mean``.\
|
||||
"""),
|
||||
x_ci=dedent("""\
|
||||
x_ci : "ci", "sd", int in [0, 100] or None, optional
|
||||
Size of the confidence interval used when plotting a central tendency
|
||||
for discrete values of ``x``. If ``"ci"``, defer to the value of the
|
||||
``ci`` parameter. If ``"sd"``, skip bootstrapping and show the
|
||||
standard deviation of the observations in each bin.\
|
||||
"""),
|
||||
scatter=dedent("""\
|
||||
scatter : bool, optional
|
||||
If ``True``, draw a scatterplot with the underlying observations (or
|
||||
the ``x_estimator`` values).\
|
||||
"""),
|
||||
fit_reg=dedent("""\
|
||||
fit_reg : bool, optional
|
||||
If ``True``, estimate and plot a regression model relating the ``x``
|
||||
and ``y`` variables.\
|
||||
"""),
|
||||
ci=dedent("""\
|
||||
ci : int in [0, 100] or None, optional
|
||||
Size of the confidence interval for the regression estimate. This will
|
||||
be drawn using translucent bands around the regression line. The
|
||||
confidence interval is estimated using a bootstrap; for large
|
||||
datasets, it may be advisable to avoid that computation by setting
|
||||
this parameter to None.\
|
||||
"""),
|
||||
n_boot=dedent("""\
|
||||
n_boot : int, optional
|
||||
Number of bootstrap resamples used to estimate the ``ci``. The default
|
||||
value attempts to balance time and stability; you may want to increase
|
||||
this value for "final" versions of plots.\
|
||||
"""),
|
||||
units=dedent("""\
|
||||
units : variable name in ``data``, optional
|
||||
If the ``x`` and ``y`` observations are nested within sampling units,
|
||||
those can be specified here. This will be taken into account when
|
||||
computing the confidence intervals by performing a multilevel bootstrap
|
||||
that resamples both units and observations (within unit). This does not
|
||||
otherwise influence how the regression is estimated or drawn.\
|
||||
"""),
|
||||
seed=dedent("""\
|
||||
seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
|
||||
Seed or random number generator for reproducible bootstrapping.\
|
||||
"""),
|
||||
order=dedent("""\
|
||||
order : int, optional
|
||||
If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a
|
||||
polynomial regression.\
|
||||
"""),
|
||||
logistic=dedent("""\
|
||||
logistic : bool, optional
|
||||
If ``True``, assume that ``y`` is a binary variable and use
|
||||
``statsmodels`` to estimate a logistic regression model. Note that this
|
||||
is substantially more computationally intensive than linear regression,
|
||||
so you may wish to decrease the number of bootstrap resamples
|
||||
(``n_boot``) or set ``ci`` to None.\
|
||||
"""),
|
||||
lowess=dedent("""\
|
||||
lowess : bool, optional
|
||||
If ``True``, use ``statsmodels`` to estimate a nonparametric lowess
|
||||
model (locally weighted linear regression). Note that confidence
|
||||
intervals cannot currently be drawn for this kind of model.\
|
||||
"""),
|
||||
robust=dedent("""\
|
||||
robust : bool, optional
|
||||
If ``True``, use ``statsmodels`` to estimate a robust regression. This
|
||||
will de-weight outliers. Note that this is substantially more
|
||||
computationally intensive than standard linear regression, so you may
|
||||
wish to decrease the number of bootstrap resamples (``n_boot``) or set
|
||||
``ci`` to None.\
|
||||
"""),
|
||||
logx=dedent("""\
|
||||
logx : bool, optional
|
||||
If ``True``, estimate a linear regression of the form y ~ log(x), but
|
||||
plot the scatterplot and regression model in the input space. Note that
|
||||
``x`` must be positive for this to work.\
|
||||
"""),
|
||||
xy_partial=dedent("""\
|
||||
{x,y}_partial : strings in ``data`` or matrices
|
||||
Confounding variables to regress out of the ``x`` or ``y`` variables
|
||||
before plotting.\
|
||||
"""),
|
||||
truncate=dedent("""\
|
||||
truncate : bool, optional
|
||||
If ``True``, the regression line is bounded by the data limits. If
|
||||
``False``, it extends to the ``x`` axis limits.
|
||||
"""),
|
||||
xy_jitter=dedent("""\
|
||||
{x,y}_jitter : floats, optional
|
||||
Add uniform random noise of this size to either the ``x`` or ``y``
|
||||
variables. The noise is added to a copy of the data after fitting the
|
||||
regression, and only influences the look of the scatterplot. This can
|
||||
be helpful when plotting variables that take discrete values.\
|
||||
"""),
|
||||
scatter_line_kws=dedent("""\
|
||||
{scatter,line}_kws : dictionaries
|
||||
Additional keyword arguments to pass to ``plt.scatter`` and
|
||||
``plt.plot``.\
|
||||
"""),
|
||||
)
|
||||
_regression_docs.update(_facet_docs)
|
||||
|
||||
|
||||
def lmplot(
|
||||
data, *,
|
||||
x=None, y=None, hue=None, col=None, row=None,
|
||||
palette=None, col_wrap=None, height=5, aspect=1, markers="o",
|
||||
sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None,
|
||||
legend=True, legend_out=None, x_estimator=None, x_bins=None,
|
||||
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
|
||||
units=None, seed=None, order=1, logistic=False, lowess=False,
|
||||
robust=False, logx=False, x_partial=None, y_partial=None,
|
||||
truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
|
||||
line_kws=None, facet_kws=None,
|
||||
):
|
||||
|
||||
if facet_kws is None:
|
||||
facet_kws = {}
|
||||
|
||||
def facet_kw_deprecation(key, val):
|
||||
msg = (
|
||||
f"{key} is deprecated from the `lmplot` function signature. "
|
||||
"Please update your code to pass it using `facet_kws`."
|
||||
)
|
||||
if val is not None:
|
||||
warnings.warn(msg, UserWarning)
|
||||
facet_kws[key] = val
|
||||
|
||||
facet_kw_deprecation("sharex", sharex)
|
||||
facet_kw_deprecation("sharey", sharey)
|
||||
facet_kw_deprecation("legend_out", legend_out)
|
||||
|
||||
if data is None:
|
||||
raise TypeError("Missing required keyword argument `data`.")
|
||||
|
||||
# Reduce the dataframe to only needed columns
|
||||
need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
|
||||
cols = np.unique([a for a in need_cols if a is not None]).tolist()
|
||||
data = data[cols]
|
||||
|
||||
# Initialize the grid
|
||||
facets = FacetGrid(
|
||||
data, row=row, col=col, hue=hue,
|
||||
palette=palette,
|
||||
row_order=row_order, col_order=col_order, hue_order=hue_order,
|
||||
height=height, aspect=aspect, col_wrap=col_wrap,
|
||||
**facet_kws,
|
||||
)
|
||||
|
||||
# Add the markers here as FacetGrid has figured out how many levels of the
|
||||
# hue variable are needed and we don't want to duplicate that process
|
||||
if facets.hue_names is None:
|
||||
n_markers = 1
|
||||
else:
|
||||
n_markers = len(facets.hue_names)
|
||||
if not isinstance(markers, list):
|
||||
markers = [markers] * n_markers
|
||||
if len(markers) != n_markers:
|
||||
raise ValueError("markers must be a singleton or a list of markers "
|
||||
"for each level of the hue variable")
|
||||
facets.hue_kws = {"marker": markers}
|
||||
|
||||
def update_datalim(data, x, y, ax, **kws):
|
||||
xys = data[[x, y]].to_numpy().astype(float)
|
||||
ax.update_datalim(xys, updatey=False)
|
||||
ax.autoscale_view(scaley=False)
|
||||
|
||||
facets.map_dataframe(update_datalim, x=x, y=y)
|
||||
|
||||
# Draw the regression plot on each facet
|
||||
regplot_kws = dict(
|
||||
x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci,
|
||||
scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units,
|
||||
seed=seed, order=order, logistic=logistic, lowess=lowess,
|
||||
robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
|
||||
truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
|
||||
scatter_kws=scatter_kws, line_kws=line_kws,
|
||||
)
|
||||
facets.map_dataframe(regplot, x=x, y=y, **regplot_kws)
|
||||
facets.set_axis_labels(x, y)
|
||||
|
||||
# Add a legend
|
||||
if legend and (hue is not None) and (hue not in [col, row]):
|
||||
facets.add_legend()
|
||||
return facets
|
||||
|
||||
|
||||
lmplot.__doc__ = dedent("""\
|
||||
Plot data and regression model fits across a FacetGrid.
|
||||
|
||||
This function combines :func:`regplot` and :class:`FacetGrid`. It is
|
||||
intended as a convenient interface to fit regression models across
|
||||
conditional subsets of a dataset.
|
||||
|
||||
When thinking about how to assign variables to different facets, a general
|
||||
rule is that it makes sense to use ``hue`` for the most important
|
||||
comparison, followed by ``col`` and ``row``. However, always think about
|
||||
your particular dataset and the goals of the visualization you are
|
||||
creating.
|
||||
|
||||
{model_api}
|
||||
|
||||
The parameters to this function span most of the options in
|
||||
:class:`FacetGrid`, although there may be occasional cases where you will
|
||||
want to use that class and :func:`regplot` directly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
{data}
|
||||
x, y : strings, optional
|
||||
Input variables; these should be column names in ``data``.
|
||||
hue, col, row : strings
|
||||
Variables that define subsets of the data, which will be drawn on
|
||||
separate facets in the grid. See the ``*_order`` parameters to control
|
||||
the order of levels of this variable.
|
||||
{palette}
|
||||
{col_wrap}
|
||||
{height}
|
||||
{aspect}
|
||||
markers : matplotlib marker code or list of marker codes, optional
|
||||
Markers for the scatterplot. If a list, each marker in the list will be
|
||||
used for each level of the ``hue`` variable.
|
||||
{share_xy}
|
||||
|
||||
.. deprecated:: 0.12.0
|
||||
Pass using the `facet_kws` dictionary.
|
||||
|
||||
{{hue,col,row}}_order : lists, optional
|
||||
Order for the levels of the faceting variables. By default, this will
|
||||
be the order that the levels appear in ``data`` or, if the variables
|
||||
are pandas categoricals, the category order.
|
||||
legend : bool, optional
|
||||
If ``True`` and there is a ``hue`` variable, add a legend.
|
||||
{legend_out}
|
||||
|
||||
.. deprecated:: 0.12.0
|
||||
Pass using the `facet_kws` dictionary.
|
||||
|
||||
{x_estimator}
|
||||
{x_bins}
|
||||
{x_ci}
|
||||
{scatter}
|
||||
{fit_reg}
|
||||
{ci}
|
||||
{n_boot}
|
||||
{units}
|
||||
{seed}
|
||||
{order}
|
||||
{logistic}
|
||||
{lowess}
|
||||
{robust}
|
||||
{logx}
|
||||
{xy_partial}
|
||||
{truncate}
|
||||
{xy_jitter}
|
||||
{scatter_line_kws}
|
||||
facet_kws : dict
|
||||
Dictionary of keyword arguments for :class:`FacetGrid`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
regplot : Plot data and a conditional model fit.
|
||||
FacetGrid : Subplot grid for plotting conditional relationships.
|
||||
pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
|
||||
``kind="reg"``).
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
{regplot_vs_lmplot}
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. include:: ../docstrings/lmplot.rst
|
||||
|
||||
""").format(**_regression_docs)
|
||||
|
||||
|
||||
def regplot(
|
||||
data=None, *, x=None, y=None,
|
||||
x_estimator=None, x_bins=None, x_ci="ci",
|
||||
scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
|
||||
seed=None, order=1, logistic=False, lowess=False, robust=False,
|
||||
logx=False, x_partial=None, y_partial=None,
|
||||
truncate=True, dropna=True, x_jitter=None, y_jitter=None,
|
||||
label=None, color=None, marker="o",
|
||||
scatter_kws=None, line_kws=None, ax=None
|
||||
):
|
||||
|
||||
plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
|
||||
scatter, fit_reg, ci, n_boot, units, seed,
|
||||
order, logistic, lowess, robust, logx,
|
||||
x_partial, y_partial, truncate, dropna,
|
||||
x_jitter, y_jitter, color, label)
|
||||
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
|
||||
scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
|
||||
scatter_kws["marker"] = marker
|
||||
line_kws = {} if line_kws is None else copy.copy(line_kws)
|
||||
plotter.plot(ax, scatter_kws, line_kws)
|
||||
return ax
|
||||
|
||||
|
||||
regplot.__doc__ = dedent("""\
|
||||
Plot data and a linear regression model fit.
|
||||
|
||||
{model_api}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x, y: string, series, or vector array
|
||||
Input variables. If strings, these should correspond with column names
|
||||
in ``data``. When pandas objects are used, axes will be labeled with
|
||||
the series name.
|
||||
{data}
|
||||
{x_estimator}
|
||||
{x_bins}
|
||||
{x_ci}
|
||||
{scatter}
|
||||
{fit_reg}
|
||||
{ci}
|
||||
{n_boot}
|
||||
{units}
|
||||
{seed}
|
||||
{order}
|
||||
{logistic}
|
||||
{lowess}
|
||||
{robust}
|
||||
{logx}
|
||||
{xy_partial}
|
||||
{truncate}
|
||||
{xy_jitter}
|
||||
label : string
|
||||
Label to apply to either the scatterplot or regression line (if
|
||||
``scatter`` is ``False``) for use in a legend.
|
||||
color : matplotlib color
|
||||
Color to apply to all plot elements; will be superseded by colors
|
||||
passed in ``scatter_kws`` or ``line_kws``.
|
||||
marker : matplotlib marker code
|
||||
Marker to use for the scatterplot glyphs.
|
||||
{scatter_line_kws}
|
||||
ax : matplotlib Axes, optional
|
||||
Axes object to draw the plot onto, otherwise uses the current Axes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ax : matplotlib Axes
|
||||
The Axes object containing the plot.
|
||||
|
||||
See Also
|
||||
--------
|
||||
lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple
|
||||
linear relationships in a dataset.
|
||||
jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with
|
||||
``kind="reg"``).
|
||||
pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
|
||||
``kind="reg"``).
|
||||
residplot : Plot the residuals of a linear regression model.
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
{regplot_vs_lmplot}
|
||||
|
||||
|
||||
It's also easy to combine :func:`regplot` and :class:`JointGrid` or
|
||||
:class:`PairGrid` through the :func:`jointplot` and :func:`pairplot`
|
||||
functions, although these do not directly accept all of :func:`regplot`'s
|
||||
parameters.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. include:: ../docstrings/regplot.rst
|
||||
|
||||
""").format(**_regression_docs)
|
||||
|
||||
|
||||
def residplot(
|
||||
data=None, *, x=None, y=None,
|
||||
x_partial=None, y_partial=None, lowess=False,
|
||||
order=1, robust=False, dropna=True, label=None, color=None,
|
||||
scatter_kws=None, line_kws=None, ax=None
|
||||
):
|
||||
"""Plot the residuals of a linear regression.
|
||||
|
||||
This function will regress y on x (possibly as a robust or polynomial
|
||||
regression) and then draw a scatterplot of the residuals. You can
|
||||
optionally fit a lowess smoother to the residual plot, which can
|
||||
help in determining if there is structure to the residuals.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : DataFrame, optional
|
||||
DataFrame to use if `x` and `y` are column names.
|
||||
x : vector or string
|
||||
Data or column name in `data` for the predictor variable.
|
||||
y : vector or string
|
||||
Data or column name in `data` for the response variable.
|
||||
{x, y}_partial : vectors or string(s) , optional
|
||||
These variables are treated as confounding and are removed from
|
||||
the `x` or `y` variables before plotting.
|
||||
lowess : boolean, optional
|
||||
Fit a lowess smoother to the residual scatterplot.
|
||||
order : int, optional
|
||||
Order of the polynomial to fit when calculating the residuals.
|
||||
robust : boolean, optional
|
||||
Fit a robust linear regression when calculating the residuals.
|
||||
dropna : boolean, optional
|
||||
If True, ignore observations with missing data when fitting and
|
||||
plotting.
|
||||
label : string, optional
|
||||
Label that will be used in any plot legends.
|
||||
color : matplotlib color, optional
|
||||
Color to use for all elements of the plot.
|
||||
{scatter, line}_kws : dictionaries, optional
|
||||
Additional keyword arguments passed to scatter() and plot() for drawing
|
||||
the components of the plot.
|
||||
ax : matplotlib axis, optional
|
||||
Plot into this axis, otherwise grab the current axis or make a new
|
||||
one if not existing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ax: matplotlib axes
|
||||
Axes with the regression plot.
|
||||
|
||||
See Also
|
||||
--------
|
||||
regplot : Plot a simple linear regression model.
|
||||
jointplot : Draw a :func:`residplot` with univariate marginal distributions
|
||||
(when used with ``kind="resid"``).
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. include:: ../docstrings/residplot.rst
|
||||
|
||||
"""
|
||||
plotter = _RegressionPlotter(x, y, data, ci=None,
|
||||
order=order, robust=robust,
|
||||
x_partial=x_partial, y_partial=y_partial,
|
||||
dropna=dropna, color=color, label=label)
|
||||
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
|
||||
# Calculate the residual from a linear regression
|
||||
_, yhat, _ = plotter.fit_regression(grid=plotter.x)
|
||||
plotter.y = plotter.y - yhat
|
||||
|
||||
# Set the regression option on the plotter
|
||||
if lowess:
|
||||
plotter.lowess = True
|
||||
else:
|
||||
plotter.fit_reg = False
|
||||
|
||||
# Plot a horizontal line at 0
|
||||
ax.axhline(0, ls=":", c=".2")
|
||||
|
||||
# Draw the scatterplot
|
||||
scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
|
||||
line_kws = {} if line_kws is None else line_kws.copy()
|
||||
plotter.plot(ax, scatter_kws, line_kws)
|
||||
return ax
|
||||
Reference in New Issue
Block a user