# -*- coding: utf-8 -*-
import copy
from datetime import datetime
import numpy as np
from scipy.interpolate import interp1d
from astropy import units as u
from tqdm import tqdm
from synphot import SourceSpectrum, Empirical1D
from synphot.units import PHOTLAM
from .optics_manager import OpticsManager
from .fov_manager import FOVManager
from .image_plane import ImagePlane
from ..commands.user_commands import UserCommands
from ..detector import DetectorArray
from ..effects import ExtraFitsKeywords
from ..utils import from_currsys, top_level_catch, get_logger
from .. import rc, __version__
logger = get_logger(__name__)
import multiprocessing as mp
N_PROCESSES = mp.cpu_count() - 1
USE_MULTIPROCESSING = False
[docs]def view_fov(fov, hdu_type):
fov.view(hdu_type)
return fov
[docs]def apply_fov_effects(fov, fov_effects):
for effect in fov_effects:
fov = effect.apply_to(fov)
return fov
[docs]class OpticalTrain:
"""
The main class for controlling a simulation.
Parameters
----------
cmds : UserCommands, str
If the name of an instrument is passed, OpticalTrain tries to find the
instrument package, and internally creates the UserCommands object
Examples
--------
Create an optical train::
>>> import scopesim as im
>>> cmd = sim.UserCommands("MICADO")
>>> opt = sim.OpticalTrain(cmd)
Observe a Source object::
>>> src = sim.source.source_templates.empty_sky()
>>> opt.observe(src)
>>> hdus = opt.readout()
List the effects modelled in an OpticalTrain::
>>> print(opt.effects)
Effects can be accessed by using the name of the effect::
>>> print(opt["dark_current"])
To include or exclude an effect during a simulation run, use the
``.include`` attribute of the effect::
>>> opt["dark_current"].include = False
Data used by an Effect object is contained in the ``.data`` attribute,
while other information is contained in the ``.meta`` attribute::
>>> opt["dark_current"].data
>>> opt["dark_current"].meta
Meta data values can be set by either using the ``.meta`` attribute
directly::
>>> opt["dark_current"].meta["value"] = 0.5
or by passing a dictionary (with one or multiple entries) to the
OpticalTrain object::
>>> opt["dark_current"] = {"value": 0.75, "dit": 30}
"""
@top_level_catch
def __init__(self, cmds=None):
self.cmds = cmds
self._description = self.__repr__()
self.optics_manager = None
self.fov_manager = None
self.image_planes = []
self.detector_arrays = []
self.yaml_dicts = None
self._last_source = None
if cmds is not None:
self.load(cmds)
[docs] def load(self, user_commands):
"""
(Re)Load an OpticalTrain with a new set of UserCommands.
Parameters
----------
user_commands : UserCommands or str
"""
if isinstance(user_commands, str):
user_commands = UserCommands(use_instrument=user_commands)
elif isinstance(user_commands, UserCommands):
user_commands = copy.deepcopy(user_commands)
else:
raise ValueError("user_commands must be a UserCommands or str object "
f"but is {type(user_commands)}")
self.cmds = user_commands
# FIXME: Setting rc.__currsys__ to user_commands causes many problems:
# UserCommands used NestedMapping internally, but is itself not
# an instance or subclas thereof. So rc.__currsys__ actually
# changes type as a result of this line. On one hand, some other
# code relies on this change, i.e. uses attributes from
# UserCommands via rc.__currsys__, but on the other hand some
# tests (now with proper patching) fail because of this type
# change. THIS IS A PROBLEM!
rc.__currsys__ = user_commands
self.yaml_dicts = self.cmds.yaml_dicts
self.optics_manager = OpticsManager(self.yaml_dicts, self.cmds)
self.update()
[docs] def update(self, **kwargs):
"""
Update the user-defined parameters and remake main internal classes.
Parameters
----------
kwargs : expanded dict
Any keyword-value pairs from a config file
"""
self.optics_manager.update(**kwargs)
opt_man = self.optics_manager
self.fov_manager = FOVManager(opt_man.fov_setup_effects, cmds=self.cmds,
**kwargs)
self.image_planes = [ImagePlane(hdr, **kwargs)
for hdr in opt_man.image_plane_headers]
self.detector_arrays = [DetectorArray(det_list, cmds=self.cmds, **kwargs)
for det_list in opt_man.detector_setup_effects]
[docs] @top_level_catch
def observe(self, orig_source, update=True, **kwargs):
"""
Main controlling method for observing ``Source`` objects.
Parameters
----------
orig_source : Source
update : bool
Reload optical system
kwargs : expanded dict
Any keyword-value pairs from a config file
Notes
-----
How the list of Effects is split between the 5 main tasks:
- Make a FOV list - z_order = 0..99
- Make a image plane - z_order = 100..199
- Apply Source altering effects - z_order = 200..299
- Apply FOV specific (3D) effects - z_order = 300..399
- Apply FOV-independent (2D) effects - z_order = 400..499
- [Apply detector plane (0D, 2D) effects - z_order = 500..599]
.. todo:: List is out of date - update
"""
if update:
self.update(**kwargs)
# self.set_focus(**kwargs) # put focus back on current instrument package
# Make a copy of the Source and prepare for observation (convert to
# internally used units, sample to internal wavelength grid)
source = orig_source.make_copy()
source = self.prepare_source(source)
# [1D - transmission curves]
for effect in self.optics_manager.source_effects:
source = effect.apply_to(source)
# [3D - Atmospheric shifts, PSF, NCPAs, Grating shift/distortion]
# START OF MULTIPROCESSING
if USE_MULTIPROCESSING:
fovs = self.fov_manager.fovs
fov_effects = self.optics_manager.fov_effects
hdu_type = "cube" if self.fov_manager.is_spectroscope else "image"
with mp.Pool(processes=N_PROCESSES) as pool:
fovs = pool.starmap(extract_source,
zip(fovs, [source] * len(fovs)))
with mp.Pool(processes=N_PROCESSES) as pool:
fovs = pool.starmap(view_fov,
zip(fovs, [hdu_type] * len(fovs)))
with mp.Pool(processes=N_PROCESSES) as pool:
fovs = pool.starmap(apply_fov_effects,
zip(fovs, [fov_effects] * len(fovs)))
# OLD SINGLE CORE CODE
else:
fovs = self.fov_manager.fovs
for fov in tqdm(fovs, desc=" FOVs", position=0):
# print("FOV", fov_i+1, "of", n_fovs, flush=True)
# .. todo: possible bug with bg flux not using plate_scale
# see fov_utils.combine_imagehdu_fields
fov.extract_from(source)
hdu_type = "cube" if self.fov_manager.is_spectroscope else "image"
fov.view(hdu_type)
for effect in tqdm(self.optics_manager.fov_effects,
desc=" FOV effects", position=1, leave=False):
fov = effect.apply_to(fov)
fov.flatten()
self.image_planes[fov.image_plane_id].add(fov.hdu, wcs_suffix="D")
# ..todo: finish off the multiple image plane stuff
# END OF MULTIPROCESSING
# [2D - Vibration, flat fielding, chopping+nodding]
for effect in tqdm(self.optics_manager.image_plane_effects,
desc=" Image Plane effects"):
for ii, image_plane in enumerate(self.image_planes):
self.image_planes[ii] = effect.apply_to(image_plane)
self._last_fovs = fovs
self._last_source = source
[docs] def prepare_source(self, source):
"""
Prepare source for observation.
The method is currently applied to cube fields only.
The source data are converted to internally used units (PHOTLAM).
The source data are interpolated to the waveset used by the FieldOfView
This is necessary when the source data are sampled on a coarser grid
than used internally, or if the source data are sampled on irregular
wavelengths.
For cube fields, the method assumes that the wavelengths at which the
cube is sampled is provided explicitely as attribute `wave` if the cube
ImageHDU.
"""
# Convert to PHOTLAM per arcsec2
# ..todo: this is not sufficiently general
for ispec, spec in enumerate(source.spectra):
# Put on fov wavegrid
wave_min = min(fov.meta["wave_min"] for fov in self.fov_manager.fovs)
wave_max = max(fov.meta["wave_max"] for fov in self.fov_manager.fovs)
wave_unit = u.Unit(from_currsys("!SIM.spectral.wave_unit", self.cmds))
dwave = from_currsys("!SIM.spectral.spectral_bin_width", self.cmds) # Not a quantity
fov_waveset = np.arange(wave_min.value, wave_max.value, dwave) * wave_unit
fov_waveset = fov_waveset.to(u.um)
source.spectra[ispec] = SourceSpectrum(Empirical1D,
points=fov_waveset,
lookup_table=spec(fov_waveset))
for cube in source.cube_fields:
header, data, wave = cube.header, cube.data, cube.wave
# Need to check whether BUNIT is per arcsec2 or per pixel
inunit = u.Unit(header["BUNIT"])
data = data.astype(np.float32) * inunit
factor = 1
for base, power in zip(inunit.bases, inunit.powers):
if (base**power).is_equivalent(u.arcsec**(-2)):
conversion = (base**power).to(u.arcsec**(-2)) / base**power
data *= conversion
factor = u.arcsec**(-2)
data = data.to(PHOTLAM,
equivalencies=u.spectral_density(wave[:, None, None]))
if factor == 1: # Normalise to 1 arcsec2 if not a spatial density
# ..todo: lower needed because "DEG" is not understood, this is ugly
pixarea = (header["CDELT1"] * u.Unit(header["CUNIT1"].lower()) *
header["CDELT2"] * u.Unit(header["CUNIT2"].lower())).to(u.arcsec**2)
data = data / pixarea.value # cube is per arcsec2
data = (data * factor).value
cube.header["BUNIT"] = "PHOTLAM/arcsec2" # ..todo: make this more explicit?
# The imageplane_utils like to have the spatial WCS in units of "deg". Ensure
# that the cube is passed on accordingly
cube.header["CDELT1"] = header["CDELT1"] * u.Unit(header["CUNIT1"].lower()).to(u.deg)
cube.header["CDELT2"] = header["CDELT2"] * u.Unit(header["CUNIT2"].lower()).to(u.deg)
cube.header["CUNIT1"] = "deg"
cube.header["CUNIT2"] = "deg"
# Put on fov wavegrid
wave_min = min(fov.meta["wave_min"] for fov in self.fov_manager.fovs)
wave_max = max(fov.meta["wave_max"] for fov in self.fov_manager.fovs)
wave_unit = u.Unit(from_currsys("!SIM.spectral.wave_unit"), self.cmds)
dwave = from_currsys("!SIM.spectral.spectral_bin_width", self.cmds) # Not a quantity
fov_waveset = np.arange(wave_min.value, wave_max.value, dwave) * wave_unit
fov_waveset = fov_waveset.to(u.um)
# Interpolate into new data cube.
# This is done layer by layer for memory reasons.
new_data = np.zeros((fov_waveset.shape[0], data.shape[1], data.shape[2]),
dtype=np.float32)
for j in range(data.shape[1]):
cube_interp = interp1d(wave.to(u.um).value, data[:, j, :],
axis=0, kind="linear",
bounds_error=False, fill_value=0)
new_data[:, j, :] = cube_interp(fov_waveset.value)
cube.data = new_data
cube.header["CTYPE3"] = "WAVE"
cube.header["CRPIX3"] = 1
cube.header["CRVAL3"] = wave_min.value
cube.header["CDELT3"] = dwave
cube.header["CUNIT3"] = wave_unit.name
return source
[docs] @top_level_catch
def readout(self, filename=None, **kwargs):
"""
Produce detector readouts for the observed image.
Parameters
----------
filename : str, optional
Where to save the FITS file
kwargs
Returns
-------
hdu : fits.HDUList
Notes
-----
- Apply detector plane (0D, 2D) effects - z_order = 500..599
"""
hduls = []
for i, detector_array in enumerate(self.detector_arrays):
array_effects = self.optics_manager.detector_array_effects
dtcr_effects = self.optics_manager.detector_effects
hdul = detector_array.readout(self.image_planes, array_effects,
dtcr_effects, **kwargs)
fits_effects = self.optics_manager.get_all(ExtraFitsKeywords)
if len(fits_effects) > 0:
for effect in fits_effects:
hdul = effect.apply_to(hdul, optical_train=self)
else:
try:
hdul = self.write_header(hdul)
except Exception:
logger.exception("Header update failed, data will be "
"saved with incomplete header. See stack "
"trace for details.")
if filename is not None and isinstance(filename, str):
fname = filename
if len(self.detector_arrays) > 1:
fname = f"{i}_{filename}"
hdul.writeto(fname, overwrite=True)
hduls.append(hdul)
return hduls
# def set_focus(self, **kwargs):
# self.cmds.update(**kwargs)
# dy = self.cmds.default_yamls
# if len(dy) > 0 and "packages" in dy:
# self.cmds.update(packages=self.default_yamls[0]["packages"])
# rc.__currsys__ = self.cmds
[docs] def shutdown(self):
"""
Shut down the instrument.
This method closes all open file handles and should be called when the
optical train is no longer needed.
"""
for effect_name in self.effects["name"]:
try:
self[effect_name]._file.close()
except AttributeError:
pass
self._description = "The instrument has been shut down."
@property
def effects(self):
return self.optics_manager.list_effects()
def __repr__(self):
return f"{self.__class__.__name__}({self.cmds!r})"
def __str__(self):
return self._description
def _repr_pretty_(self, p, cycle):
"""For ipython."""
if cycle:
p.text(f"{self.__class__.__name__}(...)")
else:
p.text(f"{self.__class__.__name__} ")
p.text(f"for {self.cmds['!OBS.instrument']} ")
p.text(f"@ {self.cmds['!TEL.telescope']}:")
p.breakable()
p.text("UserCommands:")
p.breakable()
p.pretty(self.cmds)
p.breakable()
p.text("OpticalElements:")
with p.indent(2):
for item in self:
p.breakable()
p.pretty(item)
p.breakable()
p.text("DetectorArrays:")
with p.indent(2):
for item in self.detector_arrays:
p.breakable()
p.pretty(item)
p.breakable()
p.text("Effects:")
p.breakable()
with p.indent(2):
p.pretty(self.effects)
def __getitem__(self, item):
return self.optics_manager[item]
def __setitem__(self, key, value):
self.optics_manager[key] = value
# user commands report
# package dependencies
# modes names
# default modes
# yaml hierarchy
#
# optics_manager
# derived properties
# system transmission curve
# list of effects
#
# etc
# limiting magnitudes
#