Source code for spectuner.slm_factory

from __future__ import annotations
from typing import Optional, Callable, Literal
from dataclasses import dataclass
from copy import deepcopy
from collections import defaultdict

import numpy as np

from .preprocess import load_preprocess, get_freq_data, get_T_data
from .sl_model import (
    create_spectral_line_model_state,
    create_spectral_line_db,
    SpectralLineDB,
    SpectralLineModel,
    ParameterManager,
)
from .peaks import PeakManager
from .config import Config
from .utils import pick_default_kwargs


def combine_specie_lists(specie_lists, params_list):
    specie_list_ret = []
    for specie_list in specie_lists:
        specie_list_ret.extend(specie_list)
    params_ret = np.concatenate(params_list)
    return specie_list_ret, params_ret


def sum_T_single_data(T_single_dict, T_back=0., key=None):
    # Get a test dict
    for sub_dict in T_single_dict.values():
        for T_single_data in sub_dict.values():
            break
        break
    T_ret_data = [None for _ in T_single_data]

    def sum_sub(target_dict):
        for T_single_data in target_dict.values():
            for i_segment, T_single in enumerate(T_single_data):
                if T_single is None:
                    continue
                if T_ret_data[i_segment] is None:
                    T_ret_data[i_segment] = T_back
                T_ret_data[i_segment] = T_ret_data[i_segment] + T_single - T_back

    if key is not None:
        sum_sub(T_single_dict[key])
        return T_ret_data

    for sub_dict in T_single_dict.values():
        sum_sub(sub_dict)
    return T_ret_data


def compute_T_single_data(slm_factory: SpectralLineModelFactory,
                          obs_info: list,
                          specie_list: list,
                          params: np.ndarray) -> dict:
    T_single_data = defaultdict(dict)
    for item in specie_list:
        for name in item["species"]:
            specie_list_single, \
            params_single = derive_sub_specie_list_with_params(
                slm_factory, obs_info, specie_list, [name], params
            )
            sl_model = slm_factory.create_sl_model(obs_info, specie_list_single)
            T_single_data[item["id"]][name] = sl_model(params_single)
    T_single_data = dict(T_single_data)
    return T_single_data


def derive_sub_specie_list(specie_list, species):
    """Return a new specie list that only contains the given species.

    Args:
        specie_list (list): Specie list.
        species (list): A list of specie names that should be included.

    Returns:
        list: Filtered specie list.
    """
    species_list_new = []
    for item in specie_list:
        species_new = [name for name in item["species"] if name in species]
        if len(species_new) > 0:
            item_new = deepcopy(item)
            item_new["species"] = species_new
            species_list_new.append(item_new)
    return species_list_new


def derive_sub_specie_list_with_params(slm_factory, obs_info, specie_list, species, params):
    """Extract a sub specie list and corresponding parameters.

    Args:
        specie_list (list): Specie list.
        species (list): A list of specie names that should be included.
        params (array): Parameters.
        config (dict): Config

    Returns:
        list: Filtered specie list.
        array: Filtered parameters.
    """
    specie_list_sub = derive_sub_specie_list(specie_list, species)
    param_mgr = slm_factory.create_parameter_mgr(specie_list, obs_info)
    params_sub = param_mgr.get_subset_params(species, params)
    return specie_list_sub, params_sub


def jit_fitting_model(model):
    """Call the fitting model once to enable jit."""
    model(np.mean(model.bounds, axis=1))


