Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] Add NaiveSyncBatchNorm for NPU #13183

Open
wants to merge 1 commit into
base: release/2.6.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 56 additions & 59 deletions ppocr/modeling/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

from .base_model import BaseModel
from .distillation_model import DistillationModel
from .custom_device_layers import NaiveSyncBatchNorm

__all__ = ["build_model", "apply_to_static"]
__all__ = ["build_model", "apply_to_static", "NaiveSyncBatchNorm"]


def build_model(config):
Expand All @@ -38,81 +39,77 @@ def build_model(config):
def apply_to_static(model, config, logger):
if config["Global"].get("to_static", False) is not True:
return model
assert "d2s_train_image_shape" in config[
"Global"], "d2s_train_image_shape must be assigned for static training mode..."
supported_list = [
"DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet", "SVTR"
]
assert (
"d2s_train_image_shape" in config["Global"]
), "d2s_train_image_shape must be assigned for static training mode..."
supported_list = ["DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet", "SVTR"]
if config["Architecture"]["algorithm"] in ["Distillation"]:
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
else:
algo = config["Architecture"]["algorithm"]
assert algo in supported_list, f"algorithms that supports static training must in in {supported_list} but got {algo}"
assert (
algo in supported_list
), f"algorithms that supports static training must in in {supported_list} but got {algo}"

specs = [
InputSpec(
[None] + config["Global"]["d2s_train_image_shape"], dtype='float32')
InputSpec([None] + config["Global"]["d2s_train_image_shape"], dtype="float32")
]

if algo == "SVTR_LCNet":
specs.append([
InputSpec(
[None, config["Global"]["max_text_length"]],
dtype='int64'), InputSpec(
[None, config["Global"]["max_text_length"]], dtype='int64'),
InputSpec(
[None], dtype='int64'), InputSpec(
[None], dtype='float64')
])
specs.append(
[
InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
InputSpec([None], dtype="int64"),
InputSpec([None], dtype="float64"),
]
)
elif algo == "TableMaster":
specs.append(
[
InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
InputSpec(
[None, config["Global"]["max_text_length"]], dtype='int64'),
InputSpec(
[None, config["Global"]["max_text_length"], 4],
dtype='float32'),
[None, config["Global"]["max_text_length"], 4], dtype="float32"
),
InputSpec(
[None, config["Global"]["max_text_length"], 1],
dtype='float32'),
InputSpec(
[None, 6], dtype='float32'),
])
[None, config["Global"]["max_text_length"], 1], dtype="float32"
),
InputSpec([None, 6], dtype="float32"),
]
)
elif algo == "LayoutXLM":
specs = [[
InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
InputSpec(
shape=[None, 512, 4], dtype="int64"), # bbox
InputSpec(
shape=[None, 512], dtype="int64"), # attention_mask
InputSpec(
shape=[None, 512], dtype="int64"), # token_type_ids
InputSpec(
shape=[None, 3, 224, 224], dtype="float32"), # image
InputSpec(
shape=[None, 512], dtype="int64"), # label
]]
specs = [
[
InputSpec(shape=[None, 512], dtype="int64"), # input_ids
InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
InputSpec(shape=[None, 3, 224, 224], dtype="float32"), # image
InputSpec(shape=[None, 512], dtype="int64"), # label
]
]
elif algo == "SLANet":
specs.append([
InputSpec(
[None, config["Global"]["max_text_length"] + 2], dtype='int64'),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 4],
dtype='float32'),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 1],
dtype='float32'),
InputSpec(
[None, 6], dtype='float64'),
])
specs.append(
[
InputSpec(
[None, config["Global"]["max_text_length"] + 2], dtype="int64"
),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 4], dtype="float32"
),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 1], dtype="float32"
),
InputSpec([None, 6], dtype="float64"),
]
)
elif algo == "SVTR":
specs.append([
InputSpec(
[None, config["Global"]["max_text_length"]], dtype='int64'),
InputSpec(
[None], dtype='int64')
])
specs.append(
[
InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
InputSpec([None], dtype="int64"),
]
)
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
return model
125 changes: 125 additions & 0 deletions ppocr/modeling/architectures/custom_device_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import paddle
import paddle.nn as nn
import paddle.distributed as dist

__all__ = ["NaiveSyncBatchNorm"]


