O(N*log(N))のSorting Algorithm - Merge Sort(併合法)と平行プログラミング

O(N*log(N))のSorting Algorithmとして有名なHeap Sort(http://d.hatena.ne.jp/yokolet/20080629#1214800194)とQuick Sort(http://d.hatena.ne.jp/yokolet/20080702#1215024229)を調べたところなので、もう一つ同じ計算量になるMerge Sortについても調べてみました。

Merge Sortというのは集合を左と右に1/2, 左と右に1/2,...に分けていき、最終的に一つだけの集合にしてから、各要素を比較しながら2つの集合を併合(merge)していくsortingです。Quick Sortのように分割するところにある値に影響されることなく、均等に1/2に分けられていくので、計算量は安定してO(N*log(N))で済みます。

例えば、[18, 6, 9, 1, 4, 12, 15, 5, 6, 7, 11]というデータ(集合)をソートする場合は

[18, 6, 9, 1, 4][12, 15, 5, 6, 7, 11]
[18, 6][9, 1, 4][12, 15, 5][6, 7, 11]
[18][6][9][1, 4][12][15, 5][6][7, 11]
[18][6][9][1][4][12][15][5][6][7][11]  2分割終わり
[6, 18][9][1, 4][12][5, 15][6][7, 11] 右と左の組になっているところから併合開始 
[6, 18][1, 4, 9][5, 12, 15][6, 7, 11]
[1, 4, 6, 9, 18][5, 6, 7, 11, 12, 15]
[1, 4, 5, 6, 6, 7, 9, 11, 12, 15, 18] sorting終了

のようになります。

Java言語の場合、java.util.Arraysのうち、オブジェクト型の配列を対象とするsort()メソッドにMerge Sortが採用されています。ただし、配列のサイズが7以下の場合は単なる総当たりでの比較、O(N^2) (big o of N squared)のBubble Sortを行っています。サイズの小さい集合についてはBubble Sortの方が適しているためではないかと思われます。

またしても、途中経過を眺めてみたかったので、Topcoderで解説されているMerge Sortを参考にプログラムを書いてみました。どのくらいの時間がかかるのかも調べてみたかったので、上記の配列の他、ランダムに発生させた100000個の整数をsortingして、時間を計ってみました。

import java.util.ArrayList;

public class SimpleMergeSort {

    private int inputs = {18, 6, 9, 1, 4, 12, 15, 5, 6, 7, 11};

    private SimpleMergeSort() {
        ArrayList src = new ArrayList();
        for (int value : inputs) {
            src.add(value);
        }
        mergeSort(src, true);
        
        src.clear();
        int size = 100000;
        for (int i = 0; i < size; i++) {
            src.add((int)(10000.0 * Math.sin(i)));
        }

        long start = System.currentTimeMillis();
        mergeSort(src, false);
        long end = System.currentTimeMillis();
        System.out.println("Sorted in: " + (end - start) + " ms");
    }

    private ArrayList mergeSort(ArrayList src, boolean verbose) {
        print("src: ", src, verbose);
        if (src.size() <= 1) {
            return src;
        }
        int mid = src.size() / 2;
        ArrayList left = new ArrayList();
        ArrayList right = new ArrayList();
        for (int i = 0; i < src.size(); i++) {
            if (i < mid) {
                left.add(src.get(i));
            } else {
                right.add(src.get(i));
            }
        }
        left = mergeSort(left, verbose);
        right = mergeSort(right, verbose);
        ArrayList dest = SimpleMergeSort.merge(left, right);
        print("dest: ", dest, verbose);
        return dest;
    }

    public static void main(String args) {
        new SimpleMergeSort();
    }

    private static ArrayList merge(ArrayList left, ArrayList right) {
        ArrayList dest = new ArrayList();
        int leftIndex = 0;
        int rightIndex = 0;
        int destLength = left.size()+right.size();
        while (dest.size() < destLength) {
            if (leftIndex == left.size()) {
                dest.add(right.get(rightIndex));
                rightIndex++;
            } else if (rightIndex == right.size()) {
                dest.add(left.get(leftIndex));
                leftIndex++;
            } else if (left.get(leftIndex) < right.get(rightIndex)) {
                dest.add(left.get(leftIndex));
                leftIndex++;
            } else {
                dest.add(right.get(rightIndex));
                rightIndex++;
            }
        }
        return dest;
    }

    private void print(String comment, ArrayList array, boolean verbose) {
        if (verbose) {
            System.out.print(comment);
            for (int i : array) {
                System.out.print(i + "|");
            }
            System.out.println();
        }
    }
}

実行すると、このような結果になりました。

src: 18|6|9|1|4|12|15|5|6|7|11|
src: 18|6|9|1|4|
src: 18|6|
src: 18|
src: 6|
dest: 6|18|
src: 9|1|4|
src: 9|
src: 1|4|
src: 1|
src: 4|
dest: 1|4|
dest: 1|4|9|
dest: 1|4|6|9|18|
src: 12|15|5|6|7|11|
src: 12|15|5|
src: 12|
src: 15|5|
src: 15|
src: 5|
dest: 5|15|
dest: 5|12|15|
src: 6|7|11|
src: 6|
src: 7|11|
src: 7|
src: 11|
dest: 7|11|
dest: 6|7|11|
dest: 5|6|7|11|12|15|
dest: 1|4|5|6|6|7|9|11|12|15|18|
Sorted in: 329 ms

概念的なモデルとは違い、このプログラムでは左半分の並べ替えが終わったあとで右半分を並べ替えて、全体を併合して並べ替えているのがわかります。

では左半分と右半分を並行的に処理したらどうか、ということでQuick Sort(http://d.hatena.ne.jp/yokolet/20080702#1215024229)のところで試してみた、java.util.concurrentパッケージのクラスを使うMerge Sort版を作ってみました。

import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

public class ConcurrentMergeSort {
    private int inputs = {18, 6, 9, 1, 4, 12, 15, 5, 6, 7, 11};
    private ExecutorService service;

    private ConcurrentMergeSort() throws InterruptedException, ExecutionException {
        ArrayList src = new ArrayList();
        for (int value : inputs) {
            src.add(value);
        }
        service = Executors.newCachedThreadPool();
        service.submit(new MergeSortHandler(src, true));
        service.awaitTermination(10, TimeUnit.MILLISECONDS);
        
        src.clear();
        int size = 10000;
        for (int i = 0; i < size; i++) {
            src.add((int)(10000.0 * Math.sin(i)));
        }

        long start = System.currentTimeMillis();
        service.submit(new MergeSortHandler(src, false));
        service.awaitTermination(1000, TimeUnit.MILLISECONDS);
        long end = System.currentTimeMillis();
        System.out.println("Sorted in: " + (end - start) + " ms");
        service.shutdown();
    }
    
    public static void main(String args)
            throws InterruptedException, ExecutionException {
        new ConcurrentMergeSort();
    }
    
    private static ArrayList merge(ArrayList left, ArrayList right) {
        ArrayList dest = new ArrayList();
        int leftIndex = 0;
        int rightIndex = 0;
        int destLength = left.size()+right.size();
        while (dest.size() < destLength) {
            if (leftIndex == left.size()) {
                dest.add(right.get(rightIndex));
                rightIndex++;
            } else if (rightIndex == right.size()) {
                dest.add(left.get(leftIndex));
                leftIndex++;
            } else if (left.get(leftIndex) < right.get(rightIndex)) {
                dest.add(left.get(leftIndex));
                leftIndex++;
            } else {
                dest.add(right.get(rightIndex));
                rightIndex++;
            }
        }
        return dest;
    }
    
    class MergeSortHandler implements Callable {
        private ArrayList src;
        private boolean verbose;

        MergeSortHandler(ArrayList array, boolean verbose) {
            this.src = array;
            this.verbose = verbose;
        }

        @Override
        public ArrayList call() throws InterruptedException, ExecutionException {
            print("src: ", src, verbose);
            if (src.size() <= 1) {
                return src;
            }
            int mid = src.size() / 2;
            ArrayList left = new ArrayList();
            ArrayList right = new ArrayList();
            for (int i = 0; i < src.size(); i++) {
                if (i < mid) {
                    left.add(src.get(i));
                } else {
                    right.add(src.get(i));
                }
            }
            Future leftResult = service.submit(new MergeSortHandler(left, verbose));
            Future rightResult = service.submit(new MergeSortHandler(right, verbose));
            ArrayList dest = ConcurrentMergeSort.merge(leftResult.get(), rightResult.get());
            print("dest: ", dest, verbose);
            return dest;
        }

        private void print(String comment, ArrayList array, boolean verbose) {
            if (verbose) {
                System.out.print(comment);
                for (int i : array) {
                    System.out.print(i + "|");
                }
                System.out.println();
            }
        }
    }
}

複数のスレッドを使っているから速くなっているはずと期待したのですが、意外にもこの方法では逆に遅くなってしまっていました。実行結果は

src: 18|6|9|1|4|12|15|5|6|7|11|
src: 18|6|9|1|4|
src: 12|15|5|6|7|11|
src: 18|6|
src: 9|1|4|
src: 12|src: 18|
15|5|
src: 6|7|11|
src: 6|
src: 9|
src: 6|
dest: 6|18|
src: 12|
src: 15|5|
src: 7|11|
src: 5|
src: 1|4|
src: 15|
src: 1|
dest: 5|15|
src: 11|
dest: 5|12|15|
src: 7|
src: 4|
dest: 7|11|
dest: 1|4|
dest: 6|7|11|
dest: 1|4|9|
dest: 5|6|7|11|12|15|
dest: 1|4|6|9|18|
dest: 1|4|5|6|6|7|9|11|12|15|18|
Sorted in: 3179 ms

というように、バラバラと実行されていくので複数のスレッドで実行されて入るようなのですが、10000個のsorting(一桁少ない)を行うだけで3179msという予想外の数値。一方、シングルスレッドのsortingは一桁多い100000個でも329msです。どうやら、service.awaitTermination(1000, TimeUnit.MILLISECONDS);と各スレッドの処理が終了するのを待つ時間のパラメータがマジックナンバーで、小さすぎると実行が終わらないうちに強制的に終わりにしてしまうので、十分大きくしないとだめで、その結果、全体的に時間がかかってしまっているようです。チューンアップの方法はいろいろとありそうなので、これからいろいろと試してみようと思います。

さて、Javaのconcurrent programming関係はJDK 5, 6と改善されてきましたが、7でさらなる改善がある模様です。JSRも166(JDK 5), 166x(JDK 6), 166y(JDK 7)のように改善が続いています。(JDK 8は166zか?その先は?) JSR 166yではfork/join型の平行プログラミングができるようになること、より効果的なモニターの仕組みが導入されること、マルチCPUのハードウェアで確実に各CPUにスレッドを分散させて実行させることなどがあるようです。
参照

この新APIは一つ目にあげたDoug Lea氏のサイトからダウンロードできるので、Fork/Join版のMerge Sortも試してみました。まだあまり資料が無いので、もっといい方法がありそうなのですが、とりあえずこんなコードでMerge Sortができるようになりました。

import java.util.ArrayList;
import jsr166y.forkjoin.ForkJoinPool;
import jsr166y.forkjoin.RecursiveTask;

public class ForkJoinMergeSort {

    private int inputs = {18, 6, 9, 1, 4, 12, 15, 5, 6, 7, 11};
    private ForkJoinPool pool;

    private ForkJoinMergeSort() throws InterruptedException {
        ArrayList src = new ArrayList();
        for (int value : inputs) {
            src.add(value);
        }
        pool = new ForkJoinPool();
        System.out.println("pool size: " + pool.getPoolSize());
        pool.invoke(new MergeSortTask(src, true));

        src.clear();
        int size = 100000;
        for (int i = 0; i < size; i++) {
            src.add((int) (10000.0 * Math.sin(i)));
        }

        long start = System.currentTimeMillis();
        pool.invoke(new MergeSortTask(src, false));
        long end = System.currentTimeMillis();
        System.out.println("Sorted in: " + (end - start) + " ms");
        pool.shutdown();
    }

    public static void main(String args) throws InterruptedException {
        new ForkJoinMergeSort();
    }

    private static ArrayList merge(ArrayList left, ArrayList right) {
        ArrayList dest = new ArrayList();
        int leftIndex = 0;
        int rightIndex = 0;
        int destLength = left.size() + right.size();
        while (dest.size() < destLength) {
            if (leftIndex == left.size()) {
                dest.add(right.get(rightIndex));
                rightIndex++;
            } else if (rightIndex == right.size()) {
                dest.add(left.get(leftIndex));
                leftIndex++;
            } else if (left.get(leftIndex) < right.get(rightIndex)) {
                dest.add(left.get(leftIndex));
                leftIndex++;
            } else {
                dest.add(right.get(rightIndex));
                rightIndex++;
            }
        }
        return dest;
    }

    class MergeSortTask extends RecursiveTask {

        private ArrayList src;
        private boolean verbose;
        ArrayList dest;

        MergeSortTask(ArrayList array, boolean verbose) {
            this.src = array;
            this.verbose = verbose;
        }
        
        @Override
        protected ArrayList compute() {
            print("src: ", src, verbose);
            if (src.size() <= 1) {
                return src;
            }
            
            ArrayList left = new ArrayList();
            ArrayList right = new ArrayList();

            int mid = src.size() / 2;
            for (int i = 0; i < src.size(); i++) {
                if (i < mid) {
                    left.add(src.get(i));
                } else {
                    right.add(src.get(i));
                }
            }
            
            MergeSortTask leftTask = new MergeSortTask(left, verbose);
            MergeSortTask rightTask = new MergeSortTask(right, verbose);
            forkJoin(leftTask, rightTask);
            dest = ForkJoinMergeSort.merge*1;
            print("dest: ", dest, verbose);
            return dest;
        }

        private void print(String comment, ArrayList array, boolean verbose) {
            if (verbose) {
                System.out.print(comment);
                for (int i : array) {
                    System.out.print(i + "|");
                }
                System.out.println();
            }
        }
    }
}

実行すると、このような結果になりました。(pool sizeはCPUの数)

pool size: 2
src: 18|6|9|1|4|12|15|5|6|7|11|
src: 18|6|9|1|4|
src: 12|15|5|6|7|11|
src: 12|15|5|
src: 12|
src: 15|5|
src: 15|
src: 5|
dest: src: 18|6|
src: 18|
src: 6|
dest: 6|18|
src: 9|1|4|
src: 9|
src: 1|4|
src: 1|
src: 4|
dest: 1|4|
dest: 1|4|9|
dest: 1|4|6|9|18|
src: 6|7|11|
src: 6|
src: 7|11|
src: 7|
src: 11|
dest: 7|11|
dest: 6|7|11|
5|15|
dest: 5|12|15|
dest: 5|6|7|11|12|15|
dest: 1|4|5|6|6|7|9|11|12|15|18|
Sorted in: 358 ms

こんどもバラバラと実行されていて、複数のスレッドが動いているのがわかります。ただ、100000個のsortingはシングルスレッド版よりも遅くて、358msです。速いはずといわれているにも関わらず、あまり速くありませんでした。ただ、ArrayListではなくて、jsr166y.forkjoinパッケージのPararellArrayなどを使うとかいろいろとチューンアップする余地はありそうです。こちらのパッケージの方が、スレッド実行後の戻り値取得など、扱いやすくなっているかもしれません。

*1:ArrayList)leftTask.rawResult(), (ArrayList)rightTask.rawResult(