diff --git a/fairscale/nn/model_parallel/layers.py b/fairscale/nn/model_parallel/layers.py index 05af2c741..1931200de 100644 --- a/fairscale/nn/model_parallel/layers.py +++ b/fairscale/nn/model_parallel/layers.py @@ -120,7 +120,7 @@ def __init__( self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights. - self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) + self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim)) # And initialize. _initialize_affine_weight( self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method @@ -187,7 +187,7 @@ def __init__( self.embedding_dim_per_partition = divide_and_check_no_remainder(self.embedding_dim, world_size) # Allocate weights. - self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim_per_partition)) + self.weight = Parameter(torch.empty(self.num_embeddings, self.embedding_dim_per_partition)) # And initialize. _initialize_affine_weight( self.weight, @@ -259,9 +259,9 @@ def __init__( # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. - self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) + self.weight = Parameter(torch.empty(self.output_size_per_partition, self.in_features)) if bias: - self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) + self.bias = Parameter(torch.empty(self.output_size_per_partition)) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() @@ -346,9 +346,9 @@ def __init__( # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. - self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition)) if bias: - self.bias = Parameter(torch.Tensor(self.out_features)) + self.bias = Parameter(torch.empty(self.out_features)) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_()