百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术资源 > 正文

【Python深度学习系列】网格搜索神经网络超参数:激活函数

off999 2024-10-23 12:42 38 浏览 0 评论

这是我的第267篇原创文章。

一、引言

在深度学习中,超参数是指在训练模型时需要手动设置的参数,它们通常不能通过训练数据自动学习得到。超参数的选择对于模型的性能至关重要,因此在进行深度学习实验时,超参数调优通常是一个重要的步骤。常见的超参数包括:

  • model.add()
    • neurons(隐含层神经元数量)
    • init_mode(初始权重方法)
    • activation(激活函数)
    • dropout(丢弃率)
  • model.compile()
    • loss(损失函数)
    • optimizer(优化器)
      • learning rate(学习率)
      • momentum(动量)
      • weight decay(权重衰减系数)
  • model.fit()
    • batch size(批量大小)
    • epochs(迭代次数)

一般来说,可以通过手动调优、网格搜索(Grid Search)、随机搜索(Random Search)、自动调参算法方式进行超参数调优,本文采用网格搜索选择激活函数。

二、实现过程

2.1 准备数据

dataset:

dataset = pd.read_csv("data.csv", header=None)
dataset = pd.DataFrame(dataset)
print(dataset)

2.2 数据划分

# 切分数据为输入 X 和输出 Y
X = dataset.iloc[:,0:8]
Y = dataset.iloc[:,8]
# 为了复现,设置随机种子
seed = 7
np.random.seed(seed)
random.set_seed(seed)

2.3 创建模型

需要定义个网格的架构函数create_model,create_model里面的参数要在KerasClassifier这个对象里面存在而且参数名要一致。

def create_model(activation):
    # 创建模型
    model = Sequential()
    model.add(Dense(50, input_shape=(8, ), kernel_initializer='uniform', activation=activation))
    model.add(Dropout(0.2))
    model.add(Dense(1, kernel_initializer='uniform', activation=activation))


    # 编译模型
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model
model = KerasClassifier(model=create_model, epochs=100, batch_size=80, verbose=0, activation='relu')

这里使用了scikeras库的KerasClassifier类来定义一个分类器,这里由于KerasClassifier没有定义初始化权重的参数,需要自定义一个表示激活函数的参数activtion,并赋默认值为'relu'。

2.4 定义网格搜索参数

param_grid = {'activation': ['softmax', 'softplus', 'softsign', 'relu',
              'tanh', 'sigmoid', 'hard_sigmoid', 'linear']}

param_grid是一个字典,key是超参数名称,这里的名称必须要在KerasClassifier这个对象里面存在而且参数名要一致。value是key可取的值,也就是要尝试的方案。

2.5 进行参数搜索

from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(estimator=model,  param_grid=param_grid)
grid_result = grid.fit(X, Y)

使用sklearn里面的GridSearchCV类进行参数搜索,传入模型和网格参数。

2.6 总结搜索结果

print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))

结果:

经过网格搜索,各层激活函数最优的选择是softplus。

作者简介: 读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。关注gzh:数据杂坛,获取数据和源码学习更多内容。

原文链接:
【Python深度学习系列】网格搜索神经网络超参数:激活函数(案例+源码)

相关推荐

查看电脑ip地址的命令(查看电脑ip地址用什么命令)
查看电脑ip地址的命令(查看电脑ip地址用什么命令)

1、在“本地连接”的状态中查看。2、使用“ipconfig/all”命令查看。3、打开电脑网页,输入IP地址,点击确定,就能看到本机IP。扩展资料IP地址(InternetProtocolAddress),全称为网际协议地址,是一种在...

2025-12-30 10:03 off999

ie浏览器9(IE浏览器9.0如何升级)

1、首先,我们点击开始菜单,找到控制面板,点击一下。2、之后,找到程序和功能选项,点击一下。3、点击进入后,我们找到左边的打开或关闭windows功能,点击一下。4、点击进入后,找到Internet...

hp1020打印机驱动怎么下载(hp1020打印机驱动怎么下载)

