论文阅读:TinyGPT-V 论文阅读及源码梳理对应

这篇具有很好参考价值的文章主要介绍了论文阅读:TinyGPT-V 论文阅读及源码梳理对应。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

引言

TinyGPT-V来自论文:TinyGPT-V: Efficient Multimodal Large Language Model via Small Backbones,是一篇基于较小LLM作为backbone的多模态工作。相关工作已经开源,地址为:Github

之所以选择这篇文章,是因为比较具有落地意义,且便于本地运行,查看和调试。

以下代码只给出核心部分,会省略无关部分。如想查看完整代码,可以移步仓库SWHL/TinyGPT-V

整体结构图

论文阅读:TinyGPT-V 论文阅读及源码梳理对应,论文学习,论文阅读
从以上整体结构图中可以看到,模型主要分为4部分:Visual Encode & Q-Former、MiniGPT-4 Proj、Linear和Phi-2。

推理流程讲解

该部分主要以Stage1-3阶段模型的推理入手,输入是一个图像和对应文本,注重讲述图像和文本是如何被处理,送入模型得到最终输出结果的。以下图的图像和文本(Please write a poem about the image)作为输入。

论文阅读:TinyGPT-V 论文阅读及源码梳理对应,论文学习,论文阅读
运行效果如下所示,小伙伴可自行前往Hugging Face体验:
论文阅读:TinyGPT-V 论文阅读及源码梳理对应,论文学习,论文阅读
为了便于查看,我这里整理了命令行推理的版本(demo_cli.py),更加清晰看到数据走向,仅仅用于学习使用。

我们先来看一张图像,在推理阶段都经过了什么,才到达最终模型面前,截取demo_cli.py中核心代码如下:

# 1. Image读取图像
img_path = "tests/test_files/1.png"
img = Image.open(img_path)
img = img.convert("RGB")

# 初始化对话类
chat_state = CONV_VISION.copy()

# 2. 上传图像,并对图像做预处理
img_list = []
llm_message = chat.upload_img(img, chat_state, img_list)

# 3. 提取图像特征
chat.encode_img(img_list)

# 4. 将用户提问问题加入到对话类中,用于后续拼接prompt
user_msg = "Please write a poem about the image"
chat.ask(user_msg, chat_state)

# 5. 核心,送入到Phi-2中,根据图像回答用户问题
num_beams = 1
temperature = 1.0
llm_message = chat.answer(
    conv=chat_state,
    img_list=img_list,
    num_beams=num_beams,
    temperature=temperature,
    max_new_tokens=300,
    max_length=2000,
)[0]
print(llm_message)

接下来,依次对图像经过流程,做详细解读:

chat_state组成
conv = Conversation(
    system="Give the following image: <Img>ImageContent</Img>. "
    "You will be able to see the image once I provide it to you. Please answer my questions.",
    roles=("Human: ", "Assistant: "),
    messages=[],
    offset=2,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)
chat.upload_img()

源码位置:link

def upload_img(self, image, conv, img_list):
     # 这里将<Img></Img>添加到了mesages下,便于后续拼接完整prompt
     conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
     img_list.append(image)
     msg = "Received."
     return msg
chat.encode_img(img_list)

源码位置:link

def encode_img(self, img_list):
    image = img_list[0]
    img_list.pop(0)
    if isinstance(image, str):  # is a image path
        raw_image = Image.open(image).convert("RGB")
        image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
    elif isinstance(image, Image.Image):
        # 因为上述代码传入是Image类型的,走这里
        raw_image = image
        image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)

    # 这里进入模型对图像进行编码,得到图像特征向量	
    image_emb, _ = self.model.encode_img(image)
    img_list.append(image_emb)

其中,根据配置文件tinygptv_stage1_2_3_eval.yam,可以知道 self.vis_processor指的是Blip2ImageEvalProcessor类。该类中,对图像做了三个操作:Resize、ToTensor、Normalize。代码如下(link):

@registry.register_processor("blip2_image_eval")
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
    def __init__(self, image_size=224, mean=None, std=None):
        super().__init__(mean=mean, std=std)

        self.transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size), interpolation=InterpolationMode.BICUBIC
                ),
                transforms.ToTensor(),
                self.normalize,
            ]
        )

    def __call__(self, item):
        return self.transform(item)

图像经过以上transform之后,就变为了Tensor向量,接下来将进入提取特征部分。根据论文中描述,该部分采用的是VIT EVA、Q-Former和全连接层。核心代码如下(源码:link):
(❓ 这里我有一点小小疑问:论文结构图中有一个部分是:MiniGPT-4 Proj,这一部分在源码中并没有发现。)

