Hatena::Grouptopcoder

(iwi) { 反省します

TopCoder: [[iwi]] / Twitter: @iwiwi

 | 

2012-12-19

SMT Solver (z3) で探索問題を解く

23:01 | SMT Solver (z3) で探索問題を解く - (iwi) { 反省します を含むブックマーク はてなブックマーク - SMT Solver (z3) で探索問題を解く - (iwi) { 反省します SMT Solver (z3) で探索問題を解く - (iwi) { 反省します のブックマークコメント

ちょっと触る機会があったのでメモ程度に書いておく.Python は普段使わないのでコードはあまりよくないかもしれない.


はじめに

SMT Solver?

SMT Solver とは SAT Solver より色々できるひとたち. SMT って,SAT に比べてあまり名前を聞かないという人も多いかと思うが,SAT Solver の界隈では,実は SAT と SMT は並列にタイトルに入るぐらいには話題沸騰なのである.


z3?

最強の SMT Solver の 1 つっぽい. Microsoft Research で開発されている. Python から簡単に呼べる z3py というのがあって,ウェブからすぐ使ってみることのできるデモまである.すごい.これを見れば大体 SMT Solver でどんなことができるかの雰囲気がよく分かると思う.

最新版を git から clone してビルドしてインストールしたらスグに z3 が python から使えた.ビルドの際,dos2unix がないと駄々をこねていたので,ln -s todos unix2dos してやった.



書いてみたコード

Captain Q's Treasure (ICPC福岡2011 G)

公式データセットに対して 4 秒で走る.とはいえそもそも IP だから SMT Solver で解く意義は薄い.とりあえず定式化が簡単そうなのでやってみた.

from z3 import *

while 1:
    #
    # Input
    #
    h, w = map(int, raw_input().split())
    if h == 0 and w == 0:
        break
    field = [raw_input() for i in range(h)]

    #
    # Constraints
    #
    var_x = [ [ Int("x_%s_%s" % (i + 1, j + 1)) for j in range(w) ] for i in range(h) ]

    solver = Solver()
    solver.add([var_x[y][x] == 0 if field[y][x] == '.' else And(0 <= var_x[y][x], var_x[y][x] <= 1)
                for x in range(w) for y in range(h) ])

    for y in range(h):
        for x in range(w):
            if '0' <= field[y][x] and field[y][x] <= '9':
                xs = []
                for dy in [-1, 0, 1]:
                    for dx in [-1, 0, 1]:
                        cx = x + dx
                        cy = y + dy
                        if cx >= 0 and cx < w and cy >= 0 and cy < h:
                            xs += [var_x[cy][cx]]
                solver.add(Sum(xs) == int(field[y][x]))

    xs = [var_x[y][x] for y in range(h) for x in range(w)]

    #
    # Solve
    #
    low = -1
    high = h * w
    while high - low > 1:
        mid = (low + high) / 2
        solver.push()
        solver.add(Sum(xs) <= mid)
        if solver.check() == sat:
            high = mid
        else:
            solver.pop()
            low = mid
    print high

15 パズル

次に挑戦するものとしてはタフすぎた気配がするw 定式化大変だった.

で,性能だが,3x3 の 8 パズルだとわりと何とかなるぽいが,4x4 の 15 パズルだと 10 手以下のものは解けるけど増えるにしたがってやばくなっていく.

しょんぼりやな.

from z3 import *

god = 20  # limit...

def solve(field, goal):
    #goal = ((1,2,3),(4,5,6),(7,8,0))
    h = len(field)
    w = len(field[0])

    print [h, w]

    x0, y0 = -1, -1
    for x in range(w):
        for y in range(h):
            if field[y][x] == 0:
                x0 = x
                y0 = y

    print [x0, y0]

    var_f = [ [ [ Int("f_%s_%s_%s" % (g, y + 1, x + 1)) for x in range(w) ] for y in range(h) ] for g in range(god + 1) ]
    var_d = [ Int("d_%s" % (g)) for g in range(god) ]
    var_x = [ Int("x_%s" % (g)) for g in range(god + 1) ]
    var_y = [ Int("y_%s" % (g)) for g in range(god + 1) ]

    #
    # Placement
    #
    constraint_init = []

    for g in range(god):
        constraint_init += [var_f[0][y][x] == field[y][x] for x in range(w) for y in range(h) ]
        constraint_init += [var_x[0] == x0]
        constraint_init += [var_y[0] == y0]
        constraint_init += [var_f[god - 1][y][x] == goal[y][x] for x in range(w) for y in range(h) ]

    #
    # Movement
    #
    constraint_move = []
    dx = [1, 0, -1, 0, 0]
    dy = [0, 1, 0, -1, 0]

    for g in range(god):
        cx = var_x[g]
        cy = var_y[g]
        d = var_d[g]
        nx = var_x[g + 1]
        ny = var_y[g + 1]
        constraint_move += [Or([And(d == i, nx == cx + dx[i], ny == cy + dy[i]) for i in range(5)])]
        constraint_move += [0 <= nx, nx < w, 0 <= ny, ny < h]

    #
    # Board
    #
    constraint_board = []

    def gen_move(cf, nf, x, y, d, r, i = 0):
        if i == 4:
            return nf[y][x] == cf[y][x]
        else:
            j = (4 if i == 4 else (i + 2) % 4) if r else i
            tx = min(max(x + dx[j], 0), w - 1)
            ty = min(max(y + dy[j], 0), h - 1)
            return If(d == i, nf[ty][tx] == cf[y][x], gen_move(cf, nf, x, y, d, r, i + 1))

    for g in range(god):
        cx = var_x[g]
        cy = var_y[g]
        cf = var_f[g]
        nx = var_x[g + 1]
        ny = var_y[g + 1]
        nf = var_f[g + 1]
        d = var_d[g]
        for x in range(w):
            for y in range(h):
                constraint_board += [
                    If(And(x == cx, y == cy), gen_move(cf, nf, x, y, d, False),
                       If(And(x == nx, y == ny), gen_move(cf, nf, x, y, d, True),
                          nf[y][x] == cf[y][x]))]

    #
    # Solve!
    #
    solver = Solver()
    solver.add(constraint_init + constraint_move + constraint_board)
    for target in reversed(xrange(god)):
        solver.add(var_d[target] == 4)
        print [target, solver.check()]

solve(((1,2,3,4),(9,5,6,7),(13,10,11,8),(0,14,15,12)),
      ((1,2,3,4),(5,6,7,8),(9,10,11,12),(13,14,15,0)))
 |