Pythonを改造して多重ループを一度に抜けるbreakを実装してみた

 東京大学電子情報工学科・電気電子工学科(EEIC)では、後期選択実験として、「大規模ソフトウェアを手探る」があります。私たち第7班「だいがくいも」は、Pythonをテーマにこの実験に取り組みました。本ブログはその最終レポートとして書かれたものです。

Pythonとは

 Pythonプログラミング言語の一つです。2020年の調査Pythonプログラミング言語の人気2位になっており、特に機械学習ディープラーニングの分野で注目を集めています。また、文法が簡素で分かりやすいので、初学者にも優しいプログラミング言語でもあります。実は筆者が大学に入って初めて学習したのもPythonでした。


目標

 Pythonで多重ループを一度に抜けたいなと思ったとき、もし実装するなら例えば以下のようになります。

for i in range(10):
    for j in range(10):
        if i == 5 and j == 5:
            break
    else:
        continue
    break

もちろんこれでもいいんですが、break1回で抜けられた方が気持ちいい!!ということで、以下のような感じで多重ループを一度で抜けることのできるように、Pythonを改造してみました。

for i in range(10):
    for j in range(10):
        if i == 5 and j == 5:
            break 2 #ループを2つ抜ける


環境

  • OS:Ubuntu20.04
  • CPythonのバージョン:Python3.10.0
  • Cpythonの変更に用いたPythonのバージョン:Python3.8.4
  • エディタ:Emacs, VScode
  • デバッガ:gdb

後述しますが、CPythonはバージョンによって構文解析のあたりがかなり異なっているので、CPythonのバージョンを選ぶ際には注意が必要です。また、CPythonに変更に加えるときにPythonを使うので、もともとPCにインストールされているPythonのバージョンが古いとエラーが出たりします。少なくともPython3.6.9ではうまくいきませんでした。


ビルド

 Pythonの公式サイトからPython 3.10.0をダウンロードして、以下のようにコマンドを打ちました。

$ tar xvf Python-3.10.0.tar.gz
$ cd Python-3.10.0/
$ CFLAGS="-O0 -g" ./configure --prefix=インストールするディレクトリのパス
$ make
$ make install

これでインストールは無事完了です。実行する際は、以下のようにディレクトリを移動してから実行します。

$ cd インストールしたディレクトリ
$ ./python3 実行するファイル


breakのバイトコードを探る

 disモジュールを用いて、バイトコードを逆アセンブルして人が見て分かる形で表示させることができます。以下のようにして、breakバイトコードを調べます。

#test.py
for i in range(5):
    for j in range(5):
        if j == 1:
            break
    print(i)
print("fin")
$ ./python3 -m dis test.py
  1           0 LOAD_NAME                0 (range)
              2 LOAD_CONST               0 (5)
              4 CALL_FUNCTION            1
              6 GET_ITER
        >>    8 FOR_ITER                19 (to 48)
             10 STORE_NAME               1 (i)

  2          12 LOAD_NAME                0 (range)
             14 LOAD_CONST               0 (5)
             16 CALL_FUNCTION            1
             18 GET_ITER
        >>   20 FOR_ITER                 8 (to 38)
             22 STORE_NAME               2 (j)

  3          24 LOAD_NAME                2 (j)
             26 LOAD_CONST               1 (0)
             28 COMPARE_OP               2 (==)
             30 POP_JUMP_IF_FALSE       18 (to 36)

  4          32 POP_TOP
             34 JUMP_ABSOLUTE           19 (to 38)

  3     >>   36 JUMP_ABSOLUTE           10 (to 20)

  5     >>   38 LOAD_NAME                3 (print)
             40 LOAD_NAME                1 (i)
             42 CALL_FUNCTION            1
             44 POP_TOP
             46 JUMP_ABSOLUTE            4 (to 8)

  6     >>   48 LOAD_NAME                3 (print)
             50 LOAD_CONST               2 ('fin')
             52 CALL_FUNCTION            1
             54 POP_TOP
             56 LOAD_CONST               3 (None)
             58 RETURN_VALUE

一番左の番号が行を表しているので、4行目に対応するPOP_TOPJUMP_ABSOLUTEbreakから生成されたバイトコードです。

