Skip to content

Commit

Permalink
Fix the energy loss computation (#877)
Browse files Browse the repository at this point in the history
* keep only element-by-element tracking in get_energy_lss
* update tests
* Update Matlab tests
  • Loading branch information
lfarv authored Dec 18, 2024
1 parent b90e44d commit 3b5566e
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 580 deletions.
21 changes: 8 additions & 13 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,19 @@ jobs:

strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13']
os: [macos-latest, ubuntu-latest, windows-latest]
exclude:
- os: windows-latest
python-version: '3.7'
- os: macos-latest
python-version: '3.7'
- os: macos-latest
python-version: '3.8'
- os: macos-latest
python-version: '3.9'
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
os: [macos-13, macos-latest, ubuntu-latest, windows-latest]
include:
- os: macos-13
python-version: '3.7'
- os: macos-13
python-version: '3.8'
- os: macos-13
python-version: '3.9'
- os: ubuntu-22.04
python-version: '3.7'
- os: ubuntu-22.04
python-version: '3.8'
- os: windows-latest
python-version: '3.8'


steps:
Expand Down
14 changes: 8 additions & 6 deletions atmat/atphysics/Radiation/atgetU0.m
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@
function U0=tracking(ring)
% Ensure 6d is enabled
check_6d(ring,true);
% Turn cavities off
ringtmp=atdisable_6d(ring,'allpass','','cavipass','auto',...
'quantdiffpass','auto','simplequantdiffpass','auto');
o0=zeros(6,1);
o6=ringpass(ringtmp,o0);
U0=-o6(5)*energy;
radiating=atgetcells(ring,'PassMethod','*RadPass');
sumd=sum(cellfun(@comp, ring(radiating)));
U0=-sumd*energy;

function delta = comp(elem)
rout=elempass(elem,zeros(6,1),'Energy',energy);
delta=rout(5);
end
end

end
4 changes: 2 additions & 2 deletions atmat/attests/pytests.m
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ function tunechrom6(testCase,lat2,dp)
[mtune,mchrom]=tunechrom(mlat,'get_chrom');
ptune=double(plat.get_tune());
pchrom=double(plat.get_chrom());
testCase.verifyEqual(mod(mtune*periodicity,1),ptune,AbsTol=1.e-9);
testCase.verifyEqual(mchrom*periodicity,pchrom,RelTol=1.e-4,AbsTol=3.e-4);
testCase.verifyEqual(mod(mtune*periodicity,1),ptune,AbsTol=2.5e-9);
testCase.verifyEqual(mchrom*periodicity,pchrom,RelTol=3.e-4,AbsTol=2.e-4);
end

function linopt1(testCase,dp)
Expand Down
184 changes: 100 additions & 84 deletions pyat/at/physics/energy_loss.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from __future__ import annotations

__all__ = ["get_energy_loss", "set_cavity_phase", "ELossMethod", "get_timelag_fromU0"]

from enum import Enum
from warnings import warn
from math import pi
from typing import Optional, Tuple
import numpy
from collections.abc import Sequence

import numpy as np
from scipy.optimize import least_squares

from at.constants import clight, Cgamma
from at.lattice import Lattice, Dipole, Wiggler, RFCavity, Refpts, EnergyLoss
from at.lattice import check_radiation, AtError, AtWarning
from at.lattice import QuantumDiffusion, Collective, SimpleQuantDiff
from at.lattice import get_bool_index, set_value_refpts
from at.constants import clight, Cgamma
from at.tracking import internal_lpass

__all__ = ['get_energy_loss', 'set_cavity_phase', 'ELossMethod',
'get_timelag_fromU0']


class ELossMethod(Enum):
"""methods for the computation of energy losses"""

#: The losses are obtained from
#: :math:`E_{loss}=C_\gamma/2\pi . E^4 . I_2`.
#: Takes into account bending magnets and wigglers.
Expand All @@ -26,9 +27,9 @@ class ELossMethod(Enum):
TRACKING = 2


def get_energy_loss(ring: Lattice,
method: Optional[ELossMethod] = ELossMethod.INTEGRAL
) -> float:
def get_energy_loss(
ring: Lattice, method: ELossMethod | None = ELossMethod.INTEGRAL
) -> float:
"""Computes the energy loss per turn
Parameters:
Expand All @@ -42,153 +43,168 @@ def get_energy_loss(ring: Lattice,

# noinspection PyShadowingNames
def integral(ring):
"""Losses = Cgamma / 2pi * EGeV^4 * i2
"""
"""Losses = Cgamma / 2pi * EGeV^4 * i2"""

def wiggler_i2(wiggler: Wiggler):
rhoinv = wiggler.Bmax / ring.BRho
coefh = wiggler.By[1, :]
coefv = wiggler.Bx[1, :]
return wiggler.Length * (numpy.sum(coefh * coefh) + numpy.sum(
coefv*coefv)) * rhoinv ** 2 / 2
return (
wiggler.Length
* (np.sum(coefh * coefh) + np.sum(coefv * coefv))
* rhoinv**2
/ 2
)

def dipole_i2(dipole: Dipole):
return dipole.BendingAngle ** 2 / dipole.Length
return dipole.BendingAngle**2 / dipole.Length

def eloss_i2(eloss: EnergyLoss):
return eloss.EnergyLoss / coef

i2 = 0.0
coef = Cgamma / 2.0 / pi * ring.energy ** 4
coef = Cgamma / 2.0 / np.pi * ring.energy**4
for el in ring:
if isinstance(el, Dipole):
i2 += dipole_i2(el)
elif isinstance(el, Wiggler) and el.PassMethod != 'DriftPass':
elif isinstance(el, Wiggler) and el.PassMethod != "DriftPass":
i2 += wiggler_i2(el)
elif isinstance(el, EnergyLoss) and el.PassMethod != 'IdentityPass':
elif isinstance(el, EnergyLoss) and el.PassMethod != "IdentityPass":
i2 += eloss_i2(el)
e_loss = coef * i2
return e_loss

# noinspection PyShadowingNames
@check_radiation(True)
def tracking(ring):
"""Losses from tracking
"""
ringtmp = ring.disable_6d(RFCavity, QuantumDiffusion, Collective,
SimpleQuantDiff, copy=True)
o6 = numpy.squeeze(internal_lpass(ringtmp, numpy.zeros(6),
refpts=len(ringtmp)))
if numpy.isnan(o6[0]):
dp = 0
for e in ringtmp:
ot = numpy.squeeze(internal_lpass([e], numpy.zeros(6)))
dp += -ot[4] * ring.energy
return dp
else:
return -o6[4] * ring.energy
"""Losses from tracking"""
energy = ring.energy
particle = ring.particle
delta = 0.0
for e in ring:
if e.PassMethod.endswith("RadPass"):
ot = e.track(np.zeros(6), energy=energy, particle=particle)
delta += ot[4]
return -delta * energy

if isinstance(method, str):
method = ELossMethod[method.upper()]
warn(FutureWarning('You should use {0!s}'.format(method)))
warn(FutureWarning(f"You should use {method!s}"), stacklevel=2)
if method is ELossMethod.INTEGRAL:
return ring.periodicity * integral(ring)
elif method == ELossMethod.TRACKING:
return ring.periodicity * tracking(ring)
else:
raise AtError('Invalid method: {}'.format(method))
raise AtError(f"Invalid method: {method}")


# noinspection PyPep8Naming
def get_timelag_fromU0(ring: Lattice,
method: Optional[ELossMethod] = ELossMethod.TRACKING,
cavpts: Optional[Refpts] = None,
divider: Optional[int] = 4,
ts_tol: Optional[float] = 1.0e-9) -> Tuple[float, float]:
def get_timelag_fromU0(
ring: Lattice,
*,
method: ELossMethod | None = ELossMethod.TRACKING,
cavpts: Refpts | None = None,
divider: int | None = 4,
ts_tol: float | None = 1.0e-9,
) -> tuple[Sequence[float], float]:
"""
Get the TimeLag attribute of RF cavities based on frequency,
voltage and energy loss per turn, so that the synchronous phase is zero.
An error occurs if all cavities do not have the same frequency.
Used in set_cavity_phase()
Parameters:
ring: Lattice description
method: Method for energy loss computation.
See :py:class:`ELossMethod`.
cavpts: Cavity location. If None, use all cavities.
This allows to ignore harmonic cavities.
divider: number of segments to search for ts
phis_tol: relative tolerance for ts calculation
ts_tol: relative tolerance for ts calculation
Returns:
timelag (float): Timelag
timelag (float): (ncav,) array of *Timelag* values
ts (float): Time difference with the present value
"""

def singlev(values):
vals = numpy.unique(values)
vals = np.unique(values)
if len(vals) > 1:
raise AtError('values not equal for all cavities')
raise AtError("values not equal for all cavities")
return vals[0]

def eq(x, freq, rfv, tl0, u0):
omf = 2*numpy.pi*freq/clight
omf = 2 * np.pi * freq / clight
if u0 > 0.0:
eq1 = (numpy.sum(-rfv*numpy.sin(omf*(x-tl0)))-u0)/u0
eq1 = (np.sum(-rfv * np.sin(omf * (x - tl0))) - u0) / u0
else:
eq1 = numpy.sum(-rfv * numpy.sin(omf * (x - tl0)))
eq2 = numpy.sum(-omf*rfv*numpy.cos(omf*(x-tl0)))
eq1 = np.sum(-rfv * np.sin(omf * (x - tl0)))
eq2 = np.sum(-omf * rfv * np.cos(omf * (x - tl0)))
if eq2 > 0:
return numpy.sqrt(eq1**2+eq2**2)
return np.sqrt(eq1**2 + eq2**2)
else:
return abs(eq1)

if cavpts is None:
cavpts = get_bool_index(ring, RFCavity)
u0 = get_energy_loss(ring, method=method) / ring.periodicity
freq = numpy.array([cav.Frequency for cav in ring.select(cavpts)])
rfv = numpy.array([cav.Voltage for cav in ring.select(cavpts)])
tl0 = numpy.array([cav.TimeLag for cav in ring.select(cavpts)])
freq = np.array([cav.Frequency for cav in ring.select(cavpts)])
rfv = np.array([cav.Voltage for cav in ring.select(cavpts)])
tl0 = np.array([cav.TimeLag for cav in ring.select(cavpts)])
try:
frf = singlev(freq)
tml = singlev(tl0)
except AtError:
ctmax = clight/numpy.amin(freq)/2
tt0 = tl0[numpy.argmin(freq)]
ctmax = clight / np.amin(freq) / 2
tt0 = tl0[np.argmin(freq)]
bounds = (-ctmax, ctmax)
args = (freq, rfv, tl0, u0)
r = []
for i in range(divider):
fact = (i+1)/divider
r.append(least_squares(eq, bounds[0]*fact+tt0,
args=args, bounds=bounds+tt0))
r.append(least_squares(eq, bounds[1]*fact+tt0,
args=args, bounds=bounds+tt0))
res = numpy.array([ri.fun[0] for ri in r])
fact = (i + 1) / divider
r.append(
least_squares(
eq, bounds[0] * fact + tt0, args=args, bounds=bounds + tt0
)
)
r.append(
least_squares(
eq, bounds[1] * fact + tt0, args=args, bounds=bounds + tt0
)
)
res = np.array([ri.fun[0] for ri in r])
ok = res < ts_tol
vals = numpy.array([abs(ri.x[0]).round(decimals=6) for ri in r])
if not numpy.any(ok):
raise AtError('No solution found for Phis, please check '
'RF settings')
if len(numpy.unique(vals[ok])) > 1:
warn(AtWarning('More than one solution found for Phis: use '
'best fit, please check RF settings'))
ts = -r[numpy.argmin(res)].x[0]
timelag = ts+tl0
vals = np.array([abs(ri.x[0]).round(decimals=6) for ri in r])
if not np.any(ok):
raise AtError("No solution found for Phis: check RF settings") from None
if len(np.unique(vals[ok])) > 1:
warn(
AtWarning("More than one solution found for Phis: check RF settings"),
stacklevel=2,
)
ts = -r[np.argmin(res)].x[0]
timelag = ts + tl0
else:
if u0 > numpy.sum(rfv):
raise AtError('Not enough RF voltage: unstable ring')
vrf = numpy.sum(rfv)
timelag = clight/(2*numpy.pi*frf)*numpy.arcsin(u0/vrf)
vrf = np.sum(rfv)
if u0 > vrf:
v1 = ring.periodicity * vrf
v2 = ring.periodicity * u0
raise AtError(
f"The RF voltage ({v1:.3e} eV) is lower than "
f"the radiation losses ({v2:.3e} eV)."
)
timelag = clight / (2 * np.pi * frf) * np.arcsin(u0 / vrf)
ts = timelag - tml
timelag *= numpy.ones(ring.refcount(cavpts))
timelag *= np.ones(ring.refcount(cavpts))
return timelag, ts


def set_cavity_phase(ring: Lattice,
method: ELossMethod = ELossMethod.TRACKING,
refpts: Optional[Refpts] = None,
cavpts: Optional[Refpts] = None,
copy: bool = False) -> None:
def set_cavity_phase(
ring: Lattice,
*,
method: ELossMethod = ELossMethod.TRACKING,
refpts: Refpts | None = None,
cavpts: Refpts | None = None,
copy: bool = False,
) -> None:
"""
Adjust the TimeLag attribute of RF cavities based on frequency,
voltage and energy loss per turn, so that the synchronous phase is zero.
Expand All @@ -209,12 +225,12 @@ def set_cavity_phase(ring: Lattice,
"""
# refpts is kept for backward compatibility
if cavpts is None and refpts is not None:
warn(FutureWarning('You should use "cavpts" instead of "refpts"'))
warn(FutureWarning('You should use "cavpts" instead of "refpts"'), stacklevel=2)
cavpts = refpts
elif cavpts is None:
cavpts = get_bool_index(ring, RFCavity)
timelag, _ = get_timelag_fromU0(ring, method=method, cavpts=cavpts)
set_value_refpts(ring, cavpts, 'TimeLag', timelag, copy=copy)
set_value_refpts(ring, cavpts, "TimeLag", timelag, copy=copy)


Lattice.get_energy_loss = get_energy_loss
Expand Down
9 changes: 4 additions & 5 deletions pyat/at/physics/orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,11 @@ def _orbit6(ring: Lattice, cavpts=None, guess=None, keep_lattice=False,
harm_number = round(f_rf*l0/ring.beta/clight)

if guess is None:
_, dt = get_timelag_fromU0(ring, method=method, cavpts=cavpts)
# Getting timelag by tracking uses a different lattice,
# so we cannot now use the same one again.
if method is ELossMethod.TRACKING:
keep_lattice = False
ref_in = numpy.zeros((6,), order='F')
try:
_, dt = get_timelag_fromU0(ring, method=method, cavpts=cavpts)
except AtError as exc:
raise AtError("Could not determine the initial synchronous phase") from exc
ref_in[5] = -dt
else:
ref_in = numpy.copy(guess)
Expand Down
Loading

0 comments on commit 3b5566e

Please sign in to comment.