惠普1020打印机驱动怎么安装:  1.首先到下载软件名称:惠普1020打印机驱动程序官方版(支持win7/8)32位/64位软件大小:5.09MB更新时间:2014-09-05立即下载  2.然后...

win2003是windows7系统(win2003哪个版本好)

win2003是专门用于服务器的操作系统,现在最主流的windows服务器系统主要是win2003server和win2008server,winXP是个人电脑专用的操作系统,现在微软已经不再提供XP...

路由器账号和密码忘了怎么办

你好,如果你忘记了路由器的用户名和密码,你可以尝试重置一下路由器,大多数路由器都配备了一个复位按键。在重置路由器之后,用户名和密码将被还原为默认值,你可以在路由器的用户手册或厂家网站上找到默认的用户名...

win10永久禁止自动更新(win10禁止自动更新彻底)

阻止Windows10自动更新的方法如下:使用“本地组策略编辑器”:按下“Win+R”键,输入“gpedit.msc”打开本地组策略编辑器,找到“计算机配置”>“管理模板”>“W...

联想笔记本怎么看配置和型号

联想笔记本看配置的方法如下1、打开电脑,点击桌面的计算机,右键菜单里选择【属性】;打开后,即可看到电脑系统的大概信息;2、如果要看比较详细的设备相关信息,点击桌面的计算机,点击右键,在菜单里选择【系统...

怎样把打印机连接到电脑上(怎么把打印机连接电脑上)
  • 怎样把打印机连接到电脑上(怎么把打印机连接电脑上)
  • 怎样把打印机连接到电脑上(怎么把打印机连接电脑上)
  • 怎样把打印机连接到电脑上(怎么把打印机连接电脑上)
  • 怎样把打印机连接到电脑上(怎么把打印机连接电脑上)
photoshop6序列号(photoshop8.01序列号)
  • photoshop6序列号(photoshop8.01序列号)
  • photoshop6序列号(photoshop8.01序列号)
  • photoshop6序列号(photoshop8.01序列号)
  • photoshop6序列号(photoshop8.01序列号)
win10下载应用商店(win10应用商店打不开)

1、点击Win10系统的开始菜单,然后在点击应用商店;2、打开Win10应用商店后,在搜索框里输入想要搜索的应用软件,然后点击检索;3、点击搜索到的应用,点击安装;4、点击安装后,系统会提示要切换到这...

dell电脑重装系统win10(dell 重装win10系统)

戴尔笔记本重装系统win10的步骤如下:制作好wepe启动盘之后,将win10系统iso镜像直接复制到U盘。在需要重装系统的戴尔电脑上插入pe启动盘,重启后不停按F12启动快捷键,调出启动菜单对话框,...

android升级包下载安装(android 升级包)

打开手机系统更新升级,前提是官方有新系统推送才能更新  哪个大不一定,但一般规律如下:  1、小版本的更新,通常越更新越大。比如3.1更新到3.2,通常是修复bug,代码量通常会增大,体积就会增大。 ...

hdd硬盘和ssd(ssd硬盘和hdd硬盘是什么意思)

HDD硬盘和SSD硬盘是两种不同类型的电脑存储设备,它们有着以下区别:1.工作原理:HDD硬盘使用机械旋转的磁盘和读写磁头来存储和读取数据,而SSD硬盘则使用闪存存储数据,类似于USB闪存盘。2....

电脑免费软件下载大全(电脑上免费的下载软件)

正常情况下,如果我们想要在自己的电脑上面下载一个不要钱的单机游戏,那么我们是可以直接在我们的软件管理中心进行一个下载的,这个时候我们只需要通过一个权限就能够正常的下载,当然我们也是可以在一些小游戏的软...

mpp文件转换excel(mpp转换成pdf)

要将Excel表格转换为MPP格式,您可以按照以下步骤操作:1.打开Excel表格并确保数据按照项目的不同阶段或任务进行组织。2.将Excel表格中的数据复制到一个新的MicrosoftProj...

取消回复欢迎 发表评论: