diff --git a/src_cpp/include/py_connection.h b/src_cpp/include/py_connection.h index 97c6ca9..6dc6b52 100644 --- a/src_cpp/include/py_connection.h +++ b/src_cpp/include/py_connection.h @@ -57,6 +57,10 @@ class PyConnection { py::object arrowTable); std::unique_ptr createArrowRelTable(const std::string& tableName, py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName); + std::unique_ptr createArrowCsrRelTable(const std::string& tableName, + const std::string& srcTableName, const std::string& dstTableName, py::object fwdIndices, + py::object fwdIndptr, py::object bwdIndices = py::none(), + py::object bwdIndptr = py::none()); std::unique_ptr dropArrowTable(const std::string& tableName); static Value transformPythonValue(const py::handle& val); diff --git a/src_cpp/py_connection.cpp b/src_cpp/py_connection.cpp index 56a7c20..401e986 100644 --- a/src_cpp/py_connection.cpp +++ b/src_cpp/py_connection.cpp @@ -53,6 +53,10 @@ void PyConnection::initialize(py::handle& m) { py::arg("arrow_table")) .def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"), py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name")) + .def("create_arrow_csr_rel_table", &PyConnection::createArrowCsrRelTable, + py::arg("table_name"), py::arg("src_table_name"), py::arg("dst_table_name"), + py::arg("fwd_indices"), py::arg("fwd_indptr"), py::arg("bwd_indices") = py::none(), + py::arg("bwd_indptr") = py::none()) .def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name")); PyDateTime_IMPORT; } @@ -1093,6 +1097,89 @@ std::unique_ptr PyConnection::createArrowRelTable(const std::stri return checkAndWrapQueryResult(result.queryResult, state); } +static std::pair> exportPyArrowTable( + py::object& tbl) { + ArrowSchemaWrapper schema; + tbl.attr("schema").attr("_export_to_c")(reinterpret_cast(&schema)); + std::vector arrays; + py::list batches = tbl.attr("to_batches")(); + for (auto& batch : batches) { + arrays.emplace_back(); + batch.attr("_export_to_c")(reinterpret_cast(&arrays.back())); + } + return {std::move(schema), std::move(arrays)}; +} + +static py::object toPyArrow(const py::object& obj, + const std::shared_ptr& importCache) { + if (PyConnection::isPandasDataframe(obj)) { + return importCache->pyarrow.lib.Table.from_pandas()(obj); + } + + if (PyConnection::isPolarsDataframe(obj)) { + return obj.attr("to_arrow")(); + } + + if (PyConnection::isPyArrowTable(obj)) { + return obj; + } + + throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame"); +} + +std::unique_ptr PyConnection::createArrowCsrRelTable(const std::string& tableName, + const std::string& srcTableName, const std::string& dstTableName, py::object fwdIndices, + py::object fwdIndptr, py::object bwdIndices, py::object bwdIndptr) { + auto& stateRef = refState(); + py::gil_scoped_acquire acquire; + + bool hasBwd = !bwdIndices.is_none(); + if (hasBwd != !bwdIndptr.is_none()) { + throw RuntimeException("bwd_indices and bwd_indptr must both be provided or both be None"); + } + + fwdIndices = toPyArrow(fwdIndices, importCache); + fwdIndptr = toPyArrow(fwdIndptr, importCache); + + py::list keepAlive; + keepAlive.append(fwdIndices); + keepAlive.append(fwdIndices.attr("to_batches")()); + keepAlive.append(fwdIndptr); + keepAlive.append(fwdIndptr.attr("to_batches")()); + + auto [fwdIdxSchema, fwdIdxArrays] = exportPyArrowTable(fwdIndices); + auto [fwdIpSchema, fwdIpArrays] = exportPyArrowTable(fwdIndptr); + + std::optional bwdIdxSchema; + std::optional> bwdIdxArrays; + std::optional bwdIpSchema; + std::optional> bwdIpArrays; + if (hasBwd) { + bwdIndices = toPyArrow(bwdIndices, importCache); + bwdIndptr = toPyArrow(bwdIndptr, importCache); + keepAlive.append(bwdIndices); + keepAlive.append(bwdIndices.attr("to_batches")()); + keepAlive.append(bwdIndptr); + keepAlive.append(bwdIndptr.attr("to_batches")()); + auto [bis, bia] = exportPyArrowTable(bwdIndices); + auto [bps, bpa] = exportPyArrowTable(bwdIndptr); + bwdIdxSchema = std::move(bis); + bwdIdxArrays = std::move(bia); + bwdIpSchema = std::move(bps); + bwdIpArrays = std::move(bpa); + } + + auto result = ArrowTableSupport::createArrowCsrRelTable(stateRef.ref(), tableName, srcTableName, + dstTableName, std::move(fwdIdxSchema), std::move(fwdIdxArrays), std::move(fwdIpSchema), + std::move(fwdIpArrays), std::move(bwdIdxSchema), std::move(bwdIdxArrays), + std::move(bwdIpSchema), std::move(bwdIpArrays)); + if (result.queryResult && result.queryResult->isSuccess()) { + stateRef.arrowTableRefs[tableName] = std::move(keepAlive); + } + + return checkAndWrapQueryResult(result.queryResult, state); +} + std::unique_ptr PyConnection::dropArrowTable(const std::string& tableName) { auto& stateRef = refState(); auto result = ArrowTableSupport::unregisterArrowTable(stateRef.ref(), tableName); diff --git a/src_py/_lbug_capi.py b/src_py/_lbug_capi.py index eb464e2..c6d6e9b 100644 --- a/src_py/_lbug_capi.py +++ b/src_py/_lbug_capi.py @@ -339,6 +339,27 @@ def _setup_signatures() -> None: ] _LIB.lbug_connection_drop_arrow_table.restype = ctypes.c_int + _LIB.lbug_connection_create_arrow_csr_rel_table.argtypes = [ + ctypes.POINTER(_LbugConnection), # connection + ctypes.c_char_p, # table_name + ctypes.c_char_p, # src_table_name + ctypes.c_char_p, # dst_table_name + ctypes.POINTER(_ArrowSchema), # fwd_indices_schema + ctypes.POINTER(_ArrowArray), # fwd_indices_arrays + ctypes.c_uint64, # fwd_indices_num_arrays + ctypes.POINTER(_ArrowSchema), # fwd_indptr_schema + ctypes.POINTER(_ArrowArray), # fwd_indptr_arrays + ctypes.c_uint64, # fwd_indptr_num_arrays + ctypes.POINTER(_ArrowSchema), # bwd_indices_schema (nullable) + ctypes.POINTER(_ArrowArray), # bwd_indices_arrays (nullable) + ctypes.c_uint64, # bwd_indices_num_arrays + ctypes.POINTER(_ArrowSchema), # bwd_indptr_schema (nullable) + ctypes.POINTER(_ArrowArray), # bwd_indptr_arrays (nullable) + ctypes.c_uint64, # bwd_indptr_num_arrays + ctypes.POINTER(_LbugQueryResult), # out_query_result + ] + _LIB.lbug_connection_create_arrow_csr_rel_table.restype = ctypes.c_int + _LIB.lbug_prepared_statement_destroy.argtypes = [ ctypes.POINTER(_LbugPreparedStatement) ] @@ -2340,3 +2361,62 @@ def create_arrow_rel_table( if state != _LBUG_SUCCESS and not result._query_result: _check_state(state, "Failed to create Arrow relationship table") return QueryResult(result) + + def create_arrow_csr_rel_table( + self, + table_name: str, + src_table_name: str, + dst_table_name: str, + fwd_indices: Any, + fwd_indptr: Any, + bwd_indices: Any = None, + bwd_indptr: Any = None, + ) -> QueryResult: + has_bwd = bwd_indices is not None + if has_bwd != (bwd_indptr is not None): + msg = "bwd_indices and bwd_indptr must both be provided or both be None" + raise ValueError(msg) + + _fi_tbl, fi_schema, fi_arrays, _fi_b = self._export_arrow_table(fwd_indices) + _fp_tbl, fp_schema, fp_arrays, _fp_b = self._export_arrow_table(fwd_indptr) + + if has_bwd: + _bi_tbl, bi_schema, bi_arrays, _bi_b = self._export_arrow_table(bwd_indices) + _bp_tbl, bp_schema, bp_arrays, _bp_b = self._export_arrow_table(bwd_indptr) + bi_schema_ref = ctypes.byref(bi_schema) + bi_arrays_ref = bi_arrays + bi_num = len(bi_arrays) + bp_schema_ref = ctypes.byref(bp_schema) + bp_arrays_ref = bp_arrays + bp_num = len(bp_arrays) + else: + bi_schema_ref = None + bi_arrays_ref = None + bi_num = 0 + bp_schema_ref = None + bp_arrays_ref = None + bp_num = 0 + + result = _LbugQueryResult() + state = _LIB.lbug_connection_create_arrow_csr_rel_table( + ctypes.byref(self._connection), + table_name.encode("utf-8"), + src_table_name.encode("utf-8"), + dst_table_name.encode("utf-8"), + ctypes.byref(fi_schema), + fi_arrays, + len(fi_arrays), + ctypes.byref(fp_schema), + fp_arrays, + len(fp_arrays), + bi_schema_ref, + bi_arrays_ref, + bi_num, + bp_schema_ref, + bp_arrays_ref, + bp_num, + ctypes.byref(result), + ) + if state != _LBUG_SUCCESS and not result._query_result: + _check_state(state, "Failed to create Arrow CSR relationship table") + return QueryResult(result) diff --git a/src_py/connection.py b/src_py/connection.py index 4fe6665..83dc9f6 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -857,3 +857,75 @@ def create_arrow_rel_table( if not query_result_internal.isSuccess(): raise RuntimeError(query_result_internal.getErrorMessage()) return QueryResult(self, query_result_internal) + + def create_arrow_csr_rel_table( + self, + table_name: str, + src_table_name: str, + dst_table_name: str, + fwd_indices: Any, + fwd_indptr: Any, + bwd_indices: Any = None, + bwd_indptr: Any = None, + ) -> QueryResult: + """ + Create an Arrow CSR memory-backed relationship table. + + Parameters + ---------- + table_name : str + Name of the relationship table to create. + src_table_name : str + Source node table name. + dst_table_name : str + Destination node table name. + fwd_indices : Any + Forward adjacency indices table (struct array: child[0] = UINT64 dst offsets, + optional further children are edge properties). Accepts pandas, polars, or pyarrow. + fwd_indptr : Any + Forward adjacency indptr table (struct array: child[0] = UINT64 row pointers). + bwd_indices : Any, optional + Backward adjacency indices table. Must be provided together with bwd_indptr. + bwd_indptr : Any, optional + Backward adjacency indptr table. Must be provided together with bwd_indices. + + Returns + ------- + QueryResult + Result of the table creation query. + """ + has_bwd = bwd_indices is not None + if has_bwd != (bwd_indptr is not None): + msg = "bwd_indices and bwd_indptr must both be provided or both be None" + raise ValueError(msg) + + self.init_connection() + try: + query_result_internal = self._connection.create_arrow_csr_rel_table( + table_name, + src_table_name, + dst_table_name, + fwd_indices, + fwd_indptr, + bwd_indices, + bwd_indptr, + ) + except NotImplementedError: + py_connection = self._get_pybind_connection() + if py_connection is None: + raise + self._prefer_pybind = True + query_result_internal = py_connection.create_arrow_csr_rel_table( + table_name, + src_table_name, + dst_table_name, + fwd_indices, + fwd_indptr, + bwd_indices, + bwd_indptr, + ) + + if not query_result_internal.isSuccess(): + raise RuntimeError(query_result_internal.getErrorMessage()) + + return QueryResult(self, query_result_internal) diff --git a/test/test_arrow_memory_backed_table.py b/test/test_arrow_memory_backed_table.py index abd4680..e425336 100644 --- a/test/test_arrow_memory_backed_table.py +++ b/test/test_arrow_memory_backed_table.py @@ -1,4 +1,5 @@ import polars as pl +import pyarrow as pa import pytest from type_aliases import ConnDB @@ -402,3 +403,170 @@ def test_arrow_memory_backed_native_node_and_arrow_rel_table( ] conn.drop_arrow_table("native_people_arrow_knows") + + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers for CSR tests +# ───────────────────────────────────────────────────────────────────────────── + + +def _make_csr_node_table(conn, table_name: str) -> None: + """4 Arrow nodes: offsets 0=A, 1=B, 2=C, 3=D (id=0..3).""" + df = pl.DataFrame({"id": [0, 1, 2, 3]}) + conn.create_arrow_table(table_name, df) + + +def _make_csr_fwd_indices() -> pa.Table: + """FWD adjacency: (dst=1,w=10),(dst=2,w=20),(dst=2,w=30),(dst=3,w=40).""" + return pa.table( + { + "dst_offset": pa.array([1, 2, 2, 3], type=pa.uint64()), + "weight": pa.array([10, 20, 30, 40], type=pa.int64()), + } + ) + + +def _make_csr_fwd_indptr() -> pa.Table: + """FWD indptr: node0→edges[0..1], node1→edges[2], node2→edges[3], node3→(empty).""" + return pa.table({"v": pa.array([0, 2, 3, 4, 4], type=pa.uint64())}) + + +def _make_csr_bwd_indices() -> pa.Table: + """BWD adjacency: (src=0,w=10),(src=0,w=20),(src=1,w=30),(src=2,w=40).""" + return pa.table( + { + "src_offset": pa.array([0, 0, 1, 2], type=pa.uint64()), + "weight": pa.array([10, 20, 30, 40], type=pa.int64()), + } + ) + + +def _make_csr_bwd_indptr() -> pa.Table: + """BWD indptr: node0→(empty), node1→edges[0], node2→edges[1..2], node3→edges[3].""" + return pa.table({"v": pa.array([0, 0, 1, 3, 4], type=pa.uint64())}) + + +# ───────────────────────────────────────────────────────────────────────────── +# CSR tests +# ───────────────────────────────────────────────────────────────────────────── + + +def test_arrow_csr_rel_table_fwd_only(conn_db_empty: ConnDB) -> None: + """FWD-only CSR table: count=4, sum=100.""" + conn, _ = conn_db_empty + _make_csr_node_table(conn, "csr_node") + + conn.create_arrow_csr_rel_table( + "csr_rel", + "csr_node", + "csr_node", + _make_csr_fwd_indices(), + _make_csr_fwd_indptr(), + ) + + result = conn.execute("MATCH (:csr_node)-[:csr_rel]->(:csr_node) RETURN count(*)") + assert result.get_next()[0] == 4 + + result = conn.execute( + "MATCH (:csr_node)-[e:csr_rel]->(:csr_node) RETURN sum(e.weight)" + ) + assert result.get_next()[0] == 100 + + conn.drop_arrow_table("csr_rel") + conn.drop_arrow_table("csr_node") + + +def test_arrow_csr_rel_table_with_bwd(conn_db_empty: ConnDB) -> None: + """CSR table with explicit BWD adjacency: BWD count=4, BWD sum=100.""" + conn, _ = conn_db_empty + _make_csr_node_table(conn, "csr_node") + + conn.create_arrow_csr_rel_table( + "csr_rel", + "csr_node", + "csr_node", + _make_csr_fwd_indices(), + _make_csr_fwd_indptr(), + _make_csr_bwd_indices(), + _make_csr_bwd_indptr(), + ) + + result = conn.execute("MATCH (:csr_node)<-[:csr_rel]-(:csr_node) RETURN count(*)") + assert result.get_next()[0] == 4 + + result = conn.execute( + "MATCH (:csr_node)<-[e:csr_rel]-(:csr_node) RETURN sum(e.weight)" + ) + assert result.get_next()[0] == 100 + + conn.drop_arrow_table("csr_rel") + conn.drop_arrow_table("csr_node") + + +def test_arrow_csr_rel_table_bwd_fallback(conn_db_empty: ConnDB) -> None: + """CSR table without BWD data: fallback full-scan returns correct BWD count.""" + conn, _ = conn_db_empty + _make_csr_node_table(conn, "csr_node") + + conn.create_arrow_csr_rel_table( + "csr_rel", + "csr_node", + "csr_node", + _make_csr_fwd_indices(), + _make_csr_fwd_indptr(), + ) + + result = conn.execute("MATCH (:csr_node)<-[:csr_rel]-(:csr_node) RETURN count(*)") + assert result.get_next()[0] == 4 + + conn.drop_arrow_table("csr_rel") + conn.drop_arrow_table("csr_node") + + +def test_arrow_csr_rel_table_partial_bwd_raises(conn_db_empty: ConnDB) -> None: + """Providing only bwd_indices or only bwd_indptr must raise ValueError.""" + conn, _ = conn_db_empty + _make_csr_node_table(conn, "csr_node") + + with pytest.raises((ValueError, RuntimeError)): + conn.create_arrow_csr_rel_table( + "csr_rel", + "csr_node", + "csr_node", + _make_csr_fwd_indices(), + _make_csr_fwd_indptr(), + bwd_indices=_make_csr_bwd_indices(), + bwd_indptr=None, + ) + + with pytest.raises((ValueError, RuntimeError)): + conn.create_arrow_csr_rel_table( + "csr_rel2", + "csr_node", + "csr_node", + _make_csr_fwd_indices(), + _make_csr_fwd_indptr(), + bwd_indices=None, + bwd_indptr=_make_csr_bwd_indptr(), + ) + + conn.drop_arrow_table("csr_node") + + +def test_arrow_csr_over_native_node_table_raises(conn_db_empty: ConnDB) -> None: + """CSR rel table over a native (non-Arrow) node table must fail.""" + conn, _ = conn_db_empty + conn.execute( + "CREATE NODE TABLE native_node(id INT64, PRIMARY KEY(id));" + "CREATE (:native_node {id: 0});" + "CREATE (:native_node {id: 1});" + ) + + with pytest.raises((RuntimeError, Exception)): + conn.create_arrow_csr_rel_table( + "csr_rel", + "native_node", + "native_node", + _make_csr_fwd_indices(), + _make_csr_fwd_indptr(), + )