Return PyInfo provider with imports from _gen rules and pass that as deps in py_library. This allows hiding _virtual_imports include path from the surface.

pull/20150/head
vam-google 6 years ago
parent 9e6e57bfba
commit 564dc771dc
  1. 23
      bazel/protobuf.bzl
  2. 48
      bazel/python_rules.bzl
  3. 17
      bazel/test/python_test_repo/BUILD
  4. 23
      bazel/test/python_test_repo/helloworld.py
  5. 76
      bazel/test/python_test_repo/helloworld_moved.py

@ -3,7 +3,6 @@
_PROTO_EXTENSION = ".proto"
_VIRTUAL_IMPORTS = "/_virtual_imports/"
def well_known_proto_libs():
return [
"@com_google_protobuf//:any_proto",
@ -111,8 +110,8 @@ def get_plugin_args(plugin, flags, dir_out, generate_mocks):
]
def _get_staged_proto_file(context, source_file):
if source_file.dirname == context.label.package \
or is_in_virtual_imports(source_file):
if source_file.dirname == context.label.package or \
is_in_virtual_imports(source_file):
# Current target and source_file are in same package
return source_file
else:
@ -175,12 +174,8 @@ def declare_out_files(protos, context, generated_file_format):
out_file_paths.append(proto.basename)
else:
path = proto.path[proto.path.index(_VIRTUAL_IMPORTS) + 1:]
# TODO: uncomment if '.' path is chosen over
# `_virtual_imports/proto_library_target_name` as the output
# path = proto.path.split(_VIRTUAL_IMPORTS)[1].split("/", 1)[1]
out_file_paths.append(path)
return [
context.actions.declare_file(
proto_path_to_generated_filename(
@ -208,11 +203,15 @@ def get_out_dir(protos, context):
elif at_least_one_virtual:
fail("Proto sources must be either all virtual imports or all real")
if at_least_one_virtual:
return get_include_directory(protos[0])
# TODO: uncomment if '.' path is chosen over
# `_virtual_imports/proto_library_target_name` as the output path
# return "{}/{}".format(context.genfiles_dir.path, context.label.package)
return context.genfiles_dir.path
out_dir = get_include_directory(protos[0])
ws_root = protos[0].owner.workspace_root
if ws_root and out_dir.find(ws_root) >= 0:
out_dir = "".join(out_dir.rsplit(ws_root, 1))
return struct(
path = out_dir,
import_path = out_dir[out_dir.find(_VIRTUAL_IMPORTS) + 1:],
)
return struct(path = context.genfiles_dir.path, import_path = None)
def is_in_virtual_imports(source_file, virtual_folder = _VIRTUAL_IMPORTS):
"""Determines if source_file is virtual (is placed in _virtual_imports

@ -4,7 +4,6 @@ load(
"//bazel:protobuf.bzl",
"get_include_directory",
"get_plugin_args",
"get_proto_root",
"protos_from_context",
"includes_from_deps",
"get_proto_arguments",
@ -18,12 +17,12 @@ _GENERATED_GRPC_PROTO_FORMAT = "{}_pb2_grpc.py"
def _generate_py_impl(context):
protos = protos_from_context(context)
includes = includes_from_deps(context.attr.deps)
proto_root = get_proto_root(context.label.workspace_root)
out_files = declare_out_files(protos, context, _GENERATED_PROTO_FORMAT)
tools = [context.executable._protoc]
out_dir = get_out_dir(protos, context)
arguments = ([
"--python_out={}".format(get_out_dir(protos, context)),
"--python_out={}".format(out_dir.path),
] + [
"--proto_path={}".format(get_include_directory(i))
for i in includes
@ -40,7 +39,18 @@ def _generate_py_impl(context):
arguments = arguments,
mnemonic = "ProtocInvocation",
)
return struct(files = depset(out_files))
imports = []
if out_dir.import_path:
imports.append("__main__/%s" % out_dir.import_path)
return [
DefaultInfo(files = depset(direct = out_files)),
PyInfo(
transitive_sources = depset(),
imports = depset(direct = imports),
),
]
_generate_pb2_src = rule(
attrs = {
@ -83,24 +93,27 @@ def py_proto_library(
native.py_library(
name = name,
srcs = [":{}".format(codegen_target)],
deps = ["@com_google_protobuf//:protobuf_python"],
deps = [
"@com_google_protobuf//:protobuf_python",
":{}".format(codegen_target),
],
**kwargs
)
def _generate_pb2_grpc_src_impl(context):
protos = protos_from_context(context)
includes = includes_from_deps(context.attr.deps)
proto_root = get_proto_root(context.label.workspace_root)
out_files = declare_out_files(protos, context, _GENERATED_GRPC_PROTO_FORMAT)
plugin_flags = ["grpc_2_0"] + context.attr.strip_prefixes
arguments = []
tools = [context.executable._protoc, context.executable._plugin]
out_dir = get_out_dir(protos, context)
arguments += get_plugin_args(
context.executable._plugin,
plugin_flags,
get_out_dir(protos, context),
out_dir.path,
False,
)
@ -119,7 +132,18 @@ def _generate_pb2_grpc_src_impl(context):
arguments = arguments,
mnemonic = "ProtocInvocation",
)
return struct(files = depset(out_files))
imports = []
if out_dir.import_path:
imports.append("__main__/%s" % out_dir.import_path)
return [
DefaultInfo(files = depset(direct = out_files)),
PyInfo(
transitive_sources = depset(),
imports = depset(direct = imports),
),
]
_generate_pb2_grpc_src = rule(
attrs = {
@ -185,7 +209,11 @@ def py_grpc_library(
srcs = [
":{}".format(codegen_grpc_target),
],
deps = [Label("//src/python/grpcio/grpc:grpcio")] + deps,
deps = [
Label("//src/python/grpcio/grpc:grpcio"),
] + deps + [
":{}".format(codegen_grpc_target)
],
**kwargs
)

@ -88,26 +88,13 @@ py_grpc_library(
py_test(
name = "import_moved_test",
main = "helloworld.py",
srcs = ["helloworld.py"],
main = "helloworld_moved.py",
srcs = ["helloworld_moved.py"],
deps = [
":helloworld_moved_py_pb2",
":helloworld_moved_py_pb2_grpc",
":duration_py_pb2",
":timestamp_py_pb2",
],
imports = [
"_virtual_imports/helloworld_moved_proto",
# The following line allows us to keep helloworld.py file same for both
# test cases ("import_test" and "import_moved_test") and reduce the code
# duplication.
#
# Without this line, the actual imports in hellowold.py should look
# like the following:
# import google.cloud.helloworld_pb2 as helloworld_pb2
# instead of:
# import helloworld_pb2
"_virtual_imports/helloworld_moved_proto/google/cloud"
],
python_version = "PY3",
)

@ -20,7 +20,9 @@ import unittest
import grpc
import duration_pb2
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
from concurrent import futures
import helloworld_pb2
import helloworld_pb2_grpc
@ -31,12 +33,13 @@ _SERVER_ADDRESS = '{}:0'.format(_HOST)
class Greeter(helloworld_pb2_grpc.GreeterServicer):
def SayHello(self, request, context):
request_in_flight = datetime.now() - request.request_initation.ToDatetime()
request_in_flight = datetime.datetime.now() - \
request.request_initiation.ToDatetime()
request_duration = duration_pb2.Duration()
request_duration.FromTimedelta(request_in_flight)
return helloworld_pb2.HelloReply(
message='Hello, %s!' % request.name,
request_duration=request_duration,
message='Hello, %s!' % request.name,
request_duration=request_duration,
)
@ -53,19 +56,19 @@ def _listening_server():
class ImportTest(unittest.TestCase):
def run():
def test_import(self):
with _listening_server() as port:
with grpc.insecure_channel('{}:{}'.format(_HOST, port)) as channel:
stub = helloworld_pb2_grpc.GreeterStub(channel)
request_timestamp = timestamp_pb2.Timestamp()
request_timestamp.GetCurrentTime()
response = stub.SayHello(helloworld_pb2.HelloRequest(
name='you',
request_initiation=request_timestamp,
),
wait_for_ready=True)
name='you',
request_initiation=request_timestamp,
),
wait_for_ready=True)
self.assertEqual(response.message, "Hello, you!")
self.assertGreater(response.request_duration.microseconds, 0)
self.assertGreater(response.request_duration.nanos, 0)
if __name__ == '__main__':

@ -0,0 +1,76 @@
# Copyright 2019 the gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Python implementation of the GRPC helloworld.Greeter client."""
import contextlib
import datetime
import logging
import unittest
import grpc
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
from concurrent import futures
from google.cloud import helloworld_pb2
from google.cloud import helloworld_pb2_grpc
_HOST = 'localhost'
_SERVER_ADDRESS = '{}:0'.format(_HOST)
class Greeter(helloworld_pb2_grpc.GreeterServicer):
def SayHello(self, request, context):
request_in_flight = datetime.datetime.now() - \
request.request_initiation.ToDatetime()
request_duration = duration_pb2.Duration()
request_duration.FromTimedelta(request_in_flight)
return helloworld_pb2.HelloReply(
message='Hello, %s!' % request.name,
request_duration=request_duration,
)
@contextlib.contextmanager
def _listening_server():
server = grpc.server(futures.ThreadPoolExecutor())
helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
port = server.add_insecure_port(_SERVER_ADDRESS)
server.start()
try:
yield port
finally:
server.stop(0)
class ImportTest(unittest.TestCase):
def test_import(self):
with _listening_server() as port:
with grpc.insecure_channel('{}:{}'.format(_HOST, port)) as channel:
stub = helloworld_pb2_grpc.GreeterStub(channel)
request_timestamp = timestamp_pb2.Timestamp()
request_timestamp.GetCurrentTime()
response = stub.SayHello(helloworld_pb2.HelloRequest(
name='you',
request_initiation=request_timestamp,
),
wait_for_ready=True)
self.assertEqual(response.message, "Hello, you!")
self.assertGreater(response.request_duration.nanos, 0)
if __name__ == '__main__':
logging.basicConfig()
unittest.main()
Loading…
Cancel
Save