EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
公式 ONNX RAR Paddle Agent Augmentation Mixtral Linux Paper DeepSeek SQL Website WebCrawler Animate Python RGB Streamlit Use Base64 LoRA Image2Text CV Docker Logo PIP Bitcoin Hungarian FP32 Jupyter XML Domain XGBoost SQLite 递归学习法 Disk WAN Hotel VSCode BeautifulSoup Transformers Tensor GGML tar Bin Llama QWEN Ubuntu Input 飞书 SAM Data 图形思考法 Statistics CSV Excel InvalidArgumentError Ptyhon LaTeX 证件照 Diagram Interview Dataset uwsgi Markdown FP64 DeepStream Translation Clash Safetensors 报税 PyCharm Quantization Claude Attention 财报 EXCEL Numpy Zip Tiktoken Hilton Anaconda CUDA Quantize CAM Pytorch FP16 tqdm FastAPI Heatmap Git Password AI CEIR GPT4 hf ModelScope PyTorch Math Django Michelin Qwen2 BF16 Card YOLO Magnet v0.dev Knowledge Color Gemma API Github Miniforge llama.cpp C++ 净利润 Bipartite scipy Random Plate 图标 Nginx CLAP git-lfs LLM uWSGI IndexTTS2 Bert MD5 SVR 多线程 Algorithm Windows News 第一性原理 GIT TTS Search Pillow COCO JSON Breakpoint 腾讯云 Vmess PDF Freesound Rebuttal NLTK TSV logger ChatGPT Proxy mmap Sklearn transformers OpenCV v2ray ResNet-50 Datetime Conda Web 关于博主 FP8 顶会 继承 音频 Template Video 阿里云 Permission icon 云服务器 SPIE 强化学习 git Qwen2.5 CTC diffusers Cloudreve Distillation FlashAttention GPTQ NLP NameSilo Pandas 签证 HaggingFace OCR VPN GoogLeNet 版权 Shortcut TensorFlow BTC LLAMA HuggingFace UNIX 多进程 TensorRT VGG-16 Google Tracking PDB Review Land CC Qwen Baidu 域名 Plotly printf 搞笑 Vim LeetCode torchinfo Food UI OpenAI Jetson Firewall Pickle Crawler 算法题
站点统计

本站现有博文323篇,共被浏览795412

本站已经建立2493天!

热门文章
文章归档
回到顶部