目录
引言
基本原理
损失函数
参数估计
优缺点
应用
Logistic回归优化算法
具体案例
引言
逻辑回归(Logistic Regression)是一种广泛用于分类问题的统计方法,尤其是二分类问题。尽管名字中有“回归”二字,但它实际上是一种分类算法,主要用于估计一个样本属于某个类别的概率。逻辑回归通过逻辑函数(sigmoid函数)将线性回归模型的输出(通常是一个实数值)映射到(0,1)区间,从而得到属于某个类别的概率。
取定特征组 X∈ R ,称模型为一个Logistic模型,其中w∈ 为模型参数
Logistic回归算法就是以Logistic模型为模型假设,以对数损失为损失函数的经验损失最小化算法。
基本原理
线性回归:首先,逻辑回归通过一组自变量(特征)的线性组合来预测一个因变量(通常是连续值)。但在逻辑回归中,我们不直接使用这个连续值作为预测结果。
Sigmoid函数:线性回归的输出被用作Sigmoid函数的输入。Sigmoid函数是一个S形曲线,它将任意实数值压缩到(0,1)区间内,其公式为:
二分类:在二分类问题中,我们设定一个阈值(通常是0.5),如果Sigmoid函数的输出大于这个阈值,则认为样本属于正类(标签为1的类别),否则属于负类(标签为0的类别)
损失函数
逻辑回归使用对数损失(log loss)或交叉熵损失(cross-entropy loss)作为损失函数。这个损失函数衡量了模型预测的概率分布与真实概率分布之间的差异。损失函数越小,模型的预测越准确。
Logistic回归问题的目标函数称为交叉熵:
交叉熵统计意义:
在线性回归算法中,假设标签分布是以模型预测值为期望的正态分布
在Logistic回归算法中,假设标签分布是以模型概率预测值 为期望的伯努利分布
参数估计
逻辑回归的参数(即线性组合中的系数和截距)通常通过最大似然估计(MLE)或梯度下降等优化算法来求解。最大似然估计的目标是找到一组参数,使得在这组参数下,观测到当前样本集的概率最大。
优缺点
- 优点:
- 实现简单,易于理解和实现。
- 计算代价不高,容易扩展到大规模数据集上。
- 输出结果是一个概率值,可以辅助决策过程。
- 缺点:
- 对数据和场景的适应能力有局限,例如不适合处理非线性问题。
- 对多重共线性数据敏感,可能导致过拟合。
- 分类的类别数不宜过多,主要用于二分类问题。
应用
逻辑回归因其简单、易实现且可解释性强而被广泛应用于各种领域,如:
- 垃圾邮件检测
- 预测用户是否会点击广告
- 医疗诊断(如预测病人是否患有某种疾病)
- 风险评估(如信用评分)
总之,逻辑回归是一种简单而强大的分类算法,特别适用于处理二分类问题,并且其输出具有概率意义,便于理解和应用。
Logistic回归优化算法
具体案例
数据集是著名的鸢尾花(Iris)数据集,它常被用于分类算法的测试和教学。数据集包含了150个样本,每个样本都有4个特征(花萼长度Sepal.Length、花萼宽度Sepal.Width、花瓣长度Petal.Length、花瓣宽度Petal.Width)和一个目标变量(Species),即鸢尾花的种类。在这个数据集中,鸢尾花被分为三种类型:Setosa、Versicolour和Virginica。
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np# 加载数据
data = pd.read_csv('iris.csv')# 假设最后一列是目标变量
X = data.iloc[:, :4]
y = data.iloc[:, 4]# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 使用逻辑回归模型
model = LogisticRegression(random_state=42, max_iter=200) # 增加max_iter以避免收敛警告
model.fit(X_train, y_train)# 预测测试集
y_pred = model.predict(X_test)# 计算准确率、精确率、召回率
accuracy = model.score(X_test, y_test)
precision, recall, _, _ = precision_recall_fscore_support(y_test, y_pred, average='weighted')
print("Accuracy: {:.2%}".format(accuracy))
print("Precision: {:.2%}".format(precision))
print("Recall: {:.2%}".format(recall))# 绘制混淆矩阵
unique_labels = y.unique() # 获取目标列的唯一值
conf_mat = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(10, 7))
plt.imshow(conf_mat, cmap='Blues', interpolation='nearest')
plt.colorbar()
tick_marks = np.arange(len(unique_labels))
plt.xticks(tick_marks, unique_labels, rotation=45)
plt.yticks(tick_marks, unique_labels)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()