Nextgen Proto Pythonic API: Timestamp/Duration assignment, creation and calculation

Timestamp and Duration are now have more support with datetime and timedelta:
- Allows assign python datetime to protobuf DateTime field in addition to current FromDatetime/ToDatetime (Note: will throw exceptions for the differences in supported ranges)
- Allows assign python timedelta to protobuf Duration field in addition to current FromTimedelta/ToTimedelta
- Calculation between Timestamp, Duration, datetime and timedelta will also be supported.

example usage:

from datetime import datetime, timedelta
from event_pb2 import Event
e = Event(start_time=datetime(year=2112, month=2, day=3),
          duration=timedelta(hours=10))
duration = timedelta(hours=10))
end_time = e.start_time + timedelta(hours=4)
e.duration = end_time - e.start_time
PiperOrigin-RevId: 640639168
pull/16987/head
Jie Luo 6 months ago committed by Copybara-Service
parent a450c9cad0
commit b690e729eb
  1. 10
      python/google/protobuf/internal/descriptor_pool_test.py
  2. 8
      python/google/protobuf/internal/more_messages.proto
  3. 43
      python/google/protobuf/internal/python_message.py
  4. 51
      python/google/protobuf/internal/well_known_types.py
  5. 236
      python/google/protobuf/internal/well_known_types_test.py
  6. 71
      python/google/protobuf/pyext/message.cc
  7. 57
      python/message.c

@ -29,6 +29,8 @@ from google.protobuf.internal import more_messages_pb2
from google.protobuf.internal import no_package_pb2
from google.protobuf.internal import testing_refleaks
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import unittest_features_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_import_public_pb2
@ -435,6 +437,8 @@ class DescriptorPoolTestBase(object):
self.assertEqual(file2.name,
'google/protobuf/internal/factory_test2.proto')
self.testFindMessageTypeByName()
self.pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb)
self.pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb)
file_json = self.pool.AddSerializedFile(
more_messages_pb2.DESCRIPTOR.serialized_pb)
field = file_json.message_types_by_name['class'].fields_by_name['int_field']
@ -542,12 +546,18 @@ class DescriptorPoolTestBase(object):
# that uses a DescriptorDatabase.
# TODO: Fix python and cpp extension diff.
return
timestamp_desc = descriptor_pb2.FileDescriptorProto.FromString(
timestamp_pb2.DESCRIPTOR.serialized_pb)
duration_desc = descriptor_pb2.FileDescriptorProto.FromString(
duration_pb2.DESCRIPTOR.serialized_pb)
more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
more_messages_pb2.DESCRIPTOR.serialized_pb)
test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
self.pool.Add(timestamp_desc)
self.pool.Add(duration_desc)
self.pool.Add(more_messages_desc)
self.pool.Add(test1_desc)
self.pool.Add(test2_desc)

@ -13,6 +13,9 @@ syntax = "proto2";
package google.protobuf.internal;
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
// A message where tag numbers are listed out of order, to allow us to test our
// canonicalization of serialized output, which should always be in tag order.
// We also mix in some extensions for extra fun.
@ -348,3 +351,8 @@ message ConflictJsonName {
optional int32 value = 1 [json_name = "old_value"];
optional int32 new_value = 2 [json_name = "value"];
}
message WKTMessage {
optional Timestamp optional_timestamp = 1;
optional Duration optional_duration = 2;
}

@ -27,6 +27,7 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
import datetime
from io import BytesIO
import struct
import sys
@ -536,13 +537,30 @@ def _AddInitMethod(message_descriptor, cls):
self._fields[field] = copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
new_val = field_value
if isinstance(field_value, dict):
new_val = None
if isinstance(field_value, message_mod.Message):
new_val = field_value
elif isinstance(field_value, dict):
new_val = field.message_type._concrete_class(**field_value)
try:
copy.MergeFrom(new_val)
except TypeError:
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
elif field.message_type.full_name == 'google.protobuf.Timestamp':
copy.FromDatetime(field_value)
elif field.message_type.full_name == 'google.protobuf.Duration':
copy.FromTimedelta(field_value)
else:
raise TypeError(
'Message field {0}.{1} must be initialized with a '
'dict or instance of same class, got {2}.'.format(
message_descriptor.name,
field_name,
type(field_value).__name__,
)
)
if new_val:
try:
copy.MergeFrom(new_val)
except TypeError:
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
self._fields[field] = copy
else:
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
@ -753,8 +771,17 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
# We define a setter just so we can throw an exception with a more
# helpful error message.
def setter(self, new_value):
raise AttributeError('Assignment not allowed to composite field '
'"%s" in protocol message object.' % proto_field_name)
if field.message_type.full_name == 'google.protobuf.Timestamp':
getter(self)
self._fields[field].FromDatetime(new_value)
elif field.message_type.full_name == 'google.protobuf.Duration':
getter(self)
self._fields[field].FromTimedelta(new_value)
else:
raise AttributeError(
'Assignment not allowed to composite field '
'"%s" in protocol message object.' % proto_field_name
)
# Add a property to encapsulate the getter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name

