pytorch逻辑回归实现垃圾邮件检测

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np# 增强的数据集:更多的垃圾邮件与正常邮件样本
X = ["Congratulations! You've won a $1000 gift card. Claim it now!","Dear friend, I hope you are doing well. Let's catch up soon.","Urgent: Your bank account has been compromised. Please contact support immediately.","Hello, just wanted to confirm our meeting at 2 PM today.","You have a new message from your friend. Click here to read.","Get a free iPhone now! Limited offer, click here.","Last chance to claim your prize, you won $500!","Meeting scheduled for tomorrow. Please confirm.","Hello! You are invited to an exclusive event!","Click here to get free lottery tickets. Hurry up!","Reminder: Your subscription will expire soon, renew now.","Don't forget to submit your report by end of day today."
]
y = [1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0]  # 1 为垃圾邮件,0 为正常邮件# 使用 TfidfVectorizer 进行文本向量化
vectorizer = TfidfVectorizer(stop_words='english')  # 去除停用词
X_vec = vectorizer.fit_transform(X).toarray()# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_vec, y, test_size=0.33, random_state=42)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self, input_dim):super(LogisticRegressionModel, self).__init__()self.fc = nn.Linear(input_dim, 1)  # 线性层,输入维度是特征的数量,输出是1def forward(self, x):return torch.sigmoid(self.fc(x))  # 使用sigmoid激活函数输出0到1之间的概率# 定义训练过程
def train_model(model, X_train, y_train, num_epochs=200, learning_rate=0.001):criterion = nn.BCELoss()  # 二分类交叉熵损失optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 使用Adam优化器X_train_tensor = torch.tensor(X_train, dtype=torch.float32)y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train_tensor)loss = criterion(outputs, y_train_tensor)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
def evaluate_model(model, X_test, y_test):model.eval()X_test_tensor = torch.tensor(X_test, dtype=torch.float32)y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)with torch.no_grad():outputs = model(X_test_tensor)predictions = (outputs >= 0.5).float()  # 阈值设为0.5accuracy = accuracy_score(y_test, predictions.numpy())print(f'Accuracy: {accuracy * 100:.2f}%')# 训练并评估模型
input_dim = X_train.shape[1]  # 输入特征的数量
model = LogisticRegressionModel(input_dim)
train_model(model, X_train, y_train, num_epochs=200, learning_rate=0.001)
evaluate_model(model, X_test, y_test)# 预测新邮件
def predict(model, new_email):model.eval()new_email_vec = vectorizer.transform([new_email]).toarray()new_email_tensor = torch.tensor(new_email_vec, dtype=torch.float32)with torch.no_grad():prediction = model(new_email_tensor)return "Spam" if prediction >= 0.5 else "Not Spam"# 检测新邮件
email_1 = "Congratulations! You have a limited time offer for a free cruise."
email_2 = "Hi, let's discuss the project updates tomorrow."print(f"Email 1: {predict(model, email_1)}")  # 可能输出:Spam
print(f"Email 2: {predict(model, email_2)}")  # 可能输出:Not Spam
1. 数据预处理
  • 准备数据集:包含垃圾邮件(Spam)和正常邮件(Not Spam)。
  • 文本向量化:使用 TfidfVectorizer 将文本转换为数值特征,使模型能够处理。
  • 去除停用词:排除无意义的常见词(如 "the", "is", "and"),提高模型性能。
2. 训练集与测试集划分
  • 将数据集拆分为训练集和测试集,以 67% 训练,33% 测试,保证模型有足够数据训练,同时可以评估其泛化能力。
3. 逻辑回归模型
  • 搭建 PyTorch 逻辑回归模型
    • 采用 nn.Linear() 构建一个单层神经网络(输入为文本特征,输出为 1 个数值)。
    • 使用 sigmoid 作为激活函数,将输出转换为 0-1 之间的概率值。