class _AllReduce(paddle.autograd.PyLayer):

@staticmethod
def forward(ctx, input):
input_list = [paddle.zeros_like(input) for k in range(dist.get_world_size())]
# Use allgather instead of allreduce since I don't trust in-place operations ..
dist.all_gather(input_list, input, sync_op=True)
inputs = paddle.stack(input_list, axis=0)
return paddle.sum(inputs, axis=0)

@staticmethod
def backward(ctx, grad_output):
dist.all_reduce(grad_output, sync_op=True)
return grad_output


def differentiable_all_reduce(input):
"""
Differentiable counterpart of `dist.all_reduce`.
"""
if (
not dist.is_available()
or not dist.is_initialized()
or dist.get_world_size() == 1
):
return input
return _AllReduce.apply(input)


class NaiveSyncBatchNorm(nn.BatchNorm2D):

def __init__(self, *args, stats_mode="", **kwargs):
super().__init__(*args, **kwargs)
assert stats_mode in ["", "N"]
self._stats_mode = stats_mode

def forward(self, input):
if dist.get_world_size() == 1 or not self.training:
return super().forward(input)

B, C = input.shape[0], input.shape[1]

mean = paddle.mean(input, axis=[0, 2, 3])
meansqr = paddle.mean(input * input, axis=[0, 2, 3])

if self._stats_mode == "":
assert (
B > 0
), 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
vec = paddle.concat([mean, meansqr], axis=0)
vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
mean, meansqr = paddle.split(vec, [C, C])
momentum = (
1 - self._momentum
) # NOTE: paddle has reverse momentum defination
else:
if B == 0:
vec = paddle.zeros([2 * C + 1], dtype=mean.dtype)
vec = vec + input.sum() # make sure there is gradient w.r.t input
else:
vec = paddle.concat(
[
mean,
meansqr,
paddle.ones([1], dtype=mean.dtype),
],
axis=0,
)
vec = differentiable_all_reduce(vec * B)

total_batch = vec[-1].detach()
momentum = total_batch.clip(max=1) * (
1 - self._momentum
) # no update if total_batch is 0
mean, meansqr, _ = paddle.split(
vec / total_batch.clip(min=1), [C, C, int(vec.shape[0] - 2 * C)]
) # avoid div-by-zero

var = meansqr - mean * mean
invstd = paddle.rsqrt(var + self._epsilon)
scale = self.weight * invstd
bias = self.bias - mean * scale
scale = scale.reshape([1, -1, 1, 1])
bias = bias.reshape([1, -1, 1, 1])

tmp_mean = self._mean + momentum * (mean.detach() - self._mean)
self._mean.set_value(tmp_mean)
tmp_variance = self._variance + (momentum * (var.detach() - self._variance))
self._variance.set_value(tmp_variance)
ret = input * scale + bias
return ret

@classmethod
def convert_sync_batchnorm(cls, layer):
layer_output = layer
if isinstance(layer, nn.BatchNorm2D):

layer_output = NaiveSyncBatchNorm(
layer._num_features,
layer._momentum,
layer._epsilon,
layer._weight_attr,
layer._bias_attr,
layer._data_format,
layer._name,
)

if layer._weight_attr is not False and layer._bias_attr is not False:
with paddle.no_grad():
layer_output.weight = layer.weight
layer_output.bias = layer.bias
layer_output._mean = layer._mean
layer_output._variance = layer._variance

for name, sublayer in layer.named_children():
layer_output.add_sublayer(name, cls.convert_sync_batchnorm(sublayer))
del layer
return layer_output
9 changes: 7 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import set_seed
from ppocr.modeling.architectures import apply_to_static
from ppocr.modeling.architectures import NaiveSyncBatchNorm
import tools.program as program

dist.get_world_size()
Expand Down Expand Up @@ -138,8 +139,12 @@ def main(config, device, logger, vdl_writer):

use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info("convert_sync_batchnorm")
if "npu" in paddle.get_device() and dist.ParallelEnv().nranks > 1:
model = NaiveSyncBatchNorm.convert_sync_batchnorm(model)
logger.info("convert_sync_batchnorm for NPU")
else:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info("convert_sync_batchnorm")

model = apply_to_static(model, config, logger)

Expand Down