알고리즘/baekjoon(boj)

[diamond5][ 풀이, 분류 미확인] boj18444 우체국3

끄응끄응 2026. 3. 29. 05:27

https://www.acmicpc.net/problem/18444

 

 

 

아이디어도 어려웠고, 특히 한번도 구현해보진 않은 희소배열의 아이디어를 기억 깊숙한 곳에서 꺼내기가 힘들었다.. 

+ 구현도 상당히 복잡했고,, 제일 큰 이슈는 java로는 시간제한을 지킬 수 없는 문제여서 수없는 삽질을 하였다...

Wow

 

 

처음 생각한 아이디어는 누적합을 이용한 dp 였다.

 

위와 같이 dp는 시작지점,  끝지점, 그 사이의 우체국 갯수 (시작 , 끝 점은 우체국 반드시 포함) 을

state로 한다.

 dp[2][6][3] =  Math.min(dp[2][5][2], dp[2][4][2], dp[2][3][2])  의 점화식을 가질 것이다.

end 지점이 e번이고 시작지점이 s번,  우체국 갯수가 p개인 dp를 구하고자 할 때  , n-1, n- 2... 가 end지점, 우체국갯수가 p - 1 개, start지점은 s번인 dp들을 참고해야 할 것이다.       

dp[s][e][p] 를 구할 때  dp[s][e - 1][p],   dp[s][e - 2][p]... dp[s][e -  p  + 1][p]...  만큼 참고해야 하므로 최대, e번 계산해야하고, 최대 N번이 된다.  이 과정을 N(s에 대해) * N (e에 대해)* P(우체국갯수)에 반복하므로,  시간복잡도는 거의 O(N^4)이다.

어떻게 하면 이걸 줄일 수 있을 까...

 한참을 고민하다가, 희소배열의 개념이 떠올랐다. 사실 문제를 보자마자 생김새가 예전에 풀었던 감시카메라 문제와 유사해서 희소배열도 잠깐 떠올린 것 같지만 , dp에 꽂혀서 금세 잊혀졌었다. 그러다가 dp[s][e][p] 에 대해,

dp[s][e][p] 를 dp[s][m][p1] 와 dp[m][e][p2] (p1 + p2 = p ) 로 나타낼 수 있지 않을 까 생각하던중, 거듭제곱을 활용한 희소배열이 생각이 났다. 희소배열을 사용하면서 분할 정복도 사용이 될 것이다.  

우선, 희소배열을 생각하기 전 분할정복은 다음과 같다.    dp[1][10][5] 를 구성할 수 있는 요소로는 dp[1][5][3], dp[5][10][3] 이 있을 것이다.

 

또한 아래와 같이 dp[1][3][3], dp[3][10][3] 과 같은 경우도 있을 것이다.

 

dp[1][10][5] = Math.min(dp[1][5][3] + dp[5][10][3], dp[1][6][3] + dp[6][10][3],        dp[1][7][3] + dp[7][10][3]  ....

dp[1][8][10]  + dp[8][10][3]) 으로 생각 할 수 있을 것이다.

그런데 여기서 의문이 들 수도 있다. 왜 dp[1][4][2] + dp[4][10][4]와 같은 경우는 생각하지 않는가? 

아래와 같이 말이다.   결국   dp[1][4][2] + dp[4][10][4]  는 범위를 조정하면 dp[1][5][3] + dp[5][10][3] 로도 나타낼 수 있다.

따라서 우체국 수는 3으로 고정시킬 수 있다.

 

분할정복은 대충 이렇게 하는 거라는 걸 알았지만, 위와같이 dp[s][e][cnt] = dp[s][m][cnt1] + dp[m][e][cnt2]  의 형태로 구하려한다면, 여기서도 어김없이 O(N ^4)의 시간복잡도이다.

여기서 희소배열의 아이디어가 쓰인다.

위에서 dp[s][e][5] 은 dp[s][m][3] + dp[m][e][3] 에서 m을 조정하며 얻은 최솟값이라는 걸 알았다.

그렇다면, dp[s][e][9] 는?  dp[s][m][5] + dp[m][e][5] 를 조정하며 얻은 최솟값이다.

마찬가지로 dp[s][e][17] 은  dp[s][m][9] + dp[m][e][9]에서 계산할수 있다.

