From 604fa651fcc14f893c96b8280928dca3fd9438e7 Mon Sep 17 00:00:00 2001 From: Jeidnx Date: Sun, 18 Jun 2023 18:28:18 +0200 Subject: [PATCH] Implement oidc --- build.gradle | 1 + config.properties | 6 ++ .../java/me/kavin/piped/consts/Constants.java | 26 +++++ .../me/kavin/piped/server/ServerLauncher.java | 101 ++++++++++++++++++ .../server/handlers/auth/UserHandlers.java | 38 ++++++- .../kavin/piped/utils/obj/OidcProvider.java | 25 +++++ .../me/kavin/piped/utils/obj/db/User.java | 2 +- 7 files changed, 195 insertions(+), 4 deletions(-) create mode 100644 src/main/java/me/kavin/piped/utils/obj/OidcProvider.java diff --git a/build.gradle b/build.gradle index ad13d64..a59c346 100644 --- a/build.gradle +++ b/build.gradle @@ -42,6 +42,7 @@ dependencies { implementation 'io.sentry:sentry:6.23.0' implementation 'rocks.kavin:reqwest4j:1.0.4' implementation 'io.minio:minio:8.5.3' + implementation 'com.nimbusds:oauth2-oidc-sdk:10.9.1' } shadowJar { diff --git a/config.properties b/config.properties index 5b147b8..bd6efcb 100644 --- a/config.properties +++ b/config.properties @@ -79,3 +79,9 @@ hibernate.connection.password:changeme # Frontend configuration #frontend.statusPageUrl:https://kavin.rocks #frontend.donationUrl:https://kavin.rocks + +# Oidc configuration +#oidc.provider.INSERT_HERE.name:INSERT_HERE +#oidc.provider.INSERT_HERE.clientId:INSERT_HERE +#oidc.provider.INSERT_HERE.clientSecret:INSERT_HERE +#oidc.provider.INSERT_HERE.authUrl:INSERT_HERE diff --git a/src/main/java/me/kavin/piped/consts/Constants.java b/src/main/java/me/kavin/piped/consts/Constants.java index 478b99e..ef2fd82 100644 --- a/src/main/java/me/kavin/piped/consts/Constants.java +++ b/src/main/java/me/kavin/piped/consts/Constants.java @@ -3,12 +3,14 @@ package me.kavin.piped.consts; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import io.minio.MinioClient; import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; import me.kavin.piped.utils.PageMixin; import me.kavin.piped.utils.RequestUtils; +import me.kavin.piped.utils.obj.OidcProvider; import me.kavin.piped.utils.resp.ListLinkHandlerMixin; import okhttp3.OkHttpClient; import okhttp3.brotli.BrotliInterceptor; @@ -24,6 +26,7 @@ import java.io.File; import java.io.FileReader; import java.net.InetSocketAddress; import java.net.ProxySelector; +import java.util.LinkedList; import java.util.List; import java.util.Properties; import java.util.regex.Pattern; @@ -99,6 +102,7 @@ public class Constants { public static final String YOUTUBE_COUNTRY; public static final String VERSION; + public static final LinkedList OIDC_PROVIDERS; public static final ObjectMapper mapper = JsonMapper.builder() .addMixIn(Page.class, PageMixin.class) @@ -162,12 +166,34 @@ public class Constants { MATRIX_SERVER = getProperty(prop, "MATRIX_SERVER", "https://matrix-client.matrix.org"); MATRIX_TOKEN = getProperty(prop, "MATRIX_TOKEN"); GEO_RESTRICTION_CHECKER_URL = getProperty(prop, "GEO_RESTRICTION_CHECKER_URL"); + + OIDC_PROVIDERS = new LinkedList<>(); + ArrayNode providerNames = frontendProperties.putArray("oidcProviders"); prop.forEach((_key, _value) -> { String key = String.valueOf(_key), value = String.valueOf(_value); if (key.startsWith("hibernate")) hibernateProperties.put(key, value); else if (key.startsWith("frontend.")) frontendProperties.put(StringUtils.substringAfter(key, "frontend."), value); + else if (key.startsWith("oidc.provider")) { + String[] split = key.split("\\."); + if (split.length != 4 || !split[3].equals("name")) return; + + try { + OIDC_PROVIDERS.add(new OidcProvider( + value, + getProperty(prop, "oidc.provider." + value + ".clientId"), + getProperty(prop, "oidc.provider." + value + ".clientSecret"), + getProperty(prop, "oidc.provider." + value + ".authUrl"), + getProperty(prop, "oidc.provider." + value + ".tokenUrl"), + getProperty(prop, "oidc.provider." + value + ".userinfoUrl") + )); + } catch (Exception e) { + System.err.println("Error while getting properties for oidc provider '" + value + "'"); + throw new RuntimeException(e); + } + providerNames.add(value); + } }); frontendProperties.put("imageProxyUrl", IMAGE_PROXY_PART); frontendProperties.putArray("countries").addAll( diff --git a/src/main/java/me/kavin/piped/server/ServerLauncher.java b/src/main/java/me/kavin/piped/server/ServerLauncher.java index 44e83ad..6a860fc 100644 --- a/src/main/java/me/kavin/piped/server/ServerLauncher.java +++ b/src/main/java/me/kavin/piped/server/ServerLauncher.java @@ -2,6 +2,9 @@ package me.kavin.piped.server; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; +import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.openid.connect.sdk.claims.UserInfo; import com.rometools.rome.feed.synd.SyndFeed; import com.rometools.rome.io.SyndFeedInput; import io.activej.config.Config; @@ -19,7 +22,9 @@ import me.kavin.piped.server.handlers.auth.FeedHandlers; import me.kavin.piped.server.handlers.auth.StorageHandlers; import me.kavin.piped.server.handlers.auth.UserHandlers; import me.kavin.piped.utils.*; +import me.kavin.piped.utils.ErrorResponse; import me.kavin.piped.utils.obj.MatrixHelper; +import me.kavin.piped.utils.obj.OidcProvider; import me.kavin.piped.utils.obj.federation.FederatedVideoInfo; import me.kavin.piped.utils.resp.*; import org.apache.commons.lang3.StringUtils; @@ -30,12 +35,18 @@ import org.jetbrains.annotations.NotNull; import org.schabi.newpipe.extractor.exceptions.ParsingException; import org.schabi.newpipe.extractor.localization.DateWrapper; import org.xml.sax.InputSource; +import com.nimbusds.oauth2.sdk.*; +import com.nimbusds.openid.connect.sdk.*; +import com.nimbusds.oauth2.sdk.id.*; import java.io.ByteArrayInputStream; import java.net.InetSocketAddress; +import java.net.URI; +import java.util.LinkedList; import java.util.Objects; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress; import static io.activej.http.HttpHeaders.*; @@ -293,6 +304,88 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { LoginRequest.class); return getJsonResponse(UserHandlers.registerResponse(body.username, body.password), "private"); + } catch (Exception e) { + return getErrorResponse(e, request.getPath()); + } + })).map(GET, "/oidc/:provider/:function", AsyncServlet.ofBlocking(executor, request -> { + try { + String function = request.getPathParameter("function"); + + OidcProvider provider = findOidcProvider(request.getPathParameter("provider"), Constants.OIDC_PROVIDERS); + if(provider == null) + return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server."); + + URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback"); + + switch (function) { + case "login" -> { + + State state = new State(); + Nonce nonce = new Nonce(); + + AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder( + new ResponseType("code"), + new Scope("openid"), + provider.clientID, + callback) + .endpointURI(provider.authUri) + .state(state) + .nonce(nonce) + .build(); + + return HttpResponse.redirect302(oidcRequest.toURI().toString()); + } + case "callback" -> { + ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret); + + AuthenticationResponse response = AuthenticationResponseParser.parse( + URI.create(request.getFullUrl()) + ); + + if (response instanceof AuthenticationErrorResponse) { + // The OpenID provider returned an error + System.err.println(response.toErrorResponse().getErrorObject()); + return HttpResponse.ofCode(500).withHtml("OpenID provider returned an error:\n\n" + response.toErrorResponse().getErrorObject().toString()); + } + AuthenticationSuccessResponse sr = response.toSuccessResponse(); + + AuthorizationCode code = sr.getAuthorizationCode(); + AuthorizationGrant codeGrant = new AuthorizationCodeGrant( + code, callback + ); + + TokenRequest tr = new TokenRequest(provider.tokenUri, clientAuth, codeGrant); + TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tr.toHTTPRequest().send()); + + if (! tokenResponse.indicatesSuccess()) { + TokenErrorResponse errorResponse = tokenResponse.toErrorResponse(); + return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription()); + } + + OIDCTokenResponse successResponse = (OIDCTokenResponse)tokenResponse.toSuccessResponse(); + + + UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken()); + UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send()); + + if (! userInfoResponse.indicatesSuccess()) { + System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getCode()); + System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getDescription()); + return HttpResponse.ofCode(500).withHtml("Failed to query userInfo:\n\n" + userInfoResponse.toErrorResponse().getErrorObject().getDescription()); + } + + UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo(); + + String sessionId = UserHandlers.oidcCallbackResponse(provider.name, userInfo.getSubject().toString()); + + return HttpResponse.redirect302(Constants.FRONTEND_URL + "/login?session=" + sessionId); + } + default -> { + return HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`."); + } + } + + } catch (Exception e) { return getErrorResponse(e, request.getPath()); } @@ -542,6 +635,14 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { return new CustomServletDecorator(router); } + private static OidcProvider findOidcProvider(String provider, LinkedList list){ + for(int i = 0; i < list.size(); i++) { + OidcProvider curr = list.get(i); + if(curr == null || !curr.name.equals(provider)) continue; + return curr; + } + return null; + } private static String[] getArray(String s) { if (s == null) { diff --git a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java index ad36481..1c7c21c 100644 --- a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java +++ b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java @@ -107,11 +107,36 @@ public class UserHandlers { return null; } } + public static String oidcCallbackResponse(String provider, String uid) { + try (Session s = DatabaseSessionFactory.createSession()) { + String dbName = provider + "-" + uid; + System.out.println(dbName); //TODO: + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery cr = cb.createQuery(User.class); + Root root = cr.from(User.class); + cr.select(root).where(root.get("username").in( + dbName + )); + User dbuser = s.createQuery(cr).uniqueResult(); + + if (dbuser == null) { + User newuser = new User(dbName, "", Set.of()); + + var tr = s.beginTransaction(); + s.persist(newuser); + tr.commit(); + + + return newuser.getSessionId(); + } + return dbuser.getSessionId(); + } + + } public static byte[] deleteUserResponse(String session, String pass) throws IOException { - - if (StringUtils.isBlank(session) || StringUtils.isBlank(pass)) - ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session and password are required parameters")); + if (StringUtils.isBlank(session)) + ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session is a required parameter")); try (Session s = DatabaseSessionFactory.createSession()) { User user = DatabaseHelper.getUserFromSession(session); @@ -121,6 +146,13 @@ public class UserHandlers { String hash = user.getPassword(); + if (hash.equals("")) { + //TODO: Authorize against oidc provider before deletion + var tr = s.beginTransaction(); + s.remove(user); + tr.commit(); + return mapper.writeValueAsBytes(new DeleteUserResponse(user.getUsername())); + } if (!hashMatch(hash, pass)) ExceptionHandler.throwErrorResponse(new IncorrectCredentialsResponse()); diff --git a/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java b/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java new file mode 100644 index 0000000..a869c2a --- /dev/null +++ b/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java @@ -0,0 +1,25 @@ +package me.kavin.piped.utils.obj; + +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.id.ClientID; + +import java.net.URI; +import java.net.URISyntaxException; + +public class OidcProvider { + public String name; + public ClientID clientID; + public Secret clientSecret; + public URI authUri; + public URI tokenUri; + public URI userinfoUri; + + public OidcProvider(String name, String clientID, String clientSecret, String authUri, String tokenUri, String userinfoUri) throws URISyntaxException { + this.name = name; + this.clientID = new ClientID(clientID); + this.clientSecret = new Secret(clientSecret); + this.authUri = new URI(authUri); + this.tokenUri = new URI(tokenUri); + this.userinfoUri = new URI(userinfoUri); + } +} diff --git a/src/main/java/me/kavin/piped/utils/obj/db/User.java b/src/main/java/me/kavin/piped/utils/obj/db/User.java index fe3ceb4..1bf1b42 100644 --- a/src/main/java/me/kavin/piped/utils/obj/db/User.java +++ b/src/main/java/me/kavin/piped/utils/obj/db/User.java @@ -20,7 +20,7 @@ public class User implements Serializable { @Column(name = "id") private long id; - @Column(name = "username", unique = true, length = 24) + @Column(name = "username", unique = true, length = 32) private String username; @Column(name = "password", columnDefinition = "text")