企业营销类专业网站,网站建设工作室 怎么样,网站浏览器兼容性通用,商务网站建设与维护试题目录
模型格式介绍
加载以及保存
- 加载.safetensors文件#xff1a;
- 保存/加载.pth文件#xff1a;
- 保存/加载.ckpt文件#xff1a;
- 处理.bin文件#xff1a;
模型之间的互相转换
pytorch-lightning 和 pytorch
ckpt和safetensors 模型格式介绍
在大型深度…目录
模型格式介绍
加载以及保存
- 加载.safetensors文件
- 保存/加载.pth文件
- 保存/加载.ckpt文件
- 处理.bin文件
模型之间的互相转换
pytorch-lightning 和 pytorch
ckpt和safetensors 模型格式介绍
在大型深度学习模型的上下文中.safetensors、.bin 和 .pth ckpt 文件的用途和区别如下 .safetensors 文件 这是由 Hugging Face 推出的一种新型安全模型存储格式特别关注模型安全性、隐私保护和快速加载。它仅包含模型的权重参数而不包括执行代码这样可以减少模型文件大小提高加载速度。加载方式使用 Hugging Face 提供的相关API来加载 .safetensors 文件例如 safetensors.torch.load_file() 函数。 ckpt文件 ckpt 文件是 PyTorch Lightning 框架采用的模型存储格式它不仅包含了模型参数还包括优化器状态以及可能的训练元数据信息使得用户可以无缝地恢复训练或执行推理。 .bin 文件 通常是一种通用的二进制格式文件它可以用来存储任意类型的数据。在机器学习领域.bin 文件有时用于存储模型权重或其他二进制数据但并不特指PyTorch的官方标准格式。对于PyTorch而言如果用户自己选择将模型权重以二进制格式保存可能会使用 .bin 扩展名加载时需要自定义逻辑读取和应用这些权重到模型结构中。 .pth 文件 是 PyTorch 中用于保存模型状态的标准格式。主要用于保存模型的 state_dict包含了模型的所有可学习参数或者整个模型包括结构和参数。加载方式使用 PyTorch 的 torch.load() 函数直接加载 .pth 文件并通过调用 model.load_state_dict() 将加载的字典应用于模型实例。
总结起来
.safetensors 侧重于安全性和效率适合于那些希望快速部署且对安全有较高要求的场景尤其在Hugging Face生态中。.ckpt 文件是 PyTorch Lightning 框架采用的模型存储格式它不仅包含了模型参数还包括优化器状态以及可能的训练元数据信息使得用户可以无缝地恢复训练或执行推理。.bin 文件不是标准化的模型保存格式但在某些情况下可用于存储原始二进制权重数据加载时需额外处理。.pth 是PyTorch的标准模型保存格式方便模型的持久化和复用支持完整模型结构和参数的保存与恢复。
加载以及保存
- 加载.safetensors文件
# 用SDXL举例
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_filebase stabilityai/stable-diffusion-xl-base-1.0
repo ByteDance/SDXL-Lightning
ckpt /home/bino/svul/models/sdxl/sdxl_lightning_2step_unet.safetensors # Use the correct ckpt for your step setting!# Load model.
unet UNet2DConditionModel.from_config(base, subfolderunet).to(cuda, torch.float16)
unet.load_state_dict(load_file(ckpt, devicecuda))
# unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), devicecuda))
pipe StableDiffusionXLPipeline.from_pretrained(base, unetunet, torch_dtypetorch.float16, variantfp16).to(cuda)# Ensure sampler uses trailing timesteps.
pipe.scheduler EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacingtrailing)# Ensure using the same inference steps as the loaded model and CFG set to 0.
pipe(A girl smiling, num_inference_steps4, guidance_scale0).images[0].save(output.png)
- 保存/加载.pth文件 # 保存模型状态字典torch.save(model.state_dict(), model.pth)# 加载模型状态字典到已有模型结构中model TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(model.pth))# 或者保存整个模型包括结构torch.save(model, model.pth)# 加载整个模型model torch.load(model.pth, map_locationdevice)
- 保存/加载.ckpt文件
import pytorch_lightning as pl# 定义一个 PyTorch Lightning 训练模块
class MyLightningModel(pl.LightningModule):def __init__(self):super().__init__()self.linear_layer nn.Linear(10, 1)self.loss_function nn.MSELoss()def forward(self, inputs):return self.linear_layer(inputs)def training_step(self, batch, batch_idx):features, targets batchpredictions self(features)loss self.loss_function(predictions, targets)self.log(train_loss, loss)return loss# 初始化 PyTorch Lightning 模型
lightning_model MyLightningModel()# 配置 ModelCheckpoint 回调以定期保存最佳模型至 .ckpt 文件
checkpoint_callback pl.callbacks.ModelCheckpoint(monitorval_loss,filenamebest-model-{epoch:02d}-{val_loss:.2f},save_top_k3,modemin
)# 创建训练器并启动模型训练
trainer pl.Trainer(callbacks[checkpoint_callback],max_epochs10
)
trainer.fit(lightning_model)# 从 .ckpt 文件加载最优模型权重
best_model MyLightningModel.load_from_checkpoint(checkpoint_pathbest-model.ckpt)# 使用加载的 .ckpt 文件中的模型进行预测
sample_input torch.randn(1, 10)
predicted_output best_model(sample_input)
print(predicted_output)
在此示例中我们首先定义了一个 PyTorch Lightning 模块该模块集成了模型训练的逻辑。然后我们配置了 ModelCheckpoint 回调函数在训练过程中按照验证损失自动保存最佳模型至 .ckpt 文件。接着我们展示了如何加载 .ckpt 文件中的最优模型权重并利用加载后的模型对随机输入数据进行预测同样输出预测结果。值得注意的是由于 .ckpt 文件完整记录了训练状态它在实际应用中常被用于模型微调和进一步训练。
- 处理.bin文件
如果.bin文件是纯二进制权重文件加载时需要知道模型结构并且手动将权重加载到对应的层中例如 # 假设已经从.bin文件中读取到了模型权重数据weights_data load_binary_weights(weights.bin)# 手动初始化模型并加载权重model TheModelClass(*args, **kwargs)for name, param in model.named_parameters():if name in weights_mapping: # 需要预先知道权重映射关系param.data.copy_(weights_data[weights_mapping[name]])
模型之间的互相转换
pytorch-lightning 和 pytorch
由于 PyTorch Lightning 模型本身就是 PyTorch 模型因此不存在严格意义上的转换过程。你可以直接通过 LightningModule 中定义的神经网络层来进行保存和加载就像普通的 PyTorch 模型一样
# 假设 model 是一个 PyTorch Lightning 模型实例
model MyLightningModel()# 保存模型权重
torch.save(model.state_dict(), lightning_model.pth)# 加载到一个新的 PyTorch 模型实例
new_model MyLightningModel()
new_model.load_state_dict(torch.load(lightning_model.pth))# 或者加载到一个普通的 PyTorch Module 实例假设结构一致
plain_pytorch_model MyPlainPytorchModel()
plain_pytorch_model.load_state_dict(torch.load(lightning_model.pth))
ckpt和safetensors
转换后的模型在stable-diffussion-webui中使用过没有问题不知道有没有错误或者没转换成功
import torch
import os
import safetensors
from typing import Dict, List, Optional, Set, Tuple
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_filedef ckpt2safetensors():loaded torch.load(v1-5-pruned-emaonly.ckpt)if state_dict in loaded:loaded loaded[state_dict]safetensors.torch.save_file(loaded, v1-5-pruned-emaonly.safetensors)def st2ckpt():# 加载 .safetensors 文件data safetensors.torch.load_file(v1-5-pruned-emaonly.safetensors.bk)data[state_dict] data# 将数据保存为 .ckpt 文件torch.save(data, os.path.splitext(v1-5-pruned-emaonly.safetensors)[0] .ckpt)