Python Program to Solve Matrix-Chain Multiplication using Dynamic Programming with Memoization

This is a Python program to solve matrix-chain multiplication using dynamic programming with top-down approach or memoization.

Problem Description

In the matrix-chain multiplication problem, we are given a sequence of matrices A(1), A(2), …, A(n). The aim is to compute the product A(1)…A(n) with the minimum number of scalar multiplications. Thus, we have to find an optimal parenthesization of the matrix product A(1)…A(n) such that the cost of computing the product is minimized.

Problem Solution

1. Three functions are defined, matrix_product, matrix_product_helper and print_parenthesization.
2. matrix_product_helper takes as arguments a list p, two 2D tables m and s, and two indexes start and end.
3. The function stores the minimum number of scalar multiplications needed to compute the product A(i) x A(i + 1) x … x A(j) in m[i][j].
4. The index of the matrix after which the above product is split in an optimal parenthesization is stored in s[i][j].
5. p[0… n] is a list such that matrix A(i) has dimensions p[i – 1] x p[i].
6. The function returns m[start][end].
7. That is, it returns the minimum computations needed to evaluate A(start) x … x A(end).
8. This is done by finding a k such that m[start][k] + m[k + 1][end] + p[start – 1]*p[k]*p[end] is minimized. The last term is the cost of multiplying the two products formed by splitting the matrix-chain after matrix k.
9. The function is implemented recursively and as a minimum cost is calculated it is stored in m and the index of the split is stored in s.
10. If a minimum cost has been already calculated and stored in m, then it is immediately returned and not calculated again.
11. The function matrix_product takes the list p as argument, which contains the dimensions of the matrices in the matrix-chain.
12. It simply initializes two 2D tables m and s as a list of lists and calls matrix_product_helper.
13. It then returns m and s.
14. The function print_parenthesization takes as argument a 2D table s as generated above.
15. It also takes two indexes start and end as arguments.
16. It prints the optimal parenthesization of the matrix-chain product A(start) x … x A(end).

Program/Source Code

Here is the source code of a Python program to solve the matrix-chain multiplication problem using dynamic programming with memoization. The program output is shown below.

def matrix_product(p):
    """Return m and s.
 
    m[i][j] is the minimum number of scalar multiplications needed to compute the
    product of matrices A(i), A(i + 1), ..., A(j).
 
    s[i][j] is the index of the matrix after which the product is split in an
    optimal parenthesization of the matrix product.
 
    p[0... n] is a list such that matrix A(i) has dimensions p[i - 1] x p[i].
    """
    length = len(p) # len(p) = number of matrices + 1
 
    # m[i][j] is the minimum number of multiplications needed to compute the
    # product of matrices A(i), A(i+1), ..., A(j)
    # s[i][j] is the matrix after which the product is split in the minimum
    # number of multiplications needed
    m = [[-1]*length for _ in range(length)]
    s = [[-1]*length for _ in range(length)]
 
    matrix_product_helper(p, 1, length - 1, m, s)
 
    return m, s
 
 
def matrix_product_helper(p, start, end, m, s):
    """Return minimum number of scalar multiplications needed to compute the
    product of matrices A(start), A(start + 1), ..., A(end).
 
    The minimum number of scalar multiplications needed to compute the
    product of matrices A(i), A(i + 1), ..., A(j) is stored in m[i][j].
 
    The index of the matrix after which the above product is split in an optimal
    parenthesization is stored in s[i][j].
 
    p[0... n] is a list such that matrix A(i) has dimensions p[i - 1] x p[i].
    """
    if m[start][end] >= 0:
        return m[start][end]
 
    if start == end:
        q = 0
    else:
        q = float('inf')
        for k in range(start, end):
            temp = matrix_product_helper(p, start, k, m, s) \
                   + matrix_product_helper(p, k + 1, end, m, s) \
                   + p[start - 1]*p[k]*p[end]
            if q > temp:
                q = temp
                s[start][end] = k
 
    m[start][end] = q
    return q
 
 
def print_parenthesization(s, start, end):
    """Print the optimal parenthesization of the matrix product A(start) x
    A(start + 1) x ... x A(end).
 
    s[i][j] is the index of the matrix after which the product is split in an
    optimal parenthesization of the matrix product.
    """
    if start == end:
        print('A[{}]'.format(start), end='')
        return
 
    k = s[start][end]
 
    print('(', end='')
    print_parenthesization(s, start, k)
    print_parenthesization(s, k + 1, end)
    print(')', end='')
 
 
n = int(input('Enter number of matrices: '))
p = []
for i in range(n):
    temp = int(input('Enter number of rows in matrix {}: '.format(i + 1)))
    p.append(temp)
temp = int(input('Enter number of columns in matrix {}: '.format(n)))
p.append(temp)
 
m, s = matrix_product(p)
print('The number of scalar multiplications needed:', m[1][n])
print('Optimal parenthesization: ', end='')
print_parenthesization(s, 1, n)
Program Explanation

1. The user is prompted to enter the number of matrices, n.
2. The user is then asked to enter the dimensions of the matrices.
3. matrix_product is called to get the tables m and s.
4. m[1][n] is the minimum cost of computing the matrix product.
5. print_parenthesization is then called to display the optimal way to parenthesize the matrix product.

advertisement
advertisement
Runtime Test Cases
Case 1:
Enter number of matrices: 3
Enter number of rows in matrix 1: 10
Enter number of rows in matrix 2: 100
Enter number of rows in matrix 3: 5
Enter number of columns in matrix 3: 50
The number of scalar multiplications needed: 7500
Optimal parenthesization: ((A[1]A[2])A[3])
 
Case 2:
Enter number of matrices: 5
Enter number of rows in matrix 1: 5
Enter number of rows in matrix 2: 10
Enter number of rows in matrix 3: 8
Enter number of rows in matrix 4: 15
Enter number of rows in matrix 5: 20
Enter number of columns in matrix 5: 4
The number of scalar multiplications needed: 2200
Optimal parenthesization: (A[1](A[2](A[3](A[4]A[5]))))
 
Case 3:
Enter number of matrices: 1
Enter number of rows in matrix 1: 5
Enter number of columns in matrix 1: 7
The number of scalar multiplications needed: 0
Optimal parenthesization: A[1]

Sanfoundry Global Education & Learning Series – Python Programs.

To practice all Python programs, here is complete set of 150+ Python Problems and Solutions.

Note: Join free Sanfoundry classes at Telegram or Youtube

If you find any mistake above, kindly email to [email protected]

advertisement
advertisement
Subscribe to our Newsletters (Subject-wise). Participate in the Sanfoundry Certification contest to get free Certificate of Merit. Join our social networks below and stay updated with latest contests, videos, internships and jobs!

Youtube | Telegram | LinkedIn | Instagram | Facebook | Twitter | Pinterest
Manish Bhojasia - Founder & CTO at Sanfoundry
Manish Bhojasia, a technology veteran with 20+ years @ Cisco & Wipro, is Founder and CTO at Sanfoundry. He lives in Bangalore, and focuses on development of Linux Kernel, SAN Technologies, Advanced C, Data Structures & Alogrithms. Stay connected with him at LinkedIn.

Subscribe to his free Masterclasses at Youtube & discussions at Telegram SanfoundryClasses.