@ -21,8 +21,8 @@ import calendar
import collections.abc
import datetime
import warnings
from google.protobuf.internal import field_mask
from typing import Union
FieldMask = field_mask.FieldMask
@ -271,12 +271,35 @@ class Timestamp(object):
# manipulated into a long value of seconds. During the conversion from
# struct_time to long, the source date in UTC, and so it follows that the
# correct transformation is calendar.timegm()
seconds = calendar.timegm(dt.utctimetuple())
nanos = dt.microsecond * _NANOS_PER_MICROSECOND
try:
seconds = calendar.timegm(dt.utctimetuple())
nanos = dt.microsecond * _NANOS_PER_MICROSECOND
except AttributeError as e:
raise AttributeError(
'Fail to convert to Timestamp. Expected a datetime like '
'object got {0} : {1}'.format(type(dt).__name__, e)
) from e
_CheckTimestampValid(seconds, nanos)
self.seconds = seconds
self.nanos = nanos
def __add__(self, value) -> datetime.datetime:
if isinstance(value, Duration):
return self.ToDatetime() + value.ToTimedelta()
return self.ToDatetime() + value
__radd__ = __add__
def __sub__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
if isinstance(value, Timestamp):
return self.ToDatetime() - value.ToDatetime()
elif isinstance(value, Duration):
return self.ToDatetime() - value.ToTimedelta()
return self.ToDatetime() - value
def __rsub__(self, dt) -> datetime.timedelta:
return dt - self.ToDatetime()
def _CheckTimestampValid(seconds, nanos):
if seconds < _TIMESTAMP_SECONDS_MIN or seconds > _TIMESTAMP_SECONDS_MAX:
@ -408,8 +431,16 @@ class Duration(object):
def FromTimedelta(self, td):
"""Converts timedelta to Duration."""
self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
td.microseconds * _NANOS_PER_MICROSECOND)
try:
self._NormalizeDuration(
td.seconds + td.days * _SECONDS_PER_DAY,
td.microseconds * _NANOS_PER_MICROSECOND,
)
except AttributeError as e:
raise AttributeError(
'Fail to convert to Duration. Expected a timedelta like '
'object got {0}: {1}'.format(type(td).__name__, e)
) from e
def _NormalizeDuration(self, seconds, nanos):
"""Set Duration by seconds and nanos."""
@ -420,6 +451,16 @@ class Duration(object):
self.seconds = seconds
self.nanos = nanos
def __add__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
if isinstance(value, Timestamp):
return self.ToTimedelta() + value.ToDatetime()
return self.ToTimedelta() + value
__radd__ = __add__
def __rsub__(self, dt) -> Union[datetime.datetime, datetime.timedelta]:
return dt - self.ToTimedelta()
def _CheckDurationValid(seconds, nanos):
if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:

