前言
矩阵乘法可以采用分治的策略。
这里提供了两个分治策略的解决
n
∗
n
n*n
n∗n矩阵之间乘法的算法
1.矩阵乘法的普通递归方法
2.矩阵乘法的Strassen(斯特拉森)方法
但是着两个方法的缺点是只能是两个 n ∗ n n*n n∗n矩阵的乘法,同时n必须为2的幂
之后也对这两个算法进行了时间复杂度上的分析
一、矩阵乘法的普通递归方法
1.C语言代码实现
#include<stdio.h>
#include<stdlib.h>
#include<time.h>
//宏定义改变矩阵的大小
//此处size指的是n*n矩阵的宽度或者长度n
//由于Strassen算法本身的局限性
//两个相乘的矩阵只能是n*n,且n只能是2的幂
//即size为1,2,4,8,16...
#define size 4
//矩阵的合并
//就是将A11,A12,A21,A22合并为
// |A11 A12|
// |A21 A22|
//这样的形式
void Merge_Matrix(int *a,int *b,int *c,int *d,int* c0,int rows)
{
//rows是子矩阵的宽度,那么合并后矩阵的宽度就是2*rows
int i=0;//子矩阵遍历索引
int j=0;//合并矩阵遍历索引,此处先合并A11
for(i=0;i<(rows*rows);i++)
{
//如果执行了rows次就需要换行,要锁定到合并后矩阵的第二行,所以加上2*rows即可
if((i%rows==0)&&i!=0)
{
j=(rows*2)*(i/rows);//(i/rows)代表需要换多少行;(rows*2)就是行数
c0[j]=a[i];
j++;
}
else
{
c0[j]=a[i];
j++;
}
}
//此处是A12的首元素,令索引等于子矩阵的那个函数
j=rows;
for(i=0;i<(rows*rows);i++)
{
//这里的条件和之前的一样的,其它四部分也是一样的,底层逻辑是一致的
if((i%rows==0)&&i!=0)
{
j=(rows*2)*(i/rows)+rows;//换行要换到第二部分,所以需要再加一个rows
c0[j]=b[i];
j++;
}
else
{
c0[j]=b[i];
j++;
}
}
//此处是A21的首元素,令索引等于子矩阵的那个函数
j=rows*2*rows;
for(i=0;i<(rows*rows);i++)
{
if((i%rows==0)&&i!=0)
{
j=rows*2*rows+(rows*2)*(i/rows);
c0[j]=c[i];
j++;
}
else
{
c0[j]=c[i];
j++;
}
}
//此处是A22的首元素,令索引等于子矩阵的那个函数
j=rows*2*rows+rows;
for(i=0;i<(rows*rows);i++)
{
if((i%rows==0)&&i!=0)
{
j=rows*2*rows+rows+(rows*2)*(i/rows);
c0[j]=d[i];
j++;
}
else
{
c0[j]=d[i];
j++;
}
}
}
//两个大小相同的矩阵的加法
void Matrix_Add(int* x,int* y,int* c0,int rows)
{
int i=0;
for(i=0;i<(rows*rows);i++)
{
c0[i]=x[i]+y[i];
}
}
//矩阵的乘法递归函数
void Matrix_Multiply(int* x,int* y,int* c0,int loca,int locb,int rows)
{
//进行索引的申明
//递归过程为了节约时间,没有进行原矩阵的切割赋值
//直接采用索引的方法进行矩阵的伪切割
int loca11,loca12,loca21,loca22;
int locb11,locb12,locb21,locb22;
//结果矩阵的四分子矩阵申明
int c11[(rows/2)*(rows/2)];
int c12[(rows/2)*(rows/2)];
int c21[(rows/2)*(rows/2)];
int c22[(rows/2)*(rows/2)];
//中间暂存矩阵申明
int temp1[(rows/2)*(rows/2)];
int temp2[(rows/2)*(rows/2)];
//四分矩阵后的子矩阵宽度
int newrows=rows/2;
//如果矩阵的宽度为1,那么就直接元素相乘,递归结束
if(rows==1)
{
c0[0]=x[(loca)]*y[(locb)];
}
else
{
//索引的计算,进行矩阵伪分割
//因为n为2的幂,所以直接二分即可
loca11=loca;
locb11=locb;
loca12=loca+(rows/2);
locb12=locb+(rows/2);
loca21=loca+(size*(rows/2));
locb21=locb+(size*(rows/2));
loca22=loca+(size*(rows/2))+(rows/2);
locb22=locb+(size*(rows/2))+(rows/2);
//根据算法公式进行运算
//涉及矩阵相乘的进入递归
Matrix_Multiply(x,y,temp1,loca11,locb11,newrows);
Matrix_Multiply(x,y,temp2,loca12,locb21,newrows);
Matrix_Add(temp1,temp2,c11,newrows);
Matrix_Multiply(x,y,temp1,loca11,locb12,newrows);
Matrix_Multiply(x,y,temp2,loca12,locb22,newrows);
Matrix_Add(temp1,temp2,c12,newrows);
Matrix_Multiply(x,y,temp1,loca21,locb11,newrows);
Matrix_Multiply(x,y,temp2,loca22,locb21,newrows);
Matrix_Add(temp1,temp2,c21,newrows);
Matrix_Multiply(x,y,temp1,loca21,locb12,newrows);
Matrix_Multiply(x,y,temp2,loca22,locb22,newrows);
Matrix_Add(temp1,temp2,c22,newrows);
//递归后,进行矩阵的合并
Merge_Matrix(c11,c12,c21,c22,c0,newrows);
}
}
int main()
{
int i=0;
//动态生成指定大小的矩阵
int* a=NULL;
int* b=NULL;
int* c=NULL;
a=(int *)malloc(size*size*sizeof(int));
b=(int *)malloc(size*size*sizeof(int));
c=(int *)malloc(size*size*sizeof(int));
//时间随机数种子生成
srand((unsigned)time(NULL));
//给A和B矩阵赋随机值 (0~10)
for(i=0;i<(size*size);i++)
{
a[i]=rand()%10;
b[i]=rand()%10;
}
//打印矩阵A和B
printf("A matrix is:\n\n");
for(i=0;i<(size*size);i++)
{
if(((i+1)%size)==0)
{
printf("%5d\n",a[i]);
}
else
{
printf("%5d",a[i]);
}
}
printf("B matrix is:\n\n");
for(i=0;i<(size*size);i++)
{
if(((i+1)%size)==0)
{
printf("%5d\n",b[i]);
}
else
{
printf("%5d",b[i]);
}
}
//进行矩阵乘法
Matrix_Multiply(a,b,c,0,0,size);
//打印结果矩阵C
printf("C matrix is:\n\n");
for(i=0;i<(size*size);i++)
{
if(((i+1)%size)==0)
{
printf("%5d\n",c[i]);
}
else
{
printf("%5d",c[i]);
}
}
//释放动态内存
free(a);
free(b);
free(c);
return 0;
}
2.算法原理分析
对于
n
∗
n
n*n
n∗n矩阵A,B,C(n为2的幂)
我们可以进行如下矩阵分割:
A
=
[
A
11
A
12
A
21
A
22
]
,
B
=
[
B
11
B
12
B
21
B
22
]
,
C
=
[
C
11
C
12
C
21
C
22
]
A= \begin{bmatrix} A_{11}& A_{12}\\ A_{21}& A_{22}\\ \end{bmatrix} , B= \begin{bmatrix} B_{11}& B_{12}\\ B_{21}& B_{22}\\ \end{bmatrix} , C= \begin{bmatrix} C_{11}& C_{12}\\ C_{21}& C_{22}\\ \end{bmatrix}
A=[A11A21A12A22],B=[B11B21B12B22],C=[C11C21C12C22]
对于矩阵运算式
C
=
A
⋅
B
C=A\cdot B
C=A⋅B,我们可以改写为:
[
C
11
C
12
C
21
C
22
]
=
[
A
11
A
12
A
21
A
22
]
⋅
[
B
11
B
12
B
21
B
22
]
\begin{bmatrix} C_{11}& C_{12}\\ C_{21}& C_{22}\\ \end{bmatrix}= \begin{bmatrix} A_{11}& A_{12}\\ A_{21}& A_{22}\\ \end{bmatrix} \cdot \begin{bmatrix} B_{11}& B_{12}\\ B_{21}& B_{22}\\ \end{bmatrix}
[C11C21C12C22]=[A11A21A12A22]⋅[B11B21B12B22]
可以进一步写成:
C
11
=
A
11
⋅
B
11
+
A
12
⋅
B
21
C_{11}=A_{11}\cdot B_{11}+A_{12}\cdot B_{21}
C11=A11⋅B11+A12⋅B21
C
12
=
A
11
⋅
B
12
+
A
12
⋅
B
22
C_{12}=A_{11}\cdot B_{12}+A_{12}\cdot B_{22}
C12=A11⋅B12+A12⋅B22
C
21
=
A
21
⋅
B
11
+
A
22
⋅
B
21
C_{21}=A_{21}\cdot B_{11}+A_{22}\cdot B_{21}
C21=A21⋅B11+A22⋅B21
C
22
=
A
21
⋅
B
12
+
A
22
⋅
B
22
C_{22}=A_{21}\cdot B_{12}+A_{22}\cdot B_{22}
C22=A21⋅B12+A22⋅B22
这样我们就可以实现两个
n
∗
n
n*n
n∗n矩阵乘法的分割,分次运算。
再结合n为2的幂的条件,我们就可以进行算法的递归实现。
3.编程细节
(1)用索引的方式进行伪切割
根据算法,程序需要对原矩阵进行切割,将其分为四个大小一样的子矩阵,然后一直这样细分下去,直到递归结束。那么这个分割的操作就是很关键的。有两种可行的分割方式:
1.将矩阵实际分割,并且进行赋值,得到四个子矩阵
2.将矩阵进行伪分割,并没有去得到四个子矩阵,而是通过索引来让程序辨别子矩阵的位置,实际上一直对原矩阵操作
在进行索引的赋值时,关键就是要明确:该方法本质上是对原矩阵进行操作,那么换行时是原矩阵的宽度:即size
(2)编写递归结构
该方法的递归是针对矩阵乘法,那么我们只需要在乘法函数中明确递归结束条件,并且在函数实现时,将涉及矩阵乘法的部分,用调用函数本身的方法完成,即可完成递归结构了。
二、矩阵乘法的Strassen(斯特拉森)方法
1.C语言代码实现
#include<stdio.h>
#include<stdlib.h>
#include<time.h>
//宏定义改变矩阵的大小
//此处size指的是n*n矩阵的宽度或者长度n
//由于Strassen算法本身的局限性
//两个相乘的矩阵只能是n*n,且n只能是2的幂
//即size为1,2,4,8,16...
#define size 32
//矩阵的合并
//就是将A11,A12,A21,A22合并为
// |A11 A12|
// |A21 A22|
//这样的形式
void Merge_Matrix(int *a,int *b,int *c,int *d,int* c0,int rows)
{
//rows是子矩阵的宽度,那么合并后矩阵的宽度就是2*rows
int i=0;//子矩阵遍历索引
int j=0;//合并矩阵遍历索引,此处先合并A11
for(i=0;i<(rows*rows);i++)
{
//如果执行了rows次就需要换行,要锁定到合并后矩阵的第二行,所以加上2*rows即可
if((i%rows==0)&&i!=0)
{
j=(rows*2)*(i/rows);//(i/rows)代表需要换多少行;(rows*2)就是行数
c0[j]=a[i];
j++;
}
else
{
c0[j]=a[i];
j++;
}
}
//此处是A12的首元素,令索引等于子矩阵的那个函数
j=rows;
for(i=0;i<(rows*rows);i++)
{
//这里的条件和之前的一样的,其它四部分也是一样的,底层逻辑是一致的
if((i%rows==0)&&i!=0)
{
j=(rows*2)*(i/rows)+rows;//换行要换到第二部分,所以需要再加一个rows
c0[j]=b[i];
j++;
}
else
{
c0[j]=b[i];
j++;
}
}
//此处是A21的首元素,令索引等于子矩阵的那个函数
j=rows*2*rows;
for(i=0;i<(rows*rows);i++)
{
if((i%rows==0)&&i!=0)
{
j=rows*2*rows+(rows*2)*(i/rows);
c0[j]=c[i];
j++;
}
else
{
c0[j]=c[i];
j++;
}
}
//此处是A22的首元素,令索引等于子矩阵的那个函数
j=rows*2*rows+rows;
for(i=0;i<(rows*rows);i++)
{
if((i%rows==0)&&i!=0)
{
j=rows*2*rows+rows+(rows*2)*(i/rows);
c0[j]=d[i];
j++;
}
else
{
c0[j]=d[i];
j++;
}
}
}
//两个n*n矩阵的减法,是x-y
void Matrix_SUB(int* x,int* y,int* c0,int rows)
{
int i=0;
for(i=0;i<(rows*rows);i++)
{
c0[i]=x[i]-y[i];
}
}
//两个n*n矩阵的加法,是x+y
void Matrix_ADD(int* x,int* y,int* c0,int rows)
{
int i=0;
for(i=0;i<(rows*rows);i++)
{
c0[i]=x[i]+y[i];
}
}
//将一个矩阵
// |A11 A12|
// |A21 A22|
//分解为四个矩阵
//A11,A12,A21,A22
//这个函数其实是Matrix_merge的一个逆运算
//逻辑是一致的,只需要交换赋值位置即可
void Matrix_division(int* a,int* b,int* c,int* d,int* c0,int rows)
{
int i=0;
int j=0;
for(i=0;i<(rows*rows);i++)
{
if((i%rows==0)&&i!=0)
{
j=(rows*2)*(i/rows);
a[i]=c0[j];
j++;
}
else
{
a[i]=c0[j];
j++;
}
}
j=rows;
for(i=0;i<(rows*rows);i++)
{
if((i%rows==0)&&i!=0)
{
j=(rows*2)*(i/rows)+rows;
b[i]=c0[j];
j++;
}
else
{
b[i]=c0[j];
j++;
}
}
j=rows*2*rows;
for(i=0;i<(rows*rows);i++)
{
if((i%rows==0)&&i!=0)
{
j=rows*2*rows+(rows*2)*(i/rows);
c[i]=c0[j];
j++;
}
else
{
c[i]=c0[j];
j++;
}
}
j=rows*2*rows+rows;
for(i=0;i<(rows*rows);i++)
{
if((i%rows==0)&&i!=0)
{
j=rows*2*rows+rows+(rows*2)*(i/rows);
d[i]=c0[j];
j++;
}
else
{
d[i]=c0[j];
j++;
}
}
}
//矩阵乘法的函数
void Matrix_Multiply(int* x,int* y,int* c0,int rows)
{
//定义八个子矩阵,两个乘数矩阵各四个
int x11[(rows/2)*(rows/2)];
int x12[(rows/2)*(rows/2)];
int x21[(rows/2)*(rows/2)];
int x22[(rows/2)*(rows/2)];
int y11[(rows/2)*(rows/2)];
int y12[(rows/2)*(rows/2)];
int y21[(rows/2)*(rows/2)];
int y22[(rows/2)*(rows/2)];
//定义四个子矩阵,结果矩阵的子矩阵
int c11[(rows/2)*(rows/2)];
int c12[(rows/2)*(rows/2)];
int c21[(rows/2)*(rows/2)];
int c22[(rows/2)*(rows/2)];
//Strassen方法需要的中间矩阵
int s1[(rows/2)*(rows/2)];
int s2[(rows/2)*(rows/2)];
int s3[(rows/2)*(rows/2)];
int s4[(rows/2)*(rows/2)];
int s5[(rows/2)*(rows/2)];
int s6[(rows/2)*(rows/2)];
int s7[(rows/2)*(rows/2)];
int s8[(rows/2)*(rows/2)];
int s9[(rows/2)*(rows/2)];
int s10[(rows/2)*(rows/2)];
int p1[(rows/2)*(rows/2)];
int p2[(rows/2)*(rows/2)];
int p3[(rows/2)*(rows/2)];
int p4[(rows/2)*(rows/2)];
int p5[(rows/2)*(rows/2)];
int p6[(rows/2)*(rows/2)];
int p7[(rows/2)*(rows/2)];
//代码实现需要的暂存矩阵
int temp1[(rows/2)*(rows/2)];
int temp2[(rows/2)*(rows/2)];
//rows==1说明矩阵只有一个元素,是递归的结束,递归树的终点
if(rows==1)
{
c0[0]=x[0]*y[0];
}
else
{
//得到新的大小,进入下次递归
int newrows=rows/2;
//将矩阵分割
Matrix_division(y11,y12,y21,y22,y,newrows);
Matrix_division(x11,x12,x21,x22,x,newrows);
//按照Strassen方法,进行矩阵预处理
Matrix_SUB(y12,y22,s1,newrows);
Matrix_ADD(x11,x12,s2,newrows);
Matrix_ADD(x21,x22,s3,newrows);
Matrix_SUB(y21,y11,s4,newrows);
Matrix_ADD(x11,x22,s5,newrows);
Matrix_ADD(y11,y22,s6,newrows);
Matrix_SUB(x12,x22,s7,newrows);
Matrix_ADD(y21,y22,s8,newrows);
Matrix_SUB(x11,x21,s9,newrows);
Matrix_ADD(y11,y12,s10,newrows);
//按照Strassen方法,进行矩阵乘法上的递归
Matrix_Multiply(x11,s1,p1,newrows);
Matrix_Multiply(s2,y22,p2,newrows);
Matrix_Multiply(s3,y11,p3,newrows);
Matrix_Multiply(x22,s4,p4,newrows);
Matrix_Multiply(s5,s6,p5,newrows);
Matrix_Multiply(s7,s8,p6,newrows);
Matrix_Multiply(s9,s10,p7,newrows);
//按照Strassen方法,进行结果矩阵的计算
//得到c11,c12,c21,c22
Matrix_ADD(p5,p6,temp1,newrows);
Matrix_SUB(p4,p2,temp2,newrows);
Matrix_ADD(temp1,temp2,c11,newrows);
Matrix_ADD(p1,p2,c12,newrows);
Matrix_ADD(p4,p3,c21,newrows);
Matrix_SUB(p5,p3,temp1,newrows);
Matrix_SUB(p1,p7,temp2,newrows);
Matrix_ADD(temp1,temp2,c22,newrows);
//最后将c11,c12,c21,c22合并为C
Merge_Matrix(c11,c12,c21,c22,c0,newrows);
}
}
int main()
{
int i=0;
int* a=NULL;
int* b=NULL;
int* c=NULL;
//动态生成三个n*n矩阵
a=(int *)malloc(size*size*sizeof(int));
b=(int *)malloc(size*size*sizeof(int));
c=(int *)malloc(size*size*sizeof(int));
//随机数时间种子生成
srand((unsigned)time(NULL));
//生成-10~10的随机数
for(i=0;i<(size*size);i++)
{
a[i]=(rand()%10)-(rand()%10);
b[i]=rand()%10-(rand()%10);
}
//打印乘数矩阵A和B
printf("A matrix is:\n\n");
for(i=0;i<(size*size);i++)
{
if(((i+1)%size)==0)
{
printf("%6d\n",a[i]);
}
else
{
printf("%6d",a[i]);
}
}
printf("B matrix is:\n\n");
for(i=0;i<(size*size);i++)
{
if(((i+1)%size)==0)
{
printf("%6d\n",b[i]);
}
else
{
printf("%6d",b[i]);
}
}
//进行矩阵乘法
Matrix_Multiply(a,b,c,size);
//打印结果矩阵
printf("C matrix is:\n\n");
for(i=0;i<(size*size);i++)
{
if(((i+1)%size)==0)
{
printf("%6d\n",c[i]);
}
else
{
printf("%6d",c[i]);
}
}
//释放动态生成的矩阵
free(a);
free(b);
free(c);
return 0;
}
2.算法原理分析
对于
n
∗
n
n*n
n∗n矩阵A,B,C(n为2的幂)
我们还是可以进行如下矩阵分割:
A
=
[
A
11
A
12
A
21
A
22
]
,
B
=
[
B
11
B
12
B
21
B
22
]
,
C
=
[
C
11
C
12
C
21
C
22
]
A= \begin{bmatrix} A_{11}& A_{12}\\ A_{21}& A_{22}\\ \end{bmatrix} , B= \begin{bmatrix} B_{11}& B_{12}\\ B_{21}& B_{22}\\ \end{bmatrix} , C= \begin{bmatrix} C_{11}& C_{12}\\ C_{21}& C_{22}\\ \end{bmatrix}
A=[A11A21A12A22],B=[B11B21B12B22],C=[C11C21C12C22]
分割之后我们先进行如下一些预处理运算
S
1
=
B
12
−
B
22
S_1=B_{12}-B_{22}
S1=B12−B22
S
2
=
A
11
+
A
12
S_2=A_{11}+A_{12}
S2=A11+A12
S
3
=
A
21
+
A
22
S_3=A_{21}+A_{22}
S3=A21+A22
S
4
=
B
21
−
B
11
S_4=B_{21}-B_{11}
S4=B21−B11
S
5
=
A
11
+
A
22
S_5=A_{11}+A_{22}
S5=A11+A22
S
6
=
B
11
+
B
22
S_6=B_{11}+B_{22}
S6=B11+B22
S
7
=
A
12
−
A
22
S_7=A_{12}-A_{22}
S7=A12−A22
S
8
=
B
21
+
B
22
S_8=B_{21}+B_{22}
S8=B21+B22
S
9
=
A
11
−
A
21
S_9=A_{11}-A_{21}
S9=A11−A21
S
10
=
B
11
+
B
12
S_{10}=B_{11}+B_{12}
S10=B11+B12
然后我们根据S1~S2的值,进行7次乘法运算,如下:
P
1
=
A
11
⋅
S
1
P_{1}=A_{11}\cdot S_1
P1=A11⋅S1
P
2
=
S
2
⋅
B
22
P_{2}=S_{2}\cdot B_{22}
P2=S2⋅B22
P
3
=
S
3
⋅
B
11
P_{3}=S_3\cdot B_{11}
P3=S3⋅B11
P
4
=
A
22
⋅
S
4
P_{4}=A_{22}\cdot S_4
P4=A22⋅S4
P
5
=
S
5
⋅
S
6
P_{5}=S_5\cdot S_6
P5=S5⋅S6
P
6
=
S
7
⋅
S
8
P_{6}=S_7\cdot S_8
P6=S7⋅S8
P
7
=
S
9
⋅
S
10
P_{7}=S_9\cdot S_{10}
P7=S9⋅S10
然后再通过简单的加法运算,就可以得到结果矩阵的四个子矩阵:
C
11
=
P
5
+
P
4
−
P
2
+
P
6
C_{11}=P_5+P_4-P_2+P_6
C11=P5+P4−P2+P6
C
12
=
P
1
+
P
2
C_{12}=P_1+P_2
C12=P1+P2
C
21
=
P
3
+
P
4
C_{21}=P_3+P_4
C21=P3+P4
C
22
=
P
5
+
P
1
−
P
3
−
P
7
C_{22}=P_5+P_1-P_3-P_7
C22=P5+P1−P3−P7
最后将子矩阵合并即可
3.编程细节
(1)分割矩阵
我们在普通的递归方法中,对于分割矩阵采用了利用索引的方法。但是在Strassen方法中,涉及了大量的中间矩阵,造成在递归调用的时候,原矩阵和中间矩阵会一起进入递归,如果用统一的索引方法,就会造成数据空间上的越界,故只能采用传统的对子矩阵赋值的方法。
当然我们可以对合并矩阵的代码进行逆向工程,稍微修改就可以得到正确的分割矩阵的代码。
三、算法的时间复杂度分析
1.两个方法的时间复杂度
1.普通递归方法的时间复杂度为: Ω ( n 3 ) \Omega(n^3) Ω(n3)
2.Strassen方法的时间复杂度为: O ( n 2.18 ) O(n^{2.18}) O(n2.18)
具体的计算过程将在下一篇文章中讲解。
链接: 《算法导论》学习(五)---- 分治策略的时间复杂度求解
2.两个方法时间上的比较
1.从渐进复杂性上,Strassen方法是要优于普通的递归方法的,归根结底是因为Strassen的递归数更小,因为只需要7次乘法,而普通递归方法需要8次乘法。文章来源:https://www.toymoban.com/news/detail-406127.html
2.但是由于Strassen方法的中间过程更加繁琐,所以当输入规模小的时候,普通的递归方法反而要快不少。文章来源地址https://www.toymoban.com/news/detail-406127.html
到了这里,关于《算法导论》学习(四)---- 矩阵乘法的Strassen(斯特拉森)算法的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!