FrontPage

2008/11/04からのアクセス回数 15038

集合知は、Amazonで有名になったレコメンドシステムやGoogle SearchのPageRank アルゴリズムなどをPythonを使って、簡潔にかつ分かりやすく説明した素晴らしい本です。

ここでは、11章で紹介されている「遺伝的プログラミング」をjavaに移植しながら、その アルゴリズムをトレースしてみます。

遺伝的プログラミング

遺伝的プログラミングでは、プログラムの集合が互いに競り合い、進化します。

この進化には、

があります。

突然変異
プログラムを少しずつランダムに変化させる方法
交叉
良いプログラムの一部を自分のプログラムと一部と入れ替える方法

です。

ツリー構造のプログラム

プログラムを進化するために、プログラムをツリー構造で表現します。

11-2.jpg

集合知の図11-2から引用

このツリーに対応するjavaの関数は、

int func(int x, int y) {
	if (x > 3)
		return y + 5;
	else
		return y - 2;
}

となります。

Javaでツリーを表現する

Gp.javaというファイルに、

Fwrapper
関数ノードで利用される関数のラーパー。
Node
関数ノードのクラス(子ノードを持つ)、evaluateコールによって子ノードを評価し、関数にその結果を渡し、その結果を返す。
ParamNode
プログラムに渡されたパラメータノード、evaluateコールによってパラメータのインデックスを返す。
ConstNode
定数を返すノード。(javaへの移植では値をdoubleとしました)

ノード関連クラスの実装

javaでは、関数を引数に渡すことができないので、IFunctionインタフェースを持つオブジェクトを渡すことにしました。

interface IFunction {
	Object eval(List l);
}

class Node {
	IFunction	function;
	String		name;
	List		children;
	
	Node() {
		this(null, null);
	}
	
	Node(Fwrapper fw, List children) {
		if (fw != null) {
			this.function = fw.function;
			this.name = fw.name;
		}
		this.children = children;		
	}
	
	protected Object evaluate(List inp) {
		List results = new ArrayList();
		for (int i = 0; i < children.size(); i++) {
			Node node = (Node)children.get(i);
			results.add(node.evaluate(inp));
		}
		return this.function.eval(results);
	}
}

class Fwrapper extends Node {
	int	childCount;
	
	Fwrapper(IFunction function, int childCount, String name) {
		this.function = function;
		this.childCount = childCount;
		this.name = name;
	}
}

class ParamNode extends Node {
	int	idx;
	
	ParamNode(int idx) {
		this.idx = idx;
	}
	
	protected Object evaluate(List inp) {
		return (inp.get(idx));
	}	
}

class ConstNode extends Node {
	Double	v;
	
	ConstNode(double v) {
		this.v = v;
	}
	
	protected Object evaluate(List inp) {
		return (v);
	}	
}

関数の定義

Fwapperに渡す関数として、

を用意します。

class AddW implements IFunction {
	public Object eval(List l) {
		return (Double)l.get(0) + (Double)l.get(1);
	}	
}

class SubW implements IFunction {
	public Object eval(List l) {
		return (Double)l.get(0) - (Double)l.get(1);
	}	
}

class MulW implements IFunction {
	public Object eval(List l) {
		return (Double)l.get(0) * (Double)l.get(1);
	}	
}

class IfFunc implements IFunction {
	public Object eval(List l) {
		if ((Double)l.get(0) > 0) 
			return l.get(1);
		else
			return l.get(2);
	}	
}

class IsGreater implements IFunction {
	public Object eval(List l) {
		if ((Double)l.get(0) > (Double)l.get(1)) 
			return new Double(1);
		else
			return new Double(0);
	}	
}

ツリーの評価

AddW, SubW, MulW, Ifw, Gtwの関数を保持するラッパノードを持つリストflistを作成、サンプルツリーを作成する exampletreeメソッドを追加します。

