Recreating Steffl’s Flatfields

This notebook focuses on the row-to-row flatfield.

⚠️ Demo cells below are marked #| eval: false — they require live PDS index fetches, cached data products, or interactive plotting backends that are unavailable during a clean docs render. Open this notebook in JupyterLab to step through it cell-by-cell against your local data.

import itertools
from functools import cached_property
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
from astropy import modeling
from tqdm.auto import tqdm, trange
from IPython.display import HTML

import holoviews as hv
import hvplot.xarray  # noqa: F401

from planetarypy.datetime_format_converters import (
    fromdoyformat,  # replaces removed nasa_time_to_iso
    doyformat,      # replaces removed iso_to_nasa_time
)
from pyuvis import UVPDS, UVISObs, CatalogFilter
from pyuvis.calib.steffl import (
    steffl_spica_dates,
    steffl_spica_doy_dates,
    Row2Row,
    Col2Col,
    create_detector_stack,
)

# NOTE: `StarCalibs` (referenced by some demo cells below) was in the
# original nbdev draft but never made it into pyuvis 0.9.0. Cells that
# need it are eval:false; reintroduce the class or rewrite those cells
# against `Row2Row` / `Col2Col` if you want them executable.

# `param` is only needed if you're re-defining the row-to-row
# class interactively; the canonical Row2Row lives in
# src/pyuvis/calib/steffl.py.
# import param
hv.extension("bokeh", logo=False, inline=False)
def get_pids(date, det='fuv'):
    "getting pids via PDS catalog"
    cat = CatalogFilter(date)
    func = getattr(cat, f"get_{det}_date")
    f1 = "OBSERVATION_TYPE in ['CALIB', 'USTARE']"
    f2 = "TARGET_NAME=='STAR'"
    pids = list(func().query(f"{f1} and {f2}").index)
    cat.set_next_day()
    pids.extend(list(func().query(f"{f1} and {f2}").index))
    return pids
pids = get_pids(steffl_spica_dates[2], det='fuv')
calibs = StarCalibs()  # getting pids from Greg's Table
calibs.steffl_years  # as defined by magic number = 332
pids = calibs.fuv_steffls_by_year(2014).tolist()
pids
len(pids)

AlongSlit data -

Valid scan numbers for FUV are variable, checking with the sensor-integrated

along = AlongSlit(pids[5])
along.data
along.plot_set
along.plot_averaged
along.plot_column_std
along.summed.sum(dim='spectral').hvplot()
along.plot_integrated
along.plot_ff(vmax=2)
along.ff.hvplot.hist(bins=np.linspace(0, 2, 100)).options(active_tools=['box_zoom'])

Across Slit stack -

One product_id scans along a slit. The set of product_ids scans across the slit (hence the walking of spectral signatures within the set) I stack the “along-slits” observations below.

Universal detector stack xarray creator

This function allows to create detector-shaped stacks of data for different purposes, like: * along_slit scans * across_slit scans * sets of corrections for the Steffl flatfield

class AcrossSlit:
    spectral_slice = slice(8, 1016)
    
    @classmethod
    def from_year(cls, year):
        pids = calibs.fuv_steffls_by_year(year).tolist()
        return cls(pids, year=year)

    def __init__(self, pids, year=None):
        self.pids = pids
        self.year = year

        stacked = []
        for pid in tqdm(sorted(pids)):
            along = AlongSlit(pid)
            stacked.append(along.integrated)

        self.arr = xr.concat(stacked, xr.DataArray(range(len(pids)), dims='m'))
        if self.year is not None:
            self.arr.name = str(self.year)

    def get_single_obs(self, i_pid):
        "Get a single UVIS obs via the index number of the stored product_ids"
        return UVPDS(self.pids[i_pid])

    @property
    def detector(self):
        return self.pids[0][:3]

    def good_swaths(self, mslice=slice(3, 6)):
        return self.arr.sel(m=mslice)  # sel slices inclusive!

    def illuminated(self, mslice=slice(3, 6)):
        return self.good_swaths(mslice).isel(spectral=self.spectral_slice).sel(spatial=AlongSlit.spatial_slice)

    @property
    def illum_summed(self):
        return self.illuminated.sum(dim='m')

    @property
    def illum_summed_df(self):
        self.spectral_attrs = self.illum_summed.spectral.attrs
        self.spatial_attrs = self.illum_summed.spatial.attrs
        return self.illum_summed.to_dataframe().unstack().T

    def median_smooth_df(self, win_size=5):
        return self.illum_summed_df.rolling(
            axis='columns', center=True, window=win_size).median()

    def median_smooth(self, win_size=5):
        arr = self.median_smooth_df(win_size).T.stack().to_xarray()['Counts']
        arr.spectral.attrs = self.spectral_attrs
        arr.spatial.attrs = self.spatial_attrs
        return arr

    def smooth_each_m(self):
        pass

    def plot_reduced_over_m(self, func='median'):
        f = getattr(self.arr, func)
        return f(dim=['spectral', 'spatial']).hvplot(
            label=f"{self.year}", 
            title=f'Detector {func} over swath m', 
            # legend=True
        )
