一般来说,从FATE框架中获得数据使用get_component('name').get_output_data()
。
但是这样子在目前的1.x的FATE中,只能以分类、回归的格式输出才能获得。
如果是图片、文本、token embedding等,用这种方式根本拿不到模型的输出。
经过跟FATE社区人员交涉,社区肯定了这种方法拿不出。并且给了个方法,在自定义的trainer中的predict
函数,直接保存输出。不在通过上述方法获得。
只能说现在只能先这样用了。
如何自定义trainer,在官方文档有。
trainer中的predict部分部分原代码如下,直接在这里面添加save model prediction就行:文章来源:https://www.toymoban.com/news/detail-527895.html
def _predict(self, dataset: Dataset):
pred_result = []
# switch eval mode
dataset.eval()
self.model.eval()
labels = []
# 直接在这里save prediction
pred = self.model(images)
torch.save('./xxxx',pred)
length=len(dataset.get_sample_ids())
ret_rs = torch.rand(length,1)
ret_label = torch.rand(length, 1).int()
return dataset.get_sample_ids(), ret_rs, ret_label
def predict(self, dataset: Dataset):
ids, ret_rs, ret_label=self._predict(dataset)
if self.fed_mode:
return self.format_predict_result(
ids, ret_rs, ret_label, task_type=self.task_type)
else:
return ret_rs, ret_label
在上述代码我返回了一些假的数据,因为如果返回数据的格式不符合,Fateboard会直接报错,无法进入到下一步。所以放在那里,没用。文章来源地址https://www.toymoban.com/news/detail-527895.html
到了这里,关于【FATE联邦学习】非分类、回归任务,如何获得联邦模型的输出?的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!