Tridiagonal矩阵,即只有主对角线、主对角线上一条对角线,主对角线下一条对角线含有非零元素的矩阵。
比如说,n=5时为例,以下就是一个tridiagonal矩阵:
[tex]\begin{pmatrix}1 & 3 & 0 & 0 & 0 \\ 2 & 4 & 6 & 0 & 0 \\ 0 & 7 & 9 & 2 & 0 \\ 0 & 0 & 8 & 1 & 3 \\ 0 & 0 & 0 & 5 & 7\end{pmatrix}[\tex]
非tex公式:
解法原理与LU分解法相同,只不过简化了步骤增强效率。
L矩阵为主对角线和主对角线下一条对角线的矩阵,即含有a斜线和b斜线,U矩阵为1作为对角线和主对角线上一条对角线的矩阵,即对角线上皆为1并含有c斜线。LU计算方法与LU分解法中类似。
(m0 0 0 0 0 0 0 ) (1 p1 0 0 0 0 0 )
(n1 m1 0 0 0 0 0 ) (0 1 p2 0 0 0 0 )
(0 n2 m2 0 0 0 0 ) (0 0 1 p3 0 0 0 )
L =(0 0 n3 m3 0 0 0 ) U=(0 0 0 1 p4 0 0 )
(0 0 0 n4 m4 0 0 ) (0 0 0 0 1 p5 0 )
(0 0 0 0 n5 m5 0 ) (0 0 0 0 0 1 p6)
(0 0 0 0 0 n6 m6) (0 0 0 0 0 0 1 )
其中L和U的计算有:
n[i] = a[i];
m[0] = b[0]; p[i] = c[i] / m[i];
m[i+1] = b[i+1] - a[i+1] * p[i],其中i = 1, 2, ..., n.
现在同样有:L * y = r; U * x = y。其中r为右端项(right-hand side)。
计算公式:
y[1] = r[1]/b[1]; y[i] = (r[i] - a[i] * y[i-1]) / p[i],其中i = 2, 3, ..., n.
x[n] = y[n]; x[j] = y[j] - p[i] * x[i+1],其中j = n-1, n-2, ..., 1.
/** *@file triDiagonal.h */ #ifndef TRIDIAGONAL_H_INCLUDED #define TRIDIAGONAL_H_INCLUDED /** *@brief Solve linear equations P * u = r, which P is a tridiagonal matrix. *@param below_diagonal[0..n-1] the diagonal below the main diagonal in P-- generally starts from 1 to n-1 (won't be modified) *@param main_diagonal[0..n-1] the main diagonal in P-- generally starts from 0 to n-1 (won't be modified) *@param above_diagonal[0..n-1] the diagonal above the main diagonal in P-- generally starts from 0 to n-2 (won't be modified) *@param right_hand_side[0..n-1] the vector r in linear equation P * u = r (won't be modified) *@param solution[0..n-1] the corresponding set of solution vectors *triDiagonal Equation P * u = r, taken n = 7 as an example: * (b0 c0 0 0 0 0 0) (u0) (r0) * (a1 b1 c1 0 0 0 0) (u1) (r1) * (0 a2 b2 c2 0 0 0) (u2) (r2) * (0 0 a3 b3 c3 0 0) * (u3) = (r3) * (0 0 0 a4 b4 c4 0) (u4) (r4) * (0 0 0 0 a5 b5 c5) (u5) (r5) * (0 0 0 0 0 a6 c6) (u6) (r6) * a[] is the below_diagonal, b[] is the main_diagonal, * c[] is the above_diagonal, r[] is the right_hand_side and u[] is the solution *@attention below_diagonal[0] and above_diagonal[n-1] are undefined and are not referenced by the function. */ void triDiagonal( VecDoub_I &below_diagonal, VecDoub_I &main_diagonal, VecDoub_I &above_diagonal, VecDoub_I &right_hand_side, VecDoub_O &solution){ Int j, n = below_diagonal.size(); Doub bet; VecDoub gam(n); //One vector of workspace, gam, is needed. if(main_diagonal[0] == 0.0) NRthrow("Error 1 in triDiagonal: Please Check Your Matrix.\n" "Suggesion: Rewrite your equations as a set of order N-1."); solution[0] = right_hand_side[0]/(bet = main_diagonal[0]); for(j = 1; j < n; ++j){ //Decomposition and forward substitution. gam[j] = above_diagonal[j-1] / bet; bet = main_diagonal[j] - below_diagonal[j] * gam[j]; if(bet == 0.0) NRthrow("Error 2 in triDiagonal: Algorithm fails.\n" "Suggesion: Use Other Function to solve the problem."); solution[j] = (right_hand_side[j] - below_diagonal[j] * solution[j-1]) / bet; } //for for(j = (n-2); j >= 0; --j){ //Backsubstitution solution[j] -= gam[j+1] * solution[j+1]; } //for } //triDiagonal(VecDoub_I &below_diagonal, VecDoub_I &main_diagonal, VecDoub_I &above_diagonal, VecDoub_I &right_hand_side, VecDoub_O &solution) #endif // TRIDIAGONAL_H_INCLUDED
函数参数请看函数的注释部分。需要测试文件请看下方,大致例程如下:
/** *@file triDiagonal_test.cpp */ #include <iostream> #include "nr3.h" #include "triDiagonal.h" using namespace std; int main(){ int size; VecDoub mainDiagonal, belowDiagonal, aboveDiagonal, rightHandSide, solution; //INPUT DATA: cout << "SOLVE A * x = b with A as tridiagonal matrix specially." << endl; cout << "Please input diagonal size for A, which should be the same as the size of b and x: " << endl; cin >> size; belowDiagonal.resize(size); mainDiagonal.resize(size); aboveDiagonal.resize(size); rightHandSide.resize(size); solution.resize(size); cout << "Please input the below diagonal data for A: " << endl << "(Attention: below_diagonal[0] is undefined and will not be referenced by the function)"<< endl; readVector_console(belowDiagonal); cout << "Please input the main diagonal data for A: " << endl; readVector_console(mainDiagonal); cout << "Please input the above diagonal data for A: " << endl << "(Attention: above_diagonal[n-1] is undefined and will not be referenced by the function)"<< endl;; readVector_console(aboveDiagonal); cout << "Please input the vector data for b: " << endl; readVector_console(rightHandSide); //Do the Linear Solve triDiagonal(belowDiagonal, mainDiagonal, aboveDiagonal, rightHandSide, solution); //PRINT: printVector_console(solution); return 0; }