update the code to use the module's __call__

This commit is contained in:
wangzhihong
2024-03-21 10:09:34 +08:00
parent 52afd48b06
commit 1e8375848d

View File

@@ -403,10 +403,10 @@ class TransformerBlock(nn.Module):
torch.Tensor: Output tensor after applying attention and feedforward layers.
"""
h = x + self.attention.forward(
h = x + self.attention(
self.attention_norm(x), start_pos, freqs_cis, mask
)
out = h + self.feed_forward.forward(self.ffn_norm(h))
out = h + self.feed_forward(self.ffn_norm(h))
return out