python の内包表記で list の flatten を書く

(以下, Python 2.6.4 でやってます.)

python には内包表記という filter, map, reduce の代替ができるような表記方法があります.
詳しくは検索してもらうとして, 今回のネタはその内包表記で flatten という関数を書こうというものです.

flatten というのは, ネストしたリスト構造を単一のリストにまとめる, というものです.

# flatten
# [[1, 2, 3], [4, 5, 6], [7, 8, 9]] -> [1, 2, 3, 4, 5, 6, 7, 8, 9]

ちょっと仕事で書く場面があり, 時間の関係で綺麗に内包表記で書くのは断念してしまったので, 代わりにこの日記で拘泥ってみようという魂胆です.

まず簡単に1段階だけ.

(ここのパクリです.)

>>> mat = [
...     [1, 2, 3],
...     [4, 5, 6],
...     [7, 8, 9],
... ]
>>> [row[i] for row in mat for i in xrange(len(row))]
[1, 2, 3, 4, 5, 6, 7, 8, 9]

2つの for... を逆にすると怒られました.

>>> [row[i] for i in xrange(len(row)) for row in mat]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: list index out of range

for... の部分は前の方から評価されていくみたいです. 順序で言えば C とかの for ループと同じく, 大きいループを先に書くんですね.
ただ, これはちょっと直感的じゃないなぁ? 内側に行くほど小さいループになるほうが自然じゃないかなぁ?(不満) まだ Python に慣れていないだけなのかなぁ?

またこんな解答も見付けました.

sum(mat, [])

空リストを初期値として, mat をイテレートしながら足していくそうです.
けっこう綺麗ですが, sum という名前からはイメージが遠そうですね. (我儘??)

さて, 今日の仕事で出喰わした面倒なパターン

# ['a b c', 'd e f', 'g h i'] -> ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']

ということをしたかったのですが, なかなか面倒なようです.

まずは, ネストしたリストを出してみます.

>>> lst = ['a b c', 'd e f', 'g h i']
>>> [[char for char in elem.split()] for elem in lst]
[['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']]

うまく処理ができました. さて, これを上でやったように処理するには,

>>> [row[i] for row in [[char for char in elem.split()] for elem in lst] for i in xrange(len(row))]
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']

となりますが, なんと言うか, 魔術書の様相を呈してきましたね……. まだまだだよ, とツッコマれそうですが, 初見でこの処理の意味は私には分かりそうにはありません.

何とか処理を2段階に分けずに, 一気に目的のリストが作れないか考えてみました.

>>> [elem.split()[i] for elem in lst for i in xrange(len(elem.split()))]
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']

ふぅ, これならなんとか許容できそうです.
split の呼び出しを2回行っているのが気に入らないですが, elem.split() の結果を変数に割り当てておく方法を知らないので如何ともし難いですなぁ.
(Python では代入は式ではなく文であるので, 無理じゃないかなぁ.)

このパターンでは以下のように sum を使う方が綺麗に見えますね.

>>> sum([elem.split() for elem in lst], [])
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']

そして仕事で本当に使いたかったのは,

>>> ','.join(sum([elem.split() for elem in lst], []))
'a,b,c,d,e,f,g,h,i'

このような処理だったので, 一旦はこれを答えとしていいのかなぁ, という気持ちになりました.

実際の仕事では,

>>> ','.join([','.join(elem.split()) for elem in lst])
'a,b,c,d,e,f,g,h,i'

と書いて解決しました. それなりに意味の分かる形で書けたとは思うのですが, この join を2つ書くのが悔しくて悔しくて, 今こんな時間になっても家で検討しています.

多段階の場合

さて, さっきまでのはネストの深さが決まってるものでした.
ネストの深さが任意のものに対してはどうすればいいのでしょうか?

そこらへんは他のブログで試みられています. 詳細はそっちを参考にしてもらうとして, 以下のような解答があります.

def flatten(x):
    """flatten(sequence) -> list

    Returns a single, flat list which contains all elements retrieved
    from the sequence and all recursively contained sub-sequences
    (iterables).

    Examples:
    >>> [1, 2, [3,4], (5,6)]
    [1, 2, [3, 4], (5, 6)]
    >>> flatten([[[1,2,3], (42,None)], [4,5], [6], 7, MyVector(8,9,10)])
    [1, 2, 3, 42, None, 4, 5, 6, 7, 8, 9, 10]"""

    result = []
    for el in x:
        #if isinstance(el, (list, tuple)):
        if hasattr(el, "__iter__") and not isinstance(el, basestring):
            result.extend(flatten(el))
        else:
            result.append(el)
    return result
http://kogs-www.informatik.uni-hamburg.de/~meine/python_tricks