diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index 639f27498dd9..0547431e3099 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -40,8 +40,9 @@ jobs: python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Install transformers run: | - git clone --depth=1 https://github.com/huggingface/transformers + git clone https://github.com/huggingface/transformers cd transformers + git checkout v4.47.1 git rev-parse --short HEAD python -m pip install . - name: Install deepspeed diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml index 329a1060f5eb..7e209cbe4397 100644 --- a/.github/workflows/nv-ds-chat.yml +++ b/.github/workflows/nv-ds-chat.yml @@ -43,7 +43,7 @@ jobs: - name: Install deepspeed run: | - pip install transformers==4.45.2 + pip install transformers pip install .[dev] ds_report diff --git a/SECURITY.md b/SECURITY.md index 9e9391ee0bac..3061748e610b 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -39,3 +39,7 @@ We prefer all communications to be in English. Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + +--- + +Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models. diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index ced9218d7aca..eb4e17850882 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -178,7 +178,7 @@ def get_accelerator(): if accelerator_name is None: # borrow this log from PR#5084 if accel_logger is not None: - accel_logger.warn( + accel_logger.warning( "Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.") # cpu added as catch-all when accelerator detection fails accelerator_name = "cpu" diff --git a/blogs/windows/08-2024/README.md b/blogs/windows/08-2024/README.md index 34e11bd47792..8a23372a1d64 100644 --- a/blogs/windows/08-2024/README.md +++ b/blogs/windows/08-2024/README.md @@ -48,7 +48,7 @@ Regardless of the installation choice, you can check that the installation was s We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed. ## Pretraining CIFAR10 -The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py –deepspeed`. The final output should look something like this: +The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py --deepspeed`. The final output should look something like this:
diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 42ffebbc4386..6df61f7c8841 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -174,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): values for :any:`DeepSpeedMoEConfig`. """ + keep_module_on_host: bool = False + """ + When loading checkpoints to model parameters, they are moved to the device. In very large models + this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on + host and not move them directly to the device (giving an option to quantize checkpoint data before + moving it to the device for example). + Set only for models with injection policies and auto TP. + """ + quant: QuantizationConfig = {} """ NOTE: only works for int8 dtype. diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 2a2e4665c310..837b8ae0dffb 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -79,7 +79,6 @@ def __init__(self, model, config): self.mp_group = config.tensor_parallel.tp_group self.mpu = config.tensor_parallel.mpu - #self._validate_args(self.mpu, config.replace_with_kernel_inject) self.quantize_merge_count = 1 self.quantization_scales = None @@ -169,7 +168,7 @@ def __init__(self, model, config): is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta' if is_meta_device: self.module.to_empty(device=device) - else: + elif not config.keep_module_on_host: self.module.to(device) if config.tensor_parallel.tp_size > 1: @@ -299,29 +298,6 @@ def _init_quantization_setting(self, quantization_setting): f"mlp_extra_grouping = {self.mlp_extra_grouping}, " f"quantize_groups = {self.quantize_groups}", [0]) - # TODO: remove this function and add this functionality to pydantic config checking - def _validate_args(self, mpu, replace_with_kernel_inject): - # TODO: to support SD pipeline we need to avoid this check for now - if replace_with_kernel_inject and not isinstance(self.module, Module): - raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}") - if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1: - raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}") - - if mpu: - methods = ["get_model_parallel_group", "get_data_parallel_group"] - for method in methods: - if not hasattr(mpu, method): - raise ValueError(f"mpu is missing {method}") - if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)): - raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}") - - supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16] - if self._config.dtype not in supported_dtypes: - raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}") - - if self.injection_dict is not None and not isinstance(self.injection_dict, dict): - raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}") - def load_model_with_checkpoint(self, r_module): self.mp_replace = ReplaceWithTensorSlicing( mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 05b9a8555ff9..d148c26968b3 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -19,14 +19,14 @@ from deepspeed.module_inject.layers import is_autotp_training_mode -def move(tensor, device): +def move(tensor, device, copy=True): if tensor.is_meta: return torch.empty_like(tensor, device=device) else: # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). # Using copy=True instead of clone() will help in case of cpu --> cpu. # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced. - return tensor.to(device, copy=True) + return tensor.to(device, copy=copy) class ReplaceWithTensorSlicing: @@ -136,7 +136,8 @@ def is_load_module(module): load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding", - "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm" + "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm", + "DeepseekV2RMSNorm", "DeepseekV2YarnRotaryEmbedding", "MoEGate" ] return module.__class__ in load_layers or module._get_name() in load_layer_names @@ -190,7 +191,14 @@ def load(module, state_dict, prefix, mp_group=None): class AutoTP(): - def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl): + def __init__(self, + module, + all_reduce_linears, + prefix, + state_dict, + linear_layer_setting, + orig_layer_impl, + keep_module_on_host=False): self.module = module self.all_reduce_linears = all_reduce_linears self.prefix = prefix @@ -202,6 +210,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_ self.orig_layer_impl = orig_layer_impl self.linear_policies = None self.conv_linear_layer = False + self.keep_module_on_host = keep_module_on_host def in_module_list(module, module_list): for item in module_list: @@ -340,11 +349,15 @@ def _replace(self, child, name, conv_linear_layer): # and avoid any complex shard-related logic. if getattr(child, "replaced", False) == True: return + device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name() + # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some + # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy. + return_new_copy = not self.keep_module_on_host weight_shape = child.weight.shape mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) - # For mixtral-7x8b, need to skip MoE gate linear replace. - if name == "block_sparse_moe.gate" or (('mlp.shared_expert_gate' == name or 'mlp.gate' == name) - and 'qwen2_moe' in str(type(self.module))): + # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip + if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or ( + ('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))): return child # For Yuan model if 'Yuan' in str(self.module): @@ -361,7 +374,11 @@ def _replace(self, child, name, conv_linear_layer): arctic_w2_all_reduce_linear = False if 'Arctic' in str(self.module) and 'w2' in name: arctic_w2_all_reduce_linear = True - if name in self.all_reduce_linears or arctic_w2_all_reduce_linear: + # For MoE MLP model, e.g., deepseek and jamba + down_proj = False + if 'down_proj' in name: + down_proj = True + if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj: setattr(child, "replaced", True) if self.conv_linear_layer: diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py index a78ac8120346..7a9b9ca2065b 100644 --- a/deepspeed/module_inject/containers/bloom.py +++ b/deepspeed/module_inject/containers/bloom.py @@ -19,6 +19,18 @@ class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer): def __init__(self, **kwargs): + # Check transformers version, error if > 4.43.4 (breaks at 4.44.0) + from importlib.metadata import version + v_transformers = version('transformers') + vers = v_transformers.split('.') + major = int(vers[0]) + minor = int(vers[1]) + if major > 4 or (major == 4 and minor > 43): + import sys + sys.exit( + f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported." + ) + super().__init__(**kwargs) # All model specific things should be defined here instead of the base class. diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 95c9c31ca01c..4782c197592a 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -226,6 +226,14 @@ def __deepcopy__(self, memo): memo[id(self)] = new_obj return new_obj + def extra_repr(self): + if self.weight is not None: + out_features, in_features = self.weight.shape if self.weight is not None else (None, None) + dtype = self.weight.dtype if self.weight is not None else None + extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format( + in_features, out_features, self.bias is not None, dtype) + return extra_repr_str + class GatherReplacedLayerParams: """ @@ -681,7 +689,7 @@ def __init__(self, weight_shape=None, weight=None, bias=None): self.offset = 2 super().__init__(weight_shape, weight=weight) - def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" attention_mask = attention_mask.long() diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 0f3349b32256..e8928059767f 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group) # 1. Create AutoTP object - _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl) + _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl, + config.keep_module_on_host) # 2. Set the tensor parallelism config _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) @@ -342,13 +343,11 @@ def set_lm_head(module): module.lm_head, "weight") and module.lm_head.weight.is_meta: module.lm_head.weight = embedding_weight # enable tensor parallel for the last linear - if hasattr(module, "lm_head") and hasattr(module.lm_head, - "weight") and not module.lm_head.weight.is_meta and isinstance( - module.lm_head, torch.nn.Linear): + if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance( + module.lm_head, torch.nn.Linear): module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head") - elif hasattr(module, "embed_out") and hasattr(module.embed_out, - "weight") and not module.embed_out.weight.is_meta and isinstance( - module.embed_out, torch.nn.Linear): + elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance( + module.embed_out, torch.nn.Linear): module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out") elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"): module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head") @@ -389,7 +388,6 @@ def conv2d_parallel_shard_weights(model, rank, world_size): checkpoint=checkpoint_file) pbar.update(1) gc.collect() - replaced_module = set_lm_head(replaced_module) # conv2d tp module replace # Now is for yuan model. Add model list and conv policy to decide whether to replace conv. if 'Yuan' in str(replaced_module): @@ -399,6 +397,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size): orig_class=orig_layer_impl, replace_fn=replace_fn, _replace_policy=config.injection_policy_tuple) + # AutoTP default set lm_head tp + if not config.replace_with_kernel_inject: + replaced_module = set_lm_head(replaced_module) quantizer = GroupQuantizer(q_int8=quantize) world_size = dist.get_world_size() if dist.is_initialized() else 1 diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 3e6fc2b63ef1..ded262edcf61 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -42,11 +42,16 @@ def get_num_attention_heads(): def get_shard_size(total_size, mp_size, name=None, rank=None): global num_kv_heads last_linear = ["lm_head", "embed_out"] + # MoE MLP layer use near even division will get better perf. + moe_mlp_layer = ["gate_proj", "up_proj", "down_proj", "w1", "w2", "w3"] + not_moe_mlp_layer = True + if name != None and any(s in str(name) for s in moe_mlp_layer): + not_moe_mlp_layer = False # When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division if rank == None: rank = dist.get_rank() if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str( - name) not in last_linear: + name) not in last_linear and not_moe_mlp_layer: my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) return total_size * my_slices // num_kv_heads else: diff --git a/deepspeed/ops/transformer/inference/triton/matmul_ext.py b/deepspeed/ops/transformer/inference/triton/matmul_ext.py index 412c8740a216..9be4b0098c37 100644 --- a/deepspeed/ops/transformer/inference/triton/matmul_ext.py +++ b/deepspeed/ops/transformer/inference/triton/matmul_ext.py @@ -19,6 +19,9 @@ # ----------------------------------------------------------------------------- # util class/functions for triton def is_nfs_path(path): + if os.name == 'nt': + return False + # Normalize the path to get the absolute path path = os.path.abspath(path) @@ -99,7 +102,7 @@ def put(self, table): with FileLock(self.lock_path): with open(self.file_path + ".tmp", 'wb') as handle: pickle.dump(table, handle) - os.rename(self.file_path + ".tmp", self.file_path) + os.replace(self.file_path + ".tmp", self.file_path) def load(self): if os.path.exists(self.file_path): diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index b8df7499450d..d2c54155da89 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -28,7 +28,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) if self.mpu is None: - logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.") + logger.warning("MPU is not provided, setting tp size to 1 in checkpoint loading.") tp_world_size = 1 else: tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 612a5b237aee..437e1caf4df3 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3190,7 +3190,7 @@ def _get_all_zero_checkpoints(self, load_dir, tag): if bf16_mode is not self.bfloat16_enabled(): checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16 engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16 - logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine') + logger.warning(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine') return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) return None @@ -3346,7 +3346,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa local_expert_id = None if not m: - logger.warn(f'No expert found in key {key}.') + logger.warning(f'No expert found in key {key}.') else: local_expert_id = m.group(1) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 899358e2c5ef..2ffd0bf9f036 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -508,7 +508,7 @@ def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, l def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration): if 'betas' not in optimizer.defaults: optimizer_name = type(optimizer).__name__ - logger.warn( + logger.warning( f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults" ) self.cycle_momentum = False diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index ecb2a527f870..0508766f8896 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -614,7 +614,7 @@ def _configure_moe_settings(self): assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE" # NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion if not self.partition_gradients and not self.contiguous_gradients: - logger.warn( + logger.warning( "ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.") assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE" diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 4fab768ce63c..4fa2cc988a19 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -47,7 +47,7 @@ def _update_out_and_lse( block_out = block_out.to(torch.float32) block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + new_lse = lse + torch.log1p(torch.exp(block_lse - lse)) out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out diff --git a/op_builder/hpu/builder.py b/op_builder/hpu/builder.py index c176a586ba49..11e710a8ee48 100644 --- a/op_builder/hpu/builder.py +++ b/op_builder/hpu/builder.py @@ -32,9 +32,6 @@ def builder(self): def cxx_args(self): args = ['-O3', '-g', '-Wno-reorder'] - CPU_ARCH = self.cpu_arch() - SIMD_WIDTH = self.simd_width() - args += [CPU_ARCH, '-fopenmp', SIMD_WIDTH] return args def libraries_args(self): diff --git a/setup.py b/setup.py index c0452f867b31..cc5eb4a3500c 100755 --- a/setup.py +++ b/setup.py @@ -321,9 +321,9 @@ def op_enabled(op_name): include_package_data=True, scripts=scripts, classifiers=[ - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10' + 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12' ], license='Apache Software License 2.0', ext_modules=ext_modules, diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 9b563523dbeb..df85ed232a2e 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty @pytest.mark.seq_inference +@pytest.mark.parametrize('keep_module_on_host', [True, False]) @pytest.mark.parametrize( "model_w_task", [("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")], @@ -570,6 +571,7 @@ def test( inf_kwargs, assert_fn, dtype, + keep_module_on_host, ): invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) if invalid_test_msg: @@ -592,13 +594,20 @@ def test( framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) print(local_rank, "deepspeed", ds_output) assert assert_fn(bs_output, ds_output) + if keep_module_on_host: + for name, param in model.named_parameters(): + assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu" + @pytest.mark.world_size(3) def test_odd_world_size( self, @@ -607,6 +616,7 @@ def test_odd_world_size( inf_kwargs, assert_fn, dtype, + keep_module_on_host, ): invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) if invalid_test_msg: @@ -624,13 +634,20 @@ def test_odd_world_size( framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) print(local_rank, "deepspeed", ds_output) assert assert_fn(bs_output, ds_output) + if keep_module_on_host: + for name, param in model.named_parameters(): + assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu" + @pytest.mark.nightly @pytest.mark.parametrize( diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py index 9cfcae809f09..d63c51267e51 100644 --- a/tests/unit/ops/transformer/inference/inference_test_utils.py +++ b/tests/unit/ops/transformer/inference/inference_test_utils.py @@ -3,6 +3,8 @@ # DeepSpeed Team +from typing import Tuple + import torch from deepspeed.accelerator import get_accelerator @@ -23,38 +25,22 @@ def get_tolerances(): DTYPES = None -def get_dtypes(): +def get_dtypes(include_float=True): global DTYPES if DTYPES is None: - DTYPES = get_accelerator().supported_dtypes() + DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16] + try: + if get_accelerator().is_bf16_supported(): + DTYPES.append(torch.bfloat16) + except (AssertionError, AttributeError): + pass return DTYPES -def allclose(x, y): +def allclose(x, y, tolerances: Tuple[int, int] = None): assert x.dtype == y.dtype - rtol, atol = get_tolerances()[x.dtype] + if tolerances is None: + rtol, atol = get_tolerances()[x.dtype] + else: + rtol, atol = tolerances return torch.allclose(x, y, rtol=rtol, atol=atol) - - -def assert_almost_equal(x, y, decimal=2, err_msg=''): - import numpy.testing as npt - if isinstance(x, torch.Tensor): - if x.dtype == torch.bfloat16: - x = x.float() - x = x.cpu().detach().numpy() - if isinstance(y, torch.Tensor): - if y.dtype == torch.bfloat16: - y = y.float() - y = y.cpu().detach().numpy() - npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) - - -def max_diff(a, b): - a = a.to(torch.float32).flatten() - b = b.to(torch.float32).flatten() - diff = torch.abs(a - b) - max_diff_indices = torch.argsort(diff)[-1] - print("Max difference indices:", max_diff_indices) - print("Max difference values:", diff[max_diff_indices]) - print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}") - return max_diff_indices diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py index ecf681542ff6..cae201d747a3 100644 --- a/tests/unit/ops/transformer/inference/test_attention.py +++ b/tests/unit/ops/transformer/inference/test_attention.py @@ -7,7 +7,7 @@ import torch import deepspeed from deepspeed.accelerator import get_accelerator -from .inference_test_utils import assert_almost_equal +from .inference_test_utils import allclose # reference timplementation @@ -88,4 +88,4 @@ def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float use_triton_flash=False, use_ds_attention=False) tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3) - assert_almost_equal(ref_out, tri_out) + assert (allclose(ref_out, tri_out)) diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py index 7711daf0d887..4a84add16046 100644 --- a/tests/unit/ops/transformer/inference/test_layer_norm.py +++ b/tests/unit/ops/transformer/inference/test_layer_norm.py @@ -9,7 +9,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp -from .inference_test_utils import allclose, get_dtypes, assert_almost_equal +from .inference_test_utils import allclose, get_dtypes try: import triton # noqa: F401 # type: ignore from deepspeed.ops.transformer.inference.triton import ( @@ -188,4 +188,4 @@ def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device=' y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias, eps).to(dtype) # compare - assert_almost_equal(y_tri, y_ref) + assert (allclose(y_tri, y_ref))