|
|
|
@ -127,6 +127,9 @@ class DescriptorPool(object): |
|
|
|
|
self._service_descriptors = {} |
|
|
|
|
self._file_descriptors = {} |
|
|
|
|
self._toplevel_extensions = {} |
|
|
|
|
# TODO(jieluo): Remove _file_desc_by_toplevel_extension when |
|
|
|
|
# FieldDescriptor.file is added in code gen. |
|
|
|
|
self._file_desc_by_toplevel_extension = {} |
|
|
|
|
# We store extensions in two two-level mappings: The first key is the |
|
|
|
|
# descriptor of the message being extended, the second key is the extension |
|
|
|
|
# full name or its tag number. |
|
|
|
@ -170,7 +173,7 @@ class DescriptorPool(object): |
|
|
|
|
raise TypeError('Expected instance of descriptor.Descriptor.') |
|
|
|
|
|
|
|
|
|
self._descriptors[desc.full_name] = desc |
|
|
|
|
self.AddFileDescriptor(desc.file) |
|
|
|
|
self._AddFileDescriptor(desc.file) |
|
|
|
|
|
|
|
|
|
def AddEnumDescriptor(self, enum_desc): |
|
|
|
|
"""Adds an EnumDescriptor to the pool. |
|
|
|
@ -185,7 +188,7 @@ class DescriptorPool(object): |
|
|
|
|
raise TypeError('Expected instance of descriptor.EnumDescriptor.') |
|
|
|
|
|
|
|
|
|
self._enum_descriptors[enum_desc.full_name] = enum_desc |
|
|
|
|
self.AddFileDescriptor(enum_desc.file) |
|
|
|
|
self._AddFileDescriptor(enum_desc.file) |
|
|
|
|
|
|
|
|
|
def AddServiceDescriptor(self, service_desc): |
|
|
|
|
"""Adds a ServiceDescriptor to the pool. |
|
|
|
@ -251,6 +254,23 @@ class DescriptorPool(object): |
|
|
|
|
file_desc: A FileDescriptor. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
self._AddFileDescriptor(file_desc) |
|
|
|
|
# TODO(jieluo): This is a temporary solution for FieldDescriptor.file. |
|
|
|
|
# Remove it when FieldDescriptor.file is added in code gen. |
|
|
|
|
for extension in file_desc.extensions_by_name.itervalues(): |
|
|
|
|
self._file_desc_by_toplevel_extension[ |
|
|
|
|
extension.full_name] = file_desc |
|
|
|
|
|
|
|
|
|
def _AddFileDescriptor(self, file_desc): |
|
|
|
|
"""Adds a FileDescriptor to the pool, non-recursively. |
|
|
|
|
|
|
|
|
|
If the FileDescriptor contains messages or enums, the caller must explicitly |
|
|
|
|
register them. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
file_desc: A FileDescriptor. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if not isinstance(file_desc, descriptor.FileDescriptor): |
|
|
|
|
raise TypeError('Expected instance of descriptor.FileDescriptor.') |
|
|
|
|
self._file_descriptors[file_desc.name] = file_desc |
|
|
|
@ -313,12 +333,18 @@ class DescriptorPool(object): |
|
|
|
|
except KeyError: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
return self._file_desc_by_toplevel_extension[symbol] |
|
|
|
|
except KeyError: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
# Try nested extensions inside a message. |
|
|
|
|
message_name, _, extension_name = symbol.rpartition('.') |
|
|
|
|
try: |
|
|
|
|
scope = self.FindMessageTypeByName(message_name) |
|
|
|
|
assert scope.extensions_by_name[extension_name] |
|
|
|
|
return scope.file |
|
|
|
|
message = self.FindMessageTypeByName(message_name) |
|
|
|
|
assert message.extensions_by_name[extension_name] |
|
|
|
|
return message.file |
|
|
|
|
|
|
|
|
|
except KeyError: |
|
|
|
|
raise KeyError('Cannot find a file containing %s' % symbol) |
|
|
|
|
|
|
|
|
|