首页 > 其他 > 详细

TrajPreModel

时间:2020-05-19 22:37:09      阅读:49      评论:0      收藏:0      [点我收藏+]

轨迹预测模型

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

#######################################
class TrajPreModel(nn.Module):
    """self-attention model"""
    def __init__(self, loc_size=528, loc_emb_size=128, hidden_size=32, head_num=1, dropout_p=0):
        super(TrajPreModel, self).__init__()
        self.loc_size = loc_size
        self.loc_emb_size = loc_emb_size
        self.hidden_size = hidden_size
        self.heads = head_num
        self.dropout_p = dropout_p
        # embeding
        self.emb_loc = nn.Embedding(self.loc_size, self.loc_emb_size)
        self.weight = self.emb_loc.weight
              
        #-------------model---------------
        self.attention = MultiSelfAttention(self.heads, self.loc_emb_size, dropout=self.dropout_p)
        self.fc = nn.Linear(self.loc_emb_size, self.loc_size)
        self.is_weight_sharing = False#is_weight_sharing
        self.init_weights()
        self.dropout = nn.Dropout(p=dropout_p)

    def init_weights(self):
        ih = (param.data for name, param in self.named_parameters() if ‘weight_ih‘ in name)
        hh = (param.data for name, param in self.named_parameters() if ‘weight_hh‘ in name)
        b = (param.data for name, param in self.named_parameters() if ‘bias‘ in name)
        for t in ih:
            nn.init.xavier_uniform(t)
        for t in hh:
            nn.init.orthogonal(t)
        for t in b:
            nn.init.constant_(t, 0)

    def forward(self, x):
        
        seq = x[1] # [batch_size, seq_len]
        loc_emb = self.emb_loc(seq) 
        output = self.dropout(loc_emb)
        #Self-attention
        
        output = self.attention(output,output, output)
        output = self.dropout(output)

        if not self.is_weight_sharing:
            y = self.fc(output)
        else:
            y = F.linear(output, self.weight)
        
        score = F.log_softmax(y, dim=-1) 
        return score.view(-1, self.loc_size) # [batch_size, seq_len, loc_size]

TrajPreModel

原文:https://www.cnblogs.com/lixyuan/p/12919950.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!