EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
多线程 FP16 Bin Tensor Hotel Django Plate Jupyter LoRA SVR Review SQL Streamlit NLP git Windows Jetson Augmentation Land Distillation WAN Bipartite Plotly Freesound FastAPI 公式 Dataset BeautifulSoup Domain uwsgi Video 净利润 Pandas Git Proxy Claude transformers EXCEL Web C++ 财报 LLM Qwen2 关于博主 LaTeX Firewall Image2Text BF16 InvalidArgumentError FlashAttention Tiktoken Animate Search 强化学习 LLAMA 算法题 Vim Website Linux Ubuntu Translation CEIR Nginx ResNet-50 GGML GoogLeNet Transformers Cloudreve Breakpoint Agent DeepSeek Template CTC scipy TSV Interview BTC Zip Safetensors HaggingFace Tracking TensorFlow Paddle Rebuttal Pickle Baidu Hilton SPIE 证件照 IndexTTS2 Password 图形思考法 Michelin 腾讯云 Food ChatGPT Shortcut torchinfo LeetCode SAM 云服务器 TTS Crawler Gemma NameSilo Paper OpenAI UI Pillow CV DeepStream Algorithm VGG-16 Qwen2.5 阿里云 CUDA Disk RAR API Card Statistics CAM OpenCV git-lfs Use Input CC 第一性原理 Magnet VSCode ModelScope Diagram Permission GPT4 PyTorch icon Attention OCR Llama Clash MD5 Python PDB XML COCO FP64 FP8 顶会 llama.cpp 版权 YOLO NLTK printf Quantization Knowledge Markdown Math 飞书 WebCrawler Anaconda PIP Ptyhon Bert Bitcoin Pytorch 图标 HuggingFace Mixtral Data Hungarian ONNX Excel PDF Random 音频 tqdm RGB Color SQLite Datetime 签证 News 搞笑 Miniforge v2ray Conda mmap uWSGI FP32 diffusers Vmess VPN Github logger JSON PyCharm tar UNIX TensorRT Logo 多进程 CLAP 报税 继承 Qwen Base64 QWEN Quantize Docker 递归学习法 GPTQ Sklearn Google XGBoost Numpy hf GIT 域名 CSV AI v0.dev Heatmap
站点统计

本站现有博文323篇,共被浏览795765

本站已经建立2493天!

热门文章
文章归档
回到顶部