diff --git a/.gitignore b/.gitignore index b3e5e211..29ac9edf 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,8 @@ # Local integration test datasets integration/testdata/local/ +# Local benchmark comparison output +.bench/ + # Local test and metric artifacts .coverage/ diff --git a/Makefile b/Makefile index 81dd8363..71b69d9e 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,12 @@ THIS_FILE := $(lastword $(MAKEFILE_LIST)) # Go configuration GO_CMD ?= go CGO_ENABLED ?= 0 +BENCH ?= . +BENCH_COUNT ?= 10 +BENCH_TIME ?= 1s +BENCH_BASE ?= main +BENCH_TARGET ?= HEAD +BENCH_KIND ?= all # Main packages to test/build MAIN_PACKAGES := $(shell $(GO_CMD) list ./...) @@ -92,6 +98,14 @@ test_integration: @echo "Running all integration tests..." @$(GO_CMD) test -tags 'manual_integration integration' -race -cover -count=1 -p=1 -parallel=1 $(MAIN_PACKAGES) +test_bench: + @echo "Running benchmarks..." + @$(GO_CMD) test -run '^$$' -bench '$(BENCH)' -benchmem -count=$(BENCH_COUNT) -benchtime=$(BENCH_TIME) $(MAIN_PACKAGES) + +bench_diff: + @echo "Running benchmark diff..." + @$(GO_CMD) run ./cmd/benchdiff --base '$(BENCH_BASE)' --target '$(BENCH_TARGET)' --kind '$(BENCH_KIND)' --bench '$(BENCH)' --bench-count '$(BENCH_COUNT)' --benchtime '$(BENCH_TIME)' $(BENCHDIFF_ARGS) + test_neo4j: @echo "Running Neo4j integration tests..." @$(GO_CMD) test -tags integration -race -cover -count=1 -p=1 -parallel=1 $(MAIN_PACKAGES) @@ -216,6 +230,7 @@ help: @echo " test_all - Run all tests including integration tests" @echo " test_integration - Run all integration tests" @echo " test_bench - Run benchmark test" + @echo " bench_diff - Compare benchmarks between commits" @echo " test_neo4j - Run Neo4j integration tests" @echo " test_pg - Run PostgreSQL integration tests" @echo " test_update - Update test cases" diff --git a/README.md b/README.md index e1aad7bb..f2eef241 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,35 @@ export CONNECTION_STRING="neo4j://neo4j:weneedbetterpasswords@localhost:7687" Use `make test` for unit tests only and `make test_integration` for integration tests only. +### Benchmarking + +Run the package benchmark suite with: + +```bash +make test_bench +``` + +Use `cmd/benchdiff` to compare benchmarks between two committed refs without changing the active worktree: + +```bash +go run ./cmd/benchdiff -base main -target HEAD -kind unit +``` + +For integration benchmark comparisons, provide the same `CONNECTION_STRING` used by integration tests: + +```bash +export CONNECTION_STRING="postgresql://dawgs:weneedbetterpasswords@localhost:65432/dawgs" +go run ./cmd/benchdiff -base main -target HEAD -kind all -driver pg -fail-regression 10% +``` + +The harness writes raw outputs and a Markdown report under `.bench/runs/` by default. The report begins with comparison +findings, includes the raw `benchstat` output for each benchmark suite, and ends with a table of all captured benchmark +numbers. + +The integration benchmark runner includes committed `base` and `traversal_shapes` datasets by default. The traversal +shape suite checks expected result counts for chain, fanout, bounded cycle, disconnected, edge-kind-selective, and +multi-path shortest-path scenarios before recording timings. + ### Test Metrics `make test` writes unit test coverage artifacts under `.coverage/`: diff --git a/cmd/benchdiff/README.md b/cmd/benchdiff/README.md new file mode 100644 index 00000000..41774194 --- /dev/null +++ b/cmd/benchdiff/README.md @@ -0,0 +1,52 @@ +# Benchdiff + +Compares the existing benchmark suites between two committed git refs without changing the active worktree. + +## Usage + +```bash +# Unit Go benchmarks only +go run ./cmd/benchdiff -base main -target HEAD -kind unit + +# Unit and integration benchmarks +export CONNECTION_STRING="postgresql://dawgs:weneedbetterpasswords@localhost:65432/dawgs" +go run ./cmd/benchdiff -base main -target HEAD -kind all -driver pg + +# Fail if a benchmark median regresses by more than 10% +go run ./cmd/benchdiff -base main -target HEAD -kind unit -fail-regression 10% +``` + +`benchdiff` creates detached worktrees under `.bench/`, runs each selected benchmark suite, writes raw output, and +produces a Markdown report. The report starts with comparison findings, including median regressions, improvements, +unchanged counts, and benchmark names that only appeared in one ref. It ends with an `All Executed Benchmark Numbers` +section that lists the median, percent change, and sample counts for every benchmark captured in either ref. Worktrees +are removed by default after the run; pass `-keep-worktrees` to preserve them. + +## Flags + +| Flag | Default | Description | +|------|---------|-------------| +| `-base` | `main` | Base git ref | +| `-target` | `HEAD` | Target git ref | +| `-kind` | `all` | Benchmark kind (`all`, `unit`, `integration`) | +| `-packages` | `./...` | Package list for Go benchmarks | +| `-bench` | `.` | Go benchmark regexp | +| `-bench-count` | `10` | Go benchmark repetition count | +| `-benchtime` | `1s` | Go benchmark benchtime | +| `-driver` | `pg` | Integration benchmark database driver | +| `-connection` | | Integration connection string (or `CONNECTION_STRING`) | +| `-dataset` | | Run only this integration dataset | +| `-local-dataset` | | Add a local integration dataset | +| `-dataset-dir` | `integration/testdata` | Integration testdata directory | +| `-integration-iterations` | `10` | Timed iterations per integration scenario | +| `-out` | `.bench/runs/..-` | Output directory | +| `-benchstat` | `auto` | `benchstat` command, `auto`, or `none` | +| `-fail-regression` | `0` | Median regression percentage that fails the command | +| `-keep-worktrees` | `false` | Preserve temporary worktrees | + +If `benchstat` is not on `PATH` and `-benchstat auto` is used, the harness falls back to +`go run golang.org/x/perf/cmd/benchstat@latest`. + +Integration comparisons use native `cmd/benchmark -format benchfmt` when both refs support it. If either ref predates +that flag, the harness runs both refs in Markdown compatibility mode and compares each scenario's median as a single +`ns/op` sample. diff --git a/cmd/benchdiff/benchfmt.go b/cmd/benchdiff/benchfmt.go new file mode 100644 index 00000000..cdc5a68d --- /dev/null +++ b/cmd/benchdiff/benchfmt.go @@ -0,0 +1,325 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bufio" + "bytes" + "fmt" + "io" + "math" + "os" + "regexp" + "runtime" + "sort" + "strconv" + "strings" + "time" + "unicode" +) + +var benchmarkLinePattern = regexp.MustCompile(`^(Benchmark\S+)\s+\d+\s+([0-9]+(?:\.[0-9]+)?)\s+ns/op\b`) + +type benchmarkSamples map[string][]float64 + +type comparisonFindings struct { + Compared int + Regressions []benchmarkFinding + Improvements []benchmarkFinding + Unchanged int + OnlyBase []string + OnlyTarget []string + Results []benchmarkResult +} + +type benchmarkFinding struct { + Name string + BaseMedianNS float64 + TargetMedianNS float64 + DeltaPercent float64 +} + +type benchmarkResult struct { + Name string + BaseMedianNS float64 + TargetMedianNS float64 + DeltaPercent float64 + BaseSamples int + TargetSamples int + HasBase bool + HasTarget bool +} + +type regression struct { + Name string + BaseMedianNS float64 + TargetMedianNS float64 + Percent float64 +} + +func parseBenchfmtNS(data []byte) benchmarkSamples { + samples := benchmarkSamples{} + scanner := bufio.NewScanner(bytes.NewReader(data)) + + for scanner.Scan() { + matches := benchmarkLinePattern.FindStringSubmatch(scanner.Text()) + if len(matches) != 3 { + continue + } + + ns, err := strconv.ParseFloat(matches[2], 64) + if err != nil { + continue + } + + samples[matches[1]] = append(samples[matches[1]], ns) + } + + return samples +} + +func parseBenchfmtNSFile(path string) (benchmarkSamples, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + return parseBenchfmtNS(data), nil +} + +func summarizeFindings(base, target benchmarkSamples) comparisonFindings { + findings := comparisonFindings{} + names := map[string]struct{}{} + + for name := range base { + names[name] = struct{}{} + } + for name := range target { + names[name] = struct{}{} + } + + for name := range names { + baseValues := base[name] + targetValues := target[name] + result := benchmarkResult{ + Name: name, + BaseSamples: len(baseValues), + TargetSamples: len(targetValues), + HasBase: len(baseValues) > 0, + HasTarget: len(targetValues) > 0, + } + + switch { + case len(baseValues) == 0: + result.TargetMedianNS = median(targetValues) + findings.Results = append(findings.Results, result) + findings.OnlyTarget = append(findings.OnlyTarget, name) + continue + case len(targetValues) == 0: + result.BaseMedianNS = median(baseValues) + findings.Results = append(findings.Results, result) + findings.OnlyBase = append(findings.OnlyBase, name) + continue + } + + baseMedian := median(baseValues) + targetMedian := median(targetValues) + if baseMedian <= 0 { + findings.Results = append(findings.Results, result) + continue + } + + result.BaseMedianNS = baseMedian + result.TargetMedianNS = targetMedian + findings.Compared++ + deltaPercent := ((targetMedian - baseMedian) / baseMedian) * 100 + result.DeltaPercent = deltaPercent + findings.Results = append(findings.Results, result) + + finding := benchmarkFinding{ + Name: name, + BaseMedianNS: baseMedian, + TargetMedianNS: targetMedian, + DeltaPercent: deltaPercent, + } + + switch { + case deltaPercent > 0: + findings.Regressions = append(findings.Regressions, finding) + case deltaPercent < 0: + findings.Improvements = append(findings.Improvements, finding) + default: + findings.Unchanged++ + } + } + + sort.Slice(findings.Regressions, func(i, j int) bool { + return findings.Regressions[i].DeltaPercent > findings.Regressions[j].DeltaPercent + }) + sort.Slice(findings.Improvements, func(i, j int) bool { + return findings.Improvements[i].DeltaPercent < findings.Improvements[j].DeltaPercent + }) + sort.Strings(findings.OnlyBase) + sort.Strings(findings.OnlyTarget) + sort.Slice(findings.Results, func(i, j int) bool { + return findings.Results[i].Name < findings.Results[j].Name + }) + + return findings +} + +func findingsForFiles(baseFile, targetFile string) (comparisonFindings, error) { + base, err := parseBenchfmtNSFile(baseFile) + if err != nil { + return comparisonFindings{}, err + } + target, err := parseBenchfmtNSFile(targetFile) + if err != nil { + return comparisonFindings{}, err + } + + return summarizeFindings(base, target), nil +} + +func (findings comparisonFindings) regressionsOver(threshold float64) []regression { + if threshold <= 0 { + return nil + } + + var regressions []regression + for _, finding := range findings.Regressions { + if finding.DeltaPercent <= threshold { + continue + } + + regressions = append(regressions, regression{ + Name: finding.Name, + BaseMedianNS: finding.BaseMedianNS, + TargetMedianNS: finding.TargetMedianNS, + Percent: finding.DeltaPercent, + }) + } + + return regressions +} + +func median(values []float64) float64 { + sorted := append([]float64(nil), values...) + sort.Float64s(sorted) + + mid := len(sorted) / 2 + if len(sorted)%2 == 0 { + return (sorted[mid-1] + sorted[mid]) / 2 + } + + return sorted[mid] +} + +func writeIntegrationBenchfmt(w io.Writer, driver string, rows []markdownBenchmarkRow) error { + fmt.Fprintf(w, "goos: %s\n", runtime.GOOS) + fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH) + fmt.Fprintln(w, "pkg: github.com/specterops/dawgs/cmd/benchmark") + + procs := runtime.GOMAXPROCS(0) + for _, row := range rows { + fmt.Fprintf(w, "%s-%d\t1\t%d ns/op\n", integrationBenchmarkName(driver, row.Dataset, row.Query), procs, row.Median.Nanoseconds()) + } + + return nil +} + +func integrationBenchmarkName(driver, dataset, query string) string { + return strings.Join([]string{ + "BenchmarkDawgsIntegration", + sanitizeBenchNamePart(driver), + sanitizeBenchNamePart(dataset), + sanitizeBenchNamePart(query), + }, "/") +} + +func sanitizeBenchNamePart(value string) string { + var builder strings.Builder + lastUnderscore := false + + for _, char := range value { + switch { + case char == '/' || char == '-' || char == '_': + if char == '_' { + if !lastUnderscore { + builder.WriteRune(char) + } + lastUnderscore = true + } else { + builder.WriteRune(char) + lastUnderscore = false + } + case unicode.IsLetter(char) || unicode.IsDigit(char): + builder.WriteRune(char) + lastUnderscore = false + case unicode.IsSpace(char): + if !lastUnderscore { + builder.WriteByte('_') + } + lastUnderscore = true + default: + if !lastUnderscore { + builder.WriteByte('_') + } + lastUnderscore = true + } + } + + if builder.Len() == 0 { + return "unknown" + } + + return builder.String() +} + +func parseBenchmarkDuration(value string) (time.Duration, error) { + trimmed := strings.TrimSpace(value) + if trimmed == "" || trimmed == "-" { + return 0, fmt.Errorf("empty benchmark duration") + } + + unitStart := len(trimmed) + for idx, char := range trimmed { + if (char < '0' || char > '9') && char != '.' { + unitStart = idx + break + } + } + + number, err := strconv.ParseFloat(strings.TrimSpace(trimmed[:unitStart]), 64) + if err != nil { + return 0, err + } + + unit := strings.TrimSpace(trimmed[unitStart:]) + switch unit { + case "ns": + return time.Duration(math.Round(number)), nil + case "us": + return time.Duration(math.Round(number * float64(time.Microsecond))), nil + case "ms": + return time.Duration(math.Round(number * float64(time.Millisecond))), nil + case "s": + return time.Duration(math.Round(number * float64(time.Second))), nil + default: + return 0, fmt.Errorf("unsupported benchmark duration unit %q", unit) + } +} diff --git a/cmd/benchdiff/benchfmt_test.go b/cmd/benchdiff/benchfmt_test.go new file mode 100644 index 00000000..0b692352 --- /dev/null +++ b/cmd/benchdiff/benchfmt_test.go @@ -0,0 +1,143 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestParseBenchfmtNS(t *testing.T) { + samples := parseBenchfmtNS([]byte(` +goos: linux +BenchmarkThing-12 10 100.5 ns/op 1 B/op +BenchmarkThing-12 10 120 ns/op 1 B/op +BenchmarkOther/sub-12 1 200 ns/op +`)) + + require.Equal(t, []float64{100.5, 120}, samples["BenchmarkThing-12"]) + require.Equal(t, []float64{200}, samples["BenchmarkOther/sub-12"]) +} + +func TestSummarizeFindings(t *testing.T) { + base := benchmarkSamples{ + "BenchmarkRegression-12": {100, 110, 120}, + "BenchmarkImprovement-12": {200, 200, 200}, + "BenchmarkSame-12": {100, 100, 100}, + "BenchmarkOnlyBase-12": {50}, + } + target := benchmarkSamples{ + "BenchmarkRegression-12": {140, 150, 160}, + "BenchmarkImprovement-12": {100, 100, 100}, + "BenchmarkSame-12": {100, 100, 100}, + "BenchmarkOnlyTarget-12": {75}, + } + + findings := summarizeFindings(base, target) + + require.Equal(t, 3, findings.Compared) + require.Equal(t, 1, findings.Unchanged) + require.Equal(t, []string{"BenchmarkOnlyBase-12"}, findings.OnlyBase) + require.Equal(t, []string{"BenchmarkOnlyTarget-12"}, findings.OnlyTarget) + require.Len(t, findings.Results, 5) + require.Len(t, findings.Regressions, 1) + require.Equal(t, "BenchmarkRegression-12", findings.Regressions[0].Name) + require.InDelta(t, 36.36, findings.Regressions[0].DeltaPercent, 0.01) + require.Len(t, findings.Improvements, 1) + require.Equal(t, "BenchmarkImprovement-12", findings.Improvements[0].Name) + require.Equal(t, -50.0, findings.Improvements[0].DeltaPercent) + + regressions := findings.regressionsOver(10) + require.Len(t, regressions, 1) + require.Equal(t, "BenchmarkRegression-12", regressions[0].Name) + require.InDelta(t, 36.36, regressions[0].Percent, 0.01) + + onlyBase := requireBenchmarkResult(t, findings.Results, "BenchmarkOnlyBase-12") + require.True(t, onlyBase.HasBase) + require.False(t, onlyBase.HasTarget) + require.Equal(t, 1, onlyBase.BaseSamples) + require.Equal(t, 50.0, onlyBase.BaseMedianNS) + + onlyTarget := requireBenchmarkResult(t, findings.Results, "BenchmarkOnlyTarget-12") + require.False(t, onlyTarget.HasBase) + require.True(t, onlyTarget.HasTarget) + require.Equal(t, 1, onlyTarget.TargetSamples) + require.Equal(t, 75.0, onlyTarget.TargetMedianNS) +} + +func TestParseBenchmarkMarkdown(t *testing.T) { + rows := parseBenchmarkMarkdown([]byte(` +| Query | Dataset | Median | P95 | Max | +|-------|---------|-------:|----:|----:| +| Match Nodes | base | 0.14ms | 0.22ms | 0.31ms | +| Match Edges | base | 464ms | 604ms | 604ms | +`)) + + require.Equal(t, []markdownBenchmarkRow{ + {Query: "Match Nodes", Dataset: "base", Median: 140 * time.Microsecond}, + {Query: "Match Edges", Dataset: "base", Median: 464 * time.Millisecond}, + }, rows) +} + +func TestWriteIntegrationBenchfmt(t *testing.T) { + var out bytes.Buffer + rows := []markdownBenchmarkRow{{ + Query: "Shortest Paths / n1 -> n3", + Dataset: "base", + Median: time.Millisecond, + }} + + require.NoError(t, writeIntegrationBenchfmt(&out, "pg", rows)) + require.Contains(t, out.String(), "BenchmarkDawgsIntegration/pg/base/Shortest_Paths_/_n1_-_n3-") + require.Contains(t, out.String(), "\t1\t1000000 ns/op") +} + +func TestParseRegressionThreshold(t *testing.T) { + threshold, err := parseRegressionThreshold("10%") + require.NoError(t, err) + require.Equal(t, 10.0, threshold) + + threshold, err = parseRegressionThreshold("2.5") + require.NoError(t, err) + require.Equal(t, 2.5, threshold) + + _, err = parseRegressionThreshold("-1") + require.Error(t, err) +} + +func TestValidateBenchtime(t *testing.T) { + require.NoError(t, validateBenchtime("1s")) + require.NoError(t, validateBenchtime("100x")) + require.Error(t, validateBenchtime("0x")) + require.Error(t, validateBenchtime("soon")) +} + +func requireBenchmarkResult(t *testing.T, results []benchmarkResult, name string) benchmarkResult { + t.Helper() + + for _, result := range results { + if result.Name == name { + return result + } + } + + require.Failf(t, "benchmark result not found", "%s", name) + return benchmarkResult{} +} diff --git a/cmd/benchdiff/command.go b/cmd/benchdiff/command.go new file mode 100644 index 00000000..eae51723 --- /dev/null +++ b/cmd/benchdiff/command.go @@ -0,0 +1,91 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "strings" +) + +func gitOutput(ctx context.Context, dir string, args ...string) (string, error) { + output, err := runCommand(ctx, dir, nil, "git", args...) + if err != nil { + return "", err + } + + return strings.TrimSpace(string(output)), nil +} + +func runCommand(ctx context.Context, dir string, env []string, name string, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + cmd.Dir = dir + cmd.Env = os.Environ() + if len(env) > 0 { + cmd.Env = append(cmd.Env, env...) + } + + output, err := cmd.CombinedOutput() + if err != nil { + return output, commandError{ + Name: name, + Args: args, + Err: err, + Output: output, + } + } + + return output, nil +} + +type commandError struct { + Name string + Args []string + Err error + Output []byte +} + +func (err commandError) Error() string { + var builder strings.Builder + builder.WriteString(err.Name) + if len(err.Args) > 0 { + builder.WriteByte(' ') + builder.WriteString(strings.Join(err.Args, " ")) + } + builder.WriteString(": ") + builder.WriteString(err.Err.Error()) + + output := bytes.TrimSpace(err.Output) + if len(output) > 0 { + builder.WriteString("\n") + builder.Write(outputTail(output, 4096)) + } + + return builder.String() +} + +func outputTail(output []byte, limit int) []byte { + if len(output) <= limit { + return output + } + + prefix := []byte(fmt.Sprintf("... truncated %d bytes ...\n", len(output)-limit)) + return append(prefix, output[len(output)-limit:]...) +} diff --git a/cmd/benchdiff/compare.go b/cmd/benchdiff/compare.go new file mode 100644 index 00000000..10792f8b --- /dev/null +++ b/cmd/benchdiff/compare.go @@ -0,0 +1,263 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" +) + +func runUnitComparison(ctx context.Context, cfg resolvedConfig, baseWorktree, targetWorktree string) (comparison, error) { + outDir := filepath.Join(cfg.OutDirAbs, "unit") + if err := os.MkdirAll(outDir, 0755); err != nil { + return comparison{}, err + } + + baseFile := filepath.Join(outDir, "base.txt") + targetFile := filepath.Join(outDir, "target.txt") + + if err := runGoBenchmarks(ctx, cfg, baseWorktree, baseFile); err != nil { + return comparison{}, err + } + if err := runGoBenchmarks(ctx, cfg, targetWorktree, targetFile); err != nil { + return comparison{}, err + } + + benchstatOutput, err := runBenchstat(ctx, cfg, baseFile, targetFile) + if err != nil { + return comparison{}, err + } + benchstatFile := filepath.Join(outDir, "benchstat.txt") + if err := os.WriteFile(benchstatFile, benchstatOutput, 0644); err != nil { + return comparison{}, err + } + + findings, err := findingsForFiles(baseFile, targetFile) + if err != nil { + return comparison{}, err + } + regressions := findings.regressionsOver(cfg.Threshold) + + return comparison{ + Name: "Unit Benchmarks", + BaseFile: baseFile, + TargetFile: targetFile, + BenchstatFile: benchstatFile, + Benchstat: string(benchstatOutput), + Findings: findings, + Regressions: regressions, + }, nil +} + +func runGoBenchmarks(ctx context.Context, cfg resolvedConfig, worktree, outputPath string) error { + args := []string{ + "test", + "-run", "^$", + "-bench", cfg.Bench, + "-benchmem", + "-count", strconv.Itoa(cfg.BenchCount), + "-benchtime", cfg.Benchtime, + } + args = append(args, strings.Fields(cfg.Packages)...) + + output, err := runCommand(ctx, worktree, nil, "go", args...) + if writeErr := os.WriteFile(outputPath, output, 0644); writeErr != nil { + return writeErr + } + if err != nil { + return fmt.Errorf("run Go benchmarks in %s: %w", worktree, err) + } + + return nil +} + +func runIntegrationComparison(ctx context.Context, cfg resolvedConfig, baseWorktree, targetWorktree string) (comparison, error) { + outDir := filepath.Join(cfg.OutDirAbs, "integration") + binDir := filepath.Join(cfg.OutDirAbs, "bin") + if err := os.MkdirAll(outDir, 0755); err != nil { + return comparison{}, err + } + if err := os.MkdirAll(binDir, 0755); err != nil { + return comparison{}, err + } + + baseBinary := filepath.Join(binDir, "benchmark-base") + targetBinary := filepath.Join(binDir, "benchmark-target") + if err := buildBenchmarkBinary(ctx, baseWorktree, baseBinary); err != nil { + return comparison{}, err + } + if err := buildBenchmarkBinary(ctx, targetWorktree, targetBinary); err != nil { + return comparison{}, err + } + + baseSupportsBenchfmt := benchmarkBinarySupportsFormat(ctx, baseBinary) + targetSupportsBenchfmt := benchmarkBinarySupportsFormat(ctx, targetBinary) + useNativeBenchfmt := baseSupportsBenchfmt && targetSupportsBenchfmt + + baseFile := filepath.Join(outDir, "base.bench") + targetFile := filepath.Join(outDir, "target.bench") + var notes []string + if useNativeBenchfmt { + if err := runBenchmarkBinary(ctx, cfg, baseWorktree, baseBinary, baseFile, true); err != nil { + return comparison{}, err + } + if err := runBenchmarkBinary(ctx, cfg, targetWorktree, targetBinary, targetFile, true); err != nil { + return comparison{}, err + } + notes = append(notes, "Used native benchfmt output from cmd/benchmark.") + } else { + baseMarkdown := filepath.Join(outDir, "base.md") + targetMarkdown := filepath.Join(outDir, "target.md") + + if err := runBenchmarkBinary(ctx, cfg, baseWorktree, baseBinary, baseMarkdown, false); err != nil { + return comparison{}, err + } + if err := runBenchmarkBinary(ctx, cfg, targetWorktree, targetBinary, targetMarkdown, false); err != nil { + return comparison{}, err + } + + if err := markdownFileToBenchfmt(baseMarkdown, baseFile, cfg.Driver); err != nil { + return comparison{}, err + } + if err := markdownFileToBenchfmt(targetMarkdown, targetFile, cfg.Driver); err != nil { + return comparison{}, err + } + + notes = append(notes, "Used Markdown compatibility mode because at least one ref does not support cmd/benchmark -format benchfmt.") + } + + benchstatOutput, err := runBenchstat(ctx, cfg, baseFile, targetFile) + if err != nil { + return comparison{}, err + } + benchstatFile := filepath.Join(outDir, "benchstat.txt") + if err := os.WriteFile(benchstatFile, benchstatOutput, 0644); err != nil { + return comparison{}, err + } + + findings, err := findingsForFiles(baseFile, targetFile) + if err != nil { + return comparison{}, err + } + regressions := findings.regressionsOver(cfg.Threshold) + + return comparison{ + Name: "Integration Benchmarks", + BaseFile: baseFile, + TargetFile: targetFile, + BenchstatFile: benchstatFile, + Benchstat: string(benchstatOutput), + Notes: notes, + Findings: findings, + Regressions: regressions, + }, nil +} + +func buildBenchmarkBinary(ctx context.Context, worktree, outputPath string) error { + output, err := runCommand(ctx, worktree, nil, "go", "build", "-o", outputPath, "./cmd/benchmark") + if err != nil { + return fmt.Errorf("build cmd/benchmark in %s: %w", worktree, err) + } + if len(bytes.TrimSpace(output)) > 0 { + logPath := outputPath + ".log" + if writeErr := os.WriteFile(logPath, output, 0644); writeErr != nil { + return writeErr + } + } + + return nil +} + +func benchmarkBinarySupportsFormat(ctx context.Context, binaryPath string) bool { + output, err := runCommand(ctx, "", nil, binaryPath, "-h") + if err != nil { + return false + } + + return bytes.Contains(output, []byte("-format")) +} + +func runBenchmarkBinary(ctx context.Context, cfg resolvedConfig, worktree, binaryPath, outputPath string, benchfmt bool) error { + args := []string{ + "-driver", cfg.Driver, + "-iterations", strconv.Itoa(cfg.IntegrationIterations), + "-dataset-dir", cfg.DatasetDirAbs, + "-output", outputPath, + } + if benchfmt { + args = append(args, "-format", "benchfmt") + } + if cfg.Dataset != "" { + args = append(args, "-dataset", cfg.Dataset) + } + if cfg.LocalDataset != "" { + args = append(args, "-local-dataset", cfg.LocalDataset) + } + + output, err := runCommand(ctx, worktree, []string{"CONNECTION_STRING=" + cfg.Connection}, binaryPath, args...) + logPath := outputPath + ".log" + if writeErr := os.WriteFile(logPath, output, 0644); writeErr != nil { + return writeErr + } + if err != nil { + return fmt.Errorf("run integration benchmark in %s: %w", worktree, err) + } + + return nil +} + +func markdownFileToBenchfmt(markdownPath, benchfmtPath, driver string) error { + data, err := os.ReadFile(markdownPath) + if err != nil { + return err + } + + var output bytes.Buffer + if err := writeIntegrationBenchfmt(&output, driver, parseBenchmarkMarkdown(data)); err != nil { + return err + } + + return os.WriteFile(benchfmtPath, output.Bytes(), 0644) +} + +func runBenchstat(ctx context.Context, cfg resolvedConfig, baseFile, targetFile string) ([]byte, error) { + if cfg.Benchstat == "none" { + return []byte("benchstat skipped\n"), nil + } + + if cfg.Benchstat == "" || cfg.Benchstat == "auto" { + if benchstatPath, err := exec.LookPath("benchstat"); err == nil { + return runCommand(ctx, cfg.Root, nil, benchstatPath, baseFile, targetFile) + } + + return runCommand(ctx, cfg.Root, nil, "go", "run", "golang.org/x/perf/cmd/benchstat@latest", baseFile, targetFile) + } + + fields := strings.Fields(cfg.Benchstat) + if len(fields) == 0 { + return nil, fmt.Errorf("empty benchstat command") + } + + args := append(fields[1:], baseFile, targetFile) + return runCommand(ctx, cfg.Root, nil, fields[0], args...) +} diff --git a/cmd/benchdiff/main.go b/cmd/benchdiff/main.go new file mode 100644 index 00000000..df921705 --- /dev/null +++ b/cmd/benchdiff/main.go @@ -0,0 +1,180 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "flag" + "fmt" + "os" + "strconv" + "strings" + "time" +) + +const ( + benchKindAll = "all" + benchKindIntegration = "integration" + benchKindUnit = "unit" +) + +type config struct { + BaseRef string + TargetRef string + Kind string + Packages string + Bench string + BenchCount int + Benchtime string + Driver string + Connection string + Dataset string + LocalDataset string + DatasetDir string + IntegrationIterations int + OutDir string + Benchstat string + FailRegression string + KeepWorktrees bool +} + +type resolvedConfig struct { + config + Root string + BaseSHA string + TargetSHA string + BaseShortSHA string + TargetShortSHA string + DatasetDirAbs string + OutDirAbs string + Threshold float64 +} + +func main() { + cfg, err := parseConfig(os.Args[1:]) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(2) + } + + if err := run(context.Background(), cfg); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func parseConfig(args []string) (config, error) { + cfg := config{} + flags := flag.NewFlagSet("benchdiff", flag.ContinueOnError) + flags.SetOutput(os.Stderr) + + flags.StringVar(&cfg.BaseRef, "base", "main", "base git ref to benchmark") + flags.StringVar(&cfg.TargetRef, "target", "HEAD", "target git ref to benchmark") + flags.StringVar(&cfg.Kind, "kind", benchKindAll, "benchmark kind: all, unit, integration") + flags.StringVar(&cfg.Packages, "packages", "./...", "package list for Go benchmarks") + flags.StringVar(&cfg.Bench, "bench", ".", "Go benchmark regexp") + flags.IntVar(&cfg.BenchCount, "bench-count", 10, "Go benchmark repetition count") + flags.StringVar(&cfg.Benchtime, "benchtime", "1s", "Go benchmark benchtime") + flags.StringVar(&cfg.Driver, "driver", "pg", "integration benchmark database driver") + flags.StringVar(&cfg.Connection, "connection", "", "integration database connection string (or CONNECTION_STRING)") + flags.StringVar(&cfg.Dataset, "dataset", "", "run only this integration dataset") + flags.StringVar(&cfg.LocalDataset, "local-dataset", "", "additional local integration dataset") + flags.StringVar(&cfg.DatasetDir, "dataset-dir", "integration/testdata", "integration testdata directory") + flags.IntVar(&cfg.IntegrationIterations, "integration-iterations", 10, "timed iterations per integration scenario") + flags.StringVar(&cfg.OutDir, "out", "", "output directory") + flags.StringVar(&cfg.Benchstat, "benchstat", "auto", "benchstat command, auto, or none") + flags.StringVar(&cfg.FailRegression, "fail-regression", "0", "fail when median ns/op regression exceeds this percent, e.g. 10%") + flags.BoolVar(&cfg.KeepWorktrees, "keep-worktrees", false, "keep temporary git worktrees") + + if err := flags.Parse(args); err != nil { + return config{}, err + } + if flags.NArg() != 0 { + return config{}, fmt.Errorf("unexpected positional arguments: %s", strings.Join(flags.Args(), " ")) + } + if !isBenchKind(cfg.Kind) { + return config{}, fmt.Errorf("unsupported benchmark kind %q", cfg.Kind) + } + if cfg.BenchCount < 1 { + return config{}, fmt.Errorf("bench-count must be at least 1") + } + if cfg.IntegrationIterations < 1 { + return config{}, fmt.Errorf("integration-iterations must be at least 1") + } + if err := validateBenchtime(cfg.Benchtime); err != nil { + return config{}, err + } + if _, err := parseRegressionThreshold(cfg.FailRegression); err != nil { + return config{}, err + } + + return cfg, nil +} + +func isBenchKind(kind string) bool { + switch kind { + case benchKindAll, benchKindIntegration, benchKindUnit: + return true + default: + return false + } +} + +func (cfg config) runsUnitBenchmarks() bool { + return cfg.Kind == benchKindAll || cfg.Kind == benchKindUnit +} + +func (cfg config) runsIntegrationBenchmarks() bool { + return cfg.Kind == benchKindAll || cfg.Kind == benchKindIntegration +} + +func parseRegressionThreshold(value string) (float64, error) { + trimmed := strings.TrimSpace(value) + trimmed = strings.TrimSuffix(trimmed, "%") + + if trimmed == "" { + return 0, fmt.Errorf("fail-regression must be a non-negative percent") + } + + threshold, err := strconv.ParseFloat(trimmed, 64) + if err != nil { + return 0, fmt.Errorf("invalid fail-regression %q: %w", value, err) + } + if threshold < 0 { + return 0, fmt.Errorf("fail-regression must be a non-negative percent") + } + + return threshold, nil +} + +func validateBenchtime(value string) error { + trimmed := strings.TrimSpace(value) + if strings.HasSuffix(trimmed, "x") { + count, err := strconv.Atoi(strings.TrimSuffix(trimmed, "x")) + if err != nil || count < 1 { + return fmt.Errorf("invalid benchtime %q", value) + } + + return nil + } + + if _, err := time.ParseDuration(trimmed); err != nil { + return fmt.Errorf("invalid benchtime %q: %w", value, err) + } + + return nil +} diff --git a/cmd/benchdiff/markdown.go b/cmd/benchdiff/markdown.go new file mode 100644 index 00000000..b472b3b3 --- /dev/null +++ b/cmd/benchdiff/markdown.go @@ -0,0 +1,76 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bufio" + "bytes" + "strings" + "time" +) + +type markdownBenchmarkRow struct { + Query string + Dataset string + Median time.Duration +} + +func parseBenchmarkMarkdown(data []byte) []markdownBenchmarkRow { + var rows []markdownBenchmarkRow + scanner := bufio.NewScanner(bytes.NewReader(data)) + + for scanner.Scan() { + columns := splitMarkdownTableRow(scanner.Text()) + if len(columns) < 5 { + continue + } + if columns[0] == "Query" || strings.HasPrefix(columns[0], "---") { + continue + } + + median, err := parseBenchmarkDuration(columns[2]) + if err != nil { + continue + } + + rows = append(rows, markdownBenchmarkRow{ + Query: columns[0], + Dataset: columns[1], + Median: median, + }) + } + + return rows +} + +func splitMarkdownTableRow(line string) []string { + trimmed := strings.TrimSpace(line) + if !strings.HasPrefix(trimmed, "|") || !strings.HasSuffix(trimmed, "|") { + return nil + } + + trimmed = strings.TrimPrefix(trimmed, "|") + trimmed = strings.TrimSuffix(trimmed, "|") + + rawColumns := strings.Split(trimmed, "|") + columns := make([]string, 0, len(rawColumns)) + for _, column := range rawColumns { + columns = append(columns, strings.TrimSpace(column)) + } + + return columns +} diff --git a/cmd/benchdiff/report.go b/cmd/benchdiff/report.go new file mode 100644 index 00000000..416c69b3 --- /dev/null +++ b/cmd/benchdiff/report.go @@ -0,0 +1,292 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +const maxFindingRows = 10 + +func writeRunReport(path string, summary runSummary) error { + var out bytes.Buffer + cfg := summary.Config + + fmt.Fprintln(&out, "# Benchmark Diff") + fmt.Fprintln(&out) + fmt.Fprintln(&out, "| Field | Value |") + fmt.Fprintln(&out, "|-------|-------|") + fmt.Fprintf(&out, "| Base | `%s` (`%s`) |\n", cfg.BaseRef, cfg.BaseShortSHA) + fmt.Fprintf(&out, "| Target | `%s` (`%s`) |\n", cfg.TargetRef, cfg.TargetShortSHA) + fmt.Fprintf(&out, "| Started | %s |\n", summary.StartedAt.UTC().Format(time.RFC3339)) + fmt.Fprintf(&out, "| Finished | %s |\n", summary.FinishedAt.UTC().Format(time.RFC3339)) + fmt.Fprintf(&out, "| Go | %s |\n", summary.GoVersion) + fmt.Fprintf(&out, "| Platform | %s/%s |\n", runtime.GOOS, runtime.GOARCH) + fmt.Fprintf(&out, "| Kind | `%s` |\n", cfg.Kind) + fmt.Fprintf(&out, "| Output | `%s` |\n", cfg.OutDirAbs) + if cfg.runsIntegrationBenchmarks() { + fmt.Fprintf(&out, "| Driver | `%s` |\n", cfg.Driver) + fmt.Fprintf(&out, "| Dataset Dir | `%s` |\n", cfg.DatasetDirAbs) + fmt.Fprintf(&out, "| Integration Iterations | %d |\n", cfg.IntegrationIterations) + } + if cfg.runsUnitBenchmarks() { + fmt.Fprintf(&out, "| Packages | `%s` |\n", cfg.Packages) + fmt.Fprintf(&out, "| Bench | `%s` |\n", cfg.Bench) + fmt.Fprintf(&out, "| Bench Count | %d |\n", cfg.BenchCount) + fmt.Fprintf(&out, "| Benchtime | `%s` |\n", cfg.Benchtime) + } + if cfg.Threshold > 0 { + fmt.Fprintf(&out, "| Regression Failure Threshold | %.2f%% |\n", cfg.Threshold) + } + fmt.Fprintln(&out) + + writeFindingsSummary(&out, summary) + + for _, comparison := range summary.Comparisons { + fmt.Fprintf(&out, "## %s\n\n", comparison.Name) + for _, note := range comparison.Notes { + fmt.Fprintf(&out, "- %s\n", note) + } + if len(comparison.Notes) > 0 { + fmt.Fprintln(&out) + } + + fmt.Fprintf(&out, "- Base raw: `%s`\n", relOrAbs(cfg.OutDirAbs, comparison.BaseFile)) + fmt.Fprintf(&out, "- Target raw: `%s`\n", relOrAbs(cfg.OutDirAbs, comparison.TargetFile)) + fmt.Fprintf(&out, "- Benchstat: `%s`\n\n", relOrAbs(cfg.OutDirAbs, comparison.BenchstatFile)) + + fmt.Fprintln(&out, "```text") + fmt.Fprint(&out, comparison.Benchstat) + if len(comparison.Benchstat) == 0 || comparison.Benchstat[len(comparison.Benchstat)-1] != '\n' { + fmt.Fprintln(&out) + } + fmt.Fprintln(&out, "```") + fmt.Fprintln(&out) + + writeRegressionSection(&out, comparison, cfg.Threshold) + } + + writeAllExecutedNumbers(&out, summary) + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + + return os.WriteFile(path, out.Bytes(), 0644) +} + +func writeAllExecutedNumbers(out *bytes.Buffer, summary runSummary) { + fmt.Fprintln(out, "## All Executed Benchmark Numbers") + fmt.Fprintln(out) + + if len(summary.Comparisons) == 0 { + fmt.Fprintln(out, "No benchmark numbers were captured.") + fmt.Fprintln(out) + return + } + + for _, comparison := range summary.Comparisons { + fmt.Fprintf(out, "### %s\n\n", comparison.Name) + if len(comparison.Findings.Results) == 0 { + fmt.Fprintln(out, "No benchmark numbers were captured.") + fmt.Fprintln(out) + continue + } + + fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change | Base Samples | Target Samples |") + fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|-------------:|---------------:|") + for _, result := range comparison.Findings.Results { + fmt.Fprintf(out, "| `%s` | %s | %s | %s | %s | %s |\n", + result.Name, + formatOptionalNS(result.HasBase, result.BaseMedianNS), + formatOptionalNS(result.HasTarget, result.TargetMedianNS), + formatOptionalPercent(result.HasBase && result.HasTarget, result.DeltaPercent), + formatOptionalInt(result.HasBase, result.BaseSamples), + formatOptionalInt(result.HasTarget, result.TargetSamples), + ) + } + fmt.Fprintln(out) + } +} + +func writeFindingsSummary(out *bytes.Buffer, summary runSummary) { + fmt.Fprintln(out, "## Findings") + fmt.Fprintln(out) + + if len(summary.Comparisons) == 0 { + fmt.Fprintln(out, "No benchmark comparisons were run.") + fmt.Fprintln(out) + return + } + + for _, comparison := range summary.Comparisons { + findings := comparison.Findings + fmt.Fprintf(out, "### %s\n\n", comparison.Name) + fmt.Fprintf(out, "- Compared %d matching benchmark%s.\n", findings.Compared, pluralSuffix(findings.Compared)) + fmt.Fprintf(out, "- Median regressions: %d; median improvements: %d; unchanged: %d.\n", + len(findings.Regressions), + len(findings.Improvements), + findings.Unchanged, + ) + if len(findings.OnlyBase) > 0 { + fmt.Fprintf(out, "- Only in base: %s.\n", inlineBenchmarkList(findings.OnlyBase, maxFindingRows)) + } + if len(findings.OnlyTarget) > 0 { + fmt.Fprintf(out, "- Only in target: %s.\n", inlineBenchmarkList(findings.OnlyTarget, maxFindingRows)) + } + fmt.Fprintln(out) + + writeFindingTable(out, "Top Median Regressions", findings.Regressions, maxFindingRows) + writeFindingTable(out, "Top Median Improvements", findings.Improvements, maxFindingRows) + } +} + +func writeFindingTable(out *bytes.Buffer, title string, findings []benchmarkFinding, limit int) { + fmt.Fprintf(out, "#### %s\n\n", title) + if len(findings) == 0 { + fmt.Fprintln(out, "None.") + fmt.Fprintln(out) + return + } + + fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change |") + fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|") + + for idx, finding := range findings { + if idx >= limit { + break + } + + fmt.Fprintf(out, "| `%s` | %s | %s | %+.2f%% |\n", + finding.Name, + formatNS(finding.BaseMedianNS), + formatNS(finding.TargetMedianNS), + finding.DeltaPercent, + ) + } + if len(findings) > limit { + fmt.Fprintf(out, "\n_%d more not shown._\n", len(findings)-limit) + } + fmt.Fprintln(out) +} + +func pluralSuffix(count int) string { + if count == 1 { + return "" + } + + return "s" +} + +func inlineBenchmarkList(names []string, limit int) string { + var builder strings.Builder + + for idx, name := range names { + if idx >= limit { + break + } + if idx > 0 { + builder.WriteString(", ") + } + builder.WriteByte('`') + builder.WriteString(name) + builder.WriteByte('`') + } + + if len(names) > limit { + fmt.Fprintf(&builder, ", and %d more", len(names)-limit) + } + + return builder.String() +} + +func formatNS(value float64) string { + switch { + case value >= float64(time.Second): + return fmt.Sprintf("%.2fs", value/float64(time.Second)) + case value >= float64(time.Millisecond): + return fmt.Sprintf("%.2fms", value/float64(time.Millisecond)) + case value >= float64(time.Microsecond): + return fmt.Sprintf("%.2fus", value/float64(time.Microsecond)) + default: + return fmt.Sprintf("%.0fns", value) + } +} + +func formatOptionalNS(ok bool, value float64) string { + if !ok { + return "-" + } + + return formatNS(value) +} + +func formatOptionalPercent(ok bool, value float64) string { + if !ok { + return "-" + } + + return fmt.Sprintf("%+.2f%%", value) +} + +func formatOptionalInt(ok bool, value int) string { + if !ok { + return "-" + } + + return fmt.Sprintf("%d", value) +} + +func writeRegressionSection(out *bytes.Buffer, comparison comparison, threshold float64) { + if threshold <= 0 { + return + } + + fmt.Fprintf(out, "### Regressions Over %.2f%%\n\n", threshold) + if len(comparison.Regressions) == 0 { + fmt.Fprintln(out, "None.") + fmt.Fprintln(out) + return + } + + fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change |") + fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|") + for _, regression := range comparison.Regressions { + fmt.Fprintf(out, "| `%s` | %s | %s | +%.2f%% |\n", + regression.Name, + formatNS(regression.BaseMedianNS), + formatNS(regression.TargetMedianNS), + regression.Percent, + ) + } + fmt.Fprintln(out) +} + +func relOrAbs(base, path string) string { + rel, err := filepath.Rel(base, path) + if err != nil || rel == "." || len(rel) >= len(path) { + return path + } + + return rel +} diff --git a/cmd/benchdiff/report_test.go b/cmd/benchdiff/report_test.go new file mode 100644 index 00000000..1e09694d --- /dev/null +++ b/cmd/benchdiff/report_test.go @@ -0,0 +1,141 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWriteRunReportIncludesTopLevelFindings(t *testing.T) { + outputPath := filepath.Join(t.TempDir(), "report.md") + summary := runSummary{ + Config: resolvedConfig{ + config: config{ + BaseRef: "main", + TargetRef: "HEAD", + Kind: benchKindUnit, + Packages: "./...", + Bench: ".", + BenchCount: 3, + Benchtime: "1s", + FailRegression: "10%", + }, + BaseShortSHA: "abc1234", + TargetShortSHA: "def5678", + OutDirAbs: filepath.Dir(outputPath), + Threshold: 10, + }, + GoVersion: "go-test", + StartedAt: time.Date(2026, 5, 11, 1, 2, 3, 0, time.UTC), + FinishedAt: time.Date(2026, 5, 11, 1, 2, 4, 0, time.UTC), + Comparisons: []comparison{{ + Name: "Unit Benchmarks", + BaseFile: filepath.Join(filepath.Dir(outputPath), "unit", "base.txt"), + TargetFile: filepath.Join(filepath.Dir(outputPath), "unit", "target.txt"), + BenchstatFile: filepath.Join(filepath.Dir(outputPath), "unit", "benchstat.txt"), + Benchstat: "benchstat output\n", + Findings: comparisonFindings{ + Compared: 3, + Unchanged: 1, + Regressions: []benchmarkFinding{{ + Name: "BenchmarkSlow-12", + BaseMedianNS: 100, + TargetMedianNS: 150, + DeltaPercent: 50, + }}, + Improvements: []benchmarkFinding{{ + Name: "BenchmarkFast-12", + BaseMedianNS: 200, + TargetMedianNS: 100, + DeltaPercent: -50, + }}, + OnlyBase: []string{"BenchmarkRemoved-12"}, + OnlyTarget: []string{"BenchmarkAdded-12"}, + Results: []benchmarkResult{ + { + Name: "BenchmarkAdded-12", + TargetMedianNS: 75, + TargetSamples: 1, + HasTarget: true, + }, + { + Name: "BenchmarkFast-12", + BaseMedianNS: 200, + TargetMedianNS: 100, + DeltaPercent: -50, + BaseSamples: 3, + TargetSamples: 3, + HasBase: true, + HasTarget: true, + }, + { + Name: "BenchmarkRemoved-12", + BaseMedianNS: 50, + BaseSamples: 1, + HasBase: true, + }, + { + Name: "BenchmarkSlow-12", + BaseMedianNS: 100, + TargetMedianNS: 150, + DeltaPercent: 50, + BaseSamples: 3, + TargetSamples: 3, + HasBase: true, + HasTarget: true, + }, + }, + }, + Regressions: []regression{{ + Name: "BenchmarkSlow-12", + BaseMedianNS: 100, + TargetMedianNS: 150, + Percent: 50, + }}, + }}, + } + + require.NoError(t, writeRunReport(outputPath, summary)) + + report, err := os.ReadFile(outputPath) + require.NoError(t, err) + output := string(report) + + require.Contains(t, output, "## Findings") + require.Contains(t, output, "### Unit Benchmarks") + require.Contains(t, output, "- Compared 3 matching benchmarks.") + require.Contains(t, output, "- Median regressions: 1; median improvements: 1; unchanged: 1.") + require.Contains(t, output, "- Only in base: `BenchmarkRemoved-12`.") + require.Contains(t, output, "- Only in target: `BenchmarkAdded-12`.") + require.Contains(t, output, "#### Top Median Regressions") + require.Contains(t, output, "| `BenchmarkSlow-12` | 100ns | 150ns | +50.00% |") + require.Contains(t, output, "#### Top Median Improvements") + require.Contains(t, output, "| `BenchmarkFast-12` | 200ns | 100ns | -50.00% |") + require.Contains(t, output, "## Unit Benchmarks") + require.Contains(t, output, "benchstat output") + require.Contains(t, output, "## All Executed Benchmark Numbers") + require.Contains(t, output, "| Benchmark | Base Median | Target Median | Change | Base Samples | Target Samples |") + require.Contains(t, output, "| `BenchmarkSlow-12` | 100ns | 150ns | +50.00% | 3 | 3 |") + require.Contains(t, output, "| `BenchmarkFast-12` | 200ns | 100ns | -50.00% | 3 | 3 |") + require.Contains(t, output, "| `BenchmarkRemoved-12` | 50ns | - | - | 1 | - |") + require.Contains(t, output, "| `BenchmarkAdded-12` | - | 75ns | - | - | 1 |") +} diff --git a/cmd/benchdiff/run.go b/cmd/benchdiff/run.go new file mode 100644 index 00000000..12afb66a --- /dev/null +++ b/cmd/benchdiff/run.go @@ -0,0 +1,260 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +type comparison struct { + Name string + BaseFile string + TargetFile string + BenchstatFile string + Benchstat string + Notes []string + Findings comparisonFindings + Regressions []regression +} + +type runSummary struct { + Config resolvedConfig + GoVersion string + StartedAt time.Time + FinishedAt time.Time + Comparisons []comparison + ReportPath string +} + +func run(ctx context.Context, cfg config) error { + resolved, err := resolveConfig(ctx, cfg) + if err != nil { + return err + } + + summary := runSummary{ + Config: resolved, + GoVersion: runtime.Version(), + StartedAt: time.Now(), + } + + if err := os.MkdirAll(resolved.OutDirAbs, 0755); err != nil { + return fmt.Errorf("create output directory: %w", err) + } + + worktreeRoot := filepath.Join(resolved.OutDirAbs, "worktrees") + baseWorktree := filepath.Join(worktreeRoot, "base") + targetWorktree := filepath.Join(worktreeRoot, "target") + + fmt.Fprintf(os.Stderr, "preparing benchmark worktrees for %s and %s...\n", resolved.BaseShortSHA, resolved.TargetShortSHA) + + if err := addWorktree(ctx, resolved.Root, baseWorktree, resolved.BaseSHA); err != nil { + return err + } + removeBase := true + defer func() { + if removeBase && !resolved.KeepWorktrees { + _ = removeWorktree(context.Background(), resolved.Root, baseWorktree) + } + }() + + if err := addWorktree(ctx, resolved.Root, targetWorktree, resolved.TargetSHA); err != nil { + return err + } + removeTarget := true + defer func() { + if removeTarget && !resolved.KeepWorktrees { + _ = removeWorktree(context.Background(), resolved.Root, targetWorktree) + } + }() + + if resolved.runsUnitBenchmarks() { + fmt.Fprintln(os.Stderr, "running unit benchmarks...") + unitComparison, err := runUnitComparison(ctx, resolved, baseWorktree, targetWorktree) + if err != nil { + return err + } + summary.Comparisons = append(summary.Comparisons, unitComparison) + } + + if resolved.runsIntegrationBenchmarks() { + fmt.Fprintln(os.Stderr, "running integration benchmarks...") + integrationComparison, err := runIntegrationComparison(ctx, resolved, baseWorktree, targetWorktree) + if err != nil { + return err + } + summary.Comparisons = append(summary.Comparisons, integrationComparison) + } + + removeBase = false + removeTarget = false + if !resolved.KeepWorktrees { + if err := removeWorktree(ctx, resolved.Root, baseWorktree); err != nil { + return err + } + if err := removeWorktree(ctx, resolved.Root, targetWorktree); err != nil { + return err + } + } + + summary.FinishedAt = time.Now() + summary.ReportPath = filepath.Join(resolved.OutDirAbs, "report.md") + if err := writeRunReport(summary.ReportPath, summary); err != nil { + return err + } + + fmt.Fprintf(os.Stderr, "wrote benchmark diff report: %s\n", summary.ReportPath) + + if regressions := summary.regressions(); len(regressions) > 0 && resolved.Threshold > 0 { + return fmt.Errorf("%d benchmark regressions exceeded %.2f%%; see %s", len(regressions), resolved.Threshold, summary.ReportPath) + } + + return nil +} + +func resolveConfig(ctx context.Context, cfg config) (resolvedConfig, error) { + root, err := gitOutput(ctx, "", "rev-parse", "--show-toplevel") + if err != nil { + return resolvedConfig{}, err + } + + baseSHA, err := resolveCommit(ctx, root, cfg.BaseRef) + if err != nil { + return resolvedConfig{}, err + } + targetSHA, err := resolveCommit(ctx, root, cfg.TargetRef) + if err != nil { + return resolvedConfig{}, err + } + + baseShortSHA, err := shortCommit(ctx, root, baseSHA) + if err != nil { + return resolvedConfig{}, err + } + targetShortSHA, err := shortCommit(ctx, root, targetSHA) + if err != nil { + return resolvedConfig{}, err + } + + threshold, err := parseRegressionThreshold(cfg.FailRegression) + if err != nil { + return resolvedConfig{}, err + } + + if cfg.runsIntegrationBenchmarks() && cfg.Connection == "" { + cfg.Connection = os.Getenv("CONNECTION_STRING") + } + if cfg.runsIntegrationBenchmarks() && cfg.Connection == "" { + return resolvedConfig{}, fmt.Errorf("integration benchmarks require -connection or CONNECTION_STRING") + } + + datasetDirAbs, err := resolvePath(root, cfg.DatasetDir) + if err != nil { + return resolvedConfig{}, err + } + + outDir := cfg.OutDir + if outDir == "" { + outDir = filepath.Join(".bench", "runs", fmt.Sprintf("%s..%s-%s", baseShortSHA, targetShortSHA, time.Now().UTC().Format("20060102T150405Z"))) + } + outDirAbs, err := resolvePath(root, outDir) + if err != nil { + return resolvedConfig{}, err + } + + return resolvedConfig{ + config: cfg, + Root: root, + BaseSHA: baseSHA, + TargetSHA: targetSHA, + BaseShortSHA: baseShortSHA, + TargetShortSHA: targetShortSHA, + DatasetDirAbs: datasetDirAbs, + OutDirAbs: outDirAbs, + Threshold: threshold, + }, nil +} + +func resolvePath(root, value string) (string, error) { + if filepath.IsAbs(value) { + return filepath.Clean(value), nil + } + + return filepath.Abs(filepath.Join(root, value)) +} + +func resolveCommit(ctx context.Context, root, ref string) (string, error) { + sha, err := gitOutput(ctx, root, "rev-parse", "--verify", ref+"^{commit}") + if err != nil { + return "", fmt.Errorf("resolve git ref %q: %w", ref, err) + } + + return sha, nil +} + +func shortCommit(ctx context.Context, root, sha string) (string, error) { + shortSHA, err := gitOutput(ctx, root, "rev-parse", "--short", sha) + if err != nil { + return "", err + } + + return shortSHA, nil +} + +func addWorktree(ctx context.Context, root, path, sha string) error { + if _, err := os.Stat(path); err == nil { + return fmt.Errorf("worktree path already exists: %s", path) + } else if !os.IsNotExist(err) { + return err + } + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + + _, err := runCommand(ctx, root, nil, "git", "worktree", "add", "--detach", path, sha) + if err != nil { + return fmt.Errorf("add worktree %s: %w", path, err) + } + + return nil +} + +func removeWorktree(ctx context.Context, root, path string) error { + _, err := runCommand(ctx, root, nil, "git", "worktree", "remove", "--force", path) + if err != nil && !strings.Contains(err.Error(), "is not a working tree") { + return fmt.Errorf("remove worktree %s: %w", path, err) + } + + return nil +} + +func (summary runSummary) regressions() []regression { + var regressions []regression + + for _, comparison := range summary.Comparisons { + regressions = append(regressions, comparison.Regressions...) + } + + return regressions +} diff --git a/cmd/benchmark/README.md b/cmd/benchmark/README.md index 741118f5..c05d4c7a 100644 --- a/cmd/benchmark/README.md +++ b/cmd/benchmark/README.md @@ -1,13 +1,16 @@ # Benchmark -Runs query scenarios against a real database and outputs a markdown timing table. +Runs query scenarios against a real database and outputs markdown, JSON, or benchfmt timing data. ## Usage ```bash -# Default dataset (base) +# Default datasets (base and traversal_shapes) go run ./cmd/benchmark -connection "postgresql://dawgs:dawgs@localhost:5432/dawgs" +# Traversal shape dataset only +go run ./cmd/benchmark -connection "..." -dataset traversal_shapes + # Local dataset (not committed to repo) go run ./cmd/benchmark -connection "..." -dataset local/phantom @@ -20,6 +23,9 @@ go run ./cmd/benchmark -driver neo4j -connection "neo4j://neo4j:password@localho # Save to file go run ./cmd/benchmark -connection "..." -output report.md +# Emit benchfmt for benchstat +go run ./cmd/benchmark -connection "..." -format benchfmt -output report.bench + # Save markdown and JSON for quality baseline comparison go run ./cmd/benchmark -connection "..." -output report.md -json-output report.json ``` @@ -34,9 +40,17 @@ go run ./cmd/benchmark -connection "..." -output report.md -json-output report.j | `-dataset` | | Run only this dataset | | `-local-dataset` | | Add a local dataset to the default set | | `-dataset-dir` | `integration/testdata` | Path to testdata directory | -| `-output` | stdout | Markdown output file | +| `-format` | `markdown` | Output format (`markdown`, `json`, `benchfmt`) | +| `-output` | stdout | Output file | | `-json-output` | | JSON output file for baseline comparison | +Use `-format benchfmt` when comparing scenario timings with `benchstat`. Each timed scenario iteration is emitted as a +separate `ns/op` sample so two benchmark runs can be compared directly. + +The committed default datasets are `base` and `traversal_shapes`. `traversal_shapes` covers chain, fanout, bounded +cycle, disconnected, edge-kind-selective, and multi-path shortest-path traversal shapes. Scenarios with declared +expected row counts fail before reporting timings if a query returns the wrong result shape. + ## Example: Neo4j on local/phantom ``` diff --git a/cmd/benchmark/main.go b/cmd/benchmark/main.go index bb857d5b..dbb0e64b 100644 --- a/cmd/benchmark/main.go +++ b/cmd/benchmark/main.go @@ -40,7 +40,8 @@ func main() { driver = flag.String("driver", "pg", "database driver (pg, neo4j)") connStr = flag.String("connection", "", "database connection string (or CONNECTION_STRING)") iterations = flag.Int("iterations", 10, "timed iterations per scenario") - output = flag.String("output", "", "markdown output file (default: stdout)") + output = flag.String("output", "", "output file (default: stdout)") + format = flag.String("format", reportFormatMarkdown, "output format (markdown, json, benchfmt)") jsonOutput = flag.String("json-output", "", "JSON output file for baseline comparison") datasetDir = flag.String("dataset-dir", "integration/testdata", "path to testdata directory") localDataset = flag.String("local-dataset", "", "additional local dataset (e.g. local/phantom)") @@ -50,6 +51,13 @@ func main() { flag.Parse() + if *iterations < 1 { + fatal("iterations must be at least 1") + } + if !isReportFormat(*format) { + fatal("unsupported output format %q", *format) + } + conn := *connStr if conn == "" { conn = os.Getenv("CONNECTION_STRING") @@ -154,21 +162,21 @@ func main() { } } - // Write markdown - var mdOut *os.File + // Write primary report. + var out *os.File if *output != "" { var err error - mdOut, err = os.Create(*output) + out, err = os.Create(*output) if err != nil { fatal("failed to create output: %v", err) } - defer mdOut.Close() + defer out.Close() } else { - mdOut = os.Stdout + out = os.Stdout } - if err := writeMarkdown(mdOut, report); err != nil { - fatal("failed to write markdown: %v", err) + if err := writeReport(out, report, *format); err != nil { + fatal("failed to write report: %v", err) } if *output != "" { diff --git a/cmd/benchmark/report.go b/cmd/benchmark/report.go index dacab7dd..841baaa1 100644 --- a/cmd/benchmark/report.go +++ b/cmd/benchmark/report.go @@ -20,7 +20,16 @@ import ( "encoding/json" "fmt" "io" + "runtime" + "strings" "time" + "unicode" +) + +const ( + reportFormatBenchfmt = "benchfmt" + reportFormatJSON = "json" + reportFormatMarkdown = "markdown" ) // Report holds all benchmark results and metadata. @@ -32,6 +41,30 @@ type Report struct { Results []Result `json:"results"` } +func writeReport(w io.Writer, r Report, format string) error { + if !isReportFormat(format) { + return fmt.Errorf("unsupported output format %q", format) + } + + switch format { + case reportFormatBenchfmt: + return writeBenchfmt(w, r) + case reportFormatJSON: + return writeJSON(w, r) + default: + return writeMarkdown(w, r) + } +} + +func isReportFormat(format string) bool { + switch format { + case reportFormatBenchfmt, reportFormatJSON, reportFormatMarkdown: + return true + default: + return false + } +} + func writeJSON(w io.Writer, r Report) error { encoder := json.NewEncoder(w) encoder.SetIndent("", " ") @@ -62,6 +95,77 @@ func writeMarkdown(w io.Writer, r Report) error { return nil } +func writeBenchfmt(w io.Writer, r Report) error { + goos := runtime.GOOS + goarch := runtime.GOARCH + procs := runtime.GOMAXPROCS(0) + + fmt.Fprintf(w, "goos: %s\n", goos) + fmt.Fprintf(w, "goarch: %s\n", goarch) + fmt.Fprintf(w, "pkg: github.com/specterops/dawgs/cmd/benchmark\n") + + for _, res := range r.Results { + benchName := benchName(r.Driver, res) + + for _, sample := range res.Samples { + fmt.Fprintf(w, "%s-%d\t1\t%d ns/op\n", benchName, procs, sample.Nanoseconds()) + } + } + + return nil +} + +func benchName(driver string, res Result) string { + parts := []string{ + "BenchmarkDawgsIntegration", + sanitizeBenchNamePart(driver), + sanitizeBenchNamePart(res.Dataset), + sanitizeBenchNamePart(res.Section), + sanitizeBenchNamePart(res.Label), + } + + return strings.Join(parts, "/") +} + +func sanitizeBenchNamePart(value string) string { + var builder strings.Builder + lastUnderscore := false + + for _, char := range value { + switch { + case char == '/' || char == '-' || char == '_': + if char == '_' { + if !lastUnderscore { + builder.WriteRune(char) + } + lastUnderscore = true + } else { + builder.WriteRune(char) + lastUnderscore = false + } + case unicode.IsLetter(char) || unicode.IsDigit(char): + builder.WriteRune(char) + lastUnderscore = false + case unicode.IsSpace(char): + if !lastUnderscore { + builder.WriteByte('_') + } + lastUnderscore = true + default: + if !lastUnderscore { + builder.WriteByte('_') + } + lastUnderscore = true + } + } + + if builder.Len() == 0 { + return "unknown" + } + + return builder.String() +} + func fmtDuration(d time.Duration) string { ms := float64(d.Microseconds()) / 1000.0 if ms < 1 { diff --git a/cmd/benchmark/report_test.go b/cmd/benchmark/report_test.go index 2d72ed4d..c4bea53f 100644 --- a/cmd/benchmark/report_test.go +++ b/cmd/benchmark/report_test.go @@ -21,8 +21,82 @@ import ( "strings" "testing" "time" + + "github.com/stretchr/testify/require" ) +func TestWriteReportRejectsUnknownFormat(t *testing.T) { + err := writeReport(&bytes.Buffer{}, Report{}, "xml") + require.ErrorContains(t, err, "unsupported output format") +} + +func TestWriteJSON(t *testing.T) { + report := testReport() + var out bytes.Buffer + + require.NoError(t, writeReport(&out, report, reportFormatJSON)) + + require.Contains(t, out.String(), `"driver": "pg"`) + require.Contains(t, out.String(), `"samples": [`) + require.Contains(t, out.String(), `1000000`) +} + +func TestWriteBenchfmt(t *testing.T) { + report := testReport() + var out bytes.Buffer + + require.NoError(t, writeReport(&out, report, reportFormatBenchfmt)) + + output := out.String() + require.Contains(t, output, "goos: ") + require.Contains(t, output, "goarch: ") + require.Contains(t, output, "pkg: github.com/specterops/dawgs/cmd/benchmark") + require.Contains(t, output, "BenchmarkDawgsIntegration/pg/base/Match_Nodes/base-") + require.Contains(t, output, "\t1\t1000000 ns/op") + require.Contains(t, output, "\t1\t2000000 ns/op") +} + +func TestSanitizeBenchNamePart(t *testing.T) { + require.Equal(t, "Shortest_Paths", sanitizeBenchNamePart("Shortest Paths")) + require.Equal(t, "n1_-_n3", sanitizeBenchNamePart("n1 -> n3")) + require.Equal(t, "local/phantom", sanitizeBenchNamePart("local/phantom")) + require.Equal(t, "unknown", sanitizeBenchNamePart("")) +} + +func TestWriteMarkdownOmitsSamples(t *testing.T) { + report := testReport() + var out bytes.Buffer + + require.NoError(t, writeReport(&out, report, reportFormatMarkdown)) + + output := out.String() + require.Contains(t, output, "| Match Nodes | base | 2.0ms | 2.0ms | 2.0ms |") + require.False(t, strings.Contains(output, "1000000")) +} + +func testReport() Report { + return Report{ + Driver: "pg", + GitRef: "abcdef0", + Date: "2026-05-11", + Iterations: 2, + Results: []Result{{ + Section: "Match Nodes", + Dataset: "base", + Label: "base", + Stats: Stats{ + Median: 2 * time.Millisecond, + P95: 2 * time.Millisecond, + Max: 2 * time.Millisecond, + }, + Samples: []time.Duration{ + time.Millisecond, + 2 * time.Millisecond, + }, + }}, + } +} + func TestWriteJSONEmitsBaselineFriendlyReport(t *testing.T) { report := Report{ Driver: "pg", diff --git a/cmd/benchmark/runner.go b/cmd/benchmark/runner.go index b146f11d..d74eb698 100644 --- a/cmd/benchmark/runner.go +++ b/cmd/benchmark/runner.go @@ -18,6 +18,7 @@ package main import ( "context" + "fmt" "sort" "time" @@ -33,16 +34,17 @@ type Stats struct { // Result is one row in the report. type Result struct { - Section string `json:"section"` - Dataset string `json:"dataset"` - Label string `json:"label"` - Stats Stats `json:"stats"` + Section string `json:"section"` + Dataset string `json:"dataset"` + Label string `json:"label"` + Stats Stats `json:"stats"` + Samples []time.Duration `json:"samples,omitempty"` } // runScenario executes a scenario N times and returns timing stats. func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations int) (Result, error) { // Warm-up: one untimed run. - if err := db.ReadTransaction(ctx, s.Query); err != nil { + if err := runScenarioOnce(ctx, db, s); err != nil { return Result{}, err } @@ -50,7 +52,7 @@ func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations for i := range iterations { start := time.Now() - if err := db.ReadTransaction(ctx, s.Query); err != nil { + if err := runScenarioOnce(ctx, db, s); err != nil { return Result{}, err } durations[i] = time.Since(start) @@ -61,9 +63,29 @@ func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations Dataset: s.Dataset, Label: s.Label, Stats: computeStats(durations), + Samples: durations, }, nil } +func runScenarioOnce(ctx context.Context, db graph.Database, s Scenario) error { + return db.ReadTransaction(ctx, func(tx graph.Transaction) error { + rows, err := s.Query(tx) + if err != nil { + return err + } + + return validateScenarioRows(s, rows) + }) +} + +func validateScenarioRows(s Scenario, actualRows int) error { + if s.ExpectedRows == nil || *s.ExpectedRows == actualRows { + return nil + } + + return fmt.Errorf("%s/%s on %s expected %d rows, got %d", s.Section, s.Label, s.Dataset, *s.ExpectedRows, actualRows) +} + func computeStats(durations []time.Duration) Stats { sort.Slice(durations, func(i, j int) bool { return durations[i] < durations[j] }) diff --git a/cmd/benchmark/scenarios.go b/cmd/benchmark/scenarios.go index 217ae63d..c445d966 100644 --- a/cmd/benchmark/scenarios.go +++ b/cmd/benchmark/scenarios.go @@ -25,20 +25,25 @@ import ( // Scenario defines a single benchmark query to run against a loaded dataset. type Scenario struct { - Section string // grouping key in the report (e.g. "Match Nodes") - Dataset string - Label string // human-readable row label - Query func(tx graph.Transaction) error + Section string // grouping key in the report (e.g. "Match Nodes") + Dataset string + Label string // human-readable row label + ExpectedRows *int + Query func(tx graph.Transaction) (int, error) } +const traversalShapesDataset = "traversal_shapes" + // defaultDatasets is the set of datasets committed to the repo. -var defaultDatasets = []string{"base"} +var defaultDatasets = []string{"base", traversalShapesDataset} // scenariosForDataset returns all benchmark scenarios for a given dataset and its loaded ID map. func scenariosForDataset(dataset string, idMap opengraph.IDMap) []Scenario { switch dataset { case "base": return baseScenarios(idMap) + case traversalShapesDataset: + return traversalShapesScenarios(idMap) case "local/phantom": return phantomScenarios(idMap) default: @@ -46,23 +51,31 @@ func scenariosForDataset(dataset string, idMap opengraph.IDMap) []Scenario { } } -func countNodes(tx graph.Transaction) error { - _, err := tx.Nodes().Count() - return err +func expectRows(rows int) *int { + return &rows +} + +func countNodes(tx graph.Transaction) (int, error) { + count, err := tx.Nodes().Count() + return int(count), err } -func countEdges(tx graph.Transaction) error { - _, err := tx.Relationships().Count() - return err +func countEdges(tx graph.Transaction) (int, error) { + count, err := tx.Relationships().Count() + return int(count), err } -func cypherQuery(cypher string) func(tx graph.Transaction) error { - return func(tx graph.Transaction) error { +func cypherQuery(cypher string) func(tx graph.Transaction) (int, error) { + return func(tx graph.Transaction) (int, error) { result := tx.Query(cypher, nil) defer result.Close() + + rows := 0 for result.Next() { + rows++ } - return result.Error() + + return rows, result.Error() } } @@ -71,22 +84,80 @@ func cypherQuery(cypher string) func(tx graph.Transaction) error { func baseScenarios(idMap opengraph.IDMap) []Scenario { ds := "base" return []Scenario{ - {Section: "Match Nodes", Dataset: ds, Label: ds, Query: countNodes}, - {Section: "Match Edges", Dataset: ds, Label: ds, Query: countEdges}, - {Section: "Shortest Paths", Dataset: ds, Label: "n1 -> n3", Query: cypherQuery(fmt.Sprintf( + {Section: "Match Nodes", Dataset: ds, Label: ds, ExpectedRows: expectRows(3), Query: countNodes}, + {Section: "Match Edges", Dataset: ds, Label: ds, ExpectedRows: expectRows(2), Query: countEdges}, + {Section: "Shortest Paths", Dataset: ds, Label: "n1 -> n3", ExpectedRows: expectRows(1), Query: cypherQuery(fmt.Sprintf( "MATCH p = allShortestPaths((s)-[*1..]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p", idMap["n1"], idMap["n3"], ))}, - {Section: "Traversal", Dataset: ds, Label: "n1", Query: cypherQuery(fmt.Sprintf( + {Section: "Traversal", Dataset: ds, Label: "n1", ExpectedRows: expectRows(2), Query: cypherQuery(fmt.Sprintf( "MATCH (s)-[*1..]->(e) WHERE id(s) = %d RETURN e", idMap["n1"], ))}, - {Section: "Match Return", Dataset: ds, Label: "n1", Query: cypherQuery(fmt.Sprintf( + {Section: "Match Return", Dataset: ds, Label: "n1", ExpectedRows: expectRows(1), Query: cypherQuery(fmt.Sprintf( "MATCH (s)-[]->(e) WHERE id(s) = %d RETURN e", idMap["n1"], ))}, - {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind1", Query: cypherQuery("MATCH (n:NodeKind1) RETURN n")}, - {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind2", Query: cypherQuery("MATCH (n:NodeKind2) RETURN n")}, + {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind1", ExpectedRows: expectRows(2), Query: cypherQuery("MATCH (n:NodeKind1) RETURN n")}, + {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind2", ExpectedRows: expectRows(2), Query: cypherQuery("MATCH (n:NodeKind2) RETURN n")}, + } +} + +// --- Traversal shape scenarios --- + +func traversalShapesScenarios(idMap opengraph.IDMap) []Scenario { + ds := traversalShapesDataset + return []Scenario{ + {Section: "Match Nodes", Dataset: ds, Label: ds, ExpectedRows: expectRows(45), Query: countNodes}, + {Section: "Match Edges", Dataset: ds, Label: ds, ExpectedRows: expectRows(41), Query: countEdges}, + {Section: "Traversal Depth", Dataset: ds, Label: "chain depth 1", ExpectedRows: expectRows(1), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:ChainEdge*1..1]->(e) WHERE id(s) = %d RETURN e", + idMap["c0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "chain depth 3", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:ChainEdge*1..3]->(e) WHERE id(s) = %d RETURN e", + idMap["c0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "chain depth 10", ExpectedRows: expectRows(10), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:ChainEdge*1..10]->(e) WHERE id(s) = %d RETURN e", + idMap["c0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "fanout depth 1", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:FanoutEdge*1..1]->(e) WHERE id(s) = %d RETURN e", + idMap["f0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "fanout depth 2", ExpectedRows: expectRows(9), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:FanoutEdge*1..2]->(e) WHERE id(s) = %d RETURN e", + idMap["f0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "fanout depth 3", ExpectedRows: expectRows(15), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:FanoutEdge*1..3]->(e) WHERE id(s) = %d RETURN e", + idMap["f0"], + ))}, + {Section: "Traversal Cycle", Dataset: ds, Label: "bounded cycle", ExpectedRows: expectRows(4), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:CycleEdge*1..4]->(e) WHERE id(s) = %d RETURN e", + idMap["y0"], + ))}, + {Section: "Traversal Dead End", Dataset: ds, Label: "chain terminal", ExpectedRows: expectRows(0), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:ChainEdge*1..]->(e) WHERE id(s) = %d RETURN e", + idMap["c10"], + ))}, + {Section: "Edge Kind Traversal", Dataset: ds, Label: "Allowed", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:Allowed*1..]->(e) WHERE id(s) = %d RETURN e", + idMap["s0"], + ))}, + {Section: "Edge Kind Traversal", Dataset: ds, Label: "all kinds", ExpectedRows: expectRows(6), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[*1..]->(e) WHERE id(s) = %d RETURN e", + idMap["s0"], + ))}, + {Section: "Shortest Paths", Dataset: ds, Label: "diamond many paths", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf( + "MATCH p = allShortestPaths((s)-[*1..]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p", + idMap["d0"], idMap["d4"], + ))}, + {Section: "Shortest Paths", Dataset: ds, Label: "disconnected", ExpectedRows: expectRows(0), Query: cypherQuery(fmt.Sprintf( + "MATCH p = allShortestPaths((s)-[*1..]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p", + idMap["x0"], idMap["x1"], + ))}, } } diff --git a/cmd/benchmark/scenarios_test.go b/cmd/benchmark/scenarios_test.go new file mode 100644 index 00000000..647ee8ae --- /dev/null +++ b/cmd/benchmark/scenarios_test.go @@ -0,0 +1,120 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "os" + "testing" + + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/opengraph" + "github.com/stretchr/testify/require" +) + +func TestBaseScenariosDeclareExpectedRows(t *testing.T) { + scenarios := baseScenarios(opengraph.IDMap{ + "n1": graph.ID(1), + "n2": graph.ID(2), + "n3": graph.ID(3), + }) + + requireExpectedRows(t, scenarios, "Match Nodes", "base", 3) + requireExpectedRows(t, scenarios, "Match Edges", "base", 2) + requireExpectedRows(t, scenarios, "Shortest Paths", "n1 -> n3", 1) + requireExpectedRows(t, scenarios, "Traversal", "n1", 2) + requireExpectedRows(t, scenarios, "Match Return", "n1", 1) + requireExpectedRows(t, scenarios, "Filter By Kind", "NodeKind1", 2) + requireExpectedRows(t, scenarios, "Filter By Kind", "NodeKind2", 2) +} + +func TestTraversalShapesDatasetIsValid(t *testing.T) { + file, err := os.Open("../../integration/testdata/traversal_shapes.json") + require.NoError(t, err) + defer file.Close() + + doc, err := opengraph.ParseDocument(file) + require.NoError(t, err) + require.Len(t, doc.Graph.Nodes, 45) + require.Len(t, doc.Graph.Edges, 41) +} + +func TestTraversalShapesScenariosDeclareExpectedRows(t *testing.T) { + scenarios := traversalShapesScenarios(traversalShapesIDMap()) + + requireExpectedRows(t, scenarios, "Match Nodes", traversalShapesDataset, 45) + requireExpectedRows(t, scenarios, "Match Edges", traversalShapesDataset, 41) + requireExpectedRows(t, scenarios, "Traversal Depth", "chain depth 1", 1) + requireExpectedRows(t, scenarios, "Traversal Depth", "chain depth 3", 3) + requireExpectedRows(t, scenarios, "Traversal Depth", "chain depth 10", 10) + requireExpectedRows(t, scenarios, "Traversal Depth", "fanout depth 1", 3) + requireExpectedRows(t, scenarios, "Traversal Depth", "fanout depth 2", 9) + requireExpectedRows(t, scenarios, "Traversal Depth", "fanout depth 3", 15) + requireExpectedRows(t, scenarios, "Traversal Cycle", "bounded cycle", 4) + requireExpectedRows(t, scenarios, "Traversal Dead End", "chain terminal", 0) + requireExpectedRows(t, scenarios, "Edge Kind Traversal", "Allowed", 3) + requireExpectedRows(t, scenarios, "Edge Kind Traversal", "all kinds", 6) + requireExpectedRows(t, scenarios, "Shortest Paths", "diamond many paths", 3) + requireExpectedRows(t, scenarios, "Shortest Paths", "disconnected", 0) +} + +func TestDefaultDatasetsIncludeTraversalShapes(t *testing.T) { + require.Contains(t, defaultDatasets, traversalShapesDataset) +} + +func TestValidateScenarioRows(t *testing.T) { + scenario := Scenario{ + Section: "Traversal", + Dataset: "base", + Label: "n1", + ExpectedRows: expectRows(2), + } + + require.NoError(t, validateScenarioRows(scenario, 2)) + require.ErrorContains(t, validateScenarioRows(scenario, 1), "Traversal/n1 on base expected 2 rows, got 1") +} + +func traversalShapesIDMap() opengraph.IDMap { + ids := []string{ + "c0", "c10", + "f0", + "d0", "d4", + "y0", + "x0", "x1", + "s0", + } + + idMap := opengraph.IDMap{} + for idx, id := range ids { + idMap[id] = graph.ID(idx + 1) + } + + return idMap +} + +func requireExpectedRows(t *testing.T, scenarios []Scenario, section, label string, expectedRows int) { + t.Helper() + + for _, scenario := range scenarios { + if scenario.Section == section && scenario.Label == label { + require.NotNil(t, scenario.ExpectedRows) + require.Equal(t, expectedRows, *scenario.ExpectedRows) + return + } + } + + require.Failf(t, "scenario not found", "%s/%s", section, label) +} diff --git a/cypher/models/cypher/format/format.go b/cypher/models/cypher/format/format.go index 495cf806..a05d7b1a 100644 --- a/cypher/models/cypher/format/format.go +++ b/cypher/models/cypher/format/format.go @@ -80,14 +80,11 @@ func (s Emitter) formatNodePattern(output io.Writer, nodePattern *cypher.NodePat func (s Emitter) formatRelationshipPattern(output io.Writer, relationshipPattern *cypher.RelationshipPattern) error { switch relationshipPattern.Direction { - case graph.DirectionOutbound: + case graph.DirectionOutbound, graph.DirectionBoth: if _, err := io.WriteString(output, "-["); err != nil { return err } - case graph.DirectionBoth: - fallthrough - case graph.DirectionInbound: if _, err := io.WriteString(output, "<-["); err != nil { return err @@ -147,14 +144,11 @@ func (s Emitter) formatRelationshipPattern(output io.Writer, relationshipPattern } switch relationshipPattern.Direction { - case graph.DirectionInbound: + case graph.DirectionInbound, graph.DirectionBoth: if _, err := io.WriteString(output, "]-"); err != nil { return err } - case graph.DirectionBoth: - fallthrough - case graph.DirectionOutbound: if _, err := io.WriteString(output, "]->"); err != nil { return err @@ -296,7 +290,7 @@ func (s Emitter) formatProjection(output io.Writer, projection *cypher.Projectio } func (s Emitter) formatReturn(output io.Writer, returnClause *cypher.Return) error { - if _, err := io.WriteString(output, " return "); err != nil { + if _, err := io.WriteString(output, "return "); err != nil { return err } @@ -1095,6 +1089,12 @@ func (s Emitter) formatSinglePartQuery(writer io.Writer, singlePartQuery *cypher } if singlePartQuery.Return != nil { + if len(singlePartQuery.ReadingClauses) > 0 || len(singlePartQuery.UpdatingClauses) > 0 { + if _, err := io.WriteString(writer, " "); err != nil { + return err + } + } + return s.formatReturn(writer, singlePartQuery.Return) } diff --git a/cypher/models/cypher/format/format_test.go b/cypher/models/cypher/format/format_test.go index 327f65d4..8a02d2e9 100644 --- a/cypher/models/cypher/format/format_test.go +++ b/cypher/models/cypher/format/format_test.go @@ -6,6 +6,7 @@ import ( "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/cypher/format" + "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/cypher/frontend" "github.com/stretchr/testify/require" @@ -27,6 +28,55 @@ func TestCypherEmitter_StripLiterals(t *testing.T) { require.Equal(t, "match (n {value: $STRIPPED}) where n.other = $STRIPPED and n.number = $STRIPPED return n.name, n", buffer.String()) } +func TestCypherEmitter_RelationshipDirections(t *testing.T) { + testCases := []struct { + name string + direction graph.Direction + expected string + }{ + { + name: "outbound", + direction: graph.DirectionOutbound, + expected: "match (a)-[r]->(b) return r", + }, + { + name: "inbound", + direction: graph.DirectionInbound, + expected: "match (a)<-[r]-(b) return r", + }, + { + name: "both", + direction: graph.DirectionBoth, + expected: "match (a)-[r]-(b) return r", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + regularQuery, singlePartQuery := cypher.NewRegularQueryWithSingleQuery() + match := singlePartQuery.NewReadingClause().NewMatch(false) + match.NewPatternPart().AddPatternElements( + &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("a"), + }, + &cypher.RelationshipPattern{ + Variable: cypher.NewVariableWithSymbol("r"), + Direction: testCase.direction, + }, + &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("b"), + }, + ) + + singlePartQuery.NewProjection(false).AddItem(cypher.NewProjectionItemWithExpr(cypher.NewVariableWithSymbol("r"))) + + rendered, err := format.RegularQuery(regularQuery, false) + require.NoError(t, err) + require.Equal(t, testCase.expected, rendered) + }) + } +} + func TestCypherEmitter_HappyPath(t *testing.T) { test.LoadFixture(t, test.MutationTestCases).Run(t) test.LoadFixture(t, test.PositiveTestCases).Run(t) diff --git a/cypher/models/pgsql/test/query_test.go b/cypher/models/pgsql/test/query_test.go index 38ec85ab..ac39ca9f 100644 --- a/cypher/models/pgsql/test/query_test.go +++ b/cypher/models/pgsql/test/query_test.go @@ -5,12 +5,11 @@ import ( "slices" "testing" - "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/pgsql" "github.com/specterops/dawgs/cypher/models/pgsql/translate" "github.com/specterops/dawgs/cypher/models/walk" "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" + v2 "github.com/specterops/dawgs/query/v2" ) var ( @@ -24,21 +23,20 @@ var ( func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) { mapper := newKindMapper() - queries := []*cypher.Where{ - query.Where(query.KindIn(query.Node(), NodeKind1)), - query.Where(query.Kind(query.Node(), NodeKind2)), + queries := []v2.QueryBuilder{ + v2.New().Where(v2.KindIn(v2.Node(), NodeKind1)).Return(v2.Node()), + v2.New().Where(v2.Kind(v2.Node(), NodeKind2)).Return(v2.Node()), } - for _, nodeQuery := range queries { - builder := query.NewBuilderWithCriteria(nodeQuery) - builtQuery, err := builder.Build(false) + for _, queryBuilder := range queries { + builtQuery, err := queryBuilder.Build() if err != nil { - t.Errorf("could not build query: %v", err) + t.Fatalf("could not build query: %v", err) } - translatedQuery, err := translate.Translate(context.Background(), builtQuery, mapper, nil, translate.DefaultGraphID) + translatedQuery, err := translate.Translate(context.Background(), builtQuery.Query, mapper, builtQuery.Parameters, translate.DefaultGraphID) if err != nil { - t.Errorf("could not translate query: %#v: %v", builtQuery, err) + t.Fatalf("could not translate query: %#v: %v", builtQuery, err) } walk.PgSQL(translatedQuery.Statement, walk.NewSimpleVisitor(func(node pgsql.SyntaxNode, visitorHandler walk.VisitorHandler) { @@ -47,7 +45,7 @@ func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) { switch leftTyped := typedNode.LOperand.(type) { case pgsql.CompoundIdentifier: if slices.Equal(leftTyped, pgsql.AsCompoundIdentifier("n0", "kind_ids")) && typedNode.Operator != pgsql.OperatorPGArrayOverlap { - t.Errorf("query did not generate an array overlap operator (&&): %#v", nodeQuery) + t.Errorf("query did not generate an array overlap operator (&&): %#v", builtQuery) } } } diff --git a/cypher/test/cases/mutation_tests.json b/cypher/test/cases/mutation_tests.json index 7893439b..dc73b031 100644 --- a/cypher/test/cases/mutation_tests.json +++ b/cypher/test/cases/mutation_tests.json @@ -4,7 +4,7 @@ "name": "Multipart query with mutation", "type": "string_match", "details": { - "query": "match (s:Ship {name: 'Nebuchadnezzar'}) with s as ship merge p = (c:Crew {name: 'Neo'})\u003c-[:CrewOf]-\u003e(ship) set c.title = 'The One' return p", + "query": "match (s:Ship {name: 'Nebuchadnezzar'}) with s as ship merge p = (c:Crew {name: 'Neo'})-[:CrewOf]-(ship) set c.title = 'The One' return p", "fitness": 7 } }, diff --git a/cypher/test/cases/positive_tests.json b/cypher/test/cases/positive_tests.json index b3941c04..cedfccdc 100644 --- a/cypher/test/cases/positive_tests.json +++ b/cypher/test/cases/positive_tests.json @@ -189,7 +189,7 @@ "name": "Specify bi-directional relationship", "type": "string_match", "details": { - "query": "match (p:Person)\u003c-[]-\u003e(m:Movie) return m", + "query": "match (p:Person)-[]-(m:Movie) return m", "fitness": 0 } }, @@ -437,7 +437,7 @@ "name": "built-in shortestPaths()", "type": "string_match", "details": { - "query": "match p = shortestPath((p1:Person)\u003c-[*]-\u003e(p2:Person)) where p1.name = 'tom' and p2.name = 'jerry' return p", + "query": "match p = shortestPath((p1:Person)-[*]-(p2:Person)) where p1.name = 'tom' and p2.name = 'jerry' return p", "fitness": 17 } }, @@ -453,7 +453,7 @@ "name": "Find nodes with relationships", "type": "string_match", "details": { - "query": "match (b) where (b)\u003c-[]-\u003e() return b", + "query": "match (b) where (b)-[]-() return b", "fitness": -4 } }, @@ -461,7 +461,7 @@ "name": "Find nodes with no relationships", "type": "string_match", "details": { - "query": "match (b) where not ((b)\u003c-[]-\u003e()) return b", + "query": "match (b) where not ((b)-[]-()) return b", "fitness": -5 } }, @@ -898,4 +898,4 @@ } } ] -} \ No newline at end of file +} diff --git a/integration/testdata/traversal_shapes.json b/integration/testdata/traversal_shapes.json new file mode 100644 index 00000000..d1041096 --- /dev/null +++ b/integration/testdata/traversal_shapes.json @@ -0,0 +1,94 @@ +{ + "graph": { + "nodes": [ + {"id": "c0", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c1", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c2", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c3", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c4", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c5", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c6", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c7", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c8", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c9", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c10", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "f0", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1a", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1b", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2a", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2b", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3a", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3b", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1a1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1b1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2a1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2b1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3a1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3b1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "d0", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "d1", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "d2", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "d3", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "d4", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "y0", "kinds": ["TraversalNode", "CycleNode"]}, + {"id": "y1", "kinds": ["TraversalNode", "CycleNode"]}, + {"id": "y2", "kinds": ["TraversalNode", "CycleNode"]}, + {"id": "y3", "kinds": ["TraversalNode", "CycleNode"]}, + {"id": "x0", "kinds": ["TraversalNode", "DisconnectedNode"]}, + {"id": "x1", "kinds": ["TraversalNode", "DisconnectedNode"]}, + {"id": "s0", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "s1", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "s2", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "s3", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "t1", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "t2", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "t3", "kinds": ["TraversalNode", "SelectiveNode"]} + ], + "edges": [ + {"start_id": "c0", "end_id": "c1", "kind": "ChainEdge"}, + {"start_id": "c1", "end_id": "c2", "kind": "ChainEdge"}, + {"start_id": "c2", "end_id": "c3", "kind": "ChainEdge"}, + {"start_id": "c3", "end_id": "c4", "kind": "ChainEdge"}, + {"start_id": "c4", "end_id": "c5", "kind": "ChainEdge"}, + {"start_id": "c5", "end_id": "c6", "kind": "ChainEdge"}, + {"start_id": "c6", "end_id": "c7", "kind": "ChainEdge"}, + {"start_id": "c7", "end_id": "c8", "kind": "ChainEdge"}, + {"start_id": "c8", "end_id": "c9", "kind": "ChainEdge"}, + {"start_id": "c9", "end_id": "c10", "kind": "ChainEdge"}, + {"start_id": "f0", "end_id": "f1", "kind": "FanoutEdge"}, + {"start_id": "f0", "end_id": "f2", "kind": "FanoutEdge"}, + {"start_id": "f0", "end_id": "f3", "kind": "FanoutEdge"}, + {"start_id": "f1", "end_id": "f1a", "kind": "FanoutEdge"}, + {"start_id": "f1", "end_id": "f1b", "kind": "FanoutEdge"}, + {"start_id": "f2", "end_id": "f2a", "kind": "FanoutEdge"}, + {"start_id": "f2", "end_id": "f2b", "kind": "FanoutEdge"}, + {"start_id": "f3", "end_id": "f3a", "kind": "FanoutEdge"}, + {"start_id": "f3", "end_id": "f3b", "kind": "FanoutEdge"}, + {"start_id": "f1a", "end_id": "f1a1", "kind": "FanoutEdge"}, + {"start_id": "f1b", "end_id": "f1b1", "kind": "FanoutEdge"}, + {"start_id": "f2a", "end_id": "f2a1", "kind": "FanoutEdge"}, + {"start_id": "f2b", "end_id": "f2b1", "kind": "FanoutEdge"}, + {"start_id": "f3a", "end_id": "f3a1", "kind": "FanoutEdge"}, + {"start_id": "f3b", "end_id": "f3b1", "kind": "FanoutEdge"}, + {"start_id": "d0", "end_id": "d1", "kind": "DiamondEdge"}, + {"start_id": "d0", "end_id": "d2", "kind": "DiamondEdge"}, + {"start_id": "d0", "end_id": "d3", "kind": "DiamondEdge"}, + {"start_id": "d1", "end_id": "d4", "kind": "DiamondEdge"}, + {"start_id": "d2", "end_id": "d4", "kind": "DiamondEdge"}, + {"start_id": "d3", "end_id": "d4", "kind": "DiamondEdge"}, + {"start_id": "y0", "end_id": "y1", "kind": "CycleEdge"}, + {"start_id": "y1", "end_id": "y2", "kind": "CycleEdge"}, + {"start_id": "y2", "end_id": "y0", "kind": "CycleEdge"}, + {"start_id": "y2", "end_id": "y3", "kind": "CycleEdge"}, + {"start_id": "s0", "end_id": "s1", "kind": "Allowed"}, + {"start_id": "s1", "end_id": "s2", "kind": "Allowed"}, + {"start_id": "s2", "end_id": "s3", "kind": "Allowed"}, + {"start_id": "s0", "end_id": "t1", "kind": "Blocked"}, + {"start_id": "t1", "end_id": "t2", "kind": "Blocked"}, + {"start_id": "t2", "end_id": "t3", "kind": "Blocked"} + ] + } +} diff --git a/query/neo4j/neo4j.go b/query/neo4j/neo4j.go index 0689f74f..7299d004 100644 --- a/query/neo4j/neo4j.go +++ b/query/neo4j/neo4j.go @@ -53,6 +53,14 @@ func (s *QueryBuilder) rewriteParameters() error { return nil } +func hasPreparedMatchPattern(readingClause *cypher.ReadingClause) bool { + if readingClause == nil || readingClause.Match == nil { + return false + } + + return len(readingClause.Match.Pattern) > 0 +} + func (s *QueryBuilder) Apply(criteria graph.Criteria) { switch typedCriteria := criteria.(type) { case *cypher.Where: @@ -201,6 +209,10 @@ func (s *QueryBuilder) prepareMatch() error { return ErrAmbiguousQueryVariables } + if firstReadingClause := query.GetFirstReadingClause(s.query); hasPreparedMatchPattern(firstReadingClause) { + return nil + } + if singleNodeBound && !creatingSingleNode { patternPart.AddPatternElements(&cypher.NodePattern{ Variable: cypher.NewVariableWithSymbol(query.NodeSymbol), diff --git a/query/neo4j/neo4j_test.go b/query/neo4j/neo4j_test.go index 2efc3435..05117347 100644 --- a/query/neo4j/neo4j_test.go +++ b/query/neo4j/neo4j_test.go @@ -422,7 +422,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where (n)<-[]->() return n")) + ), "match (n) where (n)-[]-() return n")) t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery( query.Where( @@ -436,7 +436,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.OrderBy( query.Order(query.NodeProperty("value"), query.Ascending()), ), - ), "match (n) where (n)<-[]->() return n order by n.value asc")) + ), "match (n) where (n)-[]-() return n order by n.value asc")) t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery( query.Where( @@ -451,7 +451,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Order(query.NodeProperty("value_1"), query.Ascending()), query.Order(query.NodeProperty("value_2"), query.Descending()), ), - ), "match (n) where (n)<-[]->() return n order by n.value_1 asc, n.value_2 desc")) + ), "match (n) where (n)-[]-() return n order by n.value_1 asc, n.value_2 desc")) t.Run("Node has Relationships Order by Node Item with Limit and Offset", assertQueryResult(query.SinglePartQuery( query.Where( @@ -469,7 +469,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Limit(10), query.Offset(20), - ), "match (n) where (n)<-[]->() return n order by n.value_1 asc, n.value_2 desc skip 20 limit 10")) + ), "match (n) where (n)-[]-() return n order by n.value_1 asc, n.value_2 desc skip 20 limit 10")) t.Run("Node has no Relationships", assertQueryResult(query.SinglePartQuery( query.Where( @@ -479,7 +479,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where not ((n)<-[]->()) return n")) + ), "match (n) where not ((n)-[]-()) return n")) t.Run("Node Datetime Before", assertQueryResult(query.SinglePartQuery( query.Where( diff --git a/query/v2/backend_test.go b/query/v2/backend_test.go new file mode 100644 index 00000000..475e4e45 --- /dev/null +++ b/query/v2/backend_test.go @@ -0,0 +1,359 @@ +package v2_test + +import ( + "context" + "strings" + "testing" + + "github.com/specterops/dawgs/cypher/models/pgsql/translate" + "github.com/specterops/dawgs/drivers/pg/pgutil" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query/neo4j" + v2 "github.com/specterops/dawgs/query/v2" + "github.com/stretchr/testify/require" +) + +func testKindMapper(kinds ...graph.Kind) *pgutil.InMemoryKindMapper { + mapper := pgutil.NewInMemoryKindMapper() + + for _, kind := range kinds { + mapper.Put(kind) + } + + return mapper +} + +func TestBackendParityNeo4jPrepare(t *testing.T) { + cases := map[string]struct { + builder v2.QueryBuilder + expectedCypher string + expectedParams map[string]any + }{ + "node read": { + builder: v2.New().Where( + v2.Node().Kinds().Has(graph.StringKind("User")), + v2.Node().Property("name").Contains("admin"), + ).Return( + v2.Node(), + ).OrderBy( + v2.Node().Property("name"), + ), + expectedCypher: "match (n) where n:User and n.name contains $p0 return n order by n.name asc", + expectedParams: map[string]any{"p0": "admin"}, + }, + "relationship read": { + builder: v2.New().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + expectedCypher: "match (s)-[r:MemberOf]->(e) where id(s) = $p0 return id(s), id(r), id(e)", + expectedParams: map[string]any{"p0": 1}, + }, + "shortest path": { + builder: v2.New().WithShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedCypher: "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", + expectedParams: map[string]any{"p0": 1, "p1": 2}, + }, + "all shortest paths": { + builder: v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedCypher: "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", + expectedParams: map[string]any{"p0": 1, "p1": 2}, + }, + "recursive traversal": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 2)).Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + ).Return( + v2.Path(), + v2.End().ID(), + ), + expectedCypher: "match p = (s)-[r:MemberOf*1..2]->(e) where id(s) = $p0 return p, id(e)", + expectedParams: map[string]any{"p0": 1}, + }, + "create node": { + builder: v2.New().Create( + v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.NamedParameter("props", map[string]any{"name": "u"})), + ).Return( + v2.Node().ID(), + ), + expectedCypher: "create (n:User $p0) return id(n)", + expectedParams: map[string]any{"p0": map[string]any{"name": "u"}}, + }, + "update node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "updated"), + ), + expectedCypher: "match (n) where id(n) = $p0 set n.name = $p1", + expectedParams: map[string]any{"p0": 1, "p1": "updated"}, + }, + "delete relationship": { + builder: v2.New().Where( + v2.Relationship().ID().Equals(1), + ).Delete( + v2.Relationship(), + ), + expectedCypher: "match ()-[r]->() where id(r) = $p0 delete r", + expectedParams: map[string]any{"p0": 1}, + }, + "delete node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Delete( + v2.Node(), + ), + expectedCypher: "match (n) where id(n) = $p0 detach delete n", + expectedParams: map[string]any{"p0": 1}, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + + queryBuilder := neo4j.NewQueryBuilder(preparedQuery.Query) + require.NoError(t, queryBuilder.Prepare()) + + rendered, err := queryBuilder.Render() + require.NoError(t, err) + require.Equal(t, testCase.expectedCypher, rendered) + require.Equal(t, testCase.expectedParams, queryBuilder.Parameters) + }) + } +} + +func TestBackendParityPGTranslateTraversalDepth(t *testing.T) { + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(edgeKind) + + cases := map[string]struct { + builder v2.QueryBuilder + expectedSQLContains []string + }{ + "path": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 2)).Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + ).Return( + v2.Path(), + ), + expectedSQLContains: []string{ + "with recursive", + "ordered_edges_to_path", + "n0.id = @pi0::int8", + "e0.kind_id = any (array [1]::int2[])", + "depth < 2", + }, + }, + "endpoints": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 2)).Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.End().ID(), + ), + expectedSQLContains: []string{ + "with recursive", + "n0.id = @pi0::int8", + "e0.kind_id = any (array [1]::int2[])", + "depth < 2", + "select (s0.n0).id, (s0.n1).id from s0", + }, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + + for _, expected := range testCase.expectedSQLContains { + require.Contains(t, sql, expected) + } + }) + } +} + +func TestBackendParityPGTranslate(t *testing.T) { + userKind := graph.StringKind("User") + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(userKind, edgeKind) + + cases := map[string]struct { + builder v2.QueryBuilder + expectedSQL string + expectedParams map[string]any + }{ + "node read": { + builder: v2.New().Where( + v2.Node().Kinds().Has(userKind), + v2.Node().Property("name").Contains("admin"), + ).Return( + v2.Node().ID(), + v2.Node().Kinds(), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and cypher_contains((n0.properties ->> 'name'), (@pi0::text)::text)::bool)) select (s0.n0).id, (array(select _kind.name from generate_subscripts((s0.n0).kind_ids, 1) as _kind_idx, kind _kind where _kind.id = ((s0.n0).kind_ids)[_kind_idx] order by _kind_idx))::text[] from s0;", + expectedParams: map[string]any{"pi0": "admin"}, + }, + "relationship read": { + builder: v2.New().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on (n0.id = @pi0::int8) and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [2]::int2[])) select (s0.n0).id, (s0.e0).id, (s0.n1).id from s0;", + expectedParams: map[string]any{"pi0": 1}, + }, + "update node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "updated"), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (update node n1 set properties = n1.properties || jsonb_build_object('name', @pi1::text)::jsonb from s0 where (s0.n0).id = n1.id returning (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n0) select 1;", + expectedParams: map[string]any{"pi0": 1, "pi1": "updated"}, + }, + "delete relationship": { + builder: v2.New().Where( + v2.Relationship().ID().Equals(1), + ).Delete( + v2.Relationship(), + ), + expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where (e0.id = @pi0::int8)), s1 as (delete from edge e1 using s0 where (s0.e0).id = e1.id) select 1;", + expectedParams: map[string]any{"pi0": 1}, + }, + "delete node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Delete( + v2.Node(), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (delete from node n1 using s0 where (s0.n0).id = n1.id) select 1;", + expectedParams: map[string]any{"pi0": 1}, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + require.Equal(t, testCase.expectedSQL, sql) + require.Equal(t, testCase.expectedParams, translation.Parameters) + }) + } +} + +func TestBackendParityPGTranslateShortestPaths(t *testing.T) { + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(edgeKind) + + cases := map[string]struct { + builder v2.QueryBuilder + expectedHarness string + }{ + "shortest path": { + builder: v2.New().WithShortestPaths().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedHarness: "bidirectional_sp_harness", + }, + "all shortest paths": { + builder: v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedHarness: "bidirectional_asp_harness", + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + require.Contains(t, sql, testCase.expectedHarness) + require.Contains(t, sql, "ordered_edges_to_path") + require.Contains(t, sql, "n0.id = 1") + require.Contains(t, sql, "n1.id = 2") + + serializedHarnessQueryHasKindConstraint := false + for _, parameterValue := range translation.Parameters { + if serializedQuery, typeOK := parameterValue.(string); typeOK && strings.Contains(serializedQuery, "array [1]::int2[]") { + serializedHarnessQueryHasKindConstraint = true + break + } + } + require.True(t, serializedHarnessQueryHasKindConstraint, "expected serialized shortest-path harness query to contain edge kind constraint: %#v", translation.Parameters) + }) + } +} + +func TestBackendParityPGCreate(t *testing.T) { + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(edgeKind) + + preparedQuery, err := v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.RelationshipPattern(edgeKind, nil, graph.DirectionOutbound), + ).Return( + v2.Relationship().ID(), + ).Build() + require.NoError(t, err) + + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + require.Contains(t, sql, "insert into edge") + require.Contains(t, sql, "graph_id") + require.Contains(t, sql, "kind_id") +} diff --git a/query/v2/compat.go b/query/v2/compat.go new file mode 100644 index 00000000..fb6be7b3 --- /dev/null +++ b/query/v2/compat.go @@ -0,0 +1,303 @@ +package v2 + +import ( + "fmt" + "strings" + "time" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" +) + +func Variable(name string) *cypher.Variable { + return cypher.NewVariableWithSymbol(name) +} + +func Identity(reference any) *cypher.FunctionInvocation { + return cypher.NewSimpleFunctionInvocation(cypher.IdentityFunction, expressionOrError(reference)) +} + +func NodeID() *cypher.FunctionInvocation { + return Identity(Identifiers.Node()) +} + +func RelationshipID() *cypher.FunctionInvocation { + return Identity(Identifiers.Relationship()) +} + +func StartID() *cypher.FunctionInvocation { + return Identity(Identifiers.Start()) +} + +func EndID() *cypher.FunctionInvocation { + return Identity(Identifiers.End()) +} + +func Count(reference any) *cypher.FunctionInvocation { + return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, expressionOrError(reference)) +} + +func CountDistinct(reference any) *cypher.FunctionInvocation { + return &cypher.FunctionInvocation{ + Name: cypher.CountFunction, + Distinct: true, + Arguments: []cypher.Expression{expressionOrError(reference)}, + } +} + +func Size(expression any) *cypher.FunctionInvocation { + return cypher.NewSimpleFunctionInvocation(cypher.ListSizeFunction, expressionOrError(expression)) +} + +func KindsOf(reference any) *cypher.FunctionInvocation { + if scopedReference, typeOK := reference.(scopedExpression); typeOK { + if variable, typeOK := scopedReference.qualifier().(*cypher.Variable); !typeOK { + return invalidExpression(fmt.Errorf("expected variable reference, got %T", scopedReference.qualifier())) + } else if expression, err := kindProjectionExpression(scopedReference.roleName(), variable); err != nil { + return invalidExpression(err) + } else if invocation, typeOK := expression.(*cypher.FunctionInvocation); !typeOK { + return invalidExpression(fmt.Errorf("expected kind projection function, got %T", expression)) + } else { + return invocation + } + } + + expression := expressionOrError(reference) + + switch typedExpression := expression.(type) { + case *cypher.Variable: + switch typedExpression.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedExpression) + + case Identifiers.relationship: + return cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedExpression) + } + } + + return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, expression) +} + +func Kind(reference any, kinds ...graph.Kind) *cypher.KindMatcher { + return &cypher.KindMatcher{ + Reference: expressionOrError(reference), + Kinds: kinds, + } +} + +func KindIn(reference any, kinds ...graph.Kind) *cypher.KindMatcher { + return Kind(reference, kinds...) +} + +func AddKind(reference any, kind graph.Kind) *cypher.SetItem { + return AddKinds(reference, graph.Kinds{kind}) +} + +func AddKinds(reference any, kinds graph.Kinds) *cypher.SetItem { + return cypher.NewSetItem(expressionOrError(reference), cypher.OperatorLabelAssignment, kinds) +} + +func DeleteKind(reference any, kind graph.Kind) *cypher.RemoveItem { + return DeleteKinds(reference, graph.Kinds{kind}) +} + +func DeleteKinds(reference any, kinds graph.Kinds) *cypher.RemoveItem { + return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(expressionOrError(reference), kinds, false)) +} + +func SetProperty(reference any, value any) *cypher.SetItem { + return cypher.NewSetItem(expressionOrError(reference), cypher.OperatorAssignment, valueExpression(value)) +} + +func SetProperties(reference any, properties map[string]any) *cypher.Set { + set := &cypher.Set{} + + for _, key := range sortedPropertyKeys(properties) { + set.Items = append(set.Items, cypher.NewSetItem( + propertyLookupOrError(reference, key), + cypher.OperatorAssignment, + valueExpression(properties[key]), + )) + } + + return set +} + +func DeleteProperty(reference any) *cypher.RemoveItem { + return cypher.RemoveProperty(expressionOrError(reference)) +} + +func DeleteProperties(reference any, propertyNames ...string) *cypher.Remove { + remove := &cypher.Remove{} + + for _, propertyName := range propertyNames { + remove.Items = append(remove.Items, cypher.RemoveProperty(propertyLookupOrError(reference, propertyName))) + } + + return remove +} + +func NodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: Identifiers.Node(), + Kinds: kinds, + Properties: properties, + } +} + +func StartNodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: Identifiers.Start(), + Kinds: kinds, + Properties: properties, + } +} + +func EndNodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: Identifiers.End(), + Kinds: kinds, + Properties: properties, + } +} + +func RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) *cypher.RelationshipPattern { + return &cypher.RelationshipPattern{ + Variable: Identifiers.Relationship(), + Kinds: graph.Kinds{kind}, + Direction: direction, + Properties: properties, + } +} + +func Equals(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorEquals, valueExpression(value)) +} + +func GreaterThan(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorGreaterThan, valueExpression(value)) +} + +func After(reference any, value any) cypher.Expression { + return GreaterThan(reference, value) +} + +func GreaterThanOrEqualTo(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorGreaterThanOrEqualTo, valueExpression(value)) +} + +func GreaterThanOrEquals(reference any, value any) cypher.Expression { + return GreaterThanOrEqualTo(reference, value) +} + +func LessThan(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorLessThan, valueExpression(value)) +} + +func LessThanGraphQuery(reference any, other any) cypher.Expression { + return LessThan(reference, other) +} + +func Before(reference any, value time.Time) cypher.Expression { + return LessThan(reference, value) +} + +func BeforeGraphQuery(reference any, other any) cypher.Expression { + return LessThan(reference, other) +} + +func LessThanOrEqualTo(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorLessThanOrEqualTo, valueExpression(value)) +} + +func LessThanOrEquals(reference any, value any) cypher.Expression { + return LessThanOrEqualTo(reference, value) +} + +func In(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIn, valueExpression(value)) +} + +func InInverted(reference any, value any) cypher.Expression { + return cypher.NewComparison(valueExpression(value), cypher.OperatorIn, expressionOrError(reference)) +} + +func InIDs(reference any, ids ...graph.ID) cypher.Expression { + expression := expressionOrError(reference) + + if variable, typeOK := expression.(*cypher.Variable); typeOK { + expression = Identity(variable) + } + + return cypher.NewComparison(expression, cypher.OperatorIn, Parameter(ids)) +} + +func StringContains(reference any, value string) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorContains, Parameter(value)) +} + +func StringStartsWith(reference any, value string) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorStartsWith, Parameter(value)) +} + +func StringEndsWith(reference any, value string) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorEndsWith, Parameter(value)) +} + +func CaseInsensitiveStringContains(reference any, value string) cypher.Expression { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)), + cypher.OperatorContains, + Parameter(strings.ToLower(value)), + ) +} + +func CaseInsensitiveStringStartsWith(reference any, value string) cypher.Expression { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)), + cypher.OperatorStartsWith, + Parameter(strings.ToLower(value)), + ) +} + +func CaseInsensitiveStringEndsWith(reference any, value string) cypher.Expression { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)), + cypher.OperatorEndsWith, + Parameter(strings.ToLower(value)), + ) +} + +func Exists(reference any) cypher.Expression { + return IsNotNull(reference) +} + +func IsNull(reference any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIs, Literal(nil)) +} + +func IsNotNull(reference any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIsNot, Literal(nil)) +} + +func HasRelationships(reference any) *cypher.PatternPredicate { + patternPredicate := cypher.NewPatternPredicate() + + if variable, err := variableReference(reference); err != nil { + patternPredicate.AddElement(&cypher.NodePattern{ + Properties: invalidExpression(err), + }) + } else { + patternPredicate.AddElement(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(variable.Symbol), + }) + } + + patternPredicate.AddElement(&cypher.RelationshipPattern{ + Direction: graph.DirectionBoth, + }) + + patternPredicate.AddElement(&cypher.NodePattern{}) + + return patternPredicate +} diff --git a/query/v2/doc.go b/query/v2/doc.go new file mode 100644 index 00000000..a21a5004 --- /dev/null +++ b/query/v2/doc.go @@ -0,0 +1,6 @@ +// Package v2 contains the experimental fluent Cypher query builder. +// +// It is intentionally isolated from the stable query package so callers can +// opt in without pulling the current graph query APIs through a compatibility +// layer. +package v2 diff --git a/query/v2/legacy_parity_test.go b/query/v2/legacy_parity_test.go new file mode 100644 index 00000000..e8fe1bee --- /dev/null +++ b/query/v2/legacy_parity_test.go @@ -0,0 +1,213 @@ +package v2_test + +import ( + "testing" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" + legacyquery "github.com/specterops/dawgs/query" + "github.com/specterops/dawgs/query/neo4j" + v2 "github.com/specterops/dawgs/query/v2" + "github.com/stretchr/testify/require" +) + +func renderNeo4jQuery(t *testing.T, regularQuery *cypher.RegularQuery, prepareAllShortestPaths bool) (string, map[string]any) { + t.Helper() + + queryBuilder := neo4j.NewQueryBuilder(regularQuery) + + if prepareAllShortestPaths { + require.NoError(t, queryBuilder.PrepareAllShortestPaths()) + } else { + require.NoError(t, queryBuilder.Prepare()) + } + + rendered, err := queryBuilder.Render() + require.NoError(t, err) + + return rendered, queryBuilder.Parameters +} + +func assertLegacyNeo4jParity(t *testing.T, legacyQuery *cypher.RegularQuery, v2Builder v2.QueryBuilder, prepareLegacyAllShortestPaths bool) { + t.Helper() + + preparedQuery, err := v2Builder.Build() + require.NoError(t, err) + + legacyRendered, legacyParameters := renderNeo4jQuery(t, legacyQuery, prepareLegacyAllShortestPaths) + v2Rendered, v2Parameters := renderNeo4jQuery(t, preparedQuery.Query, false) + + require.Equal(t, legacyRendered, v2Rendered) + require.Equal(t, legacyParameters, v2Parameters) +} + +func TestLegacyNeo4jParity(t *testing.T) { + userKind := graph.StringKind("User") + edgeKind := graph.StringKind("MemberOf") + + t.Run("node count by kind", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.KindIn(legacyquery.Node(), userKind), + ), + legacyquery.Returning( + legacyquery.Count(legacyquery.Node()), + ), + ), + v2.New().Where( + v2.Node().Kinds().Has(userKind), + ).Return( + v2.Node().Count(), + ), + false, + ) + }) + + t.Run("node read with pagination", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.StringContains(legacyquery.NodeProperty("name"), "admin"), + legacyquery.IsNotNull(legacyquery.NodeProperty("enabled")), + ), + ), + legacyquery.Returning( + legacyquery.Node(), + legacyquery.OrderBy(legacyquery.Order(legacyquery.NodeProperty("name"), legacyquery.Ascending())), + legacyquery.Offset(0), + legacyquery.Limit(0), + ), + ), + v2.New().Where( + v2.And( + v2.Node().Property("name").Contains("admin"), + v2.Node().Property("enabled").IsNotNull(), + ), + ).Return( + v2.Node(), + ).OrderBy( + v2.Asc(v2.Node().Property("name")), + ).Skip(0).Limit(0), + false, + ) + }) + + t.Run("node read with or and adjacent predicate", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.Or( + legacyquery.Equals(legacyquery.NodeProperty("name"), "alice"), + legacyquery.Equals(legacyquery.NodeProperty("name"), "bob"), + ), + legacyquery.IsNotNull(legacyquery.NodeProperty("enabled")), + ), + ), + legacyquery.Returning( + legacyquery.Node(), + ), + ), + v2.New().Where( + v2.Or( + v2.Node().Property("name").Equals("alice"), + v2.Node().Property("name").Equals("bob"), + ), + v2.Node().Property("enabled").IsNotNull(), + ).Return( + v2.Node(), + ), + false, + ) + }) + + t.Run("relationship read", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.KindIn(legacyquery.Relationship(), edgeKind), + legacyquery.Equals(legacyquery.StartID(), 1), + legacyquery.Equals(legacyquery.EndID(), 2), + ), + ), + legacyquery.Returning( + legacyquery.StartID(), + legacyquery.RelationshipID(), + legacyquery.EndID(), + ), + ), + v2.New().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + false, + ) + }) + + t.Run("create relationship with matched endpoints", func(t *testing.T) { + properties := map[string]any{"name": "rel"} + + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.Equals(legacyquery.StartID(), 1), + legacyquery.Equals(legacyquery.EndID(), 2), + ), + ), + legacyquery.Create( + legacyquery.Start(), + legacyquery.RelationshipPattern(edgeKind, legacyquery.Parameter(properties), graph.DirectionOutbound), + legacyquery.End(), + ), + legacyquery.Returning( + legacyquery.RelationshipID(), + ), + ), + v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.Start(), + v2.RelationshipPattern(edgeKind, v2.Parameter(properties), graph.DirectionOutbound), + v2.End(), + ).Return( + v2.Relationship().ID(), + ), + false, + ) + }) + + t.Run("all shortest paths", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.KindIn(legacyquery.Relationship(), edgeKind), + legacyquery.Equals(legacyquery.StartID(), 1), + legacyquery.Equals(legacyquery.EndID(), 2), + ), + ), + legacyquery.Returning( + legacyquery.Path(), + ), + ), + v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + true, + ) + }) +} diff --git a/query/v2/query.go b/query/v2/query.go new file mode 100644 index 00000000..66ecb179 --- /dev/null +++ b/query/v2/query.go @@ -0,0 +1,1560 @@ +package v2 + +import ( + "errors" + "fmt" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" +) + +type runtimeIdentifiers struct { + path string + node string + start string + relationship string + end string +} + +type TraversalDepth struct { + patternRange *cypher.PatternRange + err error +} + +func traversalDepthBound(value int64) *int64 { + return &value +} + +func newTraversalDepth(start, end *int64) TraversalDepth { + if start != nil && *start < 0 { + return TraversalDepth{ + err: fmt.Errorf("traversal depth minimum must be non-negative: %d", *start), + } + } + + if end != nil && *end < 0 { + return TraversalDepth{ + err: fmt.Errorf("traversal depth maximum must be non-negative: %d", *end), + } + } + + if start != nil && end != nil && *end < *start { + return TraversalDepth{ + err: fmt.Errorf("traversal depth maximum %d is less than minimum %d", *end, *start), + } + } + + return TraversalDepth{ + patternRange: cypher.NewPatternRange(start, end), + } +} + +func (s TraversalDepth) rangePattern() *cypher.PatternRange { + if s.patternRange == nil { + return &cypher.PatternRange{} + } + + return cypher.Copy(s.patternRange) +} + +func AnyDepth() TraversalDepth { + return newTraversalDepth(nil, nil) +} + +func MinDepth(min int64) TraversalDepth { + return newTraversalDepth(traversalDepthBound(min), nil) +} + +func MaxDepth(max int64) TraversalDepth { + return newTraversalDepth(nil, traversalDepthBound(max)) +} + +func DepthRange(min, max int64) TraversalDepth { + return newTraversalDepth(traversalDepthBound(min), traversalDepthBound(max)) +} + +func ExactDepth(depth int64) TraversalDepth { + depthBound := traversalDepthBound(depth) + return newTraversalDepth(depthBound, depthBound) +} + +// Accessors return fresh variables; compare symbols rather than pointer identity. +func (s runtimeIdentifiers) Path() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.path) +} + +func (s runtimeIdentifiers) Node() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.node) +} + +func (s runtimeIdentifiers) Start() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.start) +} + +func (s runtimeIdentifiers) Relationship() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.relationship) +} + +func (s runtimeIdentifiers) End() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.end) +} + +var Identifiers = runtimeIdentifiers{ + path: "p", + node: "n", + start: "s", + relationship: "r", + end: "e", +} + +type Scope struct { + identifiers runtimeIdentifiers + errors []error +} + +func DefaultScope() Scope { + return Scope{ + identifiers: Identifiers, + } +} + +func NewScope(path, node, start, relationship, end string) Scope { + identifiers := runtimeIdentifiers{ + path: path, + node: node, + start: start, + relationship: relationship, + end: end, + } + + return Scope{ + identifiers: identifiers, + errors: validateRuntimeIdentifiers(identifiers), + } +} + +func validateRuntimeIdentifiers(identifiers runtimeIdentifiers) []error { + aliases := []struct { + role string + value string + }{ + {role: "path", value: identifiers.path}, + {role: "node", value: identifiers.node}, + {role: "start", value: identifiers.start}, + {role: "relationship", value: identifiers.relationship}, + {role: "end", value: identifiers.end}, + } + + var ( + errs []error + seen = map[string]string{} + ) + + for _, alias := range aliases { + if err := validateCypherSymbol(alias.value, "scope alias "+alias.role); err != nil { + errs = append(errs, err) + continue + } + + if existingRole, exists := seen[alias.value]; exists { + errs = append(errs, fmt.Errorf("scope aliases %s and %s both use %q", existingRole, alias.role, alias.value)) + } else { + seen[alias.value] = alias.role + } + } + + return errs +} + +func (s Scope) New() QueryBuilder { + return newBuilder(s.identifiers, s.errors...) +} + +func (s Scope) Node() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: s.identifiers.Node(), + role: Identifiers.node, + } +} + +func (s Scope) Path() PathContinuation { + return &entity[PathContinuation]{ + identifier: s.identifiers.Path(), + role: Identifiers.path, + } +} + +func (s Scope) Start() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: s.identifiers.Start(), + role: Identifiers.start, + } +} + +func (s Scope) Relationship() RelationshipContinuation { + return &entity[RelationshipContinuation]{ + identifier: s.identifiers.Relationship(), + role: Identifiers.relationship, + } +} + +func (s Scope) End() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: s.identifiers.End(), + role: Identifiers.end, + } +} + +func Literal(value any) *cypher.Literal { + if value == nil { + return cypher.NewLiteral(nil, true) + } + + if strValue, typeOK := value.(string); typeOK { + return cypher.NewStringLiteral(strValue) + } + + return cypher.NewLiteral(value, false) +} + +func Parameter(value any) *cypher.Parameter { + if parameter, typeOK := value.(*cypher.Parameter); typeOK { + return parameter + } + + return &cypher.Parameter{ + Value: value, + } +} + +func NamedParameter(symbol string, value any) *cypher.Parameter { + return cypher.NewParameter(symbol, value) +} + +func valueExpression(value any) cypher.Expression { + switch typedValue := value.(type) { + case *cypher.Parameter: + return typedValue + case *cypher.Literal: + return typedValue + case *cypher.Variable: + return typedValue + case *cypher.PropertyLookup: + return typedValue + case *cypher.FunctionInvocation: + return typedValue + case *cypher.Parenthetical: + return typedValue + case *cypher.Comparison: + return typedValue + case *cypher.Negation: + return typedValue + case *cypher.Conjunction: + return typedValue + case *cypher.Disjunction: + return typedValue + case *cypher.ExclusiveDisjunction: + return typedValue + case *cypher.KindMatcher: + return typedValue + case *cypher.ListLiteral: + return typedValue + case cypher.MapLiteral: + return typedValue + case *cypher.PatternPredicate: + return typedValue + case *cypher.ArithmeticExpression: + return typedValue + case *cypher.UnaryAddOrSubtractExpression: + return typedValue + case *cypher.FilterExpression: + return typedValue + case *cypher.IDInCollection: + return typedValue + case QualifiedExpression: + if expression, err := projectionExpression(typedValue); err != nil { + return invalidExpression(err) + } else { + return expression + } + default: + return Parameter(value) + } +} + +func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode) ([]cypher.Expression, cypher.SyntaxNode) { + if len(operands) == 0 { + return nil, invalidExpression(fmt.Errorf("%s requires at least one operand", operator)) + } + + expressions := make([]cypher.Expression, len(operands)) + for idx, operand := range operands { + expressions[idx] = operand + } + + return expressions, nil +} + +func comparisonHasLogicalOperator(comparison *cypher.Comparison) bool { + if comparison == nil { + return false + } + + for _, partial := range comparison.Partials { + switch partial.Operator { + case cypher.OperatorAnd, cypher.OperatorOr: + return true + } + } + + return false +} + +func parenthesizeDisjunctiveExpression(expression cypher.Expression) cypher.Expression { + switch typedExpression := expression.(type) { + case *cypher.Parenthetical: + return typedExpression + case *cypher.Disjunction, *cypher.ExclusiveDisjunction: + return cypher.NewParenthetical(typedExpression) + case *cypher.Comparison: + if comparisonHasLogicalOperator(typedExpression) { + return cypher.NewParenthetical(typedExpression) + } + } + + return expression +} + +func parenthesizeLogicalExpression(expression cypher.Expression) cypher.Expression { + switch typedExpression := expression.(type) { + case *cypher.Parenthetical: + return typedExpression + case *cypher.Conjunction, *cypher.Disjunction, *cypher.ExclusiveDisjunction: + return cypher.NewParenthetical(typedExpression) + case *cypher.Comparison: + if comparisonHasLogicalOperator(typedExpression) { + return cypher.NewParenthetical(typedExpression) + } + } + + return expression +} + +func Not(operand cypher.Expression) cypher.Expression { + return cypher.NewNegation(parenthesizeLogicalExpression(operand)) +} + +func And(operands ...cypher.SyntaxNode) cypher.SyntaxNode { + expressions, errExpression := joinedExpressionList(cypher.OperatorAnd, operands) + if errExpression != nil { + return errExpression + } + + for idx, expression := range expressions { + expressions[idx] = parenthesizeDisjunctiveExpression(expression) + } + + return cypher.NewConjunction(expressions...) +} + +func Or(operands ...cypher.SyntaxNode) cypher.SyntaxNode { + expressions, errExpression := joinedExpressionList(cypher.OperatorOr, operands) + if errExpression != nil { + return errExpression + } + + for idx, expression := range expressions { + expressions[idx] = parenthesizeLogicalExpression(expression) + } + + return cypher.NewParenthetical(cypher.NewDisjunction(expressions...)) +} + +type SortDirection int + +const ( + SortAscending SortDirection = iota + SortDescending +) + +func Asc(expression any) *cypher.SortItem { + return Order(expression, SortAscending) +} + +func Desc(expression any) *cypher.SortItem { + return Order(expression, SortDescending) +} + +func validateSortDirection(direction SortDirection) error { + switch direction { + case SortAscending, SortDescending: + return nil + default: + return fmt.Errorf("unsupported sort direction: %d", direction) + } +} + +func Order(expression any, direction SortDirection) *cypher.SortItem { + expressionValue := expressionOrError(expression) + if err := validateSortDirection(direction); err != nil { + expressionValue = invalidExpression(err) + } + + return &cypher.SortItem{ + Ascending: direction != SortDescending, + Expression: expressionValue, + } +} + +func As(expression any, alias string) *cypher.ProjectionItem { + return &cypher.ProjectionItem{ + Expression: expressionOrError(expression), + Alias: cypher.NewVariableWithSymbol(alias), + } +} + +func Node() NodeContinuation { + return DefaultScope().Node() +} + +func Path() PathContinuation { + return DefaultScope().Path() +} + +func Start() NodeContinuation { + return DefaultScope().Start() +} + +func Relationship() RelationshipContinuation { + return DefaultScope().Relationship() +} + +func End() NodeContinuation { + return DefaultScope().End() +} + +type QualifiedExpression interface { + qualifier() cypher.Expression +} + +type scopedExpression interface { + QualifiedExpression + + roleName() string +} + +type deleteTarget interface { + QualifiedExpression + + deleteTarget() +} + +type EntityContinuation interface { + QualifiedExpression + + Count() cypher.Expression + ID() IdentityContinuation + Property(name string) PropertyContinuation +} + +type KindContinuation interface { + Is(kind graph.Kind) cypher.Expression + IsOneOf(kinds graph.Kinds) cypher.Expression +} + +type KindsContinuation interface { + Has(kind graph.Kind) cypher.Expression + HasOneOf(kinds graph.Kinds) cypher.Expression + Add(kinds graph.Kinds) *cypher.SetItem + Remove(kinds graph.Kinds) *cypher.RemoveItem +} + +type Comparable interface { + In(value any) cypher.Expression + Contains(value any) cypher.Expression + StartsWith(value any) cypher.Expression + EndsWith(value any) cypher.Expression + Equals(value any) cypher.Expression + GreaterThan(value any) cypher.Expression + GreaterThanOrEqualTo(value any) cypher.Expression + LessThan(value any) cypher.Expression + LessThanOrEqualTo(value any) cypher.Expression + IsNull() cypher.Expression + IsNotNull() cypher.Expression +} + +type PropertyContinuation interface { + QualifiedExpression + Comparable + + Set(value any) *cypher.SetItem + Remove() *cypher.RemoveItem +} + +type IdentityContinuation interface { + QualifiedExpression + Comparable +} + +type comparisonContinuation struct { + qualifierExpression cypher.Expression +} + +func (s *comparisonContinuation) qualifier() cypher.Expression { + return s.qualifierExpression +} + +func (s *comparisonContinuation) asComparison(operator cypher.Operator, rOperand any) cypher.Expression { + return cypher.NewComparison( + s.qualifier(), + operator, + valueExpression(rOperand), + ) +} + +func (s *comparisonContinuation) In(value any) cypher.Expression { + return s.asComparison(cypher.OperatorIn, value) +} + +func (s *comparisonContinuation) Contains(value any) cypher.Expression { + return s.asComparison(cypher.OperatorContains, value) +} + +func (s *comparisonContinuation) StartsWith(value any) cypher.Expression { + return s.asComparison(cypher.OperatorStartsWith, value) +} + +func (s *comparisonContinuation) EndsWith(value any) cypher.Expression { + return s.asComparison(cypher.OperatorEndsWith, value) +} + +func (s *comparisonContinuation) Equals(value any) cypher.Expression { + return s.asComparison(cypher.OperatorEquals, value) +} + +func (s *comparisonContinuation) GreaterThan(value any) cypher.Expression { + return s.asComparison(cypher.OperatorGreaterThan, value) +} + +func (s *comparisonContinuation) GreaterThanOrEqualTo(value any) cypher.Expression { + return s.asComparison(cypher.OperatorGreaterThanOrEqualTo, value) +} + +func (s *comparisonContinuation) LessThan(value any) cypher.Expression { + return s.asComparison(cypher.OperatorLessThan, value) +} + +func (s *comparisonContinuation) LessThanOrEqualTo(value any) cypher.Expression { + return s.asComparison(cypher.OperatorLessThanOrEqualTo, value) +} + +func (s *comparisonContinuation) IsNull() cypher.Expression { + return cypher.NewComparison(s.qualifier(), cypher.OperatorIs, Literal(nil)) +} + +func (s *comparisonContinuation) IsNotNull() cypher.Expression { + return cypher.NewComparison(s.qualifier(), cypher.OperatorIsNot, Literal(nil)) +} + +type propertyContinuation struct { + comparisonContinuation +} + +func (s *propertyContinuation) Set(value any) *cypher.SetItem { + return cypher.NewSetItem( + s.qualifier(), + cypher.OperatorAssignment, + valueExpression(value), + ) +} + +func (s *propertyContinuation) Remove() *cypher.RemoveItem { + return cypher.RemoveProperty(s.qualifier()) +} + +type entity[T any] struct { + identifier *cypher.Variable + role string +} + +func (s *entity[T]) Kind() KindContinuation { + return kindContinuation{ + identifier: s.identifier, + role: s.role, + } +} + +func (s *entity[T]) Kinds() KindsContinuation { + return kindsContinuation{ + identifier: s.identifier, + role: s.role, + } +} + +func (s *entity[T]) Count() cypher.Expression { + return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, s.identifier) +} + +func (s *entity[T]) SetProperties(properties map[string]any) *cypher.Set { + set := &cypher.Set{} + + for _, key := range sortedPropertyKeys(properties) { + set.Items = append(set.Items, s.Property(key).Set(properties[key])) + } + + return set +} + +func (s *entity[T]) RemoveProperties(properties []string) *cypher.Remove { + remove := &cypher.Remove{} + + for _, key := range properties { + remove.Items = append(remove.Items, s.Property(key).Remove()) + } + + return remove +} + +func (s *entity[T]) RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression { + return &cypher.RelationshipPattern{ + Variable: s.identifier, + Kinds: graph.Kinds{kind}, + Direction: direction, + Properties: properties, + } +} + +func (s *entity[T]) NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression { + return &cypher.NodePattern{ + Variable: s.identifier, + Kinds: kinds, + Properties: properties, + } +} + +func (s *entity[T]) qualifier() cypher.Expression { + return s.identifier +} + +func (s *entity[T]) deleteTarget() {} + +func (s *entity[T]) roleName() string { + return s.role +} + +func (s *entity[T]) ID() IdentityContinuation { + return &comparisonContinuation{ + qualifierExpression: &cypher.FunctionInvocation{ + Distinct: false, + Name: cypher.IdentityFunction, + Arguments: []cypher.Expression{s.identifier}, + }, + } +} + +func (s *entity[T]) Property(propertyName string) PropertyContinuation { + return &propertyContinuation{ + comparisonContinuation: comparisonContinuation{ + qualifierExpression: cypher.NewPropertyLookup(s.identifier.Symbol, propertyName), + }, + } +} + +type kindContinuation struct { + identifier *cypher.Variable + role string +} + +func (s kindContinuation) qualifier() cypher.Expression { + return s.identifier +} + +func (s kindContinuation) roleName() string { + return s.role +} + +func (s kindContinuation) Is(kind graph.Kind) cypher.Expression { + return s.IsOneOf(graph.Kinds{kind}) +} + +func (s kindContinuation) IsOneOf(kinds graph.Kinds) cypher.Expression { + return &cypher.KindMatcher{ + Reference: s.identifier, + Kinds: kinds, + } +} + +type kindsContinuation struct { + identifier *cypher.Variable + role string +} + +func (s kindsContinuation) qualifier() cypher.Expression { + return s.identifier +} + +func (s kindsContinuation) roleName() string { + return s.role +} + +func (s kindsContinuation) Has(kind graph.Kind) cypher.Expression { + return s.HasOneOf(graph.Kinds{kind}) +} + +func (s kindsContinuation) HasOneOf(kinds graph.Kinds) cypher.Expression { + return &cypher.KindMatcher{ + Reference: s.identifier, + Kinds: kinds, + } +} + +func (s kindsContinuation) Add(kinds graph.Kinds) *cypher.SetItem { + return cypher.NewSetItem( + s.identifier, + cypher.OperatorLabelAssignment, + kinds, + ) +} + +func (s kindsContinuation) Remove(kinds graph.Kinds) *cypher.RemoveItem { + return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(s.identifier, kinds, false)) +} + +type PathContinuation interface { + QualifiedExpression + + Count() cypher.Expression +} + +type RelationshipContinuation interface { + EntityContinuation + + RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression + + Kind() KindContinuation + SetProperties(properties map[string]any) *cypher.Set + RemoveProperties(properties []string) *cypher.Remove +} + +type NodeContinuation interface { + EntityContinuation + + NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression + + Kinds() KindsContinuation + SetProperties(properties map[string]any) *cypher.Set + RemoveProperties(properties []string) *cypher.Remove +} + +type QueryBuilder interface { + Where(constraints ...cypher.SyntaxNode) QueryBuilder + OrderBy(sortItems ...any) QueryBuilder + Skip(offset int) QueryBuilder + // Limit accepts zero, which renders LIMIT 0 and returns an empty result set. + Limit(limit int) QueryBuilder + Return(projections ...any) QueryBuilder + ReturnDistinct(projections ...any) QueryBuilder + Update(updatingClauses ...any) QueryBuilder + Create(creationClauses ...any) QueryBuilder + Delete(expressions ...any) QueryBuilder + WithShortestPaths() QueryBuilder + WithAllShortestPaths() QueryBuilder + WithTraversalDepth(depth TraversalDepth) QueryBuilder + WithRelationshipDirection(direction graph.Direction) QueryBuilder + Build() (*PreparedQuery, error) +} + +type updatingClauseKind int + +const ( + updatingClauseSet updatingClauseKind = iota + updatingClauseRemove + updatingClauseDelete + updatingClauseCreate +) + +type pendingUpdatingClause struct { + kind updatingClauseKind + creates []any + setItems []*cypher.SetItem + removeItems []*cypher.RemoveItem + deleteItems []cypher.Expression + detach bool +} + +type builder struct { + errors []error + constraints []cypher.SyntaxNode + sortItems []any + projections []any + distinct bool + identifiers runtimeIdentifiers + updatingClauses []pendingUpdatingClause + creates []any + setItems []*cypher.SetItem + removeItems []*cypher.RemoveItem + deleteItems []cypher.Expression + detachDelete bool + relationshipDirection graph.Direction + traversalDepth *cypher.PatternRange + shortestPathQuery bool + allShorestPathsQuery bool + skip *int + limit *int +} + +func New() QueryBuilder { + return DefaultScope().New() +} + +func newBuilder(identifiers runtimeIdentifiers, errs ...error) QueryBuilder { + return &builder{ + identifiers: identifiers, + errors: append([]error(nil), errs...), + relationshipDirection: graph.DirectionOutbound, + } +} + +func (s *builder) WithShortestPaths() QueryBuilder { + s.shortestPathQuery = true + return s +} + +func (s *builder) WithAllShortestPaths() QueryBuilder { + s.allShorestPathsQuery = true + return s +} + +func (s *builder) WithTraversalDepth(depth TraversalDepth) QueryBuilder { + if depth.err != nil { + s.trackError(depth.err) + } else { + s.traversalDepth = depth.rangePattern() + } + + return s +} + +func (s *builder) WithRelationshipDirection(direction graph.Direction) QueryBuilder { + if err := validateRelationshipDirection(direction); err != nil { + s.trackError(err) + } else { + s.relationshipDirection = direction + } + + return s +} + +func (s *builder) OrderBy(sortItems ...any) QueryBuilder { + s.sortItems = append(s.sortItems, sortItems...) + return s +} + +func (s *builder) Skip(skip int) QueryBuilder { + if skip < 0 { + s.trackError(fmt.Errorf("skip must be non-negative: %d", skip)) + return s + } + + s.skip = &skip + return s +} + +func (s *builder) Limit(limit int) QueryBuilder { + if limit < 0 { + s.trackError(fmt.Errorf("limit must be non-negative: %d", limit)) + return s + } + + s.limit = &limit + return s +} + +func (s *builder) Return(projections ...any) QueryBuilder { + s.projections = append(s.projections, projections...) + return s +} + +func (s *builder) ReturnDistinct(projections ...any) QueryBuilder { + s.distinct = true + s.projections = append(s.projections, projections...) + return s +} + +func (s *builder) appendSetItems(items ...*cypher.SetItem) { + if len(items) == 0 { + return + } + + lastClauseIdx := len(s.updatingClauses) - 1 + if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseSet { + s.updatingClauses[lastClauseIdx].setItems = append(s.updatingClauses[lastClauseIdx].setItems, items...) + } else { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseSet, + setItems: items, + }) + } +} + +func (s *builder) appendRemoveItems(items ...*cypher.RemoveItem) { + if len(items) == 0 { + return + } + + lastClauseIdx := len(s.updatingClauses) - 1 + if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseRemove { + s.updatingClauses[lastClauseIdx].removeItems = append(s.updatingClauses[lastClauseIdx].removeItems, items...) + } else { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseRemove, + removeItems: items, + }) + } +} + +func (s *builder) appendDeleteItems(detach bool, items ...cypher.Expression) { + if len(items) == 0 { + return + } + + // Consecutive deletes share one clause; any node delete makes the whole clause DETACH DELETE. + lastClauseIdx := len(s.updatingClauses) - 1 + if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseDelete { + s.updatingClauses[lastClauseIdx].detach = s.updatingClauses[lastClauseIdx].detach || detach + s.updatingClauses[lastClauseIdx].deleteItems = append(s.updatingClauses[lastClauseIdx].deleteItems, items...) + } else { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseDelete, + deleteItems: items, + detach: detach, + }) + } +} + +func (s *builder) Create(creationClauses ...any) QueryBuilder { + s.creates = append(s.creates, creationClauses...) + + if len(creationClauses) > 0 { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseCreate, + creates: creationClauses, + }) + } + + return s +} + +func (s *builder) Update(updates ...any) QueryBuilder { + for _, nextUpdate := range updates { + switch typedNextUpdate := nextUpdate.(type) { + case *cypher.Set: + if setItems, err := setItemsFromSet(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.setItems = append(s.setItems, setItems...) + s.appendSetItems(setItems...) + } + + case *cypher.SetItem: + if setItem, err := setItemFromValue(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.setItems = append(s.setItems, setItem) + s.appendSetItems(setItem) + } + + case *cypher.Remove: + if removeItems, err := removeItemsFromRemove(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.removeItems = append(s.removeItems, removeItems...) + s.appendRemoveItems(removeItems...) + } + + case *cypher.RemoveItem: + if removeItem, err := removeItemFromValue(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.removeItems = append(s.removeItems, removeItem) + s.appendRemoveItems(removeItem) + } + + default: + s.trackError(fmt.Errorf("unknown update type: %T", nextUpdate)) + } + } + + return s +} + +func (s *builder) Delete(deleteItems ...any) QueryBuilder { + var pendingDeleteItems []cypher.Expression + pendingDetachDelete := false + + for _, nextDelete := range deleteItems { + switch typedNextUpdate := nextDelete.(type) { + case deleteTarget: + if isNilPointer(typedNextUpdate) { + s.trackError(fmt.Errorf("delete target is nil")) + continue + } + + deleteItem, detach, err := deleteItemFromExpression(typedNextUpdate.qualifier(), s.identifiers) + if err != nil { + s.trackError(err) + continue + } + + if detach { + s.detachDelete = true + pendingDetachDelete = true + } + + s.deleteItems = append(s.deleteItems, deleteItem) + pendingDeleteItems = append(pendingDeleteItems, deleteItem) + + case *cypher.Variable: + deleteItem, detach, err := deleteItemFromExpression(typedNextUpdate, s.identifiers) + if err != nil { + s.trackError(err) + continue + } + + if detach { + s.detachDelete = true + pendingDetachDelete = true + } + + s.deleteItems = append(s.deleteItems, deleteItem) + pendingDeleteItems = append(pendingDeleteItems, deleteItem) + + case *cypher.PropertyLookup: + if err := validateExpressionValue(typedNextUpdate, "delete expression"); err != nil { + s.trackError(err) + continue + } + + s.trackError(fmt.Errorf("delete target must be a node, relationship, or variable; use remove for properties")) + + case QualifiedExpression: + if isNilPointer(typedNextUpdate) { + s.trackError(fmt.Errorf("delete target is nil")) + continue + } + + if err := validateExpressionValue(typedNextUpdate.qualifier(), "delete expression"); err != nil { + s.trackError(err) + continue + } + + s.trackError(fmt.Errorf("delete target must be a node, relationship, or variable; got %T", nextDelete)) + + default: + s.trackError(fmt.Errorf("unknown delete type: %T", nextDelete)) + } + } + + s.appendDeleteItems(pendingDetachDelete, pendingDeleteItems...) + return s +} + +func deleteItemFromExpression(expression cypher.Expression, identifiers runtimeIdentifiers) (cypher.Expression, bool, error) { + if err := validateExpressionValue(expression, "delete expression"); err != nil { + return nil, false, err + } + + variable, typeOK := expression.(*cypher.Variable) + if !typeOK || variable == nil { + return nil, false, fmt.Errorf("delete target must resolve to a variable, got %T", expression) + } + + if variable.Symbol == identifiers.path { + return nil, false, fmt.Errorf("delete target must be a node or relationship variable, got path variable %q", variable.Symbol) + } + + return copyExpression(variable), isDetachDeleteQualifier(variable, identifiers), nil +} + +func (s *builder) trackError(err error) { + s.errors = append(s.errors, err) +} + +func (s *builder) Where(constraints ...cypher.SyntaxNode) QueryBuilder { + s.constraints = append(s.constraints, constraints...) + return s +} + +func patternEndsWithNodePattern(pattern *cypher.PatternPart) bool { + numElements := len(pattern.PatternElements) + if numElements == 0 { + return false + } + + return pattern.PatternElements[numElements-1].IsNodePattern() +} + +func isCreateNodeValue(value any, identifiers runtimeIdentifiers) bool { + switch typedValue := value.(type) { + case QualifiedExpression: + if variable, typeOK := typedValue.qualifier().(*cypher.Variable); typeOK { + switch variable.Symbol { + case identifiers.node, identifiers.start, identifiers.end: + return true + } + } + + case *cypher.NodePattern: + return typedValue != nil + } + + return false +} + +func isCreateRelationshipValue(value any) bool { + _, typeOK := value.(*cypher.RelationshipPattern) + return typeOK +} + +func nextCreateValueIsNode(creates []any, idx int, identifiers runtimeIdentifiers) bool { + nextIdx := idx + 1 + return nextIdx < len(creates) && isCreateNodeValue(creates[nextIdx], identifiers) +} + +func newCreatePatternPart(createClause *cypher.Create) *cypher.PatternPart { + pattern := &cypher.PatternPart{} + createClause.Pattern = append(createClause.Pattern, pattern) + return pattern +} + +func createPatternHasElements(pattern *cypher.PatternPart) bool { + return pattern != nil && len(pattern.PatternElements) > 0 +} + +func shouldStartNewCreatePattern(pattern *cypher.PatternPart, nextCreate any, patternClosed bool, identifiers runtimeIdentifiers) bool { + if !createPatternHasElements(pattern) { + return false + } + + if isCreateNodeValue(nextCreate, identifiers) && patternEndsWithNodePattern(pattern) { + return true + } + + return patternClosed && isCreateRelationshipValue(nextCreate) +} + +func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeIdentifiers, creates []any) error { + if len(creates) == 0 { + return nil + } + + var ( + createClause = &cypher.Create{ + Unique: false, + } + pattern = newCreatePatternPart(createClause) + patternClosed bool + ) + + for idx, nextCreate := range creates { + if shouldStartNewCreatePattern(pattern, nextCreate, patternClosed, identifiers) { + pattern = newCreatePatternPart(createClause) + patternClosed = false + } + + switch typedNextCreate := nextCreate.(type) { + case QualifiedExpression: + switch typedExpression := typedNextCreate.qualifier().(type) { + case *cypher.Variable: + if typedExpression == nil { + return fmt.Errorf("invalid variable reference for create: ") + } + + switch typedExpression.Symbol { + case identifiers.node, identifiers.start, identifiers.end: + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(typedExpression.Symbol), + }) + patternClosed = false + + default: + return fmt.Errorf("invalid variable reference for create: %s", typedExpression.Symbol) + } + + default: + return fmt.Errorf("invalid qualified expression for create: %T", typedExpression) + } + + case *cypher.NodePattern: + if err := validateNodePattern(typedNextCreate); err != nil { + return err + } + + pattern.AddPatternElements(cypher.Copy(typedNextCreate)) + patternClosed = false + + case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedNextCreate); err != nil { + return err + } + + if !patternEndsWithNodePattern(pattern) { + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.Start(), + }) + } + + pattern.AddPatternElements(cypher.Copy(typedNextCreate)) + + if !nextCreateValueIsNode(creates, idx, identifiers) { + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.End(), + }) + patternClosed = true + } else { + patternClosed = false + } + + default: + return fmt.Errorf("invalid type for create: %T", nextCreate) + } + } + + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause(createClause)) + return nil +} + +func (s *builder) buildUpdatingClauses(singlePartQuery *cypher.SinglePartQuery) error { + for _, updatingClause := range s.updatingClauses { + switch updatingClause.kind { + case updatingClauseSet: + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewSet(updatingClause.setItems), + )) + + case updatingClauseRemove: + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewRemove(updatingClause.removeItems), + )) + + case updatingClauseDelete: + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewDelete( + updatingClause.detach, + updatingClause.deleteItems, + ), + )) + + case updatingClauseCreate: + if err := buildCreates(singlePartQuery, s.identifiers, updatingClause.creates); err != nil { + return err + } + } + } + + return nil +} + +func (s *builder) buildProjectionOrder() (*cypher.Order, error) { + var orderByNode *cypher.Order + + if len(s.sortItems) > 0 { + orderByNode = &cypher.Order{} + + for _, untypedSortItem := range s.sortItems { + switch typedSortItem := untypedSortItem.(type) { + case *cypher.Order: + if sortItems, err := sortItemsFromOrder(typedSortItem); err != nil { + return nil, err + } else { + orderByNode.Items = append(orderByNode.Items, sortItems...) + } + + case *cypher.SortItem: + if sortItem, err := sortItemFromValue(typedSortItem); err != nil { + return nil, err + } else { + orderByNode.Items = append(orderByNode.Items, sortItem) + } + + default: + if sortItem, err := sortItemFromValue(typedSortItem); err != nil { + return nil, err + } else { + orderByNode.Items = append(orderByNode.Items, sortItem) + } + } + } + } + + return orderByNode, nil +} + +func appendProjectionOrder(projection *cypher.Projection, sortItems ...*cypher.SortItem) { + if len(sortItems) == 0 { + return + } + + if projection.Order == nil { + projection.Order = &cypher.Order{} + } + + projection.Order.Items = append(projection.Order.Items, sortItems...) +} + +func applyReturnProjection(projection *cypher.Projection, returnClause *cypher.Return) error { + if projectionItems, err := projectionItemsFromReturn(returnClause); err != nil { + return err + } else { + projection.Distinct = projection.Distinct || returnClause.Projection.Distinct + projection.All = projection.All || returnClause.Projection.All + + for _, projectionItem := range projectionItems { + projection.AddItem(projectionItem) + } + } + + if returnClause.Projection.Order != nil { + if sortItems, err := sortItemsFromOrder(returnClause.Projection.Order); err != nil { + return err + } else { + appendProjectionOrder(projection, sortItems...) + } + } + + if returnClause.Projection.Skip != nil { + projection.Skip = copySkip(returnClause.Projection.Skip) + } + + if returnClause.Projection.Limit != nil { + projection.Limit = copyLimit(returnClause.Projection.Limit) + } + + return nil +} + +func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error { + var ( + hasProjectedItems = len(s.projections) > 0 + hasSkip = s.skip != nil + hasLimit = s.limit != nil + requiresProjection = hasProjectedItems || hasSkip || hasLimit + ) + + if requiresProjection { + if !hasProjectedItems { + return fmt.Errorf("query expected projected items") + } + + projection := singlePartQuery.NewProjection(s.distinct) + + for _, nextProjection := range s.projections { + switch typedNextProjection := nextProjection.(type) { + case *cypher.Return: + if err := applyReturnProjection(projection, typedNextProjection); err != nil { + return err + } + + default: + if projectionItem, err := projectionItemFromValue(typedNextProjection); err != nil { + return err + } else { + projection.AddItem(projectionItem) + } + } + } + + if s.skip != nil { + projection.Skip = cypher.NewSkip(*s.skip) + } + + if s.limit != nil { + projection.Limit = cypher.NewLimit(*s.limit) + } + + if projectionOrder, err := s.buildProjectionOrder(); err != nil { + return err + } else if projectionOrder != nil { + appendProjectionOrder(projection, projectionOrder.Items...) + } + } + + return nil +} + +func countRelationshipKindMatchers(constraints []cypher.SyntaxNode, identifiers runtimeIdentifiers) (int, error) { + var count int + + for _, nextConstraint := range constraints { + if kindMatcher, typeOK := nextConstraint.(*cypher.KindMatcher); typeOK { + if identifier, typeOK := kindMatcher.Reference.(*cypher.Variable); !typeOK { + return 0, fmt.Errorf("expected type *cypher.Variable, got %T", kindMatcher.Reference) + } else if identifier.Symbol == identifiers.relationship { + count++ + } + } + } + + return count, nil +} + +type PreparedQuery struct { + Query *cypher.RegularQuery + Parameters map[string]any +} + +func (s *builder) hasActions() bool { + return len(s.projections) > 0 || len(s.setItems) > 0 || len(s.removeItems) > 0 || len(s.creates) > 0 || len(s.deleteItems) > 0 +} + +func (s *builder) wantsShortestPathPattern() bool { + return s.shortestPathQuery || s.allShorestPathsQuery +} + +func (s *builder) wantsTraversalPattern() bool { + return s.traversalDepth != nil +} + +func (s *builder) usesRangedRelationshipPattern() bool { + return s.wantsTraversalPattern() || s.wantsShortestPathPattern() +} + +func (s *builder) Build() (*PreparedQuery, error) { + if len(s.errors) > 0 { + return nil, errors.Join(s.errors...) + } + + if !s.hasActions() { + return nil, fmt.Errorf("query has no action specified") + } + + if err := collectModelErrorsFromKnownValues(s.constraints, s.creates, s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems); err != nil { + return nil, err + } + + var ( + regularQuery, singlePartQuery = cypher.NewRegularQueryWithSingleQuery() + match = &cypher.Match{} + readIdentifiers = newIdentifierSet() + relationshipKinds graph.Kinds + ) + + createScope, err := collectCreateScope(s.identifiers, s.creates...) + if err != nil { + return nil, err + } + + if err := s.buildUpdatingClauses(singlePartQuery); err != nil { + return nil, err + } + + if err := s.buildProjection(singlePartQuery); err != nil { + return nil, err + } + + if len(s.constraints) > 0 { + var ( + whereClause = match.NewWhere() + constraints = cypher.NewConjunction() + numRelationshipKindMatchers, err = countRelationshipKindMatchers(s.constraints, s.identifiers) + ) + if err != nil { + return nil, err + } + + for _, nextConstraint := range s.constraints { + switch typedNextConstraint := nextConstraint.(type) { + case *cypher.KindMatcher: + if identifier, typeOK := typedNextConstraint.Reference.(*cypher.Variable); !typeOK { + return nil, fmt.Errorf("expected type *cypher.Variable, got %T", typedNextConstraint.Reference) + } else if identifier.Symbol == s.identifiers.relationship && numRelationshipKindMatchers == 1 { + relationshipKinds = relationshipKinds.Add(typedNextConstraint.Kinds...) + readIdentifiers.Add(s.identifiers.relationship) + continue + } + } + + constraintCopy := cypher.Copy(nextConstraint) + constraints.Add(parenthesizeDisjunctiveExpression(constraintCopy)) + } + + if constraints.Len() > 0 { + whereClause.Add(constraints) + + whereIdentifiers := newIdentifierSet() + if err := whereIdentifiers.CollectFromExpression(whereClause); err != nil { + return nil, err + } + + if s.usesRangedRelationshipPattern() && whereIdentifiers.Contains(s.identifiers.relationship) { + return nil, fmt.Errorf("ranged relationship patterns only support top-level relationship kind constraints") + } + + readIdentifiers.Or(whereIdentifiers) + } + } + + actionIdentifiers, err := collectIdentifiersFromValues(s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems) + if err != nil { + return nil, err + } + + actionIdentifiers.Remove(createScope.identifiers) + + if s.usesRangedRelationshipPattern() && actionIdentifiers.Contains(s.identifiers.relationship) { + return nil, fmt.Errorf("ranged relationship patterns do not support relationship projections or mutations; return the path instead") + } + + matchIdentifiers := readIdentifiers.Clone() + matchIdentifiers.Or(actionIdentifiers) + + if err := validateKnownIdentifiers(matchIdentifiers, s.identifiers); err != nil { + return nil, err + } + + if s.wantsTraversalPattern() && !isRelationshipPattern(matchIdentifiers, s.identifiers) && !matchIdentifiers.Contains(s.identifiers.path) { + return nil, fmt.Errorf("recursive traversal query requires relationship query identifiers") + } + + if s.wantsShortestPathPattern() && !isRelationshipPattern(matchIdentifiers, s.identifiers) { + return nil, fmt.Errorf("shortest path query requires relationship query identifiers") + } + + if len(s.constraints) > 0 || matchIdentifiers.Len() > 0 { + if isNodePattern(matchIdentifiers, s.identifiers) { + if err := prepareNodePattern(match, matchIdentifiers, s.identifiers); err != nil { + return nil, err + } + } else if createScope.createsRelationship && !matchIdentifiers.Contains(s.identifiers.relationship) { + if err := prepareCreateRelationshipMatch(match, matchIdentifiers, s.identifiers); err != nil { + return nil, err + } + } else if isRelationshipPattern(matchIdentifiers, s.identifiers) || (s.wantsTraversalPattern() && matchIdentifiers.Contains(s.identifiers.path)) { + if err := prepareRelationshipPattern(match, matchIdentifiers, s.identifiers, relationshipKinds, s.traversalDepth, s.relationshipDirection, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("query has no node and relationship query identifiers specified") + } + } + + if len(match.Pattern) > 0 { + newReadingClause := cypher.NewReadingClause() + newReadingClause.Match = match + + singlePartQuery.ReadingClauses = append(singlePartQuery.ReadingClauses, newReadingClause) + } + + if err := collectModelErrors(regularQuery); err != nil { + return nil, err + } + + if parameters, err := materializeParameters(regularQuery); err != nil { + return nil, err + } else { + return &PreparedQuery{ + Query: regularQuery, + Parameters: parameters, + }, nil + } +} diff --git a/query/v2/query_test.go b/query/v2/query_test.go new file mode 100644 index 00000000..61ea0d32 --- /dev/null +++ b/query/v2/query_test.go @@ -0,0 +1,985 @@ +package v2_test + +import ( + "testing" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/cypher/format" + "github.com/specterops/dawgs/graph" + v2 "github.com/specterops/dawgs/query/v2" + "github.com/stretchr/testify/require" +) + +func renderPrepared(t *testing.T, preparedQuery *v2.PreparedQuery) string { + t.Helper() + + cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + return cypherQueryStr +} + +func firstCreateClause(t *testing.T, preparedQuery *v2.PreparedQuery) *cypher.Create { + t.Helper() + + updatingClauses := preparedQuery.Query.SingleQuery.SinglePartQuery.UpdatingClauses + require.NotEmpty(t, updatingClauses) + + updatingClause, typeOK := updatingClauses[0].(*cypher.UpdatingClause) + require.True(t, typeOK) + + createClause, typeOK := updatingClause.Clause.(*cypher.Create) + require.True(t, typeOK) + + return createClause +} + +func TestQuery(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Not(v2.Relationship().Kind().Is(graph.StringKind("test"))), + v2.Not(v2.Relationship().Kind().IsOneOf(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")})), + v2.Relationship().Property("rel_prop").LessThanOrEqualTo(1234), + v2.Relationship().Property("other_prop").Equals(5678), + v2.Start().Kinds().HasOneOf(graph.Kinds{graph.StringKind("test")}), + ).Update( + v2.Start().Property("this_prop").Set(1234), + v2.End().Kinds().Remove(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")}), + ).Delete( + v2.Start(), + ).Return( + v2.Relationship(), + v2.Start().Property("node_prop"), + ).Skip(10).Limit(10).Build() + require.NoError(t, err) + + cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + require.Equal(t, "match (s)-[r]->(e) where not r:test and not (r:A or r:B) and r.rel_prop <= $p0 and r.other_prop = $p1 and s:test set s.this_prop = $p2 remove e:A:B detach delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) + require.Equal(t, map[string]any{ + "p0": 1234, + "p1": 5678, + "p2": 1234, + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().Create( + v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, cypher.NewParameter("props", map[string]any{})), + ).Build() + + require.NoError(t, err) + + cypherQueryStr, err = format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + require.Equal(t, "create (n:A $props)", cypherQueryStr) + require.Equal(t, map[string]any{ + "props": map[string]any{}, + }, preparedQuery.Parameters) +} + +func TestCreateRelationshipWithMatchedEndpoints(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.Relationship().RelationshipPattern(graph.StringKind("A"), v2.NamedParameter("props", map[string]any{"name": "rel"}), graph.DirectionOutbound), + ).Return( + v2.Relationship().ID(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (s), (e) where id(s) = $p0 and id(e) = $p1 create (s)-[r:A $props]->(e) return id(r)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "props": map[string]any{"name": "rel"}, + }, preparedQuery.Parameters) +} + +func TestCreateRelationshipWithExplicitEndpoints(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.Start(), + v2.RelationshipPattern(graph.StringKind("A"), v2.NamedParameter("props", map[string]any{"name": "rel"}), graph.DirectionOutbound), + v2.End(), + ).Return( + v2.Relationship().ID(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (s), (e) where id(s) = $p0 and id(e) = $p1 create (s)-[r:A $props]->(e) return id(r)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "props": map[string]any{"name": "rel"}, + }, preparedQuery.Parameters) +} + +func TestCreateSplitsDisjointNodePatterns(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.NodePattern(graph.Kinds{graph.StringKind("A")}, nil), + v2.NodePattern(graph.Kinds{graph.StringKind("B")}, nil), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (n:A), (n:B)", renderPrepared(t, preparedQuery)) + require.Len(t, firstCreateClause(t, preparedQuery).Pattern, 2) +} + +func TestCreateSplitsBackToBackRelationshipPatterns(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.RelationshipPattern(graph.StringKind("A"), nil, graph.DirectionOutbound), + v2.RelationshipPattern(graph.StringKind("B"), nil, graph.DirectionOutbound), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (s)-[r:A]->(e), (s)-[r:B]->(e)", renderPrepared(t, preparedQuery)) + require.Len(t, firstCreateClause(t, preparedQuery).Pattern, 2) +} + +func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, v2.NamedParameter("props", map[string]any{"name": "node"})), + ).Return( + v2.Node().ID(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (n:A $props) return id(n)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "props": map[string]any{"name": "node"}, + }, preparedQuery.Parameters) +} + +func TestLogicalHelpersPreservePrecedence(t *testing.T) { + a := v2.Node().Property("a").Equals("a") + b := v2.Node().Property("b").Equals("b") + c := v2.Node().Property("c").Equals("c") + + testCases := []struct { + name string + builder v2.QueryBuilder + expected string + }{ + { + name: "or is parenthesized in isolation", + builder: v2.New().Where( + v2.Or(a, b), + ).Return(v2.Node()), + expected: "match (n) where (n.a = $p0 or n.b = $p1) return n", + }, + { + name: "or is parenthesized when where and-chains constraints", + builder: v2.New().Where( + v2.Or(a, b), + c, + ).Return(v2.Node()), + expected: "match (n) where (n.a = $p0 or n.b = $p1) and n.c = $p2 return n", + }, + { + name: "nested or is parenthesized inside and", + builder: v2.New().Where( + v2.And(a, v2.Or(b, c)), + ).Return(v2.Node()), + expected: "match (n) where n.a = $p0 and (n.b = $p1 or n.c = $p2) return n", + }, + { + name: "nested and is parenthesized inside or", + builder: v2.New().Where( + v2.Or(v2.And(a, b), c), + ).Return(v2.Node()), + expected: "match (n) where ((n.a = $p0 and n.b = $p1) or n.c = $p2) return n", + }, + { + name: "not wraps or", + builder: v2.New().Where( + v2.Not(v2.Or(a, b)), + ).Return(v2.Node()), + expected: "match (n) where not (n.a = $p0 or n.b = $p1) return n", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + require.Equal(t, testCase.expected, renderPrepared(t, preparedQuery)) + }) + } +} + +func TestInvalidCreateQualifiedExpressionReturnsError(t *testing.T) { + _, err := v2.New().Create(v2.Node().Property("name")).Build() + require.ErrorContains(t, err, "invalid qualified expression for create: *cypher.PropertyLookup") +} + +func TestUpdatingClausesPreserveFluentOrder(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.NodePattern(graph.Kinds{graph.StringKind("User")}, nil), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "created"), + ).Return( + v2.Node().Property("name"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (n:User) set n.name = $p0 return n.name", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": "created", + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.DeleteProperties(v2.Node(), "old"), + ).Update( + v2.SetProperty(v2.Node().Property("new"), "value"), + ).Return( + v2.Node(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where id(n) = $p0 remove n.old set n.new = $p1 return n", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": "value", + }, preparedQuery.Parameters) +} + +func TestScopedRelationshipPatternControls(t *testing.T) { + scope := v2.NewScope("path", "person", "source", "edge", "target") + + preparedQuery, err := scope.New().WithRelationshipDirection(graph.DirectionInbound).Where( + scope.Relationship().Kind().Is(graph.StringKind("MemberOf")), + scope.Start().ID().Equals(1), + ).Return( + scope.Relationship().Kind(), + scope.End().Kinds(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (source)<-[edge:MemberOf]-(target) where id(source) = $p0 return type(edge), labels(target)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + }, preparedQuery.Parameters) +} + +func TestScopedKindsOfCompatibilityHelper(t *testing.T) { + scope := v2.NewScope("path", "person", "source", "edge", "target") + + preparedQuery, err := scope.New().Return( + v2.KindsOf(scope.Relationship()), + v2.KindsOf(scope.End()), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match ()-[edge]->(target) return type(edge), labels(target)", renderPrepared(t, preparedQuery)) +} + +func TestInvalidScopeAliasesReturnBuildErrors(t *testing.T) { + emptyAliasScope := v2.NewScope("", "node", "start", "relationship", "end") + _, err := emptyAliasScope.New().Return(emptyAliasScope.Node()).Build() + require.ErrorContains(t, err, "scope alias path is empty") + + duplicateAliasScope := v2.NewScope("path", "node", "node", "relationship", "end") + _, err = duplicateAliasScope.New().Return(duplicateAliasScope.Start()).Build() + require.ErrorContains(t, err, `scope aliases node and start both use "node"`) + + invalidAliasScope := v2.NewScope("path", "bad name", "start", "relationship", "end") + _, err = invalidAliasScope.New().Return(invalidAliasScope.Node()).Build() + require.ErrorContains(t, err, `scope alias node has invalid symbol "bad name"`) +} + +func TestUnicodeCypherSymbols(t *testing.T) { + scope := v2.NewScope("路径", "节点", "起点", "关系", "终点") + + preparedQuery, err := scope.New().Where( + scope.Node().Property("name").Equals(v2.NamedParameter("名字", "alice")), + ).Return( + v2.As(scope.Node().ID(), "标识"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (节点) where 节点.name = $名字 return id(节点) as 标识", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "名字": "alice", + }, preparedQuery.Parameters) +} + +func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { + _, err := v2.New().WithRelationshipDirection(graph.Direction(99)).Return(v2.Relationship()).Build() + require.ErrorContains(t, err, "unsupported relationship direction: invalid") +} + +func TestRelationshipDirectionBoth(t *testing.T) { + preparedQuery, err := v2.New().WithRelationshipDirection(graph.DirectionBoth).Return(v2.Relationship()).Build() + require.NoError(t, err) + + require.Equal(t, "match ()-[r]-() return r", renderPrepared(t, preparedQuery)) +} + +func TestShortestPathControls(t *testing.T) { + preparedQuery, err := v2.New().WithShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + + preparedQuery, err = v2.New().WithShortestPaths().WithTraversalDepth(v2.MinDepth(1)).Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match p = shortestPath((s)-[r:MemberOf*1..]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + + _, err = v2.New().WithShortestPaths().WithAllShortestPaths().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.ErrorContains(t, err, "query is requesting both all shortest paths and shortest paths") + + _, err = v2.New().WithShortestPaths().Return(v2.Node()).Build() + require.ErrorContains(t, err, "shortest path query requires relationship query identifiers") + + _, err = v2.New().WithAllShortestPaths().Return(v2.As(v2.Literal(1), "one")).Build() + require.ErrorContains(t, err, "shortest path query requires relationship query identifiers") +} + +func TestTraversalDepthControls(t *testing.T) { + cases := map[string]struct { + builder v2.QueryBuilder + expectedCypher string + expectedParams map[string]any + }{ + "any depth": { + builder: v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.End()), + expectedCypher: "match ()-[*]->(e) return e", + }, + "minimum depth": { + builder: v2.New().WithTraversalDepth(v2.MinDepth(1)).Return(v2.End()), + expectedCypher: "match ()-[*1..]->(e) return e", + }, + "maximum depth": { + builder: v2.New().WithTraversalDepth(v2.MaxDepth(5)).Return(v2.End()), + expectedCypher: "match ()-[*..5]->(e) return e", + }, + "depth range": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 5)).Where( + v2.Relationship().Kind().IsOneOf(graph.Kinds{graph.StringKind("KindA"), graph.StringKind("KindB")}), + v2.Start().ID().Equals(1), + v2.End().Kinds().Has(graph.StringKind("User")), + ).Return( + v2.Path(), + v2.End(), + ), + expectedCypher: "match p = (s)-[r:KindA|KindB*1..5]->(e) where id(s) = $p0 and e:User return p, e", + expectedParams: map[string]any{"p0": 1}, + }, + "exact depth": { + builder: v2.New().WithTraversalDepth(v2.ExactDepth(3)).Return(v2.End()), + expectedCypher: "match ()-[*3..3]->(e) return e", + }, + "inbound depth range": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(2, 5)).WithRelationshipDirection(graph.DirectionInbound).Return(v2.Start()), + expectedCypher: "match (s)<-[*2..5]-() return s", + }, + "path only": { + builder: v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.Path()), + expectedCypher: "match p = ()-[*]->() return p", + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + require.Equal(t, testCase.expectedCypher, renderPrepared(t, preparedQuery)) + + if testCase.expectedParams == nil { + require.Empty(t, preparedQuery.Parameters) + } else { + require.Equal(t, testCase.expectedParams, preparedQuery.Parameters) + } + }) + } +} + +func TestInvalidTraversalDepthControls(t *testing.T) { + _, err := v2.New().WithTraversalDepth(v2.MinDepth(-1)).Return(v2.End()).Build() + require.ErrorContains(t, err, "traversal depth minimum must be non-negative: -1") + + _, err = v2.New().WithTraversalDepth(v2.MaxDepth(-1)).Return(v2.End()).Build() + require.ErrorContains(t, err, "traversal depth maximum must be non-negative: -1") + + _, err = v2.New().WithTraversalDepth(v2.DepthRange(3, 1)).Return(v2.End()).Build() + require.ErrorContains(t, err, "traversal depth maximum 1 is less than minimum 3") + + _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.Node()).Build() + require.ErrorContains(t, err, "recursive traversal query requires relationship query identifiers") + + _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Where( + v2.Relationship().Property("enabled").Equals(true), + ).Return( + v2.End(), + ).Build() + require.ErrorContains(t, err, "ranged relationship patterns only support top-level relationship kind constraints") + + _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.Relationship()).Build() + require.ErrorContains(t, err, "ranged relationship patterns do not support relationship projections or mutations") + + _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Where( + v2.Start().ID().Equals(1), + ).Delete( + v2.Relationship(), + ).Build() + require.ErrorContains(t, err, "ranged relationship patterns do not support relationship projections or mutations") +} + +func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) { + _, err := v2.New().Where( + v2.Node().ID().Equals(1), + v2.Relationship().ID().Equals(2), + ).Return( + v2.Node(), + ).Build() + require.ErrorContains(t, err, "query mixes node and relationship query identifiers") +} + +func TestRawIdentifiersMustBeKnownToScope(t *testing.T) { + cases := map[string]v2.QueryBuilder{ + "delete": v2.New().Where( + v2.Node().ID().Equals(1), + ).Delete( + v2.Variable("x"), + ), + "projection": v2.New().Return( + v2.Node(), + v2.Variable("x"), + ), + "sort": v2.New().Return( + v2.Node(), + ).OrderBy( + v2.Asc(v2.Variable("x")), + ), + } + + for name, builder := range cases { + t.Run(name, func(t *testing.T) { + _, err := builder.Build() + require.ErrorContains(t, err, `query contains unknown identifier "x"`) + }) + } +} + +func TestPathIdentifierRequiresShortestPathMatch(t *testing.T) { + _, err := v2.New().Return( + v2.Node(), + v2.Path(), + ).Build() + require.ErrorContains(t, err, `query contains unbound identifier "p"`) + + _, err = v2.New().Return( + v2.Relationship(), + v2.Path(), + ).Build() + require.ErrorContains(t, err, `query contains unbound identifier "p"`) +} + +func TestCreatedRawIdentifiersDoNotRequireMatch(t *testing.T) { + preparedQuery, err := v2.New().Create(&cypher.NodePattern{ + Variable: v2.Variable("created"), + }).Return( + v2.Variable("created"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (created) return created", renderPrepared(t, preparedQuery)) +} + +func TestMultipleRelationshipKindMatchersRemainConjunctive(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Relationship().Kind().Is(graph.StringKind("A")), + v2.Relationship().Kind().Is(graph.StringKind("B")), + ).Return( + v2.Relationship(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match ()-[r]->() where r:A and r:B return r", renderPrepared(t, preparedQuery)) +} + +func TestEmptyLogicalHelpersReturnBuildErrors(t *testing.T) { + _, err := v2.New().Where(v2.And()).Return(v2.Node()).Build() + require.ErrorContains(t, err, "and requires at least one operand") + + _, err = v2.New().Where(v2.Or()).Return(v2.Node()).Build() + require.ErrorContains(t, err, "or requires at least one operand") +} + +func TestExplicitRelationshipPatternDirectionBoth(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.RelationshipPattern(graph.StringKind("Edge"), nil, graph.DirectionBoth), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (s)-[r:Edge]-(e)", renderPrepared(t, preparedQuery)) +} + +func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) { + _, err := v2.New().Create( + v2.Relationship().RelationshipPattern(graph.StringKind("Edge"), nil, graph.Direction(99)), + ).Build() + require.ErrorContains(t, err, "unsupported relationship direction: invalid") +} + +func TestProjectionAndOrderHelpers(t *testing.T) { + preparedQuery, err := v2.New().ReturnDistinct( + v2.As(v2.Node().ID(), "node_id"), + ).OrderBy( + v2.Node().Property("name"), + v2.Desc(v2.Node().ID()), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return distinct id(n) as node_id order by n.name asc, id(n) desc", renderPrepared(t, preparedQuery)) +} + +func TestInvalidSortDirectionReturnsError(t *testing.T) { + _, err := v2.New().Return(v2.Node()).OrderBy( + v2.Order(v2.Node().Property("name"), v2.SortDirection(99)), + ).Build() + require.ErrorContains(t, err, "unsupported sort direction: 99") +} + +func TestPaginationZeroValuesAndNegativeValidation(t *testing.T) { + preparedQuery, err := v2.New().Return(v2.Node()).Skip(0).Limit(0).Build() + require.NoError(t, err) + require.Equal(t, "match (n) return n skip 0 limit 0", renderPrepared(t, preparedQuery)) + + _, err = v2.New().Return(v2.Node()).Skip(-1).Build() + require.ErrorContains(t, err, "skip must be non-negative: -1") + + _, err = v2.New().Return(v2.Node()).Limit(-1).Build() + require.ErrorContains(t, err, "limit must be non-negative: -1") +} + +func TestProjectionAliasDoesNotCreateMatchInference(t *testing.T) { + preparedQuery, err := v2.New().Return( + v2.As(v2.Literal(1), "one"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "return 1 as one", renderPrepared(t, preparedQuery)) + require.Empty(t, preparedQuery.Parameters) +} + +func TestAliasedProjectionCreatesMatchInference(t *testing.T) { + preparedQuery, err := v2.New().Return( + v2.As(v2.Node().ID(), "node_id"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return id(n) as node_id", renderPrepared(t, preparedQuery)) + + preparedQuery, err = v2.New().Return( + v2.As(v2.Node(), "alias"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return n as alias", renderPrepared(t, preparedQuery)) +} + +func TestInvalidProjectionAliasReturnsBuildError(t *testing.T) { + _, err := v2.New().Return(v2.As(v2.Literal(1), "bad alias")).Build() + require.ErrorContains(t, err, `projection alias has invalid symbol "bad alias"`) + + _, err = v2.New().Return(&cypher.ProjectionItem{ + Expression: v2.Literal(1), + Alias: cypher.NewVariableWithSymbol("1bad"), + }).Build() + require.ErrorContains(t, err, `projection alias has invalid symbol "1bad"`) +} + +func TestUnsupportedOrderByTypeReturnsError(t *testing.T) { + _, err := v2.New().Return(v2.Node()).OrderBy(123).Build() + require.ErrorContains(t, err, "unsupported expression type: int") +} + +func TestRawProjectionAndOrderInputsAreValidated(t *testing.T) { + _, err := v2.New().Return(&cypher.Return{}).Build() + require.ErrorContains(t, err, "return clause has nil projection") + + returnClause := cypher.NewReturn() + returnClause.NewProjection(false).Items = append(returnClause.Projection.Items, &cypher.ProjectionItem{}) + _, err = v2.New().Return(returnClause).Build() + require.ErrorContains(t, err, "projection item has nil expression") + + _, err = v2.New().Return(v2.Node()).OrderBy(&cypher.SortItem{}).Build() + require.ErrorContains(t, err, "sort item has nil expression") + + _, err = v2.New().Return(v2.Node()).OrderBy(&cypher.Order{ + Items: []*cypher.SortItem{{}}, + }).Build() + require.ErrorContains(t, err, "sort item has nil expression") +} + +func TestRawProjectionAndOrderInputsAreNormalized(t *testing.T) { + returnClause := cypher.NewReturn() + returnClause.NewProjection(false).Items = append(returnClause.Projection.Items, v2.Node().ID()) + + preparedQuery, err := v2.New().Return(returnClause).OrderBy(&cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + }).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return id(n) order by n.name desc", renderPrepared(t, preparedQuery)) +} + +func TestRawReturnInputPreservesProjectionMetadata(t *testing.T) { + returnClause := cypher.NewReturn() + projection := returnClause.NewProjection(true) + projection.Items = append(projection.Items, v2.Node().ID()) + projection.Order = &cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + } + projection.Skip = cypher.NewSkip(5) + projection.Limit = cypher.NewLimit(10) + + preparedQuery, err := v2.New().Return(returnClause).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return distinct id(n) order by n.name desc skip 5 limit 10", renderPrepared(t, preparedQuery)) +} + +func TestRawReturnMetadataCreatesMatchInference(t *testing.T) { + returnClause := cypher.NewReturn() + projection := returnClause.NewProjection(false) + projection.Items = append(projection.Items, v2.Literal(1)) + projection.Order = &cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + } + + preparedQuery, err := v2.New().Return(returnClause).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return 1 order by n.name desc", renderPrepared(t, preparedQuery)) +} + +func TestRawReturnInputMergesWithBuilderProjectionControls(t *testing.T) { + returnClause := cypher.NewReturn() + projection := returnClause.NewProjection(true) + projection.Items = append(projection.Items, v2.Node().ID()) + projection.Order = &cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + } + projection.Skip = cypher.NewSkip(5) + projection.Limit = cypher.NewLimit(10) + + preparedQuery, err := v2.New().Return(returnClause).OrderBy( + v2.Asc(v2.Node().Property("created_at")), + ).Skip(15).Limit(20).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return distinct id(n) order by n.name desc, n.created_at asc skip 15 limit 20", renderPrepared(t, preparedQuery)) +} + +func TestRawUpdatingInputsAreValidated(t *testing.T) { + var setClause *cypher.Set + _, err := v2.New().Update(setClause).Build() + require.ErrorContains(t, err, "set clause is nil") + + _, err = v2.New().Update(&cypher.Set{Items: []*cypher.SetItem{nil}}).Build() + require.ErrorContains(t, err, "set item is nil") + + _, err = v2.New().Update(&cypher.SetItem{}).Build() + require.ErrorContains(t, err, "set item left has nil expression") + + var removeClause *cypher.Remove + _, err = v2.New().Update(removeClause).Build() + require.ErrorContains(t, err, "remove clause is nil") + + _, err = v2.New().Update(&cypher.Remove{Items: []*cypher.RemoveItem{nil}}).Build() + require.ErrorContains(t, err, "remove item is nil") + + _, err = v2.New().Update(&cypher.RemoveItem{}).Build() + require.ErrorContains(t, err, "remove item has no target") + + var deleteVariable *cypher.Variable + _, err = v2.New().Delete(deleteVariable).Build() + require.ErrorContains(t, err, "delete expression has nil expression") + + var nodePattern *cypher.NodePattern + _, err = v2.New().Create(nodePattern).Build() + require.ErrorContains(t, err, "node pattern is nil") + + var relationshipPattern *cypher.RelationshipPattern + _, err = v2.New().Create(relationshipPattern).Build() + require.ErrorContains(t, err, "relationship pattern is nil") +} + +func TestDeleteRejectsNonTargetQualifiedExpressions(t *testing.T) { + cases := map[string]any{ + "property continuation": v2.Node().Property("name"), + "raw property lookup": cypher.NewPropertyLookup("n", "name"), + "id": v2.Node().ID(), + "kinds": v2.Node().Kinds(), + "kind": v2.Relationship().Kind(), + } + + for name, target := range cases { + t.Run(name, func(t *testing.T) { + _, err := v2.New().Delete(target).Build() + require.ErrorContains(t, err, "delete target must be a node, relationship, or variable") + }) + } +} + +func TestDeleteRejectsPathTargets(t *testing.T) { + _, err := v2.New().Delete(v2.Path()).Build() + require.ErrorContains(t, err, `delete target must be a node or relationship variable, got path variable "p"`) + + _, err = v2.New().Delete(v2.Variable("p")).Build() + require.ErrorContains(t, err, `delete target must be a node or relationship variable, got path variable "p"`) +} + +func TestInvalidHelperInputsReturnBuildErrors(t *testing.T) { + cases := map[string]struct { + builder v2.QueryBuilder + err string + }{ + "aliased projection": { + builder: v2.New().Return(v2.As(123, "bad")), + err: "unsupported expression type: int", + }, + "sort item": { + builder: v2.New().Return(v2.Node()).OrderBy(v2.Desc(123)), + err: "unsupported expression type: int", + }, + "set properties": { + builder: v2.New().Update(v2.SetProperties(123, map[string]any{"name": "bad"})), + err: "unsupported expression type: int", + }, + "delete properties": { + builder: v2.New().Update(v2.DeleteProperties(123, "name")), + err: "unsupported expression type: int", + }, + "pattern predicate": { + builder: v2.New().Where(v2.HasRelationships(123)).Return(v2.Node()), + err: "unsupported expression type: int", + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + _, err := testCase.builder.Build() + require.ErrorContains(t, err, testCase.err) + }) + } +} + +func TestNamedParameterMaterialization(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Node().Property("first").Equals("auto"), + v2.Node().Property("second").Equals(v2.NamedParameter("p0", "named")), + ).Return( + v2.Node(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where n.first = $p1 and n.second = $p0 return n", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": "named", + "p1": "auto", + }, preparedQuery.Parameters) + + _, err = v2.New().Where( + v2.Node().Property("name").Equals(v2.NamedParameter("bad name", "value")), + ).Return( + v2.Node(), + ).Build() + require.ErrorContains(t, err, `parameter has invalid symbol "bad name"`) + + _, err = v2.New().Where( + v2.Node().Property("first").Equals(v2.NamedParameter("same", "first")), + v2.Node().Property("second").Equals(v2.NamedParameter("same", "second")), + ).Return( + v2.Node(), + ).Build() + require.ErrorContains(t, err, "parameter same is bound to multiple values") +} + +func TestQualifiedExpressionValuesUseProjectionSemantics(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Node().Property("copy").Equals(v2.Node().Property("source")), + v2.Node().Property("kinds").Equals(v2.Node().Kinds()), + ).Return( + v2.Node(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match (n) where n.copy = n.source and n.kinds = labels(n) return n", renderPrepared(t, preparedQuery)) + + preparedQuery, err = v2.New().Where( + v2.Relationship().Property("kind").Equals(v2.Relationship().Kind()), + ).Return( + v2.Relationship(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match ()-[r]->() where r.kind = type(r) return r", renderPrepared(t, preparedQuery)) +} + +func TestBuildDoesNotMutateCallerOwnedAST(t *testing.T) { + constraint := v2.Node().Property("name").Equals("alice") + constraintParameter := constraint.(*cypher.Comparison).FirstPartial().Right.(*cypher.Parameter) + + preparedQuery, err := v2.New().Where(constraint).Return(v2.Node()).Build() + require.NoError(t, err) + require.Equal(t, map[string]any{"p0": "alice"}, preparedQuery.Parameters) + require.Empty(t, constraintParameter.Symbol) + + setItem := v2.SetProperty(v2.Node().Property("status"), "active") + setParameter := setItem.Right.(*cypher.Parameter) + + preparedQuery, err = v2.New().Where(v2.Node().ID().Equals(1)).Update(setItem).Build() + require.NoError(t, err) + require.Equal(t, map[string]any{"p0": 1, "p1": "active"}, preparedQuery.Parameters) + require.Empty(t, setParameter.Symbol) + + createPattern := v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.Parameter(map[string]any{"name": "node"})) + createParameter := createPattern.Properties.(*cypher.Parameter) + + preparedQuery, err = v2.New().Create(createPattern).Build() + require.NoError(t, err) + require.Equal(t, map[string]any{"p0": map[string]any{"name": "node"}}, preparedQuery.Parameters) + require.Empty(t, createParameter.Symbol) + + rawReturn := cypher.NewReturn() + rawReturn.NewProjection(false).AddItem(cypher.NewProjectionItemWithExpr(v2.Parameter("projected"))) + rawReturnParameter := rawReturn.Projection.Items[0].(*cypher.ProjectionItem).Expression.(*cypher.Parameter) + + preparedQuery, err = v2.New().Return(rawReturn).Build() + require.NoError(t, err) + require.Equal(t, map[string]any{"p0": "projected"}, preparedQuery.Parameters) + require.Empty(t, rawReturnParameter.Symbol) +} + +func TestCompatibilityHelpers(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.And( + v2.InIDs(v2.NodeID(), 1, 2), + v2.KindIn(v2.Node(), graph.StringKind("User")), + v2.CaseInsensitiveStringContains(v2.Node().Property("name"), "ADMIN"), + v2.IsNotNull(v2.Node().Property("enabled")), + ), + ).Return( + v2.CountDistinct(v2.Node()), + v2.KindsOf(v2.Node()), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where id(n) in $p0 and n:User and toLower(n.name) contains $p1 and n.enabled is not null return count(distinct n), labels(n)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": []graph.ID{1, 2}, + "p1": "admin", + }, preparedQuery.Parameters) +} + +func TestUpdateCompatibilityHelpers(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.AddKind(v2.Node(), graph.StringKind("Enabled")), + v2.SetProperties(v2.Node(), map[string]any{"name": "updated"}), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where id(n) = $p0 set n:Enabled, n.name = $p1", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": "updated", + }, preparedQuery.Parameters) +} + +func TestFluentMutationHelpersReturnConcreteMutationTypes(t *testing.T) { + kinds := graph.Kinds{graph.StringKind("Enabled")} + + var addItem *cypher.SetItem = v2.Node().Kinds().Add(kinds) + require.NotNil(t, addItem) + + var removeItem *cypher.RemoveItem = v2.Node().Kinds().Remove(kinds) + require.NotNil(t, removeItem) + + var nodeSet *cypher.Set = v2.Node().SetProperties(map[string]any{"name": "updated"}) + require.Len(t, nodeSet.Items, 1) + + var nodeRemove *cypher.Remove = v2.Node().RemoveProperties([]string{"stale"}) + require.Len(t, nodeRemove.Items, 1) + + var relationshipSet *cypher.Set = v2.Relationship().SetProperties(map[string]any{"name": "updated"}) + require.Len(t, relationshipSet.Items, 1) + + var relationshipRemove *cypher.Remove = v2.Relationship().RemoveProperties([]string{"stale"}) + require.Len(t, relationshipRemove.Items, 1) +} + +func TestSetPropertiesSortsKeys(t *testing.T) { + properties := map[string]any{ + "zeta": 3, + "alpha": 1, + "mid": 2, + } + + preparedQuery, err := v2.New().Update( + v2.SetProperties(v2.Node(), properties), + ).Build() + require.NoError(t, err) + require.Equal(t, "match (n) set n.alpha = $p0, n.mid = $p1, n.zeta = $p2", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "p2": 3, + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().Update( + v2.Node().SetProperties(properties), + ).Build() + require.NoError(t, err) + require.Equal(t, "match (n) set n.alpha = $p0, n.mid = $p1, n.zeta = $p2", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "p2": 3, + }, preparedQuery.Parameters) +} diff --git a/query/v2/util.go b/query/v2/util.go new file mode 100644 index 00000000..03b5fb10 --- /dev/null +++ b/query/v2/util.go @@ -0,0 +1,1277 @@ +package v2 + +import ( + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "unicode" + "unicode/utf8" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/walk" + "github.com/specterops/dawgs/graph" +) + +func isNodePattern(seen *identifierSet, identifiers runtimeIdentifiers) bool { + return seen.Contains(identifiers.node) +} + +func isRelationshipPattern(seen *identifierSet, identifiers runtimeIdentifiers) bool { + var ( + hasStart = seen.Contains(identifiers.start) + hasRelationship = seen.Contains(identifiers.relationship) + hasEnd = seen.Contains(identifiers.end) + ) + + return hasStart || hasRelationship || hasEnd +} + +func runtimeIdentifierSet(identifiers runtimeIdentifiers) *identifierSet { + return newIdentifierSet( + identifiers.path, + identifiers.node, + identifiers.start, + identifiers.relationship, + identifiers.end, + ) +} + +func nodePatternIdentifierSet(identifiers runtimeIdentifiers) *identifierSet { + return newIdentifierSet(identifiers.node) +} + +func relationshipPatternIdentifierSet(identifiers runtimeIdentifiers, includePath bool) *identifierSet { + allowedIdentifiers := newIdentifierSet( + identifiers.start, + identifiers.relationship, + identifiers.end, + ) + + if includePath { + allowedIdentifiers.Add(identifiers.path) + } + + return allowedIdentifiers +} + +func createRelationshipMatchIdentifierSet(identifiers runtimeIdentifiers) *identifierSet { + return newIdentifierSet(identifiers.start, identifiers.end) +} + +func validateKnownIdentifiers(seen *identifierSet, identifiers runtimeIdentifiers) error { + if identifier, hasIdentifier := seen.FirstOutside(runtimeIdentifierSet(identifiers)); hasIdentifier { + return fmt.Errorf("query contains unknown identifier %q", identifier) + } + + return nil +} + +func validateBoundIdentifiers(seen, bound *identifierSet) error { + if identifier, hasIdentifier := seen.FirstOutside(bound); hasIdentifier { + return fmt.Errorf("query contains unbound identifier %q", identifier) + } + + return nil +} + +func prepareNodePattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error { + if isRelationshipPattern(seen, identifiers) { + return fmt.Errorf("query mixes node and relationship query identifiers") + } + + if err := validateBoundIdentifiers(seen, nodePatternIdentifierSet(identifiers)); err != nil { + return err + } + + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.Node(), + }) + + return nil +} + +func validateRelationshipDirection(direction graph.Direction) error { + switch direction { + case graph.DirectionInbound, graph.DirectionOutbound, graph.DirectionBoth: + return nil + default: + return fmt.Errorf("unsupported relationship direction: %s", direction) + } +} + +func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers, relationshipKinds graph.Kinds, relationshipRange *cypher.PatternRange, direction graph.Direction, shortestPaths, allShortestPaths bool) error { + if shortestPaths && allShortestPaths { + return errors.New("query is requesting both all shortest paths and shortest paths") + } + + if err := validateRelationshipDirection(direction); err != nil { + return err + } + + hasRangedRelationshipPattern := relationshipRange != nil || shortestPaths || allShortestPaths + if err := validateBoundIdentifiers(seen, relationshipPatternIdentifierSet(identifiers, hasRangedRelationshipPattern)); err != nil { + return err + } + + var ( + newPatternPart = match.NewPatternPart() + startNodeSeen = seen.Contains(identifiers.start) + relationshipSeen = seen.Contains(identifiers.relationship) + endNodeSeen = seen.Contains(identifiers.end) + ) + + newPatternPart.ShortestPathPattern = shortestPaths + newPatternPart.AllShortestPathsPattern = allShortestPaths + + if startNodeSeen { + newPatternPart.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.Start(), + }) + } else { + newPatternPart.AddPatternElements(&cypher.NodePattern{}) + } + + relationshipPattern := &cypher.RelationshipPattern{ + Kinds: relationshipKinds, + Direction: direction, + } + + if relationshipSeen { + relationshipPattern.Variable = identifiers.Relationship() + } + + if shortestPaths || allShortestPaths || (relationshipRange != nil && seen.Contains(identifiers.path)) { + newPatternPart.Variable = identifiers.Path() + } + + if relationshipRange != nil { + relationshipPattern.Range = cypher.Copy(relationshipRange) + } else if shortestPaths || allShortestPaths { + relationshipPattern.Range = &cypher.PatternRange{} + } + + newPatternPart.AddPatternElements(relationshipPattern) + + if endNodeSeen { + newPatternPart.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.End(), + }) + } else { + newPatternPart.AddPatternElements(&cypher.NodePattern{}) + } + + return nil +} + +func prepareCreateRelationshipMatch(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error { + if err := validateBoundIdentifiers(seen, createRelationshipMatchIdentifierSet(identifiers)); err != nil { + return err + } + + if seen.Contains(identifiers.start) { + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.Start(), + }) + } + + if seen.Contains(identifiers.end) { + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.End(), + }) + } + + return nil +} + +func isDetachDeleteQualifier(qualifier cypher.Expression, identifiers runtimeIdentifiers) bool { + variable, typeOK := qualifier.(*cypher.Variable) + if !typeOK || variable == nil { + return false + } + + switch variable.Symbol { + case identifiers.node, identifiers.start, identifiers.end: + return true + default: + return false + } +} + +func kindProjectionExpression(role string, identifier *cypher.Variable) (cypher.Expression, error) { + switch role { + case Identifiers.node, Identifiers.start, Identifiers.end: + return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, identifier), nil + + case Identifiers.relationship: + return cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), nil + + default: + return nil, fmt.Errorf("invalid kind projection reference: %s", identifier.Symbol) + } +} + +func invalidExpression(err error) *cypher.FunctionInvocation { + return cypher.WithErrors(cypher.NewSimpleFunctionInvocation("__invalid_expression__"), err) +} + +func isNilPointer(value any) bool { + if value == nil { + return true + } + + reflectValue := reflect.ValueOf(value) + return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() +} + +func expressionOrError(value any) cypher.Expression { + if expression, err := projectionExpression(value); err != nil { + return invalidExpression(err) + } else { + return expression + } +} + +func variableReference(value any) (*cypher.Variable, error) { + expression, err := projectionExpression(value) + if err != nil { + return nil, err + } + + if variable, typeOK := expression.(*cypher.Variable); !typeOK { + return nil, fmt.Errorf("expected variable reference, got %T", expression) + } else { + return variable, nil + } +} + +func propertyLookupOrError(reference any, propertyName string) cypher.Expression { + if variable, err := variableReference(reference); err != nil { + return invalidExpression(err) + } else { + return cypher.NewPropertyLookup(variable.Symbol, propertyName) + } +} + +func sortedPropertyKeys(properties map[string]any) []string { + keys := make([]string, 0, len(properties)) + + for key := range properties { + keys = append(keys, key) + } + + sort.Strings(keys) + return keys +} + +func isCypherSymbolStart(char rune) bool { + return char == '_' || unicode.IsLetter(char) || unicode.In(char, unicode.Nl, unicode.Pc) +} + +func isCypherSymbolPart(char rune) bool { + return isCypherSymbolStart(char) || unicode.IsDigit(char) || unicode.In(char, unicode.Mark, unicode.Sc) +} + +func validateCypherSymbol(symbol, context string) error { + if strings.TrimSpace(symbol) == "" { + return fmt.Errorf("%s is empty", context) + } + + if !utf8.ValidString(symbol) { + return fmt.Errorf("%s has invalid symbol %q", context, symbol) + } + + for idx, char := range symbol { + if idx == 0 { + if !isCypherSymbolStart(char) { + return fmt.Errorf("%s has invalid symbol %q", context, symbol) + } + } else if !isCypherSymbolPart(char) { + return fmt.Errorf("%s has invalid symbol %q", context, symbol) + } + } + + return nil +} + +func projectionExpression(value any) (cypher.Expression, error) { + if isNilPointer(value) { + return nil, fmt.Errorf("expression is nil: %T", value) + } + + switch typedValue := value.(type) { + case kindContinuation: + return kindProjectionExpression(typedValue.role, typedValue.identifier) + + case kindsContinuation: + return kindProjectionExpression(typedValue.role, typedValue.identifier) + + case QualifiedExpression: + return typedValue.qualifier(), nil + + case *cypher.ProjectionItem: + if typedValue.Expression == nil { + return nil, fmt.Errorf("projection item has nil expression") + } + + return typedValue.Expression, nil + + case *cypher.Parameter: + return typedValue, nil + + case *cypher.Literal: + return typedValue, nil + + case *cypher.Variable: + return typedValue, nil + + case *cypher.PropertyLookup: + return typedValue, nil + + case *cypher.FunctionInvocation: + return typedValue, nil + + case *cypher.Parenthetical: + return typedValue, nil + + case *cypher.Comparison: + return typedValue, nil + + case *cypher.Negation: + return typedValue, nil + + case *cypher.Conjunction: + return typedValue, nil + + case *cypher.Disjunction: + return typedValue, nil + + case *cypher.ExclusiveDisjunction: + return typedValue, nil + + case *cypher.KindMatcher: + return typedValue, nil + + case *cypher.ListLiteral: + return typedValue, nil + + case cypher.MapLiteral: + return typedValue, nil + + case *cypher.PatternPredicate: + return typedValue, nil + + case *cypher.ArithmeticExpression: + return typedValue, nil + + case *cypher.UnaryAddOrSubtractExpression: + return typedValue, nil + + case *cypher.FilterExpression: + return typedValue, nil + + case *cypher.IDInCollection: + return typedValue, nil + + default: + return nil, fmt.Errorf("unsupported expression type: %T", value) + } +} + +func copyExpression(expression cypher.Expression) cypher.Expression { + return cypher.Copy(expression) +} + +func copyProjectionItem(item *cypher.ProjectionItem) *cypher.ProjectionItem { + return cypher.Copy(item) +} + +func copySortItem(item *cypher.SortItem) *cypher.SortItem { + return cypher.Copy(item) +} + +func copySetItem(item *cypher.SetItem) *cypher.SetItem { + return cypher.Copy(item) +} + +func copyRemoveItem(item *cypher.RemoveItem) *cypher.RemoveItem { + return cypher.Copy(item) +} + +func copySkip(skip *cypher.Skip) *cypher.Skip { + return cypher.Copy(skip) +} + +func copyLimit(limit *cypher.Limit) *cypher.Limit { + return cypher.Copy(limit) +} + +func validateExpressionValue(expression cypher.Expression, context string) error { + if isNilPointer(expression) { + return fmt.Errorf("%s has nil expression", context) + } + + return collectModelErrors(expression) +} + +func validateAssignmentOperator(operator cypher.AssignmentOperator) error { + switch operator { + case cypher.OperatorAssignment, cypher.OperatorAdditionAssignment, cypher.OperatorLabelAssignment: + return nil + default: + return fmt.Errorf("unsupported set item operator: %s", operator) + } +} + +func setItemFromValue(setItem *cypher.SetItem) (*cypher.SetItem, error) { + if setItem == nil { + return nil, fmt.Errorf("set item is nil") + } + + if err := validateExpressionValue(setItem.Left, "set item left"); err != nil { + return nil, err + } + + if err := validateAssignmentOperator(setItem.Operator); err != nil { + return nil, err + } + + if err := validateExpressionValue(setItem.Right, "set item right"); err != nil { + return nil, err + } + + if err := collectModelErrors(setItem); err != nil { + return nil, err + } + + return copySetItem(setItem), nil +} + +func setItemsFromSet(setClause *cypher.Set) ([]*cypher.SetItem, error) { + if setClause == nil { + return nil, fmt.Errorf("set clause is nil") + } + + setItems := make([]*cypher.SetItem, 0, len(setClause.Items)) + + for _, setItem := range setClause.Items { + if normalizedSetItem, err := setItemFromValue(setItem); err != nil { + return nil, err + } else { + setItems = append(setItems, normalizedSetItem) + } + } + + return setItems, nil +} + +func removeItemFromValue(removeItem *cypher.RemoveItem) (*cypher.RemoveItem, error) { + if removeItem == nil { + return nil, fmt.Errorf("remove item is nil") + } + + hasKindMatcher := removeItem.KindMatcher != nil + hasProperty := !isNilPointer(removeItem.Property) + + switch { + case hasKindMatcher && hasProperty: + return nil, fmt.Errorf("remove item has multiple targets") + + case hasKindMatcher: + if err := collectModelErrors(removeItem.KindMatcher); err != nil { + return nil, err + } + + case hasProperty: + if err := validateExpressionValue(removeItem.Property, "remove item property"); err != nil { + return nil, err + } + + default: + return nil, fmt.Errorf("remove item has no target") + } + + if err := collectModelErrors(removeItem); err != nil { + return nil, err + } + + return copyRemoveItem(removeItem), nil +} + +func removeItemsFromRemove(removeClause *cypher.Remove) ([]*cypher.RemoveItem, error) { + if removeClause == nil { + return nil, fmt.Errorf("remove clause is nil") + } + + removeItems := make([]*cypher.RemoveItem, 0, len(removeClause.Items)) + + for _, removeItem := range removeClause.Items { + if normalizedRemoveItem, err := removeItemFromValue(removeItem); err != nil { + return nil, err + } else { + removeItems = append(removeItems, normalizedRemoveItem) + } + } + + return removeItems, nil +} + +func validateNodePattern(nodePattern *cypher.NodePattern) error { + if nodePattern == nil { + return fmt.Errorf("node pattern is nil") + } + + return collectModelErrors(nodePattern) +} + +func validateRelationshipPattern(relationshipPattern *cypher.RelationshipPattern) error { + if relationshipPattern == nil { + return fmt.Errorf("relationship pattern is nil") + } + + if err := validateRelationshipDirection(relationshipPattern.Direction); err != nil { + return err + } + + return collectModelErrors(relationshipPattern) +} + +func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { + if projectionItem, typeOK := value.(*cypher.ProjectionItem); typeOK { + if projectionItem == nil { + return nil, fmt.Errorf("projection item is nil") + } + + if err := validateExpressionValue(projectionItem.Expression, "projection item"); err != nil { + return nil, err + } + + if projectionItem.Alias != nil { + if err := validateCypherSymbol(projectionItem.Alias.Symbol, "projection alias"); err != nil { + return nil, err + } + } + + if err := collectModelErrors(projectionItem); err != nil { + return nil, err + } + + return copyProjectionItem(projectionItem), nil + } + + if expression, err := projectionExpression(value); err != nil { + return nil, err + } else { + return cypher.NewProjectionItemWithExpr(copyExpression(expression)), nil + } +} + +func sortItemFromValue(value any) (*cypher.SortItem, error) { + if sortItem, typeOK := value.(*cypher.SortItem); typeOK { + if sortItem == nil { + return nil, fmt.Errorf("sort item is nil") + } + + if err := validateExpressionValue(sortItem.Expression, "sort item"); err != nil { + return nil, err + } + + if err := collectModelErrors(sortItem); err != nil { + return nil, err + } + + return copySortItem(sortItem), nil + } + + if expression, err := projectionExpression(value); err != nil { + return nil, err + } else { + return &cypher.SortItem{ + Ascending: true, + Expression: copyExpression(expression), + }, nil + } +} + +func projectionItemsFromReturn(returnClause *cypher.Return) ([]*cypher.ProjectionItem, error) { + if returnClause == nil { + return nil, fmt.Errorf("return clause is nil") + } + + if returnClause.Projection == nil { + return nil, fmt.Errorf("return clause has nil projection") + } + + if err := validateProjectionMetadata(returnClause.Projection); err != nil { + return nil, err + } + + projectionItems := make([]*cypher.ProjectionItem, 0, len(returnClause.Projection.Items)) + + for _, returnItem := range returnClause.Projection.Items { + if projectionItem, err := projectionItemFromValue(returnItem); err != nil { + return nil, err + } else { + projectionItems = append(projectionItems, projectionItem) + } + } + + return projectionItems, nil +} + +func validateProjectionMetadata(projection *cypher.Projection) error { + if projection.Order != nil { + if _, err := sortItemsFromOrder(projection.Order); err != nil { + return err + } + } + + if projection.Skip != nil { + if err := validateExpressionValue(projection.Skip.Value, "projection skip"); err != nil { + return err + } + } + + if projection.Limit != nil { + if err := validateExpressionValue(projection.Limit.Value, "projection limit"); err != nil { + return err + } + } + + return nil +} + +func sortItemsFromOrder(order *cypher.Order) ([]*cypher.SortItem, error) { + if order == nil { + return nil, fmt.Errorf("order is nil") + } + + sortItems := make([]*cypher.SortItem, 0, len(order.Items)) + + for _, sortItem := range order.Items { + if normalizedSortItem, err := sortItemFromValue(sortItem); err != nil { + return nil, err + } else { + sortItems = append(sortItems, normalizedSortItem) + } + } + + return sortItems, nil +} + +type identifierSet struct { + identifiers map[string]struct{} +} + +func newIdentifierSet(identifiers ...string) *identifierSet { + set := &identifierSet{ + identifiers: map[string]struct{}{}, + } + + for _, identifier := range identifiers { + set.Add(identifier) + } + + return set +} + +func (s *identifierSet) Add(identifier string) { + s.identifiers[identifier] = struct{}{} +} + +func (s *identifierSet) Len() int { + return len(s.identifiers) +} + +func (s *identifierSet) Clone() *identifierSet { + clone := newIdentifierSet() + clone.Or(s) + return clone +} + +func (s *identifierSet) Or(other *identifierSet) { + if s == nil || other == nil { + return + } + + for otherIdentifier := range other.identifiers { + s.identifiers[otherIdentifier] = struct{}{} + } +} + +func (s *identifierSet) Remove(other *identifierSet) { + if s == nil || other == nil { + return + } + + for otherIdentifier := range other.identifiers { + delete(s.identifiers, otherIdentifier) + } +} + +func (s *identifierSet) Contains(identifier string) bool { + _, containsIdentifier := s.identifiers[identifier] + return containsIdentifier +} + +func (s *identifierSet) FirstOutside(allowed *identifierSet) (string, bool) { + var identifiers []string + + for identifier := range s.identifiers { + if !allowed.Contains(identifier) { + identifiers = append(identifiers, identifier) + } + } + + sort.Strings(identifiers) + + if len(identifiers) == 0 { + return "", false + } + + return identifiers[0], true +} + +func (s *identifierSet) CollectFromExpression(expr cypher.Expression) error { + if exprIdentifiers, err := extractCypherIdentifiers(expr); err != nil { + return err + } else { + s.Or(exprIdentifiers) + return nil + } +} + +func (s *identifierSet) CollectFromValue(value any) error { + switch typedValue := value.(type) { + case nil: + return nil + + case QualifiedExpression: + return s.CollectFromExpression(typedValue.qualifier()) + + case *cypher.ProjectionItem: + if projectionItem, err := projectionItemFromValue(typedValue); err != nil { + return err + } else { + return s.CollectFromExpression(projectionItem) + } + + case *cypher.Return: + if projectionItems, err := projectionItemsFromReturn(typedValue); err != nil { + return err + } else { + for _, projectionItem := range projectionItems { + if err := s.CollectFromExpression(projectionItem); err != nil { + return err + } + } + } + + if err := s.CollectFromProjectionMetadata(typedValue.Projection); err != nil { + return err + } + + case *cypher.Order: + if sortItems, err := sortItemsFromOrder(typedValue); err != nil { + return err + } else { + for _, sortItem := range sortItems { + if err := s.CollectFromExpression(sortItem); err != nil { + return err + } + } + } + + case *cypher.SortItem: + if sortItem, err := sortItemFromValue(typedValue); err != nil { + return err + } else { + return s.CollectFromExpression(sortItem) + } + + case *cypher.Set: + if setItems, err := setItemsFromSet(typedValue); err != nil { + return err + } else { + for _, setItem := range setItems { + if err := s.CollectFromExpression(setItem); err != nil { + return err + } + } + } + + case *cypher.SetItem: + if setItem, err := setItemFromValue(typedValue); err != nil { + return err + } else { + return s.CollectFromExpression(setItem) + } + + case *cypher.Remove: + if removeItems, err := removeItemsFromRemove(typedValue); err != nil { + return err + } else { + for _, removeItem := range removeItems { + if err := s.CollectFromValue(removeItem); err != nil { + return err + } + } + } + + case *cypher.RemoveItem: + if removeItem, err := removeItemFromValue(typedValue); err != nil { + return err + } else if removeItem.KindMatcher != nil { + return s.CollectFromExpression(removeItem.KindMatcher) + } else { + return s.CollectFromExpression(removeItem) + } + + case *cypher.NodePattern: + if err := validateNodePattern(typedValue); err != nil { + return err + } + + return s.CollectFromExpression(typedValue) + + case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedValue); err != nil { + return err + } + + return s.CollectFromExpression(typedValue) + + case *cypher.Variable: + return s.CollectFromExpression(typedValue) + + case *cypher.FunctionInvocation: + return s.CollectFromExpression(typedValue) + + case *cypher.PropertyLookup: + return s.CollectFromExpression(typedValue) + + case []any: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []cypher.SyntaxNode: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []cypher.Expression: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []*cypher.SetItem: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []*cypher.RemoveItem: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + default: + return nil + } + + return nil +} + +func (s *identifierSet) CollectFromProjectionMetadata(projection *cypher.Projection) error { + if projection == nil { + return nil + } + + if projection.Order != nil { + if err := s.CollectFromValue(projection.Order); err != nil { + return err + } + } + + if projection.Skip != nil { + if err := s.CollectFromExpression(projection.Skip.Value); err != nil { + return err + } + } + + if projection.Limit != nil { + if err := s.CollectFromExpression(projection.Limit.Value); err != nil { + return err + } + } + + return nil +} + +func collectIdentifiersFromValues(values ...any) (*identifierSet, error) { + identifiers := newIdentifierSet() + + for _, value := range values { + if err := identifiers.CollectFromValue(value); err != nil { + return nil, err + } + } + + return identifiers, nil +} + +type createScope struct { + identifiers *identifierSet + createsRelationship bool +} + +func collectCreateScope(identifiers runtimeIdentifiers, values ...any) (*createScope, error) { + scope := &createScope{ + identifiers: newIdentifierSet(), + } + + for _, value := range values { + switch typedValue := value.(type) { + case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedValue); err != nil { + return nil, err + } + + scope.createsRelationship = true + scope.identifiers.Add(identifiers.start) + scope.identifiers.Add(identifiers.end) + + if typedValue.Variable != nil { + scope.identifiers.Add(typedValue.Variable.Symbol) + } + + default: + if err := scope.identifiers.CollectFromValue(value); err != nil { + return nil, err + } + } + } + + return scope, nil +} + +type identifierExtractor struct { + walk.Visitor[cypher.SyntaxNode] + + seen *identifierSet +} + +func newIdentifierExtractor() *identifierExtractor { + return &identifierExtractor{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + seen: newIdentifierSet(), + } +} + +func (s *identifierExtractor) Enter(node cypher.SyntaxNode) { + switch typedNode := node.(type) { + case *cypher.Variable: + s.seen.Add(typedNode.Symbol) + + case *cypher.NodePattern: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.RelationshipPattern: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.PatternPart: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + } +} + +func extractCypherIdentifiers(expression cypher.Expression) (*identifierSet, error) { + var ( + identifierExtractorVisitor = newIdentifierExtractor() + err = walk.Cypher(expression, identifierExtractorVisitor) + ) + + return identifierExtractorVisitor.seen, err +} + +func collectModelErrors(node cypher.SyntaxNode) error { + var modelErrors []error + + if err := walk.Cypher(node, walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + if errorNode, typeOK := node.(cypher.Fallible); typeOK { + modelErrors = append(modelErrors, errorNode.Errors()...) + } + })); err != nil { + modelErrors = append(modelErrors, err) + } + + return errors.Join(modelErrors...) +} + +func collectModelErrorsFromKnownValues(values ...any) error { + var modelErrors []error + + for _, value := range values { + switch typedValue := value.(type) { + case nil: + continue + + case []any: + if err := collectModelErrorsFromKnownValues(typedValue...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []cypher.SyntaxNode: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []cypher.Expression: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.SetItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.RemoveItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.ProjectionItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.SortItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.NodePattern: + if err := validateNodePattern(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Order: + if _, err := sortItemsFromOrder(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.ProjectionItem: + if _, err := projectionItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Return: + if _, err := projectionItemsFromReturn(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Remove: + if _, err := removeItemsFromRemove(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.RemoveItem: + if _, err := removeItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Set: + if _, err := setItemsFromSet(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.SetItem: + if _, err := setItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.SortItem: + if _, err := sortItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.ArithmeticExpression, + *cypher.Comparison, + *cypher.Conjunction, + *cypher.Create, + *cypher.Delete, + *cypher.Disjunction, + *cypher.ExclusiveDisjunction, + *cypher.FilterExpression, + *cypher.FunctionInvocation, + *cypher.IDInCollection, + *cypher.KindMatcher, + *cypher.ListLiteral, + *cypher.Negation, + *cypher.Parenthetical, + *cypher.PatternPredicate, + *cypher.PropertyLookup, + *cypher.UnaryAddOrSubtractExpression, + *cypher.UpdatingClause, + *cypher.Variable: + if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + } + } + + return errors.Join(modelErrors...) +} + +func anySlice[T any](values []T) []any { + items := make([]any, len(values)) + + for idx, value := range values { + items[idx] = value + } + + return items +} + +type parameterMaterializer struct { + walk.Visitor[cypher.SyntaxNode] + + parameters map[string]any + nextIndex int +} + +func newParameterMaterializer(parameters map[string]any) *parameterMaterializer { + materializedParameters := map[string]any{} + + for symbol, value := range parameters { + materializedParameters[symbol] = value + } + + return ¶meterMaterializer{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + parameters: materializedParameters, + } +} + +func (s *parameterMaterializer) nextSymbol() string { + for { + symbol := "p" + strconv.Itoa(s.nextIndex) + s.nextIndex++ + + if _, taken := s.parameters[symbol]; !taken { + return symbol + } + } +} + +func (s *parameterMaterializer) Enter(node cypher.SyntaxNode) { + parameter, typeOK := node.(*cypher.Parameter) + if !typeOK { + return + } + + if parameter.Symbol == "" { + parameter.Symbol = s.nextSymbol() + } + + if existingValue, exists := s.parameters[parameter.Symbol]; exists && !reflect.DeepEqual(existingValue, parameter.Value) { + s.SetErrorf("parameter %s is bound to multiple values", parameter.Symbol) + return + } + + s.parameters[parameter.Symbol] = parameter.Value +} + +type namedParameterCollector struct { + walk.Visitor[cypher.SyntaxNode] + + parameters map[string]any +} + +func newNamedParameterCollector() *namedParameterCollector { + return &namedParameterCollector{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + parameters: map[string]any{}, + } +} + +func (s *namedParameterCollector) Enter(node cypher.SyntaxNode) { + parameter, typeOK := node.(*cypher.Parameter) + if !typeOK || parameter.Symbol == "" { + return + } + + if err := validateCypherSymbol(parameter.Symbol, "parameter"); err != nil { + s.SetError(err) + return + } + + if existingValue, exists := s.parameters[parameter.Symbol]; exists && !reflect.DeepEqual(existingValue, parameter.Value) { + s.SetErrorf("parameter %s is bound to multiple values", parameter.Symbol) + return + } + + s.parameters[parameter.Symbol] = parameter.Value +} + +func collectNamedParameters(query *cypher.RegularQuery) (map[string]any, error) { + collector := newNamedParameterCollector() + + if err := walk.Cypher(query, collector); err != nil { + return nil, err + } + + return collector.parameters, nil +} + +func materializeParameters(query *cypher.RegularQuery) (map[string]any, error) { + namedParameters, err := collectNamedParameters(query) + if err != nil { + return nil, err + } + + materializer := newParameterMaterializer(namedParameters) + + if err := walk.Cypher(query, materializer); err != nil { + return nil, err + } + + return materializer.parameters, nil +} diff --git a/query/v2/util_internal_test.go b/query/v2/util_internal_test.go new file mode 100644 index 00000000..1ae62236 --- /dev/null +++ b/query/v2/util_internal_test.go @@ -0,0 +1,20 @@ +package v2 + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIdentifierSetOrAndRemoveNilSafe(t *testing.T) { + var nilSet *identifierSet + + nilSet.Or(newIdentifierSet("ignored")) + nilSet.Remove(newIdentifierSet("ignored")) + + set := newIdentifierSet("kept") + set.Or(nil) + set.Remove(nil) + + require.True(t, set.Contains("kept")) +}