如果需要和命令行接口进行交互,可以使用Python中的argparse包,快捷方便,对于Lightning而言,可以利用它,在命令行窗口中,直接配置超参数等操作,但也可以使用LightningCLI的方法,更加轻便简单。
ArgumentParser
ArgumentParser是Python的内置特性,进而构建CLI程序,我们可以使用它在命令行中设置超参数和其他训练设置。
from argparse import ArgumentParser
parser = ArgumentParser()
# 训练方式(GPU or CPU or 其他)
parser.add_argument("--devices", type=int, default=2)
# 超参数
parser.add_argument("--layer_1_dim", type=int, default=128)
# 解析用户输入和默认值 (returns argparse.Namespace)
args = parser.parse_args()
# 在程序中使用解析后的参数
trainer = Trainer(devices=args.devices)
model = MyModel(layer_1_dim=args.layer_1_dim)
然后在命令行中如此调用
python trainer.py --layer_1_dim 64 --devices 1
Python的参数解析器在简单的用例中工作得很好,但在大型项目中维护它可能会变得很麻烦。例如,每次在模型中添加、更改或删除参数时,都必须添加、编辑或删除相应的add_argument。Lightning CLI提供了与Trainer和LightningModule的无缝集成,为您自动生成CLI参数。
LightningCLI
pip install "jsonargparse[signatures]"
执行起来很简单,例如
# main.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
def cli_main():
# 只需要写这一行即可,两个参数,对应模型和数据
cli = LightningCLI(DemoModel, BoringDataModule)
# 注意: 别写.fit
if __name__ == "__main__":
cli_main() # 在函数中实现CLI并在主if块中调用它是一种很好的做法
然后在命令行中执行help,进行文档查询
python main.py --help
执行结果
usage: main.py [-h] [-c CONFIG] [--print_config[=flags]]
{fit,validate,test,predict,tune} ...
pytorch-lightning trainer command line tool
optional arguments:
-h, --help Show this help message and exit.
-c CONFIG, --config CONFIG
Path to a configuration file in json or yaml format.
--print_config[=flags]
Print the configuration after applying all other
arguments and exit. The optional flags customizes the
output and are one or more keywords separated by
comma. The supported flags are: comments,
skip_default, skip_null.
subcommands:
For more details of each subcommand, add it as an argument followed by
--help.
{fit,validate,test,predict,tune}
fit Runs the full optimization routine.
validate Perform one evaluation epoch over the validation set.
test Perform one evaluation epoch over the test set.
predict Run inference on your data.
tune Runs routines to tune hyperparameters before training.
因此可以使用如下方法:文章来源:https://www.toymoban.com/news/detail-614035.html
$ python main.py fit # 训练
$ python main.py validate # 验证
$ python main.py test # 测试
$ python main.py predict # 预测
例如训练过程,可以通过以下方法具体调参数文章来源地址https://www.toymoban.com/news/detail-614035.html
# learning_rate
python main.py fit --model.learning_rate 0.1
# output dimensions
python main.py fit --model.out_dim 10 --model.learning_rate 0.1
# trainer 和 data arguments
python main.py fit --model.out_dim 2 --model.learning_rate 0.1 --data.data_dir '~/' --trainer.logger False
到了这里,关于PyTorch Lightning教程四:超参数的使用的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!