Source code for spectuner.spectral_plot

from typing import Optional
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, AutoMinorLocator, ScalarFormatter

from .config import Config
from .preprocess import load_preprocess, get_freq_data, get_T_data
from .identify import IdentResult


[docs] class PeakPlot: """Multi-window plot for visualizing the spectrum. This class make sub-plots around the given frequencies and automatically merge overlapping windows. Args: freqs: Cenetral frequencies for each window. delta_v: Velocity width in km/s of each window. Windows with overlaps will be merged. n_col: Number of windows per row. plot_width: Figure width of each sub-plot. plot_height: Figure height of each sub-plot. """ def __init__(self, freqs: np.ndarray, delta_v: float=100., n_col: int=4, plot_width: float=4, plot_height: float=3): bounds = [] for freq in freqs: bounds.append([freq*(1 - delta_v/3e5), freq*(1 + delta_v/3e5)]) # Merge inter peaks bounds_new = [] for lower, upper in bounds: if len(bounds_new) == 0 or bounds_new[-1][-1] < lower: bounds_new.append([lower, upper]) else: bounds_new[-1][-1] = max(bounds_new[-1][-1], upper) bounds = bounds_new self._bounds = bounds n_plot = len(bounds) if n_plot < n_col: n_row = 1 n_col = len(bounds) else: n_row = n_plot//n_col + int(n_plot%n_col != 0) self.n_plot = n_plot self._fig, self._axes = plt.subplots( figsize=(n_col*plot_width, n_row*plot_height), nrows=n_row, ncols=n_col ) if n_row == 1 and n_col == 1: self._axes = np.ravel(self.axes) formatter = ScalarFormatter(useOffset=False) formatter.set_scientific(False) for ax in self._axes.flat: ax.xaxis.set_major_locator(MaxNLocator('auto', integer=True)) ax.xaxis.set_major_formatter(formatter) for ax in np.ravel(self._axes)[n_plot:]: ax.axis("off") @property def fig(self): """Figure.""" return self._fig @property def axes(self): """Axes.""" return self._axes @property def bounds(self): """X-limits of each plot.""" return self._bounds
[docs] def plot_spec_from_config(self, config: Config, step_plot: bool=True, ylim_factor: Optional[float]=None, y_top_min: float=0., color="k", **kwargs): """Plot the spectrum defined in the config dict. Args: config: Config dict. step_plot: Whether to use ``plt.step`` instead of ``plt.plot``. ylim_factor: Factor to multiply the maximum value of the spectrum to set the y-axis limit. y_top_min: Minimum value of the y-axis limit. **kwargs: Keyword arguments passed to ``plt.plot`` or ``plt.step``. """ obs_data = load_preprocess(config["obs_info"], clip=False) freq_data = get_freq_data(obs_data) T_data = get_T_data(obs_data) self.plot_spec( freq_data, T_data, step_plot=step_plot, ylim_factor=ylim_factor, y_top_min=y_top_min, color=color, **kwargs )
[docs] def plot_spec(self, freq_data: list, spec_data: list, step_plot: bool=False, ylim_factor: Optional[float]=None, y_top_min: float=0., **kwargs): """Plot a spectrum. Args: freq_data: List of 1D arrays that specifies the frequency values. spec_data: List of 1D arrays that specifies the intensity values. step_plot: Whether to use ``plt.step`` instead of ``plt.plot``. ylim_factor: Factor to multiply the maximum value of the spectrum to set the y-axis limit. y_top_min: Minimum value of the y-axis limit. **kwargs: Keyword arguments passed to ``plt.plot`` or ``plt.step``. """ for i_a, ax in enumerate(self._axes.flat): if i_a >= self.n_plot: continue y_max = 0. lower, upper = self.bounds[i_a] for i_segment, freq in enumerate(freq_data): cond = (freq >= lower) & (freq <= upper) if np.count_nonzero(cond) == 0: continue T_data = spec_data[i_segment][cond] if step_plot: ax.step(freq[cond], T_data, where="mid", **kwargs) else: ax.plot(freq[cond], T_data, **kwargs) y_max = max(y_max, max(T_data)) if ylim_factor is not None: y_top = ylim_factor*y_max y_top = max(y_top, y_top_min) ax.set_ylim(top=y_top)
def plot_prominence(self, freq_data, prom_list): for i_a, ax in enumerate(self._axes.flat): if i_a >= self.n_plot: continue lower, upper = self.bounds[i_a] for freq, prom in zip(freq_data, prom_list): if freq[0] > upper or freq[-1] < upper: continue x_min = max(freq[0], lower) x_max = min(freq[-1], upper) ax.hlines(prom, x_min, x_max, "grey")
[docs] def vlines(self, freqs: np.ndarray, **kwargs): """Plot vertical lines. Args: freqs: Frequencies of the vertical lines. **kwargs: Keyword arguments passed to ``plt.vlines``. """ kwargs_ = {"linestyle": "--", "color": "k"} kwargs_.update(kwargs) for i_a, ax in enumerate(self._axes.flat): if i_a >= self.n_plot: continue y_min_, y_max_ = ax.get_ylim() for freq in freqs: lower, upper = self.bounds[i_a] if freq >= lower and freq <= upper: ax.vlines(freq, y_min_, y_max_, **kwargs_) ax.set_ylim(y_min_, y_max_)
[docs] def vtexts(self, freqs: np.ndarray, texts: np.ndarray, h_txt_offset: float=1.5e-2, v_txt_offset: float=.95, **kwargs): """Plot vertical texts. Args: freqs: Frequencies to plot the texts. texts: Texts to show. h_txt_offset: Horizontal offset of the texts. v_txt_offset: Vertical offset of the texts. **kwargs: Keyword arguments passed to ``plt.text``. """ for i_a, ax in enumerate(self._axes.flat): if i_a >= self.n_plot: continue x_min, x_max = ax.get_xlim() y_min_, y_max_ = ax.get_ylim() for freq, text in zip(freqs, texts): lower, upper = self.bounds[i_a] if freq >= lower and freq <= upper: y_show = y_min_ + v_txt_offset*(y_max_ - y_min_) x_show = freq + h_txt_offset*(x_max - x_min) ax.text( x_show, y_show, text, rotation="vertical", va="top", **kwargs )
[docs] class SpectralPlot: """Multi-row plot for visualizing the spectrum of multiple spectral windows. Args: freq_data: List of 1D arrays that specifies the frequency values for each spectral window. freq_per_row: Frequency range to show in each row. The unit should be the same as ``freq_data``. width: Figure width. height: Figure height of each row. n_minor_tick: Number of minor ticks. Set to 0 to disable minor ticks. axes: Axes to plot. If ``None``, create a new figure. """ def __init__(self, freq_data: list, freq_per_row: float=1000., width: float=20., height: float=3., n_minor_tick: int=10, axes: Optional[np.ndarray]=None): bounds = self._derive_bounds(freq_data, freq_per_row) n_axe = len(bounds) if axes is None: fig, axes = plt.subplots(figsize=(width, n_axe*height), nrows=n_axe,) axes = np.ravel(axes) else: assert len(axes) == n_axe, f"Number of input axes must be equal to {n_axe}." fig = None for i_ax, ax in enumerate(axes): ax.set_xlim(*bounds[i_ax]) if n_minor_tick > 0: ax.xaxis.set_minor_locator(AutoMinorLocator(n_minor_tick)) self._fig = fig self._axes = axes self._bounds = bounds self._y_min = None self._y_max = None self._freq_per_row = freq_per_row def _derive_bounds(self, freq_data, freq_per_row): freq_data = freq_data.copy() freq_data.sort(key=lambda item: item[0]) freq_min = freq_data[0][0] freq_max = freq_data[-1][-1] if freq_max - freq_min < freq_per_row: return [(freq_min, freq_max),] bounds_dict = {} slice_dict = defaultdict(list) i_segment = 0 i_ax = 0 freq_curr = freq_min + freq_per_row idx_b = 0 while i_segment < len(freq_data) and freq_curr < freq_max + freq_per_row: freq = freq_data[i_segment] idx_e = np.searchsorted(freq, freq_curr) if idx_e != 0 and idx_e - idx_b > 1: bounds_dict[i_ax] = (freq_curr - freq_per_row, freq_curr) slice_dict[i_ax].append((i_segment, slice(idx_b, idx_e))) if idx_e != len(freq): freq_curr += freq_per_row i_ax += 1 idx_b = idx_e else: i_segment += 1 idx_b = 0 bounds = [args for args in bounds_dict.values()] return bounds def _get_axe_idx(self, freq): idx = 0 for lower, upper in self._bounds: if freq >= lower and freq <= upper: break idx += 1 else: idx = None return idx
[docs] @classmethod def from_config(cls, config: Config, freq_per_row: float=1000., width: float=20., height: float=3., n_minor_tick: int=10, axes:np.ndarray=None, color: str="k", **kwargs): """Create a plot from a ``Config`` instance. Args: config: ``Config`` instance. freq_per_row: Frequency range to show in each row. The unit should be the same as defined in ``config``. width: Figure width. height: Figure height of each row. n_minor_tick: Number of minor ticks. Set to 0 to disable minor ticks. axes: Axes to plot. If ``None``, create a new figure. color: Color of the spectrum defined in ``config``. **kwargs: Other arguments passed to ``plt.plot`` to plot the spectrum defined in ``config``. """ obs_data = load_preprocess(config["obs_info"], clip=False) freq_data = get_freq_data(obs_data) plot = cls( freq_data=freq_data, freq_per_row=freq_per_row, width=width, height=height, n_minor_tick=n_minor_tick, axes=axes, ) plot.plot_spec(freq_data, get_T_data(obs_data), color=color, **kwargs) noise = np.mean([item["noise"] for item in config["obs_info"]]) plot.set_ylim(-10.*noise, 100.*noise) for ax in plot.axes: ax.set_xlabel("Frequency [MHz]") ax.set_ylabel("Intensity [K]") plt.subplots_adjust(hspace=.3) return plot
@property def fig(self): """Figure.""" return self._fig @property def axes(self): """Axes.""" return self._axes @property def bounds(self): """X-limits of each plot.""" return self._bounds
[docs] def plot_ident_result(self, ident_result: IdentResult, key: Optional[int]=None, name: Optional[str]=None, show_lines: bool=True, color: str="k", color_blen: str="r", color_fp: str="b", h_txt_offset: float=2.5e-3, v_txt_offset: float=.95, fontsize: float=12, T_base_data: Optional[list]=None, kwargs_spec: Optional[dict]=None): """Plot a identification result. This method is a combination of ``plot_spec`` and ``plot_names``. This method reads the curret y limts to plot the lines. Therefore, this method should be called after ``set_ylim``. Args: ident_result: Identification result. key: Molecule ID. This is used to plot the result of a single molecule in a combined result. name: Molecule name. This is used to plot the result of a single molecule in a combined result. show_lines: Whether to show the vertical lines that indicate the molecules. txt_offset: Text offset of the lines. Larger values mean farther from the line. color: Line color of the peaks that match the observed spectrum. Set ``color='none'`` to hide the lines. color_blen: Line color of the peaks that match the observed spectrum but contributed by multiple species. Set ``color_blen='none'`` to hide the lines. color_fp: Color of the peaks found in the fitted spectrum but missing from the observed spectrum. Set ``color_fp='none'`` to hide the lines. fontsize: Font size of the molecules. T_base_data: Base intensity data. kwargs_spec: Keyword arguments passed to ``plt.plot`` to plot the spectrum. """ T_data = ident_result.get_T_pred(key, name) if T_base_data is not None: for i_segment, T_base in enumerate(T_base_data): if T_base_data is None or T_data[i_segment] is None: continue T_data[i_segment] = T_data[i_segment] \ + T_base - ident_result.T_back if kwargs_spec is None: kwargs_spec = {} self.plot_spec(ident_result.freq_data, T_data, **kwargs_spec) if not show_lines: return if key is None: self.plot_names( ident_result.line_table.freq, ident_result.line_table.name, color=color, color_blen=color_blen, h_txt_offset=h_txt_offset, v_txt_offset=v_txt_offset, fontsize=fontsize ) self.plot_names( ident_result.line_table_fp.freq, ident_result.line_table_fp.name, color=color_fp, color_blen=color_fp, h_txt_offset=h_txt_offset, v_txt_offset=v_txt_offset, fontsize=fontsize ) return if name is None: name_set = set(ident_result.T_single_dict[key]) else: name_set = set((name,)) line_table = ident_result.line_table line_table_fp = ident_result.line_table_fp inds = ident_result.filter_name_list(name_set, line_table.name) spans = line_table.freq[inds] name_list = np.array(line_table.name, dtype=object)[inds] self.plot_names( spans, name_list, color=color, color_blen=color_blen, h_txt_offset=h_txt_offset, v_txt_offset=v_txt_offset, fontsize=fontsize ) inds = ident_result.filter_name_list(name_set, line_table_fp.name) spans = line_table_fp.freq[inds] name_list = np.array(line_table_fp.name, dtype=object)[inds] self.plot_names( spans, name_list, color=color_fp, color_blen=color_fp, h_txt_offset=h_txt_offset, v_txt_offset=v_txt_offset, fontsize=fontsize )
[docs] def plot_spec(self, freq_data: list, spec_data: list, *args, color: str="C0", **kwargs): """Plot a spectrum. Args: freq_data: List of 1D arrays specifiyng the frequency of each spectral window. spec_data: List of 1D arrays specifiyng the intensity of each spectral window. *args: Arguments passed to ``plt.plot``. color: Color of the spectrum. **kwargs: Keyword arguments passed to ``plt.plot``. """ sort_list = list(zip(freq_data, spec_data)) sort_list.sort(key=lambda item: item[0][0]) freq_data, spec_data = list(zip(*sort_list)) i_segment = 0 i_ax = 0 idx_b = 0 while i_segment < len(freq_data) and i_ax < len(self.axes): freq = freq_data[i_segment] spec = spec_data[i_segment] idx_e = np.searchsorted(freq, self.bounds[i_ax][-1]) if idx_e - idx_b > 1 and spec is not None: self.axes[i_ax].plot( freq[idx_b:idx_e], spec[idx_b:idx_e], *args, color=color, **kwargs ) # Only use one label if "label" in kwargs: kwargs["label"] = None if idx_e != len(freq): i_ax += 1 idx_b = idx_e else: i_segment += 1 idx_b = 0
[docs] def plot_names(self, freqs: np.ndarray, name_list: list, key: Optional[int]=None, color: str="k", color_blen: str="r", linestyles: str="--", h_txt_offset: float=2.5e-3, v_txt_offset: float=.95, fontsize: float=12): """Plot the identitied names of the lines. Args: freqs: Frequencies of the lines. name_list: Names of the lines. key: Molecule ID. This is used to plot the result of a single molecule in a combined result. color: Line color of the peaks that match the observed spectrum. color_blen: Line color of the peaks that match the observed spectrum but contributed by multiple species. linestyles: Line styles. h_txt_offset: Horizontal text offset. Larger values mean farther from the line. v_txt_offset: Vertical text offset. Smaller values mean farther from the top. fontsize: Font size of the molecules. """ for freq_c, names in zip(freqs, name_list): if names is None or (key is not None and key not in names): continue idx_ax = self._get_axe_idx(freq_c) if idx_ax is None: continue ax = self.axes[idx_ax] y_min, y_max = self.get_ylim(ax) c = color if len(names) == 1 else color_blen ax.vlines(freq_c, y_min, y_max, c, linestyles) y_show = y_min + v_txt_offset*(y_max - y_min) x_show = freq_c + h_txt_offset*self._freq_per_row ax.text( x_show, y_show, "\n".join(names), rotation="vertical", verticalalignment="top", fontsize=fontsize, c=c )
[docs] def plot_unknown_lines(self, ident_result: IdentResult, color: str="grey", linestyle: str="-", alpha: float=0.5): """Plot unidentified lines. Args: ident_result: Identification result. color: Color of the lines. linestyle: Line style. alpha: Transparency. """ freqs = ident_result.get_unknown_lines() self.vlines(freqs, colors=color, linestyles=linestyle, alpha=alpha)
[docs] def vlines(self, freqs: np.ndarray, **kwargs): """Plot vertical lines. Args: freqs: Frequencies of the lines. **kwargs: Keyword arguments passed to ``plt.vlines``. """ for freq_c in freqs: idx_ax = self._get_axe_idx(freq_c) if idx_ax is None: continue ax = self.axes[idx_ax] y_min, y_max = self.get_ylim(ax) ax.vlines(freq_c, y_min, y_max, **kwargs)
[docs] def set_ylim(self, y_min: float, y_max: float, **kwargs): """Set the y limits for each plot. Args: y_min: Minimum y value. y_max: Maximum y value. **kwargs: Keyword arguments passed to ``plt.set_ylim``. """ for ax in self.axes: ax.set_ylim(y_min, y_max, **kwargs) self._y_min = y_min self._y_max = y_max
[docs] def get_ylim(self, ax): """Get the y limits for the given axis. Args: ax: Axis. """ y_min, y_max = ax.get_ylim() if self._y_min is not None: y_min = self._y_min if self._y_max is not None: y_max = self._y_max return y_min, y_max