EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
CSV Vim 报税 tar Attention Firewall 顶会 Breakpoint FP8 CC PyCharm Zip FastAPI Review 论文速读 VSCode Website TSV HuggingFace Freesound Google TTS uWSGI Llama Video GGML Github QWEN CTC ONNX Bipartite FP32 PyTorch CUDA Translation git-lfs Diagram Docker Bitcoin Permission 搞笑 Color WebCrawler Markdown Hotel tqdm Bin ResNet-50 Knowledge Python SAM SPIE FP16 公式 IndexTTS2 Data Animate Quantize Plotly Nginx Algorithm Ptyhon JSON PDB SQL Plate uwsgi Hungarian Sklearn DeepSeek Crawler 域名 Cloudreve LoRA 证件照 printf Pytorch Safetensors HaggingFace GIT ChatGPT Password Vmess InvalidArgumentError News RAR MD5 C++ Disk Pillow transformers BF16 飞书 UI Domain CLAP VPN mmap logger 图形思考法 GPTQ Michelin OpenAI Tiktoken BeautifulSoup 算法题 SQLite NameSilo Input Streamlit diffusers Anaconda v2ray FP64 LLM BTC Qwen2 Linux Image2Text Git torchinfo Interview Tracking DeepStream 论文 Windows GoogLeNet UNIX 签证 Claude llama.cpp Land Ubuntu Web Quantization WAN Heatmap Paddle Transformers LaTeX CAM YOLO 音频 继承 PIP 关于博主 Hilton Datetime Base64 ModelScope Search Math CV VGG-16 Distillation Magnet OCR Baidu Jetson Shortcut icon 云服务器 XGBoost 强化学习 Django XML Jupyter Food PDF SVR Rebuttal LLAMA AI Use hf Conda NLP NLTK 净利润 Proxy Tensor scipy git GPT4 Excel Clash CEIR 多进程 Statistics TensorRT 版权 Qwen2.5 v0.dev API Agent Qwen 多线程 Dataset Paper COCO 第一性原理 TensorFlow Mixtral 阿里云 Logo 财报 腾讯云 RGB Pickle Bert Miniforge Random Pandas LeetCode 递归学习法 EXCEL Numpy Gemma Card 图标 Template OpenCV FlashAttention Augmentation
站点统计

本站现有博文327篇,共被浏览833174

本站已经建立2538天!

热门文章
文章归档
回到顶部