资源简介
Python3写的EM算法,包含两个程序,一个是em分类,一个是GMM应用
(EM算法推导(收敛性证明和在GMM中的应用))我的博客:https://blog.csdn.net/kevinoop/article/details/80522477
代码片段和文件信息
# !/usr/bin/python
# -*- coding:utf-8 -*-
import numpy as np
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin
mpl.rcParams[‘font.sans-serif‘] = [‘SimHei‘]
mpl.rcParams[‘axes.unicode_minus‘] = False
if __name__ == ‘__main__‘:
style = ‘myself‘
np.random.seed(0)
mu1_fact = (0 0 0)
cov1_fact = np.diag((1 2 3))
data1 = np.random.multivariate_normal(mu1_fact cov1_fact 400)
mu2_fact = (2 2 1)
cov2_fact = np.array(((1 1 3) (1 2 1) (0 0 1)))
data2 = np.random.multivariate_normal(mu2_fact cov2_fact 100)
data = np.vstack((data1 data2))
y = np.array([True] * 400 + [False] * 100)
if style == ‘sklearn‘:
g = GaussianMixture(n_components=2 covariance_type=‘full‘ tol=1e-6 max_iter=1000)
g.fit(data)
print(‘类别概率:\t‘ g.weights_[0])
print(‘均值:\n‘ g.means_ ‘\n‘)
print(‘方差:\n‘ g.covariances_ ‘\n‘)
mu1 mu2 = g.means_
sigma1 sigma2 = g.covariances_
else:
num_iter = 100
n d = data.shape
# 随机指定
# mu1 = np.random.standard_normal(d)
# print mu1
# mu2 = np.random.standard_normal(d)
# print mu2
mu1 = data.min(axis=0)
mu2 = data.max(axis=0)
sigma1 = np.identity(d)
sigma2 = np.identity(d)
pi = 0.5
# EM
for i in range(num_iter):
# E Step
norm1 = multivariate_normal(mu1 sigma1)
norm2 = multivariate_normal(mu2 sigma2)
tau1 = pi * norm1.pdf(data)
tau2 = (1 - pi) * norm2.pdf(data)
gamma = tau1 / (tau1 + tau2)
# M Step
mu1 = np.dot(gamma data) / np.sum(gamma)
mu2 = np.dot((1 - gamma) data) / np.sum((1 - gamma))
sigma1 = np.dot(gamma * (data - mu1).T data - mu1) / np.sum(gamma)
sigma2 = np.dot((1 - gamma) * (data - mu2).T data - mu2) / np.sum(1 - gamma)
pi = np.sum(gamma) / n
print(i “:\t“ mu1 mu2)
print(‘类别概率:\t‘ pi)
print(‘均值:\t‘ mu1 mu2)
print(‘方差:\n‘ sigma1 ‘\n\n‘ sigma2 ‘\n‘)
# 预测分类
norm1 = multivariate_normal(mu1 sigma1)
norm2 = multivariate_normal(mu2 sigma2)
tau1 = norm1.pdf(data)
tau2 = norm2.pdf(data)
fig = plt.figure(figsize=(10 5) facecolor=‘w‘)
ax = fig.add_subplot(121 projection=‘3d‘)
ax.scatter(data[: 0] data[: 1] data[: 2] c=‘b‘ s=30 marker=‘o‘ edgecolors=‘k‘ depthshade=True)
ax.set_xlabel(‘X‘)
ax.set_ylabel(‘Y‘)
ax.set_zlabel(‘Z‘)
ax.set_title(‘原始数据‘ fontsize=15)
ax = fig.add_subplot(122 projection=‘3d‘)
order = pairwise_distances_argmin([mu1_fact mu2_fact] [mu1 mu2] metri
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 2784 2017-07-04 10:01 20.6.GMM_pdf.py
文件 3737 2017-07-04 09:33 20.1.EM.py
评论
共有 条评论