public class Gp {
	static List		flist;
	static Fwrapper	addw;
	static Fwrapper	subw;
	static Fwrapper	mulw;
	static Fwrapper	ifw;
	static Fwrapper	gtw;
	
	public Gp() {
		addw = new Fwrapper(new AddW(), 2, "add");
		subw = new Fwrapper(new SubW(), 2, "subtract");
		mulw = new Fwrapper(new MulW(), 2, "multiply");
		ifw  = new Fwrapper(new IfFunc(), 3, "if");
		gtw  = new Fwrapper(new IsGreater(), 2, "isgreater");
		flist = new ArrayList();
		flist.add(addw);
		flist.add(mulw);
		flist.add(ifw);
		flist.add(gtw);
		flist.add(subw);
	}
	
	Node exampletree() {
		List gtwParList  = new ArrayList();
		gtwParList.add(new ParamNode(0));
		gtwParList.add(new ConstNode(3));
		
		List addwParList = new ArrayList();
		addwParList.add(new ParamNode(1));
		addwParList.add(new ConstNode(5));
		
		List subwParList = new ArrayList();
		subwParList.add(new ParamNode(1));
		subwParList.add(new ConstNode(2));
		
		List ifwParList  = new ArrayList();
		ifwParList.add(new Node(gtw, gtwParList));
		ifwParList.add(new Node(addw, addwParList));
		ifwParList.add(new Node(subw, subwParList));
		
		return (new Node(ifw, ifwParList));
	}
}

動作確認

main関数に記述する代わりに、test1メソッドに最初のテストメソッドをまとめました。

	void test1() {
		Node 	exampleTree = exampletree();
		List	para1 = new ArrayList();
		para1.add(new Double(2));
		para1.add(new Double(3));
		System.out.println(exampleTree.evaluate(para1));	
		
		List	para2 = new ArrayList();
		para2.add(new Double(5));
		para2.add(new Double(3));
		System.out.println(exampleTree.evaluate(para2));	
		
	}
	
	public static void main(String[] args) {
		Gp		gp = new Gp();
		gp.test1();
	}

実行すると

1.0
8.0

のように表示されます。

プログラムの表示

ツリー構造がどのようになっているかを表示するために、Nodeクラスにdisplayメソッドを追加します。

Nodeクラスには、

	protected void display(int indent) {
		for (int i = 0; i < indent; i++)
			System.out.print(" ");
		System.out.println(name);
		for (int i = 0; i < children.size(); i++) {
			children.get(i).display(indent + 1);
		}
	}

ParamNodeクラスには、

	protected void display(int indent) {
		for (int i = 0; i < indent; i++)
			System.out.print(" ");
		System.out.format("p%d\n", idx);
	}

ConstNodeクラスには、

	protected void display(int indent) {
		for (int i = 0; i < indent; i++)
			System.out.print(" ");
		System.out.format("%.1f\n", v);
	}

のようにdisplayメソッドを追加します。

動作確認

表示の方法は、トップのノードにdisplayメソッドを渡すだけです。

	void test2() {
		Node 	exampleTree = exampletree();
		exampleTree.display(0);
	}

では、実際にexamletreeで生成されたツリーを表示してみます。 mainメソッドで先ほどのtest1メソッドをtest2に変えて実行すると

if
 isgreater
  p0
  3.0
 add
  p1
  5.0
 subtract
  p1
  2.0

最初の集団を作る

ランダムな集合を作成するために、makerandomtreeメソッドを追加します。

	Node makeRandamTree(int pc) {
		return (makeRandamTree(pc, 4, 0.5, 0.6));
	}
	
	Node makeRandamTree(int pc, int maxdepth, double fpr, double ppr) {
		if (Math.random() < fpr && maxdepth >0) {
			Fwrapper f = (Fwrapper)Randam.choid(flist);
			List children = new ArrayList<Node>();
			for (int i = 0; i < f.childCount; i++ ) {
				Node node = makeRandamTree(pc, maxdepth-1, fpr, ppr);
				children.add(node);
			}
			return (new Node(f, children));
		}
		else if (Math.random() < ppr) {
			return (new ParamNode(Randam.randint(0, pc-1)));
		}
		else
			return (new ConstNode(Randam.randint(0, 10)));
	}

