AtCoder ABC 123 D - Cake 123 をNumpyを用いてスマートに解きたかった話 (Python)
irisruneです。Numpyを有効に使って計算時間を短くしたかったのですが、なかなか難しいですね。
問題
公式解説でも解法が複数説明されていますが、どう解くとしても手法に工夫が必要でかなり難しいです。AtCoder ProblemsのDifficulty684しかないって本当ですか。
また、解説ACのためアルゴリズム面の解説はあまり行いません。
提出1(公式解説:解法2)
import sys import numpy as np def main(): x, y, z, k = [int(n) for n in input().split()] a = sorted([int(n) for n in input().split()], reverse=True) b = sorted([int(n) for n in input().split()], reverse=True) c = sorted([int(n) for n in input().split()], reverse=True) va = np.array(a)[:, np.newaxis, np.newaxis] vb = np.array(b)[np.newaxis, :, np.newaxis] vc = np.array(c)[np.newaxis, np.newaxis, :] vx = np.arange(1, x + 1)[:, np.newaxis, np.newaxis] vy = np.arange(1, y + 1)[np.newaxis, :, np.newaxis] vz = np.arange(1, z + 1)[np.newaxis, np.newaxis, :] print(*(sorted((va + vb + vc)[(vx * vy * vz) <= k], reverse=True))[:k], sep="\n") main()
結論から言えばこれはREでした。
方針としてはすべての組み合わせについてのケーキの美味しさの合計を3次元行列の形で持ち、最終的な出力に含まれる可能性のあるケーキのみをブーリアンマスクで抽出して1次元リストに変換、それをソートして出力するというものです。
まあ、の3次元行列を作ったらREも起こしますよね。
提出2(公式解説:解法1)
import sys import numpy as np def main(): x, y, z, k = [int(n) for n in input().split()] va = np.array(sorted([int(n) for n in input().split()], reverse=True))[:, np.newaxis] vb = np.array(sorted([int(n) for n in input().split()], reverse=True))[np.newaxis, :] xy = min(x*y, k) vac = np.array(sorted((va + vb).reshape(-1), reverse=True)[:xy]).reshape(xy, 1) vc = np.array(sorted([int(n) for n in input().split()], reverse=True))[np.newaxis, :] xyz = min(x*y*z, k) print(*(sorted((vac + vc).reshape(-1), reverse=True)[:xyz]), sep="\n") main()
REを起こす原因として異常に大きな行列という推測が立てられたので、最大でもの2次元行列で済むように解いてみました。方針としては2つのケーキの美味しさの合計を2次元行列の形で持ち、大きい方から個を1次元リスト(ベクトル)の形に置き換えた後に残り1つのケーキの美味しさとブロードキャスト計算を行うといったものです。
しかしこちらはTLEになってしまいました。ブロードキャスト演算自体は高速で行われると思われるので、(あるいは)の2次元行列を1次元ベクトルに変換する処理に時間がかかっているものと推測されますが、実際の所はわかりません。
提出3(公式解説:解法1)
import sys def main(): x, y, z, k = [int(n) for n in input().split()] va = sorted([int(n) for n in input().split()], reverse=True) vb = sorted([int(n) for n in input().split()], reverse=True) xy = min(x*y, k) vab = (sorted([a + b for a in va for b in vb], reverse=True))[:xy] vc = sorted([int(n) for n in input().split()], reverse=True) xyz = min(x*y*z, k) print(*((sorted([ab + c for ab in vab for c in vc], reverse=True))[:xyz]), sep="\n") main()
結局、Numpyを用いずに解いたところACできました(1757msかかっていますが)。リスト内法表記を使えばNumpyを用いるメリットも薄い…んでしょうか?まだよくわかっていないですね。
雑記
- 提出3のコードをPyPy3で提出すると727msしかかかりませんでした。Numpy…というよりPythonって難しいですね。