あるエンジニアのAtCoder奮闘記

東京都港区にあるアミフィアブル株式会社のエンジニアが、AtCoderで解いた問題について振り返ったりしていく会社公認のブログです。

AtCoder ABC 123 D - Cake 123 をNumpyを用いてスマートに解きたかった話 (Python)

irisruneです。Numpyを有効に使って計算時間を短くしたかったのですが、なかなか難しいですね。

問題

atcoder.jp

公式解説でも解法が複数説明されていますが、どう解くとしても手法に工夫が必要でかなり難しいです。AtCoder ProblemsのDifficulty684しかないって本当ですか。

また、解説ACのためアルゴリズム面の解説はあまり行いません。

提出1(公式解説:解法2)

import sys
import numpy as np


def main():
    x, y, z, k = [int(n) for n in input().split()]
    a = sorted([int(n) for n in input().split()], reverse=True)
    b = sorted([int(n) for n in input().split()], reverse=True)
    c = sorted([int(n) for n in input().split()], reverse=True)
    va = np.array(a)[:, np.newaxis, np.newaxis]
    vb = np.array(b)[np.newaxis, :, np.newaxis]
    vc = np.array(c)[np.newaxis, np.newaxis, :]
    vx = np.arange(1, x + 1)[:, np.newaxis, np.newaxis]
    vy = np.arange(1, y + 1)[np.newaxis, :, np.newaxis]
    vz = np.arange(1, z + 1)[np.newaxis, np.newaxis, :]
    print(*(sorted((va + vb + vc)[(vx * vy * vz) <= k], reverse=True))[:k], sep="\n")


main()

結論から言えばこれはREでした。

方針としてはすべての組み合わせについてのケーキの美味しさの合計を3次元行列の形で持ち、最終的な出力に含まれる可能性のあるケーキのみをブーリアンマスクで抽出して1次元リストに変換、それをソートして出力するというものです。

まあ、1000\times1000\times1000の3次元行列を作ったらREも起こしますよね。

提出2(公式解説:解法1)

import sys
import numpy as np


def main():
    x, y, z, k = [int(n) for n in input().split()]
    va = np.array(sorted([int(n) for n in input().split()], reverse=True))[:, np.newaxis]
    vb = np.array(sorted([int(n) for n in input().split()], reverse=True))[np.newaxis, :]
    xy = min(x*y, k)
    vac = np.array(sorted((va + vb).reshape(-1), reverse=True)[:xy]).reshape(xy, 1)
    vc = np.array(sorted([int(n) for n in input().split()], reverse=True))[np.newaxis, :]
    xyz = min(x*y*z, k)
    print(*(sorted((vac + vc).reshape(-1), reverse=True)[:xyz]), sep="\n")


main()

REを起こす原因として異常に大きな行列という推測が立てられたので、最大でも3000(\leq K)\times1000の2次元行列で済むように解いてみました。方針としては2つのケーキの美味しさの合計を2次元行列の形で持ち、大きい方からK個を1次元リスト(ベクトル)の形に置き換えた後に残り1つのケーキの美味しさとブロードキャスト計算を行うといったものです。

しかしこちらはTLEになってしまいました。ブロードキャスト演算自体は高速で行われると思われるので、X\times Y(あるいはK\times Z)の2次元行列を1次元ベクトルに変換する処理に時間がかかっているものと推測されますが、実際の所はわかりません。

提出3(公式解説:解法1)

import sys


def main():
    x, y, z, k = [int(n) for n in input().split()]
    va = sorted([int(n) for n in input().split()], reverse=True)
    vb = sorted([int(n) for n in input().split()], reverse=True)
    xy = min(x*y, k)
    vab = (sorted([a + b for a in va for b in vb], reverse=True))[:xy]
    vc = sorted([int(n) for n in input().split()], reverse=True)
    xyz = min(x*y*z, k)
    print(*((sorted([ab + c for ab in vab for c in vc], reverse=True))[:xyz]), sep="\n")


main()

結局、Numpyを用いずに解いたところACできました(1757msかかっていますが)。リスト内法表記を使えばNumpyを用いるメリットも薄い…んでしょうか?まだよくわかっていないですね。

雑記

  • 提出3のコードをPyPy3で提出すると727msしかかかりませんでした。Numpy…というよりPythonって難しいですね。