[AbortError] Reapply "[AbortError] And and check AbortError while abort" (#34525)

Reverts grpc/grpc#34515

This PR reapplies AbortError change as the previous one was reverted.

This change was mentioned in this gRFC: [L105: Python Add New Error Types](https://github.com/grpc/proposal/blob/master/L105-python-expose-new-error-types.md)

Closes #34525

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/34525 from grpc:revert-34515-revert-33969-checkAbortError ce63ab1d6a
PiperOrigin-RevId: 613718295
pull/36076/head
Xuan Wang 9 months ago committed by Copybara-Service
parent d1cea2dd09
commit 675dcccd5e
  1. 6
      src/python/grpcio/grpc/BUILD.bazel
  2. 20
      src/python/grpcio/grpc/__init__.py
  3. 4
      src/python/grpcio/grpc/_channel.py
  4. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  5. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  6. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  7. 24
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  8. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  9. 1
      src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
  10. 57
      src/python/grpcio/grpc/_errors.py
  11. 12
      src/python/grpcio/grpc/_server.py
  12. 8
      src/python/grpcio/grpc/aio/__init__.py
  13. 8
      src/python/grpcio/grpc/aio/_call.py
  14. 3
      src/python/grpcio/grpc/aio/_channel.py
  15. 5
      src/python/grpcio/grpc/aio/_interceptor.py
  16. 28
      src/python/grpcio_tests/tests/unit/_abort_test.py
  17. 2
      src/python/grpcio_tests/tests/unit/_api_test.py
  18. 5
      src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py

@ -99,6 +99,11 @@ py_library(
srcs = ["_observability.py"], srcs = ["_observability.py"],
) )
py_library(
name = "errors",
srcs = ["_errors.py"],
)
py_library( py_library(
name = "grpcio", name = "grpcio",
srcs = ["__init__.py"], srcs = ["__init__.py"],
@ -115,6 +120,7 @@ py_library(
":auth", ":auth",
":channel", ":channel",
":compression", ":compression",
":errors",
":interceptor", ":interceptor",
":plugin_wrapping", ":plugin_wrapping",
":server", ":server",

@ -21,6 +21,9 @@ import sys
from grpc import _compression from grpc import _compression
from grpc._cython import cygrpc as _cygrpc from grpc._cython import cygrpc as _cygrpc
from grpc._errors import AbortError
from grpc._errors import BaseError
from grpc._errors import RpcError
from grpc._runtime_protos import protos from grpc._runtime_protos import protos
from grpc._runtime_protos import protos_and_services from grpc._runtime_protos import protos_and_services
from grpc._runtime_protos import services from grpc._runtime_protos import services
@ -307,13 +310,6 @@ class Status(abc.ABC):
""" """
############################# gRPC Exceptions ################################
class RpcError(Exception):
"""Raised by the gRPC library to indicate non-OK-status RPC termination."""
############################## Shared Context ################################ ############################## Shared Context ################################
@ -1241,8 +1237,8 @@ class ServicerContext(RpcContext, metaclass=abc.ABCMeta):
termination of the RPC. termination of the RPC.
Raises: Raises:
Exception: An exception is always raised to signal the abortion the AbortError: A grpc.AbortError is always raised to signal the abortion
RPC to the gRPC runtime. the RPC to the gRPC runtime.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -1260,8 +1256,8 @@ class ServicerContext(RpcContext, metaclass=abc.ABCMeta):
StatusCode.OK. StatusCode.OK.
Raises: Raises:
Exception: An exception is always raised to signal the abortion the AbortError: A grpc.AbortError is always raised to signal the abortion
RPC to the gRPC runtime. the RPC to the gRPC runtime.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -2273,6 +2269,8 @@ __all__ = (
"ServiceRpcHandler", "ServiceRpcHandler",
"Server", "Server",
"ServerInterceptor", "ServerInterceptor",
"AbortError",
"BaseError",
"unary_unary_rpc_method_handler", "unary_unary_rpc_method_handler",
"unary_stream_rpc_method_handler", "unary_stream_rpc_method_handler",
"stream_unary_rpc_method_handler", "stream_unary_rpc_method_handler",

@ -369,7 +369,9 @@ def _rpc_state_string(class_name: str, rpc_state: _RPCState) -> str:
) )
class _InactiveRpcError(grpc.RpcError, grpc.Call, grpc.Future): class _InactiveRpcError(
grpc.RpcError, grpc.Call, grpc.Future
): # pylint: disable=too-many-ancestors
"""An RPC error not tied to the execution of a particular RPC. """An RPC error not tied to the execution of a particular RPC.
The RPC represented by the state object must not be in-progress or The RPC represented by the state object must not be in-progress or

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from grpc._errors import InternalError
_EMPTY_FLAGS = 0 _EMPTY_FLAGS = 0
_EMPTY_MASK = 0 _EMPTY_MASK = 0

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from grpc._errors import InternalError
cdef class CallbackFailureHandler: cdef class CallbackFailureHandler:

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# #
from grpc._errors import UsageError
class _WatchConnectivityFailed(Exception): class _WatchConnectivityFailed(Exception):
"""Dedicated exception class for watch connectivity failed. """Dedicated exception class for watch connectivity failed.

@ -82,30 +82,6 @@ _COMPRESSION_METADATA_STRING_MAPPING = {
CompressionAlgorithm.gzip: 'gzip', CompressionAlgorithm.gzip: 'gzip',
} }
class BaseError(Exception):
"""The base class for exceptions generated by gRPC AsyncIO stack."""
class UsageError(BaseError):
"""Raised when the usage of API by applications is inappropriate.
For example, trying to invoke RPC on a closed channel, mixing two styles
of streaming API on the client side. This exception should not be
suppressed.
"""
class AbortError(BaseError):
"""Raised when calling abort in servicer methods.
This exception should not be suppressed. Applications may catch it to
perform certain clean-up logic, and then re-raise it.
"""
class InternalError(BaseError):
"""Raised upon unexpected errors in native code."""
def schedule_coro_threadsafe(object coro, object loop): def schedule_coro_threadsafe(object coro, object loop):
try: try:

@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
from grpc._errors import BaseError, AbortError, InternalError, UsageError
import inspect import inspect
import traceback import traceback
import functools import functools

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from grpc._errors import InternalError, UsageError
cdef class Server: cdef class Server:

@ -0,0 +1,57 @@
# Copyright 2024 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
############################# gRPC Exceptions ################################
class BaseError(Exception):
"""
The base class for exceptions generated by gRPC.
"""
class UsageError(BaseError):
"""
Raised when the usage of API by applications is inappropriate.
For example, trying to invoke RPC on a closed channel, mixing two styles
of streaming API on the client side. This exception should not be
suppressed.
"""
class AbortError(BaseError):
"""
Raised when calling abort in servicer methods.
This exception should not be suppressed. Applications may catch it to
perform certain clean-up logic, and then re-raise it.
"""
class InternalError(BaseError):
"""
Raised upon unexpected errors in native code.
"""
class RpcError(BaseError):
"""Raised by the gRPC library to indicate non-OK-status RPC termination."""
__all__ = (
"BaseError",
"UsageError",
"AbortError",
"InternalError",
"RpcError",
)

@ -42,6 +42,7 @@ from grpc import _common # pytype: disable=pyi-error
from grpc import _compression # pytype: disable=pyi-error from grpc import _compression # pytype: disable=pyi-error
from grpc import _interceptor # pytype: disable=pyi-error from grpc import _interceptor # pytype: disable=pyi-error
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc._errors import AbortError
from grpc._typing import ArityAgnosticMethodHandler from grpc._typing import ArityAgnosticMethodHandler
from grpc._typing import ChannelArgumentType from grpc._typing import ChannelArgumentType
from grpc._typing import DeserializingFunction from grpc._typing import DeserializingFunction
@ -404,7 +405,7 @@ class _Context(grpc.ServicerContext):
self._state.code = code self._state.code = code
self._state.details = _common.encode(details) self._state.details = _common.encode(details)
self._state.aborted = True self._state.aborted = True
raise Exception() raise AbortError()
def abort_with_status(self, status: grpc.Status) -> None: def abort_with_status(self, status: grpc.Status) -> None:
self._state.trailing_metadata = status.trailing_metadata self._state.trailing_metadata = status.trailing_metadata
@ -557,6 +558,15 @@ def _call_behavior(
except Exception as exception: # pylint: disable=broad-except except Exception as exception: # pylint: disable=broad-except
with state.condition: with state.condition:
if state.aborted: if state.aborted:
if not isinstance(exception, AbortError):
try:
details = f"Exception happened while aborting: {exception}"
except Exception: # pylint: disable=broad-except
details = (
"Calling abort raised unprintable Exception!"
)
traceback.print_exc()
_LOGGER.exception(details)
_abort( _abort(
state, state,
rpc_event.call, rpc_event.call,

@ -20,13 +20,13 @@ created. AsyncIO doesn't provide thread safety for most of its APIs.
from typing import Any, Optional, Sequence, Tuple from typing import Any, Optional, Sequence, Tuple
import grpc import grpc
from grpc._cython.cygrpc import AbortError
from grpc._cython.cygrpc import BaseError
from grpc._cython.cygrpc import EOF from grpc._cython.cygrpc import EOF
from grpc._cython.cygrpc import InternalError
from grpc._cython.cygrpc import UsageError
from grpc._cython.cygrpc import init_grpc_aio from grpc._cython.cygrpc import init_grpc_aio
from grpc._cython.cygrpc import shutdown_grpc_aio from grpc._cython.cygrpc import shutdown_grpc_aio
from grpc._errors import AbortError
from grpc._errors import BaseError
from grpc._errors import InternalError
from grpc._errors import UsageError
from ._base_call import Call from ._base_call import Call
from ._base_call import RpcContext from ._base_call import RpcContext

@ -24,6 +24,8 @@ from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple
import grpc import grpc
from grpc import _common from grpc import _common
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc._errors import InternalError
from grpc._errors import UsageError
from . import _base_call from . import _base_call
from ._metadata import Metadata from ._metadata import Metadata
@ -337,7 +339,7 @@ class _StreamResponseMixin(Call):
if self._response_style is _APIStyle.UNKNOWN: if self._response_style is _APIStyle.UNKNOWN:
self._response_style = style self._response_style = style
elif self._response_style is not style: elif self._response_style is not style:
raise cygrpc.UsageError(_API_STYLE_ERROR) raise UsageError(_API_STYLE_ERROR)
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
@ -418,7 +420,7 @@ class _StreamRequestMixin(Call):
def _raise_for_different_style(self, style: _APIStyle): def _raise_for_different_style(self, style: _APIStyle):
if self._request_style is not style: if self._request_style is not style:
raise cygrpc.UsageError(_API_STYLE_ERROR) raise UsageError(_API_STYLE_ERROR)
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
@ -490,7 +492,7 @@ class _StreamRequestMixin(Call):
) )
try: try:
await self._cython_call.send_serialized_message(serialized_request) await self._cython_call.send_serialized_message(serialized_request)
except cygrpc.InternalError as err: except InternalError as err:
self._cython_call.set_internal_error(str(err)) self._cython_call.set_internal_error(str(err))
await self._raise_for_status() await self._raise_for_status()
except asyncio.CancelledError: except asyncio.CancelledError:

@ -22,6 +22,7 @@ from grpc import _common
from grpc import _compression from grpc import _compression
from grpc import _grpcio_metadata from grpc import _grpcio_metadata
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc._errors import InternalError
from . import _base_call from . import _base_call
from . import _base_channel from . import _base_channel
@ -431,7 +432,7 @@ class Channel(_base_channel.Channel):
continue continue
else: else:
# Unidentified Call object # Unidentified Call object
raise cygrpc.InternalError( raise InternalError(
f"Unrecognized call object: {candidate}" f"Unrecognized call object: {candidate}"
) )

@ -30,6 +30,7 @@ from typing import (
import grpc import grpc
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc._errors import UsageError
from . import _base_call from . import _base_call
from ._call import AioRpcError from ._call import AioRpcError
@ -562,7 +563,7 @@ class _InterceptedStreamRequestMixin:
# should be expected through an iterators provided # should be expected through an iterators provided
# by the caller. # by the caller.
if self._write_to_iterator_queue is None: if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(_API_STYLE_ERROR) raise UsageError(_API_STYLE_ERROR)
try: try:
call = await self._interceptors_task call = await self._interceptors_task
@ -588,7 +589,7 @@ class _InterceptedStreamRequestMixin:
# should be expected through an iterators provided # should be expected through an iterators provided
# by the caller. # by the caller.
if self._write_to_iterator_queue is None: if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(_API_STYLE_ERROR) raise UsageError(_API_STYLE_ERROR)
try: try:
call = await self._interceptors_task call = await self._interceptors_task

@ -20,11 +20,13 @@ import unittest
import weakref import weakref
import grpc import grpc
from grpc import AbortError
from tests.unit import test_common from tests.unit import test_common
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
_ABORT = "/test/abort" _ABORT = "/test/abort"
_ABORT_WITH_SERVER_CODE = "/test/abortServerCode"
_ABORT_WITH_STATUS = "/test/AbortWithStatus" _ABORT_WITH_STATUS = "/test/AbortWithStatus"
_INVALID_CODE = "/test/InvalidCode" _INVALID_CODE = "/test/InvalidCode"
@ -58,6 +60,20 @@ def abort_unary_unary(request, servicer_context):
raise Exception("This line should not be executed!") raise Exception("This line should not be executed!")
def abort_unary_unary_with_server_error(request, servicer_context):
try:
servicer_context.abort(
grpc.StatusCode.INTERNAL,
_ABORT_DETAILS,
)
except AbortError as err:
servicer_context.abort(
grpc.StatusCode.INTERNAL,
str(type(err).__name__),
)
raise Exception("This line should not be executed!")
def abort_with_status_unary_unary(request, servicer_context): def abort_with_status_unary_unary(request, servicer_context):
servicer_context.abort_with_status( servicer_context.abort_with_status(
_Status( _Status(
@ -80,6 +96,10 @@ class _GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_call_details): def service(self, handler_call_details):
if handler_call_details.method == _ABORT: if handler_call_details.method == _ABORT:
return grpc.unary_unary_rpc_method_handler(abort_unary_unary) return grpc.unary_unary_rpc_method_handler(abort_unary_unary)
elif handler_call_details.method == _ABORT_WITH_SERVER_CODE:
return grpc.unary_unary_rpc_method_handler(
abort_unary_unary_with_server_error
)
elif handler_call_details.method == _ABORT_WITH_STATUS: elif handler_call_details.method == _ABORT_WITH_STATUS:
return grpc.unary_unary_rpc_method_handler( return grpc.unary_unary_rpc_method_handler(
abort_with_status_unary_unary abort_with_status_unary_unary
@ -116,6 +136,14 @@ class AbortTest(unittest.TestCase):
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
self.assertEqual(rpc_error.details(), _ABORT_DETAILS) self.assertEqual(rpc_error.details(), _ABORT_DETAILS)
def test_server_abort_code(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ABORT_WITH_SERVER_CODE)(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
self.assertEqual(rpc_error.details(), str(AbortError.__name__))
# This test ensures that abort() does not store the raised exception, which # This test ensures that abort() does not store the raised exception, which
# on Python 3 (via the `__traceback__` attribute) holds a reference to # on Python 3 (via the `__traceback__` attribute) holds a reference to
# all local vars. Storing the raised exception can prevent GC and stop the # all local vars. Storing the raised exception can prevent GC and stop the

@ -59,6 +59,8 @@ class AllTest(unittest.TestCase):
"ServiceRpcHandler", "ServiceRpcHandler",
"Server", "Server",
"ServerInterceptor", "ServerInterceptor",
"AbortError",
"BaseError",
"LocalConnectionType", "LocalConnectionType",
"local_channel_credentials", "local_channel_credentials",
"local_server_credentials", "local_server_credentials",

@ -17,6 +17,7 @@ import logging
import unittest import unittest
import grpc import grpc
from grpc._errors import UsageError
from grpc.experimental import aio from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import messages_pb2
@ -544,10 +545,10 @@ class TestStreamUnaryClientInterceptor(AioTestBase):
call = stub.StreamingInputCall(request_iterator()) call = stub.StreamingInputCall(request_iterator())
with self.assertRaises(grpc._cython.cygrpc.UsageError): with self.assertRaises(UsageError):
await call.write(request) await call.write(request)
with self.assertRaises(grpc._cython.cygrpc.UsageError): with self.assertRaises(UsageError):
await call.done_writing() await call.done_writing()
await channel.close() await channel.close()

Loading…
Cancel
Save