多目标运动轨迹预测HiVT代码跑通

这篇具有很好参考价值的文章主要介绍了多目标运动轨迹预测HiVT代码跑通。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

先介绍一下学者使用的运动轨迹预测数据集

Argoverse Motion Forecasting Dataset v1.1

现在Argoverse数据集已经出到v2版本,可以支持windows系统,但大多学者都是用2019发布的Argoverse v1.1,这个版本的api没有提供windows系统的支持,数据集作者说应该是转义字符的问题。(Argoverse v2的Motion Forecasting Dataset更大,全部下载完要50+g)。

https://github.com/argoai/argoverse-api

可以根据上面链接下载Argoverse api,这里都是使用v1.1版本的。Argoverse api v1.1仅支持MacOS和Linux,下载完后,对应下面步骤进行安装:

  1. 创建虚拟环境 python版本为3.8。(不懂的可以找我之前发的利用anaconda配置虚拟环境)

  1. 进入虚拟环境后,进入当前下载的Argoverse api文件夹中,执行

  pip install -e ./
  1. 也可以选择性安装mayavi、ffmpeg、Stereo tutorial dependencies。我只安装了mayavi,在安装这个之前先安装pip install PyQt5,再安装pip install mayavi。(这个应该也不用安装的,毕竟是视频流的相关库)

  1. 可以从数据集官网下载以下三个文件,用于使用Jupyter Notebook测试是否安装成功。可以打开Argoverse api中的Usage文件夹,打开你想要测试的,比如你想要测试轨迹预测,那就点击关于“Forecasting”的文件,逐个执行就可以了。

目标轨迹预测,python,深度学习,计算机视觉,Powered by 金山文档

再介绍HiVT代码

先上代码链接

https://github.com/ZikangZhou/HiVT

配置环境

conda create -n HiVT python=3.8
conda activate HiVT

# CUDA 10.2
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch

# CUDA 11.1
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge

# CPU Only
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cpuonly -c pytorch

conda install pytorch-geometric==1.7.2 
conda install pytorch-lightning==1.5.2

根据自己CUDA版本下载对用的pytorch,个人觉得pytorch在1.8.0-1.10.0左右版本应该都差不多,这个CUDA只要你大于11.1都可以向下兼容安装这两个版本的cudatoolkit。

你会发现pytorch-geometric有可能直接安装不了,这是很正常,进阶安装如下:

  1. 先进入轮子地址:https://pytorch-geometric.com/whl/

  1. 再寻找你的对应pytorch版本和cudatoolkit版本,点进去找到下面whl

目标轨迹预测,python,深度学习,计算机视觉,Powered by 金山文档
  1. 然后再 pip install torch-geometric==1.7.2

  1. 到此,应该安装完毕了

准备数据集

回到我们的数据集官网,下载

目标轨迹预测,python,深度学习,计算机视觉,Powered by 金山文档

将这些数据集解压放置新建dataset文件夹中,格式如下

目标轨迹预测,python,深度学习,计算机视觉,Powered by 金山文档

Training

训练小一点的模型 HiVT-64,可以适当修改batchsize或者其他参数,实测batchsize=32时候,占用显存才6g+。

python train.py --root dataset/ --embed_dim 64

训练大一点的模型HiVT-128文章来源地址https://www.toymoban.com/news/detail-613673.html

python train.py --root dataset/ --embed_dim 128

Evaluation

python eval.py --root dataset/ --batch_size 32 --ckpt_path /path/to/your_checkpoint.ckpt

到了这里,关于多目标运动轨迹预测HiVT代码跑通的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包