Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Mismatch of model parameters when using Sequence Parallel #6868

Open
chetwin-character opened this issue Dec 13, 2024 · 0 comments
Open
Labels
bug Something isn't working training

Comments

@chetwin-character
Copy link

chetwin-character commented Dec 13, 2024

Describe the bug
We are trying a very simple script with sequence parallel and no data parallelism. After a single training step, we expect that the model parameters for each rank of the sequence parallel group should be the same, but there we are seeing a mismatch across the ranks.

To Reproduce
Run CUDA_VISIBLE_DEVICES=0,1,2,3 deepspeed train_sp.py --config-file=config.yaml --deepspeed

train_sp.py is

import random
import deepspeed
from omegaconf import OmegaConf
import torch
import torch.nn as nn

import torch.distributed as dist
import argparse
from sys import argv

def get_argument_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-file",
                        type=str,
                        required=True)
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    
    return parser


def get_arguments(args=argv[1:]):
    parser = get_argument_parser()
    # Include DeepSpeed configuration arguments
    parser = deepspeed.add_config_arguments(parser)

    args = parser.parse_args(args)

    # no cuda mode is not supported
    args.no_cuda = False

    return args

def seed_everything(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def initialize_sequence_parallelism(sequence_parallel_size, world_size, world_rank):
    assert int(world_size) % sequence_parallel_size == 0
    sequence_parallel_num_groups = int(world_size) // sequence_parallel_size
    global _SEQUENCE_PARALLEL_GROUP
    for i in range(sequence_parallel_num_groups):
        ranks = range(i * sequence_parallel_size,
                    (i + 1) * sequence_parallel_size)
        group = torch.distributed.new_group(ranks)
        if int(world_rank) in ranks:
            _SEQUENCE_PARALLEL_GROUP = group


def get_sequence_parallel_group():
    """Get the sequence parallel group the caller rank belongs to."""
    return _SEQUENCE_PARALLEL_GROUP

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(640, 16)
    
    def forward(self, x):
        return self.linear(x)


def main(config, args):
    deepspeed.init_distributed(dist_backend='nccl')

    seed_everything(12345)

    # Set DeepSpeed info (not even sure if important but whatever)
    device = dist.get_rank() % 8
    world_size = dist.get_world_size()
    global_rank = dist.get_rank()
    local_rank = device

    args.local_rank = local_rank
    args.device = device

    # set up stuff related to SP
    ds_config = dict(config.ds_config)
    sequence_parallel_size = ds_config['sequence_parallel_size'] if 'sequence_parallel_size' in ds_config else 1
    using_sequence_parallel = sequence_parallel_size > 1
    if using_sequence_parallel:
        initialize_sequence_parallelism(sequence_parallel_size, world_size=world_size, world_rank=global_rank)
        ds_config['data_parallel_size'] = world_size // sequence_parallel_size

    print(f"Running distributed training with world_size: {world_size}, this is rank {global_rank}, sequence_parallel: {using_sequence_parallel}")

    model = SimpleMLP()

    model.to(device)

    # optimizer (deepspeed style)
    param_optimizer = list(model.named_parameters())
    optimizer_grouped_parameters = [{'params': [p for n, p in param_optimizer]}]

    model, optimizer, _, _ = deepspeed.initialize(
        model=model, model_parameters=optimizer_grouped_parameters, args=args, config = ds_config
    )

    x = torch.randn((12000, 640)).to(device)
    local_x = torch.chunk(x, sequence_parallel_size, dim=0)[local_rank]
    local_out = local_x[:, :16]
    out = model(local_x)

    ## dummy MAE loss
    loss = (local_out - out).abs().mean()

    model.backward(loss)

    model.step()
    

    """
    save model params from each local gpu for comparison
    os.makedirs("saved_model", exist_ok=True)

    param_dict = {}
    for name, param in model.named_parameters():
        param_dict[name] = param.data

    torch.save(param_dict, f"saved_model/model_{local_rank}.pt")
    """

if __name__ == "__main__":
    args = get_arguments()
    config = OmegaConf.load(args.config_file)
    main(config=config,args=args)

And config.yaml is

ds_config:
  train_micro_batch_size_per_gpu: 4
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  optimizer:
    type: AdamW
    params:
      lr: 1e-4
      betas: 
        - 0.9
        - 0.999
      eps: 1e-8
      weight_decay: 0.01
  sequence_parallel_size: 4

Expected behavior
We expect the model parameters at the end of full training script to be the same across all ranks in the sequence parallel group. However, some of the parameter values differ.

Here are the parameters for the model on rank 0 and rank 1. They are not exactly the same.

In [6]: one
Out[6]: 
{'module.linear.weight': tensor([[ 0.0382,  0.0299,  0.0388,  ...,  0.0148,  0.0271, -0.0111],
         [-0.0137, -0.0059,  0.0370,  ...,  0.0140, -0.0097,  0.0391],
         [-0.0125, -0.0030, -0.0279,  ...,  0.0323, -0.0291,  0.0200],
         ...,
         [ 0.0282,  0.0249, -0.0326,  ..., -0.0072, -0.0136, -0.0192],
         [-0.0138,  0.0360,  0.0177,  ...,  0.0012,  0.0037, -0.0271],
         [-0.0019,  0.0249,  0.0101,  ..., -0.0148, -0.0341, -0.0256]]),
 'module.linear.bias': tensor([-0.0369, -0.0284, -0.0305,  0.0252, -0.0189,  0.0334, -0.0197,  0.0271,
         -0.0387, -0.0345, -0.0239, -0.0095, -0.0123,  0.0021, -0.0123, -0.0251])}

In [7]: two
Out[7]: 
{'module.linear.weight': tensor([[ 0.0382,  0.0299,  0.0388,  ...,  0.0150,  0.0271, -0.0111],
         [-0.0135, -0.0059,  0.0370,  ...,  0.0140, -0.0099,  0.0391],
         [-0.0125, -0.0030, -0.0279,  ...,  0.0323, -0.0291,  0.0200],
         ...,
         [ 0.0284,  0.0247, -0.0326,  ..., -0.0072, -0.0134, -0.0192],
         [-0.0138,  0.0360,  0.0177,  ...,  0.0014,  0.0037, -0.0271],
         [-0.0021,  0.0249,  0.0103,  ..., -0.0146, -0.0341, -0.0256]]),
 'module.linear.bias': tensor([-0.0371, -0.0286, -0.0303,  0.0254, -0.0189,  0.0334, -0.0195,  0.0271,
         -0.0387, -0.0345, -0.0239, -0.0095, -0.0123,  0.0019, -0.0123, -0.0249])}

ds_report output

[2024-12-13 22:31:48,877] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Warning: The cache directory for DeepSpeed Triton autotune, /home/chetwinlow/.triton/autotune, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path.
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  FP Quantizer is using an untested triton version (3.1.0), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
 [WARNING]  using untested triton version (3.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/chetwinlow/.local/lib/python3.12/site-packages/torch']
torch version .................... 2.5.1+cu124
deepspeed install path ........... ['/home/chetwinlow/.local/lib/python3.12/site-packages/deepspeed']
deepspeed info ................... 0.16.0, unknown, unknown
torch cuda version ............... 12.4
torch hip version ................ None
nvcc version ..................... 12.4
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 1.76 TB

System info (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]: Ubuntu 22.04.5 LTS
  • GPU count and types [e.g. two machines with x8 A100s each]: 1 machine with x8 H100
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
  • Python version: 3.12.4
  • Any other relevant info about your setup

Launcher context
deepspeed launcher

@chetwin-character chetwin-character added bug Something isn't working training labels Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

1 participant