@ -13,13 +13,15 @@ import collections.abc as collections_abc
import datetime
import unittest
from google.protobuf import any_pb2
from google.protobuf import text_format
from google.protobuf.internal import any_test_pb2
from google.protobuf.internal import more_messages_pb2
from google.protobuf.internal import well_known_types
from google.protobuf import any_pb2
from google.protobuf import duration_pb2
from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
from google.protobuf.internal import well_known_types
from google.protobuf import text_format
from google.protobuf.internal import _parameterized
from google.protobuf import unittest_pb2
@ -351,6 +353,123 @@ class TimeUtilTest(TimeUtilTestBase):
tz_aware_min_datetime, ts.ToDatetime(datetime.timezone.utc)
)
# Two hours after the Unix Epoch, around the world.
@_parameterized.named_parameters(
('London', [1970, 1, 1, 2], datetime.timezone.utc),
('Tokyo', [1970, 1, 1, 11], _TZ_JAPAN),
('LA', [1969, 12, 31, 18], _TZ_PACIFIC),
)
def testTimestampAssignment(self, date_parts, tzinfo):
original_datetime = datetime.datetime(*date_parts, tzinfo=tzinfo) # pylint:disable=g-tzinfo-datetime
msg = more_messages_pb2.WKTMessage()
msg.optional_timestamp = original_datetime
self.assertEqual(7200, msg.optional_timestamp.seconds)
self.assertEqual(0, msg.optional_timestamp.nanos)
# Two hours after the Unix Epoch, around the world.
@_parameterized.named_parameters(
('London', [1970, 1, 1, 2], datetime.timezone.utc),
('Tokyo', [1970, 1, 1, 11], _TZ_JAPAN),
('LA', [1969, 12, 31, 18], _TZ_PACIFIC),
)
def testTimestampCreation(self, date_parts, tzinfo):
original_datetime = datetime.datetime(*date_parts, tzinfo=tzinfo) # pylint:disable=g-tzinfo-datetime
msg = more_messages_pb2.WKTMessage(optional_timestamp=original_datetime)
self.assertEqual(7200, msg.optional_timestamp.seconds)
self.assertEqual(0, msg.optional_timestamp.nanos)
msg2 = more_messages_pb2.WKTMessage(
optional_timestamp=msg.optional_timestamp
)
self.assertEqual(7200, msg2.optional_timestamp.seconds)
self.assertEqual(0, msg2.optional_timestamp.nanos)
@_parameterized.named_parameters(
(
'tz_aware_min_dt',
datetime.datetime(1, 1, 1, tzinfo=datetime.timezone.utc),
datetime.timedelta(hours=9),
-62135564400,
0,
),
(
'no_change',
datetime.datetime(1970, 1, 1, 11, tzinfo=_TZ_JAPAN),
datetime.timedelta(hours=0),
7200,
0,
),
)
def testTimestampAdd(self, old_time, time_delta, expected_sec, expected_nano):
msg = more_messages_pb2.WKTMessage()
msg.optional_timestamp = old_time
# Timestamp + timedelta
new_msg1 = more_messages_pb2.WKTMessage()
new_msg1.optional_timestamp = msg.optional_timestamp + time_delta
self.assertEqual(expected_sec, new_msg1.optional_timestamp.seconds)
self.assertEqual(expected_nano, new_msg1.optional_timestamp.nanos)
# timedelta + Timestamp
new_msg2 = more_messages_pb2.WKTMessage()
new_msg2.optional_timestamp = time_delta + msg.optional_timestamp
self.assertEqual(expected_sec, new_msg2.optional_timestamp.seconds)
self.assertEqual(expected_nano, new_msg2.optional_timestamp.nanos)
# Timestamp + Duration
msg.optional_duration.FromTimedelta(time_delta)
new_msg3 = more_messages_pb2.WKTMessage()
new_msg3.optional_timestamp = msg.optional_timestamp + msg.optional_duration
self.assertEqual(expected_sec, new_msg3.optional_timestamp.seconds)
self.assertEqual(expected_nano, new_msg3.optional_timestamp.nanos)
@_parameterized.named_parameters(
(
'test1',
datetime.datetime(999, 1, 1, tzinfo=datetime.timezone.utc),
datetime.timedelta(hours=9),
-30641792400,
0,
),
(
'no_change',
datetime.datetime(1970, 1, 1, 11, tzinfo=_TZ_JAPAN),
datetime.timedelta(hours=0),
7200,
0,
),
)
def testTimestampSub(self, old_time, time_delta, expected_sec, expected_nano):
msg = more_messages_pb2.WKTMessage()
msg.optional_timestamp = old_time
# Timestamp - timedelta
new_msg1 = more_messages_pb2.WKTMessage()
new_msg1.optional_timestamp = msg.optional_timestamp - time_delta
self.assertEqual(expected_sec, new_msg1.optional_timestamp.seconds)
self.assertEqual(expected_nano, new_msg1.optional_timestamp.nanos)
# Timestamp - Duration
msg.optional_duration = time_delta
new_msg2 = more_messages_pb2.WKTMessage()
new_msg2.optional_timestamp = msg.optional_timestamp - msg.optional_duration
self.assertEqual(expected_sec, new_msg2.optional_timestamp.seconds)
self.assertEqual(expected_nano, new_msg2.optional_timestamp.nanos)
result_msg = more_messages_pb2.WKTMessage()
result_msg.optional_timestamp = old_time - time_delta
# Timestamp - Timestamp
td = msg.optional_timestamp - result_msg.optional_timestamp
self.assertEqual(time_delta, td)
# Timestamp - datetime
td1 = msg.optional_timestamp - result_msg.optional_timestamp.ToDatetime()
self.assertEqual(time_delta, td1)
# datetime - Timestamp
td2 = msg.optional_timestamp.ToDatetime() - result_msg.optional_timestamp
self.assertEqual(time_delta, td2)
def testNanosOneSecond(self):
tz = _TZ_PACIFIC
ts = timestamp_pb2.Timestamp(nanos=1_000_000_000)
@ -413,6 +532,18 @@ class TimeUtilTest(TimeUtilTestBase):
message.ToJsonString)
self.assertRaisesRegex(ValueError, 'Timestamp is not valid',
message.FromSeconds, -62135596801)
msg = more_messages_pb2.WKTMessage()
with self.assertRaises(AttributeError):
msg.optional_timestamp = 1
with self.assertRaises(AttributeError):
msg2 = more_messages_pb2.WKTMessage(optional_timestamp=1)
with self.assertRaises(TypeError):
msg.optional_timestamp + ''
with self.assertRaises(TypeError):
msg.optional_timestamp - 123
def testInvalidDuration(self):
message = duration_pb2.Duration()
@ -446,6 +577,105 @@ class TimeUtilTest(TimeUtilTestBase):
self.assertRaisesRegex(ValueError,
r'Duration is not valid\: Sign mismatch.',
message.ToJsonString)
msg = more_messages_pb2.WKTMessage()
with self.assertRaises(AttributeError):
msg.optional_duration = 1
with self.assertRaises(AttributeError):
msg2 = more_messages_pb2.WKTMessage(optional_duration=1)
with self.assertRaises(TypeError):
msg.optional_duration + ''
with self.assertRaises(TypeError):
123 - msg.optional_duration
@_parameterized.named_parameters(
('test1', -1999999, -1, -999999000), ('test2', 1999999, 1, 999999000)
)
def testDurationAssignment(self, microseconds, expected_sec, expected_nano):
message = more_messages_pb2.WKTMessage()
expected_td = datetime.timedelta(microseconds=microseconds)
message.optional_duration = expected_td
self.assertEqual(expected_td, message.optional_duration.ToTimedelta())
self.assertEqual(expected_sec, message.optional_duration.seconds)
self.assertEqual(expected_nano, message.optional_duration.nanos)
@_parameterized.named_parameters(
('test1', -1999999, -1, -999999000), ('test2', 1999999, 1, 999999000)
)
def testDurationCreation(self, microseconds, expected_sec, expected_nano):
message = more_messages_pb2.WKTMessage(
optional_duration=datetime.timedelta(microseconds=microseconds)
)
expected_td = datetime.timedelta(microseconds=microseconds)
self.assertEqual(expected_td, message.optional_duration.ToTimedelta())
self.assertEqual(expected_sec, message.optional_duration.seconds)
self.assertEqual(expected_nano, message.optional_duration.nanos)
@_parameterized.named_parameters(
(
'tz_aware_min_dt',
datetime.datetime(1, 1, 1, tzinfo=datetime.timezone.utc),
datetime.timedelta(hours=9),
-62135564400,
0,
),
(
'no_change',
datetime.datetime(1970, 1, 1, 11, tzinfo=_TZ_JAPAN),
datetime.timedelta(hours=0),
7200,
0,
),
)
def testDurationAdd(self, old_time, time_delta, expected_sec, expected_nano):
msg = more_messages_pb2.WKTMessage()
msg.optional_duration = time_delta
msg.optional_timestamp = old_time
# Duration + datetime
msg1 = more_messages_pb2.WKTMessage()
msg1.optional_timestamp = msg.optional_duration + old_time
self.assertEqual(expected_sec, msg1.optional_timestamp.seconds)
self.assertEqual(expected_nano, msg1.optional_timestamp.nanos)
# datetime + Duration
msg2 = more_messages_pb2.WKTMessage()
msg2.optional_timestamp = old_time + msg.optional_duration
self.assertEqual(expected_sec, msg2.optional_timestamp.seconds)
self.assertEqual(expected_nano, msg2.optional_timestamp.nanos)
# Duration + Timestamp
msg3 = more_messages_pb2.WKTMessage()
msg3.optional_timestamp = msg.optional_duration + msg.optional_timestamp
self.assertEqual(expected_sec, msg3.optional_timestamp.seconds)
self.assertEqual(expected_nano, msg3.optional_timestamp.nanos)
@_parameterized.named_parameters(
(
'test1',
datetime.datetime(999, 1, 1, tzinfo=datetime.timezone.utc),
datetime.timedelta(hours=9),
-30641792400,
0,
),
(
'no_change',
datetime.datetime(1970, 1, 1, 11, tzinfo=_TZ_JAPAN),
datetime.timedelta(hours=0),
7200,
0,
),
)
def testDurationSub(self, old_time, time_delta, expected_sec, expected_nano):
msg = more_messages_pb2.WKTMessage()
msg.optional_duration = time_delta
# datetime - Duration
msg.optional_timestamp = old_time - msg.optional_duration
self.assertEqual(expected_sec, msg.optional_timestamp.seconds)
self.assertEqual(expected_nano, msg.optional_timestamp.nanos)
class StructTest(unittest.TestCase):

