diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 00000000..570c4fc0 --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,26 @@ +changelog: + categories: + - title: "Breaking changes" + labels: + - "breaking change" + - title: "New features" + labels: + - "new feature" + - title: "Bug fixes" + labels: + - bug + - title: Changes + labels: + - change + - title: Deprecated + labels: + - deprecate + - title: Removed + labels: + - remove + - title: "Update Dependencies" + labels: + - dependencies + - title: Others + labels: + - "*" diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index e63fc81f..4123d4a6 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -1,21 +1,53 @@ +# Based on starter workflow +# https://github.com/actions/starter-workflows/blob/8217436fdee2338da2d6fd02b7c9fcff634c40e7/pages/static.yml +# +# Simple workflow for deploying static content to GitHub Pages name: "GitHub Pages" on: + # Runs on pushes targeting the default branch push: branches: - master + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow one concurrent deployment +concurrency: + group: "pages" + cancel-in-progress: true + jobs: - pages: - runs-on: ubuntu-18.04 + # Single deploy job since we're just deploying + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Generate code coverage - run: | - RUSTDOCFLAGS="--html-in-header katex-header.html" cargo doc --no-deps - mv target/doc public - - name: Deploy GitHub Pages - uses: peaceiris/actions-gh-pages@v3 + - name: Checkout + uses: actions/checkout@v4 + + # Generate cargo-doc + - name: Generate documentation + run: cargo doc --no-deps + + - name: Setup Pages + uses: actions/configure-pages@v5 + + # Upload target/doc directory + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./public + path: 'target/doc' + + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/intel-mkl.yml b/.github/workflows/intel-mkl.yml index e7e06f2f..14a0b443 100644 --- a/.github/workflows/intel-mkl.yml +++ b/.github/workflows/intel-mkl.yml @@ -5,43 +5,16 @@ on: branches: - master pull_request: {} + workflow_dispatch: jobs: - windows: - runs-on: windows-2019 + intel-mkl: + strategy: + fail-fast: false + matrix: + system: [ubuntu-22.04, windows-latest] + runs-on: ${{ matrix.system }} steps: - - uses: actions/checkout@v1 - - uses: actions-rs/cargo@v1 - with: - command: test - args: > - --manifest-path=ndarray-linalg/Cargo.toml - --no-default-features - --features=intel-mkl-static - - linux: - runs-on: ubuntu-18.04 - steps: - - uses: actions/checkout@v1 - - uses: actions-rs/cargo@v1 - name: cargo test - with: - command: test - args: > - --manifest-path=ndarray-linalg/Cargo.toml - --no-default-features - --features=intel-mkl-static - - linux-container: - runs-on: ubuntu-18.04 - container: rustmath/mkl-rust:1.43.0 - steps: - - uses: actions/checkout@v1 - - uses: actions-rs/cargo@v1 - name: cargo test - with: - command: test - args: > - --manifest-path=ndarray-linalg/Cargo.toml - --no-default-features - --features=intel-mkl-system + - uses: actions/checkout@v4 + - name: cargo test + run: cargo test --manifest-path=ndarray-linalg/Cargo.toml --no-default-features --features=intel-mkl-static --verbose diff --git a/.github/workflows/netlib.yml b/.github/workflows/netlib.yml index f278f0ca..70048503 100644 --- a/.github/workflows/netlib.yml +++ b/.github/workflows/netlib.yml @@ -5,6 +5,7 @@ on: branches: - master pull_request: {} + workflow_dispatch: jobs: linux: @@ -13,17 +14,12 @@ jobs: matrix: feature: - static - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: apt install gfortran run: | sudo apt update sudo apt install -y gfortran - - uses: actions-rs/cargo@v1 - with: - command: test - args: > - --manifest-path=ndarray-linalg/Cargo.toml - --no-default-features - --features=netlib-${{ matrix.feature }} + - name: cargo test + run: cargo test --manifest-path=ndarray-linalg/Cargo.toml --no-default-features --features=netlib-${{ matrix.feature }} diff --git a/.github/workflows/openblas.yml b/.github/workflows/openblas.yml index 644bf313..5375be70 100644 --- a/.github/workflows/openblas.yml +++ b/.github/workflows/openblas.yml @@ -5,12 +5,11 @@ on: branches: - master pull_request: {} + workflow_dispatch: jobs: linux: - runs-on: ubuntu-18.04 - container: - image: rust + runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: @@ -18,20 +17,15 @@ jobs: - static - system steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: apt install gfortran run: | - apt update - apt install -y gfortran + sudo apt update + sudo apt install -y gfortran - name: Install OpenBLAS by apt run: | - apt update - apt install -y libopenblas-dev + sudo apt update + sudo apt install -y libopenblas-dev if: ${{ contains(matrix.feature, 'system') }} - - uses: actions-rs/cargo@v1 - with: - command: test - args: > - --manifest-path=ndarray-linalg/Cargo.toml - --no-default-features - --features=openblas-${{ matrix.feature }} + - name: cargo test + run: cargo test --manifest-path=ndarray-linalg/Cargo.toml --no-default-features --features=openblas-${{ matrix.feature }} diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index b1c2bc1f..66ca1553 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -5,34 +5,50 @@ on: branches: - master pull_request: {} + workflow_dispatch: jobs: check-format: - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v1 - - uses: actions-rs/cargo@v1 - with: - command: fmt - args: -- --check + - uses: actions/checkout@v4 + - name: fmt + run: cargo fmt -- --check + + check: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + - name: cargo check + run: cargo check --all-targets + + check-doc: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + - name: cargo doc + run: cargo doc --no-deps clippy: - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v1 - - uses: actions-rs/cargo@v1 - with: - command: clippy + - uses: actions/checkout@v4 + - name: cargo clippy + run: cargo clippy coverage: - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 container: - image: rustmath/mkl-rust:1.43.0 + image: xd009642/tarpaulin:develop-nightly options: --security-opt seccomp=unconfined steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly + - name: Install Cross + uses: taiki-e/install-action@v2 + with: + tool: cargo-tarpaulin - name: Generate code coverage - run: | - cargo tarpaulin --verbose --features=intel-mkl --out Xml --manifest-path=ndarray-linalg/Cargo.toml + run: cargo +nightly tarpaulin --features=intel-mkl-static --out xml - name: Upload to codecov.io - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v5 diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index d15ccf81..00000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,140 +0,0 @@ -Unreleased ------------ - -0.13.0 - 20 Feb 2021 -===================== - -https://github.com/rust-ndarray/ndarray-linalg/milestone/5 - -Updated dependencies ---------------------- -- ndarray 0.14 https://github.com/rust-ndarray/ndarray-linalg/pull/258 -- cauchy 0.3.0 (num-complex 0.3.1, rand 0.7.3), lapack 0.17.0 https://github.com/rust-ndarray/ndarray-linalg/pull/260 - -### optional dependencies - -- openblas-src 0.10.2 https://github.com/rust-ndarray/ndarray-linalg/pull/253 -- intel-mkl-src 0.6.0 https://github.com/rust-ndarray/ndarray-linalg/pull/204 - -Added ------- -- Split out `ndarray_linalg::lapack` as "lax" crate https://github.com/rust-ndarray/ndarray-linalg/pull/207 - - cargo-workspace https://github.com/rust-ndarray/ndarray-linalg/pull/209 - -Changed --------- -- Dual license, MIT or Apache-2.0 License https://github.com/rust-ndarray/ndarray-linalg/pull/262 -- Revise tests for least-square problem https://github.com/rust-ndarray/ndarray-linalg/pull/227 -- Support static link to LAPACK backend https://github.com/rust-ndarray/ndarray-linalg/pull/204 -- Drop LAPACKE dependence, and rewrite them in Rust (see below) https://github.com/rust-ndarray/ndarray-linalg/pull/206 -- Named record like `C { row: i32, lda: i32 }` instead of enum for `MatrixLayout` https://github.com/rust-ndarray/ndarray-linalg/pull/211 -- Split LAPACK error into computational failure and invalid values https://github.com/rust-ndarray/ndarray-linalg/pull/210 -- Use thiserror crate https://github.com/rust-ndarray/ndarray-linalg/pull/208 - -### LAPACKE rewrite - -- Cholesky https://github.com/rust-ndarray/ndarray-linalg/pull/225 -- Eigenvalue for general matrix https://github.com/rust-ndarray/ndarray-linalg/pull/212 -- Eigenvalue for symmetric/Hermitian matrix https://github.com/rust-ndarray/ndarray-linalg/pull/217 -- least squares problem https://github.com/rust-ndarray/ndarray-linalg/pull/220 -- QR decomposition https://github.com/rust-ndarray/ndarray-linalg/pull/224 -- LU decomposition https://github.com/rust-ndarray/ndarray-linalg/pull/213 -- LDL decomposition https://github.com/rust-ndarray/ndarray-linalg/pull/216 -- SVD https://github.com/rust-ndarray/ndarray-linalg/pull/218 -- SVD divid-and-conquer https://github.com/rust-ndarray/ndarray-linalg/pull/219 -- Tridiagonal https://github.com/rust-ndarray/ndarray-linalg/pull/235 - -Maintenance ------------ -- Coverage report using codecov https://github.com/rust-ndarray/ndarray-linalg/pull/215 -- Fix for clippy, and add CI check https://github.com/rust-ndarray/ndarray-linalg/pull/205 - -0.12.1 - 28 June 2020 -====================== - -Added ------- -- Tridiagonal matrix support https://github.com/rust-ndarray/ndarray-linalg/pull/196 -- KaTeX support in rustdoc https://github.com/rust-ndarray/ndarray-linalg/pull/202 -- Least square problems https://github.com/rust-ndarray/ndarray-linalg/pull/197 -- LOBPCG solver https://github.com/rust-ndarray/ndarray-linalg/pull/184 - -Changed -------- -- Grouping and Plot in benchmark https://github.com/rust-ndarray/ndarray-linalg/pull/200 -- `Clone` trait for `LUFactorized` https://github.com/rust-ndarray/ndarray-linalg/pull/192 - -Maintenance ------------ -- Fix repository URL https://github.com/rust-ndarray/ndarray-linalg/pull/198 -- Use GitHub Actions instead of Azure Pipeline https://github.com/rust-ndarray/ndarray-linalg/pull/193 -- Test cargo-fmt on CI https://github.com/rust-ndarray/ndarray-linalg/pull/194 - -0.12.0 - 14 Oct 2019 -==================== - -Added ------ -- SVD by divide-and-conquer https://github.com/rust-ndarray/ndarray-linalg/pull/164 -- Householder reflection https://github.com/rust-ndarray/ndarray-linalg/pull/154 -- Arnoldi iteration https://github.com/rust-ndarray/ndarray-linalg/pull/155 - -Changed ----------- -- Replace `operator::Operator*` traits by new `LinearOperator trait` https://github.com/rust-ndarray/ndarray-linalg/pull/159 -- ndarray 0.13.0 https://github.com/rust-ndarray/ndarray-linalg/pull/172 -- blas-src 0.4.0, lapack-src 0.4.0, openblas-src 0.7.0 https://github.com/rust-ndarray/ndarray-linalg/pull/174 -- restore `static` feature flag - -0.11.1 - 12 June 2019 -====================== - -- Hotfix for document generation https://github.com/rust-ndarray/ndarray-linalg/pull/153 - -0.11.0 - 12 June 2019 -==================== - -Added --------- -- Dependency to cauchy 0.2 https://github.com/rust-ndarray/ndarray-linalg/pull/139 -- `generate::random_{unitary,regular}` for debug use https://github.com/rust-ndarray/ndarray-linalg/pull/140 -- `krylov` submodule - - modified Gram-Schmit https://github.com/rust-ndarray/ndarray-linalg/pull/149 https://github.com/rust-ndarray/ndarray-linalg/pull/150 - - Krylov subspace methods are not implemented yet. - -Removed ----------- -- `static` feature https://github.com/rust-ndarray/ndarray-linalg/pull/136 - - See README for detail -- `accelerate` feature https://github.com/rust-ndarray/ndarray-linalg/pull/141 -- Dependencies to derive-new, procedurals - -Changed ---------- -- Switch CI service: Circle CI -> Azure Pipeline https://github.com/rust-ndarray/ndarray-linalg/pull/141 -- submodule `lapack_traits` is renamed to https://github.com/rust-ndarray/ndarray-linalg/pull/139 -- `ndarray_linalg::Scalar` trait is split into two parts https://github.com/rust-ndarray/ndarray-linalg/pull/139 - - [cauchy::Scalar](https://docs.rs/cauchy/0.2.0/cauchy/trait.Scalar.html) is a refined real/complex common trait - - `lapack::Lapack` is a trait for primitive types which LAPACK supports -- Error type becomes simple https://github.com/rust-ndarray/ndarray-linalg/pull/118 https://github.com/rust-ndarray/ndarray-linalg/pull/127 -- Assertions becomes more verbose https://github.com/rust-ndarray/ndarray-linalg/pull/147 -- blas-src 0.3, lapack-src 0.3 - - intel-mkl-src becomes 0.4, which supports Windows! https://github.com/rust-ndarray/ndarray-linalg/pull/146 - -0.10.0 - 2 Sep 2018 -=================== - -Update Dependencies --------------------- - -- ndarray 0.12 -- rand 0.5 -- num-complex 0.2 -- openblas-src 0.6 -- lapacke 0.2 - -See also https://github.com/rust-ndarray/ndarray-linalg/pull/110 - -Added ------- -- serde-1 feature gate https://github.com/rust-ndarray/ndarray-linalg/pull/99, https://github.com/rust-ndarray/ndarray-linalg/pull/116 diff --git a/README.md b/README.md index 4478dca3..85b1428c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ ndarray-linalg =============== -[![Crate](http://meritbadge.herokuapp.com/ndarray-linalg)](https://crates.io/crates/ndarray-linalg) +[![crate](https://img.shields.io/crates/v/ndarray-linalg.svg)](https://crates.io/crates/ndarray-linalg) [![docs.rs](https://docs.rs/ndarray-linalg/badge.svg)](https://docs.rs/ndarray-linalg) +[![master](https://img.shields.io/badge/docs-master-blue)](https://rust-ndarray.github.io/ndarray-linalg/ndarray_linalg/index.html) Linear algebra package for Rust with [ndarray](https://github.com/rust-ndarray/ndarray) based on external LAPACK implementations. @@ -61,7 +62,7 @@ Supported features are following: ### For library developer -If you creating a library depending on this crate, we encourage you not to link any backend: +If you are creating a library depending on this crate, we encourage you not to link any backend: ```toml [dependencies] @@ -85,25 +86,6 @@ Only x86_64 system is supported currently. |Netlib |✔️ |- |- | |Intel MKL|✔️ |✔️ |✔️ | -Generate document with KaTeX ------------------------------- - -You need to set `RUSTDOCFLAGS` explicitly: - -```shell -RUSTDOCFLAGS="--html-in-header katex-header.html" cargo doc --no-deps -``` - -This **only** works for `--no-deps` build because `katex-header.html` does not exists for dependent crates. -If you wish to set `RUSTDOCFLAGS` automatically in this crate, you can put [.cargo/config](https://doc.rust-lang.org/cargo/reference/config.html): - -```toml -[build] -rustdocflags = ["--html-in-header", "katex-header.html"] -``` - -But, be sure that this works only for `--no-deps`. `cargo doc` will fail with this `.cargo/config`. - License -------- diff --git a/katex-header.html b/katex-header.html deleted file mode 100644 index 6e10c052..00000000 --- a/katex-header.html +++ /dev/null @@ -1,16 +0,0 @@ - - - - - diff --git a/lax/CHANGELOG.md b/lax/CHANGELOG.md new file mode 100644 index 00000000..68e828b4 --- /dev/null +++ b/lax/CHANGELOG.md @@ -0,0 +1,15 @@ +Unreleased +----------- + +0.2.0 - 17 July 2021 +===================== + +Updated dependencies +--------------------- +- cauchy 0.4 (num-complex 0.4, rand 0.8), lapack 0.18 https://github.com/rust-ndarray/ndarray-linalg/pull/276 + +Fixed +----- +- Fix memory layout of the output of inverse of LUFactorized https://github.com/rust-ndarray/ndarray-linalg/pull/297 +- Fix Eig for column-major arrays with real elements https://github.com/rust-ndarray/ndarray-linalg/pull/298 +- Fix Solve::solve_h_* for complex inputs with standard layout https://github.com/rust-ndarray/ndarray-linalg/pull/296 diff --git a/lax/Cargo.toml b/lax/Cargo.toml index 812ad126..076d5cf2 100644 --- a/lax/Cargo.toml +++ b/lax/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lax" -version = "0.1.0" +version = "0.17.0" authors = ["Toshiki Teramura "] edition = "2018" @@ -25,31 +25,29 @@ netlib-system = ["netlib-src/system"] openblas-static = ["openblas-src/static"] openblas-system = ["openblas-src/system"] -intel-mkl-static = ["intel-mkl-src/mkl-static-lp64-seq", "intel-mkl-src/download"] +intel-mkl-static = ["intel-mkl-src/mkl-static-lp64-seq"] intel-mkl-system = ["intel-mkl-src/mkl-dynamic-lp64-seq"] [dependencies] -thiserror = "1.0.23" -cauchy = "0.3.0" +thiserror = "2.0.0" +cauchy = "0.4.0" num-traits = "0.2.14" -lapack = "0.17.0" +lapack-sys = "0.15.0" +katexit = "0.1.2" [dependencies.intel-mkl-src] -version = "0.6.0" +version = "0.8.1" default-features = false optional = true [dependencies.netlib-src] -version = "0.8.0" +version = "0.9.0" optional = true features = ["cblas"] default-features = false [dependencies.openblas-src] -version = "0.10.2" +version = "0.10.4" optional = true default-features = false features = ["cblas"] - -[package.metadata.release] -no-dev-version = true diff --git a/lax/README.md b/lax/README.md index ed563735..9dea6b32 100644 --- a/lax/README.md +++ b/lax/README.md @@ -1,6 +1,9 @@ Linear Algebra eXtension (LAX) =============================== +[![crates.io](https://img.shields.io/badge/crates.io-lax-blue)](https://crates.io/crates/lax) +[![docs.rs](https://docs.rs/lax/badge.svg)](https://docs.rs/lax) + ndarray-free safe Rust wrapper for LAPACK FFI for implementing ndarray-linalg crate. This crate responsibles for diff --git a/lax/src/alloc.rs b/lax/src/alloc.rs new file mode 100644 index 00000000..63458818 --- /dev/null +++ b/lax/src/alloc.rs @@ -0,0 +1,78 @@ +use cauchy::*; +use std::mem::MaybeUninit; + +/// Helper for getting pointer of slice +pub(crate) trait AsPtr: Sized { + type Elem; + fn as_ptr(vec: &[Self]) -> *const Self::Elem; + fn as_mut_ptr(vec: &mut [Self]) -> *mut Self::Elem; +} + +macro_rules! impl_as_ptr { + ($target:ty, $elem:ty) => { + impl AsPtr for $target { + type Elem = $elem; + fn as_ptr(vec: &[Self]) -> *const Self::Elem { + vec.as_ptr() as *const _ + } + fn as_mut_ptr(vec: &mut [Self]) -> *mut Self::Elem { + vec.as_mut_ptr() as *mut _ + } + } + }; +} +impl_as_ptr!(i32, i32); +impl_as_ptr!(f32, f32); +impl_as_ptr!(f64, f64); +impl_as_ptr!(c32, lapack_sys::__BindgenComplex); +impl_as_ptr!(c64, lapack_sys::__BindgenComplex); +impl_as_ptr!(MaybeUninit, i32); +impl_as_ptr!(MaybeUninit, f32); +impl_as_ptr!(MaybeUninit, f64); +impl_as_ptr!(MaybeUninit, lapack_sys::__BindgenComplex); +impl_as_ptr!(MaybeUninit, lapack_sys::__BindgenComplex); + +pub(crate) trait VecAssumeInit { + type Elem; + unsafe fn assume_init(self) -> Vec; + + /// An replacement of unstable API + /// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.slice_assume_init_ref + unsafe fn slice_assume_init_ref(&self) -> &[Self::Elem]; + + /// An replacement of unstable API + /// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.slice_assume_init_mut + unsafe fn slice_assume_init_mut(&mut self) -> &mut [Self::Elem]; +} + +impl VecAssumeInit for Vec> { + type Elem = T; + unsafe fn assume_init(self) -> Vec { + // FIXME use Vec::into_raw_parts instead after stablized + // https://doc.rust-lang.org/std/vec/struct.Vec.html#method.into_raw_parts + let mut me = std::mem::ManuallyDrop::new(self); + Vec::from_raw_parts(me.as_mut_ptr() as *mut T, me.len(), me.capacity()) + } + + unsafe fn slice_assume_init_ref(&self) -> &[T] { + std::slice::from_raw_parts(self.as_ptr() as *const T, self.len()) + } + + unsafe fn slice_assume_init_mut(&mut self) -> &mut [T] { + std::slice::from_raw_parts_mut(self.as_mut_ptr() as *mut T, self.len()) + } +} + +/// Create a vector without initialization +/// +/// Safety +/// ------ +/// - Memory is not initialized. Do not read the memory before write. +/// +pub(crate) fn vec_uninit(n: usize) -> Vec> { + let mut v = Vec::with_capacity(n); + unsafe { + v.set_len(n); + } + v +} diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 8305efe5..785f6e5e 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -1,27 +1,25 @@ -//! Cholesky decomposition +//! Factorize positive-definite symmetric/Hermitian matrices using Cholesky algorithm use super::*; use crate::{error::*, layout::*}; use cauchy::*; -pub trait Cholesky_: Sized { - /// Cholesky: wrapper of `*potrf` - /// - /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** +/// Compute Cholesky decomposition according to [UPLO] +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | spotrf | dpotrf | cpotrf | zpotrf | +/// +pub trait CholeskyImpl: Scalar { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - - /// Wrapper of `*potri` - /// - /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** - fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - - /// Wrapper of `*potrs` - fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } -macro_rules! impl_cholesky { - ($scalar:ty, $trf:path, $tri:path, $trs:path) => { - impl Cholesky_ for $scalar { +macro_rules! impl_cholesky_ { + ($s:ty, $trf:path) => { + impl CholeskyImpl for $s { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); if matches!(l, MatrixLayout::C { .. }) { @@ -29,7 +27,7 @@ macro_rules! impl_cholesky { } let mut info = 0; unsafe { - $trf(uplo as u8, n, a, n, &mut info); + $trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info); } info.as_lapack_result()?; if matches!(l, MatrixLayout::C { .. }) { @@ -37,7 +35,30 @@ macro_rules! impl_cholesky { } Ok(()) } + } + }; +} +impl_cholesky_!(c64, lapack_sys::zpotrf_); +impl_cholesky_!(c32, lapack_sys::cpotrf_); +impl_cholesky_!(f64, lapack_sys::dpotrf_); +impl_cholesky_!(f32, lapack_sys::spotrf_); + +/// Compute inverse matrix using Cholesky factroization result +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | spotri | dpotri | cpotri | zpotri | +/// +pub trait InvCholeskyImpl: Scalar { + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; +} +macro_rules! impl_inv_cholesky { + ($s:ty, $tri:path) => { + impl InvCholeskyImpl for $s { fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); if matches!(l, MatrixLayout::C { .. }) { @@ -45,7 +66,7 @@ macro_rules! impl_cholesky { } let mut info = 0; unsafe { - $tri(uplo as u8, n, a, l.lda(), &mut info); + $tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info); } info.as_lapack_result()?; if matches!(l, MatrixLayout::C { .. }) { @@ -53,7 +74,30 @@ macro_rules! impl_cholesky { } Ok(()) } + } + }; +} +impl_inv_cholesky!(c64, lapack_sys::zpotri_); +impl_inv_cholesky!(c32, lapack_sys::cpotri_); +impl_inv_cholesky!(f64, lapack_sys::dpotri_); +impl_inv_cholesky!(f32, lapack_sys::spotri_); +/// Solve linear equation using Cholesky factroization result +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | spotrs | dpotrs | cpotrs | zpotrs | +/// +pub trait SolveCholeskyImpl: Scalar { + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_solve_cholesky { + ($s:ty, $trs:path) => { + impl SolveCholeskyImpl for $s { fn solve_cholesky( l: MatrixLayout, mut uplo: UPLO, @@ -70,7 +114,16 @@ macro_rules! impl_cholesky { } } unsafe { - $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); + $trs( + uplo.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(a), + &l.lda(), + AsPtr::as_mut_ptr(b), + &n, + &mut info, + ); } info.as_lapack_result()?; if matches!(l, MatrixLayout::C { .. }) { @@ -82,9 +135,8 @@ macro_rules! impl_cholesky { } } }; -} // end macro_rules - -impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs); -impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs); -impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs); -impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs); +} +impl_solve_cholesky!(c64, lapack_sys::zpotrs_); +impl_solve_cholesky!(c32, lapack_sys::cpotrs_); +impl_solve_cholesky!(f64, lapack_sys::dpotrs_); +impl_solve_cholesky!(f32, lapack_sys::spotrs_); diff --git a/lax/src/eig.rs b/lax/src/eig.rs index 53245de7..f02035bb 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -1,165 +1,297 @@ -//! Eigenvalue decomposition for general matrices +//! Eigenvalue problem for general matricies +//! +//! LAPACK correspondance +//! ---------------------- +//! +//! | f32 | f64 | c32 | c64 | +//! |:------|:------|:------|:------| +//! | sgeev | dgeev | cgeev | zgeev | +//! use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -/// Wraps `*geev` for general matrices -pub trait Eig_: Scalar { - /// Calculate Right eigenvalue - fn eig( - calc_v: bool, - l: MatrixLayout, - a: &mut [Self], - ) -> Result<(Vec, Vec)>; +#[cfg_attr(doc, katexit::katexit)] +/// Eigenvalue problem for general matrix +/// +/// To manage memory more strictly, use [EigWork]. +/// +/// Right and Left eigenvalue problem +/// ---------------------------------- +/// LAPACK can solve both right eigenvalue problem +/// $$ +/// AV_R = V_R \Lambda +/// $$ +/// where $V_R = \left( v_R^1, \cdots, v_R^n \right)$ are right eigenvectors +/// and left eigenvalue problem +/// $$ +/// V_L^\dagger A = V_L^\dagger \Lambda +/// $$ +/// where $V_L = \left( v_L^1, \cdots, v_L^n \right)$ are left eigenvectors +/// and eigenvalues +/// $$ +/// \Lambda = \begin{pmatrix} +/// \lambda_1 & & 0 \\\\ +/// & \ddots & \\\\ +/// 0 & & \lambda_n +/// \end{pmatrix} +/// $$ +/// which satisfies $A v_R^i = \lambda_i v_R^i$ and +/// $\left(v_L^i\right)^\dagger A = \lambda_i \left(v_L^i\right)^\dagger$ +/// for column-major matrices, although row-major matrices are not supported. +/// Since a row-major matrix can be interpreted +/// as a transpose of a column-major matrix, +/// this transforms right eigenvalue problem to left one: +/// +/// $$ +/// A^\dagger V = V Λ ⟺ V^\dagger A = Λ V^\dagger +/// $$ +/// +#[non_exhaustive] +pub struct EigWork { + /// Problem size + pub n: i32, + /// Compute right eigenvectors or not + pub jobvr: JobEv, + /// Compute left eigenvectors or not + pub jobvl: JobEv, + + /// Eigenvalues + pub eigs: Vec>, + /// Real part of eigenvalues used in real routines + pub eigs_re: Option>>, + /// Imaginary part of eigenvalues used in real routines + pub eigs_im: Option>>, + + /// Left eigenvectors + pub vc_l: Option>>, + /// Left eigenvectors used in real routines + pub vr_l: Option>>, + /// Right eigenvectors + pub vc_r: Option>>, + /// Right eigenvectors used in real routines + pub vr_r: Option>>, + + /// Working memory + pub work: Vec>, + /// Working memory with `T::Real` + pub rwork: Option>>, +} + +impl EigWork +where + T: Scalar, + EigWork: EigWorkImpl, +{ + /// Create new working memory for eigenvalues compution. + pub fn new(calc_v: bool, l: MatrixLayout) -> Result { + EigWorkImpl::new(calc_v, l) + } + + /// Compute eigenvalues and vectors on this working memory. + pub fn calc(&mut self, a: &mut [T]) -> Result> { + EigWorkImpl::calc(self, a) + } + + /// Compute eigenvalues and vectors by consuming this working memory. + pub fn eval(self, a: &mut [T]) -> Result> { + EigWorkImpl::eval(self, a) + } +} + +/// Owned result of eigenvalue problem by [EigWork::eval] +#[derive(Debug, Clone, PartialEq)] +pub struct EigOwned { + /// Eigenvalues + pub eigs: Vec, + /// Right eigenvectors + pub vr: Option>, + /// Left eigenvectors + pub vl: Option>, } -macro_rules! impl_eig_complex { - ($scalar:ty, $ev:path) => { - impl Eig_ for $scalar { - fn eig( - calc_v: bool, - l: MatrixLayout, - mut a: &mut [Self], - ) -> Result<(Vec, Vec)> { +/// Reference result of eigenvalue problem by [EigWork::calc] +#[derive(Debug, Clone, PartialEq)] +pub struct EigRef<'work, T: Scalar> { + /// Eigenvalues + pub eigs: &'work [T::Complex], + /// Right eigenvectors + pub vr: Option<&'work [T::Complex]>, + /// Left eigenvectors + pub vl: Option<&'work [T::Complex]>, +} + +/// Helper trait for implementing [EigWork] methods +pub trait EigWorkImpl: Sized { + type Elem: Scalar; + fn new(calc_v: bool, l: MatrixLayout) -> Result; + fn calc<'work>(&'work mut self, a: &mut [Self::Elem]) -> Result>; + fn eval(self, a: &mut [Self::Elem]) -> Result>; +} + +macro_rules! impl_eig_work_c { + ($c:ty, $ev:path) => { + impl EigWorkImpl for EigWork<$c> { + type Elem = $c; + + fn new(calc_v: bool, l: MatrixLayout) -> Result { let (n, _) = l.size(); - // Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate. - // However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (b'V', b'N'), - MatrixLayout::F { .. } => (b'N', b'V'), + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), } } else { - (b'N', b'N') + (JobEv::None, JobEv::None) }; - let mut eigs = unsafe { vec_uninit(n as usize) }; - let mut rwork = unsafe { vec_uninit(2 * n as usize) }; + let mut eigs = vec_uninit(n as usize); + let mut rwork = vec_uninit(2 * n as usize); - let mut vl = if jobvl == b'V' { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; - let mut vr = if jobvr == b'V' { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; + let mut vc_l = jobvl.then(|| vec_uninit((n * n) as usize)); + let mut vc_r = jobvr.then(|| vec_uninit((n * n) as usize)); // calc work size let mut info = 0; - let mut work_size = [Self::zero()]; + let mut work_size = [<$c>::zero()]; unsafe { $ev( - jobvl, - jobvr, - n, - &mut a, - n, - &mut eigs, - &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, - &mut rwork, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(vc_l.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vc_r.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), &mut info, ) }; info.as_lapack_result()?; - // actal ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let work: Vec> = vec_uninit(lwork); + Ok(Self { + n, + jobvl, + jobvr, + eigs, + eigs_re: None, + eigs_im: None, + rwork: Some(rwork), + vc_l, + vc_r, + vr_l: None, + vr_r: None, + work, + }) + } + + fn calc<'work>( + &'work mut self, + a: &mut [Self::Elem], + ) -> Result> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; unsafe { $ev( - jobvl, - jobvr, - n, - &mut a, - n, - &mut eigs, - &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, - &mut rwork, + self.jobvl.as_ptr(), + self.jobvr.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(&mut self.eigs), + AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), &mut info, ) }; info.as_lapack_result()?; - // Hermite conjugate - if jobvl == b'V' { - for c in vl.as_mut().unwrap().iter_mut() { - c.im = -c.im + if let Some(vl) = self.vc_l.as_mut() { + for value in vl { + let value = unsafe { value.assume_init_mut() }; + value.im = -value.im; } } + Ok(EigRef { + eigs: unsafe { self.eigs.slice_assume_init_ref() }, + vl: self + .vc_l + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + vr: self + .vc_r + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + }) + } - Ok((eigs, vr.or(vl).unwrap_or(Vec::new()))) + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let _eig_ref = self.calc(a)?; + Ok(EigOwned { + eigs: unsafe { self.eigs.assume_init() }, + vl: self.vc_l.map(|v| unsafe { v.assume_init() }), + vr: self.vc_r.map(|v| unsafe { v.assume_init() }), + }) } } }; } -impl_eig_complex!(c64, lapack::zgeev); -impl_eig_complex!(c32, lapack::cgeev); - -macro_rules! impl_eig_real { - ($scalar:ty, $ev:path) => { - impl Eig_ for $scalar { - fn eig( - calc_v: bool, - l: MatrixLayout, - mut a: &mut [Self], - ) -> Result<(Vec, Vec)> { +impl_eig_work_c!(c32, lapack_sys::cgeev_); +impl_eig_work_c!(c64, lapack_sys::zgeev_); + +macro_rules! impl_eig_work_r { + ($f:ty, $ev:path) => { + impl EigWorkImpl for EigWork<$f> { + type Elem = $f; + + fn new(calc_v: bool, l: MatrixLayout) -> Result { let (n, _) = l.size(); - // Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate. - // However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A let (jobvl, jobvr) = if calc_v { match l { - MatrixLayout::C { .. } => (b'V', b'N'), - MatrixLayout::F { .. } => (b'N', b'V'), + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), } } else { - (b'N', b'N') - }; - let mut eig_re = unsafe { vec_uninit(n as usize) }; - let mut eig_im = unsafe { vec_uninit(n as usize) }; - - let mut vl = if jobvl == b'V' { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None - }; - let mut vr = if jobvr == b'V' { - Some(unsafe { vec_uninit((n * n) as usize) }) - } else { - None + (JobEv::None, JobEv::None) }; + let mut eigs_re = vec_uninit(n as usize); + let mut eigs_im = vec_uninit(n as usize); + let mut vr_l = jobvl.then(|| vec_uninit((n * n) as usize)); + let mut vr_r = jobvr.then(|| vec_uninit((n * n) as usize)); + let vc_l = jobvl.then(|| vec_uninit((n * n) as usize)); + let vc_r = jobvr.then(|| vec_uninit((n * n) as usize)); // calc work size let mut info = 0; - let mut work_size = [0.0]; + let mut work_size: [$f; 1] = [0.0]; unsafe { $ev( - jobvl, - jobvr, - n, - &mut a, - n, - &mut eig_re, - &mut eig_im, - vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut eigs_re), + AsPtr::as_mut_ptr(&mut eigs_im), + AsPtr::as_mut_ptr(vr_l.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr_r.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), &mut info, ) }; @@ -167,92 +299,153 @@ macro_rules! impl_eig_real { // actual ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let work = vec_uninit(lwork); + + Ok(Self { + n, + jobvr, + jobvl, + eigs: vec_uninit(n as usize), + eigs_re: Some(eigs_re), + eigs_im: Some(eigs_im), + rwork: None, + vr_l, + vr_r, + vc_l, + vc_r, + work, + }) + } + + fn calc<'work>( + &'work mut self, + a: &mut [Self::Elem], + ) -> Result> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; unsafe { $ev( - jobvl, - jobvr, - n, - &mut a, - n, - &mut eig_re, - &mut eig_im, - vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, + self.jobvl.as_ptr(), + self.jobvr.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(self.eigs_re.as_mut().unwrap()), + AsPtr::as_mut_ptr(self.eigs_im.as_mut().unwrap()), + AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, &mut info, ) }; info.as_lapack_result()?; - // reconstruct eigenvalues - let eigs: Vec = eig_re - .iter() - .zip(eig_im.iter()) - .map(|(&re, &im)| Self::complex(re, im)) - .collect(); + let eigs_re = self + .eigs_re + .as_ref() + .map(|e| unsafe { e.slice_assume_init_ref() }) + .unwrap(); + let eigs_im = self + .eigs_im + .as_ref() + .map(|e| unsafe { e.slice_assume_init_ref() }) + .unwrap(); + reconstruct_eigs(eigs_re, eigs_im, &mut self.eigs); - if !calc_v { - return Ok((eigs, Vec::new())); + if let Some(v) = self.vr_l.as_ref() { + let v = unsafe { v.slice_assume_init_ref() }; + reconstruct_eigenvectors(true, eigs_im, v, self.vc_l.as_mut().unwrap()); } - - // Reconstruct eigenvectors into complex-array - // -------------------------------------------- - // - // From LAPACK API https://software.intel.com/en-us/node/469230 - // - // - If the j-th eigenvalue is real, - // - v(j) = VR(:,j), the j-th column of VR. - // - // - If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, - // - v(j) = VR(:,j) + i*VR(:,j+1) - // - v(j+1) = VR(:,j) - i*VR(:,j+1). - // - // ``` - // j -> <----pair----> <----pair----> - // [ ... (real), (imag), (imag), (imag), (imag), ... ] : eigs - // ^ ^ ^ ^ ^ - // false false true false true : is_conjugate_pair - // ``` - let n = n as usize; - let v = vr.or(vl).unwrap(); - let mut eigvecs = unsafe { vec_uninit(n * n) }; - let mut is_conjugate_pair = false; // flag for check `j` is complex conjugate - for j in 0..n { - if eig_im[j] == 0.0 { - // j-th eigenvalue is real - for i in 0..n { - eigvecs[i + j * n] = Self::complex(v[i + j * n], 0.0); - } - } else { - // j-th eigenvalue is complex - // complex conjugated pair can be `j-1` or `j+1` - if is_conjugate_pair { - let j_pair = j - 1; - assert!(j_pair < n); - for i in 0..n { - eigvecs[i + j * n] = Self::complex(v[i + j_pair * n], v[i + j * n]); - } - } else { - let j_pair = j + 1; - assert!(j_pair < n); - for i in 0..n { - eigvecs[i + j * n] = - Self::complex(v[i + j * n], -v[i + j_pair * n]); - } - } - is_conjugate_pair = !is_conjugate_pair; - } + if let Some(v) = self.vr_r.as_ref() { + let v = unsafe { v.slice_assume_init_ref() }; + reconstruct_eigenvectors(false, eigs_im, v, self.vc_r.as_mut().unwrap()); } - Ok((eigs, eigvecs)) + Ok(EigRef { + eigs: unsafe { self.eigs.slice_assume_init_ref() }, + vl: self + .vc_l + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + vr: self + .vc_r + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + }) + } + + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let _eig_ref = self.calc(a)?; + Ok(EigOwned { + eigs: unsafe { self.eigs.assume_init() }, + vl: self.vc_l.map(|v| unsafe { v.assume_init() }), + vr: self.vc_r.map(|v| unsafe { v.assume_init() }), + }) } } }; } +impl_eig_work_r!(f32, lapack_sys::sgeev_); +impl_eig_work_r!(f64, lapack_sys::dgeev_); -impl_eig_real!(f64, lapack::dgeev); -impl_eig_real!(f32, lapack::sgeev); +/// Reconstruct eigenvectors into complex-array +/// +/// From LAPACK API https://software.intel.com/en-us/node/469230 +/// +/// - If the j-th eigenvalue is real, +/// - v(j) = VR(:,j), the j-th column of VR. +/// +/// - If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, +/// - v(j) = VR(:,j) + i*VR(:,j+1) +/// - v(j+1) = VR(:,j) - i*VR(:,j+1). +/// +/// In the C-layout case, we need the conjugates of the left +/// eigenvectors, so the signs should be reversed. +pub(crate) fn reconstruct_eigenvectors( + take_hermite_conjugate: bool, + eig_im: &[T], + vr: &[T], + vc: &mut [MaybeUninit], +) { + let n = eig_im.len(); + assert_eq!(vr.len(), n * n); + assert_eq!(vc.len(), n * n); + + let mut col = 0; + while col < n { + if eig_im[col].is_zero() { + // The corresponding eigenvalue is real. + for row in 0..n { + let re = vr[row + col * n]; + vc[row + col * n].write(T::complex(re, T::zero())); + } + col += 1; + } else { + // This is a complex conjugate pair. + assert!(col + 1 < n); + for row in 0..n { + let re = vr[row + col * n]; + let mut im = vr[row + (col + 1) * n]; + if take_hermite_conjugate { + im = -im; + } + vc[row + col * n].write(T::complex(re, im)); + vc[row + (col + 1) * n].write(T::complex(re, -im)); + } + col += 2; + } + } +} + +/// Create complex eigenvalues from real and imaginary parts. +fn reconstruct_eigs(re: &[T], im: &[T], eigs: &mut [MaybeUninit]) { + let n = eigs.len(); + assert_eq!(re.len(), n); + assert_eq!(im.len(), n); + for i in 0..n { + eigs[i].write(T::complex(re[i], im[i])); + } +} diff --git a/lax/src/eig_generalized.rs b/lax/src/eig_generalized.rs new file mode 100644 index 00000000..ea99dbdb --- /dev/null +++ b/lax/src/eig_generalized.rs @@ -0,0 +1,520 @@ +//! Generalized eigenvalue problem for general matrices +//! +//! LAPACK correspondance +//! ---------------------- +//! +//! | f32 | f64 | c32 | c64 | +//! |:------|:------|:------|:------| +//! | sggev | dggev | cggev | zggev | +//! +use std::mem::MaybeUninit; + +use crate::eig::reconstruct_eigenvectors; +use crate::{error::*, layout::MatrixLayout, *}; +use cauchy::*; +use num_traits::{ToPrimitive, Zero}; + +#[derive(Clone, PartialEq, Eq)] +pub enum GeneralizedEigenvalue { + /// Finite generalized eigenvalue: `Finite(α/β, (α, β))` + Finite(T, (T, T)), + + /// Indeterminate generalized eigenvalue: `Indeterminate((α, β))` + Indeterminate((T, T)), +} + +impl std::fmt::Display for GeneralizedEigenvalue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Finite(e, (a, b)) => write!(f, "{e:.3e} ({a:.3e}/{b:.3e})"), + Self::Indeterminate((a, b)) => write!(f, "∞ ({a:.3e}/{b:.3e})"), + } + } +} + +#[non_exhaustive] +pub struct EigGeneralizedWork { + /// Problem size + pub n: i32, + /// Compute right eigenvectors or not + pub jobvr: JobEv, + /// Compute left eigenvectors or not + pub jobvl: JobEv, + + /// Eigenvalues: alpha (numerators) + pub alpha: Vec>, + /// Eigenvalues: beta (denominators) + pub beta: Vec>, + /// Real part of alpha (eigenvalue numerators) used in real routines + pub alpha_re: Option>>, + /// Imaginary part of alpha (eigenvalue numerators) used in real routines + pub alpha_im: Option>>, + /// Real part of beta (eigenvalue denominators) used in real routines + pub beta_re: Option>>, + /// Imaginary part of beta (eigenvalue denominators) used in real routines + pub beta_im: Option>>, + + /// Left eigenvectors + pub vc_l: Option>>, + /// Left eigenvectors used in real routines + pub vr_l: Option>>, + /// Right eigenvectors + pub vc_r: Option>>, + /// Right eigenvectors used in real routines + pub vr_r: Option>>, + + /// Working memory + pub work: Vec>, + /// Working memory with `T::Real` + pub rwork: Option>>, +} + +impl EigGeneralizedWork +where + T: Scalar, + EigGeneralizedWork: EigGeneralizedWorkImpl, +{ + /// Create new working memory for eigenvalues compution. + pub fn new(calc_v: bool, l: MatrixLayout) -> Result { + EigGeneralizedWorkImpl::new(calc_v, l) + } + + /// Compute eigenvalues and vectors on this working memory. + pub fn calc(&mut self, a: &mut [T], b: &mut [T]) -> Result> { + EigGeneralizedWorkImpl::calc(self, a, b) + } + + /// Compute eigenvalues and vectors by consuming this working memory. + pub fn eval(self, a: &mut [T], b: &mut [T]) -> Result> { + EigGeneralizedWorkImpl::eval(self, a, b) + } +} + +/// Owned result of eigenvalue problem by [EigGeneralizedWork::eval] +#[derive(Debug, Clone, PartialEq)] +pub struct EigGeneralizedOwned { + /// Eigenvalues + pub alpha: Vec, + + pub beta: Vec, + + /// Right eigenvectors + pub vr: Option>, + + /// Left eigenvectors + pub vl: Option>, +} + +/// Reference result of eigenvalue problem by [EigGeneralizedWork::calc] +#[derive(Debug, Clone, PartialEq)] +pub struct EigGeneralizedRef<'work, T: Scalar> { + /// Eigenvalues + pub alpha: &'work [T::Complex], + + pub beta: &'work [T::Complex], + + /// Right eigenvectors + pub vr: Option<&'work [T::Complex]>, + + /// Left eigenvectors + pub vl: Option<&'work [T::Complex]>, +} + +/// Helper trait for implementing [EigGeneralizedWork] methods +pub trait EigGeneralizedWorkImpl: Sized { + type Elem: Scalar; + fn new(calc_v: bool, l: MatrixLayout) -> Result; + fn calc<'work>( + &'work mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result>; + fn eval( + self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result>; +} + +macro_rules! impl_eig_generalized_work_c { + ($f:ty, $c:ty, $ev:path) => { + impl EigGeneralizedWorkImpl for EigGeneralizedWork<$c> { + type Elem = $c; + + fn new(calc_v: bool, l: MatrixLayout) -> Result { + let (n, _) = l.size(); + let (jobvl, jobvr) = if calc_v { + match l { + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), + } + } else { + (JobEv::None, JobEv::None) + }; + let mut rwork = vec_uninit(8 * n as usize); + + let mut alpha = vec_uninit(n as usize); + let mut beta = vec_uninit(n as usize); + + let mut vc_l = jobvl.then(|| vec_uninit((n * n) as usize)); + let mut vc_r = jobvr.then(|| vec_uninit((n * n) as usize)); + + // calc work size + let mut info = 0; + let mut work_size = [<$c>::zero()]; + unsafe { + $ev( + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + std::ptr::null_mut(), + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut alpha), + AsPtr::as_mut_ptr(&mut beta), + AsPtr::as_mut_ptr(vc_l.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vc_r.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), + &mut info, + ) + }; + info.as_lapack_result()?; + + let lwork = work_size[0].to_usize().unwrap(); + let work: Vec> = vec_uninit(lwork); + Ok(Self { + n, + jobvl, + jobvr, + alpha, + beta, + alpha_re: None, + alpha_im: None, + beta_re: None, + beta_im: None, + rwork: Some(rwork), + vc_l, + vc_r, + vr_l: None, + vr_r: None, + work, + }) + } + + fn calc<'work>( + &'work mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $ev( + self.jobvl.as_ptr(), + self.jobvr.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(b), + &self.n, + AsPtr::as_mut_ptr(&mut self.alpha), + AsPtr::as_mut_ptr(&mut self.beta), + AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), + &mut info, + ) + }; + info.as_lapack_result()?; + // Hermite conjugate + if let Some(vl) = self.vc_l.as_mut() { + for value in vl { + let value = unsafe { value.assume_init_mut() }; + value.im = -value.im; + } + } + Ok(EigGeneralizedRef { + alpha: unsafe { self.alpha.slice_assume_init_ref() }, + beta: unsafe { self.beta.slice_assume_init_ref() }, + vl: self + .vc_l + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + vr: self + .vc_r + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + }) + } + + fn eval( + mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result> { + let _eig_generalized_ref = self.calc(a, b)?; + Ok(EigGeneralizedOwned { + alpha: unsafe { self.alpha.assume_init() }, + beta: unsafe { self.beta.assume_init() }, + vl: self.vc_l.map(|v| unsafe { v.assume_init() }), + vr: self.vc_r.map(|v| unsafe { v.assume_init() }), + }) + } + } + + impl EigGeneralizedOwned<$c> { + pub fn calc_eigs(&self, thresh_opt: Option<$f>) -> Vec> { + self.alpha + .iter() + .zip(self.beta.iter()) + .map(|(alpha, beta)| { + if let Some(thresh) = thresh_opt { + if beta.abs() < thresh { + GeneralizedEigenvalue::Indeterminate((alpha.clone(), beta.clone())) + } else { + GeneralizedEigenvalue::Finite( + alpha / beta, + (alpha.clone(), beta.clone()), + ) + } + } else { + if beta.is_zero() { + GeneralizedEigenvalue::Indeterminate((alpha.clone(), beta.clone())) + } else { + GeneralizedEigenvalue::Finite( + alpha / beta, + (alpha.clone(), beta.clone()), + ) + } + } + }) + .collect::>() + } + } + }; +} + +impl_eig_generalized_work_c!(f32, c32, lapack_sys::cggev_); +impl_eig_generalized_work_c!(f64, c64, lapack_sys::zggev_); + +macro_rules! impl_eig_generalized_work_r { + ($f:ty, $c:ty, $ev:path) => { + impl EigGeneralizedWorkImpl for EigGeneralizedWork<$f> { + type Elem = $f; + + fn new(calc_v: bool, l: MatrixLayout) -> Result { + let (n, _) = l.size(); + let (jobvl, jobvr) = if calc_v { + match l { + MatrixLayout::C { .. } => (JobEv::All, JobEv::None), + MatrixLayout::F { .. } => (JobEv::None, JobEv::All), + } + } else { + (JobEv::None, JobEv::None) + }; + let mut alpha_re = vec_uninit(n as usize); + let mut alpha_im = vec_uninit(n as usize); + let mut beta_re = vec_uninit(n as usize); + let mut vr_l = jobvl.then(|| vec_uninit((n * n) as usize)); + let mut vr_r = jobvr.then(|| vec_uninit((n * n) as usize)); + let vc_l = jobvl.then(|| vec_uninit((n * n) as usize)); + let vc_r = jobvr.then(|| vec_uninit((n * n) as usize)); + + // calc work size + let mut info = 0; + let mut work_size: [$f; 1] = [0.0]; + unsafe { + $ev( + jobvl.as_ptr(), + jobvr.as_ptr(), + &n, + std::ptr::null_mut(), + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut alpha_re), + AsPtr::as_mut_ptr(&mut alpha_im), + AsPtr::as_mut_ptr(&mut beta_re), + AsPtr::as_mut_ptr(vr_l.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(vr_r.as_deref_mut().unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; + info.as_lapack_result()?; + + // actual ev + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + + Ok(Self { + n, + jobvr, + jobvl, + alpha: vec_uninit(n as usize), + beta: vec_uninit(n as usize), + alpha_re: Some(alpha_re), + alpha_im: Some(alpha_im), + beta_re: Some(beta_re), + beta_im: None, + rwork: None, + vr_l, + vr_r, + vc_l, + vc_r, + work, + }) + } + + fn calc<'work>( + &'work mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $ev( + self.jobvl.as_ptr(), + self.jobvr.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(b), + &self.n, + AsPtr::as_mut_ptr(self.alpha_re.as_mut().unwrap()), + AsPtr::as_mut_ptr(self.alpha_im.as_mut().unwrap()), + AsPtr::as_mut_ptr(self.beta_re.as_mut().unwrap()), + AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])), + &self.n, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + let alpha_re = self + .alpha_re + .as_ref() + .map(|e| unsafe { e.slice_assume_init_ref() }) + .unwrap(); + let alpha_im = self + .alpha_im + .as_ref() + .map(|e| unsafe { e.slice_assume_init_ref() }) + .unwrap(); + let beta_re = self + .beta_re + .as_ref() + .map(|e| unsafe { e.slice_assume_init_ref() }) + .unwrap(); + reconstruct_eigs_optional_im(alpha_re, Some(alpha_im), &mut self.alpha); + reconstruct_eigs_optional_im(beta_re, None, &mut self.beta); + + if let Some(v) = self.vr_l.as_ref() { + let v = unsafe { v.slice_assume_init_ref() }; + reconstruct_eigenvectors(true, alpha_im, v, self.vc_l.as_mut().unwrap()); + } + if let Some(v) = self.vr_r.as_ref() { + let v = unsafe { v.slice_assume_init_ref() }; + reconstruct_eigenvectors(false, alpha_im, v, self.vc_r.as_mut().unwrap()); + } + + Ok(EigGeneralizedRef { + alpha: unsafe { self.alpha.slice_assume_init_ref() }, + beta: unsafe { self.beta.slice_assume_init_ref() }, + vl: self + .vc_l + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + vr: self + .vc_r + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }), + }) + } + + fn eval( + mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result> { + let _eig_generalized_ref = self.calc(a, b)?; + Ok(EigGeneralizedOwned { + alpha: unsafe { self.alpha.assume_init() }, + beta: unsafe { self.beta.assume_init() }, + vl: self.vc_l.map(|v| unsafe { v.assume_init() }), + vr: self.vc_r.map(|v| unsafe { v.assume_init() }), + }) + } + } + + impl EigGeneralizedOwned<$f> { + pub fn calc_eigs(&self, thresh_opt: Option<$f>) -> Vec> { + self.alpha + .iter() + .zip(self.beta.iter()) + .map(|(alpha, beta)| { + if let Some(thresh) = thresh_opt { + if beta.abs() < thresh { + GeneralizedEigenvalue::Indeterminate((alpha.clone(), beta.clone())) + } else { + GeneralizedEigenvalue::Finite( + alpha / beta, + (alpha.clone(), beta.clone()), + ) + } + } else { + if beta.is_zero() { + GeneralizedEigenvalue::Indeterminate((alpha.clone(), beta.clone())) + } else { + GeneralizedEigenvalue::Finite( + alpha / beta, + (alpha.clone(), beta.clone()), + ) + } + } + }) + .collect::>() + } + } + }; +} +impl_eig_generalized_work_r!(f32, c32, lapack_sys::sggev_); +impl_eig_generalized_work_r!(f64, c64, lapack_sys::dggev_); + +/// Create complex eigenvalues from real and optional imaginary parts. +fn reconstruct_eigs_optional_im( + re: &[T], + im_opt: Option<&[T]>, + eigs: &mut [MaybeUninit], +) { + let n = eigs.len(); + assert_eq!(re.len(), n); + + if let Some(im) = im_opt { + assert_eq!(im.len(), n); + for i in 0..n { + eigs[i].write(T::complex(re[i], im[i])); + } + } else { + for i in 0..n { + eigs[i].write(T::complex(re[i], T::zero())); + } + } +} diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index 46a3b131..bb3ca500 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -1,159 +1,190 @@ -//! Eigenvalue decomposition for Symmetric/Hermite matrices +//! Eigenvalue problem for symmetric/Hermitian matricies +//! +//! LAPACK correspondance +//! ---------------------- +//! +//! | f32 | f64 | c32 | c64 | +//! |:------|:------|:------|:------| +//! | ssyev | dsyev | cheev | zheev | use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -pub trait Eigh_: Scalar { - /// Wraps `*syev` for real and `*heev` for complex - fn eigh( - calc_eigenvec: bool, - layout: MatrixLayout, - uplo: UPLO, - a: &mut [Self], - ) -> Result>; +pub struct EighWork { + pub n: i32, + pub jobz: JobEv, + pub eigs: Vec>, + pub work: Vec>, + pub rwork: Option>>, +} - /// Wraps `*syegv` for real and `*heegv` for complex - fn eigh_generalized( - calc_eigenvec: bool, - layout: MatrixLayout, - uplo: UPLO, - a: &mut [Self], - b: &mut [Self], - ) -> Result>; +pub trait EighWorkImpl: Sized { + type Elem: Scalar; + fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result; + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) + -> Result<&[::Real]>; + fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result::Real>>; } -macro_rules! impl_eigh { - (@real, $scalar:ty, $ev:path, $evg:path) => { - impl_eigh!(@body, $scalar, $ev, $evg, ); - }; - (@complex, $scalar:ty, $ev:path, $evg:path) => { - impl_eigh!(@body, $scalar, $ev, $evg, rwork); - }; - (@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => { - impl Eigh_ for $scalar { - fn eigh( - calc_v: bool, - layout: MatrixLayout, - uplo: UPLO, - mut a: &mut [Self], - ) -> Result> { +macro_rules! impl_eigh_work_c { + ($c:ty, $ev:path) => { + impl EighWorkImpl for EighWork<$c> { + type Elem = $c; + + fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); - let jobz = if calc_v { b'V' } else { b'N' }; - let mut eigs = unsafe { vec_uninit(n as usize) }; - - $( - let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2 as usize) }; - )* - - // calc work size + let jobz = if calc_eigenvectors { + JobEv::All + } else { + JobEv::None + }; + let mut eigs = vec_uninit(n as usize); + let mut rwork = vec_uninit(3 * n as usize - 2 as usize); let mut info = 0; - let mut work_size = [Self::zero()]; + let mut work_size = [Self::Elem::zero()]; unsafe { $ev( - jobz, - uplo as u8, - n, - &mut a, - n, - &mut eigs, - &mut work_size, - -1, - $(&mut $rwork_ident,)* + jobz.as_ptr(), + UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), &mut info, ); } info.as_lapack_result()?; - - // actual ev let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let work = vec_uninit(lwork); + Ok(EighWork { + n, + eigs, + jobz, + work, + rwork: Some(rwork), + }) + } + + fn calc( + &mut self, + uplo: UPLO, + a: &mut [Self::Elem], + ) -> Result<&[::Real]> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; unsafe { $ev( - jobz, - uplo as u8, - n, - &mut a, - n, - &mut eigs, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + self.jobz.as_ptr(), + uplo.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(&mut self.eigs), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), &mut info, ); } info.as_lapack_result()?; - Ok(eigs) + Ok(unsafe { self.eigs.slice_assume_init_ref() }) } - fn eigh_generalized( - calc_v: bool, - layout: MatrixLayout, + fn eval( + mut self, uplo: UPLO, - mut a: &mut [Self], - mut b: &mut [Self], - ) -> Result> { - assert_eq!(layout.len(), layout.lda()); - let n = layout.len(); - let jobz = if calc_v { b'V' } else { b'N' }; - let mut eigs = unsafe { vec_uninit(n as usize) }; + a: &mut [Self::Elem], + ) -> Result::Real>> { + let _eig = self.calc(uplo, a)?; + Ok(unsafe { self.eigs.assume_init() }) + } + } + }; +} +impl_eigh_work_c!(c64, lapack_sys::zheev_); +impl_eigh_work_c!(c32, lapack_sys::cheev_); - $( - let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2) }; - )* +macro_rules! impl_eigh_work_r { + ($f:ty, $ev:path) => { + impl EighWorkImpl for EighWork<$f> { + type Elem = $f; - // calc work size + fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result { + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); + let jobz = if calc_eigenvectors { + JobEv::All + } else { + JobEv::None + }; + let mut eigs = vec_uninit(n as usize); let mut info = 0; - let mut work_size = [Self::zero()]; + let mut work_size = [Self::Elem::zero()]; unsafe { - $evg( - &[1], - jobz, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut eigs, - &mut work_size, - -1, - $(&mut $rwork_ident,)* + $ev( + jobz.as_ptr(), + UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), &mut info, ); } info.as_lapack_result()?; - - // actual evg let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let work = vec_uninit(lwork); + Ok(EighWork { + n, + eigs, + jobz, + work, + rwork: None, + }) + } + + fn calc( + &mut self, + uplo: UPLO, + a: &mut [Self::Elem], + ) -> Result<&[::Real]> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; unsafe { - $evg( - &[1], - jobz, - uplo as u8, - n, - &mut a, - n, - &mut b, - n, - &mut eigs, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + $ev( + self.jobz.as_ptr(), + uplo.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(&mut self.eigs), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, &mut info, ); } info.as_lapack_result()?; - Ok(eigs) + Ok(unsafe { self.eigs.slice_assume_init_ref() }) + } + + fn eval( + mut self, + uplo: UPLO, + a: &mut [Self::Elem], + ) -> Result::Real>> { + let _eig = self.calc(uplo, a)?; + Ok(unsafe { self.eigs.assume_init() }) } } }; -} // impl_eigh! - -impl_eigh!(@real, f64, lapack::dsyev, lapack::dsygv); -impl_eigh!(@real, f32, lapack::ssyev, lapack::ssygv); -impl_eigh!(@complex, c64, lapack::zheev, lapack::zhegv); -impl_eigh!(@complex, c32, lapack::cheev, lapack::chegv); +} +impl_eigh_work_r!(f64, lapack_sys::dsyev_); +impl_eigh_work_r!(f32, lapack_sys::ssyev_); diff --git a/lax/src/eigh_generalized.rs b/lax/src/eigh_generalized.rs new file mode 100644 index 00000000..5d4d83ca --- /dev/null +++ b/lax/src/eigh_generalized.rs @@ -0,0 +1,216 @@ +//! Generalized eigenvalue problem for symmetric/Hermitian matrices +//! +//! LAPACK correspondance +//! ---------------------- +//! +//! | f32 | f64 | c32 | c64 | +//! |:------|:------|:------|:------| +//! | ssygv | dsygv | chegv | zhegv | +//! + +use super::*; +use crate::{error::*, layout::MatrixLayout}; +use cauchy::*; +use num_traits::{ToPrimitive, Zero}; + +pub struct EighGeneralizedWork { + pub n: i32, + pub jobz: JobEv, + pub eigs: Vec>, + pub work: Vec>, + pub rwork: Option>>, +} + +pub trait EighGeneralizedWorkImpl: Sized { + type Elem: Scalar; + fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result; + fn calc( + &mut self, + uplo: UPLO, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result<&[::Real]>; + fn eval( + self, + uplo: UPLO, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result::Real>>; +} + +macro_rules! impl_eigh_generalized_work_c { + ($c:ty, $gv:path) => { + impl EighGeneralizedWorkImpl for EighGeneralizedWork<$c> { + type Elem = $c; + + fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result { + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); + let jobz = if calc_eigenvectors { + JobEv::All + } else { + JobEv::None + }; + let mut eigs = vec_uninit(n as usize); + let mut rwork = vec_uninit(3 * n as usize - 2 as usize); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $gv( + &1, // ITYPE A*x = (lambda)*B*x + jobz.as_ptr(), + UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO + &n, + std::ptr::null_mut(), + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), + &mut info, + ); + } + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(EighGeneralizedWork { + n, + eigs, + jobz, + work, + rwork: Some(rwork), + }) + } + + fn calc( + &mut self, + uplo: UPLO, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result<&[::Real]> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $gv( + &1, // ITYPE A*x = (lambda)*B*x + self.jobz.as_ptr(), + uplo.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(b), + &self.n, + AsPtr::as_mut_ptr(&mut self.eigs), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), + &mut info, + ); + } + info.as_lapack_result()?; + Ok(unsafe { self.eigs.slice_assume_init_ref() }) + } + + fn eval( + mut self, + uplo: UPLO, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result::Real>> { + let _eig = self.calc(uplo, a, b)?; + Ok(unsafe { self.eigs.assume_init() }) + } + } + }; +} +impl_eigh_generalized_work_c!(c64, lapack_sys::zhegv_); +impl_eigh_generalized_work_c!(c32, lapack_sys::chegv_); + +macro_rules! impl_eigh_generalized_work_r { + ($f:ty, $gv:path) => { + impl EighGeneralizedWorkImpl for EighGeneralizedWork<$f> { + type Elem = $f; + + fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result { + assert_eq!(layout.len(), layout.lda()); + let n = layout.len(); + let jobz = if calc_eigenvectors { + JobEv::All + } else { + JobEv::None + }; + let mut eigs = vec_uninit(n as usize); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $gv( + &1, // ITYPE A*x = (lambda)*B*x + jobz.as_ptr(), + UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO + &n, + std::ptr::null_mut(), + &n, + std::ptr::null_mut(), + &n, + AsPtr::as_mut_ptr(&mut eigs), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ); + } + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(EighGeneralizedWork { + n, + eigs, + jobz, + work, + rwork: None, + }) + } + + fn calc( + &mut self, + uplo: UPLO, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result<&[::Real]> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $gv( + &1, // ITYPE A*x = (lambda)*B*x + self.jobz.as_ptr(), + uplo.as_ptr(), + &self.n, + AsPtr::as_mut_ptr(a), + &self.n, + AsPtr::as_mut_ptr(b), + &self.n, + AsPtr::as_mut_ptr(&mut self.eigs), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ); + } + info.as_lapack_result()?; + Ok(unsafe { self.eigs.slice_assume_init_ref() }) + } + + fn eval( + mut self, + uplo: UPLO, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result::Real>> { + let _eig = self.calc(uplo, a, b)?; + Ok(unsafe { self.eigs.assume_init() }) + } + } + }; +} +impl_eigh_generalized_work_r!(f64, lapack_sys::dsygv_); +impl_eigh_generalized_work_r!(f32, lapack_sys::ssygv_); diff --git a/lax/src/error.rs b/lax/src/error.rs index fb4b9838..e1c314ee 100644 --- a/lax/src/error.rs +++ b/lax/src/error.rs @@ -11,7 +11,7 @@ pub enum Error { LapackInvalidValue { return_code: i32 }, #[error( - "Comutational failure in LAPACK subroutine: return_code = {}", + "Computational failure in LAPACK subroutine: return_code = {}", return_code )] LapackComputationalFailure { return_code: i32 }, diff --git a/lax/src/flags.rs b/lax/src/flags.rs new file mode 100644 index 00000000..f9dea20d --- /dev/null +++ b/lax/src/flags.rs @@ -0,0 +1,140 @@ +//! Charactor flags, e.g. `'T'`, used in LAPACK API +use core::ffi::c_char; + +/// Upper/Lower specification for seveal usages +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum UPLO { + Upper = b'U', + Lower = b'L', +} + +impl UPLO { + pub fn t(self) -> Self { + match self { + UPLO::Upper => UPLO::Lower, + UPLO::Lower => UPLO::Upper, + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const c_char { + self as *const UPLO as *const c_char + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum Transpose { + No = b'N', + Transpose = b'T', + Hermite = b'C', +} + +impl Transpose { + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const c_char { + self as *const Transpose as *const c_char + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum NormType { + One = b'O', + Infinity = b'I', + Frobenius = b'F', +} + +impl NormType { + pub fn transpose(self) -> Self { + match self { + NormType::One => NormType::Infinity, + NormType::Infinity => NormType::One, + NormType::Frobenius => NormType::Frobenius, + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const c_char { + self as *const NormType as *const c_char + } +} + +/// Flag for calculating eigenvectors or not +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum JobEv { + /// Calculate eigenvectors in addition to eigenvalues + All = b'V', + /// Do not calculate eigenvectors. Only calculate eigenvalues. + None = b'N', +} + +impl JobEv { + pub fn is_calc(&self) -> bool { + match self { + JobEv::All => true, + JobEv::None => false, + } + } + + pub fn then T>(&self, f: F) -> Option { + if self.is_calc() { + Some(f()) + } else { + None + } + } + + /// To use Fortran LAPACK API in lapack-sys crate + pub fn as_ptr(&self) -> *const c_char { + self as *const JobEv as *const c_char + } +} + +/// Specifies how many singular vectors are computed +/// +/// For an input matrix $A$ of shape $m \times n$, +/// the following are computed on the singular value decomposition $A = U\Sigma V^T$: +#[cfg_attr(doc, katexit::katexit)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum JobSvd { + /// All $m$ columns of $U$, and/or all $n$ rows of $V^T$. + All = b'A', + /// The first $\min(m, n)$ columns of $U$ and/or the first $\min(m, n)$ rows of $V^T$. + Some = b'S', + /// No columns of $U$ and/or rows of $V^T$. + None = b'N', +} + +impl JobSvd { + pub fn from_bool(calc_uv: bool) -> Self { + if calc_uv { + JobSvd::All + } else { + JobSvd::None + } + } + + pub fn as_ptr(&self) -> *const c_char { + self as *const JobSvd as *const c_char + } +} + +/// Specify whether input triangular matrix is unit or not +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum Diag { + /// Unit triangular matrix, i.e. all diagonal elements of the matrix are `1` + Unit = b'U', + /// Non-unit triangular matrix. Its diagonal elements may be different from `1` + NonUnit = b'N', +} + +impl Diag { + pub fn as_ptr(&self) -> *const c_char { + self as *const Diag as *const c_char + } +} diff --git a/lax/src/layout.rs b/lax/src/layout.rs index e7ab1da4..28b35122 100644 --- a/lax/src/layout.rs +++ b/lax/src/layout.rs @@ -37,7 +37,7 @@ //! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`. //! -use cauchy::Scalar; +use super::*; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MatrixLayout { @@ -153,7 +153,7 @@ impl MatrixLayout { /// ------ /// - If size of `a` and `layout` size mismatch /// -pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { +pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { let (m, n) = layout.size(); let n = n as usize; let m = m as usize; @@ -162,23 +162,78 @@ pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { for j in (i + 1)..n { let a_ij = a[i * n + j]; let a_ji = a[j * m + i]; - a[i * n + j] = a_ji.conj(); - a[j * m + i] = a_ij.conj(); + a[i * n + j] = a_ji; + a[j * m + i] = a_ij; } } } /// Out-place transpose for general matrix /// -/// Inplace transpose of non-square matrices is hard. -/// See also: https://en.wikipedia.org/wiki/In-place_matrix_transposition +/// Examples +/// --------- +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::C { row: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let (l, b) = transpose(layout, &a); +/// assert_eq!(l, MatrixLayout::F { col: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::F { col: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let (l, b) = transpose(layout, &a); +/// assert_eq!(l, MatrixLayout::C { row: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// Panics +/// ------ +/// - If input array size and `layout` size mismatch +/// +pub fn transpose(layout: MatrixLayout, input: &[T]) -> (MatrixLayout, Vec) { + let (m, n) = layout.size(); + let transposed = layout.resized(n, m).t(); + let m = m as usize; + let n = n as usize; + assert_eq!(input.len(), m * n); + + let mut out: Vec> = vec_uninit(m * n); + + match layout { + MatrixLayout::C { .. } => { + for i in 0..m { + for j in 0..n { + out[j * m + i].write(input[i * n + j]); + } + } + } + MatrixLayout::F { .. } => { + for i in 0..m { + for j in 0..n { + out[i * n + j].write(input[j * m + i]); + } + } + } + } + (transposed, unsafe { out.assume_init() }) +} + +/// Out-place transpose for general matrix +/// +/// Examples +/// --------- /// /// ```rust /// # use lax::layout::*; /// let layout = MatrixLayout::C { row: 2, lda: 3 }; /// let a = vec![1., 2., 3., 4., 5., 6.]; /// let mut b = vec![0.0; a.len()]; -/// let l = transpose(layout, &a, &mut b); +/// let l = transpose_over(layout, &a, &mut b); /// assert_eq!(l, MatrixLayout::F { col: 3, lda: 2 }); /// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); /// ``` @@ -188,16 +243,16 @@ pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { /// let layout = MatrixLayout::F { col: 2, lda: 3 }; /// let a = vec![1., 2., 3., 4., 5., 6.]; /// let mut b = vec![0.0; a.len()]; -/// let l = transpose(layout, &a, &mut b); +/// let l = transpose_over(layout, &a, &mut b); /// assert_eq!(l, MatrixLayout::C { row: 3, lda: 2 }); /// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); /// ``` /// /// Panics /// ------ -/// - If size of `a` and `layout` size mismatch +/// - If input array sizes and `layout` size mismatch /// -pub fn transpose(layout: MatrixLayout, from: &[T], to: &mut [T]) -> MatrixLayout { +pub fn transpose_over(layout: MatrixLayout, from: &[T], to: &mut [T]) -> MatrixLayout { let (m, n) = layout.size(); let transposed = layout.resized(n, m).t(); let m = m as usize; diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index fc378aa6..d0bb7def 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -5,154 +5,328 @@ use cauchy::*; use num_traits::{ToPrimitive, Zero}; /// Result of LeastSquares -pub struct LeastSquaresOutput { +pub struct LeastSquaresOwned { /// singular values pub singular_values: Vec, /// The rank of the input matrix A pub rank: i32, } -/// Wraps `*gelsd` -pub trait LeastSquaresSvdDivideConquer_: Scalar { - fn least_squares( - a_layout: MatrixLayout, - a: &mut [Self], - b: &mut [Self], - ) -> Result>; - - fn least_squares_nrhs( - a_layout: MatrixLayout, - a: &mut [Self], - b_layout: MatrixLayout, - b: &mut [Self], - ) -> Result>; +/// Result of LeastSquares +pub struct LeastSquaresRef<'work, A: Scalar> { + /// singular values + pub singular_values: &'work [A::Real], + /// The rank of the input matrix A + pub rank: i32, } -macro_rules! impl_least_squares { - (@real, $scalar:ty, $gelsd:path) => { - impl_least_squares!(@body, $scalar, $gelsd, ); - }; - (@complex, $scalar:ty, $gelsd:path) => { - impl_least_squares!(@body, $scalar, $gelsd, rwork); - }; +pub struct LeastSquaresWork { + pub a_layout: MatrixLayout, + pub b_layout: MatrixLayout, + pub singular_values: Vec>, + pub work: Vec>, + pub iwork: Vec>, + pub rwork: Option>>, +} - (@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => { - impl LeastSquaresSvdDivideConquer_ for $scalar { - fn least_squares( - l: MatrixLayout, - a: &mut [Self], - b: &mut [Self], - ) -> Result> { - let b_layout = l.resized(b.len() as i32, 1); - Self::least_squares_nrhs(l, a, b_layout, b) - } +pub trait LeastSquaresWorkImpl: Sized { + type Elem: Scalar; + fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result; + fn calc( + &mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result>; + fn eval( + self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result>; +} + +macro_rules! impl_least_squares_work_c { + ($c:ty, $lsd:path) => { + impl LeastSquaresWorkImpl for LeastSquaresWork<$c> { + type Elem = $c; - fn least_squares_nrhs( - a_layout: MatrixLayout, - a: &mut [Self], - b_layout: MatrixLayout, - b: &mut [Self], - ) -> Result> { - // Minimize |b - Ax|_2 - // - // where - // A : (m, n) - // b : (max(m, n), nrhs) // `b` has to store `x` on exit - // x : (n, nrhs) + fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result { let (m, n) = a_layout.size(); let (m_, nrhs) = b_layout.size(); let k = m.min(n); assert!(m_ >= m); + let rcond = -1.; + let mut singular_values = vec_uninit(k as usize); + let mut rank: i32 = 0; + + // eval work size + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + let mut iwork_size = [0]; + let mut rwork = [::Real::zero()]; + unsafe { + $lsd( + &m, + &n, + &nrhs, + std::ptr::null_mut(), + &m, + std::ptr::null_mut(), + &m_, + AsPtr::as_mut_ptr(&mut singular_values), + &rcond, + &mut rank, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), + iwork_size.as_mut_ptr(), + &mut info, + ) + }; + info.as_lapack_result()?; + + let lwork = work_size[0].to_usize().unwrap(); + let liwork = iwork_size[0].to_usize().unwrap(); + let lrwork = rwork[0].to_usize().unwrap(); + + let work = vec_uninit(lwork); + let iwork = vec_uninit(liwork); + let rwork = vec_uninit(lrwork); + + Ok(LeastSquaresWork { + a_layout, + b_layout, + work, + iwork, + rwork: Some(rwork), + singular_values, + }) + } + + fn calc( + &mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result> { + let (m, n) = self.a_layout.size(); + let (m_, nrhs) = self.b_layout.size(); + assert!(m_ >= m); + + let lwork = self.work.len().to_i32().unwrap(); + // Transpose if a is C-continuous let mut a_t = None; - let a_layout = match a_layout { + let _ = match self.a_layout { MatrixLayout::C { .. } => { - a_t = Some(unsafe { vec_uninit( a.len()) }); - transpose(a_layout, a, a_t.as_mut().unwrap()) + let (layout, t) = transpose(self.a_layout, a); + a_t = Some(t); + layout } - MatrixLayout::F { .. } => a_layout, + MatrixLayout::F { .. } => self.a_layout, }; // Transpose if b is C-continuous let mut b_t = None; - let b_layout = match b_layout { + let b_layout = match self.b_layout { MatrixLayout::C { .. } => { - b_t = Some(unsafe { vec_uninit( b.len()) }); - transpose(b_layout, b, b_t.as_mut().unwrap()) + let (layout, t) = transpose(self.b_layout, b); + b_t = Some(t); + layout } - MatrixLayout::F { .. } => b_layout, + MatrixLayout::F { .. } => self.b_layout, }; - let rcond: Self::Real = -1.; - let mut singular_values: Vec = unsafe { vec_uninit( k as usize) }; + let rcond: ::Real = -1.; + let mut rank: i32 = 0; + + let mut info = 0; + unsafe { + $lsd( + &m, + &n, + &nrhs, + AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)), + &m, + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &m_, + AsPtr::as_mut_ptr(&mut self.singular_values), + &rcond, + &mut rank, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), + AsPtr::as_mut_ptr(&mut self.iwork), + &mut info, + ); + } + info.as_lapack_result()?; + + let singular_values = unsafe { self.singular_values.slice_assume_init_ref() }; + + // Skip a_t -> a transpose because A has been destroyed + // Re-transpose b + if let Some(b_t) = b_t { + transpose_over(b_layout, &b_t, b); + } + + Ok(LeastSquaresRef { + singular_values, + rank, + }) + } + + fn eval( + mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result> { + let LeastSquaresRef { rank, .. } = self.calc(a, b)?; + let singular_values = unsafe { self.singular_values.assume_init() }; + Ok(LeastSquaresOwned { + singular_values, + rank, + }) + } + } + }; +} +impl_least_squares_work_c!(c64, lapack_sys::zgelsd_); +impl_least_squares_work_c!(c32, lapack_sys::cgelsd_); + +macro_rules! impl_least_squares_work_r { + ($c:ty, $lsd:path) => { + impl LeastSquaresWorkImpl for LeastSquaresWork<$c> { + type Elem = $c; + + fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result { + let (m, n) = a_layout.size(); + let (m_, nrhs) = b_layout.size(); + let k = m.min(n); + assert!(m_ >= m); + + let rcond = -1.; + let mut singular_values = vec_uninit(k as usize); let mut rank: i32 = 0; // eval work size let mut info = 0; - let mut work_size = [Self::zero()]; + let mut work_size = [Self::Elem::zero()]; let mut iwork_size = [0]; - $( - let mut $rwork = [Self::Real::zero()]; - )* unsafe { - $gelsd( - m, - n, - nrhs, - a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), - &mut singular_values, - rcond, + $lsd( + &m, + &n, + &nrhs, + std::ptr::null_mut(), + &m, + std::ptr::null_mut(), + &m_, + AsPtr::as_mut_ptr(&mut singular_values), + &rcond, &mut rank, - &mut work_size, - -1, - $(&mut $rwork,)* - &mut iwork_size, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + iwork_size.as_mut_ptr(), &mut info, ) }; info.as_lapack_result()?; - // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit( lwork) }; let liwork = iwork_size[0].to_usize().unwrap(); - let mut iwork = unsafe { vec_uninit( liwork) }; - $( - let lrwork = $rwork[0].to_usize().unwrap(); - let mut $rwork = unsafe { vec_uninit( lrwork) }; - )* + + let work = vec_uninit(lwork); + let iwork = vec_uninit(liwork); + + Ok(LeastSquaresWork { + a_layout, + b_layout, + work, + iwork, + rwork: None, + singular_values, + }) + } + + fn calc( + &mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result> { + let (m, n) = self.a_layout.size(); + let (m_, nrhs) = self.b_layout.size(); + assert!(m_ >= m); + + let lwork = self.work.len().to_i32().unwrap(); + + // Transpose if a is C-continuous + let mut a_t = None; + let _ = match self.a_layout { + MatrixLayout::C { .. } => { + let (layout, t) = transpose(self.a_layout, a); + a_t = Some(t); + layout + } + MatrixLayout::F { .. } => self.a_layout, + }; + + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match self.b_layout { + MatrixLayout::C { .. } => { + let (layout, t) = transpose(self.b_layout, b); + b_t = Some(t); + layout + } + MatrixLayout::F { .. } => self.b_layout, + }; + + let rcond: ::Real = -1.; + let mut rank: i32 = 0; + + let mut info = 0; unsafe { - $gelsd( - m, - n, - nrhs, - a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), - &mut singular_values, - rcond, + $lsd( + &m, + &n, + &nrhs, + AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)), + &m, + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &m_, + AsPtr::as_mut_ptr(&mut self.singular_values), + &rcond, &mut rank, - &mut work, - lwork as i32, - $(&mut $rwork,)* - &mut iwork, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(&mut self.iwork), &mut info, ); } info.as_lapack_result()?; + let singular_values = unsafe { self.singular_values.slice_assume_init_ref() }; + // Skip a_t -> a transpose because A has been destroyed // Re-transpose b if let Some(b_t) = b_t { - transpose(b_layout, &b_t, b); + transpose_over(b_layout, &b_t, b); } - Ok(LeastSquaresOutput { + Ok(LeastSquaresRef { + singular_values, + rank, + }) + } + + fn eval( + mut self, + a: &mut [Self::Elem], + b: &mut [Self::Elem], + ) -> Result> { + let LeastSquaresRef { rank, .. } = self.calc(a, b)?; + let singular_values = unsafe { self.singular_values.assume_init() }; + Ok(LeastSquaresOwned { singular_values, rank, }) @@ -160,8 +334,5 @@ macro_rules! impl_least_squares { } }; } - -impl_least_squares!(@real, f64, lapack::dgelsd); -impl_least_squares!(@real, f32, lapack::sgelsd); -impl_least_squares!(@complex, c64, lapack::zgelsd); -impl_least_squares!(@complex, c32, lapack::cgelsd); +impl_least_squares_work_r!(f64, lapack_sys::dgelsd_); +impl_least_squares_work_r!(f32, lapack_sys::sgelsd_); diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 41c15237..680ff0db 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -1,63 +1,80 @@ -//! Linear Algebra eXtension (LAX) -//! =============================== +//! Safe Rust wrapper for LAPACK without external dependency. //! -//! ndarray-free safe Rust wrapper for LAPACK FFI +//! [Lapack] trait +//! ---------------- //! -//! Linear equation, Inverse matrix, Condition number -//! -------------------------------------------------- +//! This crates provides LAPACK wrapper as a traits. +//! For example, LU decomposition of general matrices is provided like: //! -//! As the property of $A$, several types of triangular factorization are used: +//! ```ignore +//! pub trait Lapack { +//! fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; +//! } +//! ``` //! -//! - LU-decomposition for general matrix -//! - $PA = LU$, where $L$ is lower matrix, $U$ is upper matrix, and $P$ is permutation matrix -//! - Bunch-Kaufman diagonal pivoting method for nonpositive-definite Hermitian matrix -//! - $A = U D U^\dagger$, where $U$ is upper matrix, -//! $D$ is Hermitian and block diagonal with 1-by-1 and 2-by-2 diagonal blocks. +//! see [Lapack] for detail. +//! This trait is implemented for [f32], [f64], [c32] which is an alias to `num::Complex`, +//! and [c64] which is an alias to `num::Complex`. +//! You can use it like `f64::lu`: //! -//! | matrix type | Triangler factorization (TRF) | Solve (TRS) | Inverse matrix (TRI) | Reciprocal condition number (CON) | -//! |:--------------------------------|:------------------------------|:------------|:---------------------|:----------------------------------| -//! | General (GE) | [lu] | [solve] | [inv] | [rcond] | -//! | Symmetric (SY) / Hermitian (HE) | [bk] | [solveh] | [invh] | - | +//! ``` +//! use lax::{Lapack, layout::MatrixLayout, Transpose}; //! -//! [lu]: solve/trait.Solve_.html#tymethod.lu -//! [solve]: solve/trait.Solve_.html#tymethod.solve -//! [inv]: solve/trait.Solve_.html#tymethod.inv -//! [rcond]: solve/trait.Solve_.html#tymethod.rcond +//! let mut a = vec![ +//! 1.0, 2.0, +//! 3.0, 4.0 +//! ]; +//! let mut b = vec![1.0, 2.0]; +//! let layout = MatrixLayout::C { row: 2, lda: 2 }; +//! let pivot = f64::lu(layout, &mut a).unwrap(); +//! f64::solve(layout, Transpose::No, &a, &pivot, &mut b).unwrap(); +//! ``` //! -//! [bk]: solveh/trait.Solveh_.html#tymethod.bk -//! [solveh]: solveh/trait.Solveh_.html#tymethod.solveh -//! [invh]: solveh/trait.Solveh_.html#tymethod.invh +//! When you want to write generic algorithm for real and complex matrices, +//! this trait can be used as a trait bound: //! -//! Eigenvalue Problem -//! ------------------- +//! ``` +//! use lax::{Lapack, layout::MatrixLayout, Transpose}; //! -//! Solve eigenvalue problem for a matrix $A$ +//! fn solve_at_once(layout: MatrixLayout, a: &mut [T], b: &mut [T]) -> Result<(), lax::error::Error> { +//! let pivot = T::lu(layout, a)?; +//! T::solve(layout, Transpose::No, a, &pivot, b)?; +//! Ok(()) +//! } +//! ``` //! -//! $$ Av_i = \lambda_i v_i $$ +//! There are several similar traits as described below to keep development easy. +//! They are merged into a single trait, [Lapack]. //! -//! or generalized eigenvalue problem +//! Linear equation, Inverse matrix, Condition number +//! -------------------------------------------------- +//! +//! According to the property input metrix, several types of triangular decomposition are used: //! -//! $$ Av_i = \lambda_i B v_i $$ +//! - [solve] module provides methods for LU-decomposition for general matrix. +//! - [solveh] module provides methods for Bunch-Kaufman diagonal pivoting method for symmetric/Hermitian indefinite matrix. +//! - [cholesky] module provides methods for Cholesky decomposition for symmetric/Hermitian positive dinite matrix. //! -//! | matrix type | Eigenvalue (EV) | Generalized Eigenvalue Problem (EG) | -//! |:--------------------------------|:----------------|:------------------------------------| -//! | General (GE) |[eig] | - | -//! | Symmetric (SY) / Hermitian (HE) |[eigh] |[eigh_generalized] | +//! Eigenvalue Problem +//! ------------------- //! -//! [eig]: eig/trait.Eig_.html#tymethod.eig -//! [eigh]: eigh/trait.Eigh_.html#tymethod.eigh -//! [eigh_generalized]: eigh/trait.Eigh_.html#tymethod.eigh_generalized +//! According to the property input metrix, +//! there are several types of eigenvalue problem API //! -//! Singular Value Decomposition (SVD), Least square problem -//! ---------------------------------------------------------- +//! - [eig] module for eigenvalue problem for general matrix. +//! - [eig_generalized] module for generalized eigenvalue problem for general matrix. +//! - [eigh] module for eigenvalue problem for symmetric/Hermitian matrix. +//! - [eigh_generalized] module for generalized eigenvalue problem for symmetric/Hermitian matrix. //! -//! | matrix type | Singular Value Decomposition (SVD) | SVD with divided-and-conquer (SDD) | Least square problem (LSD) | -//! |:-------------|:-----------------------------------|:-----------------------------------|:---------------------------| -//! | General (GE) | [svd] | [svddc] | [least_squares] | +//! Singular Value Decomposition +//! ----------------------------- //! -//! [svd]: svd/trait.SVD_.html#tymethod.svd -//! [svddc]: svddck/trait.SVDDC_.html#tymethod.svddc -//! [least_squares]: least_squares/trait.LeastSquaresSvdDivideConquer_.html#tymethod.least_squares +//! - [svd] module for singular value decomposition (SVD) for general matrix +//! - [svddc] module for singular value decomposition (SVD) with divided-and-conquer algorithm for general matrix +//! - [least_squares] module for solving least square problem using SVD +//! + +#![deny(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links)] #[cfg(any(feature = "intel-mkl-system", feature = "intel-mkl-static"))] extern crate intel_mkl_src as _src; @@ -68,115 +85,482 @@ extern crate openblas_src as _src; #[cfg(any(feature = "netlib-system", feature = "netlib-static"))] extern crate netlib_src as _src; +pub mod alloc; +pub mod cholesky; +pub mod eig; +pub mod eig_generalized; +pub mod eigh; +pub mod eigh_generalized; pub mod error; +pub mod flags; pub mod layout; +pub mod least_squares; +pub mod opnorm; +pub mod qr; +pub mod rcond; +pub mod solve; +pub mod solveh; +pub mod svd; +pub mod svddc; +pub mod triangular; +pub mod tridiagonal; -mod cholesky; -mod eig; -mod eigh; -mod least_squares; -mod opnorm; -mod qr; -mod rcond; -mod solve; -mod solveh; -mod svd; -mod svddc; -mod triangular; -mod tridiagonal; - -pub use self::cholesky::*; -pub use self::eig::*; -pub use self::eigh::*; -pub use self::least_squares::*; -pub use self::opnorm::*; -pub use self::qr::*; -pub use self::rcond::*; -pub use self::solve::*; -pub use self::solveh::*; -pub use self::svd::*; -pub use self::svddc::*; -pub use self::triangular::*; -pub use self::tridiagonal::*; +pub use crate::eig_generalized::GeneralizedEigenvalue; +pub use self::flags::*; +pub use self::least_squares::LeastSquaresOwned; +pub use self::svd::{SvdOwned, SvdRef}; +pub use self::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal}; + +use self::{alloc::*, error::*, layout::*}; use cauchy::*; +use std::mem::MaybeUninit; pub type Pivot = Vec; +#[cfg_attr(doc, katexit::katexit)] /// Trait for primitive types which implements LAPACK subroutines -pub trait Lapack: - OperatorNorm_ - + QR_ - + SVD_ - + SVDDC_ - + Solve_ - + Solveh_ - + Cholesky_ - + Eig_ - + Eigh_ - + Triangular_ - + Tridiagonal_ - + Rcond_ - + LeastSquaresSvdDivideConquer_ -{ -} +pub trait Lapack: Scalar { + /// Compute right eigenvalue and eigenvectors for a general matrix + fn eig( + calc_v: bool, + l: MatrixLayout, + a: &mut [Self], + ) -> Result<(Vec, Vec)>; -impl Lapack for f32 {} -impl Lapack for f64 {} -impl Lapack for c32 {} -impl Lapack for c64 {} - -/// Upper/Lower specification for seveal usages -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum UPLO { - Upper = b'U', - Lower = b'L', -} + /// Compute right eigenvalue and eigenvectors for a general matrix + fn eig_generalized( + calc_v: bool, + l: MatrixLayout, + a: &mut [Self], + b: &mut [Self], + thresh_opt: Option, + ) -> Result<( + Vec>, + Vec, + )>; -impl UPLO { - pub fn t(self) -> Self { - match self { - UPLO::Upper => UPLO::Lower, - UPLO::Lower => UPLO::Upper, - } - } -} + /// Compute right eigenvalue and eigenvectors for a symmetric or Hermitian matrix + fn eigh( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + ) -> Result>; -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Transpose { - No = b'N', - Transpose = b'T', - Hermite = b'C', -} + /// Compute right eigenvalue and eigenvectors for a symmetric or Hermitian matrix + fn eigh_generalized( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + b: &mut [Self], + ) -> Result>; -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum NormType { - One = b'O', - Infinity = b'I', - Frobenius = b'F', -} + /// Execute Householder reflection as the first step of QR-decomposition + /// + /// For C-continuous array, + /// this will call LQ-decomposition of the transposed matrix $ A^T = LQ^T $ + fn householder(l: MatrixLayout, a: &mut [Self]) -> Result>; -impl NormType { - pub fn transpose(self) -> Self { - match self { - NormType::One => NormType::Infinity, - NormType::Infinity => NormType::One, - NormType::Frobenius => NormType::Frobenius, - } - } + /// Reconstruct Q-matrix from Householder-reflectors + fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>; + + /// Execute QR-decomposition at once + fn qr(l: MatrixLayout, a: &mut [Self]) -> Result>; + + /// Compute singular-value decomposition (SVD) + fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>; + + /// Compute singular value decomposition (SVD) with divide-and-conquer algorithm + fn svddc(layout: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result>; + + /// Compute a vector $x$ which minimizes Euclidian norm $\| Ax - b\|$ + /// for a given matrix $A$ and a vector $b$. + fn least_squares( + a_layout: MatrixLayout, + a: &mut [Self], + b: &mut [Self], + ) -> Result>; + + /// Solve least square problems $\argmin_X \| AX - B\|$ + fn least_squares_nrhs( + a_layout: MatrixLayout, + a: &mut [Self], + b_layout: MatrixLayout, + b: &mut [Self], + ) -> Result>; + + /// Computes the LU decomposition of a general $m \times n$ matrix + /// with partial pivoting with row interchanges. + /// + /// For a given matrix $A$, LU decomposition is described as $A = PLU$ where: + /// + /// - $L$ is lower matrix + /// - $U$ is upper matrix + /// - $P$ is permutation matrix represented by [Pivot] + /// + /// This is designed as two step computation according to LAPACK API: + /// + /// 1. Factorize input matrix $A$ into $L$, $U$, and $P$. + /// 2. Solve linear equation $Ax = b$ by [Lapack::solve] + /// or compute inverse matrix $A^{-1}$ by [Lapack::inv] using the output of LU decomposition. + /// + /// Output + /// ------- + /// - $U$ and $L$ are stored in `a` after LU decomposition has succeeded. + /// - $P$ is returned as [Pivot] + /// + /// Error + /// ------ + /// - if the matrix is singular + /// - On this case, `return_code` in [Error::LapackComputationalFailure] means + /// `return_code`-th diagonal element of $U$ becomes zero. + /// + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; + + /// Compute inverse matrix $A^{-1}$ from the output of LU-decomposition + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; + + /// Solve linear equations $Ax = b$ using the output of LU-decomposition + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; + + /// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method + /// + /// For a given symmetric matrix $A$, + /// this method factorizes $A = U^T D U$ or $A = L D L^T$ where + /// + /// - $U$ (or $L$) are is a product of permutation and unit upper (lower) triangular matrices + /// - $D$ is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks. + /// + /// This takes two-step approach based in LAPACK: + /// + /// 1. Factorize given matrix $A$ into upper ($U$) or lower ($L$) form with diagonal matrix $D$ + /// 2. Then solve linear equation $Ax = b$, and/or calculate inverse matrix $A^{-1}$ + /// + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; + + /// Compute inverse matrix $A^{-1}$ using the result of [Lapack::bk] + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; + + /// Solve symmetric/Hermitian linear equation $Ax = b$ using the result of [Lapack::bk] + fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; + + /// Solve symmetric/Hermitian positive-definite linear equations using Cholesky decomposition + /// + /// For a given positive definite matrix $A$, + /// Cholesky decomposition is described as $A = U^T U$ or $A = LL^T$ where + /// + /// - $L$ is lower matrix + /// - $U$ is upper matrix + /// + /// This is designed as two step computation according to LAPACK API + /// + /// 1. Factorize input matrix $A$ into $L$ or $U$ + /// 2. Solve linear equation $Ax = b$ by [Lapack::solve_cholesky] + /// or compute inverse matrix $A^{-1}$ by [Lapack::inv_cholesky] + /// + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + + /// Compute inverse matrix $A^{-1}$ using $U$ or $L$ calculated by [Lapack::cholesky] + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + + /// Solve linear equation $Ax = b$ using $U$ or $L$ calculated by [Lapack::cholesky] + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; + + /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. + /// + /// `anorm` should be the 1-norm of the matrix `a`. + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; + + /// Compute norm of matrices + /// + /// For a $n \times m$ matrix + /// $$ + /// A = \begin{pmatrix} + /// a_{11} & \cdots & a_{1m} \\\\ + /// \vdots & \ddots & \vdots \\\\ + /// a_{n1} & \cdots & a_{nm} + /// \end{pmatrix} + /// $$ + /// LAPACK can compute three types of norms: + /// + /// - Operator norm based on 1-norm for its domain linear space: + /// $$ + /// \Vert A \Vert_1 = \sup_{\Vert x \Vert_1 = 1} \Vert Ax \Vert_1 + /// = \max_{1 \le j \le m } \sum_{i=1}^n |a_{ij}| + /// $$ + /// where + /// $\Vert x\Vert_1 = \sum_{j=1}^m |x_j|$ + /// is 1-norm for a vector $x$. + /// + /// - Operator norm based on $\infty$-norm for its domain linear space: + /// $$ + /// \Vert A \Vert_\infty = \sup_{\Vert x \Vert_\infty = 1} \Vert Ax \Vert_\infty + /// = \max_{1 \le i \le n } \sum_{j=1}^m |a_{ij}| + /// $$ + /// where + /// $\Vert x\Vert_\infty = \max_{j=1}^m |x_j|$ + /// is $\infty$-norm for a vector $x$. + /// + /// - Frobenious norm + /// $$ + /// \Vert A \Vert_F = \sqrt{\mathrm{Tr} \left(AA^\dagger\right)} = \sqrt{\sum_{i=1}^n \sum_{j=1}^m |a_{ij}|^2} + /// $$ + /// + fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real; + + fn solve_triangular( + al: MatrixLayout, + bl: MatrixLayout, + uplo: UPLO, + d: Diag, + a: &[Self], + b: &mut [Self], + ) -> Result<()>; + + /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using + /// partial pivoting with row interchanges. + fn lu_tridiagonal(a: Tridiagonal) -> Result>; + + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; + + fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()>; } -/// Create a vector without initialization -/// -/// Safety -/// ------ -/// - Memory is not initialized. Do not read the memory before write. -/// -unsafe fn vec_uninit(n: usize) -> Vec { - let mut v = Vec::with_capacity(n); - v.set_len(n); - v +macro_rules! impl_lapack { + ($s:ty) => { + impl Lapack for $s { + fn eig( + calc_v: bool, + l: MatrixLayout, + a: &mut [Self], + ) -> Result<(Vec, Vec)> { + use eig::*; + let work = EigWork::<$s>::new(calc_v, l)?; + let EigOwned { eigs, vr, vl } = work.eval(a)?; + Ok((eigs, vr.or(vl).unwrap_or_default())) + } + + fn eig_generalized( + calc_v: bool, + l: MatrixLayout, + a: &mut [Self], + b: &mut [Self], + thresh_opt: Option, + ) -> Result<( + Vec>, + Vec, + )> { + use eig_generalized::*; + let work = EigGeneralizedWork::<$s>::new(calc_v, l)?; + let eig_generalized_owned = work.eval(a, b)?; + let eigs = eig_generalized_owned.calc_eigs(thresh_opt); + let vr = eig_generalized_owned.vr; + let vl = eig_generalized_owned.vl; + Ok((eigs, vr.or(vl).unwrap_or_default())) + } + + fn eigh( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + ) -> Result> { + use eigh::*; + let work = EighWork::<$s>::new(calc_eigenvec, layout)?; + work.eval(uplo, a) + } + + fn eigh_generalized( + calc_eigenvec: bool, + layout: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + b: &mut [Self], + ) -> Result> { + use eigh_generalized::*; + let work = EighGeneralizedWork::<$s>::new(calc_eigenvec, layout)?; + work.eval(uplo, a, b) + } + + fn householder(l: MatrixLayout, a: &mut [Self]) -> Result> { + use qr::*; + let work = HouseholderWork::<$s>::new(l)?; + work.eval(a) + } + + fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()> { + use qr::*; + let mut work = QWork::<$s>::new(l)?; + work.calc(a, tau)?; + Ok(()) + } + + fn qr(l: MatrixLayout, a: &mut [Self]) -> Result> { + let tau = Self::householder(l, a)?; + let r = Vec::from(&*a); + Self::q(l, a, &tau)?; + Ok(r) + } + + fn svd( + l: MatrixLayout, + calc_u: bool, + calc_vt: bool, + a: &mut [Self], + ) -> Result> { + use svd::*; + let work = SvdWork::<$s>::new(l, calc_u, calc_vt)?; + work.eval(a) + } + + fn svddc(layout: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result> { + use svddc::*; + let work = SvdDcWork::<$s>::new(layout, jobz)?; + work.eval(a) + } + + fn least_squares( + l: MatrixLayout, + a: &mut [Self], + b: &mut [Self], + ) -> Result> { + let b_layout = l.resized(b.len() as i32, 1); + Self::least_squares_nrhs(l, a, b_layout, b) + } + + fn least_squares_nrhs( + a_layout: MatrixLayout, + a: &mut [Self], + b_layout: MatrixLayout, + b: &mut [Self], + ) -> Result> { + use least_squares::*; + let work = LeastSquaresWork::<$s>::new(a_layout, b_layout)?; + work.eval(a, b) + } + + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { + use solve::*; + LuImpl::lu(l, a) + } + + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()> { + use solve::*; + let mut work = InvWork::<$s>::new(l)?; + work.calc(a, p)?; + Ok(()) + } + + fn solve( + l: MatrixLayout, + t: Transpose, + a: &[Self], + p: &Pivot, + b: &mut [Self], + ) -> Result<()> { + use solve::*; + SolveImpl::solve(l, t, a, p, b) + } + + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { + use solveh::*; + let work = BkWork::<$s>::new(l)?; + work.eval(uplo, a) + } + + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + use solveh::*; + let mut work = InvhWork::<$s>::new(l)?; + work.calc(uplo, a, ipiv) + } + + fn solveh( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + ipiv: &Pivot, + b: &mut [Self], + ) -> Result<()> { + use solveh::*; + SolvehImpl::solveh(l, uplo, a, ipiv, b) + } + + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + use cholesky::*; + CholeskyImpl::cholesky(l, uplo, a) + } + + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + use cholesky::*; + InvCholeskyImpl::inv_cholesky(l, uplo, a) + } + + fn solve_cholesky( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + b: &mut [Self], + ) -> Result<()> { + use cholesky::*; + SolveCholeskyImpl::solve_cholesky(l, uplo, a, b) + } + + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + use rcond::*; + let mut work = RcondWork::<$s>::new(l); + work.calc(a, anorm) + } + + fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { + use opnorm::*; + let mut work = OperatorNormWork::<$s>::new(t, l); + work.calc(a) + } + + fn solve_triangular( + al: MatrixLayout, + bl: MatrixLayout, + uplo: UPLO, + d: Diag, + a: &[Self], + b: &mut [Self], + ) -> Result<()> { + use triangular::*; + SolveTriangularImpl::solve_triangular(al, bl, uplo, d, a, b) + } + + fn lu_tridiagonal(a: Tridiagonal) -> Result> { + use tridiagonal::*; + let work = LuTridiagonalWork::<$s>::new(a.l); + work.eval(a) + } + + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { + use tridiagonal::*; + let mut work = RcondTridiagonalWork::<$s>::new(lu.a.l); + work.calc(lu) + } + + fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()> { + use tridiagonal::*; + SolveTridiagonalImpl::solve_tridiagonal(lu, bl, t, b) + } + } + }; } +impl_lapack!(c64); +impl_lapack!(c32); +impl_lapack!(f64); +impl_lapack!(f32); diff --git a/lax/src/opnorm.rs b/lax/src/opnorm.rs index dd84f441..1789f385 100644 --- a/lax/src/opnorm.rs +++ b/lax/src/opnorm.rs @@ -1,35 +1,58 @@ -//! Operator norms of matrices +//! Operator norm -use super::NormType; +use super::{AsPtr, NormType}; use crate::{layout::MatrixLayout, *}; use cauchy::*; -pub trait OperatorNorm_: Scalar { - fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real; +pub struct OperatorNormWork { + pub ty: NormType, + pub layout: MatrixLayout, + pub work: Vec>, } -macro_rules! impl_opnorm { - ($scalar:ty, $lange:path) => { - impl OperatorNorm_ for $scalar { - fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { - let m = l.lda(); - let n = l.len(); - let t = match l { - MatrixLayout::F { .. } => t, - MatrixLayout::C { .. } => t.transpose(), +pub trait OperatorNormWorkImpl { + type Elem: Scalar; + fn new(t: NormType, l: MatrixLayout) -> Self; + fn calc(&mut self, a: &[Self::Elem]) -> ::Real; +} + +macro_rules! impl_operator_norm { + ($s:ty, $lange:path) => { + impl OperatorNormWorkImpl for OperatorNormWork<$s> { + type Elem = $s; + + fn new(ty: NormType, layout: MatrixLayout) -> Self { + let m = layout.lda(); + let work = match (ty, layout) { + (NormType::Infinity, MatrixLayout::F { .. }) + | (NormType::One, MatrixLayout::C { .. }) => vec_uninit(m as usize), + _ => Vec::new(), }; - let mut work = if matches!(t, NormType::Infinity) { - unsafe { vec_uninit(m as usize) } - } else { - Vec::new() + OperatorNormWork { ty, layout, work } + } + + fn calc(&mut self, a: &[Self::Elem]) -> ::Real { + let m = self.layout.lda(); + let n = self.layout.len(); + let t = match self.layout { + MatrixLayout::F { .. } => self.ty, + MatrixLayout::C { .. } => self.ty.transpose(), }; - unsafe { $lange(t as u8, m, n, a, m, &mut work) } + unsafe { + $lange( + t.as_ptr(), + &m, + &n, + AsPtr::as_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut self.work), + ) + } } } }; -} // impl_opnorm! - -impl_opnorm!(f64, lapack::dlange); -impl_opnorm!(f32, lapack::slange); -impl_opnorm!(c64, lapack::zlange); -impl_opnorm!(c32, lapack::clange); +} +impl_operator_norm!(c64, lapack_sys::zlange_); +impl_operator_norm!(c32, lapack_sys::clange_); +impl_operator_norm!(f64, lapack_sys::dlange_); +impl_operator_norm!(f32, lapack_sys::slange_); diff --git a/lax/src/qr.rs b/lax/src/qr.rs index 6460b8b9..f37bd579 100644 --- a/lax/src/qr.rs +++ b/lax/src/qr.rs @@ -4,152 +4,213 @@ use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -pub trait QR_: Sized { - /// Execute Householder reflection as the first step of QR-decomposition - /// - /// For C-continuous array, - /// this will call LQ-decomposition of the transposed matrix $ A^T = LQ^T $ - fn householder(l: MatrixLayout, a: &mut [Self]) -> Result>; - - /// Reconstruct Q-matrix from Householder-reflectors - fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>; +pub struct HouseholderWork { + pub m: i32, + pub n: i32, + pub layout: MatrixLayout, + pub tau: Vec>, + pub work: Vec>, +} - /// Execute QR-decomposition at once - fn qr(l: MatrixLayout, a: &mut [Self]) -> Result>; +pub trait HouseholderWorkImpl: Sized { + type Elem: Scalar; + fn new(l: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self::Elem]) -> Result<&[Self::Elem]>; + fn eval(self, a: &mut [Self::Elem]) -> Result>; } -macro_rules! impl_qr { - ($scalar:ty, $qrf:path, $lqf:path, $gqr:path, $glq:path) => { - impl QR_ for $scalar { - fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result> { - let m = l.lda(); - let n = l.len(); - let k = m.min(n); - let mut tau = unsafe { vec_uninit(k as usize) }; +macro_rules! impl_householder_work { + ($s:ty, $qrf:path, $lqf: path) => { + impl HouseholderWorkImpl for HouseholderWork<$s> { + type Elem = $s; - // eval work size + fn new(layout: MatrixLayout) -> Result { + let m = layout.lda(); + let n = layout.len(); + let k = m.min(n); + let mut tau = vec_uninit(k as usize); let mut info = 0; - let mut work_size = [Self::zero()]; - unsafe { - match l { - MatrixLayout::F { .. } => { - $qrf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); - } - MatrixLayout::C { .. } => { - $lqf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info); - } - } + let mut work_size = [Self::Elem::zero()]; + match layout { + MatrixLayout::F { .. } => unsafe { + $qrf( + &m, + &n, + std::ptr::null_mut(), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }, + MatrixLayout::C { .. } => unsafe { + $lqf( + &m, + &n, + std::ptr::null_mut(), + &m, + AsPtr::as_mut_ptr(&mut tau), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }, } info.as_lapack_result()?; - - // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; - unsafe { - match l { - MatrixLayout::F { .. } => { - $qrf( - m, - n, - &mut a, - m, - &mut tau, - &mut work, - lwork as i32, - &mut info, - ); - } - MatrixLayout::C { .. } => { - $lqf( - m, - n, - &mut a, - m, - &mut tau, - &mut work, - lwork as i32, - &mut info, - ); - } - } + let work = vec_uninit(lwork); + Ok(HouseholderWork { + n, + m, + layout, + tau, + work, + }) + } + + fn calc(&mut self, a: &mut [Self::Elem]) -> Result<&[Self::Elem]> { + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + match self.layout { + MatrixLayout::F { .. } => unsafe { + $qrf( + &self.m, + &self.n, + AsPtr::as_mut_ptr(a), + &self.m, + AsPtr::as_mut_ptr(&mut self.tau), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ); + }, + MatrixLayout::C { .. } => unsafe { + $lqf( + &self.m, + &self.n, + AsPtr::as_mut_ptr(a), + &self.m, + AsPtr::as_mut_ptr(&mut self.tau), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ); + }, } info.as_lapack_result()?; + Ok(unsafe { self.tau.slice_assume_init_ref() }) + } - Ok(tau) + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let _eig = self.calc(a)?; + Ok(unsafe { self.tau.assume_init() }) } + } + }; +} +impl_householder_work!(c64, lapack_sys::zgeqrf_, lapack_sys::zgelqf_); +impl_householder_work!(c32, lapack_sys::cgeqrf_, lapack_sys::cgelqf_); +impl_householder_work!(f64, lapack_sys::dgeqrf_, lapack_sys::dgelqf_); +impl_householder_work!(f32, lapack_sys::sgeqrf_, lapack_sys::sgelqf_); - fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { - let m = l.lda(); - let n = l.len(); - let k = m.min(n); - assert_eq!(tau.len(), k as usize); +pub struct QWork { + pub layout: MatrixLayout, + pub work: Vec>, +} - // eval work size - let mut info = 0; - let mut work_size = [Self::zero()]; - unsafe { - match l { - MatrixLayout::F { .. } => { - $gqr(m, k, k, &mut a, m, &tau, &mut work_size, -1, &mut info) - } - MatrixLayout::C { .. } => { - $glq(k, n, k, &mut a, m, &tau, &mut work_size, -1, &mut info) - } - } - }; +pub trait QWorkImpl: Sized { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self::Elem], tau: &[Self::Elem]) -> Result<()>; +} - // calc +macro_rules! impl_q_work { + ($s:ty, $gqr:path, $glq:path) => { + impl QWorkImpl for QWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let m = layout.lda(); + let n = layout.len(); + let k = m.min(n); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + match layout { + MatrixLayout::F { .. } => unsafe { + $gqr( + &m, + &k, + &k, + std::ptr::null_mut(), + &m, + std::ptr::null_mut(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }, + MatrixLayout::C { .. } => unsafe { + $glq( + &k, + &n, + &k, + std::ptr::null_mut(), + &m, + std::ptr::null_mut(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }, + } let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; - unsafe { - match l { - MatrixLayout::F { .. } => { - $gqr(m, k, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) - } - MatrixLayout::C { .. } => { - $glq(k, n, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info) - } - } + let work = vec_uninit(lwork); + Ok(QWork { layout, work }) + } + + fn calc(&mut self, a: &mut [Self::Elem], tau: &[Self::Elem]) -> Result<()> { + let m = self.layout.lda(); + let n = self.layout.len(); + let k = m.min(n); + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + match self.layout { + MatrixLayout::F { .. } => unsafe { + $gqr( + &m, + &k, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }, + MatrixLayout::C { .. } => unsafe { + $glq( + &k, + &n, + &k, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_ptr(&tau), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }, } info.as_lapack_result()?; Ok(()) } - - fn qr(l: MatrixLayout, a: &mut [Self]) -> Result> { - let tau = Self::householder(l, a)?; - let r = Vec::from(&*a); - Self::q(l, a, &tau)?; - Ok(r) - } } }; -} // endmacro +} -impl_qr!( - f64, - lapack::dgeqrf, - lapack::dgelqf, - lapack::dorgqr, - lapack::dorglq -); -impl_qr!( - f32, - lapack::sgeqrf, - lapack::sgelqf, - lapack::sorgqr, - lapack::sorglq -); -impl_qr!( - c64, - lapack::zgeqrf, - lapack::zgelqf, - lapack::zungqr, - lapack::zunglq -); -impl_qr!( - c32, - lapack::cgeqrf, - lapack::cgelqf, - lapack::cungqr, - lapack::cunglq -); +impl_q_work!(c64, lapack_sys::zungqr_, lapack_sys::zunglq_); +impl_q_work!(c32, lapack_sys::cungqr_, lapack_sys::cunglq_); +impl_q_work!(f64, lapack_sys::dorgqr_, lapack_sys::dorglq_); +impl_q_work!(f32, lapack_sys::sorgqr_, lapack_sys::sorglq_); diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs index 91d7458c..4d4a4c92 100644 --- a/lax/src/rcond.rs +++ b/lax/src/rcond.rs @@ -1,85 +1,124 @@ +//! Reciprocal conditional number + use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::Zero; -pub trait Rcond_: Scalar + Sized { - /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. - /// - /// `anorm` should be the 1-norm of the matrix `a`. - fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; +pub struct RcondWork { + pub layout: MatrixLayout, + pub work: Vec>, + pub rwork: Option>>, + pub iwork: Option>>, } -macro_rules! impl_rcond_real { - ($scalar:ty, $gecon:path) => { - impl Rcond_ for $scalar { - fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { - let (n, _) = l.size(); - let mut rcond = Self::Real::zero(); - let mut info = 0; +pub trait RcondWorkImpl { + type Elem: Scalar; + fn new(l: MatrixLayout) -> Self; + fn calc( + &mut self, + a: &[Self::Elem], + anorm: ::Real, + ) -> Result<::Real>; +} + +macro_rules! impl_rcond_work_c { + ($c:ty, $con:path) => { + impl RcondWorkImpl for RcondWork<$c> { + type Elem = $c; - let mut work = unsafe { vec_uninit(4 * n as usize) }; - let mut iwork = unsafe { vec_uninit(n as usize) }; - let norm_type = match l { + fn new(layout: MatrixLayout) -> Self { + let (n, _) = layout.size(); + let work = vec_uninit(2 * n as usize); + let rwork = vec_uninit(2 * n as usize); + RcondWork { + layout, + work, + rwork: Some(rwork), + iwork: None, + } + } + + fn calc( + &mut self, + a: &[Self::Elem], + anorm: ::Real, + ) -> Result<::Real> { + let (n, _) = self.layout.size(); + let mut rcond = ::Real::zero(); + let mut info = 0; + let norm_type = match self.layout { MatrixLayout::C { .. } => NormType::Infinity, MatrixLayout::F { .. } => NormType::One, - } as u8; + }; unsafe { - $gecon( - norm_type, - n, - a, - l.lda(), - anorm, + $con( + norm_type.as_ptr(), + &n, + AsPtr::as_ptr(a), + &self.layout.lda(), + &anorm, &mut rcond, - &mut work, - &mut iwork, + AsPtr::as_mut_ptr(&mut self.work), + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), &mut info, ) }; info.as_lapack_result()?; - Ok(rcond) } } }; } +impl_rcond_work_c!(c64, lapack_sys::zgecon_); +impl_rcond_work_c!(c32, lapack_sys::cgecon_); -impl_rcond_real!(f32, lapack::sgecon); -impl_rcond_real!(f64, lapack::dgecon); +macro_rules! impl_rcond_work_r { + ($r:ty, $con:path) => { + impl RcondWorkImpl for RcondWork<$r> { + type Elem = $r; + + fn new(layout: MatrixLayout) -> Self { + let (n, _) = layout.size(); + let work = vec_uninit(4 * n as usize); + let iwork = vec_uninit(n as usize); + RcondWork { + layout, + work, + rwork: None, + iwork: Some(iwork), + } + } -macro_rules! impl_rcond_complex { - ($scalar:ty, $gecon:path) => { - impl Rcond_ for $scalar { - fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { - let (n, _) = l.size(); - let mut rcond = Self::Real::zero(); + fn calc( + &mut self, + a: &[Self::Elem], + anorm: ::Real, + ) -> Result<::Real> { + let (n, _) = self.layout.size(); + let mut rcond = ::Real::zero(); let mut info = 0; - let mut work = unsafe { vec_uninit(2 * n as usize) }; - let mut rwork = unsafe { vec_uninit(2 * n as usize) }; - let norm_type = match l { + let norm_type = match self.layout { MatrixLayout::C { .. } => NormType::Infinity, MatrixLayout::F { .. } => NormType::One, - } as u8; + }; unsafe { - $gecon( - norm_type, - n, - a, - l.lda(), - anorm, + $con( + norm_type.as_ptr(), + &n, + AsPtr::as_ptr(a), + &self.layout.lda(), + &anorm, &mut rcond, - &mut work, - &mut rwork, + AsPtr::as_mut_ptr(&mut self.work), + AsPtr::as_mut_ptr(self.iwork.as_mut().unwrap()), &mut info, ) }; info.as_lapack_result()?; - Ok(rcond) } } }; } - -impl_rcond_complex!(c32, lapack::cgecon); -impl_rcond_complex!(c64, lapack::zgecon); +impl_rcond_work_r!(f64, lapack_sys::dgecon_); +impl_rcond_work_r!(f32, lapack_sys::sgecon_); diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 39498a04..63f69983 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -1,30 +1,25 @@ -//! Solve linear problem using LU decomposition +//! Solve linear equations using LU-decomposition use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -pub trait Solve_: Scalar + Sized { - /// Computes the LU factorization of a general `m x n` matrix `a` using - /// partial pivoting with row interchanges. - /// - /// $ PA = LU $ - /// - /// Error - /// ------ - /// - `LapackComputationalFailure { return_code }` when the matrix is singular - /// - Division by zero will occur if it is used to solve a system of equations - /// because `U[(return_code-1, return_code-1)]` is exactly zero. +/// Helper trait to abstract `*getrf` LAPACK routines for implementing [Lapack::lu] +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | sgetrf | dgetrf | cgetrf | zgetrf | +/// +pub trait LuImpl: Scalar { fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; - - fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; - - fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } -macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { - impl Solve_ for $scalar { +macro_rules! impl_lu { + ($scalar:ty, $getrf:path) => { + impl LuImpl for $scalar { fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); assert_eq!(a.len() as i32, row * col); @@ -33,41 +28,71 @@ macro_rules! impl_solve { return Ok(Vec::new()); } let k = ::std::cmp::min(row, col); - let mut ipiv = unsafe { vec_uninit(k as usize) }; - let mut info = 0; - unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) }; - info.as_lapack_result()?; - Ok(ipiv) - } - - fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { - let (n, _) = l.size(); - - // calc work size + let mut ipiv = vec_uninit(k as usize); let mut info = 0; - let mut work_size = [Self::zero()]; - unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) }; - info.as_lapack_result()?; - - // actual - let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; unsafe { - $getri( - l.len(), - a, - l.lda(), - ipiv, - &mut work, - lwork as i32, + $getrf( + &l.lda(), + &l.len(), + AsPtr::as_mut_ptr(a), + &l.lda(), + AsPtr::as_mut_ptr(&mut ipiv), &mut info, ) }; info.as_lapack_result()?; - - Ok(()) + let ipiv = unsafe { ipiv.assume_init() }; + Ok(ipiv) } + } + }; +} + +impl_lu!(c64, lapack_sys::zgetrf_); +impl_lu!(c32, lapack_sys::cgetrf_); +impl_lu!(f64, lapack_sys::dgetrf_); +impl_lu!(f32, lapack_sys::sgetrf_); + +#[cfg_attr(doc, katexit::katexit)] +/// Helper trait to abstract `*getrs` LAPACK routines for implementing [Lapack::solve] +/// +/// If the array has C layout, then it needs to be handled +/// specially, since LAPACK expects a Fortran-layout array. +/// Reinterpreting a C layout array as Fortran layout is +/// equivalent to transposing it. So, we can handle the "no +/// transpose" and "transpose" cases by swapping to "transpose" +/// or "no transpose", respectively. For the "Hermite" case, we +/// can take advantage of the following: +/// +/// $$ +/// \begin{align*} +/// A^H x &= b \\\\ +/// \Leftrightarrow \overline{A^T} x &= b \\\\ +/// \Leftrightarrow \overline{\overline{A^T} x} &= \overline{b} \\\\ +/// \Leftrightarrow \overline{\overline{A^T}} \overline{x} &= \overline{b} \\\\ +/// \Leftrightarrow A^T \overline{x} &= \overline{b} +/// \end{align*} +/// $$ +/// +/// So, we can handle this case by switching to "no transpose" +/// (which is equivalent to transposing the array since it will +/// be reinterpreted as Fortran layout) and applying the +/// elementwise conjugate to `x` and `b`. +/// +pub trait SolveImpl: Scalar { + /// LAPACK correspondance + /// ---------------------- + /// + /// | f32 | f64 | c32 | c64 | + /// |:-------|:-------|:-------|:-------| + /// | sgetrs | dgetrs | cgetrs | zgetrs | + /// + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; +} +macro_rules! impl_solve { + ($scalar:ty, $getrs:path) => { + impl SolveImpl for $scalar { fn solve( l: MatrixLayout, t: Transpose, @@ -75,18 +100,41 @@ macro_rules! impl_solve { ipiv: &Pivot, b: &mut [Self], ) -> Result<()> { - let t = match l { + let (t, conj) = match l { MatrixLayout::C { .. } => match t { - Transpose::No => Transpose::Transpose, - Transpose::Transpose | Transpose::Hermite => Transpose::No, + Transpose::No => (Transpose::Transpose, false), + Transpose::Transpose => (Transpose::No, false), + Transpose::Hermite => (Transpose::No, true), }, - _ => t, + MatrixLayout::F { .. } => (t, false), }; let (n, _) = l.size(); let nrhs = 1; let ldb = l.lda(); let mut info = 0; - unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) }; + if conj { + for b_elem in &mut *b { + *b_elem = b_elem.conj(); + } + } + unsafe { + $getrs( + t.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b), + &ldb, + &mut info, + ) + }; + if conj { + for b_elem in &mut *b { + *b_elem = b_elem.conj(); + } + } info.as_lapack_result()?; Ok(()) } @@ -94,7 +142,83 @@ macro_rules! impl_solve { }; } // impl_solve! -impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs); -impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs); -impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs); -impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs); +impl_solve!(f64, lapack_sys::dgetrs_); +impl_solve!(f32, lapack_sys::sgetrs_); +impl_solve!(c64, lapack_sys::zgetrs_); +impl_solve!(c32, lapack_sys::cgetrs_); + +/// Working memory for computing inverse matrix +pub struct InvWork { + pub layout: MatrixLayout, + pub work: Vec>, +} + +/// Helper trait to abstract `*getri` LAPACK rotuines for implementing [Lapack::inv] +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | sgetri | dgetri | cgetri | zgetri | +/// +pub trait InvWorkImpl: Sized { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>; +} + +macro_rules! impl_inv_work { + ($s:ty, $tri:path) => { + impl InvWorkImpl for InvWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $tri( + &n, + std::ptr::null_mut(), + &layout.lda(), + std::ptr::null(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(InvWork { layout, work }) + } + + fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { + if self.layout.len() == 0 { + return Ok(()); + } + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $tri( + &self.layout.len(), + AsPtr::as_mut_ptr(a), + &self.layout.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(()) + } + } + }; +} + +impl_inv_work!(c64, lapack_sys::zgetri_); +impl_inv_work!(c32, lapack_sys::cgetri_); +impl_inv_work!(f64, lapack_sys::dgetri_); +impl_inv_work!(f32, lapack_sys::sgetri_); diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index 1a4d6e3e..abb75cb8 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -1,75 +1,169 @@ -//! Solve symmetric linear problem using the Bunch-Kaufman diagonal pivoting method. +//! Factorize symmetric/Hermitian matrix using [Bunch-Kaufman diagonal pivoting method][BK] +//! +//! [BK]: https://doi.org/10.2307/2005787 //! -//! See also [the manual of dsytrf](http://www.netlib.org/lapack/lapack-3.1.1/html/dsytrf.f.html) use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -pub trait Solveh_: Sized { - /// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf` - fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; - /// Wrapper of `*sytri` and `*hetri` - fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; - /// Wrapper of `*sytrs` and `*hetrs` - fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; +pub struct BkWork { + pub layout: MatrixLayout, + pub work: Vec>, + pub ipiv: Vec>, } -macro_rules! impl_solveh { - ($scalar:ty, $trf:path, $tri:path, $trs:path) => { - impl Solveh_ for $scalar { - fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { - let (n, _) = l.size(); - let mut ipiv = unsafe { vec_uninit(n as usize) }; - if n == 0 { - return Ok(Vec::new()); - } +/// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | ssytrf | dsytrf | chetrf | zhetrf | +/// +pub trait BkWorkImpl: Sized { + type Elem: Scalar; + fn new(l: MatrixLayout) -> Result; + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]>; + fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result; +} - // calc work size +macro_rules! impl_bk_work { + ($s:ty, $trf:path) => { + impl BkWorkImpl for BkWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let ipiv = vec_uninit(n as usize); let mut info = 0; - let mut work_size = [Self::zero()]; + let mut work_size = [Self::Elem::zero()]; unsafe { $trf( - uplo as u8, - n, - a, - l.lda(), - &mut ipiv, - &mut work_size, - -1, + UPLO::Upper.as_ptr(), + &n, + std::ptr::null_mut(), + &layout.lda(), + std::ptr::null_mut(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), &mut info, ) }; info.as_lapack_result()?; - - // actual let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit(lwork) }; + let work = vec_uninit(lwork); + Ok(BkWork { layout, work, ipiv }) + } + + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]> { + let (n, _) = self.layout.size(); + let lwork = self.work.len().to_i32().unwrap(); + if lwork == 0 { + return Ok(&[]); + } + let mut info = 0; unsafe { $trf( - uplo as u8, - n, - a, - l.lda(), - &mut ipiv, - &mut work, - lwork as i32, + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &self.layout.lda(), + AsPtr::as_mut_ptr(&mut self.ipiv), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, &mut info, ) }; info.as_lapack_result()?; - Ok(ipiv) + Ok(unsafe { self.ipiv.slice_assume_init_ref() }) } - fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { - let (n, _) = l.size(); + fn eval(mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result { + let _ref = self.calc(uplo, a)?; + Ok(unsafe { self.ipiv.assume_init() }) + } + } + }; +} +impl_bk_work!(c64, lapack_sys::zhetrf_); +impl_bk_work!(c32, lapack_sys::chetrf_); +impl_bk_work!(f64, lapack_sys::dsytrf_); +impl_bk_work!(f32, lapack_sys::ssytrf_); + +pub struct InvhWork { + pub layout: MatrixLayout, + pub work: Vec>, +} + +/// Compute inverse matrix of symmetric/Hermitian matrix +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | ssytri | dsytri | chetri | zhetri | +/// +pub trait InvhWorkImpl: Sized { + type Elem; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()>; +} + +macro_rules! impl_invh_work { + ($s:ty, $tri:path) => { + impl InvhWorkImpl for InvhWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let work = vec_uninit(n as usize); + Ok(InvhWork { layout, work }) + } + + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { + let (n, _) = self.layout.size(); let mut info = 0; - let mut work = unsafe { vec_uninit(n as usize) }; - unsafe { $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info) }; + unsafe { + $tri( + uplo.as_ptr(), + &n, + AsPtr::as_mut_ptr(a), + &self.layout.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut self.work), + &mut info, + ) + }; info.as_lapack_result()?; Ok(()) } + } + }; +} +impl_invh_work!(c64, lapack_sys::zhetri_); +impl_invh_work!(c32, lapack_sys::chetri_); +impl_invh_work!(f64, lapack_sys::dsytri_); +impl_invh_work!(f32, lapack_sys::ssytri_); +/// Solve symmetric/Hermitian linear equation +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | ssytrs | dsytrs | chetrs | zhetrs | +/// +pub trait SolvehImpl: Scalar { + fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_solveh_ { + ($s:ty, $trs:path) => { + impl SolvehImpl for $s { fn solveh( l: MatrixLayout, uplo: UPLO, @@ -79,15 +173,27 @@ macro_rules! impl_solveh { ) -> Result<()> { let (n, _) = l.size(); let mut info = 0; - unsafe { $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info) }; + unsafe { + $trs( + uplo.as_ptr(), + &n, + &1, + AsPtr::as_ptr(a), + &l.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b), + &n, + &mut info, + ) + }; info.as_lapack_result()?; Ok(()) } } }; -} // impl_solveh! +} -impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs); -impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs); -impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs); -impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs); +impl_solveh_!(c64, lapack_sys::zhetrs_); +impl_solveh_!(c32, lapack_sys::chetrs_); +impl_solveh_!(f64, lapack_sys::dsytrs_); +impl_solveh_!(f32, lapack_sys::ssytrs_); diff --git a/lax/src/svd.rs b/lax/src/svd.rs index c990cd27..fc695108 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -1,140 +1,314 @@ //! Singular-value decomposition +//! +//! LAPACK correspondance +//! ---------------------- +//! +//! | f32 | f64 | c32 | c64 | +//! |:-------|:-------|:-------|:-------| +//! | sgesvd | dgesvd | cgesvd | zgesvd | +//! -use crate::{error::*, layout::MatrixLayout, *}; +use super::{error::*, layout::*, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -#[repr(u8)] -#[derive(Debug, Copy, Clone)] -enum FlagSVD { - All = b'A', - // OverWrite = b'O', - // Separately = b'S', - No = b'N', +pub struct SvdWork { + pub ju: JobSvd, + pub jvt: JobSvd, + pub layout: MatrixLayout, + pub s: Vec>, + pub u: Option>>, + pub vt: Option>>, + pub work: Vec>, + pub rwork: Option>>, } -impl FlagSVD { - fn from_bool(calc_uv: bool) -> Self { - if calc_uv { - FlagSVD::All - } else { - FlagSVD::No - } - } +#[derive(Debug, Clone)] +pub struct SvdRef<'work, T: Scalar> { + pub s: &'work [T::Real], + pub u: Option<&'work [T]>, + pub vt: Option<&'work [T]>, } -/// Result of SVD -pub struct SVDOutput { - /// diagonal values - pub s: Vec, - /// Unitary matrix for destination space - pub u: Option>, - /// Unitary matrix for departure space - pub vt: Option>, +#[derive(Debug, Clone)] +pub struct SvdOwned { + pub s: Vec, + pub u: Option>, + pub vt: Option>, } -/// Wraps `*gesvd` -pub trait SVD_: Scalar { - /// Calculate singular value decomposition $ A = U \Sigma V^T $ - fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) - -> Result>; +pub trait SvdWorkImpl: Sized { + type Elem: Scalar; + fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result; + fn calc(&mut self, a: &mut [Self::Elem]) -> Result>; + fn eval(self, a: &mut [Self::Elem]) -> Result>; } -macro_rules! impl_svd { - (@real, $scalar:ty, $gesvd:path) => { - impl_svd!(@body, $scalar, $gesvd, ); - }; - (@complex, $scalar:ty, $gesvd:path) => { - impl_svd!(@body, $scalar, $gesvd, rwork); - }; - (@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => { - impl SVD_ for $scalar { - fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self],) -> Result> { - let ju = match l { - MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u), - MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt), +macro_rules! impl_svd_work_c { + ($s:ty, $svd:path) => { + impl SvdWorkImpl for SvdWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result { + let ju = match layout { + MatrixLayout::F { .. } => JobSvd::from_bool(calc_u), + MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt), }; - let jvt = match l { - MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt), - MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u), + let jvt = match layout { + MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt), + MatrixLayout::C { .. } => JobSvd::from_bool(calc_u), }; - let m = l.lda(); + let m = layout.lda(); let mut u = match ju { - FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }), - FlagSVD::No => None, + JobSvd::All => Some(vec_uninit((m * m) as usize)), + JobSvd::None => None, + _ => unimplemented!("SVD with partial vector output is not supported yet"), }; - let n = l.len(); + let n = layout.len(); let mut vt = match jvt { - FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }), - FlagSVD::No => None, + JobSvd::All => Some(vec_uninit((n * n) as usize)), + JobSvd::None => None, + _ => unimplemented!("SVD with partial vector output is not supported yet"), }; let k = std::cmp::min(m, n); - let mut s = unsafe { vec_uninit( k as usize) }; - - $( - let mut $rwork_ident = unsafe { vec_uninit( 5 * k as usize) }; - )* + let mut s = vec_uninit(k as usize); + let mut rwork = vec_uninit(5 * k as usize); // eval work size let mut info = 0; - let mut work_size = [Self::zero()]; + let mut work_size = [Self::Elem::zero()]; unsafe { - $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work_size, - -1, - $(&mut $rwork_ident,)* + $svd( + ju.as_ptr(), + jvt.as_ptr(), + &m, + &n, + std::ptr::null_mut(), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), &mut info, ); } info.as_lapack_result()?; - - // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit( lwork) }; + let work = vec_uninit(lwork); + Ok(SvdWork { + layout, + ju, + jvt, + s, + u, + vt, + work, + rwork: Some(rwork), + }) + } + + fn calc(&mut self, a: &mut [Self::Elem]) -> Result> { + let m = self.layout.lda(); + let n = self.layout.len(); + let lwork = self.work.len().to_i32().unwrap(); + + let mut info = 0; unsafe { - $gesvd( - ju as u8, - jvt as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - n, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* + $svd( + self.ju.as_ptr(), + self.jvt.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut self.s), + AsPtr::as_mut_ptr( + self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + ), + &m, + AsPtr::as_mut_ptr( + self.vt + .as_mut() + .map(|x| x.as_mut_slice()) + .unwrap_or(&mut []), + ), + &n, + AsPtr::as_mut_ptr(&mut self.work), + &(lwork as i32), + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), &mut info, ); } info.as_lapack_result()?; - match l { - MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), - MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + + let s = unsafe { self.s.slice_assume_init_ref() }; + let u = self + .u + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }); + let vt = self + .vt + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }); + + match self.layout { + MatrixLayout::F { .. } => Ok(SvdRef { s, u, vt }), + MatrixLayout::C { .. } => Ok(SvdRef { s, u: vt, vt: u }), + } + } + + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let _ref = self.calc(a)?; + let s = unsafe { self.s.assume_init() }; + let u = self.u.map(|v| unsafe { v.assume_init() }); + let vt = self.vt.map(|v| unsafe { v.assume_init() }); + match self.layout { + MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }), + MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }), } } } }; -} // impl_svd! +} +impl_svd_work_c!(c64, lapack_sys::zgesvd_); +impl_svd_work_c!(c32, lapack_sys::cgesvd_); + +macro_rules! impl_svd_work_r { + ($s:ty, $svd:path) => { + impl SvdWorkImpl for SvdWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result { + let ju = match layout { + MatrixLayout::F { .. } => JobSvd::from_bool(calc_u), + MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt), + }; + let jvt = match layout { + MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt), + MatrixLayout::C { .. } => JobSvd::from_bool(calc_u), + }; + + let m = layout.lda(); + let mut u = match ju { + JobSvd::All => Some(vec_uninit((m * m) as usize)), + JobSvd::None => None, + _ => unimplemented!("SVD with partial vector output is not supported yet"), + }; + + let n = layout.len(); + let mut vt = match jvt { + JobSvd::All => Some(vec_uninit((n * n) as usize)), + JobSvd::None => None, + _ => unimplemented!("SVD with partial vector output is not supported yet"), + }; + + let k = std::cmp::min(m, n); + let mut s = vec_uninit(k as usize); -impl_svd!(@real, f64, lapack::dgesvd); -impl_svd!(@real, f32, lapack::sgesvd); -impl_svd!(@complex, c64, lapack::zgesvd); -impl_svd!(@complex, c32, lapack::cgesvd); + // eval work size + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $svd( + ju.as_ptr(), + jvt.as_ptr(), + &m, + &n, + std::ptr::null_mut(), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &n, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ); + } + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(SvdWork { + layout, + ju, + jvt, + s, + u, + vt, + work, + rwork: None, + }) + } + + fn calc(&mut self, a: &mut [Self::Elem]) -> Result> { + let m = self.layout.lda(); + let n = self.layout.len(); + let lwork = self.work.len().to_i32().unwrap(); + + let mut info = 0; + unsafe { + $svd( + self.ju.as_ptr(), + self.jvt.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut self.s), + AsPtr::as_mut_ptr( + self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + ), + &m, + AsPtr::as_mut_ptr( + self.vt + .as_mut() + .map(|x| x.as_mut_slice()) + .unwrap_or(&mut []), + ), + &n, + AsPtr::as_mut_ptr(&mut self.work), + &(lwork as i32), + &mut info, + ); + } + info.as_lapack_result()?; + + let s = unsafe { self.s.slice_assume_init_ref() }; + let u = self + .u + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }); + let vt = self + .vt + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }); + + match self.layout { + MatrixLayout::F { .. } => Ok(SvdRef { s, u, vt }), + MatrixLayout::C { .. } => Ok(SvdRef { s, u: vt, vt: u }), + } + } + + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let _ref = self.calc(a)?; + let s = unsafe { self.s.assume_init() }; + let u = self.u.map(|v| unsafe { v.assume_init() }); + let vt = self.vt.map(|v| unsafe { v.assume_init() }); + match self.layout { + MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }), + MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }), + } + } + } + }; +} +impl_svd_work_r!(f64, lapack_sys::dgesvd_); +impl_svd_work_r!(f32, lapack_sys::sgesvd_); diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index a94bdace..c16db4bb 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -1,125 +1,307 @@ +//! Compute singular value decomposition with divide-and-conquer algorithm +//! +//! LAPACK correspondance +//! ---------------------- +//! +//! | f32 | f64 | c32 | c64 | +//! |:-------|:-------|:-------|:-------| +//! | sgesdd | dgesdd | cgesdd | zgesdd | +//! + use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. -/// -/// For an input array of shape *m*×*n*, the following are computed: -#[derive(Clone, Copy, Eq, PartialEq)] -#[repr(u8)] -pub enum UVTFlag { - /// All *m* columns of *U* and all *n* rows of *V*ᵀ. - Full = b'A', - /// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ. - Some = b'S', - /// No columns of *U* or rows of *V*ᵀ. - None = b'N', +pub struct SvdDcWork { + pub jobz: JobSvd, + pub layout: MatrixLayout, + pub s: Vec>, + pub u: Option>>, + pub vt: Option>>, + pub work: Vec>, + pub iwork: Vec>, + pub rwork: Option>>, } -pub trait SVDDC_: Scalar { - fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; +pub trait SvdDcWorkImpl: Sized { + type Elem: Scalar; + fn new(layout: MatrixLayout, jobz: JobSvd) -> Result; + fn calc(&mut self, a: &mut [Self::Elem]) -> Result>; + fn eval(self, a: &mut [Self::Elem]) -> Result>; } -macro_rules! impl_svddc { - (@real, $scalar:ty, $gesdd:path) => { - impl_svddc!(@body, $scalar, $gesdd, ); - }; - (@complex, $scalar:ty, $gesdd:path) => { - impl_svddc!(@body, $scalar, $gesdd, rwork); - }; - (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { - impl SVDDC_ for $scalar { - fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self],) -> Result> { - let m = l.lda(); - let n = l.len(); - let k = m.min(n); - let mut s = unsafe { vec_uninit( k as usize) }; +macro_rules! impl_svd_dc_work_c { + ($s:ty, $sdd:path) => { + impl SvdDcWorkImpl for SvdDcWork<$s> { + type Elem = $s; + fn new(layout: MatrixLayout, jobz: JobSvd) -> Result { + let m = layout.lda(); + let n = layout.len(); + let k = m.min(n); let (u_col, vt_row) = match jobz { - UVTFlag::Full | UVTFlag::None => (m, n), - UVTFlag::Some => (k, k), + JobSvd::All | JobSvd::None => (m, n), + JobSvd::Some => (k, k), }; + + let mut s = vec_uninit(k as usize); let (mut u, mut vt) = match jobz { - UVTFlag::Full => ( - Some(unsafe { vec_uninit( (m * m) as usize) }), - Some(unsafe { vec_uninit( (n * n) as usize) }), + JobSvd::All => ( + Some(vec_uninit((m * m) as usize)), + Some(vec_uninit((n * n) as usize)), ), - UVTFlag::Some => ( - Some(unsafe { vec_uninit( (m * u_col) as usize) }), - Some(unsafe { vec_uninit( (n * vt_row) as usize) }), + JobSvd::Some => ( + Some(vec_uninit((m * u_col) as usize)), + Some(vec_uninit((n * vt_row) as usize)), ), - UVTFlag::None => (None, None), + JobSvd::None => (None, None), }; + let mut iwork = vec_uninit(8 * k as usize); - $( // for complex only let mx = n.max(m) as usize; let mn = n.min(m) as usize; let lrwork = match jobz { - UVTFlag::None => 7 * mn, - _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), + JobSvd::None => 7 * mn, + _ => std::cmp::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn), }; - let mut $rwork_ident = unsafe { vec_uninit( lrwork) }; - )* + let mut rwork = vec_uninit(lrwork); - // eval work size let mut info = 0; - let mut iwork = unsafe { vec_uninit( 8 * k as usize) }; - let mut work_size = [Self::zero()]; + let mut work_size = [Self::Elem::zero()]; unsafe { - $gesdd( - jobz as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - vt_row, - &mut work_size, - -1, - $(&mut $rwork_ident,)* - &mut iwork, + $sdd( + jobz.as_ptr(), + &m, + &n, + std::ptr::null_mut(), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &vt_row, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), + AsPtr::as_mut_ptr(&mut iwork), &mut info, ); } info.as_lapack_result()?; - - // do svd let lwork = work_size[0].to_usize().unwrap(); - let mut work = unsafe { vec_uninit( lwork) }; + let work = vec_uninit(lwork); + Ok(SvdDcWork { + layout, + jobz, + iwork, + work, + rwork: Some(rwork), + u, + vt, + s, + }) + } + + fn calc(&mut self, a: &mut [Self::Elem]) -> Result> { + let m = self.layout.lda(); + let n = self.layout.len(); + let k = m.min(n); + let (_, vt_row) = match self.jobz { + JobSvd::All | JobSvd::None => (m, n), + JobSvd::Some => (k, k), + }; + let lwork = self.work.len().to_i32().unwrap(); + + let mut info = 0; unsafe { - $gesdd( - jobz as u8, - m, - n, - &mut a, - m, - &mut s, - u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - m, - vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), - vt_row, - &mut work, - lwork as i32, - $(&mut $rwork_ident,)* - &mut iwork, + $sdd( + self.jobz.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut self.s), + AsPtr::as_mut_ptr( + self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + ), + &m, + AsPtr::as_mut_ptr( + self.vt + .as_mut() + .map(|x| x.as_mut_slice()) + .unwrap_or(&mut []), + ), + &vt_row, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()), + AsPtr::as_mut_ptr(&mut self.iwork), &mut info, ); } info.as_lapack_result()?; - match l { - MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), - MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), - } + let s = unsafe { self.s.slice_assume_init_ref() }; + let u = self + .u + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }); + let vt = self + .vt + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }); + + Ok(match self.layout { + MatrixLayout::F { .. } => SvdRef { s, u, vt }, + MatrixLayout::C { .. } => SvdRef { s, u: vt, vt: u }, + }) + } + + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let _ref = self.calc(a)?; + let s = unsafe { self.s.assume_init() }; + let u = self.u.map(|v| unsafe { v.assume_init() }); + let vt = self.vt.map(|v| unsafe { v.assume_init() }); + Ok(match self.layout { + MatrixLayout::F { .. } => SvdOwned { s, u, vt }, + MatrixLayout::C { .. } => SvdOwned { s, u: vt, vt: u }, + }) } } }; } +impl_svd_dc_work_c!(c64, lapack_sys::zgesdd_); +impl_svd_dc_work_c!(c32, lapack_sys::cgesdd_); + +macro_rules! impl_svd_dc_work_r { + ($s:ty, $sdd:path) => { + impl SvdDcWorkImpl for SvdDcWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout, jobz: JobSvd) -> Result { + let m = layout.lda(); + let n = layout.len(); + let k = m.min(n); + let (u_col, vt_row) = match jobz { + JobSvd::All | JobSvd::None => (m, n), + JobSvd::Some => (k, k), + }; + + let mut s = vec_uninit(k as usize); + let (mut u, mut vt) = match jobz { + JobSvd::All => ( + Some(vec_uninit((m * m) as usize)), + Some(vec_uninit((n * n) as usize)), + ), + JobSvd::Some => ( + Some(vec_uninit((m * u_col) as usize)), + Some(vec_uninit((n * vt_row) as usize)), + ), + JobSvd::None => (None, None), + }; + let mut iwork = vec_uninit(8 * k as usize); -impl_svddc!(@real, f32, lapack::sgesdd); -impl_svddc!(@real, f64, lapack::dgesdd); -impl_svddc!(@complex, c32, lapack::cgesdd); -impl_svddc!(@complex, c64, lapack::zgesdd); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $sdd( + jobz.as_ptr(), + &m, + &n, + std::ptr::null_mut(), + &m, + AsPtr::as_mut_ptr(&mut s), + AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &m, + AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), + &vt_row, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut iwork), + &mut info, + ); + } + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(SvdDcWork { + layout, + jobz, + iwork, + work, + rwork: None, + u, + vt, + s, + }) + } + + fn calc(&mut self, a: &mut [Self::Elem]) -> Result> { + let m = self.layout.lda(); + let n = self.layout.len(); + let k = m.min(n); + let (_, vt_row) = match self.jobz { + JobSvd::All | JobSvd::None => (m, n), + JobSvd::Some => (k, k), + }; + let lwork = self.work.len().to_i32().unwrap(); + + let mut info = 0; + unsafe { + $sdd( + self.jobz.as_ptr(), + &m, + &n, + AsPtr::as_mut_ptr(a), + &m, + AsPtr::as_mut_ptr(&mut self.s), + AsPtr::as_mut_ptr( + self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + ), + &m, + AsPtr::as_mut_ptr( + self.vt + .as_mut() + .map(|x| x.as_mut_slice()) + .unwrap_or(&mut []), + ), + &vt_row, + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + AsPtr::as_mut_ptr(&mut self.iwork), + &mut info, + ); + } + info.as_lapack_result()?; + + let s = unsafe { self.s.slice_assume_init_ref() }; + let u = self + .u + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }); + let vt = self + .vt + .as_ref() + .map(|v| unsafe { v.slice_assume_init_ref() }); + + Ok(match self.layout { + MatrixLayout::F { .. } => SvdRef { s, u, vt }, + MatrixLayout::C { .. } => SvdRef { s, u: vt, vt: u }, + }) + } + + fn eval(mut self, a: &mut [Self::Elem]) -> Result> { + let _ref = self.calc(a)?; + let s = unsafe { self.s.assume_init() }; + let u = self.u.map(|v| unsafe { v.assume_init() }); + let vt = self.vt.map(|v| unsafe { v.assume_init() }); + Ok(match self.layout { + MatrixLayout::F { .. } => SvdOwned { s, u, vt }, + MatrixLayout::C { .. } => SvdOwned { s, u: vt, vt: u }, + }) + } + } + }; +} +impl_svd_dc_work_r!(f64, lapack_sys::dgesdd_); +impl_svd_dc_work_r!(f32, lapack_sys::sgesdd_); diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index a48b12b3..da4dc4cf 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -1,17 +1,18 @@ -//! Implement linear solver and inverse matrix +//! Linear problem for triangular matrices use crate::{error::*, layout::*, *}; use cauchy::*; -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Diag { - Unit = b'U', - NonUnit = b'N', -} - -/// Wraps `*trtri` and `*trtrs` -pub trait Triangular_: Scalar { +/// Solve linear problem for triangular matrices +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | strtrs | dtrtrs | ctrtrs | ztrtrs | +/// +pub trait SolveTriangularImpl: Scalar { fn solve_triangular( al: MatrixLayout, bl: MatrixLayout, @@ -23,8 +24,8 @@ pub trait Triangular_: Scalar { } macro_rules! impl_triangular { - ($scalar:ty, $trtri:path, $trtrs:path) => { - impl Triangular_ for $scalar { + ($scalar:ty, $trtrs:path) => { + impl SolveTriangularImpl for $scalar { fn solve_triangular( a_layout: MatrixLayout, b_layout: MatrixLayout, @@ -37,8 +38,9 @@ macro_rules! impl_triangular { let mut a_t = None; let a_layout = match a_layout { MatrixLayout::C { .. } => { - a_t = Some(unsafe { vec_uninit(a.len()) }); - transpose(a_layout, a, a_t.as_mut().unwrap()) + let (layout, t) = transpose(a_layout, a); + a_t = Some(t); + layout } MatrixLayout::F { .. } => a_layout, }; @@ -47,8 +49,9 @@ macro_rules! impl_triangular { let mut b_t = None; let b_layout = match b_layout { MatrixLayout::C { .. } => { - b_t = Some(unsafe { vec_uninit(b.len()) }); - transpose(b_layout, b, b_t.as_mut().unwrap()) + let (layout, t) = transpose(b_layout, b); + b_t = Some(t); + layout } MatrixLayout::F { .. } => b_layout, }; @@ -60,15 +63,15 @@ macro_rules! impl_triangular { let mut info = 0; unsafe { $trtrs( - uplo as u8, - Transpose::No as u8, - diag as u8, - m, - nrhs, - a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a), - a_layout.lda(), - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - b_layout.lda(), + uplo.as_ptr(), + Transpose::No.as_ptr(), + diag.as_ptr(), + &m, + &nrhs, + AsPtr::as_ptr(a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a)), + &a_layout.lda(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &b_layout.lda(), &mut info, ); } @@ -76,7 +79,7 @@ macro_rules! impl_triangular { // Re-transpose b if let Some(b_t) = b_t { - transpose(b_layout, &b_t, b); + transpose_over(b_layout, &b_t, b); } Ok(()) } @@ -84,7 +87,7 @@ macro_rules! impl_triangular { }; } // impl_triangular! -impl_triangular!(f64, lapack::dtrtri, lapack::dtrtrs); -impl_triangular!(f32, lapack::strtri, lapack::strtrs); -impl_triangular!(c64, lapack::ztrtri, lapack::ztrtrs); -impl_triangular!(c32, lapack::ctrtri, lapack::ctrtrs); +impl_triangular!(f64, lapack_sys::dtrtrs_); +impl_triangular!(f32, lapack_sys::strtrs_); +impl_triangular!(c64, lapack_sys::ztrtrs_); +impl_triangular!(c32, lapack_sys::ctrtrs_); diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs deleted file mode 100644 index f6cbfc7e..00000000 --- a/lax/src/tridiagonal.rs +++ /dev/null @@ -1,246 +0,0 @@ -//! Implement linear solver using LU decomposition -//! for tridiagonal matrix - -use crate::{error::*, layout::*, *}; -use cauchy::*; -use num_traits::Zero; -use std::ops::{Index, IndexMut}; - -/// Represents a tridiagonal matrix as 3 one-dimensional vectors. -/// -/// ```text -/// [d0, u1, 0, ..., 0, -/// l1, d1, u2, ..., -/// 0, l2, d2, -/// ... ..., u{n-1}, -/// 0, ..., l{n-1}, d{n-1},] -/// ``` -#[derive(Clone, PartialEq)] -pub struct Tridiagonal { - /// layout of raw matrix - pub l: MatrixLayout, - /// (n-1) sub-diagonal elements of matrix. - pub dl: Vec, - /// (n) diagonal elements of matrix. - pub d: Vec, - /// (n-1) super-diagonal elements of matrix. - pub du: Vec, -} - -impl Tridiagonal { - fn opnorm_one(&self) -> A::Real { - let mut col_sum: Vec = self.d.iter().map(|val| val.abs()).collect(); - for i in 0..col_sum.len() { - if i < self.dl.len() { - col_sum[i] += self.dl[i].abs(); - } - if i > 0 { - col_sum[i] += self.du[i - 1].abs(); - } - } - let mut max = A::Real::zero(); - for &val in &col_sum { - if max < val { - max = val; - } - } - max - } -} - -/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`. -#[derive(Clone, PartialEq)] -pub struct LUFactorizedTridiagonal { - /// A tridiagonal matrix which consists of - /// - l : layout of raw matrix - /// - dl: (n-1) multipliers that define the matrix L. - /// - d : (n) diagonal elements of the upper triangular matrix U. - /// - du: (n-1) elements of the first super-diagonal of U. - pub a: Tridiagonal, - /// (n-2) elements of the second super-diagonal of U. - pub du2: Vec, - /// The pivot indices that define the permutation matrix `P`. - pub ipiv: Pivot, - - a_opnorm_one: A::Real, -} - -impl Index<(i32, i32)> for Tridiagonal { - type Output = A; - #[inline] - fn index(&self, (row, col): (i32, i32)) -> &A { - let (n, _) = self.l.size(); - assert!( - std::cmp::max(row, col) < n, - "ndarray: index {:?} is out of bounds for array of shape {}", - [row, col], - n - ); - match row - col { - 0 => &self.d[row as usize], - 1 => &self.dl[col as usize], - -1 => &self.du[row as usize], - _ => panic!( - "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", - [row, col] - ), - } - } -} - -impl Index<[i32; 2]> for Tridiagonal { - type Output = A; - #[inline] - fn index(&self, [row, col]: [i32; 2]) -> &A { - &self[(row, col)] - } -} - -impl IndexMut<(i32, i32)> for Tridiagonal { - #[inline] - fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A { - let (n, _) = self.l.size(); - assert!( - std::cmp::max(row, col) < n, - "ndarray: index {:?} is out of bounds for array of shape {}", - [row, col], - n - ); - match row - col { - 0 => &mut self.d[row as usize], - 1 => &mut self.dl[col as usize], - -1 => &mut self.du[row as usize], - _ => panic!( - "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", - [row, col] - ), - } - } -} - -impl IndexMut<[i32; 2]> for Tridiagonal { - #[inline] - fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A { - &mut self[(row, col)] - } -} - -/// Wraps `*gttrf`, `*gtcon` and `*gttrs` -pub trait Tridiagonal_: Scalar + Sized { - /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using - /// partial pivoting with row interchanges. - fn lu_tridiagonal(a: Tridiagonal) -> Result>; - - fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; - - fn solve_tridiagonal( - lu: &LUFactorizedTridiagonal, - bl: MatrixLayout, - t: Transpose, - b: &mut [Self], - ) -> Result<()>; -} - -macro_rules! impl_tridiagonal { - (@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { - impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork); - }; - (@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { - impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, ); - }; - (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => { - impl Tridiagonal_ for $scalar { - fn lu_tridiagonal(mut a: Tridiagonal) -> Result> { - let (n, _) = a.l.size(); - let mut du2 = unsafe { vec_uninit( (n - 2) as usize) }; - let mut ipiv = unsafe { vec_uninit( n as usize) }; - // We have to calc one-norm before LU factorization - let a_opnorm_one = a.opnorm_one(); - let mut info = 0; - unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) }; - info.as_lapack_result()?; - Ok(LUFactorizedTridiagonal { - a, - du2, - ipiv, - a_opnorm_one, - }) - } - - fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { - let (n, _) = lu.a.l.size(); - let ipiv = &lu.ipiv; - let mut work = unsafe { vec_uninit( 2 * n as usize) }; - $( - let mut $iwork = unsafe { vec_uninit( n as usize) }; - )* - let mut rcond = Self::Real::zero(); - let mut info = 0; - unsafe { - $gtcon( - NormType::One as u8, - n, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - lu.a_opnorm_one, - &mut rcond, - &mut work, - $(&mut $iwork,)* - &mut info, - ); - } - info.as_lapack_result()?; - Ok(rcond) - } - - fn solve_tridiagonal( - lu: &LUFactorizedTridiagonal, - b_layout: MatrixLayout, - t: Transpose, - b: &mut [Self], - ) -> Result<()> { - let (n, _) = lu.a.l.size(); - let ipiv = &lu.ipiv; - // Transpose if b is C-continuous - let mut b_t = None; - let b_layout = match b_layout { - MatrixLayout::C { .. } => { - b_t = Some(unsafe { vec_uninit( b.len()) }); - transpose(b_layout, b, b_t.as_mut().unwrap()) - } - MatrixLayout::F { .. } => b_layout, - }; - let (ldb, nrhs) = b_layout.size(); - let mut info = 0; - unsafe { - $gttrs( - t as u8, - n, - nrhs, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - ldb, - &mut info, - ); - } - info.as_lapack_result()?; - if let Some(b_t) = b_t { - transpose(b_layout, &b_t, b); - } - Ok(()) - } - } - }; -} // impl_tridiagonal! - -impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs); -impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs); -impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs); -impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs); diff --git a/lax/src/tridiagonal/lu.rs b/lax/src/tridiagonal/lu.rs new file mode 100644 index 00000000..e159bec6 --- /dev/null +++ b/lax/src/tridiagonal/lu.rs @@ -0,0 +1,101 @@ +use crate::*; +use cauchy::*; +use num_traits::Zero; + +/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`. +#[derive(Clone, PartialEq)] +pub struct LUFactorizedTridiagonal { + /// A tridiagonal matrix which consists of + /// - l : layout of raw matrix + /// - dl: (n-1) multipliers that define the matrix L. + /// - d : (n) diagonal elements of the upper triangular matrix U. + /// - du: (n-1) elements of the first super-diagonal of U. + pub a: Tridiagonal, + /// (n-2) elements of the second super-diagonal of U. + pub du2: Vec, + /// The pivot indices that define the permutation matrix `P`. + pub ipiv: Pivot, + + pub a_opnorm_one: A::Real, +} + +impl Tridiagonal { + fn opnorm_one(&self) -> A::Real { + let mut col_sum: Vec = self.d.iter().map(|val| val.abs()).collect(); + for i in 0..col_sum.len() { + if i < self.dl.len() { + col_sum[i] += self.dl[i].abs(); + } + if i > 0 { + col_sum[i] += self.du[i - 1].abs(); + } + } + let mut max = A::Real::zero(); + for &val in &col_sum { + if max < val { + max = val; + } + } + max + } +} + +pub struct LuTridiagonalWork { + pub layout: MatrixLayout, + pub du2: Vec>, + pub ipiv: Vec>, +} + +pub trait LuTridiagonalWorkImpl { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Self; + fn eval(self, a: Tridiagonal) -> Result>; +} + +macro_rules! impl_lu_tridiagonal_work { + ($s:ty, $trf:path) => { + impl LuTridiagonalWorkImpl for LuTridiagonalWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Self { + let (n, _) = layout.size(); + let du2 = vec_uninit((n - 2) as usize); + let ipiv = vec_uninit(n as usize); + LuTridiagonalWork { layout, du2, ipiv } + } + + fn eval( + mut self, + mut a: Tridiagonal, + ) -> Result> { + let (n, _) = self.layout.size(); + // We have to calc one-norm before LU factorization + let a_opnorm_one = a.opnorm_one(); + let mut info = 0; + unsafe { + $trf( + &n, + AsPtr::as_mut_ptr(&mut a.dl), + AsPtr::as_mut_ptr(&mut a.d), + AsPtr::as_mut_ptr(&mut a.du), + AsPtr::as_mut_ptr(&mut self.du2), + AsPtr::as_mut_ptr(&mut self.ipiv), + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(LUFactorizedTridiagonal { + a, + du2: unsafe { self.du2.assume_init() }, + ipiv: unsafe { self.ipiv.assume_init() }, + a_opnorm_one, + }) + } + } + }; +} + +impl_lu_tridiagonal_work!(c64, lapack_sys::zgttrf_); +impl_lu_tridiagonal_work!(c32, lapack_sys::cgttrf_); +impl_lu_tridiagonal_work!(f64, lapack_sys::dgttrf_); +impl_lu_tridiagonal_work!(f32, lapack_sys::sgttrf_); diff --git a/lax/src/tridiagonal/matrix.rs b/lax/src/tridiagonal/matrix.rs new file mode 100644 index 00000000..47401430 --- /dev/null +++ b/lax/src/tridiagonal/matrix.rs @@ -0,0 +1,84 @@ +use crate::layout::*; +use cauchy::*; +use std::ops::{Index, IndexMut}; + +/// Represents a tridiagonal matrix as 3 one-dimensional vectors. +/// +/// ```text +/// [d0, u1, 0, ..., 0, +/// l1, d1, u2, ..., +/// 0, l2, d2, +/// ... ..., u{n-1}, +/// 0, ..., l{n-1}, d{n-1},] +/// ``` +#[derive(Clone, PartialEq, Eq)] +pub struct Tridiagonal { + /// layout of raw matrix + pub l: MatrixLayout, + /// (n-1) sub-diagonal elements of matrix. + pub dl: Vec, + /// (n) diagonal elements of matrix. + pub d: Vec, + /// (n-1) super-diagonal elements of matrix. + pub du: Vec, +} + +impl Index<(i32, i32)> for Tridiagonal { + type Output = A; + #[inline] + fn index(&self, (row, col): (i32, i32)) -> &A { + let (n, _) = self.l.size(); + assert!( + std::cmp::max(row, col) < n, + "ndarray: index {:?} is out of bounds for array of shape {}", + [row, col], + n + ); + match row - col { + 0 => &self.d[row as usize], + 1 => &self.dl[col as usize], + -1 => &self.du[row as usize], + _ => panic!( + "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", + [row, col] + ), + } + } +} + +impl Index<[i32; 2]> for Tridiagonal { + type Output = A; + #[inline] + fn index(&self, [row, col]: [i32; 2]) -> &A { + &self[(row, col)] + } +} + +impl IndexMut<(i32, i32)> for Tridiagonal { + #[inline] + fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A { + let (n, _) = self.l.size(); + assert!( + std::cmp::max(row, col) < n, + "ndarray: index {:?} is out of bounds for array of shape {}", + [row, col], + n + ); + match row - col { + 0 => &mut self.d[row as usize], + 1 => &mut self.dl[col as usize], + -1 => &mut self.du[row as usize], + _ => panic!( + "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element", + [row, col] + ), + } + } +} + +impl IndexMut<[i32; 2]> for Tridiagonal { + #[inline] + fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A { + &mut self[(row, col)] + } +} diff --git a/lax/src/tridiagonal/mod.rs b/lax/src/tridiagonal/mod.rs new file mode 100644 index 00000000..834ddd8a --- /dev/null +++ b/lax/src/tridiagonal/mod.rs @@ -0,0 +1,12 @@ +//! Implement linear solver using LU decomposition +//! for tridiagonal matrix + +mod lu; +mod matrix; +mod rcond; +mod solve; + +pub use lu::*; +pub use matrix::*; +pub use rcond::*; +pub use solve::*; diff --git a/lax/src/tridiagonal/rcond.rs b/lax/src/tridiagonal/rcond.rs new file mode 100644 index 00000000..a309cae4 --- /dev/null +++ b/lax/src/tridiagonal/rcond.rs @@ -0,0 +1,109 @@ +use crate::*; +use cauchy::*; +use num_traits::Zero; + +pub struct RcondTridiagonalWork { + pub work: Vec>, + pub iwork: Option>>, +} + +pub trait RcondTridiagonalWorkImpl { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Self; + fn calc( + &mut self, + lu: &LUFactorizedTridiagonal, + ) -> Result<::Real>; +} + +macro_rules! impl_rcond_tridiagonal_work_c { + ($c:ty, $gtcon:path) => { + impl RcondTridiagonalWorkImpl for RcondTridiagonalWork<$c> { + type Elem = $c; + + fn new(layout: MatrixLayout) -> Self { + let (n, _) = layout.size(); + let work = vec_uninit(2 * n as usize); + RcondTridiagonalWork { work, iwork: None } + } + + fn calc( + &mut self, + lu: &LUFactorizedTridiagonal, + ) -> Result<::Real> { + let (n, _) = lu.a.l.size(); + let ipiv = &lu.ipiv; + let mut rcond = ::Real::zero(); + let mut info = 0; + unsafe { + $gtcon( + NormType::One.as_ptr(), + &n, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + ipiv.as_ptr(), + &lu.a_opnorm_one, + &mut rcond, + AsPtr::as_mut_ptr(&mut self.work), + &mut info, + ); + } + info.as_lapack_result()?; + Ok(rcond) + } + } + }; +} + +impl_rcond_tridiagonal_work_c!(c64, lapack_sys::zgtcon_); +impl_rcond_tridiagonal_work_c!(c32, lapack_sys::cgtcon_); + +macro_rules! impl_rcond_tridiagonal_work_r { + ($c:ty, $gtcon:path) => { + impl RcondTridiagonalWorkImpl for RcondTridiagonalWork<$c> { + type Elem = $c; + + fn new(layout: MatrixLayout) -> Self { + let (n, _) = layout.size(); + let work = vec_uninit(2 * n as usize); + let iwork = vec_uninit(n as usize); + RcondTridiagonalWork { + work, + iwork: Some(iwork), + } + } + + fn calc( + &mut self, + lu: &LUFactorizedTridiagonal, + ) -> Result<::Real> { + let (n, _) = lu.a.l.size(); + let mut rcond = ::Real::zero(); + let mut info = 0; + unsafe { + $gtcon( + NormType::One.as_ptr(), + &n, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + AsPtr::as_ptr(&lu.ipiv), + &lu.a_opnorm_one, + &mut rcond, + AsPtr::as_mut_ptr(&mut self.work), + AsPtr::as_mut_ptr(self.iwork.as_mut().unwrap()), + &mut info, + ); + } + info.as_lapack_result()?; + Ok(rcond) + } + } + }; +} + +impl_rcond_tridiagonal_work_r!(f64, lapack_sys::dgtcon_); +impl_rcond_tridiagonal_work_r!(f32, lapack_sys::sgtcon_); diff --git a/lax/src/tridiagonal/solve.rs b/lax/src/tridiagonal/solve.rs new file mode 100644 index 00000000..43f7d120 --- /dev/null +++ b/lax/src/tridiagonal/solve.rs @@ -0,0 +1,64 @@ +use crate::{error::*, layout::*, *}; +use cauchy::*; + +pub trait SolveTridiagonalImpl: Scalar { + fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + bl: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()>; +} + +macro_rules! impl_solve_tridiagonal { + ($s:ty, $trs:path) => { + impl SolveTridiagonalImpl for $s { + fn solve_tridiagonal( + lu: &LUFactorizedTridiagonal, + b_layout: MatrixLayout, + t: Transpose, + b: &mut [Self], + ) -> Result<()> { + let (n, _) = lu.a.l.size(); + let ipiv = &lu.ipiv; + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + let (layout, t) = transpose(b_layout, b); + b_t = Some(t); + layout + } + MatrixLayout::F { .. } => b_layout, + }; + let (ldb, nrhs) = b_layout.size(); + let mut info = 0; + unsafe { + $trs( + t.as_ptr(), + &n, + &nrhs, + AsPtr::as_ptr(&lu.a.dl), + AsPtr::as_ptr(&lu.a.d), + AsPtr::as_ptr(&lu.a.du), + AsPtr::as_ptr(&lu.du2), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), + &ldb, + &mut info, + ); + } + info.as_lapack_result()?; + if let Some(b_t) = b_t { + transpose_over(b_layout, &b_t, b); + } + Ok(()) + } + } + }; +} + +impl_solve_tridiagonal!(c64, lapack_sys::zgttrs_); +impl_solve_tridiagonal!(c32, lapack_sys::cgttrs_); +impl_solve_tridiagonal!(f64, lapack_sys::dgttrs_); +impl_solve_tridiagonal!(f32, lapack_sys::sgttrs_); diff --git a/ndarray-linalg/Cargo.toml b/ndarray-linalg/Cargo.toml index 63fa3e86..8c5c83a8 100644 --- a/ndarray-linalg/Cargo.toml +++ b/ndarray-linalg/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ndarray-linalg" -version = "0.13.0" +version = "0.17.0" authors = ["Toshiki Teramura "] edition = "2018" @@ -13,7 +13,8 @@ readme = "../README.md" categories = ["algorithms", "science"] [features] -default = [] +default = ["blas"] +blas = ["ndarray/blas"] netlib = ["lax/netlib"] openblas = ["lax/openblas"] @@ -29,27 +30,29 @@ intel-mkl-static = ["lax/intel-mkl-static"] intel-mkl-system = ["lax/intel-mkl-system"] [dependencies] -cauchy = "0.3.0" -num-complex = "0.3.1" -num-traits = "0.2.11" -rand = "0.7.3" -thiserror = "1.0.20" +cauchy = "0.4.0" +katexit = "0.1.2" +num-complex = "0.4.0" +num-traits = "0.2.14" +rand = "0.8.3" +thiserror = "2.0.0" [dependencies.ndarray] -version = "0.14" -features = ["blas", "approx"] +version = "0.16.0" +features = ["approx", "std"] default-features = false [dependencies.lax] -version = "0.1.0" +version = "0.17.0" path = "../lax" default-features = false [dev-dependencies] -paste = "1.0" -criterion = "0.3" +paste = "1.0.5" +criterion = "0.5.1" # Keep the same version as ndarray's dependency! -approx = { version = "0.4", features = ["num-complex"] } +approx = { version = "0.5", features = ["num-complex"] } +rand_pcg = "0.3.1" [[bench]] name = "truncated_eig" @@ -78,9 +81,3 @@ harness = false [[bench]] name = "solveh" harness = false - -[package.metadata.docs.rs] -rustdoc-args = ["--html-in-header", "katex-header.html"] - -[package.metadata.release] -no-dev-version = true diff --git a/ndarray-linalg/benches/eig_generalized.rs b/ndarray-linalg/benches/eig_generalized.rs new file mode 100644 index 00000000..d1f5621b --- /dev/null +++ b/ndarray-linalg/benches/eig_generalized.rs @@ -0,0 +1,40 @@ +use criterion::*; +use ndarray::*; +use ndarray_linalg::*; + +fn eig_generalized_small(c: &mut Criterion) { + let mut group = c.benchmark_group("eig"); + for &n in &[4, 8, 16, 32, 64, 128] { + group.bench_with_input(BenchmarkId::new("vecs/C/r", n), &n, |c, n| { + let a: Array2 = random((*n, *n)); + let b: Array2 = random((*n, *n)); + c.iter(|| { + let (_e, _vecs) = (a.clone(), b.clone()).eig_generalized(None).unwrap(); + }) + }); + group.bench_with_input(BenchmarkId::new("vecs/F/r", n), &n, |c, n| { + let a: Array2 = random((*n, *n).f()); + let b: Array2 = random((*n, *n).f()); + c.iter(|| { + let (_e, _vecs) = (a.clone(), b.clone()).eig_generalized(None).unwrap(); + }) + }); + group.bench_with_input(BenchmarkId::new("vecs/C/c", n), &n, |c, n| { + let a: Array2 = random((*n, *n)); + let b: Array2 = random((*n, *n)); + c.iter(|| { + let (_e, _vecs) = (a.clone(), b.clone()).eig_generalized(None).unwrap(); + }) + }); + group.bench_with_input(BenchmarkId::new("vecs/F/c", n), &n, |c, n| { + let a: Array2 = random((*n, *n).f()); + let b: Array2 = random((*n, *n).f()); + c.iter(|| { + let (_e, _vecs) = (a.clone(), b.clone()).eig_generalized(None).unwrap(); + }) + }); + } +} + +criterion_group!(eig, eig_generalized_small); +criterion_main!(eig); diff --git a/ndarray-linalg/benches/svd.rs b/ndarray-linalg/benches/svd.rs index a1870a8f..02ea5806 100644 --- a/ndarray-linalg/benches/svd.rs +++ b/ndarray-linalg/benches/svd.rs @@ -62,37 +62,37 @@ fn svddc_small(c: &mut Criterion) { group.bench_with_input(BenchmarkId::new("C", n), &n, |b, n| { let a: Array2 = random((*n, *n)); b.iter(|| { - let _ = a.svddc(UVTFlag::None).unwrap(); + let _ = a.svddc(JobSvd::None).unwrap(); }) }); group.bench_with_input(BenchmarkId::new("F", n), &n, |b, n| { let a: Array2 = random((*n, *n).f()); b.iter(|| { - let _ = a.svddc(UVTFlag::None).unwrap(); + let _ = a.svddc(JobSvd::None).unwrap(); }) }); group.bench_with_input(BenchmarkId::new("some/C", n), &n, |b, n| { let a: Array2 = random((*n, *n)); b.iter(|| { - let _ = a.svddc(UVTFlag::Some).unwrap(); + let _ = a.svddc(JobSvd::Some).unwrap(); }) }); group.bench_with_input(BenchmarkId::new("some/F", n), &n, |b, n| { let a: Array2 = random((*n, *n).f()); b.iter(|| { - let _ = a.svddc(UVTFlag::Some).unwrap(); + let _ = a.svddc(JobSvd::Some).unwrap(); }) }); group.bench_with_input(BenchmarkId::new("full/C", n), &n, |b, n| { let a: Array2 = random((*n, *n)); b.iter(|| { - let _ = a.svddc(UVTFlag::Full).unwrap(); + let _ = a.svddc(JobSvd::All).unwrap(); }) }); group.bench_with_input(BenchmarkId::new("full/F", n), &n, |b, n| { let a: Array2 = random((*n, *n).f()); b.iter(|| { - let _ = a.svddc(UVTFlag::Full).unwrap(); + let _ = a.svddc(JobSvd::All).unwrap(); }) }); } diff --git a/ndarray-linalg/examples/eig.rs b/ndarray-linalg/examples/eig.rs index 3e41556a..746ec2e4 100644 --- a/ndarray-linalg/examples/eig.rs +++ b/ndarray-linalg/examples/eig.rs @@ -3,7 +3,7 @@ use ndarray_linalg::*; fn main() { let a = arr2(&[[2.0, 1.0, 2.0], [-2.0, 2.0, 1.0], [1.0, 2.0, -2.0]]); - let (e, vecs) = a.clone().eig().unwrap(); + let (e, vecs) = a.eig().unwrap(); println!("eigenvalues = \n{:?}", e); println!("V = \n{:?}", vecs); let a_c: Array2 = a.map(|f| c64::new(*f, 0.0)); diff --git a/ndarray-linalg/examples/eigh.rs b/ndarray-linalg/examples/eigh.rs index c9bcc941..f4f36841 100644 --- a/ndarray-linalg/examples/eigh.rs +++ b/ndarray-linalg/examples/eigh.rs @@ -6,7 +6,7 @@ use ndarray_linalg::*; fn main() { let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]); - let (e, vecs) = a.clone().eigh(UPLO::Upper).unwrap(); + let (e, vecs) = a.eigh(UPLO::Upper).unwrap(); println!("eigenvalues = \n{:?}", e); println!("V = \n{:?}", vecs); let av = a.dot(&vecs); diff --git a/ndarray-linalg/src/assert.rs b/ndarray-linalg/src/assert.rs index 670aabc8..74c1618b 100644 --- a/ndarray-linalg/src/assert.rs +++ b/ndarray-linalg/src/assert.rs @@ -86,7 +86,7 @@ where } macro_rules! generate_assert { - ($assert:ident, $close:path) => { + ($assert:ident, $close:tt) => { #[macro_export] macro_rules! $assert { ($test: expr,$truth: expr,$tol: expr) => { diff --git a/ndarray-linalg/src/cholesky.rs b/ndarray-linalg/src/cholesky.rs index 8ce0da84..58cc5cee 100644 --- a/ndarray-linalg/src/cholesky.rs +++ b/ndarray-linalg/src/cholesky.rs @@ -26,7 +26,7 @@ //! //! // Obtain `L` //! let lower = a.cholesky(UPLO::Lower).unwrap(); -//! assert!(lower.all_close(&array![ +//! assert!(lower.abs_diff_eq(&array![ //! [ 2., 0., 0.], //! [ 6., 1., 0.], //! [-8., 5., 3.] @@ -39,7 +39,7 @@ //! // Solve `A * x = b` //! let b = array![4., 13., -11.]; //! let x = a.solvec(&b).unwrap(); -//! assert!(x.all_close(&array![-2., 1., 0.], 1e-9)); +//! assert!(x.abs_diff_eq(&array![-2., 1., 0.], 1e-9)); //! # } //! ``` diff --git a/ndarray-linalg/src/convert.rs b/ndarray-linalg/src/convert.rs index e1446e96..c808211e 100644 --- a/ndarray-linalg/src/convert.rs +++ b/ndarray-linalg/src/convert.rs @@ -12,7 +12,7 @@ where S: Data, { let n = a.len(); - a.into_shape((n, 1)).unwrap() + a.into_shape_with_order((n, 1)).unwrap() } pub fn into_row(a: ArrayBase) -> ArrayBase @@ -20,7 +20,7 @@ where S: Data, { let n = a.len(); - a.into_shape((1, n)).unwrap() + a.into_shape_with_order((1, n)).unwrap() } pub fn flatten(a: ArrayBase) -> ArrayBase @@ -28,7 +28,7 @@ where S: Data, { let n = a.len(); - a.into_shape(n).unwrap() + a.into_shape_with_order(n).unwrap() } pub fn into_matrix(l: MatrixLayout, a: Vec) -> Result> @@ -46,21 +46,6 @@ where } } -fn uninitialized(l: MatrixLayout) -> ArrayBase -where - A: Copy, - S: DataOwned, -{ - match l { - MatrixLayout::C { row, lda } => unsafe { - ArrayBase::uninitialized((row as usize, lda as usize)) - }, - MatrixLayout::F { col, lda } => unsafe { - ArrayBase::uninitialized((lda as usize, col as usize).f()) - }, - } -} - pub fn replicate(a: &ArrayBase) -> ArrayBase where A: Copy, @@ -68,9 +53,12 @@ where So: DataOwned + DataMut, D: Dimension, { - let mut b = unsafe { ArrayBase::uninitialized(a.dim()) }; - b.assign(a); - b + unsafe { + let ret = ArrayBase::::build_uninit(a.dim(), |view| { + a.assign_to(view); + }); + ret.assume_init() + } } fn clone_with_layout(l: MatrixLayout, a: &ArrayBase) -> ArrayBase @@ -79,9 +67,16 @@ where Si: Data, So: DataOwned + DataMut, { - let mut b = uninitialized(l); - b.assign(a); - b + let shape_builder = match l { + MatrixLayout::C { row, lda } => (row as usize, lda as usize).set_f(false), + MatrixLayout::F { col, lda } => (lda as usize, col as usize).set_f(true), + }; + unsafe { + let ret = ArrayBase::::build_uninit(shape_builder, |view| { + a.assign_to(view); + }); + ret.assume_init() + } } pub fn transpose_data(a: &mut ArrayBase) -> Result<&mut ArrayBase> @@ -104,9 +99,9 @@ where // https://github.com/bluss/rust-ndarray/issues/325 let strides: Vec = a.strides().to_vec(); let new = if a.is_standard_layout() { - ArrayBase::from_shape_vec(a.dim(), a.into_raw_vec()).unwrap() + ArrayBase::from_shape_vec(a.dim(), a.into_raw_vec_and_offset().0).unwrap() } else { - ArrayBase::from_shape_vec(a.dim().f(), a.into_raw_vec()).unwrap() + ArrayBase::from_shape_vec(a.dim().f(), a.into_raw_vec_and_offset().0).unwrap() }; assert_eq!( new.strides(), diff --git a/ndarray-linalg/src/eig.rs b/ndarray-linalg/src/eig.rs index 17f5a1e8..03e3ee03 100644 --- a/ndarray-linalg/src/eig.rs +++ b/ndarray-linalg/src/eig.rs @@ -3,8 +3,10 @@ use crate::error::*; use crate::layout::*; use crate::types::*; +pub use lax::GeneralizedEigenvalue; use ndarray::*; +#[cfg_attr(doc, katexit::katexit)] /// Eigenvalue decomposition of general matrix reference pub trait Eig { /// EigVec is the right eivenvector @@ -66,13 +68,96 @@ pub trait EigVals { impl EigVals for ArrayBase where A: Scalar + Lapack, - S: DataMut, + S: Data, { type EigVal = Array1; fn eigvals(&self) -> Result { let mut a = self.to_owned(); - let (s, _) = A::eig(true, a.square_layout()?, a.as_allocated_mut()?)?; + let (s, _) = A::eig(false, a.square_layout()?, a.as_allocated_mut()?)?; Ok(ArrayBase::from(s)) } } + +#[cfg_attr(doc, katexit::katexit)] +/// Eigenvalue decomposition of general matrix reference +pub trait EigGeneralized { + /// EigVec is the right eivenvector + type EigVal; + type EigVec; + type Real; + /// Calculate eigenvalues with the right eigenvector + /// + /// $$ A u_i = \lambda_i B u_i $$ + /// + /// ``` + /// use ndarray::*; + /// use ndarray_linalg::*; + /// + /// let a: Array2 = array![ + /// [-1.01, 0.86, -4.60, 3.31, -4.81], + /// [ 3.98, 0.53, -7.04, 5.29, 3.55], + /// [ 3.30, 8.26, -3.89, 8.20, -1.51], + /// [ 4.43, 4.96, -7.66, -7.33, 6.18], + /// [ 7.31, -6.43, -6.16, 2.47, 5.58], + /// ]; + /// let b: Array2 = array![ + /// [ 1.23, -4.56, 7.89, 0.12, -3.45], + /// [ 6.78, -9.01, 2.34, -5.67, 8.90], + /// [-1.11, 3.33, -6.66, 9.99, -2.22], + /// [ 4.44, -7.77, 0.00, 1.11, 5.55], + /// [-8.88, 6.66, -3.33, 2.22, -9.99], + /// ]; + /// let (geneigs, vecs) = (a.clone(), b.clone()).eig_generalized(None).unwrap(); + /// + /// let a = a.map(|v| v.as_c()); + /// let b = b.map(|v| v.as_c()); + /// for (ge, vec) in geneigs.iter().zip(vecs.axis_iter(Axis(1))) { + /// if let GeneralizedEigenvalue::Finite(e, _) = ge { + /// let ebv = b.dot(&vec).map(|v| v * e); + /// let av = a.dot(&vec); + /// assert_close_l2!(&av, &ebv, 1e-5); + /// } + /// } + /// ``` + /// + /// # Arguments + /// + /// * `thresh_opt` - An optional threshold for determining approximate zero |β| values when + /// computing the eigenvalues as α/β. If `None`, no approximate comparisons to zero will be + /// made. + fn eig_generalized( + &self, + thresh_opt: Option, + ) -> Result<(Self::EigVal, Self::EigVec)>; +} + +impl EigGeneralized for (ArrayBase, ArrayBase) +where + A: Scalar + Lapack, + S: Data, +{ + type EigVal = Array1>; + type EigVec = Array2; + type Real = A::Real; + + fn eig_generalized( + &self, + thresh_opt: Option, + ) -> Result<(Self::EigVal, Self::EigVec)> { + let (mut a, mut b) = (self.0.to_owned(), self.1.to_owned()); + let layout = a.square_layout()?; + let (s, t) = A::eig_generalized( + true, + layout, + a.as_allocated_mut()?, + b.as_allocated_mut()?, + thresh_opt, + )?; + let n = layout.len() as usize; + Ok(( + ArrayBase::from(s), + Array2::from_shape_vec((n, n).f(), t).unwrap(), + )) + } +} diff --git a/ndarray-linalg/src/eigh.rs b/ndarray-linalg/src/eigh.rs index 86f1fb46..837f51f8 100644 --- a/ndarray-linalg/src/eigh.rs +++ b/ndarray-linalg/src/eigh.rs @@ -1,4 +1,37 @@ -//! Eigenvalue decomposition for Hermite matrices +//! Eigendecomposition for Hermitian matrices. +//! +//! For a Hermitian matrix `A`, this solves the eigenvalue problem `A V = V D` +//! for `D` and `V`, where `D` is the diagonal matrix of eigenvalues in +//! ascending order and `V` is the orthonormal matrix of corresponding +//! eigenvectors. +//! +//! For a pair of Hermitian matrices `A` and `B` where `B` is also positive +//! definite, this solves the generalized eigenvalue problem `A V = B V D`, +//! where `D` is the diagonal matrix of generalized eigenvalues in ascending +//! order and `V` is the matrix of corresponding generalized eigenvectors. The +//! matrix `V` is normalized such that `V^H B V = I`. +//! +//! # Example +//! +//! Find the eigendecomposition of a Hermitian (or real symmetric) matrix. +//! +//! ``` +//! use approx::assert_abs_diff_eq; +//! use ndarray::{array, Array2}; +//! use ndarray_linalg::{Eigh, UPLO}; +//! +//! let a: Array2 = array![ +//! [2., 1.], +//! [1., 2.], +//! ]; +//! let (eigvals, eigvecs) = a.eigh(UPLO::Lower)?; +//! assert_abs_diff_eq!(eigvals, array![1., 3.]); +//! assert_abs_diff_eq!( +//! a.dot(&eigvecs), +//! eigvecs.dot(&Array2::from_diag(&eigvals)), +//! ); +//! # Ok::<(), Box>(()) +//! ``` use ndarray::*; @@ -8,7 +41,6 @@ use crate::layout::*; use crate::operator::LinearOperator; use crate::types::*; use crate::UPLO; -use std::iter::FromIterator; /// Eigenvalue decomposition of Hermite matrix reference pub trait Eigh { @@ -112,7 +144,17 @@ where { type EigVal = Array1; + /// Solves the generalized eigenvalue problem. + /// + /// # Panics + /// + /// Panics if the shapes of the matrices are different. fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)> { + assert_eq!( + self.0.shape(), + self.1.shape(), + "The shapes of the matrices must be identical.", + ); let layout = self.0.square_layout()?; // XXX Force layout to be Fortran (see #146) match layout { diff --git a/ndarray-linalg/src/generate.rs b/ndarray-linalg/src/generate.rs index 67def14a..5646c808 100644 --- a/ndarray-linalg/src/generate.rs +++ b/ndarray-linalg/src/generate.rs @@ -22,7 +22,10 @@ where a } -/// Generate random array +/// Generate random array with given shape +/// +/// - This function uses [rand::thread_rng]. +/// See [random_using] for using another RNG pub fn random(sh: Sh) -> ArrayBase where A: Scalar, @@ -31,29 +34,77 @@ where Sh: ShapeBuilder, { let mut rng = thread_rng(); - ArrayBase::from_shape_fn(sh, |_| A::rand(&mut rng)) + random_using(sh, &mut rng) +} + +/// Generate random array with given RNG +/// +/// - See [random] for using default RNG +pub fn random_using(sh: Sh, rng: &mut R) -> ArrayBase +where + A: Scalar, + S: DataOwned, + D: Dimension, + Sh: ShapeBuilder, + R: Rng, +{ + ArrayBase::from_shape_fn(sh, |_| A::rand(rng)) } /// Generate random unitary matrix using QR decomposition /// -/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose. +/// - Be sure that this it **NOT** a uniform distribution. +/// Use it only for test purpose. +/// - This function uses [rand::thread_rng]. +/// See [random_unitary_using] for using another RNG. pub fn random_unitary(n: usize) -> Array2 where A: Scalar + Lapack, { - let a: Array2 = random((n, n)); + let mut rng = thread_rng(); + random_unitary_using(n, &mut rng) +} + +/// Generate random unitary matrix using QR decomposition with given RNG +/// +/// - Be sure that this it **NOT** a uniform distribution. +/// Use it only for test purpose. +/// - See [random_unitary] for using default RNG. +pub fn random_unitary_using(n: usize, rng: &mut R) -> Array2 +where + A: Scalar + Lapack, + R: Rng, +{ + let a: Array2 = random_using((n, n), rng); let (q, _r) = a.qr_into().unwrap(); q } /// Generate random regular matrix /// -/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose. +/// - Be sure that this it **NOT** a uniform distribution. +/// Use it only for test purpose. +/// - This function uses [rand::thread_rng]. +/// See [random_regular_using] for using another RNG. pub fn random_regular(n: usize) -> Array2 where A: Scalar + Lapack, { - let a: Array2 = random((n, n)); + let mut rng = rand::thread_rng(); + random_regular_using(n, &mut rng) +} + +/// Generate random regular matrix with given RNG +/// +/// - Be sure that this it **NOT** a uniform distribution. +/// Use it only for test purpose. +/// - See [random_regular] for using default RNG. +pub fn random_regular_using(n: usize, rng: &mut R) -> Array2 +where + A: Scalar + Lapack, + R: Rng, +{ + let a: Array2 = random_using((n, n), rng); let (q, mut r) = a.qr_into().unwrap(); for i in 0..n { r[(i, i)] = A::one() + A::from_real(r[(i, i)].abs()); @@ -62,12 +113,28 @@ where } /// Random Hermite matrix +/// +/// - This function uses [rand::thread_rng]. +/// See [random_hermite_using] for using another RNG. pub fn random_hermite(n: usize) -> ArrayBase where A: Scalar, S: DataOwned + DataMut, { - let mut a: ArrayBase = random((n, n)); + let mut rng = rand::thread_rng(); + random_hermite_using(n, &mut rng) +} + +/// Random Hermite matrix with given RNG +/// +/// - See [random_hermite] for using default RNG. +pub fn random_hermite_using(n: usize, rng: &mut R) -> ArrayBase +where + A: Scalar, + S: DataOwned + DataMut, + R: Rng, +{ + let mut a: ArrayBase = random_using((n, n), rng); for i in 0..n { a[(i, i)] = a[(i, i)] + a[(i, i)].conj(); for j in (i + 1)..n { @@ -80,13 +147,30 @@ where /// Random Hermite Positive-definite matrix /// /// - Eigenvalue of matrix must be larger than 1 (thus non-singular) +/// - This function uses [rand::thread_rng]. +/// See [random_hpd_using] for using another RNG. /// pub fn random_hpd(n: usize) -> ArrayBase where A: Scalar, S: DataOwned + DataMut, { - let a: Array2 = random((n, n)); + let mut rng = rand::thread_rng(); + random_hpd_using(n, &mut rng) +} + +/// Random Hermite Positive-definite matrix with given RNG +/// +/// - Eigenvalue of matrix must be larger than 1 (thus non-singular) +/// - See [random_hpd] for using default RNG. +/// +pub fn random_hpd_using(n: usize, rng: &mut R) -> ArrayBase +where + A: Scalar, + S: DataOwned + DataMut, + R: Rng, +{ + let a: Array2 = random_using((n, n), rng); let ah: Array2 = conjugate(&a); ArrayBase::eye(n) + &ah.dot(&a) } diff --git a/ndarray-linalg/src/krylov/householder.rs b/ndarray-linalg/src/krylov/householder.rs index 844d0e7e..c04d9049 100644 --- a/ndarray-linalg/src/krylov/householder.rs +++ b/ndarray-linalg/src/krylov/householder.rs @@ -34,7 +34,7 @@ where { assert_eq!(w.len(), a.len()); let n = a.len(); - let c = A::from(2.0).unwrap() * w.inner(&a); + let c = A::from(2.0).unwrap() * w.inner(a); for l in 0..n { a[l] -= c * w[l]; } diff --git a/ndarray-linalg/src/krylov/mgs.rs b/ndarray-linalg/src/krylov/mgs.rs index dc0dfba6..67806a2c 100644 --- a/ndarray-linalg/src/krylov/mgs.rs +++ b/ndarray-linalg/src/krylov/mgs.rs @@ -50,7 +50,7 @@ impl Orthogonalizer for MGS { let mut coef = Array1::zeros(self.len() + 1); for i in 0..self.len() { let q = &self.q[i]; - let c = q.inner(&a); + let c = q.inner(a); azip!((a in &mut *a, &q in q) *a -= c * q); coef[i] = c; } @@ -77,12 +77,12 @@ impl Orthogonalizer for MGS { self.div_append(&mut a) } - fn div_append(&mut self, mut a: &mut ArrayBase) -> AppendResult + fn div_append(&mut self, a: &mut ArrayBase) -> AppendResult where A: Lapack, S: DataMut, { - let coef = self.decompose(&mut a); + let coef = self.decompose(a); let nrm = coef[coef.len() - 1].re(); if nrm < self.tol { // Linearly dependent diff --git a/ndarray-linalg/src/layout.rs b/ndarray-linalg/src/layout.rs index 9ca772bb..d0d13585 100644 --- a/ndarray-linalg/src/layout.rs +++ b/ndarray-linalg/src/layout.rs @@ -67,9 +67,8 @@ where } fn as_allocated(&self) -> Result<&[A]> { - Ok(self - .as_slice_memory_order() - .ok_or_else(|| LinalgError::MemoryNotCont)?) + self.as_slice_memory_order() + .ok_or(LinalgError::MemoryNotCont) } } @@ -78,8 +77,7 @@ where S: DataMut, { fn as_allocated_mut(&mut self) -> Result<&mut [A]> { - Ok(self - .as_slice_memory_order_mut() - .ok_or_else(|| LinalgError::MemoryNotCont)?) + self.as_slice_memory_order_mut() + .ok_or(LinalgError::MemoryNotCont) } } diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index 03583a25..f376f569 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -149,12 +149,13 @@ where /// Solve least squares for immutable references and a single /// column vector as a right-hand side. -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvd for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvd for ArrayBase where E: Scalar + Lapack, - D: Data, + D1: Data, + D2: Data, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(&rhs)`, where `rhs` is a @@ -163,7 +164,7 @@ where /// `A` and `rhs` must have the same layout, i.e. they must /// be both either row- or column-major format, otherwise a /// `IncompatibleShape` error is raised. - fn least_squares(&self, rhs: &ArrayBase) -> Result> { + fn least_squares(&self, rhs: &ArrayBase) -> Result> { let a = self.to_owned(); let b = rhs.to_owned(); a.least_squares_into(b) @@ -172,12 +173,13 @@ where /// Solve least squares for immutable references and matrix /// (=mulitipe vectors) as a right-hand side. -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvd for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvd for ArrayBase where E: Scalar + Lapack, - D: Data, + D1: Data, + D2: Data, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(&rhs)`, where `rhs` is @@ -186,7 +188,7 @@ where /// `A` and `rhs` must have the same layout, i.e. they must /// be both either row- or column-major format, otherwise a /// `IncompatibleShape` error is raised. - fn least_squares(&self, rhs: &ArrayBase) -> Result> { + fn least_squares(&self, rhs: &ArrayBase) -> Result> { let a = self.to_owned(); let b = rhs.to_owned(); a.least_squares_into(b) @@ -199,10 +201,11 @@ where /// /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any /// valid representation for `ArrayBase`. -impl LeastSquaresSvdInto for ArrayBase +impl LeastSquaresSvdInto for ArrayBase where E: Scalar + Lapack, - D: DataMut, + D1: DataMut, + D2: DataMut, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(rhs)`, where `rhs` is a @@ -213,7 +216,7 @@ where /// `IncompatibleShape` error is raised. fn least_squares_into( mut self, - mut rhs: ArrayBase, + mut rhs: ArrayBase, ) -> Result> { self.least_squares_in_place(&mut rhs) } @@ -223,12 +226,13 @@ where /// as a right-hand side. The matrix and the RHS matrix /// are consumed. /// -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvdInto for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvdInto for ArrayBase where E: Scalar + Lapack, - D: DataMut, + D1: DataMut, + D2: DataMut, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(rhs)`, where `rhs` is a @@ -239,7 +243,7 @@ where /// `IncompatibleShape` error is raised. fn least_squares_into( mut self, - mut rhs: ArrayBase, + mut rhs: ArrayBase, ) -> Result> { self.least_squares_in_place(&mut rhs) } @@ -249,12 +253,13 @@ where /// as a right-hand side. Both values are overwritten in the /// call. /// -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvdInPlace for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvdInPlace for ArrayBase where E: Scalar + Lapack, - D: DataMut, + D1: DataMut, + D2: DataMut, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(rhs)`, where `rhs` is a @@ -265,7 +270,7 @@ where /// `IncompatibleShape` error is raised. fn least_squares_in_place( &mut self, - rhs: &mut ArrayBase, + rhs: &mut ArrayBase, ) -> Result> { if self.shape()[0] != rhs.shape()[0] { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); @@ -292,19 +297,19 @@ where D1: DataMut, D2: DataMut, { - let LeastSquaresOutput:: { + let LeastSquaresOwned:: { singular_values, rank, } = E::least_squares( a.layout()?, a.as_allocated_mut()?, rhs.as_slice_memory_order_mut() - .ok_or_else(|| LinalgError::MemoryNotCont)?, + .ok_or(LinalgError::MemoryNotCont)?, )?; let (m, n) = (a.shape()[0], a.shape()[1]); let solution = rhs.slice(s![0..n]).to_owned(); - let residual_sum_of_squares = compute_residual_scalar(m, n, rank, &rhs); + let residual_sum_of_squares = compute_residual_scalar(m, n, rank, rhs); Ok(LeastSquaresResult { solution, singular_values: Array::from_shape_vec((singular_values.len(),), singular_values)?, @@ -331,12 +336,13 @@ fn compute_residual_scalar>( /// as a right-hand side. Both values are overwritten in the /// call. /// -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvdInPlace for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvdInPlace for ArrayBase where - E: Scalar + Lapack + LeastSquaresSvdDivideConquer_, - D: DataMut, + E: Scalar + Lapack, + D1: DataMut, + D2: DataMut, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(rhs)`, where `rhs` is a @@ -347,7 +353,7 @@ where /// `IncompatibleShape` error is raised. fn least_squares_in_place( &mut self, - rhs: &mut ArrayBase, + rhs: &mut ArrayBase, ) -> Result> { if self.shape()[0] != rhs.shape()[0] { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); @@ -380,7 +386,7 @@ where { let a_layout = a.layout()?; let rhs_layout = rhs.layout()?; - let LeastSquaresOutput:: { + let LeastSquaresOwned:: { singular_values, rank, } = E::least_squares_nrhs( @@ -393,7 +399,7 @@ where let solution: Array2 = rhs.slice(s![..a.shape()[1], ..]).to_owned(); let singular_values = Array::from_shape_vec((singular_values.len(),), singular_values)?; let (m, n) = (a.shape()[0], a.shape()[1]); - let residual_sum_of_squares = compute_residual_array1(m, n, rank, &rhs); + let residual_sum_of_squares = compute_residual_array1(m, n, rank, rhs); Ok(LeastSquaresResult { solution, singular_values, @@ -425,7 +431,7 @@ mod tests { use ndarray::*; // - // Test that the different lest squares traits work as intended on the + // Test that the different least squares traits work as intended on the // different array types. // // | least_squares | ls_into | ls_in_place | @@ -437,9 +443,9 @@ mod tests { // ArrayViewMut | yes | no | yes | // - fn assert_result>( - a: &ArrayBase, - b: &ArrayBase, + fn assert_result, D2: Data>( + a: &ArrayBase, + b: &ArrayBase, res: &LeastSquaresResult, ) { assert_eq!(res.rank, 2); @@ -487,6 +493,15 @@ mod tests { assert_result(&av, &bv, &res); } + #[test] + fn on_cow_view() { + let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]); + let b: Array1 = array![1., 2., 3.]; + let bv = b.view(); + let res = a.least_squares(&bv).unwrap(); + assert_result(&a, &bv, &res); + } + #[test] fn into_on_owned() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; @@ -517,6 +532,16 @@ mod tests { assert_result(&a, &b, &res); } + #[test] + fn into_on_owned_cow() { + let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; + let b = CowArray::from(array![1., 2., 3.]); + let ac = a.clone(); + let b2 = b.clone(); + let res = ac.least_squares_into(b2).unwrap(); + assert_result(&a, &b, &res); + } + #[test] fn in_place_on_owned() { let a = array![[1., 2.], [4., 5.], [3., 4.]]; @@ -549,6 +574,16 @@ mod tests { assert_result(&a, &b, &res); } + #[test] + fn in_place_on_owned_cow() { + let a = array![[1., 2.], [4., 5.], [3., 4.]]; + let b = CowArray::from(array![1., 2., 3.]); + let mut a2 = a.clone(); + let mut b2 = b.clone(); + let res = a2.least_squares_in_place(&mut b2).unwrap(); + assert_result(&a, &b, &res); + } + // // Testing error cases // diff --git a/ndarray-linalg/src/lib.rs b/ndarray-linalg/src/lib.rs index ba9d10c5..784e1dff 100644 --- a/ndarray-linalg/src/lib.rs +++ b/ndarray-linalg/src/lib.rs @@ -44,6 +44,7 @@ clippy::type_complexity, clippy::ptr_arg )] +#![deny(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links)] #[macro_use] extern crate ndarray; @@ -74,26 +75,26 @@ pub mod triangular; pub mod tridiagonal; pub mod types; -pub use assert::*; -pub use cholesky::*; -pub use convert::*; -pub use diagonal::*; -pub use eig::*; -pub use eigh::*; -pub use generate::*; -pub use inner::*; -pub use layout::*; -pub use least_squares::*; -pub use lobpcg::{TruncatedEig, TruncatedOrder, TruncatedSvd}; -pub use norm::*; -pub use operator::*; -pub use opnorm::*; -pub use qr::*; -pub use solve::*; -pub use solveh::*; -pub use svd::*; -pub use svddc::*; -pub use trace::*; -pub use triangular::*; -pub use tridiagonal::*; -pub use types::*; +pub use crate::assert::*; +pub use crate::cholesky::*; +pub use crate::convert::*; +pub use crate::diagonal::*; +pub use crate::eig::*; +pub use crate::eigh::*; +pub use crate::generate::*; +pub use crate::inner::*; +pub use crate::layout::*; +pub use crate::least_squares::*; +pub use crate::lobpcg::{TruncatedEig, TruncatedOrder, TruncatedSvd}; +pub use crate::norm::*; +pub use crate::operator::*; +pub use crate::opnorm::*; +pub use crate::qr::*; +pub use crate::solve::*; +pub use crate::solveh::*; +pub use crate::svd::*; +pub use crate::svddc::*; +pub use crate::trace::*; +pub use crate::triangular::*; +pub use crate::tridiagonal::*; +pub use crate::types::*; diff --git a/ndarray-linalg/src/lobpcg/eig.rs b/ndarray-linalg/src/lobpcg/eig.rs index 0b5a60e8..e60adb04 100644 --- a/ndarray-linalg/src/lobpcg/eig.rs +++ b/ndarray-linalg/src/lobpcg/eig.rs @@ -139,9 +139,9 @@ impl Iterator // add the new eigenvector to the internal constrain matrix let new_constraints = if let Some(ref constraints) = self.eig.constraints { let eigvecs_arr: Vec<_> = constraints - .gencolumns() + .columns() .into_iter() - .chain(vecs.gencolumns().into_iter()) + .chain(vecs.columns().into_iter()) .collect(); stack(Axis(1), &eigvecs_arr).unwrap() diff --git a/ndarray-linalg/src/lobpcg/lobpcg.rs b/ndarray-linalg/src/lobpcg/lobpcg.rs index a10e5961..10d12da2 100644 --- a/ndarray-linalg/src/lobpcg/lobpcg.rs +++ b/ndarray-linalg/src/lobpcg/lobpcg.rs @@ -81,14 +81,13 @@ fn apply_constraints( let gram_yv = y.t().dot(&v); let u = gram_yv - .gencolumns() + .columns() .into_iter() - .map(|x| { + .flat_map(|x| { let res = cholesky_yy.solvec(&x).unwrap(); res.to_vec() }) - .flatten() .collect::>(); let rows = gram_yv.len_of(Axis(0)); @@ -222,7 +221,7 @@ pub fn lobpcg< // calculate L2 norm of error for every eigenvalue let residual_norms = r - .gencolumns() + .columns() .into_iter() .map(|x| x.norm()) .collect::>(); @@ -461,7 +460,8 @@ mod tests { /// Test the `sorted_eigen` function #[test] fn test_sorted_eigen() { - let matrix: Array2 = generate::random((10, 10)) * 10.0; + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let matrix: Array2 = generate::random_using((10, 10), &mut rng) * 10.0; let matrix = matrix.t().dot(&matrix); // return all eigenvectors with largest first @@ -477,7 +477,8 @@ mod tests { /// Test the masking function #[test] fn test_masking() { - let matrix: Array2 = generate::random((10, 5)) * 10.0; + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let matrix: Array2 = generate::random_using((10, 5), &mut rng) * 10.0; let masked_matrix = ndarray_mask(matrix.view(), &[true, true, false, true, false]); close_l2( &masked_matrix.slice(s![.., 2]), @@ -489,7 +490,8 @@ mod tests { /// Test orthonormalization of a random matrix #[test] fn test_orthonormalize() { - let matrix: Array2 = generate::random((10, 10)) * 10.0; + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let matrix: Array2 = generate::random_using((10, 10), &mut rng) * 10.0; let (n, l) = orthonormalize(matrix.clone()).unwrap(); @@ -510,7 +512,8 @@ mod tests { assert_symmetric(a); let n = a.len_of(Axis(0)); - let x: Array2 = generate::random((n, num)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let x: Array2 = generate::random_using((n, num), &mut rng); let result = lobpcg(|y| a.dot(&y), x, |_| {}, None, 1e-5, n * 2, order); match result { @@ -554,7 +557,8 @@ mod tests { #[test] fn test_eigsolver_constructed() { let n = 50; - let tmp = generate::random((n, n)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let tmp = generate::random_using((n, n), &mut rng); //let (v, _) = tmp.qr_square().unwrap(); let (v, _) = orthonormalize(tmp).unwrap(); @@ -571,7 +575,8 @@ mod tests { fn test_eigsolver_constrained() { let diag = arr1(&[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]); let a = Array2::from_diag(&diag); - let x: Array2 = generate::random((10, 1)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let x: Array2 = generate::random_using((10, 1), &mut rng); let y: Array2 = arr2(&[ [1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1.0, 0., 0., 0., 0., 0., 0., 0., 0.], diff --git a/ndarray-linalg/src/lobpcg/svd.rs b/ndarray-linalg/src/lobpcg/svd.rs index a796364d..62d18b49 100644 --- a/ndarray-linalg/src/lobpcg/svd.rs +++ b/ndarray-linalg/src/lobpcg/svd.rs @@ -30,7 +30,7 @@ impl + 'static + MagnitudeCorrection> Trunc let mut a = self.eigvals.iter().enumerate().collect::>(); // sort by magnitude - a.sort_by(|(_, x), (_, y)| x.partial_cmp(&y).unwrap().reverse()); + a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap().reverse()); // calculate cut-off magnitude (borrowed from scipy) let cutoff = A::epsilon() * // float precision @@ -64,7 +64,7 @@ impl + 'static + MagnitudeCorrection> Trunc let mut ularge = self.problem.dot(&vlarge); ularge - .gencolumns_mut() + .columns_mut() .into_iter() .zip(values.iter()) .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b)); @@ -75,7 +75,7 @@ impl + 'static + MagnitudeCorrection> Trunc let mut vlarge = self.problem.t().dot(&ularge); vlarge - .gencolumns_mut() + .columns_mut() .into_iter() .zip(values.iter()) .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b)); @@ -214,7 +214,8 @@ mod tests { #[test] fn test_truncated_svd_random() { - let a: Array2 = generate::random((50, 10)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = generate::random_using((50, 10), &mut rng); let res = TruncatedSvd::new(a.clone(), Order::Largest) .precision(1e-5) diff --git a/ndarray-linalg/src/qr.rs b/ndarray-linalg/src/qr.rs index 441e7cef..4bf2f0ec 100644 --- a/ndarray-linalg/src/qr.rs +++ b/ndarray-linalg/src/qr.rs @@ -135,9 +135,7 @@ where S2: DataMut + DataOwned, { let av = a.slice(s![..n as isize, ..m as isize]); - let mut a = unsafe { ArrayBase::uninitialized((n, m)) }; - a.assign(&av); - a + replicate(&av) } fn take_slice_upper(a: &ArrayBase, n: usize, m: usize) -> ArrayBase @@ -146,10 +144,12 @@ where S1: Data, S2: DataMut + DataOwned, { - let av = a.slice(s![..n as isize, ..m as isize]); - let mut a = unsafe { ArrayBase::uninitialized((n, m)) }; - for ((i, j), val) in a.indexed_iter_mut() { - *val = if i <= j { av[(i, j)] } else { A::zero() }; - } + let av = a.slice(s![..n, ..m]); + let mut a = replicate(&av); + Zip::indexed(&mut a).for_each(|(i, j), elt| { + if i > j { + *elt = A::zero() + } + }); a } diff --git a/ndarray-linalg/src/solve.rs b/ndarray-linalg/src/solve.rs index 2277d227..79df1f77 100644 --- a/ndarray-linalg/src/solve.rs +++ b/ndarray-linalg/src/solve.rs @@ -5,20 +5,13 @@ //! Solve `A * x = b`: //! //! ``` -//! #[macro_use] -//! extern crate ndarray; -//! extern crate ndarray_linalg; -//! //! use ndarray::prelude::*; //! use ndarray_linalg::Solve; -//! # fn main() { //! //! let a: Array2 = array![[3., 2., -1.], [2., -2., 4.], [-2., 1., -2.]]; //! let b: Array1 = array![1., -2., 0.]; //! let x = a.solve_into(b).unwrap(); -//! assert!(x.all_close(&array![1., -2., -2.], 1e-9)); -//! -//! # } +//! assert!(x.abs_diff_eq(&array![1., -2., -2.], 1e-9)); //! ``` //! //! There are also special functions for solving `A^T * x = b` and @@ -29,21 +22,18 @@ //! the beginning than solving directly using `A`: //! //! ``` -//! # extern crate ndarray; -//! # extern crate ndarray_linalg; -//! //! use ndarray::prelude::*; //! use ndarray_linalg::*; -//! # fn main() { //! -//! let a: Array2 = random((3, 3)); +//! /// Use fixed algorithm and seed of PRNG for reproducible test +//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); +//! +//! let a: Array2 = random_using((3, 3), &mut rng); //! let f = a.factorize_into().unwrap(); // LU factorize A (A is consumed) //! for _ in 0..10 { -//! let b: Array1 = random(3); +//! let b: Array1 = random_using(3, &mut rng); //! let x = f.solve_into(b).unwrap(); // Solve A * x = b using factorized L, U //! } -//! -//! # } //! ``` use ndarray::*; @@ -77,13 +67,24 @@ pub use lax::{Pivot, Transpose}; pub trait Solve { /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of columns + /// of `A`. fn solve>(&self, b: &ArrayBase) -> Result> { let mut b = replicate(b); self.solve_inplace(&mut b)?; Ok(b) } + /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of columns + /// of `A`. fn solve_into>( &self, mut b: ArrayBase, @@ -91,8 +92,14 @@ pub trait Solve { self.solve_inplace(&mut b)?; Ok(b) } + /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of columns + /// of `A`. fn solve_inplace<'a, S: DataMut>( &self, b: &'a mut ArrayBase, @@ -100,13 +107,24 @@ pub trait Solve { /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of rows of + /// `A`. fn solve_t>(&self, b: &ArrayBase) -> Result> { let mut b = replicate(b); self.solve_t_inplace(&mut b)?; Ok(b) } + /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of rows of + /// `A`. fn solve_t_into>( &self, mut b: ArrayBase, @@ -114,8 +132,14 @@ pub trait Solve { self.solve_t_inplace(&mut b)?; Ok(b) } + /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of rows of + /// `A`. fn solve_t_inplace<'a, S: DataMut>( &self, b: &'a mut ArrayBase, @@ -123,6 +147,11 @@ pub trait Solve { /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of rows of + /// `A`. fn solve_h>(&self, b: &ArrayBase) -> Result> { let mut b = replicate(b); self.solve_h_inplace(&mut b)?; @@ -130,6 +159,11 @@ pub trait Solve { } /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of rows of + /// `A`. fn solve_h_into>( &self, mut b: ArrayBase, @@ -139,6 +173,11 @@ pub trait Solve { } /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of rows of + /// `A`. fn solve_h_inplace<'a, S: DataMut>( &self, b: &'a mut ArrayBase, @@ -150,9 +189,9 @@ pub trait Solve { pub struct LUFactorized { /// The factors `L` and `U`; the unit diagonal elements of `L` are not /// stored. - pub a: ArrayBase, + a: ArrayBase, /// The pivot indices that define the permutation matrix `P`. - pub ipiv: Pivot, + ipiv: Pivot, } impl Solve for LUFactorized @@ -167,6 +206,11 @@ where where Sb: DataMut, { + assert_eq!( + rhs.len(), + self.a.len_of(Axis(1)), + "The length of `rhs` must be compatible with the shape of the factored matrix.", + ); A::solve( self.a.square_layout()?, Transpose::No, @@ -183,6 +227,11 @@ where where Sb: DataMut, { + assert_eq!( + rhs.len(), + self.a.len_of(Axis(0)), + "The length of `rhs` must be compatible with the shape of the factored matrix.", + ); A::solve( self.a.square_layout()?, Transpose::Transpose, @@ -199,6 +248,11 @@ where where Sb: DataMut, { + assert_eq!( + rhs.len(), + self.a.len_of(Axis(0)), + "The length of `rhs` must be compatible with the shape of the factored matrix.", + ); A::solve( self.a.square_layout()?, Transpose::Hermite, @@ -323,8 +377,15 @@ where type Output = Array2; fn inv(&self) -> Result> { + // Preserve the existing layout. This is required to obtain the correct + // result, because the result of `A::inv` is layout-dependent. + let a = if self.a.is_standard_layout() { + replicate(&self.a) + } else { + replicate(&self.a.t()).reversed_axes() + }; let f = LUFactorized { - a: replicate(&self.a), + a, ipiv: self.ipiv.clone(), }; f.inv_into() @@ -468,7 +529,8 @@ where self.ensure_square()?; match self.factorize() { Ok(fac) => fac.sln_det(), - Err(LinalgError::Lapack(e)) if matches!(e, lax::error::Error::LapackComputationalFailure {..}) => + Err(LinalgError::Lapack(e)) + if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) => { // The determinant is zero. Ok((A::zero(), A::Real::neg_infinity())) @@ -487,7 +549,8 @@ where self.ensure_square()?; match self.factorize_into() { Ok(fac) => fac.sln_det_into(), - Err(LinalgError::Lapack(e)) if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) => + Err(LinalgError::Lapack(e)) + if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) => { // The determinant is zero. Ok((A::zero(), A::Real::neg_infinity())) diff --git a/ndarray-linalg/src/solveh.rs b/ndarray-linalg/src/solveh.rs index c70138d4..df0b3a6f 100644 --- a/ndarray-linalg/src/solveh.rs +++ b/ndarray-linalg/src/solveh.rs @@ -8,13 +8,8 @@ //! Solve `A * x = b`, where `A` is a Hermitian (or real symmetric) matrix: //! //! ``` -//! #[macro_use] -//! extern crate ndarray; -//! extern crate ndarray_linalg; -//! //! use ndarray::prelude::*; //! use ndarray_linalg::SolveH; -//! # fn main() { //! //! let a: Array2 = array![ //! [3., 2., -1.], @@ -23,9 +18,7 @@ //! ]; //! let b: Array1 = array![11., -12., 1.]; //! let x = a.solveh_into(b).unwrap(); -//! assert!(x.all_close(&array![1., 3., -2.], 1e-9)); -//! -//! # } +//! assert!(x.abs_diff_eq(&array![1., 3., -2.], 1e-9)); //! ``` //! //! If you are solving multiple systems of linear equations with the same @@ -33,20 +26,18 @@ //! the factorization once at the beginning than solving directly using `A`: //! //! ``` -//! # extern crate ndarray; -//! # extern crate ndarray_linalg; //! use ndarray::prelude::*; //! use ndarray_linalg::*; -//! # fn main() { //! -//! let a: Array2 = random((3, 3)); +//! /// Use fixed algorithm and seed of PRNG for reproducible test +//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); +//! +//! let a: Array2 = random_using((3, 3), &mut rng); //! let f = a.factorizeh_into().unwrap(); // Factorize A (A is consumed) //! for _ in 0..10 { -//! let b: Array1 = random(3); +//! let b: Array1 = random_using(3, &mut rng); //! let x = f.solveh_into(b).unwrap(); // Solve A * x = b using the factorization //! } -//! -//! # } //! ``` use ndarray::*; @@ -69,14 +60,25 @@ pub trait SolveH { /// Solves a system of linear equations `A * x = b` with Hermitian (or real /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of columns + /// of `A`. fn solveh>(&self, b: &ArrayBase) -> Result> { let mut b = replicate(b); self.solveh_inplace(&mut b)?; Ok(b) } + /// Solves a system of linear equations `A * x = b` with Hermitian (or real /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of columns + /// of `A`. fn solveh_into>( &self, mut b: ArrayBase, @@ -84,10 +86,16 @@ pub trait SolveH { self.solveh_inplace(&mut b)?; Ok(b) } + /// Solves a system of linear equations `A * x = b` with Hermitian (or real /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. The value of `x` is also assigned to the /// argument. + /// + /// # Panics + /// + /// Panics if the length of `b` is not the equal to the number of columns + /// of `A`. fn solveh_inplace<'a, S: DataMut>( &self, b: &'a mut ArrayBase, @@ -113,6 +121,11 @@ where where Sb: DataMut, { + assert_eq!( + rhs.len(), + self.a.len_of(Axis(1)), + "The length of `rhs` must be compatible with the shape of the factored matrix.", + ); A::solveh( self.a.square_layout()?, UPLO::Upper, @@ -426,7 +439,8 @@ where fn sln_deth(&self) -> Result<(A::Real, A::Real)> { match self.factorizeh() { Ok(fac) => Ok(fac.sln_deth()), - Err(LinalgError::Lapack(e)) if matches!(e, lax::error::Error::LapackComputationalFailure {..}) => + Err(LinalgError::Lapack(e)) + if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) => { // Determinant is zero. Ok((A::Real::zero(), A::Real::neg_infinity())) @@ -451,7 +465,8 @@ where fn sln_deth_into(self) -> Result<(A::Real, A::Real)> { match self.factorizeh_into() { Ok(fac) => Ok(fac.sln_deth_into()), - Err(LinalgError::Lapack(e)) if matches!(e, lax::error::Error::LapackComputationalFailure {..}) => + Err(LinalgError::Lapack(e)) + if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) => { // Determinant is zero. Ok((A::Real::zero(), A::Real::neg_infinity())) diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index 017354c2..0b0ae237 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -3,14 +3,14 @@ use super::{convert::*, error::*, layout::*, types::*}; use ndarray::*; -pub use lax::UVTFlag; +pub use lax::JobSvd; /// Singular-value decomposition of matrix (copying) by divide-and-conquer pub trait SVDDC { type U; type VT; type Sigma; - fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)>; + fn svddc(&self, uvt_flag: JobSvd) -> Result<(Option, Self::Sigma, Option)>; } /// Singular-value decomposition of matrix by divide-and-conquer @@ -20,7 +20,7 @@ pub trait SVDDCInto { type Sigma; fn svddc_into( self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)>; } @@ -31,20 +31,20 @@ pub trait SVDDCInplace { type Sigma; fn svddc_inplace( &mut self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)>; } impl SVDDC for ArrayBase where A: Scalar + Lapack, - S: DataMut, + S: Data, { type U = Array2; type VT = Array2; type Sigma = Array1; - fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)> { + fn svddc(&self, uvt_flag: JobSvd) -> Result<(Option, Self::Sigma, Option)> { self.to_owned().svddc_into(uvt_flag) } } @@ -60,7 +60,7 @@ where fn svddc_into( mut self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)> { self.svddc_inplace(uvt_flag) } @@ -77,7 +77,7 @@ where fn svddc_inplace( &mut self, - uvt_flag: UVTFlag, + uvt_flag: JobSvd, ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; let svd_res = A::svddc(l, uvt_flag, self.as_allocated_mut()?)?; @@ -85,9 +85,9 @@ where let k = m.min(n); let (u_col, vt_row) = match uvt_flag { - UVTFlag::Full => (m, n), - UVTFlag::Some => (k, k), - UVTFlag::None => (0, 0), + JobSvd::All => (m, n), + JobSvd::Some => (k, k), + JobSvd::None => (0, 0), }; let u = svd_res diff --git a/ndarray-linalg/src/trace.rs b/ndarray-linalg/src/trace.rs index 3020a9a5..feb119f2 100644 --- a/ndarray-linalg/src/trace.rs +++ b/ndarray-linalg/src/trace.rs @@ -4,7 +4,6 @@ use ndarray::*; use std::iter::Sum; use super::error::*; -use super::layout::*; use super::types::*; pub trait Trace { @@ -20,7 +19,13 @@ where type Output = A; fn trace(&self) -> Result { - let (n, _) = self.square_layout()?.size(); + let n = match self.is_square() { + true => Ok(self.nrows()), + false => Err(LinalgError::NotSquare { + rows: self.nrows() as i32, + cols: self.ncols() as i32, + }), + }?; Ok((0..n as usize).map(|i| self[(i, i)]).sum()) } } diff --git a/ndarray-linalg/src/tridiagonal.rs b/ndarray-linalg/src/tridiagonal.rs index b603a5ee..b5aebbf2 100644 --- a/ndarray-linalg/src/tridiagonal.rs +++ b/ndarray-linalg/src/tridiagonal.rs @@ -272,7 +272,7 @@ where Sb: DataMut, { A::solve_tridiagonal( - &self, + self, rhs.layout()?, Transpose::No, rhs.as_slice_mut().unwrap(), @@ -287,7 +287,7 @@ where Sb: DataMut, { A::solve_tridiagonal( - &self, + self, rhs.layout()?, Transpose::Transpose, rhs.as_slice_mut().unwrap(), @@ -302,7 +302,7 @@ where Sb: DataMut, { A::solve_tridiagonal( - &self, + self, rhs.layout()?, Transpose::Hermite, rhs.as_slice_mut().unwrap(), @@ -622,7 +622,7 @@ where { fn det_tridiagonal(&self) -> Result { let n = self.d.len(); - Ok(rec_rel(&self)[n]) + Ok(rec_rel(self)[n]) } } @@ -671,7 +671,7 @@ where A: Scalar + Lapack, { fn rcond_tridiagonal(&self) -> Result { - Ok(A::rcond_tridiagonal(&self)?) + Ok(A::rcond_tridiagonal(self)?) } } diff --git a/ndarray-linalg/tests/arnoldi.rs b/ndarray-linalg/tests/arnoldi.rs index bbc15553..dd56e0a0 100644 --- a/ndarray-linalg/tests/arnoldi.rs +++ b/ndarray-linalg/tests/arnoldi.rs @@ -3,8 +3,9 @@ use ndarray_linalg::{krylov::*, *}; #[test] fn aq_qh_mgs() { - let a: Array2 = random((5, 5)); - let v: Array1 = random(5); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((5, 5), &mut rng); + let v: Array1 = random_using(5, &mut rng); let (q, h) = arnoldi_mgs(a.clone(), v, 1e-9); println!("A = \n{:?}", &a); println!("Q = \n{:?}", &q); @@ -18,8 +19,9 @@ fn aq_qh_mgs() { #[test] fn aq_qh_householder() { - let a: Array2 = random((5, 5)); - let v: Array1 = random(5); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((5, 5), &mut rng); + let v: Array1 = random_using(5, &mut rng); let (q, h) = arnoldi_mgs(a.clone(), v, 1e-9); println!("A = \n{:?}", &a); println!("Q = \n{:?}", &q); @@ -33,8 +35,9 @@ fn aq_qh_householder() { #[test] fn aq_qh_mgs_complex() { - let a: Array2 = random((5, 5)); - let v: Array1 = random(5); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((5, 5), &mut rng); + let v: Array1 = random_using(5, &mut rng); let (q, h) = arnoldi_mgs(a.clone(), v, 1e-9); println!("A = \n{:?}", &a); println!("Q = \n{:?}", &q); @@ -48,8 +51,9 @@ fn aq_qh_mgs_complex() { #[test] fn aq_qh_householder_complex() { - let a: Array2 = random((5, 5)); - let v: Array1 = random(5); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((5, 5), &mut rng); + let v: Array1 = random_using(5, &mut rng); let (q, h) = arnoldi_mgs(a.clone(), v, 1e-9); println!("A = \n{:?}", &a); println!("Q = \n{:?}", &q); diff --git a/ndarray-linalg/tests/cholesky.rs b/ndarray-linalg/tests/cholesky.rs index b45afb5c..d3e9942b 100644 --- a/ndarray-linalg/tests/cholesky.rs +++ b/ndarray-linalg/tests/cholesky.rs @@ -6,7 +6,8 @@ macro_rules! cholesky { paste::item! { #[test] fn []() { - let a_orig: Array2<$elem> = random_hpd(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a_orig: Array2<$elem> = random_hpd_using(3, &mut rng); println!("a = \n{:?}", a_orig); let upper = a_orig.cholesky(UPLO::Upper).unwrap(); @@ -79,7 +80,8 @@ macro_rules! cholesky_into_lower_upper { paste::item! { #[test] fn []() { - let a: Array2<$elem> = random_hpd(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$elem> = random_hpd_using(3, &mut rng); println!("a = \n{:?}", a); let upper = a.cholesky(UPLO::Upper).unwrap(); let fac_upper = a.factorizec(UPLO::Upper).unwrap(); @@ -106,7 +108,8 @@ macro_rules! cholesky_into_inverse { paste::item! { #[test] fn []() { - let a: Array2<$elem> = random_hpd(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$elem> = random_hpd_using(3, &mut rng); println!("a = \n{:?}", a); let inv = a.invc().unwrap(); assert_close_l2!(&a.dot(&inv), &Array2::eye(3), $rtol); @@ -134,13 +137,14 @@ macro_rules! cholesky_det { paste::item! { #[test] fn []() { - let a: Array2<$elem> = random_hpd(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$elem> = random_hpd_using(3, &mut rng); println!("a = \n{:?}", a); let ln_det = a .eigvalsh(UPLO::Upper) .unwrap() .mapv(|elem| elem.ln()) - .scalar_sum(); + .sum(); let det = ln_det.exp(); assert_aclose!(a.factorizec(UPLO::Upper).unwrap().detc(), det, $atol); assert_aclose!(a.factorizec(UPLO::Upper).unwrap().ln_detc(), ln_det, $atol); @@ -168,8 +172,9 @@ macro_rules! cholesky_solve { paste::item! { #[test] fn []() { - let a: Array2<$elem> = random_hpd(3); - let x: Array1<$elem> = random(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$elem> = random_hpd_using(3, &mut rng); + let x: Array1<$elem> = random_using(3, &mut rng); let b = a.dot(&x); println!("a = \n{:?}", a); println!("x = \n{:?}", x); diff --git a/ndarray-linalg/tests/convert.rs b/ndarray-linalg/tests/convert.rs index 3a6155d8..1e20d916 100644 --- a/ndarray-linalg/tests/convert.rs +++ b/ndarray-linalg/tests/convert.rs @@ -3,7 +3,8 @@ use ndarray_linalg::*; #[test] fn generalize() { - let a: Array3 = random((3, 2, 4).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array3 = random_using((3, 2, 4).f(), &mut rng); let ans = a.clone(); let a: Array3 = convert::generalize(a); assert_eq!(a, ans); diff --git a/ndarray-linalg/tests/det.rs b/ndarray-linalg/tests/det.rs index c3986528..40dafd57 100644 --- a/ndarray-linalg/tests/det.rs +++ b/ndarray-linalg/tests/det.rs @@ -94,7 +94,7 @@ fn det_zero_nonsquare() { assert!(a.sln_det_into().is_err()); }; } - for &shape in &[(1, 2).into_shape(), (1, 2).f()] { + for &shape in &[(1, 2).into_shape_with_order(), (1, 2).f()] { det_zero_nonsquare!(f64, shape); det_zero_nonsquare!(f32, shape); det_zero_nonsquare!(c64, shape); @@ -136,15 +136,36 @@ fn det() { assert_rclose!(result.1, ln_det, rtol); } } + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); for rows in 1..5 { - det_impl(random_regular::(rows), 1e-9); - det_impl(random_regular::(rows), 1e-4); - det_impl(random_regular::(rows), 1e-9); - det_impl(random_regular::(rows), 1e-4); - det_impl(random_regular::(rows).t().to_owned(), 1e-9); - det_impl(random_regular::(rows).t().to_owned(), 1e-4); - det_impl(random_regular::(rows).t().to_owned(), 1e-9); - det_impl(random_regular::(rows).t().to_owned(), 1e-4); + det_impl(random_regular_using::(rows, &mut rng), 1e-9); + det_impl(random_regular_using::(rows, &mut rng), 1e-4); + det_impl(random_regular_using::(rows, &mut rng), 1e-9); + det_impl(random_regular_using::(rows, &mut rng), 1e-4); + det_impl( + random_regular_using::(rows, &mut rng) + .t() + .to_owned(), + 1e-9, + ); + det_impl( + random_regular_using::(rows, &mut rng) + .t() + .to_owned(), + 1e-4, + ); + det_impl( + random_regular_using::(rows, &mut rng) + .t() + .to_owned(), + 1e-9, + ); + det_impl( + random_regular_using::(rows, &mut rng) + .t() + .to_owned(), + 1e-4, + ); } } @@ -152,7 +173,8 @@ fn det() { fn det_nonsquare() { macro_rules! det_nonsquare { ($elem:ty, $shape:expr) => { - let a: Array2<$elem> = random($shape); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$elem> = random_using($shape, &mut rng); assert!(a.factorize().unwrap().det().is_err()); assert!(a.factorize().unwrap().sln_det().is_err()); assert!(a.factorize().unwrap().det_into().is_err()); @@ -164,7 +186,7 @@ fn det_nonsquare() { }; } for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] { - for &shape in &[dims.into_shape(), dims.f()] { + for &shape in &[dims.into_shape_with_order(), dims.f()] { det_nonsquare!(f64, shape); det_nonsquare!(f32, shape); det_nonsquare!(c64, shape); diff --git a/ndarray-linalg/tests/deth.rs b/ndarray-linalg/tests/deth.rs index abd54105..9079c342 100644 --- a/ndarray-linalg/tests/deth.rs +++ b/ndarray-linalg/tests/deth.rs @@ -60,7 +60,7 @@ fn deth_zero_nonsquare() { assert!(a.sln_deth_into().is_err()); }; } - for &shape in &[(1, 2).into_shape(), (1, 2).f()] { + for &shape in &[(1, 2).into_shape_with_order(), (1, 2).f()] { deth_zero_nonsquare!(f64, shape); deth_zero_nonsquare!(f32, shape); deth_zero_nonsquare!(c64, shape); @@ -72,7 +72,8 @@ fn deth_zero_nonsquare() { fn deth() { macro_rules! deth { ($elem:ty, $rows:expr, $atol:expr) => { - let a: Array2<$elem> = random_hermite($rows); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$elem> = random_hermite_using($rows, &mut rng); println!("a = \n{:?}", a); // Compute determinant from eigenvalues. @@ -137,7 +138,7 @@ fn deth_nonsquare() { }; } for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] { - for &shape in &[dims.into_shape(), dims.f()] { + for &shape in &[dims.into_shape_with_order(), dims.f()] { deth_nonsquare!(f64, shape); deth_nonsquare!(f32, shape); deth_nonsquare!(c64, shape); diff --git a/ndarray-linalg/tests/eig.rs b/ndarray-linalg/tests/eig.rs index 28314b8a..8fb40212 100644 --- a/ndarray-linalg/tests/eig.rs +++ b/ndarray-linalg/tests/eig.rs @@ -1,9 +1,25 @@ use ndarray::*; use ndarray_linalg::*; +fn sorted_eigvals(eigvals: ArrayView1<'_, T>) -> Array1 { + let mut indices: Vec = (0..eigvals.len()).collect(); + indices.sort_by(|&ind1, &ind2| { + let e1 = eigvals[ind1]; + let e2 = eigvals[ind2]; + e1.re() + .partial_cmp(&e2.re()) + .unwrap() + .then(e1.im().partial_cmp(&e2.im()).unwrap()) + }); + indices.iter().map(|&ind| eigvals[ind]).collect() +} + // Test Av_i = e_i v_i for i = 0..n -fn test_eig(a: Array2, eigs: Array1, vecs: Array2) -where +fn test_eig( + a: ArrayView2<'_, T>, + eigs: ArrayView1<'_, T::Complex>, + vecs: ArrayView2<'_, T::Complex>, +) where T::Complex: Lapack, { println!("a\n{:+.4}", &a); @@ -87,7 +103,10 @@ fn test_matrix_real() -> Array2 { } fn test_matrix_real_t() -> Array2 { - test_matrix_real::().t().permuted_axes([1, 0]).to_owned() + let orig = test_matrix_real::(); + let mut out = Array2::zeros(orig.raw_dim().f()); + out.assign(&orig); + out } fn answer_eig_real() -> Array1 { @@ -154,10 +173,10 @@ fn test_matrix_complex() -> Array2 { } fn test_matrix_complex_t() -> Array2 { - test_matrix_complex::() - .t() - .permuted_axes([1, 0]) - .to_owned() + let orig = test_matrix_complex::(); + let mut out = Array2::zeros(orig.raw_dim().f()); + out.assign(&orig); + out } fn answer_eig_complex() -> Array1 { @@ -205,29 +224,45 @@ macro_rules! impl_test_real { #[test] fn [<$real _eigvals >]() { let a = test_matrix_real::<$real>(); - let (e, _vecs) = a.eig().unwrap(); - assert_close_l2!(&e, &answer_eig_real::<$real>(), 1.0e-3); + let (e1, _vecs) = a.eig().unwrap(); + let e2 = a.eigvals().unwrap(); + assert_close_l2!(&e1, &answer_eig_real::<$real>(), 1.0e-3); + assert_close_l2!(&e2, &answer_eig_real::<$real>(), 1.0e-3); } #[test] fn [<$real _eigvals_t>]() { let a = test_matrix_real_t::<$real>(); - let (e, _vecs) = a.eig().unwrap(); - assert_close_l2!(&e, &answer_eig_real::<$real>(), 1.0e-3); + let (e1, _vecs) = a.eig().unwrap(); + assert_close_l2!( + &sorted_eigvals(e1.view()), + &sorted_eigvals(answer_eig_real::<$real>().view()), + 1.0e-3 + ); + let e2 = a.eigvals().unwrap(); + assert_close_l2!( + &sorted_eigvals(e2.view()), + &sorted_eigvals(answer_eig_real::<$real>().view()), + 1.0e-3 + ); } #[test] fn [<$real _eig>]() { let a = test_matrix_real::<$real>(); - let (e, vecs) = a.eig().unwrap(); - test_eig(a, e, vecs); + let (e1, vecs) = a.eig().unwrap(); + let e2 = a.eigvals().unwrap(); + test_eig(a.view(), e1.view(), vecs.view()); + test_eig(a.view(), e2.view(), vecs.view()); } #[test] fn [<$real _eig_t>]() { let a = test_matrix_real_t::<$real>(); - let (e, vecs) = a.eig().unwrap(); - test_eig(a, e, vecs); + let (e1, vecs) = a.eig().unwrap(); + let e2 = a.eigvals().unwrap(); + test_eig(a.view(), e1.view(), vecs.view()); + test_eig(a.view(), e2.view(), vecs.view()); } } // paste::item! @@ -243,15 +278,19 @@ macro_rules! impl_test_complex { #[test] fn [<$complex _eigvals >]() { let a = test_matrix_complex::<$complex>(); - let (e, _vecs) = a.eig().unwrap(); - assert_close_l2!(&e, &answer_eig_complex::<$complex>(), 1.0e-3); + let (e1, _vecs) = a.eig().unwrap(); + let e2 = a.eigvals().unwrap(); + assert_close_l2!(&e1, &answer_eig_complex::<$complex>(), 1.0e-3); + assert_close_l2!(&e2, &answer_eig_complex::<$complex>(), 1.0e-3); } #[test] fn [<$complex _eigvals_t>]() { let a = test_matrix_complex_t::<$complex>(); - let (e, _vecs) = a.eig().unwrap(); - assert_close_l2!(&e, &answer_eig_complex::<$complex>(), 1.0e-3); + let (e1, _vecs) = a.eig().unwrap(); + let e2 = a.eigvals().unwrap(); + assert_close_l2!(&e1, &answer_eig_complex::<$complex>(), 1.0e-3); + assert_close_l2!(&e2, &answer_eig_complex::<$complex>(), 1.0e-3); } #[test] @@ -271,15 +310,19 @@ macro_rules! impl_test_complex { #[test] fn [<$complex _eig>]() { let a = test_matrix_complex::<$complex>(); - let (e, vecs) = a.eig().unwrap(); - test_eig(a, e, vecs); + let (e1, vecs) = a.eig().unwrap(); + let e2 = a.eigvals().unwrap(); + test_eig(a.view(), e1.view(), vecs.view()); + test_eig(a.view(), e2.view(), vecs.view()); } #[test] fn [<$complex _eig_t>]() { let a = test_matrix_complex_t::<$complex>(); - let (e, vecs) = a.eig().unwrap(); - test_eig(a, e, vecs); + let (e1, vecs) = a.eig().unwrap(); + let e2 = a.eigvals().unwrap(); + test_eig(a.view(), e1.view(), vecs.view()); + test_eig(a.view(), e2.view(), vecs.view()); } } // paste::item! }; diff --git a/ndarray-linalg/tests/eig_generalized.rs b/ndarray-linalg/tests/eig_generalized.rs new file mode 100644 index 00000000..06df81ec --- /dev/null +++ b/ndarray-linalg/tests/eig_generalized.rs @@ -0,0 +1,190 @@ +use ndarray::*; +use ndarray_linalg::*; + +#[test] +fn generalized_eigenvalue_fmt() { + let ge0 = GeneralizedEigenvalue::Finite(0.1, (1.0, 10.0)); + assert_eq!(ge0.to_string(), "1.000e-1 (1.000e0/1.000e1)".to_string()); + + let ge1 = GeneralizedEigenvalue::Indeterminate((1.0, 0.0)); + assert_eq!(ge1.to_string(), "∞ (1.000e0/0.000e0)".to_string()); +} + +#[test] +fn real_a_real_b_3x3_full_rank() { + #[rustfmt::skip] + let a = array![ + [ 2.0, 1.0, 8.0], + [-2.0, 0.0, 3.0], + [ 7.0, 6.0, 5.0], + ]; + #[rustfmt::skip] + let b = array![ + [ 1.0, 2.0, -7.0], + [-3.0, 1.0, 6.0], + [ 4.0, -5.0, 1.0], + ]; + let (geneigvals, eigvecs) = (a.clone(), b.clone()).eig_generalized(None).unwrap(); + + let a = a.map(|v| v.as_c()); + let b = b.map(|v| v.as_c()); + for (ge, vec) in geneigvals.iter().zip(eigvecs.columns()) { + if let GeneralizedEigenvalue::Finite(e, _) = ge { + let ebv = b.dot(&vec).map(|v| v * e); + let av = a.dot(&vec); + assert_close_l2!(&av, &ebv, 1e-7); + } + } + + let mut eigvals = geneigvals + .iter() + .filter_map(|ge: &GeneralizedEigenvalue| match ge { + GeneralizedEigenvalue::Finite(e, _) => Some(e.clone()), + GeneralizedEigenvalue::Indeterminate(_) => None, + }) + .collect::>(); + eigvals.sort_by(|a, b| a.re().partial_cmp(&b.re()).unwrap()); + let eigvals = Array1::from_vec(eigvals); + // Reference eigenvalues from Mathematica + assert_close_l2!( + &eigvals, + &array![-0.4415795111, 0.5619249537, 50.87965456].map(c64::from), + 1e-7 + ); +} + +#[test] +fn real_a_real_b_3x3_nullity_1() { + #[rustfmt::skip] + let a = array![ + [ 2.0, 1.0, 8.0], + [-2.0, 0.0, 3.0], + [ 7.0, 6.0, 5.0], + ]; + #[rustfmt::skip] + let b = array![ + [1.0, 2.0, 3.0], + [0.0, 1.0, 1.0], + [1.0, -1.0, 0.0], + ]; + let (geneigvals, eigvecs) = (a.clone(), b.clone()).eig_generalized(Some(1e-4)).unwrap(); + + let a = a.map(|v| v.as_c()); + let b = b.map(|v| v.as_c()); + for (ge, vec) in geneigvals.iter().zip(eigvecs.columns()) { + if let GeneralizedEigenvalue::Finite(e, _) = ge { + let ebv = b.dot(&vec).map(|v| v * e); + let av = a.dot(&vec); + assert_close_l2!(&av, &ebv, 1e-7); + } + } + + let mut eigvals = geneigvals + .iter() + .filter_map(|ge: &GeneralizedEigenvalue| match ge { + GeneralizedEigenvalue::Finite(e, _) => Some(e.clone()), + GeneralizedEigenvalue::Indeterminate(_) => None, + }) + .collect::>(); + eigvals.sort_by(|a, b| a.re().partial_cmp(&b.re()).unwrap()); + let eigvals = Array1::from_vec(eigvals); + // Reference eigenvalues from Mathematica + assert_close_l2!( + &eigvals, + &array![-12.91130192, 3.911301921].map(c64::from), + 1e-7 + ); +} + +#[test] +fn complex_a_complex_b_3x3_full_rank() { + #[rustfmt::skip] + let a = array![ + [c64::new(1.0, 2.0), c64::new(-3.0, 0.5), c64::new( 0.0, -1.0)], + [c64::new(2.5, -4.0), c64::new( 1.0, 1.0), c64::new(-1.5, 2.5)], + [c64::new(0.0, 0.0), c64::new( 3.0, -2.0), c64::new( 4.0, 4.0)], + ]; + #[rustfmt::skip] + let b = array![ + [c64::new(-2.0, 1.0), c64::new( 3.5, -1.0), c64::new( 1.0, 1.0)], + [c64::new( 0.0, -3.0), c64::new( 2.0, 2.0), c64::new(-4.0, 0.0)], + [c64::new( 5.0, 5.0), c64::new(-1.5, 1.5), c64::new( 0.0, -2.0)], + ]; + let (geneigvals, eigvecs) = (a.clone(), b.clone()).eig_generalized(None).unwrap(); + + let a = a.map(|v| v.as_c()); + let b = b.map(|v| v.as_c()); + for (ge, vec) in geneigvals.iter().zip(eigvecs.columns()) { + if let GeneralizedEigenvalue::Finite(e, _) = ge { + let ebv = b.dot(&vec).map(|v| v * e); + let av = a.dot(&vec); + assert_close_l2!(&av, &ebv, 1e-7); + } + } + + let mut eigvals = geneigvals + .iter() + .filter_map(|ge: &GeneralizedEigenvalue| match ge { + GeneralizedEigenvalue::Finite(e, _) => Some(e.clone()), + GeneralizedEigenvalue::Indeterminate(_) => None, + }) + .collect::>(); + eigvals.sort_by(|a, b| a.re().partial_cmp(&b.re()).unwrap()); + let eigvals = Array1::from_vec(eigvals); + // Reference eigenvalues from Mathematica + assert_close_l2!( + &eigvals, + &array![ + c64::new(-0.701598, -1.71262), + c64::new(-0.67899, -0.0172468), + c64::new(0.59059, 0.276034) + ], + 1e-5 + ); +} + +#[test] +fn complex_a_complex_b_3x3_nullity_1() { + #[rustfmt::skip] + let a = array![ + [c64::new(1.0, 2.0), c64::new(-3.0, 0.5), c64::new( 0.0, -1.0)], + [c64::new(2.5, -4.0), c64::new( 1.0, 1.0), c64::new(-1.5, 2.5)], + [c64::new(0.0, 0.0), c64::new( 3.0, -2.0), c64::new( 4.0, 4.0)], + ]; + #[rustfmt::skip] + let b = array![ + [c64::new(-2.55604, -4.10176), c64::new(9.03944, 3.745000), c64::new(35.4641, 21.1704)], + [c64::new( 7.85029, 7.02144), c64::new(9.23225, -0.479451), c64::new(13.9507, -16.5402)], + [c64::new(-4.47803, 3.98981), c64::new(9.44434, -4.519970), c64::new(40.9006, -23.5060)], + ]; + let (geneigvals, eigvecs) = (a.clone(), b.clone()).eig_generalized(Some(1e-4)).unwrap(); + + let a = a.map(|v| v.as_c()); + let b = b.map(|v| v.as_c()); + for (ge, vec) in geneigvals.iter().zip(eigvecs.columns()) { + if let GeneralizedEigenvalue::Finite(e, _) = ge { + let ebv = b.dot(&vec).map(|v| v * e); + let av = a.dot(&vec); + assert_close_l2!(&av, &ebv, 1e-7); + } + } + + let mut eigvals = geneigvals + .iter() + .filter_map(|ge: &GeneralizedEigenvalue| match ge { + GeneralizedEigenvalue::Finite(e, _) => Some(e.clone()), + GeneralizedEigenvalue::Indeterminate(_) => None, + }) + .collect::>(); + eigvals.sort_by(|a, b| a.re().partial_cmp(&b.re()).unwrap()); + let eigvals = Array1::from_vec(eigvals); + // Reference eigenvalues from Mathematica + assert_close_l2!( + &eigvals, + &array![ + c64::new(-0.0620674, -0.270016), + c64::new(0.0218236, 0.0602709), + ], + 1e-5 + ); +} diff --git a/ndarray-linalg/tests/eigh.rs b/ndarray-linalg/tests/eigh.rs index 77be699b..8d8ce385 100644 --- a/ndarray-linalg/tests/eigh.rs +++ b/ndarray-linalg/tests/eigh.rs @@ -1,6 +1,14 @@ use ndarray::*; use ndarray_linalg::*; +#[should_panic] +#[test] +fn eigh_generalized_shape_mismatch() { + let a = Array2::::eye(3); + let b = Array2::::eye(2); + let _ = (a, b).eigh_inplace(UPLO::Upper); +} + #[test] fn fixed() { let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]); @@ -71,7 +79,8 @@ fn fixed_t_lower() { #[test] fn ssqrt() { - let a: Array2 = random_hpd(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng); let ans = a.clone(); let s = a.ssqrt(UPLO::Upper).unwrap(); println!("a = {:?}", &ans); @@ -84,7 +93,8 @@ fn ssqrt() { #[test] fn ssqrt_t() { - let a: Array2 = random_hpd(3).reversed_axes(); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng).reversed_axes(); let ans = a.clone(); let s = a.ssqrt(UPLO::Upper).unwrap(); println!("a = {:?}", &ans); @@ -97,7 +107,8 @@ fn ssqrt_t() { #[test] fn ssqrt_lower() { - let a: Array2 = random_hpd(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng); let ans = a.clone(); let s = a.ssqrt(UPLO::Lower).unwrap(); println!("a = {:?}", &ans); @@ -110,7 +121,8 @@ fn ssqrt_lower() { #[test] fn ssqrt_t_lower() { - let a: Array2 = random_hpd(3).reversed_axes(); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng).reversed_axes(); let ans = a.clone(); let s = a.ssqrt(UPLO::Lower).unwrap(); println!("a = {:?}", &ans); diff --git a/ndarray-linalg/tests/householder.rs b/ndarray-linalg/tests/householder.rs index adc4d9b2..83b500f7 100644 --- a/ndarray-linalg/tests/householder.rs +++ b/ndarray-linalg/tests/householder.rs @@ -3,7 +3,8 @@ use ndarray_linalg::{krylov::*, *}; fn over(rtol: A::Real) { const N: usize = 4; - let a: Array2 = random((N, N * 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((N, N * 2), &mut rng); // Terminate let (q, r) = householder(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate); @@ -45,7 +46,8 @@ fn over_c64() { fn full(rtol: A::Real) { const N: usize = 5; - let a: Array2 = random((N, N)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((N, N), &mut rng); let (q, r) = householder(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate); let qc: Array2 = conjugate(&q); assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol; "Check Q^H Q = I"); @@ -71,7 +73,8 @@ fn full_c64() { fn half(rtol: A::Real) { const N: usize = 4; - let a: Array2 = random((N, N / 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((N, N / 2), &mut rng); let (q, r) = householder(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate); let qc: Array2 = conjugate(&q); assert_close_l2!(&qc.dot(&q), &Array::eye(N / 2), rtol; "Check Q^H Q = I"); diff --git a/ndarray-linalg/tests/inner.rs b/ndarray-linalg/tests/inner.rs index 7fc42c83..076b2791 100644 --- a/ndarray-linalg/tests/inner.rs +++ b/ndarray-linalg/tests/inner.rs @@ -19,7 +19,8 @@ fn size_longer() { #[test] fn abs() { - let a: Array1 = random(1); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array1 = random_using(1, &mut rng); let aa = a.inner(&a); assert_aclose!(aa.re(), a.norm().powi(2), 1e-5); assert_aclose!(aa.im(), 0.0, 1e-5); diff --git a/ndarray-linalg/tests/inv.rs b/ndarray-linalg/tests/inv.rs index cbbcffd0..93ff1aad 100644 --- a/ndarray-linalg/tests/inv.rs +++ b/ndarray-linalg/tests/inv.rs @@ -1,27 +1,114 @@ use ndarray::*; use ndarray_linalg::*; +fn test_inv_random(n: usize, set_f: bool, rtol: A::Real) +where + A: Scalar + Lapack, +{ + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using([n; 2].set_f(set_f), &mut rng); + let identity = Array2::eye(n); + assert_close_l2!(&a.inv().unwrap().dot(&a), &identity, rtol); + assert_close_l2!( + &a.factorize().unwrap().inv().unwrap().dot(&a), + &identity, + rtol + ); + assert_close_l2!( + &a.clone().factorize_into().unwrap().inv().unwrap().dot(&a), + &identity, + rtol + ); +} + +fn test_inv_into_random(n: usize, set_f: bool, rtol: A::Real) +where + A: Scalar + Lapack, +{ + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using([n; 2].set_f(set_f), &mut rng); + let identity = Array2::eye(n); + assert_close_l2!(&a.clone().inv_into().unwrap().dot(&a), &identity, rtol); + assert_close_l2!( + &a.factorize().unwrap().inv_into().unwrap().dot(&a), + &identity, + rtol + ); + assert_close_l2!( + &a.clone() + .factorize_into() + .unwrap() + .inv_into() + .unwrap() + .dot(&a), + &identity, + rtol + ); +} + +#[test] +fn inv_empty() { + test_inv_random::(0, false, 0.); + test_inv_random::(0, false, 0.); + test_inv_random::(0, false, 0.); + test_inv_random::(0, false, 0.); +} + +#[test] +fn inv_random_float() { + for n in 1..=8 { + for &set_f in &[false, true] { + test_inv_random::(n, set_f, 1e-3); + test_inv_random::(n, set_f, 1e-9); + } + } +} + +#[test] +fn inv_random_complex() { + for n in 1..=8 { + for &set_f in &[false, true] { + test_inv_random::(n, set_f, 1e-3); + test_inv_random::(n, set_f, 1e-9); + } + } +} + +#[test] +fn inv_into_empty() { + test_inv_into_random::(0, false, 0.); + test_inv_into_random::(0, false, 0.); + test_inv_into_random::(0, false, 0.); + test_inv_into_random::(0, false, 0.); +} + #[test] -fn inv_random() { - let a: Array2 = random((3, 3)); - let ai: Array2<_> = (&a).inv().unwrap(); - let id = Array::eye(3); - assert_close_l2!(&ai.dot(&a), &id, 1e-7); +fn inv_into_random_float() { + for n in 1..=8 { + for &set_f in &[false, true] { + test_inv_into_random::(n, set_f, 1e-3); + test_inv_into_random::(n, set_f, 1e-9); + } + } } #[test] -fn inv_random_t() { - let a: Array2 = random((3, 3).f()); - let ai: Array2<_> = (&a).inv().unwrap(); - let id = Array::eye(3); - assert_close_l2!(&ai.dot(&a), &id, 1e-7); +fn inv_into_random_complex() { + for n in 1..=8 { + for &set_f in &[false, true] { + test_inv_into_random::(n, set_f, 1e-3); + test_inv_into_random::(n, set_f, 1e-9); + } + } } #[test] #[should_panic] fn inv_error() { // do not have inverse - let a = Array::::zeros(9).into_shape((3, 3)).unwrap(); + let a = Array::::zeros(9) + .into_shape_with_order((3, 3)) + .unwrap(); let a_inv = a.inv().unwrap(); println!("{:?}", a_inv); } diff --git a/ndarray-linalg/tests/least_squares.rs b/ndarray-linalg/tests/least_squares.rs index 33e20ca7..17f62b64 100644 --- a/ndarray-linalg/tests/least_squares.rs +++ b/ndarray-linalg/tests/least_squares.rs @@ -5,7 +5,8 @@ use ndarray_linalg::*; /// A is square. `x = A^{-1} b`, `|b - Ax| = 0` fn test_exact(a: Array2) { - let b: Array1 = random(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array1 = random_using(3, &mut rng); let result = a.least_squares(&b).unwrap(); // unpack result let x = result.solution; @@ -27,13 +28,15 @@ macro_rules! impl_exact { paste::item! { #[test] fn []() { - let a: Array2<$scalar> = random((3, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 3), &mut rng); test_exact(a) } #[test] fn []() { - let a: Array2<$scalar> = random((3, 3).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 3).f(), &mut rng); test_exact(a) } } @@ -51,7 +54,8 @@ fn test_overdetermined(a: Array2) where T::Real: AbsDiffEq, { - let b: Array1 = random(4); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array1 = random_using(4, &mut rng); let result = a.least_squares(&b).unwrap(); // unpack result let x = result.solution; @@ -73,13 +77,15 @@ macro_rules! impl_overdetermined { paste::item! { #[test] fn []() { - let a: Array2<$scalar> = random((4, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((4, 3), &mut rng); test_overdetermined(a) } #[test] fn []() { - let a: Array2<$scalar> = random((4, 3).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((4, 3).f(), &mut rng); test_overdetermined(a) } } @@ -94,7 +100,8 @@ impl_overdetermined!(c64); /// #column > #row case. /// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique fn test_underdetermined(a: Array2) { - let b: Array1 = random(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array1 = random_using(3, &mut rng); let result = a.least_squares(&b).unwrap(); assert_eq!(result.rank, 3); assert!(result.residual_sum_of_squares.is_none()); @@ -110,13 +117,15 @@ macro_rules! impl_underdetermined { paste::item! { #[test] fn []() { - let a: Array2<$scalar> = random((3, 4)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 4), &mut rng); test_underdetermined(a) } #[test] fn []() { - let a: Array2<$scalar> = random((3, 4).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 4).f(), &mut rng); test_underdetermined(a) } } diff --git a/ndarray-linalg/tests/least_squares_nrhs.rs b/ndarray-linalg/tests/least_squares_nrhs.rs index dd7d283c..bcf6d013 100644 --- a/ndarray-linalg/tests/least_squares_nrhs.rs +++ b/ndarray-linalg/tests/least_squares_nrhs.rs @@ -32,29 +32,33 @@ macro_rules! impl_exact { paste::item! { #[test] fn []() { - let a: Array2<$scalar> = random((3, 3)); - let b: Array2<$scalar> = random((3, 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 3), &mut rng); + let b: Array2<$scalar> = random_using((3, 2), &mut rng); test_exact(a, b) } #[test] fn []() { - let a: Array2<$scalar> = random((3, 3)); - let b: Array2<$scalar> = random((3, 2).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 3), &mut rng); + let b: Array2<$scalar> = random_using((3, 2).f(), &mut rng); test_exact(a, b) } #[test] fn []() { - let a: Array2<$scalar> = random((3, 3).f()); - let b: Array2<$scalar> = random((3, 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 3).f(), &mut rng); + let b: Array2<$scalar> = random_using((3, 2), &mut rng); test_exact(a, b) } #[test] fn []() { - let a: Array2<$scalar> = random((3, 3).f()); - let b: Array2<$scalar> = random((3, 2).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 3).f(), &mut rng); + let b: Array2<$scalar> = random_using((3, 2).f(), &mut rng); test_exact(a, b) } } @@ -100,29 +104,33 @@ macro_rules! impl_overdetermined { paste::item! { #[test] fn []() { - let a: Array2<$scalar> = random((4, 3)); - let b: Array2<$scalar> = random((4, 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((4, 3), &mut rng); + let b: Array2<$scalar> = random_using((4, 2), &mut rng); test_overdetermined(a, b) } #[test] fn []() { - let a: Array2<$scalar> = random((4, 3).f()); - let b: Array2<$scalar> = random((4, 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((4, 3).f(), &mut rng); + let b: Array2<$scalar> = random_using((4, 2), &mut rng); test_overdetermined(a, b) } #[test] fn []() { - let a: Array2<$scalar> = random((4, 3)); - let b: Array2<$scalar> = random((4, 2).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((4, 3), &mut rng); + let b: Array2<$scalar> = random_using((4, 2).f(), &mut rng); test_overdetermined(a, b) } #[test] fn []() { - let a: Array2<$scalar> = random((4, 3).f()); - let b: Array2<$scalar> = random((4, 2).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((4, 3).f(), &mut rng); + let b: Array2<$scalar> = random_using((4, 2).f(), &mut rng); test_overdetermined(a, b) } } @@ -155,29 +163,33 @@ macro_rules! impl_underdetermined { paste::item! { #[test] fn []() { - let a: Array2<$scalar> = random((3, 4)); - let b: Array2<$scalar> = random((3, 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 4), &mut rng); + let b: Array2<$scalar> = random_using((3, 2), &mut rng); test_underdetermined(a, b) } #[test] fn []() { - let a: Array2<$scalar> = random((3, 4).f()); - let b: Array2<$scalar> = random((3, 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 4).f(), &mut rng); + let b: Array2<$scalar> = random_using((3, 2), &mut rng); test_underdetermined(a, b) } #[test] fn []() { - let a: Array2<$scalar> = random((3, 4)); - let b: Array2<$scalar> = random((3, 2).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 4), &mut rng); + let b: Array2<$scalar> = random_using((3, 2).f(), &mut rng); test_underdetermined(a, b) } #[test] fn []() { - let a: Array2<$scalar> = random((3, 4).f()); - let b: Array2<$scalar> = random((3, 2).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$scalar> = random_using((3, 4).f(), &mut rng); + let b: Array2<$scalar> = random_using((3, 2).f(), &mut rng); test_underdetermined(a, b) } } diff --git a/ndarray-linalg/tests/mgs.rs b/ndarray-linalg/tests/mgs.rs index 35c860de..9e9aa29e 100644 --- a/ndarray-linalg/tests/mgs.rs +++ b/ndarray-linalg/tests/mgs.rs @@ -5,7 +5,8 @@ fn qr_full() { const N: usize = 5; let rtol: A::Real = A::real(1e-9); - let a: Array2 = random((N, N)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((N, N), &mut rng); let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate); assert_close_l2!(&q.dot(&r), &a, rtol); @@ -27,7 +28,8 @@ fn qr() { const N: usize = 4; let rtol: A::Real = A::real(1e-9); - let a: Array2 = random((N, N / 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((N, N / 2), &mut rng); let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate); assert_close_l2!(&q.dot(&r), &a, rtol); @@ -49,7 +51,8 @@ fn qr_over() { const N: usize = 4; let rtol: A::Real = A::real(1e-9); - let a: Array2 = random((N, N * 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((N, N * 2), &mut rng); // Terminate let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate); diff --git a/ndarray-linalg/tests/normalize.rs b/ndarray-linalg/tests/normalize.rs index ca50912e..8d71a009 100644 --- a/ndarray-linalg/tests/normalize.rs +++ b/ndarray-linalg/tests/normalize.rs @@ -3,14 +3,16 @@ use ndarray_linalg::*; #[test] fn n_columns() { - let a: Array2 = random((3, 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((3, 2), &mut rng); let (n, v) = normalize(a.clone(), NormalizeAxis::Column); assert_close_l2!(&n.dot(&from_diag(&v)), &a, 1e-7); } #[test] fn n_rows() { - let a: Array2 = random((3, 2)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((3, 2), &mut rng); let (n, v) = normalize(a.clone(), NormalizeAxis::Row); assert_close_l2!(&from_diag(&v).dot(&n), &a, 1e-7); } diff --git a/ndarray-linalg/tests/opnorm.rs b/ndarray-linalg/tests/opnorm.rs index abd41748..cd45d258 100644 --- a/ndarray-linalg/tests/opnorm.rs +++ b/ndarray-linalg/tests/opnorm.rs @@ -14,11 +14,13 @@ fn gen(i: usize, j: usize, rev: bool) -> Array2 { let n = (i * j + 1) as f64; if rev { Array::range(1., n, 1.) - .into_shape((j, i)) + .into_shape_with_order((j, i)) .unwrap() .reversed_axes() } else { - Array::range(1., n, 1.).into_shape((i, j)).unwrap() + Array::range(1., n, 1.) + .into_shape_with_order((i, j)) + .unwrap() } } diff --git a/ndarray-linalg/tests/qr.rs b/ndarray-linalg/tests/qr.rs index a69d89e2..702ed060 100644 --- a/ndarray-linalg/tests/qr.rs +++ b/ndarray-linalg/tests/qr.rs @@ -26,48 +26,56 @@ fn test_square(a: &Array2, n: usize, m: usize) { #[test] fn qr_sq() { - let a = random((3, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using((3, 3), &mut rng); test_square(&a, 3, 3); } #[test] fn qr_sq_t() { - let a = random((3, 3).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using((3, 3).f(), &mut rng); test_square(&a, 3, 3); } #[test] fn qr_3x3() { - let a = random((3, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using((3, 3), &mut rng); test(&a, 3, 3); } #[test] fn qr_3x3_t() { - let a = random((3, 3).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using((3, 3).f(), &mut rng); test(&a, 3, 3); } #[test] fn qr_3x4() { - let a = random((3, 4)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using((3, 4), &mut rng); test(&a, 3, 4); } #[test] fn qr_3x4_t() { - let a = random((3, 4).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using((3, 4).f(), &mut rng); test(&a, 3, 4); } #[test] fn qr_4x3() { - let a = random((4, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using((4, 3), &mut rng); test(&a, 4, 3); } #[test] fn qr_4x3_t() { - let a = random((4, 3).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using((4, 3).f(), &mut rng); test(&a, 4, 3); } diff --git a/ndarray-linalg/tests/solve.rs b/ndarray-linalg/tests/solve.rs index d069ec7a..074350ce 100644 --- a/ndarray-linalg/tests/solve.rs +++ b/ndarray-linalg/tests/solve.rs @@ -1,49 +1,241 @@ -use ndarray::*; -use ndarray_linalg::*; +use ndarray::prelude::*; +use ndarray_linalg::{ + assert_aclose, assert_close_l2, c32, c64, random_hpd_using, random_using, solve::*, + OperationNorm, Scalar, +}; +macro_rules! test_solve { + ( + [$($elem_type:ty => $rtol:expr),*], + $a_ident:ident = $a:expr, + $x_ident:ident = $x:expr, + b = $b:expr, + $solve:ident, + ) => { + $({ + let $a_ident: Array2<$elem_type> = $a; + let $x_ident: Array1<$elem_type> = $x; + let b: Array1<$elem_type> = $b; + let a = $a_ident; + let x = $x_ident; + let rtol = $rtol; + assert_close_l2!(&a.$solve(&b).unwrap(), &x, rtol); + assert_close_l2!(&a.factorize().unwrap().$solve(&b).unwrap(), &x, rtol); + assert_close_l2!(&a.factorize_into().unwrap().$solve(&b).unwrap(), &x, rtol); + })* + }; +} + +macro_rules! test_solve_into { + ( + [$($elem_type:ty => $rtol:expr),*], + $a_ident:ident = $a:expr, + $x_ident:ident = $x:expr, + b = $b:expr, + $solve_into:ident, + ) => { + $({ + let $a_ident: Array2<$elem_type> = $a; + let $x_ident: Array1<$elem_type> = $x; + let b: Array1<$elem_type> = $b; + let a = $a_ident; + let x = $x_ident; + let rtol = $rtol; + assert_close_l2!(&a.$solve_into(b.clone()).unwrap(), &x, rtol); + assert_close_l2!(&a.factorize().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol); + assert_close_l2!(&a.factorize_into().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol); + })* + }; +} + +macro_rules! test_solve_inplace { + ( + [$($elem_type:ty => $rtol:expr),*], + $a_ident:ident = $a:expr, + $x_ident:ident = $x:expr, + b = $b:expr, + $solve_inplace:ident, + ) => { + $({ + let $a_ident: Array2<$elem_type> = $a; + let $x_ident: Array1<$elem_type> = $x; + let b: Array1<$elem_type> = $b; + let a = $a_ident; + let x = $x_ident; + let rtol = $rtol; + { + let mut b = b.clone(); + assert_close_l2!(&a.$solve_inplace(&mut b).unwrap(), &x, rtol); + assert_close_l2!(&b, &x, rtol); + } + { + let mut b = b.clone(); + assert_close_l2!(&a.factorize().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol); + assert_close_l2!(&b, &x, rtol); + } + { + let mut b = b.clone(); + assert_close_l2!(&a.factorize_into().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol); + assert_close_l2!(&b, &x, rtol); + } + })* + }; +} + +macro_rules! test_solve_all { + ( + [$($elem_type:ty => $rtol:expr),*], + $a_ident:ident = $a:expr, + $x_ident:ident = $x:expr, + b = $b:expr, + [$solve:ident, $solve_into:ident, $solve_inplace:ident], + ) => { + test_solve!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve,); + test_solve_into!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_into,); + test_solve_inplace!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_inplace,); + }; +} + +#[test] +fn solve_random_float() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [f32 => 1e-3, f64 => 1e-9], + a = random_using([n; 2].set_f(set_f), &mut rng), + x = random_using(n, &mut rng), + b = a.dot(&x), + [solve, solve_into, solve_inplace], + ); + } + } +} + +#[test] +fn solve_random_complex() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [c32 => 1e-3, c64 => 1e-9], + a = random_using([n; 2].set_f(set_f), &mut rng), + x = random_using(n, &mut rng), + b = a.dot(&x), + [solve, solve_into, solve_inplace], + ); + } + } +} + +#[should_panic] #[test] -fn solve_random() { - let a: Array2 = random((3, 3)); - let x: Array1 = random(3); - let b = a.dot(&x); - let y = a.solve_into(b).unwrap(); - assert_close_l2!(&x, &y, 1e-7); +fn solve_shape_mismatch() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((3, 3), &mut rng); + let b: Array1 = random_using(2, &mut rng); + let _ = a.solve_into(b); } #[test] -fn solve_random_t() { - let a: Array2 = random((3, 3).f()); - let x: Array1 = random(3); - let b = a.dot(&x); - let y = a.solve_into(b).unwrap(); - assert_close_l2!(&x, &y, 1e-7); +fn solve_t_random_float() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [f32 => 1e-3, f64 => 1e-9], + a = random_using([n; 2].set_f(set_f), &mut rng), + x = random_using(n, &mut rng), + b = a.t().dot(&x), + [solve_t, solve_t_into, solve_t_inplace], + ); + } + } +} + +#[should_panic] +#[test] +fn solve_t_shape_mismatch() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((3, 3).f(), &mut rng); + let b: Array1 = random_using(4, &mut rng); + let _ = a.solve_into(b); +} + +#[test] +fn solve_t_random_complex() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [c32 => 1e-3, c64 => 1e-9], + a = random_using([n; 2].set_f(set_f), &mut rng), + x = random_using(n, &mut rng), + b = a.t().dot(&x), + [solve_t, solve_t_into, solve_t_inplace], + ); + } + } } +#[should_panic] #[test] -fn solve_factorized() { - let a: Array2 = random((3, 3)); - let ans: Array1 = random(3); - let b = a.dot(&ans); +fn solve_factorized_shape_mismatch() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((3, 3), &mut rng); + let b: Array1 = random_using(4, &mut rng); let f = a.factorize_into().unwrap(); - let x = f.solve_into(b).unwrap(); - assert_close_l2!(&x, &ans, 1e-7); + let _ = f.solve_into(b); } #[test] -fn solve_factorized_t() { - let a: Array2 = random((3, 3).f()); - let ans: Array1 = random(3); - let b = a.dot(&ans); +fn solve_h_random_float() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [f32 => 1e-3, f64 => 1e-9], + a = random_using([n; 2].set_f(set_f), &mut rng), + x = random_using(n, &mut rng), + b = a.t().mapv(|x| x.conj()).dot(&x), + [solve_h, solve_h_into, solve_h_inplace], + ); + } + } +} + +#[should_panic] +#[test] +fn solve_factorized_t_shape_mismatch() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((3, 3).f(), &mut rng); + let b: Array1 = random_using(4, &mut rng); let f = a.factorize_into().unwrap(); - let x = f.solve_into(b).unwrap(); - assert_close_l2!(&x, &ans, 1e-7); + let _ = f.solve_into(b); +} + +#[test] +fn solve_h_random_complex() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + for n in 0..=8 { + for &set_f in &[false, true] { + test_solve_all!( + [c32 => 1e-3, c64 => 1e-9], + a = random_using([n; 2].set_f(set_f), &mut rng), + x = random_using(n, &mut rng), + b = a.t().mapv(|x| x.conj()).dot(&x), + [solve_h, solve_h_into, solve_h_inplace], + ); + } + } } #[test] fn rcond() { macro_rules! rcond { ($elem:ty, $rows:expr, $atol:expr) => { - let a: Array2<$elem> = random_hpd($rows); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2<$elem> = random_hpd_using($rows, &mut rng); let rcond = 1. / (a.opnorm_one().unwrap() * a.inv().unwrap().opnorm_one().unwrap()); assert_aclose!(a.rcond().unwrap(), rcond, $atol); assert_aclose!(a.rcond_into().unwrap(), rcond, $atol); @@ -62,7 +254,7 @@ fn rcond_hilbert() { macro_rules! rcond_hilbert { ($elem:ty, $rows:expr, $atol:expr) => { let a = Array2::<$elem>::from_shape_fn(($rows, $rows), |(i, j)| { - 1. / (i as $elem + j as $elem - 1.) + 1. / (i as $elem + j as $elem + 1.) }); assert_aclose!(a.rcond().unwrap(), 0., $atol); assert_aclose!(a.rcond_into().unwrap(), 0., $atol); diff --git a/ndarray-linalg/tests/solveh.rs b/ndarray-linalg/tests/solveh.rs index 1074057f..25513551 100644 --- a/ndarray-linalg/tests/solveh.rs +++ b/ndarray-linalg/tests/solveh.rs @@ -1,10 +1,30 @@ use ndarray::*; use ndarray_linalg::*; +#[should_panic] +#[test] +fn solveh_shape_mismatch() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng); + let b: Array1 = random_using(2, &mut rng); + let _ = a.solveh_into(b); +} + +#[should_panic] +#[test] +fn factorizeh_solveh_shape_mismatch() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng); + let b: Array1 = random_using(2, &mut rng); + let f = a.factorizeh_into().unwrap(); + let _ = f.solveh_into(b); +} + #[test] fn solveh_random() { - let a: Array2 = random_hpd(3); - let x: Array1 = random(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng); + let x: Array1 = random_using(3, &mut rng); let b = a.dot(&x); let y = a.solveh_into(b).unwrap(); assert_close_l2!(&x, &y, 1e-7); @@ -15,10 +35,30 @@ fn solveh_random() { assert_close_l2!(&x, &y, 1e-7); } +#[should_panic] +#[test] +fn solveh_t_shape_mismatch() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng).reversed_axes(); + let b: Array1 = random_using(2, &mut rng); + let _ = a.solveh_into(b); +} + +#[should_panic] +#[test] +fn factorizeh_solveh_t_shape_mismatch() { + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng).reversed_axes(); + let b: Array1 = random_using(2, &mut rng); + let f = a.factorizeh_into().unwrap(); + let _ = f.solveh_into(b); +} + #[test] fn solveh_random_t() { - let a: Array2 = random_hpd(3).reversed_axes(); - let x: Array1 = random(3); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_hpd_using(3, &mut rng).reversed_axes(); + let x: Array1 = random_using(3, &mut rng); let b = a.dot(&x); let y = a.solveh_into(b).unwrap(); assert_close_l2!(&x, &y, 1e-7); diff --git a/ndarray-linalg/tests/svd.rs b/ndarray-linalg/tests/svd.rs index c83885e1..0eac35ea 100644 --- a/ndarray-linalg/tests/svd.rs +++ b/ndarray-linalg/tests/svd.rs @@ -53,13 +53,15 @@ macro_rules! test_svd_impl { paste::item! { #[test] fn []() { - let a = random(($n, $m)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using(($n, $m), &mut rng); $test::<$type>(&a); } #[test] fn []() { - let a = random(($n, $m).f()); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using(($n, $m).f(), &mut rng); $test::<$type>(&a); } } diff --git a/ndarray-linalg/tests/svddc.rs b/ndarray-linalg/tests/svddc.rs index fb26c8d5..60c6bd66 100644 --- a/ndarray-linalg/tests/svddc.rs +++ b/ndarray-linalg/tests/svddc.rs @@ -1,16 +1,16 @@ use ndarray::*; use ndarray_linalg::*; -fn test(a: &Array2, flag: UVTFlag) { +fn test(a: &Array2, flag: JobSvd) { let (n, m) = a.dim(); let k = n.min(m); let answer = a.clone(); println!("a = \n{:?}", a); let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap(); let mut sm: Array2 = match flag { - UVTFlag::Full => Array::zeros((n, m)), - UVTFlag::Some => Array::zeros((k, k)), - UVTFlag::None => { + JobSvd::All => Array::zeros((n, m)), + JobSvd::Some => Array::zeros((k, k)), + JobSvd::None => { assert!(u.is_none()); assert!(vt.is_none()); return; @@ -32,38 +32,44 @@ macro_rules! test_svd_impl { paste::item! { #[test] fn []() { - let a = random(($n, $m)); - test::<$scalar>(&a, UVTFlag::Full); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using(($n, $m), &mut rng); + test::<$scalar>(&a, JobSvd::All); } #[test] fn []() { - let a = random(($n, $m)); - test::<$scalar>(&a, UVTFlag::Some); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using(($n, $m), &mut rng); + test::<$scalar>(&a, JobSvd::Some); } #[test] fn []() { - let a = random(($n, $m)); - test::<$scalar>(&a, UVTFlag::None); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using(($n, $m), &mut rng); + test::<$scalar>(&a, JobSvd::None); } #[test] fn []() { - let a = random(($n, $m).f()); - test::<$scalar>(&a, UVTFlag::Full); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using(($n, $m).f(), &mut rng); + test::<$scalar>(&a, JobSvd::All); } #[test] fn []() { - let a = random(($n, $m).f()); - test::<$scalar>(&a, UVTFlag::Some); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using(($n, $m).f(), &mut rng); + test::<$scalar>(&a, JobSvd::Some); } #[test] fn []() { - let a = random(($n, $m).f()); - test::<$scalar>(&a, UVTFlag::None); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a = random_using(($n, $m).f(), &mut rng); + test::<$scalar>(&a, JobSvd::None); } } }; diff --git a/ndarray-linalg/tests/trace.rs b/ndarray-linalg/tests/trace.rs index 9127be39..f93bb52b 100644 --- a/ndarray-linalg/tests/trace.rs +++ b/ndarray-linalg/tests/trace.rs @@ -3,6 +3,7 @@ use ndarray_linalg::*; #[test] fn trace() { - let a: Array2 = random((3, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let a: Array2 = random_using((3, 3), &mut rng); assert_rclose!(a.trace().unwrap(), a[(0, 0)] + a[(1, 1)] + a[(2, 2)], 1e-7); } diff --git a/ndarray-linalg/tests/triangular.rs b/ndarray-linalg/tests/triangular.rs index ca609c5c..4ddf0a8b 100644 --- a/ndarray-linalg/tests/triangular.rs +++ b/ndarray-linalg/tests/triangular.rs @@ -34,87 +34,99 @@ where #[test] fn triangular_1d_upper() { let n = 3; - let b: Array1 = random(n); - let a: Array2 = random((n, n)).into_triangular(UPLO::Upper); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array1 = random_using(n, &mut rng); + let a: Array2 = random_using((n, n), &mut rng).into_triangular(UPLO::Upper); test1d(UPLO::Upper, &a, &b, 1e-7); } #[test] fn triangular_1d_lower() { let n = 3; - let b: Array1 = random(n); - let a: Array2 = random((n, n)).into_triangular(UPLO::Lower); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array1 = random_using(n, &mut rng); + let a: Array2 = random_using((n, n), &mut rng).into_triangular(UPLO::Lower); test1d(UPLO::Lower, &a, &b, 1e-7); } #[test] fn triangular_1d_upper_t() { let n = 3; - let b: Array1 = random(n); - let a: Array2 = random((n, n).f()).into_triangular(UPLO::Upper); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array1 = random_using(n, &mut rng); + let a: Array2 = random_using((n, n).f(), &mut rng).into_triangular(UPLO::Upper); test1d(UPLO::Upper, &a, &b, 1e-7); } #[test] fn triangular_1d_lower_t() { let n = 3; - let b: Array1 = random(n); - let a: Array2 = random((n, n).f()).into_triangular(UPLO::Lower); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array1 = random_using(n, &mut rng); + let a: Array2 = random_using((n, n).f(), &mut rng).into_triangular(UPLO::Lower); test1d(UPLO::Lower, &a, &b, 1e-7); } #[test] fn triangular_2d_upper() { - let b: Array2 = random((3, 4)); - let a: Array2 = random((3, 3)).into_triangular(UPLO::Upper); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array2 = random_using((3, 4), &mut rng); + let a: Array2 = random_using((3, 3), &mut rng).into_triangular(UPLO::Upper); test2d(UPLO::Upper, &a, &b, 1e-7); } #[test] fn triangular_2d_lower() { - let b: Array2 = random((3, 4)); - let a: Array2 = random((3, 3)).into_triangular(UPLO::Lower); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array2 = random_using((3, 4), &mut rng); + let a: Array2 = random_using((3, 3), &mut rng).into_triangular(UPLO::Lower); test2d(UPLO::Lower, &a, &b, 1e-7); } #[test] fn triangular_2d_lower_t() { - let b: Array2 = random((3, 4)); - let a: Array2 = random((3, 3).f()).into_triangular(UPLO::Lower); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array2 = random_using((3, 4), &mut rng); + let a: Array2 = random_using((3, 3).f(), &mut rng).into_triangular(UPLO::Lower); test2d(UPLO::Lower, &a, &b, 1e-7); } #[test] fn triangular_2d_upper_t() { - let b: Array2 = random((3, 4)); - let a: Array2 = random((3, 3).f()).into_triangular(UPLO::Upper); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array2 = random_using((3, 4), &mut rng); + let a: Array2 = random_using((3, 3).f(), &mut rng).into_triangular(UPLO::Upper); test2d(UPLO::Upper, &a, &b, 1e-7); } #[test] fn triangular_2d_upper_bt() { - let b: Array2 = random((3, 4).f()); - let a: Array2 = random((3, 3)).into_triangular(UPLO::Upper); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array2 = random_using((3, 4).f(), &mut rng); + let a: Array2 = random_using((3, 3), &mut rng).into_triangular(UPLO::Upper); test2d(UPLO::Upper, &a, &b, 1e-7); } #[test] fn triangular_2d_lower_bt() { - let b: Array2 = random((3, 4).f()); - let a: Array2 = random((3, 3)).into_triangular(UPLO::Lower); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array2 = random_using((3, 4).f(), &mut rng); + let a: Array2 = random_using((3, 3), &mut rng).into_triangular(UPLO::Lower); test2d(UPLO::Lower, &a, &b, 1e-7); } #[test] fn triangular_2d_lower_t_bt() { - let b: Array2 = random((3, 4).f()); - let a: Array2 = random((3, 3).f()).into_triangular(UPLO::Lower); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array2 = random_using((3, 4).f(), &mut rng); + let a: Array2 = random_using((3, 3).f(), &mut rng).into_triangular(UPLO::Lower); test2d(UPLO::Lower, &a, &b, 1e-7); } #[test] fn triangular_2d_upper_t_bt() { - let b: Array2 = random((3, 4).f()); - let a: Array2 = random((3, 3).f()).into_triangular(UPLO::Upper); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let b: Array2 = random_using((3, 4).f(), &mut rng); + let a: Array2 = random_using((3, 3).f(), &mut rng).into_triangular(UPLO::Upper); test2d(UPLO::Upper, &a, &b, 1e-7); } diff --git a/ndarray-linalg/tests/tridiagonal.rs b/ndarray-linalg/tests/tridiagonal.rs index 38278951..513d625b 100644 --- a/ndarray-linalg/tests/tridiagonal.rs +++ b/ndarray-linalg/tests/tridiagonal.rs @@ -28,7 +28,8 @@ fn tridiagonal_index() { #[test] fn opnorm_tridiagonal() { - let mut a: Array2 = random((4, 4)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let mut a: Array2 = random_using((4, 4), &mut rng); a[[0, 2]] = 0.0; a[[0, 3]] = 0.0; a[[1, 3]] = 0.0; @@ -129,10 +130,11 @@ fn solve_tridiagonal_c64() { #[test] fn solve_tridiagonal_random() { - let mut a: Array2 = random((3, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let mut a: Array2 = random_using((3, 3), &mut rng); a[[0, 2]] = 0.0; a[[2, 0]] = 0.0; - let x: Array1 = random(3); + let x: Array1 = random_using(3, &mut rng); let b1 = a.dot(&x); let b2 = b1.clone(); let y1 = a.solve_tridiagonal_into(b1).unwrap(); @@ -143,10 +145,11 @@ fn solve_tridiagonal_random() { #[test] fn solve_tridiagonal_random_t() { - let mut a: Array2 = random((3, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let mut a: Array2 = random_using((3, 3), &mut rng); a[[0, 2]] = 0.0; a[[2, 0]] = 0.0; - let x: Array1 = random(3); + let x: Array1 = random_using(3, &mut rng); let at = a.t(); let b1 = at.dot(&x); let b2 = b1.clone(); @@ -158,11 +161,12 @@ fn solve_tridiagonal_random_t() { #[test] fn extract_tridiagonal_solve_random() { - let mut a: Array2 = random((3, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let mut a: Array2 = random_using((3, 3), &mut rng); a[[0, 2]] = 0.0; a[[2, 0]] = 0.0; let tridiag = a.extract_tridiagonal().unwrap(); - let x: Array1 = random(3); + let x: Array1 = random_using(3, &mut rng); let b1 = a.dot(&x); let b2 = b1.clone(); let y1 = tridiag.solve_tridiagonal_into(b1).unwrap(); @@ -180,7 +184,8 @@ fn det_tridiagonal_f64() { #[test] fn det_tridiagonal_random() { - let mut a: Array2 = random((3, 3)); + let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); + let mut a: Array2 = random_using((3, 3), &mut rng); a[[0, 2]] = 0.0; a[[2, 0]] = 0.0; assert_aclose!(a.det_tridiagonal().unwrap(), a.det().unwrap(), 1e-7);