POP_TOP

スタックから1つ取り出します。この場合は、forから抜ける後処理として、ループを回すためのイテレータを取り出しています。(ちなみにこの処理をしないと、このコードの場合は無限ループになるはずです。スタックに無限にイテレータがたまっていくので、最終的には異常終了しそうです。)ここで行われる処理はループの種類によって変わります。

JUMP_ABSOLUTE

breakの本体である、ジャンプ命令のバイトコードです。ジャンプのターゲット((to 38)のところ)を変更することで、多重ループを抜けるbreakを実装できそうです。


CPythonのコンパイラの設計

 以下に、ソースコードからバイトコードが生成されるまでの過程を簡単にまとめました。Python開発者ガイドにあったドキュメントを参考にしているので、より詳細に知りたい方はそちらを確認してください。また、Python3.8からPython3.9になる際に、パーサー(構文解析)の仕組みが大きく変わっており、以下の説明はPython3.8以前のバージョンには必ずしも適用できないので注意してください。

f:id:doss2021_7:20211104011255p:plain
バイトコードが生成されるまでの流れ

ソースコードからASTが生成されるまで

 まず、ソースコードParser/tokenizer.cによってトークンに分解されます。このときトークンの定義はGrammar/Tokensにあり、新しいトークンを定義したい場合はここを変更する必要がありますが、今回の場合は特に書き換える必要はありません。
 次に、トークン列はParser/parser.cによって抽象構文木(AST)に変換されます。ASTとは、ソースコードバイトコードの中間状態であり、プログラムの構造をASTノード(ステートメントや式など)からなる木構造で表したものです。またこのとき、Python/Python-ast.cにより、各ASTノードに対応したC言語の構造体が生成されます。
 今回は、break 2のような構文を正しく認識させるためにParser/parser.cPython/Python-ast.cに変更を加える必要があります。ただし、直接書き換えるのではなく、以下のような手順で変更を加えます。

  1. Grammar/python.gramParser/Python.asdlを手動で書き換える。
  2. この2つのファイルから、それぞれParser/parser.cPython/Python-ast.cを再生成する。

再生成の手順については実装のときに後述します。Grammar/python.gramとは、トークンとASTの対応付け(いわゆる文法) が記述されたファイルです。対して、 Parser/Python.asdlにはASTノードの定義が記述されています。

ASTのデバッグ・最適化

 さらにASTはPython/ast.cによってデバッグが行われます。このファイルはPython3.8まではASTを生成する際に使われていたようですが、パーサーの仕組みが変わってからはただデバッグするファイルになっています。またPython/ast_opt.cによりASTに対して最適化が行われます。この2つのファイルにも変更を加えます。

ASTからバイトコードが生成されるまで

 ASTからは、まず制御フローグラフ(CFG)が生成されます。CFGとは、ブロックという単位で分けられており、そのブロックがノードとなっている有効グラフです。さらにこのCFGからバイトコードが生成されます。これらの一連の流れは、主にPython/compile.cの中で行われます。このとき、変数などを管理するための名前空間Python/symtable.cによって生成されており、これがPython/compile.cで利用されます。これらの2つのファイルにも変更を加えます。

手動で変更が必要なファイル

 以上をまとめると、今回のbreakの改造にあたって、手動で書き換える必要があるファイルは以下の通りになります。

  • Grammar/python.gram

  • Parser/Python.asdl

  • Python/ast.c

  • Python/ast_opt.c

  • Python/symtable.c

  • Python/compile.c


実装

 チェックリストを参考にして実装していきます。

Grammar/python.gram

 文法を規定するファイルです。breakが引数を取れるように変更します。"break" で検索すると、以下のブロックが見つかります。

Grammar/python.gram:L67

simple_stmt[stmt_ty] (memo):
    | assignment
    | e=star_expressions { _PyAST_Expr(e, EXTRA) }
    | &'return' return_stmt
    | &('import' | 'from') import_stmt
    | &'raise' raise_stmt
    | 'pass' { _PyAST_Pass(EXTRA) }
    | &'del' del_stmt
    | &'yield' yield_stmt
    | &'assert' assert_stmt
    | 'break' { _PyAST_Break(EXTRA) }
    | 'continue' { _PyAST_Continue(EXTRA) }
    | &'global' global_stmt
    | &'nonlocal' nonlocal_stmt

