diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml
index 639f27498dd9..0547431e3099 100644
--- a/.github/workflows/nv-a6000.yml
+++ b/.github/workflows/nv-a6000.yml
@@ -40,8 +40,9 @@ jobs:
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install transformers
run: |
- git clone --depth=1 https://github.com/huggingface/transformers
+ git clone https://github.com/huggingface/transformers
cd transformers
+ git checkout v4.47.1
git rev-parse --short HEAD
python -m pip install .
- name: Install deepspeed
diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml
index 329a1060f5eb..7e209cbe4397 100644
--- a/.github/workflows/nv-ds-chat.yml
+++ b/.github/workflows/nv-ds-chat.yml
@@ -43,7 +43,7 @@ jobs:
- name: Install deepspeed
run: |
- pip install transformers==4.45.2
+ pip install transformers
pip install .[dev]
ds_report
diff --git a/SECURITY.md b/SECURITY.md
index 9e9391ee0bac..3061748e610b 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -39,3 +39,7 @@ We prefer all communications to be in English.
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
+
+---
+
+Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models.
diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py
index ced9218d7aca..eb4e17850882 100644
--- a/accelerator/real_accelerator.py
+++ b/accelerator/real_accelerator.py
@@ -178,7 +178,7 @@ def get_accelerator():
if accelerator_name is None:
# borrow this log from PR#5084
if accel_logger is not None:
- accel_logger.warn(
+ accel_logger.warning(
"Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.")
# cpu added as catch-all when accelerator detection fails
accelerator_name = "cpu"
diff --git a/blogs/windows/08-2024/README.md b/blogs/windows/08-2024/README.md
index 34e11bd47792..8a23372a1d64 100644
--- a/blogs/windows/08-2024/README.md
+++ b/blogs/windows/08-2024/README.md
@@ -48,7 +48,7 @@ Regardless of the installation choice, you can check that the installation was s
We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed.
## Pretraining CIFAR10
-The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py –deepspeed`. The final output should look something like this:
+The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py --deepspeed`. The final output should look something like this:
diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py
index 42ffebbc4386..6df61f7c8841 100644
--- a/deepspeed/inference/config.py
+++ b/deepspeed/inference/config.py
@@ -174,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
values for :any:`DeepSpeedMoEConfig`.
"""
+ keep_module_on_host: bool = False
+ """
+ When loading checkpoints to model parameters, they are moved to the device. In very large models
+ this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
+ host and not move them directly to the device (giving an option to quantize checkpoint data before
+ moving it to the device for example).
+ Set only for models with injection policies and auto TP.
+ """
+
quant: QuantizationConfig = {}
"""
NOTE: only works for int8 dtype.
diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py
index 2a2e4665c310..837b8ae0dffb 100755
--- a/deepspeed/inference/engine.py
+++ b/deepspeed/inference/engine.py
@@ -79,7 +79,6 @@ def __init__(self, model, config):
self.mp_group = config.tensor_parallel.tp_group
self.mpu = config.tensor_parallel.mpu
- #self._validate_args(self.mpu, config.replace_with_kernel_inject)
self.quantize_merge_count = 1
self.quantization_scales = None
@@ -169,7 +168,7 @@ def __init__(self, model, config):
is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta'
if is_meta_device:
self.module.to_empty(device=device)
- else:
+ elif not config.keep_module_on_host:
self.module.to(device)
if config.tensor_parallel.tp_size > 1:
@@ -299,29 +298,6 @@ def _init_quantization_setting(self, quantization_setting):
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}", [0])
- # TODO: remove this function and add this functionality to pydantic config checking
- def _validate_args(self, mpu, replace_with_kernel_inject):
- # TODO: to support SD pipeline we need to avoid this check for now
- if replace_with_kernel_inject and not isinstance(self.module, Module):
- raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
- if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
- raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")
-
- if mpu:
- methods = ["get_model_parallel_group", "get_data_parallel_group"]
- for method in methods:
- if not hasattr(mpu, method):
- raise ValueError(f"mpu is missing {method}")
- if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
- raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")
-
- supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16]
- if self._config.dtype not in supported_dtypes:
- raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
-
- if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
- raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")
-
def load_model_with_checkpoint(self, r_module):
self.mp_replace = ReplaceWithTensorSlicing(
mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py
index 05b9a8555ff9..d148c26968b3 100755
--- a/deepspeed/module_inject/auto_tp.py
+++ b/deepspeed/module_inject/auto_tp.py
@@ -19,14 +19,14 @@
from deepspeed.module_inject.layers import is_autotp_training_mode
-def move(tensor, device):
+def move(tensor, device, copy=True):
if tensor.is_meta:
return torch.empty_like(tensor, device=device)
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
- return tensor.to(device, copy=True)
+ return tensor.to(device, copy=copy)
class ReplaceWithTensorSlicing:
@@ -136,7 +136,8 @@ def is_load_module(module):
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
- "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm"
+ "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm",
+ "DeepseekV2RMSNorm", "DeepseekV2YarnRotaryEmbedding", "MoEGate"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names
@@ -190,7 +191,14 @@ def load(module, state_dict, prefix, mp_group=None):
class AutoTP():
- def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
+ def __init__(self,
+ module,
+ all_reduce_linears,
+ prefix,
+ state_dict,
+ linear_layer_setting,
+ orig_layer_impl,
+ keep_module_on_host=False):
self.module = module
self.all_reduce_linears = all_reduce_linears
self.prefix = prefix
@@ -202,6 +210,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
+ self.keep_module_on_host = keep_module_on_host
def in_module_list(module, module_list):
for item in module_list:
@@ -340,11 +349,15 @@ def _replace(self, child, name, conv_linear_layer):
# and avoid any complex shard-related logic.
if getattr(child, "replaced", False) == True:
return
+ device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
+ # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
+ # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
+ return_new_copy = not self.keep_module_on_host
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
- # For mixtral-7x8b, need to skip MoE gate linear replace.
- if name == "block_sparse_moe.gate" or (('mlp.shared_expert_gate' == name or 'mlp.gate' == name)
- and 'qwen2_moe' in str(type(self.module))):
+ # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
+ if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
+ ('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
return child
# For Yuan model
if 'Yuan' in str(self.module):
@@ -361,7 +374,11 @@ def _replace(self, child, name, conv_linear_layer):
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
- if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
+ # For MoE MLP model, e.g., deepseek and jamba
+ down_proj = False
+ if 'down_proj' in name:
+ down_proj = True
+ if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
setattr(child, "replaced", True)
if self.conv_linear_layer:
diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py
index a78ac8120346..7a9b9ca2065b 100644
--- a/deepspeed/module_inject/containers/bloom.py
+++ b/deepspeed/module_inject/containers/bloom.py
@@ -19,6 +19,18 @@
class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
+ # Check transformers version, error if > 4.43.4 (breaks at 4.44.0)
+ from importlib.metadata import version
+ v_transformers = version('transformers')
+ vers = v_transformers.split('.')
+ major = int(vers[0])
+ minor = int(vers[1])
+ if major > 4 or (major == 4 and minor > 43):
+ import sys
+ sys.exit(
+ f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported."
+ )
+
super().__init__(**kwargs)
# All model specific things should be defined here instead of the base class.
diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py
index 95c9c31ca01c..4782c197592a 100644
--- a/deepspeed/module_inject/layers.py
+++ b/deepspeed/module_inject/layers.py
@@ -226,6 +226,14 @@ def __deepcopy__(self, memo):
memo[id(self)] = new_obj
return new_obj
+ def extra_repr(self):
+ if self.weight is not None:
+ out_features, in_features = self.weight.shape if self.weight is not None else (None, None)
+ dtype = self.weight.dtype if self.weight is not None else None
+ extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
+ in_features, out_features, self.bias is not None, dtype)
+ return extra_repr_str
+
class GatherReplacedLayerParams:
"""
@@ -681,7 +689,7 @@ def __init__(self, weight_shape=None, weight=None, bias=None):
self.offset = 2
super().__init__(weight_shape, weight=weight)
- def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py
index 0f3349b32256..e8928059767f 100644
--- a/deepspeed/module_inject/replace_module.py
+++ b/deepspeed/module_inject/replace_module.py
@@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
#mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)
# 1. Create AutoTP object
- _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
+ _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl,
+ config.keep_module_on_host)
# 2. Set the tensor parallelism config
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
@@ -342,13 +343,11 @@ def set_lm_head(module):
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight
# enable tensor parallel for the last linear
- if hasattr(module, "lm_head") and hasattr(module.lm_head,
- "weight") and not module.lm_head.weight.is_meta and isinstance(
- module.lm_head, torch.nn.Linear):
+ if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance(
+ module.lm_head, torch.nn.Linear):
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
- elif hasattr(module, "embed_out") and hasattr(module.embed_out,
- "weight") and not module.embed_out.weight.is_meta and isinstance(
- module.embed_out, torch.nn.Linear):
+ elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance(
+ module.embed_out, torch.nn.Linear):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"):
module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head")
@@ -389,7 +388,6 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
checkpoint=checkpoint_file)
pbar.update(1)
gc.collect()
- replaced_module = set_lm_head(replaced_module)
# conv2d tp module replace
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
if 'Yuan' in str(replaced_module):
@@ -399,6 +397,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)
+ # AutoTP default set lm_head tp
+ if not config.replace_with_kernel_inject:
+ replaced_module = set_lm_head(replaced_module)
quantizer = GroupQuantizer(q_int8=quantize)
world_size = dist.get_world_size() if dist.is_initialized() else 1
diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py
index 3e6fc2b63ef1..ded262edcf61 100644
--- a/deepspeed/module_inject/tp_shard.py
+++ b/deepspeed/module_inject/tp_shard.py
@@ -42,11 +42,16 @@ def get_num_attention_heads():
def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
+ # MoE MLP layer use near even division will get better perf.
+ moe_mlp_layer = ["gate_proj", "up_proj", "down_proj", "w1", "w2", "w3"]
+ not_moe_mlp_layer = True
+ if name != None and any(s in str(name) for s in moe_mlp_layer):
+ not_moe_mlp_layer = False
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
if rank == None:
rank = dist.get_rank()
if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
- name) not in last_linear:
+ name) not in last_linear and not_moe_mlp_layer:
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
diff --git a/deepspeed/ops/transformer/inference/triton/matmul_ext.py b/deepspeed/ops/transformer/inference/triton/matmul_ext.py
index 412c8740a216..9be4b0098c37 100644
--- a/deepspeed/ops/transformer/inference/triton/matmul_ext.py
+++ b/deepspeed/ops/transformer/inference/triton/matmul_ext.py
@@ -19,6 +19,9 @@
# -----------------------------------------------------------------------------
# util class/functions for triton
def is_nfs_path(path):
+ if os.name == 'nt':
+ return False
+
# Normalize the path to get the absolute path
path = os.path.abspath(path)
@@ -99,7 +102,7 @@ def put(self, table):
with FileLock(self.lock_path):
with open(self.file_path + ".tmp", 'wb') as handle:
pickle.dump(table, handle)
- os.rename(self.file_path + ".tmp", self.file_path)
+ os.replace(self.file_path + ".tmp", self.file_path)
def load(self):
if os.path.exists(self.file_path):
diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py
index b8df7499450d..d2c54155da89 100644
--- a/deepspeed/runtime/base_optimizer.py
+++ b/deepspeed/runtime/base_optimizer.py
@@ -28,7 +28,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
if self.mpu is None:
- logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
+ logger.warning("MPU is not provided, setting tp size to 1 in checkpoint loading.")
tp_world_size = 1
else:
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 612a5b237aee..437e1caf4df3 100755
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -3190,7 +3190,7 @@ def _get_all_zero_checkpoints(self, load_dir, tag):
if bf16_mode is not self.bfloat16_enabled():
checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16
engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16
- logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
+ logger.warning(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)
return None
@@ -3346,7 +3346,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
local_expert_id = None
if not m:
- logger.warn(f'No expert found in key {key}.')
+ logger.warning(f'No expert found in key {key}.')
else:
local_expert_id = m.group(1)
diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py
index 899358e2c5ef..2ffd0bf9f036 100755
--- a/deepspeed/runtime/lr_schedules.py
+++ b/deepspeed/runtime/lr_schedules.py
@@ -508,7 +508,7 @@ def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, l
def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration):
if 'betas' not in optimizer.defaults:
optimizer_name = type(optimizer).__name__
- logger.warn(
+ logger.warning(
f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
)
self.cycle_momentum = False
diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py
index ecb2a527f870..0508766f8896 100755
--- a/deepspeed/runtime/zero/stage_1_and_2.py
+++ b/deepspeed/runtime/zero/stage_1_and_2.py
@@ -614,7 +614,7 @@ def _configure_moe_settings(self):
assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
# NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion
if not self.partition_gradients and not self.contiguous_gradients:
- logger.warn(
+ logger.warning(
"ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.")
assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py
index 4fab768ce63c..4fa2cc988a19 100644
--- a/deepspeed/sequence/fpdt_layer.py
+++ b/deepspeed/sequence/fpdt_layer.py
@@ -47,7 +47,7 @@ def _update_out_and_lse(
block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
- new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
+ new_lse = lse + torch.log1p(torch.exp(block_lse - lse))
out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
diff --git a/op_builder/hpu/builder.py b/op_builder/hpu/builder.py
index c176a586ba49..11e710a8ee48 100644
--- a/op_builder/hpu/builder.py
+++ b/op_builder/hpu/builder.py
@@ -32,9 +32,6 @@ def builder(self):
def cxx_args(self):
args = ['-O3', '-g', '-Wno-reorder']
- CPU_ARCH = self.cpu_arch()
- SIMD_WIDTH = self.simd_width()
- args += [CPU_ARCH, '-fopenmp', SIMD_WIDTH]
return args
def libraries_args(self):
diff --git a/setup.py b/setup.py
index c0452f867b31..cc5eb4a3500c 100755
--- a/setup.py
+++ b/setup.py
@@ -321,9 +321,9 @@ def op_enabled(op_name):
include_package_data=True,
scripts=scripts,
classifiers=[
- 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9',
- 'Programming Language :: Python :: 3.10'
+ 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11',
+ 'Programming Language :: Python :: 3.12'
],
license='Apache Software License 2.0',
ext_modules=ext_modules,
diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py
index 9b563523dbeb..df85ed232a2e 100644
--- a/tests/unit/inference/test_inference.py
+++ b/tests/unit/inference/test_inference.py
@@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty
@pytest.mark.seq_inference
+@pytest.mark.parametrize('keep_module_on_host', [True, False])
@pytest.mark.parametrize(
"model_w_task",
[("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")],
@@ -570,6 +571,7 @@ def test(
inf_kwargs,
assert_fn,
dtype,
+ keep_module_on_host,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
@@ -592,13 +594,20 @@ def test(
framework="pt")
bs_output = pipe(query, **inf_kwargs)
- pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
+ pipe.model = deepspeed.init_inference(pipe.model,
+ mp_size=world_size,
+ dtype=dtype,
+ keep_module_on_host=keep_module_on_host)
ds_output = pipe(query, **inf_kwargs)
print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)
+ if keep_module_on_host:
+ for name, param in model.named_parameters():
+ assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"
+
@pytest.mark.world_size(3)
def test_odd_world_size(
self,
@@ -607,6 +616,7 @@ def test_odd_world_size(
inf_kwargs,
assert_fn,
dtype,
+ keep_module_on_host,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
@@ -624,13 +634,20 @@ def test_odd_world_size(
framework="pt")
bs_output = pipe(query, **inf_kwargs)
- pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
+ pipe.model = deepspeed.init_inference(pipe.model,
+ mp_size=world_size,
+ dtype=dtype,
+ keep_module_on_host=keep_module_on_host)
ds_output = pipe(query, **inf_kwargs)
print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)
+ if keep_module_on_host:
+ for name, param in model.named_parameters():
+ assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"
+
@pytest.mark.nightly
@pytest.mark.parametrize(
diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py
index 9cfcae809f09..d63c51267e51 100644
--- a/tests/unit/ops/transformer/inference/inference_test_utils.py
+++ b/tests/unit/ops/transformer/inference/inference_test_utils.py
@@ -3,6 +3,8 @@
# DeepSpeed Team
+from typing import Tuple
+
import torch
from deepspeed.accelerator import get_accelerator
@@ -23,38 +25,22 @@ def get_tolerances():
DTYPES = None
-def get_dtypes():
+def get_dtypes(include_float=True):
global DTYPES
if DTYPES is None:
- DTYPES = get_accelerator().supported_dtypes()
+ DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16]
+ try:
+ if get_accelerator().is_bf16_supported():
+ DTYPES.append(torch.bfloat16)
+ except (AssertionError, AttributeError):
+ pass
return DTYPES
-def allclose(x, y):
+def allclose(x, y, tolerances: Tuple[int, int] = None):
assert x.dtype == y.dtype
- rtol, atol = get_tolerances()[x.dtype]
+ if tolerances is None:
+ rtol, atol = get_tolerances()[x.dtype]
+ else:
+ rtol, atol = tolerances
return torch.allclose(x, y, rtol=rtol, atol=atol)
-
-
-def assert_almost_equal(x, y, decimal=2, err_msg=''):
- import numpy.testing as npt
- if isinstance(x, torch.Tensor):
- if x.dtype == torch.bfloat16:
- x = x.float()
- x = x.cpu().detach().numpy()
- if isinstance(y, torch.Tensor):
- if y.dtype == torch.bfloat16:
- y = y.float()
- y = y.cpu().detach().numpy()
- npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
-
-
-def max_diff(a, b):
- a = a.to(torch.float32).flatten()
- b = b.to(torch.float32).flatten()
- diff = torch.abs(a - b)
- max_diff_indices = torch.argsort(diff)[-1]
- print("Max difference indices:", max_diff_indices)
- print("Max difference values:", diff[max_diff_indices])
- print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}")
- return max_diff_indices
diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py
index ecf681542ff6..cae201d747a3 100644
--- a/tests/unit/ops/transformer/inference/test_attention.py
+++ b/tests/unit/ops/transformer/inference/test_attention.py
@@ -7,7 +7,7 @@
import torch
import deepspeed
from deepspeed.accelerator import get_accelerator
-from .inference_test_utils import assert_almost_equal
+from .inference_test_utils import allclose
# reference timplementation
@@ -88,4 +88,4 @@ def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float
use_triton_flash=False,
use_ds_attention=False)
tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3)
- assert_almost_equal(ref_out, tri_out)
+ assert (allclose(ref_out, tri_out))
diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py
index 7711daf0d887..4a84add16046 100644
--- a/tests/unit/ops/transformer/inference/test_layer_norm.py
+++ b/tests/unit/ops/transformer/inference/test_layer_norm.py
@@ -9,7 +9,7 @@
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
-from .inference_test_utils import allclose, get_dtypes, assert_almost_equal
+from .inference_test_utils import allclose, get_dtypes
try:
import triton # noqa: F401 # type: ignore
from deepspeed.ops.transformer.inference.triton import (
@@ -188,4 +188,4 @@ def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='
y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias,
eps).to(dtype)
# compare
- assert_almost_equal(y_tri, y_ref)
+ assert (allclose(y_tri, y_ref))