mirror of
https://github.com/TeamPiped/Piped-Backend.git
synced 2025-01-10 03:20:30 +05:30
Implement oidc
This commit is contained in:
parent
73b303ffe6
commit
604fa651fc
@ -42,6 +42,7 @@ dependencies {
|
|||||||
implementation 'io.sentry:sentry:6.23.0'
|
implementation 'io.sentry:sentry:6.23.0'
|
||||||
implementation 'rocks.kavin:reqwest4j:1.0.4'
|
implementation 'rocks.kavin:reqwest4j:1.0.4'
|
||||||
implementation 'io.minio:minio:8.5.3'
|
implementation 'io.minio:minio:8.5.3'
|
||||||
|
implementation 'com.nimbusds:oauth2-oidc-sdk:10.9.1'
|
||||||
}
|
}
|
||||||
|
|
||||||
shadowJar {
|
shadowJar {
|
||||||
|
@ -79,3 +79,9 @@ hibernate.connection.password:changeme
|
|||||||
# Frontend configuration
|
# Frontend configuration
|
||||||
#frontend.statusPageUrl:https://kavin.rocks
|
#frontend.statusPageUrl:https://kavin.rocks
|
||||||
#frontend.donationUrl:https://kavin.rocks
|
#frontend.donationUrl:https://kavin.rocks
|
||||||
|
|
||||||
|
# Oidc configuration
|
||||||
|
#oidc.provider.INSERT_HERE.name:INSERT_HERE
|
||||||
|
#oidc.provider.INSERT_HERE.clientId:INSERT_HERE
|
||||||
|
#oidc.provider.INSERT_HERE.clientSecret:INSERT_HERE
|
||||||
|
#oidc.provider.INSERT_HERE.authUrl:INSERT_HERE
|
||||||
|
@ -3,12 +3,14 @@ package me.kavin.piped.consts;
|
|||||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.fasterxml.jackson.databind.json.JsonMapper;
|
import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||||
|
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import io.minio.MinioClient;
|
import io.minio.MinioClient;
|
||||||
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
|
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
|
||||||
import me.kavin.piped.utils.PageMixin;
|
import me.kavin.piped.utils.PageMixin;
|
||||||
import me.kavin.piped.utils.RequestUtils;
|
import me.kavin.piped.utils.RequestUtils;
|
||||||
|
import me.kavin.piped.utils.obj.OidcProvider;
|
||||||
import me.kavin.piped.utils.resp.ListLinkHandlerMixin;
|
import me.kavin.piped.utils.resp.ListLinkHandlerMixin;
|
||||||
import okhttp3.OkHttpClient;
|
import okhttp3.OkHttpClient;
|
||||||
import okhttp3.brotli.BrotliInterceptor;
|
import okhttp3.brotli.BrotliInterceptor;
|
||||||
@ -24,6 +26,7 @@ import java.io.File;
|
|||||||
import java.io.FileReader;
|
import java.io.FileReader;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
import java.net.ProxySelector;
|
import java.net.ProxySelector;
|
||||||
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Properties;
|
import java.util.Properties;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
@ -99,6 +102,7 @@ public class Constants {
|
|||||||
public static final String YOUTUBE_COUNTRY;
|
public static final String YOUTUBE_COUNTRY;
|
||||||
|
|
||||||
public static final String VERSION;
|
public static final String VERSION;
|
||||||
|
public static final LinkedList<OidcProvider> OIDC_PROVIDERS;
|
||||||
|
|
||||||
public static final ObjectMapper mapper = JsonMapper.builder()
|
public static final ObjectMapper mapper = JsonMapper.builder()
|
||||||
.addMixIn(Page.class, PageMixin.class)
|
.addMixIn(Page.class, PageMixin.class)
|
||||||
@ -162,12 +166,34 @@ public class Constants {
|
|||||||
MATRIX_SERVER = getProperty(prop, "MATRIX_SERVER", "https://matrix-client.matrix.org");
|
MATRIX_SERVER = getProperty(prop, "MATRIX_SERVER", "https://matrix-client.matrix.org");
|
||||||
MATRIX_TOKEN = getProperty(prop, "MATRIX_TOKEN");
|
MATRIX_TOKEN = getProperty(prop, "MATRIX_TOKEN");
|
||||||
GEO_RESTRICTION_CHECKER_URL = getProperty(prop, "GEO_RESTRICTION_CHECKER_URL");
|
GEO_RESTRICTION_CHECKER_URL = getProperty(prop, "GEO_RESTRICTION_CHECKER_URL");
|
||||||
|
|
||||||
|
OIDC_PROVIDERS = new LinkedList<>();
|
||||||
|
ArrayNode providerNames = frontendProperties.putArray("oidcProviders");
|
||||||
prop.forEach((_key, _value) -> {
|
prop.forEach((_key, _value) -> {
|
||||||
String key = String.valueOf(_key), value = String.valueOf(_value);
|
String key = String.valueOf(_key), value = String.valueOf(_value);
|
||||||
if (key.startsWith("hibernate"))
|
if (key.startsWith("hibernate"))
|
||||||
hibernateProperties.put(key, value);
|
hibernateProperties.put(key, value);
|
||||||
else if (key.startsWith("frontend."))
|
else if (key.startsWith("frontend."))
|
||||||
frontendProperties.put(StringUtils.substringAfter(key, "frontend."), value);
|
frontendProperties.put(StringUtils.substringAfter(key, "frontend."), value);
|
||||||
|
else if (key.startsWith("oidc.provider")) {
|
||||||
|
String[] split = key.split("\\.");
|
||||||
|
if (split.length != 4 || !split[3].equals("name")) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
OIDC_PROVIDERS.add(new OidcProvider(
|
||||||
|
value,
|
||||||
|
getProperty(prop, "oidc.provider." + value + ".clientId"),
|
||||||
|
getProperty(prop, "oidc.provider." + value + ".clientSecret"),
|
||||||
|
getProperty(prop, "oidc.provider." + value + ".authUrl"),
|
||||||
|
getProperty(prop, "oidc.provider." + value + ".tokenUrl"),
|
||||||
|
getProperty(prop, "oidc.provider." + value + ".userinfoUrl")
|
||||||
|
));
|
||||||
|
} catch (Exception e) {
|
||||||
|
System.err.println("Error while getting properties for oidc provider '" + value + "'");
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
providerNames.add(value);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
frontendProperties.put("imageProxyUrl", IMAGE_PROXY_PART);
|
frontendProperties.put("imageProxyUrl", IMAGE_PROXY_PART);
|
||||||
frontendProperties.putArray("countries").addAll(
|
frontendProperties.putArray("countries").addAll(
|
||||||
|
@ -2,6 +2,9 @@ package me.kavin.piped.server;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
|
||||||
|
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
|
||||||
|
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
|
||||||
import com.rometools.rome.feed.synd.SyndFeed;
|
import com.rometools.rome.feed.synd.SyndFeed;
|
||||||
import com.rometools.rome.io.SyndFeedInput;
|
import com.rometools.rome.io.SyndFeedInput;
|
||||||
import io.activej.config.Config;
|
import io.activej.config.Config;
|
||||||
@ -19,7 +22,9 @@ import me.kavin.piped.server.handlers.auth.FeedHandlers;
|
|||||||
import me.kavin.piped.server.handlers.auth.StorageHandlers;
|
import me.kavin.piped.server.handlers.auth.StorageHandlers;
|
||||||
import me.kavin.piped.server.handlers.auth.UserHandlers;
|
import me.kavin.piped.server.handlers.auth.UserHandlers;
|
||||||
import me.kavin.piped.utils.*;
|
import me.kavin.piped.utils.*;
|
||||||
|
import me.kavin.piped.utils.ErrorResponse;
|
||||||
import me.kavin.piped.utils.obj.MatrixHelper;
|
import me.kavin.piped.utils.obj.MatrixHelper;
|
||||||
|
import me.kavin.piped.utils.obj.OidcProvider;
|
||||||
import me.kavin.piped.utils.obj.federation.FederatedVideoInfo;
|
import me.kavin.piped.utils.obj.federation.FederatedVideoInfo;
|
||||||
import me.kavin.piped.utils.resp.*;
|
import me.kavin.piped.utils.resp.*;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@ -30,12 +35,18 @@ import org.jetbrains.annotations.NotNull;
|
|||||||
import org.schabi.newpipe.extractor.exceptions.ParsingException;
|
import org.schabi.newpipe.extractor.exceptions.ParsingException;
|
||||||
import org.schabi.newpipe.extractor.localization.DateWrapper;
|
import org.schabi.newpipe.extractor.localization.DateWrapper;
|
||||||
import org.xml.sax.InputSource;
|
import org.xml.sax.InputSource;
|
||||||
|
import com.nimbusds.oauth2.sdk.*;
|
||||||
|
import com.nimbusds.openid.connect.sdk.*;
|
||||||
|
import com.nimbusds.oauth2.sdk.id.*;
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.LinkedList;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.Executor;
|
import java.util.concurrent.Executor;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress;
|
import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress;
|
||||||
import static io.activej.http.HttpHeaders.*;
|
import static io.activej.http.HttpHeaders.*;
|
||||||
@ -293,6 +304,88 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {
|
|||||||
LoginRequest.class);
|
LoginRequest.class);
|
||||||
return getJsonResponse(UserHandlers.registerResponse(body.username, body.password),
|
return getJsonResponse(UserHandlers.registerResponse(body.username, body.password),
|
||||||
"private");
|
"private");
|
||||||
|
} catch (Exception e) {
|
||||||
|
return getErrorResponse(e, request.getPath());
|
||||||
|
}
|
||||||
|
})).map(GET, "/oidc/:provider/:function", AsyncServlet.ofBlocking(executor, request -> {
|
||||||
|
try {
|
||||||
|
String function = request.getPathParameter("function");
|
||||||
|
|
||||||
|
OidcProvider provider = findOidcProvider(request.getPathParameter("provider"), Constants.OIDC_PROVIDERS);
|
||||||
|
if(provider == null)
|
||||||
|
return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server.");
|
||||||
|
|
||||||
|
URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback");
|
||||||
|
|
||||||
|
switch (function) {
|
||||||
|
case "login" -> {
|
||||||
|
|
||||||
|
State state = new State();
|
||||||
|
Nonce nonce = new Nonce();
|
||||||
|
|
||||||
|
AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder(
|
||||||
|
new ResponseType("code"),
|
||||||
|
new Scope("openid"),
|
||||||
|
provider.clientID,
|
||||||
|
callback)
|
||||||
|
.endpointURI(provider.authUri)
|
||||||
|
.state(state)
|
||||||
|
.nonce(nonce)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return HttpResponse.redirect302(oidcRequest.toURI().toString());
|
||||||
|
}
|
||||||
|
case "callback" -> {
|
||||||
|
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);
|
||||||
|
|
||||||
|
AuthenticationResponse response = AuthenticationResponseParser.parse(
|
||||||
|
URI.create(request.getFullUrl())
|
||||||
|
);
|
||||||
|
|
||||||
|
if (response instanceof AuthenticationErrorResponse) {
|
||||||
|
// The OpenID provider returned an error
|
||||||
|
System.err.println(response.toErrorResponse().getErrorObject());
|
||||||
|
return HttpResponse.ofCode(500).withHtml("OpenID provider returned an error:\n\n" + response.toErrorResponse().getErrorObject().toString());
|
||||||
|
}
|
||||||
|
AuthenticationSuccessResponse sr = response.toSuccessResponse();
|
||||||
|
|
||||||
|
AuthorizationCode code = sr.getAuthorizationCode();
|
||||||
|
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(
|
||||||
|
code, callback
|
||||||
|
);
|
||||||
|
|
||||||
|
TokenRequest tr = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
|
||||||
|
TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tr.toHTTPRequest().send());
|
||||||
|
|
||||||
|
if (! tokenResponse.indicatesSuccess()) {
|
||||||
|
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
|
||||||
|
return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription());
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDCTokenResponse successResponse = (OIDCTokenResponse)tokenResponse.toSuccessResponse();
|
||||||
|
|
||||||
|
|
||||||
|
UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken());
|
||||||
|
UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send());
|
||||||
|
|
||||||
|
if (! userInfoResponse.indicatesSuccess()) {
|
||||||
|
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getCode());
|
||||||
|
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getDescription());
|
||||||
|
return HttpResponse.ofCode(500).withHtml("Failed to query userInfo:\n\n" + userInfoResponse.toErrorResponse().getErrorObject().getDescription());
|
||||||
|
}
|
||||||
|
|
||||||
|
UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo();
|
||||||
|
|
||||||
|
String sessionId = UserHandlers.oidcCallbackResponse(provider.name, userInfo.getSubject().toString());
|
||||||
|
|
||||||
|
return HttpResponse.redirect302(Constants.FRONTEND_URL + "/login?session=" + sessionId);
|
||||||
|
}
|
||||||
|
default -> {
|
||||||
|
return HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
return getErrorResponse(e, request.getPath());
|
return getErrorResponse(e, request.getPath());
|
||||||
}
|
}
|
||||||
@ -542,6 +635,14 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {
|
|||||||
return new CustomServletDecorator(router);
|
return new CustomServletDecorator(router);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static OidcProvider findOidcProvider(String provider, LinkedList<OidcProvider> list){
|
||||||
|
for(int i = 0; i < list.size(); i++) {
|
||||||
|
OidcProvider curr = list.get(i);
|
||||||
|
if(curr == null || !curr.name.equals(provider)) continue;
|
||||||
|
return curr;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
private static String[] getArray(String s) {
|
private static String[] getArray(String s) {
|
||||||
|
|
||||||
if (s == null) {
|
if (s == null) {
|
||||||
|
@ -107,11 +107,36 @@ public class UserHandlers {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
public static String oidcCallbackResponse(String provider, String uid) {
|
||||||
|
try (Session s = DatabaseSessionFactory.createSession()) {
|
||||||
|
String dbName = provider + "-" + uid;
|
||||||
|
System.out.println(dbName); //TODO:
|
||||||
|
CriteriaBuilder cb = s.getCriteriaBuilder();
|
||||||
|
CriteriaQuery<User> cr = cb.createQuery(User.class);
|
||||||
|
Root<User> root = cr.from(User.class);
|
||||||
|
cr.select(root).where(root.get("username").in(
|
||||||
|
dbName
|
||||||
|
));
|
||||||
|
|
||||||
|
User dbuser = s.createQuery(cr).uniqueResult();
|
||||||
|
|
||||||
|
if (dbuser == null) {
|
||||||
|
User newuser = new User(dbName, "", Set.of());
|
||||||
|
|
||||||
|
var tr = s.beginTransaction();
|
||||||
|
s.persist(newuser);
|
||||||
|
tr.commit();
|
||||||
|
|
||||||
|
|
||||||
|
return newuser.getSessionId();
|
||||||
|
}
|
||||||
|
return dbuser.getSessionId();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
public static byte[] deleteUserResponse(String session, String pass) throws IOException {
|
public static byte[] deleteUserResponse(String session, String pass) throws IOException {
|
||||||
|
if (StringUtils.isBlank(session))
|
||||||
if (StringUtils.isBlank(session) || StringUtils.isBlank(pass))
|
ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session is a required parameter"));
|
||||||
ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session and password are required parameters"));
|
|
||||||
|
|
||||||
try (Session s = DatabaseSessionFactory.createSession()) {
|
try (Session s = DatabaseSessionFactory.createSession()) {
|
||||||
User user = DatabaseHelper.getUserFromSession(session);
|
User user = DatabaseHelper.getUserFromSession(session);
|
||||||
@ -121,6 +146,13 @@ public class UserHandlers {
|
|||||||
|
|
||||||
String hash = user.getPassword();
|
String hash = user.getPassword();
|
||||||
|
|
||||||
|
if (hash.equals("")) {
|
||||||
|
//TODO: Authorize against oidc provider before deletion
|
||||||
|
var tr = s.beginTransaction();
|
||||||
|
s.remove(user);
|
||||||
|
tr.commit();
|
||||||
|
return mapper.writeValueAsBytes(new DeleteUserResponse(user.getUsername()));
|
||||||
|
}
|
||||||
if (!hashMatch(hash, pass))
|
if (!hashMatch(hash, pass))
|
||||||
ExceptionHandler.throwErrorResponse(new IncorrectCredentialsResponse());
|
ExceptionHandler.throwErrorResponse(new IncorrectCredentialsResponse());
|
||||||
|
|
||||||
|
25
src/main/java/me/kavin/piped/utils/obj/OidcProvider.java
Normal file
25
src/main/java/me/kavin/piped/utils/obj/OidcProvider.java
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package me.kavin.piped.utils.obj;
|
||||||
|
|
||||||
|
import com.nimbusds.oauth2.sdk.auth.Secret;
|
||||||
|
import com.nimbusds.oauth2.sdk.id.ClientID;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.URISyntaxException;
|
||||||
|
|
||||||
|
public class OidcProvider {
|
||||||
|
public String name;
|
||||||
|
public ClientID clientID;
|
||||||
|
public Secret clientSecret;
|
||||||
|
public URI authUri;
|
||||||
|
public URI tokenUri;
|
||||||
|
public URI userinfoUri;
|
||||||
|
|
||||||
|
public OidcProvider(String name, String clientID, String clientSecret, String authUri, String tokenUri, String userinfoUri) throws URISyntaxException {
|
||||||
|
this.name = name;
|
||||||
|
this.clientID = new ClientID(clientID);
|
||||||
|
this.clientSecret = new Secret(clientSecret);
|
||||||
|
this.authUri = new URI(authUri);
|
||||||
|
this.tokenUri = new URI(tokenUri);
|
||||||
|
this.userinfoUri = new URI(userinfoUri);
|
||||||
|
}
|
||||||
|
}
|
@ -20,7 +20,7 @@ public class User implements Serializable {
|
|||||||
@Column(name = "id")
|
@Column(name = "id")
|
||||||
private long id;
|
private long id;
|
||||||
|
|
||||||
@Column(name = "username", unique = true, length = 24)
|
@Column(name = "username", unique = true, length = 32)
|
||||||
private String username;
|
private String username;
|
||||||
|
|
||||||
@Column(name = "password", columnDefinition = "text")
|
@Column(name = "password", columnDefinition = "text")
|
||||||
|
Loading…
Reference in New Issue
Block a user