📜  矩阵链乘法(AO(N ^ 2)解决方案)(1)

📅  最后修改于: 2023-12-03 15:11:23.539000             🧑  作者: Mango

矩阵链乘法(AO(N^2)解决方案)

什么是矩阵链乘法?

矩阵链乘法问题指的是给定一系列矩阵,如何安排它们的乘法顺序可以让计算的数量最少。比如说,有三个矩阵 A、B、C,它们的维度分别是 $10\times 30$、$30\times 5$ 和 $5\times 60$,如果按照 $(A\times B)\times C$ 的顺序来计算乘积的话,需要进行 $10\times 30\times 5+10\times 5\times 60=1500+3000=4500$ 次标量乘法和 $10\times 5+5\times 60=250$ 次矩阵乘法,总共 $4750$ 次计算。而如果按照 $A\times (B\times C)$ 的顺序计算乘积的话,需要进行 $30\times 5\times 60+10\times 30\times 60=9000+18000=27000$ 次标量乘法和 $10\times 30+30\times 5=450$ 次矩阵乘法,总共 $27450$ 次计算。显然前者更优。

问题的求解

矩阵链乘法问题可以使用动态规划来求解,在 $O(N^3)$ 的时间内得出最优解。但是,针对该问题,有一种 $O(N^2)$ 的解法,称为 AO(Approximate Optimal)算法,可以作为求解矩阵链乘法问题的一个实际可行的解法。

AO 算法

AO 算法是一种启发式的贪心算法,其基本思想是:通过所有的括号方案中的局部最优,和贪心法中的邻项最优性,得到问题的近似最优解。AO 算法的核心是确定每一个子问题的估计值,然后列表计算。估计值可以是任意一个低于实际计算量 $c(i,j)$ 的函数 $w(i,j)$,比如两个矩阵的乘积,可以使用矩阵的行数和列数的乘积来表示,具有鲁棒性。

启发式算法的核心是设计好寻求最优解的策略,因此 AO 算法的核心是递归从外部来的寻找最优解的策略。具体流程如下:

  1. 初始化一个 $N\times N$ 的二维数组 $W$,用来存放各个子问题的估计值。
  2. 对于所有的 $i$,都有 $w(i,i)=0$,因为相邻矩阵的乘积不需要做任何乘法。
  3. 从 ${1,2,\cdots,N}$ 中找出两个数 $i$ 和 $j$,使得 $w(i,j)$ 最小。
  4. 将两个矩阵 $A_i$ 和 $A_j$ 加上括号,形成一个新的矩阵序列 ${A_1,A_2,\cdots,A_{i-1},(A_iA_{i+1}\cdots A_j),A_{j+1},\cdots,A_N}$。
  5. 从新的矩阵序列的第一个矩阵开始,重复步骤 2 至 4,直到只剩下一个矩阵。
AO 算法的实现

下面给出 AO 算法的具体实现。首先是计算估价函数 $W$ 的代码:

def calc_w(mtx_sizes):
    n = len(mtx_sizes) - 1
    w = [[float('inf')] * n for _ in range(n)]
    for i in range(n):
        w[i][i] = 0
    for p in range(1, n):  # 计算 w(p, i+j)
        for i in range(n - p):
            j = i + p
            for k in range(i, j):
                w[i][j] = min(w[i][j], mtx_sizes[i] * mtx_sizes[k + 1] * mtx_sizes[j + 1] + w[i][k] + w[k + 1][j])
    return w

其中 mtx_sizes 是一个数组,表示一连串矩阵的大小,其中第 $i$ 个矩阵的行数为 mtx_sizes[i],列数为 mtx_sizes[i+1]

接下来是根据估价函数计算出乘积的括号方案的代码:

def get_parentheses(w, start, end):
    if start == end:
        return [start]
    for k in range(start, end):
        if w[start][end] == w[start][k] + w[k + 1][end] + mtx_sizes[start] * mtx_sizes[k + 1] * mtx_sizes[end + 1]:
            return [get_parentheses(w, start, k), get_parentheses(w, k + 1, end)]
    raise Exception('invalid parameters')

其中 startend 分别表示从矩阵序列中的第 start 个矩阵到第 end 个矩阵,这部分序列的乘积要加括号。函数递归地寻找括号的位置,最终得到括号方案的一个树状表示。

最后,可以使用如下代码,将括号方案输出为字符串:

def get_parentheses_str(parens):
    if isinstance(parens, int):
        return str(parens)
    elif len(parens) == 1:
        return get_parentheses_str(parens[0])
    else:
        return '(' + get_parentheses_str(parens[0]) + ',' + get_parentheses_str(parens[1]) + ')'
结语

AO 算法虽然不能保证得到最优解,但是它的实现简单,而且时间复杂度为 $O(N^2)$,在实际应用中可以得到较好的效果。在实际问题中,如果矩阵大小分布比较均匀,那么 AO 算法的贪心策略很可能会找到准确的最优解。