動作確認

makeRandamTreeメソッドのテストとして、test3を作成しました。

	void test3() {
		Node 	random1 = makeRandamTree(2);
		System.out.println("random1");
		random1.display(0);
		System.out.println("evaluate");
		
		List	para1 = new ArrayList();
		para1.add(new Double(7));
		para1.add(new Double(1));
		System.out.println(random1.evaluate(para1));			
		List	para2 = new ArrayList();
		para2.add(new Double(2));
		para2.add(new Double(4));
		System.out.println(random1.evaluate(para2));	
		
		Node 	random2 = makeRandamTree(2);
		System.out.println("random2");
		random2.display(0);
		System.out.println("evaluate");
		
		List	para3 = new ArrayList();
		para3.add(new Double(5));
		para3.add(new Double(3));
		System.out.println(random2.evaluate(para3));			
		List	para4 = new ArrayList();
		para4.add(new Double(5));
		para4.add(new Double(20));
		System.out.println(random2.evaluate(para4));		
	}

実行するたびに、異なる結果が得られます。 例として、

random1
subtract
 add
  if
   2.0
   multiply
    4.0
    3.0
   subtract
    10.0
    p1
  9.0
 p1
evaluate
20.0
17.0
random2
p1
evaluate
3.0
20.0

のような出力になります。

単純な数学的テスト

遺伝的プログラミングのテストとして、単純な数学的関数を推定するテストをします。

ここで、未知の関数として、

	Double hiddenFunction(int x, int y) {
		double val = x*x + 2*y + 3*x + 5;
		return (new Double(val));
	}

を定義します。

この関数を推定するために、200個のデータセットを用意します。 各行は、[Xの値、Yの値、関数値]の3つ組です。

	List buildHiddenSet() {
		List rows = new ArrayList();
		for (int i = 0; i < 200; i++) {
			int x = Randam.randint(0, 40);
			int y = Randam.randint(0, 40);
			List cols = new ArrayList();
			cols.add(new Double(x));
			cols.add(new Double(y));
			cols.add(hiddenFunction(x, y));
			rows.add(cols);
		}
		return (rows);
	}

遺伝的プログラミングがどの程度正しく推定したかを算出するために、scoreFunctionを以下の様に定義します。

	int scoreFunction(Node tree, List s) {
		double dif = 0;
		for (int i = 0; i < s.size(); i++) {
			List cols = (List)s.get(i);
			double v = Double.parseDouble(tree.evaluate(cols).toString());
			dif += Math.abs(v - ((Double)cols.get(2)).doubleValue());		
		}
		return ((int)dif);
	}

動作確認

スコアが計算を確認するtest4メソッドを以下のように追加します。

	void test4() {
		List hiddenset = buildHiddenSet();
		Node 	random1 = makeRandamTree(2);
		System.out.println("random1");
		random1.display(0);
		System.out.println("score");
		System.out.println(scoreFunction(random1, hiddenset));
		Node 	random2 = makeRandamTree(2);
		System.out.println("random2");
		random2.display(0);
		System.out.println("score");
		System.out.println(scoreFunction(random2, hiddenset));
	}

実行結果は、

random1
add
 5.0
 if
  add
   p0
   multiply
    9.0
    8.0
  6.0
  add
   p0
   8.0
score
133800
random2
add
 if
  p0
  5.0
  subtract
   p1
   add
    p1
    p1
 6.0
score
134010

とかなり大きな値になります。

プログラムの突然変異

これから、遺伝的プログラミングの突然変異をmutateメソッドに実装します。

mutateでは、

11-4.jpg

のようにツリーの一部に突然変異を適応します(集合知の図11-4から引用)。

