Refactor oidc logic into UserHandlers

This commit is contained in:
Jeidnx 2023-10-26 13:24:27 +02:00
parent e7f2187b47
commit c1fde372a5
No known key found for this signature in database
GPG Key ID: 0E9E697B7E99DF39
2 changed files with 159 additions and 171 deletions

View File

@ -2,13 +2,6 @@ package me.kavin.piped.server;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonNode;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.oauth2.sdk.*;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.openid.connect.sdk.*;
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
import com.rometools.rome.feed.synd.SyndFeed; import com.rometools.rome.feed.synd.SyndFeed;
import com.rometools.rome.io.SyndFeedInput; import com.rometools.rome.io.SyndFeedInput;
import io.activej.config.Config; import io.activej.config.Config;
@ -44,11 +37,8 @@ import org.xml.sax.InputSource;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.URI; import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress; import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress;
import static io.activej.http.HttpHeaders.*; import static io.activej.http.HttpHeaders.*;
@ -61,7 +51,6 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {
private static final HttpHeader FILE_NAME = HttpHeaders.of("x-file-name"); private static final HttpHeader FILE_NAME = HttpHeaders.of("x-file-name");
private static final HttpHeader LAST_ETAG = HttpHeaders.of("x-last-etag"); private static final HttpHeader LAST_ETAG = HttpHeaders.of("x-last-etag");
private static final Map<String, OidcData> PENDING_OIDC = new HashMap<>();
@Provides @Provides
Executor executor() { Executor executor() {
@ -291,137 +280,12 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {
if (provider == null) if (provider == null)
return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server"); return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server");
URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback"); return switch (function) {
case "login" -> UserHandlers.oidcLoginResponse(provider, request.getQueryParameter("redirect"));
switch (function) { case "callback" -> UserHandlers.oidcCallbackResponse(provider, URI.create(request.getFullUrl()));
case "login" -> { case "delete" -> UserHandlers.oidcDeleteResponse(provider, URI.create(request.getFullUrl()));
String redirectUri = request.getQueryParameter("redirect"); default -> HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`");
};
if (StringUtils.isBlank(redirectUri)) {
return HttpResponse.ofCode(400).withHtml("redirect is a required parameter");
}
OidcData data = new OidcData(redirectUri);
String state = data.getState();
PENDING_OIDC.put(state, data);
AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder(
new ResponseType("code"),
new Scope("openid"),
provider.clientID, callback).endpointURI(provider.authUri)
.state(new State(state)).nonce(data.nonce).build();
if (redirectUri.equals(Constants.FRONTEND_URL + "/login")) {
return HttpResponse.redirect302(oidcRequest.toURI().toString());
}
return HttpResponse.ok200().withHtml(
"<!DOCTYPE html><html style=\"color-scheme: dark light;\"><body>" +
"<h3>Warning:</h3> You are trying to give <pre style=\"font-size: 1.2rem;\">" +
redirectUri +
"</pre> access to your Piped account. If you wish to continue click " +
"<a style=\"text-decoration: underline;color: inherit;\"href=\"" +
oidcRequest.toURI().toString() +
"\">here</a></body></html>");
}
case "callback" -> {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);
AuthenticationSuccessResponse sr = parseOidcUri(URI.create(request.getFullUrl()));
OidcData data = PENDING_OIDC.get(sr.getState().toString());
if (data == null) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent invalid state data. Try again or contact your oidc admin"
);
}
AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback);
TokenRequest tr = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
OIDCTokenResponse tokenResponse = (OIDCTokenResponse) OIDCTokenResponseParser.parse(tr.toHTTPRequest().send());
if (!tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription());
}
OIDCTokenResponse successResponse = tokenResponse.toSuccessResponse();
if (data.isInvalidNonce((String) successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("nonce"))) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent an invalid nonce. Try again or contact your oidc admin"
);
}
UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken());
UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send());
if (!userInfoResponse.indicatesSuccess()) {
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getCode());
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getDescription());
return HttpResponse.ofCode(500).withHtml("Failed to query userInfo:\n\n" + userInfoResponse.toErrorResponse().getErrorObject().getDescription());
}
UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo();
String sessionId = UserHandlers.oidcCallbackResponse(provider.name, userInfo.getSubject().toString());
return HttpResponse.redirect302(data.data + "?session=" + sessionId);
}
case "delete" -> {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);
AuthenticationSuccessResponse sr = parseOidcUri(URI.create(request.getFullUrl()));
OidcData data = UserHandlers.PENDING_OIDC.get(sr.getState().toString());
if (data == null) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent invalid state data. Try again or contact your oidc admin"
);
}
long start = Long.parseLong(data.data.split("\\|")[1]);
String session = data.data.split("\\|")[0];
AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, new URI(Constants.PUBLIC_URL + request.getPath()));
TokenRequest tr = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tr.toHTTPRequest().send());
if (!tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription());
}
OIDCTokenResponse successResponse = (OIDCTokenResponse) tokenResponse.toSuccessResponse();
JWTClaimsSet claims = successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet();
if (data.isInvalidNonce((String) claims.getClaim("nonce"))) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent an invalid nonce. Please try again or contact your oidc admin."
);
}
long authTime = (long) claims.getClaim("auth_time");
if (authTime < start) {
return HttpResponse.ofCode(500).withHtml(
"Your oidc provider didn't verify your identity. Please try again or contact your oidc admin."
);
}
return HttpResponse.redirect302(Constants.FRONTEND_URL + "/preferences?deleted=" + UserHandlers.deleteOidcUserResponse(session));
}
default -> {
return HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`");
}
}
} catch (Exception e) { } catch (Exception e) {
return getErrorResponse(e, request.getPath()); return getErrorResponse(e, request.getPath());
} }
@ -680,17 +544,6 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {
return null; return null;
} }
private static AuthenticationSuccessResponse parseOidcUri(URI uri) throws Exception {
AuthenticationResponse response = AuthenticationResponseParser.parse(uri);
if (response instanceof AuthenticationErrorResponse) {
// The OpenID provider returned an error
System.err.println(response.toErrorResponse().getErrorObject());
throw new Exception(response.toErrorResponse().getErrorObject().toString());
}
return response.toSuccessResponse();
}
private static String[] getArray(String s) { private static String[] getArray(String s) {
if (s == null) { if (s == null) {

View File

@ -1,10 +1,14 @@
package me.kavin.piped.server.handlers.auth; package me.kavin.piped.server.handlers.auth;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.nimbusds.oauth2.sdk.ResponseType; import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.oauth2.sdk.Scope; import com.nimbusds.oauth2.sdk.*;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.id.State; import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.openid.connect.sdk.AuthenticationRequest; import com.nimbusds.openid.connect.sdk.*;
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
import io.activej.http.HttpResponse;
import jakarta.persistence.criteria.CriteriaBuilder; import jakarta.persistence.criteria.CriteriaBuilder;
import jakarta.persistence.criteria.CriteriaQuery; import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.criteria.Root; import jakarta.persistence.criteria.Root;
@ -120,7 +124,84 @@ public class UserHandlers {
} }
} }
public static String oidcCallbackResponse(String provider, String uid) { public static HttpResponse oidcLoginResponse(OidcProvider provider, String redirectUri) throws Exception{
if (StringUtils.isBlank(redirectUri)) {
return HttpResponse.ofCode(400).withHtml("redirect is a required parameter");
}
URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback");
OidcData data = new OidcData(redirectUri);
String state = data.getState();
PENDING_OIDC.put(state, data);
AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder(
new ResponseType("code"),
new Scope("openid"),
provider.clientID, callback).endpointURI(provider.authUri)
.state(new State(state)).nonce(data.nonce).build();
if (redirectUri.equals(Constants.FRONTEND_URL + "/login")) {
return HttpResponse.redirect302(oidcRequest.toURI().toString());
}
return HttpResponse.ok200().withHtml(
"<!DOCTYPE html><html style=\"color-scheme: dark light;\"><body>" +
"<h3>Warning:</h3> You are trying to give <pre style=\"font-size: 1.2rem;\">" +
redirectUri +
"</pre> access to your Piped account. If you wish to continue click " +
"<a style=\"text-decoration: underline;color: inherit;\"href=\"" +
oidcRequest.toURI().toString() +
"\">here</a></body></html>");
}
public static HttpResponse oidcCallbackResponse(OidcProvider provider, URI requestUri) throws Exception {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);
AuthenticationSuccessResponse sr = parseOidcUri(requestUri);
OidcData data = PENDING_OIDC.get(sr.getState().toString());
if (data == null) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent invalid state data. Try again or contact your oidc admin"
);
}
URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback");
AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback);
TokenRequest tokenReq = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
OIDCTokenResponse tokenResponse = (OIDCTokenResponse) OIDCTokenResponseParser.parse(tokenReq.toHTTPRequest().send());
if (!tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription());
}
OIDCTokenResponse successResponse = tokenResponse.toSuccessResponse();
if (data.isInvalidNonce((String) successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("nonce"))) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent an invalid nonce. Try again or contact your oidc admin"
);
}
UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken());
UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send());
if (!userInfoResponse.indicatesSuccess()) {
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getCode());
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getDescription());
return HttpResponse.ofCode(500).withHtml(
"The userinfo endpoint returned an error. Please try again or contact your oidc admin\n\n" +
userInfoResponse.toErrorResponse().getErrorObject().getDescription());
}
UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo();
String uid = userInfo.getSubject().toString();
String sessionId;
try (Session s = DatabaseSessionFactory.createSession()) { try (Session s = DatabaseSessionFactory.createSession()) {
// TODO: Add oidc provider to database // TODO: Add oidc provider to database
String dbName = provider + "-" + uid; String dbName = provider + "-" + uid;
@ -141,11 +222,66 @@ public class UserHandlers {
tr.commit(); tr.commit();
return newuser.getSessionId(); sessionId = newuser.getSessionId();
} else sessionId = dbuser.getSessionId();
} }
return dbuser.getSessionId(); return HttpResponse.redirect302(data.data + "?session=" + sessionId);
} }
public static HttpResponse oidcDeleteResponse(OidcProvider provider, URI requestUri) throws Exception {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);
AuthenticationSuccessResponse sr = parseOidcUri(requestUri);
OidcData data = UserHandlers.PENDING_OIDC.get(sr.getState().toString());
if (data == null) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent invalid state data. Try again or contact your oidc admin"
);
}
long start = Long.parseLong(data.data.split("\\|")[1]);
String session = data.data.split("\\|")[0];
URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/delete");
AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback);
TokenRequest tokenRequest = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tokenRequest.toHTTPRequest().send());
if (!tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription());
}
OIDCTokenResponse successResponse = (OIDCTokenResponse) tokenResponse.toSuccessResponse();
JWTClaimsSet claims = successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet();
if (data.isInvalidNonce((String) claims.getClaim("nonce"))) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent an invalid nonce. Please try again or contact your oidc admin."
);
}
long authTime = (long) claims.getClaim("auth_time");
if (authTime < start) {
return HttpResponse.ofCode(500).withHtml(
"Your oidc provider didn't verify your identity. Please try again or contact your oidc admin."
);
}
try (Session s = DatabaseSessionFactory.createSession()) {
var tr = s.beginTransaction();
s.remove(DatabaseHelper.getUserFromSession(session));
tr.commit();
}
return HttpResponse.redirect302(Constants.FRONTEND_URL + "/preferences?deleted=" + session);
} }
public static byte[] deleteUserResponse(String session, String pass) throws IOException { public static byte[] deleteUserResponse(String session, String pass) throws IOException {
@ -187,17 +323,6 @@ public class UserHandlers {
} }
} }
public static String deleteOidcUserResponse(String session) throws IOException {
try (Session s = DatabaseSessionFactory.createSession()) {
User user = DatabaseHelper.getUserFromSession(session);
var tr = s.beginTransaction();
s.remove(user);
tr.commit();
return user.getUsername();
}
}
public static byte[] logoutResponse(String session) throws JsonProcessingException { public static byte[] logoutResponse(String session) throws JsonProcessingException {
@ -217,4 +342,14 @@ public class UserHandlers {
return Constants.mapper.writeValueAsBytes(new AuthenticationFailureResponse()); return Constants.mapper.writeValueAsBytes(new AuthenticationFailureResponse());
} }
private static AuthenticationSuccessResponse parseOidcUri(URI uri) throws Exception {
AuthenticationResponse response = AuthenticationResponseParser.parse(uri);
if (response instanceof AuthenticationErrorResponse) {
System.err.println(response.toErrorResponse().getErrorObject());
throw new Exception(response.toErrorResponse().getErrorObject().toString());
}
return response.toSuccessResponse();
}
} }