diff --git a/src/main/java/me/kavin/piped/server/ServerLauncher.java b/src/main/java/me/kavin/piped/server/ServerLauncher.java index a62b996..676fc44 100644 --- a/src/main/java/me/kavin/piped/server/ServerLauncher.java +++ b/src/main/java/me/kavin/piped/server/ServerLauncher.java @@ -2,8 +2,12 @@ package me.kavin.piped.server; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; +import com.nimbusds.oauth2.sdk.*; import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.id.Identifier; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.openid.connect.sdk.*; import com.nimbusds.openid.connect.sdk.claims.UserInfo; import com.rometools.rome.feed.synd.SyndFeed; import com.rometools.rome.io.SyndFeedInput; @@ -21,8 +25,8 @@ import me.kavin.piped.server.handlers.auth.AuthPlaylistHandlers; import me.kavin.piped.server.handlers.auth.FeedHandlers; import me.kavin.piped.server.handlers.auth.StorageHandlers; import me.kavin.piped.server.handlers.auth.UserHandlers; -import me.kavin.piped.utils.*; import me.kavin.piped.utils.ErrorResponse; +import me.kavin.piped.utils.*; import me.kavin.piped.utils.obj.MatrixHelper; import me.kavin.piped.utils.obj.OidcProvider; import me.kavin.piped.utils.obj.federation.FederatedVideoInfo; @@ -35,14 +39,10 @@ import org.jetbrains.annotations.NotNull; import org.schabi.newpipe.extractor.exceptions.ParsingException; import org.schabi.newpipe.extractor.localization.DateWrapper; import org.xml.sax.InputSource; -import com.nimbusds.oauth2.sdk.*; -import com.nimbusds.openid.connect.sdk.*; -import com.nimbusds.oauth2.sdk.id.*; import java.io.ByteArrayInputStream; import java.net.InetSocketAddress; import java.net.URI; -import java.util.List; import java.util.Objects; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; @@ -330,7 +330,7 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { try { String function = request.getPathParameter("function"); OidcProvider provider = getOidcProvider(request.getPathParameter("provider")); - if(provider == null) + if (provider == null) return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server."); URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback"); @@ -339,9 +339,10 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { case "login" -> { String redirectUri = request.getQueryParameter("redirect"); - if (redirectUri == null || redirectUri.equals("")) { + if (StringUtils.isBlank(redirectUri)) { return HttpResponse.ofCode(400).withHtml("Missing redirect parameter"); } + State state = new State(new Identifier(24) + "." + redirectUri); Nonce nonce = new Nonce(); @@ -355,7 +356,7 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { .nonce(nonce) .build(); - if(redirectUri.equals(Constants.FRONTEND_URL + "/login")) { + if (redirectUri.equals(Constants.FRONTEND_URL + "/login")) { return HttpResponse.redirect302(oidcRequest.toURI().toString()); } return HttpResponse.ok200().withHtml( @@ -388,18 +389,18 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { TokenRequest tr = new TokenRequest(provider.tokenUri, clientAuth, codeGrant); TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tr.toHTTPRequest().send()); - if (! tokenResponse.indicatesSuccess()) { + if (!tokenResponse.indicatesSuccess()) { TokenErrorResponse errorResponse = tokenResponse.toErrorResponse(); return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription()); } - OIDCTokenResponse successResponse = (OIDCTokenResponse)tokenResponse.toSuccessResponse(); + OIDCTokenResponse successResponse = (OIDCTokenResponse) tokenResponse.toSuccessResponse(); UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken()); UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send()); - if (! userInfoResponse.indicatesSuccess()) { + if (!userInfoResponse.indicatesSuccess()) { System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getCode()); System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getDescription()); return HttpResponse.ofCode(500).withHtml("Failed to query userInfo:\n\n" + userInfoResponse.toErrorResponse().getErrorObject().getDescription()); @@ -666,14 +667,15 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { return new CustomServletDecorator(router); } - private static OidcProvider getOidcProvider(String provider){ - for(int i = 0; i < Constants.OIDC_PROVIDERS.size(); i++) { + private static OidcProvider getOidcProvider(String provider) { + for (int i = 0; i < Constants.OIDC_PROVIDERS.size(); i++) { OidcProvider curr = Constants.OIDC_PROVIDERS.get(i); - if(curr == null || !curr.name.equals(provider)) continue; + if (curr == null || !curr.name.equals(provider)) continue; return curr; } return null; } + private static String[] getArray(String s) { if (s == null) { diff --git a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java index 8788ce2..7059bc2 100644 --- a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java +++ b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java @@ -108,32 +108,34 @@ public class UserHandlers { return null; } } + public static String oidcCallbackResponse(String provider, String uid) { try (Session s = DatabaseSessionFactory.createSession()) { - String dbName = provider + "-" + uid; - CriteriaBuilder cb = s.getCriteriaBuilder(); - CriteriaQuery cr = cb.createQuery(User.class); - Root root = cr.from(User.class); - cr.select(root).where(root.get("username").in( - dbName - )); + String dbName = provider + "-" + uid; + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery cr = cb.createQuery(User.class); + Root root = cr.from(User.class); + cr.select(root).where(root.get("username").in( + dbName + )); - User dbuser = s.createQuery(cr).uniqueResult(); + User dbuser = s.createQuery(cr).uniqueResult(); - if (dbuser == null) { - User newuser = new User(dbName, "", Set.of()); + if (dbuser == null) { + User newuser = new User(dbName, "", Set.of()); - var tr = s.beginTransaction(); - s.persist(newuser); - tr.commit(); + var tr = s.beginTransaction(); + s.persist(newuser); + tr.commit(); - return newuser.getSessionId(); - } - return dbuser.getSessionId(); + return newuser.getSessionId(); } + return dbuser.getSessionId(); + } } + public static byte[] deleteUserResponse(String session, String pass) throws IOException { if (StringUtils.isBlank(session)) ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session is a required parameter")); diff --git a/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java b/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java index 79216ea..aedce63 100644 --- a/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java +++ b/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java @@ -22,7 +22,7 @@ public class OidcProvider { this.authUri = new URI(authUri); this.tokenUri = new URI(tokenUri); this.userinfoUri = new URI(userinfoUri); - } catch(URISyntaxException e) { + } catch (URISyntaxException e) { System.err.println("Malformed URI for oidc provider '" + name + "' found."); System.exit(1); }