feat(sso): implement SSO provider registration and update related components

- Refactored SSO registration logic in `register-oidc-dialog` and `register-saml-dialog` to use a new mutation method.
- Removed unused imports and error handling for registration failures.
- Added foreign key constraint for `organization_id` in the `sso_provider` table.
- Introduced new SSO schema and updated user relations to include SSO providers.
- Enhanced authentication flow to support SSO provider registration.
This commit is contained in:
Mauricio Siu
2026-01-31 04:43:47 -06:00
parent d22d96105c
commit d5de5b8ad7
11 changed files with 7348 additions and 50 deletions

View File

@@ -1,7 +1,7 @@
"use client";
import { zodResolver } from "@hookform/resolvers/zod";
import { Loader2, Plus, Trash2 } from "lucide-react";
import { Plus, Trash2 } from "lucide-react";
import { useState } from "react";
import type { FieldArrayPath } from "react-hook-form";
import { useFieldArray, useForm } from "react-hook-form";
@@ -27,7 +27,6 @@ import {
FormMessage,
} from "@/components/ui/form";
import { Input } from "@/components/ui/input";
import { authClient } from "@/lib/auth-client";
import { api } from "@/utils/api";
const DEFAULT_SCOPES = ["openid", "email", "profile"];
@@ -74,6 +73,7 @@ const formDefaultValues = {
export function RegisterOidcDialog({ children }: RegisterOidcDialogProps) {
const utils = api.useUtils();
const [open, setOpen] = useState(false);
const { mutateAsync, isLoading } = api.sso.register.useMutation();
const form = useForm<OidcProviderForm>({
resolver: zodResolver(oidcProviderSchema),
@@ -105,7 +105,7 @@ export function RegisterOidcDialog({ children }: RegisterOidcDialogProps) {
.map((d) => d.trim())
.filter(Boolean)
.join(",");
const { error } = await authClient.sso.register({
await mutateAsync({
providerId: data.providerId,
issuer: data.issuer,
domain,
@@ -125,11 +125,6 @@ export function RegisterOidcDialog({ children }: RegisterOidcDialogProps) {
},
});
if (error) {
toast.error(error.message ?? "Failed to register SSO provider");
return;
}
toast.success("OIDC provider registered successfully");
form.reset(formDefaultValues);
setOpen(false);
@@ -340,10 +335,7 @@ export function RegisterOidcDialog({ children }: RegisterOidcDialogProps) {
>
Cancel
</Button>
<Button type="submit" disabled={isSubmitting}>
{isSubmitting && (
<Loader2 className="mr-2 size-4 animate-spin" />
)}
<Button type="submit" isLoading={isLoading}>
Register provider
</Button>
</DialogFooter>

View File

@@ -1,7 +1,7 @@
"use client";
import { zodResolver } from "@hookform/resolvers/zod";
import { Loader2, Plus, Trash2 } from "lucide-react";
import { Plus, Trash2 } from "lucide-react";
import { useState } from "react";
import { type FieldArrayPath, useFieldArray, useForm } from "react-hook-form";
import { toast } from "sonner";
@@ -27,7 +27,6 @@ import {
} from "@/components/ui/form";
import { Input } from "@/components/ui/input";
import { Textarea } from "@/components/ui/textarea";
import { authClient } from "@/lib/auth-client";
import { api } from "@/utils/api";
const domainsArraySchema = z
@@ -80,6 +79,7 @@ const formDefaultValues: SamlProviderForm = {
export function RegisterSamlDialog({ children }: RegisterSamlDialogProps) {
const utils = api.useUtils();
const [open, setOpen] = useState(false);
const { mutateAsync, isLoading } = api.sso.register.useMutation();
const form = useForm<SamlProviderForm>({
resolver: zodResolver(samlProviderSchema),
@@ -99,7 +99,7 @@ export function RegisterSamlDialog({ children }: RegisterSamlDialogProps) {
.map((d) => d.trim())
.filter(Boolean)
.join(",");
const { error } = await authClient.sso.register({
await mutateAsync({
providerId: data.providerId,
issuer: data.issuer,
domain,
@@ -117,11 +117,6 @@ export function RegisterSamlDialog({ children }: RegisterSamlDialogProps) {
},
});
if (error) {
toast.error(error.message ?? "Failed to register SAML provider");
return;
}
toast.success("SAML provider registered successfully");
form.reset(formDefaultValues);
setOpen(false);
@@ -315,10 +310,7 @@ export function RegisterSamlDialog({ children }: RegisterSamlDialogProps) {
>
Cancel
</Button>
<Button type="submit" disabled={isSubmitting}>
{isSubmitting && (
<Loader2 className="mr-2 size-4 animate-spin" />
)}
<Button type="submit" isLoading={isLoading}>
Register provider
</Button>
</DialogFooter>

View File

@@ -0,0 +1 @@
ALTER TABLE "sso_provider" ADD CONSTRAINT "sso_provider_organization_id_organization_id_fk" FOREIGN KEY ("organization_id") REFERENCES "public"."organization"("id") ON DELETE cascade ON UPDATE no action;

File diff suppressed because it is too large Load Diff

View File

@@ -981,6 +981,13 @@
"when": 1769746948088,
"tag": "0139_smiling_havok",
"breakpoints": true
},
{
"idx": 140,
"version": "7",
"when": 1769854977685,
"tag": "0140_great_lightspeed",
"breakpoints": true
}
]
}

View File

@@ -1,5 +1,7 @@
import { IS_CLOUD } from "@dokploy/server/constants";
import { member, ssoProvider } from "@dokploy/server/db/schema";
import { ssoProviderBodySchema } from "@dokploy/server/db/schema/sso";
import { auth } from "@dokploy/server/lib/auth";
import { TRPCError } from "@trpc/server";
import { and, asc, eq } from "drizzle-orm";
import { z } from "zod";
@@ -10,6 +12,20 @@ import {
} from "@/server/api/trpc";
import { db } from "@/server/db";
function requestToHeaders(req: {
headers?: Record<string, string | string[] | undefined>;
}): Headers {
const headers = new Headers();
if (req?.headers) {
for (const [key, value] of Object.entries(req.headers)) {
if (value !== undefined && key.toLowerCase() !== "host") {
headers.set(key, Array.isArray(value) ? value.join(", ") : value);
}
}
}
return headers;
}
export const ssoRouter = createTRPCRouter({
showSignInWithSSO: publicProcedure.query(async () => {
if (IS_CLOUD) {
@@ -38,7 +54,7 @@ export const ssoRouter = createTRPCRouter({
}),
listProviders: enterpriseProcedure.query(async ({ ctx }) => {
const providers = await db.query.ssoProvider.findMany({
where: eq(ssoProvider.userId, ctx.user.id),
where: eq(ssoProvider.organizationId, ctx.session.activeOrganizationId),
columns: {
id: true,
providerId: true,
@@ -59,7 +75,7 @@ export const ssoRouter = createTRPCRouter({
.where(
and(
eq(ssoProvider.providerId, input.providerId),
eq(ssoProvider.userId, ctx.user.id),
eq(ssoProvider.organizationId, ctx.session.activeOrganizationId),
),
)
.returning({ id: ssoProvider.id });
@@ -72,6 +88,22 @@ export const ssoRouter = createTRPCRouter({
});
}
return { success: true };
}),
register: enterpriseProcedure
.input(ssoProviderBodySchema)
.mutation(async ({ ctx, input }) => {
const organizationId = ctx.session.activeOrganizationId;
const result = await auth.registerSSOProvider({
body: {
...input,
organizationId,
},
headers: requestToHeaders(ctx.req),
});
console.log(result);
return { success: true };
}),
});

View File

@@ -9,6 +9,7 @@ import {
import { nanoid } from "nanoid";
import { projects } from "./project";
import { server } from "./server";
import { ssoProvider } from "./sso";
import { user } from "./user";
export const account = pgTable("account", {
@@ -78,6 +79,7 @@ export const organizationRelations = relations(
servers: many(server),
projects: many(projects),
members: many(member),
ssoProviders: many(ssoProvider),
}),
);
@@ -203,21 +205,3 @@ export const apikeyRelations = relations(apikey, ({ one }) => ({
references: [user.id],
}),
}));
export const ssoProvider = pgTable("sso_provider", {
id: text("id").primaryKey(),
issuer: text("issuer").notNull(),
oidcConfig: text("oidc_config"),
samlConfig: text("saml_config"),
userId: text("user_id").references(() => user.id, { onDelete: "cascade" }),
providerId: text("provider_id").notNull().unique(),
organizationId: text("organization_id"),
domain: text("domain").notNull(),
});
export const ssoProviderRelations = relations(ssoProvider, ({ one }) => ({
user: one(user, {
fields: [ssoProvider.userId],
references: [user.id],
}),
}));

View File

@@ -32,6 +32,7 @@ export * from "./server";
export * from "./session";
export * from "./shared";
export * from "./ssh-key";
export * from "./sso";
export * from "./user";
export * from "./utils";
export * from "./volume-backups";

View File

@@ -0,0 +1,121 @@
import { relations } from "drizzle-orm";
import { pgTable, text } from "drizzle-orm/pg-core";
import { z } from "zod";
import { organization } from "./account";
import { user } from "./user";
export const ssoProvider = pgTable("sso_provider", {
id: text("id").primaryKey(),
issuer: text("issuer").notNull(),
oidcConfig: text("oidc_config"),
samlConfig: text("saml_config"),
providerId: text("provider_id").notNull().unique(),
userId: text("user_id").references(() => user.id, { onDelete: "cascade" }),
organizationId: text("organization_id").references(() => organization.id, {
onDelete: "cascade",
}),
domain: text("domain").notNull(),
});
export const ssoProviderRelations = relations(ssoProvider, ({ one }) => ({
organization: one(organization, {
fields: [ssoProvider.organizationId],
references: [organization.id],
}),
user: one(user, {
fields: [ssoProvider.userId],
references: [user.id],
}),
}));
export const ssoProviderBodySchema = z.object({
providerId: z.string({}),
issuer: z.string({}),
domain: z.string({}),
oidcConfig: z
.object({
clientId: z.string({}),
clientSecret: z.string({}),
authorizationEndpoint: z.string({}).optional(),
tokenEndpoint: z.string({}).optional(),
userInfoEndpoint: z.string({}).optional(),
tokenEndpointAuthentication: z
.enum(["client_secret_post", "client_secret_basic"])
.optional(),
jwksEndpoint: z.string({}).optional(),
discoveryEndpoint: z.string().optional(),
skipDiscovery: z.boolean().optional(),
scopes: z.array(z.string()).optional(),
pkce: z.boolean().default(true).optional(),
mapping: z
.object({
id: z.string({}),
email: z.string({}),
emailVerified: z.string({}).optional(),
name: z.string({}),
image: z.string({}).optional(),
extraFields: z.record(z.string(), z.any()).optional(),
})
.optional(),
})
.optional(),
samlConfig: z
.object({
entryPoint: z.string({}),
cert: z.string({}),
callbackUrl: z.string({}),
audience: z.string().optional(),
idpMetadata: z
.object({
metadata: z.string().optional(),
entityID: z.string().optional(),
cert: z.string().optional(),
privateKey: z.string().optional(),
privateKeyPass: z.string().optional(),
isAssertionEncrypted: z.boolean().optional(),
encPrivateKey: z.string().optional(),
encPrivateKeyPass: z.string().optional(),
singleSignOnService: z
.array(
z.object({
Binding: z.string(),
Location: z.string(),
}),
)
.optional(),
})
.optional(),
spMetadata: z.object({
metadata: z.string().optional(),
entityID: z.string().optional(),
binding: z.string().optional(),
privateKey: z.string().optional(),
privateKeyPass: z.string().optional(),
isAssertionEncrypted: z.boolean().optional(),
encPrivateKey: z.string().optional(),
encPrivateKeyPass: z.string().optional(),
}),
wantAssertionsSigned: z.boolean().optional(),
authnRequestsSigned: z.boolean().optional(),
signatureAlgorithm: z.string().optional(),
digestAlgorithm: z.string().optional(),
identifierFormat: z.string().optional(),
privateKey: z.string().optional(),
decryptionPvk: z.string().optional(),
additionalParams: z.record(z.string(), z.any()).optional(),
mapping: z
.object({
id: z.string({}),
email: z.string({}),
emailVerified: z.string({}).optional(),
name: z.string({}),
firstName: z.string({}).optional(),
lastName: z.string({}).optional(),
extraFields: z.record(z.string(), z.any()).optional(),
})
.optional(),
})
.optional(),
organizationId: z.string({}).optional(),
overrideUserInfo: z.boolean({}).default(false).optional(),
});

View File

@@ -10,10 +10,11 @@ import {
import { createInsertSchema } from "drizzle-zod";
import { nanoid } from "nanoid";
import { z } from "zod";
import { account, apikey, organization, ssoProvider } from "./account";
import { account, apikey, organization } from "./account";
import { backups } from "./backups";
import { projects } from "./project";
import { schedules } from "./schedule";
import { ssoProvider } from "./sso";
/**
* This is an example of how to use the multi-project schema feature of Drizzle ORM. Use the same
* database instance for multiple projects.
@@ -72,9 +73,9 @@ export const usersRelations = relations(user, ({ one, many }) => ({
references: [account.userId],
}),
organizations: many(organization),
ssoProviders: many(ssoProvider),
projects: many(projects),
apiKeys: many(apikey),
ssoProviders: many(ssoProvider),
backups: many(backups),
schedules: many(schedules),
}));

View File

@@ -24,6 +24,7 @@ export const { handler, api } = betterAuth({
provider: "pg",
schema: schema,
}),
disabledPaths: ["/sso/register"],
appName: "Dokploy",
socialProviders: {
github: {
@@ -55,6 +56,7 @@ export const { handler, api } = betterAuth({
? [
"http://localhost:3000",
"https://absolutely-handy-falcon.ngrok-free.app",
"https://dev-pee8hhc3qbjlqedb.us.auth0.com",
]
: []),
];
@@ -113,7 +115,7 @@ export const { handler, api } = betterAuth({
}
} else {
const isSSORequest = context?.path.includes("/sso/callback");
if (isSSORequest) {
if (!isSSORequest) {
return;
}
const isAdminPresent = await db.query.member.findFirst({
@@ -184,9 +186,7 @@ export const { handler, api } = betterAuth({
isDefault: true, // Mark first organization as default
});
});
}
if (isSSORequest) {
} else if (isSSORequest) {
const providerId = context?.params?.providerId;
if (!providerId) {
throw new APIError("BAD_REQUEST", {
@@ -310,6 +310,7 @@ export const { handler, api } = betterAuth({
export const auth = {
handler,
createApiKey: api.createApiKey,
registerSSOProvider: api.registerSSOProvider,
};
export const validateRequest = async (request: IncomingMessage) => {