EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Python ModelScope UNIX AI HuggingFace GIT Qwen2.5 CLAP Plate FP64 Diagram Linux Hotel llama.cpp Excel XGBoost FP32 Quantization uWSGI Ptyhon Jetson 关于博主 Password Pandas Algorithm Food FastAPI Baidu PIP Shortcut Git Permission PDF logger VSCode Claude OpenCV Clash Django 音频 Data Docker Video FlashAttention Streamlit Quantize NLTK Vmess v0.dev mmap TSV Math UI BeautifulSoup LoRA 净利润 NLP 飞书 ONNX 财报 hf Random DeepStream CSV Plotly Search CEIR FP8 BTC 顶会 SPIE Markdown Card scipy Mixtral 报税 Llama Animate 多进程 图形思考法 IndexTTS2 Tiktoken Use NameSilo News 论文 Bipartite Hungarian PyTorch Sklearn 公式 Zip 论文速读 YOLO Attention Hilton Review CTC API Website LLAMA Paddle Domain ResNet-50 JSON InvalidArgumentError Disk PDB Gemma OCR ChatGPT Ubuntu SQL Pytorch SQLite v2ray Pillow Image2Text Rebuttal 多线程 Datetime VPN Statistics CC LLM Numpy 签证 Vim WAN MD5 TensorRT Firewall SAM 阿里云 版权 XML Qwen2 Transformers TTS 强化学习 Miniforge git-lfs TensorFlow GGML Distillation VGG-16 Web FP16 RAR Logo RGB Anaconda tar Tracking Breakpoint QWEN PyCharm Knowledge Nginx 证件照 OpenAI Color DeepSeek Pickle printf Paper Tensor WebCrawler CAM 算法题 Magnet torchinfo icon Freesound HaggingFace Windows LaTeX Bitcoin Template SVR GPT4 云服务器 Input Land Cloudreve 搞笑 uwsgi LeetCode EXCEL Google BF16 CV Jupyter COCO GPTQ C++ Qwen Michelin Bin Conda Translation 递归学习法 Bert Dataset 腾讯云 diffusers GoogLeNet transformers 第一性原理 域名 Base64 Crawler Agent Interview 图标 CUDA Safetensors Proxy 继承 Augmentation git tqdm Github Heatmap
站点统计

本站现有博文327篇,共被浏览833540

本站已经建立2538天!

热门文章
文章归档
回到顶部