• 大小: 133KB
    文件类型: .rar
    金币: 2
    下载: 2 次
    发布日期: 2021-08-07
  • 语言: Python
  • 标签: Fisher  LDA  python  

资源简介

Fisher算法python实现含训练集和测试集,Fisher算法也叫Fisher判别分析,或线性判别分析(LDA)

资源截图

代码片段和文件信息

import numpy as np
import matplotlib.pyplot as plt
import operator
from os import listdir

#读取数据,并根据类别分类
def readdata(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #获取数据行数
    data = np.zeros((numberOfLines2))        
    label = []                          
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split()
        data[index0] = float(listFromLine[0])
        data[index1] = float(listFromLine[1])
        label.append(float(listFromLine[-1]))
        index += 1
    #分类
    index1 = np.array([index for (index value) in enumerate(label) if value == -1.0])  
    index2 = np.array([index for (index value) in enumerate(label) if value == 1.0]) 
    data0=data[index1]
    data1=data[index2]
    return data0data1

def calculatesi(dataiui):
    si = np.zeros((datai.shape[1] datai.shape[1]))
    for xi in datai:
        m = xi - ui
        si += m*m.reshape(2 1)
    return si

def fish(data0data1):
    #计算均值向量ui
    u0 = np.mean(data0 axis=0)
    u1 = np.mean(data1 axis=0)
    #计算类内离散度矩阵si
    s0 = calculatesi(data0u0)
    s1 = calculatesi(data1u1)
    #总类内离散度矩阵
    sw = s0 + s1
    #求逆
    sw_inv = np.linalg.inv(sw)
    #计算投影w
    w = np.dot(sw_inv(u0-u1))
    w0 = (np.dot(w.Tu0)+np.dot(w.Tu0))/2
    return wu0u1

def judge(filenamewu0u1):
    #读取数据
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #获取数据行数
    test_data = np.zeros((numberOfLines2))    
    label = []                             
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split()
        test_data[index0] = float(listFromLine[0])
        test_data[index1] = float(listFromLine[1])
        index += 1
    #判断类别
    center_0 = np.dot(w.Tu0)
    center_1 = np.dot(w.Tu1)
    for s in test_data:
        y = np.dot(w.Ts)
        #print(y0)
        if abs(y - center_0) < abs(y - center_1):
            label.append(-1.0)
        else:
            label.append(1.0)

    #分类
    index1 = np.array([index for (index value) in enumerate(label) if value == -1.0])  
    index2 = np.array([index for (index value) in enumerate(label) if value == 1.0]) 
    test_data0=test_data[index1]
    test_data1=test_data[index2]
    return test_data0test_data1   
    

def draw(data0data1w):
    plt.scatter(data0[: 0] data0[: 1] c=‘red‘marker=‘x‘)
    plt.scatter(data1[: 0] data1[: 1] c=‘blue‘marker=‘x‘)
    plt.show()

if __name__ == ‘__main__‘:  
    #读取数据集并根据数据类别分类 
    data0data1 = readdata(“train_data.txt“)
    #计算最佳投影w
    wu0u1=fish(data0data1)
    #判断测试集
    test_data0test_data1 = judge(“test_data.txt“wu0u1)
    #绘图
    draw(test_data0test_data1w)






 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----

     文件       3098  2020-03-18 10:58  Fisher算法\fish.py

     文件     308000  2017-08-16 18:42  Fisher算法\test_data.txt

     文件     154000  2017-08-16 18:42  Fisher算法\train_data.txt

     目录          0  2020-04-01 17:48  Fisher算法

----------- ---------  ---------- -----  ----

               465098                    4


评论

共有 条评论