Implement oidc

This commit is contained in:
Jeidnx 2023-06-18 18:28:18 +02:00
parent 73b303ffe6
commit 604fa651fc
No known key found for this signature in database
GPG Key ID: 0E9E697B7E99DF39
7 changed files with 195 additions and 4 deletions

View File

@ -42,6 +42,7 @@ dependencies {
implementation 'io.sentry:sentry:6.23.0'
implementation 'rocks.kavin:reqwest4j:1.0.4'
implementation 'io.minio:minio:8.5.3'
implementation 'com.nimbusds:oauth2-oidc-sdk:10.9.1'
}
shadowJar {

View File

@ -79,3 +79,9 @@ hibernate.connection.password:changeme
# Frontend configuration
#frontend.statusPageUrl:https://kavin.rocks
#frontend.donationUrl:https://kavin.rocks
# Oidc configuration
#oidc.provider.INSERT_HERE.name:INSERT_HERE
#oidc.provider.INSERT_HERE.clientId:INSERT_HERE
#oidc.provider.INSERT_HERE.clientSecret:INSERT_HERE
#oidc.provider.INSERT_HERE.authUrl:INSERT_HERE

View File

@ -3,12 +3,14 @@ package me.kavin.piped.consts;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.minio.MinioClient;
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import me.kavin.piped.utils.PageMixin;
import me.kavin.piped.utils.RequestUtils;
import me.kavin.piped.utils.obj.OidcProvider;
import me.kavin.piped.utils.resp.ListLinkHandlerMixin;
import okhttp3.OkHttpClient;
import okhttp3.brotli.BrotliInterceptor;
@ -24,6 +26,7 @@ import java.io.File;
import java.io.FileReader;
import java.net.InetSocketAddress;
import java.net.ProxySelector;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;
import java.util.regex.Pattern;
@ -99,6 +102,7 @@ public class Constants {
public static final String YOUTUBE_COUNTRY;
public static final String VERSION;
public static final LinkedList<OidcProvider> OIDC_PROVIDERS;
public static final ObjectMapper mapper = JsonMapper.builder()
.addMixIn(Page.class, PageMixin.class)
@ -162,12 +166,34 @@ public class Constants {
MATRIX_SERVER = getProperty(prop, "MATRIX_SERVER", "https://matrix-client.matrix.org");
MATRIX_TOKEN = getProperty(prop, "MATRIX_TOKEN");
GEO_RESTRICTION_CHECKER_URL = getProperty(prop, "GEO_RESTRICTION_CHECKER_URL");
OIDC_PROVIDERS = new LinkedList<>();
ArrayNode providerNames = frontendProperties.putArray("oidcProviders");
prop.forEach((_key, _value) -> {
String key = String.valueOf(_key), value = String.valueOf(_value);
if (key.startsWith("hibernate"))
hibernateProperties.put(key, value);
else if (key.startsWith("frontend."))
frontendProperties.put(StringUtils.substringAfter(key, "frontend."), value);
else if (key.startsWith("oidc.provider")) {
String[] split = key.split("\\.");
if (split.length != 4 || !split[3].equals("name")) return;
try {
OIDC_PROVIDERS.add(new OidcProvider(
value,
getProperty(prop, "oidc.provider." + value + ".clientId"),
getProperty(prop, "oidc.provider." + value + ".clientSecret"),
getProperty(prop, "oidc.provider." + value + ".authUrl"),
getProperty(prop, "oidc.provider." + value + ".tokenUrl"),
getProperty(prop, "oidc.provider." + value + ".userinfoUrl")
));
} catch (Exception e) {
System.err.println("Error while getting properties for oidc provider '" + value + "'");
throw new RuntimeException(e);
}
providerNames.add(value);
}
});
frontendProperties.put("imageProxyUrl", IMAGE_PROXY_PART);
frontendProperties.putArray("countries").addAll(

View File

@ -2,6 +2,9 @@ package me.kavin.piped.server;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
import com.rometools.rome.feed.synd.SyndFeed;
import com.rometools.rome.io.SyndFeedInput;
import io.activej.config.Config;
@ -19,7 +22,9 @@ import me.kavin.piped.server.handlers.auth.FeedHandlers;
import me.kavin.piped.server.handlers.auth.StorageHandlers;
import me.kavin.piped.server.handlers.auth.UserHandlers;
import me.kavin.piped.utils.*;
import me.kavin.piped.utils.ErrorResponse;
import me.kavin.piped.utils.obj.MatrixHelper;
import me.kavin.piped.utils.obj.OidcProvider;
import me.kavin.piped.utils.obj.federation.FederatedVideoInfo;
import me.kavin.piped.utils.resp.*;
import org.apache.commons.lang3.StringUtils;
@ -30,12 +35,18 @@ import org.jetbrains.annotations.NotNull;
import org.schabi.newpipe.extractor.exceptions.ParsingException;
import org.schabi.newpipe.extractor.localization.DateWrapper;
import org.xml.sax.InputSource;
import com.nimbusds.oauth2.sdk.*;
import com.nimbusds.openid.connect.sdk.*;
import com.nimbusds.oauth2.sdk.id.*;
import java.io.ByteArrayInputStream;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.LinkedList;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress;
import static io.activej.http.HttpHeaders.*;
@ -293,6 +304,88 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {
LoginRequest.class);
return getJsonResponse(UserHandlers.registerResponse(body.username, body.password),
"private");
} catch (Exception e) {
return getErrorResponse(e, request.getPath());
}
})).map(GET, "/oidc/:provider/:function", AsyncServlet.ofBlocking(executor, request -> {
try {
String function = request.getPathParameter("function");
OidcProvider provider = findOidcProvider(request.getPathParameter("provider"), Constants.OIDC_PROVIDERS);
if(provider == null)
return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server.");
URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback");
switch (function) {
case "login" -> {
State state = new State();
Nonce nonce = new Nonce();
AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder(
new ResponseType("code"),
new Scope("openid"),
provider.clientID,
callback)
.endpointURI(provider.authUri)
.state(state)
.nonce(nonce)
.build();
return HttpResponse.redirect302(oidcRequest.toURI().toString());
}
case "callback" -> {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);
AuthenticationResponse response = AuthenticationResponseParser.parse(
URI.create(request.getFullUrl())
);
if (response instanceof AuthenticationErrorResponse) {
// The OpenID provider returned an error
System.err.println(response.toErrorResponse().getErrorObject());
return HttpResponse.ofCode(500).withHtml("OpenID provider returned an error:\n\n" + response.toErrorResponse().getErrorObject().toString());
}
AuthenticationSuccessResponse sr = response.toSuccessResponse();
AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(
code, callback
);
TokenRequest tr = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tr.toHTTPRequest().send());
if (! tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription());
}
OIDCTokenResponse successResponse = (OIDCTokenResponse)tokenResponse.toSuccessResponse();
UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken());
UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send());
if (! userInfoResponse.indicatesSuccess()) {
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getCode());
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getDescription());
return HttpResponse.ofCode(500).withHtml("Failed to query userInfo:\n\n" + userInfoResponse.toErrorResponse().getErrorObject().getDescription());
}
UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo();
String sessionId = UserHandlers.oidcCallbackResponse(provider.name, userInfo.getSubject().toString());
return HttpResponse.redirect302(Constants.FRONTEND_URL + "/login?session=" + sessionId);
}
default -> {
return HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`.");
}
}
} catch (Exception e) {
return getErrorResponse(e, request.getPath());
}
@ -542,6 +635,14 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {
return new CustomServletDecorator(router);
}
private static OidcProvider findOidcProvider(String provider, LinkedList<OidcProvider> list){
for(int i = 0; i < list.size(); i++) {
OidcProvider curr = list.get(i);
if(curr == null || !curr.name.equals(provider)) continue;
return curr;
}
return null;
}
private static String[] getArray(String s) {
if (s == null) {

View File

@ -107,11 +107,36 @@ public class UserHandlers {
return null;
}
}
public static String oidcCallbackResponse(String provider, String uid) {
try (Session s = DatabaseSessionFactory.createSession()) {
String dbName = provider + "-" + uid;
System.out.println(dbName); //TODO:
CriteriaBuilder cb = s.getCriteriaBuilder();
CriteriaQuery<User> cr = cb.createQuery(User.class);
Root<User> root = cr.from(User.class);
cr.select(root).where(root.get("username").in(
dbName
));
User dbuser = s.createQuery(cr).uniqueResult();
if (dbuser == null) {
User newuser = new User(dbName, "", Set.of());
var tr = s.beginTransaction();
s.persist(newuser);
tr.commit();
return newuser.getSessionId();
}
return dbuser.getSessionId();
}
}
public static byte[] deleteUserResponse(String session, String pass) throws IOException {
if (StringUtils.isBlank(session) || StringUtils.isBlank(pass))
ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session and password are required parameters"));
if (StringUtils.isBlank(session))
ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session is a required parameter"));
try (Session s = DatabaseSessionFactory.createSession()) {
User user = DatabaseHelper.getUserFromSession(session);
@ -121,6 +146,13 @@ public class UserHandlers {
String hash = user.getPassword();
if (hash.equals("")) {
//TODO: Authorize against oidc provider before deletion
var tr = s.beginTransaction();
s.remove(user);
tr.commit();
return mapper.writeValueAsBytes(new DeleteUserResponse(user.getUsername()));
}
if (!hashMatch(hash, pass))
ExceptionHandler.throwErrorResponse(new IncorrectCredentialsResponse());

View File

@ -0,0 +1,25 @@
package me.kavin.piped.utils.obj;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.id.ClientID;
import java.net.URI;
import java.net.URISyntaxException;
public class OidcProvider {
public String name;
public ClientID clientID;
public Secret clientSecret;
public URI authUri;
public URI tokenUri;
public URI userinfoUri;
public OidcProvider(String name, String clientID, String clientSecret, String authUri, String tokenUri, String userinfoUri) throws URISyntaxException {
this.name = name;
this.clientID = new ClientID(clientID);
this.clientSecret = new Secret(clientSecret);
this.authUri = new URI(authUri);
this.tokenUri = new URI(tokenUri);
this.userinfoUri = new URI(userinfoUri);
}
}

View File

@ -20,7 +20,7 @@ public class User implements Serializable {
@Column(name = "id")
private long id;
@Column(name = "username", unique = true, length = 24)
@Column(name = "username", unique = true, length = 32)
private String username;
@Column(name = "password", columnDefinition = "text")