calibs.steffl_years
across = AcrossSlit.from_year(2013)
obs = across.get_single_obs(0)
obs
across.illuminated()
across.plot_reduced_over_m('median')

MultiYearSteffl

def corrank(X: pd.DataFrame):
    import itertools
    df = pd.DataFrame([[(i, j), X.corr().loc[i, j]] for i, j in list(
        itertools.combinations(X.corr(), 2))], columns=['pairs', 'corr'])
    return df.sort_values(by='corr', ascending=False)
class MultiYearSteffl:
    def __init__(self, start=2013, end=2017, win_size=5, load_saved=True):
        self.start = start
        self.end = end
        self.win_size = win_size

        if not load_saved:
            self.create_datastack()
        else:
            self.arr = xr.open_zarr("bright_swaths.zarr")['bright_swaths']

    def create_datastack(self):
        plots = []
        xarrs = []

        for year in range(start, end+1):
            if year in [2013, 2014]:
                mslice = slice(4, 7)
            else:
                mslice = slice(3, 6)
            stack = AcrossSlit.from_year(year)
            data = stack.illuminated(mslice)
            # data.coords['m'] = 
            xarrs.append(data)
            plots.append(stack.plot_reduced_over_m('median'))

        self.total_signal_plots = hv.Overlay(plots)
        self.arr = xr.concat(xarrs, xr.DataArray(range(start, end+1), dims='year'))
        self.arr.name = "bright_swaths"
        
    def plot_arr(self):
        return self.arr.hvplot(x='spectral', y='spatial', clim=(None, 10_000))

    @property
    def spatial_avg(self):
        return self.arr.median(dim='spatial')

    def plot_spatial_avg(self):
        return self.spatial_avg.hvplot(x='spectral', grid=True)

    def plot_spectral_scrub(self, m, smoothed=True):
        data = self.smoothed if smoothed else self.arr
        return data.sel(m=m).hvplot(
            x='spectral',
            y='spatial',
            framewise=False,
            clim=(0, 8000),
            widget_type='scrubber',
            widget_location='bottom',
            label=f'{m=}'
        )
    
    
    @property
    def spectral_smoothed(self):
        rolled = self.arr.rolling(spectral=self.win_size, center=True)
        return rolled.median()

    @property
    def spectra_smoothed_spatial_avg(self):
        return self.spectral_smoothed.median(dim='spatial')

    def plot_smoothed(self):
        return self.spectral_smoothed.hvplot(x='spectral', y='spatial', grid=True)
    
    @property
    def spectra_as_df(self):
        return self.spectra_smoothed_spatial_avg.to_dataframe().unstack().T.dropna(how='all').dropna(how='all', axis=1)

    @property
    def ranked_corr(self):
        return corrank(self.spectra_as_df)
    
mission = MultiYearSteffl()
from IPython.display import HTML
HTML("total_plots.html")
df = mission.spectra_as_df
df
year_base = 2013
m_base = 4
base = df[(year_base, m_base)]
years = range(year_base+1, 2018)
for year in years:
    for m in df[year].columns:
        to_check = df[(year, m)]
        for lag in range(-5,6):
            corr = base.corr(to_check.shift(lag))
            if corr > 0.999:
                print(year, m, lag, corr)
base.name = "(2013, 4)"

t = (2014, 4)
to_check = df[t]
to_check.name = str(t)
overlays = []
overlays.append(base.hvplot(color='red', xlim=(130, 132), ylim=(4000, 5500)))
overlays.append(to_check.hvplot(color='blue'))
for lag in [-2, -1, 1, 2]:
    overlays.append(to_check.shift(lag).hvplot(label=str(lag), line_dash='dotted'))
hv.Overlay(overlays)
ranked = mission.ranked_corr
ranked['corr'].hvplot.hist(bins=np.linspace(0.995, 1, 50))
ranked.head()

Define “sensitivity” of pixels. A big binned pixel would simply be the addition (averaging?) of each pixel’s sensitivity. How would these look like applying the FF procedure for rows. Try a small forward model calculating the effects of these calculations.

Bob: Issue with FF for binned pixels. How many counts would you get? The FF might become source-dependent: Imagine a bright spectral line falling on a low-sensitivity pixel, and then a different source has a bright line next to it on a high sensitivity pixel. The FF for a binned image will look different (maybe model that as a demonstration of the issue).

