diff --git a/pyat/at/latticetools/response_matrix.py b/pyat/at/latticetools/response_matrix.py index 28f12c3a3..d6fbaa186 100644 --- a/pyat/at/latticetools/response_matrix.py +++ b/pyat/at/latticetools/response_matrix.py @@ -101,19 +101,18 @@ from __future__ import annotations __all__ = [ - "split", + "sequence_split", "ResponseMatrix", "OrbitResponseMatrix", "TrajectoryResponseMatrix", ] import os -import math import copy import multiprocessing import concurrent.futures import abc -from collections.abc import Sequence +from collections.abc import Sequence, Generator from itertools import chain from functools import partial @@ -121,7 +120,7 @@ from .observables import TrajectoryObservable, OrbitObservable, LatticeObservable from .observablelist import ObservableList -from ..lattice import AtError, Lattice, Refpts, Orbit, AxisDef, plane_ +from ..lattice import AtError, Lattice, Refpts, AxisDef, plane_ from ..lattice import Monitor, checkattr from ..lattice.lattice_variables import RefptsVariable from ..lattice.variables import VariableList @@ -132,21 +131,32 @@ _globobs: ObservableList | None = None -def split(ary: Sequence, nslices): +def sequence_split(seq: Sequence, nslices: int) -> Generator[Sequence, None, None]: + """Split a sequence into multiple sub-sequences. - def _spl(lll): + The length of *seq* does not have to be a multiple of *nslices*. + + Args: + seq: sequence to split + nslices: number of sub-sequences + + Returns: + subseqs: Iterator over sub-sequences + """ + + def _split(seqsizes): beg = 0 - for k in lll: - end = beg + k - yield ary[beg:end] + for size in seqsizes: + end = beg + size + yield seq[beg:end] beg = end - lna = len(ary) - sz = math.trunc(lna/nslices) - lsl = [sz] * nslices - for k in range(lna - sz * nslices): - lsl[k] += 1 - return list(_spl(lsl)) + lna = len(seq) + sz, rem = divmod(lna, nslices) + lsubseqs = [sz] * nslices + for k in range(rem): + lsubseqs[k] += 1 + return _split(lsubseqs) def _resp( @@ -431,7 +441,7 @@ def build( ctx = multiprocessing.get_context(start_method) if pool_size is None: pool_size = min(len(self.variables), os.cpu_count()) - obschunks = np.array_split(self.variables, pool_size) + obschunks = sequence_split(self.variables, pool_size) if ctx.get_start_method() == "fork": _globring = self.ring _globobs = self.observables