[grid] Implement deferred session removal in LocalSessionMap to prevent race conditions

Signed-off-by: Viet Nguyen Duc <[email protected]>
diff --git a/java/src/org/openqa/selenium/grid/sessionmap/local/LocalSessionMap.java b/java/src/org/openqa/selenium/grid/sessionmap/local/LocalSessionMap.java
index 4cdcbae..dd62dd6 100644
--- a/java/src/org/openqa/selenium/grid/sessionmap/local/LocalSessionMap.java
+++ b/java/src/org/openqa/selenium/grid/sessionmap/local/LocalSessionMap.java
@@ -17,9 +17,11 @@
 
 package org.openqa.selenium.grid.sessionmap.local;
 
+import static org.openqa.selenium.concurrent.ExecutorServices.shutdownGracefully;
 import static org.openqa.selenium.remote.RemoteTags.SESSION_ID;
 import static org.openqa.selenium.remote.RemoteTags.SESSION_ID_EVENT;
 
+import java.io.Closeable;
 import java.net.URI;
 import java.util.Collection;
 import java.util.Collections;
@@ -29,6 +31,9 @@
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
 import java.util.logging.Logger;
 import org.openqa.selenium.NoSuchSessionException;
 import org.openqa.selenium.events.Event;
@@ -48,19 +53,33 @@
 import org.openqa.selenium.remote.tracing.Span;
 import org.openqa.selenium.remote.tracing.Tracer;
 
-public class LocalSessionMap extends SessionMap {
+public class LocalSessionMap extends SessionMap implements Closeable {
 
   private static final Logger LOG = Logger.getLogger(LocalSessionMap.class.getName());
+  // Grace period to allow in-flight responses to complete before removing session
+  private static final long SESSION_REMOVAL_DELAY_MS = 3000;
 
   private final EventBus bus;
   private final IndexedSessionMap knownSessions = new IndexedSessionMap();
+  private final ScheduledExecutorService deferredRemovalExecutor;
 
   public LocalSessionMap(Tracer tracer, EventBus bus) {
     super(tracer);
 
     this.bus = Require.nonNull("Event bus", bus);
 
-    bus.addListener(SessionClosedEvent.listener(this::remove));
+    this.deferredRemovalExecutor =
+        Executors.newScheduledThreadPool(
+            Math.max(Runtime.getRuntime().availableProcessors() / 2, 3), // At least three threads
+            r -> {
+              Thread thread = new Thread(r);
+              thread.setDaemon(true);
+              thread.setName("LocalSessionMap - Deferred Removal");
+              return thread;
+            });
+
+    // Defer session removal to allow in-flight responses to complete
+    bus.addListener(SessionClosedEvent.listener(this::deferredRemove));
 
     bus.addListener(
         NodeRemovedEvent.listener(
@@ -121,6 +140,22 @@ public Session get(SessionId id) {
     return session;
   }
 
+  /**
+   * Deferred removal of session to allow in-flight HTTP responses to complete. This prevents race
+   * condition where Router's HandleSession needs session URI after SessionClosedEvent fires but
+   * before response completes.
+   */
+  private void deferredRemove(SessionId id) {
+    Require.nonNull("Session ID", id);
+
+    LOG.fine(
+        String.format(
+            "Scheduling deferred removal of session %s in %d ms", id, SESSION_REMOVAL_DELAY_MS));
+
+    deferredRemovalExecutor.schedule(
+        () -> remove(id), SESSION_REMOVAL_DELAY_MS, TimeUnit.MILLISECONDS);
+  }
+
   @Override
   public void remove(SessionId id) {
     Require.nonNull("Session ID", id);
@@ -143,6 +178,11 @@ public void remove(SessionId id) {
     }
   }
 
+  @Override
+  public void close() {
+    shutdownGracefully("LocalSessionMap - Deferred Removal", deferredRemovalExecutor);
+  }
+
   private void batchRemoveByUri(URI externalUri, Class<? extends Event> eventClass) {
     Set<SessionId> sessionsToRemove = knownSessions.getSessionsByUri(externalUri);