4. 训练模型
  • 定义损失函数:使用二元交叉熵损失 (BCELoss),适用于二分类问题。
  • 优化器:采用 Adam 优化器,以 0.001 学习率进行参数优化。
  • 训练流程
    1. 计算前向传播的输出。
    2. 计算损失值,衡量预测结果与真实标签的差距。
    3. 进行反向传播,更新权重参数。
    4. 迭代多轮(如 200 轮),不断优化模型。
5. 评估模型
  • 将测试数据输入模型,预测结果并与真实标签进行对比。
  • 计算准确率,评估模型在未见过的数据上的表现。
6. 预测新邮件
  • 将新邮件转换为数值特征(与训练时相同的方法)。
  • 使用训练好的模型进行预测
  • 阈值判断:如果输出概率 ≥ 0.5,则判断为垃圾邮件,否则为正常邮件。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/9346.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【漫话机器学习系列】064.梯度下降小口诀(Gradient Descent rule of thume)

梯度下降小口诀 为了帮助记忆梯度下降的核心原理和关键注意事项,可以用以下简单口诀来总结: 1. 基本原理 损失递减,梯度为引:目标是让损失函数减少,依靠梯度指引方向。负梯度,反向最短:沿着负…

Autogen_core 测试代码:test_cache_store.py

