Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ env:
jobs:
test-matrix:
runs-on: ubuntu-latest
# Backstop: a hung multiprocessing worker (e.g. during a pickle regression)
# should not block CI longer than this.
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
Expand Down
754 changes: 723 additions & 31 deletions crates/core/src/codec.rs

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,22 @@ impl PySessionContext {
physical_codec,
})
}

pub fn with_python_udf_inlining(&self, enabled: bool) -> Self {
let logical_codec = Arc::new(
PythonLogicalCodec::new(Arc::clone(self.logical_codec.inner()))
.with_python_udf_inlining(enabled),
);
let physical_codec = Arc::new(
PythonPhysicalCodec::new(Arc::clone(self.physical_codec.inner()))
.with_python_udf_inlining(enabled),
);
Self {
ctx: Arc::clone(&self.ctx),
logical_codec,
physical_codec,
}
}
}

impl PySessionContext {
Expand Down
186 changes: 170 additions & 16 deletions crates/core/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::ptr::NonNull;
use std::sync::Arc;

use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::{DataType, Field, FieldRef};
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::common::ScalarValue;
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf,
Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_ffi::udaf::FFI_AggregateUDF;
use datafusion_python_util::parse_volatility;
Expand Down Expand Up @@ -144,15 +146,161 @@ impl Accumulator for RustAccumulator {
}
}

pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
Arc::new(move |_args| -> Result<Box<dyn Accumulator>> {
let accum = Python::attach(|py| {
accum
.call0(py)
.map_err(|e| DataFusionError::Execution(format!("{e}")))
})?;
Ok(Box::new(RustAccumulator::new(accum)))
})
fn instantiate_accumulator(accum: &Py<PyAny>) -> Result<Box<dyn Accumulator>> {
let instance = Python::attach(|py| {
accum
.call0(py)
.map_err(|e| DataFusionError::Execution(format!("{e}")))
})?;
Ok(Box::new(RustAccumulator::new(instance)))
}

/// Named-struct `AggregateUDFImpl` for Python-defined aggregate UDFs.
/// Holds the Python accumulator factory directly so the codec can
/// downcast and cloudpickle it across process boundaries.
#[derive(Debug)]
pub(crate) struct PythonFunctionAggregateUDF {
name: String,
accumulator: Py<PyAny>,
signature: Signature,
return_type: DataType,
state_fields: Vec<FieldRef>,
}

impl PythonFunctionAggregateUDF {
fn new(
name: String,
accumulator: Py<PyAny>,
input_types: Vec<DataType>,
return_type: DataType,
state_types: Vec<DataType>,
volatility: Volatility,
) -> Self {
let signature = Signature::exact(input_types, volatility);
let state_fields = state_types
.into_iter()
.enumerate()
.map(|(i, t)| Arc::new(Field::new(format!("state_{i}"), t, true)))
.collect();
Self {
name,
accumulator,
signature,
return_type,
state_fields,
}
}

/// Stored Python callable that returns a fresh accumulator instance
/// per partition. Consumed by the codec to cloudpickle the factory
/// across process boundaries.
pub(crate) fn accumulator(&self) -> &Py<PyAny> {
&self.accumulator
}

pub(crate) fn return_type(&self) -> &DataType {
&self.return_type
}

pub(crate) fn state_fields_ref(&self) -> &[FieldRef] {
&self.state_fields
}

/// Reconstruct a `PythonFunctionAggregateUDF` from the parts emitted
/// by the codec. `state_fields` carries the full state schema
/// (names, data types, nullability, metadata) — the codec extracts
/// it from the IPC payload, so the post-decode state schema is
/// identical to the pre-encode one. Use [`Self::new`] when only
/// `Vec<DataType>` is available (e.g. the Python constructor path,
/// where field names are synthesized).
pub(crate) fn from_parts(
name: String,
accumulator: Py<PyAny>,
input_types: Vec<DataType>,
return_type: DataType,
state_fields: Vec<FieldRef>,
volatility: Volatility,
) -> Self {
Self {
name,
accumulator,
signature: Signature::exact(input_types, volatility),
return_type,
state_fields,
}
}
}

impl Eq for PythonFunctionAggregateUDF {}
impl PartialEq for PythonFunctionAggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.signature == other.signature
&& self.return_type == other.return_type
&& self.state_fields == other.state_fields
// Pointer-identity fast path: `Arc`-shared clones of the
// same UDF skip the GIL roundtrip. Falls through to Python
// `__eq__` only for two distinct callables.
&& (self.accumulator.as_ptr() == other.accumulator.as_ptr()
|| Python::attach(|py| {
// See `PythonFunctionScalarUDF::eq` for the
// rationale on swallowing the exception as `false`
// and logging at `debug`. FIXME: revisit if
// upstream `AggregateUDFImpl` exposes a fallible
// `PartialEq`.
self.accumulator
.bind(py)
.eq(other.accumulator.bind(py))
.unwrap_or_else(|e| {
log::debug!(
target: "datafusion_python::udaf",
"PythonFunctionAggregateUDF {:?} __eq__ raised; treating as unequal: {e}",
self.name,
);
false
})
}))
}
}

