from __future__ import annotations
from abc import ABC
from typing import TYPE_CHECKING
import nested_pandas as npd
import numpy as np
import numpy.typing as npt
import pandas as pd
from lsdb.core.crossmatch.crossmatch_args import CrossmatchArgs
from lsdb.dask.merge_catalog_functions import apply_suffixes
if TYPE_CHECKING:
from lsdb.catalog import Catalog
def _na_series_for_dtype(dtype, length):
"""Return a Series with the right NA type for the given dtype and length.
Used for building NA rows for unmatched entries in left-joins.
All dtypes will be PyArrow based, even for catalogs imported using
`lsdb.from_dataframe`.
"""
# Can rely on Pandas to handle NA values correctly for all PyArrow dtypes
return pd.Series([None] * length, dtype=dtype)
# pylint: disable=too-many-instance-attributes, too-many-arguments
[docs]
class AbstractCrossmatchAlgorithm(ABC):
"""Abstract class used to write a crossmatch algorithm.
To specify a custom algorithm, write a class that subclasses the
`AbstractCrossmatchAlgorithm` class, and either overwrite the `crossmatch`
or the `perform_crossmatch` function.
The function should be able to perform a crossmatch on two pandas DataFrames
from a partition from each catalog. It should return two 1d numpy arrays of equal lengths
with the indices of the matching rows from the left and right dataframes, and a dataframe
with any extra columns generated by the crossmatch algorithm, also with the same length.
These columns are specified in {AbstractCrossmatchAlgorithm.extra_columns}, with
their respective data types, by means of an empty pandas dataframe. As an example,
the KdTreeCrossmatch algorithm outputs a "_dist_arcsec" column with the distance between
data points. Its extra_columns attribute is specified as follows::
pd.DataFrame({"_dist_arcsec": pd.Series(dtype=np.dtype("float64"))})
The `crossmatch`/`perform_crossmatch` methods will receive an instance of `CrossmatchArgs`
which includes the partitions and respective pixel information::
- left_df: npd.NestedFrame
- right_df: npd.NestedFrame
- left_order: int
- left_pixel: int
- right_order: int
- right_pixel: int
- left_catalog_info: hc.catalog.TableProperties
- right_catalog_info: hc.catalog.TableProperties
- right_margin_catalog_info: hc.catalog.TableProperties
Include any algorithm-specific parameters in the initialization of your object.
These parameters should be validated in `AbstractCrossmatchAlgorithm.validate`,
by overwriting the method.
"""
extra_columns: pd.DataFrame | None = None
"""The metadata for the columns generated by the crossmatch algorithm"""
def crossmatch(
self,
crossmatch_args: CrossmatchArgs,
how: str,
suffixes: tuple[str, str],
suffix_method: str = "all_columns",
) -> npd.NestedFrame:
"""Perform a crossmatch.
Parameters
----------
crossmatch_args : CrossmatchArgs
The partitions and respective pixel information.
how : str
One of {'inner', 'left'}
suffixes : tuple[str,str]
A pair of suffixes to be appended to the end of each column name, with the first
appended to the left columns and the second to the right columns.
suffix_method : str, default 'all_columns'
The suffix method to use.
Returns
-------
npd.NestedFrame
The dataframe containing the results of the crossmatch.
"""
# If there's no right data, return empty arrays (e.g., for left-join with no matching right partition)
if crossmatch_args.right_df is None or len(crossmatch_args.right_df) == 0:
l_inds = np.array([], dtype=np.int64)
r_inds = np.array([], dtype=np.int64)
extra_cols = self.extra_columns.copy() if self.extra_columns is not None else pd.DataFrame()
else:
l_inds, r_inds, extra_cols = self.perform_crossmatch(crossmatch_args)
if not len(l_inds) == len(r_inds) == len(extra_cols):
raise ValueError(
"Crossmatch algorithm must return left and right indices and extra columns with same length"
)
return self._create_crossmatch_df(
crossmatch_args.left_df,
crossmatch_args.right_df,
l_inds,
r_inds,
extra_cols,
how,
suffixes,
suffix_method,
)
def crossmatch_nested(
self, crossmatch_args: CrossmatchArgs, nested_column_name: str, how: str
) -> npd.NestedFrame:
"""Perform a crossmatch and store results in nested column.
Parameters
----------
crossmatch_args : CrossmatchArgs
The partitions and respective pixel information.
nested_column_name : str
The name of the column where the matches should be stored.
how : str
How to handle the crossmatch of the two catalogs.
One of {'left', 'inner'}.
Returns
-------
npd.NestedFrame
The dataframe containing the results of the crossmatch.
"""
l_inds, r_inds, extra_cols = self.perform_crossmatch(crossmatch_args)
if not len(l_inds) == len(r_inds) == len(extra_cols):
raise ValueError(
"Crossmatch algorithm must return left and right indices and extra columns with same length"
)
return self._create_nested_crossmatch_df(
crossmatch_args.left_df,
crossmatch_args.right_df,
l_inds,
r_inds,
extra_cols,
nested_column_name,
how,
)
def perform_crossmatch(
self, crossmatch_args: CrossmatchArgs
) -> tuple[np.ndarray, np.ndarray, pd.DataFrame]:
"""Performs a crossmatch to get the indices of the matching rows and any extra columns
Parameters
----------
crossmatch_args : CrossmatchArgs
The partitions and respective pixel information.
Returns
-------
tuple[np.ndarray, np.ndarray, pd.DataFrame]
- a numpy array with the indices of the matching rows from the left table
- a numpy array with the indices of the matching rows from the right table
- a pandas dataframe with any additional columns generated by the algorithm
These all must have the same lengths.
"""
raise NotImplementedError(
"Crossmatch algorithm must either implement `perform_crossmatch` or overwrite `crossmatch`"
)
# pylint: disable=protected-access
def validate(self, left: Catalog, right: Catalog):
"""Validate the metadata and arguments.
This method will be called **once**, after the algorithm object has
been initialized, during the lazy construction of the execution graph.
This can be used to catch simple errors without waiting for an
expensive ``.compute()`` call.
This should validate any parameters used to initialize the crossmatching algorithm.
Parameters
----------
left: Catalog
The left catalog for the crossmatch.
right: Catalog
The right catalog for the crossmatch.
"""
# Check that we have the appropriate columns in our dataset.
column_names = left._ddf.columns
if left.hc_structure.catalog_info.ra_column not in column_names:
raise ValueError(f"left table must have column {left.hc_structure.catalog_info.ra_column}")
if left.hc_structure.catalog_info.dec_column not in column_names:
raise ValueError(f"left table must have column {left.hc_structure.catalog_info.dec_column}")
column_names = right._ddf.columns
if right.hc_structure.catalog_info.ra_column not in column_names:
raise ValueError(f"right table must have column {right.hc_structure.catalog_info.ra_column}")
if right.hc_structure.catalog_info.dec_column not in column_names:
raise ValueError(f"right table must have column {right.hc_structure.catalog_info.dec_column}")
@staticmethod
def _rename_columns_with_suffix(dataframe, suffix):
columns_renamed = {name: name + suffix for name in dataframe.columns}
dataframe.rename(columns=columns_renamed, inplace=True)
@classmethod
def _append_extra_columns(cls, dataframe: npd.NestedFrame, extra_columns: pd.DataFrame | None = None):
"""Adds crossmatch extra columns to the resulting Dataframe."""
if cls.extra_columns is None:
return
if extra_columns is None:
raise ValueError("No extra column values were provided")
# Check if the provided columns are in the specification
for col in extra_columns.columns:
if col not in cls.extra_columns.columns:
raise ValueError(f"Provided extra column '{col}' not found in definition")
# Update columns according to crossmatch algorithm specification
columns_to_update = []
for col, col_type in cls.extra_columns.dtypes.items():
if col not in extra_columns:
raise ValueError(f"Missing extra column '{col} of type {col_type}'")
if col_type != extra_columns[col].dtype:
raise ValueError(f"Invalid type '{col_type}' for extra column '{col}'")
columns_to_update.append(col)
for col in columns_to_update:
new_col = extra_columns[col]
new_col.index = dataframe.index
dataframe[col] = new_col
# pylint: disable=too-many-locals
def _create_crossmatch_df(
self,
left_df: npd.NestedFrame,
right_df: npd.NestedFrame,
left_idx: npt.NDArray[np.int64],
right_idx: npt.NDArray[np.int64],
extra_cols: pd.DataFrame,
how: str,
suffixes: tuple[str, str],
suffix_method: str = "all_columns",
) -> npd.NestedFrame:
"""Creates a df containing the crossmatch result from matching indices and additional columns
Parameters
----------
left_df : npd.NestedFrame
The left partition.
right_df : npd.NestedFrame
The right partition.
left_idx : np.ndarray
indices of the matching rows from the left table
right_idx : np.ndarray
indices of the matching rows from the right table
extra_cols : pd.DataFrame
dataframe containing additional columns from crossmatching
suffixes : tuple[str,str]
A pair of suffixes to be appended to the end of each column name, with the first
appended to the left columns and the second to the right columns.
suffix_method, default 'all_columns'
The suffix method to use.
Returns
-------
npd.NestedFrame
A dataframe with the matching rows from the left and right table
concatenated together, with the additional columns added.
"""
# rename columns so no same names during merging
left_df, right_df = apply_suffixes(left_df, right_df, suffixes, suffix_method, log_changes=False)
# concat dataframes together
index_name = left_df.index.name if left_df.index.name is not None else "index"
left_join_part = left_df.iloc[left_idx].reset_index()
right_join_part = right_df.iloc[right_idx].reset_index(drop=True)
out = pd.concat(
[
left_join_part,
right_join_part,
],
axis=1,
)
out.set_index(index_name, inplace=True)
if how == "left":
# Matched rows: replicate left rows for each match (one row per left-right pair)
left_matched = left_df.iloc[left_idx].reset_index()
right_matched = right_df.iloc[right_idx].reset_index(drop=True)
matched_out = pd.concat([left_matched, right_matched], axis=1)
# Unmatched left rows: keep each left row once, with NA values for right columns
# Create a set of matched position indices (not index values, to handle non-unique indices)
matched_mask = np.zeros(len(left_df), dtype=bool)
matched_mask[left_idx] = True
left_unmatched = left_df.iloc[~matched_mask].reset_index()
# Build empty right-side columns (same names and dtypes as right_matched) filled with NA.
# We know that both left_df and right_df have RA and DEC columns, and left_matched
# and right_matched are derived from these. Hence we do not need to check for empty
# columns.
unmatched_right = pd.DataFrame(
{
col: _na_series_for_dtype(right_df[col].dtype, len(left_unmatched))
for col in right_df.columns
}
)
unmatched_out = pd.concat([left_unmatched, unmatched_right], axis=1)
# Combine matched (one row per match) and unmatched (one row per non-match)
out = pd.concat([matched_out, unmatched_out], axis=0, ignore_index=True)
# Restore the original left index as the DataFrame index
out.set_index(index_name, inplace=True)
# Ensure extra_cols has the same number of rows as `out` for left-joins.
# For left-joins we may have unmatched left rows (one per left row) so
# extra_cols (which only contains rows for matched pairs) must be
# expanded with NAs to cover unmatched rows before assigning the index.
n_out = len(out)
n_extra = len(extra_cols)
assert n_extra <= n_out, (
f"Logic error: extra_cols ({n_extra} rows) cannot exceed output ({n_out} rows). "
f"This indicates a bug in the crossmatch algorithm or the join logic."
)
if n_extra == n_out:
full_extra = extra_cols.reset_index(drop=True)
else:
# n_unmatched = number of rows in `out` that correspond to left rows without a match
n_unmatched = n_out - n_extra
tail = pd.DataFrame(
{c: _na_series_for_dtype(extra_cols[c].dtype, n_unmatched) for c in extra_cols.columns}
)
full_extra = pd.concat([extra_cols.reset_index(drop=True), tail], ignore_index=True)
extra_cols = full_extra
elif how == "inner":
left_join_part = left_df.iloc[left_idx].reset_index()
right_join_part = right_df.iloc[right_idx].reset_index(drop=True)
out = pd.concat(
[
left_join_part,
right_join_part,
],
axis=1,
)
out.set_index(index_name, inplace=True)
# align index
extra_cols.index = out.index
self._append_extra_columns(out, extra_cols)
return npd.NestedFrame(out)
def _create_nested_crossmatch_df(
self,
left_df: npd.NestedFrame,
right_df: npd.NestedFrame,
left_idx: npt.NDArray[np.int64],
right_idx: npt.NDArray[np.int64],
extra_cols: pd.DataFrame,
nested_column_name: str,
how: str,
) -> npd.NestedFrame:
"""Creates a df containing the crossmatch result from matching indices and additional columns
Parameters
----------
left_df : npd.NestedFrame
The left partition.
right_df : npd.NestedFrame
The right partition.
left_idx : np.ndarray
Indices of the matching rows from the left table
right_idx : np.ndarray
Indices of the matching rows from the right table
extra_cols : pd.DataFrame
Dataframe containing additional columns from crossmatching
nested_column_name : str
The name of the column where the matches should be stored.
how : str
How to handle the crossmatch of the two catalogs.
One of {'left', 'inner'}.
Returns
-------
npd.NestedFrame
A dataframe with the matching rows from the left and right table
concatenated together, with the additional columns added.
"""
# concat dataframes together
index_name = left_df.index.name if left_df.index.name is not None else "index"
left_join_part = left_df.reset_index()
right_join_part = right_df.iloc[right_idx].copy()
right_join_part["new_index_col"] = left_idx
right_join_part = right_join_part.set_index("new_index_col")
self._append_extra_columns(right_join_part, extra_cols)
out = left_join_part.join_nested(right_join_part, nested_column_name, how=how)
out.set_index(index_name, inplace=True)
return out