Skip to content

Commit 542f4fb

Browse files
committed
Replace CLAPACK library with pure swift
1 parent c18e498 commit 542f4fb

File tree

7 files changed

+147
-137
lines changed

7 files changed

+147
-137
lines changed

Package.resolved

Lines changed: 0 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,16 @@ let package = Package(
1212
],
1313
dependencies: [
1414
.package(url: "https://github.com/apple/swift-collections", from: "1.0.0"),
15-
.package(url: "https://github.com/goodnotes/CLAPACK", branch: "eigen-support"),
1615
],
1716
targets: [
1817
.target(
1918
name: "pocketFFT"
2019
),
21-
.target(
22-
name: "CLAPACKHelper",
23-
dependencies: [
24-
.product(name: "CLAPACK", package: "CLAPACK"),
25-
],
26-
publicHeadersPath: "include"
27-
),
2820
.target(
2921
name: "Matft",
3022
dependencies: [
3123
.product(name: "Collections", package: "swift-collections"),
3224
"pocketFFT",
33-
.target(name: "CLAPACKHelper", condition: .when(platforms: [.wasi])),
3425
]),
3526
.testTarget(
3627
name: "MatftTests",

Sources/CLAPACKHelper/clapack_helper.c

Lines changed: 0 additions & 48 deletions
This file was deleted.

Sources/CLAPACKHelper/include/clapack_helper.h

Lines changed: 0 additions & 32 deletions
This file was deleted.

Sources/Matft/library/lapack.swift

Lines changed: 144 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import Foundation
1010

1111
// MARK: - Pure Swift Eigenvalue Implementation
12-
// This implementation is used on WASI where CLAPACK doesn't include dgeev.
1312
// It's also available on other platforms for testing purposes.
1413

1514
/// Computes the Euclidean norm of a vector segment using SIMD for performance
@@ -1231,10 +1230,8 @@ internal func svd_by_lapack<T: MfStorable>(_ mfarray: MfArray, _ full_matrices:
12311230
return (v.swapaxes(axis1: -1, axis2: -2), s, rt.swapaxes(axis1: -1, axis2: -2))
12321231
}
12331232
#else
1234-
// MARK: - WASI Implementation using CLAPACK
1235-
// Note: The CLAPACK eigen-support branch only provides dgetrf and dgetri (double-precision).
1236-
// Other LAPACK operations will throw fatalError on WASI.
1237-
import CLAPACKHelper
1233+
// MARK: - WASI Implementation (Pure Swift)
1234+
// All LAPACK operations are implemented in pure Swift for WASI compatibility.
12381235

12391236
public typealias __CLPK_integer = Int32
12401237

@@ -1252,12 +1249,7 @@ internal typealias lapack_LU<T> = (UnsafeMutablePointer<__CLPK_integer>, UnsafeM
12521249

12531250
internal typealias lapack_inv<T> = (UnsafeMutablePointer<__CLPK_integer>, UnsafeMutablePointer<T>, UnsafeMutablePointer<__CLPK_integer>, UnsafeMutablePointer<__CLPK_integer>, UnsafeMutablePointer<T>, UnsafeMutablePointer<__CLPK_integer>, UnsafeMutablePointer<__CLPK_integer>) -> Int32
12541251

1255-
// MARK: - CLAPACK Wrapper Functions for WASI
1256-
// dgeev_ is implemented in pure Swift (swiftEigenDecomposition) above
1257-
// Note: CLAPACK uses "integer" = long int, which on wasm32 is 32-bit (CLong)
1258-
// Note: The CLAPACK header declares functions as returning int, but the f2c-generated
1259-
// implementation returns void. This causes linker warnings but doesn't affect correctness
1260-
// since we don't rely on the return value from the C function (we use info parameter instead).
1252+
// MARK: - Pure Swift LAPACK-compatible Functions for WASI
12611253

12621254
@inline(__always)
12631255
internal func sgesv_(_ n: UnsafeMutablePointer<__CLPK_integer>, _ nrhs: UnsafeMutablePointer<__CLPK_integer>, _ a: UnsafeMutablePointer<Float>, _ lda: UnsafeMutablePointer<__CLPK_integer>, _ ipiv: UnsafeMutablePointer<__CLPK_integer>, _ b: UnsafeMutablePointer<Float>, _ ldb: UnsafeMutablePointer<__CLPK_integer>, _ info: UnsafeMutablePointer<__CLPK_integer>) -> Int32 {
@@ -1274,47 +1266,165 @@ internal func sgetrf_(_ m: UnsafeMutablePointer<__CLPK_integer>, _ n: UnsafeMuta
12741266
fatalError("LAPACK sgetrf_ is not available on WASI (single-precision LU decomposition not supported)")
12751267
}
12761268

1269+
/// Pure Swift implementation of dgetrf_ (LU decomposition with partial pivoting)
1270+
/// Computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
1271+
/// The factorization has the form: A = P * L * U
1272+
/// where P is a permutation matrix, L is lower triangular with unit diagonal, and U is upper triangular.
12771273
@inline(__always)
12781274
internal func dgetrf_(_ m: UnsafeMutablePointer<__CLPK_integer>, _ n: UnsafeMutablePointer<__CLPK_integer>, _ a: UnsafeMutablePointer<Double>, _ lda: UnsafeMutablePointer<__CLPK_integer>, _ ipiv: UnsafeMutablePointer<__CLPK_integer>, _ info: UnsafeMutablePointer<__CLPK_integer>) -> Int32 {
1279-
// Call CLAPACK via helper wrapper that uses correct void return type
1280-
var mLong = clapack_int(m.pointee)
1281-
var nLong = clapack_int(n.pointee)
1282-
var ldaLong = clapack_int(lda.pointee)
1283-
var infoLong: clapack_int = 0
1284-
let minMN = min(Int(m.pointee), Int(n.pointee))
1275+
let mVal = Int(m.pointee)
1276+
let nVal = Int(n.pointee)
1277+
let ldaVal = Int(lda.pointee)
1278+
let minMN = min(mVal, nVal)
1279+
1280+
info.pointee = 0
1281+
1282+
// Quick return if possible
1283+
if mVal == 0 || nVal == 0 {
1284+
return 0
1285+
}
1286+
1287+
// LU decomposition with partial pivoting (Doolittle algorithm)
1288+
for k in 0..<minMN {
1289+
// Find pivot - largest absolute value in column k from row k to m-1
1290+
var maxVal = abs(a[k * ldaVal + k])
1291+
var maxIdx = k
1292+
for i in (k + 1)..<mVal {
1293+
let val = abs(a[k * ldaVal + i])
1294+
if val > maxVal {
1295+
maxVal = val
1296+
maxIdx = i
1297+
}
1298+
}
1299+
1300+
// Store pivot index (1-based for LAPACK compatibility)
1301+
ipiv[k] = __CLPK_integer(maxIdx + 1)
1302+
1303+
// Check for singularity
1304+
if a[k * ldaVal + maxIdx] == 0.0 {
1305+
if info.pointee == 0 {
1306+
info.pointee = __CLPK_integer(k + 1)
1307+
}
1308+
continue
1309+
}
12851310

1286-
// On wasm32, clapack_int (long) and __CLPK_integer (Int32) are both 32-bit,
1287-
// so we can use withMemoryRebound to avoid array allocation and copying
1288-
ipiv.withMemoryRebound(to: clapack_int.self, capacity: minMN) { ipivRebound in
1289-
clapack_dgetrf_wrapper(&mLong, &nLong, a, &ldaLong, ipivRebound, &infoLong)
1311+
// Swap rows k and maxIdx
1312+
if maxIdx != k {
1313+
for j in 0..<nVal {
1314+
let temp = a[j * ldaVal + k]
1315+
a[j * ldaVal + k] = a[j * ldaVal + maxIdx]
1316+
a[j * ldaVal + maxIdx] = temp
1317+
}
1318+
}
1319+
1320+
// Compute multipliers (elements of L below diagonal)
1321+
let pivot = a[k * ldaVal + k]
1322+
for i in (k + 1)..<mVal {
1323+
a[k * ldaVal + i] /= pivot
1324+
}
1325+
1326+
// Update trailing submatrix
1327+
for j in (k + 1)..<nVal {
1328+
let colVal = a[j * ldaVal + k]
1329+
for i in (k + 1)..<mVal {
1330+
a[j * ldaVal + i] -= a[k * ldaVal + i] * colVal
1331+
}
1332+
}
12901333
}
12911334

1292-
info.pointee = __CLPK_integer(infoLong)
1293-
return Int32(infoLong)
1335+
return info.pointee
12941336
}
12951337

12961338
@inline(__always)
12971339
internal func sgetri_(_ n: UnsafeMutablePointer<__CLPK_integer>, _ a: UnsafeMutablePointer<Float>, _ lda: UnsafeMutablePointer<__CLPK_integer>, _ ipiv: UnsafeMutablePointer<__CLPK_integer>, _ work: UnsafeMutablePointer<Float>, _ lwork: UnsafeMutablePointer<__CLPK_integer>, _ info: UnsafeMutablePointer<__CLPK_integer>) -> Int32 {
12981340
fatalError("LAPACK sgetri_ is not available on WASI (single-precision matrix inverse not supported)")
12991341
}
13001342

1343+
/// Pure Swift implementation of dgetri_ (matrix inversion using LU factorization)
1344+
/// Computes the inverse of a matrix using the LU factorization computed by dgetrf_.
13011345
@inline(__always)
13021346
internal func dgetri_(_ n: UnsafeMutablePointer<__CLPK_integer>, _ a: UnsafeMutablePointer<Double>, _ lda: UnsafeMutablePointer<__CLPK_integer>, _ ipiv: UnsafeMutablePointer<__CLPK_integer>, _ work: UnsafeMutablePointer<Double>, _ lwork: UnsafeMutablePointer<__CLPK_integer>, _ info: UnsafeMutablePointer<__CLPK_integer>) -> Int32 {
1303-
// Call CLAPACK via helper wrapper that uses correct void return type
1304-
var nLong = clapack_int(n.pointee)
1305-
var ldaLong = clapack_int(lda.pointee)
1306-
var lworkLong = clapack_int(lwork.pointee)
1307-
var infoLong: clapack_int = 0
13081347
let nVal = Int(n.pointee)
1348+
let ldaVal = Int(lda.pointee)
1349+
let lworkVal = Int(lwork.pointee)
1350+
1351+
// Workspace query
1352+
if lworkVal == -1 {
1353+
work.pointee = Double(nVal)
1354+
info.pointee = 0
1355+
return 0
1356+
}
1357+
1358+
info.pointee = 0
13091359

1310-
// On wasm32, clapack_int (long) and __CLPK_integer (Int32) are both 32-bit,
1311-
// so we can use withMemoryRebound to avoid array allocation and copying
1312-
ipiv.withMemoryRebound(to: clapack_int.self, capacity: nVal) { ipivRebound in
1313-
clapack_dgetri_wrapper(&nLong, a, &ldaLong, ipivRebound, work, &lworkLong, &infoLong)
1360+
// Quick return if possible
1361+
if nVal == 0 {
1362+
return 0
1363+
}
1364+
1365+
// Check for singularity (zero diagonal in U)
1366+
for i in 0..<nVal {
1367+
if a[i * ldaVal + i] == 0.0 {
1368+
info.pointee = __CLPK_integer(i + 1)
1369+
return info.pointee
1370+
}
1371+
}
1372+
1373+
// Step 1: Invert U (upper triangular part) in place
1374+
for j in 0..<nVal {
1375+
a[j * ldaVal + j] = 1.0 / a[j * ldaVal + j]
1376+
let ajj = -a[j * ldaVal + j]
1377+
1378+
// Compute elements 0:j-1 of j-th column
1379+
for k in 0..<j {
1380+
work[k] = a[j * ldaVal + k]
1381+
a[j * ldaVal + k] = 0.0
1382+
}
1383+
1384+
for k in 0..<j {
1385+
let workK = work[k]
1386+
for i in 0...k {
1387+
a[j * ldaVal + i] += workK * a[k * ldaVal + i]
1388+
}
1389+
}
1390+
1391+
for i in 0..<j {
1392+
a[j * ldaVal + i] *= ajj
1393+
}
1394+
}
1395+
1396+
// Step 2: Solve the equation inv(A)*L = inv(U) for inv(A)
1397+
// Process columns from right to left
1398+
for j in stride(from: nVal - 2, through: 0, by: -1) {
1399+
// Copy lower triangular part of column j to work array
1400+
for i in (j + 1)..<nVal {
1401+
work[i] = a[j * ldaVal + i]
1402+
a[j * ldaVal + i] = 0.0
1403+
}
1404+
1405+
// Update column j of inverse
1406+
for k in (j + 1)..<nVal {
1407+
let workK = work[k]
1408+
for i in 0..<nVal {
1409+
a[j * ldaVal + i] -= workK * a[k * ldaVal + i]
1410+
}
1411+
}
1412+
}
1413+
1414+
// Step 3: Apply column interchanges (reverse order of pivots)
1415+
for j in stride(from: nVal - 2, through: 0, by: -1) {
1416+
let jp = Int(ipiv[j]) - 1 // Convert from 1-based to 0-based
1417+
if jp != j {
1418+
// Swap columns j and jp
1419+
for i in 0..<nVal {
1420+
let temp = a[j * ldaVal + i]
1421+
a[j * ldaVal + i] = a[jp * ldaVal + i]
1422+
a[jp * ldaVal + i] = temp
1423+
}
1424+
}
13141425
}
13151426

1316-
info.pointee = __CLPK_integer(infoLong)
1317-
return Int32(infoLong)
1427+
return 0
13181428
}
13191429

13201430
@inline(__always)

Tests/MatftTests/LinAlgTest.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ final class LinAlgTests: XCTestCase {
4747
[ 1.5, -0.5]], mftype: .Float))
4848
}
4949
#endif
50-
// Double test - uses dgetrf_/dgetri_ (available on WASM via CLAPACK)
50+
// Double test - uses dgetrf_/dgetri_
5151
do{
5252
let a = MfArray([[[1.0, 2.0],
5353
[3.0, 4.0]],
@@ -70,7 +70,7 @@ final class LinAlgTests: XCTestCase {
7070
}
7171
#endif
7272

73-
// Double test - uses dgetrf_ (available on WASM via CLAPACK)
73+
// Double test - uses dgetrf_
7474
do{
7575
let a = MfArray([[[1.0, 2.0],
7676
[3.0, 4.0]],

Tests/MatftTests/WASIFallbackTests.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,7 @@ final class WASIFallbackTests: XCTestCase {
465465
XCTAssertEqual(scalar, 42)
466466
}
467467

468-
// MARK: - Linear Algebra Tests (CLAPACK-backed for WASI)
469-
// These tests validate the CLAPACK wrapper implementations used on WASI
470-
// Only Double-precision is supported via CLAPACK
468+
// MARK: - Linear Algebra Tests
471469

472470
func testMatrixInverseDouble() {
473471
// Test 2x2 matrix inverse (Double precision - supported on WASI)

0 commit comments

Comments
 (0)