def get_pair_plot(pair):
    p1, p2 = pair
    spectra = mission.spectra_smoothed_spatial_avg
    spec1 = spectra.sel(year=p1[0], m=p1[1])
    spec2 = spectra.sel(year=p2[0], m=p2[1])
    return spec1.hvplot(label=f"{p1}") * spec2.hvplot(label=f"{p2}")
pairs = ranked.head()['pairs']
pairs
spectra = mission.spectra_smoothed_spatial_avg
fitting_specs = []
selected = []
plots = []
for pair in ranked.head()['pairs']:
    for p in pair:
        if p in selected:
            continue
        spec = spectra.sel(year=p[0], m=p[1])
        fitting_specs.append(spec)
        plots.append(spec.hvplot(label=f"{p}"))
        selected.append(p)
hv.Overlay(plots).opts(title="Best corr pairs")
ranked.head()
selected
arrays = []
for s in selected:
    arrays.append(mission.arr.sel(year=s[0], m=s[1]))
arrays[0]
corr_set = xr.concat(arrays, xr.DataArray(range(len(selected)), dims="corr_set"))
corr_set
data = corr_set.isel(corr_set=slice(None,3))
data
std = data.std(dim="corr_set")
std
avg = data.mean(dim='corr_set')
relstd = std/avg
relstd.sel(spectral=slice(None,None)).hvplot()
std.sel(spectral=slice(None, None)).hvplot()
data = mission.arr.sel(year=2015, m=6, spectral=slice(None, None))
data.values.shape
%matplotlib widget
xvals = np.arange(1, 5000)
sqrroots = np.sqrt(xvals)
plt.figure()
plt.plot(data.values.ravel(), std.values.ravel(), '.', alpha=0.1)
plt.plot(xvals, sqrroots)
import seaborn as sns
spec1 = spectra.sel(year=2014, m=4)
spec1
spec2 = spectra.sel(year=2013, m=5)
spec1.hvplot() * spec2.hvplot()
mission.spectral_smooth()
mission.plot_smoothed()
mission.smoothed.sel(spectral=slice(135, 145),year=[2013, 2015, 2016]).hvplot(x='spectral', y='spatial')
mission.smoothed.median(dim='spatial').hvplot(x='spectral')
mission.plot_arr()
mission.spatial_avg
(mission.plot_spatial_avg() * mission.smoothed_spatial_avg.hvplot(x='spectral')).opts(show_grid=True)
mission.plot_spectral_scrub(0)
arr = mission.smoothed
arr
m0 = arr.sel(m=0, drop=True)
m0
m0_averaged = m0.median(dim=['year', 'spatial'])
m0_averaged.hvplot()
m0
m0_std = m0.std(dim=['year'])
m0_error = m0_std / m0.median(dim='year')
m0_error.hvplot()
ratio = m0.median(dim='year') / m0_averaged 
ratio.hvplot(label='m0_ratio')
ratio.hvplot.hist(bins=100)
ratio
ratio = stack.illum_summed / averaged
ratio.name = 'ratio'
ratio.hvplot(cmap='gray', clim=(0, 2), width=800)
ratio.hvplot.hist(bins=200)
hist, bin_edges = np.histogram(ratio, bins=200, range=(0, 3))
bin_centers = (bin_edges[1:] + bin_edges[:-1])/2
from astropy import modeling
model = modeling.models.Gaussian1D(mean=1)
fitter = modeling.fitting.LevMarLSQFitter()
bin_edges.shape
break_index = int(np.argwhere(bin_edges > 0.65)[0])
break_index
fitted = fitter(model, bin_centers[break_index:], hist[break_index:])
fitted
histplot = hv.Curve((bin_centers, hist), 'Counts', 'Frequency', label='Histogram')
fitplot = hv.Curve((bin_centers, fitted(bin_centers)), label='Fit')
(histplot.opts(tools=['hover']) * fitplot).opts(width=500,
                                                ylim=(0, None),
                                                xlim=(-0.5, None))
fitted
ratio.where(ratio < 0.65).hvplot(cmap='magma')
arr = stack.illuminated
arr.hvplot(
    y='spatial',
    x='spectral',
    logz=False,
    clim=(
        10,
        5000),
    cmap='magma',
    widget_type='scrubber',
    widget_location='bottom')
arr.std(dim='m').hvplot(cmap='magma', clim=(None, 500))
var = arr.std(dim='m')
relvar = var / arr.mean(dim='m')
relvar.hvplot(cmap='magma')
relvar.hvplot.hist(bins=200)