EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
XML FP64 Algorithm Translation Use Video Bert mmap LLAMA Random WebCrawler Bitcoin git Knowledge Miniforge 顶会 Paddle C++ Color PyTorch Plate 域名 Land Template Paper Quantization Baidu Qwen2 Conda LoRA HaggingFace Base64 VPN VSCode API Hilton 证件照 EXCEL Interview Zip Hungarian Disk MD5 Magnet Data BF16 LLM Safetensors Docker Tiktoken Linux Augmentation DeepStream ONNX Bin Numpy ResNet-50 tqdm Permission tar 签证 Clash Python Freesound TensorFlow Search IndexTTS2 Markdown 腾讯云 Quantize Transformers 图标 CSV Rebuttal Streamlit Attention CUDA torchinfo 继承 Cloudreve XGBoost Statistics Tensor PIP Distillation Vim UI Diagram Firewall Llama InvalidArgumentError FlashAttention Claude 财报 搞笑 云服务器 GIT Proxy 音频 Domain 多进程 UNIX PDF QWEN Card diffusers NLP OpenAI Mixtral News Sklearn Heatmap YOLO Git Qwen Vmess Crawler 关于博主 算法题 Anaconda logger Ubuntu Agent Image2Text transformers HuggingFace 强化学习 Logo Web uwsgi FastAPI NLTK Password Animate GPT4 Pickle CTC Excel GPTQ 报税 净利润 Hotel Website Datetime Math Google PyCharm 飞书 SQL Jetson CEIR Jupyter Dataset v2ray CAM FP32 CLAP GoogLeNet OpenCV 递归学习法 ChatGPT Github BeautifulSoup Pytorch Input PDB git-lfs Django ModelScope SAM 图形思考法 TensorRT CV Pillow Nginx llama.cpp Pandas SQLite LeetCode Review Plotly LaTeX OCR 阿里云 SPIE NameSilo TTS Windows 第一性原理 Breakpoint BTC Qwen2.5 SVR RAR v0.dev FP16 JSON printf 多线程 FP8 Ptyhon Food CC Gemma Tracking GGML uWSGI RGB VGG-16 公式 Michelin AI hf WAN DeepSeek icon COCO 版权 Bipartite scipy Shortcut TSV
站点统计

本站现有博文323篇,共被浏览796018

本站已经建立2493天!

热门文章
文章归档
回到顶部