impl std::hash::Hash for PythonFunctionAggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// See `PythonFunctionScalarUDF`'s `Hash` impl for the
// rationale: hash the identifying header only and let
// `PartialEq` disambiguate callables.
self.name.hash(state);
self.signature.hash(state);
self.return_type.hash(state);
for f in &self.state_fields {
f.hash(state);
}
}
}

impl AggregateUDFImpl for PythonFunctionAggregateUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
instantiate_accumulator(&self.accumulator)
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Ok(self.state_fields.clone())
}
}

fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
Expand Down Expand Up @@ -190,14 +338,15 @@ impl PyAggregateUDF {
state_type: PyArrowType<Vec<DataType>>,
volatility: &str,
) -> PyResult<Self> {
let function = create_udaf(
name,
let py_udf = PythonFunctionAggregateUDF::new(
name.to_string(),
accumulator,
input_type.0,
Arc::new(return_type.0),
return_type.0,
state_type.0,
parse_volatility(volatility)?,
to_rust_accumulator(accumulator),
Arc::new(state_type.0),
);
let function = AggregateUDF::new_from_impl(py_udf);
Ok(Self { function })
}

Expand Down Expand Up @@ -231,4 +380,9 @@ impl PyAggregateUDF {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("AggregateUDF({})", self.function.name()))
}

#[getter]
fn name(&self) -> &str {
self.function.name()
}
}
82 changes: 74 additions & 8 deletions crates/core/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::expr::PyExpr;
/// This struct holds the Python written function that is a
/// ScalarUDF.
#[derive(Debug)]
struct PythonFunctionScalarUDF {
pub(crate) struct PythonFunctionScalarUDF {
name: String,
func: Py<PyAny>,
signature: Signature,
Expand All @@ -67,6 +67,37 @@ impl PythonFunctionScalarUDF {
return_field: Arc::new(return_field),
}
}

/// Stored Python callable. Consumed by the codec to cloudpickle
/// the function body across process boundaries.
pub(crate) fn func(&self) -> &Py<PyAny> {
&self.func
}

pub(crate) fn return_field(&self) -> &FieldRef {
&self.return_field
}

/// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted
/// by the codec. Inputs collapse to `Vec<DataType>` because
/// `Signature::exact` cannot carry per-input nullability or
/// metadata — the encoder is free to discard that side of the
/// schema. `return_field` is kept as a `Field` so the post-decode
/// nullability and metadata match the sender's instance.
pub(crate) fn from_parts(
name: String,
func: Py<PyAny>,
input_types: Vec<DataType>,
return_field: Field,
volatility: Volatility,
) -> Self {
Self {
name,
func,
signature: Signature::exact(input_types, volatility),
return_field: Arc::new(return_field),
}
}
}

impl Eq for PythonFunctionScalarUDF {}
Expand All @@ -75,21 +106,51 @@ impl PartialEq for PythonFunctionScalarUDF {
self.name == other.name
&& self.signature == other.signature
&& self.return_field == other.return_field
&& Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false))
// Identical pointers ⇒ same Python object. Most equality
// checks compare `Arc`-shared clones of the same UDF
// (e.g. expression rewriting), so the pointer match short-
// circuits before touching the GIL.
&& (self.func.as_ptr() == other.func.as_ptr()
|| Python::attach(|py| {
// Rust's `PartialEq` cannot return `Result`, so we
// have to pick a side when Python `__eq__` raises.
// `false` is the conservative choice — better to
// report two UDFs as distinct than to wrongly
// merge them — but the silent miss can still
// surface as expression-dedup or cache-lookup
// anomalies. Log at `debug` so the failure is
// observable without flooding production logs.
// FIXME: revisit if upstream `ScalarUDFImpl`
// exposes a fallible `PartialEq`.
self.func
.bind(py)
.eq(other.func.bind(py))
.unwrap_or_else(|e| {
log::debug!(
target: "datafusion_python::udf",
"PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}",
self.name,
);
false
})
}))
}
}

impl Hash for PythonFunctionScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
// Hash only the identifying header (name + signature + return
// field). Skipping `func` is intentional: the Rust `Hash`
// contract requires `a == b ⇒ hash(a) == hash(b)`, not the
// converse, so a coarser hash is sound — `PartialEq` still
// disambiguates two UDFs with the same header but distinct
// callables. Falling back to a sentinel on `py_hash` failure
// (as a prior revision did) silently mapped every unhashable
// closure to the same bucket; that is the worst case for a
// hashmap and is what this rewrite avoids.
self.name.hash(state);
self.signature.hash(state);
self.return_field.hash(state);

Python::attach(|py| {
let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects

state.write_isize(py_hash);
});
}
}

Expand Down Expand Up @@ -220,4 +281,9 @@ impl PyScalarUDF {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("ScalarUDF({})", self.function.name()))
}

#[getter]
fn name(&self) -> &str {
self.function.name()
}
}
Loading
Loading