SML#でJITコンパイラを作る軽い話

このエントリはML Advent Calendar 2015及び言語実装 Advent Calendar 201515日目の記事です。

κeenです。先日流れてきたRustでJITコンパイラを作る話が面白かったのでSML#でもやってみます。

JITコンパイラってなに?

JITはJust In Timeで、日本語にすると「間に合って」になります。誤読されかねませんが「臨時コンパイラ」と訳すのが適当なのでしょうか。

普通のインタプリタがソースコードを直接評価するのに対してJITコンパイラはソースコードを内部で一旦ネイティブコードにコンパイルして実行します。

よく、JITコンパイラ/JITコンパイルのことをJITと略して呼びます。

どこで使うの?

主に、インタプリタの高速化の目的で使われるようです。他にはコンパイラをインタプリタ化させる(REPLの実装とか)でも使えそうな気がしますが、詳しくは知りません。

JITってどうやって作るの?

JIT でない コンパイラが吐いたコードは一旦ディスクからメモリに読まれ、メモリ上で実行されます。

JITコンパイラが吐くコードは直接メモリに吐かれ、メモリ上で実行されます。なので細かい条件を無視すれば

  1. 吐かれる命令のための(特殊な)メモリ領域を用意する
  2. そこに実行したい命令列を書き込む
  3. 書き込んだ命令列を実行する

だけです。 3.の「書き込んだ命令を実行する」が若干ふわっとしてますが、多くの場合は命令の書かれたメモリを関数ポインタとして扱って、その関数を呼び出す形で実行されます。 一応直接プログラムカウンタをいじって書き込んだ命令列を実行させることも出来ます。

一緒に作ってみよう

思ったより簡単そうですね。作ってみましょう。

1. メモリ領域を準備する

先に「特殊な」と付言しましたのでそれについて説明しましょう。プロセス上のメモリは、OSによって保護されています。 特に、そのままではそのメモリ上の命令列を実行出来ません。その保護を外すためにmprotect(2)システムコールを使います。

int mprotect(void *addr, size_t len, int prot);

SML#を使えば簡単に使えるようになります。

val mprotect = _import "mprotect": (unit ptr, word, word) -> int

SML#は2.0.0時点では32bitでしか動かないのでsize_tをwordとしてしまって大丈夫です。また、bit演算をやりたいのでprotは本来はintですがwordにしました。

mprotect(2)はPOSIX(SUSv4)的には引数のaddrがシステムのページの境界上にあることを要求しても良いことになっています。 実際、OS Xが要求するらしいです。また、メモリ保護はページ単位で行われるため、lenもページサイズの倍数である必要があります。 protは保護フラグですね。

メモリの確保、特にページの境界にアラインメントされたメモリ領域を確保するにはSML#では厳しそうなのでこれもCの関数に頼ります。 Linuxならいくつか方法はありますが、ポータブルにはposix_memalign(3)を使うのが良いようです。

int posix_memalign(void **memptr, size_t alignment, size_t size);

これもSML#から簡単に使えます。

val posix_memalign = _import "posix_memalign": (unit ptr ref, word, word) -> int

ついでにメモリの解放も必要ですね。free(3)もインポートしましょう。

val free = _import "free": unit ptr -> ()

さて、メモリを確保するのにいくつかCのマクロで定義された定数を定義しましょう。今のところSML#にCのマクロを使う術はなさそうです。どうにかしてCプリプロセッサを使えれば良いのですが…。

val PROT_READ  = 0wx1
val PROT_WRITE = 0wx2
val PROT_EXEC  = 0wx4
val PROT_NONE  = 0wx0

val PAGE_SIZE = 0w4096

ページサイズは本当はsysconf(3)を使ってsysconf(PAGESIZE)として取るのが作法ですが面倒なので4KiB決め打ちにしてしまいました。

さて、これらを用いてJIT用のメモリは次のように確保します。Pointerモジュールが必要なのでsmiファイルでrequire "ffi.smi"しましょう。

val op orb = Word.orb
infix 5 orb
val size = 0wx1
val msize = size * PAGE_SIZE
val pageRef: unit ptr ref = ref (Pointer.NULL ())
val ret = posix_memalign (pageRef, PAGE_SIZE, msize)
val page = if ret = 0
           then !pageRef
           else raise Fail "memory allocation failed"
val PROT_RWEX = PROT_READ orb PROT_WRITE orb PROT_EXEC
val ret = mprotect (page, msize, PROT_RWEX)
val () = if ret = 0
         then ()
         else raise Fail "memory protection error"

