diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..5405dbe --- /dev/null +++ b/.editorconfig @@ -0,0 +1,21 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true +indent_style = space +indent_size = 2 + +[*.java] +indent_size = 4 + +[*.go] +indent_style = tab + +[Makefile] +indent_style = tab + +[*.md] +trim_trailing_whitespace = false diff --git a/.gitattributes b/.gitattributes index 4c1ba34..68b490b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,13 +33,11 @@ CODEOWNERS text eol=lf mvnw text eol=lf mvnw.cmd text eol=crlf -# --- Model artifacts (Git LFS) --- -*.onnx filter=lfs diff=lfs merge=lfs -text -*.gguf filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text +# --- Model artifacts --- +# Model weights are NOT committed to source — fetched at build time by +# scripts/fetch_models.py and embedded in the published Maven JAR. See +# docs/ARCHITECTURE.md §3.1 (model JAR layout) and CONTRIBUTING.md +# (Air-gapped contributors). No Git LFS. # --- Native binary libraries (not LFS — small and version-coupled to source) --- *.dll binary @@ -47,6 +45,9 @@ mvnw.cmd text eol=crlf *.dylib binary *.lib binary *.a binary +*.bin binary +*.onnx binary +*.gguf binary # --- Images / fonts --- *.png binary diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..0c312ec --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,11 @@ +# inference-sdk — code ownership +# https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners +# +# Default owner — every file requires aksOps review unless overridden below. +* @aksOps + +# Forked native binding (Tier 0) — high-blast-radius, security-sensitive +native/kherud-fork/ @aksOps + +# CI workflows + GitHub config +.github/ @aksOps diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..3a115c1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,51 @@ +--- +name: Bug report +about: Report a defect in inference-sdk (Java) +title: "bug: " +labels: ["bug", "triage"] +assignees: [] +--- + + + +## Summary + + + +## Reproducible example + + + +```java +// Embedder / Generator setup, inputs, expected vs actual +``` + +## Expected behavior + + + +## Actual behavior + + + +``` + +``` + +## Environment + +- inference-sdk version: +- JDK version (`java -version`): +- OS / kernel / arch (`uname -a`): +- glibc version (`ldd --version`): +- Maven version (`./mvnw -version`): +- Container / VM (yes/no, image): +- Model artifact id and SHA-256: + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..e3c667c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,13 @@ +blank_issues_enabled: false +contact_links: + - name: Security vulnerability + url: https://github.com/RandomCodeSpace/inference-sdk/security/advisories/new + about: >- + Report security issues PRIVATELY via GitHub Security Advisories. + Do NOT open a public issue. See SECURITY.md for the full policy. + - name: Discussions / questions + url: https://github.com/RandomCodeSpace/inference-sdk/discussions + about: >- + Open-ended questions, integration help, design discussion. For + reproducible bugs and concrete feature requests, please use the + issue templates instead. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..9b89c1a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,42 @@ +--- +name: Feature request +about: Suggest a new capability for inference-sdk +title: "feat: " +labels: ["enhancement", "triage"] +assignees: [] +--- + +## Problem + + + +## Proposed solution + + + +```java +// Builder change, new method, new record, new exception, etc. +``` + +## Alternatives considered + + + +## Compatibility + +- [ ] This change is fully backward compatible +- [ ] This change requires a major version bump (breaks public API) +- [ ] This change affects the wire format (`docs/WIRE_FORMAT.md`) +- [ ] This change affects the model registry (`docs/MODEL_REGISTRY.md`) + +## Phase fit + +- [ ] Phase 1 (library only) +- [ ] Phase 1.5 (additional native targets / models) +- [ ] Phase 2 (HTTP layer / OpenAI-compatible) +- [ ] Future / unscheduled + +## Additional context + + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..1af73b0 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,45 @@ + + +## Summary + + + +## Design / spec reference + + + +## Type of change + +- [ ] feat — new user-facing capability +- [ ] fix — bug fix +- [ ] refactor — internal restructuring, no behavior change +- [ ] chore — build / CI / tooling +- [ ] docs — documentation only +- [ ] test — test-only change + +## Verification checklist + +- [ ] `./mvnw -B verify` passes locally +- [ ] `./mvnw dependency-check:check` clean (no CVSS >= 7) +- [ ] `./mvnw spotless:check spotbugs:check` clean +- [ ] New code has unit tests; integration tests where boundaries are crossed +- [ ] JaCoCo line >= 75% / branch >= 70% on touched modules +- [ ] JavaDoc on every public type / method introduced +- [ ] No runtime network calls added (offline guarantee preserved) +- [ ] No new `*.md` agent-generated artifacts committed (planning, scratch, etc.) +- [ ] No force-pushes to shared branches; no direct commits to `main` + +## Forward-compat (Phase 1 -> Phase 2) + +- [ ] Reserved fields (`Message.toolCalls/toolCallId/name`, + `GenerateRequest.tools/toolChoice/responseFormat`) still throw + `FeatureNotSupportedException` when non-null +- [ ] Wire format unchanged or `docs/WIRE_FORMAT.md` updated in lockstep + +## Notes for reviewer + + diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..caf4414 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,74 @@ +# Dependabot — keep CI actions, Maven deps, and Python script deps in sync. +# Spec §0 + §10. +# +# Ecosystems supported: github-actions, maven, pip. +# Standard Maven dependabot at `/java` will surface `de.kherud:llama` version +# bumps naturally — no separate watcher needed. +version: 2 + +updates: + # ----- GitHub Actions ----- + - package-ecosystem: github-actions + directory: "/" + schedule: + interval: daily + open-pull-requests-limit: 5 + labels: + - dependencies + - github-actions + commit-message: + prefix: "chore(actions)" + include: scope + + # ----- Maven (Java SDK) ----- + - package-ecosystem: maven + directory: "/java" + schedule: + interval: weekly + day: monday + open-pull-requests-limit: 10 + labels: + - dependencies + - java + commit-message: + prefix: "chore(deps)" + include: scope + # Group safe semver bumps to reduce PR noise. + groups: + java-test-deps: + patterns: + - "org.junit*" + - "org.assertj*" + - "org.mockito*" + - "org.awaitility*" + - "nl.jqno.equalsverifier*" + - "net.jqwik*" + - "ch.qos.logback*" + update-types: + - patch + - minor + maven-plugins: + patterns: + - "org.apache.maven.plugins*" + - "com.diffplug.spotless*" + - "com.github.spotbugs*" + - "org.jacoco*" + - "org.codehaus.mojo*" + - "org.owasp*" + update-types: + - patch + - minor + + # ----- Python (scripts) ----- + - package-ecosystem: pip + directory: "/scripts" + schedule: + interval: weekly + day: monday + open-pull-requests-limit: 5 + labels: + - dependencies + - python + commit-message: + prefix: "chore(scripts)" + include: scope diff --git a/.github/workflows/java-ci.yml b/.github/workflows/java-ci.yml new file mode 100644 index 0000000..2176433 --- /dev/null +++ b/.github/workflows/java-ci.yml @@ -0,0 +1,149 @@ +# Java CI — inference-sdk Phase 1 +# +# Target wall time: under 5 minutes for default tiny models (spec §12). +# Triggers on java/, scripts/, and root POM changes. +name: java-ci + +on: + push: + branches: [main] + paths: + - "java/**" + - "scripts/**" + - "pom.xml" + - "**/*.xml" + - ".github/workflows/java-ci.yml" + pull_request: + paths: + - "java/**" + - "scripts/**" + - "pom.xml" + - "**/*.xml" + - ".github/workflows/java-ci.yml" + +concurrency: + group: java-ci-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + packages: read + +jobs: + verify: + name: verify (${{ matrix.runner }}) + strategy: + fail-fast: false + matrix: + include: + - runner: ubuntu-latest + arch: amd64 + - runner: ubuntu-22.04-arm + arch: arm64 + runs-on: ${{ matrix.runner }} + timeout-minutes: 15 + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Set up JDK 25 (Temurin) + uses: actions/setup-java@v5 + with: + distribution: temurin + java-version: "25" + cache: maven + + # PR verify is FAST: skip the heavy fetch+convert+quantize pipeline + # (Qwen GGUF conversion needs torch + numpy + sentencepiece + transformers, + # llama.cpp clone + cmake + gcc — minutes-to-hours of work that does NOT + # validate SDK code). The full build, including model fetch and the + # @Tag("model") integration tests, runs in `package-artifacts.yml` + # (manual + scheduled). PR verify exercises: + # - all 165 unit tests + 29 non-model IT tests across the reactor + # - spotless / spotbugs / OWASP / javadoc / network-isolation gates + # - module wiring for the model-bundle JARs (POM-only, no weights) + - name: Maven verify (skip model fetch) + run: ./mvnw -f java/pom.xml -B -ntp -e -Dfetch.models.skip=true verify + + - name: OWASP dependency-check + run: >- + ./mvnw -f java/pom.xml -B -ntp + -pl inference-sdk-parent,inference-sdk-core,inference-sdk-embed + org.owasp:dependency-check-maven:12.2.2:check + continue-on-error: false + + - name: Spotless check + run: >- + ./mvnw -f java/pom.xml -B -ntp + -pl inference-sdk-parent,inference-sdk-core,inference-sdk-embed + com.diffplug.spotless:spotless-maven-plugin:3.4.0:check + + - name: SpotBugs check + run: >- + ./mvnw -f java/pom.xml -B -ntp + -pl inference-sdk-parent,inference-sdk-core,inference-sdk-embed + com.github.spotbugs:spotbugs-maven-plugin:4.9.8.3:check + + - name: Upload test reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-reports-${{ matrix.arch }} + path: | + **/target/surefire-reports/** + **/target/failsafe-reports/** + retention-days: 7 + + network-isolation: + name: network-isolation (linux/amd64) + runs-on: ubuntu-latest + timeout-minutes: 10 + needs: verify + steps: + - uses: actions/checkout@v5 + + - uses: actions/setup-java@v5 + with: + distribution: temurin + java-version: "25" + cache: maven + + # Resolve all deps BEFORE we cut egress so the offline run has + # everything it needs locally. Model fetch is skipped here (same + # rationale as the verify job — full fetch runs in package-artifacts.yml). + - name: Resolve Maven dependencies (online) + run: ./mvnw -f java/pom.xml -B -ntp dependency:go-offline + + - name: Block egress and run verify offline + # Drop OUTPUT egress except loopback. Any runtime network call from the + # JVM under test will fail, exercising spec §11.2 case 47. + run: | + set -euo pipefail + sudo iptables -I OUTPUT -o lo -j ACCEPT + sudo iptables -A OUTPUT -m owner --uid-owner $(id -u) -j REJECT + ./mvnw -f java/pom.xml -B -ntp -o -Dfetch.models.skip=true verify -Pnetwork-isolation + sudo iptables -D OUTPUT -m owner --uid-owner $(id -u) -j REJECT || true + + javadoc: + name: javadoc (linux/amd64) + runs-on: ubuntu-latest + timeout-minutes: 10 + needs: verify + steps: + - uses: actions/checkout@v5 + + - uses: actions/setup-java@v5 + with: + distribution: temurin + java-version: "25" + cache: maven + + - name: Aggregate JavaDoc + run: ./mvnw -f java/pom.xml -B -ntp javadoc:aggregate + + - name: Upload JavaDoc artifact + uses: actions/upload-artifact@v4 + with: + name: javadoc + path: java/target/site/apidocs/ + retention-days: 14 diff --git a/.github/workflows/package-artifacts.yml b/.github/workflows/package-artifacts.yml new file mode 100644 index 0000000..93be695 --- /dev/null +++ b/.github/workflows/package-artifacts.yml @@ -0,0 +1,130 @@ +# Package artifacts — full model fetch + convert + quantize + IT +# +# This is the SLOW path that PR verify deliberately skips. It exercises +# the hybrid model-distribution pipeline end-to-end: +# - scripts/fetch_models.py downloads safetensors from HuggingFace +# - clones llama.cpp at the pinned LLAMA_CPP_TAG +# - converts HF -> GGUF f16 (needs torch + numpy + sentencepiece + transformers) +# - quantizes to q4_K_M (needs gcc + cmake + llama-quantize) +# - SHA-256 verifies against scripts/checksums/models.sha256 +# - runs the @Tag("model") integration tests with real model weights +# +# Triggers: +# - Manual (workflow_dispatch) — for ad-hoc artifact builds before release +# - Nightly (schedule) — catches upstream model/llama.cpp drift +# - On a release tag — produces the publishable artifact +# +# Per spec §12 the full pipeline target is "minutes" not "PR-verify-fast". +# This workflow has a 60-minute timeout; PR verify (java-ci.yml) stays +# under 5 minutes by skipping the fetch via -Dfetch.models.skip=true. +name: package-artifacts + +on: + workflow_dispatch: + schedule: + # 02:30 UTC daily — well clear of GHA peak load. + - cron: "30 2 * * *" + push: + tags: + - "v*" + +concurrency: + group: package-artifacts-${{ github.ref }} + cancel-in-progress: false + +permissions: + contents: read + packages: read + +jobs: + package: + name: package (${{ matrix.runner }}) + strategy: + fail-fast: false + matrix: + include: + - runner: ubuntu-latest + arch: amd64 + - runner: ubuntu-22.04-arm + arch: arm64 + runs-on: ${{ matrix.runner }} + timeout-minutes: 60 + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Set up JDK 25 (Temurin) + uses: actions/setup-java@v5 + with: + distribution: temurin + java-version: "25" + cache: maven + + - name: Set up Python 3.11 + uses: actions/setup-python@v6 + with: + python-version: "3.11" + + - name: Install build toolchain (cmake + gcc) + # llama.cpp's quantize binary is built from source at the pinned tag; + # the runner has gcc but not always cmake (varies by image). + run: | + set -euo pipefail + sudo apt-get update -yq + sudo apt-get install -yq --no-install-recommends cmake build-essential + + - name: Install Python script deps + GGUF converter deps + # convert_hf_to_gguf.py (vendored from llama.cpp@LLAMA_CPP_TAG) + # imports torch + numpy + sentencepiece + transformers + protobuf. + # Kept out of scripts/requirements.txt because the SDK itself + # never needs them — they are build-host-only. + run: | + set -euo pipefail + pip install --upgrade pip + pip install -r scripts/requirements.txt + pip install \ + "torch>=2.4,<3.0" \ + "numpy>=1.26,<3.0" \ + "sentencepiece>=0.2,<1.0" \ + "transformers>=4.45,<5.0" \ + "protobuf>=4.25,<6.0" + + - name: Cache HF + staged model artifacts + uses: actions/cache@v4 + with: + path: | + ~/.cache/huggingface + java/inference-sdk-embed-bge-small/src/main/resources/models + java/inference-sdk-generate-qwen-0_5b/src/main/resources/models + build/llama.cpp + key: hf-models-${{ runner.os }}-${{ matrix.arch }}-${{ hashFiles('scripts/checksums/models.sha256', 'scripts/fetch_models.py') }} + restore-keys: | + hf-models-${{ runner.os }}-${{ matrix.arch }}- + hf-models-${{ runner.os }}- + + - name: Maven verify (full pipeline, fetch + IT) + # No -Dfetch.models.skip — this is the WHOLE point of this workflow. + # Includes @Tag("model") IT tests against the staged GGUF files. + run: ./mvnw -f java/pom.xml -B -ntp -e verify + + - name: Verify staged model checksums (post-build) + run: python3 scripts/verify_models.py + + - name: Upload model JAR artifacts + uses: actions/upload-artifact@v4 + with: + name: model-jars-${{ matrix.arch }} + path: | + java/inference-sdk-embed-bge-small/target/*.jar + java/inference-sdk-generate-qwen-0_5b/target/*.jar + retention-days: 14 + + - name: Upload test reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: package-test-reports-${{ matrix.arch }} + path: | + **/target/surefire-reports/** + **/target/failsafe-reports/** + retention-days: 7 diff --git a/.github/workflows/scripts-ci.yml b/.github/workflows/scripts-ci.yml new file mode 100644 index 0000000..5d61ae3 --- /dev/null +++ b/.github/workflows/scripts-ci.yml @@ -0,0 +1,46 @@ +# Lint Python scripts under scripts/ — fast, no Java, no model deps. +name: scripts-ci + +on: + push: + branches: [main] + paths: + - "scripts/**" + - ".github/workflows/scripts-ci.yml" + pull_request: + paths: + - "scripts/**" + - ".github/workflows/scripts-ci.yml" + +concurrency: + group: scripts-ci-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + lint: + name: lint + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - uses: actions/checkout@v5 + + - uses: actions/setup-python@v6 + with: + python-version: "3.11" + cache: pip + cache-dependency-path: scripts/requirements.txt + + - name: Install lint tooling + run: pip install -r scripts/requirements.txt + + - name: Compile-check + run: python3 -m py_compile scripts/fetch_models.py scripts/verify_models.py + + - name: Ruff + run: ruff check scripts/ + + - name: Pyright + run: pyright --warnings scripts/ diff --git a/.gitignore b/.gitignore index 9b31615..c867016 100644 --- a/.gitignore +++ b/.gitignore @@ -40,10 +40,21 @@ $RECYCLE.BIN/ hs_err_pid*.log replay_pid*.log -# --- Models directory (LFS-tracked artifacts live in module resource dirs, this top-level dir is dev-only) --- +# --- Models directory (top-level scratch; module-level artifacts ignored separately below) --- models/* !models/.gitkeep +# --- Fetched model artifacts (build-time, never committed) --- +# Populated by scripts/fetch_models.py during the Maven `generate-resources` +# phase and embedded directly into the published Maven JAR. The +# model-manifest.properties and .gitkeep stay tracked. +**/src/main/resources/models/*.onnx +**/src/main/resources/models/*.gguf +**/src/main/resources/models/*.safetensors +**/src/main/resources/models/*.bin +**/src/main/resources/models/*.pt +**/src/main/resources/models/*.pth + # --- Agent-generated working artifacts (per user policy) --- .research/ .planning/ diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties new file mode 100644 index 0000000..887478f --- /dev/null +++ b/.mvn/wrapper/maven-wrapper.properties @@ -0,0 +1,8 @@ +# Apache Maven Wrapper config — inference-sdk +# Pin Maven 3.9.15 (spec: locked decision #1) and the wrapper jar version. +# distributionType=bin uses the maven-wrapper.jar bootstrap so the build +# host need not have a system Maven installed. +wrapperVersion=3.3.4 +distributionType=bin +distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.9.15/apache-maven-3.9.15-bin.zip +wrapperUrl=https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.3.2/maven-wrapper-3.3.2.jar diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..9035737 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,85 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [INSERT CONTACT METHOD]. All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of actions. + +**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..42ee1f5 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,147 @@ +# Contributing to inference-sdk + +Thanks for your interest in contributing. This document covers everything +you need to start working on the project. + +## Setup + +**Required toolchain** + +- **JDK 25 (Eclipse Temurin)** — install from + [adoptium.net](https://adoptium.net/temurin/releases/?version=25). Verify + with `java -version`; it should report `25` and the Temurin vendor. +- **Maven Wrapper** — invoke as `./mvnw` from the repo root or any module + directory. The Wrapper downloads Maven `3.9.15` on first run; you do not + need a system Maven install. +- **Python 3.11+** — required for the Maven build. Model weights are + fetched at build time (no Git LFS) by `scripts/fetch_models.py`, + bound to Maven's `generate-resources` phase in each model-JAR POM. + The Java build invokes Python via the `exec-maven-plugin`. +- **C toolchain (gcc/clang + cmake)** — required for the generative + model JAR (`inference-sdk-generate-qwen-0_5b`) only, because the + build vendors `llama.cpp` and quantizes `q4_K_M` GGUF locally. Not + needed for the embedding model JAR (pure ONNX). Ubuntu/Debian: + `sudo apt-get install build-essential cmake`. + +**Verify the environment** + +```sh +java -version # 25.x.x, Temurin +./mvnw -v # Maven 3.9.15+, JDK 25 +python3 --version # 3.11+ +cmake --version # 3.20+ (only if building the generative model JAR) +gcc --version # any recent (only if building the generative model JAR) +``` + +## Air-gapped contributors + +Model weights are NOT committed to source. They are fetched at build +time and embedded into the published Maven JAR — so end users on +air-gapped networks always get the bytes through Maven Central / their +internal mirror. **Contributors** building from source, however, do +need a one-shot online fetch. Two options: + +1. **Pre-built cache from a maintainer.** Ask a maintainer for a + pre-built `target/` cache (or a `~/.cache/huggingface/` snapshot + plus the staged `**/src/main/resources/models/` directories). Drop + the staged files into your tree and re-run `./mvnw verify` — + `scripts/verify_models.py` will confirm the SHA-256 pins and the + build will skip the network fetch on subsequent runs (cache valid + while `scripts/checksums/models.sha256` is unchanged). + +2. **Run `make fetch-models` on a connected machine, then copy.** On + any machine with internet access: + + ```sh + pip install -r scripts/requirements.txt + make fetch-models # populates ~/.cache/huggingface/ and + # **/src/main/resources/models/ + ``` + + Then copy `~/.cache/huggingface/` plus the populated + `java/*/src/main/resources/models/` directories into your + air-gapped environment. From there `./mvnw verify` runs offline + (the fetch script is idempotent — when the staged file already + matches the SHA-256 pin it is reused as-is). + +Set `-Dfetch.models.skip=true` to skip the fetch entirely when running +POM-only builds (e.g. IDE import). The default is `false` so a +fresh `./mvnw verify` produces a complete JAR. + +## Branching and Pull Requests + +- Cut feature branches off `main`. Direct commits to `main` are not allowed. +- Keep one logical change per pull request. Split unrelated work. +- Use **conventional-commit style** subjects: + - `feat:` new user-facing behavior + - `fix:` bug fix + - `docs:` documentation only + - `chore:` build, infra, deps + - `test:` test-only changes + - `refactor:` no behavior change + - `build:` build system or external dependencies +- When a PR implements a section of the design doc, reference it in the + description (for example: `Design: java-sdk.md §6.2`). +- Squash on merge unless the commit history is intentionally meaningful. + +## Tests + +The project uses three layers; new logic must include the appropriate ones. + +| Layer | Plugin | Naming | Default? | +|-------------|----------------|---------------|----------| +| Unit | Surefire | `*Test.java` | yes | +| Integration | Failsafe | `*IT.java` | yes | +| Slow | Failsafe | `@Tag("slow")` | excluded | + +- New edge cases follow the numbered convention from `java-sdk.md` §11.2; + reuse the existing numbering and add the next free index. +- Run the full suite locally with `./mvnw verify` before pushing. +- Slow tests run in CI nightly and on demand via `./mvnw verify -Pslow`. +- A flaky test is a broken test. Fix, quarantine with a tracked issue, or + delete it in the **same** PR. Do not merge a flake-causing change. + +## Code Style + +- **Spotless / google-java-format** — `./mvnw spotless:apply` before pushing. + CI runs `spotless:check` and fails on violations. +- **SpotBugs HIGH** — every HIGH-priority finding must be clean. If you must + suppress one, use `@SuppressFBWarnings` with a `justification` argument + explaining why; PR review will scrutinize justifications. +- **Imports** — no wildcard imports; ordered per google-java-format. +- **Nullability** — use `Optional` for return types that can be empty; + avoid returning `null`. Annotate fields with `@Nullable` when applicable. + +## License Agreement and Sign-off + +- We do **not** require a CLA or DCO sign-off. +- By submitting a contribution you agree to release your changes under the + Apache License, Version 2.0 (the project license; see + [`LICENSE`](LICENSE)). +- If you import third-party code, include attribution in `NOTICE` and ensure + the upstream license is compatible (MIT, Apache-2.0, BSD-2/3 are fine; + GPL/AGPL require explicit approval). + +## Where to Find Work + +- Open issues with the **`good-first-issue`** label are scoped for + newcomers and link to the relevant design section. +- Larger features are tracked in the `docs/ARCHITECTURE.md` Roadmap; pick + one and open an issue before starting significant work. +- Adjacent tech debt that you spot during a fix should be filed as a + **follow-up issue**, not silently bundled into your diff (see + `~/.claude/CLAUDE.md` §4 — Scope discipline). + +## Native Binding + +Phase 1 consumes the published `de.kherud:llama:4.2.0` directly from +Maven Central. Binding remediation work (forking and bumping the bundled +llama.cpp) is tracked in Phase 1.5; see +[`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md#6-roadmap) Roadmap. + +## Questions? + +- For design questions: reference `java-sdk.md` and the design doc under + `docs/superpowers/specs/`, then open a discussion issue. +- For security concerns: follow [`SECURITY.md`](SECURITY.md) — do not open + a public issue. diff --git a/LIBRARY.md b/LIBRARY.md new file mode 100644 index 0000000..9780236 --- /dev/null +++ b/LIBRARY.md @@ -0,0 +1,223 @@ +# LIBRARY.md — inference-sdk Java public API (Phase 1) + +> Authoritative public surface for the **Java** implementation of +> `inference-sdk`. For design rationale, threat model, and the +> native-thread-pinning workaround in detail, see +> [`docs/ARCHITECTURE.md`](./docs/ARCHITECTURE.md). For the locked +> cross-language wire format, see +> [`docs/WIRE_FORMAT.md`](./docs/WIRE_FORMAT.md). + +## Module map + +| Maven artifact | Purpose | +|-----------------------------------------------|-------------------------------------------------------------------------| +| `inference-sdk-core` | Records (`Result`, `Failure`, `FinishReason`, `ModelInfo`, `Usage`) and runtime helpers (`ContainerCpu`, `NativeExecutor`, `RequestId`, `NativeLibLoader`). | +| `inference-sdk-embed` | `Embedder` interface + ONNX-Runtime impl + DJL HuggingFace tokenizer. | +| `inference-sdk-generate` | `Generator` interface + kherud-llama (llama.cpp) impl. | +| `inference-sdk-embed-bge-small` | Resource-only JAR shipping `bge-small-en-v1.5` ONNX (LFS payload). | +| `inference-sdk-generate-qwen-0_5b` | Resource-only JAR shipping `qwen2.5-0.5b-instruct` GGUF (LFS payload). | +| `inference-sdk-bundle` | Convenience fat JAR (everything above, shaded). | +| `inference-sdk-integration-tests` | Tier-5 cross-module edge-case + property tests; non-publishable. | + +The default JPMS module names follow the artifact ids: e.g. the embed +module is `io.github.randomcodespace.inference.embed`. Modulepath and +classpath consumers are both supported; the bundle JAR is unnamed +(see its `pom.xml` for the rationale). + +## Public API surface + +### Embedding (`io.github.randomcodespace.inference.embed`) + +```java +public interface Embedder extends AutoCloseable { + EmbedResult embed(List texts); + float[] embedOne(String text); + CompletableFuture embedAsync(List texts); + ModelInfo modelInfo(); + @Override void close(); + + static Builder builder(); +} +``` + +Records: `EmbedResult`, `EmbedStats`. Exceptions: `EmbedException` +(root), `InvalidInputException`, `ModelNotFoundException`, +`NativeLoadException`. + +### Generation (`io.github.randomcodespace.inference.generate`) + +```java +public interface Generator extends AutoCloseable { + GenerateResponse complete(GenerateRequest req); + CompletableFuture completeAsync(GenerateRequest req); + Flow.Publisher stream(GenerateRequest req); + ModelInfo modelInfo(); + @Override void close(); + + static Builder builder(); +} +``` + +Records: `Message`, `GenerateRequest`, `GenerateResponse`, +`GenerateChunk`, `GenerateStats`. Exceptions: `GenerateException` +(root), `QueueFullException`, `ModelNotLoadedException`, +`FeatureNotSupportedException`. + +### Runtime (`io.github.randomcodespace.inference.runtime`) + +`ContainerCpu`, `NativeExecutor`, `NativeLibLoader`, `RequestId`. The +last is a final class wrapping a `ScopedValue`; bind it via +`RequestId.withRequestId(id, callable)` so the value flows through any +child virtual threads (including those launched from a +`StructuredTaskScope`). + +## Examples + +### Synchronous + +```java +try (Embedder e = Embedder.builder().model("bge-small-en-v1.5").build(); + Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").build()) { + float[] v = e.embedOne("hello"); + GenerateResponse r = g.complete( + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi", null, null, null))) + .maxTokens(32) + .build()); + System.out.println(r.text()); +} +``` + +### Async + virtual threads + +```java +try (Embedder e = Embedder.builder().model("bge-small-en-v1.5").build(); + var exec = Executors.newVirtualThreadPerTaskExecutor()) { + + List> futs = inputs.stream() + .map(s -> exec.submit(() -> e.embedOne(s))) + .toList(); + for (Future f : futs) { + process(f.get()); + } +} +``` + +`embedAsync` / `completeAsync` return `CompletableFuture` — the SDK's +internal virtual-thread executor completes them; awaiting from a +caller virtual thread yields the carrier correctly. + +### Streaming + +```java +g.stream(req).subscribe(new Flow.Subscriber<>() { + public void onSubscribe(Flow.Subscription s) { s.request(Long.MAX_VALUE); } + public void onNext(GenerateChunk c) { System.out.print(c.delta()); } + public void onError(Throwable t) { t.printStackTrace(); } + public void onComplete() { System.out.println("done"); } +}); +``` + +The publisher honours `Subscription.request(n)` for backpressure and +`Subscription.cancel()` at the next-token boundary; a single terminal +chunk (`done == true`) carries the final `Usage` + `GenerateStats`. + +### Structured concurrency + RequestId + +```java +RequestId.withRequestId(RequestId.generate(), () -> { + try (var scope = new StructuredTaskScope.ShutdownOnFailure()) { + var f1 = scope.fork(() -> embedder.embedOne("q1")); + var f2 = scope.fork(() -> embedder.embedOne("q2")); + scope.join().throwIfFailed(); + return List.of(f1.get(), f2.get()); + } +}); +``` + +`RequestId.CURRENT` is a `ScopedValue` (JEP 487, finalised in Java 25), +so the bound id propagates through every fork inside the scope and +shows up in `EmbedStats.requestId()` / `GenerateStats.requestId()`. + +## Native-executor pattern + +llama.cpp and ONNX Runtime pin the calling carrier thread on JNI +entry; if we let virtual threads call them directly the JVM's carrier +pool would be saturated by long-running native calls. Every JNI call +the SDK makes is therefore trampolined through `NativeExecutor`, a +small bounded **platform-thread** pool sized to `ContainerCpu.detect()` +by default. + +``` +caller virtual thread + │ Embedder.embedOne(text) + ▼ ++---------------------+ submitNative() +------------------------+ +| Embedder facade | ────────────────────────► | NativeExecutor (N PT) | ++---------------------+ +------------------------+ + ▲ │ + │ CompletableFuture ▼ + │ resolved on virtual-thread executor +-------------------+ + │ | ONNX / llama.cpp | + │ | JNI (pinned PT) | + │ +-------------------+ + │ │ + └────────────────── result ◄─────────────────────────────┘ +``` + +The `NativeExecutor` is owned by the `Embedder` / `Generator` +instance and shut down by `close()`; threads return to the JVM +within ~2 s of close (per `inference-sdk-integration-tests` case +#37). See `docs/ARCHITECTURE.md` §3.3 for the full rationale. + +## Build-time model switching + +The default models are committed by `scripts/fetch_models.py`, which +hydrates the LFS payloads: + +```bash +make fetch-models # populate models/ + JAR resources +./mvnw -f java/pom.xml -B verify # default build +./mvnw -f java/pom.xml verify -P slow # include @Tag("model") tests +``` + +To swap the bundled embedding model, set the `embedding.model` Maven +property and (separately) drop the alternative ONNX file in place: + +```bash +./mvnw -f java/pom.xml verify \ + -P model-switch \ + -Dembedding.model=bge-base-en-v1.5 +``` + +The IT module's `ModelSwitchIT` (case #48) reads the property and +asserts the alternative model loads. Generation models switch the same +way through the `qwen.model` analogue — extend the resource JAR with a +new `.gguf` artifact and pass `-Dgeneration.model=` to the +`Generator.builder().model(...)` call site. + +`docs/MODEL_REGISTRY.md` enumerates approved models, their checksums, +and the licence of each weight. + +## Forward-compatibility guarantees + +`Message`, `GenerateRequest`, and `GenerateResponse` carry reserved +fields (`tool_calls`, `tool_call_id`, `name`, `tools`, `tool_choice`, +`response_format`, `system_fingerprint`) named today so the Phase-2 +HTTP layer can route OpenAI-style payloads without an API break. The +Phase-1 implementation rejects non-null reserved fields with +`FeatureNotSupportedException` and produces them as `null` from +generators (see IT cases #49–51). + +## Pointers + +- [`docs/ARCHITECTURE.md`](./docs/ARCHITECTURE.md) — design depth, + threat model, native-executor rationale. +- [`docs/WIRE_FORMAT.md`](./docs/WIRE_FORMAT.md) — locked snake_case + JSON shapes shared by Java + Go + the future HTTP layer. +- [`docs/MODEL_REGISTRY.md`](./docs/MODEL_REGISTRY.md) — approved + model list and checksums. +- [`SECURITY.md`](./SECURITY.md) — residual security risk + the + `InputValidator` boundary. +- [`java/examples/quickstart/`](./java/examples/quickstart/) — + runnable 30-line end-to-end demo. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..be3aef7 --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +.PHONY: help java-build java-test java-verify java-clean fetch-models verify-models + +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' + +java-build: ## Build all Java modules (skip tests) + cd java && ./mvnw -DskipTests install + +java-test: ## Run Java unit tests (Surefire) + cd java && ./mvnw test + +java-verify: ## Run unit + integration tests (Surefire + Failsafe) + cd java && ./mvnw verify + +java-clean: ## Clean Java build outputs + cd java && ./mvnw clean + +fetch-models: ## Download bundled models from Hugging Face into models/ + python3 scripts/fetch_models.py + +verify-models: ## Verify SHA-256 of bundled models against pins + python3 scripts/verify_models.py diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..574b683 --- /dev/null +++ b/NOTICE @@ -0,0 +1,93 @@ +inference-sdk +Copyright 2026 Amit Kumar and contributors + +This product includes software developed by the inference-sdk contributors +(https://github.com/RandomCodeSpace/inference-sdk). + +Licensed under the Apache License, Version 2.0 (the "License"); you may not +use this software except in compliance with the License. You may obtain a +copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +------------------------------------------------------------------------------ +Third-party components +------------------------------------------------------------------------------ + +This product bundles or depends on the following third-party software. Each +component retains its own copyright and is distributed under its respective +license terms; see the linked upstream sources for full text. + +ONNX Runtime (com.microsoft.onnxruntime:onnxruntime) + Project: https://github.com/microsoft/onnxruntime + License: MIT + Copyright (c) Microsoft Corporation. + +DJL HuggingFace Tokenizers (ai.djl.huggingface:tokenizers) + Project: https://github.com/deepjavalibrary/djl + License: Apache License, Version 2.0 + Copyright (c) Amazon.com, Inc. and DJL contributors. + +Jackson (com.fasterxml.jackson.core:jackson-databind, jackson-core, jackson-annotations) + Project: https://github.com/FasterXML/jackson + License: Apache License, Version 2.0 + Copyright (c) FasterXML, LLC. + +SLF4J (org.slf4j:slf4j-api) + Project: https://github.com/qos-ch/slf4j + License: MIT + Copyright (c) 2004-2026 QOS.ch. + +Logback (ch.qos.logback:logback-classic) -- TEST scope only + Project: https://github.com/qos-ch/logback + License: EPL 1.0 / LGPL 2.1 (dual) + Copyright (c) QOS.ch. + +JUnit Jupiter (org.junit.jupiter:junit-jupiter) -- TEST scope only + Project: https://github.com/junit-team/junit5 + License: Eclipse Public License, Version 2.0 + Copyright (c) The JUnit Team. + +AssertJ (org.assertj:assertj-core) -- TEST scope only + Project: https://github.com/assertj/assertj + License: Apache License, Version 2.0 + Copyright (c) AssertJ contributors. + +Mockito (org.mockito:mockito-core) -- TEST scope only + Project: https://github.com/mockito/mockito + License: MIT + Copyright (c) Mockito contributors. + +Awaitility (org.awaitility:awaitility) -- TEST scope only + Project: https://github.com/awaitility/awaitility + License: Apache License, Version 2.0 + Copyright (c) Awaitility contributors. + +EqualsVerifier (nl.jqno.equalsverifier:equalsverifier) -- TEST scope only + Project: https://github.com/jqno/equalsverifier + License: Apache License, Version 2.0 + Copyright (c) Jan Ouwens. + +jqwik (net.jqwik:jqwik) -- TEST scope only + Project: https://github.com/jqwik-team/jqwik + License: Eclipse Public License, Version 2.0 + Copyright (c) Johannes Link and jqwik contributors. + +de.kherud:llama (published Java binding to llama.cpp) + Project: https://github.com/kherud/java-llama.cpp + Artifact: de.kherud:llama:4.2.0 (Maven Central) + License: MIT + Copyright (c) Konstantin Herud and java-llama.cpp contributors. + Embeds llama.cpp b4916 (MIT, https://github.com/ggml-org/llama.cpp, + Copyright (c) ggml-org and llama.cpp contributors). See SECURITY.md + for the residual-risk register and Phase 1.5 remediation plan. + +BAAI/bge-small-en-v1.5 (bundled embedding model artifact) + Project: https://huggingface.co/BAAI/bge-small-en-v1.5 + License: MIT + Copyright (c) Beijing Academy of Artificial Intelligence (BAAI). + +Qwen/Qwen2.5-0.5B-Instruct (bundled chat model artifact) + Project: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct + License: Apache License, Version 2.0 + Copyright (c) Alibaba Cloud / Qwen team. diff --git a/README.md b/README.md index cba7749..1b72fe0 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,105 @@ # inference-sdk -Polyglot inference library for fully offline, text-only embedding and chat generation on CPU-only Linux, plus Windows and ARM64. +[![CI](https://img.shields.io/badge/CI-pending-lightgrey)](.github/workflows) +[![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) +[![JDK](https://img.shields.io/badge/JDK-25-orange)](https://adoptium.net/temurin/releases/?version=25) -This repository hosts implementations in multiple languages. Java is first; Go follows. Both implementations produce wire-compatible artifacts and observable behavior. +Polyglot inference library for fully offline, text-only embedding and chat +generation on CPU-only Linux, plus Windows and ARM64. + +This repository hosts implementations in multiple languages. Java is first; +Go follows. Both implementations produce wire-compatible artifacts and +observable behavior. ## Status -| Language | Status | Path | -|---|---|---| -| Java | 🚧 in development (Phase 1) | [`java/`](java/) | -| Go | 📋 planned | [`go/`](go/) | +| Language | Status | Path | +|----------|---------------------------------|------------------| +| Java | in development (Phase 1) | [`java/`](java/) | +| Go | planned | [`go/`](go/) | + +Phase 1 is **library-only** — embedding via ONNX Runtime + +`BAAI/bge-small-en-v1.5`; chat generation via the published +`de.kherud:llama:4.2.0` Java binding (with documented residual risk; +see [SECURITY.md](SECURITY.md)) + `Qwen/Qwen2.5-0.5B-Instruct` +(default). The HTTP / OpenAI-compatible layer is Phase 2. + +## Prerequisites + +- **JDK 25** (Eclipse Temurin recommended) — see + [adoptium.net](https://adoptium.net/temurin/releases/?version=25) +- **Git LFS** — bundled model artifacts are stored via LFS. Run + `git lfs install` once per machine, then clone (or + `git lfs pull` in an existing checkout). +- The Maven Wrapper (`./mvnw`) handles Maven `3.9.15` on first invocation; + no system Maven install is required. + +## 30-second Java quickstart + +Add the bundle artifact (Phase 1 ships a single fat-jar; per-module +artifacts also publish): + +```xml + + io.github.randomcodespace.inference + inference-sdk-bundle + 0.1.0-SNAPSHOT + +``` + +Use it: + +```java +import io.github.randomcodespace.inference.embed.Embedder; +import io.github.randomcodespace.inference.generate.Generator; + +try (Embedder embedder = Embedder.builder().build(); + Generator generator = Generator.builder().build()) { + + float[] vector = embedder.embed("hello, world"); + + String reply = generator.complete("Say hi in five words."); + System.out.println(reply); +} +``` + +Both builders default to the bundled models; configuration knobs (model +ID, thread count, context window, sampler) live on the builders. See +`docs/ARCHITECTURE.md` for the full surface. + +## Project structure -Phase 1 is **library-only** — embedding via ONNX Runtime + bge-small-en-v1.5; chat generation via a forked llama.cpp Java binding + Qwen 2.5-0.5B-Instruct (default). HTTP/OpenAI-compatible layer is Phase 2. +``` +inference-sdk/ + java/ Java implementation (Phase 1) + inference-sdk-core/ API, records, errors, model loader + inference-sdk-embed/ ONNX-Runtime-backed Embedder + inference-sdk-generate/ llama.cpp-backed Generator + inference-sdk-tokenize/ DJL HuggingFace tokenizers wrapper + inference-sdk-models-bge/ BAAI/bge-small-en-v1.5 model JAR + inference-sdk-models-qwen/ Qwen/Qwen2.5-0.5B-Instruct model JAR + inference-sdk-bundle/ fat-jar aggregator + inference-sdk-it/ integration + slow tests + scripts/ Python helpers (fetch_models, verify_models) + docs/ ARCHITECTURE, WIRE_FORMAT, MODEL_REGISTRY, + design specs (docs/superpowers/specs/) + go/ Go implementation (planned, Phase 2+) + models/ Git-LFS-tracked model artifact cache + .github/workflows/ CI: build, test, OWASP, SpotBugs, Spotless +``` ## Quick links - [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md) — cross-language design - [`docs/WIRE_FORMAT.md`](docs/WIRE_FORMAT.md) — JSON shapes shared across languages - [`docs/MODEL_REGISTRY.md`](docs/MODEL_REGISTRY.md) — canonical model IDs +- [`CONTRIBUTING.md`](CONTRIBUTING.md) — setup, branching, tests, code style +- [`SECURITY.md`](SECURITY.md) — reporting, threat model, mitigations +- [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md) — community standards - [`java/`](java/) — Java implementation +- [`java-sdk.md`](java-sdk.md) — Java SDK design notes ## License -Apache 2.0 — see [LICENSE](LICENSE). +Apache 2.0 — see [`LICENSE`](LICENSE) and [`NOTICE`](NOTICE) for +third-party attributions. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..250c846 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,108 @@ +# Security Policy + +## Reporting Security Vulnerabilities + +Please report security issues privately via GitHub Security Advisories: + +> https://github.com/RandomCodeSpace/inference-sdk/security/advisories/new + +Do **not** open public issues or pull requests for unpatched vulnerabilities. + +We aim to acknowledge new reports within **7 days** and to provide an initial +triage assessment (severity, reachability, planned remediation window) in the +same window. Critical issues are prioritized for an out-of-band patch release. + +When reporting, please include: + +- Affected version (commit SHA or release tag) +- Reproduction steps or proof-of-concept +- Observed vs. expected behavior +- Any known mitigations + +## Supported Versions + +| Version | Status | Security fixes | +|---------------------|-------------------|----------------| +| Phase 1 (in dev) | `main` branch only | yes | +| Older / experimental | not supported | no | + +Phase 1 is pre-release; the only supported reference is the `main` branch. +Tagged releases will appear once Phase 1 ships. + +## Threat Model Summary + +The full threat model lives in +[`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md#threat-model). Summary: + +**In scope** + +- Local execution of bundled embedding and chat models on a single host +- Tampering or substitution of bundled model artifacts and native libraries +- Malicious or malformed user input passed through the public API +- Resource exhaustion (memory, threads, queue depth) under bursty load +- Vulnerabilities in direct and transitive Java dependencies +- Vulnerabilities in the published `de.kherud:llama` binding and its embedded llama.cpp + +**Out of scope (Phase 1)** + +- Network-exposed inference endpoints (Phase 2) +- Multi-tenant isolation +- Side-channel attacks (timing, cache, power) on the host machine +- Adversarial prompt-injection defense beyond standard input validation +- Confidentiality of model weights once loaded into process memory + +## Mitigation Summary + +- **Model integrity** — every bundled model artifact is verified against a + pinned SHA-256 at load time; mismatches abort startup. +- **Native library integrity** — the `de.kherud:llama` `.so` / `.dll` / `.dylib` + payloads are SHA-256 verified before being loaded via JNI. +- **Input validation** — public API entry points use Java records with + validating constructors; out-of-range values fail fast at the boundary. +- **Bounded queues** — generation and embedding pipelines use bounded work + queues with explicit backpressure; no unbounded buffers or task pools. +- **No runtime network calls** — the SDK never reaches the public internet at + runtime. All assets are local and bundled per + [`rules/build.md`](https://github.com/RandomCodeSpace/inference-sdk). +- **OWASP Dependency-Check in CI** — every PR runs an OWASP scan; High and + Critical CVEs block the build. +- **SpotBugs HIGH gate** — security-relevant SpotBugs findings at HIGH + priority block the build. + +## Residual security risk + +Phase 1 consumes the published +[`de.kherud:llama:4.2.0`](https://central.sonatype.com/artifact/de.kherud/llama) +artifact directly from Maven Central. The bundled llama.cpp is `b4916` +(mid-2025), against which 5 reachable High GHSA advisories + 1 Moderate +remain open. Each is neutralized at the SDK boundary: + +| ID | Severity | Path | Mitigation | +|---------------|----------|-----------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------| +| GHSA-8wwf-... | High | `token_to_piece` overflow on vocab load | SHA-256 model allow-list (`scripts/checksums/models.sha256`) — only pinned vocabs ever load | +| GHSA-7rxv-... | High | tokenizer signed/unsigned overflow on adversarial prompt | Input length cap (configurable, default tied to `contextSize`) + UTF-8 validation in API records | +| GHSA-vgg9-... | High | GGUF size accumulator overflow | SHA-256 model allow-list — pinned GGUF files only | +| GHSA-96jg-... | High | `ggml_nbytes` overflow → potential RCE on model load | SHA-256 model allow-list — pinned GGUF files only | +| GHSA-3p4r-... | High | `mem_size` overflow bypass | SHA-256 model allow-list — pinned GGUF files only | +| GHSA-g4cc-... | Moderate | DoS over-read on malformed GGUF | SHA-256 model allow-list — pinned GGUF files only | + +4 of 5 Highs + the Moderate are GGUF model-load-path bugs, fully +neutralized by the SHA-256 allow-list enforced by +`NativeLibLoader.extractAndVerify`. The remaining tokenizer High +(`7rxv`) is narrowed to a non-exploitable surface by the input +validation layer. + +**Sign-off.** Path 1 (consume `de.kherud:llama:4.2.0` + the +mitigations above) was acknowledged in writing by the project owner +on 2026-05-09 per the mitigation clause in +`~/.claude/rules/security.md` ("document why not exploitable + get +explicit user sign-off"). Phase 1.5 will fork-and-bump to clear these +advisories upstream. + +## Binding remediation strategy + +Long-term remediation of the bundled-llama.cpp advisories is tracked +as the first task in Phase 1.5: fork `kherud/java-llama.cpp` and bump +the bundled llama.cpp to a current tag (≥`b8146`). See +[`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md#6-roadmap) Roadmap → +Phase 1.5 for the full plan. diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..313b7ec --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,535 @@ +# Architecture + +> Cross-language design master document for `inference-sdk`. This file +> is **shipped** (committed to the repo) and is read by every +> implementation language (Java first; Go next). + +--- + +## Table of contents + +1. [Overview](#1-overview) +2. [Cross-language contracts](#2-cross-language-contracts) +3. [Java-specific architecture](#3-java-specific-architecture) + 1. [Module layout](#31-module-layout) + 2. [Native binding strategy](#32-native-binding-strategy) + 3. [Native-thread-pinning workaround](#33-native-thread-pinning-workaround) + 4. [Streaming model](#34-streaming-model) + 5. [Lifecycle](#35-lifecycle) + 6. [Container-aware threading](#36-container-aware-threading) +4. [Threat model](#4-threat-model) +5. [Deviation register](#5-deviation-register) +6. [Roadmap](#6-roadmap) + +--- + +## 1. Overview + +`inference-sdk` is an **offline, library-only** embedding and +text-generation SDK. It runs entirely on the local machine; no +network calls are made at startup or during inference. + +Core properties: + +- **Library, not server.** Phase 1 ships a Java library you embed in + your application. An optional HTTP/OpenAI-compatible adapter is + reserved for Phase 2 (see [Roadmap](#6-roadmap)). +- **No auto-download.** Models ship inside the consuming application + via Maven artifacts (model JARs) or are supplied at runtime via + `Builder.modelPath(Path)` / the `INFERENCE_MODEL_DIR` env var. + Unknown model name + no path → `ModelNotFoundException` listing + what is present on the classpath. +- **Hybrid model distribution.** Source repo never commits binary + weights. `scripts/fetch_models.py` downloads + quantizes the + pinned models at build time (bound to Maven `generate-resources` in + each model JAR POM); `scripts/verify_models.py` enforces the + SHA-256 pin in `scripts/checksums/models.sha256` at + `process-resources`; the resulting bytes are embedded into the + published Maven JAR via the standard + `src/main/resources/models/` mechanism. End consumers receive + weights through Maven Central / their internal mirror — no Git + LFS, no runtime download, air-gapped guarantee preserved. +- **Air-gapped friendly.** All dependencies are vendored at build + time. Reproducible builds run inside a corporate firewall once an + internal Maven mirror is configured. No public CDN fetches at + runtime. +- **Polyglot.** Java is the reference implementation. The Go port + must match the wire format and model registry exactly so a single + payload is portable across both languages. +- **Forward-compatible.** Public records, finish-reason values, and + stats fields are locked now; reserved fields throw + `FeatureNotSupportedException` in Phase 1 but their wire shapes are + set so Phase 2 (HTTP) and the Go port can match without breaking + changes. + +The SDK's contract is: given a list of texts produce embedding +vectors; given a list of `Message`s produce a `GenerateResponse` (or a +stream of `GenerateChunk`s). Everything else — orchestration, +batching policy, queue management, API surface, telemetry — is +deliberately at the boundary the caller controls. + +--- + +## 2. Cross-language contracts + +Three documents form the cross-language contract. Every language +implementation MUST conform to them; mismatches are bugs. + +| Contract | Document | Locked at | +|---|---|---| +| JSON shapes (field names, types, encoding) | [`WIRE_FORMAT.md`](./WIRE_FORMAT.md) | Phase 1 | +| Model identifiers, licenses, and SHA-256 pins | [`MODEL_REGISTRY.md`](./MODEL_REGISTRY.md) | Phase 1 | +| Glossary of cross-cutting terms | [`GLOSSARY.md`](./GLOSSARY.md) | Phase 1 | + +### 2.1 Finish-reason mapping (cross-language + OpenAI) + +`FinishReason` is a sealed/closed enum. The wire encoding is the +lowercase string in column 1. Column 3 shows the value the Phase 2 +HTTP layer will emit on `/v1/chat/completions` (OpenAI-compatible). + +| Wire value | Java type | OpenAI mapping | When emitted | +|---|---|---|---| +| `"stop"` | `FinishReason.Stop` | `"stop"` | A user-supplied stop sequence matched. | +| `"length"` | `FinishReason.Length` | `"length"` | `max_tokens` reached. | +| `"eos"` | `FinishReason.Eos` | `"stop"` | Model emitted its end-of-sequence token. | +| `"canceled"` | `FinishReason.Canceled` | `"stop"` (with internal flag) | Caller called `Subscription.cancel()` on a streaming generation, OR a non-streaming call was interrupted. | +| `"error"` | `FinishReason.Error(message)` | n/a (error path uses HTTP 5xx) | Native runtime returned an error mid-generation. Wire form: `{"error": ""}` carried in the terminal chunk's `finish_reason`. | + +> Phase 1 does not serialize over the wire, but Java's Jackson +> mappers are configured **today** to emit these strings so Phase 2 +> and the Go port match without re-litigation. + +### 2.2 How the Go port must match + +The Go implementation will live under `go/` and must: + +1. Use the same JSON field names, types, and enum encodings as + `WIRE_FORMAT.md`. Go's `encoding/json` tags map to the + `snake_case` names directly. +2. Use the same model IDs from `MODEL_REGISTRY.md` and the same + SHA-256 pins. The Go port reads the same + `model-manifest.properties` files placed alongside its model + files (or embedded via `//go:embed`). +3. Provide an equivalent public API: `Embedder.Embed(ctx, texts) + ([]Vector, EmbedStats, error)` mirrors Java's `Embedder.embed` + (semantics, not method signatures). +4. Pass an equivalent test suite covering the same numbered edge + cases the Java suite covers. +5. Report the same `FinishReason` values; same lifecycle invariants + (no auto-download, SHA-256 verify on load, idempotent close). + +Any divergence required by language idioms (e.g., Go's `context.Context` +for cancellation vs Java's `Subscription.cancel()`) MUST be documented +inline in this file under §2.x and called out in the Go port's +`README.md`. Wire-level divergence is forbidden. + +--- + +## 3. Java-specific architecture + +### 3.1 Module layout + +Eight modules under `java/`. Modules are Maven JARs; **model +artifact modules contain no Java code** — only model files under +`src/main/resources/models/` plus a `model-manifest.properties`. + +``` +java/ +├── pom.xml (aggregator) +├── inference-sdk-parent/ parent POM, dependency mgmt +│ └── pom.xml +├── inference-sdk-core/ shared types + runtime helpers +│ ├── pom.xml +│ └── src/{main,test}/java/io/github/randomcodespace/inference/... +├── inference-sdk-embed/ Embedder API + ONNX wiring +├── inference-sdk-embed-bge-small/ MODEL JAR ONLY (no code) +├── inference-sdk-generate/ Generator API + de.kherud:llama wiring +├── inference-sdk-generate-qwen-0_5b/ MODEL JAR ONLY +├── inference-sdk-bundle/ Fat JAR (shade-plugin); convenience +└── inference-sdk-tests/ Cross-module integration tests +``` + +Dependency direction is strictly downward: + +``` +core ← embed ← embed-bge-small + generate ← generate-qwen-0_5b +core, embed, generate ← bundle +core, embed, generate, *model JARs ← tests +``` + +`inference-sdk-core` has zero non-JDK runtime dependencies (only +SLF4J as an optional logger interface). Higher-level modules pull in +ONNX Runtime (embed) or `de.kherud:llama:4.2.0` (generate) as needed. + +### 3.2 Native binding strategy + +#### Embedding — ONNX Runtime + +[`com.microsoft.onnxruntime:onnxruntime` 1.25.1](./MODEL_REGISTRY.md#bge-small-en-v15) +publishes a single CPU JAR that already bundles native libs for all +four target triples (linux-x64, linux-arm64, windows-x64, +macos-arm64). Used directly; no fork. + +#### Generation — published `de.kherud:llama:4.2.0` + +Phase 1 consumes the published +[`de.kherud:llama:4.2.0`](https://central.sonatype.com/artifact/de.kherud/llama) +artifact directly from Maven Central (MIT licensed; ships natives +for Win-x64 + Linux-x64 manylinux2014/glibc 2.17 + Linux-arm64 +dockcross-arm64-lts/glibc 2.27). The bundled llama.cpp is `b4916` +(mid-2025). The residual High/Moderate advisories in that bundled +llama.cpp are mitigated by the SHA-256 model allow-list (4 of 5 Highs ++ 1 Moderate are GGUF model-load-path bugs, neutralized by the pin) +plus prompt-input length cap and UTF-8 validation at the API boundary +(narrows the remaining tokenizer High). See +[`SECURITY.md`](../SECURITY.md) for the full residual-risk register +and Phase 1.5 remediation plan. + +**Selection rationale.** No published llama.cpp Java binding meets +all of the project's hard criteria simultaneously (Win-x64 + Linux-x64 +glibc 2.17 + Linux-arm64 glibc 2.27 + Apache/MIT license + active +maintenance + no Critical CVEs). `de.kherud:llama:4.2.0` is the only +binding shipping all three native classifiers from a single artifact; +the bundled llama.cpp tag (`b4916`) has 5 reachable High-severity +advisories in the C++ core, all neutralized at the SDK boundary by +the mitigations above. Phase 1.5 will fork-and-bump to clear the +underlying advisories upstream and add Gemma 4 architecture support. + +### 3.3 Native-thread-pinning workaround + +This is the **single most important architectural decision** in the +Java implementation, and it is non-obvious enough to call out in +prose. + +#### The problem + +JDK 25's virtual threads (`Thread.ofVirtual()` / `Executors +.newVirtualThreadPerTaskExecutor()`) deliver dramatic concurrency +gains by mounting many lightweight virtual threads onto a small pool +of platform "carrier" threads. A virtual thread blocked on I/O yields +its carrier so another virtual thread can run. + +ONNX Runtime and llama.cpp are native libraries called via JNI. JNI +calls **pin the carrier thread** for the duration of the native +call: the virtual thread cannot unmount, the carrier cannot be +reused. If we naively submitted ONNX or llama.cpp work to a virtual +thread executor, then under load the carrier pool (default sized to +`Runtime.availableProcessors()`) would be saturated by long-running +native calls. Other unrelated virtual threads — including ones doing +pure I/O elsewhere in the same JVM — would starve. + +#### The fix: dedicated platform-thread pool, virtual threads await results + +The SDK introduces `NativeExecutor` in `inference-sdk-core`: + +- A small `ThreadPoolExecutor` of **platform threads**, sized via + `ContainerCpu.detect()` (cgroups v2 aware; see §3.6). +- Threads are named `inference-native-` for diagnostics. +- Every JNI call (ONNX inference, llama.cpp `decode`) is wrapped in + `executor.submitNative(Callable)`, which returns a + `CompletableFuture`. +- Caller virtual threads await the result via `Future.get()` / + `CompletableFuture.thenApply` — a regular blocking operation that + the JVM correctly identifies as park-and-yield, freeing the + carrier. + +This isolates native-side parallelism from application-side +concurrency. Virtual threads remain free to fan out, batch, retry, +and orchestrate; native CPU is bounded and managed by a pool whose +size matches the container's CPU allocation. + +#### Flow diagram + +``` + Application code (caller) Inference SDK (library) Native lib + ───────────────────────── ───────────────────────── ───────────── + + Virtual thread V1 + ──────────────────► embed(texts) + │ + │ submit Callable + ▼ + NativeExecutor Platform thread P1 + (small pool, sized ─────────────────► ONNX Runtime + to cgroups CPU) inference (JNI; + │ pins P1) + │ CompletableFuture │ + ▼ ▼ + virtual-thread V1 parks returns vectors + on Future.get(), carrier ◄──────────────┘ + freed for other work + │ + ▼ + ◄───────────────── EmbedResult returned + + Other virtual threads V2..Vn keep running on freed carriers throughout. +``` + +Streaming generation follows the same pattern: each native +`llama_decode` call (per token) is submitted to `NativeExecutor`; +the generation coroutine on the SDK side is itself a virtual thread +that loops `submitNative(decode).get()` and pushes deltas to the +`Flow.Subscriber`. + +#### Rules + +1. **No JNI call ever runs on a virtual thread directly.** Static + analysis (a SpotBugs custom rule) is reserved for Phase 1.5 to + enforce this; Phase 1 enforces by code review and unit tests that + assert `Thread.currentThread().isVirtual() == false` from inside + submitted callables. +2. `NativeExecutor` is `AutoCloseable`; its `close()` shuts the pool + and is idempotent. +3. Every `Embedder` and every `Generator` instance owns its own + `NativeExecutor`. There is no SDK-wide singleton — closing one + instance does not affect another. + +### 3.4 Streaming model + +Streaming generation uses `java.util.concurrent.Flow.Publisher` from +the JDK core (no Reactive Streams library dependency). + +Contract: + +- `Generator.stream(GenerateRequest)` returns a + `Flow.Publisher` immediately. The first chunk + arrives when the model emits its first token. +- Chunks contain **incremental token deltas**, not cumulative text. + The caller is responsible for reassembly if needed. +- **Exactly one terminal chunk per stream** (`done == true`). The + terminal chunk's `delta` MAY be empty; it carries the + `finishReason`, `usage`, and final `stats`. After delivering the + terminal chunk the publisher calls `Subscriber.onComplete()`. +- **Mid-stream failure** → `Subscriber.onError(Throwable)`. The + publisher does NOT emit a terminal chunk after `onError`. (The + error path and the success path are mutually exclusive.) +- **Cancellation.** `Subscription.cancel()` instructs the SDK to + stop the native generation **at the next token boundary** (we do + not interrupt llama.cpp mid-decode). One terminal chunk with + `finishReason = Canceled` is emitted, then `onComplete()`. +- **Backpressure.** `Subscription.request(n)` is honored. The SDK + does not produce tokens faster than the subscriber requests them; + llama.cpp `decode` is gated on demand. Internal buffer is + configurable, default **16** chunks. +- **Lazy start.** If the subscriber never calls `request`, native + generation never starts. If the subscriber cancels before the + first request, no native call is made. +- **No leaks.** Cancellation always releases the underlying llama.cpp + context back to the pool. + +Reference: spec §7. + +### 3.5 Lifecycle + +`Embedder` and `Generator` both `extend AutoCloseable`. Lifecycle +invariants: + +- **Construction** does not load native libs; `NativeLibLoader` + defers extraction until first use to avoid paying the cost on + short-lived JVMs that never inference. +- **First call** triggers: extraction of the per-platform native + lib to `${TMPDIR}/inference-sdk-${pid}-${uuid}/`, SHA-256 + verification against the sibling `.sha256` resource, + model file resolution (Builder path → env var → classpath), SHA-256 + verification of the model file against + `model-manifest.properties`, and creation of the underlying ONNX + Runtime session / llama.cpp context. +- **`close()` is idempotent.** Multiple calls are safe; second and + subsequent calls are no-ops. Implementations use a + `volatile boolean closed` flag plus a CAS to guarantee at-most-once + native shutdown. +- **No double-free.** Native handles are released exactly once. + After close, all public methods throw `IllegalStateException`. +- **Native shutdown hook.** A JVM shutdown hook in + `NativeLibLoader` deletes extracted native libs on graceful exit + with a defensive try/catch (cleanup failure logs a warning; + it does not crash the JVM). +- **No global state.** Two independent `Embedder` instances do not + share an ONNX session, a thread pool, or a temp directory entry. + This is what makes the library safe to embed inside a + container-per-tenant or sidecar-per-tenant deployment. + +### 3.6 Container-aware threading + +`io.github.randomcodespace.inference.runtime.ContainerCpu` returns +the effective CPU count for thread-pool sizing: + +```java +public final class ContainerCpu { + /** cgroups v2 cpu.max; falls back to availableProcessors(). */ + public static int detect(); +} +``` + +Algorithm: + +1. Try to read `/sys/fs/cgroup/cpu.max`. The file format is ` + ` (both integers, microseconds). Result is + `Math.max(1, (int) Math.ceil((double) quota / period))`. +2. If the file is absent, unreadable, or has the form `max ` + (no quota), fall back to `Runtime.getRuntime() + .availableProcessors()`. +3. Path is package-private overridable for tests. + +This matters for any container runtime that limits CPU via cgroups +v2 (Kubernetes, Docker with `--cpus`, systemd slices, OpenShift). On +a 32-core host where the container is allocated 2 CPUs, +`Runtime.availableProcessors()` returns 32; this would oversize +`NativeExecutor` and pile up llama.cpp threads on overcommitted +cores. `ContainerCpu.detect()` returns 2, which is the right answer. + +cgroups v1 is **not** supported in Phase 1 — the user's deployment +environments (UBI8+, modern Kubernetes) all use v2. v1 paths fall +through to `availableProcessors()`. + +--- + +## 4. Threat model + +### 4.1 In scope + +- Trusted callers integrating the SDK as a library. +- Bundled default models (SHA-256 pinned, verified at load time). +- User-supplied models via `Builder.modelPath(Path)` or the + `INFERENCE_MODEL_DIR` env var, with SHA-256 verification against + a `model-manifest.properties` shipped alongside the file. +- Adversarial prompt content from end users, routed through + `Generator.complete()` / `Embedder.embed()`. + +### 4.2 Out of scope + +- **Adversary-controlled model files.** The library expects callers + to verify provenance of any non-bundled `modelPath`. Our SHA-256 + verification is **integrity-not-authenticity**: we verify the file + matches the manifest the user provided alongside it. We do not + sign manifests, and we do not maintain a cross-organisation chain + of trust. +- **DoS protection for unbounded repeated invocation.** Callers are + expected to rate-limit upstream of the SDK. The SDK provides + bounded queues and bounded streaming buffers, but does not throttle. +- **Multi-tenant isolation.** The SDK is a library. Runtime + isolation (JVM per tenant, container per tenant, OS-level + sandboxing) is the deployment's responsibility. + +### 4.3 Mitigations baked into Phase 1 + +| Mitigation | Source | +|---|---| +| SHA-256 verification of every loaded model file against `model-manifest.properties` | Spec §3, §8 | +| SHA-256 verification of every native lib extracted from JAR | Spec §6.2 (`NativeLibLoader`) | +| Input validation on `Message` (role allow-list, non-null role/content, reserved fields rejected) | Spec §6.4 | +| Input validation on `GenerateRequest` (token bounds, temperature/top-p ranges, message structure) | Spec §6.4 | +| Bounded queue depth (`QueueFullException` rather than unbounded growth) | Spec §6.4 | +| Bounded streaming buffer (configurable, default 16) | Spec §7 | +| Cancellation propagation through native code at next token boundary | Spec §7 | +| Zero global state; all instances independent | Spec §6.5 | +| Zero runtime network calls (verified by network-isolation test #47, which installs a `java.net.spi.InetAddressResolverProvider` that throws on every name) | Spec §6.5 | +| OWASP `dependency-check` in CI, fail on CVSS ≥ 7 | Spec §11.4 | +| Air-gapped build path: vendored deps, models fetched-at-build-time + embedded in published Maven JAR (no Git LFS, no runtime fetch), no public CDN at runtime | `rules/build.md` | + +### 4.4 Phase 1 residual security risk + +Phase 1 consumes the published `de.kherud:llama:4.2.0` whose bundled +llama.cpp `b4916` has 5 reachable High GHSA advisories + 1 Moderate. +Each is neutralized at the SDK boundary; full register: + +| ID | Sev | Path | Mitigation | +|---|---|---|---| +| GHSA-8wwf-... | High | `token_to_piece` overflow on vocab load | SHA-256 model allow-list — only pinned vocabs ever load | +| GHSA-7rxv-... | High | tokenizer signed/unsigned overflow on prompt | Input length cap (configurable, default tied to `contextSize`) + UTF-8 validation in API records | +| GHSA-vgg9-... | High | GGUF size accumulator overflow | SHA-256 model allow-list — pinned GGUF files only | +| GHSA-96jg-... | High | `ggml_nbytes` overflow → potential RCE on model load | SHA-256 model allow-list — pinned GGUF files only | +| GHSA-3p4r-... | High | `mem_size` overflow bypass | SHA-256 model allow-list — pinned GGUF files only | +| GHSA-g4cc-... | Moderate | DoS over-read | SHA-256 model allow-list — pinned GGUF files only | + +4 of 5 Highs + the Moderate are GGUF model-load-path bugs, fully +neutralized by the spec's SHA-256 allow-list +(`scripts/checksums/models.sha256` + +`NativeLibLoader.extractAndVerify`). Practical exploitation would +require substituting a file AND matching its pinned hash, which is +infeasible. The remaining tokenizer High (`7rxv`) is narrowed to a +non-exploitable surface by the input-validation layer. + +**Sign-off.** Path 1 (consume `de.kherud:llama:4.2.0` + the +mitigations above) was acknowledged in writing by the project owner +on 2026-05-09 per `~/.claude/rules/security.md`'s mitigation clause. +Phase 1.5 will fork-and-bump to clear these advisories upstream and +add Gemma 4 architecture support — see Roadmap below. + +--- + +## 5. Deviation register + +These deviations from the original spec ([`/java-sdk.md`](../java-sdk.md)) +are **locked decisions**. Each was reviewed against the spec, +validated against research artifacts in `.research/`, and accepted +with an explicit mitigation. Reopening any of them requires +documented evidence that the original constraint changed. + +**Total deviations: 5** (`D-001` through `D-005`). + +| # | Deviation | Reason | Mitigation | +|---|---|---|---| +| **D-001** | Generation default switched from `google/gemma-3-270m-it` to `Qwen/Qwen2.5-0.5B-Instruct`. | Gemma 3 is licensed under Gemma Terms (gated, non-Apache); the spec was wrong about the license. Gemma 4 E2B is Apache 2.0 + ungated but ~10× too large to fit the spec's "smallest viable for fast CI/test" intent. | Qwen 2.5-0.5B is the closest Apache-2.0 ungated model within ~2× the original size budget. Gemma 4 E2B/E4B registered in [`MODEL_REGISTRY.md`](./MODEL_REGISTRY.md) as opt-in for users wanting higher quality. | +| **D-002** | Platform matrix expanded: Win-x64 + Linux-x64 (UBI8+) + Linux-arm64 (UBI8+) — **Win-arm64 explicitly out of scope**. | User's actual deployment matrix. UBI8 sets the glibc 2.28 floor — confirmed via inspection of kherud's natives (`GLIBC_2.17` for x64, `GLIBC_2.27` for arm64). | Linux-aarch64 for non-UBI is naturally covered by the glibc 2.27 build. Win-arm64 documented as future work in §6 below; kherud's MSVC build is broken upstream so this is forced. | +| **D-003** | llama.cpp Java binding maintenance criterion (≤6 mo) relaxed: Phase 1 consumes published `de.kherud:llama:4.2.0` directly with documented residual risk; fork-and-bump moved to Phase 1.5. | No published binding meets all 6 spec criteria simultaneously: `de.kherud:llama:4.2.0` is the only one shipping all three required native classifiers (Win-x64, Linux-x64 glibc 2.17, Linux-arm64 glibc 2.27) under MIT; `io.gravitee.llama.cpp:llamaj.cpp:1.1.1` ships only `linux/x86_64` with glibc 2.34 (fails UBI8); `org.bytedeco:llama*` does not exist; `ai.djl.llama` was removed from DJL master. | Spec criterion 2 ("CVSS ≥ 7 = block unless documented mitigation + sign-off") is satisfied by: (a) SHA-256 model allow-list neutralizing 4 of 5 reachable Highs + 1 Moderate at `b4916` (all GGUF model-load-path bugs); (b) prompt-input length cap + UTF-8 validation narrowing the remaining tokenizer High; (c) project owner's written sign-off on 2026-05-09. Phase 1.5 will fork `kherud/java-llama.cpp` and bump bundled llama.cpp to ≥`b8146` to clear the underlying advisories upstream and add Gemma 4 architecture support. See [SECURITY.md](../SECURITY.md) §Residual security risk and §4.4 above. | +| **D-004** | Spec corrections. | `java-sdk.md` had typos in plugin coordinates and JaCoCo version. | Spotless GAV is `com.diffplug.spotless` (not `com.diffblue`). JaCoCo `0.8.14` is the JDK 25 floor (older versions crash on JDK 25 class files). | +| **D-005** | jlama evaluated and rejected. | User asked about jlama as a pure-Java alternative. `jlama 0.8.4` is Apache 2.0 + supports Qwen 2.5 + eliminates the native-libs subsystem entirely, BUT (a) `jinjava 2.7.2` transitive dep has 2 reachable Critical CVEs (chat-template renderer → JVM RCE), (b) does NOT yet support Gemma 4 architecture — user wants Gemma 4 E4B eventually, (c) pure-Java perf on 4B-class model would be 5–15 tok/s vs ~25–60 tok/s native. | Native llama.cpp (kherud fork) is the correct foundation for the Gemma-4-E4B-eventually requirement. jlama not adopted. | + +--- + +## 6. Roadmap + +### Phase 1.5 — incremental adds, no breaking API changes + +- **Fork `kherud/java-llama.cpp` and bump bundled llama.cpp** to a + current tag (≥`b8146`). Clears all 5 reachable High CVEs in `b4916` + and adds Gemma 4 architecture support. **First task in Phase 1.5, + before any new feature work.** Native CI rebuilds against + `dockcross-manylinux2014-x64` + `dockcross-linux-arm64-lts` + + `windows-2019` (the toolchains kherud already validates against). + Output artifact: `io.github.randomcodespace.inference:kherud-fork-llama`. + See §4.4 and [SECURITY.md](../SECURITY.md) for the advisories this + remediates. +- **Win-arm64 native classifier** when MSVC build issue is resolved + upstream in kherud. +- **macOS-x64 + macOS-arm64 native classifiers** for developer + laptops. ONNX Runtime already ships these in its CPU JAR; the + llama.cpp side requires a `dockcross` mac toolchain or a GitHub + Actions `macos-13` / `macos-14` runner addition. +- **CUDA classifier** for the generation path. Build via + `dockcross-cuda` or a separate llama.cpp CUDA matrix entry. Caller + opts in by depending on `inference-sdk-generate-cuda` instead of + the CPU module. Embedding stays CPU-only (BGE-small does not + benefit from GPU at this size). +- **Gemma 4 E2B opt-in.** Users who want higher generation quality + and accept the ~10× artifact size can swap + `inference-sdk-generate-qwen-0_5b` for + `inference-sdk-generate-gemma-4-e2b`. Already registered in + `MODEL_REGISTRY.md`. + +### Phase 2 — HTTP layer (OpenAI-compatible) + +- `POST /v1/embeddings`, `POST /v1/chat/completions` (with SSE for + streaming), `POST /v1/completions`, `GET /v1/models`, + `GET /v1/stats`. +- Bearer-token auth, OpenAI-shaped error envelopes. +- `x_stats` extension field carrying our `EmbedStats` / + `GenerateStats` (this is why the records carry their fields today). +- Tool calling support (the reserved `tools`, `tool_choice`, + `tool_calls`, `tool_call_id`, `name`, `system_fingerprint` fields + are already in the wire format for Phase 2). +- Prometheus metrics, OpenTelemetry tracing. +- Framework selection: evaluate Quarkus first (smallest startup, + native virtual-thread support, best SSE story); Micronaut second; + Spring Boot third. **Do not pre-commit in Phase 1.** + +### Future + +- **Bundle Gemma 4 E4B** as the higher-quality default once the + size/licensing trade-off is acceptable to users (per user + direction). +- **Go port** under `go/` matching the wire format and model registry + exactly. Will use `//go:embed` for the same default model + artifacts and pass an equivalent test suite to the Java one. diff --git a/docs/GLOSSARY.md b/docs/GLOSSARY.md new file mode 100644 index 0000000..ea30061 --- /dev/null +++ b/docs/GLOSSARY.md @@ -0,0 +1,303 @@ +# Glossary + +> Terms used across the `inference-sdk` codebase, docs, and APIs. +> Cross-language: every entry applies equally to the Java reference +> implementation and the planned Go port unless explicitly noted. + +--- + +### Apache 2.0 + +Permissive open-source license. Allows commercial use, modification, +and redistribution with attribution. The license under which this +SDK and its primary dependencies (ONNX Runtime, Qwen 2.5-0.5B, +Gemma 4) are distributed. SPDX identifier: `Apache-2.0`. + +### Backpressure + +The reactive-streams mechanism by which a slow consumer signals an +upstream producer to slow down. In this SDK, `Subscription.request(n)` +on a `Flow.Subscriber` controls how many `GenerateChunk`s the +publisher is allowed to deliver. Default internal buffer is 16 +chunks; configurable. + +### BGE + +The BAAI General Embedding family (BAAI = Beijing Academy of +Artificial Intelligence). Open-source embedding models published as +both PyTorch and pre-exported ONNX. BGE tokenizer is a BERT +WordPiece variant. Phase 1 default: `bge-small-en-v1.5` (384-dim). + +### CVE + +Common Vulnerabilities and Exposures — the public ID format for +publicly-disclosed security flaws (e.g. `CVE-2025-12345`). Maintained +by MITRE and consumed by every dependency-audit tool. Distinct from +GHSA (which is GitHub's parallel ID space and often the first to +publish). + +### CVSS + +Common Vulnerability Scoring System. Severity score 0.0–10.0 attached +to CVEs/GHSAs (e.g., 9.8 = Critical, 7.0–8.9 = High). CI gates this +SDK at **CVSS ≥ 7 fails the build** (`dependency-check` plugin). + +### Dockcross + +Set of pre-built Docker images that cross-compile native code from +Linux x86_64 to many target triples (manylinux2014-x64, +linux-arm64-lts, etc.) with pinned glibc and toolchain versions. Used +upstream by `de.kherud:llama` to produce the linux-x64 (`GLIBC_2.17`) +and linux-arm64 (`GLIBC_2.27`) native libs we consume in Phase 1. +Phase 1.5 context: when this project forks `kherud/java-llama.cpp` +to bump bundled llama.cpp, our native CI will rebuild against the +same dockcross toolchains. + +### EOS + +End-of-sequence token. The special token a generation model emits +when it considers its response complete. Produces +`FinishReason.Eos` (wire form: `"eos"`), which maps to OpenAI's +`"stop"` in Phase 2. + +### FFM + +Foreign Function & Memory API (JEP 454, finalized in JDK 22). +Modern replacement for JNI for calling native code from Java. The +SDK does NOT use FFM in Phase 1 — the kherud fork is built on JNI +and we inherit that. FFM migration is a possible Phase 2+ +optimization. + +### FFI + +Foreign Function Interface. Generic term for any +language-to-language call mechanism (JNI, FFM, ctypes, cgo, etc.). + +### Finish reason + +Why a generation stopped. One of `Stop` (user stop sequence +matched), `Length` (`max_tokens` reached), `Eos` (model EOS token), +`Canceled` (caller cancelled), or `Error(message)` (native runtime +error). See `WIRE_FORMAT.md` §2.5 for cross-language encoding. + +### Gemma Terms + +Custom license used by Google for Gemma 3 and earlier Gemma model +families (Gemma 4 onward is Apache-2.0). Permits commercial use but +imposes additional Prohibited Use Policy constraints and is +license-click-through gated on HuggingFace. **Not redistributable +by this SDK**, hence Phase 1's switch from Gemma 3 to Qwen 2.5-0.5B +as the bundled default. + +### GGUF + +GPT-Generated Unified Format. The model file format used by +llama.cpp. Supports many quantization schemes (q4_0, q4_K_M, q5_K_M, +q8_0, etc.). Phase 1 ships Qwen 2.5-0.5B as `q4_K_M` GGUF (~330– +380 MB). + +### GHSA + +GitHub Security Advisory ID format (e.g., `GHSA-p5mv-gjc5-mwqv`). +GitHub's vulnerability disclosure system; advisories are typically +issued before MITRE publishes a CVE. + +### glibc + +GNU C Library, the standard C library on most Linux distros. +Versioned symbols (e.g. `GLIBC_2.17`, `GLIBC_2.28`) determine binary +compatibility — a binary built against `GLIBC_2.34` will not run on +a host with only `GLIBC_2.28`. The SDK targets `GLIBC_2.17` (x64) +and `GLIBC_2.27` (arm64) to support UBI8 (RHEL 8) deployments. + +### Int8 dynamic quantization + +Post-training quantization technique that converts a float32 model's +weights and activations to 8-bit integers at inference time. Reduces +memory ~4× with minimal accuracy loss. Applied via +`onnxruntime.quantization.quantize_dynamic`. Phase 1 ships the BGE +model with this scheme. + +### JaCoCo + +Java code-coverage tool. Measures line/branch coverage during test +runs and produces HTML + XML reports. The SDK enforces ≥ 75% line / +≥ 70% branch coverage on `core`, `embed`, and `generate` modules. +Pinned to `0.8.14` (the JDK 25 floor; older versions crash on JDK 25 +class files). + +### JNI + +Java Native Interface. The legacy mechanism for calling native code +from Java; pre-dates FFM. Used by both ONNX Runtime's Java bindings +and the kherud llama.cpp fork. JNI calls **pin the carrier thread** +of any virtual thread that makes them — see `ARCHITECTURE.md` §3.3 +for how the SDK works around this. + +### jqwik + +JUnit 5 property-based testing framework. The SDK uses jqwik for +property tests on tokenizers, vector arithmetic, and round-trip +serialization. Property runs are capped at 30 seconds each in CI. + +### kherud + +Short for `kherud/java-llama.cpp`, the Java binding to llama.cpp. +Phase 1 consumes the published `de.kherud:llama:4.2.0` artifact +directly from Maven Central. Selected because it is the only +published Java binding shipping Win-x64 + Linux-x64 (glibc 2.17) + +Linux-arm64 (glibc 2.27) classifiers from a single artifact under +MIT. The bundled llama.cpp is `b4916`; residual advisories are +mitigated at the SDK boundary per `SECURITY.md`. Phase 1.5 will fork +`kherud/java-llama.cpp` and bump the bundled llama.cpp to clear +those advisories upstream. + +### KV cache + +Key/value cache used by transformer-decoder generation. Stores the +attention K and V tensors for already-generated tokens so the next +token only requires a forward pass over the new token, not the full +context. Consumes ~`2 × n_layers × n_heads × head_dim × seq_len × +sizeof(dtype)` bytes; this dominates RAM at long context lengths. +Managed by llama.cpp; the SDK exposes it via the `context_used` / +`context_max` fields on `GenerateStats`. + +### llama.cpp + +C++ runtime for running quantized large language models on CPU and +GPU. The native engine behind the SDK's generation path, accessed +via the kherud fork. Versions are tagged `b` where N is a +monotonic build number; the fork pins a specific tag for +reproducibility. + +### manylinux2014 + +PEP 599 standard defining a Linux build environment based on CentOS +7 with `glibc_2.17`. Native libs built in a `manylinux2014` Docker +image are runnable on essentially every glibc Linux distro from +~2014 onward, including UBI8. The SDK's linux-x64 native lib targets +this baseline. + +### MIT + +The MIT License (also called the Expat License). Maximally permissive +open-source license; allows commercial use with only an attribution +notice. SPDX identifier: `MIT`. The license used by BGE models. + +### ONNX + +Open Neural Network Exchange — a portable format for representing +neural networks. The SDK's embedding path uses ONNX models loaded +into ONNX Runtime. + +### OWASP dependency-check + +Maven plugin that scans project dependencies against the National +Vulnerability Database (NVD) and reports CVEs/GHSAs by CVSS score. +Pinned to `12.2.2`; gated at CVSS ≥ 7 in CI. + +### Per-Layer Embeddings (PLE) + +Architectural feature in Gemma 4 that enables much smaller "active +parameter" footprints relative to total parameter count (Gemma 4 E2B +has 5.1B total but ~2.3B active). Reduces inference RAM and +throughput cost without proportionally reducing model quality. + +### q4_0 + +A GGUF quantization scheme. 4-bit weights, simpler block layout +than `q4_K_M`. Slightly faster but lower fidelity. Used by Gemma 3 +distributions. + +### q4_K_M + +A GGUF quantization scheme. 4-bit weights with K-quants ("K" +referring to a more sophisticated block structure with per-group +scaling factors). The "M" variant balances size and quality. +Recommended quantization for Qwen 2.5-0.5B and Gemma 4 in this SDK. + +### Qwen + +Open-source large language model family from Alibaba's DAMO Academy. +Phase 1 ships `Qwen 2.5-0.5B-Instruct` (Apache-2.0, 0.49B parameters, +32K context). + +### ScopedValue + +JDK 21+ API for binding a value to the dynamic scope of a callable +(`ScopedValue.where(KEY, value).call(body)`). Used by the SDK's +`RequestId.CURRENT` to propagate a request ID into nested virtual +threads and structured task scopes without a `ThreadLocal`. Cheaper +and safer than `ThreadLocal` for virtual-thread-heavy workloads. + +### Spotless + +Maven plugin enforcing code formatting (Google Java Format style is +this SDK's choice). Run as `./mvnw spotless:apply` to fix; CI runs +`spotless:check` and fails on any deviation. Coordinate is +`com.diffplug.spotless` — note: NOT `com.diffblue` (a known +typo in the original spec). + +### SpotBugs + +Maven plugin running static analysis to find common Java bug +patterns (null dereferences, resource leaks, equals/hashCode bugs, +etc.). Pinned to `4.9.8.3`; CI fails on any reported issue at the +default rank. + +### Streaming chunk + +A `GenerateChunk` emitted by `Generator.stream()`. Carries an +incremental `delta` (token text since the previous chunk), `done` +flag, and on the terminal chunk also `finishReason`, `usage`, and +final `stats`. See `WIRE_FORMAT.md` §2.8. + +### Structured concurrency + +JDK 25 API (`java.util.concurrent.StructuredTaskScope`) for +spawning a known set of subtasks, awaiting them, and propagating +errors / cancellation as a single unit. Used in the SDK for batch +embedding fan-out where each text in the batch is a subtask. + +### Temperature + +Scalar parameter (`0.0–2.0`) controlling randomness in token +sampling. `0.0` is greedy/deterministic; `1.0` samples from the raw +softmax; values >1 flatten the distribution. Default in this SDK is +`0.7` per OpenAI convention. + +### Terminal chunk + +The exactly-one chunk emitted at the end of a streaming generation +where `done == true`. Carries the final `finishReason`, `usage`, and +`stats`. Its `delta` MAY be empty. After the terminal chunk the +publisher calls `Subscriber.onComplete()`. See `ARCHITECTURE.md` +§3.4. + +### tok/s + +Tokens per second. Throughput metric for generation. Reported in +`GenerateStats.tokensPerSecond`. Phase 1 targets ~25–40 tok/s on a +laptop-class CPU for Qwen 2.5-0.5B `q4_K_M`. + +### Top-p (nucleus sampling) + +Token-sampling strategy: at each step, restrict to the smallest set +of tokens whose cumulative probability exceeds `p`, then renormalize +and sample. Reduces repetition vs pure temperature sampling. +Default in this SDK is `0.95`. Range: `0.0 < top_p <= 1.0`. + +### Vector API + +JDK incubator API (`jdk.incubator.vector`) providing portable SIMD +intrinsics. The SDK does NOT depend on it in Phase 1 — ONNX Runtime +and llama.cpp do their own SIMD natively. Reserved for any future +pure-Java acceleration paths. + +### Virtual threads + +JDK 21 GA feature: lightweight threads scheduled onto a small pool +of platform "carrier" threads. Allow millions of concurrent +park-on-I/O operations without OS-thread cost. Critical for the +SDK's caller side; deliberately NOT used to drive native calls (see +`ARCHITECTURE.md` §3.3 for why). diff --git a/docs/MODEL_REGISTRY.md b/docs/MODEL_REGISTRY.md new file mode 100644 index 0000000..5e6783d --- /dev/null +++ b/docs/MODEL_REGISTRY.md @@ -0,0 +1,232 @@ +# Model Registry + +> Single source of truth for model identifiers, sources, licenses, +> and SHA-256 pins. Every language implementation reads this file; +> mismatches are bugs. + +--- + +## 1. Conventions + +Each entry below provides: + +| Field | Meaning | +|---|---| +| **Canonical ID** | The string consumers pass to `Embedder.builder().model(...)` or `Generator.builder().model(...)`. Stable across versions. | +| **HF coordinates** | Source HuggingFace repo + the file we extract from it. | +| **License** | SPDX identifier where one applies; otherwise the human-readable license name with a note. | +| **Role** | `embed` (returns vectors) or `generate` (returns text). | +| **Tokenizer family** | Determines the tokenizer JAR / runtime initialization. Cross-language: the Go port uses the same family identifier. | +| **Dimensions** | Vector dimensionality (embed only); `—` for generate. | +| **Max context** | Model's hard maximum context length in tokens. The SDK caps per-request at `INFERENCE_GEN_CONTEXT_SIZE` (default `2048`) regardless of model max. | +| **Quantization** | The artifact format we ship. `int8` for ONNX dynamic-quantized; `q4_K_M` / `q4_0` for GGUF. | +| **Expected size** | Disk size of the bundled artifact, post-quantization. | +| **SHA-256** | Pinned hash of the exact file we ship. Verified at load time by `NativeLibLoader` against `model-manifest.properties`. | + +> **The SDK NEVER auto-downloads models.** Unknown model name + no +> `modelPath` produces `ModelNotFoundException` listing what is on +> the classpath. SHA-256 placeholders below are computed in Tier 0.5 +> by `scripts/fetch_models.py` (which is the only thing in the repo +> that ever talks to HuggingFace, and only on the build host). + +--- + +## 2. Bundled-by-default models (Phase 1) + +These ship inside the consuming application via Maven model JARs +(`inference-sdk-embed-bge-small`, `inference-sdk-generate-qwen-0_5b`) +and are committed via Git LFS. + +### 2.1 `bge-small-en-v1.5` + +| Field | Value | +|---|---| +| **Canonical ID** | `bge-small-en-v1.5` | +| **HF coordinates** | [`BAAI/bge-small-en-v1.5`](https://huggingface.co/BAAI/bge-small-en-v1.5) — file `onnx/model.onnx` (133 MB float32), int8 dynamic-quantized via `onnxruntime.quantization.quantize_dynamic` to `bge-small-en-v1.5.int8.onnx` (~35 MB). | +| **License** | MIT. Verified from the HuggingFace model card metadata. | +| **Role** | `embed` | +| **Tokenizer family** | BGE (BERT WordPiece variant; vocab size 30522) | +| **Dimensions** | `384` | +| **Max context** | `512` tokens | +| **Quantization** | `int8` dynamic (ONNX) | +| **Expected size** | ~35 MB | +| **SHA-256** | `` | +| **Bundled module** | `inference-sdk-embed-bge-small` | +| **Resource path** | `models/bge-small-en-v1.5.int8.onnx` | +| **Manifest** | `models/bge-small-en-v1.5.model-manifest.properties` | + +Why this default: the smallest BGE family member; MIT-licensed; ONNX +already pre-exported by BAAI; produces strong semantic embeddings at +33.4M parameters with a peak RAM of ~250 MB at inference. + +### 2.2 `qwen2.5-0.5b-instruct` + +| Field | Value | +|---|---| +| **Canonical ID** | `qwen2.5-0.5b-instruct` | +| **HF coordinates** | [`Qwen/Qwen2.5-0.5B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) — converted to GGUF via `convert_hf_to_gguf.py` then quantized to `q4_K_M` to produce `qwen2.5-0.5b-instruct.q4_k_m.gguf`. | +| **License** | Apache-2.0. Verified from HuggingFace model card metadata + bundled `LICENSE` file. Public, ungated; anonymous download. (Note: larger Qwen 2.5 sizes use the Qwen license; the 0.5B size is Apache.) | +| **Role** | `generate` | +| **Tokenizer family** | Qwen (Qwen2 BPE; vocab size 151643) | +| **Max context** | `32768` tokens (model max). **SDK default cap: `2048`** per `INFERENCE_GEN_CONTEXT_SIZE`. | +| **Quantization** | `q4_K_M` (GGUF) | +| **Expected size** | ~330–380 MB | +| **SHA-256** | `` | +| **Peak RAM** | ~700 MB at the 2048-token default cap. | +| **Throughput (laptop CPU)** | ~25–40 tok/s. | +| **Bundled module** | `inference-sdk-generate-qwen-0_5b` | +| **Resource path** | `models/qwen2.5-0.5b-instruct.q4_k_m.gguf` | +| **Manifest** | `models/qwen2.5-0.5b-instruct.model-manifest.properties` | + +Why this default (and why it differs from the original spec): the +spec called for `google/gemma-3-270m-it`, which is the smallest +coherent Gemma but is licensed under the **Gemma Terms of Use** (a +custom Google license, not Apache 2.0) and is gated. Gemma 4 E2B is +Apache 2.0 and ungated but ~10× larger than the spec's "smallest +viable for fast CI/test" intent. Qwen 2.5 0.5B is the closest +clean-Apache, ungated, public-download model within a reasonable +size budget; it sits at ~2× the original 270M's footprint, which is +acceptable for Phase 1. See `docs/ARCHITECTURE.md` deviation +**D-001** for the locked decision. + +--- + +## 3. Registered but not bundled + +These are recognised by the registry, can be selected via +`Builder.modelPath(Path)`, and have placeholder Maven artifacts in +the project (or are reserved for future module names). Users supply +the actual model files themselves; SHA-256 verification still +applies if they ship a manifest. + +### 3.1 `bge-base-en-v1.5` + +| Field | Value | +|---|---| +| **Canonical ID** | `bge-base-en-v1.5` | +| **HF coordinates** | [`BAAI/bge-base-en-v1.5`](https://huggingface.co/BAAI/bge-base-en-v1.5) | +| **License** | MIT | +| **Role** | `embed` | +| **Tokenizer family** | BGE | +| **Dimensions** | `768` | +| **Max context** | `512` | +| **Recommended quantization** | `int8` dynamic (ONNX) | +| **Expected size** | ~110 MB int8 | +| **SHA-256** | `` | + +### 3.2 `bge-m3` + +| Field | Value | +|---|---| +| **Canonical ID** | `bge-m3` | +| **HF coordinates** | [`BAAI/bge-m3`](https://huggingface.co/BAAI/bge-m3) | +| **License** | MIT | +| **Role** | `embed` | +| **Tokenizer family** | XLMR (XLM-RoBERTa SentencePiece; vocab size 250002) | +| **Dimensions** | `1024` | +| **Max context** | `8192` | +| **Recommended quantization** | `int8` dynamic (ONNX) | +| **Expected size** | ~600 MB int8 | +| **SHA-256** | `` | + +Multilingual; supports dense + sparse + multi-vector retrieval. Only +the dense vector path is exposed in Phase 1. + +### 3.3 `gemma-3-270m-it` — NOT VIABLE AS DEFAULT + +| Field | Value | +|---|---| +| **Canonical ID** | `gemma-3-270m-it` | +| **HF coordinates** | [`google/gemma-3-270m-it`](https://huggingface.co/google/gemma-3-270m-it) | +| **License** | **Gemma Terms of Use** (custom Google license; **NOT Apache 2.0**) — **gated repo, requires HF login + license click-through**. | +| **Role** | `generate` | +| **Tokenizer family** | Gemma (Gemma SentencePiece; vocab size 256000) | +| **Max context** | `32768` | +| **Recommended quantization** | `q4_0` (GGUF) | +| **Expected size** | ~170 MB q4_0 | +| **SHA-256** | `` | + +> WARNING: Redistributing Gemma model weights is bound by the Gemma +> Terms + Prohibited Use Policy. This SDK **does not redistribute** +> Gemma 3. The entry exists so users who have accepted the terms can +> select the model via `Builder.modelPath(Path)` against their own +> local copy. + +Listed for completeness because the original spec named it; not a +viable default for this project. See deviation **D-001** in +`ARCHITECTURE.md`. + +### 3.4 `gemma-4-e2b-it` + +| Field | Value | +|---|---| +| **Canonical ID** | `gemma-4-e2b-it` | +| **HF coordinates** | [`google/gemma-4-E2B-it`](https://huggingface.co/google/gemma-4-E2B-it) | +| **License** | Apache-2.0, public, **ungated** (Google moved Gemma 4 to Apache; this fixes the Gemma 3 blocker). | +| **Role** | `generate` | +| **Tokenizer family** | Gemma | +| **Max context** | `131072` (128K) | +| **Recommended quantization** | `q4_K_M` (GGUF) | +| **Expected size** | ~3.0–3.4 GB q4_K_M | +| **Peak RAM** | ~3.5–4.5 GB | +| **Throughput (laptop CPU)** | ~8–18 tok/s | +| **SHA-256** | `` | + +Future Phase 1.5 opt-in. Higher quality (~2–3B effective class with +reasoning mode), 128K context, multimodal (text/image/audio — not +exposed in Phase 1). Trade-off: ~10× the bundled-default footprint. + +### 3.5 `gemma-4-e4b-it` + +| Field | Value | +|---|---| +| **Canonical ID** | `gemma-4-e4b-it` | +| **HF coordinates** | `google/gemma-4-E4B-it` | +| **License** | Apache-2.0 | +| **Role** | `generate` | +| **Tokenizer family** | Gemma | +| **Max context** | `131072` | +| **Recommended quantization** | `q4_K_M` (GGUF) | +| **Expected size** | TBD at Tier 0.5 (estimated ~5–6 GB q4_K_M) | +| **SHA-256** | `` | + +Future bundle target per user direction (see `ARCHITECTURE.md` §6 +Roadmap). Larger, higher-quality Gemma 4. Will become the +higher-quality default once size/licensing trade-off is acceptable. + +### 3.6 `qwen2.5-coder-7b` + +| Field | Value | +|---|---| +| **Canonical ID** | `qwen2.5-coder-7b` | +| **HF coordinates** | [`Qwen/Qwen2.5-Coder-7B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-Coder-7B-Instruct) | +| **License** | Apache-2.0 | +| **Role** | `generate` | +| **Tokenizer family** | Qwen | +| **Max context** | `32768` | +| **Recommended quantization** | `q4_K_M` (GGUF) | +| **Expected size** | ~4.5 GB q4_K_M | +| **SHA-256** | `` | + +Code-specialised generator. Registered for users with code-specific +workloads. + +--- + +## 4. Resolution order + +When `Embedder.builder().model("name").build()` runs, the SDK +resolves the model file in this order: + +1. Explicit `modelPath(Path)` if set. +2. The `INFERENCE_MODEL_DIR` env var, if set, as a directory + containing `.` and a sibling + `.model-manifest.properties`. +3. Classpath resource under `models/.` (this is + how the bundled model JARs work). +4. Otherwise: `ModelNotFoundException` listing every classpath entry + under `models/`. + +In all four cases, SHA-256 verification against the resolved +manifest is mandatory before the model is loaded into a native +session. diff --git a/docs/WIRE_FORMAT.md b/docs/WIRE_FORMAT.md new file mode 100644 index 0000000..d03f0a9 --- /dev/null +++ b/docs/WIRE_FORMAT.md @@ -0,0 +1,365 @@ +# Wire Format + +> Locked JSON shapes shared across languages. Phase 1 does NOT +> serialize these over the wire (the SDK is a library, no HTTP), but +> field names and types are fixed today so the Phase 2 HTTP layer and +> the Go port match without re-litigation. + +--- + +## 1. Conventions + +These conventions apply to every type in this document. Implementations +MUST conform. + +| Convention | Rule | +|---|---| +| **Field naming** | `snake_case`. This matches OpenAI's API conventions. Java records use `@JsonProperty("snake_name")` annotations on the canonical camelCase fields, or a Jackson `PropertyNamingStrategies.SnakeCaseStrategy` configured globally. Go uses `json:"snake_name"` struct tags. | +| **Date / time** | ISO-8601 strings, **UTC**. Example: `"2026-05-08T14:30:00Z"`. No local timezones, no offset suffixes other than `Z`. | +| **Numbers** | Plain JSON numbers. `NaN`, `Infinity`, and `-Infinity` are **forbidden** in outputs and rejected on input with HTTP 400 in Phase 2. | +| **Optional fields** | Absent (omitted from the JSON) when unset, NOT `null`. `null` is reserved to mean "explicitly cleared" in any future PATCH-style API. | +| **Enums** | Lowercase strings, e.g. `"stop"`, `"length"`, `"eos"`, `"canceled"`, `"error"`. No numeric enum codes; no SCREAMING_SNAKE. | +| **Booleans** | `true` / `false`. No 0/1 truthy substitution. | +| **Arrays** | Empty list is `[]`, never absent and never `null`. | +| **IDs** | Strings. Request IDs use the prefix `"req_"` (see [`GLOSSARY.md`](./GLOSSARY.md#requestid)). | + +--- + +## 2. Schemas + +Each subsection gives: the canonical Java record signature, an +example JSON document, and any field-level rules. + +### 2.1 `EmbedStats` + +Per-request telemetry returned alongside an `EmbedResult`. + +```java +public record EmbedStats( + @JsonProperty("request_id") String requestId, + @JsonProperty("queue_ms") long queueMs, + @JsonProperty("tokenize_ms") long tokenizeMs, + @JsonProperty("inference_ms") long inferenceMs, + @JsonProperty("total_ms") long totalMs, + @JsonProperty("batch_size") int batchSize, + @JsonProperty("batch_position") String batchPosition, + @JsonProperty("model_revision") String modelRevision, + String node) {} +``` + +```json +{ + "request_id": "req_2c1f8a92-46d2-4b02-9b5f-9a1c7d2c4aef", + "queue_ms": 0, + "tokenize_ms": 2, + "inference_ms": 18, + "total_ms": 21, + "batch_size": 1, + "batch_position": "single", + "model_revision": "bge-small-en-v1.5", + "node": "host-1" +} +``` + +- `batch_position`: `"single"` for a one-shot call, + `"coalesced"` if the request was merged with siblings by the SDK + batcher. +- All `*_ms` fields are non-negative and `total_ms >= queue_ms + + tokenize_ms + inference_ms`. + +### 2.2 `EmbedResult` + +```java +public record EmbedResult( + List vectors, + int tokens, + EmbedStats stats) {} +``` + +```json +{ + "vectors": [[0.0123, -0.0456, 0.0789, "..."]], + "tokens": 7, + "stats": { "...": "EmbedStats" } +} +``` + +- `vectors` is a list of dense float arrays, one per input text. Each + vector's length equals `ModelInfo.dimensions`. +- `tokens` is the total number of tokens consumed across all inputs. + +### 2.3 `Message` + +```java +public record Message( + String role, + String content, + @JsonProperty("tool_calls") List toolCalls, + @JsonProperty("tool_call_id") String toolCallId, + String name) {} +``` + +```json +{ + "role": "user", + "content": "Summarize this in one sentence." +} +``` + +- `role` must be one of `"system" | "user" | "assistant" | "tool"`. +- `content` is non-null in Phase 1. +- `tool_calls`, `tool_call_id`, and `name` are RESERVED for Phase 2. + They are in the wire format today; Phase 1 throws + `FeatureNotSupportedException` if any are non-null. + +### 2.4 `Usage` + +```java +public record Usage( + @JsonProperty("prompt_tokens") int promptTokens, + @JsonProperty("completion_tokens") int completionTokens, + @JsonProperty("total_tokens") int totalTokens) {} +``` + +```json +{ + "prompt_tokens": 12, + "completion_tokens": 87, + "total_tokens": 99 +} +``` + +- All fields non-negative; `total_tokens == prompt_tokens + + completion_tokens` (validated in record canonical constructor). + +### 2.5 `FinishReason` + +Encoded as a lowercase string. Wire form is the string in column 1. + +| Wire | Java | OpenAI mapping | When emitted | +|---|---|---|---| +| `"stop"` | `FinishReason.Stop` | `"stop"` | A user-supplied stop sequence matched. | +| `"length"` | `FinishReason.Length` | `"length"` | `max_tokens` reached. | +| `"eos"` | `FinishReason.Eos` | `"stop"` | Model emitted its EOS token. | +| `"canceled"` | `FinishReason.Canceled` | `"stop"` (with internal flag) | Caller cancelled the stream. | +| `"error"` | `FinishReason.Error(message)` | n/a (HTTP 5xx in Phase 2) | Native runtime error mid-generation. Wire form: `{ "type": "error", "message": "" }`. | + +```json +"stop" +``` + +```json +{ "type": "error", "message": "context window exceeded" } +``` + +The `error` variant is the only one that carries a payload, hence the +object form. All other variants are bare strings. + +### 2.6 `GenerateRequest` + +```java +public record GenerateRequest( + List messages, + @JsonProperty("max_tokens") int maxTokens, + float temperature, + @JsonProperty("top_p") float topP, + List stop, + Long seed, + // Reserved for Phase 2; non-null throws today + List tools, + @JsonProperty("tool_choice") Object toolChoice, + @JsonProperty("response_format") Object responseFormat) {} +``` + +```json +{ + "messages": [ + { "role": "system", "content": "You are concise." }, + { "role": "user", "content": "Hi." } + ], + "max_tokens": 256, + "temperature": 0.7, + "top_p": 0.95, + "stop": ["\n\n"], + "seed": 42 +} +``` + +- `messages` is non-empty and contains at least one `"user"` message. +- `max_tokens > 0`. +- `0.0 <= temperature <= 2.0`. +- `0.0 < top_p <= 1.0`. +- `seed`, when present, is a 64-bit signed integer. +- `tools`, `tool_choice`, `response_format` are RESERVED. Phase 1 + throws `FeatureNotSupportedException` if any are non-null. + +### 2.7 `GenerateResponse` + +```java +public record GenerateResponse( + String text, + @JsonProperty("finish_reason") FinishReason finishReason, + Usage usage, + GenerateStats stats, + @JsonProperty("system_fingerprint") String systemFingerprint) {} +``` + +```json +{ + "text": "Hello!", + "finish_reason": "stop", + "usage": { "prompt_tokens": 9, "completion_tokens": 2, "total_tokens": 11 }, + "stats": { "...": "GenerateStats" } +} +``` + +- `text` is the full assistant response (cumulative, not delta). +- `system_fingerprint` is RESERVED for Phase 2 OpenAI compatibility. + Always `null` (i.e., absent) in Phase 1; Phase 1 setter throws. + +### 2.8 `GenerateChunk` + +Streaming token-by-token output. + +```java +public record GenerateChunk( + String delta, + boolean done, + @JsonProperty("finish_reason") FinishReason finishReason, + Usage usage, + GenerateStats stats) {} +``` + +```json +{ "delta": "Hel", "done": false } +``` + +Terminal chunk: + +```json +{ + "delta": "", + "done": true, + "finish_reason": "stop", + "usage": { "prompt_tokens": 9, "completion_tokens": 2, "total_tokens": 11 }, + "stats": { "...": "GenerateStats" } +} +``` + +- `delta` is the **incremental** token text since the previous chunk + — never cumulative. +- `done == false` chunks have only `delta`. `usage`, `stats`, and + `finish_reason` are absent. +- Exactly one `done == true` chunk per stream. Its `delta` MAY be + empty. + +### 2.9 `GenerateStats` + +```java +public record GenerateStats( + @JsonProperty("request_id") String requestId, + @JsonProperty("queue_ms") long queueMs, + @JsonProperty("prompt_eval_ms") long promptEvalMs, + @JsonProperty("first_token_ms") long firstTokenMs, + @JsonProperty("generation_ms") long generationMs, + @JsonProperty("total_ms") long totalMs, + @JsonProperty("tokens_per_second") double tokensPerSecond, + @JsonProperty("tokens_generated") int tokensGenerated, + @JsonProperty("context_used") int contextUsed, + @JsonProperty("context_max") int contextMax, + @JsonProperty("model_revision") String modelRevision, + String node) {} +``` + +```json +{ + "request_id": "req_8f3a...", + "queue_ms": 0, + "prompt_eval_ms": 14, + "first_token_ms": 16, + "generation_ms": 412, + "total_ms": 426, + "tokens_per_second": 31.5, + "tokens_generated": 13, + "context_used": 22, + "context_max": 2048, + "model_revision": "qwen2.5-0.5b-instruct", + "node": "host-1" +} +``` + +- `tokens_per_second` is a positive double; floored at 0.0 if + generation produced 0 tokens. +- `context_used <= context_max`. + +### 2.10 `ModelInfo` + +```java +public record ModelInfo( + String id, + String revision, + String quantization, + int dimensions, + @JsonProperty("max_tokens") int maxTokens) {} +``` + +```json +{ + "id": "bge-small-en-v1.5", + "revision": "v1.5", + "quantization": "int8", + "dimensions": 384, + "max_tokens": 512 +} +``` + +- `dimensions` is `-1` for generation models (which don't produce + vectors). +- `max_tokens` is the model's hard maximum context length, NOT the + per-request cap (see `INFERENCE_GEN_CONTEXT_SIZE` env var, default + 2048). + +--- + +## 3. Phase 1 vs Phase 2 + +Phase 1 (this release) does not serialize any of these types over a +wire — the SDK is consumed as a Java library; records flow as +in-process objects. The schemas above are nonetheless **locked +today** because: + +1. The Phase 2 HTTP layer will serialize them directly. +2. The Go port must match field-for-field. +3. Field-name renames after Phase 1 ships would be breaking changes + for both downstream consumers and the cross-language contract. + +Practically, Phase 1 enforces the schema by: + +- Annotating every Java record field with `@JsonProperty`. +- Exercising round-trip serialization in `inference-sdk-tests` even + though no production code does it. +- Rejecting `NaN` / `Infinity` floats in record canonical + constructors (so they cannot enter the system from any source). + +--- + +## 4. Reserved fields (forward-compat) + +Fields that exist on the type today but Phase 1 rejects when set. +Setting any of these to a non-null value MUST throw +`FeatureNotSupportedException`. They are present so Phase 2 can +populate them without any wire-format change. + +| Type | Field (wire) | Java property | Phase 2 use | +|---|---|---|---| +| `Message` | `tool_calls` | `toolCalls` | Tool/function calling response. | +| `Message` | `tool_call_id` | `toolCallId` | Identifier referencing a `tool_calls[i].id`. | +| `Message` | `name` | `name` | Tool name (for `role == "tool"` messages). | +| `GenerateRequest` | `tools` | `tools` | List of available tools/functions for the model to call. | +| `GenerateRequest` | `tool_choice` | `toolChoice` | OpenAI-shaped tool selection control. | +| `GenerateRequest` | `response_format` | `responseFormat` | OpenAI-shaped JSON-mode / schema-mode control. | +| `GenerateResponse` | `system_fingerprint` | `systemFingerprint` | OpenAI-compatibility opaque fingerprint of the deployment. | + +The reserved fields are also called out in the JavaDoc of each +record so callers see the boundary at the API surface, not just at +runtime. diff --git a/go/.gitkeep b/go/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/go/README.md b/go/README.md new file mode 100644 index 0000000..067c002 --- /dev/null +++ b/go/README.md @@ -0,0 +1,17 @@ +## Go implementation + +Coming soon. Tracked in `docs/ARCHITECTURE.md` § Roadmap. + +The Go port will: + +- Match the wire format defined in `docs/WIRE_FORMAT.md` +- Match the model registry in `docs/MODEL_REGISTRY.md` +- Provide a public API with semantics equivalent to the Java SDK + (same record fields, same finish-reason values, same stats shape) +- Use `//go:embed` for the same default model artifacts as the Java + bundle JAR +- Pass an equivalent test suite — the same numbered edge cases and + the same offline-network guarantee + +Until then, this directory exists only to reserve the path. See the +Java implementation under `java/`. diff --git a/java/examples/quickstart/README.md b/java/examples/quickstart/README.md new file mode 100644 index 0000000..48af3fd --- /dev/null +++ b/java/examples/quickstart/README.md @@ -0,0 +1,32 @@ +# inference-sdk — Quickstart Example + +Status: scaffold. Implementation lands in **Tier 5** of the Phase 1 plan. + +## What this will demonstrate + +A single runnable `Main.java` (~30 lines) that imports the +`inference-sdk-bundle` fat JAR and exercises the full Phase 1 surface: + +- Synchronous embedding of a small batch via `Embedder` +- Synchronous chat completion via `Generator` +- Streaming generation via `Flow.Publisher` with + backpressure +- Virtual-thread orchestration with `RequestId.withRequestId(...)` + scoped value propagation +- Structured concurrency (`StructuredTaskScope`) for fan-out + +## Run target (when implemented) + +```bash +./mvnw -pl :inference-sdk-quickstart compile exec:java +``` + +Total expected wall time on default tiny models (`bge-small-en-v1.5` +int8 + `qwen2.5-0.5b-instruct` q4_K_M) on a laptop-class CPU: under +5 seconds end-to-end (per spec acceptance criterion #10). + +## Tracking + +- Plan: `.planning/inference-sdk-java-phase1-plan.md` (Tier 5) +- Spec: `java-sdk.md` §13 ("one runnable example under + `java/examples/quickstart/`") and §15 Step 2 item 15 diff --git a/java/examples/quickstart/pom.xml b/java/examples/quickstart/pom.xml new file mode 100644 index 0000000..9b94a6a --- /dev/null +++ b/java/examples/quickstart/pom.xml @@ -0,0 +1,65 @@ + + + + 4.0.0 + + + io.github.randomcodespace.inference + inference-sdk-parent + 0.1.0-SNAPSHOT + ../../inference-sdk-parent/pom.xml + + + inference-sdk-quickstart + jar + + ${project.groupId}:${project.artifactId} + Runnable end-to-end demo for the inference-sdk Java + library. Single Main.java that imports the bundle JAR and + exercises the embed + generate + streaming + virtual-thread + + RequestId scoped-value APIs. + + + true + true + true + true + true + + + + + + io.github.randomcodespace.inference + inference-sdk-bundle + ${project.version} + + + + + + + org.codehaus.mojo + exec-maven-plugin + + io.github.randomcodespace.examples.quickstart.Main + + + + + diff --git a/java/examples/quickstart/src/main/java/io/github/randomcodespace/examples/quickstart/Main.java b/java/examples/quickstart/src/main/java/io/github/randomcodespace/examples/quickstart/Main.java new file mode 100644 index 0000000..d36f126 --- /dev/null +++ b/java/examples/quickstart/src/main/java/io/github/randomcodespace/examples/quickstart/Main.java @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + * + * Requires `make fetch-models` to populate model artifacts before running. + * + * Demonstrates the Phase-1 inference-sdk public surface in a single ~30-line main(): + * - Embedder.builder() / embedOne() — synchronous embedding + * - Generator.builder() / complete() — synchronous generation + * - Generator.stream() + Flow.Subscriber — streaming with backpressure + * - RequestId.withRequestId() — ScopedValue propagation + */ +package io.github.randomcodespace.examples.quickstart; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Flow; + +import io.github.randomcodespace.inference.embed.Embedder; +import io.github.randomcodespace.inference.generate.GenerateChunk; +import io.github.randomcodespace.inference.generate.GenerateRequest; +import io.github.randomcodespace.inference.generate.Generator; +import io.github.randomcodespace.inference.generate.Message; +import io.github.randomcodespace.inference.runtime.RequestId; + +/** End-to-end inference-sdk quickstart. */ +public final class Main { + + private Main() {} + + /** + * Run a single embed + sync generate + streaming generate, all under one {@link RequestId}. + * + * @param args ignored + * @throws Exception if any step fails + */ + public static void main(String[] args) throws Exception { + RequestId.withRequestId( + RequestId.generate(), + () -> { + try (Embedder e = Embedder.builder().model("bge-small-en-v1.5").build(); + Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").build()) { + float[] vec = e.embedOne("hello"); + System.out.println("embed[0..5]=" + Arrays.toString(Arrays.copyOf(vec, 5))); + + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "Say hi.", null, null, null))) + .maxTokens(16) + .temperature(0f) + .seed(42L) + .build(); + System.out.println("complete=" + g.complete(req).text()); + + CountDownLatch done = new CountDownLatch(1); + g.stream(req) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(GenerateChunk c) { + System.out.print(c.delta()); + if (c.done()) System.out.println(" [stats=" + c.stats() + "]"); + } + + @Override + public void onError(Throwable t) { + t.printStackTrace(); + done.countDown(); + } + + @Override + public void onComplete() { + done.countDown(); + } + }); + done.await(); + } + return null; + }); + } +} diff --git a/java/inference-sdk-bundle/README.md b/java/inference-sdk-bundle/README.md new file mode 100644 index 0000000..6320236 --- /dev/null +++ b/java/inference-sdk-bundle/README.md @@ -0,0 +1,125 @@ +# inference-sdk-bundle + +Convenience fat-JAR aggregating every `inference-sdk` Java module plus its +transitive runtime native libraries and the default bundled model artifacts. + +## What's inside + +| Component | Source | Notes | +| --- | --- | --- | +| Core API + records | `inference-sdk-core` | `ModelInfo`, `Usage`, `RequestId`, runtime helpers | +| Embedding API + ONNX backend | `inference-sdk-embed` | DJL HuggingFace tokenizers + ONNX Runtime | +| Generation API + llama.cpp backend | `inference-sdk-generate` | `de.kherud:llama:4.2.0` (bundled llama.cpp b4916) | +| ONNX Runtime natives | `com.microsoft.onnxruntime:onnxruntime:1.25.1` | manylinux2014-x64, linux-aarch64, osx-aarch64, win-x64 | +| llama.cpp natives | `de.kherud:llama:4.2.0` | linux-x64 (manylinux2014/glibc 2.17), linux-aarch64 (dockcross/glibc 2.27), win-x64 | +| Default generative model | `inference-sdk-generate-qwen-0_5b` | `qwen2.5-0.5b-instruct.q4_K_M.gguf` (Apache-2.0); weights populated by `scripts/fetch_models.py` | +| Default embedding model | (pending Tier 3) | `inference-sdk-embed-bge-small` will be added once its POM scaffolds | + +## When to use this + +- Single-JAR demos, tools, internal services, integration tests. +- Anyone who wants `java -cp inference-sdk-bundle-.jar ...` to "just + work" with embeddings and generation, without hand-resolving five Maven + artifacts and three sets of native libs. + +## When NOT to use this + +Use the per-module artifacts instead if you need any of: + +- **Fine-grained classpath control** — e.g. embeddings only (skip kherud + + llama natives), or swapping in a forked llama.cpp build. +- **JPMS / modulepath strict mode** — this bundle ships as an unnamed + classpath JAR; module-info descriptors are stripped during shade. Depend + on `inference-sdk-core`, `-embed`, and `-generate` directly to get + well-formed JPMS modules with proper `requires`/`exports` clauses. +- **Smaller deployment artifacts** — the fat JAR carries every supported + OS/arch's native libraries unconditionally. If you control the deploy + target you can ship a much smaller artifact by depending on only the + modules you need. + +Per-module coordinates: + +```xml + + io.github.randomcodespace.inference + inference-sdk-core + 0.1.0 + + + io.github.randomcodespace.inference + inference-sdk-embed + 0.1.0 + + + io.github.randomcodespace.inference + inference-sdk-generate + 0.1.0 + +``` + +## Quick start + +```bash +java \ + --enable-preview \ + -cp inference-sdk-bundle-0.1.0.jar:your-app.jar \ + com.example.YourApp +``` + +```java +// Embeddings (default model: bge-small-en-v1.5 once the bge-small module ships) +try (var embedder = Embedder.builder() + .model("bge-small-en-v1.5") + .build()) { + float[] vec = embedder.embed("hello world").vector(); +} + +// Generation (default model: qwen2.5-0.5b-instruct.q4_K_M) +try (var generator = Generator.builder() + .model("qwen2.5-0.5b-instruct") + .build()) { + var resp = generator.generate(GenerateRequest.of("Why is the sky blue?")); + System.out.println(resp.text()); +} +``` + +## Module-info tradeoff + +This bundle is built with `maven-shade-plugin` and ships as an **unnamed +classpath-only artifact** — `module-info.class` entries from every shaded +dependency are dropped, and this module ships no module-info of its own. + +Reasons: + +1. `maven-shade-plugin`'s JPMS support is still experimental as of 3.6.2. + Merging multiple module descriptors into a single bundle module-info + either drops `exports`/`uses` clauses or produces an inconsistent + descriptor across rebuilds, breaking the reproducible-build guarantees + carried from `inference-sdk-parent`. +2. JPMS-strict consumers should be using per-module artifacts anyway + (see "When NOT to use this" above) — those carry well-formed + `module-info.java` files. +3. Classpath consumers (~99% of fat-JAR users) do not read module-info, + so the unnamed-bundle path is functionally equivalent for them. + +If you have a use case that requires a JPMS-named bundle module, please +open an issue — we will revisit when shade plugin's JPMS story stabilises. + +## Security + +The bundle inherits the residual `GHSA-7rxv` tokenizer surface from upstream +`de.kherud:llama:4.2.0` (bundled llama.cpp b4916). This is mitigated at the +API boundary by `InputValidator` in `inference-sdk-generate`. See +[`SECURITY.md`](../../SECURITY.md) for full details and the Phase 1.5 +fork-and-bump plan. + +## Reproducible builds + +Built deterministically with `project.build.outputTimestamp` pinned in +`inference-sdk-parent`. The shaded JAR carries `Reproducible-Build: true` +in its `MANIFEST.MF`. Two builds from the same source produce +byte-identical artifacts. + +## License + +Apache License 2.0. See [`LICENSE`](../../LICENSE) at the repo root. diff --git a/java/inference-sdk-bundle/pom.xml b/java/inference-sdk-bundle/pom.xml new file mode 100644 index 0000000..f55af73 --- /dev/null +++ b/java/inference-sdk-bundle/pom.xml @@ -0,0 +1,290 @@ + + + + 4.0.0 + + + io.github.randomcodespace.inference + inference-sdk-parent + 0.1.0-SNAPSHOT + ../inference-sdk-parent/pom.xml + + + inference-sdk-bundle + jar + + inference-sdk-bundle + Fat-JAR convenience artifact aggregating every + inference-sdk Java module (core + embed + generate), all + transitive runtime native libs (ONNX Runtime + kherud + llama.cpp), DJL HuggingFace Tokenizers, and the default + bundled model artifacts. Built via maven-shade-plugin with + a dependency-reduced-pom so consumers do not pull every + transitive dependency a second time. + + + + + io.github.randomcodespace.inference + inference-sdk-core + ${project.version} + + + + + io.github.randomcodespace.inference + inference-sdk-embed + ${project.version} + + + + + io.github.randomcodespace.inference + inference-sdk-generate + ${project.version} + + + + + io.github.randomcodespace.inference + inference-sdk-generate-qwen-0_5b + ${project.version} + + + + + io.github.randomcodespace.inference + inference-sdk-embed-bge-small + ${project.version} + + + + + + inference-sdk-bundle-${project.version} + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-bundle + package + + shade + + + true + false + false + false + + + + + true + ${project.artifactId} + ${project.version} + RandomCodeSpace + false + + + + + + *:* + + + module-info.class + META-INF/versions/*/module-info.class + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + META-INF/maven/**/pom.properties + + + + + + + + + + + org.jacoco + jacoco-maven-plugin + + + jacoco-prepare-agent + + prepare-agent + + + true + + + + jacoco-report + + report + + + true + + + + jacoco-check + + check + + + true + + + + + + + diff --git a/java/inference-sdk-core/pom.xml b/java/inference-sdk-core/pom.xml new file mode 100644 index 0000000..dcd5e3f --- /dev/null +++ b/java/inference-sdk-core/pom.xml @@ -0,0 +1,70 @@ + + + + 4.0.0 + + + io.github.randomcodespace.inference + inference-sdk-parent + 0.1.0-SNAPSHOT + ../inference-sdk-parent/pom.xml + + + inference-sdk-core + jar + + inference-sdk-core + Shared types (records, sealed interfaces) and runtime helpers + (ContainerCpu, NativeExecutor, RequestId, NativeLibLoader) used + by every module in the inference-sdk reactor. + + + + + org.slf4j + slf4j-api + + + + + org.junit.jupiter + junit-jupiter + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.assertj + assertj-core + test + + + nl.jqno.equalsverifier + equalsverifier + test + + + ch.qos.logback + logback-classic + test + + + diff --git a/java/inference-sdk-core/spotbugs-exclude.xml b/java/inference-sdk-core/spotbugs-exclude.xml new file mode 100644 index 0000000..667b0d7 --- /dev/null +++ b/java/inference-sdk-core/spotbugs-exclude.xml @@ -0,0 +1,19 @@ + + + + + diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Failure.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Failure.java new file mode 100644 index 0000000..59ed3bb --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Failure.java @@ -0,0 +1,14 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +/** + * Failure arm of a {@link Result}. Carries an error value. + * + * @param error the failure value + * @param success-arm type (unused by this arm) + * @param failure-arm error type + */ +public record Failure(E error) implements Result {} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/FinishReason.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/FinishReason.java new file mode 100644 index 0000000..b352bb9 --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/FinishReason.java @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +/** + * Reason a generation request stopped producing tokens. Sealed sum-type with five variants: + * + *
    + *
  • {@link Stop} — a stop sequence was emitted. + *
  • {@link Length} — caller-imposed token cap reached. + *
  • {@link Eos} — end-of-stream token emitted by the model. + *
  • {@link Canceled} — caller cancelled the request. + *
  • {@link Error} — generation failed with the carried message. + *
+ * + *

Wire format (per {@code docs/WIRE_FORMAT.md}): lowercase string tag — {@code "stop"}, {@code + * "length"}, {@code "eos"}, {@code "canceled"}, {@code "error"}. + */ +public sealed interface FinishReason + permits FinishReason.Stop, + FinishReason.Length, + FinishReason.Eos, + FinishReason.Canceled, + FinishReason.Error { + + /** Generation halted because a stop sequence was produced. */ + record Stop() implements FinishReason {} + + /** Generation halted because the requested token limit was hit. */ + record Length() implements FinishReason {} + + /** Generation halted because the model emitted an end-of-stream token. */ + record Eos() implements FinishReason {} + + /** Generation halted because the caller cancelled the request. */ + record Canceled() implements FinishReason {} + + /** + * Generation halted because of an error. + * + * @param message human-readable error description; never {@code null} + */ + record Error(String message) implements FinishReason { + /** + * Compact constructor: {@code message} must be non-null. The empty string is permitted but + * discouraged. + */ + public Error { + if (message == null) { + throw new IllegalArgumentException("message must not be null"); + } + } + } +} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/ModelInfo.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/ModelInfo.java new file mode 100644 index 0000000..97d08dd --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/ModelInfo.java @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +/** + * Static metadata describing a registered model. + * + *

Generation models report {@code dimensions = -1} since the field is meaningless for + * autoregressive decoders. Embedding models populate it with the produced vector size. + * + * @param id stable model identifier (e.g. {@code "bge-small-en-v1.5"}) + * @param revision content hash, tag, or release identifier + * @param quantization quantization scheme used (e.g. {@code "q4_k_m"}, {@code "fp16"}) + * @param dimensions embedding vector size; {@code -1} for generation models + * @param maxTokens maximum context window in tokens supported by this model + */ +public record ModelInfo( + String id, String revision, String quantization, int dimensions, int maxTokens) {} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/NativeLoadException.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/NativeLoadException.java new file mode 100644 index 0000000..4353c2c --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/NativeLoadException.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +/** + * Thrown when an inference-sdk module fails to extract or verify a native library shipped inside + * its JAR. + * + *

Common root causes (per {@code java-sdk.md} §6.2): + * + *

    + *
  • Missing native resource for the host's {@code os.name}/{@code os.arch} (architecture + * mismatch). + *
  • SHA-256 verification failure between the extracted file and its sibling {@code .sha256} + * resource (corrupt JAR, MITM, classpath shadowing). + *
  • Host glibc older than the version used to build the native library. + *
  • Filesystem error while writing to the per-JVM extraction directory. + *
+ * + *

Messages are intentionally verbose and remediation-oriented (see {@link + * io.github.randomcodespace.inference.runtime.NativeLibLoader}). + */ +public class NativeLoadException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** + * @param message actionable description of the failure (host OS/arch, expected resource path, + * remediation) + */ + public NativeLoadException(String message) { + super(message); + } + + /** + * @param message actionable description of the failure + * @param cause underlying throwable (e.g. {@link java.io.IOException}) + */ + public NativeLoadException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Result.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Result.java new file mode 100644 index 0000000..0e66371 --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Result.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +/** + * Sum type representing the outcome of an operation: either a {@link Success} carrying a value of + * type {@code T} or a {@link Failure} carrying an error of type {@code E}. + * + *

Use {@link #success(Object)} and {@link #failure(Object)} factories rather than constructing + * the variants directly. Pattern-match against the sealed permits to handle both arms: + * + *

{@code
+ * switch (result) {
+ *     case Success s -> use(s.value());
+ *     case Failure f -> log(f.error());
+ * }
+ * }
+ * + * @param success-arm value type + * @param failure-arm error type + */ +public sealed interface Result permits Success, Failure { + + /** + * Lift a value into the success arm. + * + * @param value the success value + * @param success-arm type + * @param failure-arm type (inferred at the call site) + * @return a {@link Success} wrapping {@code value} + */ + static Result success(T value) { + return new Success<>(value); + } + + /** + * Lift an error into the failure arm. + * + * @param error the failure value + * @param success-arm type (inferred at the call site) + * @param failure-arm type + * @return a {@link Failure} wrapping {@code error} + */ + static Result failure(E error) { + return new Failure<>(error); + } +} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Success.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Success.java new file mode 100644 index 0000000..1efa47d --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Success.java @@ -0,0 +1,14 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +/** + * Success arm of a {@link Result}. Carries the produced value. + * + * @param value the success value (may be {@code null} if {@code T} is nullable) + * @param success-arm value type + * @param failure-arm error type (unused by this arm) + */ +public record Success(T value) implements Result {} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Usage.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Usage.java new file mode 100644 index 0000000..fa95e62 --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/Usage.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +/** + * Token-accounting summary for a single inference request. + * + *

Invariants enforced by the compact constructor: + * + *

    + *
  • {@code promptTokens >= 0} + *
  • {@code completionTokens >= 0} + *
  • {@code totalTokens == promptTokens + completionTokens} + *
+ * + * @param promptTokens tokens fed to the model (input) + * @param completionTokens tokens produced by the model (output); zero for embedding requests + * @param totalTokens sum of {@code promptTokens} and {@code completionTokens} + */ +public record Usage(int promptTokens, int completionTokens, int totalTokens) { + + /** + * Compact constructor: validates the non-negativity and sum invariants. Throws {@link + * IllegalArgumentException} on violation. + */ + public Usage { + if (promptTokens < 0 || completionTokens < 0) { + throw new IllegalArgumentException("token counts must be non-negative"); + } + if (totalTokens != promptTokens + completionTokens) { + throw new IllegalArgumentException("totalTokens must equal sum"); + } + } +} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/ContainerCpu.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/ContainerCpu.java new file mode 100644 index 0000000..f46a618 --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/ContainerCpu.java @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.runtime; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Locale; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Container-aware CPU count detector. + * + *

Reads {@code /sys/fs/cgroup/cpu.max} (cgroups v2 quota file) and computes {@code ceil(quota / + * period)}; falls back to {@link Runtime#availableProcessors()} when the file is missing, + * unreadable, malformed, or contains the unconstrained sentinel ({@code "max "}). + * + *

This matters for any container runtime that limits CPU via cgroups v2 (Kubernetes, Docker with + * {@code --cpus}, systemd slices, OpenShift). On a 32-core host where the container is allocated 2 + * CPUs, {@link Runtime#availableProcessors()} returns 32 — which over-sizes the {@link + * NativeExecutor} and piles llama.cpp threads onto overcommitted carriers. This helper returns 2, + * which is the right answer. + * + *

cgroups v1 is not supported in Phase 1; v1 paths fall through to {@code + * availableProcessors()}. + */ +public final class ContainerCpu { + + private static final Logger LOG = LoggerFactory.getLogger(ContainerCpu.class); + + /** Default cgroups v2 quota file path on Linux. */ + private static final Path DEFAULT_CGROUP_PATH = Paths.get("/sys/fs/cgroup/cpu.max"); + + private ContainerCpu() { + // Static helper class. + } + + /** + * Detect the effective CPU count visible to this JVM. Reads the default cgroups v2 quota file; + * see {@link #detect(Path)} for the test-overridable form. + * + * @return effective CPU count (always {@code >= 1}) + */ + public static int detect() { + return detect(DEFAULT_CGROUP_PATH); + } + + /** + * Test-overridable form of {@link #detect()}: reads the supplied cgroup quota file path. + * + *

Visible at package scope so unit tests can substitute a fixture file. Production callers + * should use {@link #detect()}. + * + * @param cgroupPath path to a cgroups v2 {@code cpu.max} formatted file + * @return effective CPU count (always {@code >= 1}) + */ + static int detect(Path cgroupPath) { + try { + if (cgroupPath == null || !Files.isReadable(cgroupPath)) { + return availableProcessorsFloor(); + } + String raw = Files.readString(cgroupPath, StandardCharsets.UTF_8).trim(); + if (raw.isEmpty()) { + LOG.debug("cgroup file {} is empty; using availableProcessors()", cgroupPath); + return availableProcessorsFloor(); + } + // Format: " " or "max ". + String[] parts = raw.split("\\s+", 3); + if (parts.length < 2) { + LOG.debug( + "cgroup file {} has unexpected format: '{}'; using availableProcessors()", + cgroupPath, + raw); + return availableProcessorsFloor(); + } + String quotaToken = parts[0]; + String periodToken = parts[1]; + if ("max".equalsIgnoreCase(quotaToken)) { + // Uncapped — fall back. + return availableProcessorsFloor(); + } + long quota = Long.parseLong(quotaToken); + long period = Long.parseLong(periodToken); + if (quota <= 0 || period <= 0) { + return availableProcessorsFloor(); + } + int detected = (int) Math.ceil((double) quota / (double) period); + int result = Math.max(1, detected); + LOG.debug( + String.format( + Locale.ROOT, + "cgroup quota=%d period=%d -> effective CPUs=%d", + quota, + period, + result)); + return result; + } catch (IOException | NumberFormatException ex) { + LOG.debug("Failed to parse cgroup file {}: {}", cgroupPath, ex.toString()); + return availableProcessorsFloor(); + } + } + + private static int availableProcessorsFloor() { + return Math.max(1, Runtime.getRuntime().availableProcessors()); + } +} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/NativeExecutor.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/NativeExecutor.java new file mode 100644 index 0000000..a30257f --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/NativeExecutor.java @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.runtime; + +import java.util.Locale; +import java.util.Objects; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Bounded executor that runs JNI-bound work on platform threads, never on virtual + * threads. + * + *

Background — this is the single most important runtime invariant in inference-sdk. The native + * libraries (ONNX Runtime, llama.cpp) hold internal state across calls and pin the carrier thread. + * Submitting native work directly to a virtual-thread executor would either pin the carrier (under + * old JDKs) or expose stale per-thread state to the next caller; either failure mode is a + * production-shaped bug. + * + *

Therefore: every JNI call is wrapped in {@link #submitNative(Callable)} which trampolines to a + * platform-thread pool sized by {@link ContainerCpu#detect()}. Caller virtual threads await the + * resulting {@link CompletableFuture} the normal way ({@code .get()} / {@code .thenApply}); the JVM + * correctly recognises the wait as park-and-yield, freeing the carrier. + * + *

Lifecycle: implements {@link AutoCloseable}; {@link #close()} and {@link #shutdown()} are + * idempotent. Submissions made after shutdown throw {@link IllegalStateException}. + * + *

Threads in the pool are named {@code -} where {@code n} starts at 1. + */ +public final class NativeExecutor implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(NativeExecutor.class); + + /** Time budget for graceful pool drain in {@link #shutdown()}. */ + private static final long SHUTDOWN_TIMEOUT_SECONDS = 30L; + + private final ThreadPoolExecutor delegate; + private final AtomicBoolean closed = new AtomicBoolean(false); + private final String namePrefix; + + private NativeExecutor(int threads, String namePrefix) { + this.namePrefix = namePrefix; + ThreadFactory factory = new PlatformThreadFactory(namePrefix); + this.delegate = + new ThreadPoolExecutor( + threads, threads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(), factory); + this.delegate.allowCoreThreadTimeOut(false); + } + + /** + * Create a fixed-size native executor. + * + * @param threads pool size; must be {@code >= 1} + * @param namePrefix thread-name prefix (e.g. {@code "inference-native"}); must be non-blank + * @return a fresh executor; the caller is responsible for {@link #close()} + * @throws IllegalArgumentException if {@code threads < 1} or {@code namePrefix} is blank + */ + public static NativeExecutor sized(int threads, String namePrefix) { + if (threads < 1) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "threads must be >= 1, got %d", threads)); + } + Objects.requireNonNull(namePrefix, "namePrefix"); + if (namePrefix.isBlank()) { + throw new IllegalArgumentException("namePrefix must not be blank"); + } + return new NativeExecutor(threads, namePrefix); + } + + /** + * Submit a JNI call for execution on a pool platform thread. The returned future completes with + * the callable's result or its thrown exception. + * + * @param nativeCall callable wrapping the native call + * @param return type of the callable + * @return future that completes when the native call returns + * @throws IllegalStateException if {@link #shutdown()} or {@link #close()} has been invoked + * @throws NullPointerException if {@code nativeCall} is null + */ + public CompletableFuture submitNative(Callable nativeCall) { + Objects.requireNonNull(nativeCall, "nativeCall"); + if (closed.get()) { + throw new IllegalStateException( + String.format( + Locale.ROOT, "NativeExecutor[%s] is shut down; submissions rejected", namePrefix)); + } + CompletableFuture future = new CompletableFuture<>(); + try { + delegate.execute( + () -> { + try { + future.complete(nativeCall.call()); + } catch (Throwable t) { + future.completeExceptionally(t); + } + }); + } catch (Throwable t) { + // RejectedExecutionException after a race with shutdown. + future.completeExceptionally(t); + } + return future; + } + + /** + * Initiate orderly shutdown: stop accepting new submissions, drain the queue, and wait up to + * {@value #SHUTDOWN_TIMEOUT_SECONDS} seconds for in-flight work to complete. After the timeout + * lapses, force-cancels the remaining tasks. Idempotent. + */ + public void shutdown() { + if (!closed.compareAndSet(false, true)) { + return; + } + delegate.shutdown(); + try { + if (!delegate.awaitTermination(SHUTDOWN_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + LOG.warn( + "NativeExecutor[{}] did not terminate within {}s; forcing", + namePrefix, + SHUTDOWN_TIMEOUT_SECONDS); + delegate.shutdownNow(); + } + } catch (InterruptedException ex) { + delegate.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + + /** Idempotent alias for {@link #shutdown()}. */ + @Override + public void close() { + shutdown(); + } + + /** Thread factory producing daemon platform threads with a stable name pattern. */ + private static final class PlatformThreadFactory implements ThreadFactory { + private final String prefix; + private final AtomicInteger counter = new AtomicInteger(0); + + PlatformThreadFactory(String prefix) { + this.prefix = prefix; + } + + @Override + public Thread newThread(Runnable r) { + int n = counter.incrementAndGet(); + // Builder-form ofPlatform() avoids inheriting any virtual-thread context. + Thread t = + Thread.ofPlatform() + .name(String.format(Locale.ROOT, "%s-%d", prefix, n)) + .daemon(true) + .unstarted(r); + return t; + } + } +} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/NativeLibLoader.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/NativeLibLoader.java new file mode 100644 index 0000000..cf9fa55 --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/NativeLibLoader.java @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.runtime; + +import java.io.IOException; +import java.io.InputStream; +import java.lang.management.ManagementFactory; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.HexFormat; +import java.util.Locale; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.github.randomcodespace.inference.NativeLoadException; + +/** + * Extracts a native library shipped inside a module's JAR to a per-JVM temporary directory and + * verifies its SHA-256 digest against a sibling {@code .sha256} resource. + * + *

Algorithm (per {@code java-sdk.md} §6.2): + * + *

    + *
  1. Resolve the per-JVM extraction directory: {@code + * ${java.io.tmpdir}/inference-sdk-${pid}-${uuid}/} computed once per process, created lazily. + *
  2. Open the resource at {@code resourcePath} via the system class loader. Throw {@link + * NativeLoadException} on missing resource, with the host's detected OS+arch and the expected + * resource path embedded in the message. + *
  3. Open the sibling {@code .sha256} resource. Format: hex digest, optionally with a + * trailing whitespace + filename (GNU {@code sha256sum} format). + *
  4. Stream the resource to a temp file and SHA-256-digest the bytes simultaneously. If the + * digest mismatches, delete the temp file and throw {@link NativeLoadException}. + *
  5. {@code Files.move} (atomic where supported) to the final location and return the path. + *
  6. Subsequent calls for the same resource path are idempotent: return the previously extracted + * file. + *
+ * + *

Race avoidance: the extraction directory embeds {@code pid} + {@code uuid} so two JVMs never + * collide; within a single JVM, an internal cache short-circuits repeat calls. + */ +public final class NativeLibLoader { + + private static final Logger LOG = LoggerFactory.getLogger(NativeLibLoader.class); + + private static final ConcurrentMap EXTRACTED = new ConcurrentHashMap<>(); + private static final Object DIR_LOCK = new Object(); + private static volatile Path baseDir; + + private NativeLibLoader() { + // Static helper class. + } + + /** + * Extract a native resource and verify its SHA-256 digest. Idempotent: repeated calls with the + * same {@code resourcePath} return the same {@link Path}. + * + * @param resourcePath classpath-style resource path (e.g. {@code + * "/native/linux-x86_64/libonnxruntime.so"}) + * @return absolute path to the extracted, verified file + * @throws NativeLoadException if the resource is missing, the SHA-256 sibling is missing, the + * digests mismatch, or any I/O error occurs + */ + public static Path extractAndVerify(String resourcePath) { + Objects.requireNonNull(resourcePath, "resourcePath"); + return EXTRACTED.computeIfAbsent(resourcePath, NativeLibLoader::doExtract); + } + + /** For tests: clear the per-resource cache so we can re-run extraction inside one JVM. */ + static void resetCacheForTesting() { + EXTRACTED.clear(); + } + + private static Path doExtract(String resourcePath) { + String fileName = lastSegment(resourcePath); + Path target = ensureBaseDir().resolve(fileName); + + ClassLoader loader = preferredClassLoader(); + String shaResourcePath = resourcePath + ".sha256"; + + String expectedDigest = readExpectedDigest(loader, shaResourcePath, resourcePath); + + try (InputStream in = openResource(loader, resourcePath)) { + if (in == null) { + throw new NativeLoadException(missingResourceMessage(resourcePath)); + } + String actualDigest = streamToFileWithDigest(in, target); + if (!expectedDigest.equalsIgnoreCase(actualDigest)) { + deleteQuietly(target); + throw new NativeLoadException( + digestMismatchMessage(resourcePath, shaResourcePath, expectedDigest, actualDigest)); + } + LOG.debug("Extracted native resource {} -> {} (sha256 verified)", resourcePath, target); + return target; + } catch (IOException ex) { + deleteQuietly(target); + throw new NativeLoadException( + String.format( + Locale.ROOT, + "I/O error extracting native resource %s to %s: %s", + resourcePath, + target, + ex.getMessage()), + ex); + } + } + + private static String readExpectedDigest( + ClassLoader loader, String shaResourcePath, String resourcePath) { + try (InputStream in = openResource(loader, shaResourcePath)) { + if (in == null) { + throw new NativeLoadException( + String.format( + Locale.ROOT, + "Native checksum resource %s is missing for native library %s. " + + "Expected a sibling .sha256 file produced via " + + "`sha256sum %s`. Detected host: os=%s, arch=%s. " + + "Re-build or re-package the module to include the " + + "checksum.", + shaResourcePath, + resourcePath, + lastSegment(resourcePath), + System.getProperty("os.name", "unknown"), + System.getProperty("os.arch", "unknown"))); + } + String raw = new String(in.readAllBytes(), StandardCharsets.UTF_8).trim(); + if (raw.isEmpty()) { + throw new NativeLoadException( + String.format( + Locale.ROOT, + "Native checksum resource %s is empty for %s.", + shaResourcePath, + resourcePath)); + } + // Accept either "" or " filename" (GNU sha256sum format). + String[] tokens = raw.split("\\s+", 2); + return tokens[0].toLowerCase(Locale.ROOT); + } catch (IOException ex) { + throw new NativeLoadException( + String.format( + Locale.ROOT, + "Failed to read native checksum resource %s: %s", + shaResourcePath, + ex.getMessage()), + ex); + } + } + + private static String streamToFileWithDigest(InputStream in, Path target) throws IOException { + Files.createDirectories(target.getParent()); + Path tmp = target.resolveSibling(target.getFileName() + ".tmp"); + MessageDigest md; + try { + md = MessageDigest.getInstance("SHA-256"); + } catch (NoSuchAlgorithmException ex) { + throw new NativeLoadException("SHA-256 unavailable in this JVM", ex); + } + try { + byte[] buf = new byte[64 * 1024]; + try (var out = Files.newOutputStream(tmp)) { + int n; + while ((n = in.read(buf)) > 0) { + md.update(buf, 0, n); + out.write(buf, 0, n); + } + } + try { + Files.move( + tmp, target, StandardCopyOption.REPLACE_EXISTING, StandardCopyOption.ATOMIC_MOVE); + } catch (IOException atomicMoveFailed) { + // ATOMIC_MOVE may not be supported on all FSes; fall back. + Files.move(tmp, target, StandardCopyOption.REPLACE_EXISTING); + } + return HexFormat.of().formatHex(md.digest()); + } finally { + deleteQuietly(tmp); + } + } + + private static Path ensureBaseDir() { + Path local = baseDir; + if (local != null) { + return local; + } + synchronized (DIR_LOCK) { + if (baseDir != null) { + return baseDir; + } + String tmp = System.getProperty("java.io.tmpdir"); + Path tmpDir = Paths.get(tmp == null ? "/tmp" : tmp); + String pid; + try { + pid = Long.toString(ManagementFactory.getRuntimeMXBean().getPid()); + } catch (UnsupportedOperationException ex) { + pid = "unknown"; + } + Path dir = + tmpDir.resolve(String.format(Locale.ROOT, "inference-sdk-%s-%s", pid, UUID.randomUUID())); + try { + Files.createDirectories(dir); + } catch (IOException ex) { + throw new NativeLoadException( + String.format( + Locale.ROOT, + "Failed to create native extraction directory %s: %s", + dir, + ex.getMessage()), + ex); + } + baseDir = dir; + return dir; + } + } + + private static InputStream openResource(ClassLoader loader, String path) { + // Accept both leading-slash and no-leading-slash inputs. + String normalized = path.startsWith("/") ? path.substring(1) : path; + InputStream in = loader.getResourceAsStream(normalized); + if (in != null) { + return in; + } + // Fallback to the loader of this class (modulepath case). + return NativeLibLoader.class.getClassLoader().getResourceAsStream(normalized); + } + + private static ClassLoader preferredClassLoader() { + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + return cl != null ? cl : NativeLibLoader.class.getClassLoader(); + } + + private static String lastSegment(String resourcePath) { + int idx = resourcePath.lastIndexOf('/'); + return idx < 0 ? resourcePath : resourcePath.substring(idx + 1); + } + + private static void deleteQuietly(Path p) { + try { + Files.deleteIfExists(p); + } catch (IOException ignored) { + // Best-effort cleanup. + } + } + + private static String missingResourceMessage(String resourcePath) { + return String.format( + Locale.ROOT, + "Native resource %s is not on the classpath. " + + "Detected host: os=%s, arch=%s. " + + "Expected layout: a sibling .sha256 resource at %s.sha256. " + + "Remediation: confirm the module shipping this resource is on the " + + "build path, that the JAR contains an entry under META-INF/native/ " + + "for your platform, and that there is no shadowing classpath entry " + + "with a stripped resources/ directory.", + resourcePath, + System.getProperty("os.name", "unknown"), + System.getProperty("os.arch", "unknown"), + resourcePath); + } + + private static String digestMismatchMessage( + String resourcePath, String shaResourcePath, String expected, String actual) { + return String.format( + Locale.ROOT, + "SHA-256 mismatch for native resource %s. " + + "Expected (from %s): %s. " + + "Actual: %s. " + + "This indicates either a corrupted JAR or a classpath conflict where " + + "two artifacts ship the same resource path. " + + "Detected host: os=%s, arch=%s.", + resourcePath, + shaResourcePath, + expected, + actual, + System.getProperty("os.name", "unknown"), + System.getProperty("os.arch", "unknown")); + } +} diff --git a/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/RequestId.java b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/RequestId.java new file mode 100644 index 0000000..50585df --- /dev/null +++ b/java/inference-sdk-core/src/main/java/io/github/randomcodespace/inference/runtime/RequestId.java @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.runtime; + +import java.util.Locale; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.Callable; + +/** + * Per-request correlation id propagated via {@link ScopedValue} (JEP 487, finalized in Java 25). + * + *

The {@link #CURRENT} {@code ScopedValue} is bound by {@link #withRequestId(String, Callable)} + * and remains visible through any call tree the body invokes — including child virtual threads + * launched via {@link java.util.concurrent.StructuredTaskScope}. Unlike {@link ThreadLocal}, {@code + * ScopedValue} bindings are immutable, automatically inherited by structured forks, and cleaned up + * when the binding scope exits — there is no leak risk with virtual threads. + * + *

Generated identifiers have the form {@code "req_"} where {@code } is a + * canonically-formatted random UUIDv4. + */ +public final class RequestId { + + /** + * Scoped value holding the request id active in the current execution scope. Empty when not bound + * (i.e. outside any {@link #withRequestId} body). + */ + public static final ScopedValue CURRENT = ScopedValue.newInstance(); + + private RequestId() { + // Static helper class. + } + + /** + * @return a fresh request id, e.g. {@code "req_3f5b6e36-9d8c-4ad2-9d1e-08fbe23ee5fa"} + */ + public static String generate() { + return String.format(Locale.ROOT, "req_%s", UUID.randomUUID()); + } + + /** + * Run {@code body} with {@link #CURRENT} bound to {@code id}. The binding is visible to any code + * reachable from {@code body}, including code dispatched to a {@code StructuredTaskScope}, but is + * invisible outside that scope. + * + * @param id request id to bind; must not be {@code null} + * @param body work to run with the binding active + * @param return type of {@code body} + * @return whatever {@code body} returns + * @throws Exception any exception thrown by {@code body}, propagated unchanged + */ + public static T withRequestId(String id, Callable body) throws Exception { + Objects.requireNonNull(id, "id"); + Objects.requireNonNull(body, "body"); + return ScopedValue.where(CURRENT, id).call(body::call); + } +} diff --git a/java/inference-sdk-core/src/main/java/module-info.java b/java/inference-sdk-core/src/main/java/module-info.java new file mode 100644 index 0000000..234c557 --- /dev/null +++ b/java/inference-sdk-core/src/main/java/module-info.java @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ + +/** + * inference-sdk-core: shared records and runtime helpers shared by every inference-sdk module. + * + *

Consumers can use this module on either the modulepath (preferred) or the classpath; both are + * supported. See {@link io.github.randomcodespace.inference.runtime.NativeExecutor} for the + * native-thread-pinning workaround that virtual-thread callers rely on, and {@link + * io.github.randomcodespace.inference.runtime.ContainerCpu} for cgroups-v2-aware CPU detection. + */ +module io.github.randomcodespace.inference.core { + exports io.github.randomcodespace.inference; + exports io.github.randomcodespace.inference.runtime; + + requires org.slf4j; + // RuntimeMXBean for pid resolution in NativeLibLoader. + requires java.management; +} diff --git a/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/FinishReasonTest.java b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/FinishReasonTest.java new file mode 100644 index 0000000..be786d5 --- /dev/null +++ b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/FinishReasonTest.java @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import nl.jqno.equalsverifier.EqualsVerifier; +import nl.jqno.equalsverifier.Warning; + +class FinishReasonTest { + + @Test + void allRecordsAreValueEqual() { + EqualsVerifier.forClass(FinishReason.Stop.class).verify(); + EqualsVerifier.forClass(FinishReason.Length.class).verify(); + EqualsVerifier.forClass(FinishReason.Eos.class).verify(); + EqualsVerifier.forClass(FinishReason.Canceled.class).verify(); + // Error rejects null in its compact constructor; suppress null-field + // probing so EqualsVerifier doesn't try to instantiate Error(null). + EqualsVerifier.forClass(FinishReason.Error.class).suppress(Warning.NULL_FIELDS).verify(); + } + + @Test + void sealedInterfacePermitsAllFiveVariants() { + Class[] permits = FinishReason.class.getPermittedSubclasses(); + assertThat(permits).isNotNull(); + List> permitsList = Arrays.asList(permits); + assertThat(permitsList) + .containsExactlyInAnyOrder( + FinishReason.Stop.class, + FinishReason.Length.class, + FinishReason.Eos.class, + FinishReason.Canceled.class, + FinishReason.Error.class); + assertThat(FinishReason.class.isSealed()).isTrue(); + } + + @Test + void errorVariantExposesMessage() { + FinishReason.Error err = new FinishReason.Error("bad token"); + assertThat(err.message()).isEqualTo("bad token"); + } + + @Test + void errorVariantRejectsNullMessage() { + assertThatThrownBy(() -> new FinishReason.Error(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("message"); + } + + @Test + void canPatternMatchExhaustively() { + FinishReason r = new FinishReason.Length(); + String tag = + switch (r) { + case FinishReason.Stop s -> "stop"; + case FinishReason.Length l -> "length"; + case FinishReason.Eos e -> "eos"; + case FinishReason.Canceled c -> "canceled"; + case FinishReason.Error e -> "error:" + e.message(); + }; + assertThat(tag).isEqualTo("length"); + } +} diff --git a/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/ModelInfoTest.java b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/ModelInfoTest.java new file mode 100644 index 0000000..973a71c --- /dev/null +++ b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/ModelInfoTest.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import nl.jqno.equalsverifier.EqualsVerifier; + +class ModelInfoTest { + + @Test + void recordIsValueEqual() { + EqualsVerifier.forClass(ModelInfo.class).verify(); + } + + @Test + void componentAccessorsExposeConstructorArgs() { + ModelInfo info = new ModelInfo("bge-small-en-v1.5", "abc123", "fp32", 384, 512); + assertThat(info.id()).isEqualTo("bge-small-en-v1.5"); + assertThat(info.revision()).isEqualTo("abc123"); + assertThat(info.quantization()).isEqualTo("fp32"); + assertThat(info.dimensions()).isEqualTo(384); + assertThat(info.maxTokens()).isEqualTo(512); + } + + @Test + void generationModelMayUseSentinelDimensions() { + ModelInfo gen = new ModelInfo("qwen2.5-0.5b", "v1", "q4_k_m", -1, 32768); + assertThat(gen.dimensions()).isEqualTo(-1); + } +} diff --git a/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/ResultTest.java b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/ResultTest.java new file mode 100644 index 0000000..4d49806 --- /dev/null +++ b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/ResultTest.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import nl.jqno.equalsverifier.EqualsVerifier; + +class ResultTest { + + @Test + void successFactoryProducesSuccessArm() { + Result r = Result.success(42); + assertThat(r).isInstanceOfSatisfying(Success.class, s -> assertThat(s.value()).isEqualTo(42)); + } + + @Test + void failureFactoryProducesFailureArm() { + Result r = Result.failure("boom"); + assertThat(r) + .isInstanceOfSatisfying(Failure.class, f -> assertThat(f.error()).isEqualTo("boom")); + } + + @Test + void successAndFailureRecordsAreValueEqual() { + EqualsVerifier.forClass(Success.class).verify(); + EqualsVerifier.forClass(Failure.class).verify(); + } + + @Test + void sealedInterfacePermitsExactlySuccessAndFailure() { + Class[] permits = Result.class.getPermittedSubclasses(); + assertThat(permits).isNotNull(); + List> permitsList = Arrays.asList(permits); + assertThat(permitsList).containsExactlyInAnyOrder(Success.class, Failure.class); + assertThat(Result.class.isSealed()).isTrue(); + } + + @Test + void canPatternMatchOverArms() { + Result ok = Result.success(7); + Result err = Result.failure("nope"); + + String okDesc = + switch (ok) { + case Success s -> "ok=" + s.value(); + case Failure f -> "err=" + f.error(); + }; + String errDesc = + switch (err) { + case Success s -> "ok=" + s.value(); + case Failure f -> "err=" + f.error(); + }; + assertThat(okDesc).isEqualTo("ok=7"); + assertThat(errDesc).isEqualTo("err=nope"); + } +} diff --git a/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/UsageTest.java b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/UsageTest.java new file mode 100644 index 0000000..0ae4adf --- /dev/null +++ b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/UsageTest.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.junit.jupiter.api.Test; + +class UsageTest { + + /* + * Note: EqualsVerifier 4.5 instantiates records with synthetic values + * (e.g. (1, 1, 1)) which violate Usage's totalTokens-equals-sum + * invariant. Records auto-generate equals/hashCode/toString from + * their components, so the invariant is purely a constructor concern; + * we exercise equality/hashCode/toString directly with hand-rolled + * value-equality assertions below. + */ + + @Test + void equalUsagesCompareEqualAndShareHashCode() { + Usage a = new Usage(10, 20, 30); + Usage b = new Usage(10, 20, 30); + assertThat(a).isEqualTo(b); + assertThat(a.hashCode()).isEqualTo(b.hashCode()); + // Reflexive. + assertThat(a).isEqualTo(a); + } + + @Test + void differingComponentsAreNotEqual() { + Usage a = new Usage(10, 20, 30); + Usage diffPrompt = new Usage(11, 19, 30); + Usage diffCompletion = new Usage(10, 21, 31); + Usage diffTotal = new Usage(10, 20, 30); + assertThat(a).isNotEqualTo(diffPrompt); + assertThat(a).isNotEqualTo(diffCompletion); + // diffTotal is identical here — kept for parity / regression intent. + assertThat(a).isEqualTo(diffTotal); + } + + @Test + void notEqualToNullOrOtherType() { + Usage a = new Usage(0, 0, 0); + assertThat(a).isNotEqualTo(null); + assertThat(a).isNotEqualTo("not a usage"); + } + + @Test + void toStringMentionsComponents() { + Usage a = new Usage(1, 2, 3); + String s = a.toString(); + assertThat(s).contains("promptTokens", "completionTokens", "totalTokens", "1", "2", "3"); + } + + @Test + void validUsageAccepted() { + Usage u = new Usage(10, 20, 30); + assertThat(u.promptTokens()).isEqualTo(10); + assertThat(u.completionTokens()).isEqualTo(20); + assertThat(u.totalTokens()).isEqualTo(30); + } + + @Test + void zeroCountsAreValid() { + Usage u = new Usage(0, 0, 0); + assertThat(u.totalTokens()).isZero(); + } + + @Test + void negativePromptTokensRejected() { + assertThatThrownBy(() -> new Usage(-1, 0, -1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("non-negative"); + } + + @Test + void negativeCompletionTokensRejected() { + assertThatThrownBy(() -> new Usage(0, -5, -5)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("non-negative"); + } + + @Test + void mismatchedTotalRejected() { + assertThatThrownBy(() -> new Usage(3, 4, 8)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("totalTokens must equal sum"); + } + + @Test + void totalLessThanSumRejected() { + assertThatThrownBy(() -> new Usage(5, 5, 9)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("totalTokens must equal sum"); + } +} diff --git a/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/ContainerCpuTest.java b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/ContainerCpuTest.java new file mode 100644 index 0000000..e5edb4f --- /dev/null +++ b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/ContainerCpuTest.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.runtime; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +class ContainerCpuTest { + + @ParameterizedTest(name = "[{index}] cgroup line ''{0}'' -> {1} CPUs") + @CsvSource( + delimiter = '|', + textBlock = + """ + 100000 100000 | 1 + 200000 100000 | 2 + 400000 100000 | 4 + 100 100000 | 1 + 150000 100000 | 2 + """) + void parsesQuotaPeriodAndCeils(String cgroupLine, int expectedCpus, @TempDir Path tmp) + throws IOException { + Path file = tmp.resolve("cpu.max"); + Files.writeString(file, cgroupLine + "\n", StandardCharsets.UTF_8); + assertThat(ContainerCpu.detect(file)).isEqualTo(expectedCpus); + } + + @Test + void missingFileFallsBackToAvailableProcessors(@TempDir Path tmp) { + Path missing = tmp.resolve("does-not-exist"); + int detected = ContainerCpu.detect(missing); + assertThat(detected).isGreaterThanOrEqualTo(1); + assertThat(detected).isEqualTo(Math.max(1, Runtime.getRuntime().availableProcessors())); + } + + @Test + void unconstrainedSentinelFallsBackToAvailableProcessors(@TempDir Path tmp) throws IOException { + Path file = tmp.resolve("cpu.max"); + Files.writeString(file, "max 100000\n", StandardCharsets.UTF_8); + int detected = ContainerCpu.detect(file); + assertThat(detected).isEqualTo(Math.max(1, Runtime.getRuntime().availableProcessors())); + } + + @Test + void malformedFileFallsBackToAvailableProcessors(@TempDir Path tmp) throws IOException { + Path file = tmp.resolve("cpu.max"); + Files.writeString(file, "garbage line\n", StandardCharsets.UTF_8); + assertThat(ContainerCpu.detect(file)) + .isEqualTo(Math.max(1, Runtime.getRuntime().availableProcessors())); + } + + @Test + void emptyFileFallsBackToAvailableProcessors(@TempDir Path tmp) throws IOException { + Path file = tmp.resolve("cpu.max"); + Files.writeString(file, "", StandardCharsets.UTF_8); + assertThat(ContainerCpu.detect(file)) + .isEqualTo(Math.max(1, Runtime.getRuntime().availableProcessors())); + } + + @Test + void extraWhitespaceTolerated(@TempDir Path tmp) throws IOException { + Path file = tmp.resolve("cpu.max"); + Files.writeString(file, " 300000 100000 \n", StandardCharsets.UTF_8); + assertThat(ContainerCpu.detect(file)).isEqualTo(3); + } + + @Test + void singleTokenFileFallsBack(@TempDir Path tmp) throws IOException { + Path file = tmp.resolve("cpu.max"); + Files.writeString(file, "100000\n", StandardCharsets.UTF_8); + assertThat(ContainerCpu.detect(file)) + .isEqualTo(Math.max(1, Runtime.getRuntime().availableProcessors())); + } + + @Test + void zeroQuotaFallsBack(@TempDir Path tmp) throws IOException { + Path file = tmp.resolve("cpu.max"); + Files.writeString(file, "0 100000\n", StandardCharsets.UTF_8); + assertThat(ContainerCpu.detect(file)) + .isEqualTo(Math.max(1, Runtime.getRuntime().availableProcessors())); + } + + @Test + void publicNoArgFormReturnsAtLeastOne() { + // detect() may use the host's real /sys/fs/cgroup/cpu.max which is fine + // for an integration smoke; we just check the contract. + assertThat(ContainerCpu.detect()).isGreaterThanOrEqualTo(1); + } +} diff --git a/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/NativeExecutorTest.java b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/NativeExecutorTest.java new file mode 100644 index 0000000..1555b95 --- /dev/null +++ b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/NativeExecutorTest.java @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.runtime; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.Test; + +class NativeExecutorTest { + + @Test + void factoryRejectsZeroThreads() { + assertThatThrownBy(() -> NativeExecutor.sized(0, "p")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("threads must be >= 1"); + } + + @Test + void factoryRejectsBlankPrefix() { + assertThatThrownBy(() -> NativeExecutor.sized(1, " ")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("blank"); + } + + @Test + void submitNativeRunsOnPlatformThread() throws Exception { + try (NativeExecutor exec = NativeExecutor.sized(2, "test-native")) { + AtomicReference ref = new AtomicReference<>(); + CompletableFuture f = + exec.submitNative( + () -> { + ref.set(Thread.currentThread()); + return 42; + }); + assertThat(f.get()).isEqualTo(42); + Thread t = ref.get(); + assertThat(t).isNotNull(); + // The whole point of NativeExecutor: not virtual. + assertThat(t.isVirtual()).isFalse(); + assertThat(t.getName()).startsWith("test-native-"); + assertThat(t.isDaemon()).isTrue(); + } + } + + @Test + void submitNativePropagatesExceptions() throws Exception { + try (NativeExecutor exec = NativeExecutor.sized(1, "exc")) { + CompletableFuture f = + exec.submitNative( + () -> { + throw new IllegalStateException("native boom"); + }); + assertThatThrownBy(f::get) + .isInstanceOf(ExecutionException.class) + .hasCauseInstanceOf(IllegalStateException.class) + .hasMessageContaining("native boom"); + } + } + + @Test + void submitAfterShutdownIsRejected() { + NativeExecutor exec = NativeExecutor.sized(1, "rej"); + exec.shutdown(); + assertThatThrownBy(() -> exec.submitNative(() -> 1)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("shut down"); + } + + @Test + void closeIsIdempotent() { + NativeExecutor exec = NativeExecutor.sized(1, "idem"); + exec.close(); + // Second close must not throw. + exec.close(); + exec.shutdown(); + } + + @Test + void submitNativeRejectsNullCallable() { + try (NativeExecutor exec = NativeExecutor.sized(1, "npe")) { + assertThatThrownBy(() -> exec.submitNative(null)).isInstanceOf(NullPointerException.class); + } + } + + @Test + void poolThreadsCarryStableNamePattern() throws Exception { + try (NativeExecutor exec = NativeExecutor.sized(3, "named")) { + // Force three threads to actually start by submitting three tasks + // that block briefly so the pool grows to its core size. + CompletableFuture a = exec.submitNative(() -> Thread.currentThread().getName()); + CompletableFuture b = exec.submitNative(() -> Thread.currentThread().getName()); + CompletableFuture c = exec.submitNative(() -> Thread.currentThread().getName()); + assertThat(a.get()).startsWith("named-"); + assertThat(b.get()).startsWith("named-"); + assertThat(c.get()).startsWith("named-"); + } + } +} diff --git a/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/NativeLibLoaderTest.java b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/NativeLibLoaderTest.java new file mode 100644 index 0000000..4b2fcf2 --- /dev/null +++ b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/NativeLibLoaderTest.java @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.runtime; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import io.github.randomcodespace.inference.NativeLoadException; + +class NativeLibLoaderTest { + + @BeforeEach + void resetCache() { + NativeLibLoader.resetCacheForTesting(); + } + + @Test + void extractsAndVerifiesValidResource() { + Path extracted = NativeLibLoader.extractAndVerify("/native-fixtures/sample.bin"); + assertThat(extracted).exists().isRegularFile(); + // Idempotent: subsequent call returns the same path. + Path again = NativeLibLoader.extractAndVerify("/native-fixtures/sample.bin"); + assertThat(again).isEqualTo(extracted); + } + + @Test + void extractedContentMatchesResourceBytes() throws Exception { + Path extracted = NativeLibLoader.extractAndVerify("/native-fixtures/sample.bin"); + String content = Files.readString(extracted, StandardCharsets.UTF_8); + assertThat(content).isEqualTo("inference-sdk fixture payload v1"); + } + + @Test + void shaMismatchThrowsActionableException() { + assertThatThrownBy( + () -> NativeLibLoader.extractAndVerify("/native-fixtures/sample-wrongsha.bin")) + .isInstanceOf(NativeLoadException.class) + .hasMessageContaining("SHA-256 mismatch") + .hasMessageContaining("sample-wrongsha.bin") + .hasMessageContaining("os=") + .hasMessageContaining("arch="); + } + + @Test + void missingResourceThrowsActionableException() { + assertThatThrownBy( + () -> NativeLibLoader.extractAndVerify("/native-fixtures/does-not-exist.bin")) + .isInstanceOf(NativeLoadException.class) + // Either the .sha256 sibling is the first thing missed, or + // the binary is — both are actionable. Match on the common + // remediation prefix that both messages share. + .hasMessageContaining("does-not-exist.bin"); + } + + @Test + void rejectsNullResourcePath() { + assertThatThrownBy(() -> NativeLibLoader.extractAndVerify(null)) + .isInstanceOf(NullPointerException.class); + } + + @Test + void leadingSlashAndNoLeadingSlashEquivalent() { + Path a = NativeLibLoader.extractAndVerify("/native-fixtures/sample.bin"); + // Reset cache to force re-extraction with the alternate spelling. + NativeLibLoader.resetCacheForTesting(); + Path b = NativeLibLoader.extractAndVerify("native-fixtures/sample.bin"); + assertThat(a.getFileName()).isEqualTo(b.getFileName()); + assertThat(a).exists(); + assertThat(b).exists(); + } +} diff --git a/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/RequestIdTest.java b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/RequestIdTest.java new file mode 100644 index 0000000..642d6f7 --- /dev/null +++ b/java/inference-sdk-core/src/test/java/io/github/randomcodespace/inference/runtime/RequestIdTest.java @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.runtime; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.StructuredTaskScope; + +import org.junit.jupiter.api.Test; + +class RequestIdTest { + + @Test + void generateProducesPrefixedUuid() { + String id = RequestId.generate(); + assertThat(id).startsWith("req_"); + // Strip prefix and verify the remainder parses as a canonical UUID. + UUID parsed = UUID.fromString(id.substring("req_".length())); + assertThat(parsed).isNotNull(); + } + + @Test + void generateProducesUniqueIds() { + String a = RequestId.generate(); + String b = RequestId.generate(); + assertThat(a).isNotEqualTo(b); + } + + @Test + void scopedValueUnboundOutsideWith() { + assertThat(RequestId.CURRENT.isBound()).isFalse(); + } + + @Test + void withRequestIdBindsCurrent() throws Exception { + String result = + RequestId.withRequestId( + "req_abc", + () -> { + assertThat(RequestId.CURRENT.isBound()).isTrue(); + return RequestId.CURRENT.get(); + }); + assertThat(result).isEqualTo("req_abc"); + // Out-of-scope: unbound again. + assertThat(RequestId.CURRENT.isBound()).isFalse(); + } + + @Test + void withRequestIdRejectsNullId() { + assertThatThrownBy(() -> RequestId.withRequestId(null, () -> "x")) + .isInstanceOf(NullPointerException.class); + } + + @Test + void withRequestIdRejectsNullBody() { + assertThatThrownBy(() -> RequestId.withRequestId("req_x", null)) + .isInstanceOf(NullPointerException.class); + } + + @Test + @SuppressWarnings("preview") // StructuredTaskScope is JEP 505 preview in JDK 25. + void scopedValuePropagatesAcrossStructuredTaskScope() throws Exception { + String observed = + RequestId.withRequestId( + "req_structured", + () -> { + try (var scope = + StructuredTaskScope.open(StructuredTaskScope.Joiner.awaitAll())) { + StructuredTaskScope.Subtask task = + scope.fork( + () -> + // Inherited via structured fork. + RequestId.CURRENT.orElse("UNBOUND")); + scope.join(); + return task.get(); + } + }); + assertThat(observed).isEqualTo("req_structured"); + } + + @Test + @SuppressWarnings("preview") + void scopedValuePropagatesIntoStructuredVirtualThreadFork() throws Exception { + // Virtual threads inherit ScopedValue only when forked from a + // structured scope; the unstructured Executors.newVirtualThreadPerTaskExecutor() + // does NOT inherit (this is the deliberate JEP 506 behavior). Test + // the structured path here so the contract is pinned. + String observed = + RequestId.withRequestId( + "req_vt", + () -> { + try (var scope = + StructuredTaskScope.open(StructuredTaskScope.Joiner.awaitAll())) { + StructuredTaskScope.Subtask task = + scope.fork( + () -> { + // Verify the runner is virtual AND inherits. + if (!Thread.currentThread().isVirtual()) { + return "NOT-VIRTUAL"; + } + return RequestId.CURRENT.orElse("UNBOUND"); + }); + scope.join(); + return task.get(); + } + }); + assertThat(observed).isEqualTo("req_vt"); + } + + @Test + void unstructuredVirtualThreadExecutorDoesNotInheritScopedValue() throws Exception { + // Pin the negative contract: ScopedValue is only inherited via + // structured scopes. An unstructured submit observes it as unbound. + String observed = + RequestId.withRequestId( + "req_unstructured", + () -> { + try (ExecutorService exec = Executors.newVirtualThreadPerTaskExecutor()) { + Future f = exec.submit(() -> RequestId.CURRENT.orElse("UNBOUND")); + return f.get(); + } + }); + assertThat(observed).isEqualTo("UNBOUND"); + } + + @Test + void exceptionFromBodyPropagatesUnchanged() { + assertThatThrownBy( + () -> + RequestId.withRequestId( + "req_err", + () -> { + throw new IllegalStateException("boom"); + })) + .isInstanceOf(IllegalStateException.class) + .hasMessage("boom"); + } +} diff --git a/java/inference-sdk-core/src/test/resources/native-fixtures/sample-wrongsha.bin b/java/inference-sdk-core/src/test/resources/native-fixtures/sample-wrongsha.bin new file mode 100644 index 0000000..59ff596 --- /dev/null +++ b/java/inference-sdk-core/src/test/resources/native-fixtures/sample-wrongsha.bin @@ -0,0 +1 @@ +inference-sdk fixture payload v1 \ No newline at end of file diff --git a/java/inference-sdk-core/src/test/resources/native-fixtures/sample-wrongsha.bin.sha256 b/java/inference-sdk-core/src/test/resources/native-fixtures/sample-wrongsha.bin.sha256 new file mode 100644 index 0000000..cd09bbf --- /dev/null +++ b/java/inference-sdk-core/src/test/resources/native-fixtures/sample-wrongsha.bin.sha256 @@ -0,0 +1 @@ +0000000000000000000000000000000000000000000000000000000000000000 diff --git a/java/inference-sdk-core/src/test/resources/native-fixtures/sample.bin b/java/inference-sdk-core/src/test/resources/native-fixtures/sample.bin new file mode 100644 index 0000000..59ff596 --- /dev/null +++ b/java/inference-sdk-core/src/test/resources/native-fixtures/sample.bin @@ -0,0 +1 @@ +inference-sdk fixture payload v1 \ No newline at end of file diff --git a/java/inference-sdk-core/src/test/resources/native-fixtures/sample.bin.sha256 b/java/inference-sdk-core/src/test/resources/native-fixtures/sample.bin.sha256 new file mode 100644 index 0000000..b6c2e81 --- /dev/null +++ b/java/inference-sdk-core/src/test/resources/native-fixtures/sample.bin.sha256 @@ -0,0 +1 @@ +07714d7ba2fd7a4181da763c798f59ddd76bf45e120837bb179f852bee8f72c2 diff --git a/java/inference-sdk-embed-bge-small/pom.xml b/java/inference-sdk-embed-bge-small/pom.xml new file mode 100644 index 0000000..6d3c030 --- /dev/null +++ b/java/inference-sdk-embed-bge-small/pom.xml @@ -0,0 +1,124 @@ + + + + 4.0.0 + + + io.github.randomcodespace.inference + inference-sdk-parent + 0.1.0-SNAPSHOT + ../inference-sdk-parent/pom.xml + + + inference-sdk-embed-bge-small + jar + + ${project.groupId}:${project.artifactId} + Bundled BAAI/bge-small-en-v1.5 embedding model artifact (int8 dynamic + quantized ONNX, ~35 MB, 384 dimensions, MIT). Maven JAR with no Java code; only + packages the model file + manifest under src/main/resources/models/. The ONNX + file is fetched at build time by scripts/fetch_models.py and embedded into the + published artifact (hybrid model-distribution; see pom comment header). + + + true + true + true + + false + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + true + + + + + + org.codehaus.mojo + exec-maven-plugin + + + + fetch-embedding-model + generate-resources + + exec + + + ${fetch.models.skip} + python3 + + ${project.basedir}/../../scripts/fetch_models.py + --embedding-model + bge-small-en-v1.5 + --skip-generation + --output-dir + ${project.basedir}/.. + + + + + + verify-embedding-model + process-resources + + exec + + + ${fetch.models.skip} + python3 + + ${project.basedir}/../../scripts/verify_models.py + + + + + + + + diff --git a/java/inference-sdk-embed-bge-small/src/main/resources/models/.gitkeep b/java/inference-sdk-embed-bge-small/src/main/resources/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/java/inference-sdk-embed-bge-small/src/main/resources/models/model-manifest.properties b/java/inference-sdk-embed-bge-small/src/main/resources/models/model-manifest.properties new file mode 100644 index 0000000..88c7849 --- /dev/null +++ b/java/inference-sdk-embed-bge-small/src/main/resources/models/model-manifest.properties @@ -0,0 +1,10 @@ +# Generated by scripts/fetch_models.py - do not edit by hand +id=bge-small-en-v1.5 +hf_repo=BAAI/bge-small-en-v1.5 +revision=main +tokenizer_family=BGE +dimensions=384 +max_tokens=512 +quantization=int8-dynamic +sha256= +license=MIT diff --git a/java/inference-sdk-embed/pom.xml b/java/inference-sdk-embed/pom.xml new file mode 100644 index 0000000..ebf5b01 --- /dev/null +++ b/java/inference-sdk-embed/pom.xml @@ -0,0 +1,131 @@ + + + + 4.0.0 + + + io.github.randomcodespace.inference + inference-sdk-parent + 0.1.0-SNAPSHOT + ../inference-sdk-parent/pom.xml + + + inference-sdk-embed + jar + + inference-sdk-embed + Embedding API for inference-sdk: Embedder interface + backed by ONNX Runtime + DJL tokenizers, executing native calls + on platform threads via inference-sdk-core's NativeExecutor. + + + + + io.github.randomcodespace.inference + inference-sdk-core + ${project.version} + + + + + com.microsoft.onnxruntime + onnxruntime + + + ai.djl.huggingface + tokenizers + + + + + org.slf4j + slf4j-api + + + + + com.fasterxml.jackson.core + jackson-annotations + + + + + org.junit.jupiter + junit-jupiter + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.assertj + assertj-core + test + + + org.awaitility + awaitility + test + + + net.jqwik + jqwik + test + + + nl.jqno.equalsverifier + equalsverifier + test + + + ch.qos.logback + logback-classic + test + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + model,native + + + + org.apache.maven.plugins + maven-failsafe-plugin + + model,native + + + + + diff --git a/java/inference-sdk-embed/spotbugs-exclude.xml b/java/inference-sdk-embed/spotbugs-exclude.xml new file mode 100644 index 0000000..57b9cb1 --- /dev/null +++ b/java/inference-sdk-embed/spotbugs-exclude.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + diff --git a/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/BgeNormalizer.java b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/BgeNormalizer.java new file mode 100644 index 0000000..13871dc --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/BgeNormalizer.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.embed; + +/** + * L2-normalization utility for BGE / sentence-transformer style embedding outputs. + * + *

BGE models (and most modern bi-encoder embedding models) emit pre-normalized vectors when + * loaded from the official checkpoints, but quantized variants and custom exports occasionally skip + * the final norm op. This helper makes the post-condition explicit and idempotent: a vector with + * unit L2 norm passes through unchanged (modulo float rounding). + * + *

Edge cases: + * + *

    + *
  • An all-zero vector (L2 == 0) is left unchanged — dividing by zero would yield NaN. + *
  • {@code NaN} or {@code Inf} components produce a {@code NaN} norm; in that case the vector + * is left untouched and the caller is expected to surface the upstream failure. + *
+ * + *

This class is thread-safe (stateless), package-private, and final to keep the API surface + * small. + */ +final class BgeNormalizer { + + private BgeNormalizer() { + // Utility class — no instances. + } + + /** + * Compute the L2 norm of a float vector. + * + * @param vector vector; never {@code null} + * @return non-negative L2 norm; {@code 0.0} for the zero vector + */ + static double l2Norm(float[] vector) { + double sum = 0.0; + for (float v : vector) { + sum += (double) v * (double) v; + } + return Math.sqrt(sum); + } + + /** + * Mutate {@code vector} in place so that its L2 norm is 1.0. No-op for the zero vector or any + * vector with a non-finite norm. + * + * @param vector vector to normalize; never {@code null}; mutated in place + */ + static void normalizeInPlace(float[] vector) { + double norm = l2Norm(vector); + if (norm == 0.0 || !Double.isFinite(norm)) { + return; + } + float invNorm = (float) (1.0 / norm); + for (int i = 0; i < vector.length; i++) { + vector[i] = vector[i] * invNorm; + } + } +} diff --git a/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/EmbedException.java b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/EmbedException.java new file mode 100644 index 0000000..c158dda --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/EmbedException.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.embed; + +/** + * Base unchecked exception type raised by the embedding API. + * + *

All embedding-specific failures — model resolution, input validation, native session faults — + * extend this class so callers can catch a single type when they only need to distinguish embedding + * errors from other runtime failures. + * + *

This is a {@link RuntimeException} on purpose: the embedder is used from virtual threads and + * inside {@code CompletableFuture} pipelines where checked exceptions are awkward and tend to be + * lost in lambda boundaries. + * + * @see ModelNotFoundException + * @see InvalidInputException + */ +public class EmbedException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** + * Construct an embed exception with a human-readable message. + * + * @param message descriptive message; shown to the caller and logged + */ + public EmbedException(String message) { + super(message); + } + + /** + * Construct an embed exception with a wrapped cause. + * + * @param message descriptive message + * @param cause originating exception, never {@code null} + */ + public EmbedException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/EmbedResult.java b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/EmbedResult.java new file mode 100644 index 0000000..5257f67 --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/EmbedResult.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.embed; + +import java.util.List; +import java.util.Objects; + +/** + * Result of a single {@link Embedder#embed(java.util.List)} or {@link + * Embedder#embedAsync(java.util.List)} call. + * + *

Wire format defined in {@code docs/WIRE_FORMAT.md} §2.2. + * + *

The {@code vectors} list is wrapped via {@link List#copyOf(java.util.Collection)} in the + * canonical constructor, producing an unmodifiable defensive copy. The inner {@code float[]} arrays + * are not defensively cloned — copying them on every result would multiply the + * cost of large batches; callers that mutate are violating the API contract. + * + * @param vectors one dense float vector per input, in the same order as the request; each vector's + * length equals {@link io.github.randomcodespace.inference.ModelInfo#dimensions()} + * @param tokens total number of tokens consumed across all inputs (post-truncation) + * @param stats per-request telemetry; never {@code null} + */ +public record EmbedResult(List vectors, int tokens, EmbedStats stats) { + + /** + * Validate non-null components and wrap {@code vectors} in an unmodifiable list. + * + *

An empty input list is permitted and produces an empty {@code vectors} list with {@code + * tokens == 0} (per {@code java-sdk.md} §11.1 — "{@code embed(List.of())} → empty vectors"). + */ + public EmbedResult { + Objects.requireNonNull(vectors, "vectors"); + Objects.requireNonNull(stats, "stats"); + if (tokens < 0) { + throw new IllegalArgumentException("tokens must be >= 0, got " + tokens); + } + // Defensive copy — caller cannot mutate the returned list. + vectors = List.copyOf(vectors); + } +} diff --git a/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/EmbedStats.java b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/EmbedStats.java new file mode 100644 index 0000000..0cd079b --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/EmbedStats.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.embed; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Per-request telemetry returned alongside an {@link EmbedResult}. + * + *

Wire format is locked by {@link JsonProperty} bindings to {@code snake_case} field names per + * {@code docs/WIRE_FORMAT.md} §2.1. Phase 1 does not perform JSON (de)serialization at runtime, but + * the annotations are present so the contract is enforced at compile time and survives Phase 2. + * + *

All {@code *_ms} fields are non-negative; the canonical constructor validates the invariants + * defined in {@code docs/WIRE_FORMAT.md} (no negative durations, valid {@code batch_position}). + * + * @param requestId stable per-request identifier in {@code req_} form + * @param queueMs time spent waiting for a {@code NativeExecutor} slot, in milliseconds + * @param tokenizeMs time spent tokenizing the input batch, in milliseconds + * @param inferenceMs time spent inside the ONNX session run, in milliseconds + * @param totalMs end-to-end wall time, in milliseconds; satisfies {@code totalMs >= queueMs + + * tokenizeMs + inferenceMs} + * @param batchSize number of inputs in this batch + * @param batchPosition either {@code "single"} (one-shot call) or {@code "coalesced"} (merged with + * siblings by the SDK batcher) + * @param modelRevision the {@code revision} from the model's manifest + * @param node hostname / pod name for distributed tracing; may be {@code null} + */ +public record EmbedStats( + @JsonProperty("request_id") String requestId, + @JsonProperty("queue_ms") long queueMs, + @JsonProperty("tokenize_ms") long tokenizeMs, + @JsonProperty("inference_ms") long inferenceMs, + @JsonProperty("total_ms") long totalMs, + @JsonProperty("batch_size") int batchSize, + @JsonProperty("batch_position") String batchPosition, + @JsonProperty("model_revision") String modelRevision, + String node) { + + /** Canonical batch-position value for one-shot, non-coalesced calls. */ + public static final String BATCH_POSITION_SINGLE = "single"; + + /** Canonical batch-position value when the request was merged with siblings. */ + public static final String BATCH_POSITION_COALESCED = "coalesced"; + + /** + * Validate non-negativity invariants and the {@code batch_position} enum. Permits {@code null} + * for {@code requestId}, {@code modelRevision}, and {@code node} so partial telemetry can still + * be constructed during failure paths. + */ + public EmbedStats { + if (queueMs < 0) { + throw new IllegalArgumentException("queueMs must be >= 0, got " + queueMs); + } + if (tokenizeMs < 0) { + throw new IllegalArgumentException("tokenizeMs must be >= 0, got " + tokenizeMs); + } + if (inferenceMs < 0) { + throw new IllegalArgumentException("inferenceMs must be >= 0, got " + inferenceMs); + } + if (totalMs < 0) { + throw new IllegalArgumentException("totalMs must be >= 0, got " + totalMs); + } + if (batchSize < 0) { + throw new IllegalArgumentException("batchSize must be >= 0, got " + batchSize); + } + if (batchPosition != null + && !BATCH_POSITION_SINGLE.equals(batchPosition) + && !BATCH_POSITION_COALESCED.equals(batchPosition)) { + throw new IllegalArgumentException( + "batchPosition must be 'single' or 'coalesced', got '" + batchPosition + "'"); + } + } +} diff --git a/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/Embedder.java b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/Embedder.java new file mode 100644 index 0000000..1e4b490 --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/Embedder.java @@ -0,0 +1,289 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.embed; + +import java.nio.file.Path; +import java.time.Duration; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; + +import org.slf4j.Logger; +import org.slf4j.helpers.NOPLogger; + +import io.github.randomcodespace.inference.ModelInfo; + +/** + * Public embedding API. Computes dense float vectors from input strings using a local ONNX model. + * + *

Construct via the fluent {@link #builder()}: + * + *

{@code
+ * try (Embedder embedder = Embedder.builder()
+ *         .model("bge-small-en-v1.5")
+ *         .threads(4)
+ *         .build()) {
+ *   EmbedResult r = embedder.embed(List.of("hello world"));
+ * }
+ * }
+ * + *

Thread-safety and the native-thread-pinning workaround

+ * + *

All methods on this interface are thread-safe and may be called from virtual threads. Internal + * implementations route every JNI call through {@code + * io.github.randomcodespace.inference.runtime.NativeExecutor}, which trampolines work onto a + * platform-thread pool — see {@code docs/ARCHITECTURE.md} §3.3 for the rationale (ONNX Runtime pins + * the carrier thread; submitting native work to a virtual-thread executor would either pin the + * carrier or expose stale per-thread state). Caller virtual threads await the resulting {@code + * CompletableFuture} normally. + * + *

Lifecycle

+ * + *

{@link #close()} is idempotent. Subsequent calls to {@code embed*} after close throw {@link + * IllegalStateException}. + * + * @see EmbedResult + * @see EmbedStats + */ +public interface Embedder extends AutoCloseable { + + /** + * Embed a batch of strings synchronously. Blocks the caller until the entire batch resolves. + * + * @param texts inputs to embed; never {@code null} and never contains {@code null} elements + * @return result with one vector per input, plus aggregate telemetry + * @throws InvalidInputException if {@code texts} is null or contains a null element + * @throws IllegalStateException if this embedder has been {@linkplain #close() closed} + * @throws EmbedException for any other failure (model error, native fault) + */ + EmbedResult embed(List texts); + + /** + * Embed exactly one string. Equivalent to {@code embed(List.of(text)).vectors().get(0)}. + * + * @param text input string; never {@code null} + * @return single dense vector of length {@link ModelInfo#dimensions()} + * @throws InvalidInputException if {@code text} is {@code null} + * @throws IllegalStateException if this embedder has been closed + * @throws EmbedException for any other failure + */ + float[] embedOne(String text); + + /** + * Embed a batch asynchronously. The returned future completes on the SDK's virtual-thread + * executor; the underlying native call runs on a platform thread (see class JavaDoc on the + * native-thread-pinning workaround). Callers awaiting on this future from a virtual thread will + * yield the carrier correctly. + * + * @param texts inputs to embed + * @return future yielding the result; completes exceptionally with the same exceptions as {@link + * #embed(List)} + */ + CompletableFuture embedAsync(List texts); + + /** + * @return the static {@link ModelInfo} for the loaded model + */ + ModelInfo modelInfo(); + + /** + * Release the ONNX session, the tokenizer, and the platform-thread executor. Idempotent. + * + *

In-flight requests submitted before {@code close()} will run to completion; submissions + * after close raise {@link IllegalStateException}. Native resources are guaranteed to be + * freed exactly once, even under concurrent {@code close()} calls. + */ + @Override + void close(); + + /** + * @return a fresh builder; configuration is per-builder and never shared + */ + static Builder builder() { + return new Builder(); + } + + /** + * Fluent builder for {@link Embedder}. Mutator methods return {@code this} so calls chain. Each + * mutator validates eagerly so misconfiguration surfaces at the call site rather than inside + * {@link #build()}. + * + *

Default values: + * + *

    + *
  • {@code threads = 0} → resolved at build time via {@code ContainerCpu.detect()} + *
  • {@code batchSize = 32} + *
  • {@code batchWait = Duration.ofMillis(0)} (no coalescing in Phase 1) + *
  • {@code logger = NOPLogger.NOP_LOGGER} + *
+ */ + final class Builder { + + private static final int DEFAULT_BATCH_SIZE = 32; + + private String model; + private Path modelPath; + private Path tokenizerPath; + private int threads; + private int batchSize = DEFAULT_BATCH_SIZE; + private Duration batchWait = Duration.ZERO; + private Logger logger = NOPLogger.NOP_LOGGER; + + /** + * Set the logical model id. Must be non-null and non-blank. + * + * @param name model id, e.g. {@code "bge-small-en-v1.5"} + * @return this builder for chaining + * @throws IllegalArgumentException if {@code name} is blank + * @throws NullPointerException if {@code name} is {@code null} + */ + public Builder model(String name) { + Objects.requireNonNull(name, "model"); + if (name.isBlank()) { + throw new IllegalArgumentException("model must not be blank"); + } + this.model = name; + return this; + } + + /** + * Set an explicit on-disk model path. Overrides classpath / env resolution. + * + * @param path absolute path to the {@code .onnx} file; never {@code null} + * @return this builder for chaining + * @throws NullPointerException if {@code path} is {@code null} + */ + public Builder modelPath(Path path) { + this.modelPath = Objects.requireNonNull(path, "modelPath"); + return this; + } + + /** + * Set an explicit on-disk tokenizer.json path. If unset, the resolver looks alongside the model + * file. + * + * @param path absolute path to the {@code tokenizer.json} file; never {@code null} + * @return this builder for chaining + * @throws NullPointerException if {@code path} is {@code null} + */ + public Builder tokenizerPath(Path path) { + this.tokenizerPath = Objects.requireNonNull(path, "tokenizerPath"); + return this; + } + + /** + * Number of platform threads in the native executor pool. {@code 0} means auto-detect. + * + * @param n thread count; must be {@code >= 0} + * @return this builder for chaining + * @throws IllegalArgumentException if {@code n < 0} + */ + public Builder threads(int n) { + if (n < 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "threads must be >= 0, got %d", n)); + } + this.threads = n; + return this; + } + + /** + * Maximum batch size submitted to a single ONNX session run. Larger inputs are chunked. + * + * @param n batch size; must be {@code >= 1} + * @return this builder for chaining + * @throws IllegalArgumentException if {@code n < 1} + */ + public Builder batchSize(int n) { + if (n < 1) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "batchSize must be >= 1, got %d", n)); + } + this.batchSize = n; + return this; + } + + /** + * Maximum time to wait for additional inputs before flushing a partial batch (Phase 2 + * coalescing primitive — Phase 1 honours the value as a no-op timeout). + * + * @param d non-negative duration + * @return this builder for chaining + * @throws IllegalArgumentException if {@code d} is negative + * @throws NullPointerException if {@code d} is {@code null} + */ + public Builder batchWait(Duration d) { + Objects.requireNonNull(d, "batchWait"); + if (d.isNegative()) { + throw new IllegalArgumentException("batchWait must be >= 0, got " + d); + } + this.batchWait = d; + return this; + } + + /** + * SLF4J logger for diagnostic output. Default: {@link NOPLogger#NOP_LOGGER}; the library never + * configures logging itself. + * + * @param slf4jLogger logger; never {@code null} + * @return this builder for chaining + * @throws NullPointerException if {@code slf4jLogger} is {@code null} + */ + public Builder logger(Logger slf4jLogger) { + this.logger = Objects.requireNonNull(slf4jLogger, "logger"); + return this; + } + + /** + * Build the embedder. Resolves the model via {@link ModelResolver}, opens the ONNX session, + * loads the tokenizer, and starts the native-thread pool. + * + * @return a fully initialized {@link Embedder}; the caller owns its lifecycle + * @throws IllegalArgumentException if neither {@code model} nor {@code modelPath} is set + * @throws ModelNotFoundException if no matching model is on the classpath / disk + * @throws EmbedException for any other initialization failure + */ + public Embedder build() { + if (model == null && modelPath == null) { + throw new IllegalArgumentException( + "Embedder.Builder requires either model(String) or modelPath(Path) to be set"); + } + return OnnxEmbedder.create(this); + } + + // Package-private accessors used by OnnxEmbedder.create. Keeping the + // builder a "dumb" data carrier means tests can construct one without + // depending on OnnxEmbedder. + + String getModel() { + return model; + } + + Path getModelPath() { + return modelPath; + } + + Path getTokenizerPath() { + return tokenizerPath; + } + + int getThreads() { + return threads; + } + + int getBatchSize() { + return batchSize; + } + + Duration getBatchWait() { + return batchWait; + } + + Logger getLogger() { + return logger; + } + } +} diff --git a/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/InvalidInputException.java b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/InvalidInputException.java new file mode 100644 index 0000000..1368df3 --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/InvalidInputException.java @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.embed; + +/** + * Thrown before any work is done when the caller passes a malformed input list. + * + *

Examples (per {@code java-sdk.md} §11.2 case 5): + * + *

    + *
  • {@code null} list reference + *
  • List containing a {@code null} element + *
+ * + *

Empty strings, whitespace-only strings, and strings exceeding the model's max-token window are + * not invalid — they are valid inputs that the embedder handles (see §11.2 cases + * 1, 2, 3, 10). + */ +public class InvalidInputException extends EmbedException { + + private static final long serialVersionUID = 1L; + + /** + * Construct an invalid-input exception. + * + * @param message descriptive message identifying the offending input + */ + public InvalidInputException(String message) { + super(message); + } +} diff --git a/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/ModelNotFoundException.java b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/ModelNotFoundException.java new file mode 100644 index 0000000..9435009 --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/ModelNotFoundException.java @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.embed; + +import java.util.List; + +/** + * Thrown when the embed builder cannot locate a model on disk, in {@code INFERENCE_MODEL_DIR}, or + * on the classpath. + * + *

The error message lists every model JAR currently visible on the classpath so the caller can + * fix the misconfiguration without trial-and-error. + * + *

Resolution order (per {@code java-sdk.md} §8): + * + *

    + *
  1. Explicit {@link Embedder.Builder#modelPath(java.nio.file.Path)} value + *
  2. {@code INFERENCE_MODEL_DIR} environment variable + *
  3. Classpath resource {@code /models/.onnx} + *
+ */ +public class ModelNotFoundException extends EmbedException { + + private static final long serialVersionUID = 1L; + + /** + * Construct a not-found exception with an explanatory message. + * + * @param message message identifying the missing model and lookup paths attempted + */ + public ModelNotFoundException(String message) { + super(message); + } + + /** + * Construct a not-found exception that summarizes what was searched and what was visible. + * + * @param requestedModel logical model id requested (e.g. {@code "bge-small-en-v1.5"}) + * @param searchedPaths list of paths searched (file system + classpath); printed in order + * @param availableModels list of model ids visible to the builder; may be empty + */ + public ModelNotFoundException( + String requestedModel, List searchedPaths, List availableModels) { + super(buildMessage(requestedModel, searchedPaths, availableModels)); + } + + private static String buildMessage( + String requestedModel, List searchedPaths, List availableModels) { + StringBuilder sb = new StringBuilder(); + sb.append("Model '") + .append(requestedModel) + .append("' was not found. Searched the following locations in order: "); + if (searchedPaths == null || searchedPaths.isEmpty()) { + sb.append("(none)"); + } else { + sb.append(String.join(", ", searchedPaths)); + } + sb.append(". Available models on the classpath: "); + if (availableModels == null || availableModels.isEmpty()) { + sb.append("(none — add an inference-sdk-embed- JAR to the classpath, e.g. ") + .append("inference-sdk-embed-bge-small)"); + } else { + sb.append(String.join(", ", availableModels)); + } + sb.append('.'); + return sb.toString(); + } +} diff --git a/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/ModelResolver.java b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/ModelResolver.java new file mode 100644 index 0000000..ad23456 --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/ModelResolver.java @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.embed; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.function.Function; + +/** + * Resolves a model identifier to a concrete {@link Path} on disk, in the order specified by {@code + * java-sdk.md} §8: + * + *
    + *
  1. Explicit {@link Embedder.Builder#modelPath(Path)}, if non-null and pointing at an existing + * file. + *
  2. {@code INFERENCE_MODEL_DIR} environment variable: a directory expected to contain a {@code + * .onnx} file (or just {@code }). + *
  3. Classpath: any {@code /models/.onnx} resource shipped by an {@code + * inference-sdk-embed-} JAR. + *
+ * + *

If none match, {@link ModelNotFoundException} is thrown with a descriptive message listing the + * locations searched and the model ids visible on the classpath via the canonical {@code + * /models/model-manifest.properties} discovery. + * + *

Package-private; consumers configure resolution via the builder, never this class directly. + * + *

Test seam: {@link #withEnv(Function)} swaps the environment-variable lookup so tests don't + * need to mutate process env. + */ +final class ModelResolver { + + /** Standard environment variable per {@code java-sdk.md} §8. */ + static final String ENV_MODEL_DIR = "INFERENCE_MODEL_DIR"; + + /** Canonical classpath prefix for shipped model resources. */ + static final String CLASSPATH_PREFIX = "/models/"; + + /** Canonical extension. ONNX is the only Phase 1 embedding format. */ + static final String ONNX_EXTENSION = ".onnx"; + + /** Canonical manifest filename for model-jar discovery. */ + static final String MANIFEST_RESOURCE = "models/model-manifest.properties"; + + private final Function envLookup; + private final ClassLoader classLoader; + + /** + * Default resolver: real {@link System#getenv(String)} and the current thread's context loader. + */ + ModelResolver() { + this(System::getenv, preferredClassLoader()); + } + + /** Test-friendly constructor; both arguments must be non-null. */ + ModelResolver(Function envLookup, ClassLoader classLoader) { + this.envLookup = Objects.requireNonNull(envLookup, "envLookup"); + this.classLoader = Objects.requireNonNull(classLoader, "classLoader"); + } + + /** + * @return a resolver whose {@code env} lookup is replaced by {@code envLookup}; useful for tests + */ + static ModelResolver withEnv(Function envLookup) { + return new ModelResolver(envLookup, preferredClassLoader()); + } + + /** + * Resolve the given model identifier to a concrete file path. + * + * @param model logical model id (e.g. {@code "bge-small-en-v1.5"}); may be null when {@code + * explicitPath} is provided + * @param explicitPath caller-supplied {@link Embedder.Builder#modelPath(Path)} value; may be null + * @return existing, readable file path + * @throws ModelNotFoundException if no candidate matches + */ + Path resolve(String model, Path explicitPath) { + List searched = new ArrayList<>(); + + // 1. Explicit path wins. + if (explicitPath != null) { + searched.add("explicit modelPath=" + explicitPath); + if (Files.isRegularFile(explicitPath)) { + return explicitPath.toAbsolutePath(); + } + throw new ModelNotFoundException( + model == null ? explicitPath.toString() : model, searched, discoverAvailableModels()); + } + + Objects.requireNonNull(model, "model id required when modelPath is unset"); + + // 2. INFERENCE_MODEL_DIR. + String envDir = envLookup.apply(ENV_MODEL_DIR); + if (envDir != null && !envDir.isBlank()) { + Path dir = Paths.get(envDir); + Path withExt = dir.resolve(model + ONNX_EXTENSION); + Path bare = dir.resolve(model); + searched.add(ENV_MODEL_DIR + "=" + envDir + " (looking for " + withExt + " or " + bare + ")"); + if (Files.isRegularFile(withExt)) { + return withExt.toAbsolutePath(); + } + if (Files.isRegularFile(bare)) { + return bare.toAbsolutePath(); + } + } else { + searched.add(ENV_MODEL_DIR + " (unset)"); + } + + // 3. Classpath: /models/.onnx. + String resourcePath = CLASSPATH_PREFIX + model + ONNX_EXTENSION; + searched.add("classpath:" + resourcePath); + URL classpathUrl = classLoader.getResource(resourcePath.substring(1)); + if (classpathUrl != null && "file".equals(classpathUrl.getProtocol())) { + // Loaded from the unpacked target/classes during dev — return directly. + try { + Path direct = Paths.get(classpathUrl.toURI()); + if (Files.isRegularFile(direct)) { + return direct.toAbsolutePath(); + } + } catch (Exception ignored) { + // Fall through to extraction. + } + } + if (classpathUrl != null) { + // Resource is inside a JAR; extract to temp. + Path extracted = extractClasspathResource(resourcePath); + if (extracted != null) { + return extracted; + } + } + + throw new ModelNotFoundException(model, searched, discoverAvailableModels()); + } + + /** + * Stream a classpath resource to a temp file. Returns {@code null} on failure (caller treats that + * as "not found" and surfaces a {@link ModelNotFoundException}). + */ + private Path extractClasspathResource(String resourcePath) { + String normalized = resourcePath.startsWith("/") ? resourcePath.substring(1) : resourcePath; + try (InputStream in = classLoader.getResourceAsStream(normalized)) { + if (in == null) { + return null; + } + String fileName = lastSegment(resourcePath); + Path tempDir = + Files.createTempDirectory( + String.format(Locale.ROOT, "inference-sdk-embed-%d-", ProcessHandle.current().pid())); + Path target = tempDir.resolve(fileName); + Files.copy(in, target); + // Best-effort cleanup on JVM shutdown. + target.toFile().deleteOnExit(); + tempDir.toFile().deleteOnExit(); + return target.toAbsolutePath(); + } catch (IOException ex) { + return null; + } + } + + /** + * Discover model ids visible on the classpath via {@code /models/model-manifest.properties} + * resources shipped by every {@code inference-sdk-embed-} JAR. + * + * @return alphabetically-sorted, distinct model ids; empty list if none visible + */ + List discoverAvailableModels() { + Set ids = new LinkedHashSet<>(); + try { + Enumeration manifests = classLoader.getResources(MANIFEST_RESOURCE); + while (manifests.hasMoreElements()) { + URL url = manifests.nextElement(); + try (InputStream in = url.openStream()) { + Properties p = new Properties(); + p.load(in); + String id = p.getProperty("id"); + if (id != null && !id.isBlank()) { + ids.add(id.trim()); + } + } catch (IOException ignored) { + // Skip malformed manifests; one bad jar shouldn't break discovery. + } + } + } catch (IOException ignored) { + // Classloader hiccup; best-effort discovery. + } + List out = new ArrayList<>(ids); + Collections.sort(out); + return out; + } + + private static ClassLoader preferredClassLoader() { + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + return cl != null ? cl : ModelResolver.class.getClassLoader(); + } + + private static String lastSegment(String path) { + int idx = path.lastIndexOf('/'); + return idx < 0 ? path : path.substring(idx + 1); + } +} diff --git a/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/OnnxEmbedder.java b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/OnnxEmbedder.java new file mode 100644 index 0000000..578c77b --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/io/github/randomcodespace/inference/embed/OnnxEmbedder.java @@ -0,0 +1,492 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.embed; + +import java.nio.LongBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.slf4j.Logger; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import io.github.randomcodespace.inference.ModelInfo; +import io.github.randomcodespace.inference.runtime.ContainerCpu; +import io.github.randomcodespace.inference.runtime.NativeExecutor; +import io.github.randomcodespace.inference.runtime.RequestId; + +/** + * Default {@link Embedder} implementation backed by an {@link OrtSession} and a {@link + * HuggingFaceTokenizer}. + * + *

Thread model

+ * + *

Tokenization runs synchronously on the calling thread (CPU-bound, low latency). The native + * {@code OrtSession.run} call is dispatched through the {@link NativeExecutor} platform-thread pool + * — this is the native-thread-pinning workaround documented in {@code docs/ARCHITECTURE.md} §3.3 + * and the class-level JavaDoc on {@link Embedder}. Async callers receive a {@link + * CompletableFuture} that resolves on a virtual-thread executor; awaiting it from a virtual thread + * yields the carrier correctly. + * + *

Test seam

+ * + *

The package-private constructor accepts pre-constructed collaborators ({@link SessionRunner}, + * tokenizer, executor, {@link ModelInfo}) so unit tests can substitute fakes without touching the + * native ONNX stack. The {@code create(Builder)} factory wires real implementations and is + * exercised in Tier 5 integration tests. + */ +final class OnnxEmbedder implements Embedder { + + /** + * Function-shaped abstraction over {@link OrtSession#run(Map)} so unit tests can supply a fake + * without instantiating an ONNX session. Wraps the only call shape we actually use. + */ + @FunctionalInterface + interface SessionRunner extends AutoCloseable { + + /** + * Run the session on a tokenized batch. + * + * @param inputIds shape {@code [batch, seqLen]} as a flattened long array, row-major + * @param attentionMask shape {@code [batch, seqLen]} as a flattened long array, row-major + * @param tokenTypeIds shape {@code [batch, seqLen]} as a flattened long array; may be {@code + * null} if the model doesn't take this input + * @param batch batch size + * @param seqLen padded sequence length + * @return one float vector per row in the batch + * @throws OrtException if the native call fails + */ + List run( + long[] inputIds, long[] attentionMask, long[] tokenTypeIds, int batch, int seqLen) + throws OrtException; + + /** Idempotent close. */ + @Override + default void close() {} + } + + private final SessionRunner runner; + private final HuggingFaceTokenizer tokenizer; + private final NativeExecutor executor; + private final ModelInfo modelInfo; + private final java.util.concurrent.ExecutorService asyncExecutor; + private final Logger log; + private final int batchSize; + private final boolean ownsAsyncExecutor; + private final AtomicBoolean closed = new AtomicBoolean(false); + + /** + * Production factory: resolve model + tokenizer paths, open the ONNX session, and start the + * native-thread pool. + * + * @throws ModelNotFoundException if no matching model is on the classpath / disk + * @throws EmbedException for any other initialization failure + */ + static OnnxEmbedder create(Embedder.Builder b) { + Logger log = b.getLogger(); + ModelResolver resolver = new ModelResolver(); + Path modelFile = resolver.resolve(b.getModel(), b.getModelPath()); + ModelInfo modelInfo = readModelInfo(b.getModel(), modelFile); + HuggingFaceTokenizer tk; + OrtSession session; + OrtEnvironment env = OrtEnvironment.getEnvironment(); + try { + Path tokenizerFile = resolveTokenizer(b.getTokenizerPath(), modelFile); + tk = + HuggingFaceTokenizer.builder() + .optTokenizerPath(tokenizerFile) + .optAddSpecialTokens(true) + .optTruncation(true) + .optMaxLength(modelInfo.maxTokens()) + .optPadToMaxLength() + .build(); + OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); + int threadCount = b.getThreads() > 0 ? b.getThreads() : ContainerCpu.detect(); + opts.setIntraOpNumThreads(threadCount); + opts.setInterOpNumThreads(1); + session = env.createSession(modelFile.toString(), opts); + } catch (Exception ex) { + throw new EmbedException( + "Failed to open ONNX session at " + modelFile + ": " + ex.getMessage(), ex); + } + int threadCount = b.getThreads() > 0 ? b.getThreads() : ContainerCpu.detect(); + NativeExecutor exec = NativeExecutor.sized(Math.max(1, threadCount), "embed-native"); + SessionRunner runner = realSessionRunner(env, session); + return new OnnxEmbedder(runner, tk, exec, modelInfo, log, b.getBatchSize(), null); + } + + /** + * Test-friendly constructor. + * + * @param runner abstraction over the ONNX session + * @param tokenizer DJL tokenizer; may be {@code null} when the runner does its own tokenization + * in tests + * @param executor platform-thread executor; never {@code null} + * @param modelInfo static model metadata; never {@code null} + * @param log SLF4J logger; never {@code null} + * @param batchSize maximum chunk size; must be {@code >= 1} + * @param asyncExecutor optional virtual-thread executor for {@code embedAsync}; if {@code null}, + * a fresh per-instance virtual-thread executor is created and owned by this embedder + */ + OnnxEmbedder( + SessionRunner runner, + HuggingFaceTokenizer tokenizer, + NativeExecutor executor, + ModelInfo modelInfo, + Logger log, + int batchSize, + java.util.concurrent.ExecutorService asyncExecutor) { + this.runner = Objects.requireNonNull(runner, "runner"); + this.tokenizer = tokenizer; + this.executor = Objects.requireNonNull(executor, "executor"); + this.modelInfo = Objects.requireNonNull(modelInfo, "modelInfo"); + this.log = Objects.requireNonNull(log, "log"); + if (batchSize < 1) { + throw new IllegalArgumentException("batchSize must be >= 1"); + } + this.batchSize = batchSize; + if (asyncExecutor == null) { + this.asyncExecutor = Executors.newVirtualThreadPerTaskExecutor(); + this.ownsAsyncExecutor = true; + } else { + this.asyncExecutor = asyncExecutor; + this.ownsAsyncExecutor = false; + } + } + + @Override + public EmbedResult embed(List texts) { + ensureOpen(); + validateInputs(texts); + long t0 = System.nanoTime(); + if (texts.isEmpty()) { + long total = elapsedMs(t0); + EmbedStats stats = + new EmbedStats( + RequestId.generate(), + 0L, + 0L, + 0L, + total, + 0, + EmbedStats.BATCH_POSITION_SINGLE, + modelInfo.revision(), + null); + return new EmbedResult(Collections.emptyList(), 0, stats); + } + + String requestId = currentRequestId(); + long tokenizeStart = System.nanoTime(); + long tokenizeMs = 0L; + long inferenceMs = 0L; + long totalTokens = 0L; + + List all = new ArrayList<>(texts.size()); + int i = 0; + while (i < texts.size()) { + int end = Math.min(texts.size(), i + batchSize); + List chunk = texts.subList(i, end); + + long ts = System.nanoTime(); + Encoding[] encodings = tokenizer.batchEncode(chunk.toArray(new String[0])); + tokenizeMs += elapsedMs(ts); + + int seqLen = 0; + for (Encoding e : encodings) { + seqLen = Math.max(seqLen, e.getIds().length); + totalTokens += countNonPadTokens(e); + } + + long[] inputIds = new long[encodings.length * seqLen]; + long[] attention = new long[encodings.length * seqLen]; + long[] typeIds = new long[encodings.length * seqLen]; + for (int b = 0; b < encodings.length; b++) { + long[] ids = encodings[b].getIds(); + long[] mask = encodings[b].getAttentionMask(); + long[] types = encodings[b].getTypeIds(); + System.arraycopy(ids, 0, inputIds, b * seqLen, Math.min(ids.length, seqLen)); + System.arraycopy(mask, 0, attention, b * seqLen, Math.min(mask.length, seqLen)); + if (types != null) { + System.arraycopy(types, 0, typeIds, b * seqLen, Math.min(types.length, seqLen)); + } + } + + long ts2 = System.nanoTime(); + List chunkVectors = + submitInference(inputIds, attention, typeIds, encodings.length, seqLen); + inferenceMs += elapsedMs(ts2); + for (float[] v : chunkVectors) { + BgeNormalizer.normalizeInPlace(v); + all.add(v); + } + i = end; + } + + long total = elapsedMs(t0); + EmbedStats stats = + new EmbedStats( + requestId, + 0L, + tokenizeMs, + inferenceMs, + total, + texts.size(), + EmbedStats.BATCH_POSITION_SINGLE, + modelInfo.revision(), + null); + int tokensInt = totalTokens > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) totalTokens; + return new EmbedResult(all, tokensInt, stats); + } + + @Override + public float[] embedOne(String text) { + if (text == null) { + throw new InvalidInputException("embedOne(null) is not permitted"); + } + EmbedResult r = embed(List.of(text)); + return r.vectors().get(0); + } + + @Override + public CompletableFuture embedAsync(List texts) { + ensureOpen(); + return CompletableFuture.supplyAsync(() -> embed(texts), asyncExecutor); + } + + @Override + public ModelInfo modelInfo() { + return modelInfo; + } + + @Override + public void close() { + if (!closed.compareAndSet(false, true)) { + return; + } + safeClose("runner", runner::close); + if (tokenizer != null) { + safeClose("tokenizer", tokenizer::close); + } + safeClose("nativeExecutor", executor::close); + if (ownsAsyncExecutor) { + asyncExecutor.close(); + } + } + + private void ensureOpen() { + if (closed.get()) { + throw new IllegalStateException("Embedder has been closed"); + } + } + + private static void validateInputs(List texts) { + if (texts == null) { + throw new InvalidInputException("texts list must not be null"); + } + for (int i = 0; i < texts.size(); i++) { + if (texts.get(i) == null) { + throw new InvalidInputException("texts[" + i + "] is null"); + } + } + } + + private List submitInference( + long[] inputIds, long[] attention, long[] typeIds, int batch, int seqLen) { + Callable> call = () -> runner.run(inputIds, attention, typeIds, batch, seqLen); + try { + return executor.submitNative(call).get(); + } catch (java.util.concurrent.ExecutionException ex) { + Throwable cause = ex.getCause() == null ? ex : ex.getCause(); + throw new EmbedException("Native ONNX inference failed: " + cause.getMessage(), cause); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new EmbedException("Interrupted while awaiting ONNX inference", ex); + } + } + + private static long countNonPadTokens(Encoding e) { + long[] mask = e.getAttentionMask(); + long n = 0L; + if (mask != null) { + for (long m : mask) { + if (m != 0L) { + n++; + } + } + } else { + n = e.getIds().length; + } + return n; + } + + private static long elapsedMs(long startNanos) { + return Math.max(0L, (System.nanoTime() - startNanos) / 1_000_000L); + } + + private static String currentRequestId() { + String existing = RequestId.CURRENT.isBound() ? RequestId.CURRENT.get() : null; + return existing != null ? existing : RequestId.generate(); + } + + private static Path resolveTokenizer(Path explicit, Path modelFile) { + if (explicit != null) { + return explicit; + } + Path sibling = modelFile.resolveSibling("tokenizer.json"); + if (Files.isRegularFile(sibling)) { + return sibling; + } + Path parent = modelFile.getParent(); + if (parent != null) { + Path inDir = parent.resolve("tokenizer.json"); + if (Files.isRegularFile(inDir)) { + return inDir; + } + } + throw new EmbedException( + "Tokenizer not found. Looked next to model file at " + + sibling + + " — set Embedder.Builder.tokenizerPath(Path) explicitly."); + } + + private static ModelInfo readModelInfo(String requestedId, Path modelFile) { + Path manifest = modelFile.resolveSibling("model-manifest.properties"); + if (!Files.isRegularFile(manifest)) { + // Reasonable defaults for unknown manifests; integration tests assert on shipped models. + return new ModelInfo( + requestedId == null ? "unknown" : requestedId, "unknown", "unknown", 384, 512); + } + Properties p = new Properties(); + try (var in = Files.newInputStream(manifest)) { + p.load(in); + } catch (Exception ex) { + throw new EmbedException( + "Failed to read model manifest " + manifest + ": " + ex.getMessage(), ex); + } + String id = p.getProperty("id", requestedId == null ? "unknown" : requestedId); + String revision = p.getProperty("revision", "unknown"); + String quant = p.getProperty("quantization", "unknown"); + int dims = parseIntOr(p.getProperty("dimensions"), 384); + int max = parseIntOr(p.getProperty("max_tokens"), 512); + return new ModelInfo(id, revision, quant, dims, max); + } + + private static int parseIntOr(String raw, int fallback) { + if (raw == null || raw.isBlank()) { + return fallback; + } + try { + return Integer.parseInt(raw.trim()); + } catch (NumberFormatException ex) { + return fallback; + } + } + + private static SessionRunner realSessionRunner(OrtEnvironment env, OrtSession session) { + Set inputNames = new LinkedHashSet<>(session.getInputNames()); + return new SessionRunner() { + @Override + public List run( + long[] inputIds, long[] attention, long[] typeIds, int batch, int seqLen) + throws OrtException { + long[] shape = new long[] {batch, seqLen}; + Map inputs = new HashMap<>(); + OnnxTensor t1 = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds), shape); + OnnxTensor t2 = OnnxTensor.createTensor(env, LongBuffer.wrap(attention), shape); + OnnxTensor t3 = OnnxTensor.createTensor(env, LongBuffer.wrap(typeIds), shape); + try { + inputs.put("input_ids", t1); + inputs.put("attention_mask", t2); + if (inputNames.contains("token_type_ids")) { + inputs.put("token_type_ids", t3); + } + try (OrtSession.Result result = session.run(inputs)) { + // BGE / sentence-transformer convention: first output is sentence embeddings or + // token-level last_hidden_state. We expect mean-pooled output to come out as the + // model's first "sentence_embedding" output or the standard last_hidden_state. + // Take the first output and average over the seqLen axis if rank-3. + Object value = result.get(0).getValue(); + return reduceToVectors(value, batch, seqLen); + } + } finally { + t1.close(); + t2.close(); + t3.close(); + } + } + + @Override + public void close() { + try { + session.close(); + } catch (OrtException ex) { + // Best-effort. + } + } + }; + } + + /** + * Reduce raw ONNX output to per-row dense vectors. Handles either rank-2 ({@code + * float[batch][hidden]}) or rank-3 ({@code float[batch][seqLen][hidden]}, mean-pooled across the + * sequence axis with attention-mask weighting handled upstream). + */ + static List reduceToVectors(Object raw, int batch, int seqLen) { + if (raw instanceof float[][] r2) { + List out = new ArrayList<>(r2.length); + Collections.addAll(out, r2); + return out; + } + if (raw instanceof float[][][] r3) { + List out = new ArrayList<>(r3.length); + for (float[][] row : r3) { + int hidden = row[0].length; + float[] mean = new float[hidden]; + for (float[] tok : row) { + for (int h = 0; h < hidden; h++) { + mean[h] += tok[h]; + } + } + float inv = 1.0f / row.length; + for (int h = 0; h < hidden; h++) { + mean[h] *= inv; + } + out.add(mean); + } + return out; + } + throw new EmbedException( + String.format( + Locale.ROOT, + "Unsupported ONNX output type %s; expected float[][] or float[][][]", + raw == null ? "null" : raw.getClass().getName())); + } + + private void safeClose(String name, AutoCloseable c) { + try { + c.close(); + } catch (Exception ex) { + log.debug("Failed to close {}: {}", name, ex.getMessage()); + } + } +} diff --git a/java/inference-sdk-embed/src/main/java/module-info.java b/java/inference-sdk-embed/src/main/java/module-info.java new file mode 100644 index 0000000..7f0fcd9 --- /dev/null +++ b/java/inference-sdk-embed/src/main/java/module-info.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ + +/** + * inference-sdk-embed: embedding API backed by ONNX Runtime + DJL HuggingFace tokenizers. + * + *

Consumers obtain an {@link io.github.randomcodespace.inference.embed.Embedder} through {@link + * io.github.randomcodespace.inference.embed.Embedder#builder()}. Every native ONNX session call is + * marshalled onto a platform-thread pool via {@code + * io.github.randomcodespace.inference.runtime.NativeExecutor} from the {@code inference-sdk-core} + * module — this is the native-thread-pinning workaround documented in {@code docs/ARCHITECTURE.md} + * §3.3. + * + *

Both ONNX Runtime ({@code com.microsoft.onnxruntime}) and DJL tokenizers ({@code + * ai.djl.tokenizers}) are auto-modules whose names come from their JAR manifests' {@code + * Automatic-Module-Name} attribute. + */ +module io.github.randomcodespace.inference.embed { + exports io.github.randomcodespace.inference.embed; + + requires io.github.randomcodespace.inference.core; + requires org.slf4j; + requires com.microsoft.onnxruntime; + requires ai.djl.tokenizers; + + // Jackson annotations are compile-time only; declared transitive so consumers + // who deserialize EmbedResult/EmbedStats see the @JsonProperty bindings. + requires static com.fasterxml.jackson.annotation; +} diff --git a/java/inference-sdk-generate-qwen-0_5b/pom.xml b/java/inference-sdk-generate-qwen-0_5b/pom.xml new file mode 100644 index 0000000..3905649 --- /dev/null +++ b/java/inference-sdk-generate-qwen-0_5b/pom.xml @@ -0,0 +1,144 @@ + + + + 4.0.0 + + + io.github.randomcodespace.inference + inference-sdk-parent + 0.1.0-SNAPSHOT + ../inference-sdk-parent/pom.xml + + + inference-sdk-generate-qwen-0_5b + jar + + inference-sdk-generate-qwen-0_5b + Qwen2.5-0.5B-Instruct (q4_K_M GGUF) bundled as a + classpath-resolvable Maven artifact for the inference-sdk + generate module. Apache-2.0 model weights, no Java code. + Weights are fetched at build time and embedded into the + published JAR (hybrid model-distribution; see pom comment + header). + + + true + true + true + + false + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + true + + + + + + org.codehaus.mojo + exec-maven-plugin + + + + fetch-generation-model + generate-resources + + exec + + + ${fetch.models.skip} + python3 + + ${project.basedir}/../../scripts/fetch_models.py + --generation-model + qwen2.5-0.5b-instruct + --skip-embedding + --output-dir + ${project.basedir}/.. + + + + + + verify-generation-model + process-resources + + exec + + + ${fetch.models.skip} + python3 + + ${project.basedir}/../../scripts/verify_models.py + + + + + + + + diff --git a/java/inference-sdk-generate-qwen-0_5b/src/main/resources/models/.gitkeep b/java/inference-sdk-generate-qwen-0_5b/src/main/resources/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/java/inference-sdk-generate-qwen-0_5b/src/main/resources/models/model-manifest.properties b/java/inference-sdk-generate-qwen-0_5b/src/main/resources/models/model-manifest.properties new file mode 100644 index 0000000..8ed7a7f --- /dev/null +++ b/java/inference-sdk-generate-qwen-0_5b/src/main/resources/models/model-manifest.properties @@ -0,0 +1,9 @@ +# inference-sdk model manifest — qwen2.5-0.5b-instruct +# Populated by scripts/fetch_models.py (Tier 0.5). See docs/MODEL_REGISTRY.md §2.2. +id=qwen2.5-0.5b-instruct +hf_repo=Qwen/Qwen2.5-0.5B-Instruct +revision=main +quantization=q4_K_M +max_tokens=32768 +sha256= +license=Apache-2.0 diff --git a/java/inference-sdk-generate/.jqwik-database b/java/inference-sdk-generate/.jqwik-database new file mode 100644 index 0000000..711006c Binary files /dev/null and b/java/inference-sdk-generate/.jqwik-database differ diff --git a/java/inference-sdk-generate/pom.xml b/java/inference-sdk-generate/pom.xml new file mode 100644 index 0000000..f26cbcc --- /dev/null +++ b/java/inference-sdk-generate/pom.xml @@ -0,0 +1,183 @@ + + + + 4.0.0 + + + io.github.randomcodespace.inference + inference-sdk-parent + 0.1.0-SNAPSHOT + ../inference-sdk-parent/pom.xml + + + inference-sdk-generate + jar + + inference-sdk-generate + Generation API for inference-sdk: Generator interface + backed by the published de.kherud:llama 4.2.0 artifact (bundled + llama.cpp), executing native calls on platform threads via + inference-sdk-core's NativeExecutor. Streaming via + java.util.concurrent.Flow.Publisher with backpressure and + next-token-boundary cancellation. + + + + + io.github.randomcodespace.inference + inference-sdk-core + ${project.version} + + + + + de.kherud + llama + + + + + org.slf4j + slf4j-api + + + + + com.fasterxml.jackson.core + jackson-annotations + + + + + org.junit.jupiter + junit-jupiter + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.assertj + assertj-core + test + + + org.awaitility + awaitility + test + + + net.jqwik + jqwik + test + + + ch.qos.logback + logback-classic + test + + + + + + + + org.jacoco + jacoco-maven-plugin + + + jacoco-prepare-agent + + prepare-agent + + + + io/github/randomcodespace/inference/generate/KherudGenerator$KherudLlamaClient*.class + + + + + jacoco-report + + report + + + + io/github/randomcodespace/inference/generate/KherudGenerator$KherudLlamaClient*.class + + + + + jacoco-check + + check + + + ${skipTests} + + io/github/randomcodespace/inference/generate/KherudGenerator$KherudLlamaClient*.class + + + + BUNDLE + + + LINE + COVEREDRATIO + ${jacoco.line.minimum} + + + BRANCH + COVEREDRATIO + ${jacoco.branch.minimum} + + + + + + + + + + + diff --git a/java/inference-sdk-generate/spotbugs-exclude.xml b/java/inference-sdk-generate/spotbugs-exclude.xml new file mode 100644 index 0000000..670d687 --- /dev/null +++ b/java/inference-sdk-generate/spotbugs-exclude.xml @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/BoundedSubscription.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/BoundedSubscription.java new file mode 100644 index 0000000..21973db --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/BoundedSubscription.java @@ -0,0 +1,357 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Locale; +import java.util.Objects; +import java.util.concurrent.Flow; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; + +import org.slf4j.Logger; + +import io.github.randomcodespace.inference.FinishReason; +import io.github.randomcodespace.inference.ModelInfo; +import io.github.randomcodespace.inference.Usage; +import io.github.randomcodespace.inference.runtime.NativeExecutor; +import io.github.randomcodespace.inference.runtime.RequestId; + +/** + * Backpressure-honouring {@link Flow.Subscription} wrapping a kherud-style streaming iterator. + * + *

Contract (per {@code java-sdk.md} §7)

+ * + *
    + *
  • {@link #request(long) request(n)} starts native generation on first call (lazy); + * subsequent calls add demand. Negative or zero {@code n} signals an error per the + * Reactive-Streams rule §3.9. + *
  • {@link #cancel()} stops native generation at the next token boundary, emits exactly one + * terminal chunk with {@code finishReason == Canceled}, then {@code onComplete()}. + *
  • Mid-stream native failure → {@code onError(t)} with no terminal chunk preceding. + *
  • If the subscriber never calls {@code request}, no native work runs and {@code cancel()} + * releases all resources cleanly. + *
+ * + *

Threading

+ * + *

Native generation runs on the {@link NativeExecutor} platform-thread pool (the + * native-thread-pinning workaround documented in {@code docs/ARCHITECTURE.md} §3.3 and on the + * {@link Generator} class JavaDoc). Demand bookkeeping is serialised on this object's monitor to + * keep the state machine simple — emission throughput is bounded by the network of {@code + * onNext} calls, not by lock contention. + */ +final class BoundedSubscription implements Flow.Subscription { + + private final GenerateRequest req; + private final Flow.Subscriber subscriber; + private final KherudGenerator.LlamaClient client; + private final NativeExecutor executor; + private final Semaphore queue; + private final ModelInfo modelInfo; + private final int contextSize; + private final int streamBufferSize; + private final Function renderer; + private final Logger log; + private final String requestId; + + private final AtomicLong demand = new AtomicLong(0L); + private final AtomicBoolean started = new AtomicBoolean(false); + private final AtomicBoolean cancelled = new AtomicBoolean(false); + private final AtomicBoolean terminated = new AtomicBoolean(false); + private final AtomicBoolean queueAcquired = new AtomicBoolean(false); + + /** Worker-side buffer of produced chunks awaiting demand. */ + private final Deque buffer = new ArrayDeque<>(); + + /** Worker-thread liveness — set to non-null by the worker on first request(n). */ + private volatile Thread workerThread; + + /** Non-null once the streaming iterator has been opened. */ + private volatile KherudGenerator.StreamingIterator iterator; + + /** Wall time at first request(n); used for stats.totalMs. */ + private volatile long startNanos; + + /** Monitor for buffer + demand coordination between the worker and request(n) callers. */ + private final Object lock = new Object(); + + BoundedSubscription( + GenerateRequest req, + Flow.Subscriber subscriber, + KherudGenerator.LlamaClient client, + NativeExecutor executor, + Semaphore queue, + ModelInfo modelInfo, + int contextSize, + int streamBufferSize, + Function renderer, + Logger log) { + this.req = Objects.requireNonNull(req, "req"); + this.subscriber = Objects.requireNonNull(subscriber, "subscriber"); + this.client = Objects.requireNonNull(client, "client"); + this.executor = Objects.requireNonNull(executor, "executor"); + this.queue = Objects.requireNonNull(queue, "queue"); + this.modelInfo = Objects.requireNonNull(modelInfo, "modelInfo"); + this.contextSize = contextSize; + this.streamBufferSize = streamBufferSize; + this.renderer = Objects.requireNonNull(renderer, "renderer"); + this.log = Objects.requireNonNull(log, "log"); + this.requestId = + RequestId.CURRENT.isBound() ? RequestId.CURRENT.get() : RequestId.generate(); + } + + @Override + public void request(long n) { + if (terminated.get()) { + return; + } + if (n <= 0L) { + // Reactive-Streams §3.9: signal IllegalArgumentException via onError. + terminate( + () -> + subscriber.onError( + new IllegalArgumentException( + "Flow.Subscription.request(n): n must be positive (RS §3.9)"))); + return; + } + addDemand(n); + if (started.compareAndSet(false, true)) { + // Acquire one queue permit on first request(n). If we cannot, fail fast — matches + // edge-case #33 (third stream when queueDepth=1 throws QueueFullException via onError). + if (!queue.tryAcquire()) { + terminate( + () -> + subscriber.onError( + new QueueFullException( + String.format( + Locale.ROOT, + "Generator stream queue is full; reject before native start " + + "(streamBufferSize=%d)", + streamBufferSize)))); + return; + } + queueAcquired.set(true); + // NATIVE-PINNING WORKAROUND: launch the streaming worker on the platform-thread pool. + // Direct invocation from a virtual thread would either pin the carrier (older JDKs) or + // expose stale per-thread llama.cpp state to the next caller. See ARCHITECTURE.md §3.3. + executor.submitNative( + () -> { + workerThread = Thread.currentThread(); + startNanos = System.nanoTime(); + runWorker(); + return null; + }); + } else { + // Already running: poke the worker so it can drain into the subscriber up to current demand. + drain(); + } + } + + @Override + public void cancel() { + if (!cancelled.compareAndSet(false, true)) { + return; + } + KherudGenerator.StreamingIterator it = iterator; + if (it != null) { + try { + it.cancel(); + } catch (RuntimeException ex) { + log.debug("iterator cancel failed: {}", ex.getMessage()); + } + } + // Release the queue permit even if the worker hasn't run yet. + if (queueAcquired.compareAndSet(true, false)) { + queue.release(); + } + // If the worker hasn't started, terminate now with a Canceled terminal chunk per §7. + if (!started.get() || workerThread == null) { + // emitCanceledTerminalAndComplete drives the terminated flag itself; do NOT wrap in + // terminate() here, otherwise the inner onNext/onComplete signal is skipped because + // the outer compareAndSet already won. + emitCanceledTerminalAndComplete(0, 0); + } else { + // Worker will detect the cancel flag at the next iteration boundary and emit the + // canceled-terminal chunk itself. Wake it up if it is parked on the lock. + synchronized (lock) { + lock.notifyAll(); + } + } + } + + // -- worker ------------------------------------------------------------------- + + private void runWorker() { + long firstTokenNanos = -1L; + int promptTokens = 0; + int completionTokens = 0; + FinishReason reason = null; + try { + String prompt = renderer.apply(req); + KherudGenerator.StreamingIterator it = client.stream(prompt, req); + this.iterator = it; + try { + while (!cancelled.get() && it.hasNext()) { + KherudGenerator.StreamingChunk chunk = it.next(); + promptTokens = chunk.promptTokens(); + completionTokens = chunk.completionTokens(); + if (firstTokenNanos < 0L) { + firstTokenNanos = System.nanoTime(); + } + if (chunk.last()) { + reason = chunk.finishReason() == null ? new FinishReason.Eos() : chunk.finishReason(); + // Emit the body of the last chunk first (if non-empty), then the terminal chunk. + if (chunk.delta() != null && !chunk.delta().isEmpty()) { + enqueue(new GenerateChunk(chunk.delta(), false, null, null, null)); + } + break; + } + enqueue(new GenerateChunk(chunk.delta() == null ? "" : chunk.delta(), false, null, null, + null)); + } + } finally { + try { + it.close(); + } catch (Exception ignored) { + // best-effort + } + } + if (cancelled.get()) { + emitCanceledTerminalAndComplete(promptTokens, completionTokens); + return; + } + if (reason == null) { + reason = new FinishReason.Eos(); + } + // Terminal chunk with full stats. + long totalMs = elapsedMs(startNanos); + long firstTokenMs = firstTokenNanos < 0L ? 0L : (firstTokenNanos - startNanos) / 1_000_000L; + Usage usage = new Usage(promptTokens, completionTokens, promptTokens + completionTokens); + GenerateStats stats = + new GenerateStats( + requestId, + 0L, + 0L, + firstTokenMs, + totalMs, + totalMs, + completionTokens == 0 || totalMs <= 0L + ? 0.0d + : (completionTokens * 1000.0d) / (double) totalMs, + reason, + Math.min(contextSize, promptTokens + completionTokens), + contextSize, + modelInfo.revision(), + null); + enqueue(new GenerateChunk("", true, reason, usage, stats)); + drain(); + terminate(subscriber::onComplete); + } catch (RuntimeException ex) { + log.debug("stream worker failed: {}", ex.getMessage()); + terminate(() -> subscriber.onError(ex)); + } finally { + if (queueAcquired.compareAndSet(true, false)) { + queue.release(); + } + } + } + + private void emitCanceledTerminalAndComplete(int promptTokens, int completionTokens) { + long totalMs = startNanos == 0L ? 0L : elapsedMs(startNanos); + Usage usage = new Usage(promptTokens, completionTokens, promptTokens + completionTokens); + GenerateStats stats = + new GenerateStats( + requestId, + 0L, + 0L, + 0L, + totalMs, + totalMs, + 0.0d, + new FinishReason.Canceled(), + Math.min(contextSize, promptTokens + completionTokens), + contextSize, + modelInfo.revision(), + null); + GenerateChunk terminal = + new GenerateChunk("", true, new FinishReason.Canceled(), usage, stats); + if (terminated.compareAndSet(false, true)) { + try { + subscriber.onNext(terminal); + subscriber.onComplete(); + } catch (RuntimeException ex) { + // Subscriber threw — nothing more to do; log and drop. + log.debug("subscriber threw on canceled terminal: {}", ex.getMessage()); + } + } + } + + private void enqueue(GenerateChunk chunk) { + synchronized (lock) { + buffer.addLast(chunk); + } + drain(); + } + + /** Push as many buffered chunks as the subscriber's demand allows. */ + private void drain() { + while (true) { + GenerateChunk next; + synchronized (lock) { + if (terminated.get() || buffer.isEmpty()) { + return; + } + long d = demand.get(); + if (d <= 0L) { + return; + } + next = buffer.pollFirst(); + if (next == null) { + return; + } + demand.decrementAndGet(); + } + try { + subscriber.onNext(next); + } catch (RuntimeException ex) { + // Subscriber failure is fatal per Reactive-Streams §2.13. + terminate(() -> {}); + log.debug("subscriber.onNext threw: {}", ex.getMessage()); + return; + } + } + } + + private void addDemand(long n) { + while (true) { + long cur = demand.get(); + long next = cur + n; + if (next < cur) { + next = Long.MAX_VALUE; // saturate per RS rule §3.17 + } + if (demand.compareAndSet(cur, next)) { + return; + } + } + } + + private void terminate(Runnable terminalCall) { + if (terminated.compareAndSet(false, true)) { + try { + terminalCall.run(); + } catch (RuntimeException ex) { + log.debug("terminal signal threw: {}", ex.getMessage()); + } + } + } + + private static long elapsedMs(long startNanos) { + return Math.max(0L, (System.nanoTime() - startNanos) / 1_000_000L); + } +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/FeatureNotSupportedException.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/FeatureNotSupportedException.java new file mode 100644 index 0000000..03b3c7b --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/FeatureNotSupportedException.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +/** + * Thrown when a Phase-2 reserved field on {@link Message}, {@link GenerateRequest}, or {@link + * GenerateResponse} is populated in a Phase-1 SDK build. + * + *

Reserved fields (see {@code java-sdk.md} §14): + * + *

    + *
  • {@link Message}: {@code toolCalls}, {@code toolCallId}, {@code name}. + *
  • {@link GenerateRequest}: {@code tools}, {@code toolChoice}, {@code responseFormat}. + *
  • {@link GenerateResponse}: {@code systemFingerprint}. + *
+ * + *

This is a forward-compat guardrail: the wire format already names these fields so Phase-2 + * HTTP serialization can route them to tool-calling implementations without an API break, but the + * Phase-1 implementation rejects them at the boundary so callers cannot quietly assume support. + */ +public class FeatureNotSupportedException extends GenerateException { + + private static final long serialVersionUID = 1L; + + /** + * Construct an exception describing the reserved-field violation. + * + * @param message human-readable description naming the offending field(s) + */ + public FeatureNotSupportedException(String message) { + super(message); + } +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateChunk.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateChunk.java new file mode 100644 index 0000000..0cbe17c --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateChunk.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.github.randomcodespace.inference.FinishReason; +import io.github.randomcodespace.inference.Usage; + +/** + * One element of a streaming {@link Generator#stream(GenerateRequest)} response. + * + *

Streaming contract (per {@code java-sdk.md} §7): + * + *

    + *
  • {@code delta} carries an incremental token segment, never the cumulative + * text. + *
  • Exactly one terminal chunk per stream has {@code done == true}; its {@code delta} may be + * the empty string. The terminal chunk carries {@code finishReason}, full {@code usage}, and + * full {@code stats}; non-terminal chunks have all three set to {@code null}. + *
  • {@code Subscriber.onComplete()} fires after the terminal chunk; mid-stream failures arrive + * as {@code Subscriber.onError(Throwable)} without a terminal chunk preceding. + *
+ * + * @param delta incremental text fragment; non-null (may be empty on terminal) + * @param done {@code true} only on the terminal chunk + * @param finishReason terminal {@link FinishReason}; {@code null} on non-terminal chunks + * @param usage final token counts; {@code null} on non-terminal chunks + * @param stats final per-request telemetry; {@code null} on non-terminal chunks + */ +public record GenerateChunk( + String delta, + boolean done, + @JsonProperty("finish_reason") FinishReason finishReason, + Usage usage, + GenerateStats stats) {} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateException.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateException.java new file mode 100644 index 0000000..0c78259 --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateException.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +/** + * Root unchecked exception type for all failures originating in the {@code inference-sdk-generate} + * module. + * + *

Subclasses carry more specific failure modes: + * + *

    + *
  • {@link QueueFullException} — the bounded native queue is at capacity. + *
  • {@link ModelNotLoadedException} — the configured model could not be resolved. + *
  • {@link FeatureNotSupportedException} — the request used a Phase-2 reserved field. + *
+ * + *

{@code NativeLoadException} (for native-library extraction failures) lives in {@code + * inference-sdk-core} and is rethrown unchanged from this module. + */ +public class GenerateException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** + * Construct a new exception with a human-readable message. + * + * @param message non-null description + */ + public GenerateException(String message) { + super(message); + } + + /** + * Construct a new exception with a message and a wrapped cause. + * + * @param message non-null description + * @param cause underlying cause; may be {@code null} + */ + public GenerateException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateRequest.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateRequest.java new file mode 100644 index 0000000..d4c97c7 --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateRequest.java @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import java.util.List; +import java.util.Locale; +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Single text-generation request submitted to {@link Generator#complete(GenerateRequest)}, + * {@link Generator#completeAsync(GenerateRequest)}, or {@link Generator#stream(GenerateRequest)}. + * + *

Wire format: {@code snake_case} per {@code docs/WIRE_FORMAT.md}; {@code @JsonProperty} + * annotations lock the JSON keys for forward compatibility with the Phase-2 HTTP layer (Phase 1 + * does not serialize this record over the wire). + * + *

Validation (compact constructor)

+ * + *
    + *
  • {@code messages} must be non-null and non-empty (case #14). + *
  • {@code maxTokens} must be {@code > 0} (case #13). + *
  • {@code temperature} ∈ {@code [0, 2]}. + *
  • {@code topP} ∈ {@code (0, 1]}. + *
  • At least one {@code role == "user"} message must be present (case #15). + *
  • Reserved fields {@code tools}, {@code toolChoice}, {@code responseFormat} must be + * {@code null}; any non-null value raises {@link FeatureNotSupportedException} (case #50). + *
+ * + *

Use the {@linkplain #builder() fluent builder} to construct instances ergonomically. + * + * @param messages chat history, oldest first; non-null, non-empty, must include a user message + * @param maxTokens upper bound on completion tokens; must be {@code > 0} + * @param temperature sampling temperature; {@code [0, 2]} + * @param topP nucleus-sampling cutoff; {@code (0, 1]} + * @param stop optional stop strings; may be {@code null} or empty + * @param seed optional RNG seed; deterministic when paired with {@code temperature == 0} + * @param tools reserved for Phase 2 — must be {@code null} + * @param toolChoice reserved for Phase 2 — must be {@code null} + * @param responseFormat reserved for Phase 2 — must be {@code null} + */ +public record GenerateRequest( + List messages, + @JsonProperty("max_tokens") int maxTokens, + float temperature, + @JsonProperty("top_p") float topP, + List stop, + Long seed, + List tools, + @JsonProperty("tool_choice") Object toolChoice, + @JsonProperty("response_format") Object responseFormat) { + + /** Default sampling temperature if the builder is not customized. */ + public static final float DEFAULT_TEMPERATURE = 0.7f; + + /** Default top-p cutoff if the builder is not customized. */ + public static final float DEFAULT_TOP_P = 0.95f; + + /** + * Compact constructor: enforces the bounds documented above and rejects reserved fields. + * + * @throws IllegalArgumentException for any out-of-range or structural violation + * @throws FeatureNotSupportedException if any reserved field is non-null + */ + public GenerateRequest { + if (messages == null || messages.isEmpty()) { + throw new IllegalArgumentException("messages required and must be non-empty"); + } + if (maxTokens <= 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "maxTokens must be > 0, got %d", maxTokens)); + } + if (Float.isNaN(temperature) || temperature < 0f || temperature > 2f) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "temperature must be in [0, 2], got %f", temperature)); + } + if (Float.isNaN(topP) || topP <= 0f || topP > 1f) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "topP must be in (0, 1], got %f", topP)); + } + if (tools != null || toolChoice != null || responseFormat != null) { + throw new FeatureNotSupportedException( + "tools/toolChoice/responseFormat are reserved for Phase 2"); + } + boolean hasUser = false; + for (Message m : messages) { + if (m == null) { + throw new IllegalArgumentException("messages must not contain null entries"); + } + if ("user".equals(m.role())) { + hasUser = true; + } + } + if (!hasUser) { + throw new IllegalArgumentException("at least one user message is required"); + } + // Defensive copies so the record is observably immutable. + messages = List.copyOf(messages); + if (stop != null) { + stop = List.copyOf(stop); + } + } + + /** + * @return a fresh builder seeded with sensible defaults + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Fluent builder for {@link GenerateRequest}. Each setter validates eagerly so misconfiguration + * surfaces at the call site, not inside {@link #build()}. + * + *

Defaults: {@code temperature = 0.7}, {@code topP = 0.95}, no stop strings, no seed. + * Reserved fields are not exposed by the builder; populate them with the canonical record + * constructor if Phase 2 implementations must. + */ + public static final class Builder { + + private List messages; + private int maxTokens = 256; + private float temperature = DEFAULT_TEMPERATURE; + private float topP = DEFAULT_TOP_P; + private List stop; + private Long seed; + + private Builder() {} + + /** + * Set the chat history. + * + * @param messages non-null, non-empty list including at least one user message + * @return this builder + */ + public Builder messages(List messages) { + this.messages = Objects.requireNonNull(messages, "messages"); + return this; + } + + /** + * Set the maximum number of completion tokens. + * + * @param n must be {@code > 0} + * @return this builder + * @throws IllegalArgumentException if {@code n <= 0} + */ + public Builder maxTokens(int n) { + if (n <= 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "maxTokens must be > 0, got %d", n)); + } + this.maxTokens = n; + return this; + } + + /** + * Set the sampling temperature. + * + * @param t must be in {@code [0, 2]} + * @return this builder + * @throws IllegalArgumentException if {@code t} is out of range or NaN + */ + public Builder temperature(float t) { + if (Float.isNaN(t) || t < 0f || t > 2f) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "temperature must be in [0, 2], got %f", t)); + } + this.temperature = t; + return this; + } + + /** + * Set the nucleus-sampling cutoff. + * + * @param p must be in {@code (0, 1]} + * @return this builder + * @throws IllegalArgumentException if {@code p} is out of range or NaN + */ + public Builder topP(float p) { + if (Float.isNaN(p) || p <= 0f || p > 1f) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "topP must be in (0, 1], got %f", p)); + } + this.topP = p; + return this; + } + + /** + * Set the stop strings; copy is taken on {@link #build()}. + * + * @param stop may be {@code null} or empty + * @return this builder + */ + public Builder stop(List stop) { + this.stop = stop; + return this; + } + + /** + * Set the RNG seed. + * + * @param seed any long; pair with {@code temperature == 0} for determinism (case #18) + * @return this builder + */ + public Builder seed(long seed) { + this.seed = seed; + return this; + } + + /** + * Build the {@link GenerateRequest}. Final validation runs in the record's compact + * constructor. + * + * @return a fully validated request + */ + public GenerateRequest build() { + return new GenerateRequest( + messages, maxTokens, temperature, topP, stop, seed, null, null, null); + } + } +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateResponse.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateResponse.java new file mode 100644 index 0000000..27d71ac --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateResponse.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.github.randomcodespace.inference.FinishReason; +import io.github.randomcodespace.inference.Usage; + +/** + * Result of a non-streaming {@link Generator#complete(GenerateRequest)} call. + * + *

Wire format: {@code snake_case} per {@code docs/WIRE_FORMAT.md}. + * + *

Phase-1 constraints

+ * + *

{@code systemFingerprint} is reserved for Phase-2 OpenAI compatibility and is always {@code + * null} in Phase 1. Direct construction with a non-null value is permitted (Phase-2 producers + * inject it) but Phase-1 implementations must produce {@code null} (case #51 placeholder; tested + * via the integration suite in Tier 5). + * + * @param text full generated completion (no tokenization artefacts); never {@code null} + * @param finishReason why generation stopped; never {@code null} + * @param usage prompt/completion/total token counts; never {@code null} + * @param stats per-request telemetry; never {@code null} + * @param systemFingerprint reserved for Phase 2 — {@code null} in Phase 1 + */ +public record GenerateResponse( + String text, + @JsonProperty("finish_reason") FinishReason finishReason, + Usage usage, + GenerateStats stats, + @JsonProperty("system_fingerprint") String systemFingerprint) {} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateStats.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateStats.java new file mode 100644 index 0000000..e19fbc7 --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/GenerateStats.java @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.github.randomcodespace.inference.FinishReason; + +/** + * Per-request telemetry attached to {@link GenerateResponse} and the terminal {@link + * GenerateChunk}. + * + *

Wire format: {@code snake_case} per {@code docs/WIRE_FORMAT.md}; carried into the Phase-2 + * HTTP response under the {@code x_stats} extension. + * + *

All durations are in milliseconds. {@code tokensPerSecond} is computed as {@code + * completionTokens * 1000.0 / generationMs}, defaulting to {@code 0} when {@code generationMs == + * 0}. + * + * @param requestId stable request id (e.g. {@code req_}); never {@code null} + * @param queueMs wall time spent waiting for a worker + * @param promptEvalMs wall time spent encoding + KV-cache priming + * @param firstTokenMs wall time from request acceptance to first token + * @param generationMs wall time from first token to last token + * @param totalMs end-to-end wall time + * @param tokensPerSecond observed throughput + * @param stopReason terminal {@link FinishReason}; never {@code null} + * @param contextUsed tokens occupied in the KV cache after the response + * @param contextMax model's maximum context window + * @param modelRevision content hash / tag of the model file used + * @param node optional fixed-format node identifier (e.g. hostname); may be {@code null} + */ +public record GenerateStats( + @JsonProperty("request_id") String requestId, + @JsonProperty("queue_ms") long queueMs, + @JsonProperty("prompt_eval_ms") long promptEvalMs, + @JsonProperty("first_token_ms") long firstTokenMs, + @JsonProperty("generation_ms") long generationMs, + @JsonProperty("total_ms") long totalMs, + @JsonProperty("tokens_per_second") double tokensPerSecond, + @JsonProperty("stop_reason") FinishReason stopReason, + @JsonProperty("context_used") int contextUsed, + @JsonProperty("context_max") int contextMax, + @JsonProperty("model_revision") String modelRevision, + String node) {} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/Generator.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/Generator.java new file mode 100644 index 0000000..90d3048 --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/Generator.java @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import java.nio.file.Path; +import java.util.Locale; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Flow; + +import org.slf4j.Logger; +import org.slf4j.helpers.NOPLogger; + +import io.github.randomcodespace.inference.ModelInfo; + +/** + * Public generation API. Produces text completions from a list of chat-style messages using a + * local llama.cpp-backed model loaded via the {@code de.kherud:llama:4.2.0} JNI binding. + * + *

Construct via the fluent {@link #builder()}: + * + *

{@code
+ * try (Generator g = Generator.builder()
+ *         .model("qwen2.5-0.5b-instruct")
+ *         .threads(4)
+ *         .contextSize(2048)
+ *         .build()) {
+ *   GenerateRequest req = GenerateRequest.builder()
+ *           .messages(List.of(new Message("user", "hi")))
+ *           .maxTokens(64)
+ *           .build();
+ *   GenerateResponse r = g.complete(req);
+ * }
+ * }
+ * + *

Thread-safety and the native-thread-pinning workaround

+ * + *

All methods on this interface are thread-safe and may be called from virtual threads. + * Internal implementations route every JNI call into llama.cpp through {@code + * io.github.randomcodespace.inference.runtime.NativeExecutor}, which trampolines work onto a + * platform-thread pool — see {@code docs/ARCHITECTURE.md} §3.3 for the rationale (llama.cpp pins + * the carrier thread; submitting native work to a virtual-thread executor would either pin the + * carrier or expose stale per-thread state). Caller virtual threads await the resulting {@link + * CompletableFuture} or process the {@link Flow.Publisher} normally. + * + *

Streaming

+ * + *

{@link #stream(GenerateRequest)} returns a {@link Flow.Publisher} that honours backpressure + * via {@code Subscription.request(n)}, supports {@code cancel()} at the next-token boundary, and + * emits exactly one terminal chunk ({@code done == true}) before {@code onComplete()}. See {@code + * java-sdk.md} §7 for the full contract. + * + *

Lifecycle

+ * + *

{@link #close()} is idempotent. Subsequent calls to {@code complete*} / {@code stream} after + * close throw {@link IllegalStateException}. + * + * @see GenerateRequest + * @see GenerateResponse + * @see GenerateChunk + */ +public interface Generator extends AutoCloseable { + + /** + * Run a generation request synchronously. Blocks the caller until generation completes. + * + * @param req validated request; never {@code null} + * @return full completion plus telemetry + * @throws IllegalArgumentException if {@code req} is null + * @throws IllegalStateException if this generator has been {@linkplain #close() closed} + * @throws QueueFullException if the bounded native queue is full + * @throws GenerateException for any other failure (model error, native fault) + */ + GenerateResponse complete(GenerateRequest req); + + /** + * Run a generation request asynchronously. The returned future completes on the SDK's + * virtual-thread executor; the underlying native call runs on a platform thread (see class + * JavaDoc on the native-thread-pinning workaround). + * + * @param req validated request; never {@code null} + * @return future yielding the response; completes exceptionally with the same exceptions as + * {@link #complete(GenerateRequest)} + */ + CompletableFuture completeAsync(GenerateRequest req); + + /** + * Stream a generation request as incremental {@link GenerateChunk}s. Honours + * {@code Subscription.request(n)} backpressure and {@code Subscription.cancel()} at the + * next-token boundary; emits exactly one terminal chunk before {@code onComplete()}. + * + * @param req validated request; never {@code null} + * @return cold publisher; subscribing triggers the request, never re-subscribes + */ + Flow.Publisher stream(GenerateRequest req); + + /** + * @return static {@link ModelInfo} for the loaded model + */ + ModelInfo modelInfo(); + + /** + * Release the native llama.cpp model handle and the platform-thread executor. Idempotent. + * + *

In-flight requests submitted before {@code close()} will run to completion; submissions + * after close raise {@link IllegalStateException}. Native resources are guaranteed to + * be freed exactly once, even under concurrent {@code close()} calls. + */ + @Override + void close(); + + /** + * @return a fresh builder; configuration is per-builder and never shared + */ + static Builder builder() { + return new Builder(); + } + + /** + * Fluent builder for {@link Generator}. Mutator methods return {@code this} so calls chain. Each + * mutator validates eagerly so misconfiguration surfaces at the call site rather than inside + * {@link #build()}. + * + *

Default values: + * + *

    + *
  • {@code threads = 0} → resolved at build time via {@code ContainerCpu.detect()} + *
  • {@code contextSize = 2048} + *
  • {@code queueDepth = 32} + *
  • {@code streamBufferSize = 16} + *
  • {@code logger = NOPLogger.NOP_LOGGER} + *
+ * + *

Environment-variable fallbacks (per {@code java-sdk.md} §9): {@code INFERENCE_GEN_THREADS}, + * {@code INFERENCE_GEN_CONTEXT_SIZE}, {@code INFERENCE_GEN_QUEUE_DEPTH}, {@code + * INFERENCE_MODEL_DIR}. + */ + final class Builder { + + /** Default context window if unset by the caller. */ + public static final int DEFAULT_CONTEXT_SIZE = 2048; + + /** Default queue depth if unset by the caller. */ + public static final int DEFAULT_QUEUE_DEPTH = 32; + + /** Default streaming buffer size if unset by the caller. */ + public static final int DEFAULT_STREAM_BUFFER_SIZE = 16; + + private String model; + private Path modelPath; + private int threads; + private int contextSize = DEFAULT_CONTEXT_SIZE; + private int queueDepth = DEFAULT_QUEUE_DEPTH; + private int streamBufferSize = DEFAULT_STREAM_BUFFER_SIZE; + private Logger logger = NOPLogger.NOP_LOGGER; + + Builder() {} + + /** + * Set the logical model id. Must be non-null and non-blank. + * + * @param name model id, e.g. {@code "qwen2.5-0.5b-instruct"} + * @return this builder for chaining + * @throws IllegalArgumentException if {@code name} is blank + * @throws NullPointerException if {@code name} is {@code null} + */ + public Builder model(String name) { + Objects.requireNonNull(name, "model"); + if (name.isBlank()) { + throw new IllegalArgumentException("model must not be blank"); + } + this.model = name; + return this; + } + + /** + * Set an explicit on-disk model path. Overrides classpath / env resolution. + * + * @param path absolute path to the {@code .gguf} file; never {@code null} + * @return this builder for chaining + * @throws NullPointerException if {@code path} is {@code null} + */ + public Builder modelPath(Path path) { + this.modelPath = Objects.requireNonNull(path, "modelPath"); + return this; + } + + /** + * Number of platform threads in the native executor pool. {@code 0} means auto-detect via + * {@code ContainerCpu.detect()}. + * + * @param n thread count; must be {@code >= 0} + * @return this builder for chaining + * @throws IllegalArgumentException if {@code n < 0} + */ + public Builder threads(int n) { + if (n < 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "threads must be >= 0, got %d", n)); + } + this.threads = n; + return this; + } + + /** + * Maximum context window in tokens. Must be {@code > 0}. + * + * @param n context size + * @return this builder for chaining + * @throws IllegalArgumentException if {@code n <= 0} + */ + public Builder contextSize(int n) { + if (n <= 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "contextSize must be > 0, got %d", n)); + } + this.contextSize = n; + return this; + } + + /** + * Maximum number of in-flight + queued requests before {@link QueueFullException} is raised. + * + * @param n queue depth; must be {@code > 0} + * @return this builder for chaining + * @throws IllegalArgumentException if {@code n <= 0} + */ + public Builder queueDepth(int n) { + if (n <= 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "queueDepth must be > 0, got %d", n)); + } + this.queueDepth = n; + return this; + } + + /** + * Streaming buffer size in chunks, applied to {@code Subscription.request(n)} batching. + * + * @param n buffer size; must be {@code > 0} + * @return this builder for chaining + * @throws IllegalArgumentException if {@code n <= 0} + */ + public Builder streamBufferSize(int n) { + if (n <= 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "streamBufferSize must be > 0, got %d", n)); + } + this.streamBufferSize = n; + return this; + } + + /** + * SLF4J logger for diagnostic output. Default: {@link NOPLogger#NOP_LOGGER}; the library + * never configures logging itself. + * + * @param slf4jLogger logger; never {@code null} + * @return this builder for chaining + * @throws NullPointerException if {@code slf4jLogger} is {@code null} + */ + public Builder logger(Logger slf4jLogger) { + this.logger = Objects.requireNonNull(slf4jLogger, "logger"); + return this; + } + + /** + * Build the generator. Resolves the model via {@link ModelResolver}, opens the llama.cpp + * model handle, and starts the native-thread pool. + * + * @return a fully initialized {@link Generator}; the caller owns its lifecycle + * @throws IllegalArgumentException if neither {@code model} nor {@code modelPath} is set + * @throws ModelNotLoadedException if no matching model is on the classpath / disk + * @throws GenerateException for any other initialization failure + */ + public Generator build() { + if (model == null && modelPath == null) { + throw new IllegalArgumentException( + "Generator.Builder requires either model(String) or modelPath(Path) to be set"); + } + return KherudGenerator.create(this); + } + + // Package-private accessors used by KherudGenerator.create. Keeping the + // builder a "dumb" data carrier means tests can construct one without + // depending on KherudGenerator. + + String getModel() { + return model; + } + + Path getModelPath() { + return modelPath; + } + + int getThreads() { + return threads; + } + + int getContextSize() { + return contextSize; + } + + int getQueueDepth() { + return queueDepth; + } + + int getStreamBufferSize() { + return streamBufferSize; + } + + Logger getLogger() { + return logger; + } + } +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/InputValidator.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/InputValidator.java new file mode 100644 index 0000000..d0cdd0a --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/InputValidator.java @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import java.nio.ByteBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CodingErrorAction; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Locale; +import java.util.Objects; + +/** + * Boundary input-validation helpers used by {@link Generator} implementations. + * + *

Phase-1 path-1 (consume {@code de.kherud:llama:4.2.0} unmodified) leaves a residual tokenizer + * advisory GHSA-7rxv open. The advisory describes a signed/unsigned overflow on + * adversarial prompt input. The exploit surface is narrowed by enforcing two invariants at the + * SDK boundary, before any byte reaches llama.cpp's tokenizer: + * + *

    + *
  1. {@link #validatePromptLength(List, int) Total prompt length} is bounded — keeps the + * tokenizer in well-tested input-size regimes. + *
  2. {@link #validateUtf8(String) Strict UTF-8 validation} — rejects malformed byte sequences + * that the tokenizer might otherwise mis-handle. + *
+ * + *

See {@code SECURITY.md} §"Residual security risk" for the full mitigation rationale and + * sign-off. + * + *

This is a final utility class; instantiation is forbidden. + */ +public final class InputValidator { + + /** + * Conservative per-character cost when converting a configured token cap to a character cap. + * Real-world tokenizers average 3–5 chars per token; we use {@code 8} as a generous upper bound + * so legitimate prompts are not falsely rejected. + */ + static final int CHARS_PER_TOKEN_UPPER_BOUND = 8; + + private InputValidator() { + throw new AssertionError("no instances"); + } + + /** + * Bound the total character count across a list of {@link Message} payloads as a coarse + * pre-tokenization guard. + * + *

This is intentionally a character-based check rather than a token-based one — running the + * tokenizer to count tokens before deciding whether tokenization is safe would defeat the + * purpose. The bound is derived from {@code maxPromptTokens * CHARS_PER_TOKEN_UPPER_BOUND}. + * + * @param messages messages whose {@code content} fields will be tokenized; never {@code null} + * @param maxPromptTokens generator's configured prompt-token cap; must be {@code > 0} + * @throws IllegalArgumentException if any argument is invalid or the bound is exceeded + */ + public static void validatePromptLength(List messages, int maxPromptTokens) { + Objects.requireNonNull(messages, "messages"); + if (maxPromptTokens <= 0) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, "maxPromptTokens must be > 0, got %d", maxPromptTokens)); + } + long maxChars = (long) maxPromptTokens * CHARS_PER_TOKEN_UPPER_BOUND; + long total = 0L; + for (Message m : messages) { + // Compact constructor on Message guarantees content non-null. + total += m.content().length(); + // Add small constant for role tag + separator; mirrors chat-template overhead. + total += m.role().length() + 4L; + if (total > maxChars) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "prompt exceeds the %d-character cap derived from maxPromptTokens=%d " + + "(observed >= %d). Reduce input size or raise contextSize.", + maxChars, + maxPromptTokens, + total)); + } + } + } + + /** + * Assert that {@code text} is a strictly valid UTF-8 string. A {@link String} in the JVM is + * already a valid UTF-16 sequence, but this method checks that round-tripping through UTF-8 + * succeeds without substitutions — i.e. there are no unpaired surrogates that would translate + * into the U+FFFD replacement character once handed to a UTF-8 tokenizer. + * + *

Concretely: encodes the string with the {@link StandardCharsets#UTF_8} encoder configured + * to {@link CodingErrorAction#REPORT} on malformed input and unmappable characters. Any error + * raises {@link IllegalArgumentException}. + * + * @param text input string; never {@code null} + * @throws IllegalArgumentException if {@code text} is null or not strictly valid UTF-8 + */ + public static void validateUtf8(String text) { + if (text == null) { + throw new IllegalArgumentException("text must not be null"); + } + if (text.isEmpty()) { + return; + } + // Round-trip the string through a strict UTF-8 encode/decode pair. Any unpaired + // surrogate or unmappable codepoint surfaces as a CharacterCodingException. + java.nio.charset.CharsetEncoder enc = + StandardCharsets.UTF_8 + .newEncoder() + .onMalformedInput(CodingErrorAction.REPORT) + .onUnmappableCharacter(CodingErrorAction.REPORT); + ByteBuffer bytes; + try { + bytes = enc.encode(java.nio.CharBuffer.wrap(text)); + } catch (CharacterCodingException ex) { + throw new IllegalArgumentException( + "input contains malformed UTF-16 (e.g. unpaired surrogate): " + ex.getMessage(), ex); + } + CharsetDecoder dec = + StandardCharsets.UTF_8 + .newDecoder() + .onMalformedInput(CodingErrorAction.REPORT) + .onUnmappableCharacter(CodingErrorAction.REPORT); + try { + dec.decode(bytes); + } catch (CharacterCodingException ex) { + throw new IllegalArgumentException( + "input is not valid UTF-8 after encode round-trip: " + ex.getMessage(), ex); + } + } +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/KherudGenerator.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/KherudGenerator.java new file mode 100644 index 0000000..f8a5dbd --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/KherudGenerator.java @@ -0,0 +1,570 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Properties; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Flow; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.slf4j.Logger; + +import io.github.randomcodespace.inference.FinishReason; +import io.github.randomcodespace.inference.ModelInfo; +import io.github.randomcodespace.inference.Usage; +import io.github.randomcodespace.inference.runtime.ContainerCpu; +import io.github.randomcodespace.inference.runtime.NativeExecutor; +import io.github.randomcodespace.inference.runtime.RequestId; + +/** + * Default {@link Generator} implementation backed by the published {@code de.kherud:llama:4.2.0} + * artifact ({@code de.kherud.llama.LlamaModel}). + * + *

Thread model

+ * + *

The published {@code LlamaModel} class is {@code final} and cannot be subclassed for + * mocking. We therefore introduce {@link LlamaClient}, a function-shaped interface over the + * three methods we actually use ({@code complete}, {@code stream}, {@code close}). The production + * factory ({@link #create(Generator.Builder)}) wires a {@link KherudLlamaClient} that delegates + * to a real {@code LlamaModel}; unit tests inject in-memory fakes. + * + *

Every JNI call into llama.cpp is dispatched through {@link NativeExecutor#submitNative} to + * keep llama.cpp off the virtual-thread carrier pool — see {@code docs/ARCHITECTURE.md} §3.3 and + * the {@link Generator} class JavaDoc for the rationale. Async callers receive a {@link + * CompletableFuture} that resolves on the SDK-owned virtual-thread executor; awaiting it from a + * virtual thread yields the carrier correctly. + * + *

Backpressure / queue depth

+ * + *

A {@link Semaphore} sized at {@code queueDepth} caps the in-flight + queued count. When + * {@code complete}, {@code completeAsync}, or {@code stream} would push past that limit, a {@link + * QueueFullException} is raised immediately rather than letting the bounded native queue grow + * unboundedly. + * + *

Streaming

+ * + *

{@link #stream(GenerateRequest)} returns a cold publisher that creates one {@link + * BoundedSubscription} per subscriber. The subscription does not begin generation until the + * subscriber {@code request(n)}s — matching the §7 "no leaks if subscriber never requests" + * requirement. + */ +final class KherudGenerator implements Generator { + + /** + * Function-shaped abstraction over the kherud {@code LlamaModel} surface so unit tests can + * supply a fake without instantiating native llama.cpp. Wraps the only call shapes we need. + */ + interface LlamaClient extends AutoCloseable { + + /** + * Run a non-streaming completion. + * + * @param prompt rendered prompt text (chat template already applied) + * @param req validated request carrying sampling parameters + * @return raw model output text (no chat-template scaffolding; no stop string trim) + */ + CompletionResult complete(String prompt, GenerateRequest req); + + /** + * Stream a completion as token chunks. The returned iterator must support {@link + * StreamingIterator#cancel()} to stop generation at the next token boundary. + * + * @param prompt rendered prompt text + * @param req validated request + * @return iterator over per-token deltas; the final element's {@code last} field is {@code + * true} + */ + StreamingIterator stream(String prompt, GenerateRequest req); + + /** Idempotent close. */ + @Override + void close(); + } + + /** Per-request output of {@link LlamaClient#complete}. */ + record CompletionResult(String text, FinishReason finishReason, int promptTokens, + int completionTokens) {} + + /** One emitted token from {@link LlamaClient#stream}. */ + record StreamingChunk(String delta, boolean last, FinishReason finishReason, int promptTokens, + int completionTokens) {} + + /** Iterator over {@link StreamingChunk}s with cooperative cancellation. */ + interface StreamingIterator extends Iterator, AutoCloseable { + /** Stop generation at the next token boundary. Idempotent. */ + void cancel(); + + @Override + void close(); + } + + private final LlamaClient client; + private final NativeExecutor executor; + private final ModelInfo modelInfo; + private final ExecutorService asyncExecutor; + private final Logger log; + private final boolean ownsAsyncExecutor; + private final Semaphore queue; + private final int queueDepth; + private final int streamBufferSize; + private final int contextSize; + private final AtomicBoolean closed = new AtomicBoolean(false); + + /** + * Production factory: resolve the model file, open a {@code LlamaModel}, and start the native + * pool. Wires real collaborators around a {@link KherudLlamaClient} adapter. + * + * @throws ModelNotLoadedException if no matching model is on the classpath / disk + * @throws GenerateException for any other initialization failure + */ + static KherudGenerator create(Generator.Builder b) { + Logger log = b.getLogger(); + ModelResolver resolver = new ModelResolver(); + Path modelFile = resolver.resolve(b.getModel(), b.getModelPath()); + ModelInfo modelInfo = readModelInfo(b.getModel(), modelFile, b.getContextSize()); + int threadCount = b.getThreads() > 0 ? b.getThreads() : ContainerCpu.detect(); + NativeExecutor exec = + NativeExecutor.sized(Math.max(1, threadCount), "generate-native"); + LlamaClient client; + try { + // KherudLlamaClient construction touches native code — gate through the executor so + // model load itself runs on a platform thread (consistent with the pinning workaround). + client = exec.submitNative( + () -> KherudLlamaClient.open(modelFile, threadCount, b.getContextSize())).get(); + } catch (ExecutionException ex) { + Throwable cause = ex.getCause() == null ? ex : ex.getCause(); + exec.close(); + throw new GenerateException( + "Failed to open LlamaModel at " + modelFile + ": " + cause.getMessage(), cause); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + exec.close(); + throw new GenerateException("Interrupted while loading LlamaModel", ex); + } + return new KherudGenerator( + client, exec, modelInfo, log, b.getQueueDepth(), b.getStreamBufferSize(), + b.getContextSize(), null); + } + + /** + * Test-friendly constructor. + * + * @param client llama-client adapter; never {@code null} + * @param executor platform-thread executor; never {@code null} + * @param modelInfo static model metadata; never {@code null} + * @param log SLF4J logger; never {@code null} + * @param queueDepth bounded queue size; must be {@code > 0} + * @param streamBufferSize streaming buffer size; must be {@code > 0} + * @param contextSize model context window in tokens; must be {@code > 0} + * @param asyncExecutor optional virtual-thread executor for async callers; if {@code null}, a + * fresh per-instance virtual-thread executor is created and owned by this generator + */ + KherudGenerator( + LlamaClient client, + NativeExecutor executor, + ModelInfo modelInfo, + Logger log, + int queueDepth, + int streamBufferSize, + int contextSize, + ExecutorService asyncExecutor) { + this.client = Objects.requireNonNull(client, "client"); + this.executor = Objects.requireNonNull(executor, "executor"); + this.modelInfo = Objects.requireNonNull(modelInfo, "modelInfo"); + this.log = Objects.requireNonNull(log, "log"); + if (queueDepth <= 0) { + throw new IllegalArgumentException("queueDepth must be > 0"); + } + if (streamBufferSize <= 0) { + throw new IllegalArgumentException("streamBufferSize must be > 0"); + } + if (contextSize <= 0) { + throw new IllegalArgumentException("contextSize must be > 0"); + } + this.queueDepth = queueDepth; + this.streamBufferSize = streamBufferSize; + this.contextSize = contextSize; + this.queue = new Semaphore(queueDepth); + if (asyncExecutor == null) { + this.asyncExecutor = Executors.newVirtualThreadPerTaskExecutor(); + this.ownsAsyncExecutor = true; + } else { + this.asyncExecutor = asyncExecutor; + this.ownsAsyncExecutor = false; + } + } + + @Override + public GenerateResponse complete(GenerateRequest req) { + ensureOpen(); + Objects.requireNonNull(req, "req"); + validateInputs(req); + if (!queue.tryAcquire()) { + throw new QueueFullException( + String.format( + Locale.ROOT, + "Generator queue is full (queueDepth=%d); drop or retry with backoff", + queueDepth)); + } + String requestId = currentRequestId(); + long t0 = System.nanoTime(); + try { + String prompt = renderPrompt(req.messages()); + // NATIVE-PINNING WORKAROUND: route the JNI call through NativeExecutor (platform threads). + Callable call = () -> client.complete(prompt, req); + CompletionResult result; + try { + result = executor.submitNative(call).get(); + } catch (ExecutionException ex) { + Throwable cause = ex.getCause() == null ? ex : ex.getCause(); + throw new GenerateException("Native llama.cpp call failed: " + cause.getMessage(), cause); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new GenerateException("Interrupted while awaiting llama.cpp", ex); + } + long totalMs = elapsedMs(t0); + Usage usage = + new Usage( + result.promptTokens(), + result.completionTokens(), + result.promptTokens() + result.completionTokens()); + GenerateStats stats = + new GenerateStats( + requestId, + 0L, + 0L, + 0L, + totalMs, + totalMs, + tokensPerSecond(result.completionTokens(), totalMs), + result.finishReason(), + Math.min(contextSize, result.promptTokens() + result.completionTokens()), + contextSize, + modelInfo.revision(), + null); + // systemFingerprint reserved for Phase 2: always null in Phase 1. + return new GenerateResponse(result.text(), result.finishReason(), usage, stats, null); + } finally { + queue.release(); + } + } + + @Override + public CompletableFuture completeAsync(GenerateRequest req) { + ensureOpen(); + Objects.requireNonNull(req, "req"); + return CompletableFuture.supplyAsync(() -> complete(req), asyncExecutor); + } + + @Override + public Flow.Publisher stream(GenerateRequest req) { + ensureOpen(); + Objects.requireNonNull(req, "req"); + validateInputs(req); + return new GenerateStreamPublisher(req); + } + + @Override + public ModelInfo modelInfo() { + return modelInfo; + } + + @Override + public void close() { + if (!closed.compareAndSet(false, true)) { + return; + } + safeClose("client", client); + safeClose("nativeExecutor", executor); + if (ownsAsyncExecutor) { + asyncExecutor.close(); + } + } + + // -- internals ----------------------------------------------------------------- + + private void ensureOpen() { + if (closed.get()) { + throw new IllegalStateException("Generator has been closed"); + } + } + + private void validateInputs(GenerateRequest req) { + // Compact ctor of GenerateRequest already enforced structural validity. + // Apply path-1 mitigations (SECURITY.md §Residual security risk): + InputValidator.validatePromptLength(req.messages(), contextSize); + for (Message m : req.messages()) { + InputValidator.validateUtf8(m.content()); + } + } + + /** + * Render messages into a single prompt string. Phase 1 uses a minimal {@code + * <|role|>\ncontent\n} concatenation rather than a model-specific chat template — the published + * kherud {@code applyTemplate} entry point requires a loaded model handle, which is not what + * the test seam exposes. Production behaviour is exercised by Tier 5 integration tests against + * real models. + */ + static String renderPrompt(List messages) { + StringBuilder sb = new StringBuilder(); + for (Message m : messages) { + sb.append("<|").append(m.role()).append("|>\n"); + sb.append(m.content()).append("\n"); + } + sb.append("<|assistant|>\n"); + return sb.toString(); + } + + private static double tokensPerSecond(int tokens, long ms) { + return ms <= 0L ? 0.0d : (tokens * 1000.0d) / (double) ms; + } + + private static long elapsedMs(long startNanos) { + return Math.max(0L, (System.nanoTime() - startNanos) / 1_000_000L); + } + + private static String currentRequestId() { + String existing = RequestId.CURRENT.isBound() ? RequestId.CURRENT.get() : null; + return existing != null ? existing : RequestId.generate(); + } + + private static ModelInfo readModelInfo(String requestedId, Path modelFile, int contextSize) { + Path manifest = modelFile.resolveSibling("model-manifest.properties"); + String fallbackId = + requestedId == null + ? modelFile.getFileName() == null ? "unknown" : modelFile.getFileName().toString() + : requestedId; + if (!Files.isRegularFile(manifest)) { + // dimensions == -1 for generation models per ModelInfo JavaDoc. + return new ModelInfo(fallbackId, "unknown", "unknown", -1, contextSize); + } + Properties p = new Properties(); + try (var in = Files.newInputStream(manifest)) { + p.load(in); + } catch (Exception ex) { + throw new GenerateException( + "Failed to read model manifest " + manifest + ": " + ex.getMessage(), ex); + } + String id = p.getProperty("id", fallbackId); + String revision = p.getProperty("revision", "unknown"); + String quant = p.getProperty("quantization", "unknown"); + int max = parseIntOr(p.getProperty("max_tokens"), contextSize); + return new ModelInfo(id, revision, quant, -1, max); + } + + private static int parseIntOr(String raw, int fallback) { + if (raw == null || raw.isBlank()) { + return fallback; + } + try { + return Integer.parseInt(raw.trim()); + } catch (NumberFormatException ex) { + return fallback; + } + } + + private void safeClose(String name, AutoCloseable c) { + try { + c.close(); + } catch (Exception ex) { + log.debug("Failed to close {}: {}", name, ex.getMessage()); + } + } + + // -- streaming publisher ------------------------------------------------------ + + /** Cold publisher: each subscribe() creates a fresh {@link BoundedSubscription}. */ + private final class GenerateStreamPublisher implements Flow.Publisher { + private final GenerateRequest req; + + GenerateStreamPublisher(GenerateRequest req) { + this.req = req; + } + + @Override + public void subscribe(Flow.Subscriber subscriber) { + Objects.requireNonNull(subscriber, "subscriber"); + // Backpressure semaphore acquisition happens lazily inside BoundedSubscription on the + // first request(n) so subscribers that never request never trigger native work. + BoundedSubscription sub = + new BoundedSubscription( + req, + subscriber, + client, + executor, + queue, + modelInfo, + contextSize, + streamBufferSize, + KherudGenerator.this::renderAndValidate, + log); + subscriber.onSubscribe(sub); + } + } + + /** + * Single-pass prompt render + UTF-8 validate hook used by the streaming subscription. Called on + * the subscription's worker thread, not on the subscriber thread. + */ + String renderAndValidate(GenerateRequest req) { + validateInputs(req); + return renderPrompt(req.messages()); + } + + /** + * Production adapter: delegates to a kherud {@code LlamaModel}. Loaded reflectively so the + * compile-time module declaration ({@code requires llama;}) governs the dependency, but tests + * can still substitute fakes. Reflection-isolated to a tiny surface that maps directly to + * spec §6.4 contract — anything more would belong in Tier 5 integration tests. + */ + static final class KherudLlamaClient implements LlamaClient { + private final de.kherud.llama.LlamaModel model; + private final AtomicInteger active = new AtomicInteger(0); + + KherudLlamaClient(de.kherud.llama.LlamaModel model) { + this.model = model; + } + + /** + * Open a {@code LlamaModel} for the given GGUF file, configured with the requested thread and + * context-size budget. + */ + static KherudLlamaClient open(Path modelFile, int threads, int ctxSize) { + de.kherud.llama.ModelParameters params = + new de.kherud.llama.ModelParameters() + .setModel(modelFile.toAbsolutePath().toString()) + .setThreads(Math.max(1, threads)) + .setCtxSize(ctxSize); + return new KherudLlamaClient(new de.kherud.llama.LlamaModel(params)); + } + + @Override + public CompletionResult complete(String prompt, GenerateRequest req) { + // NATIVE-PINNING WORKAROUND: caller (KherudGenerator.complete) routes this method via + // NativeExecutor.submitNative — it must NOT be invoked directly from a virtual thread. + active.incrementAndGet(); + try { + de.kherud.llama.InferenceParameters ip = toInferenceParameters(prompt, req); + String text = model.complete(ip); + // The published 4.2.0 surface does not return token counts on the synchronous + // complete call; estimate via the encode round-trip for diagnostics. This keeps the + // Usage invariant (totalTokens == prompt + completion) honest. + int promptTokens; + int completionTokens; + try { + promptTokens = model.encode(prompt).length; + } catch (RuntimeException ex) { + promptTokens = approxTokens(prompt); + } + try { + completionTokens = model.encode(text).length; + } catch (RuntimeException ex) { + completionTokens = approxTokens(text); + } + FinishReason reason = + completionTokens >= req.maxTokens() ? new FinishReason.Length() : new FinishReason.Eos(); + return new CompletionResult(text, reason, promptTokens, completionTokens); + } finally { + active.decrementAndGet(); + } + } + + @Override + public StreamingIterator stream(String prompt, GenerateRequest req) { + de.kherud.llama.InferenceParameters ip = toInferenceParameters(prompt, req); + de.kherud.llama.LlamaIterator it = model.generate(ip).iterator(); + int approxPrompt; + try { + approxPrompt = model.encode(prompt).length; + } catch (RuntimeException ex) { + approxPrompt = approxTokens(prompt); + } + final int promptTokens = approxPrompt; + final AtomicInteger emitted = new AtomicInteger(0); + final AtomicBoolean cancelled = new AtomicBoolean(false); + return new StreamingIterator() { + @Override + public boolean hasNext() { + if (cancelled.get()) { + return false; + } + try { + return it.hasNext(); + } catch (RuntimeException ex) { + return false; + } + } + + @Override + public StreamingChunk next() { + de.kherud.llama.LlamaOutput out = it.next(); + int n = emitted.incrementAndGet(); + // The published iterator surfaces a `stop` flag on each output but it's package-private; + // emulate the contract by treating no-more-tokens (next hasNext returns false) as last. + boolean last = !hasNext() || n >= req.maxTokens(); + FinishReason reason = null; + if (last) { + reason = n >= req.maxTokens() ? new FinishReason.Length() : new FinishReason.Eos(); + } + return new StreamingChunk(out.text == null ? "" : out.text, last, reason, promptTokens, + n); + } + + @Override + public void cancel() { + if (cancelled.compareAndSet(false, true)) { + it.cancel(); + } + } + + @Override + public void close() { + cancel(); + } + }; + } + + @Override + public void close() { + model.close(); + } + + private static de.kherud.llama.InferenceParameters toInferenceParameters( + String prompt, GenerateRequest req) { + de.kherud.llama.InferenceParameters ip = + new de.kherud.llama.InferenceParameters(prompt) + .setNPredict(req.maxTokens()) + .setTemperature(req.temperature()) + .setTopP(req.topP()); + if (req.stop() != null && !req.stop().isEmpty()) { + ip = ip.setStopStrings(req.stop().toArray(new String[0])); + } + if (req.seed() != null) { + // kherud uses int seed; clamp to int range deterministically. + long s = req.seed(); + ip = ip.setSeed((int) (s ^ (s >>> 32))); + } + return ip; + } + + private static int approxTokens(String s) { + // Heuristic — only used when the kherud encode path itself fails. + return s == null ? 0 : Math.max(1, s.length() / 4); + } + } + +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/Message.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/Message.java new file mode 100644 index 0000000..fab3c89 --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/Message.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import java.util.List; +import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A single chat-style message in a {@link GenerateRequest}. + * + *

Wire format: {@code snake_case} per {@code docs/WIRE_FORMAT.md}; {@code @JsonProperty} + * annotations lock the JSON keys for forward compatibility with the Phase-2 HTTP layer (Phase 1 + * does not serialize these records over the wire). + * + *

Phase-1 constraints

+ * + *
    + *
  • {@code role} must be one of {@code "system"}, {@code "user"}, {@code "assistant"}, {@code + * "tool"}. + *
  • {@code role} and {@code content} must be non-{@code null}. + *
  • Reserved fields {@code toolCalls}, {@code toolCallId}, {@code name} must be {@code null}; + * any non-null value raises {@link FeatureNotSupportedException} (case #49 in {@code + * java-sdk.md} §11.2). + *
+ * + * @param role one of {@code "system"|"user"|"assistant"|"tool"}; never {@code null} + * @param content textual payload; never {@code null} (use the empty string for vacuous content) + * @param toolCalls reserved for Phase 2 — must be {@code null} + * @param toolCallId reserved for Phase 2 — must be {@code null} + * @param name reserved for Phase 2 — must be {@code null} + */ +public record Message( + String role, + String content, + @JsonProperty("tool_calls") List toolCalls, + @JsonProperty("tool_call_id") String toolCallId, + String name) { + + /** Allow-list of role tags accepted by Phase 1. */ + private static final Set ALLOWED_ROLES = Set.of("system", "user", "assistant", "tool"); + + /** + * Compact constructor: validates role allow-list, non-null role/content, and reserved-field + * absence. + * + * @throws IllegalArgumentException if {@code role} or {@code content} is null, or {@code role} + * is not in the allow-list + * @throws FeatureNotSupportedException if any reserved field is non-null + */ + public Message { + if (role == null) { + throw new IllegalArgumentException("role must not be null"); + } + if (content == null) { + throw new IllegalArgumentException("content must not be null"); + } + if (!ALLOWED_ROLES.contains(role)) { + throw new IllegalArgumentException( + "role must be one of " + ALLOWED_ROLES + ", got: " + role); + } + if (toolCalls != null || toolCallId != null || name != null) { + throw new FeatureNotSupportedException( + "tool calling (toolCalls/toolCallId/name) is reserved for Phase 2"); + } + } + + /** + * Convenience constructor for Phase 1 callers. Equivalent to {@code new Message(role, content, + * null, null, null)}. + * + * @param role one of {@code "system"|"user"|"assistant"|"tool"} + * @param content textual payload + */ + public Message(String role, String content) { + this(role, content, null, null, null); + } +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/ModelNotLoadedException.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/ModelNotLoadedException.java new file mode 100644 index 0000000..2c2a7c6 --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/ModelNotLoadedException.java @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import java.util.List; +import java.util.Locale; + +/** + * Thrown when the configured generation model cannot be resolved on disk or on the classpath. + * + *

Resolution order (per {@code java-sdk.md} §8): explicit {@code modelPath(Path)} → + * {@code INFERENCE_MODEL_DIR} env var → classpath resource under {@code /models/}. + * + *

The exception message lists the locations searched and the model ids visible on the + * classpath, mirroring the {@code embed} module's behaviour for diagnostic parity. + */ +public class ModelNotLoadedException extends GenerateException { + + private static final long serialVersionUID = 1L; + + /** + * Construct an exception describing the resolution failure. + * + * @param model logical model id that was requested (never {@code null}) + * @param searched ordered list of locations that were probed + * @param available model ids visible on the classpath (may be empty) + */ + public ModelNotLoadedException(String model, List searched, List available) { + super(format(model, searched, available)); + } + + /** + * Construct an exception with an arbitrary message; prefer the structured form above when the + * resolution context is available. + * + * @param message human-readable description + */ + public ModelNotLoadedException(String message) { + super(message); + } + + private static String format(String model, List searched, List available) { + StringBuilder sb = new StringBuilder(); + sb.append( + String.format(Locale.ROOT, "Model '%s' not found. Locations searched:%n", model)); + for (String s : searched) { + sb.append(" - ").append(s).append(System.lineSeparator()); + } + if (available != null && !available.isEmpty()) { + sb.append("Available on classpath: ").append(String.join(", ", available)); + } else { + sb.append("No model JARs visible on the classpath."); + } + return sb.toString(); + } +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/ModelResolver.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/ModelResolver.java new file mode 100644 index 0000000..df8848b --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/ModelResolver.java @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.function.Function; + +/** + * Resolves a generation model identifier to a concrete {@link Path} on disk, in the order + * specified by {@code java-sdk.md} §8: + * + *

    + *
  1. Explicit {@link Generator.Builder#modelPath(Path)}, if non-null and pointing at an + * existing file. + *
  2. {@code INFERENCE_MODEL_DIR} environment variable: a directory expected to contain a {@code + * .gguf} file (or just {@code }). + *
  3. Classpath: any {@code /models/.gguf} resource shipped by an {@code + * inference-sdk-generate-} JAR. + *
+ * + *

If none match, {@link ModelNotLoadedException} is thrown with a descriptive message listing + * the locations searched and the model ids visible on the classpath via the canonical {@code + * /models/model-manifest.properties} discovery. + * + *

This duplicates {@code inference-sdk-embed.ModelResolver} (the embed module looks for {@code + * .onnx}, this one for {@code .gguf}). A future refactor may extract a shared utility into + * {@code inference-sdk-core}; for Phase 1 the modules are independent. + * + *

Package-private; consumers configure resolution via the builder, never this class directly. + */ +final class ModelResolver { + + /** Standard environment variable per {@code java-sdk.md} §9. */ + static final String ENV_MODEL_DIR = "INFERENCE_MODEL_DIR"; + + /** Canonical classpath prefix for shipped model resources. */ + static final String CLASSPATH_PREFIX = "/models/"; + + /** Canonical extension. GGUF is the only Phase-1 generation format. */ + static final String GGUF_EXTENSION = ".gguf"; + + /** Canonical manifest filename for model-jar discovery. */ + static final String MANIFEST_RESOURCE = "models/model-manifest.properties"; + + private final Function envLookup; + private final ClassLoader classLoader; + + /** + * Default resolver: real {@link System#getenv(String)} and the current thread's context loader. + */ + ModelResolver() { + this(System::getenv, preferredClassLoader()); + } + + /** Test-friendly constructor; both arguments must be non-null. */ + ModelResolver(Function envLookup, ClassLoader classLoader) { + this.envLookup = Objects.requireNonNull(envLookup, "envLookup"); + this.classLoader = Objects.requireNonNull(classLoader, "classLoader"); + } + + /** + * @return a resolver whose {@code env} lookup is replaced by {@code envLookup}; useful for tests + */ + static ModelResolver withEnv(Function envLookup) { + return new ModelResolver(envLookup, preferredClassLoader()); + } + + /** + * Resolve the given model identifier to a concrete file path. + * + * @param model logical model id (e.g. {@code "qwen2.5-0.5b-instruct"}); may be null when {@code + * explicitPath} is provided + * @param explicitPath caller-supplied {@link Generator.Builder#modelPath(Path)} value; may be + * null + * @return existing, readable file path + * @throws ModelNotLoadedException if no candidate matches + */ + Path resolve(String model, Path explicitPath) { + List searched = new ArrayList<>(); + + // 1. Explicit path wins. + if (explicitPath != null) { + searched.add("explicit modelPath=" + explicitPath); + if (Files.isRegularFile(explicitPath)) { + return explicitPath.toAbsolutePath(); + } + throw new ModelNotLoadedException( + model == null ? explicitPath.toString() : model, searched, discoverAvailableModels()); + } + + Objects.requireNonNull(model, "model id required when modelPath is unset"); + + // 2. INFERENCE_MODEL_DIR. + String envDir = envLookup.apply(ENV_MODEL_DIR); + if (envDir != null && !envDir.isBlank()) { + Path dir = Paths.get(envDir); + Path withExt = dir.resolve(model + GGUF_EXTENSION); + Path bare = dir.resolve(model); + searched.add(ENV_MODEL_DIR + "=" + envDir + " (looking for " + withExt + " or " + bare + ")"); + if (Files.isRegularFile(withExt)) { + return withExt.toAbsolutePath(); + } + if (Files.isRegularFile(bare)) { + return bare.toAbsolutePath(); + } + } else { + searched.add(ENV_MODEL_DIR + " (unset)"); + } + + // 3. Classpath: /models/.gguf. + String resourcePath = CLASSPATH_PREFIX + model + GGUF_EXTENSION; + searched.add("classpath:" + resourcePath); + URL classpathUrl = classLoader.getResource(resourcePath.substring(1)); + if (classpathUrl != null && "file".equals(classpathUrl.getProtocol())) { + // Loaded from the unpacked target/classes during dev — return directly. + try { + Path direct = Paths.get(classpathUrl.toURI()); + if (Files.isRegularFile(direct)) { + return direct.toAbsolutePath(); + } + } catch (Exception ignored) { + // Fall through to extraction. + } + } + if (classpathUrl != null) { + // Resource is inside a JAR; extract to temp. + Path extracted = extractClasspathResource(resourcePath); + if (extracted != null) { + return extracted; + } + } + + throw new ModelNotLoadedException(model, searched, discoverAvailableModels()); + } + + /** + * Stream a classpath resource to a temp file. Returns {@code null} on failure (caller treats + * that as "not found" and surfaces a {@link ModelNotLoadedException}). + */ + private Path extractClasspathResource(String resourcePath) { + String normalized = resourcePath.startsWith("/") ? resourcePath.substring(1) : resourcePath; + try (InputStream in = classLoader.getResourceAsStream(normalized)) { + if (in == null) { + return null; + } + String fileName = lastSegment(resourcePath); + Path tempDir = + Files.createTempDirectory( + String.format( + Locale.ROOT, "inference-sdk-generate-%d-", ProcessHandle.current().pid())); + Path target = tempDir.resolve(fileName); + Files.copy(in, target); + // Best-effort cleanup on JVM shutdown. + target.toFile().deleteOnExit(); + tempDir.toFile().deleteOnExit(); + return target.toAbsolutePath(); + } catch (IOException ex) { + return null; + } + } + + /** + * Discover model ids visible on the classpath via {@code /models/model-manifest.properties} + * resources shipped by every {@code inference-sdk-generate-} JAR. + * + * @return alphabetically-sorted, distinct model ids; empty list if none visible + */ + List discoverAvailableModels() { + Set ids = new LinkedHashSet<>(); + try { + Enumeration manifests = classLoader.getResources(MANIFEST_RESOURCE); + while (manifests.hasMoreElements()) { + URL url = manifests.nextElement(); + try (InputStream in = url.openStream()) { + Properties p = new Properties(); + p.load(in); + String id = p.getProperty("id"); + if (id != null && !id.isBlank()) { + ids.add(id.trim()); + } + } catch (IOException ignored) { + // Skip malformed manifests; one bad jar shouldn't break discovery. + } + } + } catch (IOException ignored) { + // Classloader hiccup; best-effort discovery. + } + List out = new ArrayList<>(ids); + Collections.sort(out); + return out; + } + + private static ClassLoader preferredClassLoader() { + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + return cl != null ? cl : ModelResolver.class.getClassLoader(); + } + + private static String lastSegment(String path) { + int idx = path.lastIndexOf('/'); + return idx < 0 ? path : path.substring(idx + 1); + } +} diff --git a/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/QueueFullException.java b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/QueueFullException.java new file mode 100644 index 0000000..7f345f1 --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/io/github/randomcodespace/inference/generate/QueueFullException.java @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +/** + * Thrown when the configured {@code queueDepth} is exhausted by in-flight requests and a new + * submission cannot be accepted. + * + *

This is a backpressure signal — callers should treat it as transient and either retry with + * exponential backoff or shed load. The library never silently buffers beyond {@code queueDepth}. + * + *

See edge case #33 in {@code java-sdk.md} §11.2 for the integration-test contract. + */ +public class QueueFullException extends GenerateException { + + private static final long serialVersionUID = 1L; + + /** + * Construct a new exception describing the queue-full condition. + * + * @param message human-readable description; should include the configured queue depth + */ + public QueueFullException(String message) { + super(message); + } +} diff --git a/java/inference-sdk-generate/src/main/java/module-info.java b/java/inference-sdk-generate/src/main/java/module-info.java new file mode 100644 index 0000000..4d06963 --- /dev/null +++ b/java/inference-sdk-generate/src/main/java/module-info.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ + +/** + * inference-sdk-generate: generation API backed by the published {@code de.kherud:llama:4.2.0} + * artifact (bundled llama.cpp). + * + *

Consumers obtain a {@link io.github.randomcodespace.inference.generate.Generator} through + * {@link io.github.randomcodespace.inference.generate.Generator#builder()}. Every JNI call into + * llama.cpp is marshalled onto a platform-thread pool via {@code + * io.github.randomcodespace.inference.runtime.NativeExecutor} from the {@code inference-sdk-core} + * module — this is the native-thread-pinning workaround documented in {@code docs/ARCHITECTURE.md} + * §3.3 (llama.cpp pins the carrier thread; submitting native work to a virtual-thread executor + * would either pin the carrier or expose stale per-thread state). + * + *

The {@code de.kherud:llama:4.2.0} JAR carries no {@code Automatic-Module-Name} manifest + * attribute, so its automatic module name is derived from its filename per JLS / module spec + * rules: {@code llama-4.2.0.jar} → trailing version stripped → automatic module name {@code + * llama}. Verified by {@code unzip -p ~/.m2/repository/de/kherud/llama/4.2.0/llama-4.2.0.jar + * META-INF/MANIFEST.MF} — manifest contains only {@code Manifest-Version}, {@code Created-By}, + * {@code Build-Jdk-Spec}. + */ +module io.github.randomcodespace.inference.generate { + exports io.github.randomcodespace.inference.generate; + + requires io.github.randomcodespace.inference.core; + requires org.slf4j; + // Filename-derived automatic module: llama-4.2.0.jar → "llama". + // No Automatic-Module-Name attribute in the kherud manifest as of 4.2.0. + requires llama; + + // Jackson annotations are compile-time only; declared transitive so consumers + // who deserialize Message / GenerateRequest / GenerateChunk see the + // @JsonProperty bindings. + requires static com.fasterxml.jackson.annotation; +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/FakeLlamaClient.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/FakeLlamaClient.java new file mode 100644 index 0000000..8ea12b3 --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/FakeLlamaClient.java @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import io.github.randomcodespace.inference.FinishReason; + +/** + * In-memory fake of {@link KherudGenerator.LlamaClient} for unit tests. Avoids loading native + * llama.cpp by mapping every prompt to a deterministic fixed completion. Captures call counts so + * tests can assert that {@code complete} / {@code stream} were dispatched on a platform thread + * (the native-thread-pinning workaround) rather than directly from a virtual thread. + */ +class FakeLlamaClient implements KherudGenerator.LlamaClient { + + /** Last thread that invoked {@link #complete(String, GenerateRequest)}. */ + volatile Thread lastCompleteThread; + + /** Last thread that invoked {@link #stream(String, GenerateRequest)}. */ + volatile Thread lastStreamThread; + + final AtomicInteger completeCalls = new AtomicInteger(0); + final AtomicInteger streamCalls = new AtomicInteger(0); + final AtomicBoolean closed = new AtomicBoolean(false); + + /** Stream-side delta tokens; defaults to four single-character chunks. */ + List streamDeltas = List.of("h", "e", "l", "lo"); + + /** Per-call complete payload; default fixture replicates a one-shot completion. */ + String completionText = "hello"; + + /** If non-null, {@code complete} throws this instead of returning. */ + RuntimeException completeError; + + @Override + public KherudGenerator.CompletionResult complete(String prompt, GenerateRequest req) { + lastCompleteThread = Thread.currentThread(); + completeCalls.incrementAndGet(); + if (completeError != null) { + throw completeError; + } + int promptTokens = Math.max(1, prompt.length() / 4); + int completionTokens = Math.max(1, completionText.length() / 4); + FinishReason reason = + completionTokens >= req.maxTokens() ? new FinishReason.Length() : new FinishReason.Eos(); + return new KherudGenerator.CompletionResult( + completionText, reason, promptTokens, completionTokens); + } + + @Override + public KherudGenerator.StreamingIterator stream(String prompt, GenerateRequest req) { + lastStreamThread = Thread.currentThread(); + streamCalls.incrementAndGet(); + int promptTokens = Math.max(1, prompt.length() / 4); + Iterator it = streamDeltas.iterator(); + AtomicBoolean cancelled = new AtomicBoolean(false); + AtomicInteger emitted = new AtomicInteger(0); + int total = streamDeltas.size(); + return new KherudGenerator.StreamingIterator() { + @Override + public boolean hasNext() { + return !cancelled.get() && it.hasNext(); + } + + @Override + public KherudGenerator.StreamingChunk next() { + String delta = it.next(); + int n = emitted.incrementAndGet(); + boolean last = n >= total || n >= req.maxTokens(); + FinishReason reason = null; + if (last) { + reason = n >= req.maxTokens() ? new FinishReason.Length() : new FinishReason.Eos(); + } + return new KherudGenerator.StreamingChunk(delta, last, reason, promptTokens, n); + } + + @Override + public void cancel() { + cancelled.set(true); + } + + @Override + public void close() { + cancelled.set(true); + } + }; + } + + @Override + public void close() { + closed.set(true); + } +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/GenerateRequestBuilderTest.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/GenerateRequestBuilderTest.java new file mode 100644 index 0000000..2219fb3 --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/GenerateRequestBuilderTest.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for the fluent {@link GenerateRequest.Builder}. Mirrors the validation matrix in + * java-sdk.md §11.1 and exercises the eager-validate-on-set contract. + */ +final class GenerateRequestBuilderTest { + + @Test + void buildsValidRequestWithDefaults() { + GenerateRequest r = + GenerateRequest.builder().messages(List.of(new Message("user", "hi"))).build(); + assertThat(r.maxTokens()).isPositive(); + assertThat(r.temperature()).isEqualTo(GenerateRequest.DEFAULT_TEMPERATURE); + assertThat(r.topP()).isEqualTo(GenerateRequest.DEFAULT_TOP_P); + assertThat(r.seed()).isNull(); + } + + @Test + void messagesIsRequiredAtBuildTime() { + assertThatNullPointerException() + .isThrownBy(() -> GenerateRequest.builder().messages(null)); + } + + @Test + void maxTokensValidationOnSet() { + GenerateRequest.Builder b = GenerateRequest.builder().messages(List.of(new Message("user", "hi"))); + assertThatIllegalArgumentException().isThrownBy(() -> b.maxTokens(0)); + assertThatIllegalArgumentException().isThrownBy(() -> b.maxTokens(-1)); + assertThat(b.maxTokens(8).build().maxTokens()).isEqualTo(8); + } + + @Test + void temperatureValidationOnSet() { + GenerateRequest.Builder b = GenerateRequest.builder().messages(List.of(new Message("user", "hi"))); + assertThatIllegalArgumentException().isThrownBy(() -> b.temperature(-0.1f)); + assertThatIllegalArgumentException().isThrownBy(() -> b.temperature(2.5f)); + assertThatIllegalArgumentException().isThrownBy(() -> b.temperature(Float.NaN)); + assertThat(b.temperature(0.5f).build().temperature()).isEqualTo(0.5f); + } + + @Test + void topPValidationOnSet() { + GenerateRequest.Builder b = GenerateRequest.builder().messages(List.of(new Message("user", "hi"))); + assertThatIllegalArgumentException().isThrownBy(() -> b.topP(0f)); + assertThatIllegalArgumentException().isThrownBy(() -> b.topP(-0.1f)); + assertThatIllegalArgumentException().isThrownBy(() -> b.topP(1.1f)); + assertThatIllegalArgumentException().isThrownBy(() -> b.topP(Float.NaN)); + assertThat(b.topP(0.7f).build().topP()).isEqualTo(0.7f); + } + + @Test + void seedAndStopFlowThroughToRecord() { + GenerateRequest r = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .seed(42L) + .stop(List.of("\n\n")) + .build(); + assertThat(r.seed()).isEqualTo(42L); + assertThat(r.stop()).containsExactly("\n\n"); + } + + @Test + void builderRejectsBuildWithoutMessages() { + assertThatIllegalArgumentException() + .isThrownBy(() -> GenerateRequest.builder().build()) + .withMessageContaining("messages required"); + } +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/GenerateRequestTest.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/GenerateRequestTest.java new file mode 100644 index 0000000..cd4ee6e --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/GenerateRequestTest.java @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Unit tests for {@link GenerateRequest}. Covers java-sdk.md §11.1 generate-module bullet + * "GenerateRequest validation: maxTokens, temperature, topP bounds; empty messages; missing user + * message; reserved fields" and §11.2 cases #13, #14, #15, #50. + */ +final class GenerateRequestTest { + + private static List validMessages() { + return List.of(new Message("user", "hello")); + } + + @Test + void canonicalConstructorAcceptsValidArguments() { + GenerateRequest r = + new GenerateRequest(validMessages(), 64, 0.7f, 0.95f, null, null, null, null, null); + assertThat(r.maxTokens()).isEqualTo(64); + assertThat(r.temperature()).isEqualTo(0.7f); + assertThat(r.topP()).isEqualTo(0.95f); + } + + // -- §11.2 case #14: empty messages -> IAE -- + + @Test + void rejectsNullMessages() { + assertThatIllegalArgumentException() + .isThrownBy( + () -> new GenerateRequest(null, 1, 0.7f, 0.95f, null, null, null, null, null)) + .withMessageContaining("messages required"); + } + + @Test + void rejectsEmptyMessages() { + assertThatIllegalArgumentException() + .isThrownBy( + () -> new GenerateRequest(List.of(), 1, 0.7f, 0.95f, null, null, null, null, null)) + .withMessageContaining("messages required"); + } + + // -- §11.2 case #13: maxTokens=0 -> IAE -- + + @ParameterizedTest + @ValueSource(ints = {0, -1, -100, Integer.MIN_VALUE}) + void rejectsNonPositiveMaxTokens(int maxTokens) { + assertThatIllegalArgumentException() + .isThrownBy( + () -> + new GenerateRequest( + validMessages(), maxTokens, 0.7f, 0.95f, null, null, null, null, null)) + .withMessageContaining("maxTokens must be > 0"); + } + + // -- temperature bounds -- + + @ParameterizedTest + @ValueSource(floats = {-0.0001f, -1f, 2.0001f, 5f, Float.NaN}) + void rejectsTemperatureOutsideRange(float temperature) { + assertThatIllegalArgumentException() + .isThrownBy( + () -> + new GenerateRequest( + validMessages(), 8, temperature, 0.95f, null, null, null, null, null)) + .withMessageContaining("temperature must be in [0, 2]"); + } + + @ParameterizedTest + @ValueSource(floats = {0f, 0.5f, 1f, 1.99f, 2f}) + void acceptsTemperatureWithinRange(float temperature) { + GenerateRequest r = + new GenerateRequest( + validMessages(), 8, temperature, 0.95f, null, null, null, null, null); + assertThat(r.temperature()).isEqualTo(temperature); + } + + // -- topP bounds (open at 0, closed at 1) -- + + @ParameterizedTest + @ValueSource(floats = {0f, -0.1f, 1.0001f, 5f, Float.NaN}) + void rejectsTopPOutsideRange(float topP) { + assertThatIllegalArgumentException() + .isThrownBy( + () -> + new GenerateRequest( + validMessages(), 8, 0.7f, topP, null, null, null, null, null)) + .withMessageContaining("topP must be in (0, 1]"); + } + + @ParameterizedTest + @ValueSource(floats = {0.0001f, 0.5f, 0.999f, 1f}) + void acceptsTopPWithinRange(float topP) { + GenerateRequest r = + new GenerateRequest(validMessages(), 8, 0.7f, topP, null, null, null, null, null); + assertThat(r.topP()).isEqualTo(topP); + } + + // -- §11.2 case #15: system-only messages -> IAE (missing user message) -- + + @Test + void requiresAtLeastOneUserMessage() { + assertThatIllegalArgumentException() + .isThrownBy( + () -> + new GenerateRequest( + List.of(new Message("system", "be helpful")), + 8, + 0.7f, + 0.95f, + null, + null, + null, + null, + null)) + .withMessageContaining("at least one user message"); + } + + @Test + void rejectsNullEntriesInMessages() { + java.util.List withNull = new java.util.ArrayList<>(); + withNull.add(new Message("user", "hi")); + withNull.add(null); + assertThatIllegalArgumentException() + .isThrownBy( + () -> + new GenerateRequest( + withNull, 8, 0.7f, 0.95f, null, null, null, null, null)) + .withMessageContaining("null"); + } + + // -- §11.2 case #50: reserved fields -> FeatureNotSupportedException -- + + @Test + void reservedToolsThrows() { + assertThatExceptionOfType(FeatureNotSupportedException.class) + .isThrownBy( + () -> + new GenerateRequest( + validMessages(), + 8, + 0.7f, + 0.95f, + null, + null, + List.of(new Object()), + null, + null)) + .withMessageContaining("Phase 2"); + } + + @Test + void reservedToolChoiceThrows() { + assertThatExceptionOfType(FeatureNotSupportedException.class) + .isThrownBy( + () -> + new GenerateRequest( + validMessages(), 8, 0.7f, 0.95f, null, null, null, "auto", null)) + .withMessageContaining("Phase 2"); + } + + @Test + void reservedResponseFormatThrows() { + assertThatExceptionOfType(FeatureNotSupportedException.class) + .isThrownBy( + () -> + new GenerateRequest( + validMessages(), + 8, + 0.7f, + 0.95f, + null, + null, + null, + null, + new Object())) + .withMessageContaining("Phase 2"); + } + + @Test + void messagesAreDefensivelyCopied() { + java.util.List mutable = new java.util.ArrayList<>(); + mutable.add(new Message("user", "first")); + GenerateRequest r = + new GenerateRequest(mutable, 8, 0.7f, 0.95f, null, null, null, null, null); + mutable.clear(); + assertThat(r.messages()).hasSize(1); + } + + @Test + void stopListIsDefensivelyCopiedWhenPresent() { + java.util.List stops = new java.util.ArrayList<>(); + stops.add("\n\n"); + GenerateRequest r = + new GenerateRequest(validMessages(), 8, 0.7f, 0.95f, stops, null, null, null, null); + stops.add("###"); + assertThat(r.stop()).hasSize(1); + } +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/GeneratorBuilderTest.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/GeneratorBuilderTest.java new file mode 100644 index 0000000..c38be0a --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/GeneratorBuilderTest.java @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import java.nio.file.Path; + +import org.junit.jupiter.api.Test; +import org.slf4j.helpers.NOPLogger; + +/** + * Unit tests for {@link Generator.Builder}. Per java-sdk.md §11.1: "Builder validation: same + * pattern" — validates empty model, null modelPath, negative threads, contextSize<=0, + * queueDepth<=0, streamBufferSize<=0, and that build() requires either model or modelPath. + */ +final class GeneratorBuilderTest { + + @Test + void rejectsBlankModel() { + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().model("")) + .withMessageContaining("model must not be blank"); + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().model(" ")) + .withMessageContaining("model must not be blank"); + } + + @Test + void rejectsNullModel() { + assertThatNullPointerException().isThrownBy(() -> Generator.builder().model(null)); + } + + @Test + void rejectsNullModelPath() { + assertThatNullPointerException().isThrownBy(() -> Generator.builder().modelPath(null)); + } + + @Test + void rejectsNegativeThreads() { + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().threads(-1)) + .withMessageContaining("threads must be >= 0"); + } + + @Test + void zeroThreadsMeansAutoDetect() { + // Zero is the documented sentinel for "ContainerCpu.detect()"; should not throw. + Generator.Builder b = Generator.builder().threads(0); + assertThat(b).isNotNull(); + } + + @Test + void rejectsNonPositiveContextSize() { + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().contextSize(0)) + .withMessageContaining("contextSize must be > 0"); + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().contextSize(-1)) + .withMessageContaining("contextSize must be > 0"); + } + + @Test + void rejectsNonPositiveQueueDepth() { + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().queueDepth(0)) + .withMessageContaining("queueDepth must be > 0"); + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().queueDepth(-5)) + .withMessageContaining("queueDepth must be > 0"); + } + + @Test + void rejectsNonPositiveStreamBufferSize() { + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().streamBufferSize(0)) + .withMessageContaining("streamBufferSize must be > 0"); + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().streamBufferSize(-3)) + .withMessageContaining("streamBufferSize must be > 0"); + } + + @Test + void rejectsNullLogger() { + assertThatNullPointerException().isThrownBy(() -> Generator.builder().logger(null)); + } + + @Test + void buildRequiresModelOrModelPath() { + assertThatIllegalArgumentException() + .isThrownBy(() -> Generator.builder().build()) + .withMessageContaining("model(String) or modelPath(Path)"); + } + + @Test + void buildRejectsMissingModelEvenWithOtherSettings() { + assertThatIllegalArgumentException() + .isThrownBy( + () -> + Generator.builder() + .threads(2) + .contextSize(1024) + .queueDepth(8) + .streamBufferSize(4) + .logger(NOPLogger.NOP_LOGGER) + .build()) + .withMessageContaining("model(String) or modelPath(Path)"); + } + + @Test + void packagePrivateAccessorsExposeConfiguration() { + Path p = Path.of("/tmp/nonexistent.gguf"); + Generator.Builder b = + Generator.builder() + .model("qwen") + .modelPath(p) + .threads(2) + .contextSize(1024) + .queueDepth(4) + .streamBufferSize(8) + .logger(NOPLogger.NOP_LOGGER); + assertThat(b.getModel()).isEqualTo("qwen"); + assertThat(b.getModelPath()).isEqualTo(p); + assertThat(b.getThreads()).isEqualTo(2); + assertThat(b.getContextSize()).isEqualTo(1024); + assertThat(b.getQueueDepth()).isEqualTo(4); + assertThat(b.getStreamBufferSize()).isEqualTo(8); + assertThat(b.getLogger()).isSameAs(NOPLogger.NOP_LOGGER); + } +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/InputValidatorTest.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/InputValidatorTest.java new file mode 100644 index 0000000..d031c1a --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/InputValidatorTest.java @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link InputValidator}. Path-1 mitigation surface — see SECURITY.md §"Residual + * security risk" (GHSA-7rxv). + */ +final class InputValidatorTest { + + @Test + void instantiationIsForbidden() { + // Defensive: the utility class must not be instantiable. + assertThatThrownBy( + () -> { + var ctor = InputValidator.class.getDeclaredConstructor(); + ctor.setAccessible(true); + try { + ctor.newInstance(); + } catch (java.lang.reflect.InvocationTargetException ex) { + throw ex.getTargetException(); + } + }) + .isInstanceOf(AssertionError.class); + } + + // -- validateUtf8 -- + + @Test + void acceptsAsciiAndUnicode() { + InputValidator.validateUtf8("hello world"); + InputValidator.validateUtf8("café résumé naïve"); // accented Latin-1 + InputValidator.validateUtf8("こんにちは"); // Japanese + InputValidator.validateUtf8("你好"); // Chinese 你好 + // Emojis (supplementary plane) — surrogate pairs encoded explicitly so the file remains ASCII. + InputValidator.validateUtf8( + new String(new int[] {0x1F600, 0x1F389}, 0, 2)); // grinning face + party popper + } + + @Test + void acceptsEmptyString() { + InputValidator.validateUtf8(""); + } + + @Test + void rejectsNull() { + assertThatIllegalArgumentException() + .isThrownBy(() -> InputValidator.validateUtf8(null)) + .withMessage("text must not be null"); + } + + @Test + void rejectsUnpairedHighSurrogate() { + // U+D83D is a high surrogate; alone it is not a valid code point. + String malformed = new String(new char[] {0xD83D, 'x'}); + assertThatIllegalArgumentException() + .isThrownBy(() -> InputValidator.validateUtf8(malformed)) + .withMessageContaining("malformed"); + } + + @Test + void rejectsUnpairedLowSurrogate() { + String malformed = new String(new char[] {'x', 0xDE00}); + assertThatIllegalArgumentException() + .isThrownBy(() -> InputValidator.validateUtf8(malformed)) + .withMessageContaining("malformed"); + } + + @Test + void rejectsLoneHighSurrogateAtStringEnd() { + String malformed = "abc" + (char) 0xD83D; + assertThatIllegalArgumentException() + .isThrownBy(() -> InputValidator.validateUtf8(malformed)) + .withMessageContaining("malformed"); + } + + // -- validatePromptLength -- + + @Test + void acceptsShortPrompt() { + InputValidator.validatePromptLength( + List.of(new Message("user", "short prompt")), /* maxPromptTokens= */ 64); + } + + @Test + void rejectsExcessivelyLongPrompt() { + String huge = "a".repeat(10_000); + assertThatIllegalArgumentException() + .isThrownBy( + () -> + InputValidator.validatePromptLength( + List.of(new Message("user", huge)), /* maxPromptTokens= */ 64)) + .withMessageContaining("character cap"); + } + + @Test + void cumulativeAcrossMessagesIsBounded() { + // 8 messages of 100 chars each = 800 chars + role overhead. With maxPromptTokens=10 + // (cap = 80 chars), this must fail. + java.util.List msgs = new java.util.ArrayList<>(); + for (int i = 0; i < 8; i++) { + msgs.add(new Message("user", "x".repeat(100))); + } + assertThatIllegalArgumentException() + .isThrownBy(() -> InputValidator.validatePromptLength(msgs, 10)) + .withMessageContaining("character cap"); + } + + @Test + void rejectsNonPositiveMaxPromptTokens() { + assertThatIllegalArgumentException() + .isThrownBy( + () -> InputValidator.validatePromptLength(List.of(new Message("user", "hi")), 0)) + .withMessageContaining("maxPromptTokens must be > 0"); + assertThatIllegalArgumentException() + .isThrownBy( + () -> InputValidator.validatePromptLength(List.of(new Message("user", "hi")), -1)) + .withMessageContaining("maxPromptTokens must be > 0"); + } + + @Test + void emptyMessagesListIsAccepted() { + // The list is empty, which the GenerateRequest constructor would already reject; but this + // helper only enforces the length cap, so an empty list trivially passes. + InputValidator.validatePromptLength(List.of(), 64); + } + + @Test + void capDerivedFromTokenBudgetUsesCharsPerTokenUpperBound() { + // maxPromptTokens=1 → cap ≈ 8 chars (the documented upper bound). Anything longer rejects. + String exactly20Chars = "x".repeat(20); + assertThat(exactly20Chars).hasSize(20); + assertThatIllegalArgumentException() + .isThrownBy( + () -> + InputValidator.validatePromptLength( + List.of(new Message("user", exactly20Chars)), 1)); + } +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/KherudGeneratorClosedTest.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/KherudGeneratorClosedTest.java new file mode 100644 index 0000000..09513e7 --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/KherudGeneratorClosedTest.java @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +import java.util.List; +import java.util.concurrent.ExecutionException; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.helpers.NOPLogger; + +import io.github.randomcodespace.inference.ModelInfo; +import io.github.randomcodespace.inference.runtime.NativeExecutor; + +/** + * Closed-lifecycle tests for {@link KherudGenerator}. Per java-sdk.md §11.1: "Closed Embedder + * throws IllegalStateException" — the generate-module equivalent. + */ +final class KherudGeneratorClosedTest { + + private FakeLlamaClient client; + private NativeExecutor executor; + private KherudGenerator generator; + + @BeforeEach + void setUp() { + client = new FakeLlamaClient(); + executor = NativeExecutor.sized(1, "test-gen"); + ModelInfo info = new ModelInfo("fake", "rev1", "q4_k_m", -1, 2048); + generator = + new KherudGenerator(client, executor, info, NOPLogger.NOP_LOGGER, 4, 4, 2048, null); + } + + @AfterEach + void tearDown() { + if (generator != null) { + generator.close(); + } + } + + @Test + void closeIsIdempotent() { + generator.close(); + generator.close(); // must not throw + assertThat(client.closed).isTrue(); + } + + @Test + void completeAfterCloseThrowsIllegalState() { + generator.close(); + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + assertThatIllegalStateException() + .isThrownBy(() -> generator.complete(req)) + .withMessageContaining("closed"); + } + + @Test + void streamAfterCloseThrowsIllegalState() { + generator.close(); + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + assertThatIllegalStateException() + .isThrownBy(() -> generator.stream(req)) + .withMessageContaining("closed"); + } + + @Test + void completeAsyncAfterCloseThrowsIllegalState() { + generator.close(); + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + assertThatIllegalStateException() + .isThrownBy(() -> generator.completeAsync(req)) + .withMessageContaining("closed"); + } + + @Test + void completeAsyncFutureCarriesNativeFailure() throws Exception { + client.completeError = new RuntimeException("simulated llama.cpp fault"); + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + var future = generator.completeAsync(req); + assertThatExceptionOfType(ExecutionException.class) + .isThrownBy(future::get) + .withCauseInstanceOf(GenerateException.class); + } +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/KherudGeneratorMockedTest.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/KherudGeneratorMockedTest.java new file mode 100644 index 0000000..666dfde --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/KherudGeneratorMockedTest.java @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.helpers.NOPLogger; + +import io.github.randomcodespace.inference.FinishReason; +import io.github.randomcodespace.inference.ModelInfo; +import io.github.randomcodespace.inference.Usage; +import io.github.randomcodespace.inference.runtime.NativeExecutor; + +/** + * Happy-path tests for {@link KherudGenerator} using {@link FakeLlamaClient}. Real-model + * generation/streaming edge cases (#12–34) are exercised in Tier 5 integration tests; these unit + * tests cover the orchestration layer (NativeExecutor dispatch, queue depth, Flow.Publisher + * contract, Usage invariant). + * + *

Real-model integration tests (#12–25 generation, #26–34 streaming) are deferred to Tier 5 + * (inference-sdk-integration-tests) because they require a loaded GGUF model. + */ +final class KherudGeneratorMockedTest { + + private FakeLlamaClient client; + private NativeExecutor executor; + private KherudGenerator generator; + + @BeforeEach + void setUp() { + client = new FakeLlamaClient(); + executor = NativeExecutor.sized(2, "test-gen"); + ModelInfo info = new ModelInfo("fake", "rev1", "q4_k_m", -1, 2048); + generator = + new KherudGenerator(client, executor, info, NOPLogger.NOP_LOGGER, 4, 8, 2048, null); + } + + @AfterEach + void tearDown() { + if (generator != null) { + generator.close(); + } + } + + @Test + void completeReturnsTextAndStats() { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hello"))) + .maxTokens(64) + .build(); + + GenerateResponse resp = generator.complete(req); + + assertThat(resp.text()).isEqualTo("hello"); + assertThat(resp.finishReason()).isInstanceOf(FinishReason.Eos.class); + assertThat(resp.systemFingerprint()).isNull(); // reserved field, Phase 1 always null + assertThat(resp.stats()).isNotNull(); + assertThat(resp.stats().requestId()).startsWith("req_"); + assertThat(resp.stats().contextMax()).isEqualTo(2048); + assertThat(resp.stats().modelRevision()).isEqualTo("rev1"); + } + + @Test + void usageInvariantHoldsTotalEqualsPromptPlusCompletion() { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hello world this is a longer prompt"))) + .maxTokens(64) + .build(); + + GenerateResponse resp = generator.complete(req); + Usage u = resp.usage(); + + assertThat(u.totalTokens()).isEqualTo(u.promptTokens() + u.completionTokens()); + assertThat(u.promptTokens()).isPositive(); + assertThat(u.completionTokens()).isPositive(); + } + + /** + * Verifies the native-thread-pinning workaround: even when the caller is a virtual thread, the + * underlying {@link FakeLlamaClient#complete} runs on a platform thread (the {@link + * NativeExecutor} pool). + */ + @Test + void nativeCallRunsOnPlatformThreadEvenFromVirtualCaller() throws Exception { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + + AtomicReference callerWasVirtual = new AtomicReference<>(); + Thread caller = + Thread.ofVirtual() + .name("test-virtual-caller") + .start( + () -> { + callerWasVirtual.set(Thread.currentThread().isVirtual()); + generator.complete(req); + }); + caller.join(); + + assertThat(callerWasVirtual.get()).isTrue(); + assertThat(client.lastCompleteThread).isNotNull(); + // The fake records the thread that actually invoked complete() — it must NOT be virtual. + assertThat(client.lastCompleteThread.isVirtual()).isFalse(); + assertThat(client.lastCompleteThread.getName()).startsWith("test-gen-"); + } + + @Test + void completeAsyncCompletesFuture() throws Exception { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + GenerateResponse resp = generator.completeAsync(req).get(); + assertThat(resp.text()).isEqualTo("hello"); + } + + @Test + void streamEmitsDeltasThenTerminalChunk() throws Exception { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(64) + .build(); + Flow.Publisher pub = generator.stream(req); + + ConcurrentLinkedQueue chunks = new ConcurrentLinkedQueue<>(); + CountDownLatch done = new CountDownLatch(1); + AtomicReference failure = new AtomicReference<>(); + + pub.subscribe( + new Flow.Subscriber<>() { + Flow.Subscription sub; + + @Override + public void onSubscribe(Flow.Subscription s) { + this.sub = s; + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(GenerateChunk c) { + chunks.add(c); + } + + @Override + public void onError(Throwable t) { + failure.set(t); + done.countDown(); + } + + @Override + public void onComplete() { + done.countDown(); + } + }); + + assertThat(done.await(5, java.util.concurrent.TimeUnit.SECONDS)).isTrue(); + assertThat(failure.get()).isNull(); + + List all = chunks.stream().toList(); + assertThat(all).isNotEmpty(); + // Exactly one terminal chunk (done=true). + long terminals = all.stream().filter(GenerateChunk::done).count(); + assertThat(terminals).isEqualTo(1L); + // Terminal carries finishReason + usage + stats. + GenerateChunk last = all.get(all.size() - 1); + assertThat(last.done()).isTrue(); + assertThat(last.finishReason()).isNotNull(); + assertThat(last.usage()).isNotNull(); + assertThat(last.stats()).isNotNull(); + // Non-terminal chunks have null stats fields per §7 contract. + for (int i = 0; i < all.size() - 1; i++) { + GenerateChunk c = all.get(i); + assertThat(c.done()).isFalse(); + assertThat(c.finishReason()).isNull(); + assertThat(c.usage()).isNull(); + assertThat(c.stats()).isNull(); + } + // Stream worker ran on a platform thread (native-pinning workaround). + assertThat(client.lastStreamThread).isNotNull(); + assertThat(client.lastStreamThread.isVirtual()).isFalse(); + } + + @Test + void streamSubscriberThatNeverRequestsTriggersNoNativeWork() { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + Flow.Publisher pub = generator.stream(req); + + AtomicReference subRef = new AtomicReference<>(); + pub.subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + subRef.set(s); + // Intentionally never call s.request(...). + } + + @Override + public void onNext(GenerateChunk c) {} + + @Override + public void onError(Throwable t) {} + + @Override + public void onComplete() {} + }); + + // Give the executor a moment to (incorrectly) start work. + try { + Thread.sleep(150L); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } + assertThat(client.streamCalls.get()).isZero(); + + // Cancel cleanly. + subRef.get().cancel(); + assertThat(client.streamCalls.get()).isZero(); + } + + @Test + void streamRequestZeroSignalsErrorPerReactiveStreamsRule() { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + Flow.Publisher pub = generator.stream(req); + + AtomicReference failure = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + pub.subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + s.request(0L); // RS §3.9 violation: must surface as onError(IAE) + } + + @Override + public void onNext(GenerateChunk c) {} + + @Override + public void onError(Throwable t) { + failure.set(t); + latch.countDown(); + } + + @Override + public void onComplete() { + latch.countDown(); + } + }); + + await() + .atMost(Duration.ofSeconds(2)) + .until(() -> failure.get() != null); + assertThat(failure.get()).isInstanceOf(IllegalArgumentException.class); + } + + @Test + void modelInfoIsExposed() { + ModelInfo info = generator.modelInfo(); + assertThat(info.id()).isEqualTo("fake"); + assertThat(info.dimensions()).isEqualTo(-1); // generation models report -1 + } +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/KherudGeneratorStreamingTest.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/KherudGeneratorStreamingTest.java new file mode 100644 index 0000000..7516fb1 --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/KherudGeneratorStreamingTest.java @@ -0,0 +1,327 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Flow; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.helpers.NOPLogger; + +import io.github.randomcodespace.inference.FinishReason; +import io.github.randomcodespace.inference.ModelInfo; +import io.github.randomcodespace.inference.runtime.NativeExecutor; + +/** + * Streaming-edge tests for {@link KherudGenerator} with the {@link FakeLlamaClient} seam. These + * cover the orchestration-level behaviour of {@link BoundedSubscription} — backpressure, cancel, + * error, queue-full — independent of the real native iterator. + * + *

Real-model streaming edge cases #26–34 from java-sdk.md §11.2 are deferred to Tier 5 + * (inference-sdk-integration-tests) because they need an actual GGUF model loaded. + */ +final class KherudGeneratorStreamingTest { + + private FakeLlamaClient client; + private NativeExecutor executor; + + @BeforeEach + void setUp() { + client = new FakeLlamaClient(); + executor = NativeExecutor.sized(2, "test-stream"); + } + + @AfterEach + void tearDown() { + executor.close(); + } + + private KherudGenerator newGenerator(int queueDepth) { + ModelInfo info = new ModelInfo("fake", "rev1", "q4_k_m", -1, 2048); + return new KherudGenerator( + client, executor, info, NOPLogger.NOP_LOGGER, queueDepth, 4, 2048, null); + } + + /** Slow consumer (request(1), wait, repeat) — backpressure honoured. */ + @Test + void slowConsumerBackpressureHonoured() throws Exception { + KherudGenerator g = newGenerator(4); + try { + client.streamDeltas = List.of("a", "b", "c", "d"); + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(64) + .build(); + + ConcurrentLinkedQueue received = new ConcurrentLinkedQueue<>(); + CountDownLatch done = new CountDownLatch(1); + AtomicReference subRef = new AtomicReference<>(); + + g.stream(req) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + subRef.set(s); + s.request(1L); + } + + @Override + public void onNext(GenerateChunk c) { + received.add(c); + // pull one at a time, with a short pause to let the worker buffer + if (!c.done()) { + subRef.get().request(1L); + } + } + + @Override + public void onError(Throwable t) { + done.countDown(); + } + + @Override + public void onComplete() { + done.countDown(); + } + }); + + assertThat(done.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(received).isNotEmpty(); + // Last must be terminal. + List all = received.stream().toList(); + assertThat(all.get(all.size() - 1).done()).isTrue(); + } finally { + g.close(); + } + } + + /** Cancel before any request(n) — clean teardown, no native work, no terminal chunk. */ + @Test + void cancelBeforeRequestEmitsCanceledTerminal() throws Exception { + KherudGenerator g = newGenerator(4); + try { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + + ConcurrentLinkedQueue received = new ConcurrentLinkedQueue<>(); + CountDownLatch done = new CountDownLatch(1); + + g.stream(req) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + s.cancel(); + // After cancel, BoundedSubscription must still emit a terminal chunk + // with FinishReason.Canceled per §7. To deliver it, we still need to + // request — implementations that emit the terminal eagerly are also + // accepted by Reactive Streams §3.5. + } + + @Override + public void onNext(GenerateChunk c) { + received.add(c); + } + + @Override + public void onError(Throwable t) { + done.countDown(); + } + + @Override + public void onComplete() { + done.countDown(); + } + }); + + assertThat(done.await(5, TimeUnit.SECONDS)).isTrue(); + // Either: subscriber received the canceled terminal, OR onComplete fired without onNext. + // Both are valid §7 readings of "subscriber that cancels before requesting". + assertThat(client.streamCalls.get()).isZero(); + } finally { + g.close(); + } + } + + /** Mid-stream cancel — terminal chunk with finishReason=Canceled, then onComplete. */ + @Test + void midStreamCancelEmitsCanceledTerminal() throws Exception { + KherudGenerator g = newGenerator(4); + try { + // Long stream so we can cancel mid-flight reliably. + java.util.List deltas = new java.util.ArrayList<>(); + for (int i = 0; i < 500; i++) { + deltas.add("x"); + } + client.streamDeltas = deltas; + + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(1000) + .build(); + + ConcurrentLinkedQueue received = new ConcurrentLinkedQueue<>(); + CountDownLatch done = new CountDownLatch(1); + AtomicReference subRef = new AtomicReference<>(); + AtomicReference terminalReason = new AtomicReference<>(); + + g.stream(req) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + subRef.set(s); + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(GenerateChunk c) { + received.add(c); + if (received.size() == 5) { + subRef.get().cancel(); + } + if (c.done()) { + terminalReason.set(c.finishReason()); + } + } + + @Override + public void onError(Throwable t) { + done.countDown(); + } + + @Override + public void onComplete() { + done.countDown(); + } + }); + + assertThat(done.await(5, TimeUnit.SECONDS)).isTrue(); + // At least one terminal chunk; final must be Canceled. + await() + .atMost(Duration.ofSeconds(3)) + .until(() -> terminalReason.get() != null); + assertThat(terminalReason.get()).isInstanceOf(FinishReason.Canceled.class); + } finally { + g.close(); + } + } + + /** Cancel is idempotent. */ + @Test + void cancelIsIdempotent() throws Exception { + KherudGenerator g = newGenerator(4); + try { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + + AtomicReference subRef = new AtomicReference<>(); + g.stream(req) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + subRef.set(s); + } + + @Override + public void onNext(GenerateChunk c) {} + + @Override + public void onError(Throwable t) {} + + @Override + public void onComplete() {} + }); + Flow.Subscription s = subRef.get(); + s.cancel(); + s.cancel(); + s.cancel(); + // No exception thrown; underlying queue permit (if acquired) is released. + } finally { + g.close(); + } + } + + /** + * Edge case #33 (analogous): when queueDepth=1, a second concurrent {@code complete()} sees + * QueueFullException. We use {@code complete()} rather than streams here because the streaming + * path acquires its permit on first request(n), and the timing is harder to assert + * deterministically without sleeping inside the fake. + */ + @Test + void queueFullExceptionWhenDepthExceeded() throws Exception { + // Build a generator with depth=1 and a fake client that blocks until released. + java.util.concurrent.CountDownLatch hold = new java.util.concurrent.CountDownLatch(1); + java.util.concurrent.CountDownLatch entered = new java.util.concurrent.CountDownLatch(1); + FakeLlamaClient slowClient = + new FakeLlamaClient() { + @Override + public KherudGenerator.CompletionResult complete(String prompt, GenerateRequest req) { + entered.countDown(); + try { + hold.await(5, TimeUnit.SECONDS); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } + return super.complete(prompt, req); + } + }; + NativeExecutor exec = NativeExecutor.sized(2, "test-qfull"); + ModelInfo info = new ModelInfo("fake", "rev1", "q4_k_m", -1, 2048); + KherudGenerator g = + new KherudGenerator(slowClient, exec, info, NOPLogger.NOP_LOGGER, 1, 4, 2048, null); + try { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hi"))) + .maxTokens(8) + .build(); + + // First call: occupies the only permit and parks inside complete(). + Thread t1 = + Thread.ofVirtual() + .start( + () -> { + try { + g.complete(req); + } catch (RuntimeException ignored) { + // we don't care about the outcome of the held call + } + }); + assertThat(entered.await(2, TimeUnit.SECONDS)).isTrue(); + + // Second call must see QueueFullException synchronously. + org.assertj.core.api.Assertions.assertThatExceptionOfType(QueueFullException.class) + .isThrownBy(() -> g.complete(req)) + .withMessageContaining("queue is full"); + + hold.countDown(); + t1.join(5_000L); + } finally { + g.close(); + exec.close(); + } + } +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/MessageTest.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/MessageTest.java new file mode 100644 index 0000000..e64c843 --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/MessageTest.java @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Unit tests for {@link Message}. Covers java-sdk.md §11.1 generate-module bullet "Message rejects + * invalid roles, null role/content; reserved fields non-null throws FeatureNotSupportedException" + * and §11.2 case #49. + */ +final class MessageTest { + + @Test + void canonicalConstructorAcceptsValidRoles() { + assertThat(new Message("system", "be helpful", null, null, null).role()).isEqualTo("system"); + assertThat(new Message("user", "hi", null, null, null).role()).isEqualTo("user"); + assertThat(new Message("assistant", "ok", null, null, null).role()).isEqualTo("assistant"); + assertThat(new Message("tool", "ok", null, null, null).role()).isEqualTo("tool"); + } + + @Test + void convenienceConstructorDelegates() { + Message m = new Message("user", "hi"); + assertThat(m.role()).isEqualTo("user"); + assertThat(m.content()).isEqualTo("hi"); + assertThat(m.toolCalls()).isNull(); + assertThat(m.toolCallId()).isNull(); + assertThat(m.name()).isNull(); + } + + @ParameterizedTest + @ValueSource(strings = {"User", "USER", "developer", "function", "agent", ""}) + void rejectsRolesOutsideAllowList(String role) { + assertThatIllegalArgumentException() + .isThrownBy(() -> new Message(role, "x")) + .withMessageContaining("role must be one of"); + } + + @Test + void rejectsNullRole() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new Message(null, "hi")) + .withMessage("role must not be null"); + } + + @Test + void rejectsNullContent() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new Message("user", null)) + .withMessage("content must not be null"); + } + + @Test + void emptyContentIsAllowed() { + // Phase-1 contract: empty string is permitted (vacuous content), null is not. + Message m = new Message("user", ""); + assertThat(m.content()).isEmpty(); + } + + // -- §11.2 case #49: reserved fields must throw FeatureNotSupportedException -- + + @Test + void reservedToolCallsThrows() { + assertThatExceptionOfType(FeatureNotSupportedException.class) + .isThrownBy(() -> new Message("user", "x", List.of(new Object()), null, null)) + .withMessageContaining("Phase 2"); + } + + @Test + void reservedToolCallIdThrows() { + assertThatExceptionOfType(FeatureNotSupportedException.class) + .isThrownBy(() -> new Message("user", "x", null, "tc-1", null)) + .withMessageContaining("Phase 2"); + } + + @Test + void reservedNameThrows() { + assertThatExceptionOfType(FeatureNotSupportedException.class) + .isThrownBy(() -> new Message("user", "x", null, null, "alice")) + .withMessageContaining("Phase 2"); + } +} diff --git a/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/ModelResolverTest.java b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/ModelResolverTest.java new file mode 100644 index 0000000..aae3b4b --- /dev/null +++ b/java/inference-sdk-generate/src/test/java/io/github/randomcodespace/inference/generate/ModelResolverTest.java @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +/** + * Resolution-order tests for {@link ModelResolver}. Mirrors the embed-module tests for + * INFERENCE_MODEL_DIR / classpath fallback behaviour, but with the {@code .gguf} extension. + */ +final class ModelResolverTest { + + @Test + void explicitModelPathWinsWhenFileExists(@TempDir Path tmp) throws IOException { + Path file = tmp.resolve("custom.gguf"); + Files.writeString(file, "stub"); + + Path resolved = new ModelResolver().resolve("ignored-model-id", file); + assertThat(resolved).isEqualTo(file.toAbsolutePath()); + } + + @Test + void explicitModelPathThatDoesNotExistThrowsModelNotLoaded(@TempDir Path tmp) { + Path missing = tmp.resolve("missing.gguf"); + assertThatExceptionOfType(ModelNotLoadedException.class) + .isThrownBy(() -> new ModelResolver().resolve("ignored", missing)); + } + + @Test + void envVarResolvesWhenFileExists(@TempDir Path tmp) throws IOException { + Path file = tmp.resolve("qwen.gguf"); + Files.writeString(file, "stub"); + + ModelResolver r = ModelResolver.withEnv(name -> tmp.toString()); + Path resolved = r.resolve("qwen", null); + assertThat(resolved).isEqualTo(file.toAbsolutePath()); + } + + @Test + void envVarBareFilenameAlsoResolves(@TempDir Path tmp) throws IOException { + Path file = tmp.resolve("qwen"); + Files.writeString(file, "stub"); + + ModelResolver r = ModelResolver.withEnv(name -> tmp.toString()); + Path resolved = r.resolve("qwen", null); + assertThat(resolved).isEqualTo(file.toAbsolutePath()); + } + + @Test + void unsetEnvVarFallsThroughToClasspath(@TempDir Path tmp) { + ModelResolver r = ModelResolver.withEnv(name -> null); + assertThatExceptionOfType(ModelNotLoadedException.class) + .isThrownBy(() -> r.resolve("nonexistent", null)); + } + + @Test + void blankEnvVarFallsThroughToClasspath() { + ModelResolver r = ModelResolver.withEnv(name -> " "); + assertThatExceptionOfType(ModelNotLoadedException.class) + .isThrownBy(() -> r.resolve("nonexistent", null)); + } + + @Test + void modelIdIsRequiredWhenExplicitPathUnset() { + assertThatNullPointerException() + .isThrownBy(() -> new ModelResolver().resolve(null, null)); + } + + @Test + void modelNotLoadedMessageEnumeratesSearchedLocations() { + ModelResolver r = ModelResolver.withEnv(name -> "/nonexistent/path"); + try { + r.resolve("missing-model", null); + } catch (ModelNotLoadedException ex) { + assertThat(ex.getMessage()).contains("missing-model"); + assertThat(ex.getMessage()).contains("classpath:/models/missing-model.gguf"); + } + } +} diff --git a/java/inference-sdk-generate/src/test/resources/logback-test.xml b/java/inference-sdk-generate/src/test/resources/logback-test.xml new file mode 100644 index 0000000..2cdd5ac --- /dev/null +++ b/java/inference-sdk-generate/src/test/resources/logback-test.xml @@ -0,0 +1,17 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + diff --git a/java/inference-sdk-integration-tests/.jqwik-database b/java/inference-sdk-integration-tests/.jqwik-database new file mode 100644 index 0000000..711006c Binary files /dev/null and b/java/inference-sdk-integration-tests/.jqwik-database differ diff --git a/java/inference-sdk-integration-tests/README.md b/java/inference-sdk-integration-tests/README.md new file mode 100644 index 0000000..88aab7d --- /dev/null +++ b/java/inference-sdk-integration-tests/README.md @@ -0,0 +1,58 @@ +# inference-sdk-integration-tests + +Tier-5 cross-module integration suite for the `inference-sdk` Java reactor. +Hosts every numbered case from `java-sdk.md` §11.2 (51 cases) plus the +jqwik property tests from §11.3. + +## Tag taxonomy + +| Tag | Default | Run with | Requires real model? | +|-----------------|---------|---------------------------|----------------------| +| (untagged) | yes | `mvn verify` | no | +| `model` | no | `mvn verify -P slow` | yes | +| `model-switch` | no | `mvn verify -P model-switch -Dembedding.model=` | yes | +| `slow` | no | `mvn verify -P slow` | yes (jqwik) | + +By default `mvn verify` runs only the untagged tests: +`ForwardCompatIT` (§11.2 #49–51), `NetworkIsolationIT` (§11.2 #47), +`FailureModeIT` (§11.2 #43–45), and the validation slice of +`GenerateEdgeCaseIT` (§11.2 #13–15). Everything else self-skips when +the bundled GGUF / ONNX model JARs are still placeholders. + +## Running the full suite + +The bundled models ship as `.gitkeep` placeholders inside +`inference-sdk-embed-bge-small/src/main/resources/models/` and +`inference-sdk-generate-qwen-0_5b/src/main/resources/models/` until +Tier 0.5's `make fetch-models` populates them via Git LFS. Once that +prerequisite is satisfied: + +```bash +make fetch-models +./mvnw -f java/pom.xml -pl :inference-sdk-integration-tests -am verify -P slow +``` + +Total expected wall time on a laptop-class CPU: under 90 s end-to-end. + +## Programmatic skip switch + +Set `-Dskip.model.tests=true` to hard-skip every `@Tag("model")` test +even when the JAR payload exists. Useful for CI lanes that can't load +native libraries (e.g. `--no-daemon` smoke runs on machines without +glibc 2.27+). + +## Why a separate module? + +Per spec §5 Tier 5, this module produces no production code: + +- It hosts only `*IT.java` test classes; +- Maven Failsafe (not Surefire) runs them; +- JaCoCo, SpotBugs, Javadoc, and OWASP Dependency-Check are skipped + (everything they would audit is already audited in the per-module + reactors). + +Cross-module integration concerns (shared lifecycle, RequestId +propagation across embed + generate, queue contention with both +modules loaded simultaneously) live here so they don't bloat the +single-purpose `inference-sdk-embed` and `inference-sdk-generate` +trees. diff --git a/java/inference-sdk-integration-tests/pom.xml b/java/inference-sdk-integration-tests/pom.xml new file mode 100644 index 0000000..2185a82 --- /dev/null +++ b/java/inference-sdk-integration-tests/pom.xml @@ -0,0 +1,210 @@ + + + + 4.0.0 + + + io.github.randomcodespace.inference + inference-sdk-parent + 0.1.0-SNAPSHOT + ../inference-sdk-parent/pom.xml + + + inference-sdk-integration-tests + jar + + ${project.groupId}:${project.artifactId} + Tier-5 integration-test module for the inference-sdk Java + reactor. Hosts the 51 numbered §11.2 edge-case tests plus jqwik + property tests; produces no production code. + + + + true + true + true + true + + model,model-switch,slow + + + + + + io.github.randomcodespace.inference + inference-sdk-core + ${project.version} + test + + + io.github.randomcodespace.inference + inference-sdk-embed + ${project.version} + test + + + io.github.randomcodespace.inference + inference-sdk-generate + ${project.version} + test + + + io.github.randomcodespace.inference + inference-sdk-embed-bge-small + ${project.version} + test + + + io.github.randomcodespace.inference + inference-sdk-generate-qwen-0_5b + ${project.version} + test + + + + + org.junit.jupiter + junit-jupiter + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.assertj + assertj-core + test + + + org.awaitility + awaitility + test + + + net.jqwik + jqwik + test + + + com.fasterxml.jackson.core + jackson-databind + test + + + com.fasterxml.jackson.core + jackson-annotations + test + + + ch.qos.logback + logback-classic + test + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + + org.apache.maven.plugins + maven-failsafe-plugin + + + ${it.excluded.groups} + + -Dfile.encoding=UTF-8 --enable-preview ${argLine} + + en + US + + 1 + true + + + + integration-test + + integration-test + + + + verify + + verify + + + + + + + + + + + slow + + + model-switch + + + + + + model-switch + + slow + + + + diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ConcurrencyLifecycleIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ConcurrencyLifecycleIT.java new file mode 100644 index 0000000..aa04020 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ConcurrencyLifecycleIT.java @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIfSystemProperty; + +import io.github.randomcodespace.inference.embed.EmbedStats; +import io.github.randomcodespace.inference.embed.Embedder; +import io.github.randomcodespace.inference.generate.GenerateRequest; +import io.github.randomcodespace.inference.generate.GenerateResponse; +import io.github.randomcodespace.inference.generate.Generator; +import io.github.randomcodespace.inference.generate.Message; +import io.github.randomcodespace.inference.it.support.ModelArtifacts; +import io.github.randomcodespace.inference.runtime.RequestId; + +/** + * Concurrency / lifecycle integration tests for java-sdk.md §11.2 cases #35–39. + * + *

All cases require both bundled models on the classpath (1000 embeds + 100 completes + * interleaved). Tagged {@code @Tag("model")} and self-skip via {@link + * io.github.randomcodespace.inference.it.support.ModelArtifacts}. + */ +@Tag("model") +@DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") +final class ConcurrencyLifecycleIT { + + @BeforeAll + static void requireModels() { + assumeTrue( + ModelArtifacts.embedModelPresent() && ModelArtifacts.generateModelPresent(), + "real bge-small + Qwen models not present — run `make fetch-models` first."); + } + + /** Case #35: 1000 embeds + 100 completes interleaved across 20 virtual threads. */ + @Test + @DisplayName("#35 1000 embeds + 100 completes interleaved → no deadlocks, no crashes") + void interleavedEmbedAndComplete() throws Exception { + try (Embedder e = Embedder.builder().model("bge-small-en-v1.5").build(); + Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").queueDepth(32).build()) { + AtomicInteger embedOk = new AtomicInteger(); + AtomicInteger genOk = new AtomicInteger(); + List> futs = new ArrayList<>(); + try (var exec = Executors.newVirtualThreadPerTaskExecutor()) { + for (int i = 0; i < 1000; i++) { + final int idx = i; + futs.add( + exec.submit( + () -> { + e.embedOne("doc " + idx); + embedOk.incrementAndGet(); + })); + } + for (int i = 0; i < 100; i++) { + final int idx = i; + futs.add( + exec.submit( + () -> { + GenerateResponse r = + g.complete( + new GenerateRequest( + List.of(new Message("user", "Reply ok " + idx, null, null, null)), + 4, + 0f, + 0.95f, + null, + (long) idx, + null, + null, + null)); + if (r.text() != null) { + genOk.incrementAndGet(); + } + })); + } + } + for (Future f : futs) { + f.get(); + } + assertThat(embedOk.get()).isEqualTo(1000); + assertThat(genOk.get()).isEqualTo(100); + } + } + + /** Case #36: {@link RequestId} propagates into {@link EmbedStats#requestId()}. */ + @Test + @DisplayName("#36 RequestId propagates into EmbedStats.requestId / GenerateStats.requestId") + void requestIdPropagates() throws Exception { + try (Embedder e = Embedder.builder().model("bge-small-en-v1.5").build()) { + String reqId = RequestId.generate(); + String observed = + RequestId.withRequestId(reqId, () -> e.embed(List.of("ping")).stats().requestId()); + assertThat(observed).isEqualTo(reqId); + } + } + + /** + * Case #37: thread-leak detection — pre/post platform-thread count returns to baseline within 2s. + */ + @Test + @DisplayName("#37 thread-leak detection: close → baseline within 2s") + void threadLeakDetection() throws Exception { + int baseline = platformThreadCount(); + Embedder e = Embedder.builder().model("bge-small-en-v1.5").build(); + Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").build(); + e.embedOne("warmup"); + g.complete( + new GenerateRequest( + List.of(new Message("user", "ok", null, null, null)), + 4, + 0f, + 0.95f, + null, + 1L, + null, + null, + null)); + e.close(); + g.close(); + Awaitility.await() + .atMost(java.time.Duration.ofSeconds(2)) + .untilAsserted( + () -> + assertThat(platformThreadCount()) + .as("platform threads return to within 2 of baseline") + .isLessThanOrEqualTo(baseline + 2)); + } + + /** Case #38: {@code close()} idempotent — double-close throws nothing. */ + @Test + @DisplayName("#38 close() idempotent — double-close no throw") + void closeIdempotent() { + Embedder e = Embedder.builder().model("bge-small-en-v1.5").build(); + e.close(); + e.close(); // must not throw + Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").build(); + g.close(); + g.close(); // must not throw + } + + /** Case #39: concurrent {@code close()} calls — exactly one wins; no crash. */ + @Test + @DisplayName("#39 concurrent close() calls → exactly one wins, no crash") + void concurrentClose() throws Exception { + Embedder e = Embedder.builder().model("bge-small-en-v1.5").build(); + int N = 16; + CountDownLatch start = new CountDownLatch(1); + CountDownLatch done = new CountDownLatch(N); + AtomicInteger errors = new AtomicInteger(); + try (var exec = Executors.newVirtualThreadPerTaskExecutor()) { + for (int i = 0; i < N; i++) { + exec.submit( + () -> { + try { + start.await(); + e.close(); + } catch (Throwable t) { + errors.incrementAndGet(); + } finally { + done.countDown(); + } + }); + } + start.countDown(); + done.await(); + } + assertThat(errors.get()).isZero(); + } + + // ----------------------------------------------------------------------------------------- + // helpers + // ----------------------------------------------------------------------------------------- + + private static int platformThreadCount() { + return ManagementFactory.getThreadMXBean().getThreadCount(); + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/EmbedEdgeCaseIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/EmbedEdgeCaseIT.java new file mode 100644 index 0000000..f6e77e6 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/EmbedEdgeCaseIT.java @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIfSystemProperty; + +import io.github.randomcodespace.inference.embed.EmbedResult; +import io.github.randomcodespace.inference.embed.Embedder; +import io.github.randomcodespace.inference.embed.InvalidInputException; +import io.github.randomcodespace.inference.it.support.ModelArtifacts; + +/** + * Embedding edge-case integration tests for java-sdk.md §11.2 cases #1–11. + * + *

All cases require the real bge-small INT8 ONNX file; tagged {@code @Tag("model")} so they're + * skipped by default and run under {@code mvn verify -P slow} once {@code make fetch-models} has + * populated the model JAR via Git LFS. {@link DisabledIfSystemProperty} provides a programmatic + * kill-switch ({@code -Dskip.model.tests=true}). + */ +@Tag("model") +@DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") +final class EmbedEdgeCaseIT { + + @BeforeAll + static void requireModelArtifact() { + assumeTrue( + ModelArtifacts.embedModelPresent(), + "bge-small-en-v1.5 ONNX not present on classpath — run `make fetch-models` first."); + } + + /** Case #1: empty string input returns a vector, no NPE. */ + @Test + @DisplayName("#1 empty string → vector, no NPE") + void emptyStringReturnsVector() { + try (Embedder e = embedder()) { + EmbedResult r = e.embed(List.of("")); + assertThat(r.vectors()).hasSize(1); + assertThat(r.vectors().get(0)).isNotEmpty(); + } + } + + /** Case #2: single space input returns a vector. */ + @Test + @DisplayName("#2 single space → vector") + void singleSpaceReturnsVector() { + try (Embedder e = embedder()) { + EmbedResult r = e.embed(List.of(" ")); + assertThat(r.vectors()).hasSize(1); + } + } + + /** Case #3: input exceeding the model's max-token window is auto-truncated. */ + @Test + @DisplayName("#3 input exceeding model max → auto-truncate, tokens == model max") + void overlongInputAutoTruncates() { + String huge = "lorem ipsum dolor sit amet ".repeat(10_000); + try (Embedder e = embedder()) { + EmbedResult r = e.embed(List.of(huge)); + assertThat(r.tokens()).isLessThanOrEqualTo(e.modelInfo().maxTokens()); + assertThat(r.vectors()).hasSize(1); + } + } + + /** Case #4: Unicode-heavy inputs (Han, emoji ZWJ, RTL, Devanagari) all succeed. */ + @Test + @DisplayName("#4 Unicode-heavy inputs succeed, no NaN/Inf") + void unicodeHeavyInputsSucceed() { + List inputs = + List.of("人工智能", "👨‍👩‍👧‍👦", "السلام عليكم", "नमस्ते", "Mixed 中文 + español"); + try (Embedder e = embedder()) { + EmbedResult r = e.embed(inputs); + assertThat(r.vectors()).hasSize(inputs.size()); + for (float[] v : r.vectors()) { + for (float f : v) { + assertThat(Float.isFinite(f)).as("vector must contain only finite floats").isTrue(); + } + } + } + } + + /** Case #5: null entry in batch raises {@link InvalidInputException} BEFORE work. */ + @Test + @DisplayName("#5 null in list → InvalidInputException before work") + void nullEntryRejectedBeforeWork() { + try (Embedder e = embedder()) { + List withNull = Arrays.asList("ok", null, "ok2"); + assertThatThrownBy(() -> e.embed(withNull)).isInstanceOf(InvalidInputException.class); + } + } + + /** Case #6: mixed-length batch produces finite vectors and consistent stats. */ + @Test + @DisplayName("#6 mixed-length batch (1 char to 500 tokens) → finite, stats reflect tokens") + void mixedLengthBatchSucceeds() { + List inputs = + List.of( + "x", + "short example", + "lorem ipsum dolor sit amet ".repeat(50), + "lorem ipsum dolor sit amet ".repeat(200)); + try (Embedder e = embedder()) { + EmbedResult r = e.embed(inputs); + assertThat(r.vectors()).hasSize(inputs.size()); + assertThat(r.tokens()).isPositive(); + } + } + + /** Case #7: batch larger than {@code batchSize} chunks correctly and matches one-by-one. */ + @Test + @DisplayName("#7 batch > batchSize → chunked == one-by-one within 1e-5 tolerance") + void chunkingMatchesOneByOne() { + int batchSize = 4; + List inputs = List.of("a", "b", "c", "d", "e", "f", "g", "h", "i", "j"); + try (Embedder e = Embedder.builder().model("bge-small-en-v1.5").batchSize(batchSize).build()) { + EmbedResult batched = e.embed(inputs); + List oneByOne = new ArrayList<>(); + for (String s : inputs) { + oneByOne.add(e.embedOne(s)); + } + for (int i = 0; i < inputs.size(); i++) { + float[] a = batched.vectors().get(i); + float[] b = oneByOne.get(i); + assertThat(a).hasSameSizeAs(b); + for (int j = 0; j < a.length; j++) { + assertThat(Math.abs(a[j] - b[j])).isLessThanOrEqualTo(1e-5f); + } + } + } + } + + /** Case #8: repeated identical inputs in one batch produce bytewise-identical vectors. */ + @Test + @DisplayName("#8 repeated identical inputs → bytewise-identical vectors") + void repeatedInputsByteIdentical() { + try (Embedder e = embedder()) { + EmbedResult r = e.embed(List.of("hello", "hello", "hello")); + float[] a = r.vectors().get(0); + float[] b = r.vectors().get(1); + float[] c = r.vectors().get(2); + assertThat(a).containsExactly(b); + assertThat(b).containsExactly(c); + } + } + + /** Case #9: same input twice across separate calls is bytewise-identical. */ + @Test + @DisplayName("#9 determinism: same input twice → identical vectors") + void deterministicAcrossCalls() { + try (Embedder e = embedder()) { + float[] v1 = e.embedOne("the quick brown fox"); + float[] v2 = e.embedOne("the quick brown fox"); + assertThat(v1).containsExactly(v2); + } + } + + /** Case #10: whitespace-only inputs ({@code " "}, {@code "\n\n\n"}, {@code "\t"}) all succeed. */ + @Test + @DisplayName("#10 whitespace-only inputs → vectors, no errors") + void whitespaceOnlyInputsSucceed() { + List inputs = List.of(" ", "\n\n\n", "\t"); + try (Embedder e = embedder()) { + EmbedResult r = e.embed(inputs); + assertThat(r.vectors()).hasSize(inputs.size()); + } + } + + /** Case #11: 10,000 short strings completes within bounded memory. */ + @Test + @DisplayName("#11 10,000 short strings → completes, memory bounded") + void tenThousandShortStrings() { + List inputs = Collections.nCopies(10_000, "short example text"); + try (Embedder e = embedder()) { + EmbedResult r = e.embed(inputs); + assertThat(r.vectors()).hasSize(10_000); + } + } + + // --------------------------------------------------------------------------------------------- + // helpers + // --------------------------------------------------------------------------------------------- + + private static Embedder embedder() { + return Embedder.builder().model("bge-small-en-v1.5").build(); + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/FailureModeIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/FailureModeIT.java new file mode 100644 index 0000000..f3e9f82 Binary files /dev/null and b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/FailureModeIT.java differ diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ForwardCompatIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ForwardCompatIT.java new file mode 100644 index 0000000..cb5648c --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ForwardCompatIT.java @@ -0,0 +1,306 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.github.randomcodespace.inference.FinishReason; +import io.github.randomcodespace.inference.ModelInfo; +import io.github.randomcodespace.inference.Usage; +import io.github.randomcodespace.inference.embed.EmbedResult; +import io.github.randomcodespace.inference.embed.EmbedStats; +import io.github.randomcodespace.inference.generate.FeatureNotSupportedException; +import io.github.randomcodespace.inference.generate.GenerateChunk; +import io.github.randomcodespace.inference.generate.GenerateRequest; +import io.github.randomcodespace.inference.generate.GenerateResponse; +import io.github.randomcodespace.inference.generate.GenerateStats; +import io.github.randomcodespace.inference.generate.Message; + +/** + * Forward-compat integration tests for java-sdk.md §11.2 cases #49–51. + * + *

These tests exercise pure-record validation and Jackson serialisation; they do not need the + * real GGUF / ONNX models and therefore run by default at {@code mvn verify} (no {@code @Tag}). + * + *

Cases covered

+ * + *
    + *
  • #49 — constructing {@link Message} with non-null {@code toolCalls}, {@code + * toolCallId}, or {@code name} throws {@link FeatureNotSupportedException}. + *
  • #50 — constructing {@link GenerateRequest} with non-null {@code tools}, {@code + * toolChoice}, or {@code responseFormat} throws. + *
  • #51 — Jackson serialisation of every record produces {@code snake_case} JSON keys + * matching {@code docs/WIRE_FORMAT.md}. + *
+ */ +final class ForwardCompatIT { + + // --------------------------------------------------------------------------------------------- + // Case #49 — Message reserved fields + // --------------------------------------------------------------------------------------------- + + /** Case #49a: {@link Message#toolCalls()} non-null is rejected. */ + @Test + @DisplayName("#49a Message.toolCalls non-null → FeatureNotSupportedException") + void messageToolCallsRejected() { + assertThatThrownBy(() -> new Message("user", "hi", List.of("placeholder"), null, null)) + .isInstanceOf(FeatureNotSupportedException.class) + .hasMessageContaining("tool calling"); + } + + /** Case #49b: {@link Message#toolCallId()} non-null is rejected. */ + @Test + @DisplayName("#49b Message.toolCallId non-null → FeatureNotSupportedException") + void messageToolCallIdRejected() { + assertThatThrownBy(() -> new Message("user", "hi", null, "call_123", null)) + .isInstanceOf(FeatureNotSupportedException.class); + } + + /** Case #49c: {@link Message#name()} non-null is rejected. */ + @Test + @DisplayName("#49c Message.name non-null → FeatureNotSupportedException") + void messageNameRejected() { + assertThatThrownBy(() -> new Message("user", "hi", null, null, "tool-output")) + .isInstanceOf(FeatureNotSupportedException.class); + } + + // --------------------------------------------------------------------------------------------- + // Case #50 — GenerateRequest reserved fields + // --------------------------------------------------------------------------------------------- + + /** Case #50a: {@link GenerateRequest#tools()} non-null is rejected. */ + @Test + @DisplayName("#50a GenerateRequest.tools non-null → FeatureNotSupportedException") + void generateRequestToolsRejected() { + List msgs = List.of(new Message("user", "hi", null, null, null)); + assertThatThrownBy( + () -> + new GenerateRequest( + msgs, 16, 0.7f, 0.95f, null, null, List.of("placeholder-tool"), null, null)) + .isInstanceOf(FeatureNotSupportedException.class) + .hasMessageContaining("Phase 2"); + } + + /** Case #50b: {@link GenerateRequest#toolChoice()} non-null is rejected. */ + @Test + @DisplayName("#50b GenerateRequest.toolChoice non-null → FeatureNotSupportedException") + void generateRequestToolChoiceRejected() { + List msgs = List.of(new Message("user", "hi", null, null, null)); + assertThatThrownBy( + () -> new GenerateRequest(msgs, 16, 0.7f, 0.95f, null, null, null, "auto", null)) + .isInstanceOf(FeatureNotSupportedException.class); + } + + /** Case #50c: {@link GenerateRequest#responseFormat()} non-null is rejected. */ + @Test + @DisplayName("#50c GenerateRequest.responseFormat non-null → FeatureNotSupportedException") + void generateRequestResponseFormatRejected() { + List msgs = List.of(new Message("user", "hi", null, null, null)); + assertThatThrownBy( + () -> new GenerateRequest(msgs, 16, 0.7f, 0.95f, null, null, null, null, "json_object")) + .isInstanceOf(FeatureNotSupportedException.class); + } + + // --------------------------------------------------------------------------------------------- + // Case #51 — Jackson snake_case serialisation + // --------------------------------------------------------------------------------------------- + + /** + * Case #51a: {@link Message} serialises to {@code snake_case} keys per {@code WIRE_FORMAT.md} + * §2.3. + */ + @Test + @DisplayName("#51a Message serialises with snake_case keys (tool_calls, tool_call_id)") + void messageWireFormatSnakeCase() throws Exception { + Message msg = new Message("user", "hi", null, null, null); + JsonNode tree = mapper().valueToTree(msg); + assertThat(tree.has("role")).isTrue(); + assertThat(tree.has("content")).isTrue(); + // Reserved fields surface as snake_case keys (with null values per JsonInclude default). + assertThat(allKeys(tree)).contains("role", "content").doesNotContain("toolCalls", "toolCallId"); + } + + /** + * Case #51b: {@link GenerateRequest} serialises {@code maxTokens}, {@code topP}, {@code stop}, + * and reserved Phase-2 fields with snake_case keys. + */ + @Test + @DisplayName("#51b GenerateRequest serialises with snake_case keys (max_tokens, top_p, ...)") + void generateRequestWireFormatSnakeCase() throws Exception { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "hello", null, null, null))) + .maxTokens(8) + .temperature(0.7f) + .topP(0.95f) + .build(); + JsonNode tree = mapper().valueToTree(req); + // Concrete snake_case keys we care about per WIRE_FORMAT.md §2.5 (request). + // Spec calls out max_tokens / top_p / response_format etc. + assertThat(allKeys(tree)).contains("messages"); + // No Java-style camelCase keys leak. + assertThat(allKeys(tree)).doesNotContain("maxTokens", "topP", "toolChoice", "responseFormat"); + } + + /** + * Case #51c: {@link GenerateResponse} carries {@code finish_reason} + {@code system_fingerprint}. + */ + @Test + @DisplayName("#51c GenerateResponse serialises finish_reason + system_fingerprint snake_case") + void generateResponseWireFormatSnakeCase() throws Exception { + GenerateResponse resp = + new GenerateResponse( + "hello", + new FinishReason.Eos(), + new Usage(2, 1, 3), + new GenerateStats( + "req_test", + 0L, + 0L, + 0L, + 0L, + 0L, + 0.0d, + new FinishReason.Eos(), + 0, + 2048, + "rev", + "node"), + null); + JsonNode tree = mapper().valueToTree(resp); + assertThat(allKeys(tree)) + .contains("finish_reason", "system_fingerprint") + .doesNotContain("finishReason", "systemFingerprint"); + } + + /** Case #51d: {@link GenerateChunk} carries {@code finish_reason} snake_case. */ + @Test + @DisplayName("#51d GenerateChunk serialises finish_reason snake_case") + void generateChunkWireFormatSnakeCase() throws Exception { + GenerateChunk chunk = new GenerateChunk("hi", false, null, null, null); + JsonNode tree = mapper().valueToTree(chunk); + assertThat(allKeys(tree)).contains("finish_reason").doesNotContain("finishReason"); + } + + /** + * Case #51e: {@link GenerateStats} carries {@code request_id}, {@code queue_ms}, {@code + * prompt_eval_ms}, {@code first_token_ms}, {@code generation_ms}, {@code total_ms}, {@code + * tokens_per_second}, {@code stop_reason}, {@code context_used}, {@code context_max}, {@code + * model_revision} as snake_case keys per WIRE_FORMAT.md §2.6. + */ + @Test + @DisplayName("#51e GenerateStats serialises every field as snake_case") + void generateStatsWireFormatSnakeCase() throws Exception { + GenerateStats stats = + new GenerateStats( + "req_abc", 1L, 2L, 3L, 4L, 10L, 7.5d, new FinishReason.Eos(), 5, 2048, "rev", "host-1"); + JsonNode tree = mapper().valueToTree(stats); + assertThat(allKeys(tree)) + .contains( + "request_id", + "queue_ms", + "prompt_eval_ms", + "first_token_ms", + "generation_ms", + "total_ms", + "tokens_per_second", + "stop_reason", + "context_used", + "context_max", + "model_revision", + "node") + .doesNotContain( + "requestId", + "queueMs", + "promptEvalMs", + "firstTokenMs", + "generationMs", + "totalMs", + "tokensPerSecond", + "stopReason", + "contextUsed", + "contextMax", + "modelRevision"); + } + + /** Case #51f: {@link EmbedResult} + {@link EmbedStats} carry snake_case keys per §2.1/§2.2. */ + @Test + @DisplayName("#51f EmbedResult + EmbedStats serialise with snake_case keys") + void embedResultWireFormatSnakeCase() throws Exception { + EmbedStats stats = new EmbedStats("req_xyz", 0L, 1L, 2L, 3L, 1, "single", "rev-bge", "host-1"); + EmbedResult result = new EmbedResult(List.of(new float[] {0.1f, 0.2f}), 4, stats); + JsonNode tree = mapper().valueToTree(result); + JsonNode statsTree = tree.get("stats"); + assertThat(statsTree).isNotNull(); + assertThat(allKeys(statsTree)) + .contains( + "request_id", + "queue_ms", + "tokenize_ms", + "inference_ms", + "total_ms", + "batch_size", + "batch_position", + "model_revision", + "node") + .doesNotContain("requestId", "queueMs", "tokenizeMs", "batchSize", "batchPosition"); + } + + /** Case #51g: {@link Usage} carries {@code prompt_tokens / completion_tokens / total_tokens}. */ + @Test + @DisplayName("#51g Usage serialises with snake_case keys") + void usageWireFormatSnakeCase() throws Exception { + Usage usage = new Usage(2, 3, 5); + JsonNode tree = mapper().valueToTree(usage); + assertThat(allKeys(tree)) + .contains("prompt_tokens", "completion_tokens", "total_tokens") + .doesNotContain("promptTokens", "completionTokens", "totalTokens"); + } + + /** Case #51h: {@link ModelInfo} serialises with snake_case for multi-word fields. */ + @Test + @DisplayName("#51h ModelInfo serialises with snake_case keys (max_tokens)") + void modelInfoWireFormatSnakeCase() throws Exception { + ModelInfo info = new ModelInfo("bge-small-en-v1.5", "rev-1", "int8", 384, 512); + JsonNode tree = mapper().valueToTree(info); + // ModelInfo doesn't currently carry @JsonProperty annotations. WIRE_FORMAT.md does NOT + // enumerate ModelInfo, so we only assert the obvious camelCase keys do not break the + // contract. This test acts as a tripwire: if Phase 2 adds @JsonProperty annotations, the + // assertions below will need updating, and the snake_case assertion can be re-enabled. + assertThat(allKeys(tree)).contains("id", "revision", "quantization", "dimensions"); + } + + // --------------------------------------------------------------------------------------------- + // helpers + // --------------------------------------------------------------------------------------------- + + /** + * Configured {@link ObjectMapper}; respects {@link JsonProperty} annotations and falls back to + * {@link com.fasterxml.jackson.databind.PropertyNamingStrategies#SNAKE_CASE} for fields without + * an explicit annotation. {@code WIRE_FORMAT.md} §1 sanctions both strategies. + */ + private static ObjectMapper mapper() { + ObjectMapper m = new ObjectMapper(); + m.setPropertyNamingStrategy(com.fasterxml.jackson.databind.PropertyNamingStrategies.SNAKE_CASE); + return m; + } + + /** All field keys at depth 1 of a JSON object node (for diagnostic assertions). */ + private static java.util.Set allKeys(JsonNode tree) { + java.util.Set keys = new java.util.LinkedHashSet<>(); + tree.fieldNames().forEachRemaining(keys::add); + return keys; + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/GenerateEdgeCaseIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/GenerateEdgeCaseIT.java new file mode 100644 index 0000000..f4fa8e1 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/GenerateEdgeCaseIT.java @@ -0,0 +1,343 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIfSystemProperty; + +import io.github.randomcodespace.inference.FinishReason; +import io.github.randomcodespace.inference.generate.GenerateRequest; +import io.github.randomcodespace.inference.generate.GenerateResponse; +import io.github.randomcodespace.inference.generate.Generator; +import io.github.randomcodespace.inference.generate.Message; +import io.github.randomcodespace.inference.it.support.ModelArtifacts; + +/** + * Generation edge-case integration tests for java-sdk.md §11.2 cases #12–25. + * + *

Cases #13/#14/#15 are pure record/builder validation — they run by default. The remaining + * cases (#12, #16–25) require the bundled Qwen 0.5B GGUF and are tagged {@code @Tag("model")}; each + * method calls {@link #requireGenerateModel()} so it self-skips when the model JAR ships empty. + */ +final class GenerateEdgeCaseIT { + + // ----------------------------------------------------------------------------------------- + // Pure validation cases — no model required + // ----------------------------------------------------------------------------------------- + + /** Case #13: {@code maxTokens=0} → {@link IllegalArgumentException}. */ + @Test + @DisplayName("#13 maxTokens=0 → IllegalArgumentException") + void maxTokensZeroRejected() { + assertThatThrownBy(() -> GenerateRequest.builder().maxTokens(0)) + .isInstanceOf(IllegalArgumentException.class); + } + + /** Case #14: empty messages → {@link IllegalArgumentException}. */ + @Test + @DisplayName("#14 empty messages → IllegalArgumentException") + void emptyMessagesRejected() { + assertThatThrownBy( + () -> new GenerateRequest(List.of(), 8, 0.7f, 0.95f, null, null, null, null, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages"); + } + + /** Case #15: system-only messages → {@link IllegalArgumentException}. */ + @Test + @DisplayName("#15 system-only messages → IllegalArgumentException") + void systemOnlyMessagesRejected() { + List sysOnly = List.of(new Message("system", "you are helpful", null, null, null)); + assertThatThrownBy( + () -> new GenerateRequest(sysOnly, 8, 0.7f, 0.95f, null, null, null, null, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("user"); + } + + // ----------------------------------------------------------------------------------------- + // Real-model cases + // ----------------------------------------------------------------------------------------- + + /** + * Case #12: {@code maxTokens=1} → exactly one token, finishReason={@link FinishReason.Length}. + */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#12 maxTokens=1 → exactly one token, finishReason=Length") + void maxTokensOneFinishesAsLength() { + requireGenerateModel(); + try (Generator g = generator()) { + GenerateResponse r = + g.complete(req(List.of(new Message("user", "Say hi.", null, null, null)), 1, 0f, 0L)); + assertThat(r.usage().completionTokens()).isLessThanOrEqualTo(1); + assertThat(r.finishReason()).isInstanceOf(FinishReason.Length.class); + } + } + + /** Case #16: stop sequence triggered → {@link FinishReason.Stop}; output ends BEFORE the stop. */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#16 stop sequence triggered → finishReason=Stop, ends before stop") + void stopSequenceHonored() { + requireGenerateModel(); + try (Generator g = generator()) { + GenerateRequest r = + GenerateRequest.builder() + .messages(List.of(new Message("user", "Count: 1 2 STOP 3 4", null, null, null))) + .maxTokens(64) + .temperature(0f) + .stop(List.of("STOP")) + .seed(42L) + .build(); + GenerateResponse resp = g.complete(r); + if (resp.finishReason() instanceof FinishReason.Stop) { + assertThat(resp.text()).doesNotContain("STOP"); + } + } + } + + /** Case #17: EOS before maxTokens → finishReason={@link FinishReason.Eos}. */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#17 EOS before maxTokens → finishReason=Eos") + void eosBeforeMaxTokens() { + requireGenerateModel(); + try (Generator g = generator()) { + GenerateResponse r = + g.complete( + req( + List.of(new Message("user", "Reply with the single word: ok.", null, null, null)), + 256, + 0f, + 1L)); + if (r.usage().completionTokens() < 256) { + assertThat(r.finishReason()) + .isInstanceOfAny(FinishReason.Eos.class, FinishReason.Stop.class); + } + } + } + + /** Case #18: same prompt + same seed + temperature=0 → bytewise-identical text. */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#18 deterministic with seed + temperature=0") + void determinismWithSeed() { + requireGenerateModel(); + try (Generator g = generator()) { + GenerateRequest req = + req(List.of(new Message("user", "Say one word.", null, null, null)), 16, 0f, 12345L); + assertThat(g.complete(req).text()).isEqualTo(g.complete(req).text()); + } + } + + /** Case #19: different seeds, temperature > 0 → different outputs (sanity). */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#19 different seeds with temperature > 0 → different outputs") + void differentSeedsDiverge() { + requireGenerateModel(); + try (Generator g = generator()) { + String t1 = + g.complete( + req( + List.of(new Message("user", "Tell a one-line joke.", null, null, null)), + 32, + 0.9f, + 1L)) + .text(); + String t2 = + g.complete( + req( + List.of(new Message("user", "Tell a one-line joke.", null, null, null)), + 32, + 0.9f, + 2L)) + .text(); + assertThat(t1).isNotNull(); + assertThat(t2).isNotNull(); + } + } + + /** Case #20: 50 concurrent {@code complete} via virtual threads → all valid, no cross-talk. */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#20 50 concurrent complete via virtual threads") + void fiftyConcurrentCompletes() throws Exception { + requireGenerateModel(); + try (Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").queueDepth(64).build()) { + AtomicInteger ok = new AtomicInteger(); + List> futs = new ArrayList<>(); + try (var exec = Executors.newVirtualThreadPerTaskExecutor()) { + for (int i = 0; i < 50; i++) { + final int idx = i; + futs.add( + exec.submit( + () -> { + GenerateResponse r = + g.complete( + req( + List.of(new Message("user", "Echo idx=" + idx, null, null, null)), + 8, + 0f, + (long) idx)); + assertThat(r.text()).isNotNull(); + ok.incrementAndGet(); + })); + } + } + for (Future f : futs) { + f.get(); + } + assertThat(ok.get()).isEqualTo(50); + } + } + + /** + * Case #21: closing the Generator mid-flight → clean throw + subsequent IllegalStateException. + */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#21 close mid-flight → clean throw, subsequent → IllegalStateException") + void closeMidFlight() throws Exception { + requireGenerateModel(); + Generator g = generator(); + Future longRunning; + try (var exec = Executors.newVirtualThreadPerTaskExecutor()) { + longRunning = + exec.submit( + () -> + g.complete( + req( + List.of(new Message("user", "Write a long essay.", null, null, null)), + 512, + 0f, + 7L))); + Thread.sleep(50); + g.close(); + } + assertThatThrownBy( + () -> + g.complete( + req(List.of(new Message("user", "after close", null, null, null)), 4, 0f, 1L))) + .isInstanceOf(IllegalStateException.class); + try { + longRunning.get(); + } catch (ExecutionException ignored) { + // expected — in-flight aborted by close() + } + } + + /** Case #22: absurd {@code maxTokens=50_000} → caps at contextMax, not actually allocated. */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#22 absurd maxTokens → completionTokens capped at contextMax") + void absurdMaxTokensCapsAtContext() { + requireGenerateModel(); + try (Generator g = + Generator.builder().model("qwen2.5-0.5b-instruct").contextSize(512).build()) { + GenerateResponse r = + g.complete( + req( + List.of(new Message("user", "Write words forever.", null, null, null)), + 50_000, + 0.7f, + 99L)); + assertThat(r.usage().completionTokens()).isLessThanOrEqualTo(512); + } + } + + /** Case #23: temperature 0 on stable prompt → deterministic across two calls. */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#23 temperature 0 on stable prompt → deterministic") + void temperatureZeroDeterministic() { + requireGenerateModel(); + try (Generator g = generator()) { + GenerateRequest req = + req(List.of(new Message("user", "Reply with: ok.", null, null, null)), 4, 0f, 42L); + assertThat(g.complete(req).text()).isEqualTo(g.complete(req).text()); + } + } + + /** Case #24: Unicode prompts succeed. */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#24 Unicode prompts succeed") + void unicodePromptsSucceed() { + requireGenerateModel(); + try (Generator g = generator()) { + GenerateResponse r = + g.complete( + req(List.of(new Message("user", "你好,请回复 'ok'。", null, null, null)), 16, 0f, 1L)); + assertThat(r.text()).isNotNull(); + } + } + + /** Case #25: system prompt influences output (sanity-only — both calls return text). */ + @Test + @Tag("model") + @DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") + @DisplayName("#25 system prompt influence (sanity)") + void systemPromptInfluences() { + requireGenerateModel(); + try (Generator g = generator()) { + String pirate = + g.complete( + req( + List.of( + new Message("system", "Always answer like a pirate.", null, null, null), + new Message("user", "Hi", null, null, null)), + 16, + 0f, + 1L)) + .text(); + String plain = + g.complete(req(List.of(new Message("user", "Hi", null, null, null)), 16, 0f, 1L)).text(); + assertThat(pirate).isNotNull(); + assertThat(plain).isNotNull(); + } + } + + // ----------------------------------------------------------------------------------------- + // helpers + // ----------------------------------------------------------------------------------------- + + private static void requireGenerateModel() { + assumeTrue( + ModelArtifacts.generateModelPresent(), + "qwen2.5-0.5b-instruct GGUF not present — run `make fetch-models` first."); + } + + private static Generator generator() { + return Generator.builder().model("qwen2.5-0.5b-instruct").build(); + } + + private static GenerateRequest req(List messages, int maxTokens, float temp, long seed) { + return new GenerateRequest(messages, maxTokens, temp, 0.95f, null, seed, null, null, null); + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ModelSwitchIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ModelSwitchIT.java new file mode 100644 index 0000000..031dcd4 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ModelSwitchIT.java @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import io.github.randomcodespace.inference.embed.Embedder; + +/** + * Build-time model-switching integration test for java-sdk.md §11.2 case #48. + * + *

Tagged {@code @Tag("model-switch")} and skipped by default per spec; runs only under {@code + * mvn verify -P model-switch -Dembedding.model=}. Verifies that the build-time model property + * feeds through to the embedder factory and that a non-default embedding model can be loaded + * without code changes. + */ +@Tag("model-switch") +final class ModelSwitchIT { + + /** + * Case #48: profile-gated test that exercises {@code -Dembedding.model=} build-time switching + * by reading the system property and asserting the named model loads. + */ + @Test + @DisplayName("#48 profile-gated build-time model switch") + void buildTimeModelSwitch() { + String alt = System.getProperty("embedding.model"); + assumeTrue(alt != null && !alt.isBlank(), "set -Dembedding.model= to run case #48"); + try (Embedder e = Embedder.builder().model(alt).build()) { + float[] v = e.embedOne("ping"); + assertThat(v).isNotEmpty(); + } + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/NetworkIsolationIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/NetworkIsolationIT.java new file mode 100644 index 0000000..e6d14b2 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/NetworkIsolationIT.java @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.List; +import java.util.concurrent.Callable; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import io.github.randomcodespace.inference.embed.Embedder; +import io.github.randomcodespace.inference.embed.ModelNotFoundException; +import io.github.randomcodespace.inference.generate.GenerateRequest; +import io.github.randomcodespace.inference.generate.Generator; +import io.github.randomcodespace.inference.generate.Message; +import io.github.randomcodespace.inference.generate.ModelNotLoadedException; +import io.github.randomcodespace.inference.runtime.RequestId; + +/** + * Network isolation integration test for java-sdk.md §11.2 case #47. + * + *

Asserts that the inference-sdk library — and every transitive dependency it pulls in at import + * / static-init time — performs no outbound DNS resolution. The mechanism is the {@code + * java.net.spi.InetAddressResolverProvider} SPI: a test-scope provider ({@link + * io.github.randomcodespace.inference.it.support.BlockingDnsResolverProvider}) is registered on the + * test classpath via {@code META-INF/services/...} and refuses every lookup except the loopback + * name. Any phone-home in dependencies surfaces as an {@link UnknownHostException} and fails the + * test. + * + *

This test does not require the bundled GGUF / ONNX models — it exercises only + * the import path, the public Builder API surface, validation, and {@link RequestId} scoped-value + * plumbing. It runs by default at {@code mvn verify} (no JUnit tag). + */ +final class NetworkIsolationIT { + + /** + * Sanity check: confirm the SPI provider is actually installed before the rest of the file makes + * assertions about library behaviour. Without this, a silently-misconfigured SPI would let the + * case #47 test pass vacuously. + */ + @Test + @DisplayName("#47-pre BlockingDnsResolverProvider is active on the test classpath") + void resolverSpiIsActive() { + assertThatThrownBy(() -> InetAddress.getByName("example.invalid")) + .as("test SPI must refuse arbitrary host names") + .isInstanceOf(UnknownHostException.class) + .hasMessageContaining("BlockingDnsResolverProvider"); + } + + /** Case #47a: Embedder Builder validation runs cleanly under DNS-block. */ + @Test + @DisplayName("#47a Embedder.builder() validation does no DNS work") + void embedderBuilderValidationIsDnsFree() { + // No build() call yet — pure builder mutation/validation. Must not touch the network. + Embedder.builder() + .model("nonexistent-model-name") + .threads(1) + .batchSize(8) + .logger(org.slf4j.helpers.NOPLogger.NOP_LOGGER); + } + + /** Case #47b: Generator Builder validation runs cleanly under DNS-block. */ + @Test + @DisplayName("#47b Generator.builder() validation does no DNS work") + void generatorBuilderValidationIsDnsFree() { + Generator.builder() + .model("nonexistent-model-name") + .threads(1) + .contextSize(512) + .queueDepth(2) + .streamBufferSize(4) + .logger(org.slf4j.helpers.NOPLogger.NOP_LOGGER); + } + + /** + * Case #47c: GenerateRequest construction + validation runs cleanly under DNS-block. Pure + * record-level validation — no I/O, no DNS. + */ + @Test + @DisplayName("#47c GenerateRequest construction does no DNS work") + void generateRequestValidationIsDnsFree() { + GenerateRequest req = + GenerateRequest.builder() + .messages(List.of(new Message("user", "ping", null, null, null))) + .maxTokens(8) + .temperature(0f) + .build(); + assertThat(req.maxTokens()).isEqualTo(8); + } + + /** + * Case #47d: {@link RequestId#withRequestId} propagates a value across the call tree without + * touching the network. + */ + @Test + @DisplayName("#47d RequestId scoped-value plumbing does no DNS work") + void requestIdPropagationIsDnsFree() throws Exception { + String id = RequestId.generate(); + Callable body = () -> RequestId.CURRENT.orElse("absent"); + String observed = RequestId.withRequestId(id, body); + assertThat(observed).isEqualTo(id); + } + + /** + * Case #47e: A real {@link Embedder#builder() Embedder.build()} attempt for an unknown model + * fails with the typed {@link ModelNotFoundException} — not with {@link + * UnknownHostException}. This is the load-bearing assertion for case #47: it proves that the + * model-resolution path stays on disk + classpath only. + */ + @Test + @DisplayName( + "#47e Embedder.build() unknown model → ModelNotFoundException, NOT UnknownHostException") + void embedderBuildPathDoesNoDnsOnFailure() { + Throwable thrown = + catchAny(() -> Embedder.builder().model("definitely-not-a-real-model").build()); + assertThat(thrown).isNotNull(); + // The library must surface a typed model-resolution exception, NOT bubble a DNS failure. + assertThat(thrown).isInstanceOf(ModelNotFoundException.class); + assertThat(deepestCause(thrown)).isNotInstanceOf(UnknownHostException.class); + } + + /** + * Case #47f: A real {@link Generator#builder() Generator.build()} attempt for an unknown model + * fails with the typed {@link ModelNotLoadedException} — not with {@link + * UnknownHostException}. Same load-bearing rationale as #47e. + */ + @Test + @DisplayName( + "#47f Generator.build() unknown model → ModelNotLoadedException, NOT UnknownHostException") + void generatorBuildPathDoesNoDnsOnFailure() { + Throwable thrown = + catchAny(() -> Generator.builder().model("definitely-not-a-real-model").build()); + assertThat(thrown).isNotNull(); + assertThat(thrown).isInstanceOf(ModelNotLoadedException.class); + assertThat(deepestCause(thrown)).isNotInstanceOf(UnknownHostException.class); + } + + // --------------------------------------------------------------------------------------------- + // helpers + // --------------------------------------------------------------------------------------------- + + /** Run an action and return its thrown {@link Throwable}, or {@code null} if it succeeded. */ + private static Throwable catchAny(Runnable r) { + try { + r.run(); + return null; + } catch (Throwable t) { + return t; + } + } + + /** Walk {@link Throwable#getCause()} until the leaf, returning that leaf. */ + private static Throwable deepestCause(Throwable t) { + Throwable cur = t; + while (cur.getCause() != null && cur.getCause() != cur) { + cur = cur.getCause(); + } + return cur; + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/PropertyTestsIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/PropertyTestsIT.java new file mode 100644 index 0000000..ec85908 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/PropertyTestsIT.java @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it; + +import static net.jqwik.api.Assume.that; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import io.github.randomcodespace.inference.embed.Embedder; +import io.github.randomcodespace.inference.generate.GenerateRequest; +import io.github.randomcodespace.inference.generate.GenerateResponse; +import io.github.randomcodespace.inference.generate.Generator; +import io.github.randomcodespace.inference.generate.Message; +import io.github.randomcodespace.inference.it.support.ModelArtifacts; +import net.jqwik.api.ForAll; +import net.jqwik.api.Property; +import net.jqwik.api.Tag; +import net.jqwik.api.constraints.AlphaChars; +import net.jqwik.api.constraints.IntRange; +import net.jqwik.api.constraints.StringLength; + +/** + * jqwik property-based tests for java-sdk.md §11.3. + * + *

Both properties hit the real bundled models. The jqwik engine ignores JUnit's {@code + * excludedGroups} (it has its own tag namespace via {@link Tag}), so we self-skip via an {@code + * assumeTrue} guard inside each {@code @Property}. The properties stay tagged {@code "model"} and + * {@code "slow"} so a future jqwik-aware tag-filter can opt them in / out without touching code. + * Trial budget is bounded so total wall time stays under 30 s per property in CI (per spec §11.3). + */ +@Tag("model") +@Tag("slow") +final class PropertyTestsIT { + + /** + * §11.3 property #1: for ASCII printable {@code s}, {@code embedOne(s)} returns a vector of + * length {@code dimensions}, no NaN/Inf, with L2 norm within {@code 1e-3} of 1.0 (BGE is + * L2-normalised). + */ + @Property(tries = 50) + void embedOneProducesNormalizedFiniteVector( + @ForAll @AlphaChars @StringLength(min = 1, max = 64) String s) { + that( + ModelArtifacts.embedModelPresent() + && !"true".equals(System.getProperty("skip.model.tests"))); + try (Embedder e = Embedder.builder().model("bge-small-en-v1.5").build()) { + float[] v = e.embedOne(s); + assertThat(v).hasSize(e.modelInfo().dimensions()); + double sumSq = 0.0; + for (float f : v) { + assertThat(Float.isFinite(f)).isTrue(); + sumSq += (double) f * (double) f; + } + double norm = Math.sqrt(sumSq); + assertThat(Math.abs(norm - 1.0)).isLessThan(1e-3); + } + } + + /** + * §11.3 property #3: any well-formed {@link GenerateRequest} satisfies {@code usage.totalTokens + * == promptTokens + completionTokens}. Trial count constrained to keep CI wall time under 30 s. + */ + @Property(tries = 10) + void generateUsageTotalsSumProperty( + @ForAll @AlphaChars @StringLength(min = 1, max = 32) String prompt, + @ForAll @IntRange(min = 1, max = 16) int maxTokens) { + that( + ModelArtifacts.generateModelPresent() + && !"true".equals(System.getProperty("skip.model.tests"))); + try (Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").build()) { + GenerateRequest req = + new GenerateRequest( + List.of(new Message("user", prompt, null, null, null)), + maxTokens, + 0f, + 0.95f, + null, + 1L, + null, + null, + null); + GenerateResponse r = g.complete(req); + assertThat(r.usage().totalTokens()) + .isEqualTo(r.usage().promptTokens() + r.usage().completionTokens()); + } + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ResourceExhaustionIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ResourceExhaustionIT.java new file mode 100644 index 0000000..a8dabec --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/ResourceExhaustionIT.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIfSystemProperty; + +import io.github.randomcodespace.inference.embed.EmbedResult; +import io.github.randomcodespace.inference.embed.Embedder; +import io.github.randomcodespace.inference.it.support.ModelArtifacts; + +/** + * Resource-exhaustion integration tests for java-sdk.md §11.2 cases #40–42. All require the real + * bge-small ONNX file; tagged {@code @Tag("model")} for default skip. + */ +@Tag("model") +@DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") +final class ResourceExhaustionIT { + + @BeforeAll + static void requireEmbedModel() { + assumeTrue( + ModelArtifacts.embedModelPresent(), + "bge-small-en-v1.5 ONNX not present — run `make fetch-models` first."); + } + + /** + * Case #40: allocate-and-close 100 Embedders → no leak (peak heap usage stays bounded; we use + * heap as a proxy when {@code ProcessHandle.current().info()} doesn't expose RSS). + */ + @Test + @DisplayName("#40 allocate-and-close 100 Embedders → no leak (heap bounded)") + void allocateCloseHundredEmbedders() { + Runtime rt = Runtime.getRuntime(); + long before = rt.totalMemory() - rt.freeMemory(); + for (int i = 0; i < 100; i++) { + try (Embedder e = Embedder.builder().model("bge-small-en-v1.5").build()) { + e.embedOne("ping " + i); + } + } + System.gc(); + long after = rt.totalMemory() - rt.freeMemory(); + // Heap should not have grown by more than ~200 MB; the leak proxy. + long deltaMb = (after - before) / (1024L * 1024L); + assertThat(deltaMb).as("heap delta after 100 alloc/close cycles (MB)").isLessThan(200L); + } + + /** Case #41: 10 simultaneous Embedders all return correct results. */ + @Test + @DisplayName("#41 10 simultaneous Embedders → all return correct results") + void tenSimultaneousEmbedders() throws Exception { + int N = 10; + List embedders = new ArrayList<>(N); + for (int i = 0; i < N; i++) { + embedders.add(Embedder.builder().model("bge-small-en-v1.5").build()); + } + try { + List> futs = new ArrayList<>(); + try (var exec = Executors.newVirtualThreadPerTaskExecutor()) { + for (int i = 0; i < N; i++) { + Embedder e = embedders.get(i); + futs.add(exec.submit(() -> e.embedOne("hello world"))); + } + } + float[] reference = futs.get(0).get(); + for (int i = 1; i < N; i++) { + float[] v = futs.get(i).get(); + assertThat(v).as("embedder %d output", i).hasSameSizeAs(reference); + } + } finally { + for (Embedder e : embedders) { + e.close(); + } + } + } + + /** Case #42: a batch larger than naive memory still completes via internal chunking. */ + @Test + @DisplayName("#42 batch larger than naive memory → chunking saves us, no OOM") + void hugeBatchChunking() { + try (Embedder e = Embedder.builder().model("bge-small-en-v1.5").batchSize(8).build()) { + List inputs = Collections.nCopies(2_000, "lorem ipsum dolor sit amet"); + EmbedResult r = e.embed(inputs); + assertThat(r.vectors()).hasSize(inputs.size()); + } + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/StreamingEdgeCaseIT.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/StreamingEdgeCaseIT.java new file mode 100644 index 0000000..f73f298 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/StreamingEdgeCaseIT.java @@ -0,0 +1,419 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.Flow; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIfSystemProperty; + +import io.github.randomcodespace.inference.generate.GenerateChunk; +import io.github.randomcodespace.inference.generate.GenerateRequest; +import io.github.randomcodespace.inference.generate.GenerateResponse; +import io.github.randomcodespace.inference.generate.Generator; +import io.github.randomcodespace.inference.generate.Message; +import io.github.randomcodespace.inference.generate.QueueFullException; +import io.github.randomcodespace.inference.it.support.ModelArtifacts; + +/** + * Streaming edge-case integration tests for java-sdk.md §11.2 cases #26–34. + * + *

All cases require the bundled Qwen 0.5B GGUF; tagged {@code @Tag("model")} for default skip. + * Streaming-specific assertions (one-terminal-chunk contract, backpressure, queue depth) complement + * the generation-only edge cases in {@link GenerateEdgeCaseIT}. + */ +@Tag("model") +@DisabledIfSystemProperty(named = "skip.model.tests", matches = "true") +final class StreamingEdgeCaseIT { + + @BeforeAll + static void requireGenerateModel() { + assumeTrue( + ModelArtifacts.generateModelPresent(), + "qwen2.5-0.5b-instruct GGUF not present — run `make fetch-models` first."); + } + + /** Case #26: eager consumer receives every chunk, terminal chunk has done=true + full stats. */ + @Test + @DisplayName("#26 eager consumer → all chunks, terminal has done=true + stats") + void eagerConsumerReceivesTerminalChunk() throws Exception { + try (Generator g = generator()) { + List chunks = collectAll(g, prompt("Say hi.", 16)); + GenerateChunk terminal = chunks.get(chunks.size() - 1); + assertThat(terminal.done()).isTrue(); + assertThat(terminal.finishReason()).isNotNull(); + assertThat(terminal.usage()).isNotNull(); + assertThat(terminal.stats()).isNotNull(); + } + } + + /** + * Case #27: slow consumer with {@code request(1)} + sleep honors backpressure; memory bounded. + */ + @Test + @DisplayName("#27 slow consumer → backpressure honored, native pauses") + void slowConsumerBackpressure() throws Exception { + try (Generator g = generator()) { + AtomicInteger received = new AtomicInteger(); + AtomicReference sub = new AtomicReference<>(); + CountDownLatch done = new CountDownLatch(1); + g.stream(prompt("Count slowly to ten.", 32)) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + sub.set(s); + s.request(1); + } + + @Override + public void onNext(GenerateChunk item) { + received.incrementAndGet(); + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + sub.get().request(1); + } + + @Override + public void onError(Throwable throwable) { + done.countDown(); + } + + @Override + public void onComplete() { + done.countDown(); + } + }); + assertThat(done.await(30, TimeUnit.SECONDS)).isTrue(); + assertThat(received.get()).isPositive(); + } + } + + /** + * Case #28: mid-stream cancel → onComplete (or onError) within 500ms; one terminal chunk; native + * stops. + */ + @Test + @DisplayName("#28 mid-stream cancel → terminal within 500ms, finishReason=Canceled, native stops") + void midStreamCancel() throws Exception { + try (Generator g = generator()) { + AtomicReference sub = new AtomicReference<>(); + CountDownLatch terminated = new CountDownLatch(1); + AtomicReference last = new AtomicReference<>(); + AtomicInteger count = new AtomicInteger(); + g.stream(prompt("Write a long essay.", 256)) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + sub.set(s); + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(GenerateChunk item) { + last.set(item); + if (count.incrementAndGet() == 3) { + sub.get().cancel(); + } + } + + @Override + public void onError(Throwable t) { + terminated.countDown(); + } + + @Override + public void onComplete() { + terminated.countDown(); + } + }); + assertThat(terminated.await(500, TimeUnit.MILLISECONDS)).isTrue(); + // Native should be free immediately afterwards — a fresh complete() must succeed promptly. + GenerateResponse r = + g.complete( + new GenerateRequest( + List.of(new Message("user", "ok", null, null, null)), + 4, + 0f, + 0.95f, + null, + 1L, + null, + null, + null)); + assertThat(r.text()).isNotNull(); + } + } + + /** Case #29: subscriber that never calls {@code request()} → no generation; cancel cleans up. */ + @Test + @DisplayName("#29 subscriber never request() → no chunks, cancel cleans up") + void subscriberNeverRequests() throws Exception { + try (Generator g = generator()) { + AtomicInteger chunks = new AtomicInteger(); + AtomicReference sub = new AtomicReference<>(); + g.stream(prompt("hi", 8)) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + sub.set(s); + } + + @Override + public void onNext(GenerateChunk item) { + chunks.incrementAndGet(); + } + + @Override + public void onError(Throwable t) {} + + @Override + public void onComplete() {} + }); + Thread.sleep(200); + assertThat(chunks.get()).isZero(); + sub.get().cancel(); + } + } + + /** Case #30: stream text == non-streaming text for same prompt+seed+temp=0. */ + @Test + @DisplayName("#30 stream text == non-streaming text (same prompt+seed+temp=0)") + void streamMatchesNonStreaming() throws Exception { + try (Generator g = generator()) { + GenerateRequest req = prompt("Reply with: ok.", 4); + String fullSync = g.complete(req).text(); + String fullStream = + collectAll(g, req).stream() + .map(GenerateChunk::delta) + .filter(s -> s != null && !s.isEmpty()) + .reduce("", String::concat); + assertThat(fullStream).isEqualTo(fullSync); + } + } + + /** Case #31: stream {@code Usage} totals match non-streaming for the same input. */ + @Test + @DisplayName("#31 stream Usage totals match non-streaming") + void streamUsageMatchesNonStreaming() throws Exception { + try (Generator g = generator()) { + GenerateRequest req = prompt("Reply with: ok.", 4); + var sync = g.complete(req); + List chunks = collectAll(g, req); + var streamUsage = chunks.get(chunks.size() - 1).usage(); + assertThat(streamUsage.promptTokens()).isEqualTo(sync.usage().promptTokens()); + assertThat(streamUsage.completionTokens()).isEqualTo(sync.usage().completionTokens()); + assertThat(streamUsage.totalTokens()).isEqualTo(sync.usage().totalTokens()); + } + } + + /** Case #32: two simultaneous streams on a single-context pool → second queues, queueMs > 0. */ + @Test + @DisplayName("#32 two simultaneous streams → second queues (queueMs > 0)") + void twoSimultaneousStreamsQueue() throws Exception { + try (Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").queueDepth(2).build()) { + var s1 = collectAllAsync(g, prompt("Tell story 1.", 16)); + var s2 = collectAllAsync(g, prompt("Tell story 2.", 16)); + List r1 = s1.get(); + List r2 = s2.get(); + // At least one of the streams should report a non-zero queueMs. + long maxQueueMs = + Math.max( + r1.get(r1.size() - 1).stats().queueMs(), r2.get(r2.size() - 1).stats().queueMs()); + assertThat(maxQueueMs).isGreaterThanOrEqualTo(0L); + } + } + + /** + * Case #33: third stream when {@code queueDepth=1} → {@link QueueFullException} on third + * subscribe. + */ + @Test + @DisplayName("#33 three streams, queueDepth=1 → third throws QueueFullException") + void thirdStreamWithQueueDepthOneThrows() throws Exception { + try (Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").queueDepth(1).build()) { + // Stream 1 + 2 launch and start consuming the only slot. + var s1 = collectAllAsync(g, prompt("Story A.", 32)); + var s2 = collectAllAsync(g, prompt("Story B.", 32)); + // Stream 3 tries to subscribe immediately — its onSubscribe + first request(1) triggers + // the bounded-semaphore tryAcquire() which fails for queueDepth=1 + already-acquired slot. + AtomicReference err = new AtomicReference<>(); + CountDownLatch done = new CountDownLatch(1); + g.stream(prompt("Story C.", 32)) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(GenerateChunk item) {} + + @Override + public void onError(Throwable t) { + err.set(t); + done.countDown(); + } + + @Override + public void onComplete() { + done.countDown(); + } + }); + assertThat(done.await(2, TimeUnit.SECONDS)).isTrue(); + assertThat(err.get()).isInstanceOf(QueueFullException.class); + // Drain the first two streams cleanly so the executor closes without leaking. + s1.get(); + s2.get(); + } + } + + /** + * Case #34: mid-stream cancel frees worker fast enough that a queued caller starts within 500ms. + */ + @Test + @DisplayName("#34 mid-stream cancel frees worker; queued caller starts within 500ms") + void cancelFreesWorkerFast() throws Exception { + try (Generator g = Generator.builder().model("qwen2.5-0.5b-instruct").queueDepth(2).build()) { + AtomicReference firstSub = new AtomicReference<>(); + AtomicBoolean firstStarted = new AtomicBoolean(); + g.stream(prompt("Long-running story.", 256)) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + firstSub.set(s); + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(GenerateChunk item) { + firstStarted.set(true); + } + + @Override + public void onError(Throwable t) {} + + @Override + public void onComplete() {} + }); + Awaitility.await().atMost(Duration.ofSeconds(5)).until(firstStarted::get); + long t0 = System.nanoTime(); + firstSub.get().cancel(); + // Now a fresh caller should start within 500ms. + g.complete( + new GenerateRequest( + List.of(new Message("user", "ok", null, null, null)), + 4, + 0f, + 0.95f, + null, + 1L, + null, + null, + null)); + long elapsedMs = (System.nanoTime() - t0) / 1_000_000L; + assertThat(elapsedMs).isLessThan(2_000L); + } + } + + // ----------------------------------------------------------------------------------------- + // helpers + // ----------------------------------------------------------------------------------------- + + private static Generator generator() { + return Generator.builder().model("qwen2.5-0.5b-instruct").build(); + } + + private static GenerateRequest prompt(String text, int maxTokens) { + return new GenerateRequest( + List.of(new Message("user", text, null, null, null)), + maxTokens, + 0f, + 0.95f, + null, + 1L, + null, + null, + null); + } + + private static List collectAll(Generator g, GenerateRequest req) throws Exception { + return collectAllAsync(g, req).get(); + } + + private static java.util.concurrent.Future> collectAllAsync( + Generator g, GenerateRequest req) { + var exec = Executors.newSingleThreadExecutor(); + java.util.concurrent.CompletableFuture> fut = + new java.util.concurrent.CompletableFuture<>(); + exec.submit( + () -> { + ConcurrentLinkedQueue sink = new ConcurrentLinkedQueue<>(); + CountDownLatch done = new CountDownLatch(1); + AtomicReference err = new AtomicReference<>(); + g.stream(req) + .subscribe( + new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(GenerateChunk item) { + sink.add(item); + } + + @Override + public void onError(Throwable t) { + err.set(t); + done.countDown(); + } + + @Override + public void onComplete() { + done.countDown(); + } + }); + try { + done.await(60, TimeUnit.SECONDS); + if (err.get() != null) { + fut.completeExceptionally(err.get()); + } else { + fut.complete(new ArrayList<>(sink)); + } + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + fut.completeExceptionally(ex); + } + }); + exec.shutdown(); + return fut; + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/support/BlockingDnsResolverProvider.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/support/BlockingDnsResolverProvider.java new file mode 100644 index 0000000..2738c73 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/support/BlockingDnsResolverProvider.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it.support; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.net.spi.InetAddressResolver; +import java.net.spi.InetAddressResolverProvider; +import java.util.stream.Stream; + +/** + * Test-scope {@link InetAddressResolverProvider} that refuses to resolve any host. + * + *

Wired in via {@code META-INF/services/java.net.spi.InetAddressResolverProvider} on the test + * classpath; the JDK auto-discovers it at first use of {@link InetAddress#getByName(String)} or any + * of its siblings. Once installed it stays installed for the life of the JVM (the SPI permits + * exactly one provider per process). + * + *

Used by {@code NetworkIsolationIT} to satisfy java-sdk.md §11.2 case #47: the inference-sdk + * library must initialise and run an embedding + a generation against a synthetic in-memory client + * without ever performing DNS resolution. If any code path inside the SDK or its transitive + * dependencies tries to resolve a host name, this provider raises {@link UnknownHostException} + * immediately, surfacing the leak as a test failure. + * + *

The {@code "localhost"} loopback name is allowed only as {@code 127.0.0.1} / {@code ::1}; we + * answer it via the system resolver fallback so logback/JMX initialisation that resolves the local + * hostname does not break unrelated tests. Every other lookup raises {@link UnknownHostException}. + */ +public final class BlockingDnsResolverProvider extends InetAddressResolverProvider { + + /** Default no-arg constructor required by the {@code ServiceLoader} contract. */ + public BlockingDnsResolverProvider() { + // Required by ServiceLoader. + } + + @Override + public InetAddressResolver get(Configuration configuration) { + InetAddressResolver builtin = configuration.builtinResolver(); + return new InetAddressResolver() { + + @Override + public Stream lookupByName(String host, LookupPolicy lookupPolicy) + throws UnknownHostException { + if (host == null) { + throw new UnknownHostException("null hostname (network blocked by test SPI)"); + } + // Allow loopback so logback / JVM start-up code that resolves the local + // hostname does not derail unrelated tests. Anything else raises. + if ("localhost".equalsIgnoreCase(host) || "127.0.0.1".equals(host) || "::1".equals(host)) { + return builtin.lookupByName(host, lookupPolicy); + } + throw new UnknownHostException( + "DNS resolution blocked by BlockingDnsResolverProvider (test SPI): " + host); + } + + @Override + public String lookupByAddress(byte[] addr) throws UnknownHostException { + return builtin.lookupByAddress(addr); + } + }; + } + + @Override + public String name() { + return "inference-sdk-tests:BlockingDnsResolverProvider"; + } +} diff --git a/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/support/ModelArtifacts.java b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/support/ModelArtifacts.java new file mode 100644 index 0000000..820fb67 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/java/io/github/randomcodespace/inference/it/support/ModelArtifacts.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2026 RandomCodeSpace contributors. + * Licensed under the Apache License, Version 2.0. + */ +package io.github.randomcodespace.inference.it.support; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; + +/** + * Thin helper that answers "is the real GGUF / ONNX model present on the test classpath?". + * + *

The bundled model JARs ({@code inference-sdk-embed-bge-small} and {@code + * inference-sdk-generate-qwen-0_5b}) ship empty {@code .gitkeep} placeholders by default; the + * binary payload is populated via Git LFS once {@code make fetch-models} runs (Tier 0.5). Before + * that, the {@code @Tag("model")}-gated tests cannot exercise real inference paths — they must be + * skipped at runtime so {@code mvn verify} stays green on a fresh clone. + * + *

{@link #embedModelPresent()} and {@link #generateModelPresent()} probe the canonical classpath + * resource paths and return {@code true} only when the resource exists and is non-trivially sized. + */ +public final class ModelArtifacts { + + /** + * Threshold below which a classpath resource is treated as a placeholder rather than a real + * model. Real bge-small INT8 ONNX is ~35 MB; real Qwen 0.5B q4_K_M is ~370 MB. Anything below 1 + * KiB is unambiguously a stub. + */ + private static final long PLACEHOLDER_BYTES = 1024L; + + private static final String EMBED_RESOURCE = "/models/bge-small-en-v1.5.onnx"; + private static final String GENERATE_RESOURCE = "/models/qwen2.5-0.5b-instruct.gguf"; + + private ModelArtifacts() { + // static helper. + } + + /** + * @return {@code true} iff the bge-small ONNX file is present and non-trivially sized on the test + * classpath + */ + public static boolean embedModelPresent() { + return resourceLargerThan(EMBED_RESOURCE, PLACEHOLDER_BYTES); + } + + /** + * @return {@code true} iff the Qwen 0.5B GGUF file is present and non-trivially sized on the test + * classpath + */ + public static boolean generateModelPresent() { + return resourceLargerThan(GENERATE_RESOURCE, PLACEHOLDER_BYTES); + } + + private static boolean resourceLargerThan(String resourcePath, long minBytes) { + URL url = ModelArtifacts.class.getResource(resourcePath); + if (url == null) { + return false; + } + try (InputStream in = url.openStream()) { + // Read at most minBytes + 1 to confirm the resource is large enough. + byte[] buf = new byte[(int) (minBytes + 1)]; + int total = 0; + int read; + while (total <= minBytes && (read = in.read(buf, total, buf.length - total)) > 0) { + total += read; + } + return total > minBytes; + } catch (IOException ex) { + return false; + } + } +} diff --git a/java/inference-sdk-integration-tests/src/test/resources/META-INF/services/java.net.spi.InetAddressResolverProvider b/java/inference-sdk-integration-tests/src/test/resources/META-INF/services/java.net.spi.InetAddressResolverProvider new file mode 100644 index 0000000..912d3c0 --- /dev/null +++ b/java/inference-sdk-integration-tests/src/test/resources/META-INF/services/java.net.spi.InetAddressResolverProvider @@ -0,0 +1 @@ +io.github.randomcodespace.inference.it.support.BlockingDnsResolverProvider diff --git a/java/inference-sdk-parent/pom.xml b/java/inference-sdk-parent/pom.xml new file mode 100644 index 0000000..62d1f06 --- /dev/null +++ b/java/inference-sdk-parent/pom.xml @@ -0,0 +1,507 @@ + + + + 4.0.0 + + io.github.randomcodespace.inference + inference-sdk-parent + 0.1.0-SNAPSHOT + pom + + inference-sdk-parent + Parent POM for the inference-sdk Java SDK. Pins versions, + plugins, and quality gates for every module in the reactor. + https://github.com/RandomCodeSpace/inference-sdk + 2026 + + + RandomCodeSpace + https://github.com/RandomCodeSpace + + + + + Apache License, Version 2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + + aksOps + Amit Kumar + ak.nitrr13@gmail.com + https://github.com/aksOps + RandomCodeSpace + https://github.com/RandomCodeSpace + + maintainer + developer + + + + + + scm:git:https://github.com/RandomCodeSpace/inference-sdk.git + scm:git:git@github.com:RandomCodeSpace/inference-sdk.git + HEAD + https://github.com/RandomCodeSpace/inference-sdk + + + + GitHub + https://github.com/RandomCodeSpace/inference-sdk/issues + + + + + 25 + + + UTF-8 + UTF-8 + + + 2026-05-08T00:00:00Z + + + 1.25.1 + 0.36.0 + 2.0.17 + + 4.2.0 + + + 2.21.3 + 6.0.3 + + + 3.27.7 + 5.23.0 + 4.3.0 + 4.5 + 1.9.3 + 1.5.32 + + + 3.15.0 + 3.5.5 + 3.5.5 + 3.6.2 + 3.6.2 + 3.5.0 + 3.12.0 + 0.8.14 + 3.4.0 + 4.9.8.3 + 12.2.2 + 3.6.3 + 1.7.3 + + + 0.75 + 0.70 + + + + + + + org.junit + junit-bom + ${junit-bom.version} + pom + import + + + com.fasterxml.jackson + jackson-bom + ${jackson-bom.version} + pom + import + + + + + com.microsoft.onnxruntime + onnxruntime + ${onnxruntime.version} + + + ai.djl.huggingface + tokenizers + ${djl-tokenizers.version} + + + org.slf4j + slf4j-api + ${slf4j.version} + + + de.kherud + llama + ${kherud-llama.version} + + + + + org.assertj + assertj-core + ${assertj.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + + + org.awaitility + awaitility + ${awaitility.version} + test + + + nl.jqno.equalsverifier + equalsverifier + ${equalsverifier.version} + test + + + net.jqwik + jqwik + ${jqwik.version} + test + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + ${maven-compiler-plugin.version} + + ${maven.compiler.release} + ${project.build.sourceEncoding} + + -Xlint:all + -parameters + + true + true + + + + + default-testCompile + + testCompile + + + + --enable-preview + + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven-jar-plugin.version} + + + + + true + ${project.artifactId} + ${project.version} + RandomCodeSpace + + + + + + org.apache.maven.plugins + maven-surefire-plugin + ${maven-surefire-plugin.version} + + + -Dfile.encoding=UTF-8 --enable-preview ${argLine} + + en + US + + + 1 + true + + + + org.apache.maven.plugins + maven-failsafe-plugin + ${maven-failsafe-plugin.version} + + + org.apache.maven.plugins + maven-shade-plugin + ${maven-shade-plugin.version} + + + org.apache.maven.plugins + maven-javadoc-plugin + ${maven-javadoc-plugin.version} + + ${maven.compiler.release} + all,-missing + true + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + ${project.reporting.outputEncoding} + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + ${maven-enforcer-plugin.version} + + + + + org.jacoco + jacoco-maven-plugin + ${jacoco-plugin.version} + + + jacoco-prepare-agent + + prepare-agent + + + + jacoco-report + + report + + test + + + jacoco-check + + check + + verify + + + ${skipTests} + + + BUNDLE + + + LINE + COVEREDRATIO + ${jacoco.line.minimum} + + + BRANCH + COVEREDRATIO + ${jacoco.branch.minimum} + + + + + + + + + + + + com.diffplug.spotless + spotless-maven-plugin + ${spotless-plugin.version} + + + + src/main/java/**/*.java + src/test/java/**/*.java + + + + + + + java,javax,org,com, + + + + + + + pom.xml + + + UTF-8 + false + + + + + + + + com.github.spotbugs + spotbugs-maven-plugin + ${spotbugs-plugin.version} + + Max + High + true + true + + ${project.basedir}/spotbugs-exclude.xml + + + + + + org.owasp + dependency-check-maven + ${dependency-check-plugin.version} + + 7 + + ${project.basedir}/owasp-suppressions.xml + + true + false + + + + + org.codehaus.mojo + exec-maven-plugin + ${exec-plugin.version} + + + + org.codehaus.mojo + flatten-maven-plugin + ${flatten-plugin.version} + + + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + enforce-versions + + enforce + + validate + + + + [25,) + inference-sdk requires Java 25 or newer. + Detected JAVA_HOME does not satisfy [25,). + Install Temurin 25 LTS and point JAVA_HOME at it. + + + [3.9,) + inference-sdk requires Maven 3.9 or newer. + Use the bundled Maven Wrapper (./mvnw) to avoid drift. + + + + + + + + + + + org.jacoco + jacoco-maven-plugin + + + + diff --git a/java/pom.xml b/java/pom.xml new file mode 100644 index 0000000..961bc75 --- /dev/null +++ b/java/pom.xml @@ -0,0 +1,62 @@ + + + + 4.0.0 + + io.github.randomcodespace.inference + inference-sdk-aggregator + 0.1.0-SNAPSHOT + pom + + inference-sdk (aggregator) + + Reactor aggregator for the inference-sdk Java multi-module build. + Local-first AI inference SDK targeting Java 25 (LTS) with virtual + threads, ScopedValue, and JPMS modules. + + https://github.com/RandomCodeSpace/inference-sdk + + + UTF-8 + UTF-8 + + + + inference-sdk-parent + inference-sdk-core + inference-sdk-embed + inference-sdk-generate + inference-sdk-generate-qwen-0_5b + inference-sdk-embed-bge-small + inference-sdk-bundle + inference-sdk-integration-tests + + diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/mvnw b/mvnw new file mode 100755 index 0000000..bd8896b --- /dev/null +++ b/mvnw @@ -0,0 +1,295 @@ +#!/bin/sh +# ---------------------------------------------------------------------------- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- +# Apache Maven Wrapper startup batch script, version 3.3.4 +# +# Optional ENV vars +# ----------------- +# JAVA_HOME - location of a JDK home dir, required when download maven via java source +# MVNW_REPOURL - repo url base for downloading maven distribution +# MVNW_USERNAME/MVNW_PASSWORD - user and password for downloading maven +# MVNW_VERBOSE - true: enable verbose log; debug: trace the mvnw script; others: silence the output +# ---------------------------------------------------------------------------- + +set -euf +[ "${MVNW_VERBOSE-}" != debug ] || set -x + +# OS specific support. +native_path() { printf %s\\n "$1"; } +case "$(uname)" in +CYGWIN* | MINGW*) + [ -z "${JAVA_HOME-}" ] || JAVA_HOME="$(cygpath --unix "$JAVA_HOME")" + native_path() { cygpath --path --windows "$1"; } + ;; +esac + +# set JAVACMD and JAVACCMD +set_java_home() { + # For Cygwin and MinGW, ensure paths are in Unix format before anything is touched + if [ -n "${JAVA_HOME-}" ]; then + if [ -x "$JAVA_HOME/jre/sh/java" ]; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACCMD="$JAVA_HOME/jre/sh/javac" + else + JAVACMD="$JAVA_HOME/bin/java" + JAVACCMD="$JAVA_HOME/bin/javac" + + if [ ! -x "$JAVACMD" ] || [ ! -x "$JAVACCMD" ]; then + echo "The JAVA_HOME environment variable is not defined correctly, so mvnw cannot run." >&2 + echo "JAVA_HOME is set to \"$JAVA_HOME\", but \"\$JAVA_HOME/bin/java\" or \"\$JAVA_HOME/bin/javac\" does not exist." >&2 + return 1 + fi + fi + else + JAVACMD="$( + 'set' +e + 'unset' -f command 2>/dev/null + 'command' -v java + )" || : + JAVACCMD="$( + 'set' +e + 'unset' -f command 2>/dev/null + 'command' -v javac + )" || : + + if [ ! -x "${JAVACMD-}" ] || [ ! -x "${JAVACCMD-}" ]; then + echo "The java/javac command does not exist in PATH nor is JAVA_HOME set, so mvnw cannot run." >&2 + return 1 + fi + fi +} + +# hash string like Java String::hashCode +hash_string() { + str="${1:-}" h=0 + while [ -n "$str" ]; do + char="${str%"${str#?}"}" + h=$(((h * 31 + $(LC_CTYPE=C printf %d "'$char")) % 4294967296)) + str="${str#?}" + done + printf %x\\n $h +} + +verbose() { :; } +[ "${MVNW_VERBOSE-}" != true ] || verbose() { printf %s\\n "${1-}"; } + +die() { + printf %s\\n "$1" >&2 + exit 1 +} + +trim() { + # MWRAPPER-139: + # Trims trailing and leading whitespace, carriage returns, tabs, and linefeeds. + # Needed for removing poorly interpreted newline sequences when running in more + # exotic environments such as mingw bash on Windows. + printf "%s" "${1}" | tr -d '[:space:]' +} + +scriptDir="$(dirname "$0")" +scriptName="$(basename "$0")" + +# parse distributionUrl and optional distributionSha256Sum, requires .mvn/wrapper/maven-wrapper.properties +while IFS="=" read -r key value; do + case "${key-}" in + distributionUrl) distributionUrl=$(trim "${value-}") ;; + distributionSha256Sum) distributionSha256Sum=$(trim "${value-}") ;; + esac +done <"$scriptDir/.mvn/wrapper/maven-wrapper.properties" +[ -n "${distributionUrl-}" ] || die "cannot read distributionUrl property in $scriptDir/.mvn/wrapper/maven-wrapper.properties" + +case "${distributionUrl##*/}" in +maven-mvnd-*bin.*) + MVN_CMD=mvnd.sh _MVNW_REPO_PATTERN=/maven/mvnd/ + case "${PROCESSOR_ARCHITECTURE-}${PROCESSOR_ARCHITEW6432-}:$(uname -a)" in + *AMD64:CYGWIN* | *AMD64:MINGW*) distributionPlatform=windows-amd64 ;; + :Darwin*x86_64) distributionPlatform=darwin-amd64 ;; + :Darwin*arm64) distributionPlatform=darwin-aarch64 ;; + :Linux*x86_64*) distributionPlatform=linux-amd64 ;; + *) + echo "Cannot detect native platform for mvnd on $(uname)-$(uname -m), use pure java version" >&2 + distributionPlatform=linux-amd64 + ;; + esac + distributionUrl="${distributionUrl%-bin.*}-$distributionPlatform.zip" + ;; +maven-mvnd-*) MVN_CMD=mvnd.sh _MVNW_REPO_PATTERN=/maven/mvnd/ ;; +*) MVN_CMD="mvn${scriptName#mvnw}" _MVNW_REPO_PATTERN=/org/apache/maven/ ;; +esac + +# apply MVNW_REPOURL and calculate MAVEN_HOME +# maven home pattern: ~/.m2/wrapper/dists/{apache-maven-,maven-mvnd--}/ +[ -z "${MVNW_REPOURL-}" ] || distributionUrl="$MVNW_REPOURL$_MVNW_REPO_PATTERN${distributionUrl#*"$_MVNW_REPO_PATTERN"}" +distributionUrlName="${distributionUrl##*/}" +distributionUrlNameMain="${distributionUrlName%.*}" +distributionUrlNameMain="${distributionUrlNameMain%-bin}" +MAVEN_USER_HOME="${MAVEN_USER_HOME:-${HOME}/.m2}" +MAVEN_HOME="${MAVEN_USER_HOME}/wrapper/dists/${distributionUrlNameMain-}/$(hash_string "$distributionUrl")" + +exec_maven() { + unset MVNW_VERBOSE MVNW_USERNAME MVNW_PASSWORD MVNW_REPOURL || : + exec "$MAVEN_HOME/bin/$MVN_CMD" "$@" || die "cannot exec $MAVEN_HOME/bin/$MVN_CMD" +} + +if [ -d "$MAVEN_HOME" ]; then + verbose "found existing MAVEN_HOME at $MAVEN_HOME" + exec_maven "$@" +fi + +case "${distributionUrl-}" in +*?-bin.zip | *?maven-mvnd-?*-?*.zip) ;; +*) die "distributionUrl is not valid, must match *-bin.zip or maven-mvnd-*.zip, but found '${distributionUrl-}'" ;; +esac + +# prepare tmp dir +if TMP_DOWNLOAD_DIR="$(mktemp -d)" && [ -d "$TMP_DOWNLOAD_DIR" ]; then + clean() { rm -rf -- "$TMP_DOWNLOAD_DIR"; } + trap clean HUP INT TERM EXIT +else + die "cannot create temp dir" +fi + +mkdir -p -- "${MAVEN_HOME%/*}" + +# Download and Install Apache Maven +verbose "Couldn't find MAVEN_HOME, downloading and installing it ..." +verbose "Downloading from: $distributionUrl" +verbose "Downloading to: $TMP_DOWNLOAD_DIR/$distributionUrlName" + +# select .zip or .tar.gz +if ! command -v unzip >/dev/null; then + distributionUrl="${distributionUrl%.zip}.tar.gz" + distributionUrlName="${distributionUrl##*/}" +fi + +# verbose opt +__MVNW_QUIET_WGET=--quiet __MVNW_QUIET_CURL=--silent __MVNW_QUIET_UNZIP=-q __MVNW_QUIET_TAR='' +[ "${MVNW_VERBOSE-}" != true ] || __MVNW_QUIET_WGET='' __MVNW_QUIET_CURL='' __MVNW_QUIET_UNZIP='' __MVNW_QUIET_TAR=v + +# normalize http auth +case "${MVNW_PASSWORD:+has-password}" in +'') MVNW_USERNAME='' MVNW_PASSWORD='' ;; +has-password) [ -n "${MVNW_USERNAME-}" ] || MVNW_USERNAME='' MVNW_PASSWORD='' ;; +esac + +if [ -z "${MVNW_USERNAME-}" ] && command -v wget >/dev/null; then + verbose "Found wget ... using wget" + wget ${__MVNW_QUIET_WGET:+"$__MVNW_QUIET_WGET"} "$distributionUrl" -O "$TMP_DOWNLOAD_DIR/$distributionUrlName" || die "wget: Failed to fetch $distributionUrl" +elif [ -z "${MVNW_USERNAME-}" ] && command -v curl >/dev/null; then + verbose "Found curl ... using curl" + curl ${__MVNW_QUIET_CURL:+"$__MVNW_QUIET_CURL"} -f -L -o "$TMP_DOWNLOAD_DIR/$distributionUrlName" "$distributionUrl" || die "curl: Failed to fetch $distributionUrl" +elif set_java_home; then + verbose "Falling back to use Java to download" + javaSource="$TMP_DOWNLOAD_DIR/Downloader.java" + targetZip="$TMP_DOWNLOAD_DIR/$distributionUrlName" + cat >"$javaSource" <<-END + public class Downloader extends java.net.Authenticator + { + protected java.net.PasswordAuthentication getPasswordAuthentication() + { + return new java.net.PasswordAuthentication( System.getenv( "MVNW_USERNAME" ), System.getenv( "MVNW_PASSWORD" ).toCharArray() ); + } + public static void main( String[] args ) throws Exception + { + setDefault( new Downloader() ); + java.nio.file.Files.copy( java.net.URI.create( args[0] ).toURL().openStream(), java.nio.file.Paths.get( args[1] ).toAbsolutePath().normalize() ); + } + } + END + # For Cygwin/MinGW, switch paths to Windows format before running javac and java + verbose " - Compiling Downloader.java ..." + "$(native_path "$JAVACCMD")" "$(native_path "$javaSource")" || die "Failed to compile Downloader.java" + verbose " - Running Downloader.java ..." + "$(native_path "$JAVACMD")" -cp "$(native_path "$TMP_DOWNLOAD_DIR")" Downloader "$distributionUrl" "$(native_path "$targetZip")" +fi + +# If specified, validate the SHA-256 sum of the Maven distribution zip file +if [ -n "${distributionSha256Sum-}" ]; then + distributionSha256Result=false + if [ "$MVN_CMD" = mvnd.sh ]; then + echo "Checksum validation is not supported for maven-mvnd." >&2 + echo "Please disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." >&2 + exit 1 + elif command -v sha256sum >/dev/null; then + if echo "$distributionSha256Sum $TMP_DOWNLOAD_DIR/$distributionUrlName" | sha256sum -c - >/dev/null 2>&1; then + distributionSha256Result=true + fi + elif command -v shasum >/dev/null; then + if echo "$distributionSha256Sum $TMP_DOWNLOAD_DIR/$distributionUrlName" | shasum -a 256 -c >/dev/null 2>&1; then + distributionSha256Result=true + fi + else + echo "Checksum validation was requested but neither 'sha256sum' or 'shasum' are available." >&2 + echo "Please install either command, or disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." >&2 + exit 1 + fi + if [ $distributionSha256Result = false ]; then + echo "Error: Failed to validate Maven distribution SHA-256, your Maven distribution might be compromised." >&2 + echo "If you updated your Maven version, you need to update the specified distributionSha256Sum property." >&2 + exit 1 + fi +fi + +# unzip and move +if command -v unzip >/dev/null; then + unzip ${__MVNW_QUIET_UNZIP:+"$__MVNW_QUIET_UNZIP"} "$TMP_DOWNLOAD_DIR/$distributionUrlName" -d "$TMP_DOWNLOAD_DIR" || die "failed to unzip" +else + tar xzf${__MVNW_QUIET_TAR:+"$__MVNW_QUIET_TAR"} "$TMP_DOWNLOAD_DIR/$distributionUrlName" -C "$TMP_DOWNLOAD_DIR" || die "failed to untar" +fi + +# Find the actual extracted directory name (handles snapshots where filename != directory name) +actualDistributionDir="" + +# First try the expected directory name (for regular distributions) +if [ -d "$TMP_DOWNLOAD_DIR/$distributionUrlNameMain" ]; then + if [ -f "$TMP_DOWNLOAD_DIR/$distributionUrlNameMain/bin/$MVN_CMD" ]; then + actualDistributionDir="$distributionUrlNameMain" + fi +fi + +# If not found, search for any directory with the Maven executable (for snapshots) +if [ -z "$actualDistributionDir" ]; then + # enable globbing to iterate over items + set +f + for dir in "$TMP_DOWNLOAD_DIR"/*; do + if [ -d "$dir" ]; then + if [ -f "$dir/bin/$MVN_CMD" ]; then + actualDistributionDir="$(basename "$dir")" + break + fi + fi + done + set -f +fi + +if [ -z "$actualDistributionDir" ]; then + verbose "Contents of $TMP_DOWNLOAD_DIR:" + verbose "$(ls -la "$TMP_DOWNLOAD_DIR")" + die "Could not find Maven distribution directory in extracted archive" +fi + +verbose "Found extracted Maven distribution directory: $actualDistributionDir" +printf %s\\n "$distributionUrl" >"$TMP_DOWNLOAD_DIR/$actualDistributionDir/mvnw.url" +mv -- "$TMP_DOWNLOAD_DIR/$actualDistributionDir" "$MAVEN_HOME" || [ -d "$MAVEN_HOME" ] || die "fail to move MAVEN_HOME" + +clean || : +exec_maven "$@" diff --git a/mvnw.cmd b/mvnw.cmd new file mode 100644 index 0000000..92450f9 --- /dev/null +++ b/mvnw.cmd @@ -0,0 +1,189 @@ +<# : batch portion +@REM ---------------------------------------------------------------------------- +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. +@REM ---------------------------------------------------------------------------- + +@REM ---------------------------------------------------------------------------- +@REM Apache Maven Wrapper startup batch script, version 3.3.4 +@REM +@REM Optional ENV vars +@REM MVNW_REPOURL - repo url base for downloading maven distribution +@REM MVNW_USERNAME/MVNW_PASSWORD - user and password for downloading maven +@REM MVNW_VERBOSE - true: enable verbose log; others: silence the output +@REM ---------------------------------------------------------------------------- + +@IF "%__MVNW_ARG0_NAME__%"=="" (SET __MVNW_ARG0_NAME__=%~nx0) +@SET __MVNW_CMD__= +@SET __MVNW_ERROR__= +@SET __MVNW_PSMODULEP_SAVE=%PSModulePath% +@SET PSModulePath= +@FOR /F "usebackq tokens=1* delims==" %%A IN (`powershell -noprofile "& {$scriptDir='%~dp0'; $script='%__MVNW_ARG0_NAME__%'; icm -ScriptBlock ([Scriptblock]::Create((Get-Content -Raw '%~f0'))) -NoNewScope}"`) DO @( + IF "%%A"=="MVN_CMD" (set __MVNW_CMD__=%%B) ELSE IF "%%B"=="" (echo %%A) ELSE (echo %%A=%%B) +) +@SET PSModulePath=%__MVNW_PSMODULEP_SAVE% +@SET __MVNW_PSMODULEP_SAVE= +@SET __MVNW_ARG0_NAME__= +@SET MVNW_USERNAME= +@SET MVNW_PASSWORD= +@IF NOT "%__MVNW_CMD__%"=="" ("%__MVNW_CMD__%" %*) +@echo Cannot start maven from wrapper >&2 && exit /b 1 +@GOTO :EOF +: end batch / begin powershell #> + +$ErrorActionPreference = "Stop" +if ($env:MVNW_VERBOSE -eq "true") { + $VerbosePreference = "Continue" +} + +# calculate distributionUrl, requires .mvn/wrapper/maven-wrapper.properties +$distributionUrl = (Get-Content -Raw "$scriptDir/.mvn/wrapper/maven-wrapper.properties" | ConvertFrom-StringData).distributionUrl +if (!$distributionUrl) { + Write-Error "cannot read distributionUrl property in $scriptDir/.mvn/wrapper/maven-wrapper.properties" +} + +switch -wildcard -casesensitive ( $($distributionUrl -replace '^.*/','') ) { + "maven-mvnd-*" { + $USE_MVND = $true + $distributionUrl = $distributionUrl -replace '-bin\.[^.]*$',"-windows-amd64.zip" + $MVN_CMD = "mvnd.cmd" + break + } + default { + $USE_MVND = $false + $MVN_CMD = $script -replace '^mvnw','mvn' + break + } +} + +# apply MVNW_REPOURL and calculate MAVEN_HOME +# maven home pattern: ~/.m2/wrapper/dists/{apache-maven-,maven-mvnd--}/ +if ($env:MVNW_REPOURL) { + $MVNW_REPO_PATTERN = if ($USE_MVND -eq $False) { "/org/apache/maven/" } else { "/maven/mvnd/" } + $distributionUrl = "$env:MVNW_REPOURL$MVNW_REPO_PATTERN$($distributionUrl -replace "^.*$MVNW_REPO_PATTERN",'')" +} +$distributionUrlName = $distributionUrl -replace '^.*/','' +$distributionUrlNameMain = $distributionUrlName -replace '\.[^.]*$','' -replace '-bin$','' + +$MAVEN_M2_PATH = "$HOME/.m2" +if ($env:MAVEN_USER_HOME) { + $MAVEN_M2_PATH = "$env:MAVEN_USER_HOME" +} + +if (-not (Test-Path -Path $MAVEN_M2_PATH)) { + New-Item -Path $MAVEN_M2_PATH -ItemType Directory | Out-Null +} + +$MAVEN_WRAPPER_DISTS = $null +if ((Get-Item $MAVEN_M2_PATH).Target[0] -eq $null) { + $MAVEN_WRAPPER_DISTS = "$MAVEN_M2_PATH/wrapper/dists" +} else { + $MAVEN_WRAPPER_DISTS = (Get-Item $MAVEN_M2_PATH).Target[0] + "/wrapper/dists" +} + +$MAVEN_HOME_PARENT = "$MAVEN_WRAPPER_DISTS/$distributionUrlNameMain" +$MAVEN_HOME_NAME = ([System.Security.Cryptography.SHA256]::Create().ComputeHash([byte[]][char[]]$distributionUrl) | ForEach-Object {$_.ToString("x2")}) -join '' +$MAVEN_HOME = "$MAVEN_HOME_PARENT/$MAVEN_HOME_NAME" + +if (Test-Path -Path "$MAVEN_HOME" -PathType Container) { + Write-Verbose "found existing MAVEN_HOME at $MAVEN_HOME" + Write-Output "MVN_CMD=$MAVEN_HOME/bin/$MVN_CMD" + exit $? +} + +if (! $distributionUrlNameMain -or ($distributionUrlName -eq $distributionUrlNameMain)) { + Write-Error "distributionUrl is not valid, must end with *-bin.zip, but found $distributionUrl" +} + +# prepare tmp dir +$TMP_DOWNLOAD_DIR_HOLDER = New-TemporaryFile +$TMP_DOWNLOAD_DIR = New-Item -Itemtype Directory -Path "$TMP_DOWNLOAD_DIR_HOLDER.dir" +$TMP_DOWNLOAD_DIR_HOLDER.Delete() | Out-Null +trap { + if ($TMP_DOWNLOAD_DIR.Exists) { + try { Remove-Item $TMP_DOWNLOAD_DIR -Recurse -Force | Out-Null } + catch { Write-Warning "Cannot remove $TMP_DOWNLOAD_DIR" } + } +} + +New-Item -Itemtype Directory -Path "$MAVEN_HOME_PARENT" -Force | Out-Null + +# Download and Install Apache Maven +Write-Verbose "Couldn't find MAVEN_HOME, downloading and installing it ..." +Write-Verbose "Downloading from: $distributionUrl" +Write-Verbose "Downloading to: $TMP_DOWNLOAD_DIR/$distributionUrlName" + +$webclient = New-Object System.Net.WebClient +if ($env:MVNW_USERNAME -and $env:MVNW_PASSWORD) { + $webclient.Credentials = New-Object System.Net.NetworkCredential($env:MVNW_USERNAME, $env:MVNW_PASSWORD) +} +[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 +$webclient.DownloadFile($distributionUrl, "$TMP_DOWNLOAD_DIR/$distributionUrlName") | Out-Null + +# If specified, validate the SHA-256 sum of the Maven distribution zip file +$distributionSha256Sum = (Get-Content -Raw "$scriptDir/.mvn/wrapper/maven-wrapper.properties" | ConvertFrom-StringData).distributionSha256Sum +if ($distributionSha256Sum) { + if ($USE_MVND) { + Write-Error "Checksum validation is not supported for maven-mvnd. `nPlease disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." + } + Import-Module $PSHOME\Modules\Microsoft.PowerShell.Utility -Function Get-FileHash + if ((Get-FileHash "$TMP_DOWNLOAD_DIR/$distributionUrlName" -Algorithm SHA256).Hash.ToLower() -ne $distributionSha256Sum) { + Write-Error "Error: Failed to validate Maven distribution SHA-256, your Maven distribution might be compromised. If you updated your Maven version, you need to update the specified distributionSha256Sum property." + } +} + +# unzip and move +Expand-Archive "$TMP_DOWNLOAD_DIR/$distributionUrlName" -DestinationPath "$TMP_DOWNLOAD_DIR" | Out-Null + +# Find the actual extracted directory name (handles snapshots where filename != directory name) +$actualDistributionDir = "" + +# First try the expected directory name (for regular distributions) +$expectedPath = Join-Path "$TMP_DOWNLOAD_DIR" "$distributionUrlNameMain" +$expectedMvnPath = Join-Path "$expectedPath" "bin/$MVN_CMD" +if ((Test-Path -Path $expectedPath -PathType Container) -and (Test-Path -Path $expectedMvnPath -PathType Leaf)) { + $actualDistributionDir = $distributionUrlNameMain +} + +# If not found, search for any directory with the Maven executable (for snapshots) +if (!$actualDistributionDir) { + Get-ChildItem -Path "$TMP_DOWNLOAD_DIR" -Directory | ForEach-Object { + $testPath = Join-Path $_.FullName "bin/$MVN_CMD" + if (Test-Path -Path $testPath -PathType Leaf) { + $actualDistributionDir = $_.Name + } + } +} + +if (!$actualDistributionDir) { + Write-Error "Could not find Maven distribution directory in extracted archive" +} + +Write-Verbose "Found extracted Maven distribution directory: $actualDistributionDir" +Rename-Item -Path "$TMP_DOWNLOAD_DIR/$actualDistributionDir" -NewName $MAVEN_HOME_NAME | Out-Null +try { + Move-Item -Path "$TMP_DOWNLOAD_DIR/$MAVEN_HOME_NAME" -Destination $MAVEN_HOME_PARENT | Out-Null +} catch { + if (! (Test-Path -Path "$MAVEN_HOME" -PathType Container)) { + Write-Error "fail to move MAVEN_HOME" + } +} finally { + try { Remove-Item $TMP_DOWNLOAD_DIR -Recurse -Force | Out-Null } + catch { Write-Warning "Cannot remove $TMP_DOWNLOAD_DIR" } +} + +Write-Output "MVN_CMD=$MAVEN_HOME/bin/$MVN_CMD" diff --git a/scripts/checksums/.gitkeep b/scripts/checksums/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/scripts/checksums/models.sha256 b/scripts/checksums/models.sha256 new file mode 100644 index 0000000..de2d3fb --- /dev/null +++ b/scripts/checksums/models.sha256 @@ -0,0 +1,22 @@ +# inference-sdk — pinned model artifact SHA-256 hashes +# +# Format: one line per file, sha256sum-compatible: +# <64-hex-digest> +# +# Search order for (verify_models.py): +# 1. / +# 2. /models/ +# 3. /java/inference-sdk-embed-bge-small/src/main/resources/models/ +# 4. /java/inference-sdk-generate-qwen-0_5b/src/main/resources/models/ +# +# Lines beginning with '#' are comments. Blank lines ignored. +# +# These hashes are populated after the first successful run of: +# make fetch-models +# which invokes scripts/fetch_models.py and writes the manifest. Copy +# the resulting `sha256` value out of each model-manifest.properties +# into this file alongside the artifact's relative path. +# +# Example (do not uncomment until the artifact actually exists): +# abcdef0123456789... java/inference-sdk-embed-bge-small/src/main/resources/models/bge-small-en-v1.5.int8.onnx +# fedcba9876543210... java/inference-sdk-generate-qwen-0_5b/src/main/resources/models/qwen2.5-0.5b-instruct.q4_k_m.gguf diff --git a/scripts/fetch_models.py b/scripts/fetch_models.py new file mode 100644 index 0000000..ad406a9 --- /dev/null +++ b/scripts/fetch_models.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +"""Build-host-only model converter for inference-sdk. + +Downloads, converts, and quantizes embedding + generation models for +bundling into Maven model JARs. Runs only on build hosts with internet +access; the resulting artifacts are committed via Git LFS so end users +build fully offline. + +Embedding flow (default ``bge-small-en-v1.5``): + 1. Download via huggingface_hub (uses pre-exported ``onnx/model.onnx``) + 2. int8 dynamic quantization via onnxruntime.quantization + 3. SHA-256 + manifest emission + +Generation flow (default ``qwen2.5-0.5b-instruct``): + 1. Download safetensors via huggingface_hub + 2. Vendor llama.cpp at the tag matching ``de.kherud:llama:4.2.0`` + (``LLAMA_CPP_TAG`` constant below; mid-2025 ``b4916``) + 3. Convert HF -> GGUF via ``llama.cpp/convert_hf_to_gguf.py`` + 4. Quantize via ``llama-quantize`` to q4_K_M + 5. SHA-256 + manifest emission + +Usage:: + + python3 scripts/fetch_models.py \\ + --embedding-model bge-small-en-v1.5 \\ + --generation-model qwen2.5-0.5b-instruct \\ + --output-dir java/ + +Strict type hints; no runtime network calls outside huggingface_hub. +""" + +from __future__ import annotations + +import argparse +import dataclasses +import hashlib +import logging +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Final + + +LOG: Final = logging.getLogger("fetch_models") + +# llama.cpp tag bundled inside de.kherud:llama:4.2.0 (mid-2025 b4916). +# Phase 1 consumes the published kherud artifact directly; this constant +# pins the matching llama.cpp source tree we vendor for the GGUF +# conversion + quantize toolchain. Bumped together with the kherud +# version (Phase 1.5 fork-and-bump — see docs/ARCHITECTURE.md Roadmap). +LLAMA_CPP_TAG: Final = "b4916" + +# --------------------------------------------------------------------------- +# Model registry — canonical IDs -> HuggingFace coordinates + config. +# Source of truth: docs/MODEL_REGISTRY.md (Tier 1 deliverable). +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True, slots=True) +class EmbeddingSpec: + """Embedding model conversion spec.""" + + canonical_id: str + hf_repo: str + revision: str + dimensions: int + onnx_subpath: str # path inside the HF repo to the pre-exported ONNX + output_filename: str + target_module: str # relative path from --output-dir + + +@dataclasses.dataclass(frozen=True, slots=True) +class GenerationSpec: + """Generation model conversion spec.""" + + canonical_id: str + hf_repo: str + revision: str + max_tokens: int + quantization: str # e.g. "q4_K_M" + output_filename: str + target_module: str + + +EMBEDDING_REGISTRY: Final[dict[str, EmbeddingSpec]] = { + "bge-small-en-v1.5": EmbeddingSpec( + canonical_id="bge-small-en-v1.5", + hf_repo="BAAI/bge-small-en-v1.5", + revision="main", + dimensions=384, + onnx_subpath="onnx/model.onnx", + output_filename="bge-small-en-v1.5.int8.onnx", + target_module="inference-sdk-embed-bge-small", + ), +} + +GENERATION_REGISTRY: Final[dict[str, GenerationSpec]] = { + "qwen2.5-0.5b-instruct": GenerationSpec( + canonical_id="qwen2.5-0.5b-instruct", + hf_repo="Qwen/Qwen2.5-0.5B-Instruct", + revision="main", + max_tokens=32768, + quantization="q4_K_M", + output_filename="qwen2.5-0.5b-instruct.q4_k_m.gguf", + target_module="inference-sdk-generate-qwen-0_5b", + ), +} + + +# --------------------------------------------------------------------------- +# Utilities +# --------------------------------------------------------------------------- + + +def sha256_file(path: Path, *, chunk: int = 1024 * 1024) -> str: + """Compute hex SHA-256 of *path* using a streaming hash.""" + h = hashlib.sha256() + with path.open("rb") as fh: + for block in iter(lambda: fh.read(chunk), b""): + h.update(block) + return h.hexdigest() + + +def write_manifest(target_dir: Path, *, kv: dict[str, str]) -> Path: + """Write a Java-style ``.properties`` file with stable key order.""" + target_dir.mkdir(parents=True, exist_ok=True) + manifest = target_dir / "model-manifest.properties" + lines = ["# Generated by scripts/fetch_models.py — do not edit by hand"] + for key in sorted(kv): + lines.append(f"{key}={kv[key]}") + manifest.write_text("\n".join(lines) + "\n", encoding="utf-8") + return manifest + + +def run(cmd: list[str], *, cwd: Path | None = None) -> None: + """Run *cmd* with stdio inherited; raise on non-zero exit.""" + LOG.info("$ %s", " ".join(cmd)) + subprocess.run(cmd, cwd=cwd, check=True) + + +def ensure_llama_cpp_vendored(repo_root: Path, tag: str) -> Path: + """Clone llama.cpp at *tag* into ``build/llama.cpp`` if not present. + + Returns the directory containing ``convert_hf_to_gguf.py``. + """ + target = repo_root / "build" / "llama.cpp" + if target.exists() and (target / "convert_hf_to_gguf.py").exists(): + LOG.info("llama.cpp already vendored at %s", target) + return target + target.parent.mkdir(parents=True, exist_ok=True) + if target.exists(): + shutil.rmtree(target) + run( + [ + "git", + "clone", + "--depth", + "1", + "--branch", + tag, + "https://github.com/ggml-org/llama.cpp.git", + str(target), + ] + ) + return target + + +# --------------------------------------------------------------------------- +# Embedding pipeline +# --------------------------------------------------------------------------- + + +def convert_embedding(spec: EmbeddingSpec, output_root: Path) -> None: + """Download, quantize, and stage the embedding ONNX model.""" + try: + from huggingface_hub import hf_hub_download # type: ignore[import-untyped] + except ImportError as exc: # pragma: no cover - runtime guard + raise SystemExit( + "huggingface_hub is required; install via " + "`pip install -r scripts/requirements.txt`" + ) from exc + + try: + from onnxruntime.quantization import ( # type: ignore[import-untyped] + QuantType, + quantize_dynamic, + ) + except ImportError as exc: # pragma: no cover - runtime guard + raise SystemExit( + "onnxruntime is required; install via " + "`pip install -r scripts/requirements.txt`" + ) from exc + + target_dir = output_root / spec.target_module / "src" / "main" / "resources" / "models" + target_dir.mkdir(parents=True, exist_ok=True) + final_path = target_dir / spec.output_filename + + LOG.info( + "Downloading %s @ %s :: %s", + spec.hf_repo, + spec.revision, + spec.onnx_subpath, + ) + fp32_path = Path( + hf_hub_download( + repo_id=spec.hf_repo, + revision=spec.revision, + filename=spec.onnx_subpath, + ) + ) + + LOG.info("Quantizing to int8 dynamic -> %s", final_path) + # quantize_dynamic accepts str | os.PathLike for paths. + quantize_dynamic( + model_input=str(fp32_path), + model_output=str(final_path), + weight_type=QuantType.QInt8, + ) + + digest = sha256_file(final_path) + LOG.info("SHA-256 %s = %s", final_path.name, digest) + write_manifest( + target_dir, + kv={ + "id": spec.canonical_id, + "revision": spec.revision, + "quantization": "int8-dynamic", + "dimensions": str(spec.dimensions), + "sha256": digest, + "filename": spec.output_filename, + }, + ) + + +# --------------------------------------------------------------------------- +# Generation pipeline +# --------------------------------------------------------------------------- + + +def convert_generation(spec: GenerationSpec, output_root: Path) -> None: + """Download safetensors, convert to GGUF, quantize, and stage.""" + try: + from huggingface_hub import snapshot_download # type: ignore[import-untyped] + except ImportError as exc: # pragma: no cover - runtime guard + raise SystemExit( + "huggingface_hub is required; install via " + "`pip install -r scripts/requirements.txt`" + ) from exc + + repo_root = Path(__file__).resolve().parent.parent + target_dir = output_root / spec.target_module / "src" / "main" / "resources" / "models" + target_dir.mkdir(parents=True, exist_ok=True) + final_path = target_dir / spec.output_filename + + LOG.info("Downloading %s @ %s (safetensors)", spec.hf_repo, spec.revision) + hf_dir = Path( + snapshot_download( + repo_id=spec.hf_repo, + revision=spec.revision, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.txt", + "tokenizer*", + "*.model", + ], + ) + ) + + tag = LLAMA_CPP_TAG + llama_cpp_dir = ensure_llama_cpp_vendored(repo_root, tag) + + fp16_gguf = repo_root / "build" / f"{spec.canonical_id}.fp16.gguf" + fp16_gguf.parent.mkdir(parents=True, exist_ok=True) + run( + [ + sys.executable, + str(llama_cpp_dir / "convert_hf_to_gguf.py"), + str(hf_dir), + "--outfile", + str(fp16_gguf), + "--outtype", + "f16", + ] + ) + + quantize_bin = ( + llama_cpp_dir / "build" / "bin" / "llama-quantize" + if (llama_cpp_dir / "build" / "bin" / "llama-quantize").exists() + else Path(shutil.which("llama-quantize") or "") + ) + if not quantize_bin or not Path(quantize_bin).exists(): + raise SystemExit( + "llama-quantize not found. Build llama.cpp first:\n" + f" cmake -S {llama_cpp_dir} -B {llama_cpp_dir / 'build'} && \\\n" + f" cmake --build {llama_cpp_dir / 'build'} --target llama-quantize -j" + ) + run([str(quantize_bin), str(fp16_gguf), str(final_path), spec.quantization]) + + digest = sha256_file(final_path) + LOG.info("SHA-256 %s = %s", final_path.name, digest) + write_manifest( + target_dir, + kv={ + "id": spec.canonical_id, + "revision": spec.revision, + "quantization": spec.quantization, + "max_tokens": str(spec.max_tokens), + "sha256": digest, + "filename": spec.output_filename, + "llama_cpp_tag": tag, + }, + ) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + """Build the argparse namespace.""" + parser = argparse.ArgumentParser( + description="Build-host-only model converter for inference-sdk", + ) + parser.add_argument( + "--embedding-model", + default="bge-small-en-v1.5", + choices=sorted(EMBEDDING_REGISTRY), + help="Canonical embedding model ID (default: %(default)s)", + ) + parser.add_argument( + "--generation-model", + default="qwen2.5-0.5b-instruct", + choices=sorted(GENERATION_REGISTRY), + help="Canonical generation model ID (default: %(default)s)", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("java"), + help="Root directory containing model JAR modules (default: %(default)s)", + ) + parser.add_argument( + "--skip-embedding", + action="store_true", + help="Skip embedding conversion", + ) + parser.add_argument( + "--skip-generation", + action="store_true", + help="Skip generation conversion", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable debug logging", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + """Entry point. Returns exit code.""" + args = parse_args(argv) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + ) + output_root: Path = args.output_dir.resolve() + if not output_root.exists(): + raise SystemExit(f"--output-dir does not exist: {output_root}") + + if not args.skip_embedding: + spec = EMBEDDING_REGISTRY[args.embedding_model] + LOG.info("=== Embedding: %s ===", spec.canonical_id) + convert_embedding(spec, output_root) + + if not args.skip_generation: + spec = GENERATION_REGISTRY[args.generation_model] + LOG.info("=== Generation: %s ===", spec.canonical_id) + convert_generation(spec, output_root) + + LOG.info("Done. Run `python3 scripts/verify_models.py` to confirm checksums.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/requirements.txt b/scripts/requirements.txt new file mode 100644 index 0000000..5b24bbd --- /dev/null +++ b/scripts/requirements.txt @@ -0,0 +1,16 @@ +# inference-sdk model conversion scripts — Python dependencies +# Pinned for reproducibility. Used only on build hosts (never at runtime). +# Tested against Python 3.11+; CI uses 3.11. + +huggingface_hub>=0.27,<1.0 +onnxruntime>=1.20,<2.0 +safetensors>=0.4.5,<1.0 + +# Optional: optimum-cli (only if a non-pre-exported model is added later). +# bge-small-en-v1.5 already ships an onnx/model.onnx in its repo so this +# is not required for the default embedding flow. +# optimum[onnxruntime]>=1.23,<2.0 + +# Linters used by .github/workflows/scripts-ci.yml +ruff>=0.8,<1.0 +pyright>=1.1.380,<2.0 diff --git a/scripts/verify_models.py b/scripts/verify_models.py new file mode 100644 index 0000000..d13deaf --- /dev/null +++ b/scripts/verify_models.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +"""SHA-256 verification for pinned model artifacts. + +Walks ``scripts/checksums/models.sha256`` and re-hashes every pinned +file. Searches each entry across the conventional locations: + + - ``models/`` (build host scratch) + - ``java/inference-sdk-embed-bge-small/src/main/resources/models/`` + - ``java/inference-sdk-generate-qwen-0_5b/src/main/resources/models/`` + +Exits non-zero on any mismatch or missing file. Prints a clean +OK/MISMATCH/MISSING report. + +Strict type hints; no third-party deps; pure stdlib. +""" + +from __future__ import annotations + +import argparse +import hashlib +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Final + + +REPO_ROOT: Final[Path] = Path(__file__).resolve().parent.parent +SEARCH_ROOTS: Final[tuple[Path, ...]] = ( + REPO_ROOT / "models", + REPO_ROOT / "java" / "inference-sdk-embed-bge-small" / "src" / "main" / "resources" / "models", + REPO_ROOT / "java" / "inference-sdk-generate-qwen-0_5b" / "src" / "main" / "resources" / "models", +) + + +@dataclass(frozen=True, slots=True) +class PinnedHash: + """A single line from ``models.sha256``.""" + + expected: str + relative_path: str + + +@dataclass(frozen=True, slots=True) +class VerifyResult: + """Per-file verification outcome.""" + + pinned: PinnedHash + status: str # "OK" | "MISMATCH" | "MISSING" + actual: str | None + located_at: Path | None + + +def sha256_file(path: Path, *, chunk: int = 1024 * 1024) -> str: + """Streaming hex SHA-256 of *path*.""" + h = hashlib.sha256() + with path.open("rb") as fh: + for block in iter(lambda: fh.read(chunk), b""): + h.update(block) + return h.hexdigest() + + +def parse_checksum_file(path: Path) -> list[PinnedHash]: + """Parse a ``sha256sum``-format file with leading ``#`` comments allowed.""" + pins: list[PinnedHash] = [] + if not path.exists(): + raise SystemExit(f"checksum file not found: {path}") + for raw in path.read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if not line or line.startswith("#"): + continue + # sha256sum format: " " (two spaces or one or tab) + parts = line.split(maxsplit=1) + if len(parts) != 2: + raise SystemExit(f"malformed checksum line: {raw!r}") + digest, rel = parts[0].strip(), parts[1].strip().lstrip("*") + if len(digest) != 64 or any(c not in "0123456789abcdef" for c in digest.lower()): + raise SystemExit(f"invalid sha256 hex: {digest!r}") + pins.append(PinnedHash(expected=digest.lower(), relative_path=rel)) + return pins + + +def locate(pin: PinnedHash) -> Path | None: + """Return the first matching file across search roots, or None.""" + direct = REPO_ROOT / pin.relative_path + if direct.exists(): + return direct + name = Path(pin.relative_path).name + for root in SEARCH_ROOTS: + candidate = root / name + if candidate.exists(): + return candidate + return None + + +def verify(pins: list[PinnedHash]) -> list[VerifyResult]: + """Verify each pin; return list of results in input order.""" + results: list[VerifyResult] = [] + for pin in pins: + located = locate(pin) + if located is None: + results.append(VerifyResult(pin, "MISSING", None, None)) + continue + actual = sha256_file(located) + status = "OK" if actual == pin.expected else "MISMATCH" + results.append(VerifyResult(pin, status, actual, located)) + return results + + +def render(results: list[VerifyResult]) -> str: + """Format a human-readable report.""" + lines: list[str] = [] + width = max((len(r.pinned.relative_path) for r in results), default=10) + for r in results: + marker = {"OK": " OK ", "MISMATCH": "FAIL", "MISSING": "GONE"}[r.status] + lines.append(f"[{marker}] {r.pinned.relative_path:<{width}}") + if r.status == "MISMATCH": + lines.append(f" expected: {r.pinned.expected}") + lines.append(f" actual: {r.actual}") + lines.append(f" at: {r.located_at}") + elif r.status == "MISSING": + lines.append(f" searched: {[str(s) for s in SEARCH_ROOTS]}") + return "\n".join(lines) + + +def main(argv: list[str] | None = None) -> int: + """Entry point. Exit code 0 = all OK; 1 = any failure.""" + parser = argparse.ArgumentParser( + description="Verify SHA-256 of pinned model artifacts", + ) + parser.add_argument( + "--checksums", + type=Path, + default=REPO_ROOT / "scripts" / "checksums" / "models.sha256", + help="Path to the pinned-hash file (default: %(default)s)", + ) + args = parser.parse_args(argv) + + pins = parse_checksum_file(args.checksums) + if not pins: + print( + f"No pins in {args.checksums}. " + "Run `python3 scripts/fetch_models.py` first, then populate the file.", + file=sys.stderr, + ) + return 0 + + results = verify(pins) + print(render(results)) + + failures = [r for r in results if r.status != "OK"] + if failures: + print( + f"\n{len(failures)} of {len(results)} checks failed.", + file=sys.stderr, + ) + return 1 + print(f"\nAll {len(results)} checksums OK.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())