"""
Utility classes and functions for ``SpectralTraceList``.
This module contains
- the definition of the `SpectralTrace` class. The visible effect should
always be a `SpectralTraceList`, even if that contains only one
`SpectralTrace`.
- the definition of the `XiLamImage` class
- utility functions for use with spectral traces
"""
import warnings
import numpy as np
from scipy.interpolate import RectBivariateSpline
from scipy.interpolate import interp1d
from astropy.table import Table, vstack
from astropy.modeling import fitting
from astropy.io import fits
from astropy import units as u
from astropy.wcs import WCS
from astropy.modeling.models import Polynomial2D
from ..utils import (power_vector, quantify, from_currsys, close_loop,
figure_factory, get_logger)
logger = get_logger(__name__)
[docs]class SpectralTrace:
"""Definition of one spectral trace.
A SpectralTrace describes the mapping of spectral slit coordinates
to the focal plane. The class reads an order layout and fits several
functions to describe the geometry of the trace.
Slit coordinates are:
- xi : spatial position along the slit [arcsec]
- lam : Wavelength [um]
Focal plane coordinates are:
- x, y : [mm]
"""
_class_params = {
"x_colname": "x",
"y_colname": "y",
"s_colname": "s",
"wave_colname": "wavelength",
"dwave": 0.002,
"aperture_id": 0,
"image_plane_id": 0,
"extension_id": 2,
"spline_order": 4,
"pixel_size": None,
"description": "<no description>",
}
def __init__(self, trace_tbl, cmds=None, **kwargs):
# Within scopesim, the actual parameter values are
# passed as kwargs from the SpectralTraceList.
# The values need to be here for stand-alone use.
self.meta = {}
self.meta.update(self._class_params)
self.meta.update(kwargs)
self.cmds = cmds
if isinstance(trace_tbl, (fits.BinTableHDU, fits.TableHDU)):
self.table = Table.read(trace_tbl)
self.meta["trace_id"] = trace_tbl.header.get("EXTNAME",
"<unknown trace id>")
self.dispersion_axis = trace_tbl.header.get("DISPDIR", "unknown")
elif isinstance(trace_tbl, Table):
self.table = trace_tbl
self.dispersion_axis = "unknown"
else:
raise ValueError("trace_tbl must be one of (fits.BinTableHDU, "
"fits.TableHDU, astropy.Table) but is "
f"{type(trace_tbl)}")
self.compute_interpolation_functions()
# Declaration of other attributes
self._xilamimg = None
self.dlam_per_pix = None
@property
def trace_id(self):
"""Return the name of the trace."""
return self.meta["trace_id"]
[docs] def fov_grid(self):
"""
Provide information on the source space volume required by the effect.
Returns
-------
A dictionary with entries `wave_min` and `wave_max`.
Spatial limits are determined by the `ApertureMask` effect
and are not returned here.
"""
warnings.warn("The fov_grid method is deprecated and will be removed "
"in a future release.", DeprecationWarning, stacklevel=2)
aperture_id = self.meta["aperture_id"]
lam_arr = self.table[self.meta["wave_colname"]]
wave_max = np.max(lam_arr)
wave_min = np.min(lam_arr)
return {"wave_min": wave_min, "wave_max": wave_max,
"trace_id": self.trace_id, "aperture_id": aperture_id}
[docs] def compute_interpolation_functions(self):
"""
Compute various interpolation functions between slit and focal plane.
Focal plane coordinates are `x` and `y`, in mm. Slit coordinates are
`xi` (spatial coordinate along the slit, in arcsec) and `lam`
(wavelength, in um).
"""
x_arr = self.table[self.meta["x_colname"]]
y_arr = self.table[self.meta["y_colname"]]
xi_arr = self.table[self.meta["s_colname"]]
lam_arr = self.table[self.meta["wave_colname"]]
self.wave_min = quantify(np.min(lam_arr), u.um).value
self.wave_max = quantify(np.max(lam_arr), u.um).value
self.xy2xi = Transform2D.fit(x_arr, y_arr, xi_arr)
self.xy2lam = Transform2D.fit(x_arr, y_arr, lam_arr)
self.xilam2x = Transform2D.fit(xi_arr, lam_arr, x_arr)
self.xilam2y = Transform2D.fit(xi_arr, lam_arr, y_arr)
self._xiy2x = Transform2D.fit(xi_arr, y_arr, x_arr)
self._xiy2lam = Transform2D.fit(xi_arr, y_arr, lam_arr)
if self.dispersion_axis == "unknown":
dlam_dx, dlam_dy = self.xy2lam.gradient()
wave_mid = 0.5 * (self.wave_min + self.wave_max)
xi_mid = np.mean(xi_arr)
x_mid = self.xilam2x(xi_mid, wave_mid)
y_mid = self.xilam2y(xi_mid, wave_mid)
if dlam_dx(x_mid, y_mid) > dlam_dy(x_mid, y_mid):
self.dispersion_axis = "x"
else:
self.dispersion_axis = "y"
logger.info(
"Dispersion axis determined to be %s", self.dispersion_axis)
[docs] def map_spectra_to_focal_plane(self, fov):
"""
Apply the spectral trace mapping to a spectral cube.
The cube is contained in a FieldOfView object, which also has
world coordinate systems for the Source (sky coordinates and
wavelengths) and for the focal plane.
The method returns a section of the fov image along with info on
where this image lies in the focal plane.
"""
logger.info("Mapping %s", fov.trace_id)
# Initialise the image based on the footprint of the spectral
# trace and the focal plane WCS
wave_min = fov.meta["wave_min"].value # [um]
wave_max = fov.meta["wave_max"].value # [um]
xi_min = fov.meta["xi_min"].value # [arcsec]
xi_max = fov.meta["xi_max"].value # [arcsec]
xlim_mm, ylim_mm = self.footprint(wave_min=wave_min, wave_max=wave_max,
xi_min=xi_min, xi_max=xi_max)
if xlim_mm is None:
logger.warning("xlim_mm is None")
return None
fov_header = fov.header
det_header = fov.detector_header
# WCSD from the FieldOfView - this is the full detector plane
pixsize = fov_header["CDELT1D"] * u.Unit(fov_header["CUNIT1D"])
pixsize = pixsize.to(u.mm).value
pixscale = fov_header["CDELT1"] * u.Unit(fov_header["CUNIT1"])
pixscale = pixscale.to(u.arcsec).value
fpa_wcsd = WCS(det_header, key="D")
naxis1d, naxis2d = det_header["NAXIS1"], det_header["NAXIS2"]
xlim_px, ylim_px = fpa_wcsd.all_world2pix(xlim_mm, ylim_mm, 0)
xmin = np.floor(xlim_px.min()).astype(int)
xmax = np.ceil(xlim_px.max()).astype(int)
ymin = np.floor(ylim_px.min()).astype(int)
ymax = np.ceil(ylim_px.max()).astype(int)
# Check if spectral trace footprint is outside FoV
if xmax < 0 or xmin > naxis1d or ymax < 0 or ymin > naxis2d:
logger.info(
"Spectral trace %d: footprint is outside FoV", fov.trace_id)
return None
# Only work on parts within the FoV
xmin = max(xmin, 0)
xmax = min(xmax, naxis1d)
ymin = max(ymin, 0)
ymax = min(ymax, naxis2d)
# Create header for the subimage - I think this only needs the DET one,
# but we'll do both. The WCSs are initialised from the full fpa WCS and
# then shifted accordingly.
det_wcs = WCS(det_header, key="D")
det_wcs.wcs.crpix -= np.array([xmin, ymin])
sub_naxis1 = xmax - xmin
sub_naxis2 = ymax - ymin
# initialise the subimage
image = np.zeros((sub_naxis2, sub_naxis1), dtype=np.float32)
# Adjust the limits of the subimage in millimeters in the focal plane
# This takes the adjustment to integer pixels into account
xmin_mm, ymin_mm = fpa_wcsd.all_pix2world(xmin, ymin, 0)
xmax_mm, ymax_mm = fpa_wcsd.all_pix2world(xmax, ymax, 0)
self._set_dispersion(wave_min, wave_max, pixsize=pixsize)
try:
xilam = XiLamImage(fov, self.dlam_per_pix)
self._xilamimg = xilam # ..todo: remove or make available with a debug flag?
except ValueError:
logger.warning(" ---> %d gave ValueError", self.trace_id)
npix_xi, npix_lam = xilam.npix_xi, xilam.npix_lam
xilam_wcs = xilam.wcs
# focal-plane coordinate images
ximg_fpa, yimg_fpa = np.meshgrid(np.linspace(xmin_mm, xmax_mm,
sub_naxis1,
dtype=np.float32),
np.linspace(ymin_mm, ymax_mm,
sub_naxis2,
dtype=np.float32))
# Image mapping (xi, lambda) on the focal plane
xi_fpa = self.xy2xi(ximg_fpa, yimg_fpa).astype(np.float32)
lam_fpa = self.xy2lam(ximg_fpa, yimg_fpa).astype(np.float32)
# mask everything outside the wavelength range
mask = (xi_fpa >= xi_min) & (xi_fpa <= xi_max)
xi_fpa *= mask
lam_fpa *= mask
# mask everything outside the wavelength range
# mask = (lam_fpa >= wave_min) & (lam_fpa <= wave_max)
# xi_fpa *= mask
# lam_fpa *= mask
# Convert to pixel images
# These are the pixel coordinates in the image corresponding to
# xi, lambda
# It is much quicker to do the linear transformation by hand
# than to use the astropy.wcs functions for conversion.
i_img = ((lam_fpa - xilam_wcs.wcs.crval[0])
/ xilam_wcs.wcs.cdelt[0]).astype(int)
j_img = ((xi_fpa - xilam_wcs.wcs.crval[1])
/ xilam_wcs.wcs.cdelt[1]).astype(int)
# truncate images to remove pixel coordinates outside the image
ijmask = ((i_img >= 0) * (i_img < npix_lam)
* (j_img >= 0) * (j_img < npix_xi))
# do the actual interpolation
# image is in [ph/s/um/arcsec]
image = xilam.interp(xi_fpa, lam_fpa, grid=False) * ijmask
# Scale to ph / s / pixel
dlam_by_dx, dlam_by_dy = self.xy2lam.gradient()
dlam_per_pix = pixsize * np.sqrt(dlam_by_dx(ximg_fpa, yimg_fpa)**2 +
dlam_by_dy(ximg_fpa, yimg_fpa)**2)
image *= pixscale * dlam_per_pix # [arcsec/pix] * [um/pix]
# img_header = sub_wcs.to_header()
# img_header.update(det_wcs.to_header())
img_header = det_wcs.to_header()
img_header["XMIN"] = xmin
img_header["XMAX"] = xmax
img_header["YMIN"] = ymin
img_header["YMAX"] = ymax
if np.any(image < 0):
logger.warning("map_spectra_to_focal_plane: %d negative pixels",
np.sum(image < 0))
image_hdu = fits.ImageHDU(header=img_header, data=image)
return image_hdu
[docs] def rectify(self, hdulist, interps=None, wcs=None, **kwargs):
"""Create 2D spectrum for a trace.
Parameters
----------
hdulist : HDUList
The result of scopesim readout
interps : list of interpolation functions
If provided, there must be one for each image extension in `hdulist`.
The functions go from pixels to the images and can be created with,
e.g., RectBivariateSpline.
wcs : The WCS describing the rectified XiLamImage. This can be created
in a simple way from the fov included in the `OpticalTrain` used in
the simulation run producing `hdulist`.
The WCS can also be set up via the following keywords:
bin_width : float [um]
The spectral bin width. This is best computed automatically from the
spectral dispersion of the trace.
wave_min, wave_max : float [um]
Limits of the wavelength range to extract. The default is the
the full range on which the `SpectralTrace` is defined. This may
extend significantly beyond the filter window.
xi_min, xi_max : float [arcsec]
Spatial limits of the slit on the sky. This should be taken from
the header of the hdulist, but this is not yet provided by scopesim
"""
logger.info("Rectifying %s", self.trace_id)
wave_min = kwargs.get("wave_min",
self.wave_min)
wave_max = kwargs.get("wave_max",
self.wave_max)
if wave_max < self.wave_min or wave_min > self.wave_max:
logger.info(" Outside filter range")
return None
wave_min = max(wave_min, self.wave_min)
wave_max = min(wave_max, self.wave_max)
logger.info(" %.02f .. %.02f um", wave_min, wave_max)
# bin_width is taken as the minimum dispersion of the trace
# ..todo: if wcs is given take bin width from cdelt1
bin_width = kwargs.get("bin_width", None)
if bin_width is None:
self._set_dispersion(wave_min, wave_max)
bin_width = np.abs(self.dlam_per_pix.y).min()
logger.info(" Bin width %.02g um", bin_width)
pixscale = from_currsys(self.meta["pixel_scale"], self.cmds)
# Temporary solution to get slit length
xi_min = kwargs.get("xi_min", None)
if xi_min is None:
try:
xi_min = hdulist[0].header["HIERARCH INS SLIT XIMIN"]
except KeyError:
logger.error("xi_min not found")
return None
xi_max = kwargs.get("xi_max", None)
if xi_max is None:
try:
xi_max = hdulist[0].header["HIERARCH INS SLIT XIMAX"]
except KeyError:
logger.error("xi_max not found")
return None
if wcs is None:
wcs = WCS(naxis=2)
wcs.wcs.ctype = ["WAVE", "LINEAR"]
wcs.wcs.cunit = ["um", "arcsec"]
wcs.wcs.crpix = [1, 1]
wcs.wcs.cdelt = [bin_width, pixscale] # PIXSCALE
# crval set to wave_min to catch explicitely set value
wcs.wcs.crval = [wave_min, xi_min] # XIMIN
nlam = int((wave_max - wave_min) / bin_width) + 1
nxi = int((xi_max - xi_min) / pixscale) + 1
# Create interpolation functions if not provided
if interps is None:
logger.info("Computing image interpolations")
interps = make_image_interpolations(hdulist, kx=1, ky=1)
# Create Xi, Lam images (do I need Iarr and Jarr or can I build Xi, Lam directly?)
Iarr, Jarr = np.meshgrid(np.arange(nlam, dtype=np.float32),
np.arange(nxi, dtype=np.float32))
Lam, Xi = wcs.all_pix2world(Iarr, Jarr, 0)
# Make sure that we do have microns
Lam = Lam * u.Unit(wcs.wcs.cunit[0]).to(u.um)
# Convert Xi, Lam to focal plane units
Xarr = self.xilam2x(Xi, Lam)
Yarr = self.xilam2y(Xi, Lam)
rect_spec = np.zeros_like(Xarr, dtype=np.float32)
ihdu = 0
# TODO: make this more iteratory
for hdu in hdulist:
if not isinstance(hdu, fits.ImageHDU):
continue
wcs_fp = WCS(hdu.header, key="D")
n_x = hdu.header["NAXIS1"]
n_y = hdu.header["NAXIS2"]
iarr, jarr = wcs_fp.all_world2pix(Xarr, Yarr, 0)
mask = (iarr > 0) * (iarr < n_x) * (jarr > 0) * (jarr < n_y)
if np.any(mask):
specpart = interps[ihdu](jarr, iarr, grid=False)
rect_spec += specpart * mask
ihdu += 1
header = wcs.to_header()
header["EXTNAME"] = self.trace_id
return fits.ImageHDU(data=rect_spec, header=header)
[docs] def plot(self, wave_min=None, wave_max=None, xi_min=None, xi_max=None, *,
c="r", axes=None, plot_footprint=True, plot_wave=True,
plot_ctrlpnts=True, plot_outline=False, plot_trace_id=False):
"""Plot control points (and/or footprint) of the SpectralTrace.
Parameters
----------
wave_min : float, optional
Minimum wavelength, if any.
wave_max : float, optional
Maximum wavelength, if any.
xi_min : float, optional
Minimum slit, if any.
xi_max : float, optional
Maximum slit, if any.
c : str, optional
Colour, any valid matplotlib colour string. The default is "r".
axes : matplotlib axes, optional
The axes object to use for the plot. If None (default), a new
figure with one axes will be created.
Returns
-------
axes : matplotlib axes
The axes object containing the plot.
Other Parameters
----------------
plot_footprint : bool, optional
Plot a rectangle encompassing all control points, which may be
larger than the area actually covered by the trace, if the trace is
not exactly perpendicular to the detector. The default is True.
plot_wave : bool, optional
Annotate the wavelength points. The default is True.
plot_ctrlpnts : bool, optional
Plot the individual control points as makers. The default is True.
plot_outline : bool, optional
Plot the smallest tetragon encompassing all control points.
The default is False.
plot_trace_id : bool, optional
Write the trace ID in the middle of the trace.
The default is False.
"""
if axes is None:
_, axes = figure_factory()
# Footprint (rectangle enclosing the trace)
xlim, ylim = self.footprint(wave_min=wave_min, wave_max=wave_max)
if xlim is None:
return axes
if plot_footprint:
axes.plot(list(close_loop(xlim)), list(close_loop(ylim)))
# for convenience...
xname = self.meta["x_colname"]
yname = self.meta["y_colname"]
wname = self.meta["wave_colname"]
sname = self.meta["s_colname"]
# Control points
waves = self.table[wname]
if wave_min is None:
wave_min = waves.min()
if wave_max is None:
wave_max = waves.max()
xis = self.table[sname]
if xi_min is None:
xi_min = xis.min()
if xi_max is None:
xi_max = xis.max()
mask = ((waves >= wave_min) & (waves <= wave_max)
& (xis >= xi_min) & (xis <= xi_max))
if sum(mask) <= 2:
return axes
w = waves[mask]
x = self.table[xname][mask]
y = self.table[yname][mask]
if plot_ctrlpnts:
axes.plot(x, y, "o", c=c)
if plot_outline:
blue_end = self.table[mask][w == w.min()]
red_end = self.table[mask][w == w.max()]
blue_end.sort(sname)
red_end.sort(sname)
corners = vstack([blue_end[[0, -1]][xname, yname],
red_end[[-1, 0]][xname, yname],
blue_end[0][xname, yname]])
axes.plot(corners[xname], corners[yname], c=c)
if plot_trace_id:
axes.text(corners[xname][:-1].mean(), corners[yname][:-1].mean(),
self.trace_id, c=c, rotation="vertical",
ha="center", va="center")
for wave in np.unique(w):
xx = x[w == wave]
xx.sort()
dx = xx[-1] - xx[-2]
if plot_wave:
axes.text(x[w == wave].max() + 0.5 * dx,
y[w == wave].mean(),
str(wave), va="center", ha="left")
axes.set_aspect("equal")
return axes
def _set_dispersion(self, wave_min, wave_max, pixsize=None):
"""Compute of dispersion dlam_per_pix along xi=0."""
# ..todo: This may have to be generalised - xi=0 is at the centre
# of METIS slits and the short MICADO slit.
xi = np.array([0] * 1001)
lam = np.linspace(wave_min, wave_max, 1001)
x_mm = self.xilam2x(xi, lam)
y_mm = self.xilam2y(xi, lam)
if self.dispersion_axis == "x":
dlam_grad = self.xy2lam.gradient()[0] # dlam_by_dx
else:
dlam_grad = self.xy2lam.gradient()[1] # dlam_by_dy
pixsize = (from_currsys(self.meta["pixel_scale"], self.cmds) /
from_currsys(self.meta["plate_scale"], self.cmds))
self.dlam_per_pix = interp1d(lam,
dlam_grad(x_mm, y_mm) * pixsize,
fill_value="extrapolate")
def __repr__(self):
return f"{self.__class__.__name__}({self.table!r}, **{self.meta!r})"
def __str__(self):
msg = (f"<SpectralTrace> \"{self.trace_id}\" : "
f"[{self.wave_min:.4f}, {self.wave_max:.4f}]um : "
f"Ext {self.meta['extension_id']} : "
f"Aperture {self.meta['aperture_id']} : "
f"ImagePlane {self.meta['image_plane_id']}")
return msg
[docs]class XiLamImage():
"""
Class to compute a rectified 2D spectrum.
The class produces and holds an image of xi (relative position along
the spatial slit direction) and wavelength lambda.
Parameters
----------
fov : FieldOfView
dlam_per_pix : a 1-D interpolation function from wavelength (in um) to
dispersion (in um/pixel); alternatively a number giving an average
dispersion
"""
def __init__(self, fov, dlam_per_pix):
# ..todo: we assume that we always have a cube. We use SpecCADO's
# add_cube_layer method
cube_wcs = WCS(fov.cube.header, key=" ")
wcs_lam = cube_wcs.sub([3])
d_xi = fov.cube.header["CDELT1"]
d_xi *= u.Unit(fov.cube.header["CUNIT1"]).to(u.arcsec)
d_eta = fov.cube.header["CDELT2"]
d_eta *= u.Unit(fov.cube.header["CUNIT2"]).to(u.arcsec)
d_lam = fov.cube.header["CDELT3"]
d_lam *= u.Unit(fov.cube.header["CUNIT3"]).to(u.um)
# This is based on the cube shape and assumes that the cube's spatial
# dimensions are set by the slit aperture
(n_lam, n_eta, n_xi) = fov.cube.data.shape
# arrays of cube coordinates
cube_xi = d_xi * np.arange(n_xi) + fov.meta["xi_min"].value
cube_eta = d_eta * (np.arange(n_eta) - (n_eta - 1) / 2)
cube_lam = wcs_lam.all_pix2world(np.arange(n_lam), 1)[0]
cube_lam *= u.Unit(wcs_lam.wcs.cunit[0]).to(u.um)
# Initialise the array to hold the xi-lambda image
self.image = np.zeros((n_xi, n_lam), dtype=np.float32)
self.lam = cube_lam
try:
dlam_per_pix_val = dlam_per_pix(np.asarray(self.lam))
except TypeError:
dlam_per_pix_val = dlam_per_pix
logger.warning("Using scalar dlam_per_pix = %.2g",
dlam_per_pix_val)
for i, eta in enumerate(cube_eta):
lam0 = self.lam + dlam_per_pix_val * eta / d_eta
# lam0 is the target wavelength. We need to check that this
# overlaps with the wavelength range covered by the cube
if lam0.min() < cube_lam.max() and lam0.max() > cube_lam.min():
plane = fov.cube.data[:, i, :].T
plane_interp = RectBivariateSpline(cube_xi, cube_lam, plane,
kx=1, ky=1)
self.image += plane_interp(cube_xi, lam0)
self.image *= d_eta # ph/s/um/arcsec2 --> ph/s/um/arcsec
# WCS for the xi-lambda image, i.e. the rectified 2D spectrum
# Default WCS with xi in arcsec
self.wcs = WCS(naxis=2)
self.wcs.wcs.crpix = [1, 1]
self.wcs.wcs.crval = [self.lam[0], fov.meta["xi_min"].value]
self.wcs.wcs.pc = [[1, 0], [0, 1]]
self.wcs.wcs.cdelt = [d_lam, d_xi]
self.wcs.wcs.ctype = ["LINEAR", "LINEAR"]
self.wcs.wcs.cname = ["WAVELEN", "SLITPOS"]
self.wcs.wcs.cunit = ["um", "arcsec"]
# Alternative: xi = [0, 1], dimensionless
self.wcsa = WCS(naxis=2)
self.wcsa.wcs.crpix = [1, 1]
self.wcsa.wcs.crval = [self.lam[0], 0]
self.wcsa.wcs.pc = [[1, 0], [0, 1]]
self.wcsa.wcs.cdelt = [d_lam, 1./n_xi]
self.wcsa.wcs.ctype = ["LINEAR", "LINEAR"]
self.wcsa.wcs.cname = ["WAVELEN", "SLITPOS"]
self.wcs.wcs.cunit = ["um", ""]
self.xi = self.wcs.all_pix2world(self.lam[0], np.arange(n_xi), 0)[1]
self.npix_xi = n_xi
self.npix_lam = n_lam
# ..todo: cubic spline introduces negative values, linear does not.
# Alternative might be to cubic-spline interpolate on sqrt(image),
# with subsequent squaring of the result. This would require
# wrapping RectBivariateSpline in a new (sub)class.
spline_order = (1, 1)
self.interp = RectBivariateSpline(self.xi, self.lam, self.image,
kx=spline_order[0],
ky=spline_order[1])
[docs]def fit2matrix(fit):
"""
Return coefficients from a polynomial fit as a matrix.
The Polynomial2D fits of degree n have coefficients for all i, j
with i + j <= n.
How would one rearrange those?
"""
coeffs = dict(zip(fit.param_names, fit.parameters))
deg = fit.degree
mat = np.zeros((deg + 1, deg + 1), dtype=np.float32)
for i in range(deg + 1):
for j in range(deg + 1):
try:
mat[j, i] = coeffs[f"c{i}_{j}"]
except KeyError:
pass
return mat
# ..todo: should the next three functions be combined and return a dictionary of fits?
[docs]def xilam2xy_fit(layout, params):
"""
Determine polynomial fits of FPA position.
Fits are of degree 4 as a function of slit position and wavelength.
"""
xi_arr = layout[params["s_colname"]]
lam_arr = layout[params["wave_colname"]]
x_arr = layout[params["x_colname"]]
y_arr = layout[params["y_colname"]]
# Filter the lists: remove any points with x==0
# ..todo: this may not be necessary after sanitising the table
# good = x != 0
# xi = xi[good]
# lam = lam[good]
# x = x[good]
# y = y[good]
# compute the fits
pinit_x = Polynomial2D(degree=4)
pinit_y = Polynomial2D(degree=4)
fitter = fitting.LinearLSQFitter()
xilam2x = fitter(pinit_x, xi_arr, lam_arr, x_arr)
xilam2y = fitter(pinit_y, xi_arr, lam_arr, y_arr)
return xilam2x, xilam2y
[docs]def xy2xilam_fit(layout, params):
"""
Determine polynomial fits of wavelength/slit position.
Fits are of degree 4 as a function of focal plane position
"""
xi_arr = layout[params["s_colname"]]
lam_arr = layout[params["wave_colname"]]
x_arr = layout[params["x_colname"]]
y_arr = layout[params["y_colname"]]
pinit_xi = Polynomial2D(degree=4)
pinit_lam = Polynomial2D(degree=4)
fitter = fitting.LinearLSQFitter()
xy2xi = fitter(pinit_xi, x_arr, y_arr, xi_arr)
xy2lam = fitter(pinit_lam, x_arr, y_arr, lam_arr)
return xy2xi, xy2lam
def _xiy2xlam_fit(layout, params):
"""Determine polynomial fits of wavelength/slit position.
Fits are of degree 4 as a function of focal plane position
"""
# These are helper functions to allow fitting of left/right edges
# for the purpose of checking whether a trace is on a chip or not.
xi_arr = layout[params["s_colname"]]
lam_arr = layout[params["wave_colname"]]
x_arr = layout[params["x_colname"]]
y_arr = layout[params["y_colname"]]
pinit_x = Polynomial2D(degree=4)
pinit_lam = Polynomial2D(degree=4)
fitter = fitting.LinearLSQFitter()
xiy2x = fitter(pinit_x, xi_arr, y_arr, x_arr)
xiy2lam = fitter(pinit_lam, xi_arr, y_arr, lam_arr)
return xiy2x, xiy2lam
[docs]def make_image_interpolations(hdulist, **kwargs):
"""Create 2D interpolation functions for images."""
interps = []
for hdu in hdulist:
if isinstance(hdu, fits.ImageHDU):
interps.append(
RectBivariateSpline(np.arange(hdu.header["NAXIS1"]),
np.arange(hdu.header["NAXIS2"]),
hdu.data, **kwargs)
)
return interps
# ..todo: Check whether the following functions are actually used
[docs]def fill_zeros(x):
"""Fill in zeros in a sequence with the previous non-zero number."""
for i in range(1, len(x)):
if x[i] == 0:
x[i] = x[i-1]
return x
[docs]def get_affine_parameters(coords):
"""
Return rotation and shear for each MTC point along a ``SpectralTrace``.
.. note: Restrictions of this method:
* only uses the left most coordinates for the shear
* rotation angle is calculated using the trace extremes
Parameters
----------
coords : dict of 2D arrays
Each dict entry ["x", "y", "s"] contains a [N, M] 2D array of
coordinates, where:
* N is the number of points along the slit (e.g. ~5), and
* M is the number of positions along the trace (e.g. >100)
Returns
-------
rotations : array
[deg] Rotation angles for M positions along the Trace
shears : array
[deg] Shear angles for M positions along the Trace
"""
rad2deg = 180 / np.pi
dxs = coords["x"][-1, :] - coords["x"][0, :]
dys = coords["y"][-1, :] - coords["y"][0, :]
rotations = np.arctan2(dys, dxs) * rad2deg
dxs = np.diff(coords["x"], axis=1)
dys = np.diff(coords["y"], axis=1)
shears = np.array([np.arctan2(dys[i], dxs[i])
for i in range(dxs.shape[0])])
shears = np.array(list(shears.T) + [shears.T[-1]]).T
shears = (np.average(shears, axis=0) * rad2deg) - (90 + rotations)
return rotations, shears