Tridiagonal矩阵,即只有主对角线、主对角线上一条对角线,主对角线下一条对角线含有非零元素的矩阵。
比如说,n=5时为例,以下就是一个tridiagonal矩阵:
[tex]\begin{pmatrix}1 & 3 & 0 & 0 & 0 \\ 2 & 4 & 6 & 0 & 0 \\ 0 & 7 & 9 & 2 & 0 \\ 0 & 0 & 8 & 1 & 3 \\ 0 & 0 & 0 & 5 & 7\end{pmatrix}[\tex]
非tex公式:
解法原理与LU分解法相同,只不过简化了步骤增强效率。
L矩阵为主对角线和主对角线下一条对角线的矩阵,即含有a斜线和b斜线,U矩阵为1作为对角线和主对角线上一条对角线的矩阵,即对角线上皆为1并含有c斜线。LU计算方法与LU分解法中类似。
(m0 0 0 0 0 0 0 ) (1 p1 0 0 0 0 0 )
(n1 m1 0 0 0 0 0 ) (0 1 p2 0 0 0 0 )
(0 n2 m2 0 0 0 0 ) (0 0 1 p3 0 0 0 )
L =(0 0 n3 m3 0 0 0 ) U=(0 0 0 1 p4 0 0 )
(0 0 0 n4 m4 0 0 ) (0 0 0 0 1 p5 0 )
(0 0 0 0 n5 m5 0 ) (0 0 0 0 0 1 p6)
(0 0 0 0 0 n6 m6) (0 0 0 0 0 0 1 )
其中L和U的计算有:
n[i] = a[i];
m[0] = b[0]; p[i] = c[i] / m[i];
m[i+1] = b[i+1] - a[i+1] * p[i],其中i = 1, 2, ..., n.
现在同样有:L * y = r; U * x = y。其中r为右端项(right-hand side)。
计算公式:
y[1] = r[1]/b[1]; y[i] = (r[i] - a[i] * y[i-1]) / p[i],其中i = 2, 3, ..., n.
x[n] = y[n]; x[j] = y[j] - p[i] * x[i+1],其中j = n-1, n-2, ..., 1.
/**
*@file triDiagonal.h
*/
#ifndef TRIDIAGONAL_H_INCLUDED
#define TRIDIAGONAL_H_INCLUDED
/**
*@brief Solve linear equations P * u = r, which P is a tridiagonal matrix.
*@param below_diagonal[0..n-1] the diagonal below the main diagonal in P-- generally starts from 1 to n-1 (won't be modified)
*@param main_diagonal[0..n-1] the main diagonal in P-- generally starts from 0 to n-1 (won't be modified)
*@param above_diagonal[0..n-1] the diagonal above the main diagonal in P-- generally starts from 0 to n-2 (won't be modified)
*@param right_hand_side[0..n-1] the vector r in linear equation P * u = r (won't be modified)
*@param solution[0..n-1] the corresponding set of solution vectors
*triDiagonal Equation P * u = r, taken n = 7 as an example:
* (b0 c0 0 0 0 0 0) (u0) (r0)
* (a1 b1 c1 0 0 0 0) (u1) (r1)
* (0 a2 b2 c2 0 0 0) (u2) (r2)
* (0 0 a3 b3 c3 0 0) * (u3) = (r3)
* (0 0 0 a4 b4 c4 0) (u4) (r4)
* (0 0 0 0 a5 b5 c5) (u5) (r5)
* (0 0 0 0 0 a6 c6) (u6) (r6)
* a[] is the below_diagonal, b[] is the main_diagonal,
* c[] is the above_diagonal, r[] is the right_hand_side and u[] is the solution
*@attention below_diagonal[0] and above_diagonal[n-1] are undefined and are not referenced by the function.
*/
void triDiagonal(
VecDoub_I &below_diagonal,
VecDoub_I &main_diagonal,
VecDoub_I &above_diagonal,
VecDoub_I &right_hand_side,
VecDoub_O &solution){
Int j, n = below_diagonal.size();
Doub bet;
VecDoub gam(n); //One vector of workspace, gam, is needed.
if(main_diagonal[0] == 0.0)
NRthrow("Error 1 in triDiagonal: Please Check Your Matrix.\n"
"Suggesion: Rewrite your equations as a set of order N-1.");
solution[0] = right_hand_side[0]/(bet = main_diagonal[0]);
for(j = 1; j < n; ++j){
//Decomposition and forward substitution.
gam[j] = above_diagonal[j-1] / bet;
bet = main_diagonal[j] - below_diagonal[j] * gam[j];
if(bet == 0.0)
NRthrow("Error 2 in triDiagonal: Algorithm fails.\n"
"Suggesion: Use Other Function to solve the problem.");
solution[j] = (right_hand_side[j] - below_diagonal[j] * solution[j-1]) / bet;
} //for
for(j = (n-2); j >= 0; --j){
//Backsubstitution
solution[j] -= gam[j+1] * solution[j+1];
} //for
} //triDiagonal(VecDoub_I &below_diagonal, VecDoub_I &main_diagonal, VecDoub_I &above_diagonal, VecDoub_I &right_hand_side, VecDoub_O &solution)
#endif // TRIDIAGONAL_H_INCLUDED
函数参数请看函数的注释部分。需要测试文件请看下方,大致例程如下:
/**
*@file triDiagonal_test.cpp
*/
#include <iostream>
#include "nr3.h"
#include "triDiagonal.h"
using namespace std;
int main(){
int size;
VecDoub mainDiagonal, belowDiagonal, aboveDiagonal, rightHandSide, solution;
//INPUT DATA:
cout << "SOLVE A * x = b with A as tridiagonal matrix specially." << endl;
cout << "Please input diagonal size for A, which should be the same as the size of b and x: " << endl;
cin >> size;
belowDiagonal.resize(size);
mainDiagonal.resize(size);
aboveDiagonal.resize(size);
rightHandSide.resize(size);
solution.resize(size);
cout << "Please input the below diagonal data for A: " << endl
<< "(Attention: below_diagonal[0] is undefined and will not be referenced by the function)"<< endl;
readVector_console(belowDiagonal);
cout << "Please input the main diagonal data for A: " << endl;
readVector_console(mainDiagonal);
cout << "Please input the above diagonal data for A: " << endl
<< "(Attention: above_diagonal[n-1] is undefined and will not be referenced by the function)"<< endl;;
readVector_console(aboveDiagonal);
cout << "Please input the vector data for b: " << endl;
readVector_console(rightHandSide);
//Do the Linear Solve
triDiagonal(belowDiagonal, mainDiagonal, aboveDiagonal, rightHandSide, solution);
//PRINT:
printVector_console(solution);
return 0;
}