from copy import deepcopy
from itertools import product
from more_itertools import pairwise
import numpy as np
from astropy import units as u
from . import image_plane_utils as imp_utils
from .fov import FieldOfView
from .. import effects as efs
from ..effects.effects_utils import get_all_effects
from ..utils import check_keys, get_logger
# TODO: Where are all these functions used??
logger = get_logger(__name__)
[docs]def get_3d_shifts(effects, **kwargs):
"""
Returns the total 3D shifts (x,y,lam) from a series of Shift3D objects
Parameters
----------
effects : list of Shift3D effects
Returns
-------
shift_dict : dict
returns the x, y shifts for each wavelength in the fov_grid, where
fov_grid contains the edge wavelengths for each spectral layer
Notes
-----
Units returned by fov_grid():
- wavelength: [um]
- x_shift, y_shift: [deg]
"""
required_keys = ["wave_min", "wave_mid", "wave_max",
"sub_pixel_fraction", "pixel_scale"]
check_keys(kwargs, required_keys, action="warning")
effects = get_all_effects(effects, efs.Shift3D)
if len(effects) > 0:
# shifts = [[waves], [x_shifts], [y_shifts]]
shifts = [eff.fov_grid(which="shifts", **kwargs) for eff in effects]
old_bin_edges = [shift[0] for shift in shifts if len(shift[0]) >= 2]
# TODO: could this use combine_wavesets?
new_bin_edges = np.unique(np.sort(np.concatenate(old_bin_edges),
kind="stable"))
# TODO: could this be zeros_like?
x_shifts = np.zeros(len(new_bin_edges))
y_shifts = np.zeros(len(new_bin_edges))
# .. todo:: replace the 1e-7 with a variable in !SIM
for shift in shifts:
if not np.all(np.abs(shift[1]) < 1e-7):
x_shifts += np.interp(new_bin_edges, shift[0], shift[1])
if not np.all(np.abs(shift[2]) < 1e-7):
y_shifts += np.interp(new_bin_edges, shift[0], shift[2])
# After adding all the shifts together, work out a new wavelength set
z_edges = np.copy(new_bin_edges)
z_shifts = (x_shifts**2 + y_shifts**2)**0.5 # in arcsec
step_size = kwargs["pixel_scale"] * kwargs["sub_pixel_fraction"]
z_steps = (z_shifts / step_size).astype(int)
# find where the shift is larger than a sub pixel fraction size
ii = np.where(np.diff(z_steps) != 0)[0]
x_shifts = np.array([x_shifts[0]] + list(x_shifts[ii]) + [x_shifts[-1]])
y_shifts = np.array([y_shifts[0]] + list(y_shifts[ii]) + [y_shifts[-1]])
z_edges = np.array([z_edges[0]] + list(z_edges[ii]) + [z_edges[-1]])
else:
z_edges = np.array([kwargs["wave_min"], kwargs["wave_max"]])
x_shifts = np.zeros(2)
y_shifts = np.zeros(2)
shift_dict = {"wavelengths": z_edges,
"x_shifts": x_shifts / 3600., # fov_grid returns [arcsec]
"y_shifts": y_shifts / 3600.} # get_3d_shifts returns [deg]
return shift_dict
[docs]def get_imaging_waveset(effects_list, **kwargs):
"""
Returns the edge wavelengths for the spectral layers needed for simulation
Parameters
----------
effects_list : list of Effect objects
Returns
-------
wave_bin_edges : list
[um] list of wavelengths
"""
required_keys = ["wave_min", "wave_max"]
check_keys(kwargs, required_keys, action="error")
# get the filter wavelengths first to set (wave_min, wave_max)
filters = get_all_effects(effects_list, (efs.FilterCurve, efs.FilterWheel))
wave_bin_edges = [filt.fov_grid(which="waveset", **kwargs)
for filt in filters]
if wave_bin_edges:
kwargs["wave_min"] = max(wave[0].value for wave in wave_bin_edges)
kwargs["wave_max"] = min(wave[1].value for wave in wave_bin_edges)
# Bit confusing...
wave_bin_edges = [[kwargs["wave_min"], kwargs["wave_max"]]]
if kwargs["wave_min"] > kwargs["wave_max"]:
raise ValueError("Filter wavelength ranges do not overlap: "
f"{wave_bin_edges}.")
# ..todo: add in Atmospheric dispersion and ADC here
for effect_class in [efs.PSF]:
for eff in get_all_effects(effects_list, effect_class):
waveset = eff.fov_grid(which="waveset", **kwargs)
if waveset is not None:
wave_bin_edges.append(waveset)
wave_bin_edges = combine_wavesets(*wave_bin_edges)
if not wave_bin_edges:
# This is already set at the top, why again here?
wave_bin_edges = [kwargs["wave_min"], kwargs["wave_max"]]
return wave_bin_edges
[docs]def get_imaging_fovs(headers, waveset, shifts, **kwargs):
"""
Return a generator of ``FieldOfView`` objects.
Parameters
----------
headers : list of fits.Header objects
Headers giving spatial extent of each FOV region
waveset : list of floats
[um] N+1 wavelengths for N spectral layers
shifts : list of tuples (or actually arrays?)
[deg] x,y shifts w.r.t to the optical axis plane. N shifts for N
spectral layers
Returns
-------
fovs : generator of ``FieldOfView`` objects
"""
# Ensure array for later indexing
shift_waves = np.array(shifts["wavelengths"]) # in [um]
shift_dx = shifts["x_shifts"] # in [deg]
shift_dy = shifts["y_shifts"]
# combine the wavelength bins from 1D spectral effects and 3D shift effects
if shift_waves.size:
mask = (shift_waves > min(waveset)) * (shift_waves < max(waveset))
waveset = combine_wavesets(waveset, shift_waves[mask])
# Actually evaluating the generators here is only necessary for the log msg
waveset = list(waveset)
headers = list(headers)
logger.info("Preparing %d FieldOfViews", (len(waveset) - 1) * len(headers))
combos = product(pairwise(waveset), headers)
for fov_id, ((wave_min, wave_max), hdr) in enumerate(combos):
# add any pre-instrument shifts to the FOV sky coords
wave_mid = 0.5 * (wave_min + wave_max)
x_shift = np.interp(wave_mid, shift_waves, shift_dx)
y_shift = np.interp(wave_mid, shift_waves, shift_dy)
fov_hdr = deepcopy(hdr)
fov_hdr["CRVAL1"] += x_shift # headers are in [deg]
fov_hdr["CRVAL2"] += y_shift
# define the wavelength range for the FOV
waverange = [wave_min, wave_max]
# Make the FOV
yield FieldOfView(fov_hdr, waverange, id=fov_id, **kwargs)
# TODO: check that each header is not larger than chunk_size
# that's already done in get_imaging_headers, isn't it?
[docs]def get_spectroscopy_fovs(headers, shifts, effects=None, **kwargs):
"""Return a generator of ``FieldOfView`` objects."""
if effects is None:
effects = []
shift_waves = shifts["wavelengths"] # in [um]
shift_dx = shifts["x_shifts"] # in [deg]
shift_dy = shifts["y_shifts"]
logger.info("Preparing %d FieldOfViews", len(headers))
apertures = get_all_effects(effects, (efs.ApertureList, efs.ApertureMask))
masks = [ap.fov_grid(which="masks") for ap in apertures]
mask_dict = {}
for mask in masks:
if isinstance(mask, dict):
mask_dict.update(mask)
elif isinstance(mask, np.ndarray):
mask_dict[len(mask_dict)] = mask
for fov_id, hdr in enumerate(headers):
# add any pre-instrument shifts to the FOV sky coords
wave_mid = hdr["WAVE_MID"]
x_shift = np.interp(wave_mid, shift_waves, shift_dx)
y_shift = np.interp(wave_mid, shift_waves, shift_dy)
fov_hdr = deepcopy(hdr)
fov_hdr["CRVAL1"] += x_shift # headers are in [deg]
fov_hdr["CRVAL2"] += y_shift
# Make the FOV
waverange = [hdr["WAVE_MIN"], hdr["WAVE_MAX"]]
fov = FieldOfView(fov_hdr, waverange=waverange, **kwargs)
fov.meta["distortion"]["rotation"] = hdr["ROTANGD"]
fov.meta["distortion"]["shear"] = hdr["SKEWANGD"]
fov.meta["conserve_image"] = hdr["IMG_CONS"]
# TODO: In the other function, the id is set via the contructor.
# What's the difference?
fov.meta["fov_id"] = fov_id
fov.meta["aperture_id"] = hdr["APERTURE"]
# .. todo: get these masks working
# there needs to be fov_grid(which="mask") in ApertureList/Mask
# fov.mask = mask_dict[hdr["APERTURE"]]
yield fov
# FIXME: This functions doesn't seem to be covered by any separate unit test.
[docs]def combine_wavesets(*wavesets):
"""
Join and sorts several sets of wavelengths into a single 1D array.
Parameters
----------
wavesets : one or more iterables
A group of wavelength arrays or lists
Returns
-------
wave_set : np.ndarray
Combined set of wavelengths
Note
----
This assumes that all wavesets are given in the same unit!
"""
# TODO: set variable in !SIM.computing for rounding to the 7th decimal
decimals = 7
def _get_waves(waves):
for wave in waves:
if isinstance(wave, u.Quantity):
round_wave = wave.round(decimals).value
else:
round_wave = np.round(wave, decimals)
yield from round_wave
# NOTE: This function previously used np.sort(wave_set, kind="stable").
# If any issues occur with the buitin sorted, go back to that!
wave_set = sorted(set(_get_waves(wavesets)))
return wave_set