导读:K-Means算法Java代码实现(基于weka的二次开发) package com.hyman.dmtools.clusters; import static com.hyman.dmtools.utils.EqualsForInst.equalsInstance; import static com.hyman.dmtools.utils.EqualsForInst.equalsInstances; import java.io.Serial...
K-Means算法Java代码实现(基于weka的二次开发)package com.hyman.dmtools.clusters;
import static com.hyman.dmtools.utils.EqualsForInst.equalsInstance;
import static com.hyman.dmtools.utils.EqualsForInst.equalsInstances;
import java.io.Serializable;
import java.util.Enumeration;
import java.util.Random;
import com.hyman.dmtools.utils.distance.DistanceFunctionFactory;
import com.hyman.dmtools.utils.distance.NormalizableDistance;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
public class K_Means implements Serializable {
private static final long serialVersionUID = 6969463296146523940L;
private int numClusters = 2;
private static Instances clusterCentroids;
private int seedDefault = 1;
private int seed = seedDefault;
private String distanceFunctionType = "EuclideanDistance";
private NormalizableDistance distanceFunction = DistanceFunctionFactory
.getDistanceFunction(distanceFunctionType);
private Instances[] kmeansResult;
private int maxIterations = 500;
private int beginIterations = 0;
// private double [] squaredErrors = new
// double[clusterCentroids.numInstances()];
public K_Means() {
super();
seedDefault = 10;
setSeed(seedDefault);
}
public Instances getCentroid(Instances data) throws Exception {
if (data.numInstances() == 0) {// 判断输入的数据文件是否为空。
throw new Exception("输入数据为空值!请检查数据集文件");
}
Random random = new Random();
random.setSeed(seed);
int insIndex = 0;
clusterCentroids = new Instances(data, numClusters);// 初始化clusterCentroids
System.out.println("-------------聚类中心-----------");
for (int i = data.numInstances() - 1; i >= 0; i--) {
insIndex = random.nextInt(i);
// System.out.println(insIndex);
clusterCentroids.add(data.instance(insIndex));
if (clusterCentroids.numInstances() == numClusters)
break;
}
System.out.println(clusterCentroids);
System.out.println("------------------------------------");
return clusterCentroids;
}
public Instances[] createCluster(Instances data, Instances centroids) {
Instances[] newData = new Instances[centroids.numInstances()];
for (int i = 0; i < newData.length; i++) {
newData[i] = new Instances(data, 0);// 初始化newData数组中的Instances实例。取data的header作为实例头,初始容量为0;
newData[i].add(centroids.instance(i));
}
double[] tempDis;
for (int i = 0; i < data.numInstances(); i++) {
tempDis = new double[centroids.numInstances()];
for (int j = 0; j < centroids.numInstances(); j++) {
if (!equalsInstance(data.instance(i), centroids.instance(j))) { // 重写Instance//
// 的
// equals方法。见EqualsInstance类
distanceFunction.setInstances(data);// 设置距离 作用的数据源。
// distanceFunction.setInstances(centroids);
tempDis[j] = distanceFunction.distance(data.instance(i),
centroids.instance(j));// 用欧式距离计算数据中其他实例和中心点实例的距离。
}
}// end for(j)
int smallIndex = Utils.minIndex(tempDis);// 求最小数值的索引值。
newData[smallIndex].add(data.instance(i));
}// end for(i)
return newData;
}
public Instance meanCentroid(Instances data) {
int sumValue = 0;// 存放所有实例某一属性的和。
int avgValue = 0;// 存放属性的均值
Instance meanIns = new Instance(data.numAttributes());// 创建存放均值的实例;//声明存放均值的实例
for (int i = 0; i < data.numAttributes(); i++) {
for (int j = 0; j < data.numInstances(); j++) {
Attribute attr = data.instance(j).attribute(i);
// System.out.println("attr:"+attr);
sumValue += data.instance(j).value(attr);
}
// System.out.println("sum:"+sumValue);
avgValue = sumValue / data.numInstances();
meanIns.setValue(data.attribute(i), avgValue);// 向实例添加属性值
sumValue = 0;
avgValue = 0;
}
return meanIns;
}
private Instances updateCentroids(Instances centroids, Instance instance,
int index) {
centroids.add(instance);
int temp = centroids.numInstances() - 1;
centroids.swap(index, temp);
centroids.delete(temp);
return centroids;
}
public Instances[] getKMeansResult(Instances data) throws Exception {
clusterCentroids = getCentroid(data);// (1)得出初始中心点
kmeansResult = createCluster(data, clusterCentroids);// (2)第一次聚类
for (int i = beginIterations; i < maxIterations; i++) {
Instances newCentroids = new Instances(clusterCentroids);
// ins = createCluster(data, clusterCentroids);
// //每次做循环create的参数数据data都是相同的,变化的是聚类中心,
// 每次计算均值和create的聚类中心不会发生变化。
for (int j = 0; j < kmeansResult.length; j++) {
Instances cluster = kmeansResult[j];
Instance tempIns = meanCentroid(cluster);
newCentroids = updateCentroids(newCentroids, tempIns, j);
}
System.out.println("迭代......" + (i + 1));
if (equalsInstances(newCentroids, clusterCentroids)) {
System.out.println("中心点集合不再变化!迭代次数" + (i + 1));
break;
}
clusterCentroids = newCentroids;
kmeansResult = createCluster(data, clusterCentroids);
}
System.out.println("聚类结果:");
for (Instances ins : kmeansResult) {// 迭代打印生成聚类簇的数据。
Enumeration en = ins.enumerateInstances();
while (en.hasMoreElements()) {
System.out.println(en.nextElement());
}
System.out.println();
}
return kmeansResult;
}
private double getSquaredError() {
// TODO Auto-generated method stub
return 0;
}
private int clusterProcessedInstance(Instance instance, boolean updateErrors) {
return 0;
}
public Instances getClusterCentroids() {
return clusterCentroids;
}
private void setClusterCentroids(Instances clusterCentroids) {
this.clusterCentroids = clusterCentroids;
}
public int getNumClusters() {
return numClusters;
}
public void setNumClusters(int numClusters) {
this.numClusters = numClusters;
}
public int getSeed() {
return seed;
}
public void setSeed(int seed) {
this.seed = seed;
}
public static void main(String[] args) {
new K_Means().setNumClusters(7);
}
public String getDistanceFunctionType() {
return distanceFunctionType;
}
public void setDistanceFunctionType(String distanceFunctionType) {
this.distanceFunctionType = distanceFunctionType;
}
}