From 564dc771dc5be432eff721959f69e2c8221edf6c Mon Sep 17 00:00:00 2001 From: vam-google Date: Tue, 15 Oct 2019 17:47:02 -0700 Subject: [PATCH] Return PyInfo provider with imports from _gen rules and pass that as deps in py_library. This allows hiding _virtual_imports include path from the surface. --- bazel/protobuf.bzl | 23 +++--- bazel/python_rules.bzl | 48 +++++++++--- bazel/test/python_test_repo/BUILD | 17 +---- bazel/test/python_test_repo/helloworld.py | 23 +++--- .../test/python_test_repo/helloworld_moved.py | 76 +++++++++++++++++++ 5 files changed, 140 insertions(+), 47 deletions(-) create mode 100644 bazel/test/python_test_repo/helloworld_moved.py diff --git a/bazel/protobuf.bzl b/bazel/protobuf.bzl index 30b733c5b51..5ea2bbc8f00 100644 --- a/bazel/protobuf.bzl +++ b/bazel/protobuf.bzl @@ -3,7 +3,6 @@ _PROTO_EXTENSION = ".proto" _VIRTUAL_IMPORTS = "/_virtual_imports/" - def well_known_proto_libs(): return [ "@com_google_protobuf//:any_proto", @@ -111,8 +110,8 @@ def get_plugin_args(plugin, flags, dir_out, generate_mocks): ] def _get_staged_proto_file(context, source_file): - if source_file.dirname == context.label.package \ - or is_in_virtual_imports(source_file): + if source_file.dirname == context.label.package or \ + is_in_virtual_imports(source_file): # Current target and source_file are in same package return source_file else: @@ -175,12 +174,8 @@ def declare_out_files(protos, context, generated_file_format): out_file_paths.append(proto.basename) else: path = proto.path[proto.path.index(_VIRTUAL_IMPORTS) + 1:] - # TODO: uncomment if '.' path is chosen over - # `_virtual_imports/proto_library_target_name` as the output - # path = proto.path.split(_VIRTUAL_IMPORTS)[1].split("/", 1)[1] out_file_paths.append(path) - return [ context.actions.declare_file( proto_path_to_generated_filename( @@ -208,11 +203,15 @@ def get_out_dir(protos, context): elif at_least_one_virtual: fail("Proto sources must be either all virtual imports or all real") if at_least_one_virtual: - return get_include_directory(protos[0]) - # TODO: uncomment if '.' path is chosen over - # `_virtual_imports/proto_library_target_name` as the output path - # return "{}/{}".format(context.genfiles_dir.path, context.label.package) - return context.genfiles_dir.path + out_dir = get_include_directory(protos[0]) + ws_root = protos[0].owner.workspace_root + if ws_root and out_dir.find(ws_root) >= 0: + out_dir = "".join(out_dir.rsplit(ws_root, 1)) + return struct( + path = out_dir, + import_path = out_dir[out_dir.find(_VIRTUAL_IMPORTS) + 1:], + ) + return struct(path = context.genfiles_dir.path, import_path = None) def is_in_virtual_imports(source_file, virtual_folder = _VIRTUAL_IMPORTS): """Determines if source_file is virtual (is placed in _virtual_imports diff --git a/bazel/python_rules.bzl b/bazel/python_rules.bzl index 1e72b39c37f..2709d32e830 100644 --- a/bazel/python_rules.bzl +++ b/bazel/python_rules.bzl @@ -4,7 +4,6 @@ load( "//bazel:protobuf.bzl", "get_include_directory", "get_plugin_args", - "get_proto_root", "protos_from_context", "includes_from_deps", "get_proto_arguments", @@ -18,12 +17,12 @@ _GENERATED_GRPC_PROTO_FORMAT = "{}_pb2_grpc.py" def _generate_py_impl(context): protos = protos_from_context(context) includes = includes_from_deps(context.attr.deps) - proto_root = get_proto_root(context.label.workspace_root) out_files = declare_out_files(protos, context, _GENERATED_PROTO_FORMAT) - tools = [context.executable._protoc] + + out_dir = get_out_dir(protos, context) arguments = ([ - "--python_out={}".format(get_out_dir(protos, context)), + "--python_out={}".format(out_dir.path), ] + [ "--proto_path={}".format(get_include_directory(i)) for i in includes @@ -40,7 +39,18 @@ def _generate_py_impl(context): arguments = arguments, mnemonic = "ProtocInvocation", ) - return struct(files = depset(out_files)) + + imports = [] + if out_dir.import_path: + imports.append("__main__/%s" % out_dir.import_path) + + return [ + DefaultInfo(files = depset(direct = out_files)), + PyInfo( + transitive_sources = depset(), + imports = depset(direct = imports), + ), + ] _generate_pb2_src = rule( attrs = { @@ -83,24 +93,27 @@ def py_proto_library( native.py_library( name = name, srcs = [":{}".format(codegen_target)], - deps = ["@com_google_protobuf//:protobuf_python"], + deps = [ + "@com_google_protobuf//:protobuf_python", + ":{}".format(codegen_target), + ], **kwargs ) def _generate_pb2_grpc_src_impl(context): protos = protos_from_context(context) includes = includes_from_deps(context.attr.deps) - proto_root = get_proto_root(context.label.workspace_root) out_files = declare_out_files(protos, context, _GENERATED_GRPC_PROTO_FORMAT) plugin_flags = ["grpc_2_0"] + context.attr.strip_prefixes arguments = [] tools = [context.executable._protoc, context.executable._plugin] + out_dir = get_out_dir(protos, context) arguments += get_plugin_args( context.executable._plugin, plugin_flags, - get_out_dir(protos, context), + out_dir.path, False, ) @@ -119,7 +132,18 @@ def _generate_pb2_grpc_src_impl(context): arguments = arguments, mnemonic = "ProtocInvocation", ) - return struct(files = depset(out_files)) + + imports = [] + if out_dir.import_path: + imports.append("__main__/%s" % out_dir.import_path) + + return [ + DefaultInfo(files = depset(direct = out_files)), + PyInfo( + transitive_sources = depset(), + imports = depset(direct = imports), + ), + ] _generate_pb2_grpc_src = rule( attrs = { @@ -185,7 +209,11 @@ def py_grpc_library( srcs = [ ":{}".format(codegen_grpc_target), ], - deps = [Label("//src/python/grpcio/grpc:grpcio")] + deps, + deps = [ + Label("//src/python/grpcio/grpc:grpcio"), + ] + deps + [ + ":{}".format(codegen_grpc_target) + ], **kwargs ) diff --git a/bazel/test/python_test_repo/BUILD b/bazel/test/python_test_repo/BUILD index c09766131ba..0127e34e276 100644 --- a/bazel/test/python_test_repo/BUILD +++ b/bazel/test/python_test_repo/BUILD @@ -88,26 +88,13 @@ py_grpc_library( py_test( name = "import_moved_test", - main = "helloworld.py", - srcs = ["helloworld.py"], + main = "helloworld_moved.py", + srcs = ["helloworld_moved.py"], deps = [ ":helloworld_moved_py_pb2", ":helloworld_moved_py_pb2_grpc", ":duration_py_pb2", ":timestamp_py_pb2", ], - imports = [ - "_virtual_imports/helloworld_moved_proto", - # The following line allows us to keep helloworld.py file same for both - # test cases ("import_test" and "import_moved_test") and reduce the code - # duplication. - # - # Without this line, the actual imports in hellowold.py should look - # like the following: - # import google.cloud.helloworld_pb2 as helloworld_pb2 - # instead of: - # import helloworld_pb2 - "_virtual_imports/helloworld_moved_proto/google/cloud" - ], python_version = "PY3", ) \ No newline at end of file diff --git a/bazel/test/python_test_repo/helloworld.py b/bazel/test/python_test_repo/helloworld.py index deee36a8f71..3f87191efb4 100644 --- a/bazel/test/python_test_repo/helloworld.py +++ b/bazel/test/python_test_repo/helloworld.py @@ -20,7 +20,9 @@ import unittest import grpc -import duration_pb2 +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 +from concurrent import futures import helloworld_pb2 import helloworld_pb2_grpc @@ -31,12 +33,13 @@ _SERVER_ADDRESS = '{}:0'.format(_HOST) class Greeter(helloworld_pb2_grpc.GreeterServicer): def SayHello(self, request, context): - request_in_flight = datetime.now() - request.request_initation.ToDatetime() + request_in_flight = datetime.datetime.now() - \ + request.request_initiation.ToDatetime() request_duration = duration_pb2.Duration() request_duration.FromTimedelta(request_in_flight) return helloworld_pb2.HelloReply( - message='Hello, %s!' % request.name, - request_duration=request_duration, + message='Hello, %s!' % request.name, + request_duration=request_duration, ) @@ -53,19 +56,19 @@ def _listening_server(): class ImportTest(unittest.TestCase): - def run(): + def test_import(self): with _listening_server() as port: with grpc.insecure_channel('{}:{}'.format(_HOST, port)) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) request_timestamp = timestamp_pb2.Timestamp() request_timestamp.GetCurrentTime() response = stub.SayHello(helloworld_pb2.HelloRequest( - name='you', - request_initiation=request_timestamp, - ), - wait_for_ready=True) + name='you', + request_initiation=request_timestamp, + ), + wait_for_ready=True) self.assertEqual(response.message, "Hello, you!") - self.assertGreater(response.request_duration.microseconds, 0) + self.assertGreater(response.request_duration.nanos, 0) if __name__ == '__main__': diff --git a/bazel/test/python_test_repo/helloworld_moved.py b/bazel/test/python_test_repo/helloworld_moved.py new file mode 100644 index 00000000000..b32042cdfa9 --- /dev/null +++ b/bazel/test/python_test_repo/helloworld_moved.py @@ -0,0 +1,76 @@ +# Copyright 2019 the gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the GRPC helloworld.Greeter client.""" + +import contextlib +import datetime +import logging +import unittest + +import grpc + +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 +from concurrent import futures +from google.cloud import helloworld_pb2 +from google.cloud import helloworld_pb2_grpc + +_HOST = 'localhost' +_SERVER_ADDRESS = '{}:0'.format(_HOST) + + +class Greeter(helloworld_pb2_grpc.GreeterServicer): + + def SayHello(self, request, context): + request_in_flight = datetime.datetime.now() - \ + request.request_initiation.ToDatetime() + request_duration = duration_pb2.Duration() + request_duration.FromTimedelta(request_in_flight) + return helloworld_pb2.HelloReply( + message='Hello, %s!' % request.name, + request_duration=request_duration, + ) + + +@contextlib.contextmanager +def _listening_server(): + server = grpc.server(futures.ThreadPoolExecutor()) + helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) + port = server.add_insecure_port(_SERVER_ADDRESS) + server.start() + try: + yield port + finally: + server.stop(0) + + +class ImportTest(unittest.TestCase): + def test_import(self): + with _listening_server() as port: + with grpc.insecure_channel('{}:{}'.format(_HOST, port)) as channel: + stub = helloworld_pb2_grpc.GreeterStub(channel) + request_timestamp = timestamp_pb2.Timestamp() + request_timestamp.GetCurrentTime() + response = stub.SayHello(helloworld_pb2.HelloRequest( + name='you', + request_initiation=request_timestamp, + ), + wait_for_ready=True) + self.assertEqual(response.message, "Hello, you!") + self.assertGreater(response.request_duration.nanos, 0) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main()