@ -1092,9 +1092,41 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
return -1;
}
} else {
ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
if (merged == nullptr) {
return -1;
if (PyObject_TypeCheck(value, CMessage_Type)) {
ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
if (merged == nullptr) {
return -1;
}
} else {
switch (descriptor->message_type()->well_known_type()) {
case Descriptor::WELLKNOWNTYPE_TIMESTAMP: {
AssureWritable(cmessage);
ScopedPyObjectPtr ok(
PyObject_CallMethod(reinterpret_cast<PyObject*>(cmessage),
"FromDatetime", "O", value));
if (ok.get() == nullptr) {
return -1;
}
break;
}
case Descriptor::WELLKNOWNTYPE_DURATION: {
AssureWritable(cmessage);
ScopedPyObjectPtr ok(
PyObject_CallMethod(reinterpret_cast<PyObject*>(cmessage),
"FromTimedelta", "O", value));
if (ok.get() == nullptr) {
return -1;
}
break;
}
default:
PyErr_Format(
PyExc_TypeError,
"Parameter to initialize message field must be "
"dict or instance of same class: expected %s got %s.",
descriptor->full_name().c_str(), Py_TYPE(value)->tp_name);
return -1;
}
}
}
} else {
@ -2561,11 +2593,34 @@ int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
field_descriptor->name().c_str());
return -1;
} else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
PyErr_Format(PyExc_AttributeError,
"Assignment not allowed to "
"field \"%s\" in protocol message object.",
field_descriptor->name().c_str());
return -1;
switch (field_descriptor->message_type()->well_known_type()) {
case Descriptor::WELLKNOWNTYPE_TIMESTAMP: {
AssureWritable(self);
PyObject* sub_message = GetFieldValue(self, field_descriptor);
ScopedPyObjectPtr ok(
PyObject_CallMethod(sub_message, "FromDatetime", "O", value));
if (ok.get() == nullptr) {
return -1;
}
return 0;
}
case Descriptor::WELLKNOWNTYPE_DURATION: {
AssureWritable(self);
PyObject* sub_message = GetFieldValue(self, field_descriptor);
ScopedPyObjectPtr ok(
PyObject_CallMethod(sub_message, "FromTimedelta", "O", value));
if (ok.get() == nullptr) {
return -1;
}
return 0;
}
default:
PyErr_Format(PyExc_AttributeError,
"Assignment not allowed to "
"field \"%s\" in protocol message object.",
field_descriptor->name().c_str());
return -1;
}
} else {
AssureWritable(self);
return InternalSetScalar(self, field_descriptor, value);

@ -433,6 +433,7 @@ err:
}
static bool PyUpb_Message_InitMessageAttribute(PyObject* _self, PyObject* name,
const upb_FieldDef* field,
PyObject* value) {
PyObject* submsg = PyUpb_Message_GetAttr(_self, name);
if (!submsg) return -1;
@ -446,10 +447,24 @@ static bool PyUpb_Message_InitMessageAttribute(PyObject* _self, PyObject* name,
assert(!PyErr_Occurred());
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
} else {
const upb_MessageDef* m = PyUpb_Message_GetMsgdef(_self);
PyErr_Format(PyExc_TypeError, "Message must be initialized with a dict: %s",
upb_MessageDef_FullName(m));
ok = false;
const upb_MessageDef* msgdef = upb_FieldDef_MessageSubDef(field);
switch (upb_MessageDef_WellKnownType(msgdef)) {
case kUpb_WellKnown_Timestamp: {
ok = PyObject_CallMethod(submsg, "FromDatetime", "O", value);
break;
}
case kUpb_WellKnown_Duration: {
ok = PyObject_CallMethod(submsg, "FromTimedelta", "O", value);
break;
}
default: {
const upb_MessageDef* m = PyUpb_Message_GetMsgdef(_self);
PyErr_Format(PyExc_TypeError,
"Message must be initialized with a dict: %s",
upb_MessageDef_FullName(m));
ok = false;
}
}
}
Py_DECREF(submsg);
return ok;
@ -502,7 +517,7 @@ int PyUpb_Message_InitAttributes(PyObject* _self, PyObject* args,
} else if (upb_FieldDef_IsRepeated(f)) {
if (!PyUpb_Message_InitRepeatedAttribute(_self, name, value)) return -1;
} else if (upb_FieldDef_IsSubMessage(f)) {
if (!PyUpb_Message_InitMessageAttribute(_self, name, value)) return -1;
if (!PyUpb_Message_InitMessageAttribute(_self, name, f, value)) return -1;
} else {
if (!PyUpb_Message_InitScalarAttribute(msg, f, value, arena)) return -1;
}
@ -935,9 +950,9 @@ int PyUpb_Message_SetFieldValue(PyObject* _self, const upb_FieldDef* field,
PyUpb_Message* self = (void*)_self;
assert(value);
if (upb_FieldDef_IsSubMessage(field) || upb_FieldDef_IsRepeated(field)) {
if (upb_FieldDef_IsRepeated(field)) {
PyErr_Format(exc,
"Assignment not allowed to message, map, or repeated "
"Assignment not allowed to map, or repeated "
"field \"%s\" in protocol message object.",
upb_FieldDef_Name(field));
return -1;
@ -945,6 +960,34 @@ int PyUpb_Message_SetFieldValue(PyObject* _self, const upb_FieldDef* field,
PyUpb_Message_EnsureReified(self);
if (upb_FieldDef_IsSubMessage(field)) {
const upb_MessageDef* msgdef = upb_FieldDef_MessageSubDef(field);
switch (upb_MessageDef_WellKnownType(msgdef)) {
case kUpb_WellKnown_Timestamp: {
PyObject* sub_message = PyUpb_Message_GetFieldValue(_self, field);
PyObject* ok =
PyObject_CallMethod(sub_message, "FromDatetime", "O", value);
if (!ok) return -1;
Py_DECREF(ok);
return 0;
}
case kUpb_WellKnown_Duration: {
PyObject* sub_message = PyUpb_Message_GetFieldValue(_self, field);
PyObject* ok =
PyObject_CallMethod(sub_message, "FromTimedelta", "O", value);
if (!ok) return -1;
Py_DECREF(ok);
return 0;
}
default:
PyErr_Format(exc,
"Assignment not allowed to message "
"field \"%s\" in protocol message object.",
upb_FieldDef_Name(field));
return -1;
}
}
upb_MessageValue val;
upb_Arena* arena = PyUpb_Arena_Get(self->arena);
if (!PyUpb_PyToUpb(value, field, &val, arena)) {

Loading…
Cancel
Save