def encode_img(self, image):
     # 省略部分无关代码... ...

     with self.maybe_autocast():
         # self.visual_encoder就是基于eva_clip_g的模型
         # self.ln_vision指的是LayerNorm层
         image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)

		 # 使用Q-Former,基于bert-base-uncased
         if self.has_qformer:
             image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)

             query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
             query_output = self.Qformer.bert(
                 query_embeds=query_tokens,
                 encoder_hidden_states=image_embeds,
                 encoder_attention_mask=image_atts,
                 return_dict=True,
             )

             inputs_llama = self.llama_proj(query_output.last_hidden_state)
             inputs_llama = self.llama_proj2(inputs_llama)

         atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
     return inputs_llama, atts_llama
chat.ask()

源码:link

def ask(self, text, conv):
    if (
        len(conv.messages) > 0
        and conv.messages[-1][0] == conv.roles[0]
        and conv.messages[-1][1][-6:] == "</Img>"
    ):  # last message is image.
        conv.messages[-1][1] = " ".join([conv.messages[-1][1], text])
    else:
        conv.append_message(conv.roles[0], text)

将conv实例中message字段值做了拼接。拼接之后的message如下:

print(conv.messages)
[['Human: ', '<Img><ImageHere></Img>Please write a poem about the image']]
chat.answer()

源码:link

该部分是核心部分,以上做的操作都是在准备送入大模型。

def answer(self, conv, img_list, **kargs):
    # self.answer_prepare主要是将Image特征与Text特征做拼接
    generation_dict = self.answer_prepare(conv, img_list, **kargs)

    # 送入模型进行推理
    output_token = self.model_generate(**generation_dict)[0]

    # 解码得到文本
    output_text = self.model.llama_tokenizer.decode(
        output_token, skip_special_tokens=True
    )
    output_text = output_text.split("###")[0]  # remove the stop sign '###'
    output_text = output_text.split("Assistant:")[-1].strip()
    conv.messages[-1][1] = output_text
    return output_text, output_token.cpu().numpy()

接下来,让我们详细看一下self.answer_prepare这个函数具体做了什么?

def answer_prepare(self, ...):
     # 这里在message后追加一个Assitant:部分
     conv.append_message(conv.roles[1], None) 
  
     # 将conv中system内容、图像和文本拼接,获得送入大模型的完整Prompt。示例如下:
     # Give the following image: <Img>ImageContent</Img>. You will be able to see the image once I provide it to you. Please answer my questions.###Human: <Img><ImageHere></Img>###Assistant: 
     prompt = conv.get_prompt()

     # 获得文本的embedding,并与图像特征进行mix
     embs = self.model.get_context_emb(prompt, img_list)

     current_max_len = embs.shape[1] + max_new_tokens
     begin_idx = max(0, current_max_len - max_length)
     embs = embs[:, begin_idx:]

     generation_kwargs = dict(
         inputs_embeds=embs,
         max_new_tokens=max_new_tokens,
         stopping_criteria=self.stopping_criteria,
         num_beams=num_beams,
         do_sample=True,
         min_length=min_length,
         top_p=top_p,
         repetition_penalty=repetition_penalty,
         length_penalty=length_penalty,
         temperature=float(temperature),
         pad_token_id=tokenizer.pad_token_id,
         bos_token_id=tokenizer.bos_token_id,
         eos_token_id=tokenizer.eos_token_id,
     )
     return generation_kwargs

再看如何将文本与图像特征融合的:先将图像转为向量。将prompt除Image部分其他部分依次转为向量。再将两者mix,得到最终向量。

def get_context_emb(self, prompt, img_list):
    device = img_list[0].device
    prompt_segs = prompt.split("<ImageHere>")
    assert (
        len(prompt_segs) == len(img_list) + 1
    ), "Unmatched numbers of image placeholders and images."

    seg_tokens = [
        self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=i == 0)
        .to(device)
        .input_ids  # only add bos to the first seg
        for i, seg in enumerate(prompt_segs)
    ]

    seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]

    # TODO: 这里具体如何混合在一起的,需要Debug查看
    mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [
        seg_embs[-1]
    ]
    mixed_embs = torch.cat(mixed_embs, dim=1)
    return mixed_embs

至此,将准备好的数据送入到 self.model_generate() 函数中,即可得到模型的回答。

写在最后

本篇文章写得还是较为粗糙一些,只是想通过这个来学习多模态中一般处理方法,不仅仅限于模型结构。后续如果时间,还会更新训练各个阶段的具体做法。欢迎持续关注。文章来源地址https://www.toymoban.com/news/detail-792035.html

