diff --git a/src/LinSolveAlgo.jl b/src/LinSolveAlgo.jl index b3992965350257b6efc75dd85a57f6c4bb82cb00..b1eef9717423d03441f15345b97eefc0fcfddd67 100644 --- a/src/LinSolveAlgo.jl +++ b/src/LinSolveAlgo.jl @@ -36,7 +36,7 @@ for (getrf, getrs, elty) in mm,nn = size(a) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) - lda = Ref{BlasInt}(max(1,stride(a,2))) + lda = Ref{BlasInt}(max(1,stride(ws.lu,2))) info = Ref{BlasInt}(0) ccall((@blasfunc($getrf), liblapack), Cvoid, (Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt}, @@ -72,7 +72,7 @@ for (getrf, getrs, elty) in m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) nhrs = Ref{BlasInt}(size(b,2)) - lda = Ref{BlasInt}(max(1,stride(a,2))) + lda = Ref{BlasInt}(max(1,stride(ws.lu,2))) ldb = Ref{BlasInt}(max(1,stride(b,2))) info = Ref{BlasInt}(0) @@ -90,22 +90,8 @@ for (getrf, getrs, elty) in b::StridedVecOrMat{$elty}, ws::LinSolveWs) - mm,nn = size(a) - m = Ref{BlasInt}(mm) - n = Ref{BlasInt}(nn) - nhrs = Ref{BlasInt}(size(b,2)) - lda = Ref{BlasInt}(max(1,stride(a,2))) - ldb = Ref{BlasInt}(max(1,stride(b,2))) - info = Ref{BlasInt}(0) - lu!(a, ws) - ccall((@blasfunc($getrs), liblapack), Cvoid, - (Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt}, - Ptr{BlasInt},Ptr{$elty},Ref{BlasInt},Ref{BlasInt}), - $transchar, n, nhrs, ws.lu, lda, ws.ipiv, b, ldb, info) - if info[] != 0 - chklapackerror(info[]) - end + linsolve_core_no_lu!(a, b, ws) end end end