Add mutability to the MetadataType

pull/22306/head
Mariano Anaya 5 years ago
parent 7d88c61f57
commit 0be36ed606
  1. 28
      src/python/grpcio/grpc/experimental/aio/_metadata.py
  2. 1
      src/python/grpcio_tests/tests/tests.json
  3. 26
      src/python/grpcio_tests/tests/unit/_metadata_test.py

@ -26,15 +26,15 @@ class Metadata(abc.Mapping):
* The order of the values by key is preserved * The order of the values by key is preserved
* Getting by an element by key, retrieves the first mapped value * Getting by an element by key, retrieves the first mapped value
* Supports an immutable view of the data * Supports an immutable view of the data
* Allows partial mutation on the data without recreating the new object from scratch.
""" """
def __init__(self, *args) -> None: def __init__(self, *args: Tuple[str, AnyStr]) -> None:
self._metadata = OrderedDict() self._metadata = OrderedDict()
for md_key, md_value in args: for md_key, md_value in args:
self.add(md_key, md_value) self.add(md_key, md_value)
def add(self, key: str, value: str) -> None: def add(self, key: str, value: str) -> None:
key = key.lower()
self._metadata.setdefault(key, []) self._metadata.setdefault(key, [])
self._metadata[key].append(value) self._metadata[key].append(value)
@ -43,30 +43,36 @@ class Metadata(abc.Mapping):
def __getitem__(self, key: str) -> str: def __getitem__(self, key: str) -> str:
try: try:
first, *_ = self._metadata[key.lower()] return self._metadata[key][0]
return first except (ValueError, IndexError) as e:
except ValueError as e:
raise KeyError("{0!r}".format(key)) from e raise KeyError("{0!r}".format(key)) from e
def __iter__(self) -> Iterator[Tuple[AnyStr, AnyStr]]: def __setitem__(self, key: str, value: AnyStr) -> None:
self._metadata[key] = [value]
def __iter__(self) -> Iterator[Tuple[str, AnyStr]]:
for key, values in self._metadata.items(): for key, values in self._metadata.items():
for value in values: for value in values:
yield (key, value) yield (key, value)
def view(self) -> Tuple[AnyStr, AnyStr]:
return tuple(self)
def get_all(self, key: str) -> List[str]: def get_all(self, key: str) -> List[str]:
"""For compatibility with other Metadata abstraction objects (like in Java), """For compatibility with other Metadata abstraction objects (like in Java),
this would return all items under the desired <key>. this would return all items under the desired <key>.
""" """
return self._metadata.get(key.lower(), []) return self._metadata.get(key, [])
def set_all(self, key: str, values: List[AnyStr]) -> None:
self._metadata[key] = values
def __contains__(self, key: str) -> bool: def __contains__(self, key: str) -> bool:
return key.lower() in self._metadata return key in self._metadata
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return self._metadata == other._metadata return self._metadata == other._metadata
def __repr__(self):
view = tuple(self)
return f"{0!r}({1!r})".format(self.__class__.__name__, view)

@ -60,6 +60,7 @@
"unit._metadata_code_details_test.MetadataCodeDetailsTest", "unit._metadata_code_details_test.MetadataCodeDetailsTest",
"unit._metadata_flags_test.MetadataFlagsTest", "unit._metadata_flags_test.MetadataFlagsTest",
"unit._metadata_test.MetadataTest", "unit._metadata_test.MetadataTest",
"unit._metadata_test.MetadataTypeTest",
"unit._reconnect_test.ReconnectTest", "unit._reconnect_test.ReconnectTest",
"unit._resource_exhausted_test.ResourceExhaustedTest", "unit._resource_exhausted_test.ResourceExhaustedTest",
"unit._rpc_test.RPCTest", "unit._rpc_test.RPCTest",

@ -267,10 +267,6 @@ class MetadataTypeTest(unittest.TestCase):
metadata["key not found"] metadata["key not found"]
self.assertIsNone(metadata.get("key not found")) self.assertIsNone(metadata.get("key not found"))
def test_view(self):
self.assertEqual(
Metadata(*self._DEFAULT_DATA).view(), self._DEFAULT_DATA)
def test_add_value(self): def test_add_value(self):
metadata = Metadata() metadata = Metadata()
metadata.add("key", "value") metadata.add("key", "value")
@ -279,7 +275,6 @@ class MetadataTypeTest(unittest.TestCase):
self.assertEqual(metadata["key"], "value") self.assertEqual(metadata["key"], "value")
self.assertEqual(metadata["key2"], "value2") self.assertEqual(metadata["key2"], "value2")
self.assertEqual(metadata["KEY2"], "value2")
def test_get_all_items(self): def test_get_all_items(self):
metadata = Metadata(*self._MULTI_ENTRY_DATA) metadata = Metadata(*self._MULTI_ENTRY_DATA)
@ -290,9 +285,7 @@ class MetadataTypeTest(unittest.TestCase):
def test_container(self): def test_container(self):
metadata = Metadata(*self._MULTI_ENTRY_DATA) metadata = Metadata(*self._MULTI_ENTRY_DATA)
for key in ("key1", "Key1", "KEY1"): self.assertIn("key", metadata)
with self.subTest(case=key):
self.assertIn(key, metadata, "{0!r} not found".format(key))
def test_equals(self): def test_equals(self):
metadata = Metadata() metadata = Metadata()
@ -303,6 +296,23 @@ class MetadataTypeTest(unittest.TestCase):
self.assertEqual(metadata, metadata2) self.assertEqual(metadata, metadata2)
self.assertNotEqual(metadata, "foo") self.assertNotEqual(metadata, "foo")
def test_repr(self):
metadata = Metadata(*self._DEFAULT_DATA)
expected = "Metadata({0!r})".format(self._DEFAULT_DATA)
self.assertEqual(repr(metadata), expected)
def test_set(self):
metadata = Metadata(*self._DEFAULT_DATA)
metadata["key"] = "override value"
self.assertEqual(metadata["key"], "override value")
def test_set_all(self):
metadata = Metadata(self._DEFAULT_DATA)
metadata.set_all("key", ["value1", b"new value 2"])
self.assertEqual(metadata["key"], "value1")
self.assertEqual(metadata.get_all("value1"), ["value1", b"new value 2"])
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig() logging.basicConfig()

Loading…
Cancel
Save