到了这里,关于论文阅读:TinyGPT-V 论文阅读及源码梳理对应的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 计算卸载论文阅读01-理论梳理

    标题:When Learning Joins Edge: Real-time Proportional Computation Offloading via Deep Reinforcement Learning 会议:ICPADS 2019 问题: 在任务进行卸载时,往往忽略了任务的特定的卸载比例。 模型: 针对上述问题,我们提出了一种创新的强化学习(RL)方法来解决比例计算问题。我们考虑了一种常

    2024年02月09日
    浏览(39)
  • 基于微信小程序的毕业设计——在线阅读系统(附源码+论文)

    大家好!我是职场程序猿,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:微信小程序毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 安卓app毕业设计 🌎Java毕业设计 本文主要还是以在线阅读系统设计和实现为主要的考虑内容,为了能够考虑到信息安全性和系统的数据访问

    2024年02月06日
    浏览(47)
  • 论文阅读---联邦忘却学习研究综述

    论文:联邦忘却学习研究综述 federated unlearning-联邦忘却学习 摘要 联邦忘却学习撤销用户数据对联邦学习模型的训练更新,可以进一步保护联邦学习用户的数据安全。 联邦忘却学习在联邦学习框架的基础上,通过迭代训练,直接删除等方式,撤销用户本地局部模型对全局模型

    2024年03月12日
    浏览(105)
  • TM 学习记录--论文阅读1

    这里可以查看所有论文。由于作者book只更新到第二章剩下的只有从论文中学习,但书中的目录和论文可以由于对应起来。第一二章可以对应到第一篇论文,这里。 尽管单独而言很简单,但人工神经元在深度网络中互连时可提供最先进的性能。 可以说,Tsetlin 自动机是一种更

    2024年02月07日
    浏览(45)
  • 强化学习论文阅读(二)SAC算法

    原文传递:SAC算法原文 作者指出深度强化学习样本效率低下的原因是:策略学习,TRPO、PPO、A3C每次策略更新都需要收集样本。学习有效的策略需要的步骤和样本数量伴随着任务的复杂性呈现增加的趋势。Off-Policy为了重复使用过去产生的经验值,但是在传统的策略公式当中不

    2024年02月06日
    浏览(44)
  • 论文阅读——基于深度学习智能垃圾分类

    B. Fu, S. Li, J. Wei, Q. Li, Q. Wang and J. Tu, “A Novel Intelligent Garbage Classification System Based on Deep Learning and an Embedded Linux System,” in IEEE Access, vol. 9, pp. 131134-131146, 2021, doi: 10.1109/ACCESS.2021.3114496. 垃圾数量的急剧增加和垃圾中物质的复杂多样性带来了严重的环境污染和资源浪费问题。回收

    2024年02月11日
    浏览(43)
  • 对比学习论文阅读:CoCLR算法笔记

    标题:Self-supervised Co-training for Video Representation Learning 会议:NIPS2020 论文地址:https://dl.acm.org/doi/abs/10.5555/3495724.3496201 官方代码:https://www.robots.ox.ac.uk/~vgg/research/CoCLR/ 作者单位:牛津大学 本文的研究目标是纯视觉的自监督视频表征学习。我们做出了以下贡献:①我们研究了在

    2024年02月03日
    浏览(60)
  • 基于微信小程序的在线小说阅读的设计与实现(源码+论文)_v213

    摘要 近年来,随着社会科技的不断发展,人们的生活方方面面进入了信息化时代。计算机的普及,使得我们的生活更加丰富多彩,同时,随着智能手机的普遍使用,不少的微信小程序也应运而生,逐步改变着人们的生活方式。手机作为这个时代的新生产物,具有高效、便携、

    2024年02月03日
    浏览(59)
  • 【论文阅读】基于深度学习的时序预测——FEDformer

    系列文章链接 论文一:2020 Informer:长时序数据预测 论文二:2021 Autoformer:长序列数据预测 论文三:2022 FEDformer:长序列数据预测 论文四:2022 Non-Stationary Transformers:非平稳性时序预测 论文五:2022 Pyraformer:基于金字塔图结构的时序预测 论文六:2023 Crossformer:多变量时序预

    2024年02月13日
    浏览(37)
  • 【论文阅读】基于深度学习的时序预测——Crossformer

    系列文章链接 论文一:2020 Informer:长时序数据预测 论文二:2021 Autoformer:长序列数据预测 论文三:2022 FEDformer:长序列数据预测 论文四:2022 Non-Stationary Transformers:非平稳性时序预测 论文五:2022 Pyraformer:基于金字塔图结构的时序预测 论文六:2023 Crossformer:多变量时序预

    2024年02月13日
    浏览(43)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包