Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming support for Wire request bodies #4267

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

**New**

- Nothing yet!
- First-party converters now support deferring serialization to happen when the request body is written (i.e., during HTTP execution) rather than when the HTTP request is created. In some cases this moves conversion from a calling thread to a background thread, such as in the case when using `Call.enqueue` directly.

The following converters support this feature through a new `createStreaming()` factory:
- Wire

**Changed**

Expand Down
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ robovm = { module = "com.mobidevelop.robovm:robovm-rt", version.ref = "robovm" }
googleJavaFormat = "com.google.googlejavaformat:google-java-format:1.25.0"
ktlint = "com.pinterest.ktlint:ktlint-cli:1.5.0"
compileTesting = "com.google.testing.compile:compile-testing:0.21.0"
testParameterInjector = "com.google.testparameterinjector:test-parameter-injector:1.18"
1 change: 1 addition & 0 deletions retrofit-converters/wire/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies {
testImplementation libs.junit
testImplementation libs.truth
testImplementation libs.okhttp.mockwebserver
testImplementation libs.testParameterInjector
}

jar {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import javax.annotation.Nullable;
import okhttp3.RequestBody;
import okhttp3.ResponseBody;
import retrofit2.Call;
import retrofit2.Converter;
import retrofit2.Retrofit;

Expand All @@ -31,11 +32,30 @@
* <p>This converter only applies for types which extend from {@link Message}.
*/
public final class WireConverterFactory extends Converter.Factory {
/**
* Create an instance which serializes request messages to bytes eagerly on the caller thread
* when either {@link Call#execute()} or {@link Call#enqueue} is called. Response bytes are
* always converted to message instances on one of OKHttp's background threads.
*/
public static WireConverterFactory create() {
return new WireConverterFactory();
return new WireConverterFactory(false);
}

private WireConverterFactory() {}
/**
* Create an instance which streams serialization of request messages to bytes on the HTTP thread
* This is either the calling thread for {@link Call#execute()}, or one of OKHttp's background
* threads for {@link Call#enqueue}. Response bytes are always converted to message instances on
* one of OKHttp's background threads.
*/
public static WireConverterFactory createStreaming() {
return new WireConverterFactory(true);
}

private final boolean streaming;

private WireConverterFactory(boolean streaming) {
this.streaming = streaming;
}

@Override
public @Nullable Converter<ResponseBody, ?> responseBodyConverter(
Expand Down Expand Up @@ -67,6 +87,6 @@ private WireConverterFactory() {}
}
//noinspection unchecked
ProtoAdapter<? extends Message> adapter = ProtoAdapter.get((Class<? extends Message>) c);
return new WireRequestBodyConverter<>(adapter);
return new WireRequestBodyConverter<>(adapter, streaming);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,22 @@
import retrofit2.Converter;

final class WireRequestBodyConverter<T extends Message<T, ?>> implements Converter<T, RequestBody> {
private static final MediaType MEDIA_TYPE = MediaType.get("application/x-protobuf");
static final MediaType MEDIA_TYPE = MediaType.get("application/x-protobuf");

private final ProtoAdapter<T> adapter;
private final boolean streaming;

WireRequestBodyConverter(ProtoAdapter<T> adapter) {
WireRequestBodyConverter(ProtoAdapter<T> adapter, boolean streaming) {
this.adapter = adapter;
this.streaming = streaming;
}

@Override
public RequestBody convert(T value) throws IOException {
if (streaming) {
return new WireStreamingRequestBody<>(adapter, value);
}

Buffer buffer = new Buffer();
adapter.encode(buffer, value);
return RequestBody.create(MEDIA_TYPE, buffer.snapshot());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (C) 2015 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package retrofit2.converter.wire;

import static retrofit2.converter.wire.WireRequestBodyConverter.MEDIA_TYPE;

import com.squareup.wire.Message;
import com.squareup.wire.ProtoAdapter;
import java.io.IOException;
import okhttp3.MediaType;
import okhttp3.RequestBody;
import okio.BufferedSink;

final class WireStreamingRequestBody<T extends Message<T, ?>> extends RequestBody {
private final ProtoAdapter<T> adapter;
private final T value;

WireStreamingRequestBody(ProtoAdapter<T> adapter, T value) {
this.adapter = adapter;
this.value = value;
}

@Override
public MediaType contentType() {
return MEDIA_TYPE;
}

@Override
public void writeTo(BufferedSink sink) throws IOException {
adapter.encode(sink, value);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Code generated by Wire protocol buffer compiler, do not edit.
// Source file: phone.proto at 6:1
package retrofit2.converter.wire;

import com.squareup.wire.FieldEncoding;
import com.squareup.wire.Message;
import com.squareup.wire.ProtoAdapter;
import com.squareup.wire.ProtoReader;
import com.squareup.wire.ProtoWriter;
import com.squareup.wire.WireField;
import com.squareup.wire.internal.Internal;
import java.io.EOFException;
import java.io.IOException;
import okio.ByteString;

public final class CrashingPhone extends Message<CrashingPhone, CrashingPhone.Builder> {
public static final ProtoAdapter<CrashingPhone> ADAPTER = new ProtoAdapter_CrashingPhone();

private static final long serialVersionUID = 0L;

public static final String DEFAULT_NUMBER = "";

@WireField(tag = 1, adapter = "com.squareup.wire.ProtoAdapter#STRING")
public final String number;

public CrashingPhone(String number) {
this(number, ByteString.EMPTY);
}

public CrashingPhone(String number, ByteString unknownFields) {
super(ADAPTER, unknownFields);
this.number = number;
}

@Override
public Builder newBuilder() {
Builder builder = new Builder();
builder.number = number;
builder.addUnknownFields(unknownFields());
return builder;
}

@Override
public boolean equals(Object other) {
if (other == this) return true;
if (!(other instanceof CrashingPhone)) return false;
CrashingPhone o = (CrashingPhone) other;
return Internal.equals(unknownFields(), o.unknownFields()) && Internal.equals(number, o.number);
}

@Override
public int hashCode() {
int result = super.hashCode;
if (result == 0) {
result = unknownFields().hashCode();
result = result * 37 + (number != null ? number.hashCode() : 0);
super.hashCode = result;
}
return result;
}

@Override
public String toString() {
StringBuilder builder = new StringBuilder();
if (number != null) builder.append(", number=").append(number);
return builder.replace(0, 2, "Phone{").append('}').toString();
}

public static final class Builder extends Message.Builder<CrashingPhone, Builder> {
public String number;

public Builder() {}

public Builder number(String number) {
this.number = number;
return this;
}

@Override
public CrashingPhone build() {
return new CrashingPhone(number, buildUnknownFields());
}
}

private static final class ProtoAdapter_CrashingPhone extends ProtoAdapter<CrashingPhone> {
ProtoAdapter_CrashingPhone() {
super(FieldEncoding.LENGTH_DELIMITED, CrashingPhone.class);
}

@Override
public int encodedSize(CrashingPhone value) {
return (value.number != null ? ProtoAdapter.STRING.encodedSizeWithTag(1, value.number) : 0)
+ value.unknownFields().size();
}

@Override
public void encode(ProtoWriter writer, CrashingPhone value) throws IOException {
throw new EOFException("oops!");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat!

}

@Override
public CrashingPhone decode(ProtoReader reader) throws IOException {
Builder builder = new Builder();
long token = reader.beginMessage();
for (int tag; (tag = reader.nextTag()) != -1; ) {
switch (tag) {
case 1:
builder.number(ProtoAdapter.STRING.decode(reader));
break;
default:
{
FieldEncoding fieldEncoding = reader.peekFieldEncoding();
Object value = fieldEncoding.rawProtoAdapter().decode(reader);
builder.addUnknownField(tag, fieldEncoding, value);
}
}
}
reader.endMessage(token);
return builder.build();
}

@Override
public CrashingPhone redact(CrashingPhone value) {
Builder builder = value.newBuilder();
builder.clearUnknownFields();
return builder.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,32 @@

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;

import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import java.io.EOFException;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import okio.Buffer;
import okio.ByteString;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;
import retrofit2.Retrofit;
import retrofit2.http.Body;
import retrofit2.http.GET;
import retrofit2.http.POST;

@RunWith(TestParameterInjector.class)
public final class WireConverterFactoryTest {
interface Service {
@GET("/")
Expand All @@ -44,6 +51,9 @@ interface Service {
@POST("/")
Call<Phone> post(@Body Phone impl);

@POST("/")
Call<Void> postCrashing(@Body CrashingPhone impl);

@GET("/")
Call<String> wrongClass();

Expand All @@ -53,14 +63,17 @@ interface Service {

@Rule public final MockWebServer server = new MockWebServer();

private Service service;
private final Service service;
private final boolean streaming;

public WireConverterFactoryTest(@TestParameter boolean streaming) {
this.streaming = streaming;

@Before
public void setUp() {
Retrofit retrofit =
new Retrofit.Builder()
.baseUrl(server.url("/"))
.addConverterFactory(WireConverterFactory.create())
.addConverterFactory(
streaming ? WireConverterFactory.createStreaming() : WireConverterFactory.create())
.build();
service = retrofit.create(Service.class);
}
Expand All @@ -80,6 +93,36 @@ public void serializeAndDeserialize() throws IOException, InterruptedException {
assertThat(request.getHeader("Content-Type")).isEqualTo("application/x-protobuf");
}

@Test
public void serializeIsStreamed() throws IOException, InterruptedException {
assumeTrue(streaming);

Call<Void> call = service.postCrashing(new CrashingPhone("(519) 867-5309"));

final AtomicReference<Throwable> throwableRef = new AtomicReference<>();
final CountDownLatch latch = new CountDownLatch(1);

// If streaming were broken, the call to enqueue would throw the exception synchronously.
call.enqueue(
new Callback<Void>() {
@Override
public void onResponse(Call<Void> call, Response<Void> response) {
latch.countDown();
}

@Override
public void onFailure(Call<Void> call, Throwable t) {
throwableRef.set(t);
latch.countDown();
}
});
latch.await();

Throwable throwable = throwableRef.get();
assertThat(throwable).isInstanceOf(EOFException.class);
assertThat(throwable).hasMessageThat().isEqualTo("oops!");
}

@Test
public void deserializeEmpty() throws IOException {
server.enqueue(new MockResponse());
Expand Down
Loading