目录 原始代码测试代码代码中用到的typing注解 原始代码 from typing import Dict, Generic, Optional, Protocol, TypeVarT TypeVar("T")class CacheStore(Protocol, Generic[T]):"""This protocol defines the basic interface for store/cache o…

文件上传2

BUUCTF 你传你🐎呢 先上传.htaccess 修改格式 即可上传成功 返回上传图片格式的木马 用蚁剑连接 5ecf1cca-59a1-408b-b616-090edf124db5.node5.buuoj.cn:81/upload/7d8511a847edeacb5385299396a96d91/rao.jpg 即可得到flag [GXYCTF2019]BabyUpload

挂载mount

文章目录 1.挂载的概念(1)挂载命令:mount -t nfs(2)-t 选项:指定要挂载的文件系统类型(3)-o选项 2.挂载的目的和作用(1)跨操作系统访问:将Windows系统内容挂载到Linux系统下(2)访问外部存储设备(3)整合不同的存储设备 3.文件系统挂载要做的事…

UE求职Demo开发日志#15 思路与任务梳理、找需要的资源

1 思路梳理 因为有点无从下手,就梳理下最终形态. 基地的建设我是想单独一个场景,同一个关卡中小怪会每次来都会刷,小解密一次性的,关键的Boss和精英怪不会重复刷,同时场景里放一些资源可收集,基地建设锁定区…

vulfocus/thinkphp:6.0.12 命令执行

本次测试是在vulfocus靶场上进行 漏洞介绍 在其6.0.13版本及以前,存在一处本地文件包含漏洞。当多语言特性被开启时,攻击者可以使用lang参数来包含任意PHP文件。 虽然只能包含本地PHP文件,但在开启了register_argc_argv且安装了pcel/pear的环境下,可以包含/usr/local/lib/…

洛谷P3884 [JLOI2009] 二叉树问题(详解)c++

题目链接:P3884 [JLOI2009] 二叉树问题 - 洛谷 | 计算机科学教育新生态 1.题目解析 1:从8走向6的最短路径,向根节点就是向上走,从8到1会经过三条边,向叶节点就是向下走,从1走到6需要经过两条边&#xff0c…

如何获取小程序的code在uniapp开发中

如何获取小程序的code在uniapp开发中,也就是本地环境,微信开发者工具中获取code,这里的操作是页面一进入就获取code登录,没有登录页面的交互,所以写在了APP.vue中,也就是小程序一打开就获取用户的code APP.…

k8s支持自定义field-selector spec.hostNetwork过滤

好久没写博客啦,年前写一个博客就算混过去啦😂 写一个小功能,对于 Pod,在没有 label 的情况下,支持 --field-selector spec.hostNetwork 查询 Pod 是否为 hostNetwork 类型,只为了熟悉 APIServer 是如何构…

olloama下载deepseek-r1大模型本地部署

1.登录olloama,选择models,选择deepseek-r1模型,选择1.5b(核显电脑) 2.选择1.5b,复制命令,打开CMD控制台; 3.控制台输入ollama run deepseek-r1:1.5b自动下载 4.部署完成 5.退出【Ctrl d】or 【/bye】 …

C语言初阶力扣刷题——349. 两个数组的交集【难度:简单】

1. 题目描述 力扣在线OJ题目 给定两个数组,编写一个函数来计算它们的交集。 示例: 输入:nums1 [1,2,2,1], nums2 [2,2] 输出:[2] 输入:nums1 [4,9,5], nums2 [9,4,9,8,4] 输出:[9,4] 2. 思路 直接暴力…

python学opencv|读取图像(四十九)使用cv2.bitwise()系列函数实现图像按位运算

【0】基础定义 按位与运算:两个等长度二进制数上下对齐,全1取1,其余取0。 按位或运算:两个等长度二进制数上下对齐,有1取1,其余取0。 按位异或运算: 两个等长度二进制数上下对齐,相…

ZZNUOJ(C/C++)基础练习1011——1020(详解版)

1011 : 圆柱体表面积 题目描述 输入圆柱体的底面半径r和高h,计算圆柱体的表面积并输出到屏幕上。要求定义圆周率为如下宏常量 #define PI 3.14159 输入 输入两个实数,表示圆柱体的底面半径r和高h。 输出 输出一个实数,即圆柱体的表面积&…

HTML特殊符号的使用示例

目录 一、基本特殊符号的使用 1、空格符号: 2、小于号 和 大于号: 3、引号: 二、版权、注册商标符号的使用 1、版权符号:© 2、注册商标符号: 三、数学符号的使用 四、箭头符号的使用 五、货币符号的使用…

java基础-容器

一、集合基础 1、集合 Collection接口下,主要用于存放单一元素Map接口下,用于存放键值对 2、常见集合的比较 List 存储的元素是有序的、可重复的。Set: 存储的元素不可重复的。Queue: 按特定的排队规则来确定先后顺序,存储的元素是有序的、…

嵌入式知识点总结 ARM体系与架构 专题提升(三)-中断与异常

针对于嵌入式软件杂乱的知识点总结起来,提供给读者学习复习对下述内容的强化。 目录 1.中断与异常有何区别? 2.中断与DMA有何区别? 3.中断能不能睡眠,为什么?下半部能不能睡眠? 4.中断的响应执行流程是什么&#…

从替代到覆盖:暴雨信创服务器打开市场新局面

进入2025年,全球局势更加变幻莫测,高科技领域越来越受到全球局势影响。美国前任总统拜登在卸任前,特别颁布限制GPU产品出口法案。新任总统特朗普上任第一天,废除了多项之前法案,但显示技术交流的内容一条没变。 在如此艰难的局面下,我国信创市场的发展显得尤为重要,国家也从政策…

机器人抓取与操作经典规划算法(深蓝)——2

1 经典规划算法 位姿估计:(1)相机系位姿 (2)机器人系位姿 抓取位姿:(1)抓取位姿计算 (2)抓取评估和优化 路径规划:(1)笛卡…

C++二叉树进阶

1.二叉搜索树 1.1二叉搜索树概念 二叉搜索树又称二叉排序树,它或者是一颗空树,或者具有以下性质的二叉树 若它的左子树不为空,则左子树上所有结点的值小于根节点的值若它的右子树不为空,则右子树上所有节点的值都大于根节点的值…

“AI视频智能分析系统:让每一帧视频都充满智慧

嘿,大家好!今天咱们来聊聊一个特别厉害的东西——AI视频智能分析系统。想象一下,如果你有一个超级聪明的“视频助手”,它不仅能自动识别视频中的各种元素,还能根据内容生成详细的分析报告,是不是感觉特别酷…