mirror of https://github.com/opencv/opencv.git
Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
62 lines
2.3 KiB
62 lines
2.3 KiB
# This file is part of OpenCV project. |
|
# It is subject to the license terms in the LICENSE file found in the top-level directory |
|
# of this distribution and at http://opencv.org/license.html. |
|
# |
|
# Copyright (C) 2017, Intel Corporation, all rights reserved. |
|
# Third party copyrights are property of their respective owners. |
|
import tensorflow as tf |
|
import struct |
|
import argparse |
|
import numpy as np |
|
|
|
parser = argparse.ArgumentParser(description='Convert weights of a frozen TensorFlow graph to fp16.') |
|
parser.add_argument('--input', required=True, help='Path to frozen graph.') |
|
parser.add_argument('--output', required=True, help='Path to output graph.') |
|
parser.add_argument('--ops', default=['Conv2D', 'MatMul'], nargs='+', |
|
help='List of ops which weights are converted.') |
|
args = parser.parse_args() |
|
|
|
DT_FLOAT = 1 |
|
DT_HALF = 19 |
|
|
|
# For the frozen graphs, an every node that uses weights connected to Const nodes |
|
# through an Identity node. Usually they're called in the same way with '/read' suffix. |
|
# We'll replace all of them to Cast nodes. |
|
|
|
# Load the model |
|
with tf.gfile.FastGFile(args.input) as f: |
|
graph_def = tf.GraphDef() |
|
graph_def.ParseFromString(f.read()) |
|
|
|
# Set of all inputs from desired nodes. |
|
inputs = [] |
|
for node in graph_def.node: |
|
if node.op in args.ops: |
|
inputs += node.input |
|
|
|
weightsNodes = [] |
|
for node in graph_def.node: |
|
# From the whole inputs we need to keep only an Identity nodes. |
|
if node.name in inputs and node.op == 'Identity' and node.attr['T'].type == DT_FLOAT: |
|
weightsNodes.append(node.input[0]) |
|
|
|
# Replace Identity to Cast. |
|
node.op = 'Cast' |
|
node.attr['DstT'].type = DT_FLOAT |
|
node.attr['SrcT'].type = DT_HALF |
|
del node.attr['T'] |
|
del node.attr['_class'] |
|
|
|
# Convert weights to halfs. |
|
for node in graph_def.node: |
|
if node.name in weightsNodes: |
|
node.attr['dtype'].type = DT_HALF |
|
node.attr['value'].tensor.dtype = DT_HALF |
|
|
|
floats = node.attr['value'].tensor.tensor_content |
|
|
|
floats = struct.unpack('f' * (len(floats) / 4), floats) |
|
halfs = np.array(floats).astype(np.float16).view(np.uint16) |
|
node.attr['value'].tensor.tensor_content = struct.pack('H' * len(halfs), *halfs) |
|
|
|
tf.train.write_graph(graph_def, "", args.output, as_text=False)
|
|
|