diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs index 088532df2..c95d8cb19 100644 --- a/crates/core/src/codec.rs +++ b/crates/core/src/codec.rs @@ -19,11 +19,11 @@ //! //! Datafusion-python plans can carry references to Python-defined //! objects that the upstream protobuf codecs do not know how to -//! serialize: pure-Python scalar / aggregate / window UDFs, Python -//! query-planning extensions, and so on. Their state lives inside -//! `Py` callables and closures rather than being recoverable -//! from a name in the receiver's function registry. To ship a plan -//! across a process boundary (pickle, `multiprocessing`, Ray actor, +//! serialize: pure-Python scalar UDFs, Python query-planning +//! extensions, and so on. Their state lives inside `Py` +//! callables and closures rather than being recoverable from a name +//! in the receiver's function registry. To ship a plan across a +//! process boundary (pickle, `multiprocessing`, Ray actor, //! `datafusion-distributed`, etc.) those payloads have to be encoded //! into the proto wire format itself. //! @@ -48,52 +48,121 @@ //! plans to survive a serialization round-trip. Both codecs share //! the same payload framing for that reason. //! -//! Payloads emitted by these codecs are tagged with an 8-byte magic -//! prefix so the decoder can distinguish them from arbitrary bytes -//! (empty `fun_definition` from the default codec, user FFI payloads -//! that picked a non-colliding prefix). Dispatch precedence on -//! decode: **Python-inline payload (magic prefix match) → `inner` -//! codec → caller's `FunctionRegistry` fallback.** +//! Payloads emitted by these codecs are framed as +//! ` `. The +//! family magic identifies the UDF flavor; the version byte lets the +//! decoder reject too-new or too-old payloads with a clean error +//! instead of falling into an opaque `cloudpickle` tuple-unpack +//! failure when the tuple shape changes. Dispatch precedence on +//! decode: **family match + supported version → `inner` codec → +//! caller's `FunctionRegistry` fallback.** //! -//! ## Wire-format magic prefix registry +//! ## Wire-format family registry //! -//! | Layer + kind | Magic prefix | -//! | ----------------------------- | ------------ | -//! | `PythonLogicalCodec` scalar | `DFPYUDF1` | -//! | `PythonLogicalCodec` agg | `DFPYUDA1` | -//! | `PythonLogicalCodec` window | `DFPYUDW1` | -//! | `PythonPhysicalCodec` scalar | `DFPYUDF1` | -//! | `PythonPhysicalCodec` agg | `DFPYUDA1` | -//! | `PythonPhysicalCodec` window | `DFPYUDW1` | -//! | `PythonPhysicalCodec` expr | `DFPYPE1` | -//! | User FFI extension codec | user-chosen | -//! | Default codec | (none) | +//! | Layer + kind | Family prefix | +//! | ----------------------------- | ------------- | +//! | `PythonLogicalCodec` scalar | `DFPYUDF` | +//! | `PythonPhysicalCodec` scalar | `DFPYUDF` | +//! | User FFI extension codec | user-chosen | +//! | Default codec | (none) | //! -//! Downstream FFI codecs should pick non-colliding prefixes (use a -//! `DF` namespace plus a crate-specific suffix). The codec +//! Aggregate and window UDF families are reserved for follow-on work. +//! +//! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported +//! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. +//! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape +//! changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when dropping support +//! for an older shape. +//! +//! Downstream FFI codecs should pick non-colliding family prefixes +//! (use a `DF` namespace plus a crate-specific suffix). The codec //! implementations in this module currently delegate every method to //! `inner`; the encoder/decoder hooks for each kind are added as the //! corresponding Python-side type becomes serializable. use std::sync::Arc; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::ipc::reader::StreamReader; +use arrow::ipc::writer::StreamWriter; use datafusion::common::{Result, TableReference}; use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::FileFormatFactory; use datafusion::execution::TaskContext; -use datafusion::logical_expr::{AggregateUDF, Extension, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion::logical_expr::{ + AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, + Volatility, WindowUDF, +}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; +use pyo3::prelude::*; +use pyo3::sync::PyOnceLock; +use pyo3::types::{PyBytes, PyTuple}; -/// Wire-format prefix that tags a `fun_definition` payload as an -/// inlined Python scalar UDF (cloudpickled tuple of name, callable, -/// input schema, return field, volatility). Defined once here so -/// the encoder and decoder cannot drift. -#[allow(dead_code)] -pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1"; +use crate::udf::PythonFunctionScalarUDF; + +// Wire-format framing for inlined Python UDF payloads. +// +// Layout: ` `. +// The family magic identifies the UDF flavor; the version byte lets +// the decoder reject too-new or too-old payloads with a clean error +// instead of falling into an opaque `cloudpickle` tuple-unpack failure +// when the tuple shape changes. Bump [`WIRE_VERSION_CURRENT`] whenever +// the tuple shape changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when +// dropping support for an older shape. + +/// Family prefix for an inlined Python scalar UDF +/// (cloudpickled tuple of name, callable, input schema, return field, +/// volatility). +pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF"; + +/// Wire-format version this build emits. +pub(crate) const WIRE_VERSION_CURRENT: u8 = 1; + +/// Oldest wire-format version this build still decodes. Bump when +/// retiring support for an older payload shape. +pub(crate) const WIRE_VERSION_MIN_SUPPORTED: u8 = 1; + +/// Tag `buf` with the framing header for `family` at the current +/// wire-format version. Append-only — the caller writes the +/// cloudpickle payload after. +fn write_wire_header(buf: &mut Vec, family: &[u8]) { + buf.extend_from_slice(family); + buf.push(WIRE_VERSION_CURRENT); +} + +/// Inspect the framing on `buf`. +/// +/// * `Ok(None)` — `buf` does not carry `family`. The caller should +/// delegate to its `inner` codec. +/// * `Ok(Some(payload))` — `buf` carries `family` at a version this +/// build accepts; `payload` is the cloudpickle blob. +/// * `Err(_)` — `buf` carries `family` but at a version outside +/// `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. The error +/// names the version and the supported range so an operator can +/// diagnose sender/receiver version drift instead of seeing an +/// opaque cloudpickle tuple-unpack failure. +fn strip_wire_header<'a>(buf: &'a [u8], family: &[u8], kind: &str) -> Result> { + if !buf.starts_with(family) { + return Ok(None); + } + let version_idx = family.len(); + let Some(&version) = buf.get(version_idx) else { + return Err(datafusion::error::DataFusionError::Execution(format!( + "Truncated inline Python {kind} payload: missing wire-format version byte" + ))); + }; + if !(WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT).contains(&version) { + return Err(datafusion::error::DataFusionError::Execution(format!( + "Inline Python {kind} payload wire-format version v{version}; \ + this build supports v{WIRE_VERSION_MIN_SUPPORTED}..=v{WIRE_VERSION_CURRENT}. \ + Align datafusion-python versions on sender and receiver." + ))); + } + Ok(Some(&buf[version_idx + 1..])) +} /// `LogicalExtensionCodec` parked on every `SessionContext`. Holds /// the Python-aware encoding hooks for logical-layer types @@ -177,10 +246,16 @@ impl LogicalExtensionCodec for PythonLogicalCodec { } fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_scalar_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udf(node, buf) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udf) = try_decode_python_scalar_udf(buf)? { + return Ok(udf); + } self.inner.try_decode_udf(name, buf) } @@ -212,7 +287,7 @@ impl LogicalExtensionCodec for PythonLogicalCodec { /// encoding on this layer too — otherwise a plan with a Python UDF /// would round-trip at the logical level but break at the physical /// level. Both layers reuse the shared payload framing -/// ([`PY_SCALAR_UDF_MAGIC`] et al.) so the wire format is identical. +/// ([`PY_SCALAR_UDF_FAMILY`]) so the wire format is identical. #[derive(Debug)] pub struct PythonPhysicalCodec { inner: Arc, @@ -249,10 +324,16 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { } fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_scalar_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udf(node, buf) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udf) = try_decode_python_scalar_udf(buf)? { + return Ok(udf); + } self.inner.try_decode_udf(name, buf) } @@ -284,3 +365,282 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { self.inner.try_decode_udwf(name, buf) } } + +// ============================================================================= +// Shared Python scalar UDF encode / decode helpers +// +// Both `PythonLogicalCodec` and `PythonPhysicalCodec` consult these on +// every `try_encode_udf` / `try_decode_udf` call. Same wire format on +// both layers — a Python `ScalarUDF` referenced inside a `LogicalPlan` +// or an `ExecutionPlan` round-trips identically. +// ============================================================================= + +/// Encode a Python scalar UDF inline if `node` is one. Returns +/// `Ok(true)` when the payload (`DFPYUDF` family prefix, version byte, +/// cloudpickled tuple) was written and the caller should skip its +/// inner codec. Returns `Ok(false)` for any non-Python UDF, signalling +/// the caller to delegate to its `inner`. +pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node + .inner() + .as_any() + .downcast_ref::() + else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_scalar_udf(py, py_udf) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + write_wire_header(buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(&bytes); + Ok(true) + }) +} + +/// Decode an inline Python scalar UDF payload. Returns `Ok(None)` +/// when `buf` does not carry the `DFPYUDF` family prefix, signalling +/// the caller to delegate to its `inner` codec (and eventually the +/// `FunctionRegistry`). +pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result>> { + let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF")? else { + return Ok(None); + }; + + Python::attach(|py| -> Result>> { + let udf = decode_python_scalar_udf(py, payload) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf)))) + }) +} + +/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`. +/// +/// Layout: `cloudpickle.dumps((name, func, input_schema_bytes, +/// return_schema_bytes, volatility_str))`. Schema blobs are produced +/// by arrow-rs's native IPC stream writer (no pyarrow round-trip) and +/// decoded with the matching stream reader on the receiver. See +/// [`build_input_schema_bytes`] for what the input blob carries. +fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> PyResult> { + let signature = udf.signature(); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionScalarUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_schema_bytes = build_single_field_schema_bytes(udf.return_field().as_ref())?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + udf.name().into_pyobject(py)?.into_any(), + udf.func().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +/// Inverse of [`encode_python_scalar_udf`]. +fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let func: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let volatility_str: String = tuple.get_item(4)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_field = read_single_return_field(&return_schema_bytes, "PythonFunctionScalarUDF")?; + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionScalarUDF::from_parts( + name, + func, + input_types, + return_field, + volatility, + )) +} + +/// Serialize a `Schema` to a self-contained IPC stream containing +/// only the schema message (no record batches). Inverse: +/// [`schema_from_ipc_bytes`]. +fn schema_to_ipc_bytes(schema: &Schema) -> arrow::error::Result> { + let mut buf: Vec = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buf, schema)?; + writer.finish()?; + } + Ok(buf) +} + +/// Decode an IPC stream containing only a schema message back into a +/// `Schema`. Inverse: [`schema_to_ipc_bytes`]. +fn schema_from_ipc_bytes(bytes: &[u8]) -> arrow::error::Result { + let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?; + Ok(reader.schema().as_ref().clone()) +} + +/// Extract the per-arg `DataType`s from a `Signature` known to be +/// `TypeSignature::Exact` (all Python-defined UDFs are constructed +/// with `Signature::exact`). Any other variant indicates the impl was +/// not built by this crate's UDF/UDAF/UDWF constructors. +fn signature_input_dtypes(signature: &Signature, kind: &str) -> PyResult> { + match &signature.type_signature { + TypeSignature::Exact(types) => Ok(types.clone()), + other => Err(pyo3::exceptions::PyValueError::new_err(format!( + "{kind} expected Signature::Exact, got {other:?}" + ))), + } +} + +/// Wrap per-arg `DataType`s in synthetic `arg_{i}` fields and emit +/// the IPC schema blob the encoder writes into the cloudpickle tuple. +/// +/// The names and `nullable: true` are arbitrary: the underlying +/// `TypeSignature::Exact` carries no per-input nullability or +/// metadata, and the receiver collapses these fields back to +/// `Vec` via [`read_input_dtypes`], so anything set here +/// beyond the data type is discarded on decode. +fn build_input_schema_bytes(dtypes: &[DataType]) -> PyResult> { + let fields: Vec = dtypes + .iter() + .enumerate() + .map(|(i, dt)| Field::new(format!("arg_{i}"), dt.clone(), true)) + .collect(); + schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err) +} + +/// Emit a single-field IPC schema blob. Used for return-type and +/// state-field payloads where the receiver needs to recover field +/// metadata (names, nullability, key/value attributes) verbatim. +fn build_single_field_schema_bytes(field: &Field) -> PyResult> { + schema_to_ipc_bytes(&Schema::new(vec![field.clone()])).map_err(arrow_to_py_err) +} + +/// Decode the per-arg `DataType`s the encoder wrote via +/// [`build_input_schema_bytes`]. +fn read_input_dtypes(bytes: &[u8]) -> PyResult> { + let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; + Ok(schema + .fields() + .iter() + .map(|f| f.data_type().clone()) + .collect()) +} + +/// Decode a single-field IPC schema blob and return that field by +/// value. `kind` names the UDF flavor in the error message produced +/// when the blob is empty (should be unreachable for sender-side +/// payloads built via [`build_single_field_schema_bytes`]). +fn read_single_return_field(bytes: &[u8], kind: &str) -> PyResult { + let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; + let field = schema.fields().first().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err(format!( + "{kind} return schema must contain exactly one field" + )) + })?; + Ok(field.as_ref().clone()) +} + +fn arrow_to_py_err(e: arrow::error::ArrowError) -> PyErr { + pyo3::exceptions::PyValueError::new_err(format!("{e}")) +} + +fn parse_volatility_str(s: &str) -> PyResult { + datafusion_python_util::parse_volatility(s) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}"))) +} + +/// Stable wire-format string for a `Volatility`. Pinned to the three +/// tokens [`datafusion_python_util::parse_volatility`] accepts, so an +/// upstream change to `Volatility`'s `Debug` repr cannot silently +/// produce bytes the decoder rejects. +fn volatility_wire_str(v: Volatility) -> &'static str { + match v { + Volatility::Immutable => "immutable", + Volatility::Stable => "stable", + Volatility::Volatile => "volatile", + } +} + +/// Cached handle to the `cloudpickle` module. +/// +/// The encode/decode helpers above would otherwise re-resolve the +/// module on every call. `py.import` is backed by `sys.modules` and +/// therefore cheap, but each call still walks a dict and re-binds the +/// result; a plan with many Python UDFs pays that cost per UDF. +/// +/// `PyOnceLock` scopes the cached `Py` to the current +/// interpreter, so the slot drops cleanly on interpreter teardown +/// (relevant under CPython subinterpreters, PEP 684) instead of +/// resurrecting a `Py` rooted in a dead interpreter on the next call. +fn cloudpickle<'py>(py: Python<'py>) -> PyResult> { + static CLOUDPICKLE: PyOnceLock> = PyOnceLock::new(); + CLOUDPICKLE + .get_or_try_init(py, || Ok(py.import("cloudpickle")?.unbind().into_any())) + .map(|cached| cached.bind(py).clone()) +} + +#[cfg(test)] +mod wire_header_tests { + use super::*; + + #[test] + fn strip_returns_none_when_family_absent() { + let buf = b"OTHER_PAYLOAD"; + assert!(matches!( + strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF"), + Ok(None) + )); + } + + #[test] + fn strip_errors_on_truncated_version_byte() { + let buf = PY_SCALAR_UDF_FAMILY; + let err = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + assert!(format!("{err}").contains("missing wire-format version byte")); + } + + #[test] + fn strip_errors_on_too_new_version() { + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_CURRENT.saturating_add(1)); + buf.extend_from_slice(b"payload"); + let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("wire-format version v")); + assert!(msg.contains("supports")); + assert!(msg.contains("Align datafusion-python versions")); + } + + #[test] + fn strip_errors_on_too_old_version() { + if WIRE_VERSION_MIN_SUPPORTED == 0 { + return; + } + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_MIN_SUPPORTED - 1); + buf.extend_from_slice(b"payload"); + assert!(strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").is_err()); + } + + #[test] + fn write_then_strip_round_trips_payload() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(b"scalar-payload"); + + let payload = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF") + .unwrap() + .unwrap(); + assert_eq!(payload, b"scalar-payload"); + } +} diff --git a/crates/core/src/udf.rs b/crates/core/src/udf.rs index c0a39cb47..72cdddba1 100644 --- a/crates/core/src/udf.rs +++ b/crates/core/src/udf.rs @@ -43,7 +43,7 @@ use crate::expr::PyExpr; /// This struct holds the Python written function that is a /// ScalarUDF. #[derive(Debug)] -struct PythonFunctionScalarUDF { +pub(crate) struct PythonFunctionScalarUDF { name: String, func: Py, signature: Signature, @@ -67,6 +67,37 @@ impl PythonFunctionScalarUDF { return_field: Arc::new(return_field), } } + + /// Stored Python callable. Consumed by the codec to cloudpickle + /// the function body across process boundaries. + pub(crate) fn func(&self) -> &Py { + &self.func + } + + pub(crate) fn return_field(&self) -> &FieldRef { + &self.return_field + } + + /// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted + /// by the codec. Inputs collapse to `Vec` because + /// `Signature::exact` cannot carry per-input nullability or + /// metadata — the encoder is free to discard that side of the + /// schema. `return_field` is kept as a `Field` so the post-decode + /// nullability and metadata match the sender's instance. + pub(crate) fn from_parts( + name: String, + func: Py, + input_types: Vec, + return_field: Field, + volatility: Volatility, + ) -> Self { + Self { + name, + func, + signature: Signature::exact(input_types, volatility), + return_field: Arc::new(return_field), + } + } } impl Eq for PythonFunctionScalarUDF {} @@ -75,21 +106,51 @@ impl PartialEq for PythonFunctionScalarUDF { self.name == other.name && self.signature == other.signature && self.return_field == other.return_field - && Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false)) + // Identical pointers ⇒ same Python object. Most equality + // checks compare `Arc`-shared clones of the same UDF + // (e.g. expression rewriting), so the pointer match short- + // circuits before touching the GIL. + && (self.func.as_ptr() == other.func.as_ptr() + || Python::attach(|py| { + // Rust's `PartialEq` cannot return `Result`, so we + // have to pick a side when Python `__eq__` raises. + // `false` is the conservative choice — better to + // report two UDFs as distinct than to wrongly + // merge them — but the silent miss can still + // surface as expression-dedup or cache-lookup + // anomalies. Log at `debug` so the failure is + // observable without flooding production logs. + // FIXME: revisit if upstream `ScalarUDFImpl` + // exposes a fallible `PartialEq`. + self.func + .bind(py) + .eq(other.func.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udf", + "PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) } } impl Hash for PythonFunctionScalarUDF { fn hash(&self, state: &mut H) { + // Hash only the identifying header (name + signature + return + // field). Skipping `func` is intentional: the Rust `Hash` + // contract requires `a == b ⇒ hash(a) == hash(b)`, not the + // converse, so a coarser hash is sound — `PartialEq` still + // disambiguates two UDFs with the same header but distinct + // callables. Falling back to a sentinel on `py_hash` failure + // (as a prior revision did) silently mapped every unhashable + // closure to the same bucket; that is the worst case for a + // hashmap and is what this rewrite avoids. self.name.hash(state); self.signature.hash(state); self.return_field.hash(state); - - Python::attach(|py| { - let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects - - state.write_isize(py_hash); - }); } } @@ -220,4 +281,9 @@ impl PyScalarUDF { fn __repr__(&self) -> PyResult { Ok(format!("ScalarUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } diff --git a/pyproject.toml b/pyproject.toml index 951f7adc3..a02f4608a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,13 @@ classifiers = [ "Programming Language :: Rust", ] dependencies = [ + # cloudpickle is invoked by the Rust-side PythonLogicalCodec / + # PythonPhysicalCodec via pyo3 to serialize Python UDF callables — + # scalar, aggregate, and window — into the proto wire format. + # Lazy-imported on the encode / decode hot paths (and cached after + # the first import), so users who never serialize a plan or + # expression incur no runtime cost beyond the install footprint. + "cloudpickle>=2.0", "pyarrow>=16.0.0;python_version<'3.14'", "pyarrow>=22.0.0;python_version>='3.14'", "typing-extensions;python_version<'3.13'", diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index f08b464bb..dfdeef07e 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -65,7 +65,7 @@ import importlib_metadata # type: ignore[import] # Public submodules -from . import functions, object_store, substrait, unparser +from . import functions, ipc, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config @@ -142,6 +142,7 @@ "configure_formatter", "expr", "functions", + "ipc", "lit", "literal", "object_store", diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index e0135e3ed..cdb8377de 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -46,7 +46,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from typing import TYPE_CHECKING, Any, ClassVar import pyarrow as pa @@ -434,23 +434,104 @@ def variant_name(self) -> str: return self.expr.variant_name() def to_bytes(self, ctx: SessionContext | None = None) -> bytes: - """Serialize this expression to protobuf bytes. + """Serialize this expression to bytes for shipping to another process. - When ``ctx`` is supplied, encoding routes through the session's - installed :class:`LogicalExtensionCodec`. Without ``ctx`` a - default codec is used. + Use this — or :func:`pickle.dumps` — to send an expression to a + worker process for distributed evaluation. + + When ``ctx`` is supplied, encoding routes through that session's + installed :class:`LogicalExtensionCodec`. When ``ctx`` is + ``None``, the default codec is used. + + Built-in functions and Python scalar UDFs travel inside the + returned bytes; the worker does not need to pre-register them. + UDFs imported via the FFI capsule protocol travel by name only + and must be registered on the worker. + + .. warning:: + Bytes returned here may embed a cloudpickled Python + callable (when the expression carries a Python scalar UDF). + Reconstructing them via :meth:`from_bytes` or + :func:`pickle.loads` executes arbitrary Python on the + receiver. Only accept payloads from trusted sources. + + Examples: + >>> from datafusion import col, lit + >>> blob = (col("a") + lit(1)).to_bytes() + >>> isinstance(blob, bytes) + True """ ctx_arg = ctx.ctx if ctx is not None else None return self.expr.to_bytes(ctx_arg) - @staticmethod - def from_bytes(ctx: SessionContext, data: bytes) -> Expr: - """Decode an expression from serialized protobuf bytes. + @classmethod + def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr: + """Reconstruct an expression from serialized bytes. + + Accepts output of :meth:`to_bytes` or :func:`pickle.dumps`. + ``ctx`` is the :class:`SessionContext` used to resolve any + function references that travel by name (e.g. FFI UDFs). When + ``ctx`` is ``None`` the worker context installed via + :func:`datafusion.ipc.set_worker_ctx` is consulted; if no worker + context is installed, the global :class:`SessionContext` is used + (sufficient for built-ins and Python scalar UDFs, plus any UDFs + registered on the global context). + + .. warning:: + Decoding may invoke ``cloudpickle.loads`` on bytes embedded + in the payload, which executes arbitrary Python code. Treat + ``buf`` as code, not data — only decode bytes you produced + yourself or received from a trusted sender. + + Examples: + >>> from datafusion import Expr, col, lit + >>> blob = (col("a") + lit(1)).to_bytes() + >>> Expr.from_bytes(blob).canonical_name() + 'a + Int64(1)' + """ + from datafusion.ipc import _resolve_ctx - ``ctx`` provides the function registry for resolving UDF - references and the logical codec for in-band Python payloads. + resolved = _resolve_ctx(ctx) + return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf)) + + def __reduce__(self) -> tuple[Callable[[bytes], Expr], tuple[bytes]]: + """Pickle protocol hook. + + Lets expressions be shipped to worker processes via + :func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions + and Python scalar UDFs travel inside the pickle bytes; only + FFI-capsule UDFs require pre-registration on the worker. The + worker's :class:`SessionContext` for resolving those references + is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling + back to the global :class:`SessionContext` if none has been + installed on the worker. + + .. warning:: + :func:`pickle.loads` on the returned tuple executes + arbitrary Python on the receiver, including any + cloudpickled UDF callable embedded in the payload. Only + unpickle expressions from trusted sources. + + Examples: + >>> import pickle + >>> from datafusion import col, lit + >>> e = col("a") * lit(2) + >>> pickle.loads(pickle.dumps(e)).canonical_name() + 'a * Int64(2)' + """ + return (Expr._reconstruct, (self.to_bytes(),)) + + @classmethod + def _reconstruct(cls, proto_bytes: bytes) -> Expr: + """Internal entry point used by :meth:`__reduce__` on unpickle. + + Examples: + >>> from datafusion import Expr, col, lit + >>> blob = (col("a") + lit(1)).to_bytes() + >>> Expr._reconstruct(blob).canonical_name() + 'a + Int64(1)' """ - return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data)) + return cls.from_bytes(proto_bytes) def __richcmp__(self, other: Expr, op: int) -> Expr: """Comparison operator.""" diff --git a/python/datafusion/ipc.py b/python/datafusion/ipc.py new file mode 100644 index 000000000..16e68c4d0 --- /dev/null +++ b/python/datafusion/ipc.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Worker-side setup for distributing DataFusion expressions. + +When a :class:`Expr` is shipped to a worker process (e.g. through +:func:`multiprocessing.Pool` or a Ray actor), the worker reconstructs the +expression against a :class:`SessionContext`. If the expression references +UDFs imported via the FFI capsule protocol — or any UDF the worker would +otherwise resolve from its registered functions rather than from inside +the shipped expression — install a configured :class:`SessionContext` +once per worker: + +.. code-block:: python + + from datafusion import SessionContext + from datafusion.ipc import set_worker_ctx + + def init_worker(): + ctx = SessionContext() + ctx.register_udaf(my_ffi_aggregate) + set_worker_ctx(ctx) + +Built-in functions and Python scalar UDFs travel inside the shipped +expression itself and do not need pre-registration on the worker. +""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datafusion.context import SessionContext + + +__all__ = [ + "clear_worker_ctx", + "get_worker_ctx", + "set_worker_ctx", +] + + +_local = threading.local() + + +def set_worker_ctx(ctx: SessionContext) -> None: + """Install this worker's :class:`SessionContext` for shipped expressions. + + Call once per worker — typically from a ``multiprocessing.Pool`` + initializer or a Ray actor ``__init__``. Idempotent: overwrites any + previous value. Stored in a thread-local slot, so each thread within a + worker may install its own context independently. + + Examples: + >>> from datafusion import SessionContext + >>> from datafusion.ipc import set_worker_ctx, get_worker_ctx, clear_worker_ctx + >>> set_worker_ctx(SessionContext()) + >>> get_worker_ctx() is not None + True + >>> clear_worker_ctx() + """ + _local.ctx = ctx + + +def clear_worker_ctx() -> None: + """Remove this worker's installed :class:`SessionContext`. + + After clearing, expressions reconstructed in this worker fall back to + the global :class:`SessionContext` — adequate for built-ins and Python + scalar UDFs, but anything imported via the FFI capsule protocol must + be registered on the global context to resolve. + + Examples: + >>> from datafusion import SessionContext + >>> from datafusion.ipc import set_worker_ctx, clear_worker_ctx, get_worker_ctx + >>> set_worker_ctx(SessionContext()) + >>> clear_worker_ctx() + >>> get_worker_ctx() is None + True + """ + if hasattr(_local, "ctx"): + del _local.ctx + + +def get_worker_ctx() -> SessionContext | None: + """Return this worker's installed :class:`SessionContext`, or ``None``. + + Examples: + >>> from datafusion.ipc import get_worker_ctx, clear_worker_ctx + >>> clear_worker_ctx() + >>> get_worker_ctx() is None + True + """ + return getattr(_local, "ctx", None) + + +def _resolve_ctx( + explicit_ctx: SessionContext | None = None, +) -> SessionContext: + """Resolve a context for Expr reconstruction. + + Priority: explicit argument > worker context > global context. + Falling back to the global :class:`SessionContext` (instead of a + freshly constructed one) preserves any registrations the user has + installed on it. + + Examples: + >>> from datafusion import SessionContext + >>> from datafusion.ipc import _resolve_ctx, clear_worker_ctx + >>> clear_worker_ctx() + >>> isinstance(_resolve_ctx(), SessionContext) + True + >>> ctx = SessionContext() + >>> _resolve_ctx(ctx) is ctx + True + """ + if explicit_ctx is not None: + return explicit_ctx + worker = get_worker_ctx() + if worker is not None: + return worker + # Lazy import: `datafusion/__init__.py` imports `datafusion.ipc` + # before `datafusion.context`, so a module-top import would force + # `datafusion.context` to load mid-init of `datafusion.ipc`. The + # cycle is benign today (context.py only pulls expr.py at module + # scope, neither pulls ipc.py back), but a single new import in + # context.py's transitive deps could turn it into a real cycle. + # Deferring keeps `datafusion.ipc` import-order-independent. + from datafusion.context import SessionContext # noqa: PLC0415 + + return SessionContext.global_ctx() diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 848ab4cee..d79cf22e8 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -141,6 +141,29 @@ def __init__( name, func, input_fields, return_field, str(volatility) ) + @property + def name(self) -> str: + """Return the registered name of this UDF. + + For UDFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + + Examples: + >>> import pyarrow as pa + >>> from datafusion import udf + >>> double = udf( + ... lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + ... [pa.int64()], + ... pa.int64(), + ... volatility="immutable", + ... name="double", + ... ) + >>> double.name + 'double' + """ + return self._udf.name + def __repr__(self) -> str: """Print a string representation of the Scalar UDF.""" return self._udf.__repr__() diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 6a466f6f2..e1fdeab44 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -1186,7 +1186,7 @@ def test_expr_to_bytes_roundtrip(ctx: SessionContext) -> None: original = col("a") + lit(1) blob = original.to_bytes(ctx) - restored = Expr.from_bytes(ctx, blob) + restored = Expr.from_bytes(blob, ctx=ctx) # Canonical name preserves the structure of the expression even # though the underlying PyExpr instances are different. @@ -1201,6 +1201,6 @@ def test_expr_to_bytes_no_ctx_default_codec() -> None: fresh = SessionContext() original = col("a") * lit(2) blob = original.to_bytes() # encode side: default codec - restored = Expr.from_bytes(fresh, blob) + restored = Expr.from_bytes(blob, ctx=fresh) assert restored.canonical_name() == original.canonical_name() diff --git a/python/tests/test_pickle_expr.py b/python/tests/test_pickle_expr.py new file mode 100644 index 000000000..3e30ca14a --- /dev/null +++ b/python/tests/test_pickle_expr.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""In-process pickle round-trip tests for :class:`Expr`. + +Built-in functions and Python scalar UDFs travel with the pickled +expression and do not need worker-side pre-registration. The worker +context (:mod:`datafusion.ipc`) is only consulted for UDFs imported +via the FFI capsule protocol. +""" + +from __future__ import annotations + +import pickle + +import pyarrow as pa +import pytest +from datafusion import Expr, SessionContext, col, lit, udf +from datafusion.ipc import ( + clear_worker_ctx, + set_worker_ctx, +) + + +@pytest.fixture(autouse=True) +def _reset_worker_ctx(): + """Ensure every test starts with no worker context installed.""" + clear_worker_ctx() + yield + clear_worker_ctx() + + +def _double_udf(): + return udf( + lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + [pa.int64()], + pa.int64(), + volatility="immutable", + name="double", + ) + + +class TestProtoRoundTrip: + def test_builtin_round_trip(self): + e = col("a") + lit(1) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() + + def test_to_bytes_from_bytes(self): + e = col("x") * lit(7) + blob = e.to_bytes() + assert isinstance(blob, bytes) + decoded = Expr.from_bytes(blob) + assert decoded.canonical_name() == e.canonical_name() + + def test_explicit_ctx_used(self, ctx): + e = col("a") + lit(1) + decoded = Expr.from_bytes(e.to_bytes(), ctx=ctx) + assert decoded.canonical_name() == e.canonical_name() + + +class TestUDFCodec: + """Python scalar UDFs ride inside the proto blob via the Rust codec. + + No worker context needed on the receiver — the cloudpickled callable is + embedded in ``fun_definition`` and reconstructed automatically. + """ + + def test_udf_self_contained_blob(self): + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + # The codec inlines the callable, so the blob is much bigger than a + # pure built-in blob but doesn't depend on receiver-side registration. + assert len(blob) > 200 + + def test_udf_decodes_into_fresh_ctx(self): + e = _double_udf()(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "double" in decoded.canonical_name() + + def test_udf_decodes_via_pickle_with_no_worker_ctx(self): + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "double" in decoded.canonical_name() + + def test_udf_decodes_via_pickle_with_worker_ctx(self): + set_worker_ctx(SessionContext()) + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "double" in decoded.canonical_name() + + def test_closure_capturing_udf_names_match(self): + captured_multiplier = 7 + + def fn(arr): + return pa.array([(v.as_py() or 0) * captured_multiplier for v in arr]) + + u = udf( + fn, + [pa.int64()], + pa.int64(), + volatility="immutable", + name="times_seven", + ) + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() + + def test_multi_arg_udf_round_trip(self): + """Wire format builds synthetic `arg_{i}` fields per input — exercise + with a 2-arg UDF spanning two distinct DataTypes.""" + add_scaled = udf( + lambda a, b: pa.array( + [ + (x.as_py() or 0) + (y.as_py() or 0.0) + for x, y in zip(a, b, strict=False) + ] + ), + [pa.int64(), pa.float64()], + pa.float64(), + volatility="immutable", + name="add_scaled", + ) + e = add_scaled(col("a"), col("b")) + decoded = pickle.loads(pickle.dumps(e)) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() + assert "add_scaled" in decoded.canonical_name() + + +class TestErrorPaths: + def test_from_bytes_rejects_garbage(self): + with pytest.raises(Exception): # noqa: B017 + Expr.from_bytes(b"not a valid protobuf payload") + + def test_from_bytes_rejects_empty(self): + with pytest.raises(Exception): # noqa: B017 + Expr.from_bytes(b"") diff --git a/uv.lock b/uv.lock index 3b7135e32..3fd3eec4b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -257,6 +257,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767, upload-time = "2024-12-24T18:12:32.852Z" }, ] +[[package]] +name = "cloudpickle" +version = "3.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, +] + [[package]] name = "codespell" version = "2.4.1" @@ -316,6 +325,7 @@ wheels = [ name = "datafusion" source = { editable = "." } dependencies = [ + { name = "cloudpickle" }, { name = "pyarrow" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] @@ -351,6 +361,7 @@ docs = [ [package.metadata] requires-dist = [ + { name = "cloudpickle", specifier = ">=2.0" }, { name = "pyarrow", marker = "python_full_version < '3.14'", specifier = ">=16.0.0" }, { name = "pyarrow", marker = "python_full_version >= '3.14'", specifier = ">=22.0.0" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" },