Source code for lsdb.core.crossmatch.kdtree_match
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
from hats.pixel_math.validators import validate_radius
from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm
from lsdb.core.crossmatch.crossmatch_args import CrossmatchArgs
from lsdb.core.crossmatch.kdtree_utils import _find_crossmatch_indices, _get_chord_distance, _lon_lat_to_xyz
if TYPE_CHECKING:
from lsdb.catalog import Catalog
[docs]
class KdTreeCrossmatch(AbstractCrossmatchAlgorithm):
"""Nearest neighbor crossmatch using a 3D k-D tree"""
extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64()))})
[docs]
def __init__(self, n_neighbors: int = 1, radius_arcsec: float = 1.0, min_radius_arcsec: float = 0.0):
"""Initialize the KDTree crossmatch algorithm.
Parameters
----------
n_neighbors : int
The number of neighbors to find within each point.
radius_arcsec : float, default 1.0
The threshold distance in arcseconds beyond which neighbors are not added.
min_radius_arcsec : float, default 0.0
The threshold distance in arcseconds beyond which neighbors are added.
"""
super().__init__()
self.n_neighbors = n_neighbors
self.radius_arcsec = radius_arcsec
self.min_radius_arcsec = min_radius_arcsec
def validate(self, left: Catalog, right: Catalog):
"""Validate the arguments for the crossmatch"""
super().validate(left, right)
validate_radius(self.radius_arcsec)
if self.n_neighbors < 1:
raise ValueError("n_neighbors must be greater than 1")
if (
right.margin is not None
and right.margin.hc_structure.catalog_info.margin_threshold < self.radius_arcsec
):
raise ValueError("Cross match radius is greater than margin threshold")
if self.min_radius_arcsec < 0:
raise ValueError("The minimum radius must be non-negative")
if self.radius_arcsec <= self.min_radius_arcsec:
raise ValueError("Cross match maximum radius must be greater than cross match minimum radius")
def perform_crossmatch(
self, crossmatch_args: CrossmatchArgs
) -> tuple[np.ndarray, np.ndarray, pd.DataFrame]:
"""Perform a cross-match between the data from two HEALPix pixels
Finds the n closest neighbors in the right catalog for each point in the left catalog that
are within a threshold distance by using a K-D Tree.
Parameters
----------
crossmatch_args : CrossmatchArgs
The partitions and respective pixel information.
Returns
-------
tuple[np.ndarray, np.ndarray, pd.DataFrame]
- a numpy array with the indices of the matching rows from the left table
- a numpy array with the indices of the matching rows from the right table
- a pandas dataframe with any additional columns generated by the algorithm
These all must have the same lengths.
"""
# Distance in 3-D space for unit sphere
max_d_chord = _get_chord_distance(self.radius_arcsec)
min_d_chord = _get_chord_distance(self.min_radius_arcsec)
# calculate the cartesian coordinates of the points
left_xyz, right_xyz = self._get_point_coordinates(
crossmatch_args.left_df,
crossmatch_args.left_catalog_info,
crossmatch_args.right_df,
crossmatch_args.right_catalog_info,
)
# get matching indices for cross-matched rows
chord_distances, left_idx, right_idx = _find_crossmatch_indices(
left_xyz,
right_xyz,
n_neighbors=self.n_neighbors,
min_distance=min_d_chord,
max_distance=max_d_chord,
)
arc_distances = np.degrees(2.0 * np.arcsin(0.5 * chord_distances)) * 3600
extra_columns = pd.DataFrame(
{"_dist_arcsec": pd.Series(arc_distances, dtype=pd.ArrowDtype(pa.float64()))}
)
return left_idx, right_idx, extra_columns
def _get_point_coordinates(
self, left_df, left_catalog_info, right_df, right_catalog_info
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
left_xyz = _lon_lat_to_xyz(
lon=left_df[left_catalog_info.ra_column].to_numpy(),
lat=left_df[left_catalog_info.dec_column].to_numpy(),
)
right_xyz = _lon_lat_to_xyz(
lon=right_df[right_catalog_info.ra_column].to_numpy(),
lat=right_df[right_catalog_info.dec_column].to_numpy(),
)
return left_xyz, right_xyz