Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TimeDistributed困惑 #3

Open
PuHaoran opened this issue Sep 14, 2018 · 2 comments
Open

TimeDistributed困惑 #3

PuHaoran opened this issue Sep 14, 2018 · 2 comments

Comments

@PuHaoran
Copy link

PuHaoran commented Sep 14, 2018

tks,fastText模型在Embedding层之后有个TimeDistributed,不太清楚这层的作用。而且发现你的三个模型里都会有这一步。pytorch小白,针对这个点可以详细解答下吗(这块维度变化也不太懂) 不胜感激~~

self.tdfc1 = nn.Linear(D, 512)
self.td1 = TimeDistributed(self.tdfc1)
self.tdbn1 = nn.BatchNorm2d(1)

self.tdfc2 = nn.Linear(D, 512)
self.td2 = TimeDistributed(self.tdfc2)
self.tdbn2 = nn.BatchNorm2d(1)

self.fc1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.fc2 = nn.Linear(512, C)

....
def forward(self, x, y):
if self.opt['use_char_word']:
x = self.embed_char(x.long())
y = self.embed_word(y.long())
elif self.opt['use_word_char']:
x = self.embed_word(x.long())
y = self.embed_char(y.long())
else:
x = self.embed(x.long())
y = self.embed(y.long())

if self.opt['static']:
    x = x.detach()
x = F.relu(self.tdbn1(self.td1(x).unsqueeze(1))).squeeze(1)
    
if self.opt['static']:
    y = y.detach()
y = F.relu(self.tdbn2(self.td2(y).unsqueeze(1))).squeeze(1)
    
x = x.mean(1).squeeze(1)
y = y.mean(1).squeeze(1)

x = torch.cat((x, y), 1)

x = F.relu(self.bn1(self.fc1(x)))
logit = self.fc2(x)
return logit

......

class TimeDistributed(nn.Module):
def init(self, module):
super(TimeDistributed, self).init()
self.module = module

def forward(self, x):
    if len(x.size()) <= 2:
        return self.module(x)
    n, t = x.size(0), x.size(1) 
    # merge batch and seq dimensions
    x_reshape = x.contiguous().view(t * n, x.size(2))
    y = self.module(x_reshape)
    # we have to reshape Y
    y = y.contiguous().view(n, t, y.size()[1])
    return y
@PuHaoran PuHaoran changed the title TimeDistributed十分疑惑 TimeDistributed困惑 Sep 14, 2018
@Magic-Bubble
Copy link
Owner

可以理解为对原始的embedding做一个transform,是一层timestep的全连接,会提升模型效果。纬度原先是256,后面可以transfer到其他维度,比如这里的512

1 similar comment
@Magic-Bubble
Copy link
Owner

可以理解为对原始的embedding做一个transform,是一层timestep的全连接,会提升模型效果。纬度原先是256,后面可以transfer到其他维度,比如这里的512

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants