diff --git a/.github/workflows/test_upb.yml b/.github/workflows/test_upb.yml index 4357b11507..1812336470 100644 --- a/.github/workflows/test_upb.yml +++ b/.github/workflows/test_upb.yml @@ -146,7 +146,7 @@ jobs: which python3 && mv `which python3` /tmp && ! which python3 && - bazel test $BAZEL_FLAGS --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 //python/... -- -//python/dist:source_wheel + bazel test $BAZEL_FLAGS --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 //python/... -- -//python/dist:source_wheel -//python:aarch64_test -//python:x86_64_test -//python:google/protobuf/pyext/_message.so -//python:proto_api build_wheels: name: Build Wheels diff --git a/python/build_targets.bzl b/python/build_targets.bzl index ee765ab4c6..49d642eafb 100644 --- a/python/build_targets.bzl +++ b/python/build_targets.bzl @@ -433,9 +433,14 @@ def build_targets(name): native.cc_library( name = "proto_api", + srcs = ["google/protobuf/proto_api.cc"], hdrs = ["google/protobuf/proto_api.h"], + strip_include_prefix = "/python", visibility = ["//visibility:public"], deps = [ + "//src/google/protobuf", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", "@system_python//:python_headers", ], ) diff --git a/python/google/protobuf/proto_api.cc b/python/google/protobuf/proto_api.cc new file mode 100644 index 0000000000..50277a3f02 --- /dev/null +++ b/python/google/protobuf/proto_api.cc @@ -0,0 +1,57 @@ +#include "google/protobuf/proto_api.h" + +#include + +#include "absl/log/absl_check.h" +#include "google/protobuf/message.h" +namespace google { +namespace protobuf { +namespace python { + +PythonMessageMutator::PythonMessageMutator(Message* owned_msg, Message* message, + PyObject* py_msg) + : owned_msg_(owned_msg), message_(message), py_msg_(py_msg) { + ABSL_DCHECK(py_msg != nullptr); + ABSL_DCHECK(message != nullptr); + Py_INCREF(py_msg_); +} + +PythonMessageMutator::PythonMessageMutator(PythonMessageMutator&& other) + : owned_msg_(other.owned_msg_ == nullptr ? nullptr + : other.owned_msg_.release()), + message_(other.message_), + py_msg_(other.py_msg_) { + other.message_ = nullptr; + other.py_msg_ = nullptr; +} + +PythonMessageMutator::~PythonMessageMutator() { + if (py_msg_ == nullptr) { + return; + } + + // PyErr_Occurred check is required because PyObject_CallMethod need this + // check. + if (!PyErr_Occurred() && owned_msg_ != nullptr) { + std::string wire; + message_->SerializeToString(&wire); + PyObject* py_wire = PyBytes_FromStringAndSize( + wire.data(), static_cast(wire.size())); + PyObject* parse = + PyObject_CallMethod(py_msg_, "ParseFromString", "O", py_wire); + Py_DECREF(py_wire); + if (parse != nullptr) { + Py_DECREF(parse); + } + } + Py_DECREF(py_msg_); +} + +PythonMessageMutator PyProto_API::CreatePythonMessageMutator( + Message* owned_msg, Message* msg, PyObject* py_msg) const { + return PythonMessageMutator(owned_msg, msg, py_msg); +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/proto_api.h b/python/google/protobuf/proto_api.h index 558fff084f..60b7ba5809 100644 --- a/python/google/protobuf/proto_api.h +++ b/python/google/protobuf/proto_api.h @@ -22,9 +22,12 @@ #ifndef GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__ #define GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__ +#include +#include #define PY_SSIZE_T_CLEAN #include +#include "absl/status/status.h" #include "google/protobuf/descriptor_database.h" #include "google/protobuf/message.h" @@ -32,6 +35,8 @@ namespace google { namespace protobuf { namespace python { +class PythonMessageMutator; + // Note on the implementation: // This API is designed after // https://docs.python.org/3/extending/extending.html#providing-a-c-api-for-an-extension-module @@ -45,6 +50,19 @@ struct PyProto_API { // Operations on Messages. + // Returns a PythonMessageMutator which the python message has been cleared. + // This API works with UPB, Cpp Extension and Pure Python. + // Side-effect: The message will definitely be cleared. *When* the message + // gets cleared is undefined (C++ will clear it up-front, python/upb will + // clear it on destruction). Nothing should rely on the python message + // during the lifetime of this object + // User should not hold onto the returned PythonMessageMutator while + // calling back into Python + // Warning: there is a risk of deadlock with Python/C++ if users use the + // returned message->GetDescriptor()->file->pool() + virtual absl::StatusOr GetClearedMessageMutator( + PyObject* msg) const = 0; + // If the passed object is a Python Message, returns its internal pointer. // Otherwise, returns NULL with an exception set. virtual const Message* GetMessagePointer(PyObject* msg) const = 0; @@ -54,6 +72,9 @@ struct PyProto_API { // This function will succeed only if there are no other Python objects // pointing to the message, like submessages or repeated containers. // With the current implementation, only empty messages are in this case. + [[deprecated( + "GetMutableMessagePointer() only work with Cpp Extension, " + "please migrate to GetClearedMessageMutator().")]] virtual Message* GetMutableMessagePointer(PyObject* msg) const = 0; // If the passed object is a Python Message Descriptor, returns its internal @@ -107,6 +128,37 @@ struct PyProto_API { // can work and return their Python counterparts. virtual PyObject* DescriptorPool_FromPool( const google::protobuf::DescriptorPool* pool) const = 0; + + protected: + PythonMessageMutator CreatePythonMessageMutator(Message* owned_msg, + Message* msg, + PyObject* py_msg) const; +}; + +// User should not hold onto this object while calling back into Python +class PythonMessageMutator { + public: + PythonMessageMutator(PythonMessageMutator&& other); + ~PythonMessageMutator(); + + Message* get() { return message_; } + Message* operator->() { return message_; } + const Message& operator*() { return *message_; } + + private: + friend struct google::protobuf::python::PyProto_API; + PythonMessageMutator(Message* owned_msg, Message* message, PyObject* py_msg); + // owned_msg_ is set for UPB/Pure Python. Cpp + // Extension should not set owned_msg_. + // owned_msg_ is a new Message for UPB/Pure Python. + // owned_msg_ is nullptr for Cpp Extension. + std::unique_ptr owned_msg_; + // message_ points to owned_msg_ for UPB/Pure Python. + // message_ points to in-place Message* for Cpp Extension. + Message* message_; + // py_msg_ points to the python message. message_ content will be serialized + // to py_msg_ at destructor for UPB/Pure Python, CPP Extension won't. + PyObject* py_msg_; }; inline const char* PyProtoAPICapsuleName() { diff --git a/python/google/protobuf/pyext/message_module.cc b/python/google/protobuf/pyext/message_module.cc index c3beb8253d..9fc76acb40 100644 --- a/python/google/protobuf/pyext/message_module.cc +++ b/python/google/protobuf/pyext/message_module.cc @@ -8,17 +8,195 @@ #define PY_SSIZE_T_CLEAN #include +#include "google/protobuf/descriptor.pb.h" +#include "absl/log/absl_log.h" +#include "google/protobuf/dynamic_message.h" #include "google/protobuf/message_lite.h" +#include "google/protobuf/proto_api.h" #include "google/protobuf/pyext/descriptor.h" #include "google/protobuf/pyext/descriptor_pool.h" #include "google/protobuf/pyext/message.h" #include "google/protobuf/pyext/message_factory.h" -#include "google/protobuf/proto_api.h" +#include "google/protobuf/stubs/status_macros.h" + +// Must be included last. +#include "google/protobuf/port_def.inc" namespace { +class ProtoAPIDescriptorDatabase : public google::protobuf::DescriptorDatabase { + public: + ProtoAPIDescriptorDatabase() { + PyObject* descriptor_pool = + PyImport_ImportModule("google.protobuf.descriptor_pool"); + if (descriptor_pool == nullptr) { + ABSL_LOG(ERROR) + << "Failed to import google.protobuf.descriptor_pool module."; + } + + pool_ = PyObject_CallMethod(descriptor_pool, "Default", nullptr); + if (pool_ == nullptr) { + ABSL_LOG(ERROR) << "Failed to get python Default pool."; + } + Py_DECREF(descriptor_pool); + }; + + ~ProtoAPIDescriptorDatabase() { + // Objects of this class are meant to be `static`ally initialized and + // never destroyed. This is a commonly used approach, because the order + // in which destructors of static objects run is unpredictable. In + // particular, it is possible that the Python interpreter may have been + // finalized already. + ABSL_DLOG(ERROR) << "MEANT TO BE UNREACHABLE."; + }; + + bool FindFileByName(const std::string& filename, + google::protobuf::FileDescriptorProto* output) override { + PyObject* pyfile_name = + PyUnicode_FromStringAndSize(filename.data(), filename.size()); + if (pyfile_name == nullptr) { + PyErr_Format(PyExc_TypeError, "Fail to convert proto file name"); + return false; + } + + PyObject* pyfile = + PyObject_CallMethod(pool_, "FindFileByName", "O", pyfile_name); + Py_DECREF(pyfile_name); + if (pyfile == nullptr) { + PyErr_Format(PyExc_TypeError, "Default python pool fail to find %s", + filename.data()); + return false; + } + + PyObject* pyfile_serialized = + PyObject_GetAttrString(pyfile, "serialized_pb"); + Py_DECREF(pyfile); + if (pyfile_serialized == nullptr) { + PyErr_Format(PyExc_TypeError, + "Python file has no attribute 'serialized_pb'"); + return false; + } + + bool ok = output->ParseFromArray( + reinterpret_cast(PyBytes_AS_STRING(pyfile_serialized)), + PyBytes_GET_SIZE(pyfile_serialized)); + if (!ok) { + ABSL_LOG(ERROR) << "Failed to parse descriptor for " << filename; + } + Py_DECREF(pyfile_serialized); + return ok; + } + + bool FindFileContainingSymbol(const std::string& symbol_name, + google::protobuf::FileDescriptorProto* output) override { + return false; + } + + bool FindFileContainingExtension( + const std::string& containing_type, int field_number, + google::protobuf::FileDescriptorProto* output) override { + return false; + } + + PyObject* pool() { return pool_; } + + private: + PyObject* pool_; +}; + +absl::StatusOr FindMessageDescriptor( + PyObject* pyfile, const char* descriptor_full_name) { + static auto* database = new ProtoAPIDescriptorDatabase(); + static auto* pool = new google::protobuf::DescriptorPool(database); + PyObject* pyfile_name = PyObject_GetAttrString(pyfile, "name"); + if (pyfile_name == nullptr) { + return absl::InvalidArgumentError("FileDescriptor has no attribute 'name'"); + } + PyObject* pyfile_pool = PyObject_GetAttrString(pyfile, "pool"); + if (pyfile_pool == nullptr) { + Py_DECREF(pyfile_name); + return absl::InvalidArgumentError("FileDescriptor has no attribute 'pool'"); + } + // Check the file descriptor is from generated pool. + bool is_from_generated_pool = database->pool() == pyfile_pool; + Py_DECREF(pyfile_pool); + const char* pyfile_name_char_ptr = PyUnicode_AsUTF8(pyfile_name); + if (pyfile_name_char_ptr == nullptr) { + Py_DECREF(pyfile_name); + return absl::InvalidArgumentError( + "FileDescriptor 'name' PyUnicode_AsUTF8() failure."); + } + if (!is_from_generated_pool) { + std::string error_msg = pyfile_name_char_ptr; + error_msg += " is not from generated pool"; + Py_DECREF(pyfile_name); + return absl::InvalidArgumentError(error_msg); + } + const google::protobuf::FileDescriptor* file_descriptor = + pool->FindFileByName(pyfile_name_char_ptr); + Py_DECREF(pyfile_name); + if (file_descriptor == nullptr) { + // Already checked the file is from generated pool above, this + // error should never be reached. + ABSL_DLOG(ERROR) << "MEANT TO BE UNREACHABLE."; + std::string error_msg = "Fail to find/build file "; + error_msg += pyfile_name_char_ptr; + return absl::InternalError(error_msg); + } + + const google::protobuf::Descriptor* descriptor = + pool->FindMessageTypeByName(descriptor_full_name); + if (descriptor == nullptr) { + return absl::InternalError("Fail to find descriptor by name."); + } + return descriptor; +} + +google::protobuf::DynamicMessageFactory* GetFactory() { + static google::protobuf::DynamicMessageFactory* factory = + new google::protobuf::DynamicMessageFactory; + return factory; +} + // C++ API. Clients get at this via proto_api.h struct ApiImplementation : google::protobuf::python::PyProto_API { + absl::StatusOr GetClearedMessageMutator( + PyObject* py_msg) const override { + if (PyObject_TypeCheck(py_msg, google::protobuf::python::CMessage_Type)) { + google::protobuf::Message* message = + google::protobuf::python::PyMessage_GetMutableMessagePointer(py_msg); + message->Clear(); + return CreatePythonMessageMutator(nullptr, message, py_msg); + } + PyObject* pyd = PyObject_GetAttrString(py_msg, "DESCRIPTOR"); + if (pyd == nullptr) { + return absl::InvalidArgumentError("py_msg has no attribute 'DESCRIPTOR'"); + } + + PyObject* fn = PyObject_GetAttrString(pyd, "full_name"); + if (fn == nullptr) { + return absl::InvalidArgumentError( + "DESCRIPTOR has no attribute 'full_name'"); + } + + const char* descriptor_full_name = PyUnicode_AsUTF8(fn); + if (descriptor_full_name == nullptr) { + return absl::InternalError("Fail to convert descriptor full name"); + } + + PyObject* pyfile = PyObject_GetAttrString(pyd, "file"); + Py_DECREF(pyd); + if (pyfile == nullptr) { + return absl::InvalidArgumentError("DESCRIPTOR has no attribute 'file'"); + } + auto d = FindMessageDescriptor(pyfile, descriptor_full_name); + Py_DECREF(pyfile); + RETURN_IF_ERROR(d.status()); + Py_DECREF(fn); + google::protobuf::Message* msg = GetFactory()->GetPrototype(*d)->New(); + return CreatePythonMessageMutator(msg, msg, py_msg); + } + const google::protobuf::Message* GetMessagePointer(PyObject* msg) const override { return google::protobuf::python::PyMessage_GetMessagePointer(msg); }