|
|
|
@ -72,7 +72,8 @@ TEST(Torch_Importer, simple_read) |
|
|
|
|
importer->populateNet(net); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static void runTorchNet(String prefix, String outLayerName, bool isBinary) |
|
|
|
|
static void runTorchNet(String prefix, String outLayerName = "", |
|
|
|
|
bool check2ndBlob = false, bool isBinary = false) |
|
|
|
|
{ |
|
|
|
|
String suffix = (isBinary) ? ".dat" : ".txt"; |
|
|
|
|
|
|
|
|
@ -92,52 +93,69 @@ static void runTorchNet(String prefix, String outLayerName, bool isBinary) |
|
|
|
|
Blob out = net.getBlob(outLayerName); |
|
|
|
|
|
|
|
|
|
normAssert(outRef, out); |
|
|
|
|
|
|
|
|
|
if (check2ndBlob) |
|
|
|
|
{ |
|
|
|
|
Blob out2 = net.getBlob(outLayerName + ".1"); |
|
|
|
|
Blob ref2 = readTorchBlob(_tf(prefix + "_output_2" + suffix), isBinary); |
|
|
|
|
normAssert(out2, ref2); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, run_convolution) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_conv", "l1_Convolution", false); |
|
|
|
|
runTorchNet("net_conv"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, run_pool_max) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_pool_max", "l1_Pooling", false); |
|
|
|
|
runTorchNet("net_pool_max", "", true); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, run_pool_ave) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_pool_ave", "l1_Pooling", false); |
|
|
|
|
runTorchNet("net_pool_ave"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, run_reshape) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_reshape", "l1_Reshape", false); |
|
|
|
|
runTorchNet("net_reshape_batch", "l1_Reshape", false); |
|
|
|
|
runTorchNet("net_reshape"); |
|
|
|
|
runTorchNet("net_reshape_batch"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, run_linear) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_linear_2d", "l1_InnerProduct", false); |
|
|
|
|
runTorchNet("net_linear_2d"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, run_paralel) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_parallel", "l2_torchMerge", false); |
|
|
|
|
runTorchNet("net_parallel", "l2_torchMerge"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, run_concat) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_concat", "l2_torchMerge", false); |
|
|
|
|
runTorchNet("net_concat", "l2_torchMerge"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, run_deconv) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_deconv", "", false); |
|
|
|
|
runTorchNet("net_deconv"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, run_batch_norm) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_batch_norm", "", false); |
|
|
|
|
runTorchNet("net_batch_norm"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, net_prelu) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_prelu"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST(Torch_Importer, net_cadd_table) |
|
|
|
|
{ |
|
|
|
|
runTorchNet("net_cadd_table"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#if defined(ENABLE_TORCH_ENET_TESTS) |
|
|
|
|