mirror of
https://github.com/supabase/supabase.git
synced 2026-06-15 08:05:21 +08:00
When Assistant requests confirmation to run a query or deploy an edge function if the user doesn't skip or run and instead sends a follow-up message it errors out. This allows follow-up messages and treats them as "skips" which means adjusting confirmation message state as part of the follow-up. This also uses toModelOutput to cleanse data based on permissions. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Enhanced tool approval workflow: pending approvals are now automatically resolved as denied when submitting new messages * Improved chat input state management with better handling of approval states * Customizable loading messages for tool operations * **Bug Fixes** * Fixed chat input availability during pending tool approval states * Improved tool execution feedback during approval workflows <!-- review_stack_entry_start --> [](https://app.coderabbit.ai/change-stack/supabase/supabase/pull/46052?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
159 lines
5.4 KiB
TypeScript
159 lines
5.4 KiB
TypeScript
import { isToolUIPart, type UIMessage } from 'ai'
|
|
import { toast } from 'sonner'
|
|
|
|
import { SAFE_FUNCTIONS } from './AiAssistant.constants'
|
|
import { authKeys } from '@/data/auth/keys'
|
|
import { databaseExtensionsKeys } from '@/data/database-extensions/keys'
|
|
import { databaseIndexesKeys } from '@/data/database-indexes/keys'
|
|
import { databasePoliciesKeys } from '@/data/database-policies/keys'
|
|
import { databaseTriggerKeys } from '@/data/database-triggers/keys'
|
|
import { databaseKeys } from '@/data/database/keys'
|
|
import { enumeratedTypesKeys } from '@/data/enumerated-types/keys'
|
|
import { handleError } from '@/data/fetchers'
|
|
import { tableKeys } from '@/data/tables/keys'
|
|
import { tryParseJson } from '@/lib/helpers'
|
|
import { ResponseError } from '@/types'
|
|
|
|
export type MutationCategory = 'functions' | 'rls-policies'
|
|
|
|
// [Joshen] This is just very basic identification, but possible can extend perhaps
|
|
export const identifyQueryType = (query: string): MutationCategory | undefined => {
|
|
const formattedQuery = query.toLowerCase().replaceAll('\n', ' ')
|
|
if (
|
|
formattedQuery.includes('create function') ||
|
|
formattedQuery.includes('create or replace function')
|
|
) {
|
|
return 'functions'
|
|
} else if (formattedQuery.includes('create policy') || formattedQuery.includes('alter policy')) {
|
|
return 'rls-policies'
|
|
}
|
|
return undefined
|
|
}
|
|
|
|
// Check for function calls that aren't in the safe list
|
|
/** @deprecated [Joshen] Ideally we move away from this as this isn't a scalable way to deduce */
|
|
export const containsUnknownFunction = (query: string) => {
|
|
const normalizedQuery = query.trim().toLowerCase()
|
|
const functionCallRegex = /\w+\s*\(/g
|
|
const functionCalls = normalizedQuery.match(functionCallRegex) || []
|
|
|
|
return functionCalls.some((func) => {
|
|
const isReadOnlyFunc = SAFE_FUNCTIONS.some((safeFunc) => func.trim().toLowerCase() === safeFunc)
|
|
return !isReadOnlyFunc
|
|
})
|
|
}
|
|
|
|
/** @deprecated
|
|
* [Joshen] This isn't really a scalable way to reduce this behaviour, we now have support
|
|
* for a readonly connection string which we can use this to run queries, and is a much
|
|
* clearer way to deduce if the query is read only or not
|
|
*/
|
|
export const isReadOnlySelect = (query: string): boolean => {
|
|
const normalizedQuery = query.trim().toLowerCase()
|
|
|
|
// Check if it starts with SELECT
|
|
if (!normalizedQuery.startsWith('select')) return false
|
|
|
|
// List of keywords that indicate write operations
|
|
const writeOperations = ['insert', 'update', 'delete', 'alter', 'drop', 'create', 'replace']
|
|
|
|
// Words that may appear in column names etc
|
|
const allowedPatterns = ['created', 'inserted', 'updated', 'deleted', 'truncate']
|
|
|
|
// Check for any write operations
|
|
const hasWriteOperation = writeOperations.some((op) => {
|
|
// Ignore if part of allowed pattern
|
|
const isAllowed = allowedPatterns.some(
|
|
(allowed) => normalizedQuery.includes(allowed) && allowed.includes(op)
|
|
)
|
|
return !isAllowed && normalizedQuery.includes(op)
|
|
})
|
|
if (hasWriteOperation) return false
|
|
|
|
const hasUnknownFunction = containsUnknownFunction(normalizedQuery)
|
|
if (hasUnknownFunction) return false
|
|
|
|
return true
|
|
}
|
|
|
|
export const hasPendingToolApproval = (messages: Pick<UIMessage, 'role' | 'parts'>[]) => {
|
|
return messages.some((message) => {
|
|
if (message.role !== 'assistant') return false
|
|
|
|
return message.parts?.some((part) => isToolUIPart(part) && part.state === 'approval-requested')
|
|
})
|
|
}
|
|
|
|
export const resolvePendingToolApprovalsAsDenied = (messages: UIMessage[]): UIMessage[] => {
|
|
return messages.map((message) => {
|
|
if (message.role !== 'assistant') return message
|
|
|
|
const parts = message.parts?.map((part) => {
|
|
if (!isToolUIPart(part) || part.state !== 'approval-requested') return part
|
|
|
|
return {
|
|
...part,
|
|
state: 'output-denied',
|
|
approval: {
|
|
id: part.approval.id,
|
|
approved: false,
|
|
reason: 'Skipped because the user sent a follow-up message.',
|
|
},
|
|
} as UIMessage['parts'][number]
|
|
})
|
|
|
|
return { ...message, parts } as UIMessage
|
|
})
|
|
}
|
|
|
|
const getContextKey = (pathname: string) => {
|
|
const [, , , ...rest] = pathname.split('/')
|
|
const key = rest.join('/')
|
|
return key
|
|
}
|
|
|
|
export const getContextualInvalidationKeys = ({
|
|
ref,
|
|
pathname,
|
|
schema = 'public',
|
|
}: {
|
|
ref: string
|
|
pathname: string
|
|
schema?: string
|
|
}) => {
|
|
const key = getContextKey(pathname)
|
|
|
|
return (
|
|
(
|
|
{
|
|
'auth/users': [authKeys.usersInfinite(ref)],
|
|
'auth/policies': [databasePoliciesKeys.list(ref)],
|
|
'database/functions': [databaseKeys.databaseFunctions(ref)],
|
|
'database/tables': [tableKeys.list(ref, schema, true), tableKeys.list(ref, schema, false)],
|
|
'database/triggers': [databaseTriggerKeys.list(ref)],
|
|
'database/types': [enumeratedTypesKeys.list(ref)],
|
|
'database/extensions': [databaseExtensionsKeys.list(ref)],
|
|
'database/indexes': [databaseIndexesKeys.list(ref, schema)],
|
|
} as const
|
|
)[key] ?? []
|
|
)
|
|
}
|
|
|
|
export const onErrorChat = (error: Error) => {
|
|
const parsedError = error ? tryParseJson(error.message) : undefined
|
|
|
|
try {
|
|
handleError(parsedError?.error || parsedError || error)
|
|
} catch (e: any) {
|
|
if (e instanceof ResponseError) {
|
|
toast.error(e.message)
|
|
} else if (e instanceof Error) {
|
|
toast.error(e.message)
|
|
} else if (typeof e === 'string') {
|
|
toast.error(e)
|
|
} else {
|
|
toast.error('An unknown error occurred')
|
|
}
|
|
}
|
|
}
|