
如何高效爬取全球新聞網站 – 整合Scrapy、Selenium與Mediastack API實現自動化新聞采集
高斯混合模型(GMM)?是由多個高斯分布混合而成的。假設數據集是由?k?個高斯分布組成的混合模型,那么給定數據點?x,它的概率分布可以表示為每個分布的加權和:
由于我們不知道每個點屬于哪個高斯分布,因此 GMM 采用?EM算法(期望最大化算法)來迭代估計參數。
通過多次迭代,EM算法可以讓這些參數收斂到合適的值。
我們會從 Kaggle 下載一個數據集,使用 GMM 對數據進行分類。為了演示原理,我們使用一個簡單的二維數據集。并且根據原理進行代碼的編寫。
選擇一個簡單的、適合分類的二維數據集,例如 Kaggle 上的?Iris 數據集?或?Mall Customers 數據集。我們將以 Mall Customers 為例,用顧客的年收入和消費得分來進行聚類分析。
我們要畫出 4 個及以上的分析圖,來逐步理解數據和模型效果。
下面是?Mall Customers 數據集?的完整 Python 實現:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
# 讀取數據集
data = pd.read_csv("Mall_Customers.csv")
X = data[['Annual Income (k$)', 'Spending Score (1-100)']].values
# 初始化參數
def initialize_params_fixed(X, K):
n, d = X.shape
pi = np.ones(K) / K # 初始化每個混合成分的權重
mu = X[np.random.choice(n, K, False), :] # 隨機選擇K個初始均值
sigma = np.array([np.eye(d) for _ in range(K)]) # 初始化協方差矩陣為單位矩陣
return pi, mu, sigma
# 計算多元正態分布
def multivariate_gaussian(X, mu, sigma):
return multivariate_normal(mean=mu, cov=sigma).pdf(X)
# E 步:計算每個點屬于每個成分的責任值 (gamma)
def expectation_step_stable(X, pi, mu, sigma):
N = X.shape[0]
K = len(pi)
gamma = np.zeros((N, K))
for k in range(K):
try:
gamma[:, k] = pi[k] * multivariate_gaussian(X, mu[k], sigma[k])
except np.linalg.LinAlgError:
# 如果協方差矩陣是奇異矩陣,加入微小正則化項以確保正定性
sigma[k] += np.eye(X.shape[1]) * 1e-6
gamma[:, k] = pi[k] * multivariate_gaussian(X, mu[k], sigma[k])
# 防止零除錯誤,保證數值穩定性
gamma_sum = np.sum(gamma, axis=1, keepdims=True)
gamma_sum[gamma_sum == 0] = 1e-10 # 防止除以零
gamma = gamma / gamma_sum
return gamma
# M 步:更新GMM的參數
def maximization_step(X, gamma):
N, d = X.shape
K = gamma.shape[1]
Nk = np.sum(gamma, axis=0) # 計算每個聚類的總責任值
pi = Nk / N # 更新混合系數
mu = np.dot(gamma.T, X) / Nk[:, np.newaxis] # 更新均值
sigma = np.zeros((K, d, d)) # 更新協方差矩陣
for k in range(K):
X_centered = X - mu[k]
gamma_diag = np.diag(gamma[:, k])
sigma[k] = np.dot(X_centered.T, np.dot(gamma_diag, X_centered)) / Nk[k]
return pi, mu, sigma
# 計算對數似然
def compute_log_likelihood(X, pi, mu, sigma):
N = X.shape[0]
K = len(pi)
log_likelihood = 0
for n in range(N):
tmp = 0
for k in range(K):
tmp += pi[k] * multivariate_gaussian(X[n], mu[k], sigma[k])
log_likelihood += np.log(tmp)
return log_likelihood
# GMM 實現,包含數值穩定性修復
def gmm_fixed_stable(X, K, max_iter=100, tol=1e-6):
pi, mu, sigma = initialize_params_fixed(X, K)
log_likelihoods = []
for i in range(max_iter):
# E 步
gamma = expectation_step_stable(X, pi, mu, sigma)
# M 步
pi, mu, sigma = maximization_step(X, gamma)
# 添加小的正則化項,確保協方差矩陣為正定
sigma += np.eye(sigma.shape[1]) * 1e-6
# 計算對數似然
log_likelihood = compute_log_likelihood(X, pi, mu, sigma)
log_likelihoods.append(log_likelihood)
# 檢查是否收斂
if i > 0 and abs(log_likelihoods[-1] - log_likelihoods[-2]) < tol:
break
return pi, mu, sigma, log_likelihoods, gamma
# 數據可視化:原始數據分布
def plot_original_data(X):
plt.scatter(X[:, 0], X[:, 1], c='blue', label='Data points', alpha=0.5)
plt.title('Original Data Distribution')
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score (1-100)')
plt.show()
# 分類結果展示
def plot_clusters(X, gamma, mu):
K = gamma.shape[1]
colors = ['r', 'g', 'b', 'y', 'm']
for k in range(K):
plt.scatter(X[:, 0], X[:, 1], c=gamma[:, k], cmap='viridis', label=f'Cluster {k+1}', alpha=0.6)
plt.scatter(mu[:, 0], mu[:, 1], c='black', marker='x', s=100, label='Centroids')
plt.title('GMM Clustering')
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score (1-100)')
plt.legend()
plt.show()
# 對數似然收斂圖
def plot_log_likelihood(log_likelihoods):
plt.plot(log_likelihoods)
plt.title('Log Likelihood Convergence')
plt.xlabel('Iterations')
plt.ylabel('Log Likelihood')
plt.show()
# 各類別概率分布圖
def plot_probability_distributions(gamma):
K = gamma.shape[1]
for k in range(K):
plt.hist(gamma[:, k], bins=20, alpha=0.5, label=f'Cluster {k+1}')
plt.title('Probability Distributions for Each Cluster')
plt.xlabel('Probability')
plt.ylabel('Number of Points')
plt.legend()
plt.show()
# 運行 GMM 算法
K = 3 # 假設數據有 3 個聚類
pi, mu, sigma, log_likelihoods, gamma = gmm_fixed_stable(X, K)
# 繪制圖形
plot_original_data(X) # 原始數據分布圖
plot_clusters(X, gamma, mu) # 分類結果圖
plot_log_likelihood(log_likelihoods) # 對數似然收斂圖
plot_probability_distributions(gamma) # 各類別概率分布圖
代碼部分細節解釋:
1. 原始數據分布圖:展示客戶年收入和消費得分的散點圖,幫助我們直觀理解數據分布情況。
2. 分類結果圖:展示 GMM 分類后每個客戶所屬的類別,以及每個類別的均值點(質心)。
3. 對數似然收斂圖:展示對數似然值的收斂過程,判斷模型是否收斂。
4. 各類別概率分布圖:展示不同類別的概率分布,幫助理解分類的置信度。
通過高斯混合模型(GMM)的推導與 Python 實現,咱們完成了從基礎原理到實際應用的完整過程。有問題,大家可以評論區討論~