Skip to content

Commit e2fd7a0

Browse files
Add [clear_and_]parse_dont_enforce_required() to Rust protobuf.
PiperOrigin-RevId: 755791325
1 parent 69eab2b commit e2fd7a0

File tree

4 files changed

+74
-13
lines changed

4 files changed

+74
-13
lines changed

rust/codegen_traits.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ pub(crate) mod create {
7373
use super::SealedInternal;
7474
pub trait Parse: SealedInternal + Sized {
7575
fn parse(serialized: &[u8]) -> Result<Self, crate::ParseError>;
76+
fn parse_dont_enforce_required(serialized: &[u8]) -> Result<Self, crate::ParseError>;
7677
}
7778
}
7879

@@ -100,6 +101,10 @@ pub(crate) mod write {
100101

101102
pub trait ClearAndParse: SealedInternal {
102103
fn clear_and_parse(&mut self, data: &[u8]) -> Result<(), crate::ParseError>;
104+
fn clear_and_parse_dont_enforce_required(
105+
&mut self,
106+
data: &[u8],
107+
) -> Result<(), crate::ParseError>;
103108
}
104109

105110
/// Copies the contents from `src` into `self`.

rust/test/shared/serialization_test.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use protobuf::View;
1212
use paste::paste;
1313
use unittest_proto3_optional_rust_proto::TestProto3Optional;
1414
use unittest_proto3_rust_proto::TestAllTypes as TestAllTypesProto3;
15-
use unittest_rust_proto::TestAllTypes;
15+
use unittest_rust_proto::{TestAllTypes, TestRequired};
1616

1717
macro_rules! generate_parameterized_serialization_test {
1818
($(($type: ident, $name_ext: ident)),*) => {
@@ -147,3 +147,24 @@ generate_parameterized_int32_byte_size_test!(
147147
* presence" semantics and setting it to 0 (default
148148
* value) will cause it to not be serialized */
149149
);
150+
151+
#[gtest]
152+
fn test_required_field_enforced() {
153+
// Empty bytes slice is a valid binaryproto with no fields set -- therefore it should not parse
154+
// as a message with required fields.
155+
expect_that!(TestRequired::parse(&[]), err(anything()));
156+
157+
let mut msg = TestRequired::new();
158+
expect_that!(msg.clear_and_parse(&[]), err(anything()));
159+
}
160+
161+
#[gtest]
162+
fn test_required_field_not_enforced() {
163+
// Empty bytes slice is a valid binaryproto with no fields set.
164+
let mut msg = TestRequired::parse_dont_enforce_required(&[]).unwrap();
165+
expect_that!(msg.has_a(), eq(false));
166+
167+
msg.set_a(1);
168+
msg.clear_and_parse_dont_enforce_required(&[]).unwrap();
169+
expect_that!(msg.has_a(), eq(false));
170+
}

rust/upb/wire.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ pub unsafe fn decode_with_options(
111111
// - `mini_table` is the one associated with `msg`
112112
// - `buf` is legally readable for at least `buf_size` bytes.
113113
// - `extreg` is null.
114-
// - `decode_options` is a valid DecodeOptions, so contains only allowed bits.
114+
// - `decode_options_bitmask` is a bitmask of constants from the `decode_options` module.
115115
let status = unsafe {
116116
upb_Decode(
117117
buf,

src/google/protobuf/compiler/rust/message.cc

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,14 @@ void MessageMutClear(Context& ctx, const Descriptor& msg) {
111111
}
112112
}
113113

114-
void MessageMutClearAndParse(Context& ctx, const Descriptor& msg) {
114+
void MessageMutClearAndParse(Context& ctx, const Descriptor& msg,
115+
bool enforce_required) {
115116
switch (ctx.opts().kernel) {
116-
case Kernel::kCpp:
117-
ctx.Emit({},
117+
case Kernel::kCpp: {
118+
absl::string_view parse_function =
119+
enforce_required ? "proto2_rust_Message_parse"
120+
: "proto2_rust_Message_parse_dont_enforce_required";
121+
ctx.Emit({{"parse_function", parse_function}},
118122
R"rs(
119123
let success = unsafe {
120124
// SAFETY: `data.as_ptr()` is valid to read for `data.len()`.
@@ -123,34 +127,41 @@ void MessageMutClearAndParse(Context& ctx, const Descriptor& msg) {
123127
data.len(),
124128
);
125129
126-
$pbr$::proto2_rust_Message_parse(self.raw_msg(), data)
130+
$pbr$::$parse_function$(self.raw_msg(), data)
127131
};
128132
success.then_some(()).ok_or($pb$::ParseError)
129133
)rs");
130134
return;
131-
132-
case Kernel::kUpb:
133-
ctx.Emit(
134-
R"rs(
135+
}
136+
137+
case Kernel::kUpb: {
138+
absl::string_view decode_options =
139+
enforce_required ? "$pbr$::wire::decode_options::CHECK_REQUIRED"
140+
: "0";
141+
ctx.Emit({{"decode_options",
142+
[&ctx, decode_options] { ctx.Emit(decode_options); }}},
143+
R"rs(
135144
$pb$::Clear::clear(self);
136145
137146
// SAFETY:
138147
// - `data.as_ptr()` is valid to read for `data.len()`
139148
// - `mini_table` is the one used to construct `msg.raw_msg()`
140149
// - `msg.arena().raw()` is held for the same lifetime as `msg`.
141150
let status = unsafe {
142-
$pbr$::wire::decode(
151+
$pbr$::wire::decode_with_options(
143152
data,
144153
self.raw_msg(),
145154
<Self as $pbr$::AssociatedMiniTable>::mini_table(),
146-
self.arena())
155+
self.arena(),
156+
$decode_options$)
147157
};
148158
match status {
149159
Ok(_) => Ok(()),
150160
Err(_) => Err($pb$::ParseError),
151161
}
152162
)rs");
153163
return;
164+
}
154165
}
155166

156167
ABSL_LOG(FATAL) << "unreachable";
@@ -700,7 +711,13 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
700711
{"Msg::serialize", [&] { MessageSerialize(ctx, msg); }},
701712
{"MsgMut::clear", [&] { MessageMutClear(ctx, msg); }},
702713
{"MsgMut::clear_and_parse",
703-
[&] { MessageMutClearAndParse(ctx, msg); }},
714+
[&] {
715+
MessageMutClearAndParse(ctx, msg, /*enforce_required=*/true);
716+
}},
717+
{"MsgMut::clear_and_parse_dont_enforce_required",
718+
[&] {
719+
MessageMutClearAndParse(ctx, msg, /*enforce_required=*/false);
720+
}},
704721
{"Msg::drop", [&] { MessageDrop(ctx, msg); }},
705722
{"Msg::debug", [&] { MessageDebug(ctx, msg); }},
706723
{"MsgMut::take_copy_merge_from",
@@ -830,6 +847,10 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
830847
fn parse(serialized: &[u8]) -> $Result$<Self, $pb$::ParseError> {
831848
Self::parse(serialized)
832849
}
850+
851+
fn parse_dont_enforce_required(serialized: &[u8]) -> $Result$<Self, $pb$::ParseError> {
852+
Self::parse_dont_enforce_required(serialized)
853+
}
833854
}
834855
835856
impl $std$::fmt::Debug for $Msg$ {
@@ -877,6 +898,11 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
877898
let mut m = self.as_mut();
878899
$pb$::ClearAndParse::clear_and_parse(&mut m, data)
879900
}
901+
902+
fn clear_and_parse_dont_enforce_required(&mut self, data: &[u8]) -> $Result$<(), $pb$::ParseError> {
903+
let mut m = self.as_mut();
904+
$pb$::ClearAndParse::clear_and_parse_dont_enforce_required(&mut m, data)
905+
}
880906
}
881907
882908
// SAFETY:
@@ -1014,6 +1040,10 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
10141040
fn clear_and_parse(&mut self, data: &[u8]) -> $Result$<(), $pb$::ParseError> {
10151041
$MsgMut::clear_and_parse$
10161042
}
1043+
1044+
fn clear_and_parse_dont_enforce_required(&mut self, data: &[u8]) -> $Result$<(), $pb$::ParseError> {
1045+
$MsgMut::clear_and_parse_dont_enforce_required$
1046+
}
10171047
}
10181048
10191049
$MsgMut::take_copy_merge_from$
@@ -1116,6 +1146,11 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
11161146
$pb$::ClearAndParse::clear_and_parse(&mut msg, data).map(|_| msg)
11171147
}
11181148
1149+
pub fn parse_dont_enforce_required(data: &[u8]) -> $Result$<Self, $pb$::ParseError> {
1150+
let mut msg = Self::new();
1151+
$pb$::ClearAndParse::clear_and_parse_dont_enforce_required(&mut msg, data).map(|_| msg)
1152+
}
1153+
11191154
pub fn as_view(&self) -> $Msg$View {
11201155
$Msg$View::new($pbi$::Private, self.inner.msg)
11211156
}

0 commit comments

Comments
 (0)