CollocatedRequestHandler.java
// Copyright (c) ZeroC, Inc.
package com.zeroc.Ice;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletionStage;
final class CollocatedRequestHandler implements RequestHandler {
private class InvokeAllAsync extends RunnableThreadPoolWorkItem {
private InvokeAllAsync(
OutgoingAsyncBase outAsync, OutputStream os, int requestId, int batchRequestNum) {
_outAsync = outAsync;
_os = os;
_requestId = requestId;
_batchRequestNum = batchRequestNum;
}
@Override
public void run() {
if (sentAsync(_outAsync)) {
dispatchAll(_os, _requestId, _batchRequestNum);
}
}
private final OutgoingAsyncBase _outAsync;
private OutputStream _os;
private final int _requestId;
private final int _batchRequestNum;
}
public CollocatedRequestHandler(Reference reference, ObjectAdapter adapter) {
_reference = reference;
_executor = reference.getInstance().initializationData().executor != null;
_adapter = adapter;
_response = _reference.isTwoway();
_logger =
_reference
.getInstance()
.initializationData()
.logger; // Cached for better performance.
_traceLevels = _reference.getInstance().traceLevels(); // Cached for better performance.
_requestId = 0;
}
@Override
public int sendAsyncRequest(ProxyOutgoingAsyncBase outAsync) {
return outAsync.invokeCollocated(this);
}
@Override
public synchronized void asyncRequestCanceled(OutgoingAsyncBase outAsync, LocalException ex) {
Integer requestId = _sendAsyncRequests.remove(outAsync);
if (requestId != null) {
if (requestId > 0) {
_asyncRequests.remove(requestId);
}
if (outAsync.completed(ex)) {
outAsync.invokeCompletedAsync();
}
_adapter.decDirectCount(); // dispatchAll won't be called, decrease the direct count.
return;
}
if (outAsync instanceof OutgoingAsync) {
for (Map.Entry<Integer, OutgoingAsyncBase> e : _asyncRequests.entrySet()) {
if (e.getValue() == outAsync) {
_asyncRequests.remove(e.getKey());
if (outAsync.completed(ex)) {
outAsync.invokeCompletedAsync();
}
return;
}
}
}
}
@Override
public ConnectionI getConnection() {
return null;
}
int invokeAsyncRequest(OutgoingAsyncBase outAsync, int batchRequestNum, boolean sync) {
//
// Increase the direct count to prevent the thread pool from being destroyed before
// dispatchAll is called. This will also throw if the object adapter has been deactivated.
//
_adapter.incDirectCount();
int requestId = 0;
try {
synchronized (this) {
outAsync.cancelable(this); // This will throw if the request is canceled
if (_response) {
requestId = ++_requestId;
_asyncRequests.put(requestId, outAsync);
}
_sendAsyncRequests.put(outAsync, requestId);
}
outAsync.attachCollocatedObserver(_adapter, requestId);
if (!sync
|| !_response
|| _reference.getInvocationTimeout().compareTo(Duration.ZERO) > 0) {
_adapter.getThreadPool()
.dispatch(
new InvokeAllAsync(
outAsync, outAsync.getOs(), requestId, batchRequestNum));
} else if (_executor) {
_adapter.getThreadPool()
.executeFromThisThread(
new InvokeAllAsync(
outAsync, outAsync.getOs(), requestId, batchRequestNum));
} else {
// Optimization: directly call dispatchAll if there's no executor.
if (sentAsync(outAsync)) {
dispatchAll(outAsync.getOs(), requestId, batchRequestNum);
}
}
} catch (Exception ex) {
// Decrement the direct count if any exception is thrown synchronously.
_adapter.decDirectCount();
throw ex;
}
return AsyncStatus.Queued;
}
private boolean sentAsync(final OutgoingAsyncBase outAsync) {
synchronized (this) {
if (_sendAsyncRequests.remove(outAsync) == null) {
return false; // The request timed-out.
}
//
// This must be called within the synchronization to ensure completed(ex) can't be
// called concurrently if the request is canceled.
//
if (!outAsync.sent()) {
return true;
}
}
outAsync.invokeSent();
return true;
}
private void dispatchAll(OutputStream os, int requestId, int requestCount) {
if (_traceLevels.protocol >= 1) {
fillInValue(os, 10, os.size());
if (requestId > 0) {
fillInValue(os, Protocol.headerSize, requestId);
} else if (requestCount > 0) {
fillInValue(os, Protocol.headerSize, requestCount);
}
TraceUtil.traceSend(os, _reference.getInstance(), null, _logger, _traceLevels);
}
var is = new InputStream(_reference.getInstance(), os.getEncoding(), os.getBuffer(), false);
if (requestCount > 0) {
is.pos(Protocol.requestBatchHdr.length);
} else {
is.pos(Protocol.requestHdr.length);
}
int dispatchCount = requestCount > 0 ? requestCount : 1;
assert !_response || dispatchCount == 1;
Object dispatcher = _adapter.dispatchPipeline();
assert dispatcher != null;
try {
while (dispatchCount > 0) {
//
// Increase the direct count for the dispatch. We increase it again here for each
// dispatch. It's important for the direct count to be > 0 until the last collocated
// request response is sent to make sure the thread pool isn't destroyed before.
//
try {
_adapter.incDirectCount();
} catch (ObjectAdapterDestroyedException ex) {
handleException(ex, requestId, false);
break;
}
var request = new IncomingRequest(requestId, null, _adapter, is);
CompletionStage<OutgoingResponse> response = null;
try {
response = dispatcher.dispatch(request);
} catch (Throwable ex) { // UserException or an unchecked exception
sendResponse(request.current.createOutgoingResponse(ex), requestId, false);
}
if (response != null) {
response.whenComplete(
(result, exception) -> {
if (exception != null) {
sendResponse(
request.current.createOutgoingResponse(exception),
requestId,
true);
} else {
sendResponse(result, requestId, true);
}
// Any exception thrown by this closure is effectively ignored.
});
}
--dispatchCount;
}
is.clear();
} catch (LocalException ex) {
dispatchException(ex, requestId, false); // Fatal dispatch exception
} catch (RuntimeException | Error ex) {
// A runtime exception or an error was thrown outside of servant code (i.e., by Ice
// code). Note that this code does NOT send a response to the client.
var uex = new UnknownException(ex);
var sw = new StringWriter();
var pw = new PrintWriter(sw);
ex.printStackTrace(pw);
pw.flush();
_logger.error(sw.toString());
dispatchException(uex, requestId, false);
} finally {
_adapter.decDirectCount();
}
}
private void sendResponse(OutgoingResponse response, int requestId, boolean amd) {
if (_response) {
OutgoingAsyncBase outAsync = null;
OutputStream outputStream = response.outputStream;
synchronized (this) {
if (_traceLevels.protocol >= 1) {
fillInValue(outputStream, 10, outputStream.size());
}
// Adopt the OutputStream's buffer.
var inputStream =
new InputStream(
_reference.getInstance(),
outputStream.getEncoding(),
outputStream.getBuffer(),
true); // adopt: true
inputStream.pos(Protocol.replyHdr.length + 4);
if (_traceLevels.protocol >= 1) {
TraceUtil.traceRecv(inputStream, null, _logger, _traceLevels);
}
outAsync = _asyncRequests.remove(requestId);
if (outAsync != null && !outAsync.completed(inputStream)) {
outAsync = null;
}
}
if (outAsync != null) {
//
// If called from an AMD dispatch, invoke asynchronously the completion callback
// since this might be called from the user code.
//
if (amd) {
outAsync.invokeCompletedAsync();
} else {
outAsync.invokeCompleted();
}
}
}
_adapter.decDirectCount();
}
private void dispatchException(LocalException ex, int requestId, boolean amd) {
handleException(ex, requestId, amd);
_adapter.decDirectCount();
}
private void handleException(LocalException ex, int requestId, boolean amd) {
if (requestId == 0) {
return; // Ignore exception for oneway messages.
}
OutgoingAsyncBase outAsync = null;
synchronized (this) {
outAsync = _asyncRequests.remove(requestId);
if (outAsync != null && !outAsync.completed(ex)) {
outAsync = null;
}
}
if (outAsync != null) {
//
// If called from an AMD dispatch, invoke asynchronously the completion callback since
// this might be called from the user code.
//
if (amd) {
outAsync.invokeCompletedAsync();
} else {
outAsync.invokeCompleted();
}
}
}
private static void fillInValue(OutputStream os, int pos, int value) {
os.rewriteInt(value, pos);
}
private final Reference _reference;
private final boolean _executor;
private final boolean _response;
private final ObjectAdapter _adapter;
private final Logger _logger;
private final TraceLevels _traceLevels;
private int _requestId;
// A map of outstanding requests that can be canceled. A request can be canceled if it has an
// invocation timeout, or we support
// interrupts.
private final Map<OutgoingAsyncBase, Integer> _sendAsyncRequests =
new HashMap<>();
private final Map<Integer, OutgoingAsyncBase> _asyncRequests =
new HashMap<>();
}