独特の記法で書かれているためわかりづらいですが、引数を取るもの(return, yieldなど)と取らないもの(break, continueなど)で書かれ方が違うみたいです。周りにならってbreakの行を以下のように変更します。

    | &'break' breaknew_stmt 

breaknew_stmtは新しく定めたものであり、その中身は以下のようにしました。

Grammar/python.gram:L394

breaknew_stmt[stmt_ty]:
    | 'break' a=NUMBER { _PyAST_Breaknew(a, EXTRA) }
    | 'break' { _PyAST_Break(EXTRA) }

ここで、上は引数a(NUMBER型)を受けた場合の動作、下は引数を受けなかった場合の動作を表しています。

変更を終えたら、以下のコマンドでParser/parser.cを再生成します。

$ make regen-pegen

Parser/Python.asdl

 ASTノードの定義を記述するファイルです。Breaknewという新たなASTノードを作成してしまったので、その定義を追加します。

Parser/Python.asdl:L52

    | Breaknew(expr value)

変更を終えたら、以下のコマンドでPython/Python-ast.cを再生成します。

$ make regen-ast

Python/ast.c

 ASTのデバッグ用のファイルです。"break" で検索すると、以下の関数が見つかります。

Python/ast.c:L674

static int
validate_stmt(struct validator *state, stmt_ty stmt)
{
    int ret = -1;
    Py_ssize_t i;
    if (++state->recursion_depth > state->recursion_limit) {
        PyErr_SetString(PyExc_RecursionError,
                        "maximum recursion depth exceeded during compilation");
        return 0;
    }
    switch (stmt->kind) {
    // 中略
    case Expr_kind:
        ret = validate_expr(state, stmt->v.Expr.value, Load);
        break;
    // 中略
    case Break_kind:
    // 中略
   }
    if (ret < 0) {
        PyErr_SetString(PyExc_SystemError, "unexpected statement");
        ret = 0;
    }
    state->recursion_depth--;
    return ret;
}

validate_stmtという関数名から、文(statement)の記述が妥当であるかを確認している関数であると推測できます。周りの記述を参考にして、switch文のcaseに以下を追加します。

    case Breaknew_kind:
        ret = validate_expr(state, stmt->v.Breaknew.value, Load);
        break;

Python/ast_opt.c

 ASTの最適化用のファイルです。"break" で検索すると、以下の関数が見つかります。

Python/ast_opt.c:L652

static int
astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{
    if (++state->recursion_depth > state->recursion_limit) {
        PyErr_SetString(PyExc_RecursionError,
                        "maximum recursion depth exceeded during compilation");
        return 0;
    }
    switch (node_->kind) {
    // 中略
    case Expr_kind:
        CALL(astfold_expr, expr_ty, node_->v.Expr.value);
        break;
    // 中略
    case Break_kind:
    case Continue_kind:
        break;
    // No default case, so the compiler will emit a warning if new statement
    // kinds are added without being handled here
    }
    state->recursion_depth--;
    return 1;
}

astfold_stmtという関数名から、文(statement)の「畳み込み」なるものを行う関数であると推測できます。

しかし、その実態はよくわからないので、再び周りの記述を真似して、switch文のcaseに以下を追加する。

    case Breaknew_kind:
        CALL(astfold_expr, expr_ty, node_->v.Breaknew.value);
        break;

Python/symtable.c

 ASTからバイトコードに変換するときに使用するファイルです。"break" で検索すると、以下の関数が見つかります。

Python/symtable.c:L1192

static int
symtable_visit_stmt(struct symtable *st, stmt_ty s)
{
    if (++st->recursion_depth > st->recursion_limit) {
        PyErr_SetString(PyExc_RecursionError,
                        "maximum recursion depth exceeded during compilation");
        VISIT_QUIT(st, 0);
    }
    switch (s->kind) {
    // 中略
    case Expr_kind:
        VISIT(st, expr, s->v.Expr.value);
        break;
    // 中略
    case Break_kind:
    // 中略
    }
    VISIT_QUIT(st, 1);
}

