Implement oidc

This commit is contained in:
Jeidnx 2023-06-18 18:28:18 +02:00
parent 73b303ffe6
commit 604fa651fc
No known key found for this signature in database
GPG Key ID: 0E9E697B7E99DF39
7 changed files with 195 additions and 4 deletions

View File

@ -42,6 +42,7 @@ dependencies {
implementation 'io.sentry:sentry:6.23.0' implementation 'io.sentry:sentry:6.23.0'
implementation 'rocks.kavin:reqwest4j:1.0.4' implementation 'rocks.kavin:reqwest4j:1.0.4'
implementation 'io.minio:minio:8.5.3' implementation 'io.minio:minio:8.5.3'
implementation 'com.nimbusds:oauth2-oidc-sdk:10.9.1'
} }
shadowJar { shadowJar {

View File

@ -79,3 +79,9 @@ hibernate.connection.password:changeme
# Frontend configuration # Frontend configuration
#frontend.statusPageUrl:https://kavin.rocks #frontend.statusPageUrl:https://kavin.rocks
#frontend.donationUrl:https://kavin.rocks #frontend.donationUrl:https://kavin.rocks
# Oidc configuration
#oidc.provider.INSERT_HERE.name:INSERT_HERE
#oidc.provider.INSERT_HERE.clientId:INSERT_HERE
#oidc.provider.INSERT_HERE.clientSecret:INSERT_HERE
#oidc.provider.INSERT_HERE.authUrl:INSERT_HERE

View File

@ -3,12 +3,14 @@ package me.kavin.piped.consts;
import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper; import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import io.minio.MinioClient; import io.minio.MinioClient;
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import me.kavin.piped.utils.PageMixin; import me.kavin.piped.utils.PageMixin;
import me.kavin.piped.utils.RequestUtils; import me.kavin.piped.utils.RequestUtils;
import me.kavin.piped.utils.obj.OidcProvider;
import me.kavin.piped.utils.resp.ListLinkHandlerMixin; import me.kavin.piped.utils.resp.ListLinkHandlerMixin;
import okhttp3.OkHttpClient; import okhttp3.OkHttpClient;
import okhttp3.brotli.BrotliInterceptor; import okhttp3.brotli.BrotliInterceptor;
@ -24,6 +26,7 @@ import java.io.File;
import java.io.FileReader; import java.io.FileReader;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.ProxySelector; import java.net.ProxySelector;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Properties; import java.util.Properties;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@ -99,6 +102,7 @@ public class Constants {
public static final String YOUTUBE_COUNTRY; public static final String YOUTUBE_COUNTRY;
public static final String VERSION; public static final String VERSION;
public static final LinkedList<OidcProvider> OIDC_PROVIDERS;
public static final ObjectMapper mapper = JsonMapper.builder() public static final ObjectMapper mapper = JsonMapper.builder()
.addMixIn(Page.class, PageMixin.class) .addMixIn(Page.class, PageMixin.class)
@ -162,12 +166,34 @@ public class Constants {
MATRIX_SERVER = getProperty(prop, "MATRIX_SERVER", "https://matrix-client.matrix.org"); MATRIX_SERVER = getProperty(prop, "MATRIX_SERVER", "https://matrix-client.matrix.org");
MATRIX_TOKEN = getProperty(prop, "MATRIX_TOKEN"); MATRIX_TOKEN = getProperty(prop, "MATRIX_TOKEN");
GEO_RESTRICTION_CHECKER_URL = getProperty(prop, "GEO_RESTRICTION_CHECKER_URL"); GEO_RESTRICTION_CHECKER_URL = getProperty(prop, "GEO_RESTRICTION_CHECKER_URL");
OIDC_PROVIDERS = new LinkedList<>();
ArrayNode providerNames = frontendProperties.putArray("oidcProviders");
prop.forEach((_key, _value) -> { prop.forEach((_key, _value) -> {
String key = String.valueOf(_key), value = String.valueOf(_value); String key = String.valueOf(_key), value = String.valueOf(_value);
if (key.startsWith("hibernate")) if (key.startsWith("hibernate"))
hibernateProperties.put(key, value); hibernateProperties.put(key, value);
else if (key.startsWith("frontend.")) else if (key.startsWith("frontend."))
frontendProperties.put(StringUtils.substringAfter(key, "frontend."), value); frontendProperties.put(StringUtils.substringAfter(key, "frontend."), value);
else if (key.startsWith("oidc.provider")) {
String[] split = key.split("\\.");
if (split.length != 4 || !split[3].equals("name")) return;
try {
OIDC_PROVIDERS.add(new OidcProvider(
value,
getProperty(prop, "oidc.provider." + value + ".clientId"),
getProperty(prop, "oidc.provider." + value + ".clientSecret"),
getProperty(prop, "oidc.provider." + value + ".authUrl"),
getProperty(prop, "oidc.provider." + value + ".tokenUrl"),
getProperty(prop, "oidc.provider." + value + ".userinfoUrl")
));
} catch (Exception e) {
System.err.println("Error while getting properties for oidc provider '" + value + "'");
throw new RuntimeException(e);
}
providerNames.add(value);
}
}); });
frontendProperties.put("imageProxyUrl", IMAGE_PROXY_PART); frontendProperties.put("imageProxyUrl", IMAGE_PROXY_PART);
frontendProperties.putArray("countries").addAll( frontendProperties.putArray("countries").addAll(

View File

@ -2,6 +2,9 @@ package me.kavin.piped.server;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonNode;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
import com.rometools.rome.feed.synd.SyndFeed; import com.rometools.rome.feed.synd.SyndFeed;
import com.rometools.rome.io.SyndFeedInput; import com.rometools.rome.io.SyndFeedInput;
import io.activej.config.Config; import io.activej.config.Config;
@ -19,7 +22,9 @@ import me.kavin.piped.server.handlers.auth.FeedHandlers;
import me.kavin.piped.server.handlers.auth.StorageHandlers; import me.kavin.piped.server.handlers.auth.StorageHandlers;
import me.kavin.piped.server.handlers.auth.UserHandlers; import me.kavin.piped.server.handlers.auth.UserHandlers;
import me.kavin.piped.utils.*; import me.kavin.piped.utils.*;
import me.kavin.piped.utils.ErrorResponse;
import me.kavin.piped.utils.obj.MatrixHelper; import me.kavin.piped.utils.obj.MatrixHelper;
import me.kavin.piped.utils.obj.OidcProvider;
import me.kavin.piped.utils.obj.federation.FederatedVideoInfo; import me.kavin.piped.utils.obj.federation.FederatedVideoInfo;
import me.kavin.piped.utils.resp.*; import me.kavin.piped.utils.resp.*;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@ -30,12 +35,18 @@ import org.jetbrains.annotations.NotNull;
import org.schabi.newpipe.extractor.exceptions.ParsingException; import org.schabi.newpipe.extractor.exceptions.ParsingException;
import org.schabi.newpipe.extractor.localization.DateWrapper; import org.schabi.newpipe.extractor.localization.DateWrapper;
import org.xml.sax.InputSource; import org.xml.sax.InputSource;
import com.nimbusds.oauth2.sdk.*;
import com.nimbusds.openid.connect.sdk.*;
import com.nimbusds.oauth2.sdk.id.*;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.URI;
import java.util.LinkedList;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress; import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress;
import static io.activej.http.HttpHeaders.*; import static io.activej.http.HttpHeaders.*;
@ -293,6 +304,88 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {
LoginRequest.class); LoginRequest.class);
return getJsonResponse(UserHandlers.registerResponse(body.username, body.password), return getJsonResponse(UserHandlers.registerResponse(body.username, body.password),
"private"); "private");
} catch (Exception e) {
return getErrorResponse(e, request.getPath());
}
})).map(GET, "/oidc/:provider/:function", AsyncServlet.ofBlocking(executor, request -> {
try {
String function = request.getPathParameter("function");
OidcProvider provider = findOidcProvider(request.getPathParameter("provider"), Constants.OIDC_PROVIDERS);
if(provider == null)
return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server.");
URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback");
switch (function) {
case "login" -> {
State state = new State();
Nonce nonce = new Nonce();
AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder(
new ResponseType("code"),
new Scope("openid"),
provider.clientID,
callback)
.endpointURI(provider.authUri)
.state(state)
.nonce(nonce)
.build();
return HttpResponse.redirect302(oidcRequest.toURI().toString());
}
case "callback" -> {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);
AuthenticationResponse response = AuthenticationResponseParser.parse(
URI.create(request.getFullUrl())
);
if (response instanceof AuthenticationErrorResponse) {
// The OpenID provider returned an error
System.err.println(response.toErrorResponse().getErrorObject());
return HttpResponse.ofCode(500).withHtml("OpenID provider returned an error:\n\n" + response.toErrorResponse().getErrorObject().toString());
}
AuthenticationSuccessResponse sr = response.toSuccessResponse();
AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(
code, callback
);
TokenRequest tr = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tr.toHTTPRequest().send());
if (! tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription());
}
OIDCTokenResponse successResponse = (OIDCTokenResponse)tokenResponse.toSuccessResponse();
UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken());
UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send());
if (! userInfoResponse.indicatesSuccess()) {
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getCode());
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getDescription());
return HttpResponse.ofCode(500).withHtml("Failed to query userInfo:\n\n" + userInfoResponse.toErrorResponse().getErrorObject().getDescription());
}
UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo();
String sessionId = UserHandlers.oidcCallbackResponse(provider.name, userInfo.getSubject().toString());
return HttpResponse.redirect302(Constants.FRONTEND_URL + "/login?session=" + sessionId);
}
default -> {
return HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`.");
}
}
} catch (Exception e) { } catch (Exception e) {
return getErrorResponse(e, request.getPath()); return getErrorResponse(e, request.getPath());
} }
@ -542,6 +635,14 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {
return new CustomServletDecorator(router); return new CustomServletDecorator(router);
} }
private static OidcProvider findOidcProvider(String provider, LinkedList<OidcProvider> list){
for(int i = 0; i < list.size(); i++) {
OidcProvider curr = list.get(i);
if(curr == null || !curr.name.equals(provider)) continue;
return curr;
}
return null;
}
private static String[] getArray(String s) { private static String[] getArray(String s) {
if (s == null) { if (s == null) {

View File

@ -107,11 +107,36 @@ public class UserHandlers {
return null; return null;
} }
} }
public static String oidcCallbackResponse(String provider, String uid) {
try (Session s = DatabaseSessionFactory.createSession()) {
String dbName = provider + "-" + uid;
System.out.println(dbName); //TODO:
CriteriaBuilder cb = s.getCriteriaBuilder();
CriteriaQuery<User> cr = cb.createQuery(User.class);
Root<User> root = cr.from(User.class);
cr.select(root).where(root.get("username").in(
dbName
));
User dbuser = s.createQuery(cr).uniqueResult();
if (dbuser == null) {
User newuser = new User(dbName, "", Set.of());
var tr = s.beginTransaction();
s.persist(newuser);
tr.commit();
return newuser.getSessionId();
}
return dbuser.getSessionId();
}
}
public static byte[] deleteUserResponse(String session, String pass) throws IOException { public static byte[] deleteUserResponse(String session, String pass) throws IOException {
if (StringUtils.isBlank(session))
if (StringUtils.isBlank(session) || StringUtils.isBlank(pass)) ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session is a required parameter"));
ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session and password are required parameters"));
try (Session s = DatabaseSessionFactory.createSession()) { try (Session s = DatabaseSessionFactory.createSession()) {
User user = DatabaseHelper.getUserFromSession(session); User user = DatabaseHelper.getUserFromSession(session);
@ -121,6 +146,13 @@ public class UserHandlers {
String hash = user.getPassword(); String hash = user.getPassword();
if (hash.equals("")) {
//TODO: Authorize against oidc provider before deletion
var tr = s.beginTransaction();
s.remove(user);
tr.commit();
return mapper.writeValueAsBytes(new DeleteUserResponse(user.getUsername()));
}
if (!hashMatch(hash, pass)) if (!hashMatch(hash, pass))
ExceptionHandler.throwErrorResponse(new IncorrectCredentialsResponse()); ExceptionHandler.throwErrorResponse(new IncorrectCredentialsResponse());

View File

@ -0,0 +1,25 @@
package me.kavin.piped.utils.obj;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.id.ClientID;
import java.net.URI;
import java.net.URISyntaxException;
public class OidcProvider {
public String name;
public ClientID clientID;
public Secret clientSecret;
public URI authUri;
public URI tokenUri;
public URI userinfoUri;
public OidcProvider(String name, String clientID, String clientSecret, String authUri, String tokenUri, String userinfoUri) throws URISyntaxException {
this.name = name;
this.clientID = new ClientID(clientID);
this.clientSecret = new Secret(clientSecret);
this.authUri = new URI(authUri);
this.tokenUri = new URI(tokenUri);
this.userinfoUri = new URI(userinfoUri);
}
}

View File

@ -20,7 +20,7 @@ public class User implements Serializable {
@Column(name = "id") @Column(name = "id")
private long id; private long id;
@Column(name = "username", unique = true, length = 24) @Column(name = "username", unique = true, length = 32)
private String username; private String username;
@Column(name = "password", columnDefinition = "text") @Column(name = "password", columnDefinition = "text")