Source code for lsdb.core.crossmatch.abstract_crossmatch_algorithm

from __future__ import annotations

from abc import ABC
from typing import TYPE_CHECKING

import nested_pandas as npd
import numpy as np
import numpy.typing as npt
import pandas as pd

from lsdb.core.crossmatch.crossmatch_args import CrossmatchArgs
from lsdb.dask.merge_catalog_functions import apply_suffixes

if TYPE_CHECKING:
    from lsdb.catalog import Catalog


def _na_series_for_dtype(dtype, length):
    """Return a Series with the right NA type for the given dtype and length.
    Used for building NA rows for unmatched entries in left-joins.
    All dtypes will be PyArrow based, even for catalogs imported using
    `lsdb.from_dataframe`.
    """
    # Can rely on Pandas to handle NA values correctly for all PyArrow dtypes
    return pd.Series([None] * length, dtype=dtype)


# pylint: disable=too-many-instance-attributes, too-many-arguments
[docs] class AbstractCrossmatchAlgorithm(ABC): """Abstract class used to write a crossmatch algorithm. To specify a custom algorithm, write a class that subclasses the `AbstractCrossmatchAlgorithm` class, and either overwrite the `crossmatch` or the `perform_crossmatch` function. The function should be able to perform a crossmatch on two pandas DataFrames from a partition from each catalog. It should return two 1d numpy arrays of equal lengths with the indices of the matching rows from the left and right dataframes, and a dataframe with any extra columns generated by the crossmatch algorithm, also with the same length. These columns are specified in {AbstractCrossmatchAlgorithm.extra_columns}, with their respective data types, by means of an empty pandas dataframe. As an example, the KdTreeCrossmatch algorithm outputs a "_dist_arcsec" column with the distance between data points. Its extra_columns attribute is specified as follows:: pd.DataFrame({"_dist_arcsec": pd.Series(dtype=np.dtype("float64"))}) The `crossmatch`/`perform_crossmatch` methods will receive an instance of `CrossmatchArgs` which includes the partitions and respective pixel information:: - left_df: npd.NestedFrame - right_df: npd.NestedFrame - left_order: int - left_pixel: int - right_order: int - right_pixel: int - left_catalog_info: hc.catalog.TableProperties - right_catalog_info: hc.catalog.TableProperties - right_margin_catalog_info: hc.catalog.TableProperties Include any algorithm-specific parameters in the initialization of your object. These parameters should be validated in `AbstractCrossmatchAlgorithm.validate`, by overwriting the method. """ extra_columns: pd.DataFrame | None = None """The metadata for the columns generated by the crossmatch algorithm""" def crossmatch( self, crossmatch_args: CrossmatchArgs, how: str, suffixes: tuple[str, str], suffix_method: str = "all_columns", ) -> npd.NestedFrame: """Perform a crossmatch. Parameters ---------- crossmatch_args : CrossmatchArgs The partitions and respective pixel information. how : str One of {'inner', 'left'} suffixes : tuple[str,str] A pair of suffixes to be appended to the end of each column name, with the first appended to the left columns and the second to the right columns. suffix_method : str, default 'all_columns' The suffix method to use. Returns ------- npd.NestedFrame The dataframe containing the results of the crossmatch. """ # If there's no right data, return empty arrays (e.g., for left-join with no matching right partition) if crossmatch_args.right_df is None or len(crossmatch_args.right_df) == 0: l_inds = np.array([], dtype=np.int64) r_inds = np.array([], dtype=np.int64) extra_cols = self.extra_columns.copy() if self.extra_columns is not None else pd.DataFrame() else: l_inds, r_inds, extra_cols = self.perform_crossmatch(crossmatch_args) if not len(l_inds) == len(r_inds) == len(extra_cols): raise ValueError( "Crossmatch algorithm must return left and right indices and extra columns with same length" ) return self._create_crossmatch_df( crossmatch_args.left_df, crossmatch_args.right_df, l_inds, r_inds, extra_cols, how, suffixes, suffix_method, ) def crossmatch_nested( self, crossmatch_args: CrossmatchArgs, nested_column_name: str, how: str ) -> npd.NestedFrame: """Perform a crossmatch and store results in nested column. Parameters ---------- crossmatch_args : CrossmatchArgs The partitions and respective pixel information. nested_column_name : str The name of the column where the matches should be stored. how : str How to handle the crossmatch of the two catalogs. One of {'left', 'inner'}. Returns ------- npd.NestedFrame The dataframe containing the results of the crossmatch. """ l_inds, r_inds, extra_cols = self.perform_crossmatch(crossmatch_args) if not len(l_inds) == len(r_inds) == len(extra_cols): raise ValueError( "Crossmatch algorithm must return left and right indices and extra columns with same length" ) return self._create_nested_crossmatch_df( crossmatch_args.left_df, crossmatch_args.right_df, l_inds, r_inds, extra_cols, nested_column_name, how, ) def perform_crossmatch( self, crossmatch_args: CrossmatchArgs ) -> tuple[np.ndarray, np.ndarray, pd.DataFrame]: """Performs a crossmatch to get the indices of the matching rows and any extra columns Parameters ---------- crossmatch_args : CrossmatchArgs The partitions and respective pixel information. Returns ------- tuple[np.ndarray, np.ndarray, pd.DataFrame] - a numpy array with the indices of the matching rows from the left table - a numpy array with the indices of the matching rows from the right table - a pandas dataframe with any additional columns generated by the algorithm These all must have the same lengths. """ raise NotImplementedError( "Crossmatch algorithm must either implement `perform_crossmatch` or overwrite `crossmatch`" ) # pylint: disable=protected-access def validate(self, left: Catalog, right: Catalog): """Validate the metadata and arguments. This method will be called **once**, after the algorithm object has been initialized, during the lazy construction of the execution graph. This can be used to catch simple errors without waiting for an expensive ``.compute()`` call. This should validate any parameters used to initialize the crossmatching algorithm. Parameters ---------- left: Catalog The left catalog for the crossmatch. right: Catalog The right catalog for the crossmatch. """ # Check that we have the appropriate columns in our dataset. column_names = left._ddf.columns if left.hc_structure.catalog_info.ra_column not in column_names: raise ValueError(f"left table must have column {left.hc_structure.catalog_info.ra_column}") if left.hc_structure.catalog_info.dec_column not in column_names: raise ValueError(f"left table must have column {left.hc_structure.catalog_info.dec_column}") column_names = right._ddf.columns if right.hc_structure.catalog_info.ra_column not in column_names: raise ValueError(f"right table must have column {right.hc_structure.catalog_info.ra_column}") if right.hc_structure.catalog_info.dec_column not in column_names: raise ValueError(f"right table must have column {right.hc_structure.catalog_info.dec_column}") @staticmethod def _rename_columns_with_suffix(dataframe, suffix): columns_renamed = {name: name + suffix for name in dataframe.columns} dataframe.rename(columns=columns_renamed, inplace=True) @classmethod def _append_extra_columns(cls, dataframe: npd.NestedFrame, extra_columns: pd.DataFrame | None = None): """Adds crossmatch extra columns to the resulting Dataframe.""" if cls.extra_columns is None: return if extra_columns is None: raise ValueError("No extra column values were provided") # Check if the provided columns are in the specification for col in extra_columns.columns: if col not in cls.extra_columns.columns: raise ValueError(f"Provided extra column '{col}' not found in definition") # Update columns according to crossmatch algorithm specification columns_to_update = [] for col, col_type in cls.extra_columns.dtypes.items(): if col not in extra_columns: raise ValueError(f"Missing extra column '{col} of type {col_type}'") if col_type != extra_columns[col].dtype: raise ValueError(f"Invalid type '{col_type}' for extra column '{col}'") columns_to_update.append(col) for col in columns_to_update: new_col = extra_columns[col] new_col.index = dataframe.index dataframe[col] = new_col # pylint: disable=too-many-locals def _create_crossmatch_df( self, left_df: npd.NestedFrame, right_df: npd.NestedFrame, left_idx: npt.NDArray[np.int64], right_idx: npt.NDArray[np.int64], extra_cols: pd.DataFrame, how: str, suffixes: tuple[str, str], suffix_method: str = "all_columns", ) -> npd.NestedFrame: """Creates a df containing the crossmatch result from matching indices and additional columns Parameters ---------- left_df : npd.NestedFrame The left partition. right_df : npd.NestedFrame The right partition. left_idx : np.ndarray indices of the matching rows from the left table right_idx : np.ndarray indices of the matching rows from the right table extra_cols : pd.DataFrame dataframe containing additional columns from crossmatching suffixes : tuple[str,str] A pair of suffixes to be appended to the end of each column name, with the first appended to the left columns and the second to the right columns. suffix_method, default 'all_columns' The suffix method to use. Returns ------- npd.NestedFrame A dataframe with the matching rows from the left and right table concatenated together, with the additional columns added. """ # rename columns so no same names during merging left_df, right_df = apply_suffixes(left_df, right_df, suffixes, suffix_method, log_changes=False) # concat dataframes together index_name = left_df.index.name if left_df.index.name is not None else "index" left_join_part = left_df.iloc[left_idx].reset_index() right_join_part = right_df.iloc[right_idx].reset_index(drop=True) out = pd.concat( [ left_join_part, right_join_part, ], axis=1, ) out.set_index(index_name, inplace=True) if how == "left": # Matched rows: replicate left rows for each match (one row per left-right pair) left_matched = left_df.iloc[left_idx].reset_index() right_matched = right_df.iloc[right_idx].reset_index(drop=True) matched_out = pd.concat([left_matched, right_matched], axis=1) # Unmatched left rows: keep each left row once, with NA values for right columns # Create a set of matched position indices (not index values, to handle non-unique indices) matched_mask = np.zeros(len(left_df), dtype=bool) matched_mask[left_idx] = True left_unmatched = left_df.iloc[~matched_mask].reset_index() # Build empty right-side columns (same names and dtypes as right_matched) filled with NA. # We know that both left_df and right_df have RA and DEC columns, and left_matched # and right_matched are derived from these. Hence we do not need to check for empty # columns. unmatched_right = pd.DataFrame( { col: _na_series_for_dtype(right_df[col].dtype, len(left_unmatched)) for col in right_df.columns } ) unmatched_out = pd.concat([left_unmatched, unmatched_right], axis=1) # Combine matched (one row per match) and unmatched (one row per non-match) out = pd.concat([matched_out, unmatched_out], axis=0, ignore_index=True) # Restore the original left index as the DataFrame index out.set_index(index_name, inplace=True) # Ensure extra_cols has the same number of rows as `out` for left-joins. # For left-joins we may have unmatched left rows (one per left row) so # extra_cols (which only contains rows for matched pairs) must be # expanded with NAs to cover unmatched rows before assigning the index. n_out = len(out) n_extra = len(extra_cols) assert n_extra <= n_out, ( f"Logic error: extra_cols ({n_extra} rows) cannot exceed output ({n_out} rows). " f"This indicates a bug in the crossmatch algorithm or the join logic." ) if n_extra == n_out: full_extra = extra_cols.reset_index(drop=True) else: # n_unmatched = number of rows in `out` that correspond to left rows without a match n_unmatched = n_out - n_extra tail = pd.DataFrame( {c: _na_series_for_dtype(extra_cols[c].dtype, n_unmatched) for c in extra_cols.columns} ) full_extra = pd.concat([extra_cols.reset_index(drop=True), tail], ignore_index=True) extra_cols = full_extra elif how == "inner": left_join_part = left_df.iloc[left_idx].reset_index() right_join_part = right_df.iloc[right_idx].reset_index(drop=True) out = pd.concat( [ left_join_part, right_join_part, ], axis=1, ) out.set_index(index_name, inplace=True) # align index extra_cols.index = out.index self._append_extra_columns(out, extra_cols) return npd.NestedFrame(out) def _create_nested_crossmatch_df( self, left_df: npd.NestedFrame, right_df: npd.NestedFrame, left_idx: npt.NDArray[np.int64], right_idx: npt.NDArray[np.int64], extra_cols: pd.DataFrame, nested_column_name: str, how: str, ) -> npd.NestedFrame: """Creates a df containing the crossmatch result from matching indices and additional columns Parameters ---------- left_df : npd.NestedFrame The left partition. right_df : npd.NestedFrame The right partition. left_idx : np.ndarray Indices of the matching rows from the left table right_idx : np.ndarray Indices of the matching rows from the right table extra_cols : pd.DataFrame Dataframe containing additional columns from crossmatching nested_column_name : str The name of the column where the matches should be stored. how : str How to handle the crossmatch of the two catalogs. One of {'left', 'inner'}. Returns ------- npd.NestedFrame A dataframe with the matching rows from the left and right table concatenated together, with the additional columns added. """ # concat dataframes together index_name = left_df.index.name if left_df.index.name is not None else "index" left_join_part = left_df.reset_index() right_join_part = right_df.iloc[right_idx].copy() right_join_part["new_index_col"] = left_idx right_join_part = right_join_part.set_index("new_index_col") self._append_extra_columns(right_join_part, extra_cols) out = left_join_part.join_nested(right_join_part, nested_column_name, how=how) out.set_index(index_name, inplace=True) return out