EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
git SPIE XGBoost PDF GPTQ ResNet-50 Permission GGML Qwen Domain CV GoogLeNet CLAP Color Bin SQLite PIP Paddle News Math Random COCO 签证 Pillow 财报 Image2Text v2ray Excel Video Translation NLP LaTeX Interview AI 净利润 Bert 搞笑 Web 公式 Bitcoin Markdown HuggingFace XML 腾讯云 UNIX Django Use LLAMA 递归学习法 继承 RAR Tensor Github CC Streamlit Logo BTC 报税 OCR LeetCode Nginx PyCharm Template NameSilo Conda Card Git LoRA QWEN TensorFlow Distillation Website Clash tar Sklearn Rebuttal 域名 Quantize Algorithm Proxy GIT Baidu Base64 Qwen2.5 VSCode Freesound Transformers NLTK OpenCV 阿里云 printf torchinfo Safetensors Augmentation FastAPI VGG-16 Bipartite Ptyhon Food Claude Miniforge 音频 ChatGPT CSV Firewall Gemma Docker TSV Statistics Plotly HaggingFace FlashAttention FP8 v0.dev API 证件照 Jupyter Vmess 多线程 BF16 图标 mmap EXCEL Plate OpenAI diffusers CTC TensorRT PyTorch Ubuntu Python Attention 算法题 Google Jetson ModelScope 多进程 Pytorch 强化学习 第一性原理 hf git-lfs Disk JSON Windows Animate Breakpoint Land Llama FP16 Mixtral LLM Data CUDA Hungarian UI Qwen2 llama.cpp uwsgi 云服务器 Review VPN 飞书 WebCrawler Knowledge Paper Agent SAM WAN Diagram Password DeepSeek 关于博主 TTS Magnet Pickle Michelin Heatmap Search DeepStream ONNX SQL uWSGI InvalidArgumentError FP64 RGB C++ CAM tqdm Datetime Dataset Quantization Zip GPT4 BeautifulSoup MD5 Tracking transformers PDB Anaconda YOLO Pandas logger Input icon IndexTTS2 Vim 顶会 CEIR Shortcut SVR Hotel scipy Numpy Crawler 图形思考法 FP32 版权 Linux Tiktoken Cloudreve Hilton
站点统计

本站现有博文323篇,共被浏览795253

本站已经建立2493天!

热门文章
文章归档
回到顶部