mutateメソッドは以下のようになります。

	Node mutate(Node t, int pc) {
		return (mutate(t, pc, 0.1));
	}
	
	Node mutate(Node t, int pc, double probchange) {
		if (Math.random() < probchange) {
			return (makeRandamTree(pc));
		}
		else {
			Node result = (Node)t.clone();
			if (t.children != null) {
				List children = new ArrayList();
				for (int i = 0; i < t.children.size(); i++) {
					Node c = t.children.get(i);
					children.add(mutate(c, pc, probchange));
				}
				result.children = children;
			}
			return (result);
		}
	}

動作確認

では、mutateの動作を確認してみましょう。

以下のtest5を追加します。

	void test5() {
		List hiddenset = buildHiddenSet();
		Node 	random2 = makeRandamTree(2);
		System.out.println("random2");
		random2.display(0);
		System.out.println("random2 socre=" + scoreFunction(random2, hiddenset));
		System.out.println("mutate");
		Node	muttree = mutate(random2, 2);
		muttree.display(0);
		System.out.println("muttree socre=" + scoreFunction(muttree, hiddenset));
	}

実行結果は、毎回異なります。ほとんど変わらないときと以下のように変化する場合が あります。

random2
subtract
 p0
 8.0
random2 socre=133814
mutate
add
 isgreater
  0.0
  if
   add
    4.0
    p1
   p1
   8.0
 add
  p1
  subtract
   multiply
    p1
    p0
   subtract
    9.0
    p1
muttree socre=67170

交叉(Crossover)

いよいよ最後の交叉を実装します。

crossoverメソッドは以下のようになります(ほとんどpythonのまま)。

	Node crossover(Node t1, Node t2) {
		return (crossover(t1, t2, 0.7, 1));
	}
	
	Node crossover(Node t1, Node t2, double probswap, int top) {
		if (Math.random() < probswap && top == 0) {
			return ((Node)t2.clone());
		}
		else {
			Node result = (Node)t1.clone();
			if (t1.children != null && t2.children != null) {
				List children = new ArrayList<Node>();
				for (int i = 0; i < t1.children.size(); i++) {
					Node c =  t1.children.get(i);
					children.add(crossover(c, (Node)Randam.choid(t2.children), probswap, 0));
				}
				result.children = children;
			}
			return (result);
		}
	}

deepCopy対応

交叉では、ツリーの完全コピーを必要としますので、各ノードクラスにcloneメソッドを追加し、NodeクラスにはdeepCopyメソッドを追加しました。

Nodeクラスの追加

	protected void deepCopy(Node dst, Node src) {
		dst.function = src.function;
		dst.name = src.name;
		if (src.children != null)
			dst.children = new ArrayList<Node>(src.children);		
	}
	
	public Object clone() {
		Node dst = new Node();
		dst.deepCopy(dst, this);
		return (dst);
	}

ParamNodeクラスの追加

	public Object clone() {
		ParamNode dst = new ParamNode(idx);
		dst.deepCopy(dst, this);
		return (dst);
	}

ConstNodeクラスの追加

	public Object clone() {
		ConstNode dst = new ConstNode(v);
		dst.deepCopy(dst, this);
		return (dst);
	}

Fwrapperクラスの追加

	public Object clone() {
		Fwrapper dst = new Fwrapper(function, childCount, name);
		dst.deepCopy(dst, this);
		return (dst);
	}

Randam.choice対応

また、リストから1個をランダムに抽出するchoiceメソッドをRandam.javaに追加しました。

	static Object choid(List list) {
		if (list != null) {
			int idx = (int)(list.size()*Math.random());
			return (list.get(idx));
		}
		else
			return (null);
	}

動作確認

crossoverの動作確認をするために、test6メソッドを追加します。

	void test6() {
		List hiddenset = buildHiddenSet();
		Node 	randam1 = makeRandamTree(2);
		System.out.println("random1");
		randam1.display(0);

		Node 	randam2 = makeRandamTree(2);
		System.out.println("randam2");
		randam2.display(0);
		
		Node	cross = crossover(randam1, randam2);
		System.out.println("cross");
		cross.display(0);
	}

