library packages
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
"""global_random_seed fixture
|
||||
|
||||
The goal of this fixture is to prevent tests that use it to be sensitive
|
||||
to a specific seed value while still being deterministic by default.
|
||||
|
||||
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
|
||||
variable for insrtuctions on how to use this fixture.
|
||||
|
||||
https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed
|
||||
"""
|
||||
|
||||
from os import environ
|
||||
from random import Random
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# Passes the main worker's random seeds to workers
|
||||
class XDistHooks:
|
||||
def pytest_configure_node(self, node) -> None:
|
||||
random_seeds = node.config.getoption("random_seeds")
|
||||
node.workerinput["random_seeds"] = random_seeds
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
if config.pluginmanager.hasplugin("xdist"):
|
||||
config.pluginmanager.register(XDistHooks())
|
||||
|
||||
RANDOM_SEED_RANGE = list(range(100)) # All seeds in [0, 99] should be valid.
|
||||
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED")
|
||||
if hasattr(config, "workerinput") and "random_seeds" in config.workerinput:
|
||||
# Set worker random seed from seed generated from main process
|
||||
random_seeds = config.workerinput["random_seeds"]
|
||||
elif random_seed_var is None:
|
||||
# This is the way.
|
||||
random_seeds = [42]
|
||||
elif random_seed_var == "any":
|
||||
# Pick-up one seed at random in the range of admissible random seeds.
|
||||
random_seeds = [Random().choice(RANDOM_SEED_RANGE)]
|
||||
elif random_seed_var == "all":
|
||||
random_seeds = RANDOM_SEED_RANGE
|
||||
else:
|
||||
if "-" in random_seed_var:
|
||||
start, stop = random_seed_var.split("-")
|
||||
random_seeds = list(range(int(start), int(stop) + 1))
|
||||
else:
|
||||
random_seeds = [int(random_seed_var)]
|
||||
|
||||
if min(random_seeds) < 0 or max(random_seeds) > 99:
|
||||
raise ValueError(
|
||||
"The value(s) of the environment variable "
|
||||
"SKLEARN_TESTS_GLOBAL_RANDOM_SEED must be in the range [0, 99] "
|
||||
f"(or 'any' or 'all'), got: {random_seed_var}"
|
||||
)
|
||||
config.option.random_seeds = random_seeds
|
||||
|
||||
class GlobalRandomSeedPlugin:
|
||||
@pytest.fixture(params=random_seeds)
|
||||
def global_random_seed(self, request):
|
||||
"""Fixture to ask for a random yet controllable random seed.
|
||||
|
||||
All tests that use this fixture accept the contract that they should
|
||||
deterministically pass for any seed value from 0 to 99 included.
|
||||
|
||||
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
|
||||
variable for insrtuctions on how to use this fixture.
|
||||
|
||||
https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed
|
||||
"""
|
||||
yield request.param
|
||||
|
||||
config.pluginmanager.register(GlobalRandomSeedPlugin())
|
||||
|
||||
|
||||
def pytest_report_header(config):
|
||||
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED")
|
||||
if random_seed_var == "any":
|
||||
return [
|
||||
"To reproduce this test run, set the following environment variable:",
|
||||
f' SKLEARN_TESTS_GLOBAL_RANDOM_SEED="{config.option.random_seeds[0]}"',
|
||||
(
|
||||
"See: https://scikit-learn.org/dev/computing/parallelism.html"
|
||||
"#sklearn-tests-global-random-seed"
|
||||
),
|
||||
]
|
||||
Reference in New Issue
Block a user