2009年04月13日
3D版!「K-Means法」ビジュアライズしてみた
前回の クラスタリングの定番アルゴリズム「K-means法」をビジュアライズしてみた が好評だったので、3D空間でも K平均法をビジュアライズしてみました。
ちょっと重めなのでクリックすると始まります。さらにクリックして、1ステップずつ動かしてみてください。3次元にしてみると、宇宙空間のように見えてきて美しいです。
前回からは数式の上ではほとんど変わってなくて、2次元空間上での距離だったのが3次元での距離になったぐらいです。K平均法については、前回のエントリ クラスタリングの定番アルゴリズム「K-means法」をビジュアライズしてみた をご覧ください。
ちなみに、3次元座標の計算は、Flash 10 から導入された Matrix3D を使っています。ソースは [as]FP10をそろそろ。Matrix3Dとか [NUTSU] を参考にさせてもらいました。Zソートしてないので、たまに奥のものが手前に表示されたりするのですが、あまり気にならないのでそのままにしてます。
以下、ソースコードです(364行)。
// K-Means algorithm visualization
// - requires sketchbook
// http://sketchbook.libspark.org/
// - requires tweener
// http://code.google.com/p/tweener/
package{
import flash.display.*;
import flash.events.Event;
import flash.text.TextField;
import flash.geom.*;
import sketchbook.colors.ColorSB;
import caurina.transitions.Tweener;
import flash.system.Capabilities;
[SWF(backgroundColor="#223344", frameRate=18, width=400, height=400)]
public class KMeans3D extends Sprite{
private var k:int;
private var n:int;
private var colors:Array;
private var dots:Vector.<Dot>;
private var dotsPos:Vector.<Number>;
private var dotsView:Vector.<Number>;
private var groups:Array;
private var centers:Vector.<Center>;
private var changed:Boolean;
private var canvas:Sprite = new Sprite();
private var lineCanvas:Sprite = new Sprite();
private var started:Boolean = false;
private var _matrix:Matrix3D = new Matrix3D;
private var rotateAxis:Vector3D = new Vector3D( 0.2, 1.0, 0.1 );
private const WIDTH:int = 400;
private const HEIGHT:int = 300;
private const SIZE:int = 200;
private const ANIMATE:Number = .4;
public function KMeans3D():void{
stage.scaleMode = "noScale";
var version:int = parseInt(Capabilities.version.split(" ")[1].split(",")[0]);
if (version < 10){
var tf:TextField = new TextField();
tf.textColor = 0xffffff;
tf.autoSize = "left";
tf.text ="Flash Player 10 or later required.";
addChild(tf);
return;
}
// init canvas
canvas.x = lineCanvas.x = WIDTH / 2;
canvas.y = lineCanvas.y = HEIGHT / 2;
canvas.graphics.beginFill(0x000000, 0);
canvas.graphics.drawRect(-WIDTH / 2, -HEIGHT / 2, WIDTH, HEIGHT);
canvas.graphics.endFill();
canvas.useHandCursor = buttonMode = true;
canvas.mouseChildren = false;
addChild(lineCanvas);
addChild(canvas);
var state:int = 0;
canvas.addEventListener("click", function(event:Event):void{
if (!started){
addEventListener("enterFrame", render);
started = true;
return;
}
if(state == 0){
moveCenter();
}else{
updateGroups();
}
state = (state + 1) % 2;
});
// init inputs
var nInput:Input = new Input("N (the number of node):", "100");
nInput.y = HEIGHT + 5;
addChild(nInput);
var kInput:Input = new Input("K (the number of cluster):", "5");
kInput.y = nInput.y + nInput.height + 5;
addChild(kInput);
var nextButton:Button = new Button("Step");
nextButton.y = kInput.y + kInput.height + 5;
addChild(nextButton);
nextButton.addEventListener("click", canvas.dispatchEvent);
var resetButton:Button = new Button("Restart");
resetButton.x = nextButton.width + 5;
resetButton.y = nextButton.y;
addChild(resetButton);
resetButton.addEventListener("click", function(event:Event):void{
changed = true;
state = 0;
k = kInput.value;
n = nInput.value;
init();
});
resetButton.dispatchEvent(new Event("click"));
render();
}
private function init():void{
// remove previous sprites
graphics.clear();
for each(var dot:Dot in dots){
canvas.removeChild(dot);
}
for each(var center:Center in centers){
if(center) canvas.removeChild(center);
}
// init colors
colors = [];
for(var i:int = 0; i < k; i++){
colors.push(ColorSB.createHSB(i * 360 / k, 90, 100).value);
}
// init dot
dots = new Vector.<Dot>(n);
dotsPos = new Vector.<Number>((n + k) * 3);
dotsView = new Vector.<Number>((n + k) * 3);
groups = [];
centers = new Vector.<Center>(k);
for(i = 0; i < n; i++){
var group:int = Math.floor(Math.random() * k);
dots[i] = new Dot(colors[group]);
canvas.addChild(dots[i]);
dotsPos[i * 3 + 0] = Math.random() * SIZE - SIZE / 2;
dotsPos[i * 3 + 1] = Math.random() * SIZE - SIZE / 2;
dotsPos[i * 3 + 2] = Math.random() * SIZE - SIZE / 2;
if(!groups[group]) groups[group] = [];
groups[group].push(i);
}
}
private function render(event:Event = null):void{
_matrix.appendRotation( 1, rotateAxis );
_matrix.transformVectors(dotsPos, dotsView);
//描画
for (var i:int = 0; i < n; i++){
dots[i].update(dotsView[i * 3],
dotsView[i * 3 + 1],
dotsView[i * 3 + 2]);
}
lineCanvas.graphics.clear();
for (i = 0; i < k; i++){
if (!centers[i]) continue;
centers[i].update(dotsView[(n + i) * 3],
dotsView[(n + i) * 3 + 1],
dotsView[(n + i) * 3 + 2]);
var col:uint = colors[i];
var cx:Number = centers[i].x;
var cy:Number = centers[i].y;
for each(var index:int in groups[i]){
lineCanvas.graphics.lineStyle(0, col, .5);
lineCanvas.graphics.moveTo(dots[index].x, dots[index].y);
lineCanvas.graphics.lineTo(cx, cy);
lineCanvas.graphics.lineStyle();
}
}
}
private function moveCenter():void{
for each(var dot:Dot in dots) dot.glow = false;
if(!changed) return;
graphics.clear();
var animated:Boolean = false;
for(var i:int = 0; i < groups.length; i++){
if(!groups[i] || !groups.length){
continue;
}
// get center of gravity
var x:Number = 0, y:Number = 0, z:Number = 0;
for each(var index:int in groups[i]){
x += dotsPos[index * 3];
y += dotsPos[index * 3 + 1];
z += dotsPos[index * 3 + 2];
}
var gc:int = groups[i].length;
x /= gc;
y /= gc;
z /= gc;
if(centers[i]){
Tweener.addTween(centers[i], {
ax: x, ay: y, az: z, time: ANIMATE
});
animated = true;
}else{
centers[i] = new Center(colors[i], dotsPos, (n + i) * 3);
dotsPos[(n + i) * 3 + 0] = x;
dotsPos[(n + i) * 3 + 1] = y;
dotsPos[(n + i) * 3 + 2] = z;
centers[i].update(x, y, z);
canvas.addChild(centers[i]);
}
}
}
private function updateGroups():void{
changed = false;
groups = [];
for (var i:int = 0; i < n; i++){
// find the nearest group
var min:Number = Infinity;
var group:int = -1;
for(var j:int = 0; j < k; j++){
var center:Center = centers[j];
if(!center) continue;
var d:Number = Math.sqrt(
Math.pow(dotsPos[(n + j) * 3 + 0] - dotsPos[i * 3 + 0], 2)
+ Math.pow(dotsPos[(n + j) * 3 + 1] - dotsPos[i * 3 + 1], 2)
+ Math.pow(dotsPos[(n + j) * 3 + 2] - dotsPos[i * 3 + 2], 2));
if(d < min){
min = d;
group = j;
}
}
// update group
var dot:Dot = dots[i];
if(!groups[group]) groups[group] = [];
groups[group].push(i);
if(dot.color != colors[group]){
dot.color = colors[group];
dot.glow = true;
changed = true;
}
}
}
}
}
import flash.display.*;
import flash.text.*;
import flash.filters.GlowFilter;
const F:Number = 400;
class Sprite3D extends Sprite{
public function update(_x:Number, _y:Number, _z:Number):void{
var vz:Number = F / (_z + F);
x = _x * vz;
y = _y * vz;
scaleX = scaleY = vz - .5;
}
}
class Dot extends Sprite3D{
private var _color:uint;
public function get color():uint{return _color;}
public function set color(v:uint):void{
_color = v;
draw();
}
public function set glow(v:Boolean):void{
if(v) filters = [new GlowFilter(0xffffff, 1, 5, 5)];
else filters = [];
}
public function Dot(col:uint){
color = col;
}
private function draw():void{
graphics.clear();
graphics.beginFill(_color);
graphics.drawCircle(0, 0, 5);
graphics.endFill();
}
}
class Center extends Sprite3D{
private var dots:Vector.<Number>;
private var index:int;
public function get ax():Number{ return dots[index]; }
public function get ay():Number{ return dots[index + 1]; }
public function get az():Number{ return dots[index + 2]; }
public function set ax(v:Number):void{ dots[index] = v; }
public function set ay(v:Number):void{ dots[index + 1] = v; }
public function set az(v:Number):void{ dots[index + 2] = v; }
public function Center(col:uint, dots:Vector.<Number>, index:int){
this.dots = dots;
this.index = index;
graphics.lineStyle(3, 0xffffff);
draw();
graphics.endFill();
graphics.lineStyle(2, col);
draw();
graphics.endFill();
}
private function draw():void{
graphics.moveTo(-5, -5);
graphics.lineTo(5, 5);
graphics.moveTo(5, -5);
graphics.lineTo(-5, 5);
}
}
class Button extends Sprite{
public function Button(label:String){
useHandCursor = buttonMode = true;
mouseChildren = false;
var t:TextField = new TextField();
t.text = label;
t.autoSize = "left";
t.selectable = false;
t.x = t.y = 5
addChild(t);
graphics.beginFill(0xcccccc);
graphics.drawRect(0, 0, t.width + 10, t.height + 10);
graphics.endFill();
}
}
class Input extends Sprite{
private var input:TextField;
public function get value():int{
return parseInt(input.text, 10);
}
public function Input(labelStr:String, valueStr:String):void{
var tf:TextFormat = new TextFormat();
tf.size = 20;
var label:TextField = new TextField();
input = new TextField();
input.textColor = label.textColor = 0xffffff;
input.defaultTextFormat = label.defaultTextFormat = tf;
label.text = labelStr;
label.autoSize = "left";
addChild(label);
input.border = true;
input.borderColor = 0x999999;
input.type = "input";
input.text = valueStr;
input.height = 22;
addChild(input).x = 220;
}
}