diff --git a/.gitignore b/.gitignore
index b3e5e211..29ac9edf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,5 +7,8 @@
# Local integration test datasets
integration/testdata/local/
+# Local benchmark comparison output
+.bench/
+
# Local test and metric artifacts
.coverage/
diff --git a/Makefile b/Makefile
index 81dd8363..71b69d9e 100644
--- a/Makefile
+++ b/Makefile
@@ -3,6 +3,12 @@ THIS_FILE := $(lastword $(MAKEFILE_LIST))
# Go configuration
GO_CMD ?= go
CGO_ENABLED ?= 0
+BENCH ?= .
+BENCH_COUNT ?= 10
+BENCH_TIME ?= 1s
+BENCH_BASE ?= main
+BENCH_TARGET ?= HEAD
+BENCH_KIND ?= all
# Main packages to test/build
MAIN_PACKAGES := $(shell $(GO_CMD) list ./...)
@@ -92,6 +98,14 @@ test_integration:
@echo "Running all integration tests..."
@$(GO_CMD) test -tags 'manual_integration integration' -race -cover -count=1 -p=1 -parallel=1 $(MAIN_PACKAGES)
+test_bench:
+ @echo "Running benchmarks..."
+ @$(GO_CMD) test -run '^$$' -bench '$(BENCH)' -benchmem -count=$(BENCH_COUNT) -benchtime=$(BENCH_TIME) $(MAIN_PACKAGES)
+
+bench_diff:
+ @echo "Running benchmark diff..."
+ @$(GO_CMD) run ./cmd/benchdiff --base '$(BENCH_BASE)' --target '$(BENCH_TARGET)' --kind '$(BENCH_KIND)' --bench '$(BENCH)' --bench-count '$(BENCH_COUNT)' --benchtime '$(BENCH_TIME)' $(BENCHDIFF_ARGS)
+
test_neo4j:
@echo "Running Neo4j integration tests..."
@$(GO_CMD) test -tags integration -race -cover -count=1 -p=1 -parallel=1 $(MAIN_PACKAGES)
@@ -216,6 +230,7 @@ help:
@echo " test_all - Run all tests including integration tests"
@echo " test_integration - Run all integration tests"
@echo " test_bench - Run benchmark test"
+ @echo " bench_diff - Compare benchmarks between commits"
@echo " test_neo4j - Run Neo4j integration tests"
@echo " test_pg - Run PostgreSQL integration tests"
@echo " test_update - Update test cases"
diff --git a/README.md b/README.md
index e1aad7bb..f2eef241 100644
--- a/README.md
+++ b/README.md
@@ -56,6 +56,35 @@ export CONNECTION_STRING="neo4j://neo4j:weneedbetterpasswords@localhost:7687"
Use `make test` for unit tests only and `make test_integration` for integration tests only.
+### Benchmarking
+
+Run the package benchmark suite with:
+
+```bash
+make test_bench
+```
+
+Use `cmd/benchdiff` to compare benchmarks between two committed refs without changing the active worktree:
+
+```bash
+go run ./cmd/benchdiff -base main -target HEAD -kind unit
+```
+
+For integration benchmark comparisons, provide the same `CONNECTION_STRING` used by integration tests:
+
+```bash
+export CONNECTION_STRING="postgresql://dawgs:weneedbetterpasswords@localhost:65432/dawgs"
+go run ./cmd/benchdiff -base main -target HEAD -kind all -driver pg -fail-regression 10%
+```
+
+The harness writes raw outputs and a Markdown report under `.bench/runs/` by default. The report begins with comparison
+findings, includes the raw `benchstat` output for each benchmark suite, and ends with a table of all captured benchmark
+numbers.
+
+The integration benchmark runner includes committed `base` and `traversal_shapes` datasets by default. The traversal
+shape suite checks expected result counts for chain, fanout, bounded cycle, disconnected, edge-kind-selective, and
+multi-path shortest-path scenarios before recording timings.
+
### Test Metrics
`make test` writes unit test coverage artifacts under `.coverage/`:
diff --git a/cmd/benchdiff/README.md b/cmd/benchdiff/README.md
new file mode 100644
index 00000000..41774194
--- /dev/null
+++ b/cmd/benchdiff/README.md
@@ -0,0 +1,52 @@
+# Benchdiff
+
+Compares the existing benchmark suites between two committed git refs without changing the active worktree.
+
+## Usage
+
+```bash
+# Unit Go benchmarks only
+go run ./cmd/benchdiff -base main -target HEAD -kind unit
+
+# Unit and integration benchmarks
+export CONNECTION_STRING="postgresql://dawgs:weneedbetterpasswords@localhost:65432/dawgs"
+go run ./cmd/benchdiff -base main -target HEAD -kind all -driver pg
+
+# Fail if a benchmark median regresses by more than 10%
+go run ./cmd/benchdiff -base main -target HEAD -kind unit -fail-regression 10%
+```
+
+`benchdiff` creates detached worktrees under `.bench/`, runs each selected benchmark suite, writes raw output, and
+produces a Markdown report. The report starts with comparison findings, including median regressions, improvements,
+unchanged counts, and benchmark names that only appeared in one ref. It ends with an `All Executed Benchmark Numbers`
+section that lists the median, percent change, and sample counts for every benchmark captured in either ref. Worktrees
+are removed by default after the run; pass `-keep-worktrees` to preserve them.
+
+## Flags
+
+| Flag | Default | Description |
+|------|---------|-------------|
+| `-base` | `main` | Base git ref |
+| `-target` | `HEAD` | Target git ref |
+| `-kind` | `all` | Benchmark kind (`all`, `unit`, `integration`) |
+| `-packages` | `./...` | Package list for Go benchmarks |
+| `-bench` | `.` | Go benchmark regexp |
+| `-bench-count` | `10` | Go benchmark repetition count |
+| `-benchtime` | `1s` | Go benchmark benchtime |
+| `-driver` | `pg` | Integration benchmark database driver |
+| `-connection` | | Integration connection string (or `CONNECTION_STRING`) |
+| `-dataset` | | Run only this integration dataset |
+| `-local-dataset` | | Add a local integration dataset |
+| `-dataset-dir` | `integration/testdata` | Integration testdata directory |
+| `-integration-iterations` | `10` | Timed iterations per integration scenario |
+| `-out` | `.bench/runs/..-` | Output directory |
+| `-benchstat` | `auto` | `benchstat` command, `auto`, or `none` |
+| `-fail-regression` | `0` | Median regression percentage that fails the command |
+| `-keep-worktrees` | `false` | Preserve temporary worktrees |
+
+If `benchstat` is not on `PATH` and `-benchstat auto` is used, the harness falls back to
+`go run golang.org/x/perf/cmd/benchstat@latest`.
+
+Integration comparisons use native `cmd/benchmark -format benchfmt` when both refs support it. If either ref predates
+that flag, the harness runs both refs in Markdown compatibility mode and compares each scenario's median as a single
+`ns/op` sample.
diff --git a/cmd/benchdiff/benchfmt.go b/cmd/benchdiff/benchfmt.go
new file mode 100644
index 00000000..cdc5a68d
--- /dev/null
+++ b/cmd/benchdiff/benchfmt.go
@@ -0,0 +1,325 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "math"
+ "os"
+ "regexp"
+ "runtime"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+ "unicode"
+)
+
+var benchmarkLinePattern = regexp.MustCompile(`^(Benchmark\S+)\s+\d+\s+([0-9]+(?:\.[0-9]+)?)\s+ns/op\b`)
+
+type benchmarkSamples map[string][]float64
+
+type comparisonFindings struct {
+ Compared int
+ Regressions []benchmarkFinding
+ Improvements []benchmarkFinding
+ Unchanged int
+ OnlyBase []string
+ OnlyTarget []string
+ Results []benchmarkResult
+}
+
+type benchmarkFinding struct {
+ Name string
+ BaseMedianNS float64
+ TargetMedianNS float64
+ DeltaPercent float64
+}
+
+type benchmarkResult struct {
+ Name string
+ BaseMedianNS float64
+ TargetMedianNS float64
+ DeltaPercent float64
+ BaseSamples int
+ TargetSamples int
+ HasBase bool
+ HasTarget bool
+}
+
+type regression struct {
+ Name string
+ BaseMedianNS float64
+ TargetMedianNS float64
+ Percent float64
+}
+
+func parseBenchfmtNS(data []byte) benchmarkSamples {
+ samples := benchmarkSamples{}
+ scanner := bufio.NewScanner(bytes.NewReader(data))
+
+ for scanner.Scan() {
+ matches := benchmarkLinePattern.FindStringSubmatch(scanner.Text())
+ if len(matches) != 3 {
+ continue
+ }
+
+ ns, err := strconv.ParseFloat(matches[2], 64)
+ if err != nil {
+ continue
+ }
+
+ samples[matches[1]] = append(samples[matches[1]], ns)
+ }
+
+ return samples
+}
+
+func parseBenchfmtNSFile(path string) (benchmarkSamples, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ return parseBenchfmtNS(data), nil
+}
+
+func summarizeFindings(base, target benchmarkSamples) comparisonFindings {
+ findings := comparisonFindings{}
+ names := map[string]struct{}{}
+
+ for name := range base {
+ names[name] = struct{}{}
+ }
+ for name := range target {
+ names[name] = struct{}{}
+ }
+
+ for name := range names {
+ baseValues := base[name]
+ targetValues := target[name]
+ result := benchmarkResult{
+ Name: name,
+ BaseSamples: len(baseValues),
+ TargetSamples: len(targetValues),
+ HasBase: len(baseValues) > 0,
+ HasTarget: len(targetValues) > 0,
+ }
+
+ switch {
+ case len(baseValues) == 0:
+ result.TargetMedianNS = median(targetValues)
+ findings.Results = append(findings.Results, result)
+ findings.OnlyTarget = append(findings.OnlyTarget, name)
+ continue
+ case len(targetValues) == 0:
+ result.BaseMedianNS = median(baseValues)
+ findings.Results = append(findings.Results, result)
+ findings.OnlyBase = append(findings.OnlyBase, name)
+ continue
+ }
+
+ baseMedian := median(baseValues)
+ targetMedian := median(targetValues)
+ if baseMedian <= 0 {
+ findings.Results = append(findings.Results, result)
+ continue
+ }
+
+ result.BaseMedianNS = baseMedian
+ result.TargetMedianNS = targetMedian
+ findings.Compared++
+ deltaPercent := ((targetMedian - baseMedian) / baseMedian) * 100
+ result.DeltaPercent = deltaPercent
+ findings.Results = append(findings.Results, result)
+
+ finding := benchmarkFinding{
+ Name: name,
+ BaseMedianNS: baseMedian,
+ TargetMedianNS: targetMedian,
+ DeltaPercent: deltaPercent,
+ }
+
+ switch {
+ case deltaPercent > 0:
+ findings.Regressions = append(findings.Regressions, finding)
+ case deltaPercent < 0:
+ findings.Improvements = append(findings.Improvements, finding)
+ default:
+ findings.Unchanged++
+ }
+ }
+
+ sort.Slice(findings.Regressions, func(i, j int) bool {
+ return findings.Regressions[i].DeltaPercent > findings.Regressions[j].DeltaPercent
+ })
+ sort.Slice(findings.Improvements, func(i, j int) bool {
+ return findings.Improvements[i].DeltaPercent < findings.Improvements[j].DeltaPercent
+ })
+ sort.Strings(findings.OnlyBase)
+ sort.Strings(findings.OnlyTarget)
+ sort.Slice(findings.Results, func(i, j int) bool {
+ return findings.Results[i].Name < findings.Results[j].Name
+ })
+
+ return findings
+}
+
+func findingsForFiles(baseFile, targetFile string) (comparisonFindings, error) {
+ base, err := parseBenchfmtNSFile(baseFile)
+ if err != nil {
+ return comparisonFindings{}, err
+ }
+ target, err := parseBenchfmtNSFile(targetFile)
+ if err != nil {
+ return comparisonFindings{}, err
+ }
+
+ return summarizeFindings(base, target), nil
+}
+
+func (findings comparisonFindings) regressionsOver(threshold float64) []regression {
+ if threshold <= 0 {
+ return nil
+ }
+
+ var regressions []regression
+ for _, finding := range findings.Regressions {
+ if finding.DeltaPercent <= threshold {
+ continue
+ }
+
+ regressions = append(regressions, regression{
+ Name: finding.Name,
+ BaseMedianNS: finding.BaseMedianNS,
+ TargetMedianNS: finding.TargetMedianNS,
+ Percent: finding.DeltaPercent,
+ })
+ }
+
+ return regressions
+}
+
+func median(values []float64) float64 {
+ sorted := append([]float64(nil), values...)
+ sort.Float64s(sorted)
+
+ mid := len(sorted) / 2
+ if len(sorted)%2 == 0 {
+ return (sorted[mid-1] + sorted[mid]) / 2
+ }
+
+ return sorted[mid]
+}
+
+func writeIntegrationBenchfmt(w io.Writer, driver string, rows []markdownBenchmarkRow) error {
+ fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
+ fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
+ fmt.Fprintln(w, "pkg: github.com/specterops/dawgs/cmd/benchmark")
+
+ procs := runtime.GOMAXPROCS(0)
+ for _, row := range rows {
+ fmt.Fprintf(w, "%s-%d\t1\t%d ns/op\n", integrationBenchmarkName(driver, row.Dataset, row.Query), procs, row.Median.Nanoseconds())
+ }
+
+ return nil
+}
+
+func integrationBenchmarkName(driver, dataset, query string) string {
+ return strings.Join([]string{
+ "BenchmarkDawgsIntegration",
+ sanitizeBenchNamePart(driver),
+ sanitizeBenchNamePart(dataset),
+ sanitizeBenchNamePart(query),
+ }, "/")
+}
+
+func sanitizeBenchNamePart(value string) string {
+ var builder strings.Builder
+ lastUnderscore := false
+
+ for _, char := range value {
+ switch {
+ case char == '/' || char == '-' || char == '_':
+ if char == '_' {
+ if !lastUnderscore {
+ builder.WriteRune(char)
+ }
+ lastUnderscore = true
+ } else {
+ builder.WriteRune(char)
+ lastUnderscore = false
+ }
+ case unicode.IsLetter(char) || unicode.IsDigit(char):
+ builder.WriteRune(char)
+ lastUnderscore = false
+ case unicode.IsSpace(char):
+ if !lastUnderscore {
+ builder.WriteByte('_')
+ }
+ lastUnderscore = true
+ default:
+ if !lastUnderscore {
+ builder.WriteByte('_')
+ }
+ lastUnderscore = true
+ }
+ }
+
+ if builder.Len() == 0 {
+ return "unknown"
+ }
+
+ return builder.String()
+}
+
+func parseBenchmarkDuration(value string) (time.Duration, error) {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" || trimmed == "-" {
+ return 0, fmt.Errorf("empty benchmark duration")
+ }
+
+ unitStart := len(trimmed)
+ for idx, char := range trimmed {
+ if (char < '0' || char > '9') && char != '.' {
+ unitStart = idx
+ break
+ }
+ }
+
+ number, err := strconv.ParseFloat(strings.TrimSpace(trimmed[:unitStart]), 64)
+ if err != nil {
+ return 0, err
+ }
+
+ unit := strings.TrimSpace(trimmed[unitStart:])
+ switch unit {
+ case "ns":
+ return time.Duration(math.Round(number)), nil
+ case "us":
+ return time.Duration(math.Round(number * float64(time.Microsecond))), nil
+ case "ms":
+ return time.Duration(math.Round(number * float64(time.Millisecond))), nil
+ case "s":
+ return time.Duration(math.Round(number * float64(time.Second))), nil
+ default:
+ return 0, fmt.Errorf("unsupported benchmark duration unit %q", unit)
+ }
+}
diff --git a/cmd/benchdiff/benchfmt_test.go b/cmd/benchdiff/benchfmt_test.go
new file mode 100644
index 00000000..0b692352
--- /dev/null
+++ b/cmd/benchdiff/benchfmt_test.go
@@ -0,0 +1,143 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "bytes"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseBenchfmtNS(t *testing.T) {
+ samples := parseBenchfmtNS([]byte(`
+goos: linux
+BenchmarkThing-12 10 100.5 ns/op 1 B/op
+BenchmarkThing-12 10 120 ns/op 1 B/op
+BenchmarkOther/sub-12 1 200 ns/op
+`))
+
+ require.Equal(t, []float64{100.5, 120}, samples["BenchmarkThing-12"])
+ require.Equal(t, []float64{200}, samples["BenchmarkOther/sub-12"])
+}
+
+func TestSummarizeFindings(t *testing.T) {
+ base := benchmarkSamples{
+ "BenchmarkRegression-12": {100, 110, 120},
+ "BenchmarkImprovement-12": {200, 200, 200},
+ "BenchmarkSame-12": {100, 100, 100},
+ "BenchmarkOnlyBase-12": {50},
+ }
+ target := benchmarkSamples{
+ "BenchmarkRegression-12": {140, 150, 160},
+ "BenchmarkImprovement-12": {100, 100, 100},
+ "BenchmarkSame-12": {100, 100, 100},
+ "BenchmarkOnlyTarget-12": {75},
+ }
+
+ findings := summarizeFindings(base, target)
+
+ require.Equal(t, 3, findings.Compared)
+ require.Equal(t, 1, findings.Unchanged)
+ require.Equal(t, []string{"BenchmarkOnlyBase-12"}, findings.OnlyBase)
+ require.Equal(t, []string{"BenchmarkOnlyTarget-12"}, findings.OnlyTarget)
+ require.Len(t, findings.Results, 5)
+ require.Len(t, findings.Regressions, 1)
+ require.Equal(t, "BenchmarkRegression-12", findings.Regressions[0].Name)
+ require.InDelta(t, 36.36, findings.Regressions[0].DeltaPercent, 0.01)
+ require.Len(t, findings.Improvements, 1)
+ require.Equal(t, "BenchmarkImprovement-12", findings.Improvements[0].Name)
+ require.Equal(t, -50.0, findings.Improvements[0].DeltaPercent)
+
+ regressions := findings.regressionsOver(10)
+ require.Len(t, regressions, 1)
+ require.Equal(t, "BenchmarkRegression-12", regressions[0].Name)
+ require.InDelta(t, 36.36, regressions[0].Percent, 0.01)
+
+ onlyBase := requireBenchmarkResult(t, findings.Results, "BenchmarkOnlyBase-12")
+ require.True(t, onlyBase.HasBase)
+ require.False(t, onlyBase.HasTarget)
+ require.Equal(t, 1, onlyBase.BaseSamples)
+ require.Equal(t, 50.0, onlyBase.BaseMedianNS)
+
+ onlyTarget := requireBenchmarkResult(t, findings.Results, "BenchmarkOnlyTarget-12")
+ require.False(t, onlyTarget.HasBase)
+ require.True(t, onlyTarget.HasTarget)
+ require.Equal(t, 1, onlyTarget.TargetSamples)
+ require.Equal(t, 75.0, onlyTarget.TargetMedianNS)
+}
+
+func TestParseBenchmarkMarkdown(t *testing.T) {
+ rows := parseBenchmarkMarkdown([]byte(`
+| Query | Dataset | Median | P95 | Max |
+|-------|---------|-------:|----:|----:|
+| Match Nodes | base | 0.14ms | 0.22ms | 0.31ms |
+| Match Edges | base | 464ms | 604ms | 604ms |
+`))
+
+ require.Equal(t, []markdownBenchmarkRow{
+ {Query: "Match Nodes", Dataset: "base", Median: 140 * time.Microsecond},
+ {Query: "Match Edges", Dataset: "base", Median: 464 * time.Millisecond},
+ }, rows)
+}
+
+func TestWriteIntegrationBenchfmt(t *testing.T) {
+ var out bytes.Buffer
+ rows := []markdownBenchmarkRow{{
+ Query: "Shortest Paths / n1 -> n3",
+ Dataset: "base",
+ Median: time.Millisecond,
+ }}
+
+ require.NoError(t, writeIntegrationBenchfmt(&out, "pg", rows))
+ require.Contains(t, out.String(), "BenchmarkDawgsIntegration/pg/base/Shortest_Paths_/_n1_-_n3-")
+ require.Contains(t, out.String(), "\t1\t1000000 ns/op")
+}
+
+func TestParseRegressionThreshold(t *testing.T) {
+ threshold, err := parseRegressionThreshold("10%")
+ require.NoError(t, err)
+ require.Equal(t, 10.0, threshold)
+
+ threshold, err = parseRegressionThreshold("2.5")
+ require.NoError(t, err)
+ require.Equal(t, 2.5, threshold)
+
+ _, err = parseRegressionThreshold("-1")
+ require.Error(t, err)
+}
+
+func TestValidateBenchtime(t *testing.T) {
+ require.NoError(t, validateBenchtime("1s"))
+ require.NoError(t, validateBenchtime("100x"))
+ require.Error(t, validateBenchtime("0x"))
+ require.Error(t, validateBenchtime("soon"))
+}
+
+func requireBenchmarkResult(t *testing.T, results []benchmarkResult, name string) benchmarkResult {
+ t.Helper()
+
+ for _, result := range results {
+ if result.Name == name {
+ return result
+ }
+ }
+
+ require.Failf(t, "benchmark result not found", "%s", name)
+ return benchmarkResult{}
+}
diff --git a/cmd/benchdiff/command.go b/cmd/benchdiff/command.go
new file mode 100644
index 00000000..eae51723
--- /dev/null
+++ b/cmd/benchdiff/command.go
@@ -0,0 +1,91 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "strings"
+)
+
+func gitOutput(ctx context.Context, dir string, args ...string) (string, error) {
+ output, err := runCommand(ctx, dir, nil, "git", args...)
+ if err != nil {
+ return "", err
+ }
+
+ return strings.TrimSpace(string(output)), nil
+}
+
+func runCommand(ctx context.Context, dir string, env []string, name string, args ...string) ([]byte, error) {
+ cmd := exec.CommandContext(ctx, name, args...)
+ cmd.Dir = dir
+ cmd.Env = os.Environ()
+ if len(env) > 0 {
+ cmd.Env = append(cmd.Env, env...)
+ }
+
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ return output, commandError{
+ Name: name,
+ Args: args,
+ Err: err,
+ Output: output,
+ }
+ }
+
+ return output, nil
+}
+
+type commandError struct {
+ Name string
+ Args []string
+ Err error
+ Output []byte
+}
+
+func (err commandError) Error() string {
+ var builder strings.Builder
+ builder.WriteString(err.Name)
+ if len(err.Args) > 0 {
+ builder.WriteByte(' ')
+ builder.WriteString(strings.Join(err.Args, " "))
+ }
+ builder.WriteString(": ")
+ builder.WriteString(err.Err.Error())
+
+ output := bytes.TrimSpace(err.Output)
+ if len(output) > 0 {
+ builder.WriteString("\n")
+ builder.Write(outputTail(output, 4096))
+ }
+
+ return builder.String()
+}
+
+func outputTail(output []byte, limit int) []byte {
+ if len(output) <= limit {
+ return output
+ }
+
+ prefix := []byte(fmt.Sprintf("... truncated %d bytes ...\n", len(output)-limit))
+ return append(prefix, output[len(output)-limit:]...)
+}
diff --git a/cmd/benchdiff/compare.go b/cmd/benchdiff/compare.go
new file mode 100644
index 00000000..10792f8b
--- /dev/null
+++ b/cmd/benchdiff/compare.go
@@ -0,0 +1,263 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+)
+
+func runUnitComparison(ctx context.Context, cfg resolvedConfig, baseWorktree, targetWorktree string) (comparison, error) {
+ outDir := filepath.Join(cfg.OutDirAbs, "unit")
+ if err := os.MkdirAll(outDir, 0755); err != nil {
+ return comparison{}, err
+ }
+
+ baseFile := filepath.Join(outDir, "base.txt")
+ targetFile := filepath.Join(outDir, "target.txt")
+
+ if err := runGoBenchmarks(ctx, cfg, baseWorktree, baseFile); err != nil {
+ return comparison{}, err
+ }
+ if err := runGoBenchmarks(ctx, cfg, targetWorktree, targetFile); err != nil {
+ return comparison{}, err
+ }
+
+ benchstatOutput, err := runBenchstat(ctx, cfg, baseFile, targetFile)
+ if err != nil {
+ return comparison{}, err
+ }
+ benchstatFile := filepath.Join(outDir, "benchstat.txt")
+ if err := os.WriteFile(benchstatFile, benchstatOutput, 0644); err != nil {
+ return comparison{}, err
+ }
+
+ findings, err := findingsForFiles(baseFile, targetFile)
+ if err != nil {
+ return comparison{}, err
+ }
+ regressions := findings.regressionsOver(cfg.Threshold)
+
+ return comparison{
+ Name: "Unit Benchmarks",
+ BaseFile: baseFile,
+ TargetFile: targetFile,
+ BenchstatFile: benchstatFile,
+ Benchstat: string(benchstatOutput),
+ Findings: findings,
+ Regressions: regressions,
+ }, nil
+}
+
+func runGoBenchmarks(ctx context.Context, cfg resolvedConfig, worktree, outputPath string) error {
+ args := []string{
+ "test",
+ "-run", "^$",
+ "-bench", cfg.Bench,
+ "-benchmem",
+ "-count", strconv.Itoa(cfg.BenchCount),
+ "-benchtime", cfg.Benchtime,
+ }
+ args = append(args, strings.Fields(cfg.Packages)...)
+
+ output, err := runCommand(ctx, worktree, nil, "go", args...)
+ if writeErr := os.WriteFile(outputPath, output, 0644); writeErr != nil {
+ return writeErr
+ }
+ if err != nil {
+ return fmt.Errorf("run Go benchmarks in %s: %w", worktree, err)
+ }
+
+ return nil
+}
+
+func runIntegrationComparison(ctx context.Context, cfg resolvedConfig, baseWorktree, targetWorktree string) (comparison, error) {
+ outDir := filepath.Join(cfg.OutDirAbs, "integration")
+ binDir := filepath.Join(cfg.OutDirAbs, "bin")
+ if err := os.MkdirAll(outDir, 0755); err != nil {
+ return comparison{}, err
+ }
+ if err := os.MkdirAll(binDir, 0755); err != nil {
+ return comparison{}, err
+ }
+
+ baseBinary := filepath.Join(binDir, "benchmark-base")
+ targetBinary := filepath.Join(binDir, "benchmark-target")
+ if err := buildBenchmarkBinary(ctx, baseWorktree, baseBinary); err != nil {
+ return comparison{}, err
+ }
+ if err := buildBenchmarkBinary(ctx, targetWorktree, targetBinary); err != nil {
+ return comparison{}, err
+ }
+
+ baseSupportsBenchfmt := benchmarkBinarySupportsFormat(ctx, baseBinary)
+ targetSupportsBenchfmt := benchmarkBinarySupportsFormat(ctx, targetBinary)
+ useNativeBenchfmt := baseSupportsBenchfmt && targetSupportsBenchfmt
+
+ baseFile := filepath.Join(outDir, "base.bench")
+ targetFile := filepath.Join(outDir, "target.bench")
+ var notes []string
+ if useNativeBenchfmt {
+ if err := runBenchmarkBinary(ctx, cfg, baseWorktree, baseBinary, baseFile, true); err != nil {
+ return comparison{}, err
+ }
+ if err := runBenchmarkBinary(ctx, cfg, targetWorktree, targetBinary, targetFile, true); err != nil {
+ return comparison{}, err
+ }
+ notes = append(notes, "Used native benchfmt output from cmd/benchmark.")
+ } else {
+ baseMarkdown := filepath.Join(outDir, "base.md")
+ targetMarkdown := filepath.Join(outDir, "target.md")
+
+ if err := runBenchmarkBinary(ctx, cfg, baseWorktree, baseBinary, baseMarkdown, false); err != nil {
+ return comparison{}, err
+ }
+ if err := runBenchmarkBinary(ctx, cfg, targetWorktree, targetBinary, targetMarkdown, false); err != nil {
+ return comparison{}, err
+ }
+
+ if err := markdownFileToBenchfmt(baseMarkdown, baseFile, cfg.Driver); err != nil {
+ return comparison{}, err
+ }
+ if err := markdownFileToBenchfmt(targetMarkdown, targetFile, cfg.Driver); err != nil {
+ return comparison{}, err
+ }
+
+ notes = append(notes, "Used Markdown compatibility mode because at least one ref does not support cmd/benchmark -format benchfmt.")
+ }
+
+ benchstatOutput, err := runBenchstat(ctx, cfg, baseFile, targetFile)
+ if err != nil {
+ return comparison{}, err
+ }
+ benchstatFile := filepath.Join(outDir, "benchstat.txt")
+ if err := os.WriteFile(benchstatFile, benchstatOutput, 0644); err != nil {
+ return comparison{}, err
+ }
+
+ findings, err := findingsForFiles(baseFile, targetFile)
+ if err != nil {
+ return comparison{}, err
+ }
+ regressions := findings.regressionsOver(cfg.Threshold)
+
+ return comparison{
+ Name: "Integration Benchmarks",
+ BaseFile: baseFile,
+ TargetFile: targetFile,
+ BenchstatFile: benchstatFile,
+ Benchstat: string(benchstatOutput),
+ Notes: notes,
+ Findings: findings,
+ Regressions: regressions,
+ }, nil
+}
+
+func buildBenchmarkBinary(ctx context.Context, worktree, outputPath string) error {
+ output, err := runCommand(ctx, worktree, nil, "go", "build", "-o", outputPath, "./cmd/benchmark")
+ if err != nil {
+ return fmt.Errorf("build cmd/benchmark in %s: %w", worktree, err)
+ }
+ if len(bytes.TrimSpace(output)) > 0 {
+ logPath := outputPath + ".log"
+ if writeErr := os.WriteFile(logPath, output, 0644); writeErr != nil {
+ return writeErr
+ }
+ }
+
+ return nil
+}
+
+func benchmarkBinarySupportsFormat(ctx context.Context, binaryPath string) bool {
+ output, err := runCommand(ctx, "", nil, binaryPath, "-h")
+ if err != nil {
+ return false
+ }
+
+ return bytes.Contains(output, []byte("-format"))
+}
+
+func runBenchmarkBinary(ctx context.Context, cfg resolvedConfig, worktree, binaryPath, outputPath string, benchfmt bool) error {
+ args := []string{
+ "-driver", cfg.Driver,
+ "-iterations", strconv.Itoa(cfg.IntegrationIterations),
+ "-dataset-dir", cfg.DatasetDirAbs,
+ "-output", outputPath,
+ }
+ if benchfmt {
+ args = append(args, "-format", "benchfmt")
+ }
+ if cfg.Dataset != "" {
+ args = append(args, "-dataset", cfg.Dataset)
+ }
+ if cfg.LocalDataset != "" {
+ args = append(args, "-local-dataset", cfg.LocalDataset)
+ }
+
+ output, err := runCommand(ctx, worktree, []string{"CONNECTION_STRING=" + cfg.Connection}, binaryPath, args...)
+ logPath := outputPath + ".log"
+ if writeErr := os.WriteFile(logPath, output, 0644); writeErr != nil {
+ return writeErr
+ }
+ if err != nil {
+ return fmt.Errorf("run integration benchmark in %s: %w", worktree, err)
+ }
+
+ return nil
+}
+
+func markdownFileToBenchfmt(markdownPath, benchfmtPath, driver string) error {
+ data, err := os.ReadFile(markdownPath)
+ if err != nil {
+ return err
+ }
+
+ var output bytes.Buffer
+ if err := writeIntegrationBenchfmt(&output, driver, parseBenchmarkMarkdown(data)); err != nil {
+ return err
+ }
+
+ return os.WriteFile(benchfmtPath, output.Bytes(), 0644)
+}
+
+func runBenchstat(ctx context.Context, cfg resolvedConfig, baseFile, targetFile string) ([]byte, error) {
+ if cfg.Benchstat == "none" {
+ return []byte("benchstat skipped\n"), nil
+ }
+
+ if cfg.Benchstat == "" || cfg.Benchstat == "auto" {
+ if benchstatPath, err := exec.LookPath("benchstat"); err == nil {
+ return runCommand(ctx, cfg.Root, nil, benchstatPath, baseFile, targetFile)
+ }
+
+ return runCommand(ctx, cfg.Root, nil, "go", "run", "golang.org/x/perf/cmd/benchstat@latest", baseFile, targetFile)
+ }
+
+ fields := strings.Fields(cfg.Benchstat)
+ if len(fields) == 0 {
+ return nil, fmt.Errorf("empty benchstat command")
+ }
+
+ args := append(fields[1:], baseFile, targetFile)
+ return runCommand(ctx, cfg.Root, nil, fields[0], args...)
+}
diff --git a/cmd/benchdiff/main.go b/cmd/benchdiff/main.go
new file mode 100644
index 00000000..df921705
--- /dev/null
+++ b/cmd/benchdiff/main.go
@@ -0,0 +1,180 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ benchKindAll = "all"
+ benchKindIntegration = "integration"
+ benchKindUnit = "unit"
+)
+
+type config struct {
+ BaseRef string
+ TargetRef string
+ Kind string
+ Packages string
+ Bench string
+ BenchCount int
+ Benchtime string
+ Driver string
+ Connection string
+ Dataset string
+ LocalDataset string
+ DatasetDir string
+ IntegrationIterations int
+ OutDir string
+ Benchstat string
+ FailRegression string
+ KeepWorktrees bool
+}
+
+type resolvedConfig struct {
+ config
+ Root string
+ BaseSHA string
+ TargetSHA string
+ BaseShortSHA string
+ TargetShortSHA string
+ DatasetDirAbs string
+ OutDirAbs string
+ Threshold float64
+}
+
+func main() {
+ cfg, err := parseConfig(os.Args[1:])
+ if err != nil {
+ fmt.Fprintln(os.Stderr, err)
+ os.Exit(2)
+ }
+
+ if err := run(context.Background(), cfg); err != nil {
+ fmt.Fprintln(os.Stderr, err)
+ os.Exit(1)
+ }
+}
+
+func parseConfig(args []string) (config, error) {
+ cfg := config{}
+ flags := flag.NewFlagSet("benchdiff", flag.ContinueOnError)
+ flags.SetOutput(os.Stderr)
+
+ flags.StringVar(&cfg.BaseRef, "base", "main", "base git ref to benchmark")
+ flags.StringVar(&cfg.TargetRef, "target", "HEAD", "target git ref to benchmark")
+ flags.StringVar(&cfg.Kind, "kind", benchKindAll, "benchmark kind: all, unit, integration")
+ flags.StringVar(&cfg.Packages, "packages", "./...", "package list for Go benchmarks")
+ flags.StringVar(&cfg.Bench, "bench", ".", "Go benchmark regexp")
+ flags.IntVar(&cfg.BenchCount, "bench-count", 10, "Go benchmark repetition count")
+ flags.StringVar(&cfg.Benchtime, "benchtime", "1s", "Go benchmark benchtime")
+ flags.StringVar(&cfg.Driver, "driver", "pg", "integration benchmark database driver")
+ flags.StringVar(&cfg.Connection, "connection", "", "integration database connection string (or CONNECTION_STRING)")
+ flags.StringVar(&cfg.Dataset, "dataset", "", "run only this integration dataset")
+ flags.StringVar(&cfg.LocalDataset, "local-dataset", "", "additional local integration dataset")
+ flags.StringVar(&cfg.DatasetDir, "dataset-dir", "integration/testdata", "integration testdata directory")
+ flags.IntVar(&cfg.IntegrationIterations, "integration-iterations", 10, "timed iterations per integration scenario")
+ flags.StringVar(&cfg.OutDir, "out", "", "output directory")
+ flags.StringVar(&cfg.Benchstat, "benchstat", "auto", "benchstat command, auto, or none")
+ flags.StringVar(&cfg.FailRegression, "fail-regression", "0", "fail when median ns/op regression exceeds this percent, e.g. 10%")
+ flags.BoolVar(&cfg.KeepWorktrees, "keep-worktrees", false, "keep temporary git worktrees")
+
+ if err := flags.Parse(args); err != nil {
+ return config{}, err
+ }
+ if flags.NArg() != 0 {
+ return config{}, fmt.Errorf("unexpected positional arguments: %s", strings.Join(flags.Args(), " "))
+ }
+ if !isBenchKind(cfg.Kind) {
+ return config{}, fmt.Errorf("unsupported benchmark kind %q", cfg.Kind)
+ }
+ if cfg.BenchCount < 1 {
+ return config{}, fmt.Errorf("bench-count must be at least 1")
+ }
+ if cfg.IntegrationIterations < 1 {
+ return config{}, fmt.Errorf("integration-iterations must be at least 1")
+ }
+ if err := validateBenchtime(cfg.Benchtime); err != nil {
+ return config{}, err
+ }
+ if _, err := parseRegressionThreshold(cfg.FailRegression); err != nil {
+ return config{}, err
+ }
+
+ return cfg, nil
+}
+
+func isBenchKind(kind string) bool {
+ switch kind {
+ case benchKindAll, benchKindIntegration, benchKindUnit:
+ return true
+ default:
+ return false
+ }
+}
+
+func (cfg config) runsUnitBenchmarks() bool {
+ return cfg.Kind == benchKindAll || cfg.Kind == benchKindUnit
+}
+
+func (cfg config) runsIntegrationBenchmarks() bool {
+ return cfg.Kind == benchKindAll || cfg.Kind == benchKindIntegration
+}
+
+func parseRegressionThreshold(value string) (float64, error) {
+ trimmed := strings.TrimSpace(value)
+ trimmed = strings.TrimSuffix(trimmed, "%")
+
+ if trimmed == "" {
+ return 0, fmt.Errorf("fail-regression must be a non-negative percent")
+ }
+
+ threshold, err := strconv.ParseFloat(trimmed, 64)
+ if err != nil {
+ return 0, fmt.Errorf("invalid fail-regression %q: %w", value, err)
+ }
+ if threshold < 0 {
+ return 0, fmt.Errorf("fail-regression must be a non-negative percent")
+ }
+
+ return threshold, nil
+}
+
+func validateBenchtime(value string) error {
+ trimmed := strings.TrimSpace(value)
+ if strings.HasSuffix(trimmed, "x") {
+ count, err := strconv.Atoi(strings.TrimSuffix(trimmed, "x"))
+ if err != nil || count < 1 {
+ return fmt.Errorf("invalid benchtime %q", value)
+ }
+
+ return nil
+ }
+
+ if _, err := time.ParseDuration(trimmed); err != nil {
+ return fmt.Errorf("invalid benchtime %q: %w", value, err)
+ }
+
+ return nil
+}
diff --git a/cmd/benchdiff/markdown.go b/cmd/benchdiff/markdown.go
new file mode 100644
index 00000000..b472b3b3
--- /dev/null
+++ b/cmd/benchdiff/markdown.go
@@ -0,0 +1,76 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "strings"
+ "time"
+)
+
+type markdownBenchmarkRow struct {
+ Query string
+ Dataset string
+ Median time.Duration
+}
+
+func parseBenchmarkMarkdown(data []byte) []markdownBenchmarkRow {
+ var rows []markdownBenchmarkRow
+ scanner := bufio.NewScanner(bytes.NewReader(data))
+
+ for scanner.Scan() {
+ columns := splitMarkdownTableRow(scanner.Text())
+ if len(columns) < 5 {
+ continue
+ }
+ if columns[0] == "Query" || strings.HasPrefix(columns[0], "---") {
+ continue
+ }
+
+ median, err := parseBenchmarkDuration(columns[2])
+ if err != nil {
+ continue
+ }
+
+ rows = append(rows, markdownBenchmarkRow{
+ Query: columns[0],
+ Dataset: columns[1],
+ Median: median,
+ })
+ }
+
+ return rows
+}
+
+func splitMarkdownTableRow(line string) []string {
+ trimmed := strings.TrimSpace(line)
+ if !strings.HasPrefix(trimmed, "|") || !strings.HasSuffix(trimmed, "|") {
+ return nil
+ }
+
+ trimmed = strings.TrimPrefix(trimmed, "|")
+ trimmed = strings.TrimSuffix(trimmed, "|")
+
+ rawColumns := strings.Split(trimmed, "|")
+ columns := make([]string, 0, len(rawColumns))
+ for _, column := range rawColumns {
+ columns = append(columns, strings.TrimSpace(column))
+ }
+
+ return columns
+}
diff --git a/cmd/benchdiff/report.go b/cmd/benchdiff/report.go
new file mode 100644
index 00000000..416c69b3
--- /dev/null
+++ b/cmd/benchdiff/report.go
@@ -0,0 +1,292 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "time"
+)
+
+const maxFindingRows = 10
+
+func writeRunReport(path string, summary runSummary) error {
+ var out bytes.Buffer
+ cfg := summary.Config
+
+ fmt.Fprintln(&out, "# Benchmark Diff")
+ fmt.Fprintln(&out)
+ fmt.Fprintln(&out, "| Field | Value |")
+ fmt.Fprintln(&out, "|-------|-------|")
+ fmt.Fprintf(&out, "| Base | `%s` (`%s`) |\n", cfg.BaseRef, cfg.BaseShortSHA)
+ fmt.Fprintf(&out, "| Target | `%s` (`%s`) |\n", cfg.TargetRef, cfg.TargetShortSHA)
+ fmt.Fprintf(&out, "| Started | %s |\n", summary.StartedAt.UTC().Format(time.RFC3339))
+ fmt.Fprintf(&out, "| Finished | %s |\n", summary.FinishedAt.UTC().Format(time.RFC3339))
+ fmt.Fprintf(&out, "| Go | %s |\n", summary.GoVersion)
+ fmt.Fprintf(&out, "| Platform | %s/%s |\n", runtime.GOOS, runtime.GOARCH)
+ fmt.Fprintf(&out, "| Kind | `%s` |\n", cfg.Kind)
+ fmt.Fprintf(&out, "| Output | `%s` |\n", cfg.OutDirAbs)
+ if cfg.runsIntegrationBenchmarks() {
+ fmt.Fprintf(&out, "| Driver | `%s` |\n", cfg.Driver)
+ fmt.Fprintf(&out, "| Dataset Dir | `%s` |\n", cfg.DatasetDirAbs)
+ fmt.Fprintf(&out, "| Integration Iterations | %d |\n", cfg.IntegrationIterations)
+ }
+ if cfg.runsUnitBenchmarks() {
+ fmt.Fprintf(&out, "| Packages | `%s` |\n", cfg.Packages)
+ fmt.Fprintf(&out, "| Bench | `%s` |\n", cfg.Bench)
+ fmt.Fprintf(&out, "| Bench Count | %d |\n", cfg.BenchCount)
+ fmt.Fprintf(&out, "| Benchtime | `%s` |\n", cfg.Benchtime)
+ }
+ if cfg.Threshold > 0 {
+ fmt.Fprintf(&out, "| Regression Failure Threshold | %.2f%% |\n", cfg.Threshold)
+ }
+ fmt.Fprintln(&out)
+
+ writeFindingsSummary(&out, summary)
+
+ for _, comparison := range summary.Comparisons {
+ fmt.Fprintf(&out, "## %s\n\n", comparison.Name)
+ for _, note := range comparison.Notes {
+ fmt.Fprintf(&out, "- %s\n", note)
+ }
+ if len(comparison.Notes) > 0 {
+ fmt.Fprintln(&out)
+ }
+
+ fmt.Fprintf(&out, "- Base raw: `%s`\n", relOrAbs(cfg.OutDirAbs, comparison.BaseFile))
+ fmt.Fprintf(&out, "- Target raw: `%s`\n", relOrAbs(cfg.OutDirAbs, comparison.TargetFile))
+ fmt.Fprintf(&out, "- Benchstat: `%s`\n\n", relOrAbs(cfg.OutDirAbs, comparison.BenchstatFile))
+
+ fmt.Fprintln(&out, "```text")
+ fmt.Fprint(&out, comparison.Benchstat)
+ if len(comparison.Benchstat) == 0 || comparison.Benchstat[len(comparison.Benchstat)-1] != '\n' {
+ fmt.Fprintln(&out)
+ }
+ fmt.Fprintln(&out, "```")
+ fmt.Fprintln(&out)
+
+ writeRegressionSection(&out, comparison, cfg.Threshold)
+ }
+
+ writeAllExecutedNumbers(&out, summary)
+
+ if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
+ return err
+ }
+
+ return os.WriteFile(path, out.Bytes(), 0644)
+}
+
+func writeAllExecutedNumbers(out *bytes.Buffer, summary runSummary) {
+ fmt.Fprintln(out, "## All Executed Benchmark Numbers")
+ fmt.Fprintln(out)
+
+ if len(summary.Comparisons) == 0 {
+ fmt.Fprintln(out, "No benchmark numbers were captured.")
+ fmt.Fprintln(out)
+ return
+ }
+
+ for _, comparison := range summary.Comparisons {
+ fmt.Fprintf(out, "### %s\n\n", comparison.Name)
+ if len(comparison.Findings.Results) == 0 {
+ fmt.Fprintln(out, "No benchmark numbers were captured.")
+ fmt.Fprintln(out)
+ continue
+ }
+
+ fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change | Base Samples | Target Samples |")
+ fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|-------------:|---------------:|")
+ for _, result := range comparison.Findings.Results {
+ fmt.Fprintf(out, "| `%s` | %s | %s | %s | %s | %s |\n",
+ result.Name,
+ formatOptionalNS(result.HasBase, result.BaseMedianNS),
+ formatOptionalNS(result.HasTarget, result.TargetMedianNS),
+ formatOptionalPercent(result.HasBase && result.HasTarget, result.DeltaPercent),
+ formatOptionalInt(result.HasBase, result.BaseSamples),
+ formatOptionalInt(result.HasTarget, result.TargetSamples),
+ )
+ }
+ fmt.Fprintln(out)
+ }
+}
+
+func writeFindingsSummary(out *bytes.Buffer, summary runSummary) {
+ fmt.Fprintln(out, "## Findings")
+ fmt.Fprintln(out)
+
+ if len(summary.Comparisons) == 0 {
+ fmt.Fprintln(out, "No benchmark comparisons were run.")
+ fmt.Fprintln(out)
+ return
+ }
+
+ for _, comparison := range summary.Comparisons {
+ findings := comparison.Findings
+ fmt.Fprintf(out, "### %s\n\n", comparison.Name)
+ fmt.Fprintf(out, "- Compared %d matching benchmark%s.\n", findings.Compared, pluralSuffix(findings.Compared))
+ fmt.Fprintf(out, "- Median regressions: %d; median improvements: %d; unchanged: %d.\n",
+ len(findings.Regressions),
+ len(findings.Improvements),
+ findings.Unchanged,
+ )
+ if len(findings.OnlyBase) > 0 {
+ fmt.Fprintf(out, "- Only in base: %s.\n", inlineBenchmarkList(findings.OnlyBase, maxFindingRows))
+ }
+ if len(findings.OnlyTarget) > 0 {
+ fmt.Fprintf(out, "- Only in target: %s.\n", inlineBenchmarkList(findings.OnlyTarget, maxFindingRows))
+ }
+ fmt.Fprintln(out)
+
+ writeFindingTable(out, "Top Median Regressions", findings.Regressions, maxFindingRows)
+ writeFindingTable(out, "Top Median Improvements", findings.Improvements, maxFindingRows)
+ }
+}
+
+func writeFindingTable(out *bytes.Buffer, title string, findings []benchmarkFinding, limit int) {
+ fmt.Fprintf(out, "#### %s\n\n", title)
+ if len(findings) == 0 {
+ fmt.Fprintln(out, "None.")
+ fmt.Fprintln(out)
+ return
+ }
+
+ fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change |")
+ fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|")
+
+ for idx, finding := range findings {
+ if idx >= limit {
+ break
+ }
+
+ fmt.Fprintf(out, "| `%s` | %s | %s | %+.2f%% |\n",
+ finding.Name,
+ formatNS(finding.BaseMedianNS),
+ formatNS(finding.TargetMedianNS),
+ finding.DeltaPercent,
+ )
+ }
+ if len(findings) > limit {
+ fmt.Fprintf(out, "\n_%d more not shown._\n", len(findings)-limit)
+ }
+ fmt.Fprintln(out)
+}
+
+func pluralSuffix(count int) string {
+ if count == 1 {
+ return ""
+ }
+
+ return "s"
+}
+
+func inlineBenchmarkList(names []string, limit int) string {
+ var builder strings.Builder
+
+ for idx, name := range names {
+ if idx >= limit {
+ break
+ }
+ if idx > 0 {
+ builder.WriteString(", ")
+ }
+ builder.WriteByte('`')
+ builder.WriteString(name)
+ builder.WriteByte('`')
+ }
+
+ if len(names) > limit {
+ fmt.Fprintf(&builder, ", and %d more", len(names)-limit)
+ }
+
+ return builder.String()
+}
+
+func formatNS(value float64) string {
+ switch {
+ case value >= float64(time.Second):
+ return fmt.Sprintf("%.2fs", value/float64(time.Second))
+ case value >= float64(time.Millisecond):
+ return fmt.Sprintf("%.2fms", value/float64(time.Millisecond))
+ case value >= float64(time.Microsecond):
+ return fmt.Sprintf("%.2fus", value/float64(time.Microsecond))
+ default:
+ return fmt.Sprintf("%.0fns", value)
+ }
+}
+
+func formatOptionalNS(ok bool, value float64) string {
+ if !ok {
+ return "-"
+ }
+
+ return formatNS(value)
+}
+
+func formatOptionalPercent(ok bool, value float64) string {
+ if !ok {
+ return "-"
+ }
+
+ return fmt.Sprintf("%+.2f%%", value)
+}
+
+func formatOptionalInt(ok bool, value int) string {
+ if !ok {
+ return "-"
+ }
+
+ return fmt.Sprintf("%d", value)
+}
+
+func writeRegressionSection(out *bytes.Buffer, comparison comparison, threshold float64) {
+ if threshold <= 0 {
+ return
+ }
+
+ fmt.Fprintf(out, "### Regressions Over %.2f%%\n\n", threshold)
+ if len(comparison.Regressions) == 0 {
+ fmt.Fprintln(out, "None.")
+ fmt.Fprintln(out)
+ return
+ }
+
+ fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change |")
+ fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|")
+ for _, regression := range comparison.Regressions {
+ fmt.Fprintf(out, "| `%s` | %s | %s | +%.2f%% |\n",
+ regression.Name,
+ formatNS(regression.BaseMedianNS),
+ formatNS(regression.TargetMedianNS),
+ regression.Percent,
+ )
+ }
+ fmt.Fprintln(out)
+}
+
+func relOrAbs(base, path string) string {
+ rel, err := filepath.Rel(base, path)
+ if err != nil || rel == "." || len(rel) >= len(path) {
+ return path
+ }
+
+ return rel
+}
diff --git a/cmd/benchdiff/report_test.go b/cmd/benchdiff/report_test.go
new file mode 100644
index 00000000..1e09694d
--- /dev/null
+++ b/cmd/benchdiff/report_test.go
@@ -0,0 +1,141 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestWriteRunReportIncludesTopLevelFindings(t *testing.T) {
+ outputPath := filepath.Join(t.TempDir(), "report.md")
+ summary := runSummary{
+ Config: resolvedConfig{
+ config: config{
+ BaseRef: "main",
+ TargetRef: "HEAD",
+ Kind: benchKindUnit,
+ Packages: "./...",
+ Bench: ".",
+ BenchCount: 3,
+ Benchtime: "1s",
+ FailRegression: "10%",
+ },
+ BaseShortSHA: "abc1234",
+ TargetShortSHA: "def5678",
+ OutDirAbs: filepath.Dir(outputPath),
+ Threshold: 10,
+ },
+ GoVersion: "go-test",
+ StartedAt: time.Date(2026, 5, 11, 1, 2, 3, 0, time.UTC),
+ FinishedAt: time.Date(2026, 5, 11, 1, 2, 4, 0, time.UTC),
+ Comparisons: []comparison{{
+ Name: "Unit Benchmarks",
+ BaseFile: filepath.Join(filepath.Dir(outputPath), "unit", "base.txt"),
+ TargetFile: filepath.Join(filepath.Dir(outputPath), "unit", "target.txt"),
+ BenchstatFile: filepath.Join(filepath.Dir(outputPath), "unit", "benchstat.txt"),
+ Benchstat: "benchstat output\n",
+ Findings: comparisonFindings{
+ Compared: 3,
+ Unchanged: 1,
+ Regressions: []benchmarkFinding{{
+ Name: "BenchmarkSlow-12",
+ BaseMedianNS: 100,
+ TargetMedianNS: 150,
+ DeltaPercent: 50,
+ }},
+ Improvements: []benchmarkFinding{{
+ Name: "BenchmarkFast-12",
+ BaseMedianNS: 200,
+ TargetMedianNS: 100,
+ DeltaPercent: -50,
+ }},
+ OnlyBase: []string{"BenchmarkRemoved-12"},
+ OnlyTarget: []string{"BenchmarkAdded-12"},
+ Results: []benchmarkResult{
+ {
+ Name: "BenchmarkAdded-12",
+ TargetMedianNS: 75,
+ TargetSamples: 1,
+ HasTarget: true,
+ },
+ {
+ Name: "BenchmarkFast-12",
+ BaseMedianNS: 200,
+ TargetMedianNS: 100,
+ DeltaPercent: -50,
+ BaseSamples: 3,
+ TargetSamples: 3,
+ HasBase: true,
+ HasTarget: true,
+ },
+ {
+ Name: "BenchmarkRemoved-12",
+ BaseMedianNS: 50,
+ BaseSamples: 1,
+ HasBase: true,
+ },
+ {
+ Name: "BenchmarkSlow-12",
+ BaseMedianNS: 100,
+ TargetMedianNS: 150,
+ DeltaPercent: 50,
+ BaseSamples: 3,
+ TargetSamples: 3,
+ HasBase: true,
+ HasTarget: true,
+ },
+ },
+ },
+ Regressions: []regression{{
+ Name: "BenchmarkSlow-12",
+ BaseMedianNS: 100,
+ TargetMedianNS: 150,
+ Percent: 50,
+ }},
+ }},
+ }
+
+ require.NoError(t, writeRunReport(outputPath, summary))
+
+ report, err := os.ReadFile(outputPath)
+ require.NoError(t, err)
+ output := string(report)
+
+ require.Contains(t, output, "## Findings")
+ require.Contains(t, output, "### Unit Benchmarks")
+ require.Contains(t, output, "- Compared 3 matching benchmarks.")
+ require.Contains(t, output, "- Median regressions: 1; median improvements: 1; unchanged: 1.")
+ require.Contains(t, output, "- Only in base: `BenchmarkRemoved-12`.")
+ require.Contains(t, output, "- Only in target: `BenchmarkAdded-12`.")
+ require.Contains(t, output, "#### Top Median Regressions")
+ require.Contains(t, output, "| `BenchmarkSlow-12` | 100ns | 150ns | +50.00% |")
+ require.Contains(t, output, "#### Top Median Improvements")
+ require.Contains(t, output, "| `BenchmarkFast-12` | 200ns | 100ns | -50.00% |")
+ require.Contains(t, output, "## Unit Benchmarks")
+ require.Contains(t, output, "benchstat output")
+ require.Contains(t, output, "## All Executed Benchmark Numbers")
+ require.Contains(t, output, "| Benchmark | Base Median | Target Median | Change | Base Samples | Target Samples |")
+ require.Contains(t, output, "| `BenchmarkSlow-12` | 100ns | 150ns | +50.00% | 3 | 3 |")
+ require.Contains(t, output, "| `BenchmarkFast-12` | 200ns | 100ns | -50.00% | 3 | 3 |")
+ require.Contains(t, output, "| `BenchmarkRemoved-12` | 50ns | - | - | 1 | - |")
+ require.Contains(t, output, "| `BenchmarkAdded-12` | - | 75ns | - | - | 1 |")
+}
diff --git a/cmd/benchdiff/run.go b/cmd/benchdiff/run.go
new file mode 100644
index 00000000..12afb66a
--- /dev/null
+++ b/cmd/benchdiff/run.go
@@ -0,0 +1,260 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "time"
+)
+
+type comparison struct {
+ Name string
+ BaseFile string
+ TargetFile string
+ BenchstatFile string
+ Benchstat string
+ Notes []string
+ Findings comparisonFindings
+ Regressions []regression
+}
+
+type runSummary struct {
+ Config resolvedConfig
+ GoVersion string
+ StartedAt time.Time
+ FinishedAt time.Time
+ Comparisons []comparison
+ ReportPath string
+}
+
+func run(ctx context.Context, cfg config) error {
+ resolved, err := resolveConfig(ctx, cfg)
+ if err != nil {
+ return err
+ }
+
+ summary := runSummary{
+ Config: resolved,
+ GoVersion: runtime.Version(),
+ StartedAt: time.Now(),
+ }
+
+ if err := os.MkdirAll(resolved.OutDirAbs, 0755); err != nil {
+ return fmt.Errorf("create output directory: %w", err)
+ }
+
+ worktreeRoot := filepath.Join(resolved.OutDirAbs, "worktrees")
+ baseWorktree := filepath.Join(worktreeRoot, "base")
+ targetWorktree := filepath.Join(worktreeRoot, "target")
+
+ fmt.Fprintf(os.Stderr, "preparing benchmark worktrees for %s and %s...\n", resolved.BaseShortSHA, resolved.TargetShortSHA)
+
+ if err := addWorktree(ctx, resolved.Root, baseWorktree, resolved.BaseSHA); err != nil {
+ return err
+ }
+ removeBase := true
+ defer func() {
+ if removeBase && !resolved.KeepWorktrees {
+ _ = removeWorktree(context.Background(), resolved.Root, baseWorktree)
+ }
+ }()
+
+ if err := addWorktree(ctx, resolved.Root, targetWorktree, resolved.TargetSHA); err != nil {
+ return err
+ }
+ removeTarget := true
+ defer func() {
+ if removeTarget && !resolved.KeepWorktrees {
+ _ = removeWorktree(context.Background(), resolved.Root, targetWorktree)
+ }
+ }()
+
+ if resolved.runsUnitBenchmarks() {
+ fmt.Fprintln(os.Stderr, "running unit benchmarks...")
+ unitComparison, err := runUnitComparison(ctx, resolved, baseWorktree, targetWorktree)
+ if err != nil {
+ return err
+ }
+ summary.Comparisons = append(summary.Comparisons, unitComparison)
+ }
+
+ if resolved.runsIntegrationBenchmarks() {
+ fmt.Fprintln(os.Stderr, "running integration benchmarks...")
+ integrationComparison, err := runIntegrationComparison(ctx, resolved, baseWorktree, targetWorktree)
+ if err != nil {
+ return err
+ }
+ summary.Comparisons = append(summary.Comparisons, integrationComparison)
+ }
+
+ removeBase = false
+ removeTarget = false
+ if !resolved.KeepWorktrees {
+ if err := removeWorktree(ctx, resolved.Root, baseWorktree); err != nil {
+ return err
+ }
+ if err := removeWorktree(ctx, resolved.Root, targetWorktree); err != nil {
+ return err
+ }
+ }
+
+ summary.FinishedAt = time.Now()
+ summary.ReportPath = filepath.Join(resolved.OutDirAbs, "report.md")
+ if err := writeRunReport(summary.ReportPath, summary); err != nil {
+ return err
+ }
+
+ fmt.Fprintf(os.Stderr, "wrote benchmark diff report: %s\n", summary.ReportPath)
+
+ if regressions := summary.regressions(); len(regressions) > 0 && resolved.Threshold > 0 {
+ return fmt.Errorf("%d benchmark regressions exceeded %.2f%%; see %s", len(regressions), resolved.Threshold, summary.ReportPath)
+ }
+
+ return nil
+}
+
+func resolveConfig(ctx context.Context, cfg config) (resolvedConfig, error) {
+ root, err := gitOutput(ctx, "", "rev-parse", "--show-toplevel")
+ if err != nil {
+ return resolvedConfig{}, err
+ }
+
+ baseSHA, err := resolveCommit(ctx, root, cfg.BaseRef)
+ if err != nil {
+ return resolvedConfig{}, err
+ }
+ targetSHA, err := resolveCommit(ctx, root, cfg.TargetRef)
+ if err != nil {
+ return resolvedConfig{}, err
+ }
+
+ baseShortSHA, err := shortCommit(ctx, root, baseSHA)
+ if err != nil {
+ return resolvedConfig{}, err
+ }
+ targetShortSHA, err := shortCommit(ctx, root, targetSHA)
+ if err != nil {
+ return resolvedConfig{}, err
+ }
+
+ threshold, err := parseRegressionThreshold(cfg.FailRegression)
+ if err != nil {
+ return resolvedConfig{}, err
+ }
+
+ if cfg.runsIntegrationBenchmarks() && cfg.Connection == "" {
+ cfg.Connection = os.Getenv("CONNECTION_STRING")
+ }
+ if cfg.runsIntegrationBenchmarks() && cfg.Connection == "" {
+ return resolvedConfig{}, fmt.Errorf("integration benchmarks require -connection or CONNECTION_STRING")
+ }
+
+ datasetDirAbs, err := resolvePath(root, cfg.DatasetDir)
+ if err != nil {
+ return resolvedConfig{}, err
+ }
+
+ outDir := cfg.OutDir
+ if outDir == "" {
+ outDir = filepath.Join(".bench", "runs", fmt.Sprintf("%s..%s-%s", baseShortSHA, targetShortSHA, time.Now().UTC().Format("20060102T150405Z")))
+ }
+ outDirAbs, err := resolvePath(root, outDir)
+ if err != nil {
+ return resolvedConfig{}, err
+ }
+
+ return resolvedConfig{
+ config: cfg,
+ Root: root,
+ BaseSHA: baseSHA,
+ TargetSHA: targetSHA,
+ BaseShortSHA: baseShortSHA,
+ TargetShortSHA: targetShortSHA,
+ DatasetDirAbs: datasetDirAbs,
+ OutDirAbs: outDirAbs,
+ Threshold: threshold,
+ }, nil
+}
+
+func resolvePath(root, value string) (string, error) {
+ if filepath.IsAbs(value) {
+ return filepath.Clean(value), nil
+ }
+
+ return filepath.Abs(filepath.Join(root, value))
+}
+
+func resolveCommit(ctx context.Context, root, ref string) (string, error) {
+ sha, err := gitOutput(ctx, root, "rev-parse", "--verify", ref+"^{commit}")
+ if err != nil {
+ return "", fmt.Errorf("resolve git ref %q: %w", ref, err)
+ }
+
+ return sha, nil
+}
+
+func shortCommit(ctx context.Context, root, sha string) (string, error) {
+ shortSHA, err := gitOutput(ctx, root, "rev-parse", "--short", sha)
+ if err != nil {
+ return "", err
+ }
+
+ return shortSHA, nil
+}
+
+func addWorktree(ctx context.Context, root, path, sha string) error {
+ if _, err := os.Stat(path); err == nil {
+ return fmt.Errorf("worktree path already exists: %s", path)
+ } else if !os.IsNotExist(err) {
+ return err
+ }
+
+ if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
+ return err
+ }
+
+ _, err := runCommand(ctx, root, nil, "git", "worktree", "add", "--detach", path, sha)
+ if err != nil {
+ return fmt.Errorf("add worktree %s: %w", path, err)
+ }
+
+ return nil
+}
+
+func removeWorktree(ctx context.Context, root, path string) error {
+ _, err := runCommand(ctx, root, nil, "git", "worktree", "remove", "--force", path)
+ if err != nil && !strings.Contains(err.Error(), "is not a working tree") {
+ return fmt.Errorf("remove worktree %s: %w", path, err)
+ }
+
+ return nil
+}
+
+func (summary runSummary) regressions() []regression {
+ var regressions []regression
+
+ for _, comparison := range summary.Comparisons {
+ regressions = append(regressions, comparison.Regressions...)
+ }
+
+ return regressions
+}
diff --git a/cmd/benchmark/README.md b/cmd/benchmark/README.md
index 741118f5..c05d4c7a 100644
--- a/cmd/benchmark/README.md
+++ b/cmd/benchmark/README.md
@@ -1,13 +1,16 @@
# Benchmark
-Runs query scenarios against a real database and outputs a markdown timing table.
+Runs query scenarios against a real database and outputs markdown, JSON, or benchfmt timing data.
## Usage
```bash
-# Default dataset (base)
+# Default datasets (base and traversal_shapes)
go run ./cmd/benchmark -connection "postgresql://dawgs:dawgs@localhost:5432/dawgs"
+# Traversal shape dataset only
+go run ./cmd/benchmark -connection "..." -dataset traversal_shapes
+
# Local dataset (not committed to repo)
go run ./cmd/benchmark -connection "..." -dataset local/phantom
@@ -20,6 +23,9 @@ go run ./cmd/benchmark -driver neo4j -connection "neo4j://neo4j:password@localho
# Save to file
go run ./cmd/benchmark -connection "..." -output report.md
+# Emit benchfmt for benchstat
+go run ./cmd/benchmark -connection "..." -format benchfmt -output report.bench
+
# Save markdown and JSON for quality baseline comparison
go run ./cmd/benchmark -connection "..." -output report.md -json-output report.json
```
@@ -34,9 +40,17 @@ go run ./cmd/benchmark -connection "..." -output report.md -json-output report.j
| `-dataset` | | Run only this dataset |
| `-local-dataset` | | Add a local dataset to the default set |
| `-dataset-dir` | `integration/testdata` | Path to testdata directory |
-| `-output` | stdout | Markdown output file |
+| `-format` | `markdown` | Output format (`markdown`, `json`, `benchfmt`) |
+| `-output` | stdout | Output file |
| `-json-output` | | JSON output file for baseline comparison |
+Use `-format benchfmt` when comparing scenario timings with `benchstat`. Each timed scenario iteration is emitted as a
+separate `ns/op` sample so two benchmark runs can be compared directly.
+
+The committed default datasets are `base` and `traversal_shapes`. `traversal_shapes` covers chain, fanout, bounded
+cycle, disconnected, edge-kind-selective, and multi-path shortest-path traversal shapes. Scenarios with declared
+expected row counts fail before reporting timings if a query returns the wrong result shape.
+
## Example: Neo4j on local/phantom
```
diff --git a/cmd/benchmark/main.go b/cmd/benchmark/main.go
index bb857d5b..dbb0e64b 100644
--- a/cmd/benchmark/main.go
+++ b/cmd/benchmark/main.go
@@ -40,7 +40,8 @@ func main() {
driver = flag.String("driver", "pg", "database driver (pg, neo4j)")
connStr = flag.String("connection", "", "database connection string (or CONNECTION_STRING)")
iterations = flag.Int("iterations", 10, "timed iterations per scenario")
- output = flag.String("output", "", "markdown output file (default: stdout)")
+ output = flag.String("output", "", "output file (default: stdout)")
+ format = flag.String("format", reportFormatMarkdown, "output format (markdown, json, benchfmt)")
jsonOutput = flag.String("json-output", "", "JSON output file for baseline comparison")
datasetDir = flag.String("dataset-dir", "integration/testdata", "path to testdata directory")
localDataset = flag.String("local-dataset", "", "additional local dataset (e.g. local/phantom)")
@@ -50,6 +51,13 @@ func main() {
flag.Parse()
+ if *iterations < 1 {
+ fatal("iterations must be at least 1")
+ }
+ if !isReportFormat(*format) {
+ fatal("unsupported output format %q", *format)
+ }
+
conn := *connStr
if conn == "" {
conn = os.Getenv("CONNECTION_STRING")
@@ -154,21 +162,21 @@ func main() {
}
}
- // Write markdown
- var mdOut *os.File
+ // Write primary report.
+ var out *os.File
if *output != "" {
var err error
- mdOut, err = os.Create(*output)
+ out, err = os.Create(*output)
if err != nil {
fatal("failed to create output: %v", err)
}
- defer mdOut.Close()
+ defer out.Close()
} else {
- mdOut = os.Stdout
+ out = os.Stdout
}
- if err := writeMarkdown(mdOut, report); err != nil {
- fatal("failed to write markdown: %v", err)
+ if err := writeReport(out, report, *format); err != nil {
+ fatal("failed to write report: %v", err)
}
if *output != "" {
diff --git a/cmd/benchmark/report.go b/cmd/benchmark/report.go
index dacab7dd..841baaa1 100644
--- a/cmd/benchmark/report.go
+++ b/cmd/benchmark/report.go
@@ -20,7 +20,16 @@ import (
"encoding/json"
"fmt"
"io"
+ "runtime"
+ "strings"
"time"
+ "unicode"
+)
+
+const (
+ reportFormatBenchfmt = "benchfmt"
+ reportFormatJSON = "json"
+ reportFormatMarkdown = "markdown"
)
// Report holds all benchmark results and metadata.
@@ -32,6 +41,30 @@ type Report struct {
Results []Result `json:"results"`
}
+func writeReport(w io.Writer, r Report, format string) error {
+ if !isReportFormat(format) {
+ return fmt.Errorf("unsupported output format %q", format)
+ }
+
+ switch format {
+ case reportFormatBenchfmt:
+ return writeBenchfmt(w, r)
+ case reportFormatJSON:
+ return writeJSON(w, r)
+ default:
+ return writeMarkdown(w, r)
+ }
+}
+
+func isReportFormat(format string) bool {
+ switch format {
+ case reportFormatBenchfmt, reportFormatJSON, reportFormatMarkdown:
+ return true
+ default:
+ return false
+ }
+}
+
func writeJSON(w io.Writer, r Report) error {
encoder := json.NewEncoder(w)
encoder.SetIndent("", " ")
@@ -62,6 +95,77 @@ func writeMarkdown(w io.Writer, r Report) error {
return nil
}
+func writeBenchfmt(w io.Writer, r Report) error {
+ goos := runtime.GOOS
+ goarch := runtime.GOARCH
+ procs := runtime.GOMAXPROCS(0)
+
+ fmt.Fprintf(w, "goos: %s\n", goos)
+ fmt.Fprintf(w, "goarch: %s\n", goarch)
+ fmt.Fprintf(w, "pkg: github.com/specterops/dawgs/cmd/benchmark\n")
+
+ for _, res := range r.Results {
+ benchName := benchName(r.Driver, res)
+
+ for _, sample := range res.Samples {
+ fmt.Fprintf(w, "%s-%d\t1\t%d ns/op\n", benchName, procs, sample.Nanoseconds())
+ }
+ }
+
+ return nil
+}
+
+func benchName(driver string, res Result) string {
+ parts := []string{
+ "BenchmarkDawgsIntegration",
+ sanitizeBenchNamePart(driver),
+ sanitizeBenchNamePart(res.Dataset),
+ sanitizeBenchNamePart(res.Section),
+ sanitizeBenchNamePart(res.Label),
+ }
+
+ return strings.Join(parts, "/")
+}
+
+func sanitizeBenchNamePart(value string) string {
+ var builder strings.Builder
+ lastUnderscore := false
+
+ for _, char := range value {
+ switch {
+ case char == '/' || char == '-' || char == '_':
+ if char == '_' {
+ if !lastUnderscore {
+ builder.WriteRune(char)
+ }
+ lastUnderscore = true
+ } else {
+ builder.WriteRune(char)
+ lastUnderscore = false
+ }
+ case unicode.IsLetter(char) || unicode.IsDigit(char):
+ builder.WriteRune(char)
+ lastUnderscore = false
+ case unicode.IsSpace(char):
+ if !lastUnderscore {
+ builder.WriteByte('_')
+ }
+ lastUnderscore = true
+ default:
+ if !lastUnderscore {
+ builder.WriteByte('_')
+ }
+ lastUnderscore = true
+ }
+ }
+
+ if builder.Len() == 0 {
+ return "unknown"
+ }
+
+ return builder.String()
+}
+
func fmtDuration(d time.Duration) string {
ms := float64(d.Microseconds()) / 1000.0
if ms < 1 {
diff --git a/cmd/benchmark/report_test.go b/cmd/benchmark/report_test.go
index 2d72ed4d..c4bea53f 100644
--- a/cmd/benchmark/report_test.go
+++ b/cmd/benchmark/report_test.go
@@ -21,8 +21,82 @@ import (
"strings"
"testing"
"time"
+
+ "github.com/stretchr/testify/require"
)
+func TestWriteReportRejectsUnknownFormat(t *testing.T) {
+ err := writeReport(&bytes.Buffer{}, Report{}, "xml")
+ require.ErrorContains(t, err, "unsupported output format")
+}
+
+func TestWriteJSON(t *testing.T) {
+ report := testReport()
+ var out bytes.Buffer
+
+ require.NoError(t, writeReport(&out, report, reportFormatJSON))
+
+ require.Contains(t, out.String(), `"driver": "pg"`)
+ require.Contains(t, out.String(), `"samples": [`)
+ require.Contains(t, out.String(), `1000000`)
+}
+
+func TestWriteBenchfmt(t *testing.T) {
+ report := testReport()
+ var out bytes.Buffer
+
+ require.NoError(t, writeReport(&out, report, reportFormatBenchfmt))
+
+ output := out.String()
+ require.Contains(t, output, "goos: ")
+ require.Contains(t, output, "goarch: ")
+ require.Contains(t, output, "pkg: github.com/specterops/dawgs/cmd/benchmark")
+ require.Contains(t, output, "BenchmarkDawgsIntegration/pg/base/Match_Nodes/base-")
+ require.Contains(t, output, "\t1\t1000000 ns/op")
+ require.Contains(t, output, "\t1\t2000000 ns/op")
+}
+
+func TestSanitizeBenchNamePart(t *testing.T) {
+ require.Equal(t, "Shortest_Paths", sanitizeBenchNamePart("Shortest Paths"))
+ require.Equal(t, "n1_-_n3", sanitizeBenchNamePart("n1 -> n3"))
+ require.Equal(t, "local/phantom", sanitizeBenchNamePart("local/phantom"))
+ require.Equal(t, "unknown", sanitizeBenchNamePart(""))
+}
+
+func TestWriteMarkdownOmitsSamples(t *testing.T) {
+ report := testReport()
+ var out bytes.Buffer
+
+ require.NoError(t, writeReport(&out, report, reportFormatMarkdown))
+
+ output := out.String()
+ require.Contains(t, output, "| Match Nodes | base | 2.0ms | 2.0ms | 2.0ms |")
+ require.False(t, strings.Contains(output, "1000000"))
+}
+
+func testReport() Report {
+ return Report{
+ Driver: "pg",
+ GitRef: "abcdef0",
+ Date: "2026-05-11",
+ Iterations: 2,
+ Results: []Result{{
+ Section: "Match Nodes",
+ Dataset: "base",
+ Label: "base",
+ Stats: Stats{
+ Median: 2 * time.Millisecond,
+ P95: 2 * time.Millisecond,
+ Max: 2 * time.Millisecond,
+ },
+ Samples: []time.Duration{
+ time.Millisecond,
+ 2 * time.Millisecond,
+ },
+ }},
+ }
+}
+
func TestWriteJSONEmitsBaselineFriendlyReport(t *testing.T) {
report := Report{
Driver: "pg",
diff --git a/cmd/benchmark/runner.go b/cmd/benchmark/runner.go
index b146f11d..d74eb698 100644
--- a/cmd/benchmark/runner.go
+++ b/cmd/benchmark/runner.go
@@ -18,6 +18,7 @@ package main
import (
"context"
+ "fmt"
"sort"
"time"
@@ -33,16 +34,17 @@ type Stats struct {
// Result is one row in the report.
type Result struct {
- Section string `json:"section"`
- Dataset string `json:"dataset"`
- Label string `json:"label"`
- Stats Stats `json:"stats"`
+ Section string `json:"section"`
+ Dataset string `json:"dataset"`
+ Label string `json:"label"`
+ Stats Stats `json:"stats"`
+ Samples []time.Duration `json:"samples,omitempty"`
}
// runScenario executes a scenario N times and returns timing stats.
func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations int) (Result, error) {
// Warm-up: one untimed run.
- if err := db.ReadTransaction(ctx, s.Query); err != nil {
+ if err := runScenarioOnce(ctx, db, s); err != nil {
return Result{}, err
}
@@ -50,7 +52,7 @@ func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations
for i := range iterations {
start := time.Now()
- if err := db.ReadTransaction(ctx, s.Query); err != nil {
+ if err := runScenarioOnce(ctx, db, s); err != nil {
return Result{}, err
}
durations[i] = time.Since(start)
@@ -61,9 +63,29 @@ func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations
Dataset: s.Dataset,
Label: s.Label,
Stats: computeStats(durations),
+ Samples: durations,
}, nil
}
+func runScenarioOnce(ctx context.Context, db graph.Database, s Scenario) error {
+ return db.ReadTransaction(ctx, func(tx graph.Transaction) error {
+ rows, err := s.Query(tx)
+ if err != nil {
+ return err
+ }
+
+ return validateScenarioRows(s, rows)
+ })
+}
+
+func validateScenarioRows(s Scenario, actualRows int) error {
+ if s.ExpectedRows == nil || *s.ExpectedRows == actualRows {
+ return nil
+ }
+
+ return fmt.Errorf("%s/%s on %s expected %d rows, got %d", s.Section, s.Label, s.Dataset, *s.ExpectedRows, actualRows)
+}
+
func computeStats(durations []time.Duration) Stats {
sort.Slice(durations, func(i, j int) bool { return durations[i] < durations[j] })
diff --git a/cmd/benchmark/scenarios.go b/cmd/benchmark/scenarios.go
index 217ae63d..c445d966 100644
--- a/cmd/benchmark/scenarios.go
+++ b/cmd/benchmark/scenarios.go
@@ -25,20 +25,25 @@ import (
// Scenario defines a single benchmark query to run against a loaded dataset.
type Scenario struct {
- Section string // grouping key in the report (e.g. "Match Nodes")
- Dataset string
- Label string // human-readable row label
- Query func(tx graph.Transaction) error
+ Section string // grouping key in the report (e.g. "Match Nodes")
+ Dataset string
+ Label string // human-readable row label
+ ExpectedRows *int
+ Query func(tx graph.Transaction) (int, error)
}
+const traversalShapesDataset = "traversal_shapes"
+
// defaultDatasets is the set of datasets committed to the repo.
-var defaultDatasets = []string{"base"}
+var defaultDatasets = []string{"base", traversalShapesDataset}
// scenariosForDataset returns all benchmark scenarios for a given dataset and its loaded ID map.
func scenariosForDataset(dataset string, idMap opengraph.IDMap) []Scenario {
switch dataset {
case "base":
return baseScenarios(idMap)
+ case traversalShapesDataset:
+ return traversalShapesScenarios(idMap)
case "local/phantom":
return phantomScenarios(idMap)
default:
@@ -46,23 +51,31 @@ func scenariosForDataset(dataset string, idMap opengraph.IDMap) []Scenario {
}
}
-func countNodes(tx graph.Transaction) error {
- _, err := tx.Nodes().Count()
- return err
+func expectRows(rows int) *int {
+ return &rows
+}
+
+func countNodes(tx graph.Transaction) (int, error) {
+ count, err := tx.Nodes().Count()
+ return int(count), err
}
-func countEdges(tx graph.Transaction) error {
- _, err := tx.Relationships().Count()
- return err
+func countEdges(tx graph.Transaction) (int, error) {
+ count, err := tx.Relationships().Count()
+ return int(count), err
}
-func cypherQuery(cypher string) func(tx graph.Transaction) error {
- return func(tx graph.Transaction) error {
+func cypherQuery(cypher string) func(tx graph.Transaction) (int, error) {
+ return func(tx graph.Transaction) (int, error) {
result := tx.Query(cypher, nil)
defer result.Close()
+
+ rows := 0
for result.Next() {
+ rows++
}
- return result.Error()
+
+ return rows, result.Error()
}
}
@@ -71,22 +84,80 @@ func cypherQuery(cypher string) func(tx graph.Transaction) error {
func baseScenarios(idMap opengraph.IDMap) []Scenario {
ds := "base"
return []Scenario{
- {Section: "Match Nodes", Dataset: ds, Label: ds, Query: countNodes},
- {Section: "Match Edges", Dataset: ds, Label: ds, Query: countEdges},
- {Section: "Shortest Paths", Dataset: ds, Label: "n1 -> n3", Query: cypherQuery(fmt.Sprintf(
+ {Section: "Match Nodes", Dataset: ds, Label: ds, ExpectedRows: expectRows(3), Query: countNodes},
+ {Section: "Match Edges", Dataset: ds, Label: ds, ExpectedRows: expectRows(2), Query: countEdges},
+ {Section: "Shortest Paths", Dataset: ds, Label: "n1 -> n3", ExpectedRows: expectRows(1), Query: cypherQuery(fmt.Sprintf(
"MATCH p = allShortestPaths((s)-[*1..]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p",
idMap["n1"], idMap["n3"],
))},
- {Section: "Traversal", Dataset: ds, Label: "n1", Query: cypherQuery(fmt.Sprintf(
+ {Section: "Traversal", Dataset: ds, Label: "n1", ExpectedRows: expectRows(2), Query: cypherQuery(fmt.Sprintf(
"MATCH (s)-[*1..]->(e) WHERE id(s) = %d RETURN e",
idMap["n1"],
))},
- {Section: "Match Return", Dataset: ds, Label: "n1", Query: cypherQuery(fmt.Sprintf(
+ {Section: "Match Return", Dataset: ds, Label: "n1", ExpectedRows: expectRows(1), Query: cypherQuery(fmt.Sprintf(
"MATCH (s)-[]->(e) WHERE id(s) = %d RETURN e",
idMap["n1"],
))},
- {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind1", Query: cypherQuery("MATCH (n:NodeKind1) RETURN n")},
- {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind2", Query: cypherQuery("MATCH (n:NodeKind2) RETURN n")},
+ {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind1", ExpectedRows: expectRows(2), Query: cypherQuery("MATCH (n:NodeKind1) RETURN n")},
+ {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind2", ExpectedRows: expectRows(2), Query: cypherQuery("MATCH (n:NodeKind2) RETURN n")},
+ }
+}
+
+// --- Traversal shape scenarios ---
+
+func traversalShapesScenarios(idMap opengraph.IDMap) []Scenario {
+ ds := traversalShapesDataset
+ return []Scenario{
+ {Section: "Match Nodes", Dataset: ds, Label: ds, ExpectedRows: expectRows(45), Query: countNodes},
+ {Section: "Match Edges", Dataset: ds, Label: ds, ExpectedRows: expectRows(41), Query: countEdges},
+ {Section: "Traversal Depth", Dataset: ds, Label: "chain depth 1", ExpectedRows: expectRows(1), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[:ChainEdge*1..1]->(e) WHERE id(s) = %d RETURN e",
+ idMap["c0"],
+ ))},
+ {Section: "Traversal Depth", Dataset: ds, Label: "chain depth 3", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[:ChainEdge*1..3]->(e) WHERE id(s) = %d RETURN e",
+ idMap["c0"],
+ ))},
+ {Section: "Traversal Depth", Dataset: ds, Label: "chain depth 10", ExpectedRows: expectRows(10), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[:ChainEdge*1..10]->(e) WHERE id(s) = %d RETURN e",
+ idMap["c0"],
+ ))},
+ {Section: "Traversal Depth", Dataset: ds, Label: "fanout depth 1", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[:FanoutEdge*1..1]->(e) WHERE id(s) = %d RETURN e",
+ idMap["f0"],
+ ))},
+ {Section: "Traversal Depth", Dataset: ds, Label: "fanout depth 2", ExpectedRows: expectRows(9), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[:FanoutEdge*1..2]->(e) WHERE id(s) = %d RETURN e",
+ idMap["f0"],
+ ))},
+ {Section: "Traversal Depth", Dataset: ds, Label: "fanout depth 3", ExpectedRows: expectRows(15), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[:FanoutEdge*1..3]->(e) WHERE id(s) = %d RETURN e",
+ idMap["f0"],
+ ))},
+ {Section: "Traversal Cycle", Dataset: ds, Label: "bounded cycle", ExpectedRows: expectRows(4), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[:CycleEdge*1..4]->(e) WHERE id(s) = %d RETURN e",
+ idMap["y0"],
+ ))},
+ {Section: "Traversal Dead End", Dataset: ds, Label: "chain terminal", ExpectedRows: expectRows(0), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[:ChainEdge*1..]->(e) WHERE id(s) = %d RETURN e",
+ idMap["c10"],
+ ))},
+ {Section: "Edge Kind Traversal", Dataset: ds, Label: "Allowed", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[:Allowed*1..]->(e) WHERE id(s) = %d RETURN e",
+ idMap["s0"],
+ ))},
+ {Section: "Edge Kind Traversal", Dataset: ds, Label: "all kinds", ExpectedRows: expectRows(6), Query: cypherQuery(fmt.Sprintf(
+ "MATCH (s)-[*1..]->(e) WHERE id(s) = %d RETURN e",
+ idMap["s0"],
+ ))},
+ {Section: "Shortest Paths", Dataset: ds, Label: "diamond many paths", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf(
+ "MATCH p = allShortestPaths((s)-[*1..]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p",
+ idMap["d0"], idMap["d4"],
+ ))},
+ {Section: "Shortest Paths", Dataset: ds, Label: "disconnected", ExpectedRows: expectRows(0), Query: cypherQuery(fmt.Sprintf(
+ "MATCH p = allShortestPaths((s)-[*1..]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p",
+ idMap["x0"], idMap["x1"],
+ ))},
}
}
diff --git a/cmd/benchmark/scenarios_test.go b/cmd/benchmark/scenarios_test.go
new file mode 100644
index 00000000..647ee8ae
--- /dev/null
+++ b/cmd/benchmark/scenarios_test.go
@@ -0,0 +1,120 @@
+// Copyright 2026 Specter Ops, Inc.
+//
+// Licensed under the Apache License, Version 2.0
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+package main
+
+import (
+ "os"
+ "testing"
+
+ "github.com/specterops/dawgs/graph"
+ "github.com/specterops/dawgs/opengraph"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBaseScenariosDeclareExpectedRows(t *testing.T) {
+ scenarios := baseScenarios(opengraph.IDMap{
+ "n1": graph.ID(1),
+ "n2": graph.ID(2),
+ "n3": graph.ID(3),
+ })
+
+ requireExpectedRows(t, scenarios, "Match Nodes", "base", 3)
+ requireExpectedRows(t, scenarios, "Match Edges", "base", 2)
+ requireExpectedRows(t, scenarios, "Shortest Paths", "n1 -> n3", 1)
+ requireExpectedRows(t, scenarios, "Traversal", "n1", 2)
+ requireExpectedRows(t, scenarios, "Match Return", "n1", 1)
+ requireExpectedRows(t, scenarios, "Filter By Kind", "NodeKind1", 2)
+ requireExpectedRows(t, scenarios, "Filter By Kind", "NodeKind2", 2)
+}
+
+func TestTraversalShapesDatasetIsValid(t *testing.T) {
+ file, err := os.Open("../../integration/testdata/traversal_shapes.json")
+ require.NoError(t, err)
+ defer file.Close()
+
+ doc, err := opengraph.ParseDocument(file)
+ require.NoError(t, err)
+ require.Len(t, doc.Graph.Nodes, 45)
+ require.Len(t, doc.Graph.Edges, 41)
+}
+
+func TestTraversalShapesScenariosDeclareExpectedRows(t *testing.T) {
+ scenarios := traversalShapesScenarios(traversalShapesIDMap())
+
+ requireExpectedRows(t, scenarios, "Match Nodes", traversalShapesDataset, 45)
+ requireExpectedRows(t, scenarios, "Match Edges", traversalShapesDataset, 41)
+ requireExpectedRows(t, scenarios, "Traversal Depth", "chain depth 1", 1)
+ requireExpectedRows(t, scenarios, "Traversal Depth", "chain depth 3", 3)
+ requireExpectedRows(t, scenarios, "Traversal Depth", "chain depth 10", 10)
+ requireExpectedRows(t, scenarios, "Traversal Depth", "fanout depth 1", 3)
+ requireExpectedRows(t, scenarios, "Traversal Depth", "fanout depth 2", 9)
+ requireExpectedRows(t, scenarios, "Traversal Depth", "fanout depth 3", 15)
+ requireExpectedRows(t, scenarios, "Traversal Cycle", "bounded cycle", 4)
+ requireExpectedRows(t, scenarios, "Traversal Dead End", "chain terminal", 0)
+ requireExpectedRows(t, scenarios, "Edge Kind Traversal", "Allowed", 3)
+ requireExpectedRows(t, scenarios, "Edge Kind Traversal", "all kinds", 6)
+ requireExpectedRows(t, scenarios, "Shortest Paths", "diamond many paths", 3)
+ requireExpectedRows(t, scenarios, "Shortest Paths", "disconnected", 0)
+}
+
+func TestDefaultDatasetsIncludeTraversalShapes(t *testing.T) {
+ require.Contains(t, defaultDatasets, traversalShapesDataset)
+}
+
+func TestValidateScenarioRows(t *testing.T) {
+ scenario := Scenario{
+ Section: "Traversal",
+ Dataset: "base",
+ Label: "n1",
+ ExpectedRows: expectRows(2),
+ }
+
+ require.NoError(t, validateScenarioRows(scenario, 2))
+ require.ErrorContains(t, validateScenarioRows(scenario, 1), "Traversal/n1 on base expected 2 rows, got 1")
+}
+
+func traversalShapesIDMap() opengraph.IDMap {
+ ids := []string{
+ "c0", "c10",
+ "f0",
+ "d0", "d4",
+ "y0",
+ "x0", "x1",
+ "s0",
+ }
+
+ idMap := opengraph.IDMap{}
+ for idx, id := range ids {
+ idMap[id] = graph.ID(idx + 1)
+ }
+
+ return idMap
+}
+
+func requireExpectedRows(t *testing.T, scenarios []Scenario, section, label string, expectedRows int) {
+ t.Helper()
+
+ for _, scenario := range scenarios {
+ if scenario.Section == section && scenario.Label == label {
+ require.NotNil(t, scenario.ExpectedRows)
+ require.Equal(t, expectedRows, *scenario.ExpectedRows)
+ return
+ }
+ }
+
+ require.Failf(t, "scenario not found", "%s/%s", section, label)
+}
diff --git a/cypher/models/cypher/format/format.go b/cypher/models/cypher/format/format.go
index 495cf806..a05d7b1a 100644
--- a/cypher/models/cypher/format/format.go
+++ b/cypher/models/cypher/format/format.go
@@ -80,14 +80,11 @@ func (s Emitter) formatNodePattern(output io.Writer, nodePattern *cypher.NodePat
func (s Emitter) formatRelationshipPattern(output io.Writer, relationshipPattern *cypher.RelationshipPattern) error {
switch relationshipPattern.Direction {
- case graph.DirectionOutbound:
+ case graph.DirectionOutbound, graph.DirectionBoth:
if _, err := io.WriteString(output, "-["); err != nil {
return err
}
- case graph.DirectionBoth:
- fallthrough
-
case graph.DirectionInbound:
if _, err := io.WriteString(output, "<-["); err != nil {
return err
@@ -147,14 +144,11 @@ func (s Emitter) formatRelationshipPattern(output io.Writer, relationshipPattern
}
switch relationshipPattern.Direction {
- case graph.DirectionInbound:
+ case graph.DirectionInbound, graph.DirectionBoth:
if _, err := io.WriteString(output, "]-"); err != nil {
return err
}
- case graph.DirectionBoth:
- fallthrough
-
case graph.DirectionOutbound:
if _, err := io.WriteString(output, "]->"); err != nil {
return err
@@ -296,7 +290,7 @@ func (s Emitter) formatProjection(output io.Writer, projection *cypher.Projectio
}
func (s Emitter) formatReturn(output io.Writer, returnClause *cypher.Return) error {
- if _, err := io.WriteString(output, " return "); err != nil {
+ if _, err := io.WriteString(output, "return "); err != nil {
return err
}
@@ -1095,6 +1089,12 @@ func (s Emitter) formatSinglePartQuery(writer io.Writer, singlePartQuery *cypher
}
if singlePartQuery.Return != nil {
+ if len(singlePartQuery.ReadingClauses) > 0 || len(singlePartQuery.UpdatingClauses) > 0 {
+ if _, err := io.WriteString(writer, " "); err != nil {
+ return err
+ }
+ }
+
return s.formatReturn(writer, singlePartQuery.Return)
}
diff --git a/cypher/models/cypher/format/format_test.go b/cypher/models/cypher/format/format_test.go
index 327f65d4..8a02d2e9 100644
--- a/cypher/models/cypher/format/format_test.go
+++ b/cypher/models/cypher/format/format_test.go
@@ -6,6 +6,7 @@ import (
"github.com/specterops/dawgs/cypher/models/cypher"
"github.com/specterops/dawgs/cypher/models/cypher/format"
+ "github.com/specterops/dawgs/graph"
"github.com/specterops/dawgs/cypher/frontend"
"github.com/stretchr/testify/require"
@@ -27,6 +28,55 @@ func TestCypherEmitter_StripLiterals(t *testing.T) {
require.Equal(t, "match (n {value: $STRIPPED}) where n.other = $STRIPPED and n.number = $STRIPPED return n.name, n", buffer.String())
}
+func TestCypherEmitter_RelationshipDirections(t *testing.T) {
+ testCases := []struct {
+ name string
+ direction graph.Direction
+ expected string
+ }{
+ {
+ name: "outbound",
+ direction: graph.DirectionOutbound,
+ expected: "match (a)-[r]->(b) return r",
+ },
+ {
+ name: "inbound",
+ direction: graph.DirectionInbound,
+ expected: "match (a)<-[r]-(b) return r",
+ },
+ {
+ name: "both",
+ direction: graph.DirectionBoth,
+ expected: "match (a)-[r]-(b) return r",
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ regularQuery, singlePartQuery := cypher.NewRegularQueryWithSingleQuery()
+ match := singlePartQuery.NewReadingClause().NewMatch(false)
+ match.NewPatternPart().AddPatternElements(
+ &cypher.NodePattern{
+ Variable: cypher.NewVariableWithSymbol("a"),
+ },
+ &cypher.RelationshipPattern{
+ Variable: cypher.NewVariableWithSymbol("r"),
+ Direction: testCase.direction,
+ },
+ &cypher.NodePattern{
+ Variable: cypher.NewVariableWithSymbol("b"),
+ },
+ )
+
+ singlePartQuery.NewProjection(false).AddItem(cypher.NewProjectionItemWithExpr(cypher.NewVariableWithSymbol("r")))
+
+ rendered, err := format.RegularQuery(regularQuery, false)
+ require.NoError(t, err)
+ require.Equal(t, testCase.expected, rendered)
+ })
+ }
+}
+
func TestCypherEmitter_HappyPath(t *testing.T) {
test.LoadFixture(t, test.MutationTestCases).Run(t)
test.LoadFixture(t, test.PositiveTestCases).Run(t)
diff --git a/cypher/models/pgsql/test/query_test.go b/cypher/models/pgsql/test/query_test.go
index 38ec85ab..ac39ca9f 100644
--- a/cypher/models/pgsql/test/query_test.go
+++ b/cypher/models/pgsql/test/query_test.go
@@ -5,12 +5,11 @@ import (
"slices"
"testing"
- "github.com/specterops/dawgs/cypher/models/cypher"
"github.com/specterops/dawgs/cypher/models/pgsql"
"github.com/specterops/dawgs/cypher/models/pgsql/translate"
"github.com/specterops/dawgs/cypher/models/walk"
"github.com/specterops/dawgs/graph"
- "github.com/specterops/dawgs/query"
+ v2 "github.com/specterops/dawgs/query/v2"
)
var (
@@ -24,21 +23,20 @@ var (
func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) {
mapper := newKindMapper()
- queries := []*cypher.Where{
- query.Where(query.KindIn(query.Node(), NodeKind1)),
- query.Where(query.Kind(query.Node(), NodeKind2)),
+ queries := []v2.QueryBuilder{
+ v2.New().Where(v2.KindIn(v2.Node(), NodeKind1)).Return(v2.Node()),
+ v2.New().Where(v2.Kind(v2.Node(), NodeKind2)).Return(v2.Node()),
}
- for _, nodeQuery := range queries {
- builder := query.NewBuilderWithCriteria(nodeQuery)
- builtQuery, err := builder.Build(false)
+ for _, queryBuilder := range queries {
+ builtQuery, err := queryBuilder.Build()
if err != nil {
- t.Errorf("could not build query: %v", err)
+ t.Fatalf("could not build query: %v", err)
}
- translatedQuery, err := translate.Translate(context.Background(), builtQuery, mapper, nil, translate.DefaultGraphID)
+ translatedQuery, err := translate.Translate(context.Background(), builtQuery.Query, mapper, builtQuery.Parameters, translate.DefaultGraphID)
if err != nil {
- t.Errorf("could not translate query: %#v: %v", builtQuery, err)
+ t.Fatalf("could not translate query: %#v: %v", builtQuery, err)
}
walk.PgSQL(translatedQuery.Statement, walk.NewSimpleVisitor(func(node pgsql.SyntaxNode, visitorHandler walk.VisitorHandler) {
@@ -47,7 +45,7 @@ func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) {
switch leftTyped := typedNode.LOperand.(type) {
case pgsql.CompoundIdentifier:
if slices.Equal(leftTyped, pgsql.AsCompoundIdentifier("n0", "kind_ids")) && typedNode.Operator != pgsql.OperatorPGArrayOverlap {
- t.Errorf("query did not generate an array overlap operator (&&): %#v", nodeQuery)
+ t.Errorf("query did not generate an array overlap operator (&&): %#v", builtQuery)
}
}
}
diff --git a/cypher/test/cases/mutation_tests.json b/cypher/test/cases/mutation_tests.json
index 7893439b..dc73b031 100644
--- a/cypher/test/cases/mutation_tests.json
+++ b/cypher/test/cases/mutation_tests.json
@@ -4,7 +4,7 @@
"name": "Multipart query with mutation",
"type": "string_match",
"details": {
- "query": "match (s:Ship {name: 'Nebuchadnezzar'}) with s as ship merge p = (c:Crew {name: 'Neo'})\u003c-[:CrewOf]-\u003e(ship) set c.title = 'The One' return p",
+ "query": "match (s:Ship {name: 'Nebuchadnezzar'}) with s as ship merge p = (c:Crew {name: 'Neo'})-[:CrewOf]-(ship) set c.title = 'The One' return p",
"fitness": 7
}
},
diff --git a/cypher/test/cases/positive_tests.json b/cypher/test/cases/positive_tests.json
index b3941c04..cedfccdc 100644
--- a/cypher/test/cases/positive_tests.json
+++ b/cypher/test/cases/positive_tests.json
@@ -189,7 +189,7 @@
"name": "Specify bi-directional relationship",
"type": "string_match",
"details": {
- "query": "match (p:Person)\u003c-[]-\u003e(m:Movie) return m",
+ "query": "match (p:Person)-[]-(m:Movie) return m",
"fitness": 0
}
},
@@ -437,7 +437,7 @@
"name": "built-in shortestPaths()",
"type": "string_match",
"details": {
- "query": "match p = shortestPath((p1:Person)\u003c-[*]-\u003e(p2:Person)) where p1.name = 'tom' and p2.name = 'jerry' return p",
+ "query": "match p = shortestPath((p1:Person)-[*]-(p2:Person)) where p1.name = 'tom' and p2.name = 'jerry' return p",
"fitness": 17
}
},
@@ -453,7 +453,7 @@
"name": "Find nodes with relationships",
"type": "string_match",
"details": {
- "query": "match (b) where (b)\u003c-[]-\u003e() return b",
+ "query": "match (b) where (b)-[]-() return b",
"fitness": -4
}
},
@@ -461,7 +461,7 @@
"name": "Find nodes with no relationships",
"type": "string_match",
"details": {
- "query": "match (b) where not ((b)\u003c-[]-\u003e()) return b",
+ "query": "match (b) where not ((b)-[]-()) return b",
"fitness": -5
}
},
@@ -898,4 +898,4 @@
}
}
]
-}
\ No newline at end of file
+}
diff --git a/integration/testdata/traversal_shapes.json b/integration/testdata/traversal_shapes.json
new file mode 100644
index 00000000..d1041096
--- /dev/null
+++ b/integration/testdata/traversal_shapes.json
@@ -0,0 +1,94 @@
+{
+ "graph": {
+ "nodes": [
+ {"id": "c0", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c1", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c2", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c3", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c4", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c5", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c6", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c7", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c8", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c9", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "c10", "kinds": ["TraversalNode", "ChainNode"]},
+ {"id": "f0", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f1", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f2", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f3", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f1a", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f1b", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f2a", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f2b", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f3a", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f3b", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f1a1", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f1b1", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f2a1", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f2b1", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f3a1", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "f3b1", "kinds": ["TraversalNode", "FanoutNode"]},
+ {"id": "d0", "kinds": ["TraversalNode", "DiamondNode"]},
+ {"id": "d1", "kinds": ["TraversalNode", "DiamondNode"]},
+ {"id": "d2", "kinds": ["TraversalNode", "DiamondNode"]},
+ {"id": "d3", "kinds": ["TraversalNode", "DiamondNode"]},
+ {"id": "d4", "kinds": ["TraversalNode", "DiamondNode"]},
+ {"id": "y0", "kinds": ["TraversalNode", "CycleNode"]},
+ {"id": "y1", "kinds": ["TraversalNode", "CycleNode"]},
+ {"id": "y2", "kinds": ["TraversalNode", "CycleNode"]},
+ {"id": "y3", "kinds": ["TraversalNode", "CycleNode"]},
+ {"id": "x0", "kinds": ["TraversalNode", "DisconnectedNode"]},
+ {"id": "x1", "kinds": ["TraversalNode", "DisconnectedNode"]},
+ {"id": "s0", "kinds": ["TraversalNode", "SelectiveNode"]},
+ {"id": "s1", "kinds": ["TraversalNode", "SelectiveNode"]},
+ {"id": "s2", "kinds": ["TraversalNode", "SelectiveNode"]},
+ {"id": "s3", "kinds": ["TraversalNode", "SelectiveNode"]},
+ {"id": "t1", "kinds": ["TraversalNode", "SelectiveNode"]},
+ {"id": "t2", "kinds": ["TraversalNode", "SelectiveNode"]},
+ {"id": "t3", "kinds": ["TraversalNode", "SelectiveNode"]}
+ ],
+ "edges": [
+ {"start_id": "c0", "end_id": "c1", "kind": "ChainEdge"},
+ {"start_id": "c1", "end_id": "c2", "kind": "ChainEdge"},
+ {"start_id": "c2", "end_id": "c3", "kind": "ChainEdge"},
+ {"start_id": "c3", "end_id": "c4", "kind": "ChainEdge"},
+ {"start_id": "c4", "end_id": "c5", "kind": "ChainEdge"},
+ {"start_id": "c5", "end_id": "c6", "kind": "ChainEdge"},
+ {"start_id": "c6", "end_id": "c7", "kind": "ChainEdge"},
+ {"start_id": "c7", "end_id": "c8", "kind": "ChainEdge"},
+ {"start_id": "c8", "end_id": "c9", "kind": "ChainEdge"},
+ {"start_id": "c9", "end_id": "c10", "kind": "ChainEdge"},
+ {"start_id": "f0", "end_id": "f1", "kind": "FanoutEdge"},
+ {"start_id": "f0", "end_id": "f2", "kind": "FanoutEdge"},
+ {"start_id": "f0", "end_id": "f3", "kind": "FanoutEdge"},
+ {"start_id": "f1", "end_id": "f1a", "kind": "FanoutEdge"},
+ {"start_id": "f1", "end_id": "f1b", "kind": "FanoutEdge"},
+ {"start_id": "f2", "end_id": "f2a", "kind": "FanoutEdge"},
+ {"start_id": "f2", "end_id": "f2b", "kind": "FanoutEdge"},
+ {"start_id": "f3", "end_id": "f3a", "kind": "FanoutEdge"},
+ {"start_id": "f3", "end_id": "f3b", "kind": "FanoutEdge"},
+ {"start_id": "f1a", "end_id": "f1a1", "kind": "FanoutEdge"},
+ {"start_id": "f1b", "end_id": "f1b1", "kind": "FanoutEdge"},
+ {"start_id": "f2a", "end_id": "f2a1", "kind": "FanoutEdge"},
+ {"start_id": "f2b", "end_id": "f2b1", "kind": "FanoutEdge"},
+ {"start_id": "f3a", "end_id": "f3a1", "kind": "FanoutEdge"},
+ {"start_id": "f3b", "end_id": "f3b1", "kind": "FanoutEdge"},
+ {"start_id": "d0", "end_id": "d1", "kind": "DiamondEdge"},
+ {"start_id": "d0", "end_id": "d2", "kind": "DiamondEdge"},
+ {"start_id": "d0", "end_id": "d3", "kind": "DiamondEdge"},
+ {"start_id": "d1", "end_id": "d4", "kind": "DiamondEdge"},
+ {"start_id": "d2", "end_id": "d4", "kind": "DiamondEdge"},
+ {"start_id": "d3", "end_id": "d4", "kind": "DiamondEdge"},
+ {"start_id": "y0", "end_id": "y1", "kind": "CycleEdge"},
+ {"start_id": "y1", "end_id": "y2", "kind": "CycleEdge"},
+ {"start_id": "y2", "end_id": "y0", "kind": "CycleEdge"},
+ {"start_id": "y2", "end_id": "y3", "kind": "CycleEdge"},
+ {"start_id": "s0", "end_id": "s1", "kind": "Allowed"},
+ {"start_id": "s1", "end_id": "s2", "kind": "Allowed"},
+ {"start_id": "s2", "end_id": "s3", "kind": "Allowed"},
+ {"start_id": "s0", "end_id": "t1", "kind": "Blocked"},
+ {"start_id": "t1", "end_id": "t2", "kind": "Blocked"},
+ {"start_id": "t2", "end_id": "t3", "kind": "Blocked"}
+ ]
+ }
+}
diff --git a/query/neo4j/neo4j.go b/query/neo4j/neo4j.go
index 0689f74f..7299d004 100644
--- a/query/neo4j/neo4j.go
+++ b/query/neo4j/neo4j.go
@@ -53,6 +53,14 @@ func (s *QueryBuilder) rewriteParameters() error {
return nil
}
+func hasPreparedMatchPattern(readingClause *cypher.ReadingClause) bool {
+ if readingClause == nil || readingClause.Match == nil {
+ return false
+ }
+
+ return len(readingClause.Match.Pattern) > 0
+}
+
func (s *QueryBuilder) Apply(criteria graph.Criteria) {
switch typedCriteria := criteria.(type) {
case *cypher.Where:
@@ -201,6 +209,10 @@ func (s *QueryBuilder) prepareMatch() error {
return ErrAmbiguousQueryVariables
}
+ if firstReadingClause := query.GetFirstReadingClause(s.query); hasPreparedMatchPattern(firstReadingClause) {
+ return nil
+ }
+
if singleNodeBound && !creatingSingleNode {
patternPart.AddPatternElements(&cypher.NodePattern{
Variable: cypher.NewVariableWithSymbol(query.NodeSymbol),
diff --git a/query/neo4j/neo4j_test.go b/query/neo4j/neo4j_test.go
index 2efc3435..05117347 100644
--- a/query/neo4j/neo4j_test.go
+++ b/query/neo4j/neo4j_test.go
@@ -422,7 +422,7 @@ func TestQueryBuilder_Render(t *testing.T) {
query.Returning(
query.Node(),
),
- ), "match (n) where (n)<-[]->() return n"))
+ ), "match (n) where (n)-[]-() return n"))
t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery(
query.Where(
@@ -436,7 +436,7 @@ func TestQueryBuilder_Render(t *testing.T) {
query.OrderBy(
query.Order(query.NodeProperty("value"), query.Ascending()),
),
- ), "match (n) where (n)<-[]->() return n order by n.value asc"))
+ ), "match (n) where (n)-[]-() return n order by n.value asc"))
t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery(
query.Where(
@@ -451,7 +451,7 @@ func TestQueryBuilder_Render(t *testing.T) {
query.Order(query.NodeProperty("value_1"), query.Ascending()),
query.Order(query.NodeProperty("value_2"), query.Descending()),
),
- ), "match (n) where (n)<-[]->() return n order by n.value_1 asc, n.value_2 desc"))
+ ), "match (n) where (n)-[]-() return n order by n.value_1 asc, n.value_2 desc"))
t.Run("Node has Relationships Order by Node Item with Limit and Offset", assertQueryResult(query.SinglePartQuery(
query.Where(
@@ -469,7 +469,7 @@ func TestQueryBuilder_Render(t *testing.T) {
query.Limit(10),
query.Offset(20),
- ), "match (n) where (n)<-[]->() return n order by n.value_1 asc, n.value_2 desc skip 20 limit 10"))
+ ), "match (n) where (n)-[]-() return n order by n.value_1 asc, n.value_2 desc skip 20 limit 10"))
t.Run("Node has no Relationships", assertQueryResult(query.SinglePartQuery(
query.Where(
@@ -479,7 +479,7 @@ func TestQueryBuilder_Render(t *testing.T) {
query.Returning(
query.Node(),
),
- ), "match (n) where not ((n)<-[]->()) return n"))
+ ), "match (n) where not ((n)-[]-()) return n"))
t.Run("Node Datetime Before", assertQueryResult(query.SinglePartQuery(
query.Where(
diff --git a/query/v2/backend_test.go b/query/v2/backend_test.go
new file mode 100644
index 00000000..475e4e45
--- /dev/null
+++ b/query/v2/backend_test.go
@@ -0,0 +1,359 @@
+package v2_test
+
+import (
+ "context"
+ "strings"
+ "testing"
+
+ "github.com/specterops/dawgs/cypher/models/pgsql/translate"
+ "github.com/specterops/dawgs/drivers/pg/pgutil"
+ "github.com/specterops/dawgs/graph"
+ "github.com/specterops/dawgs/query/neo4j"
+ v2 "github.com/specterops/dawgs/query/v2"
+ "github.com/stretchr/testify/require"
+)
+
+func testKindMapper(kinds ...graph.Kind) *pgutil.InMemoryKindMapper {
+ mapper := pgutil.NewInMemoryKindMapper()
+
+ for _, kind := range kinds {
+ mapper.Put(kind)
+ }
+
+ return mapper
+}
+
+func TestBackendParityNeo4jPrepare(t *testing.T) {
+ cases := map[string]struct {
+ builder v2.QueryBuilder
+ expectedCypher string
+ expectedParams map[string]any
+ }{
+ "node read": {
+ builder: v2.New().Where(
+ v2.Node().Kinds().Has(graph.StringKind("User")),
+ v2.Node().Property("name").Contains("admin"),
+ ).Return(
+ v2.Node(),
+ ).OrderBy(
+ v2.Node().Property("name"),
+ ),
+ expectedCypher: "match (n) where n:User and n.name contains $p0 return n order by n.name asc",
+ expectedParams: map[string]any{"p0": "admin"},
+ },
+ "relationship read": {
+ builder: v2.New().Where(
+ v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
+ v2.Start().ID().Equals(1),
+ ).Return(
+ v2.Start().ID(),
+ v2.Relationship().ID(),
+ v2.End().ID(),
+ ),
+ expectedCypher: "match (s)-[r:MemberOf]->(e) where id(s) = $p0 return id(s), id(r), id(e)",
+ expectedParams: map[string]any{"p0": 1},
+ },
+ "shortest path": {
+ builder: v2.New().WithShortestPaths().Where(
+ v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Path(),
+ ),
+ expectedCypher: "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p",
+ expectedParams: map[string]any{"p0": 1, "p1": 2},
+ },
+ "all shortest paths": {
+ builder: v2.New().WithAllShortestPaths().Where(
+ v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Path(),
+ ),
+ expectedCypher: "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p",
+ expectedParams: map[string]any{"p0": 1, "p1": 2},
+ },
+ "recursive traversal": {
+ builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 2)).Where(
+ v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
+ v2.Start().ID().Equals(1),
+ ).Return(
+ v2.Path(),
+ v2.End().ID(),
+ ),
+ expectedCypher: "match p = (s)-[r:MemberOf*1..2]->(e) where id(s) = $p0 return p, id(e)",
+ expectedParams: map[string]any{"p0": 1},
+ },
+ "create node": {
+ builder: v2.New().Create(
+ v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.NamedParameter("props", map[string]any{"name": "u"})),
+ ).Return(
+ v2.Node().ID(),
+ ),
+ expectedCypher: "create (n:User $p0) return id(n)",
+ expectedParams: map[string]any{"p0": map[string]any{"name": "u"}},
+ },
+ "update node": {
+ builder: v2.New().Where(
+ v2.Node().ID().Equals(1),
+ ).Update(
+ v2.SetProperty(v2.Node().Property("name"), "updated"),
+ ),
+ expectedCypher: "match (n) where id(n) = $p0 set n.name = $p1",
+ expectedParams: map[string]any{"p0": 1, "p1": "updated"},
+ },
+ "delete relationship": {
+ builder: v2.New().Where(
+ v2.Relationship().ID().Equals(1),
+ ).Delete(
+ v2.Relationship(),
+ ),
+ expectedCypher: "match ()-[r]->() where id(r) = $p0 delete r",
+ expectedParams: map[string]any{"p0": 1},
+ },
+ "delete node": {
+ builder: v2.New().Where(
+ v2.Node().ID().Equals(1),
+ ).Delete(
+ v2.Node(),
+ ),
+ expectedCypher: "match (n) where id(n) = $p0 detach delete n",
+ expectedParams: map[string]any{"p0": 1},
+ },
+ }
+
+ for name, testCase := range cases {
+ t.Run(name, func(t *testing.T) {
+ preparedQuery, err := testCase.builder.Build()
+ require.NoError(t, err)
+
+ queryBuilder := neo4j.NewQueryBuilder(preparedQuery.Query)
+ require.NoError(t, queryBuilder.Prepare())
+
+ rendered, err := queryBuilder.Render()
+ require.NoError(t, err)
+ require.Equal(t, testCase.expectedCypher, rendered)
+ require.Equal(t, testCase.expectedParams, queryBuilder.Parameters)
+ })
+ }
+}
+
+func TestBackendParityPGTranslateTraversalDepth(t *testing.T) {
+ edgeKind := graph.StringKind("MemberOf")
+ mapper := testKindMapper(edgeKind)
+
+ cases := map[string]struct {
+ builder v2.QueryBuilder
+ expectedSQLContains []string
+ }{
+ "path": {
+ builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 2)).Where(
+ v2.Relationship().Kind().Is(edgeKind),
+ v2.Start().ID().Equals(1),
+ ).Return(
+ v2.Path(),
+ ),
+ expectedSQLContains: []string{
+ "with recursive",
+ "ordered_edges_to_path",
+ "n0.id = @pi0::int8",
+ "e0.kind_id = any (array [1]::int2[])",
+ "depth < 2",
+ },
+ },
+ "endpoints": {
+ builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 2)).Where(
+ v2.Relationship().Kind().Is(edgeKind),
+ v2.Start().ID().Equals(1),
+ ).Return(
+ v2.Start().ID(),
+ v2.End().ID(),
+ ),
+ expectedSQLContains: []string{
+ "with recursive",
+ "n0.id = @pi0::int8",
+ "e0.kind_id = any (array [1]::int2[])",
+ "depth < 2",
+ "select (s0.n0).id, (s0.n1).id from s0",
+ },
+ },
+ }
+
+ for name, testCase := range cases {
+ t.Run(name, func(t *testing.T) {
+ preparedQuery, err := testCase.builder.Build()
+ require.NoError(t, err)
+
+ translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID)
+ require.NoError(t, err)
+
+ sql, err := translate.Translated(translation)
+ require.NoError(t, err)
+
+ for _, expected := range testCase.expectedSQLContains {
+ require.Contains(t, sql, expected)
+ }
+ })
+ }
+}
+
+func TestBackendParityPGTranslate(t *testing.T) {
+ userKind := graph.StringKind("User")
+ edgeKind := graph.StringKind("MemberOf")
+ mapper := testKindMapper(userKind, edgeKind)
+
+ cases := map[string]struct {
+ builder v2.QueryBuilder
+ expectedSQL string
+ expectedParams map[string]any
+ }{
+ "node read": {
+ builder: v2.New().Where(
+ v2.Node().Kinds().Has(userKind),
+ v2.Node().Property("name").Contains("admin"),
+ ).Return(
+ v2.Node().ID(),
+ v2.Node().Kinds(),
+ ),
+ expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and cypher_contains((n0.properties ->> 'name'), (@pi0::text)::text)::bool)) select (s0.n0).id, (array(select _kind.name from generate_subscripts((s0.n0).kind_ids, 1) as _kind_idx, kind _kind where _kind.id = ((s0.n0).kind_ids)[_kind_idx] order by _kind_idx))::text[] from s0;",
+ expectedParams: map[string]any{"pi0": "admin"},
+ },
+ "relationship read": {
+ builder: v2.New().Where(
+ v2.Relationship().Kind().Is(edgeKind),
+ v2.Start().ID().Equals(1),
+ ).Return(
+ v2.Start().ID(),
+ v2.Relationship().ID(),
+ v2.End().ID(),
+ ),
+ expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on (n0.id = @pi0::int8) and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [2]::int2[])) select (s0.n0).id, (s0.e0).id, (s0.n1).id from s0;",
+ expectedParams: map[string]any{"pi0": 1},
+ },
+ "update node": {
+ builder: v2.New().Where(
+ v2.Node().ID().Equals(1),
+ ).Update(
+ v2.SetProperty(v2.Node().Property("name"), "updated"),
+ ),
+ expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (update node n1 set properties = n1.properties || jsonb_build_object('name', @pi1::text)::jsonb from s0 where (s0.n0).id = n1.id returning (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n0) select 1;",
+ expectedParams: map[string]any{"pi0": 1, "pi1": "updated"},
+ },
+ "delete relationship": {
+ builder: v2.New().Where(
+ v2.Relationship().ID().Equals(1),
+ ).Delete(
+ v2.Relationship(),
+ ),
+ expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where (e0.id = @pi0::int8)), s1 as (delete from edge e1 using s0 where (s0.e0).id = e1.id) select 1;",
+ expectedParams: map[string]any{"pi0": 1},
+ },
+ "delete node": {
+ builder: v2.New().Where(
+ v2.Node().ID().Equals(1),
+ ).Delete(
+ v2.Node(),
+ ),
+ expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (delete from node n1 using s0 where (s0.n0).id = n1.id) select 1;",
+ expectedParams: map[string]any{"pi0": 1},
+ },
+ }
+
+ for name, testCase := range cases {
+ t.Run(name, func(t *testing.T) {
+ preparedQuery, err := testCase.builder.Build()
+ require.NoError(t, err)
+
+ translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID)
+ require.NoError(t, err)
+
+ sql, err := translate.Translated(translation)
+ require.NoError(t, err)
+ require.Equal(t, testCase.expectedSQL, sql)
+ require.Equal(t, testCase.expectedParams, translation.Parameters)
+ })
+ }
+}
+
+func TestBackendParityPGTranslateShortestPaths(t *testing.T) {
+ edgeKind := graph.StringKind("MemberOf")
+ mapper := testKindMapper(edgeKind)
+
+ cases := map[string]struct {
+ builder v2.QueryBuilder
+ expectedHarness string
+ }{
+ "shortest path": {
+ builder: v2.New().WithShortestPaths().Where(
+ v2.Relationship().Kind().Is(edgeKind),
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Path(),
+ ),
+ expectedHarness: "bidirectional_sp_harness",
+ },
+ "all shortest paths": {
+ builder: v2.New().WithAllShortestPaths().Where(
+ v2.Relationship().Kind().Is(edgeKind),
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Path(),
+ ),
+ expectedHarness: "bidirectional_asp_harness",
+ },
+ }
+
+ for name, testCase := range cases {
+ t.Run(name, func(t *testing.T) {
+ preparedQuery, err := testCase.builder.Build()
+ require.NoError(t, err)
+
+ translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID)
+ require.NoError(t, err)
+
+ sql, err := translate.Translated(translation)
+ require.NoError(t, err)
+ require.Contains(t, sql, testCase.expectedHarness)
+ require.Contains(t, sql, "ordered_edges_to_path")
+ require.Contains(t, sql, "n0.id = 1")
+ require.Contains(t, sql, "n1.id = 2")
+
+ serializedHarnessQueryHasKindConstraint := false
+ for _, parameterValue := range translation.Parameters {
+ if serializedQuery, typeOK := parameterValue.(string); typeOK && strings.Contains(serializedQuery, "array [1]::int2[]") {
+ serializedHarnessQueryHasKindConstraint = true
+ break
+ }
+ }
+ require.True(t, serializedHarnessQueryHasKindConstraint, "expected serialized shortest-path harness query to contain edge kind constraint: %#v", translation.Parameters)
+ })
+ }
+}
+
+func TestBackendParityPGCreate(t *testing.T) {
+ edgeKind := graph.StringKind("MemberOf")
+ mapper := testKindMapper(edgeKind)
+
+ preparedQuery, err := v2.New().Where(
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Create(
+ v2.RelationshipPattern(edgeKind, nil, graph.DirectionOutbound),
+ ).Return(
+ v2.Relationship().ID(),
+ ).Build()
+ require.NoError(t, err)
+
+ translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID)
+ require.NoError(t, err)
+
+ sql, err := translate.Translated(translation)
+ require.NoError(t, err)
+ require.Contains(t, sql, "insert into edge")
+ require.Contains(t, sql, "graph_id")
+ require.Contains(t, sql, "kind_id")
+}
diff --git a/query/v2/compat.go b/query/v2/compat.go
new file mode 100644
index 00000000..fb6be7b3
--- /dev/null
+++ b/query/v2/compat.go
@@ -0,0 +1,303 @@
+package v2
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/specterops/dawgs/cypher/models/cypher"
+ "github.com/specterops/dawgs/graph"
+)
+
+func Variable(name string) *cypher.Variable {
+ return cypher.NewVariableWithSymbol(name)
+}
+
+func Identity(reference any) *cypher.FunctionInvocation {
+ return cypher.NewSimpleFunctionInvocation(cypher.IdentityFunction, expressionOrError(reference))
+}
+
+func NodeID() *cypher.FunctionInvocation {
+ return Identity(Identifiers.Node())
+}
+
+func RelationshipID() *cypher.FunctionInvocation {
+ return Identity(Identifiers.Relationship())
+}
+
+func StartID() *cypher.FunctionInvocation {
+ return Identity(Identifiers.Start())
+}
+
+func EndID() *cypher.FunctionInvocation {
+ return Identity(Identifiers.End())
+}
+
+func Count(reference any) *cypher.FunctionInvocation {
+ return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, expressionOrError(reference))
+}
+
+func CountDistinct(reference any) *cypher.FunctionInvocation {
+ return &cypher.FunctionInvocation{
+ Name: cypher.CountFunction,
+ Distinct: true,
+ Arguments: []cypher.Expression{expressionOrError(reference)},
+ }
+}
+
+func Size(expression any) *cypher.FunctionInvocation {
+ return cypher.NewSimpleFunctionInvocation(cypher.ListSizeFunction, expressionOrError(expression))
+}
+
+func KindsOf(reference any) *cypher.FunctionInvocation {
+ if scopedReference, typeOK := reference.(scopedExpression); typeOK {
+ if variable, typeOK := scopedReference.qualifier().(*cypher.Variable); !typeOK {
+ return invalidExpression(fmt.Errorf("expected variable reference, got %T", scopedReference.qualifier()))
+ } else if expression, err := kindProjectionExpression(scopedReference.roleName(), variable); err != nil {
+ return invalidExpression(err)
+ } else if invocation, typeOK := expression.(*cypher.FunctionInvocation); !typeOK {
+ return invalidExpression(fmt.Errorf("expected kind projection function, got %T", expression))
+ } else {
+ return invocation
+ }
+ }
+
+ expression := expressionOrError(reference)
+
+ switch typedExpression := expression.(type) {
+ case *cypher.Variable:
+ switch typedExpression.Symbol {
+ case Identifiers.node, Identifiers.start, Identifiers.end:
+ return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedExpression)
+
+ case Identifiers.relationship:
+ return cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedExpression)
+ }
+ }
+
+ return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, expression)
+}
+
+func Kind(reference any, kinds ...graph.Kind) *cypher.KindMatcher {
+ return &cypher.KindMatcher{
+ Reference: expressionOrError(reference),
+ Kinds: kinds,
+ }
+}
+
+func KindIn(reference any, kinds ...graph.Kind) *cypher.KindMatcher {
+ return Kind(reference, kinds...)
+}
+
+func AddKind(reference any, kind graph.Kind) *cypher.SetItem {
+ return AddKinds(reference, graph.Kinds{kind})
+}
+
+func AddKinds(reference any, kinds graph.Kinds) *cypher.SetItem {
+ return cypher.NewSetItem(expressionOrError(reference), cypher.OperatorLabelAssignment, kinds)
+}
+
+func DeleteKind(reference any, kind graph.Kind) *cypher.RemoveItem {
+ return DeleteKinds(reference, graph.Kinds{kind})
+}
+
+func DeleteKinds(reference any, kinds graph.Kinds) *cypher.RemoveItem {
+ return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(expressionOrError(reference), kinds, false))
+}
+
+func SetProperty(reference any, value any) *cypher.SetItem {
+ return cypher.NewSetItem(expressionOrError(reference), cypher.OperatorAssignment, valueExpression(value))
+}
+
+func SetProperties(reference any, properties map[string]any) *cypher.Set {
+ set := &cypher.Set{}
+
+ for _, key := range sortedPropertyKeys(properties) {
+ set.Items = append(set.Items, cypher.NewSetItem(
+ propertyLookupOrError(reference, key),
+ cypher.OperatorAssignment,
+ valueExpression(properties[key]),
+ ))
+ }
+
+ return set
+}
+
+func DeleteProperty(reference any) *cypher.RemoveItem {
+ return cypher.RemoveProperty(expressionOrError(reference))
+}
+
+func DeleteProperties(reference any, propertyNames ...string) *cypher.Remove {
+ remove := &cypher.Remove{}
+
+ for _, propertyName := range propertyNames {
+ remove.Items = append(remove.Items, cypher.RemoveProperty(propertyLookupOrError(reference, propertyName)))
+ }
+
+ return remove
+}
+
+func NodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern {
+ return &cypher.NodePattern{
+ Variable: Identifiers.Node(),
+ Kinds: kinds,
+ Properties: properties,
+ }
+}
+
+func StartNodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern {
+ return &cypher.NodePattern{
+ Variable: Identifiers.Start(),
+ Kinds: kinds,
+ Properties: properties,
+ }
+}
+
+func EndNodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern {
+ return &cypher.NodePattern{
+ Variable: Identifiers.End(),
+ Kinds: kinds,
+ Properties: properties,
+ }
+}
+
+func RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) *cypher.RelationshipPattern {
+ return &cypher.RelationshipPattern{
+ Variable: Identifiers.Relationship(),
+ Kinds: graph.Kinds{kind},
+ Direction: direction,
+ Properties: properties,
+ }
+}
+
+func Equals(reference any, value any) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorEquals, valueExpression(value))
+}
+
+func GreaterThan(reference any, value any) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorGreaterThan, valueExpression(value))
+}
+
+func After(reference any, value any) cypher.Expression {
+ return GreaterThan(reference, value)
+}
+
+func GreaterThanOrEqualTo(reference any, value any) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorGreaterThanOrEqualTo, valueExpression(value))
+}
+
+func GreaterThanOrEquals(reference any, value any) cypher.Expression {
+ return GreaterThanOrEqualTo(reference, value)
+}
+
+func LessThan(reference any, value any) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorLessThan, valueExpression(value))
+}
+
+func LessThanGraphQuery(reference any, other any) cypher.Expression {
+ return LessThan(reference, other)
+}
+
+func Before(reference any, value time.Time) cypher.Expression {
+ return LessThan(reference, value)
+}
+
+func BeforeGraphQuery(reference any, other any) cypher.Expression {
+ return LessThan(reference, other)
+}
+
+func LessThanOrEqualTo(reference any, value any) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorLessThanOrEqualTo, valueExpression(value))
+}
+
+func LessThanOrEquals(reference any, value any) cypher.Expression {
+ return LessThanOrEqualTo(reference, value)
+}
+
+func In(reference any, value any) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIn, valueExpression(value))
+}
+
+func InInverted(reference any, value any) cypher.Expression {
+ return cypher.NewComparison(valueExpression(value), cypher.OperatorIn, expressionOrError(reference))
+}
+
+func InIDs(reference any, ids ...graph.ID) cypher.Expression {
+ expression := expressionOrError(reference)
+
+ if variable, typeOK := expression.(*cypher.Variable); typeOK {
+ expression = Identity(variable)
+ }
+
+ return cypher.NewComparison(expression, cypher.OperatorIn, Parameter(ids))
+}
+
+func StringContains(reference any, value string) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorContains, Parameter(value))
+}
+
+func StringStartsWith(reference any, value string) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorStartsWith, Parameter(value))
+}
+
+func StringEndsWith(reference any, value string) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorEndsWith, Parameter(value))
+}
+
+func CaseInsensitiveStringContains(reference any, value string) cypher.Expression {
+ return cypher.NewComparison(
+ cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)),
+ cypher.OperatorContains,
+ Parameter(strings.ToLower(value)),
+ )
+}
+
+func CaseInsensitiveStringStartsWith(reference any, value string) cypher.Expression {
+ return cypher.NewComparison(
+ cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)),
+ cypher.OperatorStartsWith,
+ Parameter(strings.ToLower(value)),
+ )
+}
+
+func CaseInsensitiveStringEndsWith(reference any, value string) cypher.Expression {
+ return cypher.NewComparison(
+ cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)),
+ cypher.OperatorEndsWith,
+ Parameter(strings.ToLower(value)),
+ )
+}
+
+func Exists(reference any) cypher.Expression {
+ return IsNotNull(reference)
+}
+
+func IsNull(reference any) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIs, Literal(nil))
+}
+
+func IsNotNull(reference any) cypher.Expression {
+ return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIsNot, Literal(nil))
+}
+
+func HasRelationships(reference any) *cypher.PatternPredicate {
+ patternPredicate := cypher.NewPatternPredicate()
+
+ if variable, err := variableReference(reference); err != nil {
+ patternPredicate.AddElement(&cypher.NodePattern{
+ Properties: invalidExpression(err),
+ })
+ } else {
+ patternPredicate.AddElement(&cypher.NodePattern{
+ Variable: cypher.NewVariableWithSymbol(variable.Symbol),
+ })
+ }
+
+ patternPredicate.AddElement(&cypher.RelationshipPattern{
+ Direction: graph.DirectionBoth,
+ })
+
+ patternPredicate.AddElement(&cypher.NodePattern{})
+
+ return patternPredicate
+}
diff --git a/query/v2/doc.go b/query/v2/doc.go
new file mode 100644
index 00000000..a21a5004
--- /dev/null
+++ b/query/v2/doc.go
@@ -0,0 +1,6 @@
+// Package v2 contains the experimental fluent Cypher query builder.
+//
+// It is intentionally isolated from the stable query package so callers can
+// opt in without pulling the current graph query APIs through a compatibility
+// layer.
+package v2
diff --git a/query/v2/legacy_parity_test.go b/query/v2/legacy_parity_test.go
new file mode 100644
index 00000000..e8fe1bee
--- /dev/null
+++ b/query/v2/legacy_parity_test.go
@@ -0,0 +1,213 @@
+package v2_test
+
+import (
+ "testing"
+
+ "github.com/specterops/dawgs/cypher/models/cypher"
+ "github.com/specterops/dawgs/graph"
+ legacyquery "github.com/specterops/dawgs/query"
+ "github.com/specterops/dawgs/query/neo4j"
+ v2 "github.com/specterops/dawgs/query/v2"
+ "github.com/stretchr/testify/require"
+)
+
+func renderNeo4jQuery(t *testing.T, regularQuery *cypher.RegularQuery, prepareAllShortestPaths bool) (string, map[string]any) {
+ t.Helper()
+
+ queryBuilder := neo4j.NewQueryBuilder(regularQuery)
+
+ if prepareAllShortestPaths {
+ require.NoError(t, queryBuilder.PrepareAllShortestPaths())
+ } else {
+ require.NoError(t, queryBuilder.Prepare())
+ }
+
+ rendered, err := queryBuilder.Render()
+ require.NoError(t, err)
+
+ return rendered, queryBuilder.Parameters
+}
+
+func assertLegacyNeo4jParity(t *testing.T, legacyQuery *cypher.RegularQuery, v2Builder v2.QueryBuilder, prepareLegacyAllShortestPaths bool) {
+ t.Helper()
+
+ preparedQuery, err := v2Builder.Build()
+ require.NoError(t, err)
+
+ legacyRendered, legacyParameters := renderNeo4jQuery(t, legacyQuery, prepareLegacyAllShortestPaths)
+ v2Rendered, v2Parameters := renderNeo4jQuery(t, preparedQuery.Query, false)
+
+ require.Equal(t, legacyRendered, v2Rendered)
+ require.Equal(t, legacyParameters, v2Parameters)
+}
+
+func TestLegacyNeo4jParity(t *testing.T) {
+ userKind := graph.StringKind("User")
+ edgeKind := graph.StringKind("MemberOf")
+
+ t.Run("node count by kind", func(t *testing.T) {
+ assertLegacyNeo4jParity(t,
+ legacyquery.SinglePartQuery(
+ legacyquery.Where(
+ legacyquery.KindIn(legacyquery.Node(), userKind),
+ ),
+ legacyquery.Returning(
+ legacyquery.Count(legacyquery.Node()),
+ ),
+ ),
+ v2.New().Where(
+ v2.Node().Kinds().Has(userKind),
+ ).Return(
+ v2.Node().Count(),
+ ),
+ false,
+ )
+ })
+
+ t.Run("node read with pagination", func(t *testing.T) {
+ assertLegacyNeo4jParity(t,
+ legacyquery.SinglePartQuery(
+ legacyquery.Where(
+ legacyquery.And(
+ legacyquery.StringContains(legacyquery.NodeProperty("name"), "admin"),
+ legacyquery.IsNotNull(legacyquery.NodeProperty("enabled")),
+ ),
+ ),
+ legacyquery.Returning(
+ legacyquery.Node(),
+ legacyquery.OrderBy(legacyquery.Order(legacyquery.NodeProperty("name"), legacyquery.Ascending())),
+ legacyquery.Offset(0),
+ legacyquery.Limit(0),
+ ),
+ ),
+ v2.New().Where(
+ v2.And(
+ v2.Node().Property("name").Contains("admin"),
+ v2.Node().Property("enabled").IsNotNull(),
+ ),
+ ).Return(
+ v2.Node(),
+ ).OrderBy(
+ v2.Asc(v2.Node().Property("name")),
+ ).Skip(0).Limit(0),
+ false,
+ )
+ })
+
+ t.Run("node read with or and adjacent predicate", func(t *testing.T) {
+ assertLegacyNeo4jParity(t,
+ legacyquery.SinglePartQuery(
+ legacyquery.Where(
+ legacyquery.And(
+ legacyquery.Or(
+ legacyquery.Equals(legacyquery.NodeProperty("name"), "alice"),
+ legacyquery.Equals(legacyquery.NodeProperty("name"), "bob"),
+ ),
+ legacyquery.IsNotNull(legacyquery.NodeProperty("enabled")),
+ ),
+ ),
+ legacyquery.Returning(
+ legacyquery.Node(),
+ ),
+ ),
+ v2.New().Where(
+ v2.Or(
+ v2.Node().Property("name").Equals("alice"),
+ v2.Node().Property("name").Equals("bob"),
+ ),
+ v2.Node().Property("enabled").IsNotNull(),
+ ).Return(
+ v2.Node(),
+ ),
+ false,
+ )
+ })
+
+ t.Run("relationship read", func(t *testing.T) {
+ assertLegacyNeo4jParity(t,
+ legacyquery.SinglePartQuery(
+ legacyquery.Where(
+ legacyquery.And(
+ legacyquery.KindIn(legacyquery.Relationship(), edgeKind),
+ legacyquery.Equals(legacyquery.StartID(), 1),
+ legacyquery.Equals(legacyquery.EndID(), 2),
+ ),
+ ),
+ legacyquery.Returning(
+ legacyquery.StartID(),
+ legacyquery.RelationshipID(),
+ legacyquery.EndID(),
+ ),
+ ),
+ v2.New().Where(
+ v2.Relationship().Kind().Is(edgeKind),
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Start().ID(),
+ v2.Relationship().ID(),
+ v2.End().ID(),
+ ),
+ false,
+ )
+ })
+
+ t.Run("create relationship with matched endpoints", func(t *testing.T) {
+ properties := map[string]any{"name": "rel"}
+
+ assertLegacyNeo4jParity(t,
+ legacyquery.SinglePartQuery(
+ legacyquery.Where(
+ legacyquery.And(
+ legacyquery.Equals(legacyquery.StartID(), 1),
+ legacyquery.Equals(legacyquery.EndID(), 2),
+ ),
+ ),
+ legacyquery.Create(
+ legacyquery.Start(),
+ legacyquery.RelationshipPattern(edgeKind, legacyquery.Parameter(properties), graph.DirectionOutbound),
+ legacyquery.End(),
+ ),
+ legacyquery.Returning(
+ legacyquery.RelationshipID(),
+ ),
+ ),
+ v2.New().Where(
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Create(
+ v2.Start(),
+ v2.RelationshipPattern(edgeKind, v2.Parameter(properties), graph.DirectionOutbound),
+ v2.End(),
+ ).Return(
+ v2.Relationship().ID(),
+ ),
+ false,
+ )
+ })
+
+ t.Run("all shortest paths", func(t *testing.T) {
+ assertLegacyNeo4jParity(t,
+ legacyquery.SinglePartQuery(
+ legacyquery.Where(
+ legacyquery.And(
+ legacyquery.KindIn(legacyquery.Relationship(), edgeKind),
+ legacyquery.Equals(legacyquery.StartID(), 1),
+ legacyquery.Equals(legacyquery.EndID(), 2),
+ ),
+ ),
+ legacyquery.Returning(
+ legacyquery.Path(),
+ ),
+ ),
+ v2.New().WithAllShortestPaths().Where(
+ v2.Relationship().Kind().Is(edgeKind),
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Path(),
+ ),
+ true,
+ )
+ })
+}
diff --git a/query/v2/query.go b/query/v2/query.go
new file mode 100644
index 00000000..66ecb179
--- /dev/null
+++ b/query/v2/query.go
@@ -0,0 +1,1560 @@
+package v2
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/specterops/dawgs/cypher/models/cypher"
+ "github.com/specterops/dawgs/graph"
+)
+
+type runtimeIdentifiers struct {
+ path string
+ node string
+ start string
+ relationship string
+ end string
+}
+
+type TraversalDepth struct {
+ patternRange *cypher.PatternRange
+ err error
+}
+
+func traversalDepthBound(value int64) *int64 {
+ return &value
+}
+
+func newTraversalDepth(start, end *int64) TraversalDepth {
+ if start != nil && *start < 0 {
+ return TraversalDepth{
+ err: fmt.Errorf("traversal depth minimum must be non-negative: %d", *start),
+ }
+ }
+
+ if end != nil && *end < 0 {
+ return TraversalDepth{
+ err: fmt.Errorf("traversal depth maximum must be non-negative: %d", *end),
+ }
+ }
+
+ if start != nil && end != nil && *end < *start {
+ return TraversalDepth{
+ err: fmt.Errorf("traversal depth maximum %d is less than minimum %d", *end, *start),
+ }
+ }
+
+ return TraversalDepth{
+ patternRange: cypher.NewPatternRange(start, end),
+ }
+}
+
+func (s TraversalDepth) rangePattern() *cypher.PatternRange {
+ if s.patternRange == nil {
+ return &cypher.PatternRange{}
+ }
+
+ return cypher.Copy(s.patternRange)
+}
+
+func AnyDepth() TraversalDepth {
+ return newTraversalDepth(nil, nil)
+}
+
+func MinDepth(min int64) TraversalDepth {
+ return newTraversalDepth(traversalDepthBound(min), nil)
+}
+
+func MaxDepth(max int64) TraversalDepth {
+ return newTraversalDepth(nil, traversalDepthBound(max))
+}
+
+func DepthRange(min, max int64) TraversalDepth {
+ return newTraversalDepth(traversalDepthBound(min), traversalDepthBound(max))
+}
+
+func ExactDepth(depth int64) TraversalDepth {
+ depthBound := traversalDepthBound(depth)
+ return newTraversalDepth(depthBound, depthBound)
+}
+
+// Accessors return fresh variables; compare symbols rather than pointer identity.
+func (s runtimeIdentifiers) Path() *cypher.Variable {
+ return cypher.NewVariableWithSymbol(s.path)
+}
+
+func (s runtimeIdentifiers) Node() *cypher.Variable {
+ return cypher.NewVariableWithSymbol(s.node)
+}
+
+func (s runtimeIdentifiers) Start() *cypher.Variable {
+ return cypher.NewVariableWithSymbol(s.start)
+}
+
+func (s runtimeIdentifiers) Relationship() *cypher.Variable {
+ return cypher.NewVariableWithSymbol(s.relationship)
+}
+
+func (s runtimeIdentifiers) End() *cypher.Variable {
+ return cypher.NewVariableWithSymbol(s.end)
+}
+
+var Identifiers = runtimeIdentifiers{
+ path: "p",
+ node: "n",
+ start: "s",
+ relationship: "r",
+ end: "e",
+}
+
+type Scope struct {
+ identifiers runtimeIdentifiers
+ errors []error
+}
+
+func DefaultScope() Scope {
+ return Scope{
+ identifiers: Identifiers,
+ }
+}
+
+func NewScope(path, node, start, relationship, end string) Scope {
+ identifiers := runtimeIdentifiers{
+ path: path,
+ node: node,
+ start: start,
+ relationship: relationship,
+ end: end,
+ }
+
+ return Scope{
+ identifiers: identifiers,
+ errors: validateRuntimeIdentifiers(identifiers),
+ }
+}
+
+func validateRuntimeIdentifiers(identifiers runtimeIdentifiers) []error {
+ aliases := []struct {
+ role string
+ value string
+ }{
+ {role: "path", value: identifiers.path},
+ {role: "node", value: identifiers.node},
+ {role: "start", value: identifiers.start},
+ {role: "relationship", value: identifiers.relationship},
+ {role: "end", value: identifiers.end},
+ }
+
+ var (
+ errs []error
+ seen = map[string]string{}
+ )
+
+ for _, alias := range aliases {
+ if err := validateCypherSymbol(alias.value, "scope alias "+alias.role); err != nil {
+ errs = append(errs, err)
+ continue
+ }
+
+ if existingRole, exists := seen[alias.value]; exists {
+ errs = append(errs, fmt.Errorf("scope aliases %s and %s both use %q", existingRole, alias.role, alias.value))
+ } else {
+ seen[alias.value] = alias.role
+ }
+ }
+
+ return errs
+}
+
+func (s Scope) New() QueryBuilder {
+ return newBuilder(s.identifiers, s.errors...)
+}
+
+func (s Scope) Node() NodeContinuation {
+ return &entity[NodeContinuation]{
+ identifier: s.identifiers.Node(),
+ role: Identifiers.node,
+ }
+}
+
+func (s Scope) Path() PathContinuation {
+ return &entity[PathContinuation]{
+ identifier: s.identifiers.Path(),
+ role: Identifiers.path,
+ }
+}
+
+func (s Scope) Start() NodeContinuation {
+ return &entity[NodeContinuation]{
+ identifier: s.identifiers.Start(),
+ role: Identifiers.start,
+ }
+}
+
+func (s Scope) Relationship() RelationshipContinuation {
+ return &entity[RelationshipContinuation]{
+ identifier: s.identifiers.Relationship(),
+ role: Identifiers.relationship,
+ }
+}
+
+func (s Scope) End() NodeContinuation {
+ return &entity[NodeContinuation]{
+ identifier: s.identifiers.End(),
+ role: Identifiers.end,
+ }
+}
+
+func Literal(value any) *cypher.Literal {
+ if value == nil {
+ return cypher.NewLiteral(nil, true)
+ }
+
+ if strValue, typeOK := value.(string); typeOK {
+ return cypher.NewStringLiteral(strValue)
+ }
+
+ return cypher.NewLiteral(value, false)
+}
+
+func Parameter(value any) *cypher.Parameter {
+ if parameter, typeOK := value.(*cypher.Parameter); typeOK {
+ return parameter
+ }
+
+ return &cypher.Parameter{
+ Value: value,
+ }
+}
+
+func NamedParameter(symbol string, value any) *cypher.Parameter {
+ return cypher.NewParameter(symbol, value)
+}
+
+func valueExpression(value any) cypher.Expression {
+ switch typedValue := value.(type) {
+ case *cypher.Parameter:
+ return typedValue
+ case *cypher.Literal:
+ return typedValue
+ case *cypher.Variable:
+ return typedValue
+ case *cypher.PropertyLookup:
+ return typedValue
+ case *cypher.FunctionInvocation:
+ return typedValue
+ case *cypher.Parenthetical:
+ return typedValue
+ case *cypher.Comparison:
+ return typedValue
+ case *cypher.Negation:
+ return typedValue
+ case *cypher.Conjunction:
+ return typedValue
+ case *cypher.Disjunction:
+ return typedValue
+ case *cypher.ExclusiveDisjunction:
+ return typedValue
+ case *cypher.KindMatcher:
+ return typedValue
+ case *cypher.ListLiteral:
+ return typedValue
+ case cypher.MapLiteral:
+ return typedValue
+ case *cypher.PatternPredicate:
+ return typedValue
+ case *cypher.ArithmeticExpression:
+ return typedValue
+ case *cypher.UnaryAddOrSubtractExpression:
+ return typedValue
+ case *cypher.FilterExpression:
+ return typedValue
+ case *cypher.IDInCollection:
+ return typedValue
+ case QualifiedExpression:
+ if expression, err := projectionExpression(typedValue); err != nil {
+ return invalidExpression(err)
+ } else {
+ return expression
+ }
+ default:
+ return Parameter(value)
+ }
+}
+
+func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode) ([]cypher.Expression, cypher.SyntaxNode) {
+ if len(operands) == 0 {
+ return nil, invalidExpression(fmt.Errorf("%s requires at least one operand", operator))
+ }
+
+ expressions := make([]cypher.Expression, len(operands))
+ for idx, operand := range operands {
+ expressions[idx] = operand
+ }
+
+ return expressions, nil
+}
+
+func comparisonHasLogicalOperator(comparison *cypher.Comparison) bool {
+ if comparison == nil {
+ return false
+ }
+
+ for _, partial := range comparison.Partials {
+ switch partial.Operator {
+ case cypher.OperatorAnd, cypher.OperatorOr:
+ return true
+ }
+ }
+
+ return false
+}
+
+func parenthesizeDisjunctiveExpression(expression cypher.Expression) cypher.Expression {
+ switch typedExpression := expression.(type) {
+ case *cypher.Parenthetical:
+ return typedExpression
+ case *cypher.Disjunction, *cypher.ExclusiveDisjunction:
+ return cypher.NewParenthetical(typedExpression)
+ case *cypher.Comparison:
+ if comparisonHasLogicalOperator(typedExpression) {
+ return cypher.NewParenthetical(typedExpression)
+ }
+ }
+
+ return expression
+}
+
+func parenthesizeLogicalExpression(expression cypher.Expression) cypher.Expression {
+ switch typedExpression := expression.(type) {
+ case *cypher.Parenthetical:
+ return typedExpression
+ case *cypher.Conjunction, *cypher.Disjunction, *cypher.ExclusiveDisjunction:
+ return cypher.NewParenthetical(typedExpression)
+ case *cypher.Comparison:
+ if comparisonHasLogicalOperator(typedExpression) {
+ return cypher.NewParenthetical(typedExpression)
+ }
+ }
+
+ return expression
+}
+
+func Not(operand cypher.Expression) cypher.Expression {
+ return cypher.NewNegation(parenthesizeLogicalExpression(operand))
+}
+
+func And(operands ...cypher.SyntaxNode) cypher.SyntaxNode {
+ expressions, errExpression := joinedExpressionList(cypher.OperatorAnd, operands)
+ if errExpression != nil {
+ return errExpression
+ }
+
+ for idx, expression := range expressions {
+ expressions[idx] = parenthesizeDisjunctiveExpression(expression)
+ }
+
+ return cypher.NewConjunction(expressions...)
+}
+
+func Or(operands ...cypher.SyntaxNode) cypher.SyntaxNode {
+ expressions, errExpression := joinedExpressionList(cypher.OperatorOr, operands)
+ if errExpression != nil {
+ return errExpression
+ }
+
+ for idx, expression := range expressions {
+ expressions[idx] = parenthesizeLogicalExpression(expression)
+ }
+
+ return cypher.NewParenthetical(cypher.NewDisjunction(expressions...))
+}
+
+type SortDirection int
+
+const (
+ SortAscending SortDirection = iota
+ SortDescending
+)
+
+func Asc(expression any) *cypher.SortItem {
+ return Order(expression, SortAscending)
+}
+
+func Desc(expression any) *cypher.SortItem {
+ return Order(expression, SortDescending)
+}
+
+func validateSortDirection(direction SortDirection) error {
+ switch direction {
+ case SortAscending, SortDescending:
+ return nil
+ default:
+ return fmt.Errorf("unsupported sort direction: %d", direction)
+ }
+}
+
+func Order(expression any, direction SortDirection) *cypher.SortItem {
+ expressionValue := expressionOrError(expression)
+ if err := validateSortDirection(direction); err != nil {
+ expressionValue = invalidExpression(err)
+ }
+
+ return &cypher.SortItem{
+ Ascending: direction != SortDescending,
+ Expression: expressionValue,
+ }
+}
+
+func As(expression any, alias string) *cypher.ProjectionItem {
+ return &cypher.ProjectionItem{
+ Expression: expressionOrError(expression),
+ Alias: cypher.NewVariableWithSymbol(alias),
+ }
+}
+
+func Node() NodeContinuation {
+ return DefaultScope().Node()
+}
+
+func Path() PathContinuation {
+ return DefaultScope().Path()
+}
+
+func Start() NodeContinuation {
+ return DefaultScope().Start()
+}
+
+func Relationship() RelationshipContinuation {
+ return DefaultScope().Relationship()
+}
+
+func End() NodeContinuation {
+ return DefaultScope().End()
+}
+
+type QualifiedExpression interface {
+ qualifier() cypher.Expression
+}
+
+type scopedExpression interface {
+ QualifiedExpression
+
+ roleName() string
+}
+
+type deleteTarget interface {
+ QualifiedExpression
+
+ deleteTarget()
+}
+
+type EntityContinuation interface {
+ QualifiedExpression
+
+ Count() cypher.Expression
+ ID() IdentityContinuation
+ Property(name string) PropertyContinuation
+}
+
+type KindContinuation interface {
+ Is(kind graph.Kind) cypher.Expression
+ IsOneOf(kinds graph.Kinds) cypher.Expression
+}
+
+type KindsContinuation interface {
+ Has(kind graph.Kind) cypher.Expression
+ HasOneOf(kinds graph.Kinds) cypher.Expression
+ Add(kinds graph.Kinds) *cypher.SetItem
+ Remove(kinds graph.Kinds) *cypher.RemoveItem
+}
+
+type Comparable interface {
+ In(value any) cypher.Expression
+ Contains(value any) cypher.Expression
+ StartsWith(value any) cypher.Expression
+ EndsWith(value any) cypher.Expression
+ Equals(value any) cypher.Expression
+ GreaterThan(value any) cypher.Expression
+ GreaterThanOrEqualTo(value any) cypher.Expression
+ LessThan(value any) cypher.Expression
+ LessThanOrEqualTo(value any) cypher.Expression
+ IsNull() cypher.Expression
+ IsNotNull() cypher.Expression
+}
+
+type PropertyContinuation interface {
+ QualifiedExpression
+ Comparable
+
+ Set(value any) *cypher.SetItem
+ Remove() *cypher.RemoveItem
+}
+
+type IdentityContinuation interface {
+ QualifiedExpression
+ Comparable
+}
+
+type comparisonContinuation struct {
+ qualifierExpression cypher.Expression
+}
+
+func (s *comparisonContinuation) qualifier() cypher.Expression {
+ return s.qualifierExpression
+}
+
+func (s *comparisonContinuation) asComparison(operator cypher.Operator, rOperand any) cypher.Expression {
+ return cypher.NewComparison(
+ s.qualifier(),
+ operator,
+ valueExpression(rOperand),
+ )
+}
+
+func (s *comparisonContinuation) In(value any) cypher.Expression {
+ return s.asComparison(cypher.OperatorIn, value)
+}
+
+func (s *comparisonContinuation) Contains(value any) cypher.Expression {
+ return s.asComparison(cypher.OperatorContains, value)
+}
+
+func (s *comparisonContinuation) StartsWith(value any) cypher.Expression {
+ return s.asComparison(cypher.OperatorStartsWith, value)
+}
+
+func (s *comparisonContinuation) EndsWith(value any) cypher.Expression {
+ return s.asComparison(cypher.OperatorEndsWith, value)
+}
+
+func (s *comparisonContinuation) Equals(value any) cypher.Expression {
+ return s.asComparison(cypher.OperatorEquals, value)
+}
+
+func (s *comparisonContinuation) GreaterThan(value any) cypher.Expression {
+ return s.asComparison(cypher.OperatorGreaterThan, value)
+}
+
+func (s *comparisonContinuation) GreaterThanOrEqualTo(value any) cypher.Expression {
+ return s.asComparison(cypher.OperatorGreaterThanOrEqualTo, value)
+}
+
+func (s *comparisonContinuation) LessThan(value any) cypher.Expression {
+ return s.asComparison(cypher.OperatorLessThan, value)
+}
+
+func (s *comparisonContinuation) LessThanOrEqualTo(value any) cypher.Expression {
+ return s.asComparison(cypher.OperatorLessThanOrEqualTo, value)
+}
+
+func (s *comparisonContinuation) IsNull() cypher.Expression {
+ return cypher.NewComparison(s.qualifier(), cypher.OperatorIs, Literal(nil))
+}
+
+func (s *comparisonContinuation) IsNotNull() cypher.Expression {
+ return cypher.NewComparison(s.qualifier(), cypher.OperatorIsNot, Literal(nil))
+}
+
+type propertyContinuation struct {
+ comparisonContinuation
+}
+
+func (s *propertyContinuation) Set(value any) *cypher.SetItem {
+ return cypher.NewSetItem(
+ s.qualifier(),
+ cypher.OperatorAssignment,
+ valueExpression(value),
+ )
+}
+
+func (s *propertyContinuation) Remove() *cypher.RemoveItem {
+ return cypher.RemoveProperty(s.qualifier())
+}
+
+type entity[T any] struct {
+ identifier *cypher.Variable
+ role string
+}
+
+func (s *entity[T]) Kind() KindContinuation {
+ return kindContinuation{
+ identifier: s.identifier,
+ role: s.role,
+ }
+}
+
+func (s *entity[T]) Kinds() KindsContinuation {
+ return kindsContinuation{
+ identifier: s.identifier,
+ role: s.role,
+ }
+}
+
+func (s *entity[T]) Count() cypher.Expression {
+ return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, s.identifier)
+}
+
+func (s *entity[T]) SetProperties(properties map[string]any) *cypher.Set {
+ set := &cypher.Set{}
+
+ for _, key := range sortedPropertyKeys(properties) {
+ set.Items = append(set.Items, s.Property(key).Set(properties[key]))
+ }
+
+ return set
+}
+
+func (s *entity[T]) RemoveProperties(properties []string) *cypher.Remove {
+ remove := &cypher.Remove{}
+
+ for _, key := range properties {
+ remove.Items = append(remove.Items, s.Property(key).Remove())
+ }
+
+ return remove
+}
+
+func (s *entity[T]) RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression {
+ return &cypher.RelationshipPattern{
+ Variable: s.identifier,
+ Kinds: graph.Kinds{kind},
+ Direction: direction,
+ Properties: properties,
+ }
+}
+
+func (s *entity[T]) NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression {
+ return &cypher.NodePattern{
+ Variable: s.identifier,
+ Kinds: kinds,
+ Properties: properties,
+ }
+}
+
+func (s *entity[T]) qualifier() cypher.Expression {
+ return s.identifier
+}
+
+func (s *entity[T]) deleteTarget() {}
+
+func (s *entity[T]) roleName() string {
+ return s.role
+}
+
+func (s *entity[T]) ID() IdentityContinuation {
+ return &comparisonContinuation{
+ qualifierExpression: &cypher.FunctionInvocation{
+ Distinct: false,
+ Name: cypher.IdentityFunction,
+ Arguments: []cypher.Expression{s.identifier},
+ },
+ }
+}
+
+func (s *entity[T]) Property(propertyName string) PropertyContinuation {
+ return &propertyContinuation{
+ comparisonContinuation: comparisonContinuation{
+ qualifierExpression: cypher.NewPropertyLookup(s.identifier.Symbol, propertyName),
+ },
+ }
+}
+
+type kindContinuation struct {
+ identifier *cypher.Variable
+ role string
+}
+
+func (s kindContinuation) qualifier() cypher.Expression {
+ return s.identifier
+}
+
+func (s kindContinuation) roleName() string {
+ return s.role
+}
+
+func (s kindContinuation) Is(kind graph.Kind) cypher.Expression {
+ return s.IsOneOf(graph.Kinds{kind})
+}
+
+func (s kindContinuation) IsOneOf(kinds graph.Kinds) cypher.Expression {
+ return &cypher.KindMatcher{
+ Reference: s.identifier,
+ Kinds: kinds,
+ }
+}
+
+type kindsContinuation struct {
+ identifier *cypher.Variable
+ role string
+}
+
+func (s kindsContinuation) qualifier() cypher.Expression {
+ return s.identifier
+}
+
+func (s kindsContinuation) roleName() string {
+ return s.role
+}
+
+func (s kindsContinuation) Has(kind graph.Kind) cypher.Expression {
+ return s.HasOneOf(graph.Kinds{kind})
+}
+
+func (s kindsContinuation) HasOneOf(kinds graph.Kinds) cypher.Expression {
+ return &cypher.KindMatcher{
+ Reference: s.identifier,
+ Kinds: kinds,
+ }
+}
+
+func (s kindsContinuation) Add(kinds graph.Kinds) *cypher.SetItem {
+ return cypher.NewSetItem(
+ s.identifier,
+ cypher.OperatorLabelAssignment,
+ kinds,
+ )
+}
+
+func (s kindsContinuation) Remove(kinds graph.Kinds) *cypher.RemoveItem {
+ return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(s.identifier, kinds, false))
+}
+
+type PathContinuation interface {
+ QualifiedExpression
+
+ Count() cypher.Expression
+}
+
+type RelationshipContinuation interface {
+ EntityContinuation
+
+ RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression
+
+ Kind() KindContinuation
+ SetProperties(properties map[string]any) *cypher.Set
+ RemoveProperties(properties []string) *cypher.Remove
+}
+
+type NodeContinuation interface {
+ EntityContinuation
+
+ NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression
+
+ Kinds() KindsContinuation
+ SetProperties(properties map[string]any) *cypher.Set
+ RemoveProperties(properties []string) *cypher.Remove
+}
+
+type QueryBuilder interface {
+ Where(constraints ...cypher.SyntaxNode) QueryBuilder
+ OrderBy(sortItems ...any) QueryBuilder
+ Skip(offset int) QueryBuilder
+ // Limit accepts zero, which renders LIMIT 0 and returns an empty result set.
+ Limit(limit int) QueryBuilder
+ Return(projections ...any) QueryBuilder
+ ReturnDistinct(projections ...any) QueryBuilder
+ Update(updatingClauses ...any) QueryBuilder
+ Create(creationClauses ...any) QueryBuilder
+ Delete(expressions ...any) QueryBuilder
+ WithShortestPaths() QueryBuilder
+ WithAllShortestPaths() QueryBuilder
+ WithTraversalDepth(depth TraversalDepth) QueryBuilder
+ WithRelationshipDirection(direction graph.Direction) QueryBuilder
+ Build() (*PreparedQuery, error)
+}
+
+type updatingClauseKind int
+
+const (
+ updatingClauseSet updatingClauseKind = iota
+ updatingClauseRemove
+ updatingClauseDelete
+ updatingClauseCreate
+)
+
+type pendingUpdatingClause struct {
+ kind updatingClauseKind
+ creates []any
+ setItems []*cypher.SetItem
+ removeItems []*cypher.RemoveItem
+ deleteItems []cypher.Expression
+ detach bool
+}
+
+type builder struct {
+ errors []error
+ constraints []cypher.SyntaxNode
+ sortItems []any
+ projections []any
+ distinct bool
+ identifiers runtimeIdentifiers
+ updatingClauses []pendingUpdatingClause
+ creates []any
+ setItems []*cypher.SetItem
+ removeItems []*cypher.RemoveItem
+ deleteItems []cypher.Expression
+ detachDelete bool
+ relationshipDirection graph.Direction
+ traversalDepth *cypher.PatternRange
+ shortestPathQuery bool
+ allShorestPathsQuery bool
+ skip *int
+ limit *int
+}
+
+func New() QueryBuilder {
+ return DefaultScope().New()
+}
+
+func newBuilder(identifiers runtimeIdentifiers, errs ...error) QueryBuilder {
+ return &builder{
+ identifiers: identifiers,
+ errors: append([]error(nil), errs...),
+ relationshipDirection: graph.DirectionOutbound,
+ }
+}
+
+func (s *builder) WithShortestPaths() QueryBuilder {
+ s.shortestPathQuery = true
+ return s
+}
+
+func (s *builder) WithAllShortestPaths() QueryBuilder {
+ s.allShorestPathsQuery = true
+ return s
+}
+
+func (s *builder) WithTraversalDepth(depth TraversalDepth) QueryBuilder {
+ if depth.err != nil {
+ s.trackError(depth.err)
+ } else {
+ s.traversalDepth = depth.rangePattern()
+ }
+
+ return s
+}
+
+func (s *builder) WithRelationshipDirection(direction graph.Direction) QueryBuilder {
+ if err := validateRelationshipDirection(direction); err != nil {
+ s.trackError(err)
+ } else {
+ s.relationshipDirection = direction
+ }
+
+ return s
+}
+
+func (s *builder) OrderBy(sortItems ...any) QueryBuilder {
+ s.sortItems = append(s.sortItems, sortItems...)
+ return s
+}
+
+func (s *builder) Skip(skip int) QueryBuilder {
+ if skip < 0 {
+ s.trackError(fmt.Errorf("skip must be non-negative: %d", skip))
+ return s
+ }
+
+ s.skip = &skip
+ return s
+}
+
+func (s *builder) Limit(limit int) QueryBuilder {
+ if limit < 0 {
+ s.trackError(fmt.Errorf("limit must be non-negative: %d", limit))
+ return s
+ }
+
+ s.limit = &limit
+ return s
+}
+
+func (s *builder) Return(projections ...any) QueryBuilder {
+ s.projections = append(s.projections, projections...)
+ return s
+}
+
+func (s *builder) ReturnDistinct(projections ...any) QueryBuilder {
+ s.distinct = true
+ s.projections = append(s.projections, projections...)
+ return s
+}
+
+func (s *builder) appendSetItems(items ...*cypher.SetItem) {
+ if len(items) == 0 {
+ return
+ }
+
+ lastClauseIdx := len(s.updatingClauses) - 1
+ if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseSet {
+ s.updatingClauses[lastClauseIdx].setItems = append(s.updatingClauses[lastClauseIdx].setItems, items...)
+ } else {
+ s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{
+ kind: updatingClauseSet,
+ setItems: items,
+ })
+ }
+}
+
+func (s *builder) appendRemoveItems(items ...*cypher.RemoveItem) {
+ if len(items) == 0 {
+ return
+ }
+
+ lastClauseIdx := len(s.updatingClauses) - 1
+ if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseRemove {
+ s.updatingClauses[lastClauseIdx].removeItems = append(s.updatingClauses[lastClauseIdx].removeItems, items...)
+ } else {
+ s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{
+ kind: updatingClauseRemove,
+ removeItems: items,
+ })
+ }
+}
+
+func (s *builder) appendDeleteItems(detach bool, items ...cypher.Expression) {
+ if len(items) == 0 {
+ return
+ }
+
+ // Consecutive deletes share one clause; any node delete makes the whole clause DETACH DELETE.
+ lastClauseIdx := len(s.updatingClauses) - 1
+ if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseDelete {
+ s.updatingClauses[lastClauseIdx].detach = s.updatingClauses[lastClauseIdx].detach || detach
+ s.updatingClauses[lastClauseIdx].deleteItems = append(s.updatingClauses[lastClauseIdx].deleteItems, items...)
+ } else {
+ s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{
+ kind: updatingClauseDelete,
+ deleteItems: items,
+ detach: detach,
+ })
+ }
+}
+
+func (s *builder) Create(creationClauses ...any) QueryBuilder {
+ s.creates = append(s.creates, creationClauses...)
+
+ if len(creationClauses) > 0 {
+ s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{
+ kind: updatingClauseCreate,
+ creates: creationClauses,
+ })
+ }
+
+ return s
+}
+
+func (s *builder) Update(updates ...any) QueryBuilder {
+ for _, nextUpdate := range updates {
+ switch typedNextUpdate := nextUpdate.(type) {
+ case *cypher.Set:
+ if setItems, err := setItemsFromSet(typedNextUpdate); err != nil {
+ s.trackError(err)
+ } else {
+ s.setItems = append(s.setItems, setItems...)
+ s.appendSetItems(setItems...)
+ }
+
+ case *cypher.SetItem:
+ if setItem, err := setItemFromValue(typedNextUpdate); err != nil {
+ s.trackError(err)
+ } else {
+ s.setItems = append(s.setItems, setItem)
+ s.appendSetItems(setItem)
+ }
+
+ case *cypher.Remove:
+ if removeItems, err := removeItemsFromRemove(typedNextUpdate); err != nil {
+ s.trackError(err)
+ } else {
+ s.removeItems = append(s.removeItems, removeItems...)
+ s.appendRemoveItems(removeItems...)
+ }
+
+ case *cypher.RemoveItem:
+ if removeItem, err := removeItemFromValue(typedNextUpdate); err != nil {
+ s.trackError(err)
+ } else {
+ s.removeItems = append(s.removeItems, removeItem)
+ s.appendRemoveItems(removeItem)
+ }
+
+ default:
+ s.trackError(fmt.Errorf("unknown update type: %T", nextUpdate))
+ }
+ }
+
+ return s
+}
+
+func (s *builder) Delete(deleteItems ...any) QueryBuilder {
+ var pendingDeleteItems []cypher.Expression
+ pendingDetachDelete := false
+
+ for _, nextDelete := range deleteItems {
+ switch typedNextUpdate := nextDelete.(type) {
+ case deleteTarget:
+ if isNilPointer(typedNextUpdate) {
+ s.trackError(fmt.Errorf("delete target is nil"))
+ continue
+ }
+
+ deleteItem, detach, err := deleteItemFromExpression(typedNextUpdate.qualifier(), s.identifiers)
+ if err != nil {
+ s.trackError(err)
+ continue
+ }
+
+ if detach {
+ s.detachDelete = true
+ pendingDetachDelete = true
+ }
+
+ s.deleteItems = append(s.deleteItems, deleteItem)
+ pendingDeleteItems = append(pendingDeleteItems, deleteItem)
+
+ case *cypher.Variable:
+ deleteItem, detach, err := deleteItemFromExpression(typedNextUpdate, s.identifiers)
+ if err != nil {
+ s.trackError(err)
+ continue
+ }
+
+ if detach {
+ s.detachDelete = true
+ pendingDetachDelete = true
+ }
+
+ s.deleteItems = append(s.deleteItems, deleteItem)
+ pendingDeleteItems = append(pendingDeleteItems, deleteItem)
+
+ case *cypher.PropertyLookup:
+ if err := validateExpressionValue(typedNextUpdate, "delete expression"); err != nil {
+ s.trackError(err)
+ continue
+ }
+
+ s.trackError(fmt.Errorf("delete target must be a node, relationship, or variable; use remove for properties"))
+
+ case QualifiedExpression:
+ if isNilPointer(typedNextUpdate) {
+ s.trackError(fmt.Errorf("delete target is nil"))
+ continue
+ }
+
+ if err := validateExpressionValue(typedNextUpdate.qualifier(), "delete expression"); err != nil {
+ s.trackError(err)
+ continue
+ }
+
+ s.trackError(fmt.Errorf("delete target must be a node, relationship, or variable; got %T", nextDelete))
+
+ default:
+ s.trackError(fmt.Errorf("unknown delete type: %T", nextDelete))
+ }
+ }
+
+ s.appendDeleteItems(pendingDetachDelete, pendingDeleteItems...)
+ return s
+}
+
+func deleteItemFromExpression(expression cypher.Expression, identifiers runtimeIdentifiers) (cypher.Expression, bool, error) {
+ if err := validateExpressionValue(expression, "delete expression"); err != nil {
+ return nil, false, err
+ }
+
+ variable, typeOK := expression.(*cypher.Variable)
+ if !typeOK || variable == nil {
+ return nil, false, fmt.Errorf("delete target must resolve to a variable, got %T", expression)
+ }
+
+ if variable.Symbol == identifiers.path {
+ return nil, false, fmt.Errorf("delete target must be a node or relationship variable, got path variable %q", variable.Symbol)
+ }
+
+ return copyExpression(variable), isDetachDeleteQualifier(variable, identifiers), nil
+}
+
+func (s *builder) trackError(err error) {
+ s.errors = append(s.errors, err)
+}
+
+func (s *builder) Where(constraints ...cypher.SyntaxNode) QueryBuilder {
+ s.constraints = append(s.constraints, constraints...)
+ return s
+}
+
+func patternEndsWithNodePattern(pattern *cypher.PatternPart) bool {
+ numElements := len(pattern.PatternElements)
+ if numElements == 0 {
+ return false
+ }
+
+ return pattern.PatternElements[numElements-1].IsNodePattern()
+}
+
+func isCreateNodeValue(value any, identifiers runtimeIdentifiers) bool {
+ switch typedValue := value.(type) {
+ case QualifiedExpression:
+ if variable, typeOK := typedValue.qualifier().(*cypher.Variable); typeOK {
+ switch variable.Symbol {
+ case identifiers.node, identifiers.start, identifiers.end:
+ return true
+ }
+ }
+
+ case *cypher.NodePattern:
+ return typedValue != nil
+ }
+
+ return false
+}
+
+func isCreateRelationshipValue(value any) bool {
+ _, typeOK := value.(*cypher.RelationshipPattern)
+ return typeOK
+}
+
+func nextCreateValueIsNode(creates []any, idx int, identifiers runtimeIdentifiers) bool {
+ nextIdx := idx + 1
+ return nextIdx < len(creates) && isCreateNodeValue(creates[nextIdx], identifiers)
+}
+
+func newCreatePatternPart(createClause *cypher.Create) *cypher.PatternPart {
+ pattern := &cypher.PatternPart{}
+ createClause.Pattern = append(createClause.Pattern, pattern)
+ return pattern
+}
+
+func createPatternHasElements(pattern *cypher.PatternPart) bool {
+ return pattern != nil && len(pattern.PatternElements) > 0
+}
+
+func shouldStartNewCreatePattern(pattern *cypher.PatternPart, nextCreate any, patternClosed bool, identifiers runtimeIdentifiers) bool {
+ if !createPatternHasElements(pattern) {
+ return false
+ }
+
+ if isCreateNodeValue(nextCreate, identifiers) && patternEndsWithNodePattern(pattern) {
+ return true
+ }
+
+ return patternClosed && isCreateRelationshipValue(nextCreate)
+}
+
+func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeIdentifiers, creates []any) error {
+ if len(creates) == 0 {
+ return nil
+ }
+
+ var (
+ createClause = &cypher.Create{
+ Unique: false,
+ }
+ pattern = newCreatePatternPart(createClause)
+ patternClosed bool
+ )
+
+ for idx, nextCreate := range creates {
+ if shouldStartNewCreatePattern(pattern, nextCreate, patternClosed, identifiers) {
+ pattern = newCreatePatternPart(createClause)
+ patternClosed = false
+ }
+
+ switch typedNextCreate := nextCreate.(type) {
+ case QualifiedExpression:
+ switch typedExpression := typedNextCreate.qualifier().(type) {
+ case *cypher.Variable:
+ if typedExpression == nil {
+ return fmt.Errorf("invalid variable reference for create: ")
+ }
+
+ switch typedExpression.Symbol {
+ case identifiers.node, identifiers.start, identifiers.end:
+ pattern.AddPatternElements(&cypher.NodePattern{
+ Variable: cypher.NewVariableWithSymbol(typedExpression.Symbol),
+ })
+ patternClosed = false
+
+ default:
+ return fmt.Errorf("invalid variable reference for create: %s", typedExpression.Symbol)
+ }
+
+ default:
+ return fmt.Errorf("invalid qualified expression for create: %T", typedExpression)
+ }
+
+ case *cypher.NodePattern:
+ if err := validateNodePattern(typedNextCreate); err != nil {
+ return err
+ }
+
+ pattern.AddPatternElements(cypher.Copy(typedNextCreate))
+ patternClosed = false
+
+ case *cypher.RelationshipPattern:
+ if err := validateRelationshipPattern(typedNextCreate); err != nil {
+ return err
+ }
+
+ if !patternEndsWithNodePattern(pattern) {
+ pattern.AddPatternElements(&cypher.NodePattern{
+ Variable: identifiers.Start(),
+ })
+ }
+
+ pattern.AddPatternElements(cypher.Copy(typedNextCreate))
+
+ if !nextCreateValueIsNode(creates, idx, identifiers) {
+ pattern.AddPatternElements(&cypher.NodePattern{
+ Variable: identifiers.End(),
+ })
+ patternClosed = true
+ } else {
+ patternClosed = false
+ }
+
+ default:
+ return fmt.Errorf("invalid type for create: %T", nextCreate)
+ }
+ }
+
+ singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause(createClause))
+ return nil
+}
+
+func (s *builder) buildUpdatingClauses(singlePartQuery *cypher.SinglePartQuery) error {
+ for _, updatingClause := range s.updatingClauses {
+ switch updatingClause.kind {
+ case updatingClauseSet:
+ singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause(
+ cypher.NewSet(updatingClause.setItems),
+ ))
+
+ case updatingClauseRemove:
+ singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause(
+ cypher.NewRemove(updatingClause.removeItems),
+ ))
+
+ case updatingClauseDelete:
+ singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause(
+ cypher.NewDelete(
+ updatingClause.detach,
+ updatingClause.deleteItems,
+ ),
+ ))
+
+ case updatingClauseCreate:
+ if err := buildCreates(singlePartQuery, s.identifiers, updatingClause.creates); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func (s *builder) buildProjectionOrder() (*cypher.Order, error) {
+ var orderByNode *cypher.Order
+
+ if len(s.sortItems) > 0 {
+ orderByNode = &cypher.Order{}
+
+ for _, untypedSortItem := range s.sortItems {
+ switch typedSortItem := untypedSortItem.(type) {
+ case *cypher.Order:
+ if sortItems, err := sortItemsFromOrder(typedSortItem); err != nil {
+ return nil, err
+ } else {
+ orderByNode.Items = append(orderByNode.Items, sortItems...)
+ }
+
+ case *cypher.SortItem:
+ if sortItem, err := sortItemFromValue(typedSortItem); err != nil {
+ return nil, err
+ } else {
+ orderByNode.Items = append(orderByNode.Items, sortItem)
+ }
+
+ default:
+ if sortItem, err := sortItemFromValue(typedSortItem); err != nil {
+ return nil, err
+ } else {
+ orderByNode.Items = append(orderByNode.Items, sortItem)
+ }
+ }
+ }
+ }
+
+ return orderByNode, nil
+}
+
+func appendProjectionOrder(projection *cypher.Projection, sortItems ...*cypher.SortItem) {
+ if len(sortItems) == 0 {
+ return
+ }
+
+ if projection.Order == nil {
+ projection.Order = &cypher.Order{}
+ }
+
+ projection.Order.Items = append(projection.Order.Items, sortItems...)
+}
+
+func applyReturnProjection(projection *cypher.Projection, returnClause *cypher.Return) error {
+ if projectionItems, err := projectionItemsFromReturn(returnClause); err != nil {
+ return err
+ } else {
+ projection.Distinct = projection.Distinct || returnClause.Projection.Distinct
+ projection.All = projection.All || returnClause.Projection.All
+
+ for _, projectionItem := range projectionItems {
+ projection.AddItem(projectionItem)
+ }
+ }
+
+ if returnClause.Projection.Order != nil {
+ if sortItems, err := sortItemsFromOrder(returnClause.Projection.Order); err != nil {
+ return err
+ } else {
+ appendProjectionOrder(projection, sortItems...)
+ }
+ }
+
+ if returnClause.Projection.Skip != nil {
+ projection.Skip = copySkip(returnClause.Projection.Skip)
+ }
+
+ if returnClause.Projection.Limit != nil {
+ projection.Limit = copyLimit(returnClause.Projection.Limit)
+ }
+
+ return nil
+}
+
+func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error {
+ var (
+ hasProjectedItems = len(s.projections) > 0
+ hasSkip = s.skip != nil
+ hasLimit = s.limit != nil
+ requiresProjection = hasProjectedItems || hasSkip || hasLimit
+ )
+
+ if requiresProjection {
+ if !hasProjectedItems {
+ return fmt.Errorf("query expected projected items")
+ }
+
+ projection := singlePartQuery.NewProjection(s.distinct)
+
+ for _, nextProjection := range s.projections {
+ switch typedNextProjection := nextProjection.(type) {
+ case *cypher.Return:
+ if err := applyReturnProjection(projection, typedNextProjection); err != nil {
+ return err
+ }
+
+ default:
+ if projectionItem, err := projectionItemFromValue(typedNextProjection); err != nil {
+ return err
+ } else {
+ projection.AddItem(projectionItem)
+ }
+ }
+ }
+
+ if s.skip != nil {
+ projection.Skip = cypher.NewSkip(*s.skip)
+ }
+
+ if s.limit != nil {
+ projection.Limit = cypher.NewLimit(*s.limit)
+ }
+
+ if projectionOrder, err := s.buildProjectionOrder(); err != nil {
+ return err
+ } else if projectionOrder != nil {
+ appendProjectionOrder(projection, projectionOrder.Items...)
+ }
+ }
+
+ return nil
+}
+
+func countRelationshipKindMatchers(constraints []cypher.SyntaxNode, identifiers runtimeIdentifiers) (int, error) {
+ var count int
+
+ for _, nextConstraint := range constraints {
+ if kindMatcher, typeOK := nextConstraint.(*cypher.KindMatcher); typeOK {
+ if identifier, typeOK := kindMatcher.Reference.(*cypher.Variable); !typeOK {
+ return 0, fmt.Errorf("expected type *cypher.Variable, got %T", kindMatcher.Reference)
+ } else if identifier.Symbol == identifiers.relationship {
+ count++
+ }
+ }
+ }
+
+ return count, nil
+}
+
+type PreparedQuery struct {
+ Query *cypher.RegularQuery
+ Parameters map[string]any
+}
+
+func (s *builder) hasActions() bool {
+ return len(s.projections) > 0 || len(s.setItems) > 0 || len(s.removeItems) > 0 || len(s.creates) > 0 || len(s.deleteItems) > 0
+}
+
+func (s *builder) wantsShortestPathPattern() bool {
+ return s.shortestPathQuery || s.allShorestPathsQuery
+}
+
+func (s *builder) wantsTraversalPattern() bool {
+ return s.traversalDepth != nil
+}
+
+func (s *builder) usesRangedRelationshipPattern() bool {
+ return s.wantsTraversalPattern() || s.wantsShortestPathPattern()
+}
+
+func (s *builder) Build() (*PreparedQuery, error) {
+ if len(s.errors) > 0 {
+ return nil, errors.Join(s.errors...)
+ }
+
+ if !s.hasActions() {
+ return nil, fmt.Errorf("query has no action specified")
+ }
+
+ if err := collectModelErrorsFromKnownValues(s.constraints, s.creates, s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems); err != nil {
+ return nil, err
+ }
+
+ var (
+ regularQuery, singlePartQuery = cypher.NewRegularQueryWithSingleQuery()
+ match = &cypher.Match{}
+ readIdentifiers = newIdentifierSet()
+ relationshipKinds graph.Kinds
+ )
+
+ createScope, err := collectCreateScope(s.identifiers, s.creates...)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := s.buildUpdatingClauses(singlePartQuery); err != nil {
+ return nil, err
+ }
+
+ if err := s.buildProjection(singlePartQuery); err != nil {
+ return nil, err
+ }
+
+ if len(s.constraints) > 0 {
+ var (
+ whereClause = match.NewWhere()
+ constraints = cypher.NewConjunction()
+ numRelationshipKindMatchers, err = countRelationshipKindMatchers(s.constraints, s.identifiers)
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ for _, nextConstraint := range s.constraints {
+ switch typedNextConstraint := nextConstraint.(type) {
+ case *cypher.KindMatcher:
+ if identifier, typeOK := typedNextConstraint.Reference.(*cypher.Variable); !typeOK {
+ return nil, fmt.Errorf("expected type *cypher.Variable, got %T", typedNextConstraint.Reference)
+ } else if identifier.Symbol == s.identifiers.relationship && numRelationshipKindMatchers == 1 {
+ relationshipKinds = relationshipKinds.Add(typedNextConstraint.Kinds...)
+ readIdentifiers.Add(s.identifiers.relationship)
+ continue
+ }
+ }
+
+ constraintCopy := cypher.Copy(nextConstraint)
+ constraints.Add(parenthesizeDisjunctiveExpression(constraintCopy))
+ }
+
+ if constraints.Len() > 0 {
+ whereClause.Add(constraints)
+
+ whereIdentifiers := newIdentifierSet()
+ if err := whereIdentifiers.CollectFromExpression(whereClause); err != nil {
+ return nil, err
+ }
+
+ if s.usesRangedRelationshipPattern() && whereIdentifiers.Contains(s.identifiers.relationship) {
+ return nil, fmt.Errorf("ranged relationship patterns only support top-level relationship kind constraints")
+ }
+
+ readIdentifiers.Or(whereIdentifiers)
+ }
+ }
+
+ actionIdentifiers, err := collectIdentifiersFromValues(s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems)
+ if err != nil {
+ return nil, err
+ }
+
+ actionIdentifiers.Remove(createScope.identifiers)
+
+ if s.usesRangedRelationshipPattern() && actionIdentifiers.Contains(s.identifiers.relationship) {
+ return nil, fmt.Errorf("ranged relationship patterns do not support relationship projections or mutations; return the path instead")
+ }
+
+ matchIdentifiers := readIdentifiers.Clone()
+ matchIdentifiers.Or(actionIdentifiers)
+
+ if err := validateKnownIdentifiers(matchIdentifiers, s.identifiers); err != nil {
+ return nil, err
+ }
+
+ if s.wantsTraversalPattern() && !isRelationshipPattern(matchIdentifiers, s.identifiers) && !matchIdentifiers.Contains(s.identifiers.path) {
+ return nil, fmt.Errorf("recursive traversal query requires relationship query identifiers")
+ }
+
+ if s.wantsShortestPathPattern() && !isRelationshipPattern(matchIdentifiers, s.identifiers) {
+ return nil, fmt.Errorf("shortest path query requires relationship query identifiers")
+ }
+
+ if len(s.constraints) > 0 || matchIdentifiers.Len() > 0 {
+ if isNodePattern(matchIdentifiers, s.identifiers) {
+ if err := prepareNodePattern(match, matchIdentifiers, s.identifiers); err != nil {
+ return nil, err
+ }
+ } else if createScope.createsRelationship && !matchIdentifiers.Contains(s.identifiers.relationship) {
+ if err := prepareCreateRelationshipMatch(match, matchIdentifiers, s.identifiers); err != nil {
+ return nil, err
+ }
+ } else if isRelationshipPattern(matchIdentifiers, s.identifiers) || (s.wantsTraversalPattern() && matchIdentifiers.Contains(s.identifiers.path)) {
+ if err := prepareRelationshipPattern(match, matchIdentifiers, s.identifiers, relationshipKinds, s.traversalDepth, s.relationshipDirection, s.shortestPathQuery, s.allShorestPathsQuery); err != nil {
+ return nil, err
+ }
+ } else {
+ return nil, fmt.Errorf("query has no node and relationship query identifiers specified")
+ }
+ }
+
+ if len(match.Pattern) > 0 {
+ newReadingClause := cypher.NewReadingClause()
+ newReadingClause.Match = match
+
+ singlePartQuery.ReadingClauses = append(singlePartQuery.ReadingClauses, newReadingClause)
+ }
+
+ if err := collectModelErrors(regularQuery); err != nil {
+ return nil, err
+ }
+
+ if parameters, err := materializeParameters(regularQuery); err != nil {
+ return nil, err
+ } else {
+ return &PreparedQuery{
+ Query: regularQuery,
+ Parameters: parameters,
+ }, nil
+ }
+}
diff --git a/query/v2/query_test.go b/query/v2/query_test.go
new file mode 100644
index 00000000..61ea0d32
--- /dev/null
+++ b/query/v2/query_test.go
@@ -0,0 +1,985 @@
+package v2_test
+
+import (
+ "testing"
+
+ "github.com/specterops/dawgs/cypher/models/cypher"
+ "github.com/specterops/dawgs/cypher/models/cypher/format"
+ "github.com/specterops/dawgs/graph"
+ v2 "github.com/specterops/dawgs/query/v2"
+ "github.com/stretchr/testify/require"
+)
+
+func renderPrepared(t *testing.T, preparedQuery *v2.PreparedQuery) string {
+ t.Helper()
+
+ cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false)
+ require.NoError(t, err)
+
+ return cypherQueryStr
+}
+
+func firstCreateClause(t *testing.T, preparedQuery *v2.PreparedQuery) *cypher.Create {
+ t.Helper()
+
+ updatingClauses := preparedQuery.Query.SingleQuery.SinglePartQuery.UpdatingClauses
+ require.NotEmpty(t, updatingClauses)
+
+ updatingClause, typeOK := updatingClauses[0].(*cypher.UpdatingClause)
+ require.True(t, typeOK)
+
+ createClause, typeOK := updatingClause.Clause.(*cypher.Create)
+ require.True(t, typeOK)
+
+ return createClause
+}
+
+func TestQuery(t *testing.T) {
+ preparedQuery, err := v2.New().Where(
+ v2.Not(v2.Relationship().Kind().Is(graph.StringKind("test"))),
+ v2.Not(v2.Relationship().Kind().IsOneOf(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")})),
+ v2.Relationship().Property("rel_prop").LessThanOrEqualTo(1234),
+ v2.Relationship().Property("other_prop").Equals(5678),
+ v2.Start().Kinds().HasOneOf(graph.Kinds{graph.StringKind("test")}),
+ ).Update(
+ v2.Start().Property("this_prop").Set(1234),
+ v2.End().Kinds().Remove(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")}),
+ ).Delete(
+ v2.Start(),
+ ).Return(
+ v2.Relationship(),
+ v2.Start().Property("node_prop"),
+ ).Skip(10).Limit(10).Build()
+ require.NoError(t, err)
+
+ cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false)
+ require.NoError(t, err)
+
+ require.Equal(t, "match (s)-[r]->(e) where not r:test and not (r:A or r:B) and r.rel_prop <= $p0 and r.other_prop = $p1 and s:test set s.this_prop = $p2 remove e:A:B detach delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr)
+ require.Equal(t, map[string]any{
+ "p0": 1234,
+ "p1": 5678,
+ "p2": 1234,
+ }, preparedQuery.Parameters)
+
+ preparedQuery, err = v2.New().Create(
+ v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, cypher.NewParameter("props", map[string]any{})),
+ ).Build()
+
+ require.NoError(t, err)
+
+ cypherQueryStr, err = format.RegularQuery(preparedQuery.Query, false)
+ require.NoError(t, err)
+
+ require.Equal(t, "create (n:A $props)", cypherQueryStr)
+ require.Equal(t, map[string]any{
+ "props": map[string]any{},
+ }, preparedQuery.Parameters)
+}
+
+func TestCreateRelationshipWithMatchedEndpoints(t *testing.T) {
+ preparedQuery, err := v2.New().Where(
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Create(
+ v2.Relationship().RelationshipPattern(graph.StringKind("A"), v2.NamedParameter("props", map[string]any{"name": "rel"}), graph.DirectionOutbound),
+ ).Return(
+ v2.Relationship().ID(),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (s), (e) where id(s) = $p0 and id(e) = $p1 create (s)-[r:A $props]->(e) return id(r)", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": 1,
+ "p1": 2,
+ "props": map[string]any{"name": "rel"},
+ }, preparedQuery.Parameters)
+}
+
+func TestCreateRelationshipWithExplicitEndpoints(t *testing.T) {
+ preparedQuery, err := v2.New().Where(
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Create(
+ v2.Start(),
+ v2.RelationshipPattern(graph.StringKind("A"), v2.NamedParameter("props", map[string]any{"name": "rel"}), graph.DirectionOutbound),
+ v2.End(),
+ ).Return(
+ v2.Relationship().ID(),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (s), (e) where id(s) = $p0 and id(e) = $p1 create (s)-[r:A $props]->(e) return id(r)", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": 1,
+ "p1": 2,
+ "props": map[string]any{"name": "rel"},
+ }, preparedQuery.Parameters)
+}
+
+func TestCreateSplitsDisjointNodePatterns(t *testing.T) {
+ preparedQuery, err := v2.New().Create(
+ v2.NodePattern(graph.Kinds{graph.StringKind("A")}, nil),
+ v2.NodePattern(graph.Kinds{graph.StringKind("B")}, nil),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "create (n:A), (n:B)", renderPrepared(t, preparedQuery))
+ require.Len(t, firstCreateClause(t, preparedQuery).Pattern, 2)
+}
+
+func TestCreateSplitsBackToBackRelationshipPatterns(t *testing.T) {
+ preparedQuery, err := v2.New().Create(
+ v2.RelationshipPattern(graph.StringKind("A"), nil, graph.DirectionOutbound),
+ v2.RelationshipPattern(graph.StringKind("B"), nil, graph.DirectionOutbound),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "create (s)-[r:A]->(e), (s)-[r:B]->(e)", renderPrepared(t, preparedQuery))
+ require.Len(t, firstCreateClause(t, preparedQuery).Pattern, 2)
+}
+
+func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) {
+ preparedQuery, err := v2.New().Create(
+ v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, v2.NamedParameter("props", map[string]any{"name": "node"})),
+ ).Return(
+ v2.Node().ID(),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "create (n:A $props) return id(n)", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "props": map[string]any{"name": "node"},
+ }, preparedQuery.Parameters)
+}
+
+func TestLogicalHelpersPreservePrecedence(t *testing.T) {
+ a := v2.Node().Property("a").Equals("a")
+ b := v2.Node().Property("b").Equals("b")
+ c := v2.Node().Property("c").Equals("c")
+
+ testCases := []struct {
+ name string
+ builder v2.QueryBuilder
+ expected string
+ }{
+ {
+ name: "or is parenthesized in isolation",
+ builder: v2.New().Where(
+ v2.Or(a, b),
+ ).Return(v2.Node()),
+ expected: "match (n) where (n.a = $p0 or n.b = $p1) return n",
+ },
+ {
+ name: "or is parenthesized when where and-chains constraints",
+ builder: v2.New().Where(
+ v2.Or(a, b),
+ c,
+ ).Return(v2.Node()),
+ expected: "match (n) where (n.a = $p0 or n.b = $p1) and n.c = $p2 return n",
+ },
+ {
+ name: "nested or is parenthesized inside and",
+ builder: v2.New().Where(
+ v2.And(a, v2.Or(b, c)),
+ ).Return(v2.Node()),
+ expected: "match (n) where n.a = $p0 and (n.b = $p1 or n.c = $p2) return n",
+ },
+ {
+ name: "nested and is parenthesized inside or",
+ builder: v2.New().Where(
+ v2.Or(v2.And(a, b), c),
+ ).Return(v2.Node()),
+ expected: "match (n) where ((n.a = $p0 and n.b = $p1) or n.c = $p2) return n",
+ },
+ {
+ name: "not wraps or",
+ builder: v2.New().Where(
+ v2.Not(v2.Or(a, b)),
+ ).Return(v2.Node()),
+ expected: "match (n) where not (n.a = $p0 or n.b = $p1) return n",
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ preparedQuery, err := testCase.builder.Build()
+ require.NoError(t, err)
+ require.Equal(t, testCase.expected, renderPrepared(t, preparedQuery))
+ })
+ }
+}
+
+func TestInvalidCreateQualifiedExpressionReturnsError(t *testing.T) {
+ _, err := v2.New().Create(v2.Node().Property("name")).Build()
+ require.ErrorContains(t, err, "invalid qualified expression for create: *cypher.PropertyLookup")
+}
+
+func TestUpdatingClausesPreserveFluentOrder(t *testing.T) {
+ preparedQuery, err := v2.New().Create(
+ v2.NodePattern(graph.Kinds{graph.StringKind("User")}, nil),
+ ).Update(
+ v2.SetProperty(v2.Node().Property("name"), "created"),
+ ).Return(
+ v2.Node().Property("name"),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "create (n:User) set n.name = $p0 return n.name", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": "created",
+ }, preparedQuery.Parameters)
+
+ preparedQuery, err = v2.New().Where(
+ v2.Node().ID().Equals(1),
+ ).Update(
+ v2.DeleteProperties(v2.Node(), "old"),
+ ).Update(
+ v2.SetProperty(v2.Node().Property("new"), "value"),
+ ).Return(
+ v2.Node(),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) where id(n) = $p0 remove n.old set n.new = $p1 return n", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": 1,
+ "p1": "value",
+ }, preparedQuery.Parameters)
+}
+
+func TestScopedRelationshipPatternControls(t *testing.T) {
+ scope := v2.NewScope("path", "person", "source", "edge", "target")
+
+ preparedQuery, err := scope.New().WithRelationshipDirection(graph.DirectionInbound).Where(
+ scope.Relationship().Kind().Is(graph.StringKind("MemberOf")),
+ scope.Start().ID().Equals(1),
+ ).Return(
+ scope.Relationship().Kind(),
+ scope.End().Kinds(),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (source)<-[edge:MemberOf]-(target) where id(source) = $p0 return type(edge), labels(target)", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": 1,
+ }, preparedQuery.Parameters)
+}
+
+func TestScopedKindsOfCompatibilityHelper(t *testing.T) {
+ scope := v2.NewScope("path", "person", "source", "edge", "target")
+
+ preparedQuery, err := scope.New().Return(
+ v2.KindsOf(scope.Relationship()),
+ v2.KindsOf(scope.End()),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match ()-[edge]->(target) return type(edge), labels(target)", renderPrepared(t, preparedQuery))
+}
+
+func TestInvalidScopeAliasesReturnBuildErrors(t *testing.T) {
+ emptyAliasScope := v2.NewScope("", "node", "start", "relationship", "end")
+ _, err := emptyAliasScope.New().Return(emptyAliasScope.Node()).Build()
+ require.ErrorContains(t, err, "scope alias path is empty")
+
+ duplicateAliasScope := v2.NewScope("path", "node", "node", "relationship", "end")
+ _, err = duplicateAliasScope.New().Return(duplicateAliasScope.Start()).Build()
+ require.ErrorContains(t, err, `scope aliases node and start both use "node"`)
+
+ invalidAliasScope := v2.NewScope("path", "bad name", "start", "relationship", "end")
+ _, err = invalidAliasScope.New().Return(invalidAliasScope.Node()).Build()
+ require.ErrorContains(t, err, `scope alias node has invalid symbol "bad name"`)
+}
+
+func TestUnicodeCypherSymbols(t *testing.T) {
+ scope := v2.NewScope("路径", "节点", "起点", "关系", "终点")
+
+ preparedQuery, err := scope.New().Where(
+ scope.Node().Property("name").Equals(v2.NamedParameter("名字", "alice")),
+ ).Return(
+ v2.As(scope.Node().ID(), "标识"),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (节点) where 节点.name = $名字 return id(节点) as 标识", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "名字": "alice",
+ }, preparedQuery.Parameters)
+}
+
+func TestInvalidRelationshipDirectionReturnsError(t *testing.T) {
+ _, err := v2.New().WithRelationshipDirection(graph.Direction(99)).Return(v2.Relationship()).Build()
+ require.ErrorContains(t, err, "unsupported relationship direction: invalid")
+}
+
+func TestRelationshipDirectionBoth(t *testing.T) {
+ preparedQuery, err := v2.New().WithRelationshipDirection(graph.DirectionBoth).Return(v2.Relationship()).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match ()-[r]-() return r", renderPrepared(t, preparedQuery))
+}
+
+func TestShortestPathControls(t *testing.T) {
+ preparedQuery, err := v2.New().WithShortestPaths().Where(
+ v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Path(),
+ ).Build()
+ require.NoError(t, err)
+ require.Equal(t, "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": 1,
+ "p1": 2,
+ }, preparedQuery.Parameters)
+
+ preparedQuery, err = v2.New().WithAllShortestPaths().Where(
+ v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Path(),
+ ).Build()
+ require.NoError(t, err)
+ require.Equal(t, "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery))
+
+ preparedQuery, err = v2.New().WithShortestPaths().WithTraversalDepth(v2.MinDepth(1)).Where(
+ v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Path(),
+ ).Build()
+ require.NoError(t, err)
+ require.Equal(t, "match p = shortestPath((s)-[r:MemberOf*1..]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery))
+
+ _, err = v2.New().WithShortestPaths().WithAllShortestPaths().Where(
+ v2.Start().ID().Equals(1),
+ v2.End().ID().Equals(2),
+ ).Return(
+ v2.Path(),
+ ).Build()
+ require.ErrorContains(t, err, "query is requesting both all shortest paths and shortest paths")
+
+ _, err = v2.New().WithShortestPaths().Return(v2.Node()).Build()
+ require.ErrorContains(t, err, "shortest path query requires relationship query identifiers")
+
+ _, err = v2.New().WithAllShortestPaths().Return(v2.As(v2.Literal(1), "one")).Build()
+ require.ErrorContains(t, err, "shortest path query requires relationship query identifiers")
+}
+
+func TestTraversalDepthControls(t *testing.T) {
+ cases := map[string]struct {
+ builder v2.QueryBuilder
+ expectedCypher string
+ expectedParams map[string]any
+ }{
+ "any depth": {
+ builder: v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.End()),
+ expectedCypher: "match ()-[*]->(e) return e",
+ },
+ "minimum depth": {
+ builder: v2.New().WithTraversalDepth(v2.MinDepth(1)).Return(v2.End()),
+ expectedCypher: "match ()-[*1..]->(e) return e",
+ },
+ "maximum depth": {
+ builder: v2.New().WithTraversalDepth(v2.MaxDepth(5)).Return(v2.End()),
+ expectedCypher: "match ()-[*..5]->(e) return e",
+ },
+ "depth range": {
+ builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 5)).Where(
+ v2.Relationship().Kind().IsOneOf(graph.Kinds{graph.StringKind("KindA"), graph.StringKind("KindB")}),
+ v2.Start().ID().Equals(1),
+ v2.End().Kinds().Has(graph.StringKind("User")),
+ ).Return(
+ v2.Path(),
+ v2.End(),
+ ),
+ expectedCypher: "match p = (s)-[r:KindA|KindB*1..5]->(e) where id(s) = $p0 and e:User return p, e",
+ expectedParams: map[string]any{"p0": 1},
+ },
+ "exact depth": {
+ builder: v2.New().WithTraversalDepth(v2.ExactDepth(3)).Return(v2.End()),
+ expectedCypher: "match ()-[*3..3]->(e) return e",
+ },
+ "inbound depth range": {
+ builder: v2.New().WithTraversalDepth(v2.DepthRange(2, 5)).WithRelationshipDirection(graph.DirectionInbound).Return(v2.Start()),
+ expectedCypher: "match (s)<-[*2..5]-() return s",
+ },
+ "path only": {
+ builder: v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.Path()),
+ expectedCypher: "match p = ()-[*]->() return p",
+ },
+ }
+
+ for name, testCase := range cases {
+ t.Run(name, func(t *testing.T) {
+ preparedQuery, err := testCase.builder.Build()
+ require.NoError(t, err)
+ require.Equal(t, testCase.expectedCypher, renderPrepared(t, preparedQuery))
+
+ if testCase.expectedParams == nil {
+ require.Empty(t, preparedQuery.Parameters)
+ } else {
+ require.Equal(t, testCase.expectedParams, preparedQuery.Parameters)
+ }
+ })
+ }
+}
+
+func TestInvalidTraversalDepthControls(t *testing.T) {
+ _, err := v2.New().WithTraversalDepth(v2.MinDepth(-1)).Return(v2.End()).Build()
+ require.ErrorContains(t, err, "traversal depth minimum must be non-negative: -1")
+
+ _, err = v2.New().WithTraversalDepth(v2.MaxDepth(-1)).Return(v2.End()).Build()
+ require.ErrorContains(t, err, "traversal depth maximum must be non-negative: -1")
+
+ _, err = v2.New().WithTraversalDepth(v2.DepthRange(3, 1)).Return(v2.End()).Build()
+ require.ErrorContains(t, err, "traversal depth maximum 1 is less than minimum 3")
+
+ _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.Node()).Build()
+ require.ErrorContains(t, err, "recursive traversal query requires relationship query identifiers")
+
+ _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Where(
+ v2.Relationship().Property("enabled").Equals(true),
+ ).Return(
+ v2.End(),
+ ).Build()
+ require.ErrorContains(t, err, "ranged relationship patterns only support top-level relationship kind constraints")
+
+ _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.Relationship()).Build()
+ require.ErrorContains(t, err, "ranged relationship patterns do not support relationship projections or mutations")
+
+ _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Where(
+ v2.Start().ID().Equals(1),
+ ).Delete(
+ v2.Relationship(),
+ ).Build()
+ require.ErrorContains(t, err, "ranged relationship patterns do not support relationship projections or mutations")
+}
+
+func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) {
+ _, err := v2.New().Where(
+ v2.Node().ID().Equals(1),
+ v2.Relationship().ID().Equals(2),
+ ).Return(
+ v2.Node(),
+ ).Build()
+ require.ErrorContains(t, err, "query mixes node and relationship query identifiers")
+}
+
+func TestRawIdentifiersMustBeKnownToScope(t *testing.T) {
+ cases := map[string]v2.QueryBuilder{
+ "delete": v2.New().Where(
+ v2.Node().ID().Equals(1),
+ ).Delete(
+ v2.Variable("x"),
+ ),
+ "projection": v2.New().Return(
+ v2.Node(),
+ v2.Variable("x"),
+ ),
+ "sort": v2.New().Return(
+ v2.Node(),
+ ).OrderBy(
+ v2.Asc(v2.Variable("x")),
+ ),
+ }
+
+ for name, builder := range cases {
+ t.Run(name, func(t *testing.T) {
+ _, err := builder.Build()
+ require.ErrorContains(t, err, `query contains unknown identifier "x"`)
+ })
+ }
+}
+
+func TestPathIdentifierRequiresShortestPathMatch(t *testing.T) {
+ _, err := v2.New().Return(
+ v2.Node(),
+ v2.Path(),
+ ).Build()
+ require.ErrorContains(t, err, `query contains unbound identifier "p"`)
+
+ _, err = v2.New().Return(
+ v2.Relationship(),
+ v2.Path(),
+ ).Build()
+ require.ErrorContains(t, err, `query contains unbound identifier "p"`)
+}
+
+func TestCreatedRawIdentifiersDoNotRequireMatch(t *testing.T) {
+ preparedQuery, err := v2.New().Create(&cypher.NodePattern{
+ Variable: v2.Variable("created"),
+ }).Return(
+ v2.Variable("created"),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "create (created) return created", renderPrepared(t, preparedQuery))
+}
+
+func TestMultipleRelationshipKindMatchersRemainConjunctive(t *testing.T) {
+ preparedQuery, err := v2.New().Where(
+ v2.Relationship().Kind().Is(graph.StringKind("A")),
+ v2.Relationship().Kind().Is(graph.StringKind("B")),
+ ).Return(
+ v2.Relationship(),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match ()-[r]->() where r:A and r:B return r", renderPrepared(t, preparedQuery))
+}
+
+func TestEmptyLogicalHelpersReturnBuildErrors(t *testing.T) {
+ _, err := v2.New().Where(v2.And()).Return(v2.Node()).Build()
+ require.ErrorContains(t, err, "and requires at least one operand")
+
+ _, err = v2.New().Where(v2.Or()).Return(v2.Node()).Build()
+ require.ErrorContains(t, err, "or requires at least one operand")
+}
+
+func TestExplicitRelationshipPatternDirectionBoth(t *testing.T) {
+ preparedQuery, err := v2.New().Create(
+ v2.RelationshipPattern(graph.StringKind("Edge"), nil, graph.DirectionBoth),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "create (s)-[r:Edge]-(e)", renderPrepared(t, preparedQuery))
+}
+
+func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) {
+ _, err := v2.New().Create(
+ v2.Relationship().RelationshipPattern(graph.StringKind("Edge"), nil, graph.Direction(99)),
+ ).Build()
+ require.ErrorContains(t, err, "unsupported relationship direction: invalid")
+}
+
+func TestProjectionAndOrderHelpers(t *testing.T) {
+ preparedQuery, err := v2.New().ReturnDistinct(
+ v2.As(v2.Node().ID(), "node_id"),
+ ).OrderBy(
+ v2.Node().Property("name"),
+ v2.Desc(v2.Node().ID()),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) return distinct id(n) as node_id order by n.name asc, id(n) desc", renderPrepared(t, preparedQuery))
+}
+
+func TestInvalidSortDirectionReturnsError(t *testing.T) {
+ _, err := v2.New().Return(v2.Node()).OrderBy(
+ v2.Order(v2.Node().Property("name"), v2.SortDirection(99)),
+ ).Build()
+ require.ErrorContains(t, err, "unsupported sort direction: 99")
+}
+
+func TestPaginationZeroValuesAndNegativeValidation(t *testing.T) {
+ preparedQuery, err := v2.New().Return(v2.Node()).Skip(0).Limit(0).Build()
+ require.NoError(t, err)
+ require.Equal(t, "match (n) return n skip 0 limit 0", renderPrepared(t, preparedQuery))
+
+ _, err = v2.New().Return(v2.Node()).Skip(-1).Build()
+ require.ErrorContains(t, err, "skip must be non-negative: -1")
+
+ _, err = v2.New().Return(v2.Node()).Limit(-1).Build()
+ require.ErrorContains(t, err, "limit must be non-negative: -1")
+}
+
+func TestProjectionAliasDoesNotCreateMatchInference(t *testing.T) {
+ preparedQuery, err := v2.New().Return(
+ v2.As(v2.Literal(1), "one"),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "return 1 as one", renderPrepared(t, preparedQuery))
+ require.Empty(t, preparedQuery.Parameters)
+}
+
+func TestAliasedProjectionCreatesMatchInference(t *testing.T) {
+ preparedQuery, err := v2.New().Return(
+ v2.As(v2.Node().ID(), "node_id"),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) return id(n) as node_id", renderPrepared(t, preparedQuery))
+
+ preparedQuery, err = v2.New().Return(
+ v2.As(v2.Node(), "alias"),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) return n as alias", renderPrepared(t, preparedQuery))
+}
+
+func TestInvalidProjectionAliasReturnsBuildError(t *testing.T) {
+ _, err := v2.New().Return(v2.As(v2.Literal(1), "bad alias")).Build()
+ require.ErrorContains(t, err, `projection alias has invalid symbol "bad alias"`)
+
+ _, err = v2.New().Return(&cypher.ProjectionItem{
+ Expression: v2.Literal(1),
+ Alias: cypher.NewVariableWithSymbol("1bad"),
+ }).Build()
+ require.ErrorContains(t, err, `projection alias has invalid symbol "1bad"`)
+}
+
+func TestUnsupportedOrderByTypeReturnsError(t *testing.T) {
+ _, err := v2.New().Return(v2.Node()).OrderBy(123).Build()
+ require.ErrorContains(t, err, "unsupported expression type: int")
+}
+
+func TestRawProjectionAndOrderInputsAreValidated(t *testing.T) {
+ _, err := v2.New().Return(&cypher.Return{}).Build()
+ require.ErrorContains(t, err, "return clause has nil projection")
+
+ returnClause := cypher.NewReturn()
+ returnClause.NewProjection(false).Items = append(returnClause.Projection.Items, &cypher.ProjectionItem{})
+ _, err = v2.New().Return(returnClause).Build()
+ require.ErrorContains(t, err, "projection item has nil expression")
+
+ _, err = v2.New().Return(v2.Node()).OrderBy(&cypher.SortItem{}).Build()
+ require.ErrorContains(t, err, "sort item has nil expression")
+
+ _, err = v2.New().Return(v2.Node()).OrderBy(&cypher.Order{
+ Items: []*cypher.SortItem{{}},
+ }).Build()
+ require.ErrorContains(t, err, "sort item has nil expression")
+}
+
+func TestRawProjectionAndOrderInputsAreNormalized(t *testing.T) {
+ returnClause := cypher.NewReturn()
+ returnClause.NewProjection(false).Items = append(returnClause.Projection.Items, v2.Node().ID())
+
+ preparedQuery, err := v2.New().Return(returnClause).OrderBy(&cypher.Order{
+ Items: []*cypher.SortItem{
+ v2.Desc(v2.Node().Property("name")),
+ },
+ }).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) return id(n) order by n.name desc", renderPrepared(t, preparedQuery))
+}
+
+func TestRawReturnInputPreservesProjectionMetadata(t *testing.T) {
+ returnClause := cypher.NewReturn()
+ projection := returnClause.NewProjection(true)
+ projection.Items = append(projection.Items, v2.Node().ID())
+ projection.Order = &cypher.Order{
+ Items: []*cypher.SortItem{
+ v2.Desc(v2.Node().Property("name")),
+ },
+ }
+ projection.Skip = cypher.NewSkip(5)
+ projection.Limit = cypher.NewLimit(10)
+
+ preparedQuery, err := v2.New().Return(returnClause).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) return distinct id(n) order by n.name desc skip 5 limit 10", renderPrepared(t, preparedQuery))
+}
+
+func TestRawReturnMetadataCreatesMatchInference(t *testing.T) {
+ returnClause := cypher.NewReturn()
+ projection := returnClause.NewProjection(false)
+ projection.Items = append(projection.Items, v2.Literal(1))
+ projection.Order = &cypher.Order{
+ Items: []*cypher.SortItem{
+ v2.Desc(v2.Node().Property("name")),
+ },
+ }
+
+ preparedQuery, err := v2.New().Return(returnClause).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) return 1 order by n.name desc", renderPrepared(t, preparedQuery))
+}
+
+func TestRawReturnInputMergesWithBuilderProjectionControls(t *testing.T) {
+ returnClause := cypher.NewReturn()
+ projection := returnClause.NewProjection(true)
+ projection.Items = append(projection.Items, v2.Node().ID())
+ projection.Order = &cypher.Order{
+ Items: []*cypher.SortItem{
+ v2.Desc(v2.Node().Property("name")),
+ },
+ }
+ projection.Skip = cypher.NewSkip(5)
+ projection.Limit = cypher.NewLimit(10)
+
+ preparedQuery, err := v2.New().Return(returnClause).OrderBy(
+ v2.Asc(v2.Node().Property("created_at")),
+ ).Skip(15).Limit(20).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) return distinct id(n) order by n.name desc, n.created_at asc skip 15 limit 20", renderPrepared(t, preparedQuery))
+}
+
+func TestRawUpdatingInputsAreValidated(t *testing.T) {
+ var setClause *cypher.Set
+ _, err := v2.New().Update(setClause).Build()
+ require.ErrorContains(t, err, "set clause is nil")
+
+ _, err = v2.New().Update(&cypher.Set{Items: []*cypher.SetItem{nil}}).Build()
+ require.ErrorContains(t, err, "set item is nil")
+
+ _, err = v2.New().Update(&cypher.SetItem{}).Build()
+ require.ErrorContains(t, err, "set item left has nil expression")
+
+ var removeClause *cypher.Remove
+ _, err = v2.New().Update(removeClause).Build()
+ require.ErrorContains(t, err, "remove clause is nil")
+
+ _, err = v2.New().Update(&cypher.Remove{Items: []*cypher.RemoveItem{nil}}).Build()
+ require.ErrorContains(t, err, "remove item is nil")
+
+ _, err = v2.New().Update(&cypher.RemoveItem{}).Build()
+ require.ErrorContains(t, err, "remove item has no target")
+
+ var deleteVariable *cypher.Variable
+ _, err = v2.New().Delete(deleteVariable).Build()
+ require.ErrorContains(t, err, "delete expression has nil expression")
+
+ var nodePattern *cypher.NodePattern
+ _, err = v2.New().Create(nodePattern).Build()
+ require.ErrorContains(t, err, "node pattern is nil")
+
+ var relationshipPattern *cypher.RelationshipPattern
+ _, err = v2.New().Create(relationshipPattern).Build()
+ require.ErrorContains(t, err, "relationship pattern is nil")
+}
+
+func TestDeleteRejectsNonTargetQualifiedExpressions(t *testing.T) {
+ cases := map[string]any{
+ "property continuation": v2.Node().Property("name"),
+ "raw property lookup": cypher.NewPropertyLookup("n", "name"),
+ "id": v2.Node().ID(),
+ "kinds": v2.Node().Kinds(),
+ "kind": v2.Relationship().Kind(),
+ }
+
+ for name, target := range cases {
+ t.Run(name, func(t *testing.T) {
+ _, err := v2.New().Delete(target).Build()
+ require.ErrorContains(t, err, "delete target must be a node, relationship, or variable")
+ })
+ }
+}
+
+func TestDeleteRejectsPathTargets(t *testing.T) {
+ _, err := v2.New().Delete(v2.Path()).Build()
+ require.ErrorContains(t, err, `delete target must be a node or relationship variable, got path variable "p"`)
+
+ _, err = v2.New().Delete(v2.Variable("p")).Build()
+ require.ErrorContains(t, err, `delete target must be a node or relationship variable, got path variable "p"`)
+}
+
+func TestInvalidHelperInputsReturnBuildErrors(t *testing.T) {
+ cases := map[string]struct {
+ builder v2.QueryBuilder
+ err string
+ }{
+ "aliased projection": {
+ builder: v2.New().Return(v2.As(123, "bad")),
+ err: "unsupported expression type: int",
+ },
+ "sort item": {
+ builder: v2.New().Return(v2.Node()).OrderBy(v2.Desc(123)),
+ err: "unsupported expression type: int",
+ },
+ "set properties": {
+ builder: v2.New().Update(v2.SetProperties(123, map[string]any{"name": "bad"})),
+ err: "unsupported expression type: int",
+ },
+ "delete properties": {
+ builder: v2.New().Update(v2.DeleteProperties(123, "name")),
+ err: "unsupported expression type: int",
+ },
+ "pattern predicate": {
+ builder: v2.New().Where(v2.HasRelationships(123)).Return(v2.Node()),
+ err: "unsupported expression type: int",
+ },
+ }
+
+ for name, testCase := range cases {
+ t.Run(name, func(t *testing.T) {
+ _, err := testCase.builder.Build()
+ require.ErrorContains(t, err, testCase.err)
+ })
+ }
+}
+
+func TestNamedParameterMaterialization(t *testing.T) {
+ preparedQuery, err := v2.New().Where(
+ v2.Node().Property("first").Equals("auto"),
+ v2.Node().Property("second").Equals(v2.NamedParameter("p0", "named")),
+ ).Return(
+ v2.Node(),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) where n.first = $p1 and n.second = $p0 return n", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": "named",
+ "p1": "auto",
+ }, preparedQuery.Parameters)
+
+ _, err = v2.New().Where(
+ v2.Node().Property("name").Equals(v2.NamedParameter("bad name", "value")),
+ ).Return(
+ v2.Node(),
+ ).Build()
+ require.ErrorContains(t, err, `parameter has invalid symbol "bad name"`)
+
+ _, err = v2.New().Where(
+ v2.Node().Property("first").Equals(v2.NamedParameter("same", "first")),
+ v2.Node().Property("second").Equals(v2.NamedParameter("same", "second")),
+ ).Return(
+ v2.Node(),
+ ).Build()
+ require.ErrorContains(t, err, "parameter same is bound to multiple values")
+}
+
+func TestQualifiedExpressionValuesUseProjectionSemantics(t *testing.T) {
+ preparedQuery, err := v2.New().Where(
+ v2.Node().Property("copy").Equals(v2.Node().Property("source")),
+ v2.Node().Property("kinds").Equals(v2.Node().Kinds()),
+ ).Return(
+ v2.Node(),
+ ).Build()
+ require.NoError(t, err)
+ require.Equal(t, "match (n) where n.copy = n.source and n.kinds = labels(n) return n", renderPrepared(t, preparedQuery))
+
+ preparedQuery, err = v2.New().Where(
+ v2.Relationship().Property("kind").Equals(v2.Relationship().Kind()),
+ ).Return(
+ v2.Relationship(),
+ ).Build()
+ require.NoError(t, err)
+ require.Equal(t, "match ()-[r]->() where r.kind = type(r) return r", renderPrepared(t, preparedQuery))
+}
+
+func TestBuildDoesNotMutateCallerOwnedAST(t *testing.T) {
+ constraint := v2.Node().Property("name").Equals("alice")
+ constraintParameter := constraint.(*cypher.Comparison).FirstPartial().Right.(*cypher.Parameter)
+
+ preparedQuery, err := v2.New().Where(constraint).Return(v2.Node()).Build()
+ require.NoError(t, err)
+ require.Equal(t, map[string]any{"p0": "alice"}, preparedQuery.Parameters)
+ require.Empty(t, constraintParameter.Symbol)
+
+ setItem := v2.SetProperty(v2.Node().Property("status"), "active")
+ setParameter := setItem.Right.(*cypher.Parameter)
+
+ preparedQuery, err = v2.New().Where(v2.Node().ID().Equals(1)).Update(setItem).Build()
+ require.NoError(t, err)
+ require.Equal(t, map[string]any{"p0": 1, "p1": "active"}, preparedQuery.Parameters)
+ require.Empty(t, setParameter.Symbol)
+
+ createPattern := v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.Parameter(map[string]any{"name": "node"}))
+ createParameter := createPattern.Properties.(*cypher.Parameter)
+
+ preparedQuery, err = v2.New().Create(createPattern).Build()
+ require.NoError(t, err)
+ require.Equal(t, map[string]any{"p0": map[string]any{"name": "node"}}, preparedQuery.Parameters)
+ require.Empty(t, createParameter.Symbol)
+
+ rawReturn := cypher.NewReturn()
+ rawReturn.NewProjection(false).AddItem(cypher.NewProjectionItemWithExpr(v2.Parameter("projected")))
+ rawReturnParameter := rawReturn.Projection.Items[0].(*cypher.ProjectionItem).Expression.(*cypher.Parameter)
+
+ preparedQuery, err = v2.New().Return(rawReturn).Build()
+ require.NoError(t, err)
+ require.Equal(t, map[string]any{"p0": "projected"}, preparedQuery.Parameters)
+ require.Empty(t, rawReturnParameter.Symbol)
+}
+
+func TestCompatibilityHelpers(t *testing.T) {
+ preparedQuery, err := v2.New().Where(
+ v2.And(
+ v2.InIDs(v2.NodeID(), 1, 2),
+ v2.KindIn(v2.Node(), graph.StringKind("User")),
+ v2.CaseInsensitiveStringContains(v2.Node().Property("name"), "ADMIN"),
+ v2.IsNotNull(v2.Node().Property("enabled")),
+ ),
+ ).Return(
+ v2.CountDistinct(v2.Node()),
+ v2.KindsOf(v2.Node()),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) where id(n) in $p0 and n:User and toLower(n.name) contains $p1 and n.enabled is not null return count(distinct n), labels(n)", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": []graph.ID{1, 2},
+ "p1": "admin",
+ }, preparedQuery.Parameters)
+}
+
+func TestUpdateCompatibilityHelpers(t *testing.T) {
+ preparedQuery, err := v2.New().Where(
+ v2.Node().ID().Equals(1),
+ ).Update(
+ v2.AddKind(v2.Node(), graph.StringKind("Enabled")),
+ v2.SetProperties(v2.Node(), map[string]any{"name": "updated"}),
+ ).Build()
+ require.NoError(t, err)
+
+ require.Equal(t, "match (n) where id(n) = $p0 set n:Enabled, n.name = $p1", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": 1,
+ "p1": "updated",
+ }, preparedQuery.Parameters)
+}
+
+func TestFluentMutationHelpersReturnConcreteMutationTypes(t *testing.T) {
+ kinds := graph.Kinds{graph.StringKind("Enabled")}
+
+ var addItem *cypher.SetItem = v2.Node().Kinds().Add(kinds)
+ require.NotNil(t, addItem)
+
+ var removeItem *cypher.RemoveItem = v2.Node().Kinds().Remove(kinds)
+ require.NotNil(t, removeItem)
+
+ var nodeSet *cypher.Set = v2.Node().SetProperties(map[string]any{"name": "updated"})
+ require.Len(t, nodeSet.Items, 1)
+
+ var nodeRemove *cypher.Remove = v2.Node().RemoveProperties([]string{"stale"})
+ require.Len(t, nodeRemove.Items, 1)
+
+ var relationshipSet *cypher.Set = v2.Relationship().SetProperties(map[string]any{"name": "updated"})
+ require.Len(t, relationshipSet.Items, 1)
+
+ var relationshipRemove *cypher.Remove = v2.Relationship().RemoveProperties([]string{"stale"})
+ require.Len(t, relationshipRemove.Items, 1)
+}
+
+func TestSetPropertiesSortsKeys(t *testing.T) {
+ properties := map[string]any{
+ "zeta": 3,
+ "alpha": 1,
+ "mid": 2,
+ }
+
+ preparedQuery, err := v2.New().Update(
+ v2.SetProperties(v2.Node(), properties),
+ ).Build()
+ require.NoError(t, err)
+ require.Equal(t, "match (n) set n.alpha = $p0, n.mid = $p1, n.zeta = $p2", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": 1,
+ "p1": 2,
+ "p2": 3,
+ }, preparedQuery.Parameters)
+
+ preparedQuery, err = v2.New().Update(
+ v2.Node().SetProperties(properties),
+ ).Build()
+ require.NoError(t, err)
+ require.Equal(t, "match (n) set n.alpha = $p0, n.mid = $p1, n.zeta = $p2", renderPrepared(t, preparedQuery))
+ require.Equal(t, map[string]any{
+ "p0": 1,
+ "p1": 2,
+ "p2": 3,
+ }, preparedQuery.Parameters)
+}
diff --git a/query/v2/util.go b/query/v2/util.go
new file mode 100644
index 00000000..03b5fb10
--- /dev/null
+++ b/query/v2/util.go
@@ -0,0 +1,1277 @@
+package v2
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+ "unicode"
+ "unicode/utf8"
+
+ "github.com/specterops/dawgs/cypher/models/cypher"
+ "github.com/specterops/dawgs/cypher/models/walk"
+ "github.com/specterops/dawgs/graph"
+)
+
+func isNodePattern(seen *identifierSet, identifiers runtimeIdentifiers) bool {
+ return seen.Contains(identifiers.node)
+}
+
+func isRelationshipPattern(seen *identifierSet, identifiers runtimeIdentifiers) bool {
+ var (
+ hasStart = seen.Contains(identifiers.start)
+ hasRelationship = seen.Contains(identifiers.relationship)
+ hasEnd = seen.Contains(identifiers.end)
+ )
+
+ return hasStart || hasRelationship || hasEnd
+}
+
+func runtimeIdentifierSet(identifiers runtimeIdentifiers) *identifierSet {
+ return newIdentifierSet(
+ identifiers.path,
+ identifiers.node,
+ identifiers.start,
+ identifiers.relationship,
+ identifiers.end,
+ )
+}
+
+func nodePatternIdentifierSet(identifiers runtimeIdentifiers) *identifierSet {
+ return newIdentifierSet(identifiers.node)
+}
+
+func relationshipPatternIdentifierSet(identifiers runtimeIdentifiers, includePath bool) *identifierSet {
+ allowedIdentifiers := newIdentifierSet(
+ identifiers.start,
+ identifiers.relationship,
+ identifiers.end,
+ )
+
+ if includePath {
+ allowedIdentifiers.Add(identifiers.path)
+ }
+
+ return allowedIdentifiers
+}
+
+func createRelationshipMatchIdentifierSet(identifiers runtimeIdentifiers) *identifierSet {
+ return newIdentifierSet(identifiers.start, identifiers.end)
+}
+
+func validateKnownIdentifiers(seen *identifierSet, identifiers runtimeIdentifiers) error {
+ if identifier, hasIdentifier := seen.FirstOutside(runtimeIdentifierSet(identifiers)); hasIdentifier {
+ return fmt.Errorf("query contains unknown identifier %q", identifier)
+ }
+
+ return nil
+}
+
+func validateBoundIdentifiers(seen, bound *identifierSet) error {
+ if identifier, hasIdentifier := seen.FirstOutside(bound); hasIdentifier {
+ return fmt.Errorf("query contains unbound identifier %q", identifier)
+ }
+
+ return nil
+}
+
+func prepareNodePattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error {
+ if isRelationshipPattern(seen, identifiers) {
+ return fmt.Errorf("query mixes node and relationship query identifiers")
+ }
+
+ if err := validateBoundIdentifiers(seen, nodePatternIdentifierSet(identifiers)); err != nil {
+ return err
+ }
+
+ match.NewPatternPart().AddPatternElements(&cypher.NodePattern{
+ Variable: identifiers.Node(),
+ })
+
+ return nil
+}
+
+func validateRelationshipDirection(direction graph.Direction) error {
+ switch direction {
+ case graph.DirectionInbound, graph.DirectionOutbound, graph.DirectionBoth:
+ return nil
+ default:
+ return fmt.Errorf("unsupported relationship direction: %s", direction)
+ }
+}
+
+func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers, relationshipKinds graph.Kinds, relationshipRange *cypher.PatternRange, direction graph.Direction, shortestPaths, allShortestPaths bool) error {
+ if shortestPaths && allShortestPaths {
+ return errors.New("query is requesting both all shortest paths and shortest paths")
+ }
+
+ if err := validateRelationshipDirection(direction); err != nil {
+ return err
+ }
+
+ hasRangedRelationshipPattern := relationshipRange != nil || shortestPaths || allShortestPaths
+ if err := validateBoundIdentifiers(seen, relationshipPatternIdentifierSet(identifiers, hasRangedRelationshipPattern)); err != nil {
+ return err
+ }
+
+ var (
+ newPatternPart = match.NewPatternPart()
+ startNodeSeen = seen.Contains(identifiers.start)
+ relationshipSeen = seen.Contains(identifiers.relationship)
+ endNodeSeen = seen.Contains(identifiers.end)
+ )
+
+ newPatternPart.ShortestPathPattern = shortestPaths
+ newPatternPart.AllShortestPathsPattern = allShortestPaths
+
+ if startNodeSeen {
+ newPatternPart.AddPatternElements(&cypher.NodePattern{
+ Variable: identifiers.Start(),
+ })
+ } else {
+ newPatternPart.AddPatternElements(&cypher.NodePattern{})
+ }
+
+ relationshipPattern := &cypher.RelationshipPattern{
+ Kinds: relationshipKinds,
+ Direction: direction,
+ }
+
+ if relationshipSeen {
+ relationshipPattern.Variable = identifiers.Relationship()
+ }
+
+ if shortestPaths || allShortestPaths || (relationshipRange != nil && seen.Contains(identifiers.path)) {
+ newPatternPart.Variable = identifiers.Path()
+ }
+
+ if relationshipRange != nil {
+ relationshipPattern.Range = cypher.Copy(relationshipRange)
+ } else if shortestPaths || allShortestPaths {
+ relationshipPattern.Range = &cypher.PatternRange{}
+ }
+
+ newPatternPart.AddPatternElements(relationshipPattern)
+
+ if endNodeSeen {
+ newPatternPart.AddPatternElements(&cypher.NodePattern{
+ Variable: identifiers.End(),
+ })
+ } else {
+ newPatternPart.AddPatternElements(&cypher.NodePattern{})
+ }
+
+ return nil
+}
+
+func prepareCreateRelationshipMatch(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error {
+ if err := validateBoundIdentifiers(seen, createRelationshipMatchIdentifierSet(identifiers)); err != nil {
+ return err
+ }
+
+ if seen.Contains(identifiers.start) {
+ match.NewPatternPart().AddPatternElements(&cypher.NodePattern{
+ Variable: identifiers.Start(),
+ })
+ }
+
+ if seen.Contains(identifiers.end) {
+ match.NewPatternPart().AddPatternElements(&cypher.NodePattern{
+ Variable: identifiers.End(),
+ })
+ }
+
+ return nil
+}
+
+func isDetachDeleteQualifier(qualifier cypher.Expression, identifiers runtimeIdentifiers) bool {
+ variable, typeOK := qualifier.(*cypher.Variable)
+ if !typeOK || variable == nil {
+ return false
+ }
+
+ switch variable.Symbol {
+ case identifiers.node, identifiers.start, identifiers.end:
+ return true
+ default:
+ return false
+ }
+}
+
+func kindProjectionExpression(role string, identifier *cypher.Variable) (cypher.Expression, error) {
+ switch role {
+ case Identifiers.node, Identifiers.start, Identifiers.end:
+ return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, identifier), nil
+
+ case Identifiers.relationship:
+ return cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), nil
+
+ default:
+ return nil, fmt.Errorf("invalid kind projection reference: %s", identifier.Symbol)
+ }
+}
+
+func invalidExpression(err error) *cypher.FunctionInvocation {
+ return cypher.WithErrors(cypher.NewSimpleFunctionInvocation("__invalid_expression__"), err)
+}
+
+func isNilPointer(value any) bool {
+ if value == nil {
+ return true
+ }
+
+ reflectValue := reflect.ValueOf(value)
+ return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
+}
+
+func expressionOrError(value any) cypher.Expression {
+ if expression, err := projectionExpression(value); err != nil {
+ return invalidExpression(err)
+ } else {
+ return expression
+ }
+}
+
+func variableReference(value any) (*cypher.Variable, error) {
+ expression, err := projectionExpression(value)
+ if err != nil {
+ return nil, err
+ }
+
+ if variable, typeOK := expression.(*cypher.Variable); !typeOK {
+ return nil, fmt.Errorf("expected variable reference, got %T", expression)
+ } else {
+ return variable, nil
+ }
+}
+
+func propertyLookupOrError(reference any, propertyName string) cypher.Expression {
+ if variable, err := variableReference(reference); err != nil {
+ return invalidExpression(err)
+ } else {
+ return cypher.NewPropertyLookup(variable.Symbol, propertyName)
+ }
+}
+
+func sortedPropertyKeys(properties map[string]any) []string {
+ keys := make([]string, 0, len(properties))
+
+ for key := range properties {
+ keys = append(keys, key)
+ }
+
+ sort.Strings(keys)
+ return keys
+}
+
+func isCypherSymbolStart(char rune) bool {
+ return char == '_' || unicode.IsLetter(char) || unicode.In(char, unicode.Nl, unicode.Pc)
+}
+
+func isCypherSymbolPart(char rune) bool {
+ return isCypherSymbolStart(char) || unicode.IsDigit(char) || unicode.In(char, unicode.Mark, unicode.Sc)
+}
+
+func validateCypherSymbol(symbol, context string) error {
+ if strings.TrimSpace(symbol) == "" {
+ return fmt.Errorf("%s is empty", context)
+ }
+
+ if !utf8.ValidString(symbol) {
+ return fmt.Errorf("%s has invalid symbol %q", context, symbol)
+ }
+
+ for idx, char := range symbol {
+ if idx == 0 {
+ if !isCypherSymbolStart(char) {
+ return fmt.Errorf("%s has invalid symbol %q", context, symbol)
+ }
+ } else if !isCypherSymbolPart(char) {
+ return fmt.Errorf("%s has invalid symbol %q", context, symbol)
+ }
+ }
+
+ return nil
+}
+
+func projectionExpression(value any) (cypher.Expression, error) {
+ if isNilPointer(value) {
+ return nil, fmt.Errorf("expression is nil: %T", value)
+ }
+
+ switch typedValue := value.(type) {
+ case kindContinuation:
+ return kindProjectionExpression(typedValue.role, typedValue.identifier)
+
+ case kindsContinuation:
+ return kindProjectionExpression(typedValue.role, typedValue.identifier)
+
+ case QualifiedExpression:
+ return typedValue.qualifier(), nil
+
+ case *cypher.ProjectionItem:
+ if typedValue.Expression == nil {
+ return nil, fmt.Errorf("projection item has nil expression")
+ }
+
+ return typedValue.Expression, nil
+
+ case *cypher.Parameter:
+ return typedValue, nil
+
+ case *cypher.Literal:
+ return typedValue, nil
+
+ case *cypher.Variable:
+ return typedValue, nil
+
+ case *cypher.PropertyLookup:
+ return typedValue, nil
+
+ case *cypher.FunctionInvocation:
+ return typedValue, nil
+
+ case *cypher.Parenthetical:
+ return typedValue, nil
+
+ case *cypher.Comparison:
+ return typedValue, nil
+
+ case *cypher.Negation:
+ return typedValue, nil
+
+ case *cypher.Conjunction:
+ return typedValue, nil
+
+ case *cypher.Disjunction:
+ return typedValue, nil
+
+ case *cypher.ExclusiveDisjunction:
+ return typedValue, nil
+
+ case *cypher.KindMatcher:
+ return typedValue, nil
+
+ case *cypher.ListLiteral:
+ return typedValue, nil
+
+ case cypher.MapLiteral:
+ return typedValue, nil
+
+ case *cypher.PatternPredicate:
+ return typedValue, nil
+
+ case *cypher.ArithmeticExpression:
+ return typedValue, nil
+
+ case *cypher.UnaryAddOrSubtractExpression:
+ return typedValue, nil
+
+ case *cypher.FilterExpression:
+ return typedValue, nil
+
+ case *cypher.IDInCollection:
+ return typedValue, nil
+
+ default:
+ return nil, fmt.Errorf("unsupported expression type: %T", value)
+ }
+}
+
+func copyExpression(expression cypher.Expression) cypher.Expression {
+ return cypher.Copy(expression)
+}
+
+func copyProjectionItem(item *cypher.ProjectionItem) *cypher.ProjectionItem {
+ return cypher.Copy(item)
+}
+
+func copySortItem(item *cypher.SortItem) *cypher.SortItem {
+ return cypher.Copy(item)
+}
+
+func copySetItem(item *cypher.SetItem) *cypher.SetItem {
+ return cypher.Copy(item)
+}
+
+func copyRemoveItem(item *cypher.RemoveItem) *cypher.RemoveItem {
+ return cypher.Copy(item)
+}
+
+func copySkip(skip *cypher.Skip) *cypher.Skip {
+ return cypher.Copy(skip)
+}
+
+func copyLimit(limit *cypher.Limit) *cypher.Limit {
+ return cypher.Copy(limit)
+}
+
+func validateExpressionValue(expression cypher.Expression, context string) error {
+ if isNilPointer(expression) {
+ return fmt.Errorf("%s has nil expression", context)
+ }
+
+ return collectModelErrors(expression)
+}
+
+func validateAssignmentOperator(operator cypher.AssignmentOperator) error {
+ switch operator {
+ case cypher.OperatorAssignment, cypher.OperatorAdditionAssignment, cypher.OperatorLabelAssignment:
+ return nil
+ default:
+ return fmt.Errorf("unsupported set item operator: %s", operator)
+ }
+}
+
+func setItemFromValue(setItem *cypher.SetItem) (*cypher.SetItem, error) {
+ if setItem == nil {
+ return nil, fmt.Errorf("set item is nil")
+ }
+
+ if err := validateExpressionValue(setItem.Left, "set item left"); err != nil {
+ return nil, err
+ }
+
+ if err := validateAssignmentOperator(setItem.Operator); err != nil {
+ return nil, err
+ }
+
+ if err := validateExpressionValue(setItem.Right, "set item right"); err != nil {
+ return nil, err
+ }
+
+ if err := collectModelErrors(setItem); err != nil {
+ return nil, err
+ }
+
+ return copySetItem(setItem), nil
+}
+
+func setItemsFromSet(setClause *cypher.Set) ([]*cypher.SetItem, error) {
+ if setClause == nil {
+ return nil, fmt.Errorf("set clause is nil")
+ }
+
+ setItems := make([]*cypher.SetItem, 0, len(setClause.Items))
+
+ for _, setItem := range setClause.Items {
+ if normalizedSetItem, err := setItemFromValue(setItem); err != nil {
+ return nil, err
+ } else {
+ setItems = append(setItems, normalizedSetItem)
+ }
+ }
+
+ return setItems, nil
+}
+
+func removeItemFromValue(removeItem *cypher.RemoveItem) (*cypher.RemoveItem, error) {
+ if removeItem == nil {
+ return nil, fmt.Errorf("remove item is nil")
+ }
+
+ hasKindMatcher := removeItem.KindMatcher != nil
+ hasProperty := !isNilPointer(removeItem.Property)
+
+ switch {
+ case hasKindMatcher && hasProperty:
+ return nil, fmt.Errorf("remove item has multiple targets")
+
+ case hasKindMatcher:
+ if err := collectModelErrors(removeItem.KindMatcher); err != nil {
+ return nil, err
+ }
+
+ case hasProperty:
+ if err := validateExpressionValue(removeItem.Property, "remove item property"); err != nil {
+ return nil, err
+ }
+
+ default:
+ return nil, fmt.Errorf("remove item has no target")
+ }
+
+ if err := collectModelErrors(removeItem); err != nil {
+ return nil, err
+ }
+
+ return copyRemoveItem(removeItem), nil
+}
+
+func removeItemsFromRemove(removeClause *cypher.Remove) ([]*cypher.RemoveItem, error) {
+ if removeClause == nil {
+ return nil, fmt.Errorf("remove clause is nil")
+ }
+
+ removeItems := make([]*cypher.RemoveItem, 0, len(removeClause.Items))
+
+ for _, removeItem := range removeClause.Items {
+ if normalizedRemoveItem, err := removeItemFromValue(removeItem); err != nil {
+ return nil, err
+ } else {
+ removeItems = append(removeItems, normalizedRemoveItem)
+ }
+ }
+
+ return removeItems, nil
+}
+
+func validateNodePattern(nodePattern *cypher.NodePattern) error {
+ if nodePattern == nil {
+ return fmt.Errorf("node pattern is nil")
+ }
+
+ return collectModelErrors(nodePattern)
+}
+
+func validateRelationshipPattern(relationshipPattern *cypher.RelationshipPattern) error {
+ if relationshipPattern == nil {
+ return fmt.Errorf("relationship pattern is nil")
+ }
+
+ if err := validateRelationshipDirection(relationshipPattern.Direction); err != nil {
+ return err
+ }
+
+ return collectModelErrors(relationshipPattern)
+}
+
+func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) {
+ if projectionItem, typeOK := value.(*cypher.ProjectionItem); typeOK {
+ if projectionItem == nil {
+ return nil, fmt.Errorf("projection item is nil")
+ }
+
+ if err := validateExpressionValue(projectionItem.Expression, "projection item"); err != nil {
+ return nil, err
+ }
+
+ if projectionItem.Alias != nil {
+ if err := validateCypherSymbol(projectionItem.Alias.Symbol, "projection alias"); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := collectModelErrors(projectionItem); err != nil {
+ return nil, err
+ }
+
+ return copyProjectionItem(projectionItem), nil
+ }
+
+ if expression, err := projectionExpression(value); err != nil {
+ return nil, err
+ } else {
+ return cypher.NewProjectionItemWithExpr(copyExpression(expression)), nil
+ }
+}
+
+func sortItemFromValue(value any) (*cypher.SortItem, error) {
+ if sortItem, typeOK := value.(*cypher.SortItem); typeOK {
+ if sortItem == nil {
+ return nil, fmt.Errorf("sort item is nil")
+ }
+
+ if err := validateExpressionValue(sortItem.Expression, "sort item"); err != nil {
+ return nil, err
+ }
+
+ if err := collectModelErrors(sortItem); err != nil {
+ return nil, err
+ }
+
+ return copySortItem(sortItem), nil
+ }
+
+ if expression, err := projectionExpression(value); err != nil {
+ return nil, err
+ } else {
+ return &cypher.SortItem{
+ Ascending: true,
+ Expression: copyExpression(expression),
+ }, nil
+ }
+}
+
+func projectionItemsFromReturn(returnClause *cypher.Return) ([]*cypher.ProjectionItem, error) {
+ if returnClause == nil {
+ return nil, fmt.Errorf("return clause is nil")
+ }
+
+ if returnClause.Projection == nil {
+ return nil, fmt.Errorf("return clause has nil projection")
+ }
+
+ if err := validateProjectionMetadata(returnClause.Projection); err != nil {
+ return nil, err
+ }
+
+ projectionItems := make([]*cypher.ProjectionItem, 0, len(returnClause.Projection.Items))
+
+ for _, returnItem := range returnClause.Projection.Items {
+ if projectionItem, err := projectionItemFromValue(returnItem); err != nil {
+ return nil, err
+ } else {
+ projectionItems = append(projectionItems, projectionItem)
+ }
+ }
+
+ return projectionItems, nil
+}
+
+func validateProjectionMetadata(projection *cypher.Projection) error {
+ if projection.Order != nil {
+ if _, err := sortItemsFromOrder(projection.Order); err != nil {
+ return err
+ }
+ }
+
+ if projection.Skip != nil {
+ if err := validateExpressionValue(projection.Skip.Value, "projection skip"); err != nil {
+ return err
+ }
+ }
+
+ if projection.Limit != nil {
+ if err := validateExpressionValue(projection.Limit.Value, "projection limit"); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func sortItemsFromOrder(order *cypher.Order) ([]*cypher.SortItem, error) {
+ if order == nil {
+ return nil, fmt.Errorf("order is nil")
+ }
+
+ sortItems := make([]*cypher.SortItem, 0, len(order.Items))
+
+ for _, sortItem := range order.Items {
+ if normalizedSortItem, err := sortItemFromValue(sortItem); err != nil {
+ return nil, err
+ } else {
+ sortItems = append(sortItems, normalizedSortItem)
+ }
+ }
+
+ return sortItems, nil
+}
+
+type identifierSet struct {
+ identifiers map[string]struct{}
+}
+
+func newIdentifierSet(identifiers ...string) *identifierSet {
+ set := &identifierSet{
+ identifiers: map[string]struct{}{},
+ }
+
+ for _, identifier := range identifiers {
+ set.Add(identifier)
+ }
+
+ return set
+}
+
+func (s *identifierSet) Add(identifier string) {
+ s.identifiers[identifier] = struct{}{}
+}
+
+func (s *identifierSet) Len() int {
+ return len(s.identifiers)
+}
+
+func (s *identifierSet) Clone() *identifierSet {
+ clone := newIdentifierSet()
+ clone.Or(s)
+ return clone
+}
+
+func (s *identifierSet) Or(other *identifierSet) {
+ if s == nil || other == nil {
+ return
+ }
+
+ for otherIdentifier := range other.identifiers {
+ s.identifiers[otherIdentifier] = struct{}{}
+ }
+}
+
+func (s *identifierSet) Remove(other *identifierSet) {
+ if s == nil || other == nil {
+ return
+ }
+
+ for otherIdentifier := range other.identifiers {
+ delete(s.identifiers, otherIdentifier)
+ }
+}
+
+func (s *identifierSet) Contains(identifier string) bool {
+ _, containsIdentifier := s.identifiers[identifier]
+ return containsIdentifier
+}
+
+func (s *identifierSet) FirstOutside(allowed *identifierSet) (string, bool) {
+ var identifiers []string
+
+ for identifier := range s.identifiers {
+ if !allowed.Contains(identifier) {
+ identifiers = append(identifiers, identifier)
+ }
+ }
+
+ sort.Strings(identifiers)
+
+ if len(identifiers) == 0 {
+ return "", false
+ }
+
+ return identifiers[0], true
+}
+
+func (s *identifierSet) CollectFromExpression(expr cypher.Expression) error {
+ if exprIdentifiers, err := extractCypherIdentifiers(expr); err != nil {
+ return err
+ } else {
+ s.Or(exprIdentifiers)
+ return nil
+ }
+}
+
+func (s *identifierSet) CollectFromValue(value any) error {
+ switch typedValue := value.(type) {
+ case nil:
+ return nil
+
+ case QualifiedExpression:
+ return s.CollectFromExpression(typedValue.qualifier())
+
+ case *cypher.ProjectionItem:
+ if projectionItem, err := projectionItemFromValue(typedValue); err != nil {
+ return err
+ } else {
+ return s.CollectFromExpression(projectionItem)
+ }
+
+ case *cypher.Return:
+ if projectionItems, err := projectionItemsFromReturn(typedValue); err != nil {
+ return err
+ } else {
+ for _, projectionItem := range projectionItems {
+ if err := s.CollectFromExpression(projectionItem); err != nil {
+ return err
+ }
+ }
+ }
+
+ if err := s.CollectFromProjectionMetadata(typedValue.Projection); err != nil {
+ return err
+ }
+
+ case *cypher.Order:
+ if sortItems, err := sortItemsFromOrder(typedValue); err != nil {
+ return err
+ } else {
+ for _, sortItem := range sortItems {
+ if err := s.CollectFromExpression(sortItem); err != nil {
+ return err
+ }
+ }
+ }
+
+ case *cypher.SortItem:
+ if sortItem, err := sortItemFromValue(typedValue); err != nil {
+ return err
+ } else {
+ return s.CollectFromExpression(sortItem)
+ }
+
+ case *cypher.Set:
+ if setItems, err := setItemsFromSet(typedValue); err != nil {
+ return err
+ } else {
+ for _, setItem := range setItems {
+ if err := s.CollectFromExpression(setItem); err != nil {
+ return err
+ }
+ }
+ }
+
+ case *cypher.SetItem:
+ if setItem, err := setItemFromValue(typedValue); err != nil {
+ return err
+ } else {
+ return s.CollectFromExpression(setItem)
+ }
+
+ case *cypher.Remove:
+ if removeItems, err := removeItemsFromRemove(typedValue); err != nil {
+ return err
+ } else {
+ for _, removeItem := range removeItems {
+ if err := s.CollectFromValue(removeItem); err != nil {
+ return err
+ }
+ }
+ }
+
+ case *cypher.RemoveItem:
+ if removeItem, err := removeItemFromValue(typedValue); err != nil {
+ return err
+ } else if removeItem.KindMatcher != nil {
+ return s.CollectFromExpression(removeItem.KindMatcher)
+ } else {
+ return s.CollectFromExpression(removeItem)
+ }
+
+ case *cypher.NodePattern:
+ if err := validateNodePattern(typedValue); err != nil {
+ return err
+ }
+
+ return s.CollectFromExpression(typedValue)
+
+ case *cypher.RelationshipPattern:
+ if err := validateRelationshipPattern(typedValue); err != nil {
+ return err
+ }
+
+ return s.CollectFromExpression(typedValue)
+
+ case *cypher.Variable:
+ return s.CollectFromExpression(typedValue)
+
+ case *cypher.FunctionInvocation:
+ return s.CollectFromExpression(typedValue)
+
+ case *cypher.PropertyLookup:
+ return s.CollectFromExpression(typedValue)
+
+ case []any:
+ for _, item := range typedValue {
+ if err := s.CollectFromValue(item); err != nil {
+ return err
+ }
+ }
+
+ case []cypher.SyntaxNode:
+ for _, item := range typedValue {
+ if err := s.CollectFromValue(item); err != nil {
+ return err
+ }
+ }
+
+ case []cypher.Expression:
+ for _, item := range typedValue {
+ if err := s.CollectFromValue(item); err != nil {
+ return err
+ }
+ }
+
+ case []*cypher.SetItem:
+ for _, item := range typedValue {
+ if err := s.CollectFromValue(item); err != nil {
+ return err
+ }
+ }
+
+ case []*cypher.RemoveItem:
+ for _, item := range typedValue {
+ if err := s.CollectFromValue(item); err != nil {
+ return err
+ }
+ }
+
+ default:
+ return nil
+ }
+
+ return nil
+}
+
+func (s *identifierSet) CollectFromProjectionMetadata(projection *cypher.Projection) error {
+ if projection == nil {
+ return nil
+ }
+
+ if projection.Order != nil {
+ if err := s.CollectFromValue(projection.Order); err != nil {
+ return err
+ }
+ }
+
+ if projection.Skip != nil {
+ if err := s.CollectFromExpression(projection.Skip.Value); err != nil {
+ return err
+ }
+ }
+
+ if projection.Limit != nil {
+ if err := s.CollectFromExpression(projection.Limit.Value); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func collectIdentifiersFromValues(values ...any) (*identifierSet, error) {
+ identifiers := newIdentifierSet()
+
+ for _, value := range values {
+ if err := identifiers.CollectFromValue(value); err != nil {
+ return nil, err
+ }
+ }
+
+ return identifiers, nil
+}
+
+type createScope struct {
+ identifiers *identifierSet
+ createsRelationship bool
+}
+
+func collectCreateScope(identifiers runtimeIdentifiers, values ...any) (*createScope, error) {
+ scope := &createScope{
+ identifiers: newIdentifierSet(),
+ }
+
+ for _, value := range values {
+ switch typedValue := value.(type) {
+ case *cypher.RelationshipPattern:
+ if err := validateRelationshipPattern(typedValue); err != nil {
+ return nil, err
+ }
+
+ scope.createsRelationship = true
+ scope.identifiers.Add(identifiers.start)
+ scope.identifiers.Add(identifiers.end)
+
+ if typedValue.Variable != nil {
+ scope.identifiers.Add(typedValue.Variable.Symbol)
+ }
+
+ default:
+ if err := scope.identifiers.CollectFromValue(value); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ return scope, nil
+}
+
+type identifierExtractor struct {
+ walk.Visitor[cypher.SyntaxNode]
+
+ seen *identifierSet
+}
+
+func newIdentifierExtractor() *identifierExtractor {
+ return &identifierExtractor{
+ Visitor: walk.NewVisitor[cypher.SyntaxNode](),
+ seen: newIdentifierSet(),
+ }
+}
+
+func (s *identifierExtractor) Enter(node cypher.SyntaxNode) {
+ switch typedNode := node.(type) {
+ case *cypher.Variable:
+ s.seen.Add(typedNode.Symbol)
+
+ case *cypher.NodePattern:
+ if typedNode.Variable != nil {
+ s.seen.Add(typedNode.Variable.Symbol)
+ }
+
+ case *cypher.RelationshipPattern:
+ if typedNode.Variable != nil {
+ s.seen.Add(typedNode.Variable.Symbol)
+ }
+
+ case *cypher.PatternPart:
+ if typedNode.Variable != nil {
+ s.seen.Add(typedNode.Variable.Symbol)
+ }
+ }
+}
+
+func extractCypherIdentifiers(expression cypher.Expression) (*identifierSet, error) {
+ var (
+ identifierExtractorVisitor = newIdentifierExtractor()
+ err = walk.Cypher(expression, identifierExtractorVisitor)
+ )
+
+ return identifierExtractorVisitor.seen, err
+}
+
+func collectModelErrors(node cypher.SyntaxNode) error {
+ var modelErrors []error
+
+ if err := walk.Cypher(node, walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) {
+ if errorNode, typeOK := node.(cypher.Fallible); typeOK {
+ modelErrors = append(modelErrors, errorNode.Errors()...)
+ }
+ })); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ return errors.Join(modelErrors...)
+}
+
+func collectModelErrorsFromKnownValues(values ...any) error {
+ var modelErrors []error
+
+ for _, value := range values {
+ switch typedValue := value.(type) {
+ case nil:
+ continue
+
+ case []any:
+ if err := collectModelErrorsFromKnownValues(typedValue...); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case []cypher.SyntaxNode:
+ if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case []cypher.Expression:
+ if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case []*cypher.SetItem:
+ if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case []*cypher.RemoveItem:
+ if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case []*cypher.ProjectionItem:
+ if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case []*cypher.SortItem:
+ if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.NodePattern:
+ if err := validateNodePattern(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.Order:
+ if _, err := sortItemsFromOrder(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ } else if err := collectModelErrors(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.ProjectionItem:
+ if _, err := projectionItemFromValue(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ } else if err := collectModelErrors(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.Return:
+ if _, err := projectionItemsFromReturn(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.RelationshipPattern:
+ if err := validateRelationshipPattern(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.Remove:
+ if _, err := removeItemsFromRemove(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ } else if err := collectModelErrors(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.RemoveItem:
+ if _, err := removeItemFromValue(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ } else if err := collectModelErrors(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.Set:
+ if _, err := setItemsFromSet(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ } else if err := collectModelErrors(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.SetItem:
+ if _, err := setItemFromValue(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ } else if err := collectModelErrors(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.SortItem:
+ if _, err := sortItemFromValue(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ } else if err := collectModelErrors(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+
+ case *cypher.ArithmeticExpression,
+ *cypher.Comparison,
+ *cypher.Conjunction,
+ *cypher.Create,
+ *cypher.Delete,
+ *cypher.Disjunction,
+ *cypher.ExclusiveDisjunction,
+ *cypher.FilterExpression,
+ *cypher.FunctionInvocation,
+ *cypher.IDInCollection,
+ *cypher.KindMatcher,
+ *cypher.ListLiteral,
+ *cypher.Negation,
+ *cypher.Parenthetical,
+ *cypher.PatternPredicate,
+ *cypher.PropertyLookup,
+ *cypher.UnaryAddOrSubtractExpression,
+ *cypher.UpdatingClause,
+ *cypher.Variable:
+ if err := collectModelErrors(typedValue); err != nil {
+ modelErrors = append(modelErrors, err)
+ }
+ }
+ }
+
+ return errors.Join(modelErrors...)
+}
+
+func anySlice[T any](values []T) []any {
+ items := make([]any, len(values))
+
+ for idx, value := range values {
+ items[idx] = value
+ }
+
+ return items
+}
+
+type parameterMaterializer struct {
+ walk.Visitor[cypher.SyntaxNode]
+
+ parameters map[string]any
+ nextIndex int
+}
+
+func newParameterMaterializer(parameters map[string]any) *parameterMaterializer {
+ materializedParameters := map[string]any{}
+
+ for symbol, value := range parameters {
+ materializedParameters[symbol] = value
+ }
+
+ return ¶meterMaterializer{
+ Visitor: walk.NewVisitor[cypher.SyntaxNode](),
+ parameters: materializedParameters,
+ }
+}
+
+func (s *parameterMaterializer) nextSymbol() string {
+ for {
+ symbol := "p" + strconv.Itoa(s.nextIndex)
+ s.nextIndex++
+
+ if _, taken := s.parameters[symbol]; !taken {
+ return symbol
+ }
+ }
+}
+
+func (s *parameterMaterializer) Enter(node cypher.SyntaxNode) {
+ parameter, typeOK := node.(*cypher.Parameter)
+ if !typeOK {
+ return
+ }
+
+ if parameter.Symbol == "" {
+ parameter.Symbol = s.nextSymbol()
+ }
+
+ if existingValue, exists := s.parameters[parameter.Symbol]; exists && !reflect.DeepEqual(existingValue, parameter.Value) {
+ s.SetErrorf("parameter %s is bound to multiple values", parameter.Symbol)
+ return
+ }
+
+ s.parameters[parameter.Symbol] = parameter.Value
+}
+
+type namedParameterCollector struct {
+ walk.Visitor[cypher.SyntaxNode]
+
+ parameters map[string]any
+}
+
+func newNamedParameterCollector() *namedParameterCollector {
+ return &namedParameterCollector{
+ Visitor: walk.NewVisitor[cypher.SyntaxNode](),
+ parameters: map[string]any{},
+ }
+}
+
+func (s *namedParameterCollector) Enter(node cypher.SyntaxNode) {
+ parameter, typeOK := node.(*cypher.Parameter)
+ if !typeOK || parameter.Symbol == "" {
+ return
+ }
+
+ if err := validateCypherSymbol(parameter.Symbol, "parameter"); err != nil {
+ s.SetError(err)
+ return
+ }
+
+ if existingValue, exists := s.parameters[parameter.Symbol]; exists && !reflect.DeepEqual(existingValue, parameter.Value) {
+ s.SetErrorf("parameter %s is bound to multiple values", parameter.Symbol)
+ return
+ }
+
+ s.parameters[parameter.Symbol] = parameter.Value
+}
+
+func collectNamedParameters(query *cypher.RegularQuery) (map[string]any, error) {
+ collector := newNamedParameterCollector()
+
+ if err := walk.Cypher(query, collector); err != nil {
+ return nil, err
+ }
+
+ return collector.parameters, nil
+}
+
+func materializeParameters(query *cypher.RegularQuery) (map[string]any, error) {
+ namedParameters, err := collectNamedParameters(query)
+ if err != nil {
+ return nil, err
+ }
+
+ materializer := newParameterMaterializer(namedParameters)
+
+ if err := walk.Cypher(query, materializer); err != nil {
+ return nil, err
+ }
+
+ return materializer.parameters, nil
+}
diff --git a/query/v2/util_internal_test.go b/query/v2/util_internal_test.go
new file mode 100644
index 00000000..1ae62236
--- /dev/null
+++ b/query/v2/util_internal_test.go
@@ -0,0 +1,20 @@
+package v2
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIdentifierSetOrAndRemoveNilSafe(t *testing.T) {
+ var nilSet *identifierSet
+
+ nilSet.Or(newIdentifierSet("ignored"))
+ nilSet.Remove(newIdentifierSet("ignored"))
+
+ set := newIdentifierSet("kept")
+ set.Or(nil)
+ set.Remove(nil)
+
+ require.True(t, set.Contains("kept"))
+}