|
|
@ -64,36 +64,51 @@ removedNodes = [] |
|
|
|
|
|
|
|
|
|
|
|
# Detect unfused batch normalization nodes and fuse them. |
|
|
|
# Detect unfused batch normalization nodes and fuse them. |
|
|
|
def fuse_batch_normalization(): |
|
|
|
def fuse_batch_normalization(): |
|
|
|
pattern = ['Add', 'Rsqrt', 'Mul', 'Mul', 'Mul', 'Sub', 'Add'] |
|
|
|
# Add_0 <-- moving_variance, add_y |
|
|
|
candidates = [] |
|
|
|
# Rsqrt <-- Add_0 |
|
|
|
|
|
|
|
# Mul_0 <-- Rsqrt, gamma |
|
|
|
for node in graph_def.node: |
|
|
|
# Mul_1 <-- input, Mul_0 |
|
|
|
if node.op == pattern[len(candidates)]: |
|
|
|
# Mul_2 <-- moving_mean, Mul_0 |
|
|
|
candidates.append(node) |
|
|
|
# Sub_0 <-- beta, Mul_2 |
|
|
|
|
|
|
|
# Add_1 <-- Mul_1, Sub_0 |
|
|
|
|
|
|
|
nodesMap = {node.name: node for node in graph_def.node} |
|
|
|
|
|
|
|
subgraph = ['Add', |
|
|
|
|
|
|
|
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']], |
|
|
|
|
|
|
|
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]] |
|
|
|
|
|
|
|
def checkSubgraph(node, targetNode, inputs, fusedNodes): |
|
|
|
|
|
|
|
op = targetNode[0] |
|
|
|
|
|
|
|
if node.op == op and (len(node.input) >= len(targetNode) - 1): |
|
|
|
|
|
|
|
fusedNodes.append(node) |
|
|
|
|
|
|
|
for i, inpOp in enumerate(targetNode[1:]): |
|
|
|
|
|
|
|
if isinstance(inpOp, list): |
|
|
|
|
|
|
|
if not node.input[i] in nodesMap or \ |
|
|
|
|
|
|
|
not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes): |
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
inputs[inpOp] = node.input[i] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return True |
|
|
|
else: |
|
|
|
else: |
|
|
|
candidates = [] |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
if len(candidates) == len(pattern): |
|
|
|
|
|
|
|
inp = candidates[3].input[0] |
|
|
|
|
|
|
|
gamma = candidates[2].input[1] |
|
|
|
|
|
|
|
beta = candidates[5].input[0] |
|
|
|
|
|
|
|
moving_mean = candidates[4].input[0] |
|
|
|
|
|
|
|
moving_variance = candidates[0].input[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nodesToRemove = [] |
|
|
|
|
|
|
|
for node in graph_def.node: |
|
|
|
|
|
|
|
inputs = {} |
|
|
|
|
|
|
|
fusedNodes = [] |
|
|
|
|
|
|
|
if checkSubgraph(node, subgraph, inputs, fusedNodes): |
|
|
|
name = node.name |
|
|
|
name = node.name |
|
|
|
node.Clear() |
|
|
|
node.Clear() |
|
|
|
node.name = name |
|
|
|
node.name = name |
|
|
|
node.op = 'FusedBatchNorm' |
|
|
|
node.op = 'FusedBatchNorm' |
|
|
|
node.input.append(inp) |
|
|
|
node.input.append(inputs['input']) |
|
|
|
node.input.append(gamma) |
|
|
|
node.input.append(inputs['gamma']) |
|
|
|
node.input.append(beta) |
|
|
|
node.input.append(inputs['beta']) |
|
|
|
node.input.append(moving_mean) |
|
|
|
node.input.append(inputs['moving_mean']) |
|
|
|
node.input.append(moving_variance) |
|
|
|
node.input.append(inputs['moving_variance']) |
|
|
|
text_format.Merge('f: 0.001', node.attr["epsilon"]) |
|
|
|
text_format.Merge('f: 0.001', node.attr["epsilon"]) |
|
|
|
|
|
|
|
nodesToRemove += fusedNodes[1:] |
|
|
|
for candidate in candidates[:-1]: |
|
|
|
for node in nodesToRemove: |
|
|
|
graph_def.node.remove(candidate) |
|
|
|
graph_def.node.remove(node) |
|
|
|
candidates = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fuse_batch_normalization() |
|
|
|
fuse_batch_normalization() |
|
|
|
|
|
|
|
|
|
|
|