• 大小: 3KB
    文件类型: .py
    金币: 2
    下载: 1 次
    发布日期: 2021-06-04
  • 语言: Python
  • 标签: gibbs  sampling  

资源简介

Gibbs Sampling代码,python实现,代码仅有对一元线性回归的简单情况,配合博文【ML】线性回归的吉布斯采样(Gibbs Sampling)实现(python)使用

资源截图

代码片段和文件信息

import numpy  as np
import pandas as pd
from pandas import Dataframe
import seaborn as sns
import matplotlib.pyplot as plt
from numpy import random

plt.rcParams[‘figure.figsize‘]=(10 5)

def sample_beta_0(y x beta_1 tau mu_0 tau_0):
    assert len(x)==len(y)
    N = len(y)
    precision=tau_0 + tau*N # 精度
    mean = tau_0*mu_0+tau*np.sum(y-beta_1*x) # 期望
    mean /= precision
    return random.normal(mean 1/np.sqrt(precision)) # 得到beta_0采样


def sample_beta_1(y x beta_0 tau mu_1 tau_1):
    assert len(x)==len(y)
    precision=tau_1+tau*np.sum(x*x)
    mean=tau_1*mu_1+tau*np.sum((y-beta_0)*x)
    mean/=precision
    return random.normal(mean 1/np.sqrt(precision))    

def sample_tau(y x beta_0 beta_1 alpha beta N):
    assert len(x)==len(y)
    alpha_new=alpha+N/2
    resid=y-beta_0-beta_1*x
    beta_new=beta+np.sum(resid**2)/2
    return random.gamma(alpha_new 1/beta_new)


def synthetic_data():
    beta_0_true=-1
    beta_1_true=2
    tau_true=1

    N=50
    x=random.uniform(low=0 high=4 size=N)
    y=random.normal(beta_0_true+beta_1_true*x 1/np.sqrt(tau_true))

    # syn_plt=plt.plot(x y ‘o‘)
    # plt.xlabel(‘x(uni. dist.)‘)
    # plt.ylabel(‘y(normal dist.)‘)
    # plt.grid(True)
    # plt.show()

    return x y N

x y N = synthetic_data()

# 设置参数起点
init={‘beta_0‘:0 ‘beta_1‘:0 ‘tau‘:2}
# 超参数
hypers={‘mu_0‘: 0 ‘tau_0‘:1 ‘mu_1‘:0 ‘tau_1‘:1 ‘alpha‘:2 

评论

共有 条评论