今まで通り、周りにならってBreaknewの動作を追加します。

    case Breaknew_kind:
        VISIT(st, expr, s->v.Breaknew.value);
        break;

しかし、これでは make するときにエラーが発生しました。

$ make
...
Segmentation fault (core dumped)
generate-posix-vars failed
make: *** [Makefile:616: pybuilddir.txt] エラー 1

ここでの変更は最小限にすることで、エラーを解消しました。

    case Breaknew_kind:

Python/compile.cを探る

 compile.cは実際にソースコードを生成するファイルなので、今回の実装の要になります。まずはcompile.cの中でbreakがどのように処理されているのかを調べてみます。検索機能を用いてbreakの処理に関係している関数を探すと、compiler_breakが見つかります。

compiler_breakを探る

compile.c:L3024

static int
compiler_break(struct compiler *c)
{
    struct fblockinfo *loop = NULL;
    /* Emit instruction with line number */
    ADDOP(c, NOP);
    if (!compiler_unwind_fblock_stack(c, 0, &loop)) {
        return 0;
    }
    if (loop == NULL) {
        return compiler_error(c, "'break' outside loop");
    }
    if (!compiler_unwind_fblock(c, loop, 0)) {
        return 0;
    }
    ADDOP_JUMP(c, JUMP_ABSOLUTE, loop->fb_exit);
    NEXT_BLOCK(c);
    return 1;
}

以下がcompiler_breakの中で行われている処理になります。

  1. compiler_unwind_fblock_stackloopに一番内側のループの情報をコピーしてきます。
  2. compiler_error:ループに入っていなかった場合のエラー処理をします。
  3. compiler_unwind_fblock:ループから抜けるときに必要な後処理をします。
  4. ADDOP_JUMPloopの出口にジャンプするバイトコードを生成します。
  5. NEXT_BLOCK:次のブロックに移動します。

よって、compiler_unwind_fblock_stackで外側のループの情報を取り出してくることができれば、多重ループを抜けるbreakが実装できることになります。

compiler_unwind_fblock_stackを探る

compile.c: L1871

/** Unwind block stack. If loop is not NULL, then stop when the first loop is encountered. */
static int
compiler_unwind_fblock_stack(struct compiler *c, int preserve_tos, struct fblockinfo **loop) {
    if (c->u->u_nfblocks == 0) {
        return 1;
    }
    struct fblockinfo *top = &c->u->u_fblock[c->u->u_nfblocks-1];
    if (loop != NULL && (top->fb_type == WHILE_LOOP || top->fb_type == FOR_LOOP)) {
        *loop = top;
        return 1;
    }
    struct fblockinfo copy = *top;
    c->u->u_nfblocks--;
    if (!compiler_unwind_fblock(c, &copy, preserve_tos)) {
        return 0;
    }
    if (!compiler_unwind_fblock_stack(c, preserve_tos, loop)) {
        return 0;
    }
    c->u->u_fblock[c->u->u_nfblocks] = copy;
    c->u->u_nfblocks++;
    return 1;
}

 この関数では、frame blockを抜ける時に必要な後処理をします。 まず変数の説明です。

  • c->u->u_nfblocks:現在入っているframe blockの数
  • c->u->u_fblock:現在入っているframe blockが格納されているスタック

ここでframe blockとはforwhiletryなどの情報を持っているブロックです。


そして、以下がcompiler_unwind_fblock_stackの中で行われている動作です。

  1. c->u->u_fblockの中を探索し終えたら、loopNULLのまま終了。
  2. 今いるframe blockがループ(forwhile)なら、loopにそのループの情報を渡して終了。
  3. 今いるframe blockを抜けるための後処理をする。
  4. スタックからframe blockを1つ取り出して、再帰処理をする。(1に戻る)
  5. スタックを元に戻す。

よって、上の処理2のところで終了する条件を、「今いるframe blockがn個目のループなら」のように変更すれば、ちょうどn個のループを抜けるbreakが実装できます。

compiler_unwind_fblockを探る

compile.c:L1769

