Source code for lsdb.core.crossmatch.kdtree_match

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
from hats.pixel_math.validators import validate_radius

from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm
from lsdb.core.crossmatch.crossmatch_args import CrossmatchArgs
from lsdb.core.crossmatch.kdtree_utils import _find_crossmatch_indices, _get_chord_distance, _lon_lat_to_xyz

if TYPE_CHECKING:
    from lsdb.catalog import Catalog


[docs] class KdTreeCrossmatch(AbstractCrossmatchAlgorithm): """Nearest neighbor crossmatch using a 3D k-D tree""" extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64()))})
[docs] def __init__(self, n_neighbors: int = 1, radius_arcsec: float = 1.0, min_radius_arcsec: float = 0.0): """Initialize the KDTree crossmatch algorithm. Parameters ---------- n_neighbors : int The number of neighbors to find within each point. radius_arcsec : float, default 1.0 The threshold distance in arcseconds beyond which neighbors are not added. min_radius_arcsec : float, default 0.0 The threshold distance in arcseconds beyond which neighbors are added. """ super().__init__() self.n_neighbors = n_neighbors self.radius_arcsec = radius_arcsec self.min_radius_arcsec = min_radius_arcsec
def validate(self, left: Catalog, right: Catalog): """Validate the arguments for the crossmatch""" super().validate(left, right) validate_radius(self.radius_arcsec) if self.n_neighbors < 1: raise ValueError("n_neighbors must be greater than 1") if ( right.margin is not None and right.margin.hc_structure.catalog_info.margin_threshold < self.radius_arcsec ): raise ValueError("Cross match radius is greater than margin threshold") if self.min_radius_arcsec < 0: raise ValueError("The minimum radius must be non-negative") if self.radius_arcsec <= self.min_radius_arcsec: raise ValueError("Cross match maximum radius must be greater than cross match minimum radius") def perform_crossmatch( self, crossmatch_args: CrossmatchArgs ) -> tuple[np.ndarray, np.ndarray, pd.DataFrame]: """Perform a cross-match between the data from two HEALPix pixels Finds the n closest neighbors in the right catalog for each point in the left catalog that are within a threshold distance by using a K-D Tree. Parameters ---------- crossmatch_args : CrossmatchArgs The partitions and respective pixel information. Returns ------- tuple[np.ndarray, np.ndarray, pd.DataFrame] - a numpy array with the indices of the matching rows from the left table - a numpy array with the indices of the matching rows from the right table - a pandas dataframe with any additional columns generated by the algorithm These all must have the same lengths. """ # Distance in 3-D space for unit sphere max_d_chord = _get_chord_distance(self.radius_arcsec) min_d_chord = _get_chord_distance(self.min_radius_arcsec) # calculate the cartesian coordinates of the points left_xyz, right_xyz = self._get_point_coordinates( crossmatch_args.left_df, crossmatch_args.left_catalog_info, crossmatch_args.right_df, crossmatch_args.right_catalog_info, ) # get matching indices for cross-matched rows chord_distances, left_idx, right_idx = _find_crossmatch_indices( left_xyz, right_xyz, n_neighbors=self.n_neighbors, min_distance=min_d_chord, max_distance=max_d_chord, ) arc_distances = np.degrees(2.0 * np.arcsin(0.5 * chord_distances)) * 3600 extra_columns = pd.DataFrame( {"_dist_arcsec": pd.Series(arc_distances, dtype=pd.ArrowDtype(pa.float64()))} ) return left_idx, right_idx, extra_columns def _get_point_coordinates( self, left_df, left_catalog_info, right_df, right_catalog_info ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: left_xyz = _lon_lat_to_xyz( lon=left_df[left_catalog_info.ra_column].to_numpy(), lat=left_df[left_catalog_info.dec_column].to_numpy(), ) right_xyz = _lon_lat_to_xyz( lon=right_df[right_catalog_info.ra_column].to_numpy(), lat=right_df[right_catalog_info.dec_column].to_numpy(), ) return left_xyz, right_xyz