import numpy as np
from scipy.signal import convolve
from scipy.interpolate import RectBivariateSpline
from astropy import units as u
from astropy.io import fits
from astropy.convolution import Gaussian2DKernel
from astropy.wcs import WCS
import anisocado as aniso
from .effects import Effect
from . import ter_curves_utils as tu
from . import psf_utils as pu
from ..base_classes import ImagePlaneBase, FieldOfViewBase, FOVSetupBase
from ..utils import (figure_grid_factory, from_currsys, quantify, figure_factory,
check_keys, get_logger)
import warnings
# TODO: directly import currsys stuff, replace utils.
# DONE (KL 19.01.2024)
logger = get_logger(__name__)
[docs]class PoorMansFOV:
def __init__(self, pixel_scale, spec_dict, recursion_call=False):
self.header = {"CDELT1": pixel_scale / 3600.,
"CDELT2": pixel_scale / 3600.,
"NAXIS": 2,
"NAXIS1": 128,
"NAXIS2": 128,
}
self.meta = spec_dict
self.wavelength = spec_dict["wave_mid"] * u.um
if not recursion_call:
self.hdu = PoorMansFOV(pixel_scale, recursion_call=True)
[docs]class PSF(Effect):
def __init__(self, **kwargs):
self.kernel = None
self.valid_waverange = None
self._waveset = []
super().__init__(**kwargs)
params = {
"flux_accuracy": "!SIM.computing.flux_accuracy",
"sub_pixel_flag": "!SIM.sub_pixel.flag",
"z_order": [40, 640],
"convolve_mode": "same", # "full", "same"
"bkg_width": -1,
"wave_key": "WAVE0",
"normalise_kernel": True,
"rotational_blur_angle": 0,
"report_plot_include": True,
"report_table_include": False,
}
self.meta.update(params)
self.meta.update(kwargs)
self.meta = from_currsys(self.meta, self.cmds)
self.convolution_classes = (FieldOfViewBase, ImagePlaneBase)
[docs] def apply_to(self, obj, **kwargs):
"""Apply the PSF."""
# 1. During setup of the FieldOfViews
if isinstance(obj, FOVSetupBase) and self._waveset is not None:
waveset = self._waveset
if len(waveset) != 0:
waveset_edges = 0.5 * (waveset[:-1] + waveset[1:])
obj.split("wave", quantify(waveset_edges, u.um).value)
# 2. During observe: convolution
elif isinstance(obj, self.convolution_classes):
if ((hasattr(obj, "fields") and len(obj.fields) > 0) or
(obj.hdu is not None)):
kernel = self.get_kernel(obj).astype(float)
# apply rotational blur for field-tracking observations
rot_blur_angle = self.meta["rotational_blur_angle"]
if abs(rot_blur_angle) > 0:
# makes a copy of kernel
kernel = pu.rotational_blur(kernel, rot_blur_angle)
# normalise psf kernel KERNEL SHOULD BE normalised within get_kernel()
# if from_currsys(self.meta["normalise_kernel"], self.cmds):
# kernel /= np.sum(kernel)
# kernel[kernel < 0.] = 0.
image = obj.hdu.data.astype(float)
# subtract background level before convolving, re-add afterwards
bkg_level = pu.get_bkg_level(image, self.meta["bkg_width"])
# do the convolution
mode = from_currsys(self.meta["convolve_mode"], self.cmds)
if image.ndim == 2 and kernel.ndim == 2:
new_image = convolve(image - bkg_level, kernel, mode=mode)
elif image.ndim == 3 and kernel.ndim == 2:
kernel = kernel[None, :, :]
bkg_level = bkg_level[:, None, None]
new_image = convolve(image - bkg_level, kernel, mode=mode)
elif image.ndim == 3 and kernel.ndim == 3:
bkg_level = bkg_level[:, None, None]
new_image = np.zeros(image.shape) # assumes mode="same"
for iplane in range(image.shape[0]):
new_image[iplane,] = convolve(
image[iplane,] - bkg_level[iplane,],
kernel[iplane,], mode=mode)
obj.hdu.data = new_image + bkg_level
# ..todo: careful with which dimensions mean what
d_x = new_image.shape[-1] - image.shape[-1]
d_y = new_image.shape[-2] - image.shape[-2]
for wcsid in ["", "D"]:
if "CRPIX1" + wcsid in obj.hdu.header:
obj.hdu.header["CRPIX1" + wcsid] += d_x / 2
obj.hdu.header["CRPIX2" + wcsid] += d_y / 2
return obj
[docs] def fov_grid(self, which="waveset", **kwargs):
"""See parent docstring."""
waveset = []
if which == "waveset":
if self._waveset is not None:
_waveset = self._waveset
waves = 0.5 * (np.array(_waveset)[1:] +
np.array(_waveset)[:-1])
wave_min = kwargs.get("wave_min", np.min(_waveset))
wave_max = kwargs.get("wave_max", np.max(_waveset))
mask = (wave_min < waves) * (waves < wave_max)
waveset = np.unique([wave_min] + list(waves[mask]) +
[wave_max])
return waveset
[docs] def get_kernel(self, obj):
self.valid_waverange = None
if self.kernel is None:
self.kernel = np.ones((1, 1))
return self.kernel
[docs] def plot(self, obj=None, **kwargs):
from matplotlib.colors import LogNorm
fig, axes = figure_factory()
kernel = self.get_kernel(obj)
axes.imshow(kernel, norm=LogNorm(), origin="lower", **kwargs)
return fig
##############################################################################
# Analytical PSFs - Vibration, Seeing, NCPAs
[docs]class AnalyticalPSF(PSF):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.meta["z_order"] = [41, 641]
self.convolution_classes = FieldOfViewBase
[docs]class Vibration(AnalyticalPSF):
"""Creates a wavelength independent kernel image."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.meta["z_order"] = [244, 744]
self.meta["width_n_fwhms"] = 4
self.convolution_classes = ImagePlaneBase
self.required_keys = ["fwhm", "pixel_scale"]
check_keys(self.meta, self.required_keys, action="error")
self.kernel = None
[docs] def get_kernel(self, obj):
if self.kernel is None:
from_currsys(self.meta, self.cmds)
fwhm_pix = self.meta["fwhm"] / self.meta["pixel_scale"]
sigma = fwhm_pix / 2.35
width = max(1, int(fwhm_pix * self.meta["width_n_fwhms"]))
self.kernel = Gaussian2DKernel(sigma, x_size=width, y_size=width,
mode="center").array
self.kernel /= np.sum(self.kernel)
return self.kernel.astype(float)
[docs]class NonCommonPathAberration(AnalyticalPSF):
"""
TBA.
Needed: pixel_scale
Accepted: kernel_width, strehl_drift
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.meta["z_order"] = [241, 641]
self.meta["kernel_width"] = None
self.meta["strehl_drift"] = 0.02
self.meta["wave_min"] = "!SIM.spectral.wave_min"
self.meta["wave_max"] = "!SIM.spectral.wave_max"
self._total_wfe = None
self.valid_waverange = [0.1 * u.um, 0.2 * u.um]
self.convolution_classes = FieldOfViewBase
self.required_keys = ["pixel_scale"]
check_keys(self.meta, self.required_keys, action="error")
[docs] def fov_grid(self, which="waveset", **kwargs):
"""See parent docstring."""
warnings.warn("The fov_grid method is deprecated and will be removed "
"in a future release.", DeprecationWarning, stacklevel=2)
if which == "waveset":
self.meta.update(kwargs)
self.meta = from_currsys(self.meta, self.cmds)
min_sr = pu.wfe2strehl(self.total_wfe, self.meta["wave_min"])
max_sr = pu.wfe2strehl(self.total_wfe, self.meta["wave_max"])
srs = np.arange(min_sr, max_sr, self.meta["strehl_drift"])
waves = 6.2831853 * self.total_wfe * (-np.log(srs))**-0.5
waves = quantify(waves, u.um).value
waves = (list(waves) + [self.meta["wave_max"]]) * u.um
else:
waves = [] * u.um
return waves
[docs] def get_kernel(self, obj):
waves = obj.meta["wave_min"], obj.meta["wave_max"]
old_waves = self.valid_waverange
wave_mid_old = 0.5 * (old_waves[0] + old_waves[1])
wave_mid_new = 0.5 * (waves[0] + waves[1])
strehl_old = pu.wfe2strehl(wfe=self.total_wfe, wave=wave_mid_old)
strehl_new = pu.wfe2strehl(wfe=self.total_wfe, wave=wave_mid_new)
if np.abs(1 - strehl_old / strehl_new) > self.meta["strehl_drift"]:
self.valid_waverange = waves
self.kernel = pu.wfe2gauss(wfe=self.total_wfe, wave=wave_mid_new,
width=self.meta["kernel_width"])
self.kernel /= np.sum(self.kernel)
return self.kernel
@property
def total_wfe(self):
if self._total_wfe is None:
if self.table is not None:
self._total_wfe = pu.get_total_wfe_from_table(self.table)
else:
self._total_wfe = 0
return self._total_wfe
[docs] def plot(self):
fig, axes = figure_factory()
wave_min, wave_max = from_currsys([self.meta["wave_min"],
self.meta["wave_max"]], self.cmds)
waves = np.linspace(wave_min, wave_max, 1001) * u.um
wfe = self.total_wfe
strehl = pu.wfe2strehl(wfe=wfe, wave=waves)
axes.plot(waves, strehl)
axes.set_xlabel(f"Wavelength [{waves.unit}]")
axes.set_ylabel(f"Strehl Ratio \n[Total WFE = {wfe}]")
return fig
[docs]class SeeingPSF(AnalyticalPSF):
"""
Currently only returns gaussian kernel with a ``fwhm`` [arcsec].
Parameters
----------
fwhm : flaot
[arcsec]
"""
def __init__(self, fwhm=1.5, **kwargs):
super().__init__(**kwargs)
self.meta["fwhm"] = fwhm
self.meta["z_order"] = [242, 642]
[docs] def get_kernel(self, fov):
# called by .apply_to() from the base PSF class
pixel_scale = fov.header["CDELT1"] * u.deg.to(u.arcsec)
pixel_scale = quantify(pixel_scale, u.arcsec)
# add in the conversion to fwhm from seeing and wavelength here
fwhm = from_currsys(self.meta["fwhm"], self.cmds) * u.arcsec / pixel_scale
sigma = fwhm.value / 2.35
kernel = Gaussian2DKernel(sigma, mode="center").array
kernel /= np.sum(kernel)
return kernel
[docs] def plot(self):
pixel_scale = from_currsys("!INST.pixel_scale", self.cmds)
spec_dict = from_currsys("!SIM.spectral", self.cmds)
return super().plot(PoorMansFOV(pixel_scale, spec_dict))
[docs]class GaussianDiffractionPSF(AnalyticalPSF):
def __init__(self, diameter, **kwargs):
super().__init__(**kwargs)
self.meta["diameter"] = diameter
self.meta["z_order"] = [242, 642]
[docs] def fov_grid(self, which="waveset", **kwargs):
"""See parent docstring."""
warnings.warn("The fov_grid method is deprecated and will be removed "
"in a future release.", DeprecationWarning, stacklevel=2)
wavelengths = []
if which == "waveset" and \
"waverange" in kwargs and \
"pixel_scale" in kwargs:
waverange = quantify(kwargs["waverange"], u.um)
diameter = quantify(self.meta["diameter"], u.m).to(u.um)
fwhm = 1.22 * (waverange / diameter).value # in rad
pixel_scale = quantify(kwargs["pixel_scale"], u.deg)
pixel_scale = pixel_scale.to(u.rad).value
fwhm_range = np.arange(fwhm[0], fwhm[1], pixel_scale)
wavelengths = list(fwhm_range / 1.22 * diameter.to(u.m))
# .. todo: check that this is actually correct
return wavelengths
[docs] def update(self, **kwargs):
if "diameter" in kwargs:
self.meta["diameter"] = kwargs["diameter"]
[docs] def get_kernel(self, fov):
# called by .apply_to() from the base PSF class
pixel_scale = fov.header["CDELT1"] * u.deg.to(u.arcsec)
pixel_scale = quantify(pixel_scale, u.arcsec)
wave = 0.5 * (fov.meta["wave_max"] + fov.meta["wave_min"])
wave = quantify(wave, u.um)
diameter = quantify(self.meta["diameter"], u.m).to(u.um)
fwhm = 1.22 * (wave / diameter) * u.rad.to(u.arcsec) / pixel_scale
sigma = fwhm.value / 2.35
kernel = Gaussian2DKernel(sigma, mode="center").array
kernel /= np.sum(kernel)
return kernel
[docs] def plot(self):
pixel_scale = from_currsys("!INST.pixel_scale", self.cmds)
spec_dict = from_currsys("!SIM.spectral", self.cmds)
return super().plot(PoorMansFOV(pixel_scale, spec_dict))
##############################################################################
# Semi-analytical PSFs - AnisoCADO PSFs
[docs]class SemiAnalyticalPSF(PSF):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.meta["z_order"] = [42]
self.convolution_classes = FieldOfViewBase
# self.convolution_classes = ImagePlaneBase
[docs]class AnisocadoConstPSF(SemiAnalyticalPSF):
"""
Makes a SCAO on-axis PSF with a desired Strehl ratio at a given wavelength.
To make the PSFs a map connecting Strehl, Wavelength, and residual
wavefront error is required.
Parameters
----------
filename : str
Path to Strehl map with axes (x, y) = (wavelength, wavefront error).
strehl : float
Desired Strehl ratio. Either percentage [1, 100] or fractional
[1e-3, 1].
wavelength : float
[um] The given strehl is valid for this wavelength.
psf_side_length : int
[pixel] Default is 512. Side length of the kernel images.
offset : tuple
[arcsec] SCAO guide star offset from centre (dx, dy).
rounded_edges : bool
Default is True. Sets all halo values below a threshold to zero.
The threshold is determined from the max values of the edge rows of the
kernel image.
Other Parameters
----------------
convolve_mode : str
["same", "full"] convolution keywords from scipy.signal.convolve
Examples
--------
Add an AnisocadoConstPSF with code::
from scopesim.effects import AnisocadoConstPSF
psf = AnisocadoConstPSF(filename="test_AnisoCADO_rms_map.fits",
strehl=0.5,
wavelength=2.15,
convolve_mode="same",
psf_side_length=512)
Add an AnisocadoConstPSF to a yaml file::
effects:
- name: Ks_Stehl_40_PSF
description: A 40% Strehl PSF over the field of view
class: AnisocadoConstPSF
kwargs:
filename: "test_AnisoCADO_rms_map.fits"
strehl: 0.5
wavelength: 2.15
convolve_mode: full
psf_side_length: 512
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
params = {
"z_order": [42, 652],
"psf_side_length": 512,
"offset": (0, 0),
"rounded_edges": True,
}
self.meta.update(params)
self.meta.update(kwargs)
self.required_keys = ["filename", "strehl", "wavelength"]
check_keys(self.meta, self.required_keys, action="error")
self.nmRms # check to see if it throws an error
self._psf_object = None
self._kernel = None
[docs] def get_kernel(self, fov):
# called by .apply_to() from the base PSF class
if self._kernel is None:
if isinstance(fov, FieldOfViewBase):
pixel_scale = fov.header["CDELT1"] * u.deg.to(u.arcsec)
elif isinstance(fov, float):
pixel_scale = fov
n = self.meta["psf_side_length"]
wave = self.wavelength
self._psf_object = aniso.AnalyticalScaoPsf(pixelSize=pixel_scale,
N=n, wavelength=wave,
nmRms=self.nmRms)
if np.any(self.meta["offset"]):
self._psf_object.shift_off_axis(self.meta["offset"][0],
self.meta["offset"][1])
self._kernel = self._psf_object.psf_latest
self._kernel /= np.sum(self._kernel)
if self.meta["rounded_edges"]:
self._kernel = pu.round_kernel_edges(self._kernel)
return self._kernel
[docs] def remake_kernel(self, x):
"""
Remake the kernel based on either a pixel_scale of FieldOfView.
Parameters
----------
x: float, FieldOfView
[um] if float
"""
self._kernel = None
return self.get_kernel(x)
@property
def wavelength(self):
wave = from_currsys(self.meta["wavelength"], self.cmds)
if isinstance(wave, str) and wave in tu.FILTER_DEFAULTS:
wave = tu.get_filter_effective_wavelength(wave)
wave = quantify(wave, u.um).value
return wave
@property
def strehl_ratio(self):
strehl = None
if self._psf_object is not None:
strehl = self._psf_object.strehl_ratio
return strehl
@property
def nmRms(self):
strehl = from_currsys(self.meta["strehl"], self.cmds)
wave = self.wavelength
hdu = self._file[0]
nm_rms = pu.nmrms_from_strehl_and_wavelength(strehl, wave, hdu)
return nm_rms
[docs] def plot(self, obj=None, **kwargs):
from matplotlib.colors import LogNorm
fig, gs = figure_grid_factory(
2, 2, height_ratios=(3, 2),
left=0.3, right=0.7, bottom=0.15, top=0.85,
wspace=0.05, hspace=0.05)
# or no height_ratios and bottom=0.1, top=0.9
pixel_scale = from_currsys("!INST.pixel_scale", self.cmds)
kernel = self.get_kernel(pixel_scale)
ax = fig.add_subplot(gs[0, 0])
im = kernel
r_sky = pixel_scale * im.shape[0]
ax.imshow(im, norm=LogNorm(), origin="lower",
extent=[-r_sky, r_sky, -r_sky, r_sky], **kwargs)
ax.set_aspect("equal")
ax.set_xlabel("[arcsec]")
ax.set_ylabel("[arcsec]")
ax.xaxis.set_ticks_position("top")
ax.xaxis.set_label_position("top")
ax = fig.add_subplot(gs[0, 1])
x = kernel.shape[1] // 2
y = kernel.shape[0] // 2
r = 16
im = kernel[y-r:y+r, x-r:x+r]
r_sky = pixel_scale * im.shape[0]
ax.imshow(im, norm=LogNorm(), origin="lower",
extent=[-r_sky, r_sky, -r_sky, r_sky], **kwargs)
ax.set_aspect("equal")
ax.set_xlabel("[arcsec]")
ax.set_ylabel("[arcsec]")
ax.xaxis.set_ticks_position("top")
ax.xaxis.set_label_position("top")
ax.yaxis.set_ticks_position("right")
ax.yaxis.set_label_position("right")
ax = fig.add_subplot(gs[1, :])
hdr = self._file[0].header
data = self._file[0].data
wfes = np.arange(hdr["NAXIS1"]) * hdr["CDELT1"] + hdr["CRVAL1"]
waves = np.arange(hdr["NAXIS2"]) * hdr["CDELT2"] + hdr["CRVAL2"]
# TODO: Get unit dynamically? Then again, it's hardcoded elsewhere in
# this module...
unit_str = u.Unit("um").to_string("latex")
for strehl, wav in reversed(list(zip(data, waves))):
ax.plot(wfes, strehl, label=f"{wav:.3f} {unit_str}")
ax.set_xlabel(f"RMS Wavefront Error [{unit_str}]")
ax.set_ylabel("Strehl Ratio")
ax.legend()
fig.align_labels()
return fig
################################################################################
# Discrete PSFs - MORFEO and co PSFs
[docs]class DiscretePSF(PSF):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.meta["z_order"] = [43]
self.convolution_classes = FieldOfViewBase
# self.convolution_classes = ImagePlaneBase
[docs]class FieldConstantPSF(DiscretePSF):
"""A PSF that is constant across the field.
For spectroscopy, a wavelength-dependent PSF cube is built, where for each
wavelength the reference PSF is scaled proportional to wavelength.
"""
def __init__(self, **kwargs):
# sub_pixel_flag and flux_accuracy are taken care of in PSF base class
super().__init__(**kwargs)
self.required_keys = ["filename"]
check_keys(self.meta, self.required_keys, action="error")
self.meta["z_order"] = [262, 662]
self._waveset, self.kernel_indexes = pu.get_psf_wave_exts(
self._file, self.meta["wave_key"])
self.current_layer_id = None
self.current_ext = None
self.current_data = None
self.kernel = None
[docs] def get_kernel(self, fov):
"""Find nearest wavelength and build PSF kernel from file"""
idx = pu.nearest_index(fov.wavelength, self._waveset)
ext = self.kernel_indexes[idx]
if ext != self.current_layer_id:
if fov.hdu.header["NAXIS"] == 3:
self.current_layer_id = ext
self.make_psf_cube(fov)
else:
self.kernel = self._file[ext].data
self.current_layer_id = ext
hdr = self._file[ext].header
self.kernel /= np.sum(self.kernel)
# compare kernel and fov pixel scales, rescale if needed
if "CUNIT1" in hdr:
unit_factor = u.Unit(hdr["CUNIT1"].lower()).to(u.deg)
else:
unit_factor = 1
kernel_pixel_scale = hdr["CDELT1"] * unit_factor
fov_pixel_scale = fov.header["CDELT1"]
# rescaling kept inside loop to avoid rescaling for every fov
pix_ratio = kernel_pixel_scale / fov_pixel_scale
if abs(pix_ratio - 1) > self.meta["flux_accuracy"]:
self.kernel = pu.rescale_kernel(self.kernel, pix_ratio)
if ((fov.header["NAXIS1"] < hdr["NAXIS1"]) or
(fov.header["NAXIS2"] < hdr["NAXIS2"])):
self.kernel = pu.cutout_kernel(self.kernel, fov.header,
kernel_header=hdr)
return self.kernel
[docs] def make_psf_cube(self, fov):
"""Create a wavelength-dependent psf cube"""
# Some data from the fov
nxfov, nyfov = fov.hdu.header["NAXIS1"], fov.hdu.header["NAXIS2"]
fov_pixel_scale = fov.hdu.header["CDELT1"]
fov_pixel_unit = fov.hdu.header["CUNIT1"].lower()
lam = fov.hdu.header["CDELT3"] * (1 + np.arange(fov.hdu.header["NAXIS3"])
- fov.hdu.header["CRPIX3"]) \
+ fov.hdu.header["CRVAL3"]
# adapt the size of the output cube to the FOV's spatial shape
nxpsf = min(512, 2 * nxfov + 1)
nypsf = min(512, 2 * nyfov + 1)
# Some data from the psf file
ext = self.current_layer_id
hdr = self._file[ext].header
refwave = hdr[self.meta["wave_key"]]
if "CUNIT1" in hdr:
unit_factor = u.Unit(hdr["CUNIT1"].lower()).to(u.Unit(fov_pixel_unit))
else:
unit_factor = 1
ref_pixel_scale = hdr["CDELT1"] * unit_factor
psfwcs = WCS(hdr)
psf = self._file[ext].data
psf = psf/psf.sum() # normalisation of the input psf
nxin, nyin = psf.shape
# We need linear interpolation to preserve positivity. Might think of
# more elaborate positivity-preserving schemes.
# Note: According to some basic profiling, this line is one of the
# single largest hits on performance.
ipsf = RectBivariateSpline(np.arange(nyin), np.arange(nxin), psf,
kx=1, ky=1)
xcube, ycube = np.meshgrid(np.arange(nxpsf), np.arange(nypsf))
cubewcs = WCS(naxis=2)
cubewcs.wcs.ctype = ["LINEAR", "LINEAR"]
cubewcs.wcs.crval = [0., 0.]
cubewcs.wcs.crpix = [(nxpsf + 1) / 2, (nypsf + 1) / 2]
cubewcs.wcs.cdelt = [fov_pixel_scale, fov_pixel_scale]
cubewcs.wcs.cunit = [fov_pixel_unit, fov_pixel_unit]
xworld, yworld = cubewcs.all_pix2world(xcube, ycube, 1)
outcube = np.zeros((lam.shape[0], nypsf, nxpsf), dtype=np.float32)
for i, wave in enumerate(lam):
psf_wave_pixscale = ref_pixel_scale * wave / refwave
psfwcs.wcs.cdelt = [psf_wave_pixscale,
psf_wave_pixscale]
xpsf, ypsf = psfwcs.all_world2pix(xworld, yworld, 0)
outcube[i,] = (ipsf(ypsf, xpsf, grid=False)
* fov_pixel_scale**2 / psf_wave_pixscale**2)
self.kernel = outcube.reshape((lam.shape[0], nypsf, nxpsf))
# fits.writeto("test_psfcube.fits", data=self.kernel, overwrite=True)
[docs] def plot(self):
pixel_scale = from_currsys("!INST.pixel_scale", self.cmds)
spec_dict = from_currsys("!SIM.spectral", self.cmds)
return super().plot(PoorMansFOV(pixel_scale, spec_dict))
[docs]class FieldVaryingPSF(DiscretePSF):
"""
TBA.
Parameters
----------
sub_pixel_flag : bool, optional
flux_accuracy : float, optional
Default 1e-3. Level of flux conservation during rescaling of kernel
"""
def __init__(self, **kwargs):
# sub_pixel_flag and flux_accuracy are taken care of in PSF base class
super().__init__(**kwargs)
self.required_keys = ["filename"]
check_keys(self.meta, self.required_keys, action="error")
self.meta["z_order"] = [261, 661]
ws, ki = pu.get_psf_wave_exts(self._file, self.meta["wave_key"])
self._waveset, self.kernel_indexes = ws, ki
self.current_ext = None
self.current_data = None
self._strehl_imagehdu = None
[docs] def apply_to(self, fov, **kwargs):
"""See parent docstring."""
# .. todo: add in field rotation
# .. todo: add in 3D cubes
# accept "full", "dit", "none"
# check if there are any fov.fields to apply a psf to
if isinstance(fov, FieldOfViewBase):
if len(fov.fields) > 0:
if fov.image is None:
fov.image = fov.make_image_hdu().data
old_shape = fov.image.shape
# Get kernels that cover this fov, and their respective masks.
# Kernels and masks returned by .get_kernel as list of tuples.
canvas = None
kernels_masks = self.get_kernel(fov)
for kernel, mask in kernels_masks:
# renormalise the kernel if needs be
kernel[kernel < 0.] = 0.
sum_kernel = np.sum(kernel)
if abs(sum_kernel - 1) > self.meta["flux_accuracy"]:
kernel /= sum_kernel
# image convolution
image = fov.image.astype(float)
kernel = kernel.astype(float)
new_image = convolve(image, kernel, mode="same")
if canvas is None:
canvas = np.zeros(new_image.shape)
# mask convolution + combine with convolved image
if mask is not None:
new_mask = convolve(mask, kernel, mode="same")
canvas += new_image * new_mask
else:
canvas = new_image
# reset WCS header info
new_shape = canvas.shape
fov.image = canvas
# ..todo: careful with which dimensions mean what
if "CRPIX1" in fov.header:
fov.header["CRPIX1"] += (new_shape[0] - old_shape[0]) / 2
fov.header["CRPIX2"] += (new_shape[1] - old_shape[1]) / 2
if "CRPIX1D" in fov.header:
fov.header["CRPIX1D"] += (new_shape[0] - old_shape[0]) / 2
fov.header["CRPIX2D"] += (new_shape[1] - old_shape[1]) / 2
return fov
[docs] def get_kernel(self, fov):
# 0. get file extension
# 1. pull out strehl map for fov header
# 2. get number of unique psfs
# 3. pull out those psfs
# 4. if more than one, make masks for the fov on the fov pixel scale
# 5. make list of tuples with kernel and mask
# find which file extension to use - keep pointer in self.current_data
fov_wave = 0.5 * (fov.meta["wave_min"] + fov.meta["wave_max"])
jj = pu.nearest_index(fov_wave, self._waveset)
ext = self.kernel_indexes[jj]
if ext != self.current_ext:
self.current_ext = ext
self.current_data = self._file[ext].data
# compare the fov and psf pixel scales
kernel_pixel_scale = self._file[ext].header["CDELT1"]
fov_pixel_scale = fov.header["CDELT1"]
# get the spatial map of the kernel cube layers
strl_hdu = self.strehl_imagehdu
strl_cutout = pu.get_strehl_cutout(fov.header, strl_hdu)
# get the kernels and mask that fit inside the fov boundaries
layer_ids = np.round(np.unique(strl_cutout.data)).astype(int)
if len(layer_ids) > 1:
kernels = [self.current_data[ii] for ii in layer_ids]
masks = [strl_cutout.data == ii for ii in layer_ids]
self.kernel = [[krnl, msk] for krnl, msk in zip(kernels, masks)]
else:
self.kernel = [[self.current_data[layer_ids[0]], None]]
# .. todo: re-scale kernel and masks to pixel_scale of FOV
# .. todo: can this be put somewhere else to save on iterations?
# .. todo: should the mask also be rescaled?
# rescale the pixel scale of the kernel to match the fov images
pix_ratio = fov_pixel_scale / kernel_pixel_scale
if abs(pix_ratio - 1) > self.meta["flux_accuracy"]:
for ii, kern in enumerate(self.kernel):
self.kernel[ii][0] = pu.rescale_kernel(kern[0], pix_ratio)
for i, kern in enumerate(self.kernel):
self.kernel[i][0] /= np.sum(kern[0])
return self.kernel
@property
def strehl_imagehdu(self):
"""The HDU containing the positional info for kernel layers."""
if self._strehl_imagehdu is None:
ecat = self._file[0].header["ECAT"]
if isinstance(self._file[ecat], fits.ImageHDU):
self._strehl_imagehdu = self._file[ecat]
# ..todo: impliment this case
elif isinstance(self._file[ecat], fits.BinTableHDU):
cat = self._file[ecat]
self._strehl_imagehdu = pu.make_strehl_map_from_table(cat)
return self._strehl_imagehdu
[docs] def plot(self):
pixel_scale = from_currsys("!INST.pixel_scale", self.cmds)
spec_dict = from_currsys("!SIM.spectral", self.cmds)
return super().plot(PoorMansFOV(pixel_scale, spec_dict))