Skip to content

Commit

Permalink
Merge branch 'master' into loadams/cpu-runner-debug
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 17, 2025
2 parents 800f5de + f97f088 commit 989bcd6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
21 changes: 21 additions & 0 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def forward(self, input):
output += self.bias
return output

def extra_repr(self):
out_features, in_features = self.weight.shape if self.weight is not None else (None, None)
dtype = self.weight.dtype if self.weight is not None else None
extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str


class LmHeadLinearAllreduce(nn.Module):

Expand Down Expand Up @@ -120,6 +127,13 @@ def forward(self, input):
output += self.bias
return output

def extra_repr(self):
out_features, in_features = self.weight.shape if self.weight is not None else (None, None)
dtype = self.weight.dtype if self.weight is not None else None
extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str


class LinearLayer(nn.Module):

Expand All @@ -144,6 +158,13 @@ def forward(self, input):
output += self.bias
return output

def extra_repr(self):
out_features, in_features = self.weight.shape
dtype = self.weight.dtype
extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str


class Normalize(nn.Module):

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def cast_to_half(x):

def cifar_trainset(fp16=False):
torchvision = pytest.importorskip("torchvision", minversion="0.5.0")
import torchvision.transforms as transforms
from torchvision import transforms

transform_list = [
transforms.ToTensor(),
Expand Down

0 comments on commit 989bcd6

Please sign in to comment.