[docs] class SpectralLineModelFactory: """Factory class to create objects related to spectral line models. Args: config: ``Config`` instance. sl_db: Spectral line database. If this is provided, the code will use this database instead of the one defined in the config. """ def __init__(self, config: Config, sl_db: Optional[SpectralLineDB]=None) -> None: self._config = config if sl_db is None: self._sl_db = create_spectral_line_db(config["sl_model"]["fname_db"]) else: self._sl_db = sl_db
[docs] def create_parameter_mgr(self, specie_list: list, obs_info: list): """Create a parameter manager. This uses ``param_info`` in the config. Args: specie_list: List of species. obs_info: List of information of each spectral window. """ param_info = self._config["param_info"] return ParameterManager(specie_list, param_info, obs_info)
[docs] def create_sl_model(self, obs_info: list, specie_list: list, sl_dict_list: Optional[list]=None) -> SpectralLineModel: """Create a callable for computing model spectra. This uses ``param_info`` and ``sl_model`` in the config. Args: obs_info: List of information of each spectral window. specie_list: List of species. sl_dict_list: List of molecular transition properties. If this is provided, the code will use this list instead of querying the database. """ # Create sl_model obs_data = load_preprocess(obs_info) freq_data = get_freq_data(obs_data) if sl_dict_list is None: sl_dict_list_ = [] for item in specie_list: for specie in item["species"]: sl_dict_list_.append( self._sl_db.query_sl_dict(specie, freq_data)) else: sl_dict_list_ = sl_dict_list kwargs = pick_default_kwargs( create_spectral_line_model_state, self._config["sl_model"] ) slm_state = create_spectral_line_model_state( sl_data_list=sl_dict_list_, freq_list=freq_data, obs_info=obs_info, **kwargs, ) param_mgr = self.create_parameter_mgr(specie_list, obs_info) return SpectralLineModel(slm_state, param_mgr)
[docs] def create_peak_mgr(self, obs_info: list, T_base_data: Optional[list]=None) -> PeakManager: """Create a peak manager. This uses ``peak_manager`` in the config. Args: obs_info: List of information of each spectral window. """ obs_data = load_preprocess(obs_info) if "noise_factor" in self._config["peak_manager"]: noise_factor = self._config["peak_manager"]["noise_factor"] prominence = [noise_factor*item["noise"] for item in obs_info] else: prominence = self._config["peak_manager"]["prominence"] return PeakManager( obs_data, prominence=prominence, rel_height=self._config["peak_manager"]["rel_height"], freqs_exclude=self._config["peak_manager"]["freqs_exclude"], T_base_data=T_base_data )
[docs] def create_fitting_model(self, obs_info: list, specie_list: list, loss_fn: Literal["pm", "chi2", "chi2_ls"]="pm", sl_dict_list: Optional[list]=None, T_base_data: Optional[list]=None) -> FittingModel: """Create a callable for fitting. This uses ``param_info``, ``sl_model`` and ``peak_manager`` in the config. Args: obs_info: List of information of each spectral window. specie_list: List of species. loss_fn: Loss function for fitting. - ``"pm"``: Peak matching. - ``"chi2"``: Chi-square. - ``"chi2_ls"``: This should be used for fitting with ``scipy.optimize.least_squares``. sl_dict_list: List of molecular transition properties. If this is provided, the code will use this list instead of querying the database. """ sl_model = self.create_sl_model(obs_info, specie_list, sl_dict_list) # TODO: allow to have different loss functions if loss_fn == "pm": loss_fn = self.create_peak_mgr(obs_info, T_base_data) elif loss_fn == "chi2": obs_data = load_preprocess(obs_info) loss_fn = ChiSquare(obs_data, T_base_data, use_ls=False) elif loss_fn == "chi2_ls": obs_data = load_preprocess(obs_info) loss_fn = ChiSquare(obs_data, T_base_data, use_ls=True) else: raise ValueError(f"Unknown fitting loss {loss_fn}.") # Validate bounds fails = [] for name, item in self._config["param_info"].items(): if item["bound"] is None: fails.append(name) if len(fails) > 0: raise ValueError(f"Set the bounds for {fails}") bounds = sl_model.param_mgr.derive_bounds() return FittingModel(obs_info, sl_model, loss_fn, bounds)
@dataclass(frozen=True) class FittingModel: obs_info: list sl_model: SpectralLineModel loss_fn: Callable bounds: np.ndarray blob: bool = False def __call__(self, params): T_pred_data = self.sl_model(params) value = self.loss_fn(T_pred_data) if self.blob: if isinstance(value, tuple): return value return value, None if isinstance(value, tuple): return value[0] return value class ChiSquare: """Chi-square fitting loss funciton. Args: obs_data (list): List of observation data. T_base_data (list): List of background intensity. use_ls (bool): This must be ``true`` if the optimizer is implmented in ``scipy.optimize.least_squares``. """ def __init__(self, obs_data: list, T_base_data: Optional[list]=None, use_ls: bool=False): T_obs_data = get_T_data(obs_data) if T_base_data is None: self.T_obs_data = T_obs_data else: self.T_obs_data = [T_obs - T_base for T_obs, T_base in zip(T_obs_data, T_base_data)] self._use_ls = use_ls def __call__(self, T_pred_data): delta = [] for T_obs, T_pred in zip(self.T_obs_data, T_pred_data): delta.append(T_obs - T_pred) delta = np.concatenate(delta) if self._use_ls: return delta return .5*np.sum(np.square(delta))