Skip to content

Commit

Permalink
added sequence_split
Browse files Browse the repository at this point in the history
  • Loading branch information
lfarv committed Dec 22, 2024
1 parent f832ff7 commit a91db22
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions pyat/at/latticetools/response_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,26 @@
from __future__ import annotations

__all__ = [
"split",
"sequence_split",
"ResponseMatrix",
"OrbitResponseMatrix",
"TrajectoryResponseMatrix",
]

import os
import math
import copy
import multiprocessing
import concurrent.futures
import abc
from collections.abc import Sequence
from collections.abc import Sequence, Generator
from itertools import chain
from functools import partial

import numpy as np

from .observables import TrajectoryObservable, OrbitObservable, LatticeObservable
from .observablelist import ObservableList
from ..lattice import AtError, Lattice, Refpts, Orbit, AxisDef, plane_
from ..lattice import AtError, Lattice, Refpts, AxisDef, plane_
from ..lattice import Monitor, checkattr
from ..lattice.lattice_variables import RefptsVariable
from ..lattice.variables import VariableList
Expand All @@ -132,21 +131,32 @@
_globobs: ObservableList | None = None


def split(ary: Sequence, nslices):
def sequence_split(seq: Sequence, nslices: int) -> Generator[Sequence, None, None]:
"""Split a sequence into multiple sub-sequences.
def _spl(lll):
The length of *seq* does not have to be a multiple of *nslices*.
Args:
seq: sequence to split
nslices: number of sub-sequences
Returns:
subseqs: Iterator over sub-sequences
"""

def _split(seqsizes):
beg = 0
for k in lll:
end = beg + k
yield ary[beg:end]
for size in seqsizes:
end = beg + size
yield seq[beg:end]
beg = end

lna = len(ary)
sz = math.trunc(lna/nslices)
lsl = [sz] * nslices
for k in range(lna - sz * nslices):
lsl[k] += 1
return list(_spl(lsl))
lna = len(seq)
sz, rem = divmod(lna, nslices)
lsubseqs = [sz] * nslices
for k in range(rem):
lsubseqs[k] += 1
return _split(lsubseqs)


def _resp(
Expand Down Expand Up @@ -431,7 +441,7 @@ def build(
ctx = multiprocessing.get_context(start_method)
if pool_size is None:
pool_size = min(len(self.variables), os.cpu_count())
obschunks = np.array_split(self.variables, pool_size)
obschunks = sequence_split(self.variables, pool_size)
if ctx.get_start_method() == "fork":
_globring = self.ring
_globobs = self.observables
Expand Down

0 comments on commit a91db22

Please sign in to comment.