Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
387 changes: 387 additions & 0 deletions .claude/knowledge/w2-arrayview-migration.md

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions .claude/knowledge/w2-blas-statistics-audit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# W2-3 + W2-4 Audit — BLAS levels + statistics ArrayView compliance

## Verdict
**CLEAN.** No follow-up wave needed for `blas_level{1,2,3}.rs` or `statistics.rs`. All four files are already ArrayView-shaped via trait impls on `ArrayBase`.

## Per-file findings

### `src/hpc/blas_level1.rs`
- Trait impl on ArrayBase: **yes, L47** — `impl<A, S> BlasLevel1<A> for ArrayBase<S, Ix1>`
- Bonus trait impls (not in the original migration doc, but clean): `ScalarArith` (L196), `VecArith` (L242) — both on `ArrayBase<S, Ix1>`
- Slice-taking pub fns: **1** — `blas_rotg` (L152). **OK-as-is**: signature is `(a: A, b: A)` (scalars), not slices. The regex `^pub fn .*&\[` matched a `&[` in the doc-comment example, not the signature.
- `axis_iter` misuse: **0**
- Bridge pattern: verified present in trait methods — `blas_dot`, `blas_axpy`, `blas_scal`, `blas_nrm2`, `blas_asum` all dispatch through `as_slice()` hot path + stride-aware cold path.

### `src/hpc/blas_level2.rs`
- Trait impl on ArrayBase: **yes, L97** — `impl<A, S> BlasLevel2<A> for ArrayBase<S, Ix2>`
- Slice-taking pub fns: **0**
- `axis_iter` misuse: **0**

### `src/hpc/blas_level3.rs`
- Trait impl on ArrayBase: **yes, L59** — `impl<A, S> BlasLevel3<A> for ArrayBase<S, Ix2>`
- Slice-taking pub fns: **0**
- `axis_iter` misuse: **0**

### `src/hpc/statistics.rs`
- Trait impl on ArrayBase: **yes, L65** — `impl<A, S, D> Statistics<A> for ArrayBase<S, D>` (note: generic-D, unlike BLAS L1/L2/L3 which fix `Ix1`/`Ix2`)
- Slice-taking pub fns: **0**
- `axis_iter` misuse: **0**

## Build verification
`cargo check -p ndarray --no-default-features --features std` → clean (31.82s, no warnings).

## Surprises
- `blas_level1.rs` carries two extra trait impls (`ScalarArith`, `VecArith`) on `ArrayBase<S, Ix1>` beyond `BlasLevel1` itself. Not mentioned in the original migration doc but clean and consistent with the two-layer rule.
- `blas_rotg` regex match was a false positive (doc-comment `&[` in an example, not in the signature).

## Follow-up needed
**None.** W2-3 and W2-4 require no code changes. The W2 sprint scope reduces to the three converter waves: W2-1 (reductions), W2-2a (vml), W2-2b (activations).
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ target/

# Apple details
**/.DS_Store

# Claude Code: agent isolation worktrees (temporary, per-agent)
.claude/worktrees/
4 changes: 3 additions & 1 deletion crates/burn/src/ops/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ where
if view.is_standard_layout() {
if let Some(input) = view.as_slice() {
let mut output = alloc::vec![0.0f32; input.len()];
ndarray::hpc::activations::sigmoid_f32(input, &mut output);
let in_view = ndarray::ArrayView::from(input);
let out_view = ndarray::ArrayViewMut::from(&mut output[..]);
ndarray::hpc::activations::sigmoid_f32(in_view, out_view);
let shape: alloc::vec::Vec<usize> = view.shape().to_vec();
let array = ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), output)
.expect("sigmoid output shape mismatch");
Expand Down
30 changes: 17 additions & 13 deletions crates/burn/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,30 @@ use libm::erf;

/// Try to accelerate a unary f32 operation via ndarray's hpc::vml (F32x16 SIMD).
///
/// VML signature: `fn(input: &[f32], output: &mut [f32])`.
/// Uses crate::simd::F32x16 internally. Consumer never sees hardware details.
/// VML signature (post W2-2a): generic over dimension, takes
/// `ArrayView<f32, D> / ArrayViewMut<f32, D>`. We pass the dyn-D views from
/// the burn tensor directly; ndarray's vml routes to the F32x16 SIMD
/// primitive on the contiguous hot path and falls back to a stride-aware
/// `Zip` on the cold path. Consumer never sees hardware details.
#[cfg(feature = "simd")]
fn try_vml_unary(
tensor: NdArrayTensor,
vml_fn: fn(&[f32], &mut [f32]),
vml_fn: fn(ndarray::ArrayView<'_, f32, ndarray::IxDyn>, ndarray::ArrayViewMut<'_, f32, ndarray::IxDyn>),
) -> Result<NdArrayTensor, NdArrayTensor> {
if let NdArrayTensor::F32(storage) = tensor {
let shared = storage.into_shared();
if shared.is_standard_layout() {
if let Some(input) = shared.as_slice() {
let mut output = vec![0.0f32; input.len()];
vml_fn(input, &mut output);
let shape = shared.shape().to_vec();
let array = ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), output)
.expect("vml output shape mismatch");
return Ok(NdArrayTensor::F32(
crate::NdArrayStorage::Owned(array.into_shared()),
));
}
let shape = shared.shape().to_vec();
let len = shared.len();
let mut output = ndarray::Array::from_shape_vec(
ndarray::IxDyn(&shape),
vec![0.0f32; len],
)
.expect("vml output shape mismatch");
vml_fn(shared.view(), output.view_mut());
return Ok(NdArrayTensor::F32(
crate::NdArrayStorage::Owned(output.into_shared()),
));
}
return Err(NdArrayTensor::F32(crate::NdArrayStorage::Owned(shared)));
}
Expand Down
Loading
Loading