PyTorch DDP 多进程训练在 Kaggle 笔记本中的正确启动方式
技术百科
霞舞
发布时间:2026-01-01
浏览: 次 在 kaggle 等基于 jupyter 的环境中直接运行 pytorch ddp(distributeddataparallel)多进程代码会因 `__main__` 模块序列化失败而报错;根本解决方案是将 ddp 主逻辑写入独立 `.py` 文件,并通过命令行方式执行,避开 notebook 的模块上下文限制。
PyTorch 的 torch.multiprocessing.spawn 要求被启动的函数(如 main)必须可被子进程通过 pickle 反序列化——这在标准 Python 脚本中自然成立,因为 if __name__ == "__main__": 块内定义的函数属于顶层模块 __main__。但在 Kaggle 或 Jupyter Notebook 中,整个 cell 代码实际运行在
AttributeError: Can't get attribute 'main' on
✅ 正确做法:分离定义与执行
将 DDP 训练逻辑封装为标准 .py 文件,而非在 notebook cell 中直接调用 mp.spawn()。
✅ 实施步骤(Kaggle 环境)
-
使用 %%writefile 魔法命令创建独立脚本
在 notebook 新建 cell,粘贴并保存完整 DDP 代码(参考 PyTorch 官方示例),顶部添加 %%writefile ddp.py:
%%writefile ddp.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torch.multiprocessing as mp
from torchvision import datasets, transforms
import os
def main(rank, world_size, epochs=5, batch_size=32, lr=1e-3):
# 初始化进程组
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank
)
# 设置设备
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
# 构建模型、数据集、优化器等(此处省略细节)
model = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10)).to(device)
model = DDP(model, device_ids=[rank])
train_dataset = datasets.MNIST("./data", train=True, download=True,
transform=transforms.ToTensor())
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
sampler.set_epoch(epoch) # 关键:确保每个 epoch 数据打乱一致
for data, target in tr
ain_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data.view(data.size(0), -1))
loss = nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--world_size", type=int, default=torch.cuda.device_count())
args = parser.parse_args()
# 注意:Kaggle 中需显式设置环境变量(spawn 自动读取)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
mp.spawn(main, args=(args.world_size, 5, 32, 1e-3), nprocs=args.world_size, join=True)-
在另一个 cell 中执行脚本
使用系统命令运行,绕过 notebook 解释器上下文:
!python -W ignore ddp.py
⚠️ 注意事项:务必设置 MASTER_ADDR 和 MASTER_PORT:spawn 依赖这些环境变量初始化 NCCL 后端,Kaggle 默认未设置。避免在 notebook 中直接调用 mp.spawn():即使加了 if __name__ == "__main__":,notebook 的 __main__ 仍不可序列化。-W ignore 是可选的:用于抑制 PyTorch 分布式警告(如 UserWarning: ... is deprecated),提升日志可读性。单节点多卡适用:本方案专为 Kaggle 提供的 2×T4 场景设计;跨节点需额外配置 MASTER_ADDR 和网络互通。
该方法严格遵循 Python 多进程的“spawn”启动方式语义,确保每个子进程从干净的 .py 文件入口重新导入模块,彻底规避 AttributeError。这是在受限 notebook 环境中安全启用 PyTorch DDP 的工业级实践。
# ai
# 后端
# python
# 环境变量
# pytorch
相关栏目:
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
AI推广<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
SEO优化<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
技术百科<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
谷歌推广<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
百度推广<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
网络营销<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
案例网站<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
精选文章<?muma echo $count; ?>
】
相关推荐
- Win11截图快捷键是什么_Win11自带截图工具
- php与c语言在嵌入式中有何区别_对比两者在硬件控
- 如何在Golang中实现服务熔断与限流_Golan
- Win11如何设置ipv6 Win11开启IPv6
- 如何使用Golang处理静态文件缓存_提高页面加载
- 如何在Golang中实现邮件发送功能_Golang
- 如何优化Golang内存分配与GC调度_Golan
- Win11色盲模式怎么开_Win11屏幕颜色滤镜设
- Win11怎么关闭资讯和兴趣_Windows11任
- Windows10电脑怎么设置文件权限_Win10
- Golang如何遍历目录文件_Golang fil
- 电脑无法识别U盘怎么办 Windows磁盘管理与驱
- 如何在 Python 中将 ISO 8601 时间
- 如何在 Go 中高效缓存与分发网络视频流
- Win10路由器怎么隐藏ssid Win10隐藏w
- Win11怎么清理C盘OneDrive缓存_Win
- Windows10怎么卸载预装软件_Windows
- LINUX如何开放防火墙端口_Linux fire
- 如何在Golang中实现文件下载_Golang文件
- php8.4匿名类怎么用_php8.4匿名类创建与
- Python文件和流处理指南_高效读写大体积数据文
- Win11怎么开启远程桌面_Win11系统远程桌面
- C++如何编写函数模板?(泛型编程入门)
- 如何关闭Win10自动更新更新_Win10系统自动
- windows如何备份注册表_windows导出和
- 为什么Go需要go mod文件_Go go mod
- Win11怎么设置声音输出设备_Windows11
- Win11声音忽大忽小怎么办 Win11音频增强功
- PhpStorm怎么调试PHP代码_PhpStor
- Win11怎么关闭系统透明度_Windows11个
- 如何在Golang中验证模块完整性_Golangg
- Win11无法识别耳机怎么办_解决Win11插耳机
- php增删改查报错1054怎么办_字段名错误排查修
- Win11怎么关闭粘滞键_彻底禁用Windows
- Python lxml的etree和Element
- Windows10如何更改开机密码_Win10登录
- Win10怎么限制单程序CPU占用上限_Win10
- c++怎么使用std::tuple存储多元组数据_
- Windows10系统怎么查看IP地址_Win10
- Mac版Final Cut Pro入门_Mac视频
- Windows10怎么查看硬件信息_Windows
- Win11怎么关闭防火墙通知_屏蔽Win11安全中
- Win11如何设置文件权限 Win11 NTFS文
- php做exe支持多线程吗_并发处理实现方式【详解
- c++如何连接Redis c++ hiredis库
- 一文详解网站被黑客入侵挂马解决办法
- Windows7如何安装系统镜像_Windows7
- php中::能访问全局变量吗_全局作用域与类作用域
- 如何使用Golang构建基础消息队列模拟_Gola
- Python脚本参数接收_sys与argparse

ain_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data.view(data.size(0), -1))
loss = nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--world_size", type=int, default=torch.cuda.device_count())
args = parser.parse_args()
# 注意:Kaggle 中需显式设置环境变量(spawn 自动读取)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
mp.spawn(main, args=(args.world_size, 5, 32, 1e-3), nprocs=args.world_size, join=True)
QQ客服