venv
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -18,6 +18,7 @@ from warnings import WarningMessage
|
||||
import pprint
|
||||
import sysconfig
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
from numpy._core import (
|
||||
@@ -2684,12 +2685,27 @@ _glibcver = _get_glibc_version()
|
||||
_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x)
|
||||
|
||||
|
||||
def run_threaded(func, iters, pass_count=False):
|
||||
def run_threaded(func, iters=8, pass_count=False, max_workers=8,
|
||||
pass_barrier=False, outer_iterations=1,
|
||||
prepare_args=None):
|
||||
"""Runs a function many times in parallel"""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
|
||||
if pass_count:
|
||||
futures = [tpe.submit(func, i) for i in range(iters)]
|
||||
else:
|
||||
futures = [tpe.submit(func) for _ in range(iters)]
|
||||
for f in futures:
|
||||
f.result()
|
||||
for _ in range(outer_iterations):
|
||||
with (concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
as tpe):
|
||||
if prepare_args is None:
|
||||
args = []
|
||||
else:
|
||||
args = prepare_args()
|
||||
if pass_barrier:
|
||||
if max_workers != iters:
|
||||
raise RuntimeError(
|
||||
"Must set max_workers equal to the number of "
|
||||
"iterations to avoid deadlocks.")
|
||||
barrier = threading.Barrier(max_workers)
|
||||
args.append(barrier)
|
||||
if pass_count:
|
||||
futures = [tpe.submit(func, i, *args) for i in range(iters)]
|
||||
else:
|
||||
futures = [tpe.submit(func, *args) for _ in range(iters)]
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
||||
Reference in New Issue
Block a user