diff --git a/aws-lambda-java-events/README.md b/aws-lambda-java-events/README.md index 43c25d76a..c725fd87a 100644 --- a/aws-lambda-java-events/README.md +++ b/aws-lambda-java-events/README.md @@ -60,6 +60,23 @@ * `SQSBatchResponse` * `SQSEvent` +### API Gateway WebSocket connection context + +API Gateway manages the WebSocket connection. Lambda receives event payloads instead of a native Java WebSocket session. + +For `APIGatewayV2WebSocketEvent`, use the request context's connection metadata: + +```java +APIGatewayV2WebSocketEvent.WebSocketConnectionContext connection = event.getConnectionContext(); +if (connection != null) { + String connectionId = connection.getConnectionId(); + String endpoint = connection.getManagementApiEndpoint(); // https://{domainName}/{stage} + // endpoint is null when domainName/stage is missing or empty +} +``` + +You can pass this object through your handlers as a lightweight session-like context and use it with the API Gateway Management API to send messages. + ### Usage diff --git a/aws-lambda-java-events/RELEASE.CHANGELOG.md b/aws-lambda-java-events/RELEASE.CHANGELOG.md index a4bcd10a0..f102ad901 100644 --- a/aws-lambda-java-events/RELEASE.CHANGELOG.md +++ b/aws-lambda-java-events/RELEASE.CHANGELOG.md @@ -1,3 +1,6 @@ +### Unreleased +- Add `WebSocketConnectionContext` helper accessors on `APIGatewayV2WebSocketEvent` to expose connection metadata as a session-like object + ### June 17, 2025 `3.16.0`: - Add Schema metadata related attributes in KafkaEvent ([#548](https://github.com/aws/aws-lambda-java-libs/pull/548)) diff --git a/aws-lambda-java-events/src/main/java/com/amazonaws/services/lambda/runtime/events/APIGatewayV2WebSocketEvent.java b/aws-lambda-java-events/src/main/java/com/amazonaws/services/lambda/runtime/events/APIGatewayV2WebSocketEvent.java index cb6ffa991..bcbadbb5b 100644 --- a/aws-lambda-java-events/src/main/java/com/amazonaws/services/lambda/runtime/events/APIGatewayV2WebSocketEvent.java +++ b/aws-lambda-java-events/src/main/java/com/amazonaws/services/lambda/runtime/events/APIGatewayV2WebSocketEvent.java @@ -1,6 +1,7 @@ package com.amazonaws.services.lambda.runtime.events; import java.io.Serializable; +import java.beans.Transient; import java.util.List; import java.util.Map; import java.util.Objects; @@ -12,6 +13,81 @@ public class APIGatewayV2WebSocketEvent implements Serializable, Cloneable { private static final long serialVersionUID = 5695319264103347099L; + /** + * Represents the API Gateway-managed WebSocket connection for a single client. + * This is not a native server-side socket session. + */ + public static class WebSocketConnectionContext implements Serializable { + + private static final long serialVersionUID = 9166276112534784030L; + + private final String connectionId; + private final String domainName; + private final String stage; + + public WebSocketConnectionContext(String connectionId, String domainName, String stage) { + this.connectionId = connectionId; + this.domainName = domainName; + this.stage = stage; + } + + public String getConnectionId() { + return connectionId; + } + + public String getDomainName() { + return domainName; + } + + public String getStage() { + return stage; + } + + /** + * Builds the API Gateway Management API endpoint used for postToConnection calls. + */ + public String getManagementApiEndpoint() { + if (isNullOrEmpty(domainName) || isNullOrEmpty(stage)) { + return null; + } + return "https://" + domainName + "/" + stage; + } + + private static boolean isNullOrEmpty(String value) { + return value == null || value.isEmpty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + WebSocketConnectionContext that = (WebSocketConnectionContext) o; + + if (!Objects.equals(connectionId, that.connectionId)) return false; + if (!Objects.equals(domainName, that.domainName)) return false; + return Objects.equals(stage, that.stage); + } + + @Override + public int hashCode() { + int result = connectionId != null ? connectionId.hashCode() : 0; + result = 31 * result + (domainName != null ? domainName.hashCode() : 0); + result = 31 * result + (stage != null ? stage.hashCode() : 0); + return result; + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("WebSocketConnectionContext{"); + sb.append("connectionId='").append(connectionId).append('\''); + sb.append(", domainName='").append(domainName).append('\''); + sb.append(", stage='").append(stage).append('\''); + sb.append('}'); + return sb.toString(); + } + } + public static class RequestIdentity implements Serializable, Cloneable { private static final long serialVersionUID = -3276649362684921217L; @@ -415,6 +491,14 @@ public void setStatus(String status) { this.status = status; } + @Transient + public WebSocketConnectionContext getConnectionContext() { + if (connectionId == null) { + return null; + } + return new WebSocketConnectionContext(connectionId, domainName, stage); + } + @Override public int hashCode() { int hash = 3; @@ -646,6 +730,14 @@ public void setRequestContext(RequestContext requestContext) { this.requestContext = requestContext; } + @Transient + public WebSocketConnectionContext getConnectionContext() { + if (requestContext == null) { + return null; + } + return requestContext.getConnectionContext(); + } + public String getBody() { return body; } diff --git a/aws-lambda-java-events/src/test/java/com/amazonaws/services/lambda/runtime/events/APIGatewayV2WebSocketEventTest.java b/aws-lambda-java-events/src/test/java/com/amazonaws/services/lambda/runtime/events/APIGatewayV2WebSocketEventTest.java new file mode 100644 index 000000000..729049496 --- /dev/null +++ b/aws-lambda-java-events/src/test/java/com/amazonaws/services/lambda/runtime/events/APIGatewayV2WebSocketEventTest.java @@ -0,0 +1,75 @@ +package com.amazonaws.services.lambda.runtime.events; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +class APIGatewayV2WebSocketEventTest { + + @Test + void requestContextBuildsConnectionContext() { + APIGatewayV2WebSocketEvent.RequestContext requestContext = new APIGatewayV2WebSocketEvent.RequestContext(); + requestContext.setConnectionId("conn-123"); + requestContext.setDomainName("abc.execute-api.us-east-1.amazonaws.com"); + requestContext.setStage("prod"); + + APIGatewayV2WebSocketEvent.WebSocketConnectionContext connectionContext = requestContext.getConnectionContext(); + + assertNotNull(connectionContext); + assertEquals("conn-123", connectionContext.getConnectionId()); + assertEquals("abc.execute-api.us-east-1.amazonaws.com", connectionContext.getDomainName()); + assertEquals("prod", connectionContext.getStage()); + assertEquals("https://abc.execute-api.us-east-1.amazonaws.com/prod", connectionContext.getManagementApiEndpoint()); + } + + @Test + void eventExposesConnectionContextFromRequestContext() { + APIGatewayV2WebSocketEvent.RequestContext requestContext = new APIGatewayV2WebSocketEvent.RequestContext(); + requestContext.setConnectionId("conn-456"); + requestContext.setDomainName("xyz.execute-api.us-east-1.amazonaws.com"); + requestContext.setStage("dev"); + + APIGatewayV2WebSocketEvent event = new APIGatewayV2WebSocketEvent(); + event.setRequestContext(requestContext); + + APIGatewayV2WebSocketEvent.WebSocketConnectionContext connectionContext = event.getConnectionContext(); + + assertNotNull(connectionContext); + assertEquals("conn-456", connectionContext.getConnectionId()); + assertEquals("https://xyz.execute-api.us-east-1.amazonaws.com/dev", connectionContext.getManagementApiEndpoint()); + } + + @Test + void connectionContextRequiresConnectionId() { + APIGatewayV2WebSocketEvent.RequestContext requestContext = new APIGatewayV2WebSocketEvent.RequestContext(); + requestContext.setDomainName("abc.execute-api.us-east-1.amazonaws.com"); + requestContext.setStage("prod"); + + assertNull(requestContext.getConnectionContext()); + + APIGatewayV2WebSocketEvent event = new APIGatewayV2WebSocketEvent(); + assertNull(event.getConnectionContext()); + } + + @Test + void managementEndpointRequiresDomainAndStage() { + APIGatewayV2WebSocketEvent.WebSocketConnectionContext connectionContext = + new APIGatewayV2WebSocketEvent.WebSocketConnectionContext("conn-789", null, "prod"); + + assertNull(connectionContext.getManagementApiEndpoint()); + } + + @Test + void managementEndpointRequiresNonEmptyDomainAndStage() { + APIGatewayV2WebSocketEvent.WebSocketConnectionContext emptyDomain = + new APIGatewayV2WebSocketEvent.WebSocketConnectionContext("conn-111", "", "prod"); + APIGatewayV2WebSocketEvent.WebSocketConnectionContext emptyStage = + new APIGatewayV2WebSocketEvent.WebSocketConnectionContext("conn-222", "abc.execute-api.us-east-1.amazonaws.com", ""); + + assertNull(emptyDomain.getManagementApiEndpoint()); + assertNull(emptyStage.getManagementApiEndpoint()); + } +} +