/* Unwind a frame block.  If preserve_tos is true, the TOS before
 * popping the blocks will be restored afterwards, unless another
 * return, break or continue is found. In which case, the TOS will
 * be popped.
 */
static int
compiler_unwind_fblock(struct compiler *c, struct fblockinfo *info,
                       int preserve_tos)
{
    switch (info->fb_type) {
        case WHILE_LOOP:
        case EXCEPTION_HANDLER:
        case ASYNC_COMPREHENSION_GENERATOR:
            return 1;

        case FOR_LOOP:
            /* Pop the iterator */
            if (preserve_tos) {
                ADDOP(c, ROT_TWO);
            }
            ADDOP(c, POP_TOP);
            return 1;

        // 中略
       
    }
    Py_UNREACHABLE();
}

 ここでは、frame blockから抜ける時に、その種類に応じて後処理をするバイトコードを生成します。例えば、forループでは、ループを回すためのイテレータがスタックに入っているので、それを取り出すためにPOP_TOPバイトコードを生成します。この関数については、今回は変更する必要はありません。


compile.cの変更

f:id:doss2021_7:20211104234257p:plain
breakの処理の流れ

 上に示したのは、変更後のcompile.c内のbreakの処理の大まかな流れです。この図にそって処理するように、compile.cに変更・追加をしました。詳細は以下に記します。

compiler_unwind_fblock_stack_countの追加

 compiler_unwind_fblock_stackを参考にして作成したcompiler_unwind_fblock_stack_countを以下に示します。

static int
compiler_unwind_fblock_stack_count(struct compiler *c, int count, struct fblockinfo **loop) {
    if (c->u->u_nfblocks == 0 || count <= 0) {
        return 1;
    }

    struct fblockinfo *top = &c->u->u_fblock[c->u->u_nfblocks-1];
    if (loop != NULL && (top->fb_type == WHILE_LOOP || top->fb_type == FOR_LOOP)) {
        count--;
        if (count <= 0){
            *loop = top;
            return 1;
        }
    }
    struct fblockinfo copy = *top;
    c->u->u_nfblocks--;
    if (!compiler_unwind_fblock(c, &copy, 0)) {
        return 0;
    }
    if (!compiler_unwind_fblock_stack_count(c, count, loop)) {
        return 0;
    }
    c->u->u_fblock[c->u->u_nfblocks] = copy;
    c->u->u_nfblocks++;
    return 1;
}

この関数により、countの数だけループを抜けることができます。参考にしたcompiler_unwind_fblock_stackとの違いは、「ループ(forwhile)を見つけるまで再帰呼び出し」だったのが、「ループ(forwhile)をcountの数だけ見つけるまで再帰呼び出し」になっているところです。また、引数にあったpreserve_tosのフラグは、使わないので引数から外しました。

compiler_unwind_fblock_stack_allの追加

 今回の実装では、breakの引数が0の場合は全てのループを抜ける処理にしたいので、compiler_unwind_fblock_stack_allという関数を別に作りました。それを以下に示します。

static int
compiler_unwind_fblock_stack_all(struct compiler *c, struct fblockinfo **loop) {
    if (c->u->u_nfblocks == 0) {
        return 1;
    }
    int count = 0;
    for (int i = 0; i < (c->u->u_nfblocks); ++i){
        if (c->u->u_fblock[i].fb_type == WHILE_LOOP || c->u->u_fblock[i].fb_type == FOR_LOOP){
            count++;
        }
    }
    return compiler_unwind_fblock_stack_count(c, count, loop);
}

ここでは、c->u->u_fblockの中でwhileforの数を数え上げ、それをcompiler_unwind_fblock_stack_countcountとして渡しています。こうすることで、全てのループを抜ける処理ができます。

compiler_breaknewの追加

 breakに引数が与えられたときの処理をする関数であるcompiler_breaknewを追加したので、以下に示します。

