From caf5adb3047d56c24f16befeaf47347bd4bc36c6 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Mon, 23 Sep 2024 17:42:58 +0530 Subject: [PATCH] fix: oauth stats queries --- .../supertokens/storage/postgresql/Start.java | 23 ++++- .../postgresql/config/PostgreSQLConfig.java | 6 +- .../postgresql/queries/GeneralQueries.java | 15 +++- .../postgresql/queries/OAuthQueries.java | 86 +++++++++++++++++-- 4 files changed, 117 insertions(+), 13 deletions(-) diff --git a/src/main/java/io/supertokens/storage/postgresql/Start.java b/src/main/java/io/supertokens/storage/postgresql/Start.java index 214bee77..bdb2ebc4 100644 --- a/src/main/java/io/supertokens/storage/postgresql/Start.java +++ b/src/main/java/io/supertokens/storage/postgresql/Start.java @@ -122,7 +122,6 @@ public class Start private ResourceDistributor resourceDistributor = new ResourceDistributor(); private String processId; private HikariLoggingAppender appender; - private static final String APP_ID_KEY_NAME = "app_id"; private static final String ACCESS_TOKEN_SIGNING_KEY_NAME = "access_token_signing_key"; private static final String REFRESH_TOKEN_KEY_NAME = "refresh_token_key"; public static boolean isTesting = false; @@ -3140,15 +3139,33 @@ public boolean isRevoked(AppIdentifier appIdentifier, String[] targetTypes, Stri } } + @Override + public void addM2MToken(AppIdentifier appIdentifier, String clientId, long iat, long exp) + throws StorageQueryException { + try { + OAuthQueries.addM2MToken(this, appIdentifier, clientId, iat, exp); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public int countTotalNumberOfM2MTokensAlive(AppIdentifier appIdentifier) throws StorageQueryException { - return 0; // TODO + try { + return OAuthQueries.countTotalNumberOfM2MTokensAlive(this, appIdentifier); + } catch (SQLException e) { + throw new StorageQueryException(e); + } } @Override public int countTotalNumberOfM2MTokensCreatedSince(AppIdentifier appIdentifier, long since) throws StorageQueryException { - return 0; // TODO + try { + return OAuthQueries.countTotalNumberOfM2MTokensCreatedSince(this, appIdentifier, since); + } catch (SQLException e) { + throw new StorageQueryException(e); + } } @Override diff --git a/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java b/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java index 85032719..228719f5 100644 --- a/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java +++ b/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java @@ -439,7 +439,7 @@ public String getDashboardSessionsTable() { return addSchemaAndPrefixToTableName("dashboard_user_sessions"); } - public String getOAuthClientTable() { + public String getOAuthClientsTable() { return addSchemaAndPrefixToTableName("oauth_clients"); } @@ -447,6 +447,10 @@ public String getOAuthRevokeTable() { return addSchemaAndPrefixToTableName("oauth_revoke"); } + public String getOAuthM2MTokensTable() { + return addSchemaAndPrefixToTableName("oauth_m2m_tokens"); + } + public String getTotpUsersTable() { return addSchemaAndPrefixToTableName("totp_users"); } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java index c86d8c52..8b71d6be 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java @@ -552,7 +552,7 @@ public static void createTablesIfNotExists(Start start, Connection con) throws S update(con, TOTPQueries.getQueryToCreateTenantIdIndexForUsedCodesTable(start), NO_OP_SETTER); } - if (!doesTableExists(start, con, Config.getConfig(start).getOAuthClientTable())) { + if (!doesTableExists(start, con, Config.getConfig(start).getOAuthClientsTable())) { getInstance(start).addState(CREATING_NEW_TABLE, null); update(start, OAuthQueries.getQueryToCreateOAuthClientTable(start), NO_OP_SETTER); } @@ -565,6 +565,15 @@ public static void createTablesIfNotExists(Start start, Connection con) throws S update(con, OAuthQueries.getQueryToCreateOAuthRevokeTimestampIndex(start), NO_OP_SETTER); } + if (!doesTableExists(start, con, Config.getConfig(start).getOAuthM2MTokensTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, OAuthQueries.getQueryToCreateOAuthM2MTokensTable(start), NO_OP_SETTER); + + // index + update(con, OAuthQueries.getQueryToCreateOAuthM2MTokenIatIndex(start), NO_OP_SETTER); + update(con, OAuthQueries.getQueryToCreateOAuthM2MTokenExpIndex(start), NO_OP_SETTER); + } + } catch (Exception e) { if (e.getMessage().contains("schema") && e.getMessage().contains("does not exist") && numberOfRetries < 1) { @@ -635,7 +644,9 @@ public static void deleteAllTables(Start start) throws SQLException, StorageQuer + getConfig(start).getUserRolesTable() + "," + getConfig(start).getDashboardUsersTable() + "," + getConfig(start).getDashboardSessionsTable() + "," - + getConfig(start).getOAuthClientTable() + "," + + getConfig(start).getOAuthClientsTable() + "," + + getConfig(start).getOAuthRevokeTable() + "," + + getConfig(start).getOAuthM2MTokensTable() + "," + getConfig(start).getTotpUsedCodesTable() + "," + getConfig(start).getTotpUserDevicesTable() + "," + getConfig(start).getTotpUsersTable(); diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java index 19c6715f..f2f723a5 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java @@ -17,7 +17,7 @@ public class OAuthQueries { public static String getQueryToCreateOAuthClientTable(Start start) { String schema = Config.getConfig(start).getTableSchema(); - String oAuth2ClientTable = Config.getConfig(start).getOAuthClientTable(); + String oAuth2ClientTable = Config.getConfig(start).getOAuthClientsTable(); // @formatter:off return "CREATE TABLE IF NOT EXISTS " + oAuth2ClientTable + " (" + "app_id VARCHAR(64) DEFAULT 'public'," @@ -56,9 +56,39 @@ public static String getQueryToCreateOAuthRevokeTimestampIndex(Start start) { + oAuth2ClientTable + "(timestamp DESC, app_id DESC);"; } + public static String getQueryToCreateOAuthM2MTokensTable(Start start) { + String schema = Config.getConfig(start).getTableSchema(); + String oAuth2ClientTable = Config.getConfig(start).getOAuthM2MTokensTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + oAuth2ClientTable + " (" + + "app_id VARCHAR(64) DEFAULT 'public'," + + "client_id VARCHAR(128) NOT NULL," + + "iat BIGINT NOT NULL," + + "exp BIGINT NOT NULL," + + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "client_id", "pkey") + + " PRIMARY KEY (app_id, client_id, iat)," + + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "app_id", "fkey") + + " FOREIGN KEY(app_id)" + + " REFERENCES " + Config.getConfig(start).getAppsTable() + "(app_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateOAuthM2MTokenIatIndex(Start start) { + String oAuth2ClientTable = Config.getConfig(start).getOAuthM2MTokensTable(); + return "CREATE INDEX IF NOT EXISTS oauth_m2m_token_iat_index ON " + + oAuth2ClientTable + "(iat DESC, app_id DESC);"; + } + + public static String getQueryToCreateOAuthM2MTokenExpIndex(Start start) { + String oAuth2ClientTable = Config.getConfig(start).getOAuthM2MTokensTable(); + return "CREATE INDEX IF NOT EXISTS oauth_m2m_token_exp_index ON " + + oAuth2ClientTable + "(exp DESC, app_id DESC);"; + } + public static boolean isClientIdForAppId(Start start, String clientId, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { - String QUERY = "SELECT app_id FROM " + Config.getConfig(start).getOAuthClientTable() + + String QUERY = "SELECT app_id FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE client_id = ? AND app_id = ?"; return execute(start, QUERY, pst -> { @@ -69,7 +99,7 @@ public static boolean isClientIdForAppId(Start start, String clientId, AppIdenti public static List listClientsForApp(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { - String QUERY = "SELECT client_id FROM " + Config.getConfig(start).getOAuthClientTable() + + String QUERY = "SELECT client_id FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE app_id = ?"; return execute(start, QUERY, pst -> { pst.setString(1, appIdentifier.getAppId()); @@ -85,7 +115,7 @@ public static List listClientsForApp(Start start, AppIdentifier appIdent public static void insertClientIdForAppId(Start start, AppIdentifier appIdentifier, String clientId, boolean isClientCredentialsOnly) throws SQLException, StorageQueryException { - String INSERT = "INSERT INTO " + Config.getConfig(start).getOAuthClientTable() + String INSERT = "INSERT INTO " + Config.getConfig(start).getOAuthClientsTable() + "(app_id, client_id, is_client_credentials_only) VALUES(?, ?, ?) " + "ON CONFLICT (app_id, client_id) DO UPDATE SET is_client_credentials_only = ?"; update(start, INSERT, pst -> { @@ -98,7 +128,7 @@ public static void insertClientIdForAppId(Start start, AppIdentifier appIdentifi public static boolean deleteClientIdForAppId(Start start, String clientId, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { - String DELETE = "DELETE FROM " + Config.getConfig(start).getOAuthClientTable() + String DELETE = "DELETE FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE app_id = ? AND client_id = ?"; int numberOfRow = update(start, DELETE, pst -> { pst.setString(1, appIdentifier.getAppId()); @@ -156,7 +186,7 @@ public static boolean isRevoked(Start start, AppIdentifier appIdentifier, String public static int countTotalNumberOfClientsForApp(Start start, AppIdentifier appIdentifier, boolean filterByClientCredentialsOnly) throws SQLException, StorageQueryException { if (filterByClientCredentialsOnly) { - String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthClientTable() + + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE app_id = ? AND is_client_credentials_only = ?"; return execute(start, QUERY, pst -> { pst.setString(1, appIdentifier.getAppId()); @@ -168,7 +198,7 @@ public static int countTotalNumberOfClientsForApp(Start start, AppIdentifier app return 0; }); } else { - String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthClientTable() + + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE app_id = ?"; return execute(start, QUERY, pst -> { pst.setString(1, appIdentifier.getAppId()); @@ -180,4 +210,46 @@ public static int countTotalNumberOfClientsForApp(Start start, AppIdentifier app }); } } + + public static int countTotalNumberOfM2MTokensAlive(Start start, AppIdentifier appIdentifier) + throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthM2MTokensTable() + + " WHERE app_id = ? AND exp > ?"; + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setLong(2, System.currentTimeMillis()); + }, result -> { + if (result.next()) { + return result.getInt("c"); + } + return 0; + }); + } + + public static int countTotalNumberOfM2MTokensCreatedSince(Start start, AppIdentifier appIdentifier, long since) + throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthM2MTokensTable() + + " WHERE app_id = ? AND iat >= ?"; + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setLong(2, since / 1000); + }, result -> { + if (result.next()) { + return result.getInt("c"); + } + return 0; + }); + } + + public static void addM2MToken(Start start, AppIdentifier appIdentifier, String clientId, long iat, long exp) + throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getOAuthM2MTokensTable() + + " (app_id, client_id, iat, exp) VALUES (?, ?, ?, ?)"; + update(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, clientId); + pst.setLong(3, iat); + pst.setLong(4, exp); + }); + } }