This commit is contained in:
2025-01-26 19:24:23 -08:00
parent 32cd60e92b
commit d1dde0dbc6
4155 changed files with 29170 additions and 216373 deletions

View File

@@ -18,6 +18,7 @@ from warnings import WarningMessage
import pprint
import sysconfig
import concurrent.futures
import threading
import numpy as np
from numpy._core import (
@@ -2684,12 +2685,27 @@ _glibcver = _get_glibc_version()
_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x)
def run_threaded(func, iters, pass_count=False):
def run_threaded(func, iters=8, pass_count=False, max_workers=8,
pass_barrier=False, outer_iterations=1,
prepare_args=None):
"""Runs a function many times in parallel"""
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
if pass_count:
futures = [tpe.submit(func, i) for i in range(iters)]
else:
futures = [tpe.submit(func) for _ in range(iters)]
for f in futures:
f.result()
for _ in range(outer_iterations):
with (concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
as tpe):
if prepare_args is None:
args = []
else:
args = prepare_args()
if pass_barrier:
if max_workers != iters:
raise RuntimeError(
"Must set max_workers equal to the number of "
"iterations to avoid deadlocks.")
barrier = threading.Barrier(max_workers)
args.append(barrier)
if pass_count:
futures = [tpe.submit(func, i, *args) for i in range(iters)]
else:
futures = [tpe.submit(func, *args) for _ in range(iters)]
for f in futures:
f.result()