PyPyの仕様でハマったこと

いつものように苦役と贖罪やってたら、型推論で初見殺しにあったので記録として残しておく

 

題材とする問題

 

DPまとめコンテストのJ - Sushi

 

atcoder.jp

 

次のコードでTLE。正直これ以上何変えるねんと思ったので、潔く何がいけないかネットの海に旅に出た

 

はてなは(を2つ重ねると勝手に注釈扱いにされるので、ちょっと空白入れて不自然さあるけどスルーしてほしい

 

import sys

sys.setrecursionlimit(100000000)

from collections import Counter

N = int(input())
A = list(map(int , input().split()))
B = Counter(A)
n1 = B[1]
n2 = B[2]
n3 = B[3]
inf = -1
dp =[[[inf for _ in range(n3+10)] for _ in range(n2+n3+10)]for _ in range(n1+n2+n3+10)]
dp[0][0][0] = 0


def f(x,y,z):
    if dp[x][y][z] != inf:
        return dp[x][y][z]
    c1 = ( (x/(x+y+z))*f(x-1,y,z) if x >= 1 else 0)
    c2 = ( (y/(x+y+z))*f(x+1,y-1,z) if y >= 1 else 0)
    c3 = ( (z/(x+y+z))*f(x,y+1,z-1) if z >= 1 else 0)
    c4 = (N/(x+y+z))
    dp[x][y][z] = c1+c2+c3+c4
    return dp[x][y][z]
   
   
print(f(n1,n2,n3))

 

2文字追加するとAC

import sys

sys.setrecursionlimit(100000000)

from collections import Counter

N = int(input())
A = list(map(int , input().split()))
B = Counter(A)
n1 = B[1]
n2 = B[2]
n3 = B[3]
inf = -1.0
dp =[[[inf for _ in range(n3+10)] for _ in range(n2+n3+10)]for _ in range(n1+n2+n3+10)]
dp[0][0][0] = 0


def f(x,y,z):
    if dp[x][y][z] != inf:
        return dp[x][y][z]
    c1 = ( (x/(x+y+z))*f(x-1,y,z) if x >= 1 else 0)
    c2 = ( (y/(x+y+z))*f(x+1,y-1,z) if y >= 1 else 0)
    c3 = ( (z/(x+y+z))*f(x,y+1,z-1) if z >= 1 else 0)
    c4 = (N/(x+y+z))
    dp[x][y][z] = c1+c2+c3+c4
    return dp[x][y][z]
   
   
print(f(n1,n2,n3))

 

一番早いからPyPyで提出しているのだが、こいつが速いのがPythonの変数の型をうまいこと推論してくれてるおかげなので、型をごちゃごちゃ混ぜるとPyPyの良さが失われる

 

なので整数型の問題ばかりやってると忘れがちだが、小数が混じる時はdpの初期化をきちんとfloatで行わないと、うそ。。。わたしのPyPy遅すぎ。。。となってしまう。

 

初見殺しすぎて完全にやられたので記録に残しておく。