三对角线线性方程组解法 -- Tridiagonal

Adam posted @ Thu, 08 Mar 2012 01:19:10 +0800 in Numerical Receipes with tags c++ NumericalAlgebra , 7725 readers

Tridiagonal矩阵,即只有主对角线、主对角线上一条对角线,主对角线下一条对角线含有非零元素的矩阵。

比如说,n=5时为例,以下就是一个tridiagonal矩阵:

[tex]\begin{pmatrix}1 & 3 & 0 & 0 & 0 \\ 2 & 4 & 6 & 0 & 0 \\ 0 & 7 & 9 & 2 & 0 \\ 0 & 0 & 8 & 1 & 3 \\ 0 & 0 & 0 & 5 & 7\end{pmatrix}[\tex]

非tex公式:

(1 3 0 0 0)
(2 4 6 0 0)
(0 7 9 2 0)
(0 0 8 1 3)
(0 0 0 5 7)

解法原理与LU分解法相同,只不过简化了步骤增强效率。

(b0 c0 0  0  0  0  0 )  (u0)  (r0)
(a1 b1 c1 0  0  0  0 )  (u1)  (r1)
(0  a2 b2 c2 0  0  0 )  (u2)  (r2)
(0  0  a3 b3 c3 0  0 )*(u3)=(r3)
(0  0  0  a4 b4 c4 0 )  (u4)  (r4)
(0  0  0  0  a5 b5 c5)  (u5)  (r5)
(0  0  0  0  0  a6 c6)  (u6)  (r6)

L矩阵为主对角线和主对角线下一条对角线的矩阵,即含有a斜线和b斜线,U矩阵为1作为对角线和主对角线上一条对角线的矩阵,即对角线上皆为1并含有c斜线。LU计算方法与LU分解法中类似。

    (m0 0  0  0  0  0  0 )     (1  p1 0  0  0  0  0 )
    (n1 m1 0  0  0  0  0 )     (0  1  p2 0  0  0  0 )
    (0  n2 m2 0  0  0  0 )     (0  0  1  p3 0  0  0 )
L =(0  0  n3 m3 0  0  0 ) U=(0  0  0  1  p4 0  0 )
    (0  0  0  n4 m4 0  0 )     (0  0  0  0  1  p5 0 )
    (0  0  0  0  n5 m5 0 )     (0  0  0  0  0  1  p6)
    (0  0  0  0  0  n6 m6)     (0  0  0  0  0  0  1 )

其中L和U的计算有:

n[i] = a[i];
m[0] = b[0]; p[i] = c[i] / m[i];
m[i+1] = b[i+1] - a[i+1] * p[i],其中i = 1, 2, ..., n.

现在同样有:L * y = r; U * x = y。其中r为右端项(right-hand side)。

计算公式:

y[1] = r[1]/b[1]; y[i] = (r[i] - a[i] * y[i-1]) / p[i],其中i = 2, 3, ..., n.
x[n] = y[n]; x[j] = y[j] - p[i] * x[i+1],其中j = n-1, n-2, ..., 1.

/**
*@file triDiagonal.h
*/
#ifndef TRIDIAGONAL_H_INCLUDED
#define TRIDIAGONAL_H_INCLUDED

