# torch 中保证随机性的最佳实践¶

• 随机数据：比如 torch.random 生成的随机数，在随机初始化张量、随机采样、随机增强等过程中会用到。这种随机性可以通过设置随机种子来控制。
• 随机算法：一些算法本身具有随机性，可以通过使用确定性算法来回避随机性（注意，一些算法是没有确定性实现的，如果非要 use_deterministic_algorithms，会报 RuntimeError）。

## 野生方案¶

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.deterministic = True
cudnn.benchmark = False

def worker_init_fn(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)

ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

# https://github.com/pytorch/pytorch/pull/56488#issuecomment-825128350
def worker_init_fn(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)

ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

# https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
def worker_init_fn(id):
process_seed = torch.initial_seed()
# Back out the base_seed so we can use all the bits.
base_seed = process_seed - id
ss = np.random.SeedSequence([id, base_seed])
# More than 128 bits (4 32-bit words) would be overkill.
np.random.seed(ss.generate_state(4))

ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

## lightning 的方案¶

from lightning.pytorch import Trainer, seed_everything

# Sets seeds for numpy, torch and python.random.
seed_everything(42, workers=True)
model = Model()

# Enable deterministic training
trainer = Trainer(deterministic=True)

# https://github.com/Lightning-AI/lightning/blob/017262e5e0c65215e9e75121d155d7a07cd9e7bf/src/lightning/fabric/utilities/seed.py#L19
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
sets the following environment variables:

- PL_GLOBAL_SEED: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
- PL_SEED_WORKERS: (optional) is set to 1 if workers=True.

Args:
seed: the integer value seed for global random state in Lightning.
If None, will read seed from PL_GLOBAL_SEED env variable
or select it randomly.
workers: if set to True, will properly configure all dataloaders passed to the
Trainer with a worker_init_fn. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
:func:~lightning.fabric.utilities.seed.pl_worker_init_function.
"""
# 处理种子的设置
if seed is None:
env_seed = os.environ.get("PL_GLOBAL_SEED")
if env_seed is None:
seed = _select_seed_randomly(min_seed_value, max_seed_value)
rank_zero_warn(f"No seed found, seed set to {seed}")
else:
try:
seed = int(env_seed)
except ValueError:
seed = _select_seed_randomly(min_seed_value, max_seed_value)
rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")
elif not isinstance(seed, int):
seed = int(seed)

if not (min_seed_value <= seed <= max_seed_value):
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = _select_seed_randomly(min_seed_value, max_seed_value)

# 正剧开始
log.info(rank_prefixed_message(f"Global seed set to {seed}", _get_rank()))
# 1) 设置全局种子，在不同的 spawn 进程之间共享
os.environ["PL_GLOBAL_SEED"] = str(seed)
# 2) python.random
random.seed(seed)
# 3) numpy.random
np.random.seed(seed)
# 4) torch
torch.manual_seed(seed)
# 5) torch.cuda
torch.cuda.manual_seed_all(seed)
# 6) 设置环境变量，该环境变量用于 pl_worker_init_fn 为 DataLoader 的 workers 设置种子
os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"

return seed

# https://github.com/Lightning-AI/lightning/blob/017262e5e0c65215e9e75121d155d7a07cd9e7bf/src/lightning/fabric/utilities/data.py#L247C15-L247C29
def _auto_add_worker_init_fn(dataloader: object, rank: int) -> None:
if not hasattr(dataloader, "worker_init_fn"):
return
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)

# https://github.com/Lightning-AI/lightning/blob/017262e5e0c65215e9e75121d155d7a07cd9e7bf/src/lightning/fabric/utilities/seed.py#L81
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:  # pragma: no cover
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
seed_everything(seed, workers=True).

randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>_.
"""
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
global_rank = rank if rank is not None else rank_zero_only.rank
process_seed = torch.initial_seed()
# back out the base seed so we can use all the bits
base_seed = process_seed - worker_id
log.debug(
f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}"
)
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
# use 128 bits (4 x 32-bit words)
np.random.seed(ss.generate_state(4))
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
torch_ss, stdlib_ss = ss.spawn(2)
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
# use 128 bits expressed as an integer
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
random.seed(stdlib_seed)

• torch.cuda.cudnn.benchmark = False
• torch.use_deterministic_algorithms(True)
• os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def _set_torch_flags(
*, deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, benchmark: Optional[bool] = None
) -> None:
if deterministic:
if benchmark is None:
# Set benchmark to False to ensure determinism
benchmark = False
elif benchmark:
rank_zero_warn(
"You passed deterministic=True and benchmark=True. Note that PyTorch ignores"
" torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.",
)
if benchmark is not None:
torch.backends.cudnn.benchmark = benchmark

if deterministic == "warn":
torch.use_deterministic_algorithms(True, warn_only=True)
elif isinstance(deterministic, bool):
# do not call this if deterministic wasn't passed
torch.use_deterministic_algorithms(deterministic)
if deterministic:
# https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"