java forkjoin 模型

实际上就是多线程的一种递归模型

所有能拆分的任务都可以分布到各个线程中去

典型的递归问题比如汉诺塔

请参考 http://lizhe.name.csdn.net/node/82

用forkjoin重写之后是这样的

package testgc;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class ForkjoinHanoi {

    public static void main(String[] args) throws InterruptedException, ExecutionException {
        
        ForkJoinPool forkJoinPool = new ForkJoinPool(4);
        Long f_start = System.currentTimeMillis();
        ForkJoinTask<Integer> taskResult = forkJoinPool.submit(new HanoiTasklet(10));
        System.out.println(taskResult.get());
        Long f_end = System.currentTimeMillis();
        forkJoinPool.shutdown();
        System.out.println("ffff"+(f_end-f_start));

    }

}

/*
 *  def countStep(n):
        if n==0:
            return 0;
        else:
            return 2*countStep(n-1)+1
        
    print(countStep(100));
 */

class HanoiTasklet extends RecursiveTask<Integer> {
    
    int step = 0;
    
    public HanoiTasklet(int step){
        this.step = step;
    }

    @Override
    protected Integer compute() {
        
        if(step==0){
            return 0;
        }else{
            HanoiTasklet task = new HanoiTasklet(step-1);
            task.fork();
            int n = task.join();
            return 2*n + 1;
        }

        
    }
    
}

下面是一个排序例子,实际上有了stream之后,你根本用不到手写这么复杂的东西了

package testgc;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TestForkJoin {

    public static void main(String[] args) throws InterruptedException, ExecutionException {

        List<Integer> list = Stream.generate(Math::random).filter(i -> i < 1).limit(10000000).map(i -> i * 100)
                .map(Double::intValue).collect(Collectors.toList());

//        List<Integer> list = Stream.of(83, 16, 33, 16, 68, 16).collect(Collectors.toList());
//        System.out.println(list);
        
        Long s_start0 = System.currentTimeMillis();
        List<Integer> list0 = list.stream().sorted().collect(Collectors.toList());
        Long s_end0 = System.currentTimeMillis();
        System.out.println("ssss"+(s_end0-s_start0));

        Long s_start = System.currentTimeMillis();
        List<Integer> list2 = list.stream().parallel().sorted().collect(Collectors.toList());
        Long s_end = System.currentTimeMillis();
        System.out.println("ssss"+(s_end-s_start));

        
        
        ForkJoinPool forkJoinPool = new ForkJoinPool(4);
        Long f_start = System.currentTimeMillis();
        ForkJoinTask<List<Integer>> taskResult = forkJoinPool.submit(new Tasklet(list));
        taskResult.get();
        Long f_end = System.currentTimeMillis();
        forkJoinPool.shutdown();
        System.out.println("ffff"+(f_end-f_start));

    }

}

class Tasklet extends RecursiveTask<List<Integer>> {

    List<Integer> list = null;

    public Tasklet(List list) {
        this.list = list;
    }

    @Override
    protected List<Integer> compute() {

        if (list.size() == 0 || list.stream().distinct().collect(Collectors.toList()).size()==1) {
            return list;
        }
        if (list.size() > 2) {

            int middle = list.size() / 2;
            int middleValue = list.get(middle);

            List<Integer> low = new ArrayList<Integer>();
            List<Integer> high = new ArrayList<Integer>();
            
            List<Integer> low_tmp = new ArrayList<Integer>();
            List<Integer> high_tmp = new ArrayList<Integer>();
            
            // 79,1,30
            for (Integer i : list) {
                if (i < middleValue) {
                    low.add(i);
                } else if (i > middleValue) {
                    high.add(i);
                } else {
                    if (low.isEmpty()) {
                        low_tmp.add(i);
                    } else if (high.isEmpty()) {
                        high_tmp.add(i);
                    } else{
                        low_tmp.add(i);
                    }
                }
            }
            
            low.addAll(low_tmp);
            high.addAll(high_tmp);

            List<Integer> lows = null;
            List<Integer> highs = null;

            if (low.size() != 0) {
                Tasklet t1 = new Tasklet(low);
                t1.fork();
                lows = t1.join();
            } else {
                lows = new ArrayList<Integer>();
            }

            if (high.size() != 0) {
                Tasklet t2 = new Tasklet(high);
                t2.fork();
                highs = t2.join();
            } else {
                highs = new ArrayList<Integer>();
            }

            lows.addAll(highs);

            return lows;

        } else if (list.size() == 2) {
            Collections.sort(list);
            return list;
        } else {
            return list;
        }

    }

}