何度か動作すると結構おもしろい結果がでます。 以下は一例です。

random1
subtract
 p0
 p1

randam2
add
 10.0
 add
  subtract
   p1
   8.0
  p1

cross
subtract
 p0
 10.0

環境を作り上げる

必要なメソッドがすべて揃ったので、最後に進化するための環境を整えます。

evolveメソッドは、以下の通りです。 終了条件は、

となっています。

	void evolve(int pc, int popsize, IRankingFunction raunkingFunction, int maxgen, 
			double mutationRate, double breedingRate, double pexp, double pnew) {
		List population = new ArrayList();
		for (int i = 0; i < popsize; i++) {
			population.add(makeRandamTree(pc));
		}
		List scores = null;
		List first;
		for (int i = 0; i < maxgen; i++) {
			scores = raunkingFunction.ranking(population);			
			first = (List)scores.get(0);
			System.out.println(first.get(0).toString());
			int score = ((Double)first.get(0)).intValue();
			if (score == 0)
				break;
			
			// add top 2 nodes
			List newop = new ArrayList();
			newop.add(first.get(1));
			List second = (List)scores.get(1);
			newop.add(second.get(1));
			
			// generate nest genration.
			while (newop.size() < popsize) {
				if (Math.random() > pnew) {
					List l1 = (List)scores.get(selectindex(pexp));
					List l2 = (List)scores.get(selectindex(pexp));
					newop.add(mutate(
							crossover((Node)l1.get(1), (Node)l2.get(1), breedingRate, 1),
							pc, mutationRate));
				}
				else
					newop.add(makeRandamTree(pc));
			}
			population = newop;
		}
		first = (List)scores.get(0);
		Node best = (Node)first.get(1);
		best.display(0);
	}

pythonではランキング関数をメソッド渡ししているので、IFunctionと同様にインタフェースに変更しました。

interface IRankingFunction {
	List	ranking(List population);
}

	IRankingFunction getRankFunction(List dataSet) {		
		class RankingFunction implements IRankingFunction {
			List	dataSet;
			RankingFunction(List dataSet) {
				this.dataSet = dataSet;
			}
			public List ranking(List population) {
				List scores = new ArrayList();
				for (int i = 0; i < population.size(); i++) {
					Node t = (Node)population.get(i);
					List taple = new ArrayList();
					taple.add(new Double(scoreFunction(t, dataSet)));
					taple.add(t);
					scores.add(taple);
				}
				Collections.sort(scores, new ScoreComparator());
				return (scores);
			}			
		}
		
		return (new RankingFunction(dataSet));
	}

動作確認

いよいよ、遺伝的プログラミングの実力を見るときがきました。

	void test7() {
		List hiddenset = buildHiddenSet();
		IRankingFunction rk = getRankFunction(hiddenset);
		evolve(2, 500, rk, 500, 0.2, 0.1, 0.7, 0.1);
	}

解は、1つではありませんが、例を以下に示します。

13229.0
5388.0
2490.0
2242.0
778.0
396.0
196.0
186.0
186.0
186.0
186.0
29.0
29.0
8.0
8.0
8.0
0.0
add
 add
  add
   add
    p1
    p1
   multiply
    p0
    p0
  add
   p0
   5.0
 add
  p0
  p0

ここで、p0がX、p1がYです。

わずか17回で収束し、解として

Y+Y+X*X+X+5+X+X

となり、 hiddenFunctionの

X**2+2*Y+3*X+5

を見つけることができました。

完全なソース

以下に完全なjavaのソースを添付します。

fileGp.java
fileRandam.java

コメント

この記事は、

選択肢 投票
おもしろかった 8  
そうでもない 0  
わかりずらい 1  

皆様のご意見、ご希望をお待ちしております。


(Input image string)

トップ   新規 一覧 単語検索 最終更新   ヘルプ   最終更新のRSS
SmartDoc