テストケース生成スクリプト

最近「テスト可能なコードが良いコード」と教わり、そういう観点でコーディングを進めてみようと思ってます。

今考えてる進め方はこんな感じです。
まず、実際に使うためのコードを大まかに書きます。この時点では、ほとんどメソッド名だけ。頭の中でなるべく細かくメソッドに分けていきます。メソッドのつながりを考えているうちに短いコードを書いたりもします。
ある程度まで書けたら、スクリプトを使って空のテストメソッドを自動生成します。このスクリプトは、カレントディレクトリ以下の*.pyから関数定義、メソッド定義を探してきて、対応するテストメソッドを作ります。自分用なので適当なところもありますが晒しときます。

make_tests.py

#!/usr/bin/python
import os, sys
import inspect, _ast

excluded_dir_names = ["tests"]
excluded_file_names = ["setup.py"]

def get_method_names(clsdef):
    method_names = [ obj.name for obj in clsdef.body
        if isinstance(obj, _ast.FunctionDef) ]
    return method_names
    
def get_funcs_of_module(module):
    source = inspect.getsource(module)
    ast = compile(source, "dont_care", "exec", _ast.PyCF_ONLY_AST)

    class_dict = dict()
    classes = [ obj for obj in ast.body
        if isinstance(obj, _ast.ClassDef) ]

    for cls in classes:
        methods = get_method_names(cls)
        class_dict.setdefault(cls.name, methods)

    func_names = [ obj.name for obj in ast.body
        if isinstance(obj, _ast.FunctionDef) ]

    return class_dict, func_names

def make_test_code(test_dir, module_name):
    def write(f, old, new):
        if new not in old:
            f.write(new)

    from_list = module_name.split(".")[:-1]
    module = __import__(module_name, globals(), locals(), from_list)
    class_dict, func_names = get_funcs_of_module(module)

    if not class_dict and not func_names:
        raise Exception("no classes or functions are defined.")
    methods = [ value for value in class_dict.values() if value ]
    if not methods:
        raise Exception("no methods are defined.")

    test_file_name = os.path.join(test_dir, 
        "test_%s.py" % module_name.replace(".", "_"))
    test_file = file(test_file_name, "a+")

    old_code = test_file.read()
    code = ("import unittest\n\n"
        "class Test(unittest.TestCase):\n\n")
    write(test_file, old_code, code)

    wrote_test_method = False
    if class_dict:
        for class_name, method_names in class_dict.iteritems():
            for method_name in method_names:
                code = ("    def test_%s_%s(self):\n"
                    "        pass\n\n") % (class_name.lower(), method_name)
                write(test_file, old_code, code)
                wrote_test_method = True

    if func_names:
        for func_name in func_names:
            code = ("    def test_%s(self):\n"
                "        pass\n\n") % func_name
            write(test_file, old_code, code)
            wrote_test_method = True

    if not wrote_test_method:
        test_file.write("    pass\n")
    test_file.close()

    return test_file_name
    
def proc_module(test_dir, module_name):
    try:
        test_file_name = make_test_code(test_dir, module_name)
        print "[created] %s: %s" % (module_name, test_file_name)
    except Exception, e:
        print "[ignore] %s: %s" % (module_name, e)

def get_module_name(walk_root, root, filename):
    file_path = os.path.join(root, filename)
    file_path = file_path.replace(walk_root, "")
    if filename == "__init__.py":
        module_name = file_path[:-12]
    else:
        module_name = file_path[:-3]
    module_name = module_name.replace(os.path.sep, ".")
    return module_name

def is_target_file(root, filename):
    if (not filename.endswith(".py") or
        filename in excluded_file_names):
        return False
    path_dirs = root.split(os.path.sep)
    ignored_dir = [ excluded 
        for excluded in excluded_dir_names
        if excluded in path_dirs ]
    return (False if ignored_dir else True)

def main():
    target_dir = os.getcwd()
    excluded_file_names.append(os.path.basename(sys.argv[0]))
    test_dir = sys.argv[1] if 1 < len(sys.argv) else "tests"
    walk_root = os.path.join(os.getcwd(), "")
    for root, dirs, files in os.walk(target_dir):
        for filename in files:
            if is_target_file(root, filename):
                module_name = get_module_name(walk_root, root, filename)
                proc_module(test_dir, module_name)

if __name__ == "__main__":
    main()

プロジェクトルートに置いて、以下のように実行します。

$ python make_tests.py [test_cases_dir]

TurboGearsプロジェクトでの実行例。

$ tg-admin quickstart myproj
$ cp make_tests.py myproj
$ cd myproj
# もとのテストケースを削除
$ rm -f myproj/tests/test_*
# 実行。モジュール名と結果が表示されます
$ python make_tests.py myproj/tests
[ignore] start-myproj: no classes or functions are defined.
[created] myproj.controllers: myproj/tests/test_myproj_controllers.py
[ignore] myproj: could not get source code
[ignore] myproj.release: no classes or functions are defined.
[ignore] myproj.commands: no methods are defined.
[ignore] myproj.model: no classes or functions are defined.
[ignore] myproj.json: no classes or functions are defined.
[ignore] myproj.templates: could not get source code
[ignore] myproj.config: could not get source code
$ ls myproj/tests
__init__.py  test_myproj_controllers.py

中身はこんな感じ。「test_クラス名_メソッド名」というメソッドができてます。

$ cat myproj/tests/test_myproj_controllers.py
import unittest

class Test(unittest.TestCase):

    def test_root_index(self):
        pass

できたテストメソッドを眺めて、メソッド単位はこんなもんでいいかなーと考えます。もとのコードを適当に修正したり追加したりします。
上のスクリプトはメソッドが増えた場合でも追記で対応できるようにしてあります。ある程度の機能単位で上記を繰り返します。
変更が大きい場合は、マージが面倒ですがテストケースをいったん作り直したほうがいいかもしれません。もっとちゃんと管理できるようにしたいところです。