论文复现——AutoRec: Autoencoders Meet Collaborative Filtering
时间:2021-08-16
本文章向大家介绍论文复现——AutoRec: Autoencoders Meet Collaborative Filtering,主要包括论文复现——AutoRec: Autoencoders Meet Collaborative Filtering使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
《AutoRec: Autoencoders Meet Collaborative Filtering》是2015年Suvash等人发表在“The Web Conference”会议上的一篇论文,作者提出用自编码器预测用户对电影的评分。论文比较短,只有两页,可以说是深度学习在推荐系统领域应用的开端。
ABSTRACT
本文提出了一个新颖的基于自编码器的协同过滤框架——AutoRec。实验表明,AutoRec在Movielens数据集上的表现优于目前最好的方法(矩阵分解、受限玻尔兹曼机、LLORMA)。
THE AUTOREC MODEL
假设有\(m\)个用户,\(n\)个商品,并且有用户对商品的评分矩阵\(R\in \mathbb{R}^{m\times n}\),则用户\(u\)对所有商品的评分可以用不完全的向量\(r^{(u)}={R_{u1},...,R_{u2}}\)表示(不完全意思是,\(r^{(u)}\)中的元素有的是真实的评分数据,有的是需要我们预测的)。自编码器的作用就是将\(r^{(u)}\)作为输入数据,经过编码器将向量映射维一个低维的向量,然后通过解码器重构输出向量,使输出向量趋近于输入向量,同时能够补全原始输入向量中的缺失值。自编码器模型可以表示为:
\[min\sum_{r\in S}^{}\left \| r-h(r;\theta ) \right \|^{2}_{2}
\]
代码复现
完整代码及数据集已上传至github
import os
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.utils.data as Data
import matplotlib.pyplot as plt
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
col_name = ["userid", "movieid", "rating", "timestrap"]
u1_base_path = "data/u1.base"
u1_base = pd.read_table(u1_base_path, sep='\t', header=None, names=col_name)
# print(u1_base.head(5))
u1_test_path = "data/u1.test"
u1_test = pd.read_table(u1_test_path, sep='\t', header=None, names=col_name)
# print(u1_test.head(5))
# 将数据转换为 user-item 交互矩阵
def TranslateData(data):
user_num = data.userid.nunique() # 用户的个数
movie_num = 1682 # 电影个数(数据中标明的所有电影数)
data_mat = np.zeros(user_num * movie_num).reshape((-1, movie_num)) + 3
k = 0
for i in range(data.shape[0]):
data_mat[k][data.iloc[i, 1] - 1] = data.iloc[i, 2]
if i > 0 and data.iloc[i, 0] != data.iloc[i - 1, 0]:
k += 1
return data_mat
class AutoRec(nn.Module):
def __init__(self, input_num, hidden_num):
super(AutoRec, self).__init__()
self.input_num = input_num
self.hidden_num = hidden_num
self.encoder = nn.Linear(self.input_num, self.hidden_num, bias=True)
self.relu = nn.ReLU()
self.decoder = nn.Linear(self.hidden_num, self.input_num, bias=True)
def forward(self, x):
hidden = self.encoder(x)
hidden = self.relu(hidden)
out = self.decoder(hidden)
return out
def GetData(data_mat):
dataset = Data.TensorDataset(torch.tensor(data_mat, dtype=torch.float32),
torch.zeros(data_mat.shape[0], 1).view(-1, 1))
loader = Data.DataLoader(
dataset=dataset,
batch_size=64,
shuffle=False
)
return loader
epochs = 100
input_num, hidden_num = 1682, 200
model = AutoRec(input_num, hidden_num)
learning_rate = 0.0003
optimizer = torch.optim.Adam([
{'params': (p for name, p in model.named_parameters() if 'bias' not in name)},
{'params': (p for name, p in model.named_parameters() if 'bias' in name), 'weight_decay': 0.}
], lr=learning_rate, weight_decay=0.001)
loss_func = torch.nn.MSELoss()
loss_train_set = []
loss_test_set = []
def run():
train()
draw(loss_train_set)
def train():
train_data_mat = TranslateData(u1_base)
r = train_data_mat[0]
train_loader = GetData(train_data_mat)
for epoch in range(epochs):
rmse_loss = 0
for step, (X, y) in enumerate(train_loader):
out = model(X)
rmse_loss = torch.sqrt(loss_func(out, X))
rmse_loss.backward()
optimizer.step()
loss_train_set.append(rmse_loss)
if epoch % 100 == 0:
print("epoch %d" % (epoch + 1))
test()
def test():
test_data_mat = TranslateData(u1_test)
test_loader = GetData(test_data_mat)
with torch.no_grad():
rmse_loss = 0
for step, (X, y) in enumerate(test_loader):
out = model(X)
rmse_loss += torch.sqrt(loss_func(out, X))
print("test_loss: %f" % (rmse_loss / test_data_mat.shape[0]))
def draw(loss_train_set):
x = [i for i in range(len(loss_train_set))]
plt.plot(x, loss_train_set, label="Training loss")
plt.xlabel("epochs")
plt.ylabel("rmse")
plt.legend()
plt.show()
if __name__ == "__main__":
run()
原文地址:https://www.cnblogs.com/foghorn/p/15130098.html
- linux学习第四十五篇:Nginx访问日志,Nginx日志切割,静态文件不记录日志和过期时间
- 合格的配置中心应有的素养No.76
- linux学习第四十六篇:Nginx防盗链,Nginx访问控制,Nginx解析php相关配置,Nginx代理
- linux学习第四十七篇:Nginx负载均衡,ssl原理,生产ssl密钥对,Nginx配置ssl
- linux学习第四十八篇:php-fpm的pool,php-fpm慢执行日志,定义open_basedir,php-fpm进程管理
- linux学习第五十一篇:NFS介绍,NFS服务端安装配置,NFS配置选项
- linux学习第五十二篇: exportfs命令,NFS客户端问题,FTP介绍,使用vsftpd搭建ftp服务
- linux学习第五十四篇:Tomcat介绍,安装jdk,安装Tomcat
- linux学习第五十九篇:LVS DR模式搭建,keepalived lvs
- linux学习第五十四篇:配置Tomcat监听80端口,配置Tomcat的虚拟主机,Tomcat日志
- linux学习第五十六篇:集群介绍,keepalived介绍,用keepalived配置高可用集群
- linux学习第五十八篇: 负载均衡集群介绍,LVS介绍,LVS的调度算法,LVS NAT模式搭建
- Python中eval带来的潜在风险,你知道吗?
- React Native自定义导航条
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Shiro学习笔记 三(认证授权)
- Shiro学习笔记四(Shiro集成WEB)
- Shiro学习笔记五(Shiro标签,及通配符)
- Shiro学习笔记六(自定义Reaml-使用数据库设置 user roles permissions)
- Luncene学习 第一天 《入门程序》
- Luncene学习二《搜索索引》
- JavaWeb--简单分页技术
- 使用Python制作第一个爬虫程序
- 使用BeautifulSoup 爬取一个页面上的所有的超链接
- 使用PlaceHolder,测试碰见的问题
- 隐藏MySQL InnoDB Cluster / ReplicaSet实例
- MySQL8.0.21——错误日志中的组复制系统消息
- 【一】、搭建Hadoop环境----本地、伪分布式
- 在组复制中指定恢复IP地址
- START GROUP_REPLICATION可以将恢复凭据作为参数