/**
*@brief Solve linear equations P * u = r, which P is a tridiagonal matrix.
*@param below_diagonal[0..n-1] the diagonal below the main diagonal in P-- generally starts from 1 to n-1 (won't be modified)
*@param main_diagonal[0..n-1] the main diagonal in P-- generally starts from 0 to n-1 (won't be modified)
*@param above_diagonal[0..n-1] the diagonal above the main diagonal in P-- generally starts from 0 to n-2 (won't be modified)
*@param right_hand_side[0..n-1] the vector r in linear equation P * u = r (won't be modified)
*@param solution[0..n-1] the corresponding set of solution vectors
*triDiagonal Equation P * u = r, taken n = 7 as an example:
*	(b0	c0	0	0	0	0	0)		(u0)		(r0)
*	(a1	b1	c1	0	0	0	0)		(u1)		(r1)
*	(0	a2	b2	c2	0	0	0)		(u2)		(r2)
*	(0	0	a3	b3	c3	0	0)	*	(u3)	=	(r3)
*	(0	0	0	a4	b4	c4	0)		(u4)		(r4)
*	(0	0	0	0	a5	b5	c5)		(u5)		(r5)
*	(0	0	0	0	0	a6	c6)		(u6)		(r6)
*	a[] is the below_diagonal, b[] is the main_diagonal,
*	c[] is the above_diagonal, r[] is the right_hand_side and u[] is the solution
*@attention below_diagonal[0] and above_diagonal[n-1] are undefined and are not referenced by the function.
*/
void triDiagonal(
		VecDoub_I &below_diagonal,
		VecDoub_I &main_diagonal,
		VecDoub_I &above_diagonal,
		VecDoub_I &right_hand_side,
		VecDoub_O &solution){
	Int j, n = below_diagonal.size();
	Doub bet;
	VecDoub gam(n); //One vector of workspace, gam, is needed.
	if(main_diagonal[0] == 0.0)
		NRthrow("Error 1 in triDiagonal: Please Check Your Matrix.\n"
				"Suggesion: Rewrite your equations as a set of order N-1.");
	solution[0] = right_hand_side[0]/(bet = main_diagonal[0]);
	for(j = 1; j < n; ++j){
		//Decomposition and forward substitution.
		gam[j] = above_diagonal[j-1] / bet;
		bet = main_diagonal[j] - below_diagonal[j] * gam[j];
		if(bet == 0.0)
			NRthrow("Error 2 in triDiagonal: Algorithm fails.\n"
					"Suggesion: Use Other Function to solve the problem.");
		solution[j] = (right_hand_side[j] - below_diagonal[j] * solution[j-1]) / bet;
	} //for
	for(j = (n-2); j >= 0; --j){
		//Backsubstitution
		solution[j] -= gam[j+1] * solution[j+1];
	} //for
} //triDiagonal(VecDoub_I &below_diagonal, VecDoub_I &main_diagonal, VecDoub_I &above_diagonal, VecDoub_I &right_hand_side, VecDoub_O &solution)

#endif // TRIDIAGONAL_H_INCLUDED

函数参数请看函数的注释部分。需要测试文件请看下方,大致例程如下:

/**
*@file triDiagonal_test.cpp
*/
#include <iostream>
#include "nr3.h"
#include "triDiagonal.h"
using namespace std;

int main(){
	int size;
	VecDoub	mainDiagonal, belowDiagonal, aboveDiagonal, rightHandSide, solution;
	//INPUT DATA:
	cout << "SOLVE A * x = b with A as tridiagonal matrix specially." << endl;
	cout << "Please input diagonal size for A, which should be the same as the size of b and x: " << endl;
	cin >> size;
	belowDiagonal.resize(size);
	mainDiagonal.resize(size);
	aboveDiagonal.resize(size);
	rightHandSide.resize(size);
	solution.resize(size);
	cout << "Please input the below diagonal data for A: " << endl
		<< "(Attention: below_diagonal[0] is undefined and will not be referenced by the function)"<< endl;
	readVector_console(belowDiagonal);
	cout << "Please input the main diagonal data for A: " << endl;
	readVector_console(mainDiagonal);
	cout << "Please input the above diagonal data for A: " << endl
		<< "(Attention: above_diagonal[n-1] is undefined and will not be referenced by the function)"<< endl;;
	readVector_console(aboveDiagonal);
	cout << "Please input the vector data for b: " << endl;
	readVector_console(rightHandSide);
	//Do the Linear Solve
	triDiagonal(belowDiagonal, mainDiagonal, aboveDiagonal, rightHandSide, solution);
	//PRINT:
	printVector_console(solution);
	return 0;
}

 


Login *


loading captcha image...
(type the code from the image)
or Ctrl+Enter