资源简介
Gibbs Sampling代码,python实现,代码仅有对一元线性回归的简单情况,配合博文【ML】线性回归的吉布斯采样(Gibbs Sampling)实现(python)使用
代码片段和文件信息
import numpy as np
import pandas as pd
from pandas import Dataframe
import seaborn as sns
import matplotlib.pyplot as plt
from numpy import random
plt.rcParams[‘figure.figsize‘]=(10 5)
def sample_beta_0(y x beta_1 tau mu_0 tau_0):
assert len(x)==len(y)
N = len(y)
precision=tau_0 + tau*N # 精度
mean = tau_0*mu_0+tau*np.sum(y-beta_1*x) # 期望
mean /= precision
return random.normal(mean 1/np.sqrt(precision)) # 得到beta_0采样
def sample_beta_1(y x beta_0 tau mu_1 tau_1):
assert len(x)==len(y)
precision=tau_1+tau*np.sum(x*x)
mean=tau_1*mu_1+tau*np.sum((y-beta_0)*x)
mean/=precision
return random.normal(mean 1/np.sqrt(precision))
def sample_tau(y x beta_0 beta_1 alpha beta N):
assert len(x)==len(y)
alpha_new=alpha+N/2
resid=y-beta_0-beta_1*x
beta_new=beta+np.sum(resid**2)/2
return random.gamma(alpha_new 1/beta_new)
def synthetic_data():
beta_0_true=-1
beta_1_true=2
tau_true=1
N=50
x=random.uniform(low=0 high=4 size=N)
y=random.normal(beta_0_true+beta_1_true*x 1/np.sqrt(tau_true))
# syn_plt=plt.plot(x y ‘o‘)
# plt.xlabel(‘x(uni. dist.)‘)
# plt.ylabel(‘y(normal dist.)‘)
# plt.grid(True)
# plt.show()
return x y N
x y N = synthetic_data()
# 设置参数起点
init={‘beta_0‘:0 ‘beta_1‘:0 ‘tau‘:2}
# 超参数
hypers={‘mu_0‘: 0 ‘tau_0‘:1 ‘mu_1‘:0 ‘tau_1‘:1 ‘alpha‘:2
- 上一篇:python_16to8
- 下一篇:python爬虫 抓取页面图片
评论
共有 条评论