EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
多线程 tqdm VGG-16 ChatGPT OpenAI Translation Template AI SAM torchinfo DeepStream CV 公式 C++ TTS ModelScope Michelin PyCharm uwsgi git-lfs Transformers Pillow Bitcoin hf Quantization Pickle 第一性原理 Ptyhon llama.cpp 关于博主 JSON 证件照 v2ray Llama SPIE 域名 Animate OpenCV VPN Permission Statistics Qwen2 UNIX mmap 顶会 Crawler Review Markdown Anaconda Sklearn Claude diffusers NameSilo Clash Hilton Streamlit SQLite HuggingFace EXCEL LLAMA FP16 BTC Knowledge Baidu API Disk TSV Python ONNX scipy HaggingFace Linux Web Github NLP Datetime Qwen2.5 LoRA 图形思考法 Card Tiktoken 强化学习 Rebuttal Food LLM 论文 Bert Django Dataset News 净利润 logger GPTQ FP64 Tracking ResNet-50 Cloudreve Bin Git Search Qwen Distillation 音频 XML RGB SVR git Website Magnet Plotly QWEN Plate FlashAttention Proxy GGML 版权 Safetensors SQL Ubuntu TensorFlow VSCode Logo icon DeepSeek LeetCode 继承 LaTeX Augmentation Tensor GoogLeNet Hungarian uWSGI Base64 v0.dev GPT4 递归学习法 GIT IndexTTS2 Agent WebCrawler 签证 Interview FP8 BeautifulSoup Video Hotel Image2Text Random Shortcut Freesound NLTK CTC Jupyter Quantize RAR Algorithm 腾讯云 Google CLAP YOLO Zip MD5 Nginx Pandas Windows Password CSV Numpy Land BF16 Vim Input Excel Paddle 飞书 Math Data 算法题 财报 Docker PDB 图标 Attention Use Conda 阿里云 CC 云服务器 UI TensorRT Domain Mixtral Bipartite 多进程 Diagram PyTorch CAM XGBoost Color Vmess OCR Pytorch FP32 Firewall printf Jetson PDF Breakpoint InvalidArgumentError COCO Gemma CEIR Heatmap 搞笑 FastAPI 报税 论文速读 tar Miniforge PIP transformers CUDA WAN Paper
站点统计

本站现有博文327篇,共被浏览832979

本站已经建立2538天!

热门文章
文章归档
回到顶部