이런식으로 2 * p - 1  씩 증가시키며 계산하면,  구간안의 우체국 수가 2 , 3 , 5 , 9 , 17, 33 , 65 , 129 ... 일때의 dp을 구할 수 있다.

이때 모든 구간 s, e에 대해 구한다고 하면,   s, e, m 을 조정하면서 구하므로,  이때 O(N^3), 이걸 2,3,5...129,257 까지 구한다 하면 O( N^3 * log2(257)) 이 된다.

왜 이렇게 구했을까?  

  dp[s][e][111] 을 구해야 한다고 생각해보자.  우선 위에서 구한 2 , 3 , 5 , 9 , 17, 33 , 65 , 129 ...  중  65와 33 을 결합한다고 생각해보자.   dp[s][e][97] 은 dp[s][m][65] + dp[m][e][33] 에서 s,e,m을 조정하며 모든 구간에 대한 우체국 수가 97일 때를 구할 수 있다. 97까지 구했으면 111까지 14개의 우체국이 남았다. 

여기서 dp[s][e][97 + 9 - 1] 은  dp[s][m][97] + dp[m][e][9]  로 구할 수 있다.

dp[s][e][105 + 5 - 1]  은 dp[s][m][105] + dp[m][e][5] 

dp[s][e][109 + 3 - 1]  은 dp[s][m][109] + dp[m][e][3]   이런식으로 구할 수 있다.

 

65 + 33 -> 97

97 + 9 -> 105

105 + 5 -> 109

109 + 3 -> 111

 

 

결국 여기서 111까지 구하기 위해선  65, 33 , 9 , 5 , 3  만 필요했다. 

즉 우체국수를 2,3,5,9... 257 까지 구해놓으면 여기서 적절하게 조합해 모든 수를 만들 수 있다.

