"""Utility functions for generating protobuf code.""" _PROTO_EXTENSION = ".proto" _VIRTUAL_IMPORTS = "/_virtual_imports/" def well_known_proto_libs(): return [ "@com_google_protobuf//:any_proto", "@com_google_protobuf//:api_proto", "@com_google_protobuf//:compiler_plugin_proto", "@com_google_protobuf//:descriptor_proto", "@com_google_protobuf//:duration_proto", "@com_google_protobuf//:empty_proto", "@com_google_protobuf//:field_mask_proto", "@com_google_protobuf//:source_context_proto", "@com_google_protobuf//:struct_proto", "@com_google_protobuf//:timestamp_proto", "@com_google_protobuf//:type_proto", "@com_google_protobuf//:wrappers_proto", ] def get_proto_root(workspace_root): """Gets the root protobuf directory. Args: workspace_root: context.label.workspace_root Returns: The directory relative to which generated include paths should be. """ if workspace_root: return "/{}".format(workspace_root) else: return "" def _strip_proto_extension(proto_filename): if not proto_filename.endswith(_PROTO_EXTENSION): fail('"{}" does not end with "{}"'.format( proto_filename, _PROTO_EXTENSION, )) return proto_filename[:-len(_PROTO_EXTENSION)] def proto_path_to_generated_filename(proto_path, fmt_str): """Calculates the name of a generated file for a protobuf path. For example, "examples/protos/helloworld.proto" might map to "helloworld.pb.h". Args: proto_path: The path to the .proto file. fmt_str: A format string used to calculate the generated filename. For example, "{}.pb.h" might be used to calculate a C++ header filename. Returns: The generated filename. """ return fmt_str.format(_strip_proto_extension(proto_path)) def get_include_directory(source_file): """Returns the include directory path for the source_file. I.e. all of the include statements within the given source_file are calculated relative to the directory returned by this method. The returned directory path can be used as the "--proto_path=" argument value. Args: source_file: A proto file. Returns: The include directory path for the source_file. """ directory = source_file.path prefix_len = 0 if is_in_virtual_imports(source_file): root, relative = source_file.path.split(_VIRTUAL_IMPORTS, 2) result = root + _VIRTUAL_IMPORTS + relative.split("/", 1)[0] return result if not source_file.is_source and directory.startswith(source_file.root.path): prefix_len = len(source_file.root.path) + 1 if directory.startswith("external", prefix_len): external_separator = directory.find("/", prefix_len) repository_separator = directory.find("/", external_separator + 1) return directory[:repository_separator] else: return source_file.root.path if source_file.root.path else "." def get_plugin_args(plugin, flags, dir_out, generate_mocks): """Returns arguments configuring protoc to use a plugin for a language. Args: plugin: An executable file to run as the protoc plugin. flags: The plugin flags to be passed to protoc. dir_out: The output directory for the plugin. generate_mocks: A bool indicating whether to generate mocks. Returns: A list of protoc arguments configuring the plugin. """ augmented_flags = list(flags) if generate_mocks: augmented_flags.append("generate_mock_code=true") return [ "--plugin=protoc-gen-PLUGIN=" + plugin.path, "--PLUGIN_out=" + ",".join(augmented_flags) + ":" + dir_out, ] def _get_staged_proto_file(context, source_file): if source_file.dirname == context.label.package or \ is_in_virtual_imports(source_file): # Current target and source_file are in same package return source_file else: # Current target and source_file are in different packages (most # probably even in different repositories) copied_proto = context.actions.declare_file(source_file.basename) context.actions.run_shell( inputs = [source_file], outputs = [copied_proto], command = "cp {} {}".format(source_file.path, copied_proto.path), mnemonic = "CopySourceProto", ) return copied_proto def protos_from_context(context): """Copies proto files to the appropriate location. Args: context: The ctx object for the rule. Returns: A list of the protos. """ protos = [] for src in context.attr.deps: for file in src[ProtoInfo].direct_sources: protos.append(_get_staged_proto_file(context, file)) return protos def includes_from_deps(deps): """Get includes from rule dependencies.""" return [ file for src in deps for file in src[ProtoInfo].transitive_imports.to_list() ] def get_proto_arguments(protos, genfiles_dir_path): """Get the protoc arguments specifying which protos to compile.""" arguments = [] for proto in protos: strip_prefix_len = 0 if is_in_virtual_imports(proto): incl_directory = get_include_directory(proto) if proto.path.startswith(incl_directory): strip_prefix_len = len(incl_directory) + 1 elif proto.path.startswith(genfiles_dir_path): strip_prefix_len = len(genfiles_dir_path) + 1 arguments.append(proto.path[strip_prefix_len:]) return arguments def declare_out_files(protos, context, generated_file_format): """Declares and returns the files to be generated.""" out_file_paths = [] for proto in protos: if not is_in_virtual_imports(proto): out_file_paths.append(proto.basename) else: path = proto.path[proto.path.index(_VIRTUAL_IMPORTS) + 1:] out_file_paths.append(path) return [ context.actions.declare_file( proto_path_to_generated_filename( out_file_path, generated_file_format, ), ) for out_file_path in out_file_paths ] def get_out_dir(protos, context): """ Returns the calculated value for --_out= protoc argument based on the input source proto files and current context. Args: protos: A list of protos to be used as source files in protoc command context: A ctx object for the rule. Returns: The value of --_out= argument. """ at_least_one_virtual = 0 for proto in protos: if is_in_virtual_imports(proto): at_least_one_virtual = True elif at_least_one_virtual: fail("Proto sources must be either all virtual imports or all real") if at_least_one_virtual: out_dir = get_include_directory(protos[0]) ws_root = protos[0].owner.workspace_root if ws_root and out_dir.find(ws_root) >= 0: out_dir = "".join(out_dir.rsplit(ws_root, 1)) return struct( path = out_dir, import_path = out_dir[out_dir.find(_VIRTUAL_IMPORTS) + 1:], ) return struct(path = context.genfiles_dir.path, import_path = None) def is_in_virtual_imports(source_file, virtual_folder = _VIRTUAL_IMPORTS): """Determines if source_file is virtual (is placed in _virtual_imports subdirectory). The output of all proto_library targets which use import_prefix and/or strip_import_prefix arguments is placed under _virtual_imports directory. Args: source_file: A proto file. virtual_folder: The virtual folder name (is set to "_virtual_imports" by default) Returns: True if source_file is located under _virtual_imports, False otherwise. """ return not source_file.is_source and virtual_folder in source_file.path