これでpageに実行可能フラグが立ったメモリが確保出来ました。sizeは書き込みたい命令列に応じて変えます。

2. 命令列を書き込む

さて、ここらへんからアセンブラの世界になってきます。

まず、安全のために件のメモリ領域をret命令で埋めます。 ret命令は関数から戻る時に使われる命令で、変なことをしてもretで元の場所に戻ればどうにかなる可能性があるのです。

メモリを一気に埋めるにはmemset(3)が便利です。

void *memset(void *s, int c, size_t n);

これをSMLにインポートして使いましょう。ret命令に対応するマシン語は0xc3です。

val memset = _import "memset": (unit ptr, word, word) -> unit ptr
val _ = memset (page, 0wxc3, msize)

じゃあ、命令を書き込んでいきましょう。最初はただ1を返すだけの関数に相当する命令です。 64bitマシンを使っていても32bitでコンパイルしたSML#を使っている限りはx86(IA32)のアセンブラを使います。 x86は値をeaxレジスタに入れることで返り値とします。

mov 1 eax
ret

このコードをオンラインアセンブラでアセンブルしてみると

0:  b8 01 00 00 00          mov    eax,0x1
5:  c3                      ret

との結果が返ってきました。これを真心込めて1つづつメモリに書き込んでいきます。 SML#からメモリに直接書き込むには内部APIっぽいものを触る必要があります。 タイプ数を減らすためにいくつかエイリアスを定義しておきましょう。

val fromUnitPtr = SMLSharp_Builtin.Pointer.fromUnitPtr
val store = Pointer.store
val advance = Pointer.advance

メモリにword8の値を書き込むのにキャストが必要になるのでfromUnitPtrを使います。 さて、書き込んでいきます。

val mem: word8 ptr = fromUnitPtr page
val () = store page 0wxb8
val mem = advance mem
val () = store page 0wx01
val mem = advance mem
val () = store page 0wx00
val mem = advance mem
val () = store mem 0wx00
val mem = advance mem
val () = store mem 0wx00
val mem = advance mem
val () = store mem 0wxc3
val _ = advance mem

これでpageに1を返すだけの命令列が書き込まれした。

3. 書き込んだ命令列を実行する

書き込んだ命令列は関数にキャストして使うことが多いと言いました。今回もそのようにします。

そのために、また内部APIっぽいものを使います。

val toCodeptr = SMLSharp_Builtin.Pointer.toCodeptr

codeptr型は関数ポインタを表すもののようです。DynamicLinkモジュールなどで使われています。

さて、今回の1を返すだけの命令列は() -> int型になりますね。 それをインポート、実行しましょう。

val freturn1 = toCodePtr page: _import () -> int
val () = print ((Int.toString (freturn1 ())) ^ "\n")
val _ = free page

これで見事1が印字されたら成功です。

さらに

先程のコードは繰り返しが多く、整理されていませんでした。もう少し整理しましょう。

structure Emit = struct
    val posix_memalign = _import "posix_memalign": (unit ptr ref, word, word) -> int
    val mprotect = _import "mprotect": (unit ptr, word, word) -> int
    val memset = _import "memset": (unit ptr, word, word) -> unit ptr
    val free = _import "free": unit ptr -> ()
    val printf = _import "printf": (string, unit ptr) -> ()

    (* 
#define PROT_READ	0x1		/* Page can be read.  */
#define PROT_WRITE	0x2		/* Page can be written.  */
#define PROT_EXEC	0x4		/* Page can be executed.  */
#define PROT_NONE	0x0		/* Page can not be accessed.  */
    *)

    val PROT_READ  = 0wx1
    val PROT_WRITE = 0wx2
    val PROT_EXEC  = 0wx4
    val PROT_NONE  = 0wx0

    val PAGE_SIZE = 0w4096

    type jitptr = word8 ptr
    val fromUnitPtr = SMLSharp_Builtin.Pointer.fromUnitPtr
    val toUnitPtr = SMLSharp_Builtin.Pointer.toUnitPtr
    val toCodeptr = SMLSharp_Builtin.Pointer.toCodeptr
    val store = Pointer.store
    val advance = Pointer.advance


    fun jitMemory size: jitptr = let
        val op orb = Word.orb
        infix 5 orb
        val msize = size * PAGE_SIZE
        val pageRef: unit ptr ref = ref (Pointer.NULL ())
        val ret = posix_memalign (pageRef, PAGE_SIZE, msize)
        val page = if ret = 0
                   then !pageRef
                   else raise Fail "memory allocation failed"
        val PROT_RWEX = PROT_READ orb PROT_WRITE orb PROT_EXEC
        val ret = mprotect (page, msize, PROT_RWEX)
        val () = if ret = 0
                 then ()
                 else raise Fail "memory protection error"
        (* init with ret for safety *)
        val _ = memset (page, 0wxc3, msize)
    in
        fromUnitPtr page
    end

    fun freeJit (jitMem: jitptr) = free (SMLSharp_Builtin.Pointer.toUnitPtr jitMem)

    fun pushWord page (word: word8) = (store (page, word); advance (page, 1))
    fun pushWords (page: jitptr) l = List.foldl (fn(w,page) => pushWord page w) page l

    val import: jitptr -> codeptr = toCodeptr o toUnitPtr

    fun fromMachineCode l = let
        val len = Word.fromInt(List.length l)
        val size = (len + PAGE_SIZE) div PAGE_SIZE
        val page = jitMemory size
        val _ = pushWords page l
    in
        import page
    end
