From 6891169b58987064e77bdfd2b9591bcbcb7dcc1c Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Thu, 19 Mar 2020 17:53:48 +0100 Subject: [PATCH] Apply PR feedback * Fix the length of the object to account for all the keys it holds (consider it's a multi-mapping) * Add support for deleting items * Make test stricter --- .../grpcio/grpc/experimental/aio/_metadata.py | 28 +++++++++++++- .../grpcio_tests/tests/unit/_metadata_test.py | 37 ++++++++++++++++--- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_metadata.py b/src/python/grpcio/grpc/experimental/aio/_metadata.py index 6ba04d5e07a..4aa0b830a81 100644 --- a/src/python/grpcio/grpc/experimental/aio/_metadata.py +++ b/src/python/grpcio/grpc/experimental/aio/_metadata.py @@ -39,16 +39,40 @@ class Metadata(abc.Mapping): self._metadata[key].append(value) def __len__(self) -> int: - return len(self._metadata) + """Return the total number of elements that there are in the metadata, + including multiple values for the same key. + """ + return sum(map(len, self._metadata.values())) def __getitem__(self, key: str) -> str: + """When calling [], the first element of all those + mapped for is returned. + """ try: return self._metadata[key][0] except (ValueError, IndexError) as e: raise KeyError("{0!r}".format(key)) from e def __setitem__(self, key: str, value: AnyStr) -> None: - self._metadata[key] = [value] + """Calling metadata[] = + Maps to the first instance of . + """ + if key not in self: + self._metadata[key] = [value] + else: + current_values = self.get_all(key) + self._metadata[key] = [value, *current_values[1:]] + + def __delitem__(self, key: str) -> None: + """``del metadata[]`` deletes the first mapping for .""" + current_values = self.get_all(key) + if not current_values: + raise KeyError(repr(key)) + self._metadata[key] = current_values[1:] + + def delete_all(self, key: str) -> None: + """Delete all mappings for .""" + del self._metadata[key] def __iter__(self) -> Iterator[Tuple[str, AnyStr]]: for key, values in self._metadata.items(): diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py index 75eb61346cd..054d54cd932 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -248,7 +248,8 @@ class MetadataTypeTest(unittest.TestCase): def test_init_metadata(self): test_cases = { "emtpy": (), - "with-data": self._DEFAULT_DATA, + "with-single-data": self._DEFAULT_DATA, + "with-multi-data": self._MULTI_ENTRY_DATA, } for case, args in test_cases.items(): with self.subTest(case=case): @@ -301,17 +302,43 @@ class MetadataTypeTest(unittest.TestCase): self.assertEqual(repr(metadata), expected) def test_set(self): - metadata = Metadata(*self._DEFAULT_DATA) - metadata["key"] = "override value" - self.assertEqual(metadata["key"], "override value") + metadata = Metadata(*self._MULTI_ENTRY_DATA) + override_value = "override value" + for _ in range(3): + metadata["key1"] = override_value + + self.assertEqual(metadata["key1"], override_value) + self.assertEqual(metadata.get_all("key1"), + [override_value, "other value 1"]) + + empty_metadata = Metadata() + for _ in range(3): + empty_metadata["key"] = override_value + + self.assertEqual(empty_metadata["key"], override_value) + self.assertEqual(empty_metadata.get_all("key"), [override_value]) def test_set_all(self): - metadata = Metadata(self._DEFAULT_DATA) + metadata = Metadata(*self._DEFAULT_DATA) metadata.set_all("key", ["value1", b"new value 2"]) self.assertEqual(metadata["key"], "value1") self.assertEqual(metadata.get_all("key"), ["value1", b"new value 2"]) + def test_delete_values(self): + metadata = Metadata(*self._MULTI_ENTRY_DATA) + del metadata["key1"] + self.assertEqual(metadata.get("key1"), "other value 1") + + metadata.delete_all("key1") + self.assertNotIn("key1", metadata) + + metadata.delete_all("key2") + self.assertEqual(len(metadata), 0) + + with self.assertRaises(KeyError): + del metadata["other key"] + if __name__ == '__main__': logging.basicConfig()