diff --git a/wren-ui/migrations/20240711021133_create_thread_response_explain_table.js b/wren-ui/migrations/20240711021133_create_thread_response_explain_table.js new file mode 100644 index 000000000..64b084395 --- /dev/null +++ b/wren-ui/migrations/20240711021133_create_thread_response_explain_table.js @@ -0,0 +1,33 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = function (knex) { + return knex.schema.createTable('thread_response_explain', (table) => { + table.increments('id').comment('ID'); + table + .integer('thread_response_id') + .comment('Reference to thread_response.id'); + table + .foreign('thread_response_id') + .references('thread_response.id') + .onDelete('CASCADE'); + + table.string('query_id').nullable(); + table.string('status').nullable(); + table.jsonb('detail').nullable(); + table.jsonb('error').nullable(); + table.jsonb('analysis').nullable(); + + // timestamps + table.timestamps(true, true); + }); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = function (knex) { + return knex.schema.dropTable('thread_response_explain'); +}; diff --git a/wren-ui/migrations/20240711082655_update_thread_response_table.js b/wren-ui/migrations/20240711082655_update_thread_response_table.js new file mode 100644 index 000000000..8cf1fa33b --- /dev/null +++ b/wren-ui/migrations/20240711082655_update_thread_response_table.js @@ -0,0 +1,22 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table + .jsonb('corrections') + .nullable() + .comment('the corrections of the previous thread response'); // [{type, id, correct}, ...] + }); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table.dropColumn('corrections'); + }); +}; diff --git a/wren-ui/migrations/20240718090506_data_migrate_thread_response_explain.js b/wren-ui/migrations/20240718090506_data_migrate_thread_response_explain.js new file mode 100644 index 000000000..ccb9a00c9 --- /dev/null +++ b/wren-ui/migrations/20240718090506_data_migrate_thread_response_explain.js @@ -0,0 +1,33 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = async function (knex) { + const threadResponses = await knex('thread_response').select('*'); + if (threadResponses.length === 0) { + return; + } + const explainData = threadResponses.map((threadResponse) => { + const error = { + code: 'OLD_VERSION', + message: + 'Question asked before v0.8.0. Click "Retry" to generate manually.', + }; + return { + thread_response_id: threadResponse.id, + status: 'FAILED', + error: process.env.DB_TYPE === 'pg' ? error : JSON.stringify(error), + }; + }); + + await knex('thread_response_explain').insert(explainData); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = async function (knex) { + // remove all data + await knex('thread_response_explain').delete(); +}; diff --git a/wren-ui/package.json b/wren-ui/package.json index 33f1dc17b..e20f77c13 100644 --- a/wren-ui/package.json +++ b/wren-ui/package.json @@ -60,6 +60,7 @@ "@typescript-eslint/parser": "6.18.0", "ace-builds": "^1.32.3", "antd": "4.20.4", + "clsx": "^2.1.1", "dayjs": "^1.11.11", "duckdb": "^0.10.1", "duckdb-async": "^0.10.0", diff --git a/wren-ui/src/apollo/client/graphql/__types__.ts b/wren-ui/src/apollo/client/graphql/__types__.ts index 9310c0d43..3691806fa 100644 --- a/wren-ui/src/apollo/client/graphql/__types__.ts +++ b/wren-ui/src/apollo/client/graphql/__types__.ts @@ -71,6 +71,14 @@ export type ConnectionInfo = { username?: Maybe; }; +export type CorrectionDetail = { + __typename?: 'CorrectionDetail'; + correction: Scalars['String']; + id: Scalars['Int']; + referenceNum: Scalars['Int']; + type: ReferenceType; +}; + export type CreateCalculatedFieldInput = { expression: ExpressionName; lineage: Array; @@ -78,6 +86,11 @@ export type CreateCalculatedFieldInput = { name: Scalars['String']; }; +export type CreateCorrectedThreadResponseInput = { + corrections: Array; + responseId: Scalars['Int']; +}; + export type CreateModelInput = { fields: Array; primaryKey?: InputMaybe; @@ -104,6 +117,19 @@ export type CreateThreadInput = { viewId?: InputMaybe; }; +export type CreateThreadResponseCorrectionInput = { + correction: Scalars['String']; + id: Scalars['Int']; + reference: Scalars['String']; + referenceNum: Scalars['Int']; + stepIndex: Scalars['Int']; + type: ReferenceType; +}; + +export type CreateThreadResponseExplainWhereInput = { + responseId: Scalars['Int']; +}; + export type CreateThreadResponseInput = { question?: InputMaybe; sql?: InputMaybe; @@ -142,9 +168,19 @@ export enum DataSourceName { POSTGRES = 'POSTGRES' } +export type DetailReference = { + __typename?: 'DetailReference'; + referenceId?: Maybe; + sqlLocation?: Maybe; + sqlSnippet?: Maybe; + summary: Scalars['String']; + type: ReferenceType; +}; + export type DetailStep = { __typename?: 'DetailStep'; cteName?: Maybe; + references?: Maybe>>; sql: Scalars['String']; summary: Scalars['String']; }; @@ -324,6 +360,13 @@ export type Error = { stacktrace?: Maybe>>; }; +export enum ExplainTaskStatus { + FAILED = 'FAILED', + FINISHED = 'FINISHED', + GENERATING = 'GENERATING', + UNDERSTANDING = 'UNDERSTANDING' +} + export enum ExpressionName { ABS = 'ABS', AVG = 'AVG', @@ -399,10 +442,12 @@ export type Mutation = { cancelAskingTask: Scalars['Boolean']; createAskingTask: Task; createCalculatedField: Scalars['JSON']; + createCorrectedThreadResponse: ThreadResponse; createModel: Scalars['JSON']; createRelation: Scalars['JSON']; createThread: Thread; createThreadResponse: ThreadResponse; + createThreadResponseExplain: Scalars['JSON']; createView: ViewInfo; deleteCalculatedField: Scalars['Boolean']; deleteModel: Scalars['Boolean']; @@ -448,6 +493,12 @@ export type MutationCreateCalculatedFieldArgs = { }; +export type MutationCreateCorrectedThreadResponseArgs = { + data: CreateCorrectedThreadResponseInput; + threadId: Scalars['Int']; +}; + + export type MutationCreateModelArgs = { data: CreateModelInput; }; @@ -469,6 +520,11 @@ export type MutationCreateThreadResponseArgs = { }; +export type MutationCreateThreadResponseExplainArgs = { + where: CreateThreadResponseExplainWhereInput; +}; + + export type MutationCreateViewArgs = { data: CreateViewInput; }; @@ -704,6 +760,20 @@ export type RecommendRelations = { relations: Array>; }; +export type ReferenceSqlLocation = { + __typename?: 'ReferenceSQLLocation'; + column: Scalars['Int']; + line: Scalars['Int']; +}; + +export enum ReferenceType { + FIELD = 'FIELD', + FILTER = 'FILTER', + GROUP_BY = 'GROUP_BY', + QUERY_FROM = 'QUERY_FROM', + SORTING = 'SORTING' +} + export type Relation = { __typename?: 'Relation'; fromColumnId: Scalars['Int']; @@ -827,8 +897,10 @@ export type Thread = { export type ThreadResponse = { __typename?: 'ThreadResponse'; + corrections?: Maybe>; detail?: Maybe; error?: Maybe; + explain?: Maybe; id: Scalars['Int']; question: Scalars['String']; status: AskingTaskStatus; @@ -843,6 +915,13 @@ export type ThreadResponseDetail = { view?: Maybe; }; +export type ThreadResponseExplainInfo = { + __typename?: 'ThreadResponseExplainInfo'; + error?: Maybe; + queryId?: Maybe; + status?: Maybe; +}; + export type ThreadUniqueWhereInput = { id: Scalars['Int']; }; diff --git a/wren-ui/src/apollo/client/graphql/home.generated.ts b/wren-ui/src/apollo/client/graphql/home.generated.ts index 483a252a3..d9466c129 100644 --- a/wren-ui/src/apollo/client/graphql/home.generated.ts +++ b/wren-ui/src/apollo/client/graphql/home.generated.ts @@ -5,7 +5,7 @@ import * as Apollo from '@apollo/client'; const defaultOptions = {} as const; export type CommonErrorFragment = { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null }; -export type CommonResponseFragment = { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null }; +export type CommonResponseFragment = { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null, references?: Array<{ __typename?: 'DetailReference', referenceId?: number | null, summary: string, type: Types.ReferenceType, sqlSnippet?: string | null, sqlLocation?: { __typename?: 'ReferenceSQLLocation', column: number, line: number } | null } | null> | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, explain?: { __typename?: 'ThreadResponseExplainInfo', queryId?: string | null, status?: Types.ExplainTaskStatus | null, error?: any | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, referenceNum: number, correction: string }> | null }; export type SuggestedQuestionsQueryVariables = Types.Exact<{ [key: string]: never; }>; @@ -29,14 +29,14 @@ export type ThreadQueryVariables = Types.Exact<{ }>; -export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, sql: string, summary: string, responses: Array<{ __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null }> } }; +export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, sql: string, summary: string, responses: Array<{ __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null, references?: Array<{ __typename?: 'DetailReference', referenceId?: number | null, summary: string, type: Types.ReferenceType, sqlSnippet?: string | null, sqlLocation?: { __typename?: 'ReferenceSQLLocation', column: number, line: number } | null } | null> | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, explain?: { __typename?: 'ThreadResponseExplainInfo', queryId?: string | null, status?: Types.ExplainTaskStatus | null, error?: any | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, referenceNum: number, correction: string }> | null }> } }; export type ThreadResponseQueryVariables = Types.Exact<{ responseId: Types.Scalars['Int']; }>; -export type ThreadResponseQuery = { __typename?: 'Query', threadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null } }; +export type ThreadResponseQuery = { __typename?: 'Query', threadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null, references?: Array<{ __typename?: 'DetailReference', referenceId?: number | null, summary: string, type: Types.ReferenceType, sqlSnippet?: string | null, sqlLocation?: { __typename?: 'ReferenceSQLLocation', column: number, line: number } | null } | null> | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, explain?: { __typename?: 'ThreadResponseExplainInfo', queryId?: string | null, status?: Types.ExplainTaskStatus | null, error?: any | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, referenceNum: number, correction: string }> | null } }; export type CreateAskingTaskMutationVariables = Types.Exact<{ data: Types.AskingTaskInput; @@ -65,7 +65,14 @@ export type CreateThreadResponseMutationVariables = Types.Exact<{ }>; -export type CreateThreadResponseMutation = { __typename?: 'Mutation', createThreadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null } }; +export type CreateThreadResponseMutation = { __typename?: 'Mutation', createThreadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null, references?: Array<{ __typename?: 'DetailReference', referenceId?: number | null, summary: string, type: Types.ReferenceType, sqlSnippet?: string | null, sqlLocation?: { __typename?: 'ReferenceSQLLocation', column: number, line: number } | null } | null> | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, explain?: { __typename?: 'ThreadResponseExplainInfo', queryId?: string | null, status?: Types.ExplainTaskStatus | null, error?: any | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, referenceNum: number, correction: string }> | null } }; + +export type CreateThreadResponseExplainMutationVariables = Types.Exact<{ + where: Types.CreateThreadResponseExplainWhereInput; +}>; + + +export type CreateThreadResponseExplainMutation = { __typename?: 'Mutation', createThreadResponseExplain: any }; export type UpdateThreadMutationVariables = Types.Exact<{ where: Types.ThreadUniqueWhereInput; @@ -96,6 +103,14 @@ export type GetNativeSqlQueryVariables = Types.Exact<{ export type GetNativeSqlQuery = { __typename?: 'Query', nativeSql: string }; +export type CreateCorrectedThreadResponseMutationVariables = Types.Exact<{ + threadId: Types.Scalars['Int']; + data: Types.CreateCorrectedThreadResponseInput; +}>; + + +export type CreateCorrectedThreadResponseMutation = { __typename?: 'Mutation', createCorrectedThreadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null, references?: Array<{ __typename?: 'DetailReference', referenceId?: number | null, summary: string, type: Types.ReferenceType, sqlSnippet?: string | null, sqlLocation?: { __typename?: 'ReferenceSQLLocation', column: number, line: number } | null } | null> | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, explain?: { __typename?: 'ThreadResponseExplainInfo', queryId?: string | null, status?: Types.ExplainTaskStatus | null, error?: any | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, referenceNum: number, correction: string }> | null } }; + export const CommonErrorFragmentDoc = gql` fragment CommonError on Error { code @@ -117,6 +132,16 @@ export const CommonResponseFragmentDoc = gql` summary sql cteName + references { + referenceId + summary + type + sqlSnippet + sqlLocation { + column + line + } + } } view { id @@ -125,6 +150,17 @@ export const CommonResponseFragmentDoc = gql` displayName } } + explain { + queryId + status + error + } + corrections { + id + type + referenceNum + correction + } } `; export const SuggestedQuestionsDocument = gql` @@ -435,14 +471,12 @@ export const CreateThreadResponseDocument = gql` createThreadResponse(threadId: $threadId, data: $data) { ...CommonResponse error { - code - shortMessage - message - stacktrace + ...CommonError } } } - ${CommonResponseFragmentDoc}`; + ${CommonResponseFragmentDoc} +${CommonErrorFragmentDoc}`; export type CreateThreadResponseMutationFn = Apollo.MutationFunction; /** @@ -470,6 +504,37 @@ export function useCreateThreadResponseMutation(baseOptions?: Apollo.MutationHoo export type CreateThreadResponseMutationHookResult = ReturnType; export type CreateThreadResponseMutationResult = Apollo.MutationResult; export type CreateThreadResponseMutationOptions = Apollo.BaseMutationOptions; +export const CreateThreadResponseExplainDocument = gql` + mutation CreateThreadResponseExplain($where: CreateThreadResponseExplainWhereInput!) { + createThreadResponseExplain(where: $where) +} + `; +export type CreateThreadResponseExplainMutationFn = Apollo.MutationFunction; + +/** + * __useCreateThreadResponseExplainMutation__ + * + * To run a mutation, you first call `useCreateThreadResponseExplainMutation` within a React component and pass it any options that fit your needs. + * When your component renders, `useCreateThreadResponseExplainMutation` returns a tuple that includes: + * - A mutate function that you can call at any time to execute the mutation + * - An object with fields that represent the current status of the mutation's execution + * + * @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2; + * + * @example + * const [createThreadResponseExplainMutation, { data, loading, error }] = useCreateThreadResponseExplainMutation({ + * variables: { + * where: // value for 'where' + * }, + * }); + */ +export function useCreateThreadResponseExplainMutation(baseOptions?: Apollo.MutationHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useMutation(CreateThreadResponseExplainDocument, options); + } +export type CreateThreadResponseExplainMutationHookResult = ReturnType; +export type CreateThreadResponseExplainMutationResult = Apollo.MutationResult; +export type CreateThreadResponseExplainMutationOptions = Apollo.BaseMutationOptions; export const UpdateThreadDocument = gql` mutation UpdateThread($where: ThreadUniqueWhereInput!, $data: UpdateThreadInput!) { updateThread(where: $where, data: $data) { @@ -600,4 +665,42 @@ export function useGetNativeSqlLazyQuery(baseOptions?: Apollo.LazyQueryHookOptio } export type GetNativeSqlQueryHookResult = ReturnType; export type GetNativeSqlLazyQueryHookResult = ReturnType; -export type GetNativeSqlQueryResult = Apollo.QueryResult; \ No newline at end of file +export type GetNativeSqlQueryResult = Apollo.QueryResult; +export const CreateCorrectedThreadResponseDocument = gql` + mutation CreateCorrectedThreadResponse($threadId: Int!, $data: CreateCorrectedThreadResponseInput!) { + createCorrectedThreadResponse(threadId: $threadId, data: $data) { + ...CommonResponse + error { + ...CommonError + } + } +} + ${CommonResponseFragmentDoc} +${CommonErrorFragmentDoc}`; +export type CreateCorrectedThreadResponseMutationFn = Apollo.MutationFunction; + +/** + * __useCreateCorrectedThreadResponseMutation__ + * + * To run a mutation, you first call `useCreateCorrectedThreadResponseMutation` within a React component and pass it any options that fit your needs. + * When your component renders, `useCreateCorrectedThreadResponseMutation` returns a tuple that includes: + * - A mutate function that you can call at any time to execute the mutation + * - An object with fields that represent the current status of the mutation's execution + * + * @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2; + * + * @example + * const [createCorrectedThreadResponseMutation, { data, loading, error }] = useCreateCorrectedThreadResponseMutation({ + * variables: { + * threadId: // value for 'threadId' + * data: // value for 'data' + * }, + * }); + */ +export function useCreateCorrectedThreadResponseMutation(baseOptions?: Apollo.MutationHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useMutation(CreateCorrectedThreadResponseDocument, options); + } +export type CreateCorrectedThreadResponseMutationHookResult = ReturnType; +export type CreateCorrectedThreadResponseMutationResult = Apollo.MutationResult; +export type CreateCorrectedThreadResponseMutationOptions = Apollo.BaseMutationOptions; \ No newline at end of file diff --git a/wren-ui/src/apollo/client/graphql/home.ts b/wren-ui/src/apollo/client/graphql/home.ts index 70758cfb4..35d64e426 100644 --- a/wren-ui/src/apollo/client/graphql/home.ts +++ b/wren-ui/src/apollo/client/graphql/home.ts @@ -22,6 +22,16 @@ const COMMON_RESPONSE = gql` summary sql cteName + references { + referenceId + summary + type + sqlSnippet + sqlLocation { + column + line + } + } } view { id @@ -30,6 +40,17 @@ const COMMON_RESPONSE = gql` displayName } } + explain { + queryId + status + error + } + corrections { + id + type + referenceNum + correction + } } `; @@ -139,10 +160,7 @@ export const CREATE_THREAD_RESPONSE = gql` createThreadResponse(threadId: $threadId, data: $data) { ...CommonResponse error { - code - shortMessage - message - stacktrace + ...CommonError } } } @@ -150,6 +168,14 @@ export const CREATE_THREAD_RESPONSE = gql` ${COMMON_ERROR} `; +export const CREATE_THREAD_RESPONSE_EXPLAIN = gql` + mutation CreateThreadResponseExplain( + $where: CreateThreadResponseExplainWhereInput! + ) { + createThreadResponseExplain(where: $where) + } +`; + export const UPDATE_THREAD = gql` mutation UpdateThread( $where: ThreadUniqueWhereInput! @@ -180,3 +206,17 @@ export const GET_NATIVE_SQL = gql` nativeSql(responseId: $responseId) } `; + +export const CREATE_CORRECTED_THREAD_RESPONSE = gql` + mutation CreateCorrectedThreadResponse( + $threadId: Int! + $data: CreateCorrectedThreadResponseInput! + ) { + createCorrectedThreadResponse(threadId: $threadId, data: $data) { + ...CommonResponse + error { + ...CommonError + } + } + } +`; diff --git a/wren-ui/src/apollo/server/adaptors/ibisAdaptor.ts b/wren-ui/src/apollo/server/adaptors/ibisAdaptor.ts index 6dcb7802f..942bc6155 100644 --- a/wren-ui/src/apollo/server/adaptors/ibisAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/ibisAdaptor.ts @@ -87,6 +87,65 @@ export interface IbisQueryOptions extends IbisBaseOptions { limit?: number; } +export interface IbisQueryResponse { + columns: string[]; + data: any[]; + dtypes: Record; +} + +export interface SelectItemAnalysis { + alias: string; + expression: string; + properties: Record; +} + +export enum RelationType { + TABLE = 'TABLE', + SUBQUERY = 'SUBQUERY', + INNER_JOIN = 'INNER_JOIN', + LEFT_JOIN = 'LEFT_JOIN', + RIGHT_JOIN = 'RIGHT_JOIN', + FULL_JOIN = 'FULL_JOIN', + CROSS_JOIN = 'CROSS_JOIN', + IMPLICIT_JOIN = 'IMPLICIT_JOIN', +} + +export interface RelationAnalysis { + type: RelationType; + alias?: string; + tableName?: string; + left?: RelationAnalysis; + right?: RelationAnalysis; + criteria?: string; + // exist when type = subquery + body?: RelationAnalysis[]; + properties?: Record; +} + +export enum FilterType { + EXPR = 'EXPR', + // Logical expression + AND = 'AND', + OR = 'OR', +} +export interface FilterAnalysis { + type: FilterType; + node?: string; + left?: FilterAnalysis; + right?: FilterAnalysis; +} +export interface SortAnalysis { + expression: string; + ordering: 'ASCENDING' | 'DESCENDING'; +} +export interface QueryAnalysis { + selectItems?: SelectItemAnalysis[]; + relation?: RelationAnalysis; + filter?: FilterAnalysis; + groupByKeys?: string[][]; + sortings?: SortAnalysis; +} + export interface IIbisAdaptor { query: ( query: string, @@ -109,19 +168,35 @@ export interface IIbisAdaptor { mdl: Manifest, parameters: Record, ) => Promise; -} -export interface IbisQueryResponse { - columns: string[]; - data: any[]; - dtypes: Record; + analysisSqls: (mdl: Manifest, sqls: string[]) => Promise; } export class IbisAdaptor implements IIbisAdaptor { private ibisServerBaseUrl: string; constructor({ ibisServerEndpoint }: { ibisServerEndpoint: string }) { - this.ibisServerBaseUrl = `${ibisServerEndpoint}/v2/connector`; + this.ibisServerBaseUrl = `${ibisServerEndpoint}/v2`; + } + public async analysisSqls(mdl: Manifest, sqls: string[]) { + try { + const manifestStr = Buffer.from(JSON.stringify(mdl)).toString('base64'); + const res: AxiosResponse = await axios({ + method: 'get', + url: `${this.ibisServerBaseUrl}/analysis/sqls`, + data: { + manifestStr, + sqls, + }, + }); + return res.data; + } catch (err) { + logger.debug(`Got error when analysis sqls: ${err.response.data}`); + throw Errors.create(Errors.GeneralErrorCodes.IBIS_SERVER_ERROR, { + customMessage: err.response.data, + originalError: err, + }); + } } public async query( @@ -138,7 +213,7 @@ export class IbisAdaptor implements IIbisAdaptor { }; try { const res = await axios.post( - `${this.ibisServerBaseUrl}/${dataSourceUrlMap[dataSource]}/query`, + `${this.ibisServerBaseUrl}/connector/${dataSourceUrlMap[dataSource]}/query`, body, { params: { @@ -173,7 +248,7 @@ export class IbisAdaptor implements IIbisAdaptor { logger.debug(`Dry run sql from ibis with body:`); try { await axios.post( - `${this.ibisServerBaseUrl}/${dataSourceUrlMap[dataSource]}/query?dryRun=true`, + `${this.ibisServerBaseUrl}/connector/${dataSourceUrlMap[dataSource]}/query?dryRun=true`, body, ); logger.debug(`Ibis server Dry run success`); @@ -199,7 +274,7 @@ export class IbisAdaptor implements IIbisAdaptor { try { logger.debug(`Getting tables from ibis`); const res: AxiosResponse = await axios.post( - `${this.ibisServerBaseUrl}/${dataSourceUrlMap[dataSource]}/metadata/tables`, + `${this.ibisServerBaseUrl}/connector/${dataSourceUrlMap[dataSource]}/metadata/tables`, body, ); return res.data; @@ -226,7 +301,7 @@ export class IbisAdaptor implements IIbisAdaptor { try { logger.debug(`Getting constraint from ibis`); const res: AxiosResponse = await axios.post( - `${this.ibisServerBaseUrl}/${dataSourceUrlMap[dataSource]}/metadata/constraints`, + `${this.ibisServerBaseUrl}/connector/${dataSourceUrlMap[dataSource]}/metadata/constraints`, body, ); return res.data; @@ -258,7 +333,7 @@ export class IbisAdaptor implements IIbisAdaptor { try { logger.debug(`Run validation rule "${validationRule}" with ibis`); await axios.post( - `${this.ibisServerBaseUrl}/${dataSourceUrlMap[dataSource]}/validate/${snakeCase(validationRule)}`, + `${this.ibisServerBaseUrl}/connector/${dataSourceUrlMap[dataSource]}/validate/${snakeCase(validationRule)}`, body, ); return { valid: true, message: null }; diff --git a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts index 507b11d92..0a9c7535a 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts @@ -55,6 +55,15 @@ export interface AsyncQueryResponse { queryId: string; } +export type ExplainResult = AIServiceResponse; + +export enum ExplainPipelineStatus { + UNDERSTANDING = 'UNDERSTANDING', + GENERATING = 'GENERATING', + FINISHED = 'FINISHED', + FAILED = 'FAILED', +} + export enum AskResultStatus { UNDERSTANDING = 'UNDERSTANDING', SEARCHING = 'SEARCHING', @@ -64,6 +73,13 @@ export enum AskResultStatus { STOPPED = 'STOPPED', } +export interface StepAnalysisResult { + sql: string; + summary: string; + cte_name?: string; + sql_analysis_results: any; +} + // if it's view, viewId will be returned as well. It means the candidate is originally saved in mdl as a view. // if it's llm, viewId will not be returned. It means the candidate is generated by AI service. export enum AskCandidateType { @@ -71,7 +87,23 @@ export enum AskCandidateType { LLM = 'LLM', } -export interface AskResponse { +// The enum key's name refer to schema.ts > ReferenceType +// It helps mapping AI service enum values: selectItems, relation, filter, sortings, groupByKeys +export enum ExplanationType { + FIELD = 'selectItems', + QUERY_FROM = 'relation', + FILTER = 'filter', + SORTING = 'sortings', + GROUP_BY = 'groupByKeys', +} + +// UI currently only support nl_expression +export enum ExpressionType { + SQL_EXPRESSION = 'sql_expression', + NL_EXPRESSION = 'nl_expression', +} + +export interface AIServiceResponse { status: S; response: R | null; error: WrenAIError | null; @@ -83,7 +115,7 @@ export interface AskDetailInput { summary: string; } -export type AskDetailResult = AskResponse< +export type AskDetailResult = AIServiceResponse< { description: string; steps: AskStep[]; @@ -91,7 +123,7 @@ export type AskDetailResult = AskResponse< AskResultStatus >; -export type AskResult = AskResponse< +export type AskResult = AIServiceResponse< Array<{ type: AskCandidateType; sql: string; @@ -101,7 +133,29 @@ export type AskResult = AskResponse< AskResultStatus >; -const getAISerciceError = (error: any) => { +export interface CorrectionObject { + type: T; + value: string; +} + +export interface AskCorrectionInput { + before: CorrectionObject; + after: CorrectionObject; +} + +export interface AskStepWithCorrectionsInput { + summary: string; + sql: string; + cte_name: string; + corrections: AskCorrectionInput[]; +} + +export interface RegenerateAskDetailInput { + description: string; + steps: AskStepWithCorrectionsInput[]; +} + +const getAIServiceError = (error: any) => { const { data } = error.response || {}; return data?.detail ? `${error.message}, detail: ${data.detail}` @@ -129,6 +183,15 @@ export interface IWrenAIAdaptor { */ generateAskDetail(input: AskDetailInput): Promise; getAskDetailResult(queryId: string): Promise; + explain( + question: string, + stepAnalysisResult: StepAnalysisResult[], + ): Promise; + getExplainResult(queryId: string): Promise; + regenerateAskDetail( + input: RegenerateAskDetailInput, + ): Promise; + getRegeneratedAskDetailResult(queryId: string): Promise; } export class WrenAIAdaptor implements IWrenAIAdaptor { @@ -148,11 +211,11 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { const res = await axios.post(`${this.wrenAIBaseEndpoint}/v1/asks`, { query: input.query, id: input.deployId, - history: this.transfromHistoryInput(input.history), + history: this.transformHistoryInput(input.history), }); return { queryId: res.data.query_id }; } catch (err: any) { - logger.debug(`Got error when asking wren AI: ${getAISerciceError(err)}`); + logger.debug(`Got error when asking wren AI: ${getAIServiceError(err)}`); throw err; } } @@ -164,7 +227,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { status: 'stopped', }); } catch (err: any) { - logger.debug(`Got error when canceling ask: ${getAISerciceError(err)}`); + logger.debug(`Got error when canceling ask: ${getAIServiceError(err)}`); throw err; } } @@ -178,7 +241,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { return this.transformAskResult(res.data); } catch (err: any) { logger.debug( - `Got error when getting ask result: ${getAISerciceError(err)}`, + `Got error when getting ask result: ${getAIServiceError(err)}`, ); // throw err; throw Errors.create(Errors.GeneralErrorCodes.INTERNAL_SERVER_ERROR, { @@ -187,6 +250,49 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } } + public async explain( + question: string, + stepAnalysisResult: StepAnalysisResult[], + ): Promise { + try { + logger.info({ + question, + steps_with_analysis_results: stepAnalysisResult, + }); + const res = await axios.post( + `${this.wrenAIBaseEndpoint}/v1/sql-explanations`, + { + question, + steps_with_analysis_results: stepAnalysisResult, + }, + ); + return { queryId: res.data.query_id }; + } catch (err: any) { + logger.debug(`Got error when explaining: ${getAIServiceError(err)}`); + throw err; + } + } + + public async getExplainResult(queryId: string): Promise { + // make GET request /v1/sql-explanations/:query_id/result to get the result + try { + const res = await axios.get( + `${this.wrenAIBaseEndpoint}/v1/sql-explanations/${queryId}/result`, + ); + const { status, error } = this.transformStatusAndError(res.data); + return { + status: status as ExplainPipelineStatus, + response: res.data.response, + error, + }; + } catch (err: any) { + logger.debug( + `Got error when getting explain result: ${getAIServiceError(err)}`, + ); + throw err; + } + } + /** * After you choose a candidate, you can request AI service to generate the detail. */ @@ -202,7 +308,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { return { queryId: res.data.query_id }; } catch (err: any) { logger.debug( - `Got error when generating ask detail: ${getAISerciceError(err)}`, + `Got error when generating ask detail: ${getAIServiceError(err)}`, ); throw err; } @@ -217,7 +323,37 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { return this.transformAskDetailResult(res.data); } catch (err: any) { logger.debug( - `Got error when getting ask detail result: ${getAISerciceError(err)}`, + `Got error when getting ask detail result: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + public async regenerateAskDetail(input: RegenerateAskDetailInput) { + try { + const res = await axios.post( + `${this.wrenAIBaseEndpoint}/v1/sql-regenerations`, + input, + ); + return { queryId: res.data.query_id }; + } catch (err: any) { + logger.debug( + `Got error when regenerating ask detail: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + public async getRegeneratedAskDetailResult(queryId: string) { + // make GET request /v1/sql-regenerations/:query_id/result to get the result + try { + const res = await axios.get( + `${this.wrenAIBaseEndpoint}/v1/sql-regenerations/${queryId}/result`, + ); + return this.transformAskDetailResult(res.data); + } catch (err: any) { + logger.debug( + `Got error when getting regenerated ask detail result: ${getAIServiceError(err)}`, ); throw err; } @@ -286,9 +422,9 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { const res = await axios.get( `${this.wrenAIBaseEndpoint}/v1/semantics-preparations/${deployId}/status`, ); - if (res.data.error) { + if (res.data?.error?.message) { // passing AI response error string to catch block - throw new Error(res.data.error); + throw new Error(res.data.error.message); } return res.data?.status.toUpperCase() as WrenAISystemStatus; } catch (err: any) { @@ -309,7 +445,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { })); return { - status, + status: status as AskResultStatus, error, response: candidates, }; @@ -326,7 +462,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { })); return { - status, + status: status as AskResultStatus, error, response: { description: body?.response?.description, @@ -336,7 +472,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } private transformStatusAndError(body: any): { - status: AskResultStatus; + status: AskResultStatus | ExplainPipelineStatus; error?: { code: Errors.GeneralErrorCodes; message: string; @@ -344,9 +480,11 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } | null; } { // transform status to enum - const status = AskResultStatus[ - body?.status?.toUpperCase() - ] as AskResultStatus; + const status = + (AskResultStatus[body?.status?.toUpperCase()] as AskResultStatus) || + (ExplainPipelineStatus[ + body.status + ]?.toUpperCase() as ExplainPipelineStatus); if (!status) { throw new Error(`Unknown ask status: ${body?.status}`); @@ -380,7 +518,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { }; } - private transfromHistoryInput(history: AskHistory) { + private transformHistoryInput(history: AskHistory) { if (!history) { return null; } diff --git a/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts index 228ac9ee7..389bf0544 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts @@ -94,6 +94,9 @@ export interface IWrenEngineAdaptor { sql: string, options: WrenEngineDryRunOption, ): Promise; + + // analysis + analysisSql(sql: string, mdl: Manifest): Promise; } export class WrenEngineAdaptor implements IWrenEngineAdaptor { @@ -105,6 +108,7 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { private dryPlanUrlPath = '/v1/mdl/dry-plan'; private dryRunUrlPath = '/v1/mdl/dry-run'; private validateUrlPath = '/v1/mdl/validate'; + private analysisUrlPath = '/v1/analysis/sql'; constructor({ wrenEngineEndpoint }: { wrenEngineEndpoint: string }) { this.wrenEngineBaseEndpoint = wrenEngineEndpoint; @@ -314,16 +318,20 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { } } - private async getDeployStatus(): Promise { + public async analysisSql(sql: string, mdl: Manifest) { try { - const res = await axios.get( - `${this.wrenEngineBaseEndpoint}/v1/mdl/status`, + const url = new URL(this.analysisUrlPath, this.wrenEngineBaseEndpoint); + const headers = { + 'Content-Type': 'application/json', + }; + const res = await axios.post( + url.href, + { sql, manifest: mdl }, + { headers }, ); - return res.data as WrenEngineDeployStatusResponse; + return res.data; } catch (err: any) { - logger.debug( - `WrenEngine: Got error when getting deploy status: ${err.message}`, - ); + logger.debug(`Got error when analyzing sql: ${err.message}`); throw err; } } diff --git a/wren-ui/src/apollo/server/backgroundTrackers/explainBackgroundTracker.ts b/wren-ui/src/apollo/server/backgroundTrackers/explainBackgroundTracker.ts new file mode 100644 index 000000000..a37c38c44 --- /dev/null +++ b/wren-ui/src/apollo/server/backgroundTrackers/explainBackgroundTracker.ts @@ -0,0 +1,241 @@ +import { + ExplainPipelineStatus, + ExplanationType, + IWrenAIAdaptor, +} from '../adaptors/wrenAIAdaptor'; +import { + ExplainDetail, + IThreadResponseExplainRepository, + ThreadResponseExplain, +} from '../repositories/threadResponseExplainRepository'; +import { + IThreadResponseRepository, + ThreadResponseDetail, +} from '../repositories/threadResponseRepository'; +import { Telemetry } from '../telemetry/telemetry'; +import { GeneralErrorCodes } from '../utils/error'; +import { findAnalysisById, reverseEnum } from '../utils'; +import { BackgroundTracker } from './index'; +import { getLogger } from '@server/utils/logger'; +import { RelationType } from '../adaptors/ibisAdaptor'; + +const logger = getLogger('ExplainBackgroundTracker'); +logger.level = 'debug'; + +export class ThreadResponseExplainBackgroundTracker extends BackgroundTracker { + // tasks is a kv pair of task id and thread response + private wrenAIAdaptor: IWrenAIAdaptor; + private threadResponseRepository: IThreadResponseRepository; + private threadResponseExplainRepository: IThreadResponseExplainRepository; + + constructor({ + telemetry, + wrenAIAdaptor, + threadResponseRepository, + threadResponseExplainRepository, + }: { + telemetry: Telemetry; + wrenAIAdaptor: IWrenAIAdaptor; + threadResponseRepository: IThreadResponseRepository; + threadResponseExplainRepository: IThreadResponseExplainRepository; + }) { + super(); + this.telemetry = telemetry; + this.wrenAIAdaptor = wrenAIAdaptor; + this.threadResponseRepository = threadResponseRepository; + this.threadResponseExplainRepository = threadResponseExplainRepository; + this.intervalTime = 1000; + this.start(); + } + + protected start() { + logger.info('Explain Background tracker started'); + setInterval(() => { + const jobs = Object.values(this.tasks).map( + (threadResponseExplain) => async () => { + // check if same job is running + if (this.runningJobs.has(threadResponseExplain.id)) { + return; + } + + // mark the job as running + this.runningJobs.add(threadResponseExplain.id); + + // get the latest result from AI service + const result = await this.wrenAIAdaptor.getExplainResult( + threadResponseExplain.queryId, + ); + + // if status not changed, early return + if (threadResponseExplain.status === result.status) { + logger.debug( + `Explain job ${threadResponseExplain.id} status not changed, skipping`, + ); + this.runningJobs.delete(threadResponseExplain.id); + return; + } + + // update database + logger.debug( + `Explain job ${threadResponseExplain.id} status changed to "${result.status}", updating`, + ); + + const updatedExplain = + await this.threadResponseExplainRepository.updateOne( + threadResponseExplain.id, + { + status: result.status, + detail: result.response, + error: result.error, + }, + ); + this.tasks[threadResponseExplain.id] = updatedExplain; + + // remove the task from tracker if it is finalized + if (this.isFinalized(result.status)) { + logger.debug( + `Explain job ${threadResponseExplain.id} is finalized`, + ); + if (this.isSucceed(result.status)) { + try { + await this.mergeExplanationIntoThreadResponse( + threadResponseExplain.id, + ); + logger.debug(`Explain job ${threadResponseExplain.id} done`); + } catch (error: any) { + logger.error( + `Explain job ${threadResponseExplain.id} merge failed: ${error}`, + ); + await this.threadResponseExplainRepository.updateOne( + threadResponseExplain.id, + { + error: { + code: GeneralErrorCodes.MERGE_THREAD_RESPONSE_ERROR, + message: + typeof error === 'object' + ? JSON.stringify(error) + : error, + }, + }, + ); + } + } + delete this.tasks[threadResponseExplain.id]; + logger.debug(`Explain job ${threadResponseExplain.id} deleted`); + } + + // mark the job as finished + this.runningJobs.delete(threadResponseExplain.id); + return result; + }, + ); + + // run the jobs + Promise.allSettled(jobs.map((job) => job())).then((results) => { + // show reason of rejection + results.forEach((result, index) => { + if (result.status === 'rejected') { + this.telemetry.send_event('explain_job_failed', { + status: result.status, + reason: result.reason, + }); + logger.error(`Explain Job ${index} failed: ${result.reason}`); + } + }); + }); + }, this.intervalTime); + } + + public addTask(threadResponseExplain: ThreadResponseExplain) { + this.tasks[threadResponseExplain.id] = threadResponseExplain; + } + + public getTasks() { + return this.tasks; + } + + public isFinalized = (status: ExplainPipelineStatus) => { + const _status = status.toUpperCase(); + return ( + _status === ExplainPipelineStatus.FAILED || + _status === ExplainPipelineStatus.FINISHED + ); + }; + + public isSucceed = (status: ExplainPipelineStatus) => { + return status.toUpperCase() === ExplainPipelineStatus.FINISHED; + }; + + private async mergeExplanationIntoThreadResponse(threadResponseExplainId) { + const threadResponseExplain = + await this.threadResponseExplainRepository.findOneBy({ + id: threadResponseExplainId, + }); + const threadResponse = await this.threadResponseRepository.findOneBy({ + id: threadResponseExplain.threadResponseId, + }); + logger.debug( + `Start Merging explanation ${threadResponseExplain.id} into thread response ${threadResponse.id}`, + ); + // merge explain response to thread response + const detailWithExplanation = this.mergeExplanationToThreadResponseDetail( + threadResponse.detail, + threadResponseExplain.detail, + threadResponseExplain.analysis, + ); + logger.debug( + `Merge explanation ${threadResponseExplain.id} into thread response ${threadResponse.id} completed`, + ); + await this.threadResponseRepository.updateOne(threadResponse.id, { + detail: detailWithExplanation as any, + }); + logger.debug(`ThreadResponse ${threadResponse.id} detail updated`); + } + + // reorder explanation id and attach sql location to each explanation + // then attach the explanation to the thread response detail + private mergeExplanationToThreadResponseDetail( + detail: ThreadResponseDetail, + explanations: ExplainDetail[], + analyses: object, + ) { + const toReferenceType = reverseEnum(ExplanationType); + // reorder reference id + let id = 1; + const steps = Object.entries(detail.steps).map(([stepId, step]) => { + const analysesOfStep = analyses[stepId]; + const explanationOfStep = explanations[stepId]; + const references = explanationOfStep.map((explanation) => { + const analysis = findAnalysisById( + analysesOfStep, + Number(explanation.payload.id), + ); + // remove previous id + const payload = { ...explanation.payload }; + const sqlLocation = Object.values(RelationType).includes(analysis?.type) + ? analysis.criteria?.nodeLocation + : analysis?.nodeLocation; + delete payload.id; + return { + referenceId: id++, + type: toReferenceType[explanation.type], + sqlSnippet: + (explanation.payload as any).expression || + (explanation.payload as any).criteria || + (analysis as any).tableName, + summary: explanation.payload.explanation || '', + sqlLocation, + }; + }); + + return { + ...step, + references, + }; + }); + return { + ...detail, + steps, + }; + } +} diff --git a/wren-ui/src/apollo/server/backgroundTrackers/index.ts b/wren-ui/src/apollo/server/backgroundTrackers/index.ts new file mode 100644 index 000000000..d65de1bdb --- /dev/null +++ b/wren-ui/src/apollo/server/backgroundTrackers/index.ts @@ -0,0 +1,13 @@ +import { Telemetry } from '../telemetry/telemetry'; + +export abstract class BackgroundTracker { + protected tasks: Record = {}; + protected intervalTime: number = 1000; + protected runningJobs: Set = new Set(); + protected telemetry: Telemetry; + + protected abstract start(): void; + public abstract addTask(task: R): void; + public abstract getTasks(): Record; + public abstract isFinalized(statue: any): boolean; +} diff --git a/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts b/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts new file mode 100644 index 000000000..297822944 --- /dev/null +++ b/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts @@ -0,0 +1,149 @@ +import { + AskDetailResult, + AskResultStatus, + IWrenAIAdaptor, +} from '../adaptors/wrenAIAdaptor'; +import { + IThreadResponseRepository, + ThreadResponse, +} from '../repositories/threadResponseRepository'; +import { Telemetry } from '../telemetry/telemetry'; +import { BackgroundTracker } from './index'; +import { getLogger } from '@server/utils/logger'; + +const logger = getLogger('ThreadResponseBackgroundTracker'); +logger.level = 'debug'; + +export class ThreadResponseBackgroundTracker extends BackgroundTracker { + // tasks is a kv pair of task id and thread response + private wrenAIAdaptor: IWrenAIAdaptor; + private threadResponseRepository: IThreadResponseRepository; + private isRegenerated: boolean = false; + + constructor({ + telemetry, + wrenAIAdaptor, + threadResponseRepository, + isRegenerated, + }: { + telemetry: Telemetry; + wrenAIAdaptor: IWrenAIAdaptor; + threadResponseRepository: IThreadResponseRepository; + isRegenerated?: boolean; + }) { + super(); + this.telemetry = telemetry; + this.wrenAIAdaptor = wrenAIAdaptor; + this.threadResponseRepository = threadResponseRepository; + this.isRegenerated = isRegenerated; + this.start(); + } + + public start() { + logger.info(`${this.strategy.name} background tracker started`); + setInterval(() => { + const jobs = Object.values(this.tasks).map( + (threadResponse) => async () => { + // check if same job is running + if (this.runningJobs.has(threadResponse.id)) { + return; + } + + // mark the job as running + this.runningJobs.add(threadResponse.id); + + // get the latest result from AI service + const result = await this.strategy.getAskDetailResult( + threadResponse.queryId, + ); + + // check if status change + if (threadResponse.status === result.status) { + // mark the job as finished + logger.debug( + `Job ${threadResponse.id} status not changed, finished`, + ); + this.runningJobs.delete(threadResponse.id); + return; + } + + // update database + logger.debug(`Job ${threadResponse.id} status changed, updating`); + await this.threadResponseRepository.updateOne(threadResponse.id, { + status: result.status, + detail: result.response, + error: result.error, + }); + + // remove the task from tracker if it is finalized + if (this.isFinalized(result.status)) { + this.strategy.sendTelemetry(threadResponse, result); + logger.debug(`Job ${threadResponse.id} is finalized, removing`); + delete this.tasks[threadResponse.id]; + } + + // mark the job as finished + this.runningJobs.delete(threadResponse.id); + }, + ); + + // run the jobs + Promise.allSettled(jobs.map((job) => job())).then((results) => { + // show reason of rejection + results.forEach((result, index) => { + if (result.status === 'rejected') { + logger.error(`Job ${index} failed: ${result.reason}`); + } + }); + }); + }, this.intervalTime); + } + + public addTask(threadResponse: ThreadResponse) { + this.tasks[threadResponse.id] = threadResponse; + } + + public getTasks() { + return this.tasks; + } + + private get strategy() { + // For thread response + const strategy = { + name: 'ThreadResponse', + getAskDetailResult: (queryId) => + this.wrenAIAdaptor.getAskDetailResult(queryId), + sendTelemetry: ( + threadResponse: ThreadResponse, + result: AskDetailResult, + ) => { + this.telemetry.send_event('question_answered', { + question: threadResponse.question, + result, + }); + }, + }; + // For regenerated thread response + if (this.isRegenerated) { + strategy.name = 'RegeneratedThreadResponse'; + strategy.getAskDetailResult = (queryId) => + this.wrenAIAdaptor.getRegeneratedAskDetailResult(queryId); + strategy.sendTelemetry = (threadResponse, result) => { + this.telemetry.send_event('regenerated_question_answered', { + question: threadResponse.question, + corrections: threadResponse.corrections, + result, + }); + }; + } + return strategy; + } + + public isFinalized = (status: AskResultStatus) => { + return ( + status === AskResultStatus.FAILED || + status === AskResultStatus.FINISHED || + status === AskResultStatus.STOPPED + ); + }; +} diff --git a/wren-ui/src/apollo/server/repositories/baseRepository.ts b/wren-ui/src/apollo/server/repositories/baseRepository.ts index 87453db13..96f23385c 100644 --- a/wren-ui/src/apollo/server/repositories/baseRepository.ts +++ b/wren-ui/src/apollo/server/repositories/baseRepository.ts @@ -1,9 +1,17 @@ import { Knex } from 'knex'; import { camelCase, isPlainObject, mapKeys, snakeCase } from 'lodash'; +export enum Order { + ASC = 'asc', + DESC = 'desc', +} +export interface OrderBy { + column: string; + order: Order; +} export interface IQueryOptions { tx?: Knex.Transaction; - order?: string; + orderBy?: OrderBy[]; limit?: number; } @@ -83,8 +91,8 @@ export class BaseRepository implements IBasicRepository { const query = executer(this.tableName).where( this.transformToDBData(filter), ); - if (queryOptions?.order) { - query.orderBy(queryOptions.order); + if (queryOptions?.orderBy?.length) { + query.orderBy(queryOptions.orderBy); } const result = await query; return result.map(this.transformFromDBData); @@ -93,8 +101,8 @@ export class BaseRepository implements IBasicRepository { public async findAll(queryOptions?: IQueryOptions) { const executer = queryOptions?.tx ? queryOptions.tx : this.knex; const query = executer(this.tableName); - if (queryOptions?.order) { - query.orderBy(queryOptions.order); + if (queryOptions?.orderBy?.length) { + query.orderBy(queryOptions.orderBy); } if (queryOptions?.limit) { query.limit(queryOptions.limit); diff --git a/wren-ui/src/apollo/server/repositories/projectRepository.ts b/wren-ui/src/apollo/server/repositories/projectRepository.ts index 4afb41ce7..c1d984eef 100644 --- a/wren-ui/src/apollo/server/repositories/projectRepository.ts +++ b/wren-ui/src/apollo/server/repositories/projectRepository.ts @@ -1,5 +1,5 @@ import { Knex } from 'knex'; -import { BaseRepository, IBasicRepository } from './baseRepository'; +import { BaseRepository, IBasicRepository, Order } from './baseRepository'; import { camelCase, isPlainObject, @@ -87,7 +87,7 @@ export class ProjectRepository public async getCurrentProject() { const projects = await this.findAll({ - order: 'id', + orderBy: [{ column: 'id', order: Order.ASC }], limit: 1, }); if (!projects.length) { diff --git a/wren-ui/src/apollo/server/repositories/threadResponseExplainRepository.ts b/wren-ui/src/apollo/server/repositories/threadResponseExplainRepository.ts new file mode 100644 index 000000000..784239f32 --- /dev/null +++ b/wren-ui/src/apollo/server/repositories/threadResponseExplainRepository.ts @@ -0,0 +1,193 @@ +import { Knex } from 'knex'; +import { + BaseRepository, + IBasicRepository, + IQueryOptions, +} from './baseRepository'; +import { camelCase, isPlainObject, mapKeys, mapValues } from 'lodash'; +import { ExplainPipelineStatus } from '../adaptors/wrenAIAdaptor'; +import { getConfig } from '../config'; +const config = getConfig(); +export interface DetailStep { + summary: string; + sql: string; + cteName: string; +} + +export enum ExplainType { + FILTER = 'filter', + GROUP_BY_KEY = 'groupByKeys', + RELATION = 'relation', + SELECT_ITEMS = 'selectItems', + SORTINGS = 'sortings', +} + +export interface ExprSource { + expression: string; + sourceDataset: string; +} + +export interface FilterPayload { + id?: number; + expression: string; + explanation: string; +} + +export interface GroupByPayload { + id?: number; + expression: string; + explanation: string; +} + +export interface RelationPayload { + id?: number; + type: string; + criteria?: string; + exprSources?: ExprSource[]; + tableName?: string; + explanation?: string; +} +export interface SelectItemsPayload { + id?: number; + alias: string; + expression: string; + isFunctionCallOrMathematicalOperation: boolean; + explanation: string; +} +export interface SortingPayload { + id?: number; + expression: string; + explanation: string; +} + +export type ExplainPayload = + | FilterPayload + | GroupByPayload + | RelationPayload + | SelectItemsPayload + | SortingPayload; + +export interface ExplainDetail { + type: ExplainType; + payload: ExplainPayload; +} + +export interface ThreadResponseExplain { + id: number; // ID + threadResponseId: number; // Reference to thread_response.id + queryId: string; // explain pipeline query ID + status: ExplainPipelineStatus; // explain pipeline status + detail: ExplainDetail[]; // explain detail + error: object; // explain error + analysis: object; // analysis result +} + +export interface IThreadResponseExplainRepository + extends IBasicRepository { + findAllByThread(threadId: number): Promise; +} + +export class ThreadResponseExplainRepository + extends BaseRepository + implements IThreadResponseExplainRepository +{ + constructor(knexPg: Knex) { + super({ knexPg, tableName: 'thread_response_explain' }); + } + public async findAllByThread( + threadId: number, + ): Promise { + if (config.dbType === 'pg') { + return this.knex('thread_response as tr') + .join( + this.knex(this.tableName) + .distinctOn('thread_response_id') + .select('id', 'thread_response_id', 'detail', 'error', 'created_at') + .orderBy([ + 'thread_response_id', + { column: 'created_at', order: 'desc' }, + ]) + .as('tre'), + 'tre.thread_response_id', + 'tr.id', + ) + .select('*') + .where('tr.thread_id', threadId) + .then((results) => results.map(this.transformFromDBData)); + } + return this.knex('thread_response as tr') + .join( + this.knex(this.tableName) + .select() + .whereIn( + ['thread_response_id', 'id'], + this.knex(this.tableName) + .select('thread_response_id') + .max('id') + .groupBy('thread_response_id'), + ) + .as('tre'), + 'tr.id', + 'tre.thread_response_id', + ) + .select('tre.*') + .where('tr.thread_id', threadId) + .then((results) => results.map(this.transformFromDBData)); + } + + public async createOne( + data: Partial, + queryOptions?: IQueryOptions, + ) { + const transformedData = { + ...data, + detail: data.detail ? JSON.stringify(data.detail) : null, + error: data.error ? JSON.stringify(data.error) : null, + analysis: data.analysis ? JSON.stringify(data.analysis) : null, + } as any; + const executer = queryOptions?.tx ? queryOptions.tx : this.knex; + const [result] = await executer(this.tableName) + .insert(this.transformToDBData(transformedData)) + .returning('*'); + return this.transformFromDBData(result); + } + + public async updateOne( + id: number, + data: Partial, + queryOptions?: IQueryOptions, + ) { + const transformedData = { + ...data, + detail: data.detail ? JSON.stringify(data.detail) : null, + error: data.error ? JSON.stringify(data.error) : null, + }; + const executer = queryOptions?.tx ? queryOptions.tx : this.knex; + const [result] = await executer(this.tableName) + .where({ id }) + .update(transformedData) + .returning('*'); + return this.transformFromDBData(result); + } + + protected override transformFromDBData = ( + data: any, + ): ThreadResponseExplain => { + if (!isPlainObject(data)) { + throw new Error('Unexpected dbdata'); + } + const camelCaseData = mapKeys(data, (_value, key) => camelCase(key)); + const formattedData = mapValues(camelCaseData, (value, key) => { + if (['error', 'detail', 'analysis'].includes(key)) { + // The value from Sqlite will be string type, while the value from PG is JSON object + if (typeof value === 'string') { + return value ? JSON.parse(value) : value; + } else { + return value; + } + } + return value; + }) as ThreadResponseExplain; + return formattedData; + }; +} diff --git a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts index 47690d172..0dfa79d9a 100644 --- a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts +++ b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts @@ -4,13 +4,32 @@ import { IBasicRepository, IQueryOptions, } from './baseRepository'; -import { camelCase, isPlainObject, mapKeys, mapValues } from 'lodash'; +import { + camelCase, + snakeCase, + Dictionary, + isPlainObject, + mapKeys, + mapValues, +} from 'lodash'; import { AskResultStatus, WrenAIError } from '../adaptors/wrenAIAdaptor'; +import { ExplainType } from './threadResponseExplainRepository'; +export interface SQLLocation { + line: number; + column: number; +} +export interface ThreadResponseReference { + id: number; + type: ExplainType; + sqlSnippet: string; + sqlLocation: any; +} export interface DetailStep { summary: string; sql: string; cteName: string; + references?: ThreadResponseReference[]; } export interface ThreadResponseDetail { @@ -19,6 +38,12 @@ export interface ThreadResponseDetail { steps: Array; } +export interface PrevCorrection { + id: number; + type: string; + correction: string; +} + export interface ThreadResponse { id: number; // ID threadId: number; // Reference to thread.id @@ -28,6 +53,7 @@ export interface ThreadResponse { status: string; // Thread response status detail: ThreadResponseDetail; // Thread response detail error: object; // Thread response error + corrections: PrevCorrection[]; // Previous thread response corrections } export interface ThreadResponseWithThreadContext extends ThreadResponse { @@ -63,27 +89,9 @@ export class ThreadResponseRepository query.orderBy('created_at', 'desc').limit(limit); } - return (await query) - .map((res) => { - // turn object keys into camelCase - return mapKeys(res, (_, key) => camelCase(key)); - }) - .map((res) => { - // JSON.parse detail and error - const detail = - res.detail && typeof res.detail === 'string' - ? JSON.parse(res.detail) - : res.detail; - const error = - res.error && typeof res.error === 'string' - ? JSON.parse(res.error) - : res.error; - return { - ...res, - detail: detail || null, - error: error || null, - }; - }) as ThreadResponseWithThreadContext[]; + return (await query).map((res) => + this.transformFromDBData(res), + ) as ThreadResponseWithThreadContext[]; } public async updateOne( @@ -108,13 +116,31 @@ export class ThreadResponseRepository return this.transformFromDBData(result); } + protected override transformToDBData = (data: any) => { + if (!isPlainObject(data)) { + throw new Error('Unexpected dbdata'); + } + const formattedData = mapValues(data, (value, key) => { + if (['error', 'detail', 'corrections'].includes(key)) { + // The value from Sqlite will be string type, while the value from PG is JSON object + if (value) { + return typeof value === 'string' ? value : JSON.stringify(value); + } else { + return value; + } + } + return value; + }) as Dictionary; + return mapKeys(formattedData, (_value, key) => snakeCase(key)); + }; + protected override transformFromDBData = (data: any): ThreadResponse => { if (!isPlainObject(data)) { throw new Error('Unexpected dbdata'); } const camelCaseData = mapKeys(data, (_value, key) => camelCase(key)); const formattedData = mapValues(camelCaseData, (value, key) => { - if (['error', 'detail'].includes(key)) { + if (['error', 'detail', 'corrections'].includes(key)) { // The value from Sqlite will be string type, while the value from PG is JSON object if (typeof value === 'string') { return value ? JSON.parse(value) : value; diff --git a/wren-ui/src/apollo/server/resolvers.ts b/wren-ui/src/apollo/server/resolvers.ts index b5b20d278..e92844288 100644 --- a/wren-ui/src/apollo/server/resolvers.ts +++ b/wren-ui/src/apollo/server/resolvers.ts @@ -78,8 +78,12 @@ const resolvers = { updateThread: askingResolver.updateThread, deleteThread: askingResolver.deleteThread, createThreadResponse: askingResolver.createThreadResponse, + createCorrectedThreadResponse: askingResolver.createCorrectedThreadResponse, previewData: askingResolver.previewData, + // Explain + createThreadResponseExplain: askingResolver.createThreadResponseExplain, + // Views createView: modelResolver.createView, deleteView: modelResolver.deleteView, diff --git a/wren-ui/src/apollo/server/resolvers/askingResolver.ts b/wren-ui/src/apollo/server/resolvers/askingResolver.ts index d76ee928c..b195d9728 100644 --- a/wren-ui/src/apollo/server/resolvers/askingResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/askingResolver.ts @@ -14,6 +14,7 @@ import { SampleDatasetName, getSampleAskQuestions, } from '../data'; +import { Order } from '../repositories'; const logger = getLogger('AskingResolver'); logger.level = 'debug'; @@ -54,8 +55,12 @@ export class AskingResolver { this.deleteThread = this.deleteThread.bind(this); this.listThreads = this.listThreads.bind(this); this.createThreadResponse = this.createThreadResponse.bind(this); + this.createCorrectedThreadResponse = + this.createCorrectedThreadResponse.bind(this); this.getResponse = this.getResponse.bind(this); this.getSuggestedQuestions = this.getSuggestedQuestions.bind(this); + this.createThreadResponseExplain = + this.createThreadResponseExplain.bind(this); } public async getSuggestedQuestions( @@ -164,6 +169,7 @@ export class AskingResolver { const askingService = ctx.askingService; const responses = await askingService.getResponsesWithThread(threadId); + const explains = await askingService.getExplainDetailsByThread(threadId); // reduce responses to group by thread id const thread = reduce( responses, @@ -174,7 +180,9 @@ export class AskingResolver { acc.summary = response.threadSummary; acc.responses = []; } - + const explain = explains.find( + (e) => e.threadResponseId === response.id, + ); acc.responses.push({ id: response.id, question: response.question, @@ -188,6 +196,12 @@ export class AskingResolver { status: response.status, detail: response.detail, error: response.error, + corrections: response.corrections, + explain: { + queryId: explain?.queryId || null, + status: explain?.status || null, + error: explain?.error || null, + }, }); return acc; @@ -257,14 +271,58 @@ export class AskingResolver { return response; } + public async createCorrectedThreadResponse( + _root: any, + args: { + threadId: number; + data: { + responseId: number; + corrections: { + id: number; + type: string; + referenceNum: number; + stepIndex: number; + reference: string; + correction: string; + }[]; + }; + }, + ctx: IContext, + ): Promise { + const { threadId, data } = args; + + const askingService = ctx.askingService; + const response = await askingService.createCorrectedThreadResponse( + threadId, + data, + ); + ctx.telemetry.send_event('regenerate_asked_question', {}); + return response; + } + public async getResponse( _root: any, args: { responseId: number }, ctx: IContext, - ): Promise { + ): Promise< + ThreadResponse & { + explain: { + queryId: string | null; + status: string | null; + error: object | null; + }; + } + > { const { responseId } = args; const askingService = ctx.askingService; const response = await askingService.getResponse(responseId); + const explain = await ctx.threadResponseExplainRepository.findAllBy( + { + threadResponseId: responseId, + }, + { orderBy: [{ column: 'created_at', order: Order.DESC }], limit: 1 }, + ); + const hasExplain = !!explain.length; // we added summary in version 0.3.0. // if summary is not available, we use description and question instead. @@ -272,6 +330,11 @@ export class AskingResolver { ...response, summary: response.summary || response.detail?.description || response.question, + explain: { + queryId: hasExplain ? explain[0].queryId : null, + status: hasExplain ? explain[0].status : null, + error: hasExplain ? explain[0].error : null, + }, }; } @@ -286,6 +349,16 @@ export class AskingResolver { return data; } + public async createThreadResponseExplain( + _root: any, + args: { where: { responseId: number } }, + ctx: IContext, + ) { + return await ctx.askingService.createThreadResponseExplain( + args.where.responseId, + ); + } + /** * Nested resolvers */ diff --git a/wren-ui/src/apollo/server/schema.ts b/wren-ui/src/apollo/server/schema.ts index 193f2c777..a9a139d18 100644 --- a/wren-ui/src/apollo/server/schema.ts +++ b/wren-ui/src/apollo/server/schema.ts @@ -486,11 +486,26 @@ export const typeDefs = gql` STOPPED } + enum ExplainTaskStatus { + UNDERSTANDING + GENERATING + FINISHED + FAILED + } + enum ResultCandidateType { VIEW # View type candidate is provided basd on a saved view LLM # LLM type candidate is created by LLM } + enum ReferenceType { + FIELD + QUERY_FROM + FILTER + SORTING + GROUP_BY + } + type ResultCandidate { type: ResultCandidateType! sql: String! @@ -519,6 +534,20 @@ export const typeDefs = gql` viewId: Int } + input CreateThreadResponseCorrectionInput { + id: Int! + referenceNum: Int! + stepIndex: Int! + type: ReferenceType! + reference: String! + correction: String! + } + + input CreateCorrectedThreadResponseInput { + responseId: Int! + corrections: [CreateThreadResponseCorrectionInput!]! + } + input ThreadUniqueWhereInput { id: Int! } @@ -536,10 +565,30 @@ export const typeDefs = gql` limit: Int } + type ReferenceSQLLocation { + line: Int! + column: Int! + } + + type DetailReference { + referenceId: Int + type: ReferenceType! + sqlSnippet: String + summary: String! + sqlLocation: ReferenceSQLLocation + } + type DetailStep { summary: String! sql: String! cteName: String + references: [DetailReference] + } + + type ThreadResponseExplainInfo { + queryId: String + status: ExplainTaskStatus + error: JSON } type ThreadResponseDetail { @@ -549,6 +598,13 @@ export const typeDefs = gql` steps: [DetailStep!]! } + type CorrectionDetail { + id: Int! + type: ReferenceType! + referenceNum: Int! + correction: String! + } + type ThreadResponse { id: Int! question: String! @@ -556,6 +612,8 @@ export const typeDefs = gql` status: AskingTaskStatus! detail: ThreadResponseDetail error: Error + corrections: [CorrectionDetail!] + explain: ThreadResponseExplainInfo } # Thread only consists of basic information of a thread @@ -656,6 +714,10 @@ export const typeDefs = gql` type: SchemaChangeType! } + input CreateThreadResponseExplainWhereInput { + responseId: Int! + } + # Query and Mutation type Query { # On Boarding Steps @@ -757,8 +819,17 @@ export const typeDefs = gql` threadId: Int! data: CreateThreadResponseInput! ): ThreadResponse! + createCorrectedThreadResponse( + threadId: Int! + data: CreateCorrectedThreadResponseInput! + ): ThreadResponse! previewData(where: PreviewDataInput!): JSON! + # Explain + createThreadResponseExplain( + where: CreateThreadResponseExplainWhereInput! + ): JSON! + # Settings resetCurrentProject: Boolean! updateDataSource(data: UpdateDataSourceInput!): DataSource! diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index 79d944bdc..39fc723b6 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -3,21 +3,32 @@ import { IWrenAIAdaptor, AskResultStatus, AskHistory, + ExpressionType, + ExplanationType, + ExplainPipelineStatus, + StepAnalysisResult, } from '@server/adaptors/wrenAIAdaptor'; import { IDeployService } from './deployService'; import { IProjectService } from './projectService'; import { IThreadRepository, Thread } from '../repositories/threadRepository'; +import { + IThreadResponseExplainRepository, + ThreadResponseExplain, +} from '../repositories/threadResponseExplainRepository'; import { IThreadResponseRepository, ThreadResponse, ThreadResponseWithThreadContext, } from '../repositories/threadResponseRepository'; -import { getLogger } from '@server/utils'; -import { isEmpty, isNil } from 'lodash'; +import { groupBy, isEmpty, isNil } from 'lodash'; +import { addAutoIncrementId, getLogger } from '@server/utils'; import { format } from 'sql-formatter'; import { Telemetry } from '../telemetry/telemetry'; import { IViewRepository, View } from '../repositories'; import { IQueryService, PreviewDataResponse } from './queryService'; +import { ThreadResponseBackgroundTracker } from '../backgroundTrackers/threadResponseBackgroundTracker'; +import { ThreadResponseExplainBackgroundTracker } from '../backgroundTrackers/explainBackgroundTracker'; +import { IIbisAdaptor } from '../adaptors/ibisAdaptor'; const logger = getLogger('AskingService'); logger.level = 'debug'; @@ -40,6 +51,20 @@ export interface AskingDetailTaskInput { viewId?: number; } +export interface CorrectionInput { + id: number; + type: string; + referenceNum: number; + stepIndex: number; + reference: string; + correction: string; +} + +export interface CorrectedDetailTaskInput { + responseId: number; + corrections: CorrectionInput[]; +} + export interface IAskingService { /** * Asking task. @@ -62,9 +87,14 @@ export interface IAskingService { threadId: number, input: AskingDetailTaskInput, ): Promise; + createCorrectedThreadResponse( + threadId: number, + input: CorrectedDetailTaskInput, + ): Promise; getResponsesWithThread( threadId: number, ): Promise; + getExplainDetailsByThread(threadId: number): Promise; getResponse(responseId: number): Promise; previewData( responseId: number, @@ -72,19 +102,11 @@ export interface IAskingService { limit?: number, ): Promise; deleteAllByProjectId(projectId: number): Promise; + createThreadResponseExplain( + threadResponseId: number, + ): Promise; } -/** - * utility function to check if the status is finalized - */ -const isFinalized = (status: AskResultStatus) => { - return ( - status === AskResultStatus.FAILED || - status === AskResultStatus.FINISHED || - status === AskResultStatus.STOPPED - ); -}; - /** * Given a list of steps, construct the SQL statement with CTEs * If stepIndex is provided, only construct the SQL from top to that step @@ -132,165 +154,125 @@ export const constructCteSql = ( return sql; }; -/** - * Background tracker to track the status of the asking detail task - */ -class BackgroundTracker { - // tasks is a kv pair of task id and thread response - private tasks: Record = {}; - private intervalTime: number; - private wrenAIAdaptor: IWrenAIAdaptor; - private threadResponseRepository: IThreadResponseRepository; - private runningJobs = new Set(); - private telemetry: Telemetry; - - constructor({ - telemetry, - wrenAIAdaptor, - threadResponseRepository, - }: { - telemetry: Telemetry; - wrenAIAdaptor: IWrenAIAdaptor; - threadResponseRepository: IThreadResponseRepository; - }) { - this.telemetry = telemetry; - this.wrenAIAdaptor = wrenAIAdaptor; - this.threadResponseRepository = threadResponseRepository; - this.intervalTime = 1000; - this.start(); - } - - public start() { - logger.info('Background tracker started'); - setInterval(() => { - const jobs = Object.values(this.tasks).map( - (threadResponse) => async () => { - // check if same job is running - if (this.runningJobs.has(threadResponse.id)) { - return; - } - - // mark the job as running - this.runningJobs.add(threadResponse.id); - - // get the latest result from AI service - const result = await this.wrenAIAdaptor.getAskDetailResult( - threadResponse.queryId, - ); - - // check if status change - if (threadResponse.status === result.status) { - // mark the job as finished - logger.debug( - `Job ${threadResponse.id} status not changed, finished`, - ); - this.runningJobs.delete(threadResponse.id); - return; - } - - // update database - logger.debug(`Job ${threadResponse.id} status changed, updating`); - await this.threadResponseRepository.updateOne(threadResponse.id, { - status: result.status, - detail: result.response, - error: result.error, - }); - - // remove the task from tracker if it is finalized - if (isFinalized(result.status)) { - this.telemetry.send_event('question_answered', { - question: threadResponse.question, - result, - }); - logger.debug(`Job ${threadResponse.id} is finalized, removing`); - delete this.tasks[threadResponse.id]; - } - - // mark the job as finished - this.runningJobs.delete(threadResponse.id); - }, - ); - - // run the jobs - Promise.allSettled(jobs.map((job) => job())).then((results) => { - // show reason of rejection - results.forEach((result, index) => { - if (result.status === 'rejected') { - logger.error(`Job ${index} failed: ${result.reason}`); - } - }); - }); - }, this.intervalTime); - } - - public addTask(threadResponse: ThreadResponse) { - this.tasks[threadResponse.id] = threadResponse; - } - - public getTasks() { - return this.tasks; - } -} - export class AskingService implements IAskingService { private wrenAIAdaptor: IWrenAIAdaptor; + private ibisAdaptor: IIbisAdaptor; private deployService: IDeployService; private projectService: IProjectService; private viewRepository: IViewRepository; private threadRepository: IThreadRepository; private threadResponseRepository: IThreadResponseRepository; - private backgroundTracker: BackgroundTracker; + private backgroundTracker: ThreadResponseBackgroundTracker; + private regeneratedBackgroundTracker: ThreadResponseBackgroundTracker; + private threadResponseExplainRepository: IThreadResponseExplainRepository; + private threadResponseBackgroundTracker: ThreadResponseBackgroundTracker; + private explainBackgroundTracker: ThreadResponseExplainBackgroundTracker; private queryService: IQueryService; private telemetry: Telemetry; constructor({ telemetry, wrenAIAdaptor, + ibisAdaptor, deployService, projectService, viewRepository, threadRepository, threadResponseRepository, + threadResponseExplainRepository, queryService, }: { telemetry: Telemetry; wrenAIAdaptor: IWrenAIAdaptor; + ibisAdaptor: IIbisAdaptor; deployService: IDeployService; projectService: IProjectService; viewRepository: IViewRepository; threadRepository: IThreadRepository; threadResponseRepository: IThreadResponseRepository; + threadResponseExplainRepository: IThreadResponseExplainRepository; queryService: IQueryService; }) { this.wrenAIAdaptor = wrenAIAdaptor; + this.ibisAdaptor = ibisAdaptor; this.deployService = deployService; this.projectService = projectService; this.viewRepository = viewRepository; this.threadRepository = threadRepository; this.threadResponseRepository = threadResponseRepository; + this.threadResponseExplainRepository = threadResponseExplainRepository; this.telemetry = telemetry; this.queryService = queryService; - this.backgroundTracker = new BackgroundTracker({ + this.threadResponseBackgroundTracker = new ThreadResponseBackgroundTracker({ + telemetry, + wrenAIAdaptor, + threadResponseRepository, + }); + this.regeneratedBackgroundTracker = new ThreadResponseBackgroundTracker({ + telemetry, + wrenAIAdaptor, + threadResponseRepository, + isRegenerated: true, + }); + this.explainBackgroundTracker = new ThreadResponseExplainBackgroundTracker({ telemetry, wrenAIAdaptor, threadResponseRepository, + threadResponseExplainRepository, }); } + public async getExplainDetailsByThread( + threadId: number, + ): Promise { + return await this.threadResponseExplainRepository.findAllByThread(threadId); + } public async initialize() { - // list thread responses from database - // filter status not finalized and put them into background tracker - const threadResponses = await this.threadResponseRepository.findAll(); - const unfininshedThreadResponses = threadResponses.filter( - (threadResponse) => - !isFinalized(threadResponse.status as AskResultStatus), - ); - logger.info( - `Initialization: adding unfininshed thread responses (total: ${unfininshedThreadResponses.length}) to background tracker`, - ); - for (const threadResponse of unfininshedThreadResponses) { - this.backgroundTracker.addTask(threadResponse); - } + const initializeThreadResponseBT = async () => { + // list thread responses from database + // filter status not finalized and put them into background tracker + const threadResponses = await this.threadResponseRepository.findAll(); + const unfinishedThreadResponses = threadResponses.filter( + (threadResponse) => + !this.threadResponseBackgroundTracker.isFinalized( + threadResponse.status as AskResultStatus, + ), + ); + logger.info( + `Initialization: adding unfinished thread responses (total: ${unfinishedThreadResponses.length}) to background tracker`, + ); + for (const threadResponse of unfinishedThreadResponses) { + if (threadResponse.corrections !== null) { + this.regeneratedBackgroundTracker.addTask(threadResponse); + continue; + } + this.threadResponseBackgroundTracker.addTask(threadResponse); + } + }; + + const initializeThreadResponseExplainBT = async () => { + // list thread responses from database + // filter status not finalized and put them into background tracker + const threadResponseExplains = + await this.threadResponseExplainRepository.findAll(); + const unfinishedThreadResponseExplains = threadResponseExplains.filter( + (threadResponseExplain) => + !this.explainBackgroundTracker.isFinalized( + threadResponseExplain.status as ExplainPipelineStatus, + ), + ); + logger.info( + `Initialization: adding unfinished explain job (total: ${unfinishedThreadResponseExplains.length}) to background tracker`, + ); + for (const threadResponseExplain of unfinishedThreadResponseExplains) { + this.explainBackgroundTracker.addTask(threadResponseExplain); + } + }; + + await Promise.all([ + initializeThreadResponseBT(), + initializeThreadResponseExplainBT(), + ]); } /** @@ -364,7 +346,7 @@ export class AskingService implements IAskingService { }); // 3. put the task into background tracker - this.backgroundTracker.addTask(threadResponse); + this.threadResponseBackgroundTracker.addTask(threadResponse); // return the task id return thread; @@ -435,12 +417,111 @@ export class AskingService implements IAskingService { }); // 3. put the task into background tracker - this.backgroundTracker.addTask(threadResponse); + this.threadResponseBackgroundTracker.addTask(threadResponse); // return the task id return threadResponse; } + public async createCorrectedThreadResponse( + threadId: number, + input: CorrectedDetailTaskInput, + ): Promise { + const thread = await this.threadRepository.findOneBy({ + id: threadId, + }); + + const baseThreadResponse = await this.threadResponseRepository.findOneBy({ + id: input.responseId, + }); + + if (!baseThreadResponse) { + throw new Error(`Thread response ${input.responseId} not found`); + } + + const correctionsMap = groupBy(input.corrections, 'stepIndex'); + const response = await this.wrenAIAdaptor.regenerateAskDetail({ + description: baseThreadResponse.detail.description, + steps: baseThreadResponse.detail.steps.map((step, index) => ({ + sql: step.sql, + summary: step.summary, + cte_name: step.cteName, + corrections: (correctionsMap[index] || []).map((item) => ({ + before: { + type: ExplanationType[item.type], + value: item.reference, + }, + after: { + // Only NL_EXPRESSION is supported for now + type: ExpressionType.NL_EXPRESSION, + value: item.correction, + }, + })), + })), + }); + + const threadResponse = await this.threadResponseRepository.createOne({ + threadId: thread.id, + queryId: response.queryId, + question: baseThreadResponse.question, + summary: baseThreadResponse.summary, + status: AskResultStatus.UNDERSTANDING, + corrections: input.corrections.map((item) => ({ + id: item.id, + type: item.type, + referenceNum: item.referenceNum, + correction: item.correction, + })), + }); + + // 3. put the task into background tracker + this.regeneratedBackgroundTracker.addTask(threadResponse); + + // return the task id + return threadResponse; + } + + public async createThreadResponseExplain(threadResponseId: number) { + const threadResponse = await this.threadResponseRepository.findOneBy({ + id: threadResponseId, + }); + if (!threadResponse || threadResponse.status != AskResultStatus.FINISHED) { + throw new Error( + `Can not create explain job for threadResponseId: ${threadResponseId} `, + ); + } + logger.debug('Getting thread response analysis'); + const analysisWithIds = + await this.getThreadResponseAnalysis(threadResponse); + + // compose analysis result with step for explain + const question = threadResponse.question; + const stepAnalysisResult = Object.entries(threadResponse.detail.steps).map( + ([idx, step]) => { + return { + sql: step.sql, + summary: step.summary, + cte_name: step.cteName, + sql_analysis_results: analysisWithIds[idx], + } as StepAnalysisResult; + }, + ); + + // create explain job + const { queryId } = await this.wrenAIAdaptor.explain( + question, + stepAnalysisResult, + ); + // create explain record and add to background tracker + const explain = await this.threadResponseExplainRepository.createOne({ + threadResponseId: threadResponseId, + queryId, + analysis: analysisWithIds, + }); + this.explainBackgroundTracker.addTask(explain); + return explain; + } + public async getResponsesWithThread(threadId: number) { return this.threadResponseRepository.getResponsesWithThread(threadId); } @@ -545,4 +626,18 @@ export class AskingService implements IAskingService { }, }); } + + private async getThreadResponseAnalysis(threadResponse: ThreadResponse) { + const project = await this.projectService.getCurrentProject(); + const deployment = await this.deployService.getLastDeployment(project.id); + const manifest = deployment.manifest; + const sqls = threadResponse.detail?.steps?.map((step, index) => { + const isLastStep = index === threadResponse.detail.steps.length - 1; + return isLastStep + ? format(constructCteSql(threadResponse.detail.steps)) + : format(step.sql); + }); + const analysis = await this.ibisAdaptor.analysisSqls(manifest, sqls); + return addAutoIncrementId(analysis); + } } diff --git a/wren-ui/src/apollo/server/types/context.ts b/wren-ui/src/apollo/server/types/context.ts index 8e0118701..864b9381e 100644 --- a/wren-ui/src/apollo/server/types/context.ts +++ b/wren-ui/src/apollo/server/types/context.ts @@ -18,6 +18,7 @@ import { IMDLService, IProjectService, } from '../services'; +import { IThreadResponseExplainRepository } from '../repositories/threadResponseExplainRepository'; export interface IContext { config: IConfig; @@ -44,4 +45,5 @@ export interface IContext { viewRepository: IViewRepository; deployRepository: IDeployLogRepository; schemaChangeRepository: ISchemaChangeRepository; + threadResponseExplainRepository: IThreadResponseExplainRepository; } diff --git a/wren-ui/src/apollo/server/utils/error.ts b/wren-ui/src/apollo/server/utils/error.ts index ad5c0e42b..585a0d440 100644 --- a/wren-ui/src/apollo/server/utils/error.ts +++ b/wren-ui/src/apollo/server/utils/error.ts @@ -8,6 +8,10 @@ export enum GeneralErrorCodes { NO_RELEVANT_DATA = 'NO_RELEVANT_DATA', NO_RELEVANT_SQL = 'NO_RELEVANT_SQL', + // Background tracker errors + MERGE_THREAD_RESPONSE_ERROR = 'MERGE_THREAD_RESPONSE_ERROR', + OLD_VERSION = 'OLD_VERSION', + // Exception error for AI service (e.g., network connection error) AI_SERVICE_UNDEFINED_ERROR = 'OTHERS', @@ -45,6 +49,12 @@ export const errorMessages = { [GeneralErrorCodes.NO_RELEVANT_SQL]: 'No relevant SQL found for the given query. Please check your query and try again.', + // Background tracker errors + [GeneralErrorCodes.MERGE_THREAD_RESPONSE_ERROR]: + 'Error occurred while merging thread response', + [GeneralErrorCodes.OLD_VERSION]: + 'Question asked before v0.8.0. Click "Retry" to generate manually', + // Connector errors [GeneralErrorCodes.CONNECTION_ERROR]: 'Can not connect to data source', // duckdb @@ -79,6 +89,9 @@ export const shortMessages = { [GeneralErrorCodes.MISLEADING_QUERY]: 'Misleading query', [GeneralErrorCodes.NO_RELEVANT_DATA]: 'No relevant data', [GeneralErrorCodes.NO_RELEVANT_SQL]: 'No relevant SQL', + [GeneralErrorCodes.MERGE_THREAD_RESPONSE_ERROR]: + 'Merge thread response error', + [GeneralErrorCodes.OLD_VERSION]: 'Unable to show references', [GeneralErrorCodes.CONNECTION_ERROR]: 'Failed to connect', [GeneralErrorCodes.IBIS_SERVER_ERROR]: 'Ibis server error', [GeneralErrorCodes.INIT_SQL_ERROR]: 'Invalid initializing SQL', diff --git a/wren-ui/src/apollo/server/utils/helper.ts b/wren-ui/src/apollo/server/utils/helper.ts index 4cb6bc2d0..6d790c70d 100644 --- a/wren-ui/src/apollo/server/utils/helper.ts +++ b/wren-ui/src/apollo/server/utils/helper.ts @@ -1,3 +1,5 @@ +import { invert } from 'lodash'; + /** * @function * @description Retrieve json without error @@ -9,3 +11,13 @@ export const parseJson = (data) => { return false; } }; + +export const reverseEnum = >( + enumObject: E, +) => + invert(enumObject) as { + [V in E[keyof E]]: Extract< + { [K in keyof E]: [K, E[K]] }[keyof E], + [any, V] + >[0]; + }; diff --git a/wren-ui/src/apollo/server/utils/index.ts b/wren-ui/src/apollo/server/utils/index.ts index 8db524806..eda5f34a6 100644 --- a/wren-ui/src/apollo/server/utils/index.ts +++ b/wren-ui/src/apollo/server/utils/index.ts @@ -6,3 +6,4 @@ export * from './docker'; export * from './model'; export * from './helper'; export * from './regex'; +export * from './services'; diff --git a/wren-ui/src/apollo/server/utils/services.ts b/wren-ui/src/apollo/server/utils/services.ts new file mode 100644 index 000000000..eb9555afc --- /dev/null +++ b/wren-ui/src/apollo/server/utils/services.ts @@ -0,0 +1,61 @@ +export function addAutoIncrementId(query: any, startId = 1): any { + // Add an auto-incrementing id to each object in the query + const addId = (query) => { + if (typeof query !== 'object' || query === null) { + return query; + } + + if (Array.isArray(query)) { + const newArr = []; + for (let ele of query) { + ele = addId(ele); + newArr.push(ele); + } + return newArr; + } + + const newObj = { ...query, id: id++ }; + + for (const key in newObj) { + if (newObj.hasOwnProperty(key) && typeof newObj[key] === 'object') { + if (key === 'properties' || key === 'nodeLocation') { + continue; + } + newObj[key] = addId(newObj[key]); + } + } + return newObj; + }; + let id = startId; + return addId(query); +} + +export function findAnalysisById(analysis: any, id: number) { + if ( + analysis && + analysis.hasOwnProperty('id') && + Number(analysis.id) === Number(id) + ) { + return analysis; + } + + if (Array.isArray(analysis)) { + for (const ele of analysis) { + const result = findAnalysisById(ele, id); + if (result) { + return result; + } + } + } + + for (const key in analysis) { + if (analysis.hasOwnProperty(key) && typeof analysis[key] == 'object') { + const result = findAnalysisById(analysis[key], id); + if (result) { + return result; + } + } + } + + return null; +} diff --git a/wren-ui/src/apollo/server/utils/tests/services.test.ts b/wren-ui/src/apollo/server/utils/tests/services.test.ts new file mode 100644 index 000000000..4e12c48b1 --- /dev/null +++ b/wren-ui/src/apollo/server/utils/tests/services.test.ts @@ -0,0 +1,254 @@ +import { addAutoIncrementId, findAnalysisById } from '../services'; + +describe('addAutoIncrementId', () => { + it('should add auto-incrementing ids to an array of objects', () => { + const query = [{ name: 'Item 1' }, { name: 'Item 2' }, { name: 'Item 3' }]; + + const result = addAutoIncrementId(query); + + expect(result).toEqual([ + { id: 1, name: 'Item 1' }, + { id: 2, name: 'Item 2' }, + { id: 3, name: 'Item 3' }, + ]); + }); + + it('should add auto-incrementing ids to nested objects', () => { + const query = { + name: 'Parent', + child: { + name: 'Child', + grandchild: { + name: 'Grandchild', + }, + }, + }; + + const result = addAutoIncrementId(query); + + expect(result).toEqual({ + id: 1, + name: 'Parent', + child: { + id: 2, + name: 'Child', + grandchild: { + id: 3, + name: 'Grandchild', + }, + }, + }); + }); + + it('should not modify the original query', () => { + const query = { name: 'Item' }; + + addAutoIncrementId(query); + + expect(query).toEqual({ name: 'Item' }); + }); + + it('should get expected result with query analysis result', () => { + const analysis = [ + { + filter: { + left: { + node: '(custkey = 1)', + type: 'EXPR', + }, + right: { + node: "(name = 'tom')", + type: 'EXPR', + }, + type: 'AND', + }, + groupByKeys: [['c.name']], + relation: { + criteria: 'ON (c.custkey = o.custkey)', + left: { + alias: 'c', + tableName: 'Customer', + type: 'TABLE', + }, + right: { + alias: 'o', + tableName: 'Orders', + type: 'TABLE', + }, + type: 'INNER_JOIN', + }, + selectItems: [ + { + alias: null, + expression: 'c.name', + properties: { + includeFunctionCall: 'false', + includeMathematicalOperation: 'false', + }, + }, + { + alias: null, + expression: 'count(*)', + properties: { + includeFunctionCall: 'true', + includeMathematicalOperation: 'false', + }, + }, + ], + sortings: [ + { + expression: 'c.name', + ordering: 'DESCENDING', + }, + ], + }, + ]; + const expected = [ + { + id: 1, + filter: { + id: 2, + left: { + id: 3, + node: '(custkey = 1)', + type: 'EXPR', + }, + right: { + id: 4, + node: "(name = 'tom')", + type: 'EXPR', + }, + type: 'AND', + }, + groupByKeys: [['c.name']], + relation: { + id: 5, + criteria: 'ON (c.custkey = o.custkey)', + left: { + id: 6, + alias: 'c', + tableName: 'Customer', + type: 'TABLE', + }, + right: { + id: 7, + alias: 'o', + tableName: 'Orders', + type: 'TABLE', + }, + type: 'INNER_JOIN', + }, + selectItems: [ + { + id: 8, + alias: null, + expression: 'c.name', + properties: { + includeFunctionCall: 'false', + includeMathematicalOperation: 'false', + }, + }, + { + id: 9, + alias: null, + expression: 'count(*)', + properties: { + includeFunctionCall: 'true', + includeMathematicalOperation: 'false', + }, + }, + ], + sortings: [ + { + id: 10, + expression: 'c.name', + ordering: 'DESCENDING', + }, + ], + }, + ]; + + const result = addAutoIncrementId(analysis); + expect(result).toEqual(expected); + }); + + it('should get expected result findAnalysisById', () => { + const analysisWithIds = [ + { + id: 1, + filter: { + id: 2, + left: { + id: 3, + node: '(custkey = 1)', + type: 'EXPR', + }, + right: { + id: 4, + node: "(name = 'tom')", + type: 'EXPR', + }, + type: 'AND', + }, + groupByKeys: [['c.name']], + relation: { + id: 5, + criteria: 'ON (c.custkey = o.custkey)', + left: { + id: 6, + alias: 'c', + tableName: 'Customer', + type: 'TABLE', + }, + right: { + id: 7, + alias: 'o', + tableName: 'Orders', + type: 'TABLE', + }, + type: 'INNER_JOIN', + }, + selectItems: [ + { + id: 8, + alias: null, + expression: 'c.name', + properties: { + includeFunctionCall: 'false', + includeMathematicalOperation: 'false', + }, + }, + { + id: 9, + alias: null, + expression: 'count(*)', + properties: { + includeFunctionCall: 'true', + includeMathematicalOperation: 'false', + }, + }, + ], + sortings: [ + { + id: 10, + expression: 'c.name', + ordering: 'DESCENDING', + }, + ], + }, + ]; + expect(findAnalysisById(analysisWithIds, 1)).toEqual(analysisWithIds[0]); + expect(findAnalysisById(analysisWithIds, 2)).toEqual( + analysisWithIds[0].filter, + ); + expect(findAnalysisById(analysisWithIds, 4)).toEqual( + analysisWithIds[0].filter.right, + ); + expect(findAnalysisById(analysisWithIds, 9)).toEqual( + analysisWithIds[0].selectItems[1], + ); + expect(findAnalysisById(analysisWithIds, 10)).toEqual( + analysisWithIds[0].sortings[0], + ); + }); +}); diff --git a/wren-ui/src/components/editor/CodeBlock.tsx b/wren-ui/src/components/editor/CodeBlock.tsx index 169fa0e2c..77bd588a4 100644 --- a/wren-ui/src/components/editor/CodeBlock.tsx +++ b/wren-ui/src/components/editor/CodeBlock.tsx @@ -1,4 +1,4 @@ -import { useEffect } from 'react'; +import React, { useEffect } from 'react'; import { Typography } from 'antd'; import styled from 'styled-components'; import '@/components/editor/AceEditor'; @@ -19,16 +19,19 @@ const Block = styled.div<{ inline?: boolean; maxHeight?: string }>` : `background: var(--gray-1); padding: 8px;`} .adm-code-wrap { + position: relative; + padding-bottom: 2px; ${(props) => (props.inline ? '' : 'overflow: auto;')} ${(props) => (props.maxHeight ? `max-height: ${props.maxHeight}px;` : ``)} } .adm-code-line { display: block; + height: 22px; &-number { user-select: none; display: inline-block; - min-width: 14px; + min-width: 17px; text-align: right; margin-right: 1em; color: var(--gray-6); @@ -55,6 +58,7 @@ interface Props { loading?: boolean; maxHeight?: string; showLineNumbers?: boolean; + highlightSlot?: React.ReactNode; } const addThemeStyleManually = (cssText) => { @@ -69,21 +73,37 @@ const addThemeStyleManually = (cssText) => { } }; -export default function CodeBlock(props: Props) { - const { code, copyable, maxHeight, inline, loading, showLineNumbers } = props; +export const getTokenizer = () => { const { ace } = window as any; const { Tokenizer } = ace.require('ace/tokenizer'); const { SqlHighlightRules } = ace.require(`ace/mode/sql_highlight_rules`); const rules = new SqlHighlightRules(); const tokenizer = new Tokenizer(rules.getRules()); + return (line) => { + return tokenizer.getLineTokens(line).tokens; + }; +}; + +export default function CodeBlock(props: Props) { + const { + code, + copyable, + maxHeight, + inline, + loading, + showLineNumbers, + highlightSlot, + } = props; useEffect(() => { + const { ace } = window as any; const { cssText } = ace.require('ace/theme/tomorrow'); addThemeStyleManually(cssText); }, []); + const tokenize = getTokenizer(); const lines = (code || '').split('\n').map((line, index) => { - const tokens = tokenizer.getLineTokens(line).tokens; + const tokens = tokenize(line); const children = tokens.map((token, index) => { const classNames = token.type.split('.').map((name) => `ace_${name}`); return ( @@ -92,7 +112,6 @@ export default function CodeBlock(props: Props) { ); }); - return ( {showLineNumbers && ( @@ -112,6 +131,7 @@ export default function CodeBlock(props: Props) {
{lines} + {highlightSlot} {copyable && {code}}
diff --git a/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx b/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx deleted file mode 100644 index 4d1910b9b..000000000 --- a/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx +++ /dev/null @@ -1,148 +0,0 @@ -import { useState } from 'react'; -import Link from 'next/link'; -import { Col, Button, Row, Skeleton, Typography } from 'antd'; -import styled from 'styled-components'; -import { Path } from '@/utils/enum'; -import CheckCircleFilled from '@ant-design/icons/CheckCircleFilled'; -import QuestionCircleOutlined from '@ant-design/icons/QuestionCircleOutlined'; -import SaveOutlined from '@ant-design/icons/SaveOutlined'; -import FileDoneOutlined from '@ant-design/icons/FileDoneOutlined'; -import StepContent from '@/components/pages/home/promptThread/StepContent'; - -const { Title, Text } = Typography; - -const StyledAnswer = styled(Typography)` - position: relative; - border: 1px var(--gray-4) solid; - border-radius: 4px; - - .adm-answer-title { - font-weight: 500; - position: absolute; - top: -13px; - left: 8px; - background: white; - } -`; - -const StyledQuestion = styled(Row)` - padding: 4px 8px; - border-radius: 4px; - color: var(--gray-6); - background-color: var(--gray-3); - margin-bottom: 8px; - font-size: 14px; - - &:hover { - background-color: var(--gray-4) !important; - cursor: pointer; - } -`; - -interface Props { - loading: boolean; - question: string; - description: string; - answerResultSteps: Array<{ - summary: string; - sql: string; - }>; - fullSql: string; - threadResponseId: number; - onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; - isLastThreadResponse: boolean; - onTriggerScrollToBottom: () => void; - summary: string; - view?: { - id: number; - displayName: string; - }; -} - -export default function AnswerResult(props: Props) { - const { - loading, - question, - description, - answerResultSteps, - fullSql, - threadResponseId, - isLastThreadResponse, - onOpenSaveAsViewModal, - onTriggerScrollToBottom, - summary, - view, - } = props; - - const isViewSaved = !!view; - - const [ellipsis, setEllipsis] = useState(true); - - return ( - - setEllipsis(!ellipsis)}> - - - Question: - - - - {question} - - - - - {summary} - - - - - Summary - -
{description}
- {(answerResultSteps || []).map((step, index) => ( - - ))} -
- {isViewSaved ? ( -
- - Generated from saved view{' '} - - {view.displayName} - -
- ) : ( - - )} -
- ); -} diff --git a/wren-ui/src/components/pages/home/promptThread/index.tsx b/wren-ui/src/components/pages/home/promptThread/index.tsx deleted file mode 100644 index d066ab8ed..000000000 --- a/wren-ui/src/components/pages/home/promptThread/index.tsx +++ /dev/null @@ -1,116 +0,0 @@ -import { useEffect, useRef } from 'react'; -import { Alert, Divider } from 'antd'; -import styled from 'styled-components'; -import AnswerResult from './AnswerResult'; -import { makeIterable } from '@/utils/iteration'; -import { - AskingTaskStatus, - DetailedThread, -} from '@/apollo/client/graphql/__types__'; - -interface Props { - data: DetailedThread; - onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; -} - -const StyledPromptThread = styled.div` - width: 768px; - margin-left: auto; - margin-right: auto; - - h4.ant-typography { - margin-top: 10px; - } - - .ant-typography pre { - border: none; - border-radius: 4px; - } - - button { - vertical-align: middle; - } -`; - -const AnswerResultTemplate = ({ - index, - id, - status, - question, - detail, - error, - onOpenSaveAsViewModal, - onTriggerScrollToBottom, - data, - summary, -}) => { - const lastResponseId = data[data.length - 1].id; - const isLastThreadResponse = id === lastResponseId; - - return ( -
- {index > 0 && } - {error ? ( - - ) : ( - - )} -
- ); -}; - -const AnswerResultIterator = makeIterable(AnswerResultTemplate); - -export default function PromptThread(props: Props) { - const { data, onOpenSaveAsViewModal } = props; - const divRef = useRef(null); - - const triggerScrollToBottom = () => { - const contentLayout = divRef.current.parentElement; - const lastChild = divRef.current.lastElementChild as HTMLElement; - const lastChildElement = lastChild.lastElementChild as HTMLElement; - - if ( - contentLayout.clientHeight < - lastChild.offsetTop + lastChild.clientHeight - ) { - contentLayout.scrollTo({ - top: lastChildElement.offsetTop, - behavior: 'smooth', - }); - } - }; - - useEffect(() => { - if (divRef.current && data?.responses.length > 0) { - triggerScrollToBottom(); - } - }, [divRef, data]); - - return ( - - - - ); -} diff --git a/wren-ui/src/components/pages/home/thread/AnswerResult.tsx b/wren-ui/src/components/pages/home/thread/AnswerResult.tsx new file mode 100644 index 000000000..7b3598d00 --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/AnswerResult.tsx @@ -0,0 +1,245 @@ +import { useState } from 'react'; +import Link from 'next/link'; +import { + Col, + Button, + Row, + Skeleton, + Typography, + Divider, + Tag, + Alert, +} from 'antd'; +import styled from 'styled-components'; +import { Path } from '@/utils/enum'; +import CheckCircleFilled from '@ant-design/icons/CheckCircleFilled'; +import ShareAltOutlined from '@ant-design/icons/ShareAltOutlined'; +import QuestionCircleOutlined from '@ant-design/icons/QuestionCircleOutlined'; +import SaveOutlined from '@ant-design/icons/SaveOutlined'; +import FileDoneOutlined from '@ant-design/icons/FileDoneOutlined'; +import StepContent from '@/components/pages/home/thread/StepContent'; +import FeedbackLayout from '@/components/pages/home/thread/feedback'; +import { ThreadResponse } from '@/apollo/client/graphql/__types__'; +import { makeIterable } from '@/utils/iteration'; +import { getReferenceIcon } from '@/components/pages/home/thread/feedback/utils'; +import { getIsAskingFinished } from '@/hooks/useAskPrompt'; + +const { Title, Text } = Typography; + +const Wrapper = styled.div` + width: 100%; + flex-shrink: 0; + flex-grow: 1; +`; + +const StyledAnswer = styled(Typography)` + position: relative; + border: 1px var(--gray-4) solid; + border-radius: 4px; + + .adm-answer-title { + font-weight: 500; + position: absolute; + top: -13px; + left: 8px; + background: white; + } +`; + +const StyledQuestion = styled(Row)` + padding: 4px 8px; + border-radius: 4px; + color: var(--gray-6); + background-color: var(--gray-3); + margin-bottom: 8px; + font-size: 14px; + + &:hover { + background-color: var(--gray-4) !important; + cursor: pointer; + } +`; + +interface Props { + threadResponse: ThreadResponse; + isLastThreadResponse: boolean; + onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; + onTriggerScrollToBottom: () => void; + onSubmitReviewDrawer: (variables: any) => Promise; + onTriggerThreadResponseExplain: (variables: any) => Promise; +} + +const CorrectionTemplate = ({ referenceNum, type, correction }) => { + return ( +
+ + {getReferenceIcon(type)} + {referenceNum} + + {correction} +
+ ); +}; +const CorrectionIterator = makeIterable(CorrectionTemplate); +const RegenerateInformation = (props: ThreadResponse) => { + const { question, corrections } = props; + const [collapse, setCollapse] = useState(false); + const collapseText = collapse ? 'Hide' : 'Show'; + return ( +
+
+ +
Regenerated answer from
+ setCollapse(!collapse)} + > + {collapseText} feedbacks + +
+ + {question} + + {collapse && ( +
+ +
+ +
Feedbacks
+
+ +
+ )} +
+ ); +}; + +const QuestionInformation = (props) => { + const { question } = props; + const [ellipsis, setEllipsis] = useState(true); + return ( + setEllipsis(!ellipsis)}> + + + Question: + + + + {question} + + + + ); +}; + +export default function AnswerResult(props: Props) { + const { + threadResponse, + isLastThreadResponse, + onOpenSaveAsViewModal, + onTriggerScrollToBottom, + onSubmitReviewDrawer, + onTriggerThreadResponseExplain, + } = props; + + const { + id: responseId, + summary, + status, + corrections, + error, + } = threadResponse; + const { + view, + steps, + description, + sql: fullSql, + } = threadResponse?.detail || {}; + + const isViewSaved = !!view; + const isRegenerated = !!corrections; + const loading = !getIsAskingFinished(status); + + const Information = isRegenerated + ? RegenerateInformation + : QuestionInformation; + + return ( +
+ {error ? ( + + ) : ( + + + + + {summary} + + + } + bodySlot={ + + + + + Summary + +
{description}
+ {(steps || []).map((step, index) => ( + + ))} +
+ {isViewSaved ? ( +
+ + Generated from saved view{' '} + + {view.displayName} + +
+ ) : ( + + )} +
+ } + threadResponse={threadResponse} + onSubmitReviewDrawer={onSubmitReviewDrawer} + onTriggerThreadResponseExplain={onTriggerThreadResponseExplain} + /> +
+ )} +
+ ); +} diff --git a/wren-ui/src/components/pages/home/promptThread/CollapseContent.tsx b/wren-ui/src/components/pages/home/thread/CollapseContent.tsx similarity index 81% rename from wren-ui/src/components/pages/home/promptThread/CollapseContent.tsx rename to wren-ui/src/components/pages/home/thread/CollapseContent.tsx index 967ed8d6b..ef1991dc0 100644 --- a/wren-ui/src/components/pages/home/promptThread/CollapseContent.tsx +++ b/wren-ui/src/components/pages/home/thread/CollapseContent.tsx @@ -1,3 +1,4 @@ +import { useMemo, useState } from 'react'; import Image from 'next/image'; import dynamic from 'next/dynamic'; import { Button, Switch, Typography, Empty } from 'antd'; @@ -10,6 +11,8 @@ import PreviewData from '@/components/dataPreview/PreviewData'; import { PreviewDataMutationResult } from '@/apollo/client/graphql/home.generated'; import { DATA_SOURCE_OPTIONS } from '@/components/pages/setup/utils'; import { NativeSQLResult } from '@/hooks/useNativeSQL'; +import SQLHighlight from '@/components/pages/home/thread/feedback/SQLHighlight'; +import { useFeedbackContext } from './feedback'; const CodeBlock = dynamic(() => import('@/components/editor/CodeBlock'), { ssr: false, @@ -36,6 +39,7 @@ export interface Props { onCloseCollapse: () => void; onCopyFullSQL?: () => void; sql: string; + stepIndex: number; previewDataResult: PreviewDataMutationResult; attributes: { stepNumber: number; @@ -53,6 +57,7 @@ export default function CollapseContent(props: Props) { onCloseCollapse, onCopyFullSQL, sql, + stepIndex, previewDataResult, attributes, onChangeNativeSQL, @@ -61,11 +66,27 @@ export default function CollapseContent(props: Props) { const { hasNativeSQL, dataSourceType } = nativeSQLResult; const showNativeSQL = Boolean(attributes.isLastStep) && hasNativeSQL; + const [isNativeSQL, setIsNativeSQL] = useState(false); + const { references, sqlTargetReference, onHighlightToReferences } = + useFeedbackContext(); + const currentStepReferences = useMemo(() => { + return (references || []).filter((item) => item.stepIndex === stepIndex); + }, [references]); const sqls = nativeSQLResult.nativeSQLMode && nativeSQLResult.loading === false ? nativeSQLResult.data : sql; + const hasReferences = references && references.length > 0; + + const onSwitchChange = (checked: boolean) => { + setIsNativeSQL(checked); + onChangeNativeSQL(checked); + }; + + const onHighlightHover = (reference) => { + onHighlightToReferences && onHighlightToReferences(reference); + }; return ( <> @@ -95,7 +116,7 @@ export default function CollapseContent(props: Props) { unCheckedChildren={} className="mr-2" size="small" - onChange={onChangeNativeSQL} + onChange={onSwitchChange} loading={nativeSQLResult.loading} /> @@ -109,6 +130,17 @@ export default function CollapseContent(props: Props) { showLineNumbers maxHeight="300" loading={nativeSQLResult.loading} + highlightSlot={ + hasReferences && + !isNativeSQL && ( + + ) + } /> )} diff --git a/wren-ui/src/components/pages/home/promptThread/StepContent.tsx b/wren-ui/src/components/pages/home/thread/StepContent.tsx similarity index 96% rename from wren-ui/src/components/pages/home/promptThread/StepContent.tsx rename to wren-ui/src/components/pages/home/thread/StepContent.tsx index 905c86f75..6f38394ed 100644 --- a/wren-ui/src/components/pages/home/promptThread/StepContent.tsx +++ b/wren-ui/src/components/pages/home/thread/StepContent.tsx @@ -4,7 +4,7 @@ import FunctionOutlined from '@ant-design/icons/FunctionOutlined'; import { BinocularsIcon } from '@/utils/icons'; import CollapseContent, { Props as CollapseContentProps, -} from '@/components/pages/home/promptThread/CollapseContent'; +} from '@/components/pages/home/thread/CollapseContent'; import useAnswerStepContent from '@/hooks/useAnswerStepContent'; import { nextTick } from '@/utils/time'; @@ -106,8 +106,9 @@ export default function StepContent(props: Props) { )} + stepIndex={stepIndex} key={`collapse-${stepNumber}`} attributes={{ stepNumber, isLastStep }} /> diff --git a/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx b/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx new file mode 100644 index 000000000..7af3cb242 --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx @@ -0,0 +1,67 @@ +import clsx from 'clsx'; +import { useMemo } from 'react'; +import styled from 'styled-components'; +import { Button, Popconfirm } from 'antd'; +import { FileTextOutlined } from '@ant-design/icons'; +import { Reference } from './utils'; + +const StyledFeedbackSideFloat = styled.div` + position: relative; + + .feedbackSideFloat-title { + position: absolute; + top: -14px; + padding: 0 4px; + } +`; + +interface Props { + className?: string; + references: Reference[]; + onOpenReviewDrawer: () => void; + onResetAllCorrectionPrompts: () => void; +} + +export default function FeedbackSideFloat(props: Props) { + const { + className, + references, + onOpenReviewDrawer, + onResetAllCorrectionPrompts, + } = props; + + const changedReferences = useMemo(() => { + return (references || []).filter((item) => !!item.correctionPrompt); + }, [references]); + + if (changedReferences.length === 0) return null; + return ( + +
+ Pending feedbacks +
+
+ + + + +
+
+ ); +} diff --git a/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx new file mode 100644 index 000000000..42a0fab95 --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx @@ -0,0 +1,339 @@ +import clsx from 'clsx'; +import { groupBy } from 'lodash'; +import { + useMemo, + useState, + forwardRef, + useImperativeHandle, + useRef, + useEffect, +} from 'react'; +import styled from 'styled-components'; +import { Tag, Typography, Button, Input, Alert } from 'antd'; +import { + EditOutlined, + CloseCircleFilled, + ReloadOutlined, + InfoCircleFilled, +} from '@ant-design/icons'; +import { QuoteIcon } from '@/utils/icons'; +import { makeIterable } from '@/utils/iteration'; +import { + REFERENCE_ORDERS, + Reference, + getReferenceIcon, + getReferenceName, +} from './utils'; +import { ERROR_CODES } from '@/utils/errorHandler'; + +const StyledReferenceSideFloat = styled.div` + position: relative; + + .referenceSideFloat__summary { + cursor: pointer; + padding: 2px 0; + &:hover, + &.isActive { + background-color: rgba(250, 219, 20, 0.3); + } + } + + .referenceSideFloat-title { + position: absolute; + top: -14px; + padding: 0 4px; + } +`; + +const StyledAlert = styled(Alert)` + padding: 8px 12px 12px; + .ant-alert-icon { + font-size: 14px; + margin-right: 8px; + margin-top: 4px; + } + .ant-alert-message { + font-size: 14px; + line-height: 14px; + margin-top: 4px; + margin-bottom: 8px; + } + .ant-alert-description { + font-size: 12px; + line-height: 16px; + color: var(--gray-8); + } +`; + +interface Props { + references: Reference[]; + error?: Record; + onSaveCorrectionPrompt: (id: string, value: string) => void; + onTriggerExplanation: () => void; + onHoverReference?: (reference: Reference) => void; +} + +const ReferenceSummaryTemplate = ({ + summary, + type, + referenceNum, + correctionPrompt, +}) => { + const isRevise = !!correctionPrompt; + return ( +
+ + {getReferenceIcon(type)} + {referenceNum} + + + {summary} + +
+ ); +}; + +const GroupReferenceTemplate = ({ + name, + type, + data, + index, + saveCorrectionPrompt, + hoverReference, +}) => { + if (!data.length) return null; + return ( +
0 })}> + + {getReferenceIcon(type)}{' '} + {name} + + +
+ ); +}; + +const ReferenceTemplate = ({ + saveCorrectionPrompt, + hoverReference, + ...reference +}) => { + const { type, summary, referenceId, referenceNum, correctionPrompt } = + reference; + const [isEdit, setIsEdit] = useState(false); + const [value, setValue] = useState(correctionPrompt); + const isRevise = !!correctionPrompt; + + const openEdit = () => { + setIsEdit(!isEdit); + }; + + const handleEdit = () => { + saveCorrectionPrompt(referenceId, value); + setIsEdit(false); + setValue(''); + }; + + return ( +
+
+ + {getReferenceIcon(type)} + {referenceNum} + +
+
+ + hoverReference(reference)} + onMouseLeave={() => hoverReference()} + > + {summary} + + + {isRevise ? ( + '(feedback suggested)' + ) : ( + + )} + + + {isEdit && ( +
+ + setValue(e.target.value)} + onPressEnter={handleEdit} + /> + + +
+ )} +
+
+ ); +}; + +const ReferenceSummaryIterator = makeIterable(ReferenceSummaryTemplate); +const GroupReferenceIterator = makeIterable(GroupReferenceTemplate); +const ReferenceIterator = makeIterable(ReferenceTemplate); + +const References = (props: Props & { targetReference?: Reference }) => { + const { + references, + onSaveCorrectionPrompt, + targetReference, + onHoverReference, + } = props; + const $scroller = useRef(null); + const referencesByGroup = groupBy(references, 'type'); + const resources = REFERENCE_ORDERS.map((type) => ({ + type, + name: getReferenceName(type), + data: referencesByGroup[type] || [], + })); + + useEffect(() => { + if ($scroller.current) { + const $element = $scroller.current; + const $targets = $element.querySelectorAll(`.isActive`); + $targets.forEach((target) => { + target.classList.remove('isActive'); + }); + + if (targetReference) { + const $target = $element.querySelector( + `.reference-${targetReference.referenceNum}`, + ); + $target.classList.add('isActive'); + $element.scrollTo({ top: $target.offsetTop - 100 }); + } + } + }, [targetReference]); + + const hoverReference = (reference?: Reference) => { + onHoverReference && onHoverReference(reference); + }; + + return ( +
+ +
+ ); +}; + +function ReferenceSideFloat(props: Props, ref) { + const { references, error, onTriggerExplanation, onHoverReference } = props; + const [collapse, setCollapse] = useState(false); + const [targetReference, setTargetReference] = useState( + null, + ); + + const triggerHighlight = (reference: Reference) => { + setCollapse(true); + setTargetReference(reference); + }; + + useImperativeHandle(ref, () => ({ + triggerHighlight, + })); + + const referencesSummary = useMemo( + () => + references.reduce((result, reference) => { + if (!result[reference.stepIndex]) { + result[reference.stepIndex] = reference; + } + return result; + }, []), + [collapse, references], + ); + + const handleCollapse = () => { + setCollapse(!collapse); + }; + + if (error) { + // If the thread response was created before the release of the Feedback Loop Feature, + // the explanation will be migrated with an error code OLD_VERSION. + // In this case, users will need to manually trigger the explanation. + const isOldVersion = error.code === ERROR_CODES.OLD_VERSION; + const icon = isOldVersion ? : ; + const type = isOldVersion ? 'info' : 'error'; + return ( + } + onClick={onTriggerExplanation} + > + Retry + + } + /> + ); + } else if (references.length === 0) return null; + return ( + +
+ References +
+ {collapse ? ( + + ) : ( + <> + + + + )} +
+ ); +} + +export default forwardRef(ReferenceSideFloat); diff --git a/wren-ui/src/components/pages/home/thread/feedback/ReviewDrawer.tsx b/wren-ui/src/components/pages/home/thread/feedback/ReviewDrawer.tsx new file mode 100644 index 000000000..ef8e5084e --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/ReviewDrawer.tsx @@ -0,0 +1,191 @@ +import clsx from 'clsx'; +import { useCallback, useEffect, useMemo, useState } from 'react'; +import { + Button, + Drawer, + Space, + Typography, + Tag, + Input, + Popconfirm, +} from 'antd'; +import styled from 'styled-components'; +import { + EditOutlined, + DeleteOutlined, + FileTextOutlined, +} from '@ant-design/icons'; +import { DrawerAction } from '@/hooks/useDrawerAction'; +import { makeIterable } from '@/utils/iteration'; +import { getReferenceIcon, Reference } from './utils'; +import { CreateCorrectedThreadResponseInput } from '@/apollo/client/graphql/__types__'; + +type Props = DrawerAction & { + references: Reference[]; + threadResponseId: number; + onSaveCorrectionPrompt: (id: string, value: string) => void; + onRemoveCorrectionPrompt: (id: string) => void; + onResetAllCorrectionPrompts: () => void; +}; + +const StyledOriginal = styled.div` + cursor: pointer; + background-color: var(--gray-3); + &:hover { + background-color: var(--gray-4); + } +`; + +const ReviewTemplate = ({ + type, + summary, + referenceId, + referenceNum, + correctionPrompt, + saveCorrectionPrompt, + removeCorrectionPrompt, +}) => { + const [isEdit, setIsEdit] = useState(false); + const [isCollapse, setIsCollapse] = useState(false); + const isRevise = !!correctionPrompt; + + const openEdit = async () => { + setIsEdit(!isEdit); + }; + + const openDelete = () => { + removeCorrectionPrompt(referenceId); + }; + + const handleEdit = (event) => { + saveCorrectionPrompt(referenceId, event.target.value); + setIsEdit(false); + }; + + return ( +
+ setIsCollapse(true)} + > + + + Reference: + + + {summary} + + +
+
+ + {getReferenceIcon(type)} + {referenceNum} + +
+
+ + {isEdit ? ( + + ) : ( + correctionPrompt + )} + {!isEdit && ( + + + + + + + )} + +
+
+
+ ); +}; + +const ReviewIterator = makeIterable(ReviewTemplate); + +export default function ReviewDrawer(props: Props) { + const { + visible, + threadResponseId, + references, + onClose, + onSubmit, + onSaveCorrectionPrompt, + onRemoveCorrectionPrompt, + onResetAllCorrectionPrompts, + } = props; + + const changedReferences = useMemo(() => { + return (references || []).filter( + (reference) => !!reference.correctionPrompt, + ); + }, [references]); + + useEffect(() => { + if (changedReferences.length === 0) { + onClose(); + } + }, [changedReferences]); + + const submit = useCallback(async () => { + try { + const data: CreateCorrectedThreadResponseInput = { + responseId: threadResponseId, + corrections: changedReferences.map((reference) => ({ + id: reference.referenceId, + type: reference.type, + reference: reference.summary, + referenceNum: reference.referenceNum, + stepIndex: reference.stepIndex, + correction: reference.correctionPrompt, + })), + }; + await onSubmit(data); + onClose(); + onResetAllCorrectionPrompts(); + } catch (error) { + console.error(error); + } + }, [changedReferences]); + + return ( + + + + + } + > + + + ); +} diff --git a/wren-ui/src/components/pages/home/thread/feedback/SQLHighlight.tsx b/wren-ui/src/components/pages/home/thread/feedback/SQLHighlight.tsx new file mode 100644 index 000000000..6bf32e5ce --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/SQLHighlight.tsx @@ -0,0 +1,215 @@ +import { useEffect, useMemo, useRef } from 'react'; +import clsx from 'clsx'; +import { Tag } from 'antd'; +import { groupBy } from 'lodash'; +import styled from 'styled-components'; +import { getReferenceIcon, Reference } from './utils'; +import { getTokenizer } from '@/components/editor/CodeBlock'; + +const SQLWrapper = styled.div` + position: absolute; + top: 0; + left: 28px; + right: 0; + z-index: 1; + font-size: 14px; + color: var(--gray-9); + margin: 0 3px; + + .sqlHighlight__line { + background-color: white; + height: 22px; + } + + .sqlHighlight__block { + position: relative; + &:hover, + &.isActive { + mark { + background-color: rgba(250, 219, 20, 0.3); + } + } + } + + mark { + cursor: pointer; + position: relative; + color: currentColor; + background-color: transparent; + border-bottom: 1px dashed var(--gray-5); + padding: 2px 0; + } + + .sqlHighlight__tags { + user-select: none; + padding: 0 4px; + + &:after { + content: ''; + vertical-align: middle; + } + + .ant-tag { + cursor: pointer; + margin-right: 0; + vertical-align: middle; + + .ant-tag { + margin-left: 4px; + } + } + } +`; + +interface Props { + sql: string; + references: Reference[]; + targetReference?: Reference; + onHighlightHover?: (reference: Reference) => void; +} + +const optimizedSnippet = (snippet: string) => { + // SQL analysis may add more spaces and add brackets to the sql, so we need to handle it. + return snippet + .replace(/\(/g, '\\(?') + .replace(/\)/g, '\\)?') + .replace(/\s/g, '\\s*'); +}; + +const createSnippetsRegex = (snippets: string[]) => { + return new RegExp(`(${snippets.join('|')})`, 'gi'); +}; + +const _printUnmatchedReferences = ( + references: Reference[], + referenceMatches, +) => { + // For debugging purpose + const matchesReferences = referenceMatches.flat(); + const unmatchedReferences = references.filter( + (reference) => + !matchesReferences.find((r) => r.referenceNum === reference.referenceNum), + ); + if (unmatchedReferences.length > 0) + console.log('Unmatched references:', unmatchedReferences); +}; + +export default function SQLHighlight(props: Props) { + const { sql, references, targetReference, onHighlightHover } = props; + const $wrapper = useRef(null); + + useEffect(() => { + if ($wrapper.current) { + const $element = $wrapper.current; + const $targets = $element.querySelectorAll(`.isActive`); + $targets.forEach((target) => { + target.classList.remove('isActive'); + }); + if (targetReference) { + const $target = $wrapper.current.querySelector( + `.reference-${targetReference.referenceNum}`, + ); + if (!$target) return; + $target.classList.add('isActive'); + } + } + }, [targetReference]); + + const sqlArray = useMemo(() => sql.split('\n'), [sql]); + const referenceGroups = useMemo(() => { + const filteredReferences = references + .filter((reference) => reference.sqlLocation) + .map((reference) => ({ + ...reference, + sqlSnippet: reference.sqlSnippet + ? optimizedSnippet(reference.sqlSnippet) + : reference.sqlSnippet, + })); + return groupBy( + filteredReferences, + (reference) => reference.sqlLocation.line, + ); + }, [references]); + + const hoverHighlight = (reference?: Reference) => { + onHighlightHover && onHighlightHover(reference); + }; + const highlights = []; + const referenceMatches = []; + const tokenize = getTokenizer(); + Object.keys(referenceGroups).forEach((line) => { + const lineIndex = Number(line) - 1; + const lineReferences = referenceGroups[line]; + const snippets = lineReferences.map((r) => r.sqlSnippet); + const regex = createSnippetsRegex(snippets); + const parts = sqlArray[lineIndex].split(regex); + + // Add to highlights if the part is matched + highlights[lineIndex] = parts.map((part, index) => { + const tokens = tokenize(part); + const tokenizedPart = tokens.map((token, tokenIndex) => { + const classNames = token.type.split('.').map((name) => `ace_${name}`); + return ( + + {token.value} + + ); + }); + if (regex.test(part)) { + const matchedReferences = lineReferences.filter((reference) => + new RegExp(reference.sqlSnippet, 'i').test(part), + ); + const tags = matchedReferences.map((reference) => { + return ( + + + {getReferenceIcon(reference.type)} + + {reference.referenceNum} + + ); + }); + // Record the matched references + referenceMatches.push(matchedReferences); + const reference = matchedReferences[0]; + return ( + hoverHighlight(reference)} + onMouseLeave={() => hoverHighlight()} + key={index} + > + {tokenizedPart} + {tags && {tags}} + + ); + } + return {tokenizedPart}; + }); + }); + + const content = sqlArray.map((line, index) => { + if (highlights[index]) { + return ( +
+ {highlights[index]} +
+ ); + } + return ( +
+ {line} +
+ ); + }); + + // For debugging purpose + // _printUnmatchedReferences(references, referenceMatches); + + return {content}; +} diff --git a/wren-ui/src/components/pages/home/thread/feedback/index.tsx b/wren-ui/src/components/pages/home/thread/feedback/index.tsx new file mode 100644 index 000000000..814f9a40f --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/index.tsx @@ -0,0 +1,147 @@ +import { createContext, useContext, useMemo, useRef, useState } from 'react'; +import { sortBy } from 'lodash'; +import { Skeleton } from 'antd'; +import ReferenceSideFloat from '@/components/pages/home/thread/feedback/ReferenceSideFloat'; +import FeedbackSideFloat from '@/components/pages/home/thread/feedback/FeedbackSideFloat'; +import ReviewDrawer from '@/components/pages/home/thread/feedback/ReviewDrawer'; +import useDrawerAction from '@/hooks/useDrawerAction'; +import { ThreadResponse } from '@/apollo/client/graphql/__types__'; +import { Reference, REFERENCE_ORDERS } from './utils'; +import { getIsExplainFinished } from '@/hooks/useAskPrompt'; + +type ContextProps = { + references: Reference[]; + sqlTargetReference?: Reference; + onHighlightToReferences: (target?: Reference) => void; +} | null; + +export const FeedbackContext = createContext({ + references: [], + sqlTargetReference: null, + onHighlightToReferences: () => {}, +}); + +export const useFeedbackContext = () => { + return useContext(FeedbackContext); +}; + +interface Props { + headerSlot: React.ReactNode; + bodySlot: React.ReactNode; + threadResponse: ThreadResponse; + onSubmitReviewDrawer: (variables: any) => Promise; + onTriggerThreadResponseExplain: (variables: any) => Promise; +} + +export default function Feedback(props: Props) { + const { + headerSlot, + bodySlot, + threadResponse, + onSubmitReviewDrawer, + onTriggerThreadResponseExplain, + } = props; + + const reviewSideFloat = useRef(null); + const [sqlTargetReference, setSqlTargetReference] = + useState(); + const [correctionPrompts, setCorrectionPrompts] = useState({}); + const reviewDrawer = useDrawerAction(); + + const saveCorrectionPrompt = (id: string, value: string) => { + setCorrectionPrompts({ ...correctionPrompts, [id]: value }); + }; + + const removeCorrectionPrompt = (id: string) => { + setCorrectionPrompts({ ...correctionPrompts, [id]: undefined }); + }; + + const resetAllCorrectionPrompts = () => { + setCorrectionPrompts({}); + }; + + const triggerExplanation = () => { + onTriggerThreadResponseExplain({ responseId: threadResponse.id }); + }; + + const hoverReference = (reference?: Reference) => { + setSqlTargetReference(reference); + }; + + const loading = useMemo( + () => !getIsExplainFinished(threadResponse?.explain?.status), + [threadResponse?.explain?.status], + ); + const error = useMemo(() => { + return threadResponse?.explain?.error || null; + }, [threadResponse?.explain?.error]); + const references = useMemo(() => { + if (!threadResponse?.detail) return []; + const result = threadResponse.detail.steps.flatMap((step, index) => { + if (step.references === null) return []; + return step.references.map((reference) => ({ + ...reference, + stepIndex: index, + correctionPrompt: correctionPrompts[reference.referenceId], + })); + }); + // Generate reference number for each reference + return sortBy(result, (reference) => + REFERENCE_ORDERS.indexOf(reference.type), + ).map((reference, index) => ({ + referenceNum: index + 1, + ...reference, + })); + }, [threadResponse?.detail, correctionPrompts]); + + const contextValue = { + references, + sqlTargetReference, + onHighlightToReferences: (target) => { + if (reviewSideFloat.current) { + reviewSideFloat.current?.triggerHighlight(target); + } + }, + }; + + return ( + +
+ {headerSlot} +
+ +
+
+
+ {bodySlot} +
+ + + +
+
+ +
+ ); +} diff --git a/wren-ui/src/components/pages/home/thread/feedback/utils.tsx b/wren-ui/src/components/pages/home/thread/feedback/utils.tsx new file mode 100644 index 000000000..efc141d82 --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/utils.tsx @@ -0,0 +1,48 @@ +import { + FilterOutlined, + SortAscendingOutlined, + GroupOutlined, +} from '@ant-design/icons'; +import { ColumnsIcon, ModelIcon } from '@/utils/icons'; +import { + DetailReference, + ReferenceType, +} from '@/apollo/client/graphql/__types__'; + +export type Reference = DetailReference & { + stepIndex: number; + referenceNum: number; + correctionPrompt?: string; +}; + +export const REFERENCE_ORDERS = [ + ReferenceType.FIELD, + ReferenceType.QUERY_FROM, + ReferenceType.FILTER, + ReferenceType.SORTING, + ReferenceType.GROUP_BY, +]; + +export const getReferenceName = (type: ReferenceType) => { + return ( + { + [ReferenceType.FIELD]: 'Fields', + [ReferenceType.QUERY_FROM]: 'Query from', + [ReferenceType.FILTER]: 'Filter', + [ReferenceType.SORTING]: 'Sorting', + [ReferenceType.GROUP_BY]: 'Group by', + }[type] || null + ); +}; + +export const getReferenceIcon = (type) => { + return ( + { + [ReferenceType.FIELD]: , + [ReferenceType.QUERY_FROM]: , + [ReferenceType.FILTER]: , + [ReferenceType.SORTING]: , + [ReferenceType.GROUP_BY]: , + }[type] || null + ); +}; diff --git a/wren-ui/src/components/pages/home/thread/index.tsx b/wren-ui/src/components/pages/home/thread/index.tsx new file mode 100644 index 000000000..c0c681084 --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/index.tsx @@ -0,0 +1,118 @@ +import React, { useEffect, useRef } from 'react'; +import { Divider } from 'antd'; +import styled from 'styled-components'; +import AnswerResult from './AnswerResult'; +import { IterableComponent, makeIterable } from '@/utils/iteration'; +import { + DetailedThread, + ThreadResponse, +} from '@/apollo/client/graphql/__types__'; + +interface Props { + data: DetailedThread; + onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; + onSubmitReviewDrawer: (variables: any) => Promise; + onTriggerThreadResponseExplain: (variables: any) => Promise; +} + +const StyledThread = styled.div` + h4.ant-typography { + margin-top: 10px; + } + + .ant-typography pre { + border: none; + border-radius: 4px; + } + + button { + vertical-align: middle; + } +`; + +const StyledContainer = styled.div` + max-width: 1030px; +`; + +const AnswerResultTemplate: React.FC< + IterableComponent & { + onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; + onTriggerScrollToBottom: () => void; + onSubmitReviewDrawer: (variables: any) => Promise; + onTriggerThreadResponseExplain: (variables: any) => Promise; + } +> = ({ + onOpenSaveAsViewModal, + onTriggerScrollToBottom, + onSubmitReviewDrawer, + onTriggerThreadResponseExplain, + data, + index, + ...threadResponse +}) => { + const lastResponseId = data[data.length - 1].id; + const isLastThreadResponse = threadResponse.id === lastResponseId; + const { id } = threadResponse; + + return ( + + {index > 0 && } + + + ); +}; + +const AnswerResultIterator = makeIterable(AnswerResultTemplate); + +export default function Thread(props: Props) { + const { + data, + onOpenSaveAsViewModal, + onSubmitReviewDrawer, + onTriggerThreadResponseExplain, + } = props; + const divRef = useRef(null); + + const triggerScrollToBottom = () => { + const contentLayout = divRef.current?.parentElement; + if (!contentLayout) return; + const lastChild = divRef.current.lastElementChild as HTMLElement; + const lastChildElement = lastChild.lastElementChild as HTMLElement; + + if ( + contentLayout.clientHeight < + lastChild.offsetTop + lastChild.clientHeight + ) { + contentLayout.scrollTo({ + top: lastChildElement.offsetTop, + behavior: 'smooth', + }); + } + }; + + useEffect(() => { + if (divRef.current && data?.responses.length > 0) { + triggerScrollToBottom(); + } + }, [divRef, data]); + + return ( + + record.id} + onOpenSaveAsViewModal={onOpenSaveAsViewModal} + onTriggerScrollToBottom={triggerScrollToBottom} + onSubmitReviewDrawer={onSubmitReviewDrawer} + onTriggerThreadResponseExplain={onTriggerThreadResponseExplain} + /> + + ); +} diff --git a/wren-ui/src/hooks/useAskPrompt.tsx b/wren-ui/src/hooks/useAskPrompt.tsx index f6c2e6249..698a0c378 100644 --- a/wren-ui/src/hooks/useAskPrompt.tsx +++ b/wren-ui/src/hooks/useAskPrompt.tsx @@ -1,18 +1,44 @@ import { useEffect, useMemo } from 'react'; -import { AskingTaskStatus } from '@/apollo/client/graphql/__types__'; +import { + AskingTaskStatus, + ExplainTaskStatus, +} from '@/apollo/client/graphql/__types__'; import { useAskingTaskLazyQuery, useCancelAskingTaskMutation, useCreateAskingTaskMutation, } from '@/apollo/client/graphql/home.generated'; -export const getIsFinished = (status: AskingTaskStatus) => +export const getIsExplainFinished = (status: ExplainTaskStatus) => + [ExplainTaskStatus.FINISHED, ExplainTaskStatus.FAILED].includes(status); + +export const getIsAskingFinished = (status: AskingTaskStatus) => [ AskingTaskStatus.FINISHED, AskingTaskStatus.FAILED, AskingTaskStatus.STOPPED, ].includes(status); +export const checkExplainExisted = (explain?: { + queryId?: string; + status?: ExplainTaskStatus; +}) => { + // if the queryId is not empty, it means the question is explainable + return !!explain?.queryId ? explain.status : undefined; +}; + +export const getIsFinished = ( + askingStatus: AskingTaskStatus, + explainStatus?: ExplainTaskStatus, +) => { + const isAskingFinished = getIsAskingFinished(askingStatus); + if (explainStatus !== undefined) { + const isExplainFinished = getIsExplainFinished(explainStatus); + return isAskingFinished && isExplainFinished; + } + return isAskingFinished; +}; + export default function useAskPrompt(threadId?: number) { const [createAskingTask, createAskingTaskResult] = useCreateAskingTaskMutation(); diff --git a/wren-ui/src/pages/api/graphql.ts b/wren-ui/src/pages/api/graphql.ts index 830f5a8d6..b0ed3246a 100644 --- a/wren-ui/src/pages/api/graphql.ts +++ b/wren-ui/src/pages/api/graphql.ts @@ -33,6 +33,7 @@ import { DataSourceMetadataService, QueryService, } from '@/apollo/server/services'; +import { ThreadResponseExplainRepository } from '@/apollo/server/repositories/threadResponseExplainRepository'; const serverConfig = getConfig(); const logger = getLogger('APOLLO'); @@ -63,6 +64,9 @@ const bootstrapServer = async () => { const deployLogRepository = new DeployLogRepository(knex); const threadRepository = new ThreadRepository(knex); const threadResponseRepository = new ThreadResponseRepository(knex); + const threadResponseExplainRepository = new ThreadResponseExplainRepository( + knex, + ); const viewRepository = new ViewRepository(knex); const schemaChangeRepository = new SchemaChangeRepository(knex); @@ -115,11 +119,13 @@ const bootstrapServer = async () => { const askingService = new AskingService({ telemetry, wrenAIAdaptor, + ibisAdaptor, deployService, projectService, viewRepository, threadRepository, threadResponseRepository, + threadResponseExplainRepository, queryService, }); @@ -178,6 +184,7 @@ const bootstrapServer = async () => { viewRepository, deployRepository: deployLogRepository, schemaChangeRepository, + threadResponseExplainRepository, }), }); await apolloServer.start(); diff --git a/wren-ui/src/pages/home/[id].tsx b/wren-ui/src/pages/home/[id].tsx index 011fa7e6d..2d7f20fb2 100644 --- a/wren-ui/src/pages/home/[id].tsx +++ b/wren-ui/src/pages/home/[id].tsx @@ -1,21 +1,32 @@ import { useRouter } from 'next/router'; import { useParams } from 'next/navigation'; -import { useEffect, useMemo } from 'react'; +import { useCallback, useEffect, useMemo } from 'react'; import { message } from 'antd'; import { Path } from '@/utils/enum'; import useHomeSidebar from '@/hooks/useHomeSidebar'; import SiderLayout from '@/components/layouts/SiderLayout'; import Prompt from '@/components/pages/home/prompt'; import { + useCreateCorrectedThreadResponseMutation, + useCreateThreadResponseExplainMutation, useCreateThreadResponseMutation, useThreadQuery, useThreadResponseLazyQuery, } from '@/apollo/client/graphql/home.generated'; -import useAskPrompt, { getIsFinished } from '@/hooks/useAskPrompt'; +import useAskPrompt, { + getIsFinished, + checkExplainExisted, +} from '@/hooks/useAskPrompt'; import useModalAction from '@/hooks/useModalAction'; -import PromptThread from '@/components/pages/home/promptThread'; +import Thread from '@/components/pages/home/thread'; import SaveAsViewModal from '@/components/modals/SaveAsViewModal'; import { useCreateViewMutation } from '@/apollo/client/graphql/view.generated'; +import { + CreateCorrectedThreadResponseInput, + CreateThreadResponseExplainWhereInput, + CreateThreadResponseInput, + ThreadResponse, +} from '@/apollo/client/graphql/__types__'; export default function HomeThread() { const router = useRouter(); @@ -39,18 +50,24 @@ export default function HomeThread() { skip: threadId === null, onError: () => router.push(Path.Home), }); + const addThreadResponse = (nextResponse) => { + updateThreadQuery((prev) => { + return { + ...prev, + thread: { + ...prev.thread, + responses: [...prev.thread.responses, nextResponse], + }, + }; + }); + }; + const [createThreadResponseExplain] = useCreateThreadResponseExplainMutation({ + onError: (error) => console.error(error), + }); const [createThreadResponse] = useCreateThreadResponseMutation({ onCompleted(next) { const nextResponse = next.createThreadResponse; - updateThreadQuery((prev) => { - return { - ...prev, - thread: { - ...prev.thread, - responses: [...prev.thread.responses, nextResponse], - }, - }; - }); + addThreadResponse(nextResponse); }, }); const [fetchThreadResponse, threadResponseResult] = @@ -69,6 +86,13 @@ export default function HomeThread() { })); }, }); + const [createRegeneratedThreadResponse] = + useCreateCorrectedThreadResponseMutation({ + onCompleted(next) { + const nextResponse = next.createCorrectedThreadResponse; + addThreadResponse(nextResponse); + }, + }); const thread = useMemo(() => data?.thread || null, [data]); const threadResponse = useMemo( @@ -76,25 +100,63 @@ export default function HomeThread() { [threadResponseResult.data], ); const isFinished = useMemo( - () => getIsFinished(threadResponse?.status), + () => + getIsFinished( + threadResponse?.status, + checkExplainExisted(threadResponse?.explain), + ), [threadResponse], ); + const startThreadResponseExplanation = useCallback( + (threadResponse: ThreadResponse) => { + const isSuccessBreakdown = threadResponse?.error === null; + const isExplainable = + threadResponse?.explain && + threadResponse?.explain?.error === null && + threadResponse?.explain.queryId === null; + if (isSuccessBreakdown && isExplainable) { + createThreadResponseExplain({ + variables: { where: { responseId: threadResponse.id } }, + }).then(() => + fetchThreadResponse({ variables: { responseId: threadResponse.id } }), + ); + } + }, + [], + ); + useEffect(() => { - const unfinishedRespose = (thread?.responses || []).find( - (response) => !getIsFinished(response.status), + const unfinishedResposes = (thread?.responses || []).filter( + (response) => + !getIsFinished(response.status, checkExplainExisted(response?.explain)), ); + if (!!unfinishedResposes.length) { + for (const response of unfinishedResposes) { + fetchThreadResponse({ variables: { responseId: response.id } }); + } + // Explanation will be triggered by polling process + return; + } - if (unfinishedRespose) { - fetchThreadResponse({ variables: { responseId: unfinishedRespose.id } }); + // If all responses are finished, we need to check if there is any explanation that is not started + const unfinishedExplanations = (thread?.responses || []).filter( + (response) => response.explain?.queryId === null, + ); + for (const response of unfinishedExplanations) { + startThreadResponseExplanation(response); } }, [thread]); useEffect(() => { - if (isFinished) threadResponseResult.stopPolling(); - }, [isFinished]); + if (isFinished) { + threadResponseResult.stopPolling(); - const onSelect = async (payload) => { + startThreadResponseExplanation(threadResponse); + } + }, [isFinished, threadResponse]); + + const onSelect = async (payload: CreateThreadResponseInput) => { try { const response = await createThreadResponse({ variables: { threadId: thread.id, data: payload }, @@ -107,11 +169,43 @@ export default function HomeThread() { } }; + const onSubmitReviewDrawer = async ( + payload: CreateCorrectedThreadResponseInput, + ) => { + try { + const response = await createRegeneratedThreadResponse({ + variables: { threadId: thread.id, data: payload }, + }); + await fetchThreadResponse({ + variables: { + responseId: response.data.createCorrectedThreadResponse.id, + }, + }); + } catch (error) { + throw error; + } + }; + + const onTriggerThreadResponseExplain = async ( + payload: CreateThreadResponseExplainWhereInput, + ) => { + try { + await createThreadResponseExplain({ + variables: { where: payload }, + }); + await fetchThreadResponse({ variables: payload }); + } catch (error) { + console.error(error); + } + }; + return ( -