end

次のように使います。一回importした関数はfreeしないことにしましょう。

fun println x = print (x ^ "\n")
val return1  =
    (* 0:  b8 01 00 00 00          mov    eax,0x1  *)
    [
      0wxb8, 0wx01, 0wx00, 0wx00, 0wx00
    ] 
val freturn1 = Emit.fromMachineCode return1 :_import () -> int
val () = println (Int.toString (freturn1 ()))

引数を取る

もうアセンブラの話になります。

x86では引数は右から順にスタックに積まれます。intを1つ取ってそれに1足して返す関数はこうなります。

val add1 = 
    (* 0:  8b 44 24 04             mov    eax,DWORD PTR [esp+0x4] *)
    (* 4:  83 c0 01                add    eax,0x1 *)
    [
      0wx8b, 0wx44, 0wx24, 0wx04,
      0wx83, 0wxc0, 0wx01
    ]
val fadd1 = Emit.fromMachineCode add1 :_import (int) -> int
val () = println (Int.toString (fadd1 3))

正常に動けば4が印字されます。

つらい話

バグ

最初、Pointer.storeでなく、SMLSharp_Builtin.Pointer.storeを使っていたら変なバグ踏みました。バグというか使い方が悪かった。

デバッグ

デバッガがないので非常につらいです。書き出された命令列を見るのに困りました。gdbで見たかったので、

val printf = _import "printf": (string, unit ptr) -> int
val _ = pritnf ("page pointr: %p", page)

してprintfにブレークポイントを張り、そこで止めつつページのアドレスを取得、

(gdb) x/20xh 0x81ca000

などとして見ていました。

もっと

JITコンパイラが楽しかったのでアセンブラのDSLを作ってみました。1、2個の命令吐けるだけですがちゃんと動きました。今のところこのように書けます。Intel記法だとしんどそうだったのでAT&T記法っぽく書けるようにしました。

val freturn1' = Emit.fromInsts [
        xorl eax eax,
        addl ($1) eax,
        ret
    ]:_import () -> int
val () = println (Int.toString (freturn1' ()))

x86は命令フォーマットが1バイトから15バイトまでの可変長で、内部表現をどのようにするか決めるだけでも一苦労でした。x86つらい。

アドレッシングが複雑なのも悩みどころで、複数のアドレッシングを統一的に扱えるようにオーバーロードされた関数を用意したのですがあえなくSML#のバグを踏んで死亡しました。本当はこういう記法が出来る筈だった…

addl (%eax) eax
addl eax (%eax)
addl eax (%(eax, ebx))
addl eax (%(eax, ebx, 4))
addl eax (%(4, eax))
$ /usr/local/bin/smlsharp   -c -o main.o main.sml
uncaught exception: Bug.Bug: InferType: FIXME: user error: invalid instTy at src/compiler/typeinference2/main/InferTypes2.sml:47
Makefile:11: recipe for target 'main.o' failed
make: *** [main.o] Error 1

movlを実装しようとしたらアドレッシングが動かなかったので萎えてまだ実装してません。

よく考えたら%ってミスリーディングだし名前変えよう。

まとめ

  • JITを作るにはメモリ保護をいじれて関数ポインタのインポートが出来ればいいよ
  • SML#でもJIT作れるよ
  • アセンブラっぽいの作ったよ

今回のコードはこちらにあります。

Written by κeen