Skip to content

Commit

Permalink
minor refactor loss scaler (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunyang-wen authored Jun 12, 2020
1 parent f502550 commit 96c4daa
Showing 1 changed file with 28 additions and 31 deletions.
59 changes: 28 additions & 31 deletions deepspeed/pt/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,29 @@ def to_python_float(t):
return t[0]


class LossScaler:
class LossScalerBase:
"""LossScalarBase
Base class for a loss scaler
"""
def __init__(self, cur_scale):
self.cur_scale = cur_scale

@property
def loss_scale(self):
return self.cur_scale

def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)

def update_scale(self, overflow):
pass

def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)


class LossScaler(LossScalerBase):
"""
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
Expand All @@ -43,7 +65,7 @@ class LossScaler:
scale (float, optional, default=1.0): The loss scale.
"""
def __init__(self, scale=1):
self.cur_scale = scale
super(LossScaler, self).__init__(scale)

# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
Expand All @@ -53,22 +75,8 @@ def has_overflow(self, params):
def _has_inf_or_nan(x):
return False

def update_scale(self, overflow):
pass

@property
def loss_scale(self):
return self.cur_scale

def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)

def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)


class DynamicLossScaler:
class DynamicLossScaler(LossScalerBase):
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
Expand Down Expand Up @@ -100,7 +108,7 @@ def __init__(self,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False):
self.cur_scale = init_scale
super(DynamicLossScaler, self).__init__(init_scale)
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
Expand All @@ -113,7 +121,7 @@ def __init__(self,
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for p in params:
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True

return False
Expand All @@ -135,7 +143,7 @@ def _has_inf_or_nan(x):
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:
return True
return False

Expand All @@ -157,17 +165,6 @@ def update_scale(self, overflow):
self.cur_scale *= self.scale_factor
self.cur_iter += 1

@property
def loss_scale(self):
return self.cur_scale

def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)

def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)


##############################################################
# Example usage below here -- assuming it's in a separate file
Expand Down

0 comments on commit 96c4daa

Please sign in to comment.