각각의 우체국수에 대해 O(N^3)의 시간복잡도이므로, 결국 O(N^3 * (log2(257))로 구할 수 있다.

즉 모든 dp[s][e][p]에서 p를 1,2,3,4,... 우체국 최대갯수 까지 구하는게 아니라, 2, 3,5,9, ... 식으로 띄엄띄엄 구해서, 이 수들을 조합하는 식으로 구할 수 있는 것이다.

 

최대 우체국수가 111개 였으면, 위와 같이 모든 s,e에 대해 dp[s][e][111]을 구했다. 그런데 마을은 원형을 이루고 있다.

위와 같이 dp[s][e][111]에 대해 구했으면, 빨간 부분은 이제 우체국이 들어설 수 없는 구역이다. 이 구역에는 양끝 s, e 에만 우체국이 있으므로, 누적합을 이용해 빨간색 구간 마을의 우체국 거리 합을 구할 수 있다.

dp[s][e][111] + 빨간구간 거리합 까지 합하면, 이제 원형의 모든 마을에 대해 우체국까지 최소거리를 구할 수 있다.

 

모든 s,e에 대해 이뤄지므로 추가적으로 O(N^2 ) 이 이뤄지게 된다. (누적합을 통한 계산은 거의 O(1)이다).

 

역추적은 적당히 재귀 dfs를 이용해 돌리면 된다..

 

 

이렇게 아이디어를 내보았다.

 

아이디어를 떠올린 것 까지는 분위기가 좋았다. 그러나 구현이 매우 까다로웠고, 구현 성공후, java로는 도저히 안돌려지길래 gpt한테 그대로 c++로 옮겨달라고 해서 해결하였다.  java로 안되서 얼마나 오래 삽질했는지 정신이 아찔해진다.

그렇다고 합니다. 사실 c++로 돌려서 채점을 돌리다가 75프로에서 WA가 나서, 디버깅까지 하면 내 멘탈이 못버티겠다 하고, 구글링을 해서 반례를 찾았는데, 이 반례는 알고보니 java에서 한참 삽질하면서 이미 고쳤지만, c++로 옮기면서 미쳐 놓친 부분이였다. 마을이 같은 곳에 있을 때 마을의 위치대신 인덱스로 역추적을 하게 만드는 부분이였다.

 

 

 

 

 

 

 

 

 

 

 

  package boj18444;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
public class Main2 {
static int N;
static int posts;
static long L;

static long[] P;
static long [] sum;
static long [][][][] trace; // s, e , cnt    :    mid, cnt1 , cnt2    
static long [][][] dp; // s , e , cnt
static List<Long> ansL = new ArrayList<Long>();
  public static void main(String[] args) throws IOException{
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
N = Integer.parseInt(st.nextToken());posts = Integer.parseInt(st.nextToken());L = Long.parseLong(st.nextToken());
P = new long [N + 1]; sum = new long[N + 1]; trace = new long[N + 1][N + 1][N + 1][3]; dp = new long[N + 1][N + 1][N + 1];
st = new StringTokenizer(br.readLine());
for(int i = 1 ; i <= N; i++) {
P[i] = Long.parseLong(st.nextToken());
sum[i] = P[i] + sum [i - 1];
}
if(posts == 1) {
long ans1 = Long.MAX_VALUE; int ansi = 0;
  for(int n = 1 ; n <= N ; n++) {
long res = (P[n] + L / 2);
  long mid = res >= L ?  res - L : res;
  int mi = 0 ; int l = 1 ; int r = N;
    while(l <= r) {
    int m = (l + r) / 2;
    if(P[m] <= mid) {
    mi = m;
    l = m + 1;
    }else {
    r = m - 1;
    }
    }
 
  long sum1 = 0 ; long sum2 = 0 ; long sum3 = 0;
if(n <= mi) {
     sum1 = sum[mi] - sum[n - 1] - (mi - n + 1) * P[n];
     sum2 =   (N- (mi + 1) + 1) * P[n] + (L * (N - mi) - (sum[N] - sum[mi]));   // mi + 1 ~ N
     sum3 =  n * P[n] - sum[n];   // 1 ~ n2
  //   System.out.println(sum1 + " " + sum2 + " " + sum3);
 
  }else {
    sum1 = sum[N] - sum[n - 1] - (N - n + 1) * P[n];
    sum2 = mi * (L - P[n]) + sum[mi];
    sum3 = (n - (mi + 1) + 1) * P[n] - (sum[n] - sum[mi]);  // mi + 1 ~ n2
  }
if(sum1 + sum2 + sum3 < ans1) {
ans1 =  sum1 + sum2 + sum3;
ansi = n;
}
 

 
   

}

System.out.println(ans1);System.out.println(P[ansi]);

return;
}
for(int i = 0 ; i <= N ; i++) {
for(int j = 0 ; j <= N ; j++) {
Arrays.fill(dp[i][j], Long.MAX_VALUE / 2);
}
}

for(int i = 1 ; i <= N ; i++) {
for(int j = i + 1 ; j <= N ; j++) {
dp[i][j][2] = getS(i,j);
//System.out.println (dp[i][j][2] + " " +  i   + "  " +  j );
}
}
  List<Integer> pows = new ArrayList<Integer>();
 
  for(int pow = 2 ; pow <= 300 ; pow = 2 * pow - 1) {    
  pows.add(pow);
  }
 
 
for(int pow = 2 ; pow <= 150 ; pow = 2 * pow - 1) {        // N ^ 3 * log2N

for(int i = 1 ; i <= N ; i++) {
for(int j = i + 1 ; j <= N ; j++) {
for(int m = pow + i - 1 ; m <= j + 1 - pow ; m++) {
// if(m - i + 1 < pow ||   j - m + 1 < pow ) continue;
 


if(dp[i][m][pow] + dp[m][j][pow] < dp[i][j][2 * pow - 1]) {
trace[i][j][2 * pow - 1] = new long[] {m, pow , pow};
dp[i][j][2 * pow - 1] = dp[i][m][pow] + dp[m][j][pow];
}
 
}
}
}
}


 
int pow = 0;
int psum = 0;
for(int i = pows.size() - 1 ; i>= 0 ; i--) {
if(pows.get(i) <= posts) {
 psum = pows.get(i) ;    break;
}
}
while(psum < posts) {
for(int i = pows.size() - 1 ; i>= 0 ; i--) {
if(pows.get(i) + psum - 1 <= posts) {
 pow = pows.get(i); break;
}
}
 
for(int i = 1 ; i <= N ; i++) {
for(int j = i + 1 ; j <= N ; j++) {
for(int m = pow + i - 1 ; m <= j + 1 - psum ; m++) {
//if(m - i + 1 < pow ||   j - m + 1 < psum ) continue;
if( dp[i][m][pow] + dp[m][j][psum] < dp[i][j][pow + psum - 1]) {
dp[i][j][pow + psum - 1] = dp[i][m][pow] + dp[m][j][psum];
trace[i][j][pow + psum - 1] = new long[] {m, pow , psum};

}
 
}
}
}
psum = psum + pow - 1;
 

}
 
long ans = Long.MAX_VALUE ;

 
int anss = 0 ; int anse = 0 ; 
for(int i = 1 ; i <= N ; i++) {
for(int j = i + 1 ; j <= N ; j++) {
long res =  getS2(j,i)+ dp[i][j][posts];
if(res < ans) {
ans = res;
anss = i ; anse = j;

}
 
 
}
}

System.out.println(ans);
  makeList(anss , anse , posts);
 
  Set<Long> ansS = new HashSet<Long>();
  ansS.addAll(ansL);
  ansL = new ArrayList<Long>(ansS);
  Collections.sort(ansL);
StringBuilder sb = new StringBuilder();
for(long idx : ansL) sb.append(idx + " ");
System.out.println(sb);





}
   
 
  public static long getS(int n1 , int n2) {
  long mid = (P[n1] + P[n2]  ) /2;
  if(posts == 1) {
  long res = (P[n1] + L / 2);
  mid = res >= L ?  res - L : res;
  }
  int mi = 0 ; int l = 1 ; int r = N;
  while(l <= r) {
  int m = (l + r) / 2;
  if(P[m] <= mid) {
  mi = m;
  l = m + 1;
  }else {
  r = m - 1;
  }
  }
  //System.out.println(  n1 + "  " + n2 + " m " + mi);
  long sum1 =  sum[mi] - sum[n1 - 1] - (mi - n1 + 1) * P[n1] ;
  long sum2 =  (n2 - (mi + 1) + 1) * L - (sum[n2] - sum[mi] )      - (n2 - (mi + 1) + 1) * (L - P[n2]) ;
  return sum1 + sum2;
   
 
 
 
 
 
  }
 
  public static long getS2(int n1 , int n2) {
  long dist = L - P[n1] + P[n2]; long mid = 0;
  if(P[n1] + dist / 2 >= L) { 
  mid = P[n1] + dist / 2 - L;
  }
  else mid = P[n1 ] + dist  / 2;
 
  int mi = 0 ; int l= 1 ; int r = N;
  while(l <= r) {
  int m = (l + r) / 2;
  if(P[m] <= mid) {
  mi = m;
  l = m + 1;
  }else {
  r = m - 1;
  }
  }
  long sum1 = 0 ; long sum2 = 0 ; long sum3 = 0;
  if(n1 <= mi) {
     sum1 = sum[mi] - sum[n1 - 1] - (mi - n1 + 1) * P[n1];
     sum2 =   (N- (mi + 1) + 1) * P[n2] + (L * (N - mi) - (sum[N] - sum[mi]));   // mi + 1 ~ N
     sum3 =  n2 * P[n2] - sum[n2];   // 1 ~ n2
  //   System.out.println(sum1 + " " + sum2 + " " + sum3);
 
  }else {
    sum1 = sum[N] - sum[n1 - 1] - (N - n1 + 1) * P[n1];
    sum2 = mi * (L - P[n1]) + sum[mi];
    sum3 = (n2 - (mi + 1) + 1) * P[n2] - (sum[n2] - sum[mi]);  // mi + 1 ~ n2
 
 
  }
  // System.out.println(n1 + " " + n2 + " mi "  + mi + " mid " + mid +  " res " + (sum1 + sum2 + sum3));
  return sum1 + sum2 + sum3;
   
 
 
 
 
 
  }
    static void makeList(int s , int e , int cnt) {
     if(cnt == 2) {
     ansL.add(P[s]); ansL.add(P[e]);
     return ;
     }
     int mid = (int)trace[s][e][cnt][0];
     int cnt1 = (int)trace[s][e][cnt][1];
      int cnt2= (int)trace[s][e][cnt][2];
      makeList(s, mid , cnt1); makeList(mid, e , cnt2);
      
    }





}