使用Rxjava计算圆周率

对于圆周率率的求法有很多,最近看到一个Spark的例子使用了map和reduce的方法来求一个圆周率的近似值。这个算法的思想是这样的:

  1. 半径为r的圆的面积CA = π × r × r
  2. 这个园的外切正方形的面积SA = 4 × r × r
  3. π = CA / r / r = CA × 4 / SA

根据上面的推导,我们只要知道圆形和正方形的面积之比就行了。然后我们在这个正方形的面积内随机生成足够多的点,用落在圆内的点数除以总的点数就可以得到一个近似的比值了。当然随机值的数目越多,得到的结果就会越精确。

具体程序的实现上,我们假设圆心为(1,1)的圆的半径为1,所以正方形的边长就为2. 然后使用map来生成一个随机数并判断这个数是否在圆内,通过reduce来统计圆内的数目。这个算法是使用Spark在集群上进行计算的,所以我们创建多个工作在不同线程上的Observable对象来模拟多个任务,在最后使用zip操作符收集所有任务的计算结果并求平均值。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
private Observable<Double> createObservable(final int num) {
return Observable.range(0, num)
.map(new Func1<Integer, Integer>() {
public Integer call(Integer integer) {
double x = mRandom.nextDouble() * 2 - 1;
double y = mRandom.nextDouble() * 2 - 1;
return (x * x + y * y) < 1 ? 1 : 0;
}
}).reduce(new Func2<Integer, Integer, Integer>() {
public Integer call(Integer integer, Integer integer2) {
int reduce = integer + integer2;
return reduce;
}
})
.map(new Func1<Integer, Double>() {
public Double call(Integer integer) {
double v = 4.0 * integer / num;
System.out.println("V:" + v);
return v;
}
})
.subscribeOn(Schedulers.computation());

}

public Observable<Double> getPi(int workNum, int num) {
ArrayList<Observable<Double>> list = new ArrayList<Observable<Double>>(workNum);
for (int i = 0; i < workNum; i++) {
list.add(createObservable(num));
}
//use zip to get the average value of all workers.
return Observable.zip(list, new FuncN<Double>() {
public Double call(Object... args) {
int len = args.length;
double result = 0;
for (int i = 0; i < len; i++) {
result += (Double) (args[i]);
}
return result / len;
}
});

}

编写testcase来测试一下我们的程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

@Test
public void test1() {
final CountDownLatch latch = new CountDownLatch(1);
PI pi = new PI();
final double[] result = {0};
pi.getPi(4, 1000000)
.subscribe(new Action1<Double>() {
public void call(Double aDouble) {
System.out.print(aDouble);
result[0] = aDouble;

latch.countDown();
}
});

try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
assertEquals(PiValue, result[0], 0.001);

最后的运行结果如下,创建的点数越多,得到的结果越跟真实值相近,当然计算所花费的时间就会越多。

1
2
3
4
5
V:3.143012
V:3.138528
V:3.141844
V:3.144612
3.141999

完整代码