Files
supabase/apps/studio/components/ui/AIAssistantPanel/AIAssistant.utils.ts
Saxon Fletcher 80da153450 Fix for AI Assistant query and deploy confirmation (#46052)
When Assistant requests confirmation to run a query or deploy an edge
function if the user doesn't skip or run and instead sends a follow-up
message it errors out. This allows follow-up messages and treats them as
"skips" which means adjusting confirmation message state as part of the
follow-up. This also uses toModelOutput to cleanse data based on
permissions.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **New Features**
* Enhanced tool approval workflow: pending approvals are now
automatically resolved as denied when submitting new messages
* Improved chat input state management with better handling of approval
states
  * Customizable loading messages for tool operations

* **Bug Fixes**
  * Fixed chat input availability during pending tool approval states
  * Improved tool execution feedback during approval workflows

<!-- review_stack_entry_start -->

[![Review Change
Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/supabase/supabase/pull/46052?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack)

<!-- review_stack_entry_end -->

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-05-20 09:09:28 +10:00

159 lines
5.4 KiB
TypeScript

import { isToolUIPart, type UIMessage } from 'ai'
import { toast } from 'sonner'
import { SAFE_FUNCTIONS } from './AiAssistant.constants'
import { authKeys } from '@/data/auth/keys'
import { databaseExtensionsKeys } from '@/data/database-extensions/keys'
import { databaseIndexesKeys } from '@/data/database-indexes/keys'
import { databasePoliciesKeys } from '@/data/database-policies/keys'
import { databaseTriggerKeys } from '@/data/database-triggers/keys'
import { databaseKeys } from '@/data/database/keys'
import { enumeratedTypesKeys } from '@/data/enumerated-types/keys'
import { handleError } from '@/data/fetchers'
import { tableKeys } from '@/data/tables/keys'
import { tryParseJson } from '@/lib/helpers'
import { ResponseError } from '@/types'
export type MutationCategory = 'functions' | 'rls-policies'
// [Joshen] This is just very basic identification, but possible can extend perhaps
export const identifyQueryType = (query: string): MutationCategory | undefined => {
const formattedQuery = query.toLowerCase().replaceAll('\n', ' ')
if (
formattedQuery.includes('create function') ||
formattedQuery.includes('create or replace function')
) {
return 'functions'
} else if (formattedQuery.includes('create policy') || formattedQuery.includes('alter policy')) {
return 'rls-policies'
}
return undefined
}
// Check for function calls that aren't in the safe list
/** @deprecated [Joshen] Ideally we move away from this as this isn't a scalable way to deduce */
export const containsUnknownFunction = (query: string) => {
const normalizedQuery = query.trim().toLowerCase()
const functionCallRegex = /\w+\s*\(/g
const functionCalls = normalizedQuery.match(functionCallRegex) || []
return functionCalls.some((func) => {
const isReadOnlyFunc = SAFE_FUNCTIONS.some((safeFunc) => func.trim().toLowerCase() === safeFunc)
return !isReadOnlyFunc
})
}
/** @deprecated
* [Joshen] This isn't really a scalable way to reduce this behaviour, we now have support
* for a readonly connection string which we can use this to run queries, and is a much
* clearer way to deduce if the query is read only or not
*/
export const isReadOnlySelect = (query: string): boolean => {
const normalizedQuery = query.trim().toLowerCase()
// Check if it starts with SELECT
if (!normalizedQuery.startsWith('select')) return false
// List of keywords that indicate write operations
const writeOperations = ['insert', 'update', 'delete', 'alter', 'drop', 'create', 'replace']
// Words that may appear in column names etc
const allowedPatterns = ['created', 'inserted', 'updated', 'deleted', 'truncate']
// Check for any write operations
const hasWriteOperation = writeOperations.some((op) => {
// Ignore if part of allowed pattern
const isAllowed = allowedPatterns.some(
(allowed) => normalizedQuery.includes(allowed) && allowed.includes(op)
)
return !isAllowed && normalizedQuery.includes(op)
})
if (hasWriteOperation) return false
const hasUnknownFunction = containsUnknownFunction(normalizedQuery)
if (hasUnknownFunction) return false
return true
}
export const hasPendingToolApproval = (messages: Pick<UIMessage, 'role' | 'parts'>[]) => {
return messages.some((message) => {
if (message.role !== 'assistant') return false
return message.parts?.some((part) => isToolUIPart(part) && part.state === 'approval-requested')
})
}
export const resolvePendingToolApprovalsAsDenied = (messages: UIMessage[]): UIMessage[] => {
return messages.map((message) => {
if (message.role !== 'assistant') return message
const parts = message.parts?.map((part) => {
if (!isToolUIPart(part) || part.state !== 'approval-requested') return part
return {
...part,
state: 'output-denied',
approval: {
id: part.approval.id,
approved: false,
reason: 'Skipped because the user sent a follow-up message.',
},
} as UIMessage['parts'][number]
})
return { ...message, parts } as UIMessage
})
}
const getContextKey = (pathname: string) => {
const [, , , ...rest] = pathname.split('/')
const key = rest.join('/')
return key
}
export const getContextualInvalidationKeys = ({
ref,
pathname,
schema = 'public',
}: {
ref: string
pathname: string
schema?: string
}) => {
const key = getContextKey(pathname)
return (
(
{
'auth/users': [authKeys.usersInfinite(ref)],
'auth/policies': [databasePoliciesKeys.list(ref)],
'database/functions': [databaseKeys.databaseFunctions(ref)],
'database/tables': [tableKeys.list(ref, schema, true), tableKeys.list(ref, schema, false)],
'database/triggers': [databaseTriggerKeys.list(ref)],
'database/types': [enumeratedTypesKeys.list(ref)],
'database/extensions': [databaseExtensionsKeys.list(ref)],
'database/indexes': [databaseIndexesKeys.list(ref, schema)],
} as const
)[key] ?? []
)
}
export const onErrorChat = (error: Error) => {
const parsedError = error ? tryParseJson(error.message) : undefined
try {
handleError(parsedError?.error || parsedError || error)
} catch (e: any) {
if (e instanceof ResponseError) {
toast.error(e.message)
} else if (e instanceof Error) {
toast.error(e.message)
} else if (typeof e === 'string') {
toast.error(e)
} else {
toast.error('An unknown error occurred')
}
}
}