static int
compiler_breaknew(struct compiler *c, stmt_ty s){
    struct fblockinfo *loop = NULL;
    ADDOP(c, NOP);
    int count = PyLong_AS_LONG(s->v.Breaknew.value->v.Constant.value);
    if (count == 0) {
        if (!compiler_unwind_fblock_stack_all(c, &loop)) {
            return 0;
        }
    } else {
        if (!compiler_unwind_fblock_stack_count(c, count, &loop)) {
            return 0;
        }
    }
    if (loop == NULL) {
        return compiler_error(c, "'break' not properly in loop");
    }
    if (!compiler_unwind_fblock(c, loop, 0)) {
        return 0;
    }
    ADDOP_JUMP(c, JUMP_ABSOLUTE, loop->fb_exit);
    NEXT_BLOCK(c);
    return 1;
}

breakの引数を取り出してcountに入れ、countの値によって、compiler_unwind_fblock_stack_countcompiler_unwind_fblock_stack_allを呼び出しています。

compiler_visit_stmtの変更

 この関数は、compiler_breakなどの呼び出し元になっている関数です。以下ように、compiler_breaknewを呼び出す分岐を付け加えます。

static int
compiler_visit_stmt(struct compiler *c, stmt_ty s)
{
    Py_ssize_t i, n;

    /* Always assign a lineno to the next instruction for a stmt. */
    SET_LOC(c, s);

    switch (s->kind) {
    // 中略
    case Break_kind:
        return compiler_break(c);
    case Breaknew_kind:
        return compiler_breaknew(c, s);
    // 中略
    }

    return 1;
}

完成!!

 実装の結果、breakは以下のように使えるようになりました。

  • break     ⇒ 既存のbreakと同じ動作(ループを1つ抜ける)

  • break 自然数 ⇒ 自然数の数だけループを抜ける。

  • break 0    ⇒ 全部のループを抜ける。

以下は実際に動かしてみた結果です。

# test01.py
n = 5
for i in range(n):
    for j in range(n):
        for k in range(n):
            print(i, j, k)
            if i == 0 and j == 0 and k == 2:
                break 0 #全部break
print("fin")
$ ./python3 test01.py
0 0 0
0 0 1
0 0 2
fin
# test02.py
n = 5
for i in range(n):
    for j in range(n):
        for k in range(n):
            print(i, j, k)
            if j == 0 and k == 0:
                break 2 #ループ2つ分break
print("fin")
$ ./python3 test02.py
0 0 0
1 0 0
2 0 0
3 0 0
4 0 0
fin

ちゃんと多重ループを一度に抜けることができました。


課題

 今回の実装だと、変数を用いて抜けるループの数を指定することができません。例えば以下のようにすると、エラーが吐かれます。

n = 5
for i in range(n):
    for j in range(n):
        for k in range(n):
            print(i, j, k)
            if j == 0 and k == 0:
                m = 2
                break m
print("fin")
    break m
          ^
SyntaxError: invalid syntax

これは、python.gramの中で、breakの引数の型がNUMBERになっているからです。よく探せば、もっと適した型を選択できたと思いますが、残念ながら型を吟味してる時間が確保できませんでした。

 一方で、変数で引数を指定されては、どこにジャンプするか分からなくて大変だという意見もあったので、変数を引数にとれるようにした方がいいかは微妙なところではあります。


感想

 最後には無事実装に成功しましたが、そこに至るまでにはかなり紆余曲折がありました。一日になんの成果も得られないことはざらにあり、果たして期日までに実装できるのか不安になることも多かったです。それでも、何かのきっかけで大きく進捗が生まれた時には興奮したし、最後に完成した時はその分大きな達成感を得ることができました。

 私たちの班はPythonをいじりましたが、Pythonには公式ドキュメントや先輩方の書いた記事など参考になる資料が多かったので、その点では他のソフトウェアをいじった班よりも恵まれていたと思います。しかし、過去にPythonをいじった先輩にもcompile.cの中で引数を取り出すことを試みた例は(調べた限りでは)なく、その方法を探すのにとても苦労しました。(なんと、s->v.Breaknew.value->v.Constant.valueという、構造体のとても深いところにあった!)

 今回の実験で、自分たちの手だけでPythonという大きなソフトウェアの改造を成し遂げられたのは、本当に得難い経験であり、とても大きな自信になりました。この実験を通じて得られた多くのものを、今後活かしていけたらと思っています。


参考