diff --git a/src/python/grpcio/grpc/experimental/aio/_metadata.py b/src/python/grpcio/grpc/experimental/aio/_metadata.py index 03c39ba9b7a..d470979d5a4 100644 --- a/src/python/grpcio/grpc/experimental/aio/_metadata.py +++ b/src/python/grpcio/grpc/experimental/aio/_metadata.py @@ -26,15 +26,15 @@ class Metadata(abc.Mapping): * The order of the values by key is preserved * Getting by an element by key, retrieves the first mapped value * Supports an immutable view of the data + * Allows partial mutation on the data without recreating the new object from scratch. """ - def __init__(self, *args) -> None: + def __init__(self, *args: Tuple[str, AnyStr]) -> None: self._metadata = OrderedDict() for md_key, md_value in args: self.add(md_key, md_value) def add(self, key: str, value: str) -> None: - key = key.lower() self._metadata.setdefault(key, []) self._metadata[key].append(value) @@ -43,30 +43,36 @@ class Metadata(abc.Mapping): def __getitem__(self, key: str) -> str: try: - first, *_ = self._metadata[key.lower()] - return first - except ValueError as e: + return self._metadata[key][0] + except (ValueError, IndexError) as e: raise KeyError("{0!r}".format(key)) from e - def __iter__(self) -> Iterator[Tuple[AnyStr, AnyStr]]: + def __setitem__(self, key: str, value: AnyStr) -> None: + self._metadata[key] = [value] + + def __iter__(self) -> Iterator[Tuple[str, AnyStr]]: for key, values in self._metadata.items(): for value in values: yield (key, value) - def view(self) -> Tuple[AnyStr, AnyStr]: - return tuple(self) - def get_all(self, key: str) -> List[str]: """For compatibility with other Metadata abstraction objects (like in Java), this would return all items under the desired . """ - return self._metadata.get(key.lower(), []) + return self._metadata.get(key, []) + + def set_all(self, key: str, values: List[AnyStr]) -> None: + self._metadata[key] = values def __contains__(self, key: str) -> bool: - return key.lower() in self._metadata + return key in self._metadata def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return NotImplemented return self._metadata == other._metadata + + def __repr__(self): + view = tuple(self) + return f"{0!r}({1!r})".format(self.__class__.__name__, view) diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 196e9f08b0a..c595a4d1bab 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -60,6 +60,7 @@ "unit._metadata_code_details_test.MetadataCodeDetailsTest", "unit._metadata_flags_test.MetadataFlagsTest", "unit._metadata_test.MetadataTest", + "unit._metadata_test.MetadataTypeTest", "unit._reconnect_test.ReconnectTest", "unit._resource_exhausted_test.ResourceExhaustedTest", "unit._rpc_test.RPCTest", diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py index e19da5f3867..7325d8b12ba 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -267,10 +267,6 @@ class MetadataTypeTest(unittest.TestCase): metadata["key not found"] self.assertIsNone(metadata.get("key not found")) - def test_view(self): - self.assertEqual( - Metadata(*self._DEFAULT_DATA).view(), self._DEFAULT_DATA) - def test_add_value(self): metadata = Metadata() metadata.add("key", "value") @@ -279,7 +275,6 @@ class MetadataTypeTest(unittest.TestCase): self.assertEqual(metadata["key"], "value") self.assertEqual(metadata["key2"], "value2") - self.assertEqual(metadata["KEY2"], "value2") def test_get_all_items(self): metadata = Metadata(*self._MULTI_ENTRY_DATA) @@ -290,9 +285,7 @@ class MetadataTypeTest(unittest.TestCase): def test_container(self): metadata = Metadata(*self._MULTI_ENTRY_DATA) - for key in ("key1", "Key1", "KEY1"): - with self.subTest(case=key): - self.assertIn(key, metadata, "{0!r} not found".format(key)) + self.assertIn("key", metadata) def test_equals(self): metadata = Metadata() @@ -303,6 +296,23 @@ class MetadataTypeTest(unittest.TestCase): self.assertEqual(metadata, metadata2) self.assertNotEqual(metadata, "foo") + def test_repr(self): + metadata = Metadata(*self._DEFAULT_DATA) + expected = "Metadata({0!r})".format(self._DEFAULT_DATA) + self.assertEqual(repr(metadata), expected) + + def test_set(self): + metadata = Metadata(*self._DEFAULT_DATA) + metadata["key"] = "override value" + self.assertEqual(metadata["key"], "override value") + + def test_set_all(self): + metadata = Metadata(self._DEFAULT_DATA) + metadata.set_all("key", ["value1", b"new value 2"]) + + self.assertEqual(metadata["key"], "value1") + self.assertEqual(metadata.get_all("value1"), ["value1", b"new value 2"]) + if __name__ == '__main__': logging.basicConfig()