diff --git a/src/main/java/me/kavin/piped/Main.java b/src/main/java/me/kavin/piped/Main.java index ab710fa..2c34df9 100644 --- a/src/main/java/me/kavin/piped/Main.java +++ b/src/main/java/me/kavin/piped/Main.java @@ -9,6 +9,7 @@ import me.kavin.piped.server.ServerLauncher; import me.kavin.piped.utils.*; import me.kavin.piped.utils.matrix.SyncRunner; import me.kavin.piped.utils.obj.MatrixHelper; +import me.kavin.piped.utils.obj.db.OidcData; import me.kavin.piped.utils.obj.db.PlaylistVideo; import me.kavin.piped.utils.obj.db.PubSub; import me.kavin.piped.utils.obj.db.Video; @@ -253,5 +254,32 @@ public class Main { } }, 0, TimeUnit.MINUTES.toMillis(60)); + new Timer().scheduleAtFixedRate(new TimerTask() { + @Override + public void run() { + try (StatelessSession s = DatabaseSessionFactory.createStatelessSession()) { + + var cb = s.getCriteriaBuilder(); + var cd = cb.createCriteriaDelete(OidcData.class); + var root = cd.from(OidcData.class); + cd.where(cb.lessThan(root.get("start"), System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(3))); + + var tr = s.beginTransaction(); + + var query = s.createMutationQuery(cd); + + int affected = query.executeUpdate(); + + tr.commit(); + + if (affected > 0) { + System.out.printf("Cleanup: Removed %o orphaned oidc logins%n", affected); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + }, 0, TimeUnit.MINUTES.toMillis(5)); + } } diff --git a/src/main/java/me/kavin/piped/server/ServerLauncher.java b/src/main/java/me/kavin/piped/server/ServerLauncher.java index d0a3206..f48d507 100644 --- a/src/main/java/me/kavin/piped/server/ServerLauncher.java +++ b/src/main/java/me/kavin/piped/server/ServerLauncher.java @@ -490,8 +490,9 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { } })).map(GET, "/user/delete", AsyncServlet.ofBlocking(executor, request -> { try { - var session = request.getQueryParameter("session"); - return UserHandlers.oidcDeleteRequest(session); + String session = request.getQueryParameter("session"); + String redirect = request.getQueryParameter("redirect"); + return UserHandlers.oidcDeleteRequest(session, redirect); } catch (Exception e) { return getErrorResponse(e, request.getPath()); } diff --git a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java index 8104ca0..b3471dc 100644 --- a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java +++ b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java @@ -177,6 +177,12 @@ public class UserHandlers { URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback"); AuthorizationCode code = authResponse.getAuthorizationCode(); + if (code == null) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent an invalid code. Try again or contact your oidc admin" + ); + } + AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback, data.getOidVerifier()); ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret); @@ -241,12 +247,16 @@ public class UserHandlers { return HttpResponse.redirect302(data.data + "?session=" + sessionId); } - public static HttpResponse oidcDeleteRequest(String session) throws Exception { + public static HttpResponse oidcDeleteRequest(String session, String redirect) throws Exception { if (StringUtils.isBlank(session)) { return HttpResponse.ofCode(400).withHtml("session is a required parameter"); } + if (StringUtils.isBlank(redirect)) { + return HttpResponse.ofCode(400).withHtml("redirect is a required parameter"); + } + OidcProvider provider = null; try (Session s = DatabaseSessionFactory.createSession()) { @@ -282,7 +292,7 @@ public class UserHandlers { CodeVerifier pkceVerifier = new CodeVerifier(); URI callback = URI.create(String.format("%s/oidc/%s/delete", Constants.PUBLIC_URL, provider.name)); - OidcData data = new OidcData(session + "|" + Instant.now().getEpochSecond(), pkceVerifier); + OidcData data = new OidcData(session + "|" + redirect, pkceVerifier); String state = data.getState(); DatabaseHelper.setOidcData(data); @@ -297,7 +307,7 @@ public class UserHandlers { .nonce(data.getOidNonce()); if (provider.sendMaxAge) { - // This parameter is optional and the idp doesn't have to honor it. + // This parameter is optional and the idp doesn't have to honor it. oidcRequestBuilder.maxAge(0); } @@ -316,11 +326,18 @@ public class UserHandlers { ); } - long start = Long.parseLong(data.data.split("\\|")[1]); + String redirect = data.data.split("\\|")[1]; String session = data.data.split("\\|")[0]; URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/delete"); AuthorizationCode code = sr.getAuthorizationCode(); + + if (code == null) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent an invalid code. Try again or contact your oidc admin" + ); + } + AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback, data.getOidVerifier()); ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret); @@ -355,7 +372,7 @@ public class UserHandlers { return HttpResponse.ofCode(400).withHtml("Couldn't get the `auth_time` claim from the provided id token"); } - if (authTime < start) { + if (authTime <= data.start) { return HttpResponse.ofCode(500).withHtml( "Your oidc provider didn't verify your identity. Please try again or contact your oidc admin." ); @@ -377,7 +394,7 @@ public class UserHandlers { tr.commit(); } - return HttpResponse.redirect302(Constants.FRONTEND_URL + "/preferences?deleted=" + session); + return HttpResponse.redirect302(redirect + "?deleted=true"); } public static byte[] deleteUserResponse(String session, String pass) throws IOException { diff --git a/src/main/java/me/kavin/piped/utils/DatabaseHelper.java b/src/main/java/me/kavin/piped/utils/DatabaseHelper.java index c4bb30a..b7b9b15 100644 --- a/src/main/java/me/kavin/piped/utils/DatabaseHelper.java +++ b/src/main/java/me/kavin/piped/utils/DatabaseHelper.java @@ -247,13 +247,24 @@ public class DatabaseHelper { } public static OidcData getOidcData(String state) { - try (StatelessSession s = DatabaseSessionFactory.createStatelessSession()) { + try (Session s = DatabaseSessionFactory.createSession()) { + CriteriaBuilder cb = s.getCriteriaBuilder(); CriteriaQuery cr = cb.createQuery(OidcData.class); Root root = cr.from(OidcData.class); cr.select(root).where(cb.equal(root.get("state"), state)); - return s.createQuery(cr).uniqueResult(); + OidcData data = s.createQuery(cr).uniqueResult(); + + if (data == null){ + return null; + } + + var tr = s.beginTransaction(); + s.remove(data); + tr.commit(); + + return data; } } } diff --git a/src/main/java/me/kavin/piped/utils/obj/db/OidcData.java b/src/main/java/me/kavin/piped/utils/obj/db/OidcData.java index 2e9b50c..cb28344 100644 --- a/src/main/java/me/kavin/piped/utils/obj/db/OidcData.java +++ b/src/main/java/me/kavin/piped/utils/obj/db/OidcData.java @@ -27,13 +27,13 @@ public class OidcData implements Serializable { public String state; @Column(name = "start") - public long auth_start; + public long start; public OidcData(String data, CodeVerifier pkceVerifier) { this.nonce = new Nonce().toString(); this.verifierSecret = pkceVerifier.getValue(); this.data = data; - this.auth_start = System.currentTimeMillis() / 1000L; + this.start = System.currentTimeMillis() / 1000L; this.state = getState(); }