Inverse Modeling of Quadrupole Gradient Errors by Matching the Orbit Response Matrix
This example introduces errors to quadrupole gradient strengths and the goal of the differentiable simulation is to infer these errors by matching the Orbit Response Matrix (ORM) was well as the tunes of the resulting lattice. MADX simulations are used to provide the reference data corresponding to the lattice with errors.
Running the example script will perform the following steps:
Define the lattice
Assign a random field error to the third quadrupole of each triplet:
file = "errors"
Compute Twiss of the sequence with errors:
file = "twiss"
Here we only assign one error per triplet since the magnets that form a triplet are located very close together. For that reason the compensation of neighboring field errors is quite effective which considerably slows down the convergence of the optimization process. Assigning one error per triplet is equivalent to having only one free variable per triplet (e.g. if all magnets shared the same power supply).
[1]:
import os.path
from dipas.madx import run_file
result = run_file('example.madx', results=['twiss+meta', 'errors'],
madx=os.path.expanduser('~/bin/madx'))
twiss_ref = result['twiss']
errors = result['errors']
twiss_ref[0].set_index('NAME', inplace=True) # [0] is the twiss data, [1] is the meta data ("@"-prefixed in the TFS file)
errors.set_index('NAME', inplace=True)
Let’s check the K1L values and associated errors for all magnets. As mentioned above, only the third magnet in each triplet (*QS3
) has been assigned an error:
[2]:
import pandas as pd
k1_values = pd.DataFrame({
'K1L': twiss_ref[0]['K1L'].loc[errors.index],
'Errors': errors['K1L'],
})
print(k1_values)
K1L Errors
NAME
YR02QS1 0.508655 0.000000
YR02QS2 -0.651115 0.000000
YR02QS3 0.508655 0.011836
YR04QS1 0.508655 0.000000
YR04QS2 -0.651115 0.000000
YR04QS3 0.508655 0.011628
YR06QS1 0.508655 0.000000
YR06QS2 -0.651115 0.000000
YR06QS3 0.508655 -0.008112
YR08QS1 0.508655 0.000000
YR08QS2 -0.651115 0.000000
YR08QS3 0.508655 -0.003652
YR10QS1 0.508655 0.000000
YR10QS2 -0.651115 0.000000
YR10QS3 0.508655 0.011761
YR12QS1 0.508655 0.000000
YR12QS2 -0.651115 0.000000
YR12QS3 0.508655 0.006400
Now we load the lattice from the MADX file and declare the relevant quadrupole’s k1-errors as optimization parameters in order to infer the actual values:
[3]:
from dipas.build import from_file
from dipas.elements import Quadrupole
import torch
lattice = from_file('example.madx', errors=False) # use `errors=False` to load the nominal optics
for quad in lattice['yr*qs3']:
quad.dk1 = torch.nn.Parameter(quad.dk1)
quad.update_transfer_map() # make changes to `dk1` effective
print('# parameters: ', len(list(lattice.parameters())))
# parameters: 6
With the utility function dipas.madx.run_orm
we can have MADX compute the Orbit Response Matrix for the given script file. Here we only consider the vertical component of the ORM. This will serve as the reference data against which the model will be matched.
[4]:
from dipas.elements import VKicker, VMonitor
from dipas.madx import run_orm
kicker_labels = [x.label for x in lattice[VKicker]]
monitor_labels = [x.label for x in lattice[VMonitor]]
orm_ref = run_orm('example.madx',
kickers=kicker_labels,
monitors=monitor_labels,
madx=os.path.expanduser('~/bin/madx'))
orm_ref = orm_ref.loc[:, 'Y'] # only consider the vertical component
print(orm_ref) # rows are kickers, columns are monitors
yr02dx2 yr03dx2 yr03dx3 yr06dx2 yr07dx2 yr08dx2 yr10dx2 \
yr02kv 1.115240 1.983728 1.972641 -2.891744 1.823880 3.705735 -3.214685
yr04kv 0.809330 1.921067 1.954103 1.436141 -2.135845 -2.999211 3.510664
yr07kv -0.233959 -1.277466 -1.348631 2.203000 1.410669 2.204617 -2.361074
yr08kv 3.705474 0.733856 0.197256 1.247780 1.900128 1.201006 1.044550
yr10kv -3.056570 0.464584 0.998505 -3.012275 -0.807507 1.309568 1.132972
yr12kv 1.180532 -1.503768 -1.822981 3.617948 -0.628670 -3.078852 1.449289
yr11dx2 yr12dx2
yr02kv -0.971380 1.443970
yr04kv -0.022372 -2.643008
yr07kv -0.178819 1.586732
yr08kv -2.170915 -2.916487
yr10kv 2.065269 1.186817
yr12kv 1.893791 1.050390
Using dipas.compute.orm
we can compute the ORM for the given lattice, in dependency on the quadrupole gradient errors which we have previously declared as parameters:
[5]:
import dipas.compute as compute
orm_x, orm_y = compute.orm(lattice, kickers=VKicker, monitors=VMonitor)
orm_y = pd.DataFrame(data=orm_y.detach().numpy(), index=orm_ref.index, columns=orm_ref.columns)
print(orm_y)
yr02dx2 yr03dx2 yr03dx3 yr06dx2 yr07dx2 yr08dx2 yr10dx2 \
yr02kv 1.051808 1.992815 1.989703 -2.950579 1.775919 3.689564 -3.108786
yr04kv 0.859713 1.864128 1.881479 1.546831 -2.106485 -3.043805 3.452004
yr07kv -0.295694 -1.264483 -1.323158 2.149441 1.414799 2.254200 -2.355583
yr08kv 3.689564 0.712763 0.171017 1.415990 1.853362 1.051808 1.156167
yr10kv -2.950579 0.484488 1.006893 -3.108786 -0.747051 1.415990 1.051808
yr12kv 1.156167 -1.508437 -1.824636 3.689564 -0.626480 -3.108786 1.415990
yr11dx2 yr12dx2
yr02kv -0.946485 1.415990
yr04kv -0.022300 -2.625416
yr07kv -0.165343 1.614370
yr08kv -2.118423 -2.950579
yr10kv 1.986812 1.156167
yr12kv 1.886741 1.051808
Since the above lattice
has no gradient errors so far, the result is quite different. The goal is to align the two ORMs so that their values match.
Similarly we can compute the tunes via dipas.compute.twiss
:
[6]:
from dipas.elements import Kicker
twiss = compute.twiss(lattice.makethin({Kicker: 2}, style={Kicker: 'edge'})) # MADX uses 'edge' style
print(f'Tunes: Q1 = {twiss["Q1"]:.3f}, Q2 = {twiss["Q2"]:.3f}')
print(f'Reference: Q1 = {twiss_ref[1]["Q1"]:.3f}, Q2 = {twiss_ref[1]["Q2"]:.3f}')
Tunes: Q1 = 2.420, Q2 = 2.420
Reference: Q1 = 2.439, Q2 = 2.411
In the following we setup and run the optimization process. For that purpose we need to define an optimizer
as well as compute the necessary quantities during each step of the optimization.
[ ]:
import itertools as it
from dipas.elements import tensor
optimizer = torch.optim.Adam(lattice.parameters(), lr=1.8e-3, betas=(0.51, 0.96))
quadrupoles = lattice['yr*qs3']
Q1 = twiss_ref[1]["Q1"]
Q2 = twiss_ref[1]["Q2"]
orm_ref_y = torch.from_numpy(orm_ref.to_numpy())
cost_history = []
dk1_history = []
for step in it.count(1):
optimizer.zero_grad()
orm_y = compute.orm(lattice, kickers=VKicker, monitors=VMonitor)[1]
cost1 = torch.nn.functional.mse_loss(orm_y, orm_ref_y)
try:
twiss = compute.twiss(lattice.makethin({Kicker: 2}, style={Kicker: 'edge'}))
except compute.UnstableLatticeError:
cost2 = tensor(0.)
else:
cost2 = (twiss['Q1'] - Q1)**2 + (twiss['Q2'] - Q2)**2
cost = cost1 + cost2
cost.backward(retain_graph=True)
cost_history.append(cost.item())
dk1_history.append([quad.dk1.item() for quad in quadrupoles])
print(f'[Step {step:03d}] cost = {cost_history[-1]:.3e}')
optimizer.step()
if cost_history[-1] < 1e-12: # if converged
break
for quad in quadrupoles:
quad.update_transfer_map() # make changes from `optimizer.step()` effective
We can check the k1-error values during the optimization in order to assess the convergence:
[8]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
dk1_history = np.array(dk1_history)
fig, ax = plt.subplots(figsize=(9.6, 7.2))
ax.set(xlabel='Iteration', ylabel='K1L error [1/m]')
for i, quad in enumerate(quadrupoles):
ax.plot(dk1_history[:, i]*quad.l.item(), label=quad.label)
ax.axhline(errors.loc[quad.label.upper(), 'K1L'], lw=0.5, ls='--', color='black', zorder=-100)
ax.legend()
[8]:
<matplotlib.legend.Legend at 0x7f84cd51f110>
The small wiggles towards the end of the yr06qs3
and yr08qs3
lines come from the particular structure of the parameter space close to the target values. The considered quadrupoles in the lattice have a certain capability to compensate each other’s over- or underestimation of the true parameter values. This creates a region of strong compensation were the considered cost function (ORM + tunes) barely changes, resulting in a very slow, asymptotic convergence, as can be seen for iteration
40 or later. Perpendicular to that region however the cost increases very rapidly, so small misalignments of the optimizer momentum with respect to that region can lead to a digression from the “optimal” route, causing transverse oscillations in parameter space which are eventually damped away. Being mostly perpendicular to the direction towards the target this ususally doesn’t hinder convergence. Nevertheless the convergence properties largely depend on the used optimizer and its settings, so a
systematic screening of the available options is recommended.
We can also check the cost function during the optimization which reflects the above observed wiggles as well. Nevertheless, imagining a continuation of the cost trend line beyond iteration 75 arrives at approximately the same number of iterations needed to reach \(10^{-12}\) MSE level (i.e. the wiggles don’t hinder the convergence process).
[9]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.set(xlabel='Iteration', ylabel='Cost (MSE)')
ax.set_yscale('log')
ax.plot(cost_history)
ax.axhline(1e-12, lw=0.5, ls='--', color='black', zorder=-100)
[9]:
<matplotlib.lines.Line2D at 0x7f84c63a3c10>
Finally we run a crosscheck with MADX, using the derived k1-error values, in order to confirm that the computed ORM does indeed match our computation:
[10]:
from dipas.build import create_script
script = create_script(
beam=dict(particle='proton', energy=1),
sequence=lattice,
errors=True)
with open('crosscheck_orm.madx', 'w') as fh:
fh.write(script)
orm_cc = run_orm('crosscheck_orm.madx',
kickers=kicker_labels,
monitors=monitor_labels,
madx=os.path.expanduser('~/bin/madx'))
orm_cc = orm_cc.loc[:, 'Y']
print('Deviation between computed and crosscheck ORM:', end='\n\n')
print(orm_cc - orm_ref)
Deviation between computed and crosscheck ORM:
yr02dx2 yr03dx2 yr03dx3 yr06dx2 yr07dx2 \
yr02kv -6.660000e-06 -2.938000e-06 -2.112000e-06 6.800000e-07 -1.357000e-06
yr04kv 4.068900e-06 1.354000e-06 8.110000e-07 -4.800000e-08 1.835000e-06
yr07kv -1.316000e-06 3.540000e-07 5.970000e-07 -1.382000e-06 -2.987000e-06
yr08kv -9.660000e-07 7.280000e-07 9.459000e-07 -1.028000e-06 -3.304000e-06
yr10kv -4.969000e-06 -3.296800e-06 -2.774300e-06 1.558000e-06 1.244100e-06
yr12kv 7.332000e-06 3.797000e-06 2.936000e-06 -1.186000e-06 1.344000e-07
yr08dx2 yr10dx2 yr11dx2 yr12dx2
yr02kv -9.540000e-07 -4.630000e-06 3.543900e-06 7.109000e-06
yr04kv 1.531000e-06 1.902000e-06 -2.322250e-06 -3.865000e-06
yr07kv -2.521000e-06 9.580000e-07 9.788000e-07 3.590000e-07
yr08kv -3.103000e-06 1.396000e-06 7.440000e-07 -3.190000e-07
yr10kv 1.284000e-06 -5.831000e-06 2.325000e-06 6.553000e-06
yr12kv -2.680000e-07 6.340000e-06 -3.647000e-06 -8.422000e-06