EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Crawler SQLite Qwen2 IndexTTS2 ONNX OpenCV GoogLeNet hf Image2Text Docker Website Mixtral Bipartite Ubuntu scipy Plate git-lfs Freesound tqdm Transformers Firewall DeepStream ModelScope Vmess LoRA Qwen2.5 OCR FastAPI Tracking ChatGPT SVR WAN Interview Paddle Template Quantize 报税 Video QWEN 签证 Cloudreve C++ PDB SQL Conda Math git Git HuggingFace Miniforge TensorFlow COCO Random Sklearn Dataset 继承 llama.cpp 域名 Clash 多线程 BF16 强化学习 WebCrawler Input Google CEIR PIP XML Django Quantization Magnet EXCEL Ptyhon 顶会 净利润 云服务器 Pickle Rebuttal Review SAM 第一性原理 PDF Streamlit API icon Heatmap LeetCode TensorRT 多进程 Agent PyCharm Pytorch UNIX Color VPN BeautifulSoup Datetime Translation CV Pillow Gemma tar GPT4 Knowledge Diagram YOLO Algorithm v0.dev 证件照 FP64 版权 LLM Markdown 图形思考法 图标 Nginx InvalidArgumentError Python Augmentation Safetensors Anaconda Bitcoin Michelin Proxy FP16 Hotel MD5 Breakpoint Search Web Shortcut printf Password Plotly LLAMA Domain 财报 BTC News 飞书 递归学习法 HaggingFace Hilton Vim logger mmap FlashAttention TTS DeepSeek AI 公式 CTC JSON transformers uwsgi Baidu Tensor Pandas LaTeX Permission 关于博主 GGML CC SPIE NLTK 音频 腾讯云 RGB Statistics XGBoost Card VSCode CAM PyTorch Hungarian Disk OpenAI Distillation Bert Zip FP32 Claude UI Bin Data NLP Linux CLAP VGG-16 Land ResNet-50 GPTQ torchinfo Excel 搞笑 Attention Github Food GIT CUDA Jetson Base64 FP8 uWSGI 阿里云 Numpy 算法题 Jupyter TSV Animate Use Windows CSV Qwen v2ray Tiktoken Llama diffusers NameSilo Paper RAR Logo
站点统计

本站现有博文323篇,共被浏览796088

本站已经建立2493天!

热门文章
文章归档
回到顶部