행렬 곱셈의 순서를 알기 위해서는 모든 경우를 세기만 하면 되는 것 같습니다.
행렬의 수가 최대 500개이므로 O(n^3)까지 시간 초과되지 않습니다.
사실 원리만 알면 아주 간단한 dp 알고리즘이다.
for(int i=1; i<n; i++){
for(int j=1; j<=n-i; j++){
dp(j)(i+j)=2147483647;
for(int k=j; k<=i+j; k++)
dp(j)(i+j) = min(
dp(j)(i+j),
dp(j)(k) + dp(k+1)(i+j) + (v(j).first * v(k).second * v(i+j).second));
}
}
위의 코드를 설명하면 i+1은 곱할 행렬의 수이고 j는 곱할 행렬입니다.
k는 j에 i+j 행렬을 곱할 때 잘림(?) 위치입니다.
dp(i)(j)는 i번째 행렬에서 j번째 행렬로 곱할 때의 최소 연산 수입니다.
(5, 3) * (3, 2) * (2, 6) * (6, 1)을 계산한다고 가정합니다.
i = 1이면 인접 행렬을 계산할 때 작업 수가 계산됩니다.
i = 2이면 3개의 연속 행렬을 계산할 때 최소 연산 수가 계산됩니다.
i = n – 1이면 n개의 연속 행렬을 계산할 때 최소 연산 수가 계산됩니다.
어떤 느낌인지 잘 모르겠다면 직접 배열을 그려보세요. 곧 이해하게 될 것입니다.
(사실 설명이 어렵네요…)
답변
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
int n;
vector<pair<int, int>> v; //행, 열
int dp(501)(501);
void solve(){
for(int i=1; i<n; i++){
for(int j=1; j<=n-i; j++){
dp(j)(i+j)=2147483647;
for(int k=j; k<=i+j; k++)
dp(j)(i+j) = min(
dp(j)(i+j),
dp(j)(k) + dp(k+1)(i+j) + (v(j).first * v(k).second * v(i+j).second));
}
}
cout << dp(1)(n);
}
int main(){
cin.tie(NULL);
cout.tie(NULL);
ios_base::sync_with_stdio(false);
cin >> n;
v.push_back({0, 0});
for(int i=0; i<n; i++){
int a, b;
cin